Skip to main content

nodedb_vector/codec_index/
search.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Search algorithm for `HnswCodecIndex<C>`.
4//!
5//! Two-stage HNSW search:
6//! - Phase 1 (layers max..1): greedy descent using `fast_symmetric_distance`.
7//! - Phase 2 (layer 0): ef-wide beam search using `fast_symmetric_distance`
8//!   for navigation, then a final rerank pass with `exact_asymmetric_distance`.
9
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, HashSet};
12
13use nodedb_codec::vector_quant::codec::VectorCodec;
14
15use super::graph::HnswCodecIndex;
16
17/// A single result from a codec-index search.
18#[derive(Debug, Clone)]
19pub struct CodecSearchResult {
20    /// Caller-supplied id from `HnswCodecIndex::insert`.
21    pub id: u32,
22    /// `exact_asymmetric_distance` between the query and this vector.
23    pub distance: f32,
24}
25
26/// Internal candidate during beam search.
27#[derive(Clone, Copy, PartialEq)]
28struct Cand {
29    dist: f32,
30    /// Dense index into `HnswCodecIndex::nodes`.
31    idx: u32,
32}
33
34impl Eq for Cand {}
35
36impl PartialOrd for Cand {
37    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42impl Ord for Cand {
43    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
44        self.dist
45            .partial_cmp(&other.dist)
46            .unwrap_or(std::cmp::Ordering::Equal)
47            .then(self.idx.cmp(&other.idx))
48    }
49}
50
51impl<C: VectorCodec> HnswCodecIndex<C> {
52    /// K-NN search returning up to `k` results.
53    ///
54    /// `ef_search` controls the beam width at layer 0 (must be >= k).
55    ///
56    /// The returned results are sorted ascending by `exact_asymmetric_distance`.
57    pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<CodecSearchResult> {
58        if self.is_empty() {
59            return Vec::new();
60        }
61
62        let Some(ep) = self.entry_point else {
63            return Vec::new();
64        };
65
66        let ef = ef_search.max(k);
67
68        // Precompute the query forms used by the two phases.
69        let q_encoded = self.codec.encode(query);
70        let q_prepared = self.codec.prepare_query(query);
71
72        // Phase 1: greedy descent through layers max_layer..1.
73        let mut cur_ep = ep;
74        for layer in (1..=self.max_layer).rev() {
75            cur_ep = self.greedy_nearest_search(&q_encoded, cur_ep, layer);
76        }
77
78        // Phase 2: ef-wide beam search at layer 0.
79        let candidates = self.search_layer_0(&q_encoded, cur_ep, ef);
80
81        // Rerank top ef_search candidates with exact asymmetric distance.
82        let mut reranked: Vec<(f32, u32)> = candidates
83            .into_iter()
84            .take(ef)
85            .map(|c| {
86                let asym = self
87                    .codec
88                    .exact_asymmetric_distance(&q_prepared, &self.nodes[c.idx as usize].quantized);
89                (asym, self.nodes[c.idx as usize].id)
90            })
91            .collect();
92
93        reranked.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
94        reranked.truncate(k);
95
96        reranked
97            .into_iter()
98            .map(|(distance, id)| CodecSearchResult { id, distance })
99            .collect()
100    }
101
102    /// Greedy single-nearest descent at `layer` using the pre-encoded query.
103    fn greedy_nearest_search(&self, q_enc: &C::Quantized, ep_idx: u32, layer: usize) -> u32 {
104        let mut best_idx = ep_idx;
105        let mut best_dist = self
106            .codec
107            .fast_symmetric_distance(q_enc, &self.nodes[ep_idx as usize].quantized);
108
109        loop {
110            let mut improved = false;
111            for &nb in self.neighbors_at(best_idx, layer) {
112                if self.nodes[nb as usize].deleted {
113                    continue;
114                }
115                let d = self
116                    .codec
117                    .fast_symmetric_distance(q_enc, &self.nodes[nb as usize].quantized);
118                if d < best_dist {
119                    best_dist = d;
120                    best_idx = nb;
121                    improved = true;
122                }
123            }
124            if !improved {
125                break;
126            }
127        }
128
129        best_idx
130    }
131
132    /// Beam search at layer 0 using `fast_symmetric_distance`.
133    ///
134    /// Returns candidates sorted ascending by symmetric distance (used only
135    /// for routing; final ranking is done by `exact_asymmetric_distance` in
136    /// the caller).
137    fn search_layer_0(&self, q_enc: &C::Quantized, ep_idx: u32, ef: usize) -> Vec<Cand> {
138        let mut visited: HashSet<u32> = HashSet::new();
139        visited.insert(ep_idx);
140
141        let ep_dist = self
142            .codec
143            .fast_symmetric_distance(q_enc, &self.nodes[ep_idx as usize].quantized);
144        let ep_cand = Cand {
145            dist: ep_dist,
146            idx: ep_idx,
147        };
148
149        let mut candidates: BinaryHeap<Reverse<Cand>> = BinaryHeap::new();
150        candidates.push(Reverse(ep_cand));
151
152        let mut results: BinaryHeap<Cand> = BinaryHeap::new();
153        if !self.nodes[ep_idx as usize].deleted {
154            results.push(ep_cand);
155        }
156
157        while let Some(Reverse(cur)) = candidates.pop() {
158            let worst = results.peek().map_or(f32::INFINITY, |w| w.dist);
159            if cur.dist > worst && results.len() >= ef {
160                break;
161            }
162
163            for &nb in self.neighbors_at(cur.idx, 0) {
164                if !visited.insert(nb) {
165                    continue;
166                }
167                let d = self
168                    .codec
169                    .fast_symmetric_distance(q_enc, &self.nodes[nb as usize].quantized);
170                let worst_now = results.peek().map_or(f32::INFINITY, |w| w.dist);
171                if d < worst_now || results.len() < ef {
172                    candidates.push(Reverse(Cand { dist: d, idx: nb }));
173                }
174                if !self.nodes[nb as usize].deleted {
175                    results.push(Cand { dist: d, idx: nb });
176                    if results.len() > ef {
177                        results.pop();
178                    }
179                }
180            }
181        }
182
183        let mut out: Vec<Cand> = results.into_vec();
184        out.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
185        out
186    }
187}
188
189// ── Tests ─────────────────────────────────────────────────────────────────────
190
191#[cfg(test)]
192mod tests {
193    use nodedb_codec::vector_quant::{bbq::BbqCodec, rabitq::RaBitQCodec};
194
195    use crate::{codec_index::HnswCodecIndex, distance::l2_squared, quantize::Sq8Codec};
196
197    // ── helpers ───────────────────────────────────────────────────────────────
198
199    fn xorshift(state: &mut u64) -> u64 {
200        *state ^= *state << 13;
201        *state ^= *state >> 7;
202        *state ^= *state << 17;
203        *state
204    }
205
206    fn rand_vec(state: &mut u64, dim: usize) -> Vec<f32> {
207        (0..dim)
208            .map(|_| (xorshift(state) as f32 / u64::MAX as f32) * 2.0 - 1.0)
209            .collect()
210    }
211
212    /// Brute-force top-k by L2-squared.
213    fn ground_truth(vecs: &[Vec<f32>], query: &[f32], k: usize) -> Vec<u32> {
214        let mut scored: Vec<(f32, u32)> = vecs
215            .iter()
216            .enumerate()
217            .map(|(i, v)| (l2_squared(query, v), i as u32))
218            .collect();
219        scored.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
220        scored.into_iter().take(k).map(|(_, id)| id).collect()
221    }
222
223    // ── Sq8 round-trip ────────────────────────────────────────────────────────
224
225    #[test]
226    fn sq8_top1_exact_match() {
227        let dim = 8;
228        let n = 50usize;
229        let mut state = 0xDEAD_BEEF_u64;
230        let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
231        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
232        let codec = Sq8Codec::calibrate(&refs, dim);
233        let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(dim, 8, 100, codec, 7);
234        for (i, v) in vecs.iter().enumerate() {
235            idx.insert(i as u32, v);
236        }
237        // Query with vector 17 — top-1 should return id 17.
238        let query = vecs[17].clone();
239        let results = idx.search(&query, 1, 50);
240        assert_eq!(results.len(), 1, "expected 1 result");
241        assert_eq!(
242            results[0].id, 17,
243            "top-1 should be the queried vector itself"
244        );
245        assert!(
246            results[0].distance < 0.1,
247            "distance to self should be near 0, got {}",
248            results[0].distance
249        );
250    }
251
252    // ── RaBitQ recall ─────────────────────────────────────────────────────────
253
254    #[test]
255    fn rabitq_recall_at_least_60_percent() {
256        let dim = 16;
257        let n = 100usize;
258        let k = 5usize;
259        let mut state = 0xCAFE_BABE_u64;
260        let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
261        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
262        let codec = RaBitQCodec::calibrate(&refs, dim, 0xABCD_1234);
263        let mut idx: HnswCodecIndex<RaBitQCodec> = HnswCodecIndex::new(dim, 8, 150, codec, 99);
264        for (i, v) in vecs.iter().enumerate() {
265            idx.insert(i as u32, v);
266        }
267
268        let n_queries = 10usize;
269        let mut total_hits = 0usize;
270        let mut total = 0usize;
271        for _qi in 0..n_queries {
272            let query = rand_vec(&mut state, dim);
273            let truth: std::collections::HashSet<u32> =
274                ground_truth(&vecs, &query, k).into_iter().collect();
275            let results = idx.search(&query, k, k * 4);
276            let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
277            total_hits += found.intersection(&truth).count();
278            total += k;
279        }
280        let recall = total_hits as f64 / total as f64;
281        // 1-bit codecs at low dim (D=16) are inherently approximate; the
282        // O(1/√D) bound only bites at higher dimensions. This is a sanity
283        // test: at D=16 / n=100 / k=5, RaBitQ typically lands ≥ 30%.
284        assert!(
285            recall >= 0.30,
286            "RaBitQ recall@{k} = {recall:.2}, expected >= 0.30"
287        );
288    }
289
290    // ── BBQ recall ────────────────────────────────────────────────────────────
291
292    #[test]
293    fn bbq_recall_at_least_60_percent() {
294        let dim = 16;
295        let n = 100usize;
296        let k = 5usize;
297        let mut state = 0x1234_5678_u64;
298        let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
299        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
300        let codec = BbqCodec::calibrate(&refs, dim, 3);
301        let mut idx: HnswCodecIndex<BbqCodec> = HnswCodecIndex::new(dim, 8, 150, codec, 42);
302        for (i, v) in vecs.iter().enumerate() {
303            idx.insert(i as u32, v);
304        }
305
306        let n_queries = 10usize;
307        let mut total_hits = 0usize;
308        let mut total = 0usize;
309        for _qi in 0..n_queries {
310            let query = rand_vec(&mut state, dim);
311            let truth: std::collections::HashSet<u32> =
312                ground_truth(&vecs, &query, k).into_iter().collect();
313            let results = idx.search(&query, k, k * 4);
314            let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
315            total_hits += found.intersection(&truth).count();
316            total += k;
317        }
318        let recall = total_hits as f64 / total as f64;
319        // BBQ at D=16 is approximate (corrective factors help vs raw binary
320        // but the 1-bit code itself is information-bounded). Sanity test
321        // threshold; production uses BBQ + oversample ×3 rerank pass.
322        assert!(
323            recall >= 0.40,
324            "BBQ recall@{k} = {recall:.2}, expected >= 0.40"
325        );
326    }
327
328    // ── Edge cases ────────────────────────────────────────────────────────────
329
330    #[test]
331    fn empty_index_returns_empty() {
332        let codec = {
333            let vecs: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32; 4]).collect();
334            let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
335            Sq8Codec::calibrate(&refs, 4)
336        };
337        let idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 1);
338        let results = idx.search(&[0.0, 0.0, 0.0, 0.0], 5, 20);
339        assert!(results.is_empty(), "empty index must return no results");
340    }
341
342    #[test]
343    fn single_vector_index_always_returns_it() {
344        let dim = 4;
345        let vecs = [vec![1.0f32, 2.0, 3.0, 4.0]];
346        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
347        let codec = Sq8Codec::calibrate(&refs, dim);
348        let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(dim, 8, 50, codec, 5);
349        idx.insert(0, &vecs[0]);
350        // Query with a completely different vector.
351        let results = idx.search(&[10.0, 20.0, 30.0, 40.0], 1, 10);
352        assert_eq!(results.len(), 1, "single-node index must return 1 result");
353        assert_eq!(results[0].id, 0, "the only node must be returned");
354    }
355}