1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use crate::search::*;
use crate::util::*;
use serde::Serialize;

/// Finds the _k_ nearest vectors to a query vector, as measured by a similarity metric. _knn_ query finds nearest
/// vectors through approximate search on indexed dense_vectors. The preferred way to do approximate kNN search is
/// through the
/// [top level knn section](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) of a
/// search request. _knn_ query is reserved for expert cases, where there is a need to combine this query with other queries.
///
/// > `knn` query doesn’t have a separate `k` parameter. `k` is defined by `size` parameter of a search request
/// > similar to other queries. `knn` query collects `num_candidates` results from each shard, then merges them to get
/// > the top `size` results.
///
/// To create a knn query:
/// ```
/// # use elasticsearch_dsl::queries::*;
/// # let query =
/// Query::knn("test", vec![1.0, 2.0, 3.0]);
/// ```
/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html>
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(remote = "Self")]
pub struct KnnQuery {
    field: String,

    query_vector: Vec<f32>,

    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
    num_candidates: Option<u32>,

    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
    filter: Option<Box<Query>>,

    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
    similarity: Option<f32>,

    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
    boost: Option<f32>,

    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
    _name: Option<String>,
}

impl KnnQuery {
    /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects
    /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates`
    /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`.
    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
        self.num_candidates = Some(num_candidates);
        self
    }

    /// Query to filter the documents that can match. The kNN search will return the top documents that also match
    /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents
    /// are allowed to match.
    ///
    /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that
    /// `num_candidates` matching documents are returned.
    pub fn filter<T>(mut self, filter: T) -> Self
    where
        T: Into<Query>,
    {
        self.filter = Some(Box::new(filter.into()));
        self
    }

    ///  The minimum similarity required for a document to be considered a match. The similarity value calculated
    /// relates to the raw similarity used. Not the document score. The matched documents are then scored according
    /// to similarity and the provided boost is applied.
    pub fn similarity(mut self, similarity: f32) -> Self {
        self.similarity = Some(similarity);
        self
    }

    add_boost_and_name!();
}

impl ShouldSkip for KnnQuery {}

serialize_with_root!("knn": KnnQuery);

impl Query {
    /// Creates an instance of [`KnnQuery`]
    ///
    /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled.
    /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching
    ///   against.
    pub fn knn<T>(field: T, query_vector: Vec<f32>) -> KnnQuery
    where
        T: ToString,
    {
        KnnQuery {
            field: field.to_string(),
            query_vector,
            num_candidates: None,
            filter: None,
            similarity: None,
            boost: None,
            _name: None,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn serialization() {
        assert_serialize_query(
            Query::knn("test", vec![1.0, 2.0, 3.0]),
            json!({
                "knn": {
                    "field": "test",
                    "query_vector": [1.0, 2.0, 3.0]
                }
            }),
        );

        assert_serialize_query(
            Query::knn("test", vec![1.0, 2.0, 3.0])
                .num_candidates(100)
                .filter(Query::term("field", "value"))
                .similarity(0.5)
                .boost(2.0)
                .name("test"),
            json!({
                "knn": {
                    "field": "test",
                    "query_vector": [1.0, 2.0, 3.0],
                    "num_candidates": 100,
                    "filter": {
                        "term": {
                            "field": {
                                "value": "value"
                            }
                        }
                    },
                    "similarity": 0.5,
                    "boost": 2.0,
                    "_name": "test"
                }
            }),
        );
    }
}