nodedb_vector/codec_index/
search.rs1use std::cmp::Reverse;
11use std::collections::{BinaryHeap, HashSet};
12
13use nodedb_codec::vector_quant::codec::VectorCodec;
14
15use super::graph::HnswCodecIndex;
16
17#[derive(Debug, Clone)]
19pub struct CodecSearchResult {
20 pub id: u32,
22 pub distance: f32,
24}
25
26#[derive(Clone, Copy, PartialEq)]
28struct Cand {
29 dist: f32,
30 idx: u32,
32}
33
34impl Eq for Cand {}
35
36impl PartialOrd for Cand {
37 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38 Some(self.cmp(other))
39 }
40}
41
42impl Ord for Cand {
43 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
44 self.dist
45 .partial_cmp(&other.dist)
46 .unwrap_or(std::cmp::Ordering::Equal)
47 .then(self.idx.cmp(&other.idx))
48 }
49}
50
51impl<C: VectorCodec> HnswCodecIndex<C> {
52 pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Vec<CodecSearchResult> {
58 if self.is_empty() {
59 return Vec::new();
60 }
61
62 let Some(ep) = self.entry_point else {
63 return Vec::new();
64 };
65
66 let ef = ef_search.max(k);
67
68 let q_encoded = self.codec.encode(query);
70 let q_prepared = self.codec.prepare_query(query);
71
72 let mut cur_ep = ep;
74 for layer in (1..=self.max_layer).rev() {
75 cur_ep = self.greedy_nearest_search(&q_encoded, cur_ep, layer);
76 }
77
78 let candidates = self.search_layer_0(&q_encoded, cur_ep, ef);
80
81 let mut reranked: Vec<(f32, u32)> = candidates
83 .into_iter()
84 .take(ef)
85 .map(|c| {
86 let asym = self
87 .codec
88 .exact_asymmetric_distance(&q_prepared, &self.nodes[c.idx as usize].quantized);
89 (asym, self.nodes[c.idx as usize].id)
90 })
91 .collect();
92
93 reranked.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
94 reranked.truncate(k);
95
96 reranked
97 .into_iter()
98 .map(|(distance, id)| CodecSearchResult { id, distance })
99 .collect()
100 }
101
102 fn greedy_nearest_search(&self, q_enc: &C::Quantized, ep_idx: u32, layer: usize) -> u32 {
104 let mut best_idx = ep_idx;
105 let mut best_dist = self
106 .codec
107 .fast_symmetric_distance(q_enc, &self.nodes[ep_idx as usize].quantized);
108
109 loop {
110 let mut improved = false;
111 for &nb in self.neighbors_at(best_idx, layer) {
112 if self.nodes[nb as usize].deleted {
113 continue;
114 }
115 let d = self
116 .codec
117 .fast_symmetric_distance(q_enc, &self.nodes[nb as usize].quantized);
118 if d < best_dist {
119 best_dist = d;
120 best_idx = nb;
121 improved = true;
122 }
123 }
124 if !improved {
125 break;
126 }
127 }
128
129 best_idx
130 }
131
132 fn search_layer_0(&self, q_enc: &C::Quantized, ep_idx: u32, ef: usize) -> Vec<Cand> {
138 let mut visited: HashSet<u32> = HashSet::new();
139 visited.insert(ep_idx);
140
141 let ep_dist = self
142 .codec
143 .fast_symmetric_distance(q_enc, &self.nodes[ep_idx as usize].quantized);
144 let ep_cand = Cand {
145 dist: ep_dist,
146 idx: ep_idx,
147 };
148
149 let mut candidates: BinaryHeap<Reverse<Cand>> = BinaryHeap::new();
150 candidates.push(Reverse(ep_cand));
151
152 let mut results: BinaryHeap<Cand> = BinaryHeap::new();
153 if !self.nodes[ep_idx as usize].deleted {
154 results.push(ep_cand);
155 }
156
157 while let Some(Reverse(cur)) = candidates.pop() {
158 let worst = results.peek().map_or(f32::INFINITY, |w| w.dist);
159 if cur.dist > worst && results.len() >= ef {
160 break;
161 }
162
163 for &nb in self.neighbors_at(cur.idx, 0) {
164 if !visited.insert(nb) {
165 continue;
166 }
167 let d = self
168 .codec
169 .fast_symmetric_distance(q_enc, &self.nodes[nb as usize].quantized);
170 let worst_now = results.peek().map_or(f32::INFINITY, |w| w.dist);
171 if d < worst_now || results.len() < ef {
172 candidates.push(Reverse(Cand { dist: d, idx: nb }));
173 }
174 if !self.nodes[nb as usize].deleted {
175 results.push(Cand { dist: d, idx: nb });
176 if results.len() > ef {
177 results.pop();
178 }
179 }
180 }
181 }
182
183 let mut out: Vec<Cand> = results.into_vec();
184 out.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
185 out
186 }
187}
188
189#[cfg(test)]
192mod tests {
193 use nodedb_codec::vector_quant::{bbq::BbqCodec, rabitq::RaBitQCodec};
194
195 use crate::{codec_index::HnswCodecIndex, distance::l2_squared, quantize::Sq8Codec};
196
197 fn xorshift(state: &mut u64) -> u64 {
200 *state ^= *state << 13;
201 *state ^= *state >> 7;
202 *state ^= *state << 17;
203 *state
204 }
205
206 fn rand_vec(state: &mut u64, dim: usize) -> Vec<f32> {
207 (0..dim)
208 .map(|_| (xorshift(state) as f32 / u64::MAX as f32) * 2.0 - 1.0)
209 .collect()
210 }
211
212 fn ground_truth(vecs: &[Vec<f32>], query: &[f32], k: usize) -> Vec<u32> {
214 let mut scored: Vec<(f32, u32)> = vecs
215 .iter()
216 .enumerate()
217 .map(|(i, v)| (l2_squared(query, v), i as u32))
218 .collect();
219 scored.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
220 scored.into_iter().take(k).map(|(_, id)| id).collect()
221 }
222
223 #[test]
226 fn sq8_top1_exact_match() {
227 let dim = 8;
228 let n = 50usize;
229 let mut state = 0xDEAD_BEEF_u64;
230 let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
231 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
232 let codec = Sq8Codec::calibrate(&refs, dim);
233 let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(dim, 8, 100, codec, 7);
234 for (i, v) in vecs.iter().enumerate() {
235 idx.insert(i as u32, v);
236 }
237 let query = vecs[17].clone();
239 let results = idx.search(&query, 1, 50);
240 assert_eq!(results.len(), 1, "expected 1 result");
241 assert_eq!(
242 results[0].id, 17,
243 "top-1 should be the queried vector itself"
244 );
245 assert!(
246 results[0].distance < 0.1,
247 "distance to self should be near 0, got {}",
248 results[0].distance
249 );
250 }
251
252 #[test]
255 fn rabitq_recall_at_least_60_percent() {
256 let dim = 16;
257 let n = 100usize;
258 let k = 5usize;
259 let mut state = 0xCAFE_BABE_u64;
260 let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
261 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
262 let codec = RaBitQCodec::calibrate(&refs, dim, 0xABCD_1234);
263 let mut idx: HnswCodecIndex<RaBitQCodec> = HnswCodecIndex::new(dim, 8, 150, codec, 99);
264 for (i, v) in vecs.iter().enumerate() {
265 idx.insert(i as u32, v);
266 }
267
268 let n_queries = 10usize;
269 let mut total_hits = 0usize;
270 let mut total = 0usize;
271 for _qi in 0..n_queries {
272 let query = rand_vec(&mut state, dim);
273 let truth: std::collections::HashSet<u32> =
274 ground_truth(&vecs, &query, k).into_iter().collect();
275 let results = idx.search(&query, k, k * 4);
276 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
277 total_hits += found.intersection(&truth).count();
278 total += k;
279 }
280 let recall = total_hits as f64 / total as f64;
281 assert!(
285 recall >= 0.30,
286 "RaBitQ recall@{k} = {recall:.2}, expected >= 0.30"
287 );
288 }
289
290 #[test]
293 fn bbq_recall_at_least_60_percent() {
294 let dim = 16;
295 let n = 100usize;
296 let k = 5usize;
297 let mut state = 0x1234_5678_u64;
298 let vecs: Vec<Vec<f32>> = (0..n).map(|_| rand_vec(&mut state, dim)).collect();
299 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
300 let codec = BbqCodec::calibrate(&refs, dim, 3);
301 let mut idx: HnswCodecIndex<BbqCodec> = HnswCodecIndex::new(dim, 8, 150, codec, 42);
302 for (i, v) in vecs.iter().enumerate() {
303 idx.insert(i as u32, v);
304 }
305
306 let n_queries = 10usize;
307 let mut total_hits = 0usize;
308 let mut total = 0usize;
309 for _qi in 0..n_queries {
310 let query = rand_vec(&mut state, dim);
311 let truth: std::collections::HashSet<u32> =
312 ground_truth(&vecs, &query, k).into_iter().collect();
313 let results = idx.search(&query, k, k * 4);
314 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
315 total_hits += found.intersection(&truth).count();
316 total += k;
317 }
318 let recall = total_hits as f64 / total as f64;
319 assert!(
323 recall >= 0.40,
324 "BBQ recall@{k} = {recall:.2}, expected >= 0.40"
325 );
326 }
327
328 #[test]
331 fn empty_index_returns_empty() {
332 let codec = {
333 let vecs: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32; 4]).collect();
334 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
335 Sq8Codec::calibrate(&refs, 4)
336 };
337 let idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(4, 8, 50, codec, 1);
338 let results = idx.search(&[0.0, 0.0, 0.0, 0.0], 5, 20);
339 assert!(results.is_empty(), "empty index must return no results");
340 }
341
342 #[test]
343 fn single_vector_index_always_returns_it() {
344 let dim = 4;
345 let vecs = [vec![1.0f32, 2.0, 3.0, 4.0]];
346 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
347 let codec = Sq8Codec::calibrate(&refs, dim);
348 let mut idx: HnswCodecIndex<Sq8Codec> = HnswCodecIndex::new(dim, 8, 50, codec, 5);
349 idx.insert(0, &vecs[0]);
350 let results = idx.search(&[10.0, 20.0, 30.0, 40.0], 1, 10);
352 assert_eq!(results.len(), 1, "single-node index must return 1 result");
353 assert_eq!(results[0].id, 0, "the only node must be returned");
354 }
355}