Skip to main content

ext_vector/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) graph index.
2//!
3//! Implements approximate nearest neighbor search with:
4//! - Multi-layer navigable small world graph
5//! - Greedy search from entry point through layers
6//! - Beam search at target layer for recall
7//! - Simple heuristic neighbor selection
8
9use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::distance::DistanceMetric;
13
14/// Neighbor entry: (distance, vector_id).
15#[derive(Clone, Copy)]
16struct Neighbor {
17    dist: f32,
18    id: usize,
19}
20
21impl PartialEq for Neighbor {
22    fn eq(&self, other: &Self) -> bool {
23        self.dist.to_bits() == other.dist.to_bits() && self.id == other.id
24    }
25}
26
27impl Eq for Neighbor {}
28
29impl PartialOrd for Neighbor {
30    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
31        Some(self.cmp(other))
32    }
33}
34
35impl Ord for Neighbor {
36    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
37        self.dist
38            .partial_cmp(&other.dist)
39            .unwrap_or(std::cmp::Ordering::Equal)
40            .then(self.id.cmp(&other.id))
41    }
42}
43
44/// HNSW index configuration.
45pub struct HnswConfig {
46    /// Max connections per layer (default 16).
47    pub m: usize,
48    /// Max connections at layer 0 (default 2*M = 32).
49    pub m_max0: usize,
50    /// Beam width during construction (default 200).
51    pub ef_construction: usize,
52    /// Distance metric.
53    pub metric: DistanceMetric,
54}
55
56impl Default for HnswConfig {
57    fn default() -> Self {
58        Self {
59            m: 16,
60            m_max0: 32,
61            ef_construction: 200,
62            metric: DistanceMetric::L2,
63        }
64    }
65}
66
67/// HNSW index for approximate nearest neighbor search.
68pub struct HnswIndex {
69    vectors: Vec<Vec<f32>>,
70    /// neighbors[vector_id][layer] = vec of (neighbor_id, distance)
71    neighbors: Vec<Vec<Vec<(usize, f32)>>>,
72    entry_point: Option<usize>,
73    max_layer: usize,
74    m: usize,
75    m_max0: usize,
76    ef_construction: usize,
77    dim: usize,
78    metric: DistanceMetric,
79    ml: f64, // 1.0 / ln(M)
80    rng_state: u64,
81}
82
83impl HnswIndex {
84    /// Create a new HNSW index.
85    pub fn new(dim: usize, config: HnswConfig) -> Self {
86        let ml = 1.0 / (config.m as f64).ln();
87        Self {
88            vectors: Vec::new(),
89            neighbors: Vec::new(),
90            entry_point: None,
91            max_layer: 0,
92            m: config.m,
93            m_max0: config.m_max0,
94            ef_construction: config.ef_construction,
95            dim,
96            metric: config.metric,
97            ml,
98            rng_state: 0x5DEECE66D, // seed
99        }
100    }
101
102    /// Number of indexed vectors.
103    pub fn len(&self) -> usize {
104        self.vectors.len()
105    }
106
107    /// Whether the index is empty.
108    pub fn is_empty(&self) -> bool {
109        self.vectors.is_empty()
110    }
111
112    /// Insert a vector. Returns its assigned id.
113    pub fn insert(&mut self, vector: &[f32]) -> usize {
114        assert_eq!(vector.len(), self.dim, "dimension mismatch");
115
116        let id = self.vectors.len();
117        self.vectors.push(vector.to_vec());
118
119        let level = self.random_level();
120
121        // Initialize neighbor lists for all layers up to `level`.
122        let mut layers = Vec::with_capacity(level + 1);
123        for _ in 0..=level {
124            layers.push(Vec::new());
125        }
126        self.neighbors.push(layers);
127
128        if self.entry_point.is_none() {
129            // First vector — just set it as entry point.
130            self.entry_point = Some(id);
131            self.max_layer = level;
132            return id;
133        }
134
135        let ep = self.entry_point.unwrap();
136        let mut current_ep = ep;
137
138        // Phase 1: Greedy descent from top layer to insertion layer + 1.
139        for layer in (level + 1..=self.max_layer).rev() {
140            current_ep = self.greedy_closest(vector, current_ep, layer);
141        }
142
143        // Phase 2: At each layer from min(level, max_layer) down to 0,
144        // find neighbors and connect.
145        let start_layer = level.min(self.max_layer);
146        for layer in (0..=start_layer).rev() {
147            let m_for_layer = if layer == 0 { self.m_max0 } else { self.m };
148
149            // Search for ef_construction nearest at this layer.
150            let candidates = self.search_layer(vector, current_ep, self.ef_construction, layer);
151
152            // Select M best neighbors.
153            let selected: Vec<(usize, f32)> = candidates
154                .into_iter()
155                .take(m_for_layer)
156                .map(|n| (n.id, n.dist))
157                .collect();
158
159            // Connect: id -> selected neighbors.
160            self.neighbors[id][layer] = selected.clone();
161
162            // Reverse connections: selected neighbors -> id.
163            for &(neighbor_id, dist) in &selected {
164                if neighbor_id < self.neighbors.len()
165                    && layer < self.neighbors[neighbor_id].len()
166                {
167                    self.neighbors[neighbor_id][layer].push((id, dist));
168                    // Prune if too many connections.
169                    if self.neighbors[neighbor_id][layer].len() > m_for_layer {
170                        self.neighbors[neighbor_id][layer]
171                            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
172                        self.neighbors[neighbor_id][layer].truncate(m_for_layer);
173                    }
174                }
175            }
176
177            // Update entry point for next layer.
178            if !selected.is_empty() {
179                current_ep = selected[0].0;
180            }
181        }
182
183        // Update entry point if new vector has higher level.
184        if level > self.max_layer {
185            self.entry_point = Some(id);
186            self.max_layer = level;
187        }
188
189        id
190    }
191
192    /// Search for k approximate nearest neighbors.
193    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
194        if self.entry_point.is_none() {
195            return Vec::new();
196        }
197
198        let mut ep = self.entry_point.unwrap();
199
200        // Greedy descent through upper layers.
201        for layer in (1..=self.max_layer).rev() {
202            ep = self.greedy_closest(query, ep, layer);
203        }
204
205        // Beam search at layer 0.
206        let ef = ef.max(k);
207        let candidates = self.search_layer(query, ep, ef, 0);
208
209        candidates
210            .into_iter()
211            .take(k)
212            .map(|n| (n.id, n.dist))
213            .collect()
214    }
215
216    /// Greedy search: find the single closest node to `query` at `layer`.
217    fn greedy_closest(&self, query: &[f32], start: usize, layer: usize) -> usize {
218        let mut best = start;
219        let mut best_dist = self.distance(query, best);
220
221        loop {
222            let mut changed = false;
223            if layer < self.neighbors[best].len() {
224                for &(neighbor, _) in &self.neighbors[best][layer] {
225                    let d = self.distance(query, neighbor);
226                    if d < best_dist {
227                        best_dist = d;
228                        best = neighbor;
229                        changed = true;
230                    }
231                }
232            }
233            if !changed {
234                break;
235            }
236        }
237
238        best
239    }
240
241    /// Beam search at a specific layer. Returns neighbors sorted by distance (nearest first).
242    fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<Neighbor> {
243        let ep_dist = self.distance(query, ep);
244
245        // Candidates: min-heap (nearest first for expansion).
246        let mut candidates: BinaryHeap<Reverse<Neighbor>> = BinaryHeap::new();
247        // Result set: max-heap (furthest first for pruning).
248        let mut result: BinaryHeap<Neighbor> = BinaryHeap::new();
249        let mut visited = vec![false; self.vectors.len()];
250
251        let ep_neighbor = Neighbor { dist: ep_dist, id: ep };
252        candidates.push(Reverse(ep_neighbor));
253        result.push(ep_neighbor);
254        visited[ep] = true;
255
256        while let Some(Reverse(current)) = candidates.pop() {
257            // If nearest candidate is further than the furthest result, stop.
258            if result.peek().is_some_and(|f| current.dist > f.dist) {
259                break;
260            }
261
262            // Expand neighbors at this layer.
263            if layer < self.neighbors[current.id].len() {
264                for &(neighbor_id, _) in &self.neighbors[current.id][layer] {
265                    if visited[neighbor_id] {
266                        continue;
267                    }
268                    visited[neighbor_id] = true;
269
270                    let d = self.distance(query, neighbor_id);
271                    let n = Neighbor { dist: d, id: neighbor_id };
272
273                    let should_add = result.len() < ef
274                        || result.peek().is_some_and(|f| d < f.dist);
275
276                    if should_add {
277                        candidates.push(Reverse(n));
278                        result.push(n);
279                        if result.len() > ef {
280                            result.pop(); // Remove furthest.
281                        }
282                    }
283                }
284            }
285        }
286
287        // Drain result heap into sorted vec (nearest first).
288        let mut sorted: Vec<Neighbor> = result.into_vec();
289        sorted.sort();
290        sorted
291    }
292
293    /// Compute distance between query and indexed vector.
294    #[inline]
295    fn distance(&self, query: &[f32], id: usize) -> f32 {
296        self.metric.distance(query, &self.vectors[id])
297    }
298
299    /// Assign a random level for a new vector.
300    fn random_level(&mut self) -> usize {
301        // xorshift64
302        let mut x = self.rng_state;
303        x ^= x << 13;
304        x ^= x >> 7;
305        x ^= x << 17;
306        self.rng_state = x;
307
308        let r = (x as f64) / (u64::MAX as f64);
309        let level = (-r.ln() * self.ml) as usize;
310        level.min(16) // Cap at 16 layers.
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    fn make_index(dim: usize) -> HnswIndex {
319        HnswIndex::new(dim, HnswConfig::default())
320    }
321
322    #[test]
323    fn empty_search() {
324        let idx = make_index(3);
325        let results = idx.search(&[1.0, 0.0, 0.0], 5, 50);
326        assert!(results.is_empty());
327    }
328
329    #[test]
330    fn single_vector() {
331        let mut idx = make_index(3);
332        let id = idx.insert(&[1.0, 2.0, 3.0]);
333        assert_eq!(id, 0);
334
335        let results = idx.search(&[1.0, 2.0, 3.0], 1, 50);
336        assert_eq!(results.len(), 1);
337        assert_eq!(results[0].0, 0);
338        assert!(results[0].1 < 1e-6); // distance ~0
339    }
340
341    #[test]
342    fn exact_knn_small() {
343        let mut idx = make_index(2);
344
345        // Insert 10 known vectors.
346        let points: Vec<[f32; 2]> = (0..10)
347            .map(|i| [i as f32, 0.0])
348            .collect();
349
350        for p in &points {
351            idx.insert(p);
352        }
353
354        // Query at [5.0, 0.0] — nearest should be point 5.
355        let results = idx.search(&[5.0, 0.0], 3, 50);
356        assert!(!results.is_empty());
357        assert_eq!(results[0].0, 5); // exact nearest
358    }
359
360    #[test]
361    fn recall_100_vectors() {
362        let dim = 16;
363        let n = 100;
364        let mut idx = HnswIndex::new(dim, HnswConfig {
365            m: 16,
366            m_max0: 32,
367            ef_construction: 100,
368            metric: DistanceMetric::L2,
369        });
370
371        // Generate deterministic vectors.
372        let vectors: Vec<Vec<f32>> = (0..n)
373            .map(|i| (0..dim).map(|d| ((i * 7 + d * 13) % 100) as f32 / 100.0).collect())
374            .collect();
375
376        for v in &vectors {
377            idx.insert(v);
378        }
379
380        // Query with first vector, k=10.
381        let results = idx.search(&vectors[0], 10, 100);
382        assert!(!results.is_empty());
383        assert_eq!(results[0].0, 0); // should find itself
384
385        // Compute brute-force top-10 for recall check.
386        let mut brute: Vec<(usize, f32)> = vectors
387            .iter()
388            .enumerate()
389            .map(|(i, v)| (i, DistanceMetric::L2.distance(&vectors[0], v)))
390            .collect();
391        brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
392        let brute_top10: Vec<usize> = brute.iter().take(10).map(|r| r.0).collect();
393
394        // Check recall@10.
395        let hnsw_top10: Vec<usize> = results.iter().take(10).map(|r| r.0).collect();
396        let hits: usize = hnsw_top10.iter().filter(|id| brute_top10.contains(id)).count();
397        let recall = hits as f64 / 10.0;
398        assert!(recall >= 0.7, "recall@10 = {recall}, expected >= 0.7");
399    }
400
401    #[test]
402    fn cosine_metric() {
403        let mut idx = HnswIndex::new(3, HnswConfig {
404            metric: DistanceMetric::Cosine,
405            ..HnswConfig::default()
406        });
407
408        idx.insert(&[1.0, 0.0, 0.0]);
409        idx.insert(&[0.0, 1.0, 0.0]);
410        idx.insert(&[0.9, 0.1, 0.0]); // close to first
411
412        let results = idx.search(&[1.0, 0.0, 0.0], 3, 50);
413        assert_eq!(results.len(), 3);
414        // First result should be vector 0 (identical).
415        assert_eq!(results[0].0, 0);
416        assert!(results[0].1 < 1e-5);
417    }
418
419    #[test]
420    fn insert_respects_dimension() {
421        let mut idx = make_index(4);
422        idx.insert(&[1.0, 2.0, 3.0, 4.0]);
423        assert_eq!(idx.len(), 1);
424    }
425
426    #[test]
427    #[should_panic(expected = "dimension mismatch")]
428    fn dimension_mismatch_panics() {
429        let mut idx = make_index(4);
430        idx.insert(&[1.0, 2.0]); // wrong dimension
431    }
432
433    #[test]
434    fn level_distribution() {
435        let mut idx = make_index(2);
436        let mut max_level = 0;
437        for i in 0..1000 {
438            idx.insert(&[i as f32, 0.0]);
439            max_level = max_level.max(idx.max_layer);
440        }
441        // With M=16, ml=1/ln(16)≈0.36, expected max level for 1000 vectors ≈ 2-3.
442        assert!(max_level <= 8, "max_level = {max_level}, unexpectedly high");
443    }
444}