swiftide_integrations/qdrant/
retrieve.rs

1use qdrant_client::qdrant::{self, PrefetchQueryBuilder, ScoredPoint, SearchPointsBuilder};
2use swiftide_core::{
3    Retrieve,
4    document::Document,
5    indexing::{EmbeddedField, Metadata},
6    prelude::{Result, *},
7    querying::{
8        Query,
9        search_strategies::{HybridSearch, SimilaritySingleEmbedding},
10        states,
11    },
12};
13
14use super::Qdrant;
15
16/// Implement the `Retrieve` trait for `SimilaritySingleEmbedding` search strategy.
17///
18/// Can be used in the query pipeline to retrieve documents from Qdrant.
19///
20/// Supports filters via the `qdrant_client::qdrant::Filter` type.
21#[async_trait]
22impl Retrieve<SimilaritySingleEmbedding<qdrant::Filter>> for Qdrant {
23    #[tracing::instrument]
24    async fn retrieve(
25        &self,
26        search_strategy: &SimilaritySingleEmbedding<qdrant::Filter>,
27        query: Query<states::Pending>,
28    ) -> Result<Query<states::Retrieved>> {
29        let Some(embedding) = &query.embedding else {
30            anyhow::bail!("No embedding for query")
31        };
32        let mut query_builder = SearchPointsBuilder::new(
33            &self.collection_name,
34            embedding.to_owned(),
35            search_strategy.top_k(),
36        )
37        .with_payload(true);
38
39        if let Some(filter) = &search_strategy.filter() {
40            query_builder = query_builder.filter(filter.to_owned());
41        }
42
43        if self.vectors.len() > 1 || !self.sparse_vectors.is_empty() {
44            // TODO: Make this configurable
45            // It will break if there are multiple vectors and no combined vector
46            query_builder = query_builder.vector_name(EmbeddedField::Combined.field_name());
47        }
48
49        let result = self
50            .client
51            .search_points(query_builder.build())
52            .await
53            .context("Failed to retrieve from qdrant")?
54            .result;
55
56        let documents = result
57            .into_iter()
58            .map(scored_point_into_document)
59            .collect::<Result<Vec<_>>>()?;
60
61        Ok(query.retrieved_documents(documents))
62    }
63}
64
65/// Ensures that the `SimilaritySingleEmbedding` search strategy can be used when no filter is set.
66#[async_trait]
67impl Retrieve<SimilaritySingleEmbedding> for Qdrant {
68    async fn retrieve(
69        &self,
70        search_strategy: &SimilaritySingleEmbedding,
71        query: Query<states::Pending>,
72    ) -> Result<Query<states::Retrieved>> {
73        Retrieve::<SimilaritySingleEmbedding<qdrant::Filter>>::retrieve(
74            self,
75            &search_strategy.into_concrete_filter::<qdrant::Filter>(),
76            query,
77        )
78        .await
79    }
80}
81
82/// Implement the `Retrieve` trait for `HybridSearch` search strategy.
83///
84/// Can be used in the query pipeline to retrieve documents from Qdrant.
85///
86/// Expects both a dense and sparse embedding to be set on the query.
87#[async_trait]
88impl Retrieve<HybridSearch<qdrant::Filter>> for Qdrant {
89    #[tracing::instrument]
90    async fn retrieve(
91        &self,
92        search_strategy: &HybridSearch<qdrant::Filter>,
93        query: Query<states::Pending>,
94    ) -> Result<Query<states::Retrieved>> {
95        let Some(dense) = &query.embedding else {
96            anyhow::bail!("No embedding for query")
97        };
98
99        let Some(sparse) = &query.sparse_embedding else {
100            anyhow::bail!("No sparse embedding for query")
101        };
102
103        let mut sparse_prefetch = PrefetchQueryBuilder::default()
104            .query(qdrant::Query::new_nearest(qdrant::VectorInput::new_sparse(
105                sparse.indices.clone(),
106                sparse.values.clone(),
107            )))
108            .using(search_strategy.sparse_vector_field().sparse_field_name())
109            .limit(search_strategy.top_n());
110
111        let mut dense_prefetch = PrefetchQueryBuilder::default()
112            .query(qdrant::Query::new_nearest(dense.clone()))
113            .using(search_strategy.dense_vector_field().field_name())
114            .limit(search_strategy.top_n());
115
116        if let Some(filter) = search_strategy.filter() {
117            sparse_prefetch = sparse_prefetch.filter(filter.clone());
118            dense_prefetch = dense_prefetch.filter(filter.clone());
119        }
120
121        let query_points = qdrant::QueryPointsBuilder::new(&self.collection_name)
122            .with_payload(true)
123            .add_prefetch(sparse_prefetch)
124            .add_prefetch(dense_prefetch)
125            .query(qdrant::Query::new_fusion(qdrant::Fusion::Rrf))
126            .limit(search_strategy.top_k());
127
128        // NOTE: Potential improvement to consume the vectors instead of cloning
129        let result = self.client.query(query_points).await?.result;
130
131        let documents = result
132            .into_iter()
133            .map(scored_point_into_document)
134            .collect::<Result<Vec<_>>>()?;
135
136        Ok(query.retrieved_documents(documents))
137    }
138}
139
140fn scored_point_into_document(scored_point: ScoredPoint) -> Result<Document> {
141    let content = scored_point
142        .payload
143        .get("content")
144        .context("Expected document in qdrant payload")?
145        .to_string();
146
147    let metadata: Metadata = scored_point
148        .payload
149        .into_iter()
150        .filter(|(k, _)| *k != "content")
151        .collect::<Vec<(_, _)>>()
152        .into();
153
154    Ok(Document::new(content, Some(metadata)))
155}
156
157#[cfg(test)]
158mod tests {
159    use itertools::Itertools as _;
160    use swiftide_core::{
161        Persist as _,
162        indexing::{self, EmbeddedField},
163    };
164
165    use super::*;
166
167    async fn setup() -> (
168        testcontainers::ContainerAsync<testcontainers::GenericImage>,
169        Qdrant,
170    ) {
171        let (guard, qdrant_url) = swiftide_test_utils::start_qdrant().await;
172
173        let qdrant_client = Qdrant::try_from_url(qdrant_url)
174            .unwrap()
175            .vector_size(384)
176            .with_vector(EmbeddedField::Combined)
177            .with_sparse_vector(EmbeddedField::Combined)
178            .build()
179            .unwrap();
180
181        qdrant_client.setup().await.unwrap();
182
183        let nodes = vec![
184            indexing::TextNode::new("test_query1").with_metadata(("filter", "true")),
185            indexing::TextNode::new("test_query2").with_metadata(("filter", "true")),
186            indexing::TextNode::new("test_query3").with_metadata(("filter", "false")),
187        ]
188        .into_iter()
189        .map(|node| {
190            node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
191            node.with_sparse_vectors([(
192                EmbeddedField::Combined,
193                swiftide_core::SparseEmbedding {
194                    indices: vec![0, 1],
195                    values: vec![1.0, 1.0],
196                },
197            )]);
198            node.to_owned()
199        })
200        .collect();
201
202        qdrant_client
203            .batch_store(nodes)
204            .await
205            .try_collect::<Vec<_>>()
206            .await
207            .unwrap();
208
209        (guard, qdrant_client)
210    }
211
212    #[test_log::test(tokio::test)]
213    async fn test_retrieve_multiple_docs_and_filter() {
214        let (_guard, qdrant_client) = setup().await;
215
216        let mut query = Query::<states::Pending>::new("test_query");
217        query.embedding = Some(vec![1.0; 384]);
218
219        let search_strategy = SimilaritySingleEmbedding::<()>::default();
220        let result = qdrant_client
221            .retrieve(&search_strategy, query.clone())
222            .await
223            .unwrap();
224        assert_eq!(result.documents().len(), 3);
225        assert_eq!(
226            result
227                .documents()
228                .iter()
229                .sorted()
230                .map(Document::content)
231                .collect_vec(),
232            // FIXME: The extra quotes should be removed by serde (via qdrant::Value), but they are
233            // not
234            ["\"test_query1\"", "\"test_query2\"", "\"test_query3\""]
235                .into_iter()
236                .sorted()
237                .collect_vec()
238        );
239
240        let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([
241            qdrant::Condition::matches("filter", "true".to_string()),
242        ]));
243        let result = qdrant_client
244            .retrieve(&search_strategy, query.clone())
245            .await
246            .unwrap();
247        assert_eq!(result.documents().len(), 2);
248        assert_eq!(
249            result
250                .documents()
251                .iter()
252                .sorted()
253                .map(Document::content)
254                .collect_vec(),
255            ["\"test_query1\"", "\"test_query2\""]
256                .into_iter()
257                .sorted()
258                .collect_vec()
259        );
260
261        let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([
262            qdrant::Condition::matches("filter", "banana".to_string()),
263        ]));
264        let result = qdrant_client
265            .retrieve(&search_strategy, query.clone())
266            .await
267            .unwrap();
268        assert_eq!(result.documents().len(), 0);
269    }
270
271    #[tokio::test]
272    async fn test_hybrid_search() {
273        let (_guard, qdrant_client) = setup().await;
274        let mut query = Query::<states::Pending>::new("test_query");
275
276        query.embedding = Some(vec![1.0; 384]);
277        query.sparse_embedding = Some(swiftide_core::SparseEmbedding {
278            indices: vec![0, 1],
279            values: vec![1.0, 1.0],
280        });
281        let search_strategy = HybridSearch::default();
282        let result = qdrant_client
283            .retrieve(&search_strategy, query.clone())
284            .await
285            .unwrap();
286        assert_eq!(result.documents().len(), 3);
287    }
288
289    #[tokio::test]
290    async fn test_hybrid_search_with_filter() {
291        let (_guard, qdrant_client) = setup().await;
292        let mut query = Query::<states::Pending>::new("test_query");
293
294        query.embedding = Some(vec![1.0; 384]);
295        query.sparse_embedding = Some(swiftide_core::SparseEmbedding {
296            indices: vec![0, 1],
297            values: vec![1.0, 1.0],
298        });
299        let search_strategy =
300            HybridSearch::from_filter(qdrant::Filter::must([qdrant::Condition::matches(
301                "filter",
302                "true".to_string(),
303            )]));
304        let result = qdrant_client
305            .retrieve(&search_strategy, query.clone())
306            .await
307            .unwrap();
308        assert_eq!(result.documents().len(), 2);
309    }
310}