swiftide_integrations/qdrant/
retrieve.rs1use qdrant_client::qdrant::{self, PrefetchQueryBuilder, ScoredPoint, SearchPointsBuilder};
2use swiftide_core::{
3 Retrieve,
4 document::Document,
5 indexing::{EmbeddedField, Metadata},
6 prelude::{Result, *},
7 querying::{
8 Query,
9 search_strategies::{HybridSearch, SimilaritySingleEmbedding},
10 states,
11 },
12};
13
14use super::Qdrant;
15
16#[async_trait]
22impl Retrieve<SimilaritySingleEmbedding<qdrant::Filter>> for Qdrant {
23 #[tracing::instrument]
24 async fn retrieve(
25 &self,
26 search_strategy: &SimilaritySingleEmbedding<qdrant::Filter>,
27 query: Query<states::Pending>,
28 ) -> Result<Query<states::Retrieved>> {
29 let Some(embedding) = &query.embedding else {
30 anyhow::bail!("No embedding for query")
31 };
32 let mut query_builder = SearchPointsBuilder::new(
33 &self.collection_name,
34 embedding.to_owned(),
35 search_strategy.top_k(),
36 )
37 .with_payload(true);
38
39 if let Some(filter) = &search_strategy.filter() {
40 query_builder = query_builder.filter(filter.to_owned());
41 }
42
43 if self.vectors.len() > 1 || !self.sparse_vectors.is_empty() {
44 query_builder = query_builder.vector_name(EmbeddedField::Combined.field_name());
47 }
48
49 let result = self
50 .client
51 .search_points(query_builder.build())
52 .await
53 .context("Failed to retrieve from qdrant")?
54 .result;
55
56 let documents = result
57 .into_iter()
58 .map(scored_point_into_document)
59 .collect::<Result<Vec<_>>>()?;
60
61 Ok(query.retrieved_documents(documents))
62 }
63}
64
65#[async_trait]
67impl Retrieve<SimilaritySingleEmbedding> for Qdrant {
68 async fn retrieve(
69 &self,
70 search_strategy: &SimilaritySingleEmbedding,
71 query: Query<states::Pending>,
72 ) -> Result<Query<states::Retrieved>> {
73 Retrieve::<SimilaritySingleEmbedding<qdrant::Filter>>::retrieve(
74 self,
75 &search_strategy.into_concrete_filter::<qdrant::Filter>(),
76 query,
77 )
78 .await
79 }
80}
81
82#[async_trait]
88impl Retrieve<HybridSearch<qdrant::Filter>> for Qdrant {
89 #[tracing::instrument]
90 async fn retrieve(
91 &self,
92 search_strategy: &HybridSearch<qdrant::Filter>,
93 query: Query<states::Pending>,
94 ) -> Result<Query<states::Retrieved>> {
95 let Some(dense) = &query.embedding else {
96 anyhow::bail!("No embedding for query")
97 };
98
99 let Some(sparse) = &query.sparse_embedding else {
100 anyhow::bail!("No sparse embedding for query")
101 };
102
103 let mut sparse_prefetch = PrefetchQueryBuilder::default()
104 .query(qdrant::Query::new_nearest(qdrant::VectorInput::new_sparse(
105 sparse.indices.clone(),
106 sparse.values.clone(),
107 )))
108 .using(search_strategy.sparse_vector_field().sparse_field_name())
109 .limit(search_strategy.top_n());
110
111 let mut dense_prefetch = PrefetchQueryBuilder::default()
112 .query(qdrant::Query::new_nearest(dense.clone()))
113 .using(search_strategy.dense_vector_field().field_name())
114 .limit(search_strategy.top_n());
115
116 if let Some(filter) = search_strategy.filter() {
117 sparse_prefetch = sparse_prefetch.filter(filter.clone());
118 dense_prefetch = dense_prefetch.filter(filter.clone());
119 }
120
121 let query_points = qdrant::QueryPointsBuilder::new(&self.collection_name)
122 .with_payload(true)
123 .add_prefetch(sparse_prefetch)
124 .add_prefetch(dense_prefetch)
125 .query(qdrant::Query::new_fusion(qdrant::Fusion::Rrf))
126 .limit(search_strategy.top_k());
127
128 let result = self.client.query(query_points).await?.result;
130
131 let documents = result
132 .into_iter()
133 .map(scored_point_into_document)
134 .collect::<Result<Vec<_>>>()?;
135
136 Ok(query.retrieved_documents(documents))
137 }
138}
139
140fn scored_point_into_document(scored_point: ScoredPoint) -> Result<Document> {
141 let content = scored_point
142 .payload
143 .get("content")
144 .context("Expected document in qdrant payload")?
145 .to_string();
146
147 let metadata: Metadata = scored_point
148 .payload
149 .into_iter()
150 .filter(|(k, _)| *k != "content")
151 .collect::<Vec<(_, _)>>()
152 .into();
153
154 Ok(Document::new(content, Some(metadata)))
155}
156
157#[cfg(test)]
158mod tests {
159 use itertools::Itertools as _;
160 use swiftide_core::{
161 Persist as _,
162 indexing::{self, EmbeddedField},
163 };
164
165 use super::*;
166
167 async fn setup() -> (
168 testcontainers::ContainerAsync<testcontainers::GenericImage>,
169 Qdrant,
170 ) {
171 let (guard, qdrant_url) = swiftide_test_utils::start_qdrant().await;
172
173 let qdrant_client = Qdrant::try_from_url(qdrant_url)
174 .unwrap()
175 .vector_size(384)
176 .with_vector(EmbeddedField::Combined)
177 .with_sparse_vector(EmbeddedField::Combined)
178 .build()
179 .unwrap();
180
181 qdrant_client.setup().await.unwrap();
182
183 let nodes = vec![
184 indexing::TextNode::new("test_query1").with_metadata(("filter", "true")),
185 indexing::TextNode::new("test_query2").with_metadata(("filter", "true")),
186 indexing::TextNode::new("test_query3").with_metadata(("filter", "false")),
187 ]
188 .into_iter()
189 .map(|node| {
190 node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
191 node.with_sparse_vectors([(
192 EmbeddedField::Combined,
193 swiftide_core::SparseEmbedding {
194 indices: vec![0, 1],
195 values: vec![1.0, 1.0],
196 },
197 )]);
198 node.to_owned()
199 })
200 .collect();
201
202 qdrant_client
203 .batch_store(nodes)
204 .await
205 .try_collect::<Vec<_>>()
206 .await
207 .unwrap();
208
209 (guard, qdrant_client)
210 }
211
212 #[test_log::test(tokio::test)]
213 async fn test_retrieve_multiple_docs_and_filter() {
214 let (_guard, qdrant_client) = setup().await;
215
216 let mut query = Query::<states::Pending>::new("test_query");
217 query.embedding = Some(vec![1.0; 384]);
218
219 let search_strategy = SimilaritySingleEmbedding::<()>::default();
220 let result = qdrant_client
221 .retrieve(&search_strategy, query.clone())
222 .await
223 .unwrap();
224 assert_eq!(result.documents().len(), 3);
225 assert_eq!(
226 result
227 .documents()
228 .iter()
229 .sorted()
230 .map(Document::content)
231 .collect_vec(),
232 ["\"test_query1\"", "\"test_query2\"", "\"test_query3\""]
235 .into_iter()
236 .sorted()
237 .collect_vec()
238 );
239
240 let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([
241 qdrant::Condition::matches("filter", "true".to_string()),
242 ]));
243 let result = qdrant_client
244 .retrieve(&search_strategy, query.clone())
245 .await
246 .unwrap();
247 assert_eq!(result.documents().len(), 2);
248 assert_eq!(
249 result
250 .documents()
251 .iter()
252 .sorted()
253 .map(Document::content)
254 .collect_vec(),
255 ["\"test_query1\"", "\"test_query2\""]
256 .into_iter()
257 .sorted()
258 .collect_vec()
259 );
260
261 let search_strategy = SimilaritySingleEmbedding::from_filter(qdrant::Filter::must([
262 qdrant::Condition::matches("filter", "banana".to_string()),
263 ]));
264 let result = qdrant_client
265 .retrieve(&search_strategy, query.clone())
266 .await
267 .unwrap();
268 assert_eq!(result.documents().len(), 0);
269 }
270
271 #[tokio::test]
272 async fn test_hybrid_search() {
273 let (_guard, qdrant_client) = setup().await;
274 let mut query = Query::<states::Pending>::new("test_query");
275
276 query.embedding = Some(vec![1.0; 384]);
277 query.sparse_embedding = Some(swiftide_core::SparseEmbedding {
278 indices: vec![0, 1],
279 values: vec![1.0, 1.0],
280 });
281 let search_strategy = HybridSearch::default();
282 let result = qdrant_client
283 .retrieve(&search_strategy, query.clone())
284 .await
285 .unwrap();
286 assert_eq!(result.documents().len(), 3);
287 }
288
289 #[tokio::test]
290 async fn test_hybrid_search_with_filter() {
291 let (_guard, qdrant_client) = setup().await;
292 let mut query = Query::<states::Pending>::new("test_query");
293
294 query.embedding = Some(vec![1.0; 384]);
295 query.sparse_embedding = Some(swiftide_core::SparseEmbedding {
296 indices: vec![0, 1],
297 values: vec![1.0, 1.0],
298 });
299 let search_strategy =
300 HybridSearch::from_filter(qdrant::Filter::must([qdrant::Condition::matches(
301 "filter",
302 "true".to_string(),
303 )]));
304 let result = qdrant_client
305 .retrieve(&search_strategy, query.clone())
306 .await
307 .unwrap();
308 assert_eq!(result.documents().len(), 2);
309 }
310}