elasticsearch_dsl/search/knn/
mod.rs

1//! A k-nearest neighbor (kNN) search finds the k nearest vectors to a query vector, as measured by a similarity metric.
2//!
3//! Common use cases for kNN include:
4//! - Relevance ranking based on natural language processing (NLP) algorithms
5//! - Product recommendations and recommendation engines
6//! - Similarity search for images or videos
7//!
8//! <https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#approximate-knn>
9
10use crate::search::*;
11use crate::util::*;
12use serde::Serialize;
13
14/// Performs a k-nearest neighbor (kNN) search and returns the matching documents.
15///
16/// The kNN search API performs a k-nearest neighbor (kNN) search on a `dense_vector` field. Given a query vector, it
17/// finds the _k_ closest vectors and returns those documents as search hits.
18///
19/// Elasticsearch uses the HNSW algorithm to support efficient kNN search. Like most kNN algorithms, HNSW is an
20/// approximate method that sacrifices result accuracy for improved search speed. This means the results returned are
21/// not always the true _k_ closest neighbors.
22///
23/// The kNN search API supports restricting the search using a filter. The search will return the top `k` documents
24/// that also match the filter query.
25///
26/// To create a knn search with a query vector or query vector builder:
27/// ```
28/// # use elasticsearch_dsl::*;
29/// # let search =
30/// Search::new()
31///     .knn(Knn::query_vector("test1", vec![1.0, 2.0, 3.0]))
32///     .knn(Knn::query_vector_builder("test3", TextEmbedding::new("my-text-embedding-model", "The opposite of pink")));
33/// ```
34/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html>
35#[derive(Debug, Clone, PartialEq, Serialize)]
36pub struct Knn {
37    field: String,
38
39    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
40    query_vector: Option<Vec<f32>>,
41
42    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
43    query_vector_builder: Option<QueryVectorBuilder>,
44
45    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
46    k: Option<u32>,
47
48    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
49    num_candidates: Option<u32>,
50
51    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
52    filter: Option<Box<Query>>,
53
54    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
55    similarity: Option<f32>,
56
57    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
58    boost: Option<f32>,
59
60    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
61    _name: Option<String>,
62}
63
64impl Knn {
65    /// Creates an instance of [`Knn`] search with query vector
66    ///
67    /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled.
68    /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching
69    ///   against.
70    pub fn query_vector<T>(field: T, query_vector: Vec<f32>) -> Self
71    where
72        T: ToString,
73    {
74        Self {
75            field: field.to_string(),
76            query_vector: Some(query_vector),
77            query_vector_builder: None,
78            k: None,
79            num_candidates: None,
80            filter: None,
81            similarity: None,
82            boost: None,
83            _name: None,
84        }
85    }
86    /// Creates an instance of [`Knn`] search with query vector builder
87    ///
88    /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled.
89    /// - `query_vector_builder` - A configuration object indicating how to build a query_vector before executing the request.
90    pub fn query_vector_builder<T, U>(field: T, query_vector_builder: U) -> Self
91    where
92        T: ToString,
93        U: Into<QueryVectorBuilder>,
94    {
95        Self {
96            field: field.to_string(),
97            query_vector: None,
98            query_vector_builder: Some(query_vector_builder.into()),
99            k: None,
100            num_candidates: None,
101            filter: None,
102            similarity: None,
103            boost: None,
104            _name: None,
105        }
106    }
107
108    /// Number of nearest neighbors to return as top hits. This value must be less than `num_candidates`.
109    ///
110    /// Defaults to `size`.
111    pub fn k(mut self, k: u32) -> Self {
112        self.k = Some(k);
113        self
114    }
115
116    /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects
117    /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates`
118    /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`.
119    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
120        self.num_candidates = Some(num_candidates);
121        self
122    }
123
124    /// Query to filter the documents that can match. The kNN search will return the top documents that also match
125    /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents
126    /// are allowed to match.
127    ///
128    /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that
129    /// `num_candidates` matching documents are returned.
130    pub fn filter<T>(mut self, filter: T) -> Self
131    where
132        T: Into<Query>,
133    {
134        self.filter = Some(Box::new(filter.into()));
135        self
136    }
137
138    /// The minimum similarity required for a document to be considered a match. The similarity value calculated
139    /// relates to the raw similarity used. Not the document score. The matched documents are then scored according
140    /// to similarity and the provided boost is applied.
141    pub fn similarity(mut self, similarity: f32) -> Self {
142        self.similarity = Some(similarity);
143        self
144    }
145
146    add_boost_and_name!();
147}
148
149/// A configuration object indicating how to build a query_vector before executing the request.
150///
151/// Currently, the only supported builder is [`TextEmbedding`].
152///
153/// <https://www.elastic.co/guide/en/elasticsearch/reference/8.13/knn-search.html#knn-semantic-search>
154#[derive(Debug, Clone, PartialEq, Serialize)]
155#[serde(rename_all = "snake_case")]
156pub enum QueryVectorBuilder {
157    /// The natural language processing task to perform.
158    TextEmbedding(TextEmbedding),
159}
160
161/// The natural language processing task to perform.
162#[derive(Debug, Clone, PartialEq, Serialize)]
163pub struct TextEmbedding {
164    model_id: String,
165    model_text: String,
166}
167
168impl From<TextEmbedding> for QueryVectorBuilder {
169    fn from(embedding: TextEmbedding) -> Self {
170        Self::TextEmbedding(embedding)
171    }
172}
173
174impl TextEmbedding {
175    /// Creates an instance of [`TextEmbedding`]
176    /// - `model_id` - The ID of the text embedding model to use to generate the dense vectors from the query string.
177    ///   Use the same model that generated the embeddings from the input text in the index you search against. You can
178    ///   use the value of the deployment_id instead in the model_id argument.
179    /// - `model_text` - The query string from which the model generates the dense vector representation.
180    pub fn new<T, U>(model_id: T, model_text: U) -> Self
181    where
182        T: ToString,
183        U: ToString,
184    {
185        Self {
186            model_id: model_id.to_string(),
187            model_text: model_text.to_string(),
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn serialization() {
198        assert_serialize(
199            Search::new()
200                .knn(Knn::query_vector("test1", vec![1.0, 2.0, 3.0]))
201                .knn(
202                    Knn::query_vector("test2", vec![4.0, 5.0, 6.0])
203                        .k(3)
204                        .num_candidates(100)
205                        .filter(Query::term("field", "value"))
206                        .similarity(0.5)
207                        .boost(2.0)
208                        .name("test2"),
209                )
210                .knn(Knn::query_vector_builder(
211                    "test3",
212                    TextEmbedding::new("my-text-embedding-model", "The opposite of pink"),
213                ))
214                .knn(
215                    Knn::query_vector_builder(
216                        "test4",
217                        TextEmbedding::new("my-text-embedding-model", "The opposite of blue"),
218                    )
219                    .k(5)
220                    .num_candidates(200)
221                    .filter(Query::term("field", "value"))
222                    .similarity(0.7)
223                    .boost(2.1)
224                    .name("test4"),
225                ),
226            json!({
227                "knn": [
228                    {
229                        "field": "test1",
230                        "query_vector": [1.0, 2.0, 3.0]
231                    },
232                    {
233                        "field": "test2",
234                        "query_vector": [4.0, 5.0, 6.0],
235                        "k": 3,
236                        "num_candidates": 100,
237                        "filter": {
238                            "term": {
239                                "field": {
240                                    "value": "value"
241                                }
242                            }
243                        },
244                        "similarity": 0.5,
245                        "boost": 2.0,
246                        "_name": "test2"
247                    },
248                    {
249                        "field": "test3",
250                        "query_vector_builder": {
251                            "text_embedding": {
252                                "model_id": "my-text-embedding-model",
253                                "model_text": "The opposite of pink"
254                            }
255                        }
256                    },
257                    {
258                        "field": "test4",
259                        "query_vector_builder": {
260                            "text_embedding": {
261                                "model_id": "my-text-embedding-model",
262                                "model_text": "The opposite of blue"
263                            }
264                        },
265                        "k": 5,
266                        "num_candidates": 200,
267                        "filter": {
268                            "term": {
269                                "field": {
270                                    "value": "value"
271                                }
272                            }
273                        },
274                        "similarity": 0.7,
275                        "boost": 2.1,
276                        "_name": "test4"
277                    }
278                ]
279            }),
280        );
281    }
282}