1use crate::{
2 errors::Error,
3 indexes::Index,
4 request::HttpClient,
5 search::{Filter, Selectors},
6};
7use either::Either;
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::{Map, Value};
10
11#[derive(Deserialize, Debug, Clone)]
12pub struct SimilarResult<T> {
13 #[serde(flatten)]
14 pub result: T,
15 #[serde(rename = "_rankingScore")]
16 pub ranking_score: Option<f64>,
17 #[serde(rename = "_rankingScoreDetails")]
18 pub ranking_score_details: Option<Map<String, Value>>,
19}
20
21#[derive(Deserialize, Debug, Clone)]
22#[serde(rename_all = "camelCase")]
23pub struct SimilarResults<T> {
24 pub hits: Vec<SimilarResult<T>>,
26 pub offset: Option<usize>,
28 pub limit: Option<usize>,
30 pub estimated_total_hits: Option<usize>,
32 pub processing_time_ms: usize,
34 pub id: String,
36}
37
38#[derive(Debug, Serialize, Clone)]
72#[serde(rename_all = "camelCase")]
73pub struct SimilarQuery<'a, Http: HttpClient> {
74 #[serde(skip_serializing)]
75 index: &'a Index<Http>,
76
77 pub id: &'a str,
79
80 pub embedder: &'a str,
82
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub offset: Option<usize>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub limit: Option<usize>,
90
91 #[serde(skip_serializing_if = "Option::is_none")]
95 pub filter: Option<Filter<'a>>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
103 pub attributes_to_retrieve: Option<Selectors<&'a [&'a str]>>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
109 pub show_ranking_score: Option<bool>,
110
111 #[serde(skip_serializing_if = "Option::is_none")]
115 pub show_ranking_score_details: Option<bool>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
121 pub ranking_score_threshold: Option<f64>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
127 pub retrieve_vectors: Option<bool>,
128}
129
130#[allow(missing_docs)]
131impl<'a, Http: HttpClient> SimilarQuery<'a, Http> {
132 #[must_use]
133 pub fn new(index: &'a Index<Http>, id: &'a str, embedder: &'a str) -> SimilarQuery<'a, Http> {
134 SimilarQuery {
135 index,
136 id,
137 embedder,
138 offset: None,
139 limit: None,
140 filter: None,
141 attributes_to_retrieve: None,
142 show_ranking_score: None,
143 show_ranking_score_details: None,
144 ranking_score_threshold: None,
145 retrieve_vectors: None,
146 }
147 }
148
149 pub fn with_offset<'b>(&'b mut self, offset: usize) -> &'b mut SimilarQuery<'a, Http> {
150 self.offset = Some(offset);
151 self
152 }
153
154 pub fn with_limit<'b>(&'b mut self, limit: usize) -> &'b mut SimilarQuery<'a, Http> {
155 self.limit = Some(limit);
156 self
157 }
158
159 pub fn with_filter<'b>(&'b mut self, filter: &'a str) -> &'b mut SimilarQuery<'a, Http> {
160 self.filter = Some(Filter::new(Either::Left(filter)));
161 self
162 }
163
164 pub fn with_array_filter<'b>(
165 &'b mut self,
166 filter: Vec<&'a str>,
167 ) -> &'b mut SimilarQuery<'a, Http> {
168 self.filter = Some(Filter::new(Either::Right(filter)));
169 self
170 }
171
172 pub fn with_attributes_to_retrieve<'b>(
173 &'b mut self,
174 attributes_to_retrieve: Selectors<&'a [&'a str]>,
175 ) -> &'b mut SimilarQuery<'a, Http> {
176 self.attributes_to_retrieve = Some(attributes_to_retrieve);
177 self
178 }
179
180 pub fn with_show_ranking_score<'b>(
181 &'b mut self,
182 show_ranking_score: bool,
183 ) -> &'b mut SimilarQuery<'a, Http> {
184 self.show_ranking_score = Some(show_ranking_score);
185 self
186 }
187
188 pub fn with_show_ranking_score_details<'b>(
189 &'b mut self,
190 show_ranking_score_details: bool,
191 ) -> &'b mut SimilarQuery<'a, Http> {
192 self.show_ranking_score_details = Some(show_ranking_score_details);
193 self
194 }
195
196 pub fn with_ranking_score_threshold<'b>(
197 &'b mut self,
198 ranking_score_threshold: f64,
199 ) -> &'b mut SimilarQuery<'a, Http> {
200 self.ranking_score_threshold = Some(ranking_score_threshold);
201 self
202 }
203
204 pub fn with_retrieve_vectors<'b>(
205 &'b mut self,
206 retrieve_vectors: bool,
207 ) -> &'b mut SimilarQuery<'a, Http> {
208 self.retrieve_vectors = Some(retrieve_vectors);
209 self
210 }
211
212 pub async fn execute<T: 'static + DeserializeOwned + Send + Sync>(
214 &'a self,
215 ) -> Result<SimilarResults<T>, Error> {
216 self.index.execute_similar_query::<T>(self).await
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use std::vec;
223
224 use super::*;
225 use crate::{
226 client::*,
227 search::{
228 tests::{setup_embedder, setup_test_index, Document},
229 *,
230 },
231 };
232 use meilisearch_test_macro::meilisearch_test;
233
234 #[meilisearch_test]
235 async fn test_similar_results(client: Client, index: Index) -> Result<(), Error> {
236 setup_embedder(&client, &index).await?;
237 setup_test_index(&client, &index).await?;
238
239 let mut query = SimilarQuery::new(&index, "0", "default");
241 query.with_limit(1);
242 let results: SimilarResults<Document> = query.execute().await?;
243 let result = results.hits.first().unwrap();
244 assert_eq!(result.result.id, 1);
245
246 let mut query = SimilarQuery::new(&index, "3", "default");
248 query.with_limit(1);
249 let results: SimilarResults<Document> = query.execute().await?;
250 let result = results.hits.first().unwrap();
251 assert_eq!(result.result.id, 4);
252
253 Ok(())
254 }
255
256 #[meilisearch_test]
257 async fn test_query_limit(client: Client, index: Index) -> Result<(), Error> {
258 setup_embedder(&client, &index).await?;
259 setup_test_index(&client, &index).await?;
260
261 let mut query = SimilarQuery::new(&index, "1", "default");
262 query.with_limit(3);
263
264 let results: SimilarResults<Document> = query.execute().await?;
265 assert_eq!(results.hits.len(), 3);
266 Ok(())
267 }
268
269 #[meilisearch_test]
270 async fn test_query_offset(client: Client, index: Index) -> Result<(), Error> {
271 setup_embedder(&client, &index).await?;
272 setup_test_index(&client, &index).await?;
273
274 let mut query = SimilarQuery::new(&index, "1", "default");
275 query.with_offset(6);
276
277 let results: SimilarResults<Document> = query.execute().await?;
278 assert_eq!(results.hits.len(), 3);
279 Ok(())
280 }
281
282 #[meilisearch_test]
283 async fn test_query_filter(client: Client, index: Index) -> Result<(), Error> {
284 setup_embedder(&client, &index).await?;
285 setup_test_index(&client, &index).await?;
286
287 let mut query = SimilarQuery::new(&index, "1", "default");
288
289 let results: SimilarResults<Document> =
290 query.with_filter("kind = \"title\"").execute().await?;
291 assert_eq!(results.hits.len(), 8);
292
293 let results: SimilarResults<Document> =
294 query.with_filter("NOT kind = \"title\"").execute().await?;
295 assert_eq!(results.hits.len(), 1);
296 Ok(())
297 }
298
299 #[meilisearch_test]
300 async fn test_query_filter_with_array(client: Client, index: Index) -> Result<(), Error> {
301 setup_embedder(&client, &index).await?;
302 setup_test_index(&client, &index).await?;
303
304 let mut query = SimilarQuery::new(&index, "1", "default");
305 let results: SimilarResults<Document> = query
306 .with_array_filter(vec!["kind = \"title\"", "kind = \"text\""])
307 .execute()
308 .await?;
309 assert_eq!(results.hits.len(), 0);
310
311 let mut query = SimilarQuery::new(&index, "1", "default");
312 let results: SimilarResults<Document> = query
313 .with_array_filter(vec!["kind = \"title\"", "number <= 50"])
314 .execute()
315 .await?;
316 assert_eq!(results.hits.len(), 4);
317
318 Ok(())
319 }
320
321 #[meilisearch_test]
322 async fn test_query_attributes_to_retrieve(client: Client, index: Index) -> Result<(), Error> {
323 setup_embedder(&client, &index).await?;
324 setup_test_index(&client, &index).await?;
325
326 let mut query = SimilarQuery::new(&index, "1", "default");
327 let results: SimilarResults<Document> = query
328 .with_attributes_to_retrieve(Selectors::All)
329 .execute()
330 .await?;
331 assert_eq!(results.hits.len(), 9);
332
333 let mut query = SimilarQuery::new(&index, "1", "default");
334 query.with_attributes_to_retrieve(Selectors::Some(&["title", "id"])); assert!(query.execute::<Document>().await.is_err()); Ok(())
337 }
338
339 #[meilisearch_test]
340 async fn test_query_show_ranking_score(client: Client, index: Index) -> Result<(), Error> {
341 setup_embedder(&client, &index).await?;
342 setup_test_index(&client, &index).await?;
343
344 let mut query = SimilarQuery::new(&index, "1", "default");
345 query.with_show_ranking_score(true);
346 let results: SimilarResults<Document> = query.execute().await?;
347 assert!(results.hits[0].ranking_score.is_some());
348 Ok(())
349 }
350
351 #[meilisearch_test]
352 async fn test_query_show_ranking_score_details(
353 client: Client,
354 index: Index,
355 ) -> Result<(), Error> {
356 setup_embedder(&client, &index).await?;
357 setup_test_index(&client, &index).await?;
358
359 let mut query = SimilarQuery::new(&index, "1", "default");
360 query.with_show_ranking_score_details(true);
361 let results: SimilarResults<Document> = query.execute().await?;
362 assert!(results.hits[0].ranking_score_details.is_some());
363 Ok(())
364 }
365
366 #[meilisearch_test]
367 async fn test_query_show_ranking_score_threshold(
368 client: Client,
369 index: Index,
370 ) -> Result<(), Error> {
371 setup_embedder(&client, &index).await?;
372 setup_test_index(&client, &index).await?;
373
374 let mut query = SimilarQuery::new(&index, "1", "default");
375 query.with_ranking_score_threshold(1.0);
376 let results: SimilarResults<Document> = query.execute().await?;
377 assert!(results.hits.is_empty());
378 Ok(())
379 }
380
381 #[meilisearch_test]
382 async fn test_query_retrieve_vectors(client: Client, index: Index) -> Result<(), Error> {
383 setup_embedder(&client, &index).await?;
384 setup_test_index(&client, &index).await?;
385
386 let mut query = SimilarQuery::new(&index, "1", "default");
387 query.with_retrieve_vectors(true);
388 let results: SimilarResults<Document> = query.execute().await?;
389 assert!(results.hits[0].result._vectors.is_some());
390 Ok(())
391 }
392}