Skip to main content

sqlrite/sql/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) approximate-nearest-neighbor
2//! index. Pure algorithm; no SQL integration in this module.
3//!
4//! HNSW is the industry-standard ANN algorithm for in-memory vector search:
5//! a multi-layer graph where each node lives at some randomly-assigned max
6//! layer; higher layers are sparser, layer 0 contains every node. Search
7//! starts at the entry point (the node at the current top layer), greedily
8//! descends layer-by-layer, then does a beam search at layer 0.
9//!
10//! ```text
11//!     layer 2:   [A] -- [E]                    sparse
12//!                 |       |
13//!     layer 1:   [A] -- [E] -- [G] -- [J]      mid
14//!                 |  /  |  \   |  \   |
15//!     layer 0:   [A,B,C,D,E,F,G,H,I,J,...]     dense (every node)
16//! ```
17//!
18//! ## What this module is responsible for
19//!
20//! - The graph (per-node, per-layer neighbor lists)
21//! - Layer assignment for new nodes (geometric distribution)
22//! - Insertion: greedy descent + beam search + neighbor pruning
23//! - Query: greedy descent + beam search at layer 0, return top-k
24//!
25//! ## What it is NOT responsible for (yet)
26//!
27//! - **Storing vectors.** The algorithm calls a `get_vec(node_id) -> &[f32]`
28//!   closure to fetch the vector for any node it touches. In Phase 7d.2
29//!   that closure will read from the SQL table holding the indexed
30//!   column; in tests it reads from an in-memory `Vec<Vec<f32>>`.
31//! - **Persistence.** The graph lives in `HashMap<i64, Node>` for now.
32//!   Phase 7d.3 wires it into the cell-encoded page format.
33//! - **DELETE / UPDATE.** Pre-existing nodes can't be removed today.
34//!   Soft-delete + lazy rebuild is the planned approach for 7d.2/7d.3.
35//!
36//! ## Parameters (per Phase 7 plan Q2 — fixed defaults)
37//!
38//! - `M = 16`              — max neighbors per node at layers > 0
39//! - `m_max0 = 32` (= 2·M) — max neighbors at layer 0
40//! - `ef_construction = 200` — beam width during INSERT
41//! - `ef_search = 50`      — default beam width during query
42//! - `m_l = 1/ln(M) ≈ 0.36`  — layer-assignment scale
43//!
44//! ## Invariants
45//!
46//! - Every `node.layers` Vec has length `node_max_layer + 1` for that node.
47//! - `node.layers[i]` contains node_ids of neighbors at layer i. Each
48//!   neighbor is itself a node in `nodes`; symmetrical (if A → B at layer i
49//!   then B → A at layer i, modulo pruning).
50//! - `entry_point` is `Some(id)` iff `nodes` is non-empty. The entry node
51//!   has the highest max-layer of any node currently in the graph.
52
53use std::cmp::Ordering;
54use std::collections::{BinaryHeap, HashMap, HashSet};
55
56/// Distance metric used by the HNSW index. Must match what the
57/// surrounding `vec_distance_*` SQL function would compute on the same
58/// pair of vectors — otherwise the index probe and the brute-force
59/// fallback would disagree on which rows are "nearest". See
60/// `src/sql/executor.rs`'s `vec_distance_l2` / `_cosine` / `_dot` for
61/// the canonical implementations.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum DistanceMetric {
64    L2,
65    Cosine,
66    Dot,
67}
68
69impl DistanceMetric {
70    /// Computes the configured distance between two equal-dimension
71    /// vectors. Returns `f32::INFINITY` for the cosine/zero-magnitude
72    /// edge case; HNSW treats infinity as "worst possible candidate" and
73    /// will prefer any finite alternative, which matches the SQL-level
74    /// behaviour where `vec_distance_cosine` errors but the optimizer's
75    /// fallback path simply skips the offending row.
76    pub fn compute(self, a: &[f32], b: &[f32]) -> f32 {
77        debug_assert_eq!(a.len(), b.len(), "vector dim mismatch in HNSW distance");
78        match self {
79            DistanceMetric::L2 => {
80                let mut sum = 0.0f32;
81                for i in 0..a.len() {
82                    let d = a[i] - b[i];
83                    sum += d * d;
84                }
85                sum.sqrt()
86            }
87            DistanceMetric::Cosine => {
88                let mut dot = 0.0f32;
89                let mut na = 0.0f32;
90                let mut nb = 0.0f32;
91                for i in 0..a.len() {
92                    dot += a[i] * b[i];
93                    na += a[i] * a[i];
94                    nb += b[i] * b[i];
95                }
96                let denom = (na * nb).sqrt();
97                if denom == 0.0 {
98                    f32::INFINITY
99                } else {
100                    1.0 - dot / denom
101                }
102            }
103            DistanceMetric::Dot => {
104                let mut dot = 0.0f32;
105                for i in 0..a.len() {
106                    dot += a[i] * b[i];
107                }
108                -dot
109            }
110        }
111    }
112}
113
114/// Per-node metadata: a list of neighbor IDs for each layer this node
115/// lives in. `layers[0]` is layer 0 (densest); `layers[layers.len() - 1]`
116/// is the highest layer this node reaches.
117#[derive(Debug, Clone, Default)]
118pub struct Node {
119    /// Indexed by layer (0 = dense). `layers[i]` is the neighbor list
120    /// for this node at layer i. Always sorted-by-distance is *not* a
121    /// guaranteed invariant — pruning maintains it after each
122    /// modification, but during insert we may briefly hold an
123    /// unsorted set.
124    pub layers: Vec<Vec<i64>>,
125}
126
127impl Node {
128    /// Maximum layer this node reaches. Equals `layers.len() - 1`.
129    pub fn max_layer(&self) -> usize {
130        self.layers.len() - 1
131    }
132}
133
134/// HNSW algorithm parameters. Phase 7 ships fixed defaults (Q2 in the
135/// plan); this struct is `Clone + Copy` so callers wanting to fork an
136/// experimental tuning can do so without touching the index itself.
137#[derive(Debug, Clone, Copy)]
138pub struct HnswParams {
139    pub m: usize,
140    pub m_max0: usize,
141    pub ef_construction: usize,
142    pub ef_search: usize,
143    pub m_l: f32,
144}
145
146impl Default for HnswParams {
147    fn default() -> Self {
148        let m = 16;
149        Self {
150            m,
151            m_max0: 2 * m,
152            ef_construction: 200,
153            ef_search: 50,
154            m_l: 1.0 / (m as f32).ln(),
155        }
156    }
157}
158
159/// In-memory HNSW graph. See module docs for the model.
160#[derive(Debug, Clone)]
161pub struct HnswIndex {
162    pub params: HnswParams,
163    pub distance: DistanceMetric,
164    /// Node id of the entry point. `None` iff the index is empty.
165    /// At all times this is the id of the node with the highest
166    /// max-layer; if multiple nodes tie for the top layer, the
167    /// most-recently-promoted one wins.
168    pub entry_point: Option<i64>,
169    /// Highest layer currently populated. 0 when the index has at
170    /// most one node, grows as new nodes get assigned higher layers.
171    pub top_layer: usize,
172    /// Node id → its per-layer neighbor lists.
173    pub nodes: HashMap<i64, Node>,
174    /// xorshift64 RNG state for layer assignment. Seeded explicitly via
175    /// `new` so tests can pin a known sequence.
176    rng_state: u64,
177}
178
179impl HnswIndex {
180    /// Builds an empty HNSW index with default parameters and the given
181    /// distance metric + RNG seed. A seed of 0 is mapped to a small
182    /// nonzero constant — xorshift gets stuck at zero.
183    pub fn new(distance: DistanceMetric, seed: u64) -> Self {
184        let seed = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
185        Self {
186            params: HnswParams::default(),
187            distance,
188            entry_point: None,
189            top_layer: 0,
190            nodes: HashMap::new(),
191            rng_state: seed,
192        }
193    }
194
195    /// True if no nodes have been inserted yet.
196    pub fn is_empty(&self) -> bool {
197        self.nodes.is_empty()
198    }
199
200    /// Number of nodes currently in the index.
201    pub fn len(&self) -> usize {
202        self.nodes.len()
203    }
204
205    /// Inserts a node into the graph. The node id must be unique;
206    /// re-inserting an existing id is a no-op (returns without error).
207    /// `vec` is the new node's vector; `get_vec` looks up the vector
208    /// for any other node id the algorithm touches.
209    pub fn insert<F>(&mut self, node_id: i64, vec: &[f32], get_vec: F)
210    where
211        F: Fn(i64) -> Vec<f32>,
212    {
213        if self.nodes.contains_key(&node_id) {
214            return;
215        }
216
217        // First node: trivial case. Becomes entry point at layer 0.
218        if self.is_empty() {
219            self.nodes.insert(
220                node_id,
221                Node {
222                    layers: vec![Vec::new()],
223                },
224            );
225            self.entry_point = Some(node_id);
226            self.top_layer = 0;
227            return;
228        }
229
230        // Pick a layer for this new node.
231        let target_layer = self.pick_layer();
232
233        // Pre-allocate the new node's layer lists (empty for now;
234        // populated below).
235        let new_node = Node {
236            layers: vec![Vec::new(); target_layer + 1],
237        };
238        self.nodes.insert(node_id, new_node);
239
240        // Greedy descent from top down to (target_layer + 1) — at each
241        // layer above our target, advance the entry point to the
242        // single closest node. We don't add edges at these layers
243        // because the new node doesn't live there.
244        let mut entry = self.entry_point.expect("non-empty index has entry point");
245        for layer in (target_layer + 1..=self.top_layer).rev() {
246            let nearest = self.search_layer(vec, &[entry], 1, layer, &get_vec);
247            if let Some((_, id)) = nearest.into_iter().next() {
248                entry = id;
249            }
250        }
251
252        // Beam search + connect at each layer the new node lives in.
253        // We work top-down; the entry point for each layer is the best
254        // candidate found at the layer above.
255        let mut entries = vec![entry];
256        for layer in (0..=target_layer).rev() {
257            let candidates =
258                self.search_layer(vec, &entries, self.params.ef_construction, layer, &get_vec);
259
260            // Pick up to M neighbors from candidates (M_max0 at layer 0
261            // since we allow more connections at the dense layer).
262            let m_max = if layer == 0 {
263                self.params.m_max0
264            } else {
265                self.params.m
266            };
267            let neighbors: Vec<i64> = candidates
268                .iter()
269                .take(self.params.m)
270                .map(|(_, id)| *id)
271                .collect();
272
273            // Wire up the bidirectional edges.
274            self.nodes.get_mut(&node_id).expect("just inserted").layers[layer] = neighbors.clone();
275
276            for &nb in &neighbors {
277                let nb_layers = &mut self.nodes.get_mut(&nb).expect("neighbor must exist").layers;
278                if layer >= nb_layers.len() {
279                    // Neighbor doesn't actually live at this layer — shouldn't
280                    // happen because search_layer only returns nodes at this
281                    // layer, but defend against it.
282                    continue;
283                }
284                nb_layers[layer].push(node_id);
285
286                // Prune the neighbor's edge list if it's now over its M_max
287                // budget. Pruning policy: keep the closest M_max nodes
288                // by distance. (Distance recomputed; no precomputed values.)
289                if nb_layers[layer].len() > m_max {
290                    let nb_vec = get_vec(nb);
291                    let mut by_dist: Vec<(f32, i64)> = nb_layers[layer]
292                        .iter()
293                        .map(|id| (self.distance.compute(&nb_vec, &get_vec(*id)), *id))
294                        .collect();
295                    by_dist
296                        .sort_by(|(da, _), (db, _)| da.partial_cmp(db).unwrap_or(Ordering::Equal));
297                    by_dist.truncate(m_max);
298                    nb_layers[layer] = by_dist.into_iter().map(|(_, id)| id).collect();
299                }
300            }
301
302            // Carry the candidate set forward as entry points for the
303            // next (lower) layer.
304            entries = candidates.into_iter().map(|(_, id)| id).collect();
305        }
306
307        // If this new node lives higher than the current top, promote it.
308        if target_layer > self.top_layer {
309            self.top_layer = target_layer;
310            self.entry_point = Some(node_id);
311        }
312    }
313
314    /// Returns the k nearest node ids to `query`, in distance-ascending
315    /// order (closest first). Empty index returns an empty Vec.
316    pub fn search<F>(&self, query: &[f32], k: usize, get_vec: F) -> Vec<i64>
317    where
318        F: Fn(i64) -> Vec<f32>,
319    {
320        if self.is_empty() || k == 0 {
321            return Vec::new();
322        }
323
324        // Greedy descent from the top down to layer 1.
325        let mut entry = self.entry_point.expect("non-empty index has entry point");
326        for layer in (1..=self.top_layer).rev() {
327            let nearest = self.search_layer(query, &[entry], 1, layer, &get_vec);
328            if let Some((_, id)) = nearest.into_iter().next() {
329                entry = id;
330            }
331        }
332
333        // Beam search at layer 0 with width = max(ef_search, k).
334        let ef = self.params.ef_search.max(k);
335        let candidates = self.search_layer(query, &[entry], ef, 0, &get_vec);
336
337        candidates.into_iter().take(k).map(|(_, id)| id).collect()
338    }
339
340    /// Runs a beam search at one layer starting from `entries`, returning
341    /// the top-`ef` nearest nodes to `query` found, sorted by distance
342    /// ascending.
343    ///
344    /// This is the workhorse of both insert and search. The two priority
345    /// queues — "candidates" (nodes still to expand) and "results"
346    /// (current best ef found) — terminate when the closest unexpanded
347    /// candidate is farther than the worst kept result.
348    fn search_layer<F>(
349        &self,
350        query: &[f32],
351        entries: &[i64],
352        ef: usize,
353        layer: usize,
354        get_vec: &F,
355    ) -> Vec<(f32, i64)>
356    where
357        F: Fn(i64) -> Vec<f32>,
358    {
359        let mut visited: HashSet<i64> = HashSet::with_capacity(ef * 2);
360        // candidates: min-heap of (distance, id) — pop closest first.
361        let mut candidates: BinaryHeap<MinHeapItem> = BinaryHeap::with_capacity(ef * 2);
362        // results: max-heap of (distance, id) — top is the worst kept.
363        let mut results: BinaryHeap<MaxHeapItem> = BinaryHeap::with_capacity(ef);
364
365        for &id in entries {
366            if !visited.insert(id) {
367                continue;
368            }
369            let d = self.distance.compute(query, &get_vec(id));
370            candidates.push(MinHeapItem { dist: d, id });
371            results.push(MaxHeapItem { dist: d, id });
372        }
373
374        while let Some(MinHeapItem {
375            dist: c_dist,
376            id: c_id,
377        }) = candidates.pop()
378        {
379            // If the closest unexpanded candidate is worse than the
380            // worst kept result, no further expansion can improve the
381            // result set. Bail.
382            if let Some(worst) = results.peek() {
383                if results.len() >= ef && c_dist > worst.dist {
384                    break;
385                }
386            }
387
388            // Expand: visit each neighbor of c_id at this layer.
389            let neighbors = self
390                .nodes
391                .get(&c_id)
392                .and_then(|n| n.layers.get(layer))
393                .cloned()
394                .unwrap_or_default();
395            for nb in neighbors {
396                if !visited.insert(nb) {
397                    continue;
398                }
399                let d = self.distance.compute(query, &get_vec(nb));
400                let admit = if results.len() < ef {
401                    true
402                } else {
403                    d < results.peek().unwrap().dist
404                };
405                if admit {
406                    candidates.push(MinHeapItem { dist: d, id: nb });
407                    results.push(MaxHeapItem { dist: d, id: nb });
408                    if results.len() > ef {
409                        results.pop();
410                    }
411                }
412            }
413        }
414
415        // Drain results into a sorted vec. results is a max-heap, so
416        // popping gives descending order; reverse for ascending.
417        let mut out: Vec<(f32, i64)> = Vec::with_capacity(results.len());
418        while let Some(item) = results.pop() {
419            out.push((item.dist, item.id));
420        }
421        out.reverse();
422        out
423    }
424
425    /// Picks a layer for a new node using the standard HNSW geometric
426    /// distribution: `L = floor(-ln(uniform) * m_l)`. With M=16, mL ≈ 0.36,
427    /// so:
428    ///   - P(L=0) ≈ 1 - 1/M = 15/16
429    ///   - P(L=1) ≈ 1/16 - 1/256
430    ///   - P(L=2) ≈ 1/256 - …
431    /// i.e., most new nodes live only at layer 0; a few percolate up.
432    fn pick_layer(&mut self) -> usize {
433        let u = self.next_uniform().max(1e-6); // guard log(0)
434        let layer = (-u.ln() * self.params.m_l).floor() as usize;
435        // Cap at top_layer + 1 to keep the graph from sprouting empty
436        // layers above the current top — matches the original HNSW
437        // paper's recommendation.
438        layer.min(self.top_layer + 1)
439    }
440
441    /// Pulls a uniform-on-(0, 1] f32 from the internal xorshift state.
442    /// Top 24 bits of the next u64, divided by 2^24 — gives 24-bit
443    /// uniform precision, plenty for layer assignment.
444    fn next_uniform(&mut self) -> f32 {
445        let mut x = self.rng_state;
446        x ^= x << 13;
447        x ^= x >> 7;
448        x ^= x << 17;
449        self.rng_state = x;
450        ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
451    }
452}
453
454// -----------------------------------------------------------------
455// Heap items
456//
457// Rust's BinaryHeap is a max-heap that uses Ord. f32 doesn't impl Ord
458// (NaN), so we wrap (distance, id) pairs and provide custom Ord that
459// uses partial_cmp with NaN treated as Greater (NaN sorts as worst).
460//
461// MinHeapItem inverts the comparison so BinaryHeap<MinHeapItem> behaves
462// as a min-heap — top is the smallest distance, popping gives ascending
463// order.
464//
465// MaxHeapItem uses the natural ordering — top is the largest distance.
466
467#[derive(Debug, Clone, Copy)]
468struct MinHeapItem {
469    dist: f32,
470    id: i64,
471}
472
473impl PartialEq for MinHeapItem {
474    fn eq(&self, other: &Self) -> bool {
475        self.dist == other.dist && self.id == other.id
476    }
477}
478impl Eq for MinHeapItem {}
479impl PartialOrd for MinHeapItem {
480    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
481        Some(self.cmp(other))
482    }
483}
484impl Ord for MinHeapItem {
485    fn cmp(&self, other: &Self) -> Ordering {
486        // Reverse so smallest distance bubbles to top.
487        other
488            .dist
489            .partial_cmp(&self.dist)
490            .unwrap_or(Ordering::Equal)
491            .then(other.id.cmp(&self.id))
492    }
493}
494
495#[derive(Debug, Clone, Copy)]
496struct MaxHeapItem {
497    dist: f32,
498    id: i64,
499}
500
501impl PartialEq for MaxHeapItem {
502    fn eq(&self, other: &Self) -> bool {
503        self.dist == other.dist && self.id == other.id
504    }
505}
506impl Eq for MaxHeapItem {}
507impl PartialOrd for MaxHeapItem {
508    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
509        Some(self.cmp(other))
510    }
511}
512impl Ord for MaxHeapItem {
513    fn cmp(&self, other: &Self) -> Ordering {
514        // Natural so largest distance bubbles to top.
515        self.dist
516            .partial_cmp(&other.dist)
517            .unwrap_or(Ordering::Equal)
518            .then(self.id.cmp(&other.id))
519    }
520}
521
522// -----------------------------------------------------------------
523// Tests
524// -----------------------------------------------------------------
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    /// Deterministic xorshift to generate test vectors.
531    fn random_vec(state: &mut u64, dim: usize) -> Vec<f32> {
532        (0..dim)
533            .map(|_| {
534                let mut x = *state;
535                x ^= x << 13;
536                x ^= x >> 7;
537                x ^= x << 17;
538                *state = x;
539                ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
540            })
541            .collect()
542    }
543
544    /// Brute-force nearest-neighbors baseline for recall comparison.
545    fn brute_force_topk(
546        vectors: &[Vec<f32>],
547        query: &[f32],
548        k: usize,
549        metric: DistanceMetric,
550    ) -> Vec<i64> {
551        let mut by_dist: Vec<(f32, i64)> = vectors
552            .iter()
553            .enumerate()
554            .map(|(i, v)| (metric.compute(query, v), i as i64))
555            .collect();
556        by_dist.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(Ordering::Equal));
557        by_dist.into_iter().take(k).map(|(_, id)| id).collect()
558    }
559
560    /// recall@k — fraction of the brute-force top-k that the HNSW
561    /// search also returned (in any order).
562    fn recall_at_k(hnsw_result: &[i64], baseline: &[i64]) -> f32 {
563        let baseline_set: HashSet<i64> = baseline.iter().copied().collect();
564        let hits = hnsw_result
565            .iter()
566            .filter(|id| baseline_set.contains(id))
567            .count();
568        hits as f32 / baseline.len() as f32
569    }
570
571    #[test]
572    fn empty_index_returns_empty_search() {
573        let idx = HnswIndex::new(DistanceMetric::L2, 42);
574        let vectors: Vec<Vec<f32>> = vec![];
575        let result = idx.search(&[0.0; 4], 5, |id| vectors[id as usize].clone());
576        assert!(result.is_empty());
577    }
578
579    #[test]
580    fn single_node_returns_only_itself() {
581        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
582        let v0 = vec![1.0, 2.0, 3.0];
583        let vectors = vec![v0.clone()];
584        idx.insert(0, &v0, |id| vectors[id as usize].clone());
585        let result = idx.search(&[0.0; 3], 5, |id| vectors[id as usize].clone());
586        assert_eq!(result, vec![0]);
587    }
588
589    #[test]
590    fn duplicate_insert_is_noop() {
591        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
592        let v0 = vec![1.0, 2.0];
593        let vectors = vec![v0.clone()];
594        idx.insert(0, &v0, |id| vectors[id as usize].clone());
595        idx.insert(0, &v0, |id| vectors[id as usize].clone());
596        assert_eq!(idx.len(), 1);
597    }
598
599    #[test]
600    fn k_zero_returns_empty() {
601        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
602        let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
603        for (i, v) in vectors.iter().enumerate() {
604            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
605        }
606        let result = idx.search(&[0.5, 0.5], 0, |id| vectors[id as usize].clone());
607        assert!(result.is_empty());
608    }
609
610    #[test]
611    fn small_graph_finds_exact_nearest() {
612        // 5 well-separated points in 2D — HNSW should find the exact
613        // nearest with no recall loss for k=1 and k=3.
614        let vectors: Vec<Vec<f32>> = vec![
615            vec![0.0, 0.0],
616            vec![10.0, 0.0],
617            vec![0.0, 10.0],
618            vec![10.0, 10.0],
619            vec![5.0, 5.0],
620        ];
621        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
622        for (i, v) in vectors.iter().enumerate() {
623            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
624        }
625
626        // Query at (1, 1): nearest is (0, 0).
627        let result = idx.search(&[1.0, 1.0], 1, |id| vectors[id as usize].clone());
628        assert_eq!(result, vec![0]);
629
630        // Query at (5.5, 5.5): top-3 should be id=4 (5,5), then any
631        // two of the corners at distance ~7.78.
632        let result = idx.search(&[5.5, 5.5], 3, |id| vectors[id as usize].clone());
633        assert_eq!(result.len(), 3);
634        assert_eq!(result[0], 4, "closest to (5.5,5.5) should be id=4");
635    }
636
637    #[test]
638    fn recall_at_10_is_high_on_random_vectors_l2() {
639        // Standard recall test: 1000 random vectors in 8D, query for
640        // top-10 with HNSW, compare to brute-force ground truth.
641        // Modern HNSW papers target recall@10 > 0.95; we should clear
642        // that comfortably on this small benchmark.
643        let mut state: u64 = 0xDEADBEEF;
644        let dim = 8;
645        let n = 1000;
646        let queries = 20;
647        let k = 10;
648
649        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
650
651        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
652        for (i, v) in vectors.iter().enumerate() {
653            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
654        }
655
656        let mut total_recall = 0.0f32;
657        for _ in 0..queries {
658            let q = random_vec(&mut state, dim);
659            let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
660            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::L2);
661            total_recall += recall_at_k(&hnsw_top, &baseline);
662        }
663        let avg_recall = total_recall / queries as f32;
664        assert!(
665            avg_recall >= 0.95,
666            "recall@{k} dropped below 0.95: avg={avg_recall:.3}"
667        );
668    }
669
670    #[test]
671    fn recall_at_10_is_high_on_random_vectors_cosine() {
672        // Same shape as the L2 test but with cosine distance, to
673        // exercise the alternative metric through the same pipeline.
674        let mut state: u64 = 0xC0FFEE;
675        let dim = 16;
676        let n = 500;
677        let queries = 20;
678        let k = 10;
679
680        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
681
682        let mut idx = HnswIndex::new(DistanceMetric::Cosine, 42);
683        for (i, v) in vectors.iter().enumerate() {
684            idx.insert(i as i64, v, |id| vectors[id as usize].clone());
685        }
686
687        let mut total_recall = 0.0f32;
688        for _ in 0..queries {
689            let q = random_vec(&mut state, dim);
690            let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
691            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::Cosine);
692            total_recall += recall_at_k(&hnsw_top, &baseline);
693        }
694        let avg_recall = total_recall / queries as f32;
695        assert!(
696            avg_recall >= 0.95,
697            "cosine recall@{k} dropped below 0.95: avg={avg_recall:.3}"
698        );
699    }
700
701    #[test]
702    fn entry_point_promotes_when_higher_layer_node_inserted() {
703        // The graph's entry point should always be a node at the
704        // current top layer. Insert two nodes; if the second lands at
705        // a higher layer, it becomes the entry point.
706        // We can't easily force a particular layer (it's randomized),
707        // so check the invariant: after every insert, the entry node's
708        // max_layer == top_layer.
709        let mut state: u64 = 0xABCDEF;
710        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
711        let dim = 4;
712        let mut vectors: Vec<Vec<f32>> = Vec::new();
713        for i in 0..50 {
714            vectors.push(random_vec(&mut state, dim));
715            let v = vectors[i].clone();
716            idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
717
718            // Check invariant.
719            let entry = idx.entry_point.expect("non-empty");
720            let entry_max = idx.nodes[&entry].max_layer();
721            assert_eq!(
722                entry_max, idx.top_layer,
723                "entry-point invariant broken at step {i}: entry {entry} has max_layer {entry_max}, top_layer is {}",
724                idx.top_layer
725            );
726        }
727    }
728
729    #[test]
730    fn neighbor_lists_respect_m_max() {
731        // After inserting 200 points with M=16 (so M_max0 = 32), no
732        // node should have more than 32 neighbors at layer 0 or more
733        // than 16 at any higher layer.
734        let mut state: u64 = 0x123456;
735        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
736        let dim = 4;
737        let mut vectors: Vec<Vec<f32>> = Vec::new();
738        for i in 0..200 {
739            vectors.push(random_vec(&mut state, dim));
740            let v = vectors[i].clone();
741            idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
742        }
743
744        for (id, node) in &idx.nodes {
745            for (layer, neighbors) in node.layers.iter().enumerate() {
746                let cap = if layer == 0 {
747                    idx.params.m_max0
748                } else {
749                    idx.params.m
750                };
751                assert!(
752                    neighbors.len() <= cap,
753                    "node {id} layer {layer} has {} > cap {cap}",
754                    neighbors.len()
755                );
756            }
757        }
758    }
759
760    #[test]
761    fn deterministic_with_fixed_seed() {
762        // Same seed + same insert order → same graph topology.
763        // Catches accidental sources of nondeterminism (HashMap
764        // iteration order, etc.).
765        let mut state: u64 = 0x999;
766        let dim = 4;
767        let n = 50;
768        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
769
770        let mut idx_a = HnswIndex::new(DistanceMetric::L2, 42);
771        let mut idx_b = HnswIndex::new(DistanceMetric::L2, 42);
772        for (i, v) in vectors.iter().enumerate() {
773            idx_a.insert(i as i64, v, |id| vectors[id as usize].clone());
774            idx_b.insert(i as i64, v, |id| vectors[id as usize].clone());
775        }
776
777        // Same top layer.
778        assert_eq!(idx_a.top_layer, idx_b.top_layer);
779        // Same entry point.
780        assert_eq!(idx_a.entry_point, idx_b.entry_point);
781        // Same node count and same per-node max-layer for every id.
782        // (Neighbor list contents may differ trivially if HashMap
783        // iteration sneaked in; if this fails, fix the source first.)
784        assert_eq!(idx_a.nodes.len(), idx_b.nodes.len());
785        for (id, node_a) in &idx_a.nodes {
786            let node_b = idx_b.nodes.get(id).expect("missing id");
787            assert_eq!(node_a.max_layer(), node_b.max_layer(), "id={id}");
788        }
789    }
790}