elasticsearch_dsl/search/queries/specialized/knn_query.rs
1use crate::search::*;
2use crate::util::*;
3use serde::Serialize;
4
5/// Finds the _k_ nearest vectors to a query vector, as measured by a similarity metric. _knn_ query finds nearest
6/// vectors through approximate search on indexed dense_vectors. The preferred way to do approximate kNN search is
7/// through the
8/// [top level knn section](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) of a
9/// search request. _knn_ query is reserved for expert cases, where there is a need to combine this query with other queries.
10///
11/// > `knn` query doesn’t have a separate `k` parameter. `k` is defined by `size` parameter of a search request
12/// > similar to other queries. `knn` query collects `num_candidates` results from each shard, then merges them to get
13/// > the top `size` results.
14///
15/// To create a knn query:
16/// ```
17/// # use elasticsearch_dsl::queries::*;
18/// # let query =
19/// Query::knn("test", vec![1.0, 2.0, 3.0]);
20/// ```
21/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html>
22#[derive(Debug, Clone, PartialEq, Serialize)]
23#[serde(remote = "Self")]
24pub struct KnnQuery {
25 field: String,
26
27 query_vector: Vec<f32>,
28
29 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
30 num_candidates: Option<u32>,
31
32 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
33 filter: Option<Box<Query>>,
34
35 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
36 similarity: Option<f32>,
37
38 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
39 boost: Option<f32>,
40
41 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
42 _name: Option<String>,
43}
44
45impl KnnQuery {
46 /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects
47 /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates`
48 /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`.
49 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
50 self.num_candidates = Some(num_candidates);
51 self
52 }
53
54 /// Query to filter the documents that can match. The kNN search will return the top documents that also match
55 /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents
56 /// are allowed to match.
57 ///
58 /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that
59 /// `num_candidates` matching documents are returned.
60 pub fn filter<T>(mut self, filter: T) -> Self
61 where
62 T: Into<Query>,
63 {
64 self.filter = Some(Box::new(filter.into()));
65 self
66 }
67
68 /// The minimum similarity required for a document to be considered a match. The similarity value calculated
69 /// relates to the raw similarity used. Not the document score. The matched documents are then scored according
70 /// to similarity and the provided boost is applied.
71 pub fn similarity(mut self, similarity: f32) -> Self {
72 self.similarity = Some(similarity);
73 self
74 }
75
76 add_boost_and_name!();
77}
78
79impl ShouldSkip for KnnQuery {}
80
81serialize_with_root!("knn": KnnQuery);
82
83impl Query {
84 /// Creates an instance of [`KnnQuery`]
85 ///
86 /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled.
87 /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching
88 /// against.
89 pub fn knn<T>(field: T, query_vector: Vec<f32>) -> KnnQuery
90 where
91 T: ToString,
92 {
93 KnnQuery {
94 field: field.to_string(),
95 query_vector,
96 num_candidates: None,
97 filter: None,
98 similarity: None,
99 boost: None,
100 _name: None,
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn serialization() {
111 assert_serialize_query(
112 Query::knn("test", vec![1.0, 2.0, 3.0]),
113 json!({
114 "knn": {
115 "field": "test",
116 "query_vector": [1.0, 2.0, 3.0]
117 }
118 }),
119 );
120
121 assert_serialize_query(
122 Query::knn("test", vec![1.0, 2.0, 3.0])
123 .num_candidates(100)
124 .filter(Query::term("field", "value"))
125 .similarity(0.5)
126 .boost(2.0)
127 .name("test"),
128 json!({
129 "knn": {
130 "field": "test",
131 "query_vector": [1.0, 2.0, 3.0],
132 "num_candidates": 100,
133 "filter": {
134 "term": {
135 "field": {
136 "value": "value"
137 }
138 }
139 },
140 "similarity": 0.5,
141 "boost": 2.0,
142 "_name": "test"
143 }
144 }),
145 );
146 }
147}