qdrant_client/qdrant_client/
query.rs1use super::QdrantResult;
2use crate::qdrant::{
3 QueryBatchPoints, QueryBatchResponse, QueryGroupsResponse, QueryPointGroups, QueryPoints,
4 QueryResponse,
5};
6use crate::qdrant_client::Qdrant;
7
8impl Qdrant {
14 pub async fn query(&self, request: impl Into<QueryPoints>) -> QdrantResult<QueryResponse> {
37 let request = &request.into();
38
39 self.with_points_client(|mut points_api| async move {
40 let result = points_api.query(request.clone()).await?;
41 Ok(result.into_inner())
42 })
43 .await
44 }
45
46 pub async fn query_batch(
78 &self,
79 request: impl Into<QueryBatchPoints>,
80 ) -> QdrantResult<QueryBatchResponse> {
81 let request = &request.into();
82
83 self.with_points_client(|mut points_api| async move {
84 let result = points_api.query_batch(request.clone()).await?;
85 Ok(result.into_inner())
86 })
87 .await
88 }
89
90 pub async fn query_groups(
118 &self,
119 request: impl Into<QueryPointGroups>,
120 ) -> QdrantResult<QueryGroupsResponse> {
121 let request = &request.into();
122
123 self.with_points_client(|mut points_api| async move {
124 let result = points_api.query_groups(request.clone()).await?;
125 Ok(result.into_inner())
126 })
127 .await
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use serde_json::json;
134
135 use super::*;
136 use crate::builders::CreateCollectionBuilder;
137 use crate::qdrant::{
138 ContextInputBuilder, CreateFieldIndexCollectionBuilder, Datatype, DiscoverInputBuilder,
139 Distance, FieldType, Fusion, IntegerIndexParamsBuilder, Modifier, MultiVectorConfig,
140 NamedVectors, PointId, PointStruct, PrefetchQueryBuilder, Query, QueryPointsBuilder,
141 RecommendInputBuilder, ScalarQuantizationBuilder, SparseIndexConfigBuilder,
142 SparseVectorParamsBuilder, SparseVectorsConfigBuilder, UpsertPointsBuilder, Vector,
143 VectorInput, VectorParamsBuilder, VectorsConfigBuilder,
144 };
145 use crate::Payload;
146
147 #[tokio::test]
148 async fn test_query() {
149 let client = Qdrant::from_url("http://localhost:6334").build().unwrap();
150 let collection_name = "test_collection_query";
151
152 client.delete_collection(collection_name).await.unwrap();
153
154 let mut vector_config = VectorsConfigBuilder::default();
155
156 vector_config.add_named_vector_params(
157 "large_vector",
158 VectorParamsBuilder::new(8, Distance::Cosine),
159 );
160 vector_config.add_named_vector_params(
161 "small_vector",
162 VectorParamsBuilder::new(4, Distance::Euclid),
163 );
164
165 vector_config.add_named_vector_params(
166 "colbert_vector",
167 VectorParamsBuilder::new(4, Distance::Dot)
168 .multivector_config(MultiVectorConfig::default()),
169 );
170
171 let mut sparse_vector_config = SparseVectorsConfigBuilder::default();
172
173 sparse_vector_config.add_named_vector_params(
174 "sparse_idf_vector",
175 SparseVectorParamsBuilder::default()
176 .modifier(Modifier::Idf)
177 .index(SparseIndexConfigBuilder::default().datatype(Datatype::Float32)),
178 );
179
180 let create_collection = CreateCollectionBuilder::new(collection_name)
181 .vectors_config(vector_config)
182 .sparse_vectors_config(sparse_vector_config)
183 .quantization_config(ScalarQuantizationBuilder::default());
184
185 client.create_collection(create_collection).await.unwrap();
186
187 client
188 .upsert_points(
189 UpsertPointsBuilder::new(
190 collection_name,
191 vec![
192 PointStruct::new(
193 0,
194 NamedVectors::default()
195 .add_vector("large_vector", vec![0.1; 8])
196 .add_vector("small_vector", vec![0.1; 4])
197 .add_vector(
198 "colbert_vector",
199 vec![vec![0.1, 0.2, 0.3, 0.4], vec![0.4, 0.2, 0.3, 0.1]],
200 )
201 .add_vector(
202 "sparse_idf_vector",
203 Vector::new_sparse(vec![1, 2, 3], vec![0.1, 0.2, 0.3]),
204 ),
205 Payload::try_from(json!({"foo": "bar", "num": 1})).unwrap(),
206 ),
207 PointStruct::new(
208 1,
209 NamedVectors::default()
210 .add_vector("large_vector", vec![1.1; 8])
211 .add_vector("small_vector", vec![1.1; 4])
212 .add_vector(
213 "colbert_vector",
214 vec![vec![1.1, 1.2, 1.3, 1.4], vec![1.4, 1.2, 1.3, 1.1]],
215 )
216 .add_vector(
217 "sparse_idf_vector",
218 Vector::new_sparse(vec![1, 2, 3], vec![1.1, 1.2, 1.3]),
219 ),
220 Payload::try_from(json!({"foo": "bar", "num": 2})).unwrap(),
221 ),
222 ],
223 )
224 .wait(true),
225 )
226 .await
227 .unwrap();
228
229 client
230 .create_field_index(
231 CreateFieldIndexCollectionBuilder::new(collection_name, "num", FieldType::Integer)
232 .wait(true)
233 .field_index_params(IntegerIndexParamsBuilder::new(false, true).build()),
234 )
235 .await
236 .unwrap();
237
238 let request = QueryPointsBuilder::new(collection_name)
247 .limit(1)
248 .query(Query::new_fusion(Fusion::Rrf))
249 .add_prefetch(
250 PrefetchQueryBuilder::default()
251 .using("colbert_vector")
252 .query(Query::new_nearest(vec![
253 vec![0.1, 0.2, 0.3, 0.4],
254 vec![1.1, 1.2, 1.3, 1.4],
255 ]))
256 .add_prefetch(
257 PrefetchQueryBuilder::default()
258 .using("sparse_idf_vector")
259 .query(VectorInput::new_sparse(vec![1, 2, 3], vec![0.1, 0.2, 0.3])),
260 )
261 .add_prefetch(
262 PrefetchQueryBuilder::default()
263 .using("large_vector")
264 .query(Query::new_nearest(vec![0.1; 8]))
265 .add_prefetch(
266 PrefetchQueryBuilder::default()
267 .using("small_vector")
268 .query(Query::new_nearest(vec![0.1; 4])),
269 ),
270 ),
271 )
272 .add_prefetch(PrefetchQueryBuilder::default().query(Query::new_order_by("num")));
273
274 let response = client.query(request).await.unwrap();
275 assert_eq!(response.result.len(), 1);
276
277 let request = QueryPointsBuilder::new(collection_name)
280 .limit(1)
281 .using("large_vector")
282 .query(Query::new_recommend(
283 RecommendInputBuilder::default()
284 .add_positive(vec![0.1; 8])
285 .add_negative(PointId::from(0)),
286 ));
287
288 let response = client.query(request).await.unwrap();
289 assert_eq!(response.result.len(), 1);
290
291 let request = QueryPointsBuilder::new(collection_name)
294 .limit(1)
295 .using("large_vector")
296 .query(Query::new_discover(DiscoverInputBuilder::new(
297 vec![0.1; 8],
298 ContextInputBuilder::default().add_pair(PointId::from(0), vec![0.2; 8]),
299 )));
300
301 let response = client.query(request).await.unwrap();
302 assert_eq!(response.result.len(), 1);
303 }
304}