Skip to main content

sqlrite/sql/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) approximate-nearest-neighbor
2//! index. Pure algorithm; no SQL integration in this module.
3//!
4//! HNSW is the industry-standard ANN algorithm for in-memory vector search:
5//! a multi-layer graph where each node lives at some randomly-assigned max
6//! layer; higher layers are sparser, layer 0 contains every node. Search
7//! starts at the entry point (the node at the current top layer), greedily
8//! descends layer-by-layer, then does a beam search at layer 0.
9//!
10//! ```text
11//!     layer 2:   [A] -- [E]                    sparse
12//!                 |       |
13//!     layer 1:   [A] -- [E] -- [G] -- [J]      mid
14//!                 |  /  |  \   |  \   |
15//!     layer 0:   [A,B,C,D,E,F,G,H,I,J,...]     dense (every node)
16//! ```
17//!
18//! ## What this module is responsible for
19//!
20//! - The graph (per-node, per-layer neighbor lists)
21//! - Layer assignment for new nodes (geometric distribution)
22//! - Insertion: greedy descent + beam search + neighbor pruning
23//! - Query: greedy descent + beam search at layer 0, return top-k
24//!
25//! ## What it is NOT responsible for (yet)
26//!
27//! - **Storing vectors.** The algorithm calls a `get_vec(node_id) -> &[f32]`
28//!   closure to fetch the vector for any node it touches. In Phase 7d.2
29//!   that closure will read from the SQL table holding the indexed
30//!   column; in tests it reads from an in-memory `Vec<Vec<f32>>`.
31//! - **Persistence.** The graph lives in `HashMap<i64, Node>` for now.
32//!   Phase 7d.3 wires it into the cell-encoded page format.
33//! - **DELETE / UPDATE.** Pre-existing nodes can't be removed today.
34//!   Soft-delete + lazy rebuild is the planned approach for 7d.2/7d.3.
35//!
36//! ## Parameters (per Phase 7 plan Q2 — fixed defaults)
37//!
38//! - `M = 16`              — max neighbors per node at layers > 0
39//! - `m_max0 = 32` (= 2·M) — max neighbors at layer 0
40//! - `ef_construction = 200` — beam width during INSERT
41//! - `ef_search = 50`      — default beam width during query
42//! - `m_l = 1/ln(M) ≈ 0.36`  — layer-assignment scale
43//!
44//! ## Invariants
45//!
46//! - Every `node.layers` Vec has length `node_max_layer + 1` for that node.
47//! - `node.layers[i]` contains node_ids of neighbors at layer i. Each
48//!   neighbor is itself a node in `nodes`; symmetrical (if A → B at layer i
49//!   then B → A at layer i, modulo pruning).
50//! - `entry_point` is `Some(id)` iff `nodes` is non-empty. The entry node
51//!   has the highest max-layer of any node currently in the graph.
52
53use std::cmp::Ordering;
54use std::collections::{BinaryHeap, HashMap, HashSet};
55
56use crate::error::{Result, SQLRiteError};
57
58/// Distance metric used by the HNSW index. Must match what the
59/// surrounding `vec_distance_*` SQL function would compute on the same
60/// pair of vectors — otherwise the index probe and the brute-force
61/// fallback would disagree on which rows are "nearest". See
62/// `src/sql/executor.rs`'s `vec_distance_l2` / `_cosine` / `_dot` for
63/// the canonical implementations.
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum DistanceMetric {
66    L2,
67    Cosine,
68    Dot,
69}
70
71impl DistanceMetric {
72    /// Parses the metric name from the `CREATE INDEX … WITH
73    /// (metric = '<name>')` clause. Case-insensitive. Returns `None`
74    /// for unknown values; the parser surfaces that as a user-visible
75    /// error so a typo doesn't silently fall back to L2.
76    pub fn from_sql_name(name: &str) -> Option<Self> {
77        match name.to_ascii_lowercase().as_str() {
78            "l2" | "euclidean" => Some(DistanceMetric::L2),
79            "cosine" => Some(DistanceMetric::Cosine),
80            "dot" | "inner_product" | "ip" => Some(DistanceMetric::Dot),
81            _ => None,
82        }
83    }
84
85    /// The canonical SQL-surface name for this metric. Used when
86    /// synthesizing CREATE INDEX SQL back into `sqlrite_master`.
87    pub fn sql_name(self) -> &'static str {
88        match self {
89            DistanceMetric::L2 => "l2",
90            DistanceMetric::Cosine => "cosine",
91            DistanceMetric::Dot => "dot",
92        }
93    }
94
95    /// The `vec_distance_*` SQL function whose result this metric
96    /// orders by. The optimizer's HNSW shortcut only fires when the
97    /// query's ORDER BY expression names this exact function.
98    pub fn matching_distance_fn(self) -> &'static str {
99        match self {
100            DistanceMetric::L2 => "vec_distance_l2",
101            DistanceMetric::Cosine => "vec_distance_cosine",
102            DistanceMetric::Dot => "vec_distance_dot",
103        }
104    }
105
106    /// Computes the configured distance between two equal-dimension
107    /// vectors. Returns `f32::INFINITY` for the cosine/zero-magnitude
108    /// edge case; HNSW treats infinity as "worst possible candidate" and
109    /// will prefer any finite alternative, which matches the SQL-level
110    /// behaviour where `vec_distance_cosine` errors but the optimizer's
111    /// fallback path simply skips the offending row.
112    pub fn compute(self, a: &[f32], b: &[f32]) -> Result<f32> {
113        if a.is_empty() || b.is_empty() || a.len() != b.len() {
114            return Err(SQLRiteError::General(format!(
115                "HNSW vector dimension mismatch: left has {}, right has {}",
116                a.len(),
117                b.len()
118            )));
119        }
120        match self {
121            DistanceMetric::L2 => {
122                let mut sum = 0.0f32;
123                for i in 0..a.len() {
124                    let d = a[i] - b[i];
125                    sum += d * d;
126                }
127                Ok(sum.sqrt())
128            }
129            DistanceMetric::Cosine => {
130                let mut dot = 0.0f32;
131                let mut na = 0.0f32;
132                let mut nb = 0.0f32;
133                for i in 0..a.len() {
134                    dot += a[i] * b[i];
135                    na += a[i] * a[i];
136                    nb += b[i] * b[i];
137                }
138                let denom = (na * nb).sqrt();
139                if denom == 0.0 {
140                    Ok(f32::INFINITY)
141                } else {
142                    Ok(1.0 - dot / denom)
143                }
144            }
145            DistanceMetric::Dot => {
146                let mut dot = 0.0f32;
147                for i in 0..a.len() {
148                    dot += a[i] * b[i];
149                }
150                Ok(-dot)
151            }
152        }
153    }
154}
155
156/// Per-node metadata: a list of neighbor IDs for each layer this node
157/// lives in. `layers[0]` is layer 0 (densest); `layers[layers.len() - 1]`
158/// is the highest layer this node reaches.
159#[derive(Debug, Clone, Default)]
160pub struct Node {
161    /// Indexed by layer (0 = dense). `layers[i]` is the neighbor list
162    /// for this node at layer i. Always sorted-by-distance is *not* a
163    /// guaranteed invariant — pruning maintains it after each
164    /// modification, but during insert we may briefly hold an
165    /// unsorted set.
166    pub layers: Vec<Vec<i64>>,
167}
168
169impl Node {
170    /// Maximum layer this node reaches. Equals `layers.len() - 1`.
171    pub fn max_layer(&self) -> usize {
172        self.layers.len() - 1
173    }
174}
175
176/// HNSW algorithm parameters. Phase 7 ships fixed defaults (Q2 in the
177/// plan); this struct is `Clone + Copy` so callers wanting to fork an
178/// experimental tuning can do so without touching the index itself.
179#[derive(Debug, Clone, Copy)]
180pub struct HnswParams {
181    pub m: usize,
182    pub m_max0: usize,
183    pub ef_construction: usize,
184    pub ef_search: usize,
185    pub m_l: f32,
186}
187
188impl Default for HnswParams {
189    fn default() -> Self {
190        let m = 16;
191        Self {
192            m,
193            m_max0: 2 * m,
194            ef_construction: 200,
195            ef_search: 50,
196            m_l: 1.0 / (m as f32).ln(),
197        }
198    }
199}
200
201/// In-memory HNSW graph. See module docs for the model.
202#[derive(Debug, Clone)]
203pub struct HnswIndex {
204    pub params: HnswParams,
205    pub distance: DistanceMetric,
206    /// Node id of the entry point. `None` iff the index is empty.
207    /// At all times this is the id of the node with the highest
208    /// max-layer; if multiple nodes tie for the top layer, the
209    /// most-recently-promoted one wins.
210    pub entry_point: Option<i64>,
211    /// Highest layer currently populated. 0 when the index has at
212    /// most one node, grows as new nodes get assigned higher layers.
213    pub top_layer: usize,
214    /// Node id → its per-layer neighbor lists.
215    pub nodes: HashMap<i64, Node>,
216    /// xorshift64 RNG state for layer assignment. Seeded explicitly via
217    /// `new` so tests can pin a known sequence.
218    rng_state: u64,
219}
220
221impl HnswIndex {
222    /// Builds an empty HNSW index with default parameters and the given
223    /// distance metric + RNG seed. A seed of 0 is mapped to a small
224    /// nonzero constant — xorshift gets stuck at zero.
225    pub fn new(distance: DistanceMetric, seed: u64) -> Self {
226        let seed = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
227        Self {
228            params: HnswParams::default(),
229            distance,
230            entry_point: None,
231            top_layer: 0,
232            nodes: HashMap::new(),
233            rng_state: seed,
234        }
235    }
236
237    /// True if no nodes have been inserted yet.
238    pub fn is_empty(&self) -> bool {
239        self.nodes.is_empty()
240    }
241
242    /// Number of nodes currently in the index.
243    pub fn len(&self) -> usize {
244        self.nodes.len()
245    }
246
247    /// Phase 7d.3 — produces (node_id, layers) pairs in ascending node_id
248    /// order, suitable for serializing the graph to disk via the
249    /// `HnswNodeCell` wire format. The graph's metadata
250    /// (entry_point + top_layer) is recoverable from the nodes alone:
251    /// top_layer = max(max_layer); entry_point = any node at top_layer.
252    /// So we don't ship a separate metadata cell.
253    pub fn serialize_nodes(&self) -> Vec<(i64, Vec<Vec<i64>>)> {
254        let mut out: Vec<(i64, Vec<Vec<i64>>)> = self
255            .nodes
256            .iter()
257            .map(|(id, n)| (*id, n.layers.clone()))
258            .collect();
259        out.sort_by_key(|(id, _)| *id);
260        out
261    }
262
263    /// Phase 7d.3 — rebuilds an HnswIndex from a stream of (node_id, layers)
264    /// pairs as produced by `serialize_nodes` and round-tripped through
265    /// `HnswNodeCell` encode/decode. The rebuilt index has the same nodes,
266    /// same neighbor lists, same entry_point + top_layer as the source.
267    /// `seed` is fresh; the deserialized index is never inserted into via
268    /// the algorithmic `insert` path so the seed only matters if a caller
269    /// later calls `insert` after deserializing (then it controls layer
270    /// assignment for the appended node).
271    pub fn from_persisted_nodes<I>(distance: DistanceMetric, seed: u64, nodes: I) -> Self
272    where
273        I: IntoIterator<Item = (i64, Vec<Vec<i64>>)>,
274    {
275        let mut idx = Self::new(distance, seed);
276        let mut top_layer = 0usize;
277        let mut entry_point: Option<i64> = None;
278        for (id, layers) in nodes {
279            let max_layer = layers.len().saturating_sub(1);
280            if max_layer > top_layer || entry_point.is_none() {
281                top_layer = max_layer;
282                entry_point = Some(id);
283            }
284            idx.nodes.insert(id, Node { layers });
285        }
286        idx.top_layer = top_layer;
287        idx.entry_point = entry_point;
288        idx
289    }
290
291    /// Inserts a node into the graph. The node id must be unique;
292    /// re-inserting an existing id is a no-op (returns without error).
293    /// `vec` is the new node's vector; `get_vec` looks up the vector
294    /// for any other node id the algorithm touches.
295    pub fn insert<F>(&mut self, node_id: i64, vec: &[f32], get_vec: F) -> Result<()>
296    where
297        F: Fn(i64) -> Vec<f32>,
298    {
299        if self.nodes.contains_key(&node_id) {
300            return Ok(());
301        }
302
303        // First node: trivial case. Becomes entry point at layer 0.
304        if self.is_empty() {
305            self.nodes.insert(
306                node_id,
307                Node {
308                    layers: vec![Vec::new()],
309                },
310            );
311            self.entry_point = Some(node_id);
312            self.top_layer = 0;
313            return Ok(());
314        }
315
316        // Pick a layer for this new node.
317        let target_layer = self.pick_layer();
318
319        // Pre-allocate the new node's layer lists (empty for now;
320        // populated below).
321        let new_node = Node {
322            layers: vec![Vec::new(); target_layer + 1],
323        };
324        self.nodes.insert(node_id, new_node);
325
326        // Greedy descent from top down to (target_layer + 1) — at each
327        // layer above our target, advance the entry point to the
328        // single closest node. We don't add edges at these layers
329        // because the new node doesn't live there.
330        let mut entry = self.entry_point.expect("non-empty index has entry point");
331        for layer in (target_layer + 1..=self.top_layer).rev() {
332            let nearest = self.search_layer(vec, &[entry], 1, layer, &get_vec)?;
333            if let Some((_, id)) = nearest.into_iter().next() {
334                entry = id;
335            }
336        }
337
338        // Beam search + connect at each layer the new node lives in.
339        // We work top-down; the entry point for each layer is the best
340        // candidate found at the layer above.
341        let mut entries = vec![entry];
342        for layer in (0..=target_layer).rev() {
343            let candidates =
344                self.search_layer(vec, &entries, self.params.ef_construction, layer, &get_vec)?;
345
346            // Pick up to M neighbors from candidates (M_max0 at layer 0
347            // since we allow more connections at the dense layer).
348            let m_max = if layer == 0 {
349                self.params.m_max0
350            } else {
351                self.params.m
352            };
353            let neighbors: Vec<i64> = candidates
354                .iter()
355                .take(self.params.m)
356                .map(|(_, id)| *id)
357                .collect();
358
359            // Wire up the bidirectional edges.
360            self.nodes.get_mut(&node_id).expect("just inserted").layers[layer] = neighbors.clone();
361
362            for &nb in &neighbors {
363                let nb_layers = &mut self.nodes.get_mut(&nb).expect("neighbor must exist").layers;
364                if layer >= nb_layers.len() {
365                    // Neighbor doesn't actually live at this layer — shouldn't
366                    // happen because search_layer only returns nodes at this
367                    // layer, but defend against it.
368                    continue;
369                }
370                nb_layers[layer].push(node_id);
371
372                // Prune the neighbor's edge list if it's now over its M_max
373                // budget. Pruning policy: keep the closest M_max nodes
374                // by distance. (Distance recomputed; no precomputed values.)
375                if nb_layers[layer].len() > m_max {
376                    let nb_vec = get_vec(nb);
377                    let mut by_dist: Vec<(f32, i64)> = nb_layers[layer]
378                        .iter()
379                        .map(|id| {
380                            self.distance
381                                .compute(&nb_vec, &get_vec(*id))
382                                .map(|d| (d, *id))
383                        })
384                        .collect::<Result<Vec<_>>>()?;
385                    by_dist
386                        .sort_by(|(da, _), (db, _)| da.partial_cmp(db).unwrap_or(Ordering::Equal));
387                    by_dist.truncate(m_max);
388                    nb_layers[layer] = by_dist.into_iter().map(|(_, id)| id).collect();
389                }
390            }
391
392            // Carry the candidate set forward as entry points for the
393            // next (lower) layer.
394            entries = candidates.into_iter().map(|(_, id)| id).collect();
395        }
396
397        // If this new node lives higher than the current top, promote it.
398        if target_layer > self.top_layer {
399            self.top_layer = target_layer;
400            self.entry_point = Some(node_id);
401        }
402        Ok(())
403    }
404
405    /// Returns the k nearest node ids to `query`, in distance-ascending
406    /// order (closest first). Empty index returns an empty Vec.
407    pub fn search<F>(&self, query: &[f32], k: usize, get_vec: F) -> Result<Vec<i64>>
408    where
409        F: Fn(i64) -> Vec<f32>,
410    {
411        if self.is_empty() || k == 0 {
412            return Ok(Vec::new());
413        }
414
415        // Greedy descent from the top down to layer 1.
416        let mut entry = self.entry_point.expect("non-empty index has entry point");
417        for layer in (1..=self.top_layer).rev() {
418            let nearest = self.search_layer(query, &[entry], 1, layer, &get_vec)?;
419            if let Some((_, id)) = nearest.into_iter().next() {
420                entry = id;
421            }
422        }
423
424        // Beam search at layer 0 with width = max(ef_search, k).
425        let ef = self.params.ef_search.max(k);
426        let candidates = self.search_layer(query, &[entry], ef, 0, &get_vec)?;
427
428        Ok(candidates.into_iter().take(k).map(|(_, id)| id).collect())
429    }
430
431    /// Runs a beam search at one layer starting from `entries`, returning
432    /// the top-`ef` nearest nodes to `query` found, sorted by distance
433    /// ascending.
434    ///
435    /// This is the workhorse of both insert and search. The two priority
436    /// queues — "candidates" (nodes still to expand) and "results"
437    /// (current best ef found) — terminate when the closest unexpanded
438    /// candidate is farther than the worst kept result.
439    fn search_layer<F>(
440        &self,
441        query: &[f32],
442        entries: &[i64],
443        ef: usize,
444        layer: usize,
445        get_vec: &F,
446    ) -> Result<Vec<(f32, i64)>>
447    where
448        F: Fn(i64) -> Vec<f32>,
449    {
450        let mut visited: HashSet<i64> = HashSet::with_capacity(ef * 2);
451        // candidates: min-heap of (distance, id) — pop closest first.
452        let mut candidates: BinaryHeap<MinHeapItem> = BinaryHeap::with_capacity(ef * 2);
453        // results: max-heap of (distance, id) — top is the worst kept.
454        let mut results: BinaryHeap<MaxHeapItem> = BinaryHeap::with_capacity(ef);
455
456        for &id in entries {
457            if !visited.insert(id) {
458                continue;
459            }
460            let d = self.distance.compute(query, &get_vec(id))?;
461            candidates.push(MinHeapItem { dist: d, id });
462            results.push(MaxHeapItem { dist: d, id });
463        }
464
465        while let Some(MinHeapItem {
466            dist: c_dist,
467            id: c_id,
468        }) = candidates.pop()
469        {
470            // If the closest unexpanded candidate is worse than the
471            // worst kept result, no further expansion can improve the
472            // result set. Bail.
473            if let Some(worst) = results.peek() {
474                if results.len() >= ef && c_dist > worst.dist {
475                    break;
476                }
477            }
478
479            // Expand: visit each neighbor of c_id at this layer.
480            let neighbors = self
481                .nodes
482                .get(&c_id)
483                .and_then(|n| n.layers.get(layer))
484                .cloned()
485                .unwrap_or_default();
486            for nb in neighbors {
487                if !visited.insert(nb) {
488                    continue;
489                }
490                let d = self.distance.compute(query, &get_vec(nb))?;
491                let admit = if results.len() < ef {
492                    true
493                } else {
494                    d < results.peek().unwrap().dist
495                };
496                if admit {
497                    candidates.push(MinHeapItem { dist: d, id: nb });
498                    results.push(MaxHeapItem { dist: d, id: nb });
499                    if results.len() > ef {
500                        results.pop();
501                    }
502                }
503            }
504        }
505
506        // Drain results into a sorted vec. results is a max-heap, so
507        // popping gives descending order; reverse for ascending.
508        let mut out: Vec<(f32, i64)> = Vec::with_capacity(results.len());
509        while let Some(item) = results.pop() {
510            out.push((item.dist, item.id));
511        }
512        out.reverse();
513        Ok(out)
514    }
515
516    /// Picks a layer for a new node using the standard HNSW geometric
517    /// distribution: `L = floor(-ln(uniform) * m_l)`. With M=16, mL ≈ 0.36,
518    /// so:
519    ///   - P(L=0) ≈ 1 - 1/M = 15/16
520    ///   - P(L=1) ≈ 1/16 - 1/256
521    ///   - P(L=2) ≈ 1/256 - …
522    ///     i.e., most new nodes live only at layer 0; a few percolate up.
523    fn pick_layer(&mut self) -> usize {
524        let u = self.next_uniform().max(1e-6); // guard log(0)
525        let layer = (-u.ln() * self.params.m_l).floor() as usize;
526        // Cap at top_layer + 1 to keep the graph from sprouting empty
527        // layers above the current top — matches the original HNSW
528        // paper's recommendation.
529        layer.min(self.top_layer + 1)
530    }
531
532    /// Pulls a uniform-on-(0, 1] f32 from the internal xorshift state.
533    /// Top 24 bits of the next u64, divided by 2^24 — gives 24-bit
534    /// uniform precision, plenty for layer assignment.
535    fn next_uniform(&mut self) -> f32 {
536        let mut x = self.rng_state;
537        x ^= x << 13;
538        x ^= x >> 7;
539        x ^= x << 17;
540        self.rng_state = x;
541        ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
542    }
543}
544
545// -----------------------------------------------------------------
546// Heap items
547//
548// Rust's BinaryHeap is a max-heap that uses Ord. f32 doesn't impl Ord
549// (NaN), so we wrap (distance, id) pairs and provide custom Ord that
550// uses partial_cmp with NaN treated as Greater (NaN sorts as worst).
551//
552// MinHeapItem inverts the comparison so BinaryHeap<MinHeapItem> behaves
553// as a min-heap — top is the smallest distance, popping gives ascending
554// order.
555//
556// MaxHeapItem uses the natural ordering — top is the largest distance.
557
558#[derive(Debug, Clone, Copy)]
559struct MinHeapItem {
560    dist: f32,
561    id: i64,
562}
563
564impl PartialEq for MinHeapItem {
565    fn eq(&self, other: &Self) -> bool {
566        self.dist == other.dist && self.id == other.id
567    }
568}
569impl Eq for MinHeapItem {}
570impl PartialOrd for MinHeapItem {
571    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
572        Some(self.cmp(other))
573    }
574}
575impl Ord for MinHeapItem {
576    fn cmp(&self, other: &Self) -> Ordering {
577        // Reverse so smallest distance bubbles to top.
578        other
579            .dist
580            .partial_cmp(&self.dist)
581            .unwrap_or(Ordering::Equal)
582            .then(other.id.cmp(&self.id))
583    }
584}
585
586#[derive(Debug, Clone, Copy)]
587struct MaxHeapItem {
588    dist: f32,
589    id: i64,
590}
591
592impl PartialEq for MaxHeapItem {
593    fn eq(&self, other: &Self) -> bool {
594        self.dist == other.dist && self.id == other.id
595    }
596}
597impl Eq for MaxHeapItem {}
598impl PartialOrd for MaxHeapItem {
599    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
600        Some(self.cmp(other))
601    }
602}
603impl Ord for MaxHeapItem {
604    fn cmp(&self, other: &Self) -> Ordering {
605        // Natural so largest distance bubbles to top.
606        self.dist
607            .partial_cmp(&other.dist)
608            .unwrap_or(Ordering::Equal)
609            .then(self.id.cmp(&other.id))
610    }
611}
612
613// -----------------------------------------------------------------
614// Tests
615// -----------------------------------------------------------------
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    /// Deterministic xorshift to generate test vectors.
622    fn random_vec(state: &mut u64, dim: usize) -> Vec<f32> {
623        (0..dim)
624            .map(|_| {
625                let mut x = *state;
626                x ^= x << 13;
627                x ^= x >> 7;
628                x ^= x << 17;
629                *state = x;
630                ((x >> 40) as u32) as f32 / (1u32 << 24) as f32
631            })
632            .collect()
633    }
634
635    /// Brute-force nearest-neighbors baseline for recall comparison.
636    fn brute_force_topk(
637        vectors: &[Vec<f32>],
638        query: &[f32],
639        k: usize,
640        metric: DistanceMetric,
641    ) -> Vec<i64> {
642        let mut by_dist: Vec<(f32, i64)> = vectors
643            .iter()
644            .enumerate()
645            .map(|(i, v)| (metric.compute(query, v).expect("matching dims"), i as i64))
646            .collect();
647        by_dist.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(Ordering::Equal));
648        by_dist.into_iter().take(k).map(|(_, id)| id).collect()
649    }
650
651    /// recall@k — fraction of the brute-force top-k that the HNSW
652    /// search also returned (in any order).
653    fn recall_at_k(hnsw_result: &[i64], baseline: &[i64]) -> f32 {
654        let baseline_set: HashSet<i64> = baseline.iter().copied().collect();
655        let hits = hnsw_result
656            .iter()
657            .filter(|id| baseline_set.contains(id))
658            .count();
659        hits as f32 / baseline.len() as f32
660    }
661
662    #[test]
663    fn empty_index_returns_empty_search() {
664        let idx = HnswIndex::new(DistanceMetric::L2, 42);
665        let vectors: Vec<Vec<f32>> = vec![];
666        let result = idx
667            .search(&[0.0; 4], 5, |id| vectors[id as usize].clone())
668            .unwrap();
669        assert!(result.is_empty());
670    }
671
672    #[test]
673    fn single_node_returns_only_itself() {
674        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
675        let v0 = vec![1.0, 2.0, 3.0];
676        let vectors = [v0.clone()];
677        idx.insert(0, &v0, |id| vectors[id as usize].clone())
678            .unwrap();
679        let result = idx
680            .search(&[0.0; 3], 5, |id| vectors[id as usize].clone())
681            .unwrap();
682        assert_eq!(result, vec![0]);
683    }
684
685    #[test]
686    fn duplicate_insert_is_noop() {
687        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
688        let v0 = vec![1.0, 2.0];
689        let vectors = [v0.clone()];
690        idx.insert(0, &v0, |id| vectors[id as usize].clone())
691            .unwrap();
692        idx.insert(0, &v0, |id| vectors[id as usize].clone())
693            .unwrap();
694        assert_eq!(idx.len(), 1);
695    }
696
697    #[test]
698    fn distance_rejects_empty_or_mismatched_vectors() {
699        let empty_err = DistanceMetric::L2.compute(&[1.0, 2.0], &[]).unwrap_err();
700        assert!(format!("{empty_err}").contains("dimension mismatch"));
701
702        let mismatch_err = DistanceMetric::Cosine
703            .compute(&[1.0, 2.0], &[1.0])
704            .unwrap_err();
705        assert!(format!("{mismatch_err}").contains("left has 2, right has 1"));
706    }
707
708    #[test]
709    fn k_zero_returns_empty() {
710        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
711        let vectors = [vec![1.0, 0.0], vec![0.0, 1.0]];
712        for (i, v) in vectors.iter().enumerate() {
713            idx.insert(i as i64, v, |id| vectors[id as usize].clone())
714                .unwrap();
715        }
716        let result = idx
717            .search(&[0.5, 0.5], 0, |id| vectors[id as usize].clone())
718            .unwrap();
719        assert!(result.is_empty());
720    }
721
722    #[test]
723    fn small_graph_finds_exact_nearest() {
724        // 5 well-separated points in 2D — HNSW should find the exact
725        // nearest with no recall loss for k=1 and k=3.
726        let vectors: Vec<Vec<f32>> = vec![
727            vec![0.0, 0.0],
728            vec![10.0, 0.0],
729            vec![0.0, 10.0],
730            vec![10.0, 10.0],
731            vec![5.0, 5.0],
732        ];
733        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
734        for (i, v) in vectors.iter().enumerate() {
735            idx.insert(i as i64, v, |id| vectors[id as usize].clone())
736                .unwrap();
737        }
738
739        // Query at (1, 1): nearest is (0, 0).
740        let result = idx
741            .search(&[1.0, 1.0], 1, |id| vectors[id as usize].clone())
742            .unwrap();
743        assert_eq!(result, vec![0]);
744
745        // Query at (5.5, 5.5): top-3 should be id=4 (5,5), then any
746        // two of the corners at distance ~7.78.
747        let result = idx
748            .search(&[5.5, 5.5], 3, |id| vectors[id as usize].clone())
749            .unwrap();
750        assert_eq!(result.len(), 3);
751        assert_eq!(result[0], 4, "closest to (5.5,5.5) should be id=4");
752    }
753
754    #[test]
755    fn recall_at_10_is_high_on_random_vectors_l2() {
756        // Standard recall test: 1000 random vectors in 8D, query for
757        // top-10 with HNSW, compare to brute-force ground truth.
758        // Modern HNSW papers target recall@10 > 0.95; we should clear
759        // that comfortably on this small benchmark.
760        let mut state: u64 = 0xDEADBEEF;
761        let dim = 8;
762        let n = 1000;
763        let queries = 20;
764        let k = 10;
765
766        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
767
768        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
769        for (i, v) in vectors.iter().enumerate() {
770            idx.insert(i as i64, v, |id| vectors[id as usize].clone())
771                .unwrap();
772        }
773
774        let mut total_recall = 0.0f32;
775        for _ in 0..queries {
776            let q = random_vec(&mut state, dim);
777            let hnsw_top = idx
778                .search(&q, k, |id| vectors[id as usize].clone())
779                .unwrap();
780            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::L2);
781            total_recall += recall_at_k(&hnsw_top, &baseline);
782        }
783        let avg_recall = total_recall / queries as f32;
784        assert!(
785            avg_recall >= 0.95,
786            "recall@{k} dropped below 0.95: avg={avg_recall:.3}"
787        );
788    }
789
790    #[test]
791    fn recall_at_10_is_high_on_random_vectors_cosine() {
792        // Same shape as the L2 test but with cosine distance, to
793        // exercise the alternative metric through the same pipeline.
794        let mut state: u64 = 0xC0FFEE;
795        let dim = 16;
796        let n = 500;
797        let queries = 20;
798        let k = 10;
799
800        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
801
802        let mut idx = HnswIndex::new(DistanceMetric::Cosine, 42);
803        for (i, v) in vectors.iter().enumerate() {
804            idx.insert(i as i64, v, |id| vectors[id as usize].clone())
805                .unwrap();
806        }
807
808        let mut total_recall = 0.0f32;
809        for _ in 0..queries {
810            let q = random_vec(&mut state, dim);
811            let hnsw_top = idx
812                .search(&q, k, |id| vectors[id as usize].clone())
813                .unwrap();
814            let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::Cosine);
815            total_recall += recall_at_k(&hnsw_top, &baseline);
816        }
817        let avg_recall = total_recall / queries as f32;
818        assert!(
819            avg_recall >= 0.95,
820            "cosine recall@{k} dropped below 0.95: avg={avg_recall:.3}"
821        );
822    }
823
824    #[test]
825    fn entry_point_promotes_when_higher_layer_node_inserted() {
826        // The graph's entry point should always be a node at the
827        // current top layer. Insert two nodes; if the second lands at
828        // a higher layer, it becomes the entry point.
829        // We can't easily force a particular layer (it's randomized),
830        // so check the invariant: after every insert, the entry node's
831        // max_layer == top_layer.
832        let mut state: u64 = 0xABCDEF;
833        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
834        let dim = 4;
835        let mut vectors: Vec<Vec<f32>> = Vec::new();
836        for i in 0..50 {
837            vectors.push(random_vec(&mut state, dim));
838            let v = vectors[i].clone();
839            idx.insert(i as i64, &v, |id| vectors[id as usize].clone())
840                .unwrap();
841
842            // Check invariant.
843            let entry = idx.entry_point.expect("non-empty");
844            let entry_max = idx.nodes[&entry].max_layer();
845            assert_eq!(
846                entry_max, idx.top_layer,
847                "entry-point invariant broken at step {i}: entry {entry} has max_layer {entry_max}, top_layer is {}",
848                idx.top_layer
849            );
850        }
851    }
852
853    #[test]
854    fn neighbor_lists_respect_m_max() {
855        // After inserting 200 points with M=16 (so M_max0 = 32), no
856        // node should have more than 32 neighbors at layer 0 or more
857        // than 16 at any higher layer.
858        let mut state: u64 = 0x123456;
859        let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
860        let dim = 4;
861        let mut vectors: Vec<Vec<f32>> = Vec::new();
862        for i in 0..200 {
863            vectors.push(random_vec(&mut state, dim));
864            let v = vectors[i].clone();
865            idx.insert(i as i64, &v, |id| vectors[id as usize].clone())
866                .unwrap();
867        }
868
869        for (id, node) in &idx.nodes {
870            for (layer, neighbors) in node.layers.iter().enumerate() {
871                let cap = if layer == 0 {
872                    idx.params.m_max0
873                } else {
874                    idx.params.m
875                };
876                assert!(
877                    neighbors.len() <= cap,
878                    "node {id} layer {layer} has {} > cap {cap}",
879                    neighbors.len()
880                );
881            }
882        }
883    }
884
885    #[test]
886    fn deterministic_with_fixed_seed() {
887        // Same seed + same insert order → same graph topology.
888        // Catches accidental sources of nondeterminism (HashMap
889        // iteration order, etc.).
890        let mut state: u64 = 0x999;
891        let dim = 4;
892        let n = 50;
893        let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
894
895        let mut idx_a = HnswIndex::new(DistanceMetric::L2, 42);
896        let mut idx_b = HnswIndex::new(DistanceMetric::L2, 42);
897        for (i, v) in vectors.iter().enumerate() {
898            idx_a
899                .insert(i as i64, v, |id| vectors[id as usize].clone())
900                .unwrap();
901            idx_b
902                .insert(i as i64, v, |id| vectors[id as usize].clone())
903                .unwrap();
904        }
905
906        // Same top layer.
907        assert_eq!(idx_a.top_layer, idx_b.top_layer);
908        // Same entry point.
909        assert_eq!(idx_a.entry_point, idx_b.entry_point);
910        // Same node count and same per-node max-layer for every id.
911        // (Neighbor list contents may differ trivially if HashMap
912        // iteration sneaked in; if this fails, fix the source first.)
913        assert_eq!(idx_a.nodes.len(), idx_b.nodes.len());
914        for (id, node_a) in &idx_a.nodes {
915            let node_b = idx_b.nodes.get(id).expect("missing id");
916            assert_eq!(node_a.max_layer(), node_b.max_layer(), "id={id}");
917        }
918    }
919}