sqlrite/sql/hnsw.rs
1//! HNSW (Hierarchical Navigable Small World) approximate-nearest-neighbor
2//! index. Pure algorithm; no SQL integration in this module.
3//!
4//! HNSW is the industry-standard ANN algorithm for in-memory vector search:
5//! a multi-layer graph where each node lives at some randomly-assigned max
6//! layer; higher layers are sparser, layer 0 contains every node. Search
7//! starts at the entry point (the node at the current top layer), greedily
8//! descends layer-by-layer, then does a beam search at layer 0.
9//!
10//! ```text
11//! layer 2: [A] -- [E] sparse
12//! | |
13//! layer 1: [A] -- [E] -- [G] -- [J] mid
14//! | / | \ | \ |
15//! layer 0: [A,B,C,D,E,F,G,H,I,J,...] dense (every node)
16//! ```
17//!
18//! ## What this module is responsible for
19//!
20//! - The graph (per-node, per-layer neighbor lists)
21//! - Layer assignment for new nodes (geometric distribution)
22//! - Insertion: greedy descent + beam search + neighbor pruning
23//! - Query: greedy descent + beam search at layer 0, return top-k
24//!
25//! ## What it is NOT responsible for (yet)
26//!
27//! - **Storing vectors.** The algorithm calls a `get_vec(node_id) -> &[f32]`
28//! closure to fetch the vector for any node it touches. In Phase 7d.2
29//! that closure will read from the SQL table holding the indexed
30//! column; in tests it reads from an in-memory `Vec<Vec<f32>>`.
31//! - **Persistence.** The graph lives in `HashMap<i64, Node>` for now.
32//! Phase 7d.3 wires it into the cell-encoded page format.
33//! - **DELETE / UPDATE.** Pre-existing nodes can't be removed today.
34//! Soft-delete + lazy rebuild is the planned approach for 7d.2/7d.3.
35//!
36//! ## Parameters (per Phase 7 plan Q2 — fixed defaults)
37//!
38//! - `M = 16` — max neighbors per node at layers > 0
39//! - `m_max0 = 32` (= 2·M) — max neighbors at layer 0
40//! - `ef_construction = 200` — beam width during INSERT
41//! - `ef_search = 50` — default beam width during query
42//! - `m_l = 1/ln(M) ≈ 0.36` — layer-assignment scale
43//!
44//! ## Invariants
45//!
46//! - Every `node.layers` Vec has length `node_max_layer + 1` for that node.
47//! - `node.layers[i]` contains node_ids of neighbors at layer i. Each
48//! neighbor is itself a node in `nodes`; symmetrical (if A → B at layer i
49//! then B → A at layer i, modulo pruning).
50//! - `entry_point` is `Some(id)` iff `nodes` is non-empty. The entry node
51//! has the highest max-layer of any node currently in the graph.
52
53use std::cmp::Ordering;
54use std::collections::{BinaryHeap, HashMap, HashSet};
55
56/// Distance metric used by the HNSW index. Must match what the
57/// surrounding `vec_distance_*` SQL function would compute on the same
58/// pair of vectors — otherwise the index probe and the brute-force
59/// fallback would disagree on which rows are "nearest". See
60/// `src/sql/executor.rs`'s `vec_distance_l2` / `_cosine` / `_dot` for
61/// the canonical implementations.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum DistanceMetric {
64 L2,
65 Cosine,
66 Dot,
67}
68
69impl DistanceMetric {
70 /// Computes the configured distance between two equal-dimension
71 /// vectors. Returns `f32::INFINITY` for the cosine/zero-magnitude
72 /// edge case; HNSW treats infinity as "worst possible candidate" and
73 /// will prefer any finite alternative, which matches the SQL-level
74 /// behaviour where `vec_distance_cosine` errors but the optimizer's
75 /// fallback path simply skips the offending row.
76 pub fn compute(self, a: &[f32], b: &[f32]) -> f32 {
77 debug_assert_eq!(a.len(), b.len(), "vector dim mismatch in HNSW distance");
78 match self {
79 DistanceMetric::L2 => {
80 let mut sum = 0.0f32;
81 for i in 0..a.len() {
82 let d = a[i] - b[i];
83 sum += d * d;
84 }
85 sum.sqrt()
86 }
87 DistanceMetric::Cosine => {
88 let mut dot = 0.0f32;
89 let mut na = 0.0f32;
90 let mut nb = 0.0f32;
91 for i in 0..a.len() {
92 dot += a[i] * b[i];
93 na += a[i] * a[i];
94 nb += b[i] * b[i];
95 }
96 let denom = (na * nb).sqrt();
97 if denom == 0.0 {
98 f32::INFINITY
99 } else {
100 1.0 - dot / denom
101 }
102 }
103 DistanceMetric::Dot => {
104 let mut dot = 0.0f32;
105 for i in 0..a.len() {
106 dot += a[i] * b[i];
107 }
108 -dot
109 }
110 }
111 }
112}
113
114/// Per-node metadata: a list of neighbor IDs for each layer this node
115/// lives in. `layers[0]` is layer 0 (densest); `layers[layers.len() - 1]`
116/// is the highest layer this node reaches.
117#[derive(Debug, Clone, Default)]
118pub struct Node {
119 /// Indexed by layer (0 = dense). `layers[i]` is the neighbor list
120 /// for this node at layer i. Always sorted-by-distance is *not* a
121 /// guaranteed invariant — pruning maintains it after each
122 /// modification, but during insert we may briefly hold an
123 /// unsorted set.
124 pub layers: Vec<Vec<i64>>,
125}
126
127impl Node {
128 /// Maximum layer this node reaches. Equals `layers.len() - 1`.
129 pub fn max_layer(&self) -> usize {
130 self.layers.len() - 1
131 }
132}
133
134/// HNSW algorithm parameters. Phase 7 ships fixed defaults (Q2 in the
135/// plan); this struct is `Clone + Copy` so callers wanting to fork an
136/// experimental tuning can do so without touching the index itself.
137#[derive(Debug, Clone, Copy)]
138pub struct HnswParams {
139 pub m: usize,
140 pub m_max0: usize,
141 pub ef_construction: usize,
142 pub ef_search: usize,
143 pub m_l: f32,
144}
145
146impl Default for HnswParams {
147 fn default() -> Self {
148 let m = 16;
149 Self {
150 m,
151 m_max0: 2 * m,
152 ef_construction: 200,
153 ef_search: 50,
154 m_l: 1.0 / (m as f32).ln(),
155 }
156 }
157}
158
159/// In-memory HNSW graph. See module docs for the model.
160#[derive(Debug, Clone)]
161pub struct HnswIndex {
162 pub params: HnswParams,
163 pub distance: DistanceMetric,
164 /// Node id of the entry point. `None` iff the index is empty.
165 /// At all times this is the id of the node with the highest
166 /// max-layer; if multiple nodes tie for the top layer, the
167 /// most-recently-promoted one wins.
168 pub entry_point: Option<i64>,
169 /// Highest layer currently populated. 0 when the index has at
170 /// most one node, grows as new nodes get assigned higher layers.
171 pub top_layer: usize,
172 /// Node id → its per-layer neighbor lists.
173 pub nodes: HashMap<i64, Node>,
174 /// xorshift64 RNG state for layer assignment. Seeded explicitly via
175 /// `new` so tests can pin a known sequence.
176 rng_state: u64,
177}
178
179impl HnswIndex {
180 /// Builds an empty HNSW index with default parameters and the given
181 /// distance metric + RNG seed. A seed of 0 is mapped to a small
182 /// nonzero constant — xorshift gets stuck at zero.
183 pub fn new(distance: DistanceMetric, seed: u64) -> Self {
184 let seed = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
185 Self {
186 params: HnswParams::default(),
187 distance,
188 entry_point: None,
189 top_layer: 0,
190 nodes: HashMap::new(),
191 rng_state: seed,
192 }
193 }
194
195 /// True if no nodes have been inserted yet.
196 pub fn is_empty(&self) -> bool {
197 self.nodes.is_empty()
198 }
199
200 /// Number of nodes currently in the index.
201 pub fn len(&self) -> usize {
202 self.nodes.len()
203 }
204
205 /// Inserts a node into the graph. The node id must be unique;
206 /// re-inserting an existing id is a no-op (returns without error).
207 /// `vec` is the new node's vector; `get_vec` looks up the vector
208 /// for any other node id the algorithm touches.
209 pub fn insert<F>(&mut self, node_id: i64, vec: &[f32], get_vec: F)
210 where
211 F: Fn(i64) -> Vec<f32>,
212 {
213 if self.nodes.contains_key(&node_id) {
214 return;
215 }
216
217 // First node: trivial case. Becomes entry point at layer 0.
218 if self.is_empty() {
219 self.nodes.insert(
220 node_id,
221 Node {
222 layers: vec![Vec::new()],
223 },
224 );
225 self.entry_point = Some(node_id);
226 self.top_layer = 0;
227 return;
228 }
229
230 // Pick a layer for this new node.
231 let target_layer = self.pick_layer();
232
233 // Pre-allocate the new node's layer lists (empty for now;
234 // populated below).
235 let new_node = Node {
236 layers: vec![Vec::new(); target_layer + 1],
237 };
238 self.nodes.insert(node_id, new_node);
239
240 // Greedy descent from top down to (target_layer + 1) — at each
241 // layer above our target, advance the entry point to the
242 // single closest node. We don't add edges at these layers
243 // because the new node doesn't live there.
244 let mut entry = self.entry_point.expect("non-empty index has entry point");
245 for layer in (target_layer + 1..=self.top_layer).rev() {
246 let nearest = self.search_layer(vec, &[entry], 1, layer, &get_vec);
247 if let Some((_, id)) = nearest.into_iter().next() {
248 entry = id;
249 }
250 }
251
252 // Beam search + connect at each layer the new node lives in.
253 // We work top-down; the entry point for each layer is the best
254 // candidate found at the layer above.
255 let mut entries = vec![entry];
256 for layer in (0..=target_layer).rev() {
257 let candidates =
258 self.search_layer(vec, &entries, self.params.ef_construction, layer, &get_vec);
259
260 // Pick up to M neighbors from candidates (M_max0 at layer 0
261 // since we allow more connections at the dense layer).
262 let m_max = if layer == 0 {
263 self.params.m_max0
264 } else {
265 self.params.m
266 };
267 let neighbors: Vec<i64> = candidates
268 .iter()
269 .take(self.params.m)
270 .map(|(_, id)| *id)
271 .collect();
272
273 // Wire up the bidirectional edges.
274 self.nodes.get_mut(&node_id).expect("just inserted").layers[layer] = neighbors.clone();
275
276 for &nb in &neighbors {
277 let nb_layers = &mut self.nodes.get_mut(&nb).expect("neighbor must exist").layers;
278 if layer >= nb_layers.len() {
279 // Neighbor doesn't actually live at this layer — shouldn't
280 // happen because search_layer only returns nodes at this
281 // layer, but defend against it.
282 continue;
283 }
284 nb_layers[layer].push(node_id);
285
286 // Prune the neighbor's edge list if it's now over its M_max
287 // budget. Pruning policy: keep the closest M_max nodes
288 // by distance. (Distance recomputed; no precomputed values.)
289 if nb_layers[layer].len() > m_max {
290 let nb_vec = get_vec(nb);
291 let mut by_dist: Vec<(f32, i64)> = nb_layers[layer]
292 .iter()
293 .map(|id| (self.distance.compute(&nb_vec, &get_vec(*id)), *id))
294 .collect();
295 by_dist
296 .sort_by(|(da, _), (db, _)| da.partial_cmp(db).unwrap_or(Ordering::Equal));
297 by_dist.truncate(m_max);
298 nb_layers[layer] = by_dist.into_iter().map(|(_, id)| id).collect();
299 }
300 }
301
302 // Carry the candidate set forward as entry points for the
303 // next (lower) layer.
304 entries = candidates.into_iter().map(|(_, id)| id).collect();
305 }
306
307 // If this new node lives higher than the current top, promote it.
308 if target_layer > self.top_layer {
309 self.top_layer = target_layer;
310 self.entry_point = Some(node_id);
311 }
312 }
313
314 /// Returns the k nearest node ids to `query`, in distance-ascending
315 /// order (closest first). Empty index returns an empty Vec.
316 pub fn search<F>(&self, query: &[f32], k: usize, get_vec: F) -> Vec<i64>
317 where
318 F: Fn(i64) -> Vec<f32>,
319 {
320 if self.is_empty() || k == 0 {
321 return Vec::new();
322 }
323
324 // Greedy descent from the top down to layer 1.
325 let mut entry = self.entry_point.expect("non-empty index has entry point");
326 for layer in (1..=self.top_layer).rev() {
327 let nearest = self.search_layer(query, &[entry], 1, layer, &get_vec);
328 if let Some((_, id)) = nearest.into_iter().next() {
329 entry = id;
330 }
331 }
332
333 // Beam search at layer 0 with width = max(ef_search, k).
334 let ef = self.params.ef_search.max(k);
335 let candidates = self.search_layer(query, &[entry], ef, 0, &get_vec);
336
337 candidates.into_iter().take(k).map(|(_, id)| id).collect()
338 }
339
340 /// Runs a beam search at one layer starting from `entries`, returning
341 /// the top-`ef` nearest nodes to `query` found, sorted by distance
342 /// ascending.
343 ///
344 /// This is the workhorse of both insert and search. The two priority
345 /// queues — "candidates" (nodes still to expand) and "results"
346 /// (current best ef found) — terminate when the closest unexpanded
347 /// candidate is farther than the worst kept result.
348 fn search_layer<F>(
349 &self,
350 query: &[f32],
351 entries: &[i64],
352 ef: usize,
353 layer: usize,
354 get_vec: &F,
355 ) -> Vec<(f32, i64)>
356 where
357 F: Fn(i64) -> Vec<f32>,
358 {
359 let mut visited: HashSet<i64> = HashSet::with_capacity(ef * 2);
360 // candidates: min-heap of (distance, id) — pop closest first.
361 let mut candidates: BinaryHeap<MinHeapItem> = BinaryHeap::with_capacity(ef * 2);
362 // results: max-heap of (distance, id) — top is the worst kept.
363 let mut results: BinaryHeap<MaxHeapItem> = BinaryHeap::with_capacity(ef);
364
365 for &id in entries {
366 if !visited.insert(id) {
367 continue;
368 }
369 let d = self.distance.compute(query, &get_vec(id));
370 candidates.push(MinHeapItem { dist: d, id });
371 results.push(MaxHeapItem { dist: d, id });
372 }
373
374 while let Some(MinHeapItem {
375 dist: c_dist,
376 id: c_id,
377 }) = candidates.pop()
378 {
379 // If the closest unexpanded candidate is worse than the
380 // worst kept result, no further expansion can improve the
381 // result set. Bail.
382 if let Some(worst) = results.peek() {
383 if results.len() >= ef && c_dist > worst.dist {
384 break;
385 }
386 }
387
388 // Expand: visit each neighbor of c_id at this layer.
389 let neighbors = self
390 .nodes
391 .get(&c_id)
392 .and_then(|n| n.layers.get(layer))
393 .cloned()
394 .unwrap_or_default();
395 for nb in neighbors {
396 if !visited.insert(nb) {
397 continue;
398 }
399 let d = self.distance.compute(query, &get_vec(nb));
400 let admit = if results.len() < ef {
401 true
402 } else {
403 d < results.peek().unwrap().dist
404 };
405 if admit {
406 candidates.push(MinHeapItem { dist: d, id: nb });
407 results.push(MaxHeapItem { dist: d, id: nb });
408 if results.len() > ef {
409 results.pop();
410 }
411 }
412 }
413 }
414
415 // Drain results into a sorted vec. results is a max-heap, so
416 // popping gives descending order; reverse for ascending.
417 let mut out: Vec<(f32, i64)> = Vec::with_capacity(results.len());
418 while let Some(item) = results.pop() {
419 out.push((item.dist, item.id));
420 }
421 out.reverse();
422 out
423 }
424
425 /// Picks a layer for a new node using the standard HNSW geometric
426 /// distribution: `L = floor(-ln(uniform) * m_l)`. With M=16, mL ≈ 0.36,
427 /// so:
428 /// - P(L=0) ≈ 1 - 1/M = 15/16
429 /// - P(L=1) ≈ 1/16 - 1/256
430 /// - P(L=2) ≈ 1/256 - …
431 /// i.e., most new nodes live only at layer 0; a few percolate up.
432 fn pick_layer(&mut self) -> usize {
433 let u = self.next_uniform().max(1e-6); // guard log(0)
434 let layer = (-u.ln() * self.params.m_l).floor() as usize;
435 // Cap at top_layer + 1 to keep the graph from sprouting empty
436 // layers above the current top — matches the original HNSW
437 // paper's recommendation.
438 layer.min(self.top_layer + 1)
439 }
440
441 /// Pulls a uniform-on-(0, 1] f32 from the internal xorshift state.
442 /// Top 24 bits of the next u64, divided by 2^24 — gives 24-bit
443 /// uniform precision, plenty for layer assignment.
444 fn next_uniform(&mut self) -> f32 {
445 let mut x = self.rng_state;
446 x ^= x << 13;
447 x ^= x >> 7;
448 x ^= x << 17;
449 self.rng_state = x;
450 ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
451 }
452}
453
454// -----------------------------------------------------------------
455// Heap items
456//
457// Rust's BinaryHeap is a max-heap that uses Ord. f32 doesn't impl Ord
458// (NaN), so we wrap (distance, id) pairs and provide custom Ord that
459// uses partial_cmp with NaN treated as Greater (NaN sorts as worst).
460//
461// MinHeapItem inverts the comparison so BinaryHeap<MinHeapItem> behaves
462// as a min-heap — top is the smallest distance, popping gives ascending
463// order.
464//
465// MaxHeapItem uses the natural ordering — top is the largest distance.
466
467#[derive(Debug, Clone, Copy)]
468struct MinHeapItem {
469 dist: f32,
470 id: i64,
471}
472
473impl PartialEq for MinHeapItem {
474 fn eq(&self, other: &Self) -> bool {
475 self.dist == other.dist && self.id == other.id
476 }
477}
478impl Eq for MinHeapItem {}
479impl PartialOrd for MinHeapItem {
480 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
481 Some(self.cmp(other))
482 }
483}
484impl Ord for MinHeapItem {
485 fn cmp(&self, other: &Self) -> Ordering {
486 // Reverse so smallest distance bubbles to top.
487 other
488 .dist
489 .partial_cmp(&self.dist)
490 .unwrap_or(Ordering::Equal)
491 .then(other.id.cmp(&self.id))
492 }
493}
494
495#[derive(Debug, Clone, Copy)]
496struct MaxHeapItem {
497 dist: f32,
498 id: i64,
499}
500
501impl PartialEq for MaxHeapItem {
502 fn eq(&self, other: &Self) -> bool {
503 self.dist == other.dist && self.id == other.id
504 }
505}
506impl Eq for MaxHeapItem {}
507impl PartialOrd for MaxHeapItem {
508 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
509 Some(self.cmp(other))
510 }
511}
512impl Ord for MaxHeapItem {
513 fn cmp(&self, other: &Self) -> Ordering {
514 // Natural so largest distance bubbles to top.
515 self.dist
516 .partial_cmp(&other.dist)
517 .unwrap_or(Ordering::Equal)
518 .then(self.id.cmp(&other.id))
519 }
520}
521
522// -----------------------------------------------------------------
523// Tests
524// -----------------------------------------------------------------
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 /// Deterministic xorshift to generate test vectors.
531 fn random_vec(state: &mut u64, dim: usize) -> Vec<f32> {
532 (0..dim)
533 .map(|_| {
534 let mut x = *state;
535 x ^= x << 13;
536 x ^= x >> 7;
537 x ^= x << 17;
538 *state = x;
539 ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
540 })
541 .collect()
542 }
543
544 /// Brute-force nearest-neighbors baseline for recall comparison.
545 fn brute_force_topk(
546 vectors: &[Vec<f32>],
547 query: &[f32],
548 k: usize,
549 metric: DistanceMetric,
550 ) -> Vec<i64> {
551 let mut by_dist: Vec<(f32, i64)> = vectors
552 .iter()
553 .enumerate()
554 .map(|(i, v)| (metric.compute(query, v), i as i64))
555 .collect();
556 by_dist.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(Ordering::Equal));
557 by_dist.into_iter().take(k).map(|(_, id)| id).collect()
558 }
559
560 /// recall@k — fraction of the brute-force top-k that the HNSW
561 /// search also returned (in any order).
562 fn recall_at_k(hnsw_result: &[i64], baseline: &[i64]) -> f32 {
563 let baseline_set: HashSet<i64> = baseline.iter().copied().collect();
564 let hits = hnsw_result
565 .iter()
566 .filter(|id| baseline_set.contains(id))
567 .count();
568 hits as f32 / baseline.len() as f32
569 }
570
571 #[test]
572 fn empty_index_returns_empty_search() {
573 let idx = HnswIndex::new(DistanceMetric::L2, 42);
574 let vectors: Vec<Vec<f32>> = vec![];
575 let result = idx.search(&[0.0; 4], 5, |id| vectors[id as usize].clone());
576 assert!(result.is_empty());
577 }
578
579 #[test]
580 fn single_node_returns_only_itself() {
581 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
582 let v0 = vec![1.0, 2.0, 3.0];
583 let vectors = vec![v0.clone()];
584 idx.insert(0, &v0, |id| vectors[id as usize].clone());
585 let result = idx.search(&[0.0; 3], 5, |id| vectors[id as usize].clone());
586 assert_eq!(result, vec![0]);
587 }
588
589 #[test]
590 fn duplicate_insert_is_noop() {
591 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
592 let v0 = vec![1.0, 2.0];
593 let vectors = vec![v0.clone()];
594 idx.insert(0, &v0, |id| vectors[id as usize].clone());
595 idx.insert(0, &v0, |id| vectors[id as usize].clone());
596 assert_eq!(idx.len(), 1);
597 }
598
599 #[test]
600 fn k_zero_returns_empty() {
601 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
602 let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
603 for (i, v) in vectors.iter().enumerate() {
604 idx.insert(i as i64, v, |id| vectors[id as usize].clone());
605 }
606 let result = idx.search(&[0.5, 0.5], 0, |id| vectors[id as usize].clone());
607 assert!(result.is_empty());
608 }
609
610 #[test]
611 fn small_graph_finds_exact_nearest() {
612 // 5 well-separated points in 2D — HNSW should find the exact
613 // nearest with no recall loss for k=1 and k=3.
614 let vectors: Vec<Vec<f32>> = vec![
615 vec![0.0, 0.0],
616 vec![10.0, 0.0],
617 vec![0.0, 10.0],
618 vec![10.0, 10.0],
619 vec![5.0, 5.0],
620 ];
621 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
622 for (i, v) in vectors.iter().enumerate() {
623 idx.insert(i as i64, v, |id| vectors[id as usize].clone());
624 }
625
626 // Query at (1, 1): nearest is (0, 0).
627 let result = idx.search(&[1.0, 1.0], 1, |id| vectors[id as usize].clone());
628 assert_eq!(result, vec![0]);
629
630 // Query at (5.5, 5.5): top-3 should be id=4 (5,5), then any
631 // two of the corners at distance ~7.78.
632 let result = idx.search(&[5.5, 5.5], 3, |id| vectors[id as usize].clone());
633 assert_eq!(result.len(), 3);
634 assert_eq!(result[0], 4, "closest to (5.5,5.5) should be id=4");
635 }
636
637 #[test]
638 fn recall_at_10_is_high_on_random_vectors_l2() {
639 // Standard recall test: 1000 random vectors in 8D, query for
640 // top-10 with HNSW, compare to brute-force ground truth.
641 // Modern HNSW papers target recall@10 > 0.95; we should clear
642 // that comfortably on this small benchmark.
643 let mut state: u64 = 0xDEADBEEF;
644 let dim = 8;
645 let n = 1000;
646 let queries = 20;
647 let k = 10;
648
649 let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
650
651 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
652 for (i, v) in vectors.iter().enumerate() {
653 idx.insert(i as i64, v, |id| vectors[id as usize].clone());
654 }
655
656 let mut total_recall = 0.0f32;
657 for _ in 0..queries {
658 let q = random_vec(&mut state, dim);
659 let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
660 let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::L2);
661 total_recall += recall_at_k(&hnsw_top, &baseline);
662 }
663 let avg_recall = total_recall / queries as f32;
664 assert!(
665 avg_recall >= 0.95,
666 "recall@{k} dropped below 0.95: avg={avg_recall:.3}"
667 );
668 }
669
670 #[test]
671 fn recall_at_10_is_high_on_random_vectors_cosine() {
672 // Same shape as the L2 test but with cosine distance, to
673 // exercise the alternative metric through the same pipeline.
674 let mut state: u64 = 0xC0FFEE;
675 let dim = 16;
676 let n = 500;
677 let queries = 20;
678 let k = 10;
679
680 let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
681
682 let mut idx = HnswIndex::new(DistanceMetric::Cosine, 42);
683 for (i, v) in vectors.iter().enumerate() {
684 idx.insert(i as i64, v, |id| vectors[id as usize].clone());
685 }
686
687 let mut total_recall = 0.0f32;
688 for _ in 0..queries {
689 let q = random_vec(&mut state, dim);
690 let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
691 let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::Cosine);
692 total_recall += recall_at_k(&hnsw_top, &baseline);
693 }
694 let avg_recall = total_recall / queries as f32;
695 assert!(
696 avg_recall >= 0.95,
697 "cosine recall@{k} dropped below 0.95: avg={avg_recall:.3}"
698 );
699 }
700
701 #[test]
702 fn entry_point_promotes_when_higher_layer_node_inserted() {
703 // The graph's entry point should always be a node at the
704 // current top layer. Insert two nodes; if the second lands at
705 // a higher layer, it becomes the entry point.
706 // We can't easily force a particular layer (it's randomized),
707 // so check the invariant: after every insert, the entry node's
708 // max_layer == top_layer.
709 let mut state: u64 = 0xABCDEF;
710 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
711 let dim = 4;
712 let mut vectors: Vec<Vec<f32>> = Vec::new();
713 for i in 0..50 {
714 vectors.push(random_vec(&mut state, dim));
715 let v = vectors[i].clone();
716 idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
717
718 // Check invariant.
719 let entry = idx.entry_point.expect("non-empty");
720 let entry_max = idx.nodes[&entry].max_layer();
721 assert_eq!(
722 entry_max, idx.top_layer,
723 "entry-point invariant broken at step {i}: entry {entry} has max_layer {entry_max}, top_layer is {}",
724 idx.top_layer
725 );
726 }
727 }
728
729 #[test]
730 fn neighbor_lists_respect_m_max() {
731 // After inserting 200 points with M=16 (so M_max0 = 32), no
732 // node should have more than 32 neighbors at layer 0 or more
733 // than 16 at any higher layer.
734 let mut state: u64 = 0x123456;
735 let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
736 let dim = 4;
737 let mut vectors: Vec<Vec<f32>> = Vec::new();
738 for i in 0..200 {
739 vectors.push(random_vec(&mut state, dim));
740 let v = vectors[i].clone();
741 idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
742 }
743
744 for (id, node) in &idx.nodes {
745 for (layer, neighbors) in node.layers.iter().enumerate() {
746 let cap = if layer == 0 {
747 idx.params.m_max0
748 } else {
749 idx.params.m
750 };
751 assert!(
752 neighbors.len() <= cap,
753 "node {id} layer {layer} has {} > cap {cap}",
754 neighbors.len()
755 );
756 }
757 }
758 }
759
760 #[test]
761 fn deterministic_with_fixed_seed() {
762 // Same seed + same insert order → same graph topology.
763 // Catches accidental sources of nondeterminism (HashMap
764 // iteration order, etc.).
765 let mut state: u64 = 0x999;
766 let dim = 4;
767 let n = 50;
768 let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
769
770 let mut idx_a = HnswIndex::new(DistanceMetric::L2, 42);
771 let mut idx_b = HnswIndex::new(DistanceMetric::L2, 42);
772 for (i, v) in vectors.iter().enumerate() {
773 idx_a.insert(i as i64, v, |id| vectors[id as usize].clone());
774 idx_b.insert(i as i64, v, |id| vectors[id as usize].clone());
775 }
776
777 // Same top layer.
778 assert_eq!(idx_a.top_layer, idx_b.top_layer);
779 // Same entry point.
780 assert_eq!(idx_a.entry_point, idx_b.entry_point);
781 // Same node count and same per-node max-layer for every id.
782 // (Neighbor list contents may differ trivially if HashMap
783 // iteration sneaked in; if this fails, fix the source first.)
784 assert_eq!(idx_a.nodes.len(), idx_b.nodes.len());
785 for (id, node_a) in &idx_a.nodes {
786 let node_b = idx_b.nodes.get(id).expect("missing id");
787 assert_eq!(node_a.max_layer(), node_b.max_layer(), "id={id}");
788 }
789 }
790}