Skip to main content

luci/vector/
query.rs

1//! KNN query: approximate nearest neighbor search via the index-wide
2//! global HNSW per [[global-vector-indices]] Alternative B.
3//!
4//! The global graph is consulted once at bind time, producing a single
5//! `Vec<(SegmentId, DocId, distance)>`. Hits are bucketed by segment
6//! and exposed to the per-segment scorer dispatch the rest of the
7//! engine uses. Per-segment kNN fan-out and proportional-k oversampling
8//! are gone — there is one graph, one search, one ranking.
9
10use std::collections::HashMap;
11
12use crate::core::{
13    DocId, LuciError, NO_MORE_DOCS, Result, ScoreMode, Scorer, SegmentId, TwoPhaseIterator,
14};
15
16use crate::query::{BoundQuery, Query, ScorerSupplier};
17use crate::search::searcher::Searcher;
18use crate::segment::reader::SegmentReader;
19use crate::vector::DistanceMetric;
20
21/// Approximate k-nearest-neighbor query on a dense_vector field.
22///
23/// See [[feature-knn-query-type]].
24pub struct KnnQuery {
25    pub field: String,
26    pub query_vector: Vec<f32>,
27    pub k: usize,
28    pub num_candidates: usize,
29    /// Minimum score threshold. Results with score < threshold are excluded.
30    /// Score is metric-specific (see [`crate::vector::distance_to_score`]).
31    pub threshold: Option<f32>,
32}
33
34impl Query for KnnQuery {
35    fn bind(&self, searcher: &Searcher, _score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
36        // Resolve the field against the mapping. A knn query names a
37        // dense_vector field; if there is no mapping, the field is
38        // unknown, or the field is not a dense_vector, the query cannot
39        // run — return an explicit error rather than silently producing
40        // zero hits ([[code-must-not-lie]]). A *valid* dense_vector field
41        // with no vectors indexed yet is the only honest empty case, and
42        // is handled by the no-graph branches below.
43        let Some(mapping) = searcher.mapping() else {
44            return Err(LuciError::InvalidQuery(format!(
45                "knn query targets field '{}', but this index has no mapping",
46                self.field
47            )));
48        };
49        let Some(field_id) = mapping.field_id(&self.field) else {
50            return Err(LuciError::InvalidQuery(format!(
51                "knn query targets unknown field '{}'",
52                self.field
53            )));
54        };
55        let Some(expected_dims) = mapping.field(field_id).field_type.vector_dims() else {
56            return Err(LuciError::InvalidQuery(format!(
57                "knn query targets field '{}', which is not a dense_vector field",
58                self.field
59            )));
60        };
61        if self.query_vector.len() != expected_dims {
62            return Err(LuciError::InvalidQuery(format!(
63                "knn query_vector has {} dimensions, field '{}' expects {}",
64                self.query_vector.len(),
65                self.field,
66                expected_dims
67            )));
68        }
69
70        // Pull the index-wide global graph from the search-side
71        // snapshot. If the field has no vectors yet (or no graph is
72        // attached), short-circuit to an empty bound query — the rest
73        // of the search pipeline handles zero-hit branches uniformly.
74        let Some(global) = searcher.global_hnsw() else {
75            return Ok(Box::new(BoundKnnQuery {
76                results_by_segment: HashMap::new(),
77                metric: DistanceMetric::Cosine,
78            }));
79        };
80
81        let (hits, metric) =
82            match global.search(field_id, &self.query_vector, self.k, self.num_candidates)? {
83                Some(out) => out,
84                None => {
85                    return Ok(Box::new(BoundKnnQuery {
86                        results_by_segment: HashMap::new(),
87                        metric: DistanceMetric::Cosine,
88                    }));
89                }
90            };
91
92        // Apply the threshold against the metric-specific score
93        // mapping. Threshold semantics match the per-segment path
94        // (kept after this refactor).
95        let mut filtered: Vec<_> = hits
96            .into_iter()
97            .filter(|hit| match self.threshold {
98                Some(min_score) => {
99                    crate::vector::distance_to_score(hit.distance, metric) >= min_score
100                }
101                None => true,
102            })
103            .collect();
104
105        // Bucket by segment. Each bucket holds `(local_doc_id, distance)`
106        // and is sorted by ascending `local_doc_id` so the per-segment
107        // scorer can iterate in advance() order.
108        let mut results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>> = HashMap::new();
109        filtered.sort_by(|a, b| {
110            a.distance
111                .partial_cmp(&b.distance)
112                .unwrap_or(std::cmp::Ordering::Equal)
113        });
114        for hit in filtered {
115            results_by_segment
116                .entry(hit.segment_id)
117                .or_default()
118                .push((hit.doc_id.as_u32(), hit.distance));
119        }
120        for bucket in results_by_segment.values_mut() {
121            bucket.sort_by_key(|(doc_id, _)| *doc_id);
122        }
123
124        Ok(Box::new(BoundKnnQuery {
125            results_by_segment,
126            metric,
127        }))
128    }
129}
130
131struct BoundKnnQuery {
132    /// Pre-bucketed hits, keyed by the owning segment. Each bucket is
133    /// sorted by ascending local doc_id so the scorer can advance() in
134    /// order against bool conjunctions.
135    results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>>,
136    metric: DistanceMetric,
137}
138
139impl BoundQuery for BoundKnnQuery {
140    fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
141        let Some(bucket) = self.results_by_segment.get(&reader.segment_id()) else {
142            return Ok(None);
143        };
144        if bucket.is_empty() {
145            return Ok(None);
146        }
147        Ok(Some(Box::new(KnnScorerSupplier {
148            results: bucket.clone(),
149            metric: self.metric,
150        })))
151    }
152}
153
154struct KnnScorerSupplier {
155    results: Vec<(u32, f32)>,
156    metric: DistanceMetric,
157}
158
159impl ScorerSupplier for KnnScorerSupplier {
160    fn cost(&self) -> u64 {
161        self.results.len() as u64
162    }
163
164    fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
165        Ok(Box::new(KnnScorer {
166            results: self.results,
167            metric: self.metric,
168            pos: 0,
169        }))
170    }
171}
172
173/// Scorer that iterates over pre-materialized kNN results.
174///
175/// Results are sorted by doc_id ascending (re-sorted from HNSW's distance
176/// order). Scores are computed from distances using metric-specific formulas.
177struct KnnScorer {
178    results: Vec<(u32, f32)>, // (doc_id, distance) sorted by doc_id
179    metric: DistanceMetric,
180    pos: usize,
181}
182
183impl Scorer for KnnScorer {
184    fn doc_id(&self) -> DocId {
185        if self.pos < self.results.len() {
186            DocId::new(self.results[self.pos].0)
187        } else {
188            NO_MORE_DOCS
189        }
190    }
191
192    fn next(&mut self) -> DocId {
193        if self.pos < self.results.len() {
194            self.pos += 1;
195        }
196        self.doc_id()
197    }
198
199    fn advance(&mut self, target: DocId) -> DocId {
200        while self.pos < self.results.len() && self.results[self.pos].0 < target.as_u32() {
201            self.pos += 1;
202        }
203        self.doc_id()
204    }
205
206    fn score(&mut self) -> f32 {
207        if self.pos < self.results.len() {
208            crate::vector::distance_to_score(self.results[self.pos].1, self.metric)
209        } else {
210            0.0
211        }
212    }
213
214    fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
215        None
216    }
217}