1use std::collections::HashMap;
11
12use crate::core::{
13 DocId, LuciError, NO_MORE_DOCS, Result, ScoreMode, Scorer, SegmentId, TwoPhaseIterator,
14};
15
16use crate::query::{BoundQuery, Query, ScorerSupplier};
17use crate::search::searcher::Searcher;
18use crate::segment::reader::SegmentReader;
19use crate::vector::DistanceMetric;
20
21pub struct KnnQuery {
25 pub field: String,
26 pub query_vector: Vec<f32>,
27 pub k: usize,
28 pub num_candidates: usize,
29 pub threshold: Option<f32>,
32}
33
34impl Query for KnnQuery {
35 fn bind(&self, searcher: &Searcher, _score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
36 let Some(mapping) = searcher.mapping() else {
44 return Err(LuciError::InvalidQuery(format!(
45 "knn query targets field '{}', but this index has no mapping",
46 self.field
47 )));
48 };
49 let Some(field_id) = mapping.field_id(&self.field) else {
50 return Err(LuciError::InvalidQuery(format!(
51 "knn query targets unknown field '{}'",
52 self.field
53 )));
54 };
55 let Some(expected_dims) = mapping.field(field_id).field_type.vector_dims() else {
56 return Err(LuciError::InvalidQuery(format!(
57 "knn query targets field '{}', which is not a dense_vector field",
58 self.field
59 )));
60 };
61 if self.query_vector.len() != expected_dims {
62 return Err(LuciError::InvalidQuery(format!(
63 "knn query_vector has {} dimensions, field '{}' expects {}",
64 self.query_vector.len(),
65 self.field,
66 expected_dims
67 )));
68 }
69
70 let Some(global) = searcher.global_hnsw() else {
75 return Ok(Box::new(BoundKnnQuery {
76 results_by_segment: HashMap::new(),
77 metric: DistanceMetric::Cosine,
78 }));
79 };
80
81 let (hits, metric) =
82 match global.search(field_id, &self.query_vector, self.k, self.num_candidates)? {
83 Some(out) => out,
84 None => {
85 return Ok(Box::new(BoundKnnQuery {
86 results_by_segment: HashMap::new(),
87 metric: DistanceMetric::Cosine,
88 }));
89 }
90 };
91
92 let mut filtered: Vec<_> = hits
96 .into_iter()
97 .filter(|hit| match self.threshold {
98 Some(min_score) => {
99 crate::vector::distance_to_score(hit.distance, metric) >= min_score
100 }
101 None => true,
102 })
103 .collect();
104
105 let mut results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>> = HashMap::new();
109 filtered.sort_by(|a, b| {
110 a.distance
111 .partial_cmp(&b.distance)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 for hit in filtered {
115 results_by_segment
116 .entry(hit.segment_id)
117 .or_default()
118 .push((hit.doc_id.as_u32(), hit.distance));
119 }
120 for bucket in results_by_segment.values_mut() {
121 bucket.sort_by_key(|(doc_id, _)| *doc_id);
122 }
123
124 Ok(Box::new(BoundKnnQuery {
125 results_by_segment,
126 metric,
127 }))
128 }
129}
130
131struct BoundKnnQuery {
132 results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>>,
136 metric: DistanceMetric,
137}
138
139impl BoundQuery for BoundKnnQuery {
140 fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
141 let Some(bucket) = self.results_by_segment.get(&reader.segment_id()) else {
142 return Ok(None);
143 };
144 if bucket.is_empty() {
145 return Ok(None);
146 }
147 Ok(Some(Box::new(KnnScorerSupplier {
148 results: bucket.clone(),
149 metric: self.metric,
150 })))
151 }
152}
153
154struct KnnScorerSupplier {
155 results: Vec<(u32, f32)>,
156 metric: DistanceMetric,
157}
158
159impl ScorerSupplier for KnnScorerSupplier {
160 fn cost(&self) -> u64 {
161 self.results.len() as u64
162 }
163
164 fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
165 Ok(Box::new(KnnScorer {
166 results: self.results,
167 metric: self.metric,
168 pos: 0,
169 }))
170 }
171}
172
173struct KnnScorer {
178 results: Vec<(u32, f32)>, metric: DistanceMetric,
180 pos: usize,
181}
182
183impl Scorer for KnnScorer {
184 fn doc_id(&self) -> DocId {
185 if self.pos < self.results.len() {
186 DocId::new(self.results[self.pos].0)
187 } else {
188 NO_MORE_DOCS
189 }
190 }
191
192 fn next(&mut self) -> DocId {
193 if self.pos < self.results.len() {
194 self.pos += 1;
195 }
196 self.doc_id()
197 }
198
199 fn advance(&mut self, target: DocId) -> DocId {
200 while self.pos < self.results.len() && self.results[self.pos].0 < target.as_u32() {
201 self.pos += 1;
202 }
203 self.doc_id()
204 }
205
206 fn score(&mut self) -> f32 {
207 if self.pos < self.results.len() {
208 crate::vector::distance_to_score(self.results[self.pos].1, self.metric)
209 } else {
210 0.0
211 }
212 }
213
214 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
215 None
216 }
217}