elasticsearch_dsl/search/queries/specialized/
knn_query.rs

1use crate::search::*;
2use crate::util::*;
3use serde::Serialize;
4
5/// Finds the _k_ nearest vectors to a query vector, as measured by a similarity metric. _knn_ query finds nearest
6/// vectors through approximate search on indexed dense_vectors. The preferred way to do approximate kNN search is
7/// through the
8/// [top level knn section](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) of a
9/// search request. _knn_ query is reserved for expert cases, where there is a need to combine this query with other queries.
10///
11/// > `knn` query doesn’t have a separate `k` parameter. `k` is defined by `size` parameter of a search request
12/// > similar to other queries. `knn` query collects `num_candidates` results from each shard, then merges them to get
13/// > the top `size` results.
14///
15/// To create a knn query:
16/// ```
17/// # use elasticsearch_dsl::queries::*;
18/// # let query =
19/// Query::knn("test", vec![1.0, 2.0, 3.0]);
20/// ```
21/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html>
22#[derive(Debug, Clone, PartialEq, Serialize)]
23#[serde(remote = "Self")]
24pub struct KnnQuery {
25    field: String,
26
27    query_vector: Vec<f32>,
28
29    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
30    num_candidates: Option<u32>,
31
32    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
33    filter: Option<Box<Query>>,
34
35    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
36    similarity: Option<f32>,
37
38    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
39    boost: Option<f32>,
40
41    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
42    _name: Option<String>,
43}
44
45impl KnnQuery {
46    /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects
47    /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates`
48    /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`.
49    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
50        self.num_candidates = Some(num_candidates);
51        self
52    }
53
54    /// Query to filter the documents that can match. The kNN search will return the top documents that also match
55    /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents
56    /// are allowed to match.
57    ///
58    /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that
59    /// `num_candidates` matching documents are returned.
60    pub fn filter<T>(mut self, filter: T) -> Self
61    where
62        T: Into<Query>,
63    {
64        self.filter = Some(Box::new(filter.into()));
65        self
66    }
67
68    ///  The minimum similarity required for a document to be considered a match. The similarity value calculated
69    /// relates to the raw similarity used. Not the document score. The matched documents are then scored according
70    /// to similarity and the provided boost is applied.
71    pub fn similarity(mut self, similarity: f32) -> Self {
72        self.similarity = Some(similarity);
73        self
74    }
75
76    add_boost_and_name!();
77}
78
79impl ShouldSkip for KnnQuery {}
80
81serialize_with_root!("knn": KnnQuery);
82
83impl Query {
84    /// Creates an instance of [`KnnQuery`]
85    ///
86    /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled.
87    /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching
88    ///   against.
89    pub fn knn<T>(field: T, query_vector: Vec<f32>) -> KnnQuery
90    where
91        T: ToString,
92    {
93        KnnQuery {
94            field: field.to_string(),
95            query_vector,
96            num_candidates: None,
97            filter: None,
98            similarity: None,
99            boost: None,
100            _name: None,
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn serialization() {
111        assert_serialize_query(
112            Query::knn("test", vec![1.0, 2.0, 3.0]),
113            json!({
114                "knn": {
115                    "field": "test",
116                    "query_vector": [1.0, 2.0, 3.0]
117                }
118            }),
119        );
120
121        assert_serialize_query(
122            Query::knn("test", vec![1.0, 2.0, 3.0])
123                .num_candidates(100)
124                .filter(Query::term("field", "value"))
125                .similarity(0.5)
126                .boost(2.0)
127                .name("test"),
128            json!({
129                "knn": {
130                    "field": "test",
131                    "query_vector": [1.0, 2.0, 3.0],
132                    "num_candidates": 100,
133                    "filter": {
134                        "term": {
135                            "field": {
136                                "value": "value"
137                            }
138                        }
139                    },
140                    "similarity": 0.5,
141                    "boost": 2.0,
142                    "_name": "test"
143                }
144            }),
145        );
146    }
147}