Skip to main content

oxirs_vec/
hnsw_builder.rs

1//! HNSW (Hierarchical Navigable Small World) graph construction for ANN search.
2//!
3//! This module provides a pure-Rust implementation of the HNSW algorithm for
4//! approximate nearest-neighbour (ANN) search in high-dimensional vector spaces.
5//!
6//! Random number generation uses `scirs2_core::random` (never `rand` directly).
7
8use scirs2_core::random::{Random, Rng, StdRng};
9#[cfg(test)]
10use std::collections::HashMap;
11use std::collections::{BinaryHeap, HashSet};
12
13// ─────────────────────────────────────────────────
14// Distance helpers
15// ─────────────────────────────────────────────────
16
17/// Squared Euclidean distance between two equal-length slices.
18pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
19    a.iter()
20        .zip(b.iter())
21        .map(|(x, y)| (x - y).powi(2))
22        .sum::<f32>()
23        .sqrt()
24}
25
26/// Cosine similarity ∈ [0, 1] (0 = orthogonal, 1 = identical direction).
27pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
29    let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
30    let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
31    if norm_a == 0.0 || norm_b == 0.0 {
32        0.0
33    } else {
34        (dot / (norm_a * norm_b)).clamp(0.0, 1.0)
35    }
36}
37
38/// Compute the layer assignment for a newly inserted node.
39///
40/// `m_l` is the level multiplier (typically `1 / ln(M)`).
41/// `rng_val` must be drawn from `[0, 1)`.
42pub fn random_level(m_l: f64, rng_val: f64) -> usize {
43    if rng_val <= 0.0 {
44        return 0;
45    }
46    (-rng_val.ln() * m_l).floor() as usize
47}
48
49// ─────────────────────────────────────────────────
50// HnswConfig
51// ─────────────────────────────────────────────────
52
53/// Configuration parameters for the HNSW graph.
54#[derive(Debug, Clone)]
55pub struct HnswConfig {
56    /// Target number of connections per layer for new nodes.
57    pub m: usize,
58    /// Maximum number of connections per layer.
59    pub m_max: usize,
60    /// Size of the dynamic candidate list used during construction.
61    pub ef_construction: usize,
62    /// Level multiplier: `1 / ln(M)`.
63    pub m_l: f64,
64}
65
66impl Default for HnswConfig {
67    fn default() -> Self {
68        let m = 16_usize;
69        HnswConfig {
70            m,
71            m_max: 32,
72            ef_construction: 200,
73            m_l: 1.0 / (m as f64).ln(),
74        }
75    }
76}
77
78impl HnswConfig {
79    /// Build a configuration from `m` and `ef_construction`.
80    pub fn new(m: usize, ef_construction: usize) -> Self {
81        HnswConfig {
82            m,
83            m_max: m * 2,
84            ef_construction,
85            m_l: 1.0 / (m.max(2) as f64).ln(),
86        }
87    }
88}
89
90// ─────────────────────────────────────────────────
91// HnswNode
92// ─────────────────────────────────────────────────
93
94/// A node in the HNSW graph.
95///
96/// `connections[layer]` holds the ids of the node's neighbours at that layer.
97#[derive(Debug, Clone)]
98pub struct HnswNode {
99    pub id: usize,
100    pub vector: Vec<f32>,
101    /// `connections[0]` = bottom layer, higher indices = upper layers.
102    pub connections: Vec<Vec<usize>>,
103}
104
105impl HnswNode {
106    fn new(id: usize, vector: Vec<f32>, max_layer: usize) -> Self {
107        HnswNode {
108            id,
109            vector,
110            connections: vec![Vec::new(); max_layer + 1],
111        }
112    }
113
114    fn ensure_layers(&mut self, layers: usize) {
115        while self.connections.len() <= layers {
116            self.connections.push(Vec::new());
117        }
118    }
119}
120
121// ─────────────────────────────────────────────────
122// HnswGraph
123// ─────────────────────────────────────────────────
124
125/// The HNSW graph: a collection of nodes with hierarchical connectivity.
126pub struct HnswGraph {
127    pub nodes: Vec<HnswNode>,
128    pub entry_point: Option<usize>,
129    pub max_layer: usize,
130    config: HnswConfig,
131    /// Seeded RNG for reproducible level assignment.
132    rng: StdRng,
133}
134
135impl HnswGraph {
136    /// Create a new, empty HNSW graph with the given configuration.
137    pub fn new(config: HnswConfig) -> Self {
138        HnswGraph {
139            nodes: Vec::new(),
140            entry_point: None,
141            max_layer: 0,
142            config,
143            rng: Random::seed(42),
144        }
145    }
146
147    /// Insert a vector with the given `id` into the graph.
148    ///
149    /// Uses the seeded RNG for level assignment, guaranteeing reproducibility.
150    pub fn insert(&mut self, id: usize, vector: Vec<f32>) {
151        let rng_val: f64 = self.rng.random::<f64>();
152        let node_layer = random_level(self.config.m_l, rng_val);
153
154        let mut node = HnswNode::new(id, vector.clone(), node_layer);
155        node.ensure_layers(node_layer);
156
157        let node_idx = self.nodes.len();
158
159        match self.entry_point {
160            None => {
161                // First node becomes the entry point
162                self.entry_point = Some(node_idx);
163                self.max_layer = node_layer;
164                self.nodes.push(node);
165                return;
166            }
167            Some(ep) => {
168                // Greedy descent from the current entry point down to node_layer+1
169                let mut ep_idx = ep;
170                let current_top = self.max_layer;
171
172                if current_top > node_layer {
173                    for lc in (node_layer + 1..=current_top).rev() {
174                        ep_idx = self.greedy_search_layer(ep_idx, &vector, lc);
175                    }
176                }
177
178                // Insert connections at each layer from node_layer down to 0
179                for lc in (0..=node_layer.min(current_top)).rev() {
180                    let candidates =
181                        self.search_layer_ef(ep_idx, &vector, self.config.ef_construction, lc);
182                    let neighbours = self.select_neighbours(&candidates, self.config.m);
183
184                    // Connect new node to its neighbours
185                    for &nb_idx in &neighbours {
186                        if nb_idx < self.nodes.len() {
187                            self.nodes[nb_idx].ensure_layers(lc);
188                            self.nodes[nb_idx].connections[lc].push(node_idx);
189                            // Prune if over m_max
190                            let nb_vec = self.nodes[nb_idx].vector.clone();
191                            self.shrink_connections(nb_idx, lc, &nb_vec);
192                        }
193                    }
194                    node.ensure_layers(lc);
195                    node.connections[lc] = neighbours.clone();
196
197                    // Update entry point for next layer
198                    if !candidates.is_empty() {
199                        ep_idx = candidates[0].0;
200                    }
201                }
202
203                // If this node has higher layers, extend the graph
204                if node_layer > current_top {
205                    self.max_layer = node_layer;
206                    self.entry_point = Some(node_idx);
207                }
208            }
209        }
210        self.nodes.push(node);
211    }
212
213    /// Search for the `k` nearest neighbours to `query` using the greedy beam search
214    /// with candidate list size `ef`.
215    ///
216    /// Returns `(id, distance)` pairs sorted by ascending distance.
217    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
218        if self.nodes.is_empty() {
219            return Vec::new();
220        }
221        let ep = match self.entry_point {
222            Some(e) => e,
223            None => return Vec::new(),
224        };
225
226        let mut ep_idx = ep;
227        // Greedy descent from max_layer down to layer 1
228        for lc in (1..=self.max_layer).rev() {
229            ep_idx = self.greedy_search_layer(ep_idx, query, lc);
230        }
231
232        // Full beam search at layer 0
233        let candidates = self.search_layer_ef(ep_idx, query, ef.max(k), 0);
234
235        // Return top-k by ascending distance
236        let mut results: Vec<(usize, f32)> = candidates
237            .iter()
238            .take(k)
239            .map(|&(idx, dist)| (self.nodes[idx].id, dist))
240            .collect();
241        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
242        results.truncate(k);
243        results
244    }
245
246    /// Total number of nodes in the graph.
247    pub fn node_count(&self) -> usize {
248        self.nodes.len()
249    }
250
251    /// Number of layers in the graph (= max_layer + 1).
252    pub fn layer_count(&self) -> usize {
253        if self.nodes.is_empty() {
254            0
255        } else {
256            self.max_layer + 1
257        }
258    }
259
260    /// Return the connections of node `id` at `layer`, if they exist.
261    pub fn connections_at(&self, id: usize, layer: usize) -> Option<&Vec<usize>> {
262        let node = self.nodes.iter().find(|n| n.id == id)?;
263        node.connections.get(layer)
264    }
265
266    // ── Private helpers ────────────────────────────────────────────────
267
268    /// Single-step greedy search at a given layer: starting from `ep_idx`,
269    /// repeatedly move to the neighbour closest to `query`.
270    fn greedy_search_layer(&self, mut ep_idx: usize, query: &[f32], layer: usize) -> usize {
271        let mut best_dist = euclidean_distance(&self.nodes[ep_idx].vector, query);
272        loop {
273            let mut improved = false;
274            let conns: Vec<usize> = if layer < self.nodes[ep_idx].connections.len() {
275                self.nodes[ep_idx].connections[layer].clone()
276            } else {
277                Vec::new()
278            };
279            for nb_idx in conns {
280                if nb_idx < self.nodes.len() {
281                    let d = euclidean_distance(&self.nodes[nb_idx].vector, query);
282                    if d < best_dist {
283                        best_dist = d;
284                        ep_idx = nb_idx;
285                        improved = true;
286                    }
287                }
288            }
289            if !improved {
290                break;
291            }
292        }
293        ep_idx
294    }
295
296    /// Beam search at `layer` with candidate list of size `ef`.
297    /// Returns (node_idx, distance) sorted by ascending distance (nearest first).
298    fn search_layer_ef(
299        &self,
300        ep_idx: usize,
301        query: &[f32],
302        ef: usize,
303        layer: usize,
304    ) -> Vec<(usize, f32)> {
305        // We use two priority queues: candidates (min-heap by dist) and
306        // dynamic_list (max-heap by dist to track the worst in the result set).
307        // For simplicity we use a BTreeMap-like structure via a sorted vec.
308
309        let ep_dist = euclidean_distance(&self.nodes[ep_idx].vector, query);
310
311        // candidates: (dist, idx) – we want to process closest first
312        // result: the ef nearest found so far
313        let mut candidates: BinaryHeap<OrdPair> = BinaryHeap::new();
314        let mut result: BinaryHeap<OrdPair> = BinaryHeap::new(); // max-heap (worst at top)
315        let mut visited: HashSet<usize> = HashSet::new();
316
317        candidates.push(OrdPair(ep_dist, ep_idx));
318        result.push(OrdPair(ep_dist, ep_idx));
319        visited.insert(ep_idx);
320
321        while let Some(OrdPair(dist, idx)) = pop_min(&mut candidates) {
322            // If the worst in result is better than the current candidate, stop
323            if let Some(OrdPair(worst_dist, _)) = result.peek() {
324                if dist > *worst_dist && result.len() >= ef {
325                    break;
326                }
327            }
328            // Expand neighbours at this layer
329            let conns: Vec<usize> = if layer < self.nodes[idx].connections.len() {
330                self.nodes[idx].connections[layer].clone()
331            } else {
332                Vec::new()
333            };
334            for nb_idx in conns {
335                if nb_idx >= self.nodes.len() || visited.contains(&nb_idx) {
336                    continue;
337                }
338                visited.insert(nb_idx);
339                let d = euclidean_distance(&self.nodes[nb_idx].vector, query);
340                // Add to candidates and result if better than worst
341                let add = result.len() < ef || d < result.peek().map_or(f32::MAX, |p| p.0);
342                if add {
343                    candidates.push(OrdPair(d, nb_idx));
344                    result.push(OrdPair(d, nb_idx));
345                    // Prune result to ef elements
346                    while result.len() > ef {
347                        result.pop();
348                    }
349                }
350            }
351        }
352
353        // Collect result into sorted vec (ascending distance)
354        let mut out: Vec<(usize, f32)> = result.into_iter().map(|p| (p.1, p.0)).collect();
355        out.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
356        out
357    }
358
359    /// Select the best `m` neighbours from a sorted candidate list.
360    fn select_neighbours(&self, candidates: &[(usize, f32)], m: usize) -> Vec<usize> {
361        candidates.iter().take(m).map(|&(idx, _)| idx).collect()
362    }
363
364    /// Prune node's connections at `layer` to at most `m_max` elements.
365    fn shrink_connections(&mut self, node_idx: usize, layer: usize, node_vec: &[f32]) {
366        if layer >= self.nodes[node_idx].connections.len() {
367            return;
368        }
369        let m_max = self.config.m_max;
370        if self.nodes[node_idx].connections[layer].len() <= m_max {
371            return;
372        }
373        // Keep the m_max nearest connections by distance to node_vec
374        let mut conn_dists: Vec<(usize, f32)> = self.nodes[node_idx].connections[layer]
375            .iter()
376            .filter_map(|&nb| {
377                if nb < self.nodes.len() {
378                    let d = euclidean_distance(&self.nodes[nb].vector, node_vec);
379                    Some((nb, d))
380                } else {
381                    None
382                }
383            })
384            .collect();
385        conn_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
386        conn_dists.truncate(m_max);
387        self.nodes[node_idx].connections[layer] =
388            conn_dists.into_iter().map(|(nb, _)| nb).collect();
389    }
390}
391
392// ─────────────────────────────────────────────────
393// Internal helpers
394// ─────────────────────────────────────────────────
395
396/// Wrapper to turn (dist, idx) into an Ord element for BinaryHeap.
397/// BinaryHeap is a max-heap; we negate to get min-heap semantics via `pop_min`.
398#[derive(Debug, Clone, PartialEq)]
399struct OrdPair(f32, usize);
400
401impl Eq for OrdPair {}
402
403impl PartialOrd for OrdPair {
404    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
405        Some(self.cmp(other))
406    }
407}
408
409impl Ord for OrdPair {
410    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
411        // Max-heap by default in BinaryHeap; sort descending by dist
412        other
413            .0
414            .partial_cmp(&self.0)
415            .unwrap_or(std::cmp::Ordering::Equal)
416            .then(self.1.cmp(&other.1))
417    }
418}
419
420/// Pop the element with the minimum distance from a max-heap of OrdPair.
421/// We store (dist, idx) where the Ord impl makes the max-heap behave as a min-heap.
422fn pop_min(heap: &mut BinaryHeap<OrdPair>) -> Option<OrdPair> {
423    heap.pop()
424}
425
426// ─────────────────────────────────────────────────
427// Tests
428// ─────────────────────────────────────────────────
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    fn vec2(x: f32, y: f32) -> Vec<f32> {
435        vec![x, y]
436    }
437
438    // ── Distance functions ─────────────────────────────────────
439
440    #[test]
441    fn test_euclidean_distance_zero() {
442        let a = vec![1.0_f32, 2.0, 3.0];
443        assert_eq!(euclidean_distance(&a, &a), 0.0);
444    }
445
446    #[test]
447    fn test_euclidean_distance_unit() {
448        let a = vec![0.0_f32, 0.0];
449        let b = vec![3.0_f32, 4.0];
450        let d = euclidean_distance(&a, &b);
451        assert!((d - 5.0).abs() < 1e-5);
452    }
453
454    #[test]
455    fn test_euclidean_distance_symmetric() {
456        let a = vec![1.0_f32, 2.0, 3.0];
457        let b = vec![4.0_f32, 5.0, 6.0];
458        assert!((euclidean_distance(&a, &b) - euclidean_distance(&b, &a)).abs() < 1e-6);
459    }
460
461    #[test]
462    fn test_cosine_similarity_identical() {
463        let a = vec![1.0_f32, 0.0, 0.0];
464        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
465    }
466
467    #[test]
468    fn test_cosine_similarity_orthogonal() {
469        let a = vec![1.0_f32, 0.0];
470        let b = vec![0.0_f32, 1.0];
471        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
472    }
473
474    #[test]
475    fn test_cosine_similarity_range() {
476        let a = vec![0.6_f32, 0.8];
477        let b = vec![0.8_f32, 0.6];
478        let s = cosine_similarity(&a, &b);
479        assert!((0.0..=1.0).contains(&s));
480    }
481
482    #[test]
483    fn test_cosine_similarity_zero_vector() {
484        let a = vec![0.0_f32, 0.0];
485        let b = vec![1.0_f32, 0.0];
486        assert_eq!(cosine_similarity(&a, &b), 0.0);
487    }
488
489    // ── random_level ──────────────────────────────────────────
490
491    #[test]
492    fn test_random_level_near_zero_returns_zero() {
493        let level = random_level(1.0 / (16.0_f64).ln(), 0.999);
494        assert_eq!(level, 0);
495    }
496
497    #[test]
498    fn test_random_level_small_value_high_level() {
499        // Very small rng_val means high level
500        let level = random_level(1.0 / (16.0_f64).ln(), 1e-10);
501        assert!(level > 0);
502    }
503
504    #[test]
505    fn test_random_level_distribution() {
506        // Over many samples, level 0 should dominate
507        let m_l = 1.0 / (16.0_f64).ln();
508        let mut rng = Random::seed(0);
509        let mut counts: HashMap<usize, usize> = HashMap::new();
510        for _ in 0..1000 {
511            let v: f64 = rng.random::<f64>();
512            let level = random_level(m_l, v);
513            *counts.entry(level).or_insert(0) += 1;
514        }
515        // Level 0 must be the most common
516        let count_0 = counts.get(&0).copied().unwrap_or(0);
517        assert!(count_0 > 500, "Level 0 should dominate; got {count_0}");
518    }
519
520    // ── HnswConfig ────────────────────────────────────────────
521
522    #[test]
523    fn test_config_default_values() {
524        let cfg = HnswConfig::default();
525        assert_eq!(cfg.m, 16);
526        assert_eq!(cfg.m_max, 32);
527        assert_eq!(cfg.ef_construction, 200);
528        assert!(cfg.m_l > 0.0);
529    }
530
531    #[test]
532    fn test_config_new() {
533        let cfg = HnswConfig::new(8, 100);
534        assert_eq!(cfg.m, 8);
535        assert_eq!(cfg.m_max, 16);
536        assert_eq!(cfg.ef_construction, 100);
537    }
538
539    // ── Insert single node ────────────────────────────────────
540
541    #[test]
542    fn test_insert_single_node_entry_point_set() {
543        let mut g = HnswGraph::new(HnswConfig::default());
544        g.insert(0, vec2(1.0, 0.0));
545        assert_eq!(g.entry_point, Some(0));
546        assert_eq!(g.node_count(), 1);
547    }
548
549    #[test]
550    fn test_insert_single_node_layer_count() {
551        let mut g = HnswGraph::new(HnswConfig::default());
552        g.insert(0, vec2(0.0, 0.0));
553        assert!(g.layer_count() >= 1);
554    }
555
556    // ── Insert multiple nodes ─────────────────────────────────
557
558    #[test]
559    fn test_insert_multiple_increases_node_count() {
560        let mut g = HnswGraph::new(HnswConfig::default());
561        for i in 0..10_u32 {
562            g.insert(i as usize, vec![i as f32, 0.0]);
563        }
564        assert_eq!(g.node_count(), 10);
565    }
566
567    #[test]
568    fn test_entry_point_set_after_first_insert() {
569        let mut g = HnswGraph::new(HnswConfig::default());
570        g.insert(42, vec![1.0, 2.0]);
571        assert!(g.entry_point.is_some());
572    }
573
574    // ── Search ────────────────────────────────────────────────
575
576    #[test]
577    fn test_search_empty_graph_returns_empty() {
578        let g = HnswGraph::new(HnswConfig::default());
579        let results = g.search(&[0.0, 0.0], 3, 10);
580        assert!(results.is_empty());
581    }
582
583    #[test]
584    fn test_search_single_node_returns_it() {
585        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
586        g.insert(0, vec2(1.0, 0.0));
587        let results = g.search(&[1.0, 0.0], 1, 10);
588        assert!(!results.is_empty());
589        assert_eq!(results[0].0, 0);
590    }
591
592    #[test]
593    fn test_search_returns_at_most_k_results() {
594        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
595        for i in 0..20_u32 {
596            g.insert(i as usize, vec![i as f32, 0.0]);
597        }
598        let results = g.search(&[5.0, 0.0], 5, 20);
599        assert!(results.len() <= 5);
600    }
601
602    #[test]
603    fn test_search_results_ordered_by_distance() {
604        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
605        for i in 0..10_u32 {
606            g.insert(i as usize, vec![i as f32, 0.0]);
607        }
608        let query = vec![4.5, 0.0];
609        let results = g.search(&query, 5, 20);
610        // Distances must be non-decreasing
611        for w in results.windows(2) {
612            assert!(w[0].1 <= w[1].1 + 1e-5, "Results not sorted: {:?}", results);
613        }
614    }
615
616    #[test]
617    fn test_search_nearest_is_closest() {
618        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
619        // Well-separated points
620        g.insert(0, vec2(0.0, 0.0));
621        g.insert(1, vec2(100.0, 0.0));
622        g.insert(2, vec2(0.0, 100.0));
623        let results = g.search(&[1.0, 1.0], 1, 10);
624        assert!(!results.is_empty());
625        assert_eq!(results[0].0, 0); // closest to (0,0)
626    }
627
628    // ── layer_count grows ─────────────────────────────────────
629
630    #[test]
631    fn test_layer_count_non_zero_after_insert() {
632        let mut g = HnswGraph::new(HnswConfig::default());
633        g.insert(0, vec![1.0]);
634        assert!(g.layer_count() >= 1);
635    }
636
637    #[test]
638    fn test_layer_count_zero_when_empty() {
639        let g = HnswGraph::new(HnswConfig::default());
640        assert_eq!(g.layer_count(), 0);
641    }
642
643    // ── connections_at ────────────────────────────────────────
644
645    #[test]
646    fn test_connections_at_returns_none_for_unknown_id() {
647        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
648        g.insert(0, vec2(1.0, 0.0));
649        // ID 99 was never inserted
650        assert!(g.connections_at(99, 0).is_none());
651    }
652
653    #[test]
654    fn test_connections_at_returns_some_for_inserted_node() {
655        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
656        g.insert(0, vec2(0.0, 0.0));
657        // Layer 0 connections exist (may be empty for sole node)
658        assert!(g.connections_at(0, 0).is_some());
659    }
660
661    // ── Exact search on small graph ───────────────────────────
662
663    #[test]
664    fn test_exact_search_3_nodes() {
665        let mut g = HnswGraph::new(HnswConfig::new(2, 20));
666        g.insert(0, vec2(0.0, 0.0));
667        g.insert(1, vec2(1.0, 0.0));
668        g.insert(2, vec2(10.0, 0.0));
669
670        // Query at (0.1, 0) — should find node 0 first
671        let results = g.search(&[0.1, 0.0], 3, 10);
672        assert!(!results.is_empty());
673        // Node 0 should be the nearest or very close
674        let nearest = results[0].0;
675        assert!(
676            nearest == 0 || nearest == 1,
677            "Expected 0 or 1, got {nearest}"
678        );
679    }
680
681    // ── HnswNode ──────────────────────────────────────────────
682
683    #[test]
684    fn test_hnsw_node_new() {
685        let n = HnswNode::new(5, vec![1.0, 2.0], 2);
686        assert_eq!(n.id, 5);
687        assert_eq!(n.vector, vec![1.0, 2.0]);
688        assert_eq!(n.connections.len(), 3); // layers 0,1,2
689    }
690
691    #[test]
692    fn test_hnsw_node_ensure_layers() {
693        let mut n = HnswNode::new(0, vec![1.0], 0);
694        n.ensure_layers(3);
695        assert!(n.connections.len() >= 4);
696    }
697
698    // ── Multiple searches give consistent results ─────────────
699
700    #[test]
701    fn test_search_reproducible() {
702        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
703        for i in 0..15_u32 {
704            g.insert(i as usize, vec![(i as f32) * 0.1, 0.0]);
705        }
706        let r1 = g.search(&[0.5, 0.0], 3, 10);
707        let r2 = g.search(&[0.5, 0.0], 3, 10);
708        assert_eq!(r1.len(), r2.len());
709        for (a, b) in r1.iter().zip(r2.iter()) {
710            assert_eq!(a.0, b.0);
711        }
712    }
713
714    #[test]
715    fn test_search_returns_k_or_fewer() {
716        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
717        for i in 0..5_u32 {
718            g.insert(i as usize, vec![i as f32]);
719        }
720        let results = g.search(&[2.0], 10, 10);
721        // Can't return more than n_nodes results
722        assert!(results.len() <= 5);
723    }
724
725    #[test]
726    fn test_distances_non_negative() {
727        let mut g = HnswGraph::new(HnswConfig::new(4, 50));
728        for i in 0..8_u32 {
729            g.insert(i as usize, vec![i as f32, (8 - i) as f32]);
730        }
731        let results = g.search(&[4.0, 4.0], 5, 20);
732        for (_, dist) in &results {
733            assert!(*dist >= 0.0);
734        }
735    }
736}