Skip to main content

ext_vector/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) graph index.
2//!
3//! Implements approximate nearest neighbor search with:
4//! - Multi-layer navigable small world graph
5//! - Greedy search from entry point through layers
6//! - Beam search at target layer for recall
7//! - Simple heuristic neighbor selection
8
9use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::distance::DistanceMetric;
13
14/// Neighbor entry: (distance, vector_id).
15#[derive(Clone, Copy)]
16struct Neighbor {
17    dist: f32,
18    id: usize,
19}
20
21impl PartialEq for Neighbor {
22    fn eq(&self, other: &Self) -> bool {
23        self.dist.to_bits() == other.dist.to_bits() && self.id == other.id
24    }
25}
26
27impl Eq for Neighbor {}
28
29impl PartialOrd for Neighbor {
30    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
31        Some(self.cmp(other))
32    }
33}
34
35impl Ord for Neighbor {
36    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
37        self.dist
38            .partial_cmp(&other.dist)
39            .unwrap_or(std::cmp::Ordering::Equal)
40            .then(self.id.cmp(&other.id))
41    }
42}
43
44/// HNSW index configuration.
45pub struct HnswConfig {
46    /// Max connections per layer (default 16).
47    pub m: usize,
48    /// Max connections at layer 0 (default 2*M = 32).
49    pub m_max0: usize,
50    /// Beam width during construction (default 200).
51    pub ef_construction: usize,
52    /// Distance metric.
53    pub metric: DistanceMetric,
54}
55
56impl Default for HnswConfig {
57    fn default() -> Self {
58        Self {
59            m: 16,
60            m_max0: 32,
61            ef_construction: 200,
62            metric: DistanceMetric::L2,
63        }
64    }
65}
66
67/// HNSW index for approximate nearest neighbor search.
68pub struct HnswIndex {
69    vectors: Vec<Vec<f32>>,
70    /// neighbors[vector_id][layer] = vec of (neighbor_id, distance)
71    neighbors: Vec<Vec<Vec<(usize, f32)>>>,
72    entry_point: Option<usize>,
73    max_layer: usize,
74    m: usize,
75    m_max0: usize,
76    ef_construction: usize,
77    dim: usize,
78    metric: DistanceMetric,
79    ml: f64, // 1.0 / ln(M)
80    rng_state: u64,
81}
82
83impl HnswIndex {
84    /// Create a new HNSW index.
85    pub fn new(dim: usize, config: HnswConfig) -> Self {
86        let ml = 1.0 / (config.m as f64).ln();
87        Self {
88            vectors: Vec::new(),
89            neighbors: Vec::new(),
90            entry_point: None,
91            max_layer: 0,
92            m: config.m,
93            m_max0: config.m_max0,
94            ef_construction: config.ef_construction,
95            dim,
96            metric: config.metric,
97            ml,
98            rng_state: 0x5DEECE66D, // seed
99        }
100    }
101
102    /// Number of indexed vectors.
103    pub fn len(&self) -> usize {
104        self.vectors.len()
105    }
106
107    /// Whether the index is empty.
108    pub fn is_empty(&self) -> bool {
109        self.vectors.is_empty()
110    }
111
112    /// Insert a vector. Returns its assigned id.
113    pub fn insert(&mut self, vector: &[f32]) -> usize {
114        assert_eq!(vector.len(), self.dim, "dimension mismatch");
115
116        let id = self.vectors.len();
117        self.vectors.push(vector.to_vec());
118
119        let level = self.random_level();
120
121        // Initialize neighbor lists for all layers up to `level`.
122        let mut layers = Vec::with_capacity(level + 1);
123        for _ in 0..=level {
124            layers.push(Vec::new());
125        }
126        self.neighbors.push(layers);
127
128        if self.entry_point.is_none() {
129            // First vector — just set it as entry point.
130            self.entry_point = Some(id);
131            self.max_layer = level;
132            return id;
133        }
134
135        let ep = self.entry_point.unwrap();
136        let mut current_ep = ep;
137
138        // Phase 1: Greedy descent from top layer to insertion layer + 1.
139        for layer in (level + 1..=self.max_layer).rev() {
140            current_ep = self.greedy_closest(vector, current_ep, layer);
141        }
142
143        // Phase 2: At each layer from min(level, max_layer) down to 0,
144        // find neighbors and connect.
145        let start_layer = level.min(self.max_layer);
146        for layer in (0..=start_layer).rev() {
147            let m_for_layer = if layer == 0 { self.m_max0 } else { self.m };
148
149            // Search for ef_construction nearest at this layer.
150            let candidates = self.search_layer(vector, current_ep, self.ef_construction, layer);
151
152            // Select M best neighbors.
153            let selected: Vec<(usize, f32)> = candidates
154                .into_iter()
155                .take(m_for_layer)
156                .map(|n| (n.id, n.dist))
157                .collect();
158
159            // Connect: id -> selected neighbors.
160            self.neighbors[id][layer] = selected.clone();
161
162            // Reverse connections: selected neighbors -> id.
163            for &(neighbor_id, dist) in &selected {
164                if neighbor_id < self.neighbors.len() && layer < self.neighbors[neighbor_id].len() {
165                    self.neighbors[neighbor_id][layer].push((id, dist));
166                    // Prune if too many connections.
167                    if self.neighbors[neighbor_id][layer].len() > m_for_layer {
168                        self.neighbors[neighbor_id][layer].sort_by(|a, b| {
169                            a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
170                        });
171                        self.neighbors[neighbor_id][layer].truncate(m_for_layer);
172                    }
173                }
174            }
175
176            // Update entry point for next layer.
177            if !selected.is_empty() {
178                current_ep = selected[0].0;
179            }
180        }
181
182        // Update entry point if new vector has higher level.
183        if level > self.max_layer {
184            self.entry_point = Some(id);
185            self.max_layer = level;
186        }
187
188        id
189    }
190
191    /// Search for k approximate nearest neighbors.
192    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
193        if self.entry_point.is_none() {
194            return Vec::new();
195        }
196
197        let mut ep = self.entry_point.unwrap();
198
199        // Greedy descent through upper layers.
200        for layer in (1..=self.max_layer).rev() {
201            ep = self.greedy_closest(query, ep, layer);
202        }
203
204        // Beam search at layer 0.
205        let ef = ef.max(k);
206        let candidates = self.search_layer(query, ep, ef, 0);
207
208        candidates
209            .into_iter()
210            .take(k)
211            .map(|n| (n.id, n.dist))
212            .collect()
213    }
214
215    /// Greedy search: find the single closest node to `query` at `layer`.
216    fn greedy_closest(&self, query: &[f32], start: usize, layer: usize) -> usize {
217        let mut best = start;
218        let mut best_dist = self.distance(query, best);
219
220        loop {
221            let mut changed = false;
222            if layer < self.neighbors[best].len() {
223                for &(neighbor, _) in &self.neighbors[best][layer] {
224                    let d = self.distance(query, neighbor);
225                    if d < best_dist {
226                        best_dist = d;
227                        best = neighbor;
228                        changed = true;
229                    }
230                }
231            }
232            if !changed {
233                break;
234            }
235        }
236
237        best
238    }
239
240    /// Beam search at a specific layer. Returns neighbors sorted by distance (nearest first).
241    fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<Neighbor> {
242        let ep_dist = self.distance(query, ep);
243
244        // Candidates: min-heap (nearest first for expansion).
245        let mut candidates: BinaryHeap<Reverse<Neighbor>> = BinaryHeap::new();
246        // Result set: max-heap (furthest first for pruning).
247        let mut result: BinaryHeap<Neighbor> = BinaryHeap::new();
248        let mut visited = vec![false; self.vectors.len()];
249
250        let ep_neighbor = Neighbor {
251            dist: ep_dist,
252            id: ep,
253        };
254        candidates.push(Reverse(ep_neighbor));
255        result.push(ep_neighbor);
256        visited[ep] = true;
257
258        while let Some(Reverse(current)) = candidates.pop() {
259            // If nearest candidate is further than the furthest result, stop.
260            if result.peek().is_some_and(|f| current.dist > f.dist) {
261                break;
262            }
263
264            // Expand neighbors at this layer.
265            if layer < self.neighbors[current.id].len() {
266                for &(neighbor_id, _) in &self.neighbors[current.id][layer] {
267                    if visited[neighbor_id] {
268                        continue;
269                    }
270                    visited[neighbor_id] = true;
271
272                    let d = self.distance(query, neighbor_id);
273                    let n = Neighbor {
274                        dist: d,
275                        id: neighbor_id,
276                    };
277
278                    let should_add = result.len() < ef || result.peek().is_some_and(|f| d < f.dist);
279
280                    if should_add {
281                        candidates.push(Reverse(n));
282                        result.push(n);
283                        if result.len() > ef {
284                            result.pop(); // Remove furthest.
285                        }
286                    }
287                }
288            }
289        }
290
291        // Drain result heap into sorted vec (nearest first).
292        let mut sorted: Vec<Neighbor> = result.into_vec();
293        sorted.sort();
294        sorted
295    }
296
297    /// Compute distance between query and indexed vector.
298    #[inline]
299    fn distance(&self, query: &[f32], id: usize) -> f32 {
300        self.metric.distance(query, &self.vectors[id])
301    }
302
303    /// Assign a random level for a new vector.
304    fn random_level(&mut self) -> usize {
305        // xorshift64
306        let mut x = self.rng_state;
307        x ^= x << 13;
308        x ^= x >> 7;
309        x ^= x << 17;
310        self.rng_state = x;
311
312        let r = (x as f64) / (u64::MAX as f64);
313        let level = (-r.ln() * self.ml) as usize;
314        level.min(16) // Cap at 16 layers.
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    fn make_index(dim: usize) -> HnswIndex {
323        HnswIndex::new(dim, HnswConfig::default())
324    }
325
326    #[test]
327    fn empty_search() {
328        let idx = make_index(3);
329        let results = idx.search(&[1.0, 0.0, 0.0], 5, 50);
330        assert!(results.is_empty());
331    }
332
333    #[test]
334    fn single_vector() {
335        let mut idx = make_index(3);
336        let id = idx.insert(&[1.0, 2.0, 3.0]);
337        assert_eq!(id, 0);
338
339        let results = idx.search(&[1.0, 2.0, 3.0], 1, 50);
340        assert_eq!(results.len(), 1);
341        assert_eq!(results[0].0, 0);
342        assert!(results[0].1 < 1e-6); // distance ~0
343    }
344
345    #[test]
346    fn exact_knn_small() {
347        let mut idx = make_index(2);
348
349        // Insert 10 known vectors.
350        let points: Vec<[f32; 2]> = (0..10).map(|i| [i as f32, 0.0]).collect();
351
352        for p in &points {
353            idx.insert(p);
354        }
355
356        // Query at [5.0, 0.0] — nearest should be point 5.
357        let results = idx.search(&[5.0, 0.0], 3, 50);
358        assert!(!results.is_empty());
359        assert_eq!(results[0].0, 5); // exact nearest
360    }
361
362    #[test]
363    fn recall_100_vectors() {
364        let dim = 16;
365        let n = 100;
366        let mut idx = HnswIndex::new(
367            dim,
368            HnswConfig {
369                m: 16,
370                m_max0: 32,
371                ef_construction: 100,
372                metric: DistanceMetric::L2,
373            },
374        );
375
376        // Generate deterministic vectors.
377        let vectors: Vec<Vec<f32>> = (0..n)
378            .map(|i| {
379                (0..dim)
380                    .map(|d| ((i * 7 + d * 13) % 100) as f32 / 100.0)
381                    .collect()
382            })
383            .collect();
384
385        for v in &vectors {
386            idx.insert(v);
387        }
388
389        // Query with first vector, k=10.
390        let results = idx.search(&vectors[0], 10, 100);
391        assert!(!results.is_empty());
392        assert_eq!(results[0].0, 0); // should find itself
393
394        // Compute brute-force top-10 for recall check.
395        let mut brute: Vec<(usize, f32)> = vectors
396            .iter()
397            .enumerate()
398            .map(|(i, v)| (i, DistanceMetric::L2.distance(&vectors[0], v)))
399            .collect();
400        brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
401        let brute_top10: Vec<usize> = brute.iter().take(10).map(|r| r.0).collect();
402
403        // Check recall@10.
404        let hnsw_top10: Vec<usize> = results.iter().take(10).map(|r| r.0).collect();
405        let hits: usize = hnsw_top10
406            .iter()
407            .filter(|id| brute_top10.contains(id))
408            .count();
409        let recall = hits as f64 / 10.0;
410        assert!(recall >= 0.7, "recall@10 = {recall}, expected >= 0.7");
411    }
412
413    #[test]
414    fn cosine_metric() {
415        let mut idx = HnswIndex::new(
416            3,
417            HnswConfig {
418                metric: DistanceMetric::Cosine,
419                ..HnswConfig::default()
420            },
421        );
422
423        idx.insert(&[1.0, 0.0, 0.0]);
424        idx.insert(&[0.0, 1.0, 0.0]);
425        idx.insert(&[0.9, 0.1, 0.0]); // close to first
426
427        let results = idx.search(&[1.0, 0.0, 0.0], 3, 50);
428        assert_eq!(results.len(), 3);
429        // First result should be vector 0 (identical).
430        assert_eq!(results[0].0, 0);
431        assert!(results[0].1 < 1e-5);
432    }
433
434    #[test]
435    fn insert_respects_dimension() {
436        let mut idx = make_index(4);
437        idx.insert(&[1.0, 2.0, 3.0, 4.0]);
438        assert_eq!(idx.len(), 1);
439    }
440
441    #[test]
442    #[should_panic(expected = "dimension mismatch")]
443    fn dimension_mismatch_panics() {
444        let mut idx = make_index(4);
445        idx.insert(&[1.0, 2.0]); // wrong dimension
446    }
447
448    #[test]
449    fn level_distribution() {
450        let mut idx = make_index(2);
451        let mut max_level = 0;
452        for i in 0..1000 {
453            idx.insert(&[i as f32, 0.0]);
454            max_level = max_level.max(idx.max_layer);
455        }
456        // With M=16, ml=1/ln(16)≈0.36, expected max level for 1000 vectors ≈ 2-3.
457        assert!(max_level <= 8, "max_level = {max_level}, unexpectedly high");
458    }
459}