swiftide_integrations/duckdb/
retrieve.rs

1use anyhow::{Context as _, Result};
2use async_trait::async_trait;
3use swiftide_core::{
4    Retrieve,
5    indexing::Chunk,
6    querying::{
7        Document, Query,
8        search_strategies::{CustomStrategy, HybridSearch, SimilaritySingleEmbedding},
9        states,
10    },
11};
12
13use super::Duckdb;
14
15#[async_trait]
16impl<T: Chunk> Retrieve<SimilaritySingleEmbedding> for Duckdb<T> {
17    async fn retrieve(
18        &self,
19        search_strategy: &SimilaritySingleEmbedding,
20        query: Query<states::Pending>,
21    ) -> Result<Query<states::Retrieved>> {
22        let Some(embedding) = query.embedding.as_ref() else {
23            return Err(anyhow::Error::msg("Missing embedding in query state"));
24        };
25
26        let table_name = &self.table_name;
27
28        // Silently ignores multiple vector fields
29        let (field_name, embedding_size) = self
30            .vectors
31            .iter()
32            .next()
33            .context("No vectors configured")?;
34
35        let limit = search_strategy.top_k();
36
37        // Ideally it should be a prepared statement, where only the new parameters lead to extra
38        // allocations. This is possible in 1.2.1, but that version is still broken for VSS via
39        // Rust.
40        let sql = format!(
41            "SELECT uuid, chunk, path FROM {table_name}\n
42            ORDER BY array_distance({field_name}, ARRAY[{}]::FLOAT[{embedding_size}])\n
43            LIMIT {limit}",
44            embedding
45                .iter()
46                .map(ToString::to_string)
47                .collect::<Vec<_>>()
48                .join(",")
49        );
50
51        tracing::trace!("[duckdb] Executing query: {}", sql);
52
53        let conn = self.connection().lock().unwrap();
54
55        let mut stmt = conn
56            .prepare(&sql)
57            .context("Failed to prepare duckdb statement for persist")?;
58
59        tracing::trace!("[duckdb] Retrieving documents");
60
61        let documents = stmt
62            .query_map([], |row| {
63                Ok(Document::builder()
64                    .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
65                    .content(row.get::<_, String>(1)?)
66                    .build()
67                    .expect("Failed to build document; should never happen"))
68            })
69            .context("failed to query for documents")?
70            .collect::<Result<Vec<Document>, _>>()
71            .context("failed to build documents")?;
72
73        tracing::debug!("[duckdb] Retrieved documents");
74        Ok(query.retrieved_documents(documents))
75    }
76}
77
78#[async_trait]
79impl<T: Chunk> Retrieve<CustomStrategy<String>> for Duckdb<T> {
80    async fn retrieve(
81        &self,
82        search_strategy: &CustomStrategy<String>,
83        query: Query<states::Pending>,
84    ) -> Result<Query<states::Retrieved>> {
85        let sql = search_strategy
86            .build_query(&query)
87            .await
88            .context("Failed to build query")?;
89
90        tracing::debug!("[duckdb] Executing query: {}", sql);
91
92        let conn = self.connection().lock().unwrap();
93        let mut stmt = conn
94            .prepare(&sql)
95            .context("Failed to prepare duckdb statement for persist")?;
96
97        tracing::debug!("[duckdb] Prepared statement");
98
99        let documents = stmt
100            .query_map([], |row| {
101                Ok(Document::builder()
102                    .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
103                    .content(row.get::<_, String>(1)?)
104                    .build()
105                    .expect("Failed to build document; should never happen"))
106            })
107            .context("failed to query for documents")?
108            .collect::<Result<Vec<Document>, _>>()
109            .context("failed to build documents")?;
110
111        tracing::debug!("[duckdb] Retrieved documents");
112
113        Ok(query.retrieved_documents(documents))
114    }
115}
116
117#[async_trait]
118impl<T: Chunk> Retrieve<HybridSearch> for Duckdb<T> {
119    async fn retrieve(
120        &self,
121        search_strategy: &HybridSearch,
122        query: Query<states::Pending>,
123    ) -> Result<Query<states::Retrieved>> {
124        let Some(embedding) = query.embedding.as_ref() else {
125            return Err(anyhow::Error::msg("Missing embedding in query state"));
126        };
127
128        let sql = self
129            .hybrid_query_sql(search_strategy, query.current(), embedding)
130            .context("Failed to build query")?;
131
132        tracing::debug!("[duckdb] Executing query: {}", sql);
133
134        let conn = self.connection().lock().unwrap();
135        let mut stmt = conn
136            .prepare(&sql)
137            .context("Failed to prepare duckdb statement for persist")?;
138
139        tracing::debug!("[duckdb] Prepared statement");
140
141        let documents = stmt
142            // DuckDB has issues with using `params!` :(
143            .query_map([], |row| {
144                Ok(Document::builder()
145                    .metadata([("id", row.get::<_, String>(0)?), ("path", row.get(2)?)])
146                    .content(row.get::<_, String>(1)?)
147                    .build()
148                    .expect("Failed to build document; should never happen"))
149            })
150            .context("failed to query for documents")?
151            .collect::<Result<Vec<Document>, _>>()
152            .context("failed to build documents")?;
153
154        tracing::debug!("[duckdb] Retrieved documents");
155
156        Ok(query.retrieved_documents(documents))
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use indexing::{EmbeddedField, TextNode};
163    use swiftide_core::{Persist as _, indexing};
164
165    use super::*;
166
167    #[test_log::test(tokio::test)]
168    async fn test_duckdb_retrieving_documents() {
169        let client = Duckdb::builder()
170            .connection(duckdb::Connection::open_in_memory().unwrap())
171            .table_name("test".to_string())
172            .with_vector(EmbeddedField::Combined, 3)
173            .build()
174            .unwrap();
175
176        let node = TextNode::new("Hello duckdb!")
177            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
178            .to_owned();
179
180        client.setup().await.unwrap();
181        client.store(node.clone()).await.unwrap();
182
183        tracing::info!("Stored node");
184
185        let query = Query::<states::Pending>::builder()
186            .embedding(vec![1.0, 2.0, 3.0])
187            .original("Some query")
188            .build()
189            .unwrap();
190
191        let result = client
192            .retrieve(&SimilaritySingleEmbedding::default(), query)
193            .await
194            .unwrap();
195
196        assert_eq!(result.documents().len(), 1);
197        let document = result.documents().first().unwrap();
198
199        assert_eq!(document.content(), "Hello duckdb!");
200        assert_eq!(
201            document.metadata().get("id").unwrap().as_str(),
202            Some(node.id().to_string().as_str())
203        );
204    }
205
206    #[test_log::test(tokio::test)]
207    async fn test_duckdb_retrieving_documents_hybrid() {
208        let client = Duckdb::builder()
209            .connection(duckdb::Connection::open_in_memory().unwrap())
210            .table_name("test".to_string())
211            .with_vector(EmbeddedField::Combined, 3)
212            .build()
213            .unwrap();
214
215        let node = TextNode::new("Hello duckdb!")
216            .with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
217            .to_owned();
218
219        client.setup().await.unwrap();
220        client.store(node.clone()).await.unwrap();
221
222        tracing::info!("Stored node");
223
224        let query = Query::<states::Pending>::builder()
225            .embedding(vec![1.0, 2.0, 3.0])
226            .original("Some query")
227            .build()
228            .unwrap();
229
230        let result = client
231            .retrieve(&HybridSearch::default(), query)
232            .await
233            .unwrap();
234
235        assert_eq!(result.documents().len(), 1);
236        let document = result.documents().first().unwrap();
237
238        assert_eq!(document.content(), "Hello duckdb!");
239        assert_eq!(
240            document.metadata().get("id").unwrap().as_str(),
241            Some(node.id().to_string().as_str())
242        );
243    }
244}