Skip to main content

kora_vector/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) graph index.
2//!
3//! Implements the algorithm from Malkov & Yashunin (2018) for approximate
4//! nearest neighbor search with logarithmic query time and high recall.
5//!
6//! Key design decisions:
7//!
8//! - **Single-owner, no locking.** The index is meant to live inside one Kōra
9//!   shard worker and is accessed through `&mut self` / `&self`. No interior
10//!   mutability or atomics are needed.
11//! - **Lazy deletion.** [`HnswIndex::delete`] marks nodes as deleted without
12//!   removing them from the graph, keeping neighbour connectivity intact and
13//!   avoiding expensive edge repair.
14//! - **Deterministic level generation.** A simple xorshift64 PRNG seeded at
15//!   construction produces reproducible layer assignments, which simplifies
16//!   testing and benchmarking.
17
18use std::collections::{BinaryHeap, HashMap, HashSet};
19
20use ordered_float::OrderedFloat;
21
22use crate::distance::DistanceMetric;
23
24/// An HNSW index for approximate nearest neighbor search.
25pub struct HnswIndex {
26    dim: usize,
27    metric: DistanceMetric,
28    m: usize,
29    m_max0: usize,
30    ef_construction: usize,
31    ml: f64,
32    nodes: HashMap<u64, Node>,
33    entry_point: Option<u64>,
34    max_layer: usize,
35    rng_state: u64,
36}
37
38struct Node {
39    id: u64,
40    vector: Vec<f32>,
41    layer: usize,
42    /// Per-layer neighbour lists, indexed by layer number.
43    neighbors: Vec<Vec<u64>>,
44    deleted: bool,
45}
46
47/// A search result: (id, distance).
48#[derive(Debug, Clone, PartialEq)]
49pub struct SearchResult {
50    /// The vector ID.
51    pub id: u64,
52    /// The distance to the query.
53    pub distance: f32,
54}
55
56#[derive(PartialEq, Eq)]
57struct MinItem(OrderedFloat<f32>, u64);
58
59impl Ord for MinItem {
60    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
61        other.0.cmp(&self.0)
62    }
63}
64impl PartialOrd for MinItem {
65    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
66        Some(self.cmp(other))
67    }
68}
69
70#[derive(PartialEq, Eq)]
71struct MaxItem(OrderedFloat<f32>, u64);
72
73impl Ord for MaxItem {
74    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
75        self.0.cmp(&other.0)
76    }
77}
78impl PartialOrd for MaxItem {
79    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
80        Some(self.cmp(other))
81    }
82}
83
84impl HnswIndex {
85    /// Create a new empty HNSW index.
86    ///
87    /// - `dim`: vector dimensionality
88    /// - `metric`: distance metric to use
89    /// - `m`: max connections per node per layer (typical: 16)
90    /// - `ef_construction`: search width during construction (typical: 200)
91    pub fn new(dim: usize, metric: DistanceMetric, m: usize, ef_construction: usize) -> Self {
92        Self {
93            dim,
94            metric,
95            m,
96            m_max0: m * 2,
97            ef_construction,
98            ml: 1.0 / (m as f64).ln(),
99            nodes: HashMap::new(),
100            entry_point: None,
101            max_layer: 0,
102            rng_state: 42,
103        }
104    }
105
106    /// Get the number of vectors in the index.
107    pub fn len(&self) -> usize {
108        self.nodes.values().filter(|n| !n.deleted).count()
109    }
110
111    /// Check if the index is empty.
112    pub fn is_empty(&self) -> bool {
113        self.len() == 0
114    }
115
116    /// Get the dimensionality.
117    pub fn dim(&self) -> usize {
118        self.dim
119    }
120
121    /// Get the distance metric.
122    pub fn metric(&self) -> DistanceMetric {
123        self.metric
124    }
125
126    /// Insert a vector with the given ID.
127    ///
128    /// If a vector with this ID already exists, it is replaced.
129    pub fn insert(&mut self, id: u64, vector: &[f32]) {
130        assert_eq!(
131            vector.len(),
132            self.dim,
133            "vector dimension mismatch: expected {}, got {}",
134            self.dim,
135            vector.len()
136        );
137
138        if self.nodes.contains_key(&id) {
139            self.delete(id);
140        }
141
142        let level = self.random_level();
143        let vector = vector.to_vec();
144
145        let mut neighbors = Vec::with_capacity(level + 1);
146        for _ in 0..=level {
147            neighbors.push(Vec::new());
148        }
149
150        let node = Node {
151            id,
152            vector,
153            layer: level,
154            neighbors,
155            deleted: false,
156        };
157
158        self.nodes.insert(id, node);
159
160        if self.entry_point.is_none() {
161            self.entry_point = Some(id);
162            self.max_layer = level;
163            return;
164        }
165
166        let ep = match self.entry_point {
167            Some(ep) => ep,
168            None => return,
169        };
170
171        let mut current_ep = ep;
172        let query = &self.nodes[&id].vector.clone();
173
174        for lc in (level + 1..=self.max_layer).rev() {
175            current_ep = self.greedy_closest(query, current_ep, lc);
176        }
177
178        let insert_top = level.min(self.max_layer);
179        let mut ep_set = vec![current_ep];
180
181        for lc in (0..=insert_top).rev() {
182            let m_max = if lc == 0 { self.m_max0 } else { self.m };
183
184            let candidates = self.search_layer(query, &ep_set, self.ef_construction, lc);
185
186            let selected: Vec<u64> = candidates.iter().take(m_max).map(|&(_, nid)| nid).collect();
187
188            if let Some(node) = self.nodes.get_mut(&id) {
189                node.neighbors[lc] = selected.clone();
190            }
191
192            for &neighbor_id in &selected {
193                let needs_prune = {
194                    let Some(neighbor) = self.nodes.get_mut(&neighbor_id) else {
195                        continue;
196                    };
197                    if lc < neighbor.neighbors.len() {
198                        neighbor.neighbors[lc].push(id);
199                        neighbor.neighbors[lc].len() > m_max
200                    } else {
201                        false
202                    }
203                };
204
205                if needs_prune {
206                    let nv = self.nodes[&neighbor_id].vector.clone();
207                    let neighbor_ids: Vec<u64> = self.nodes[&neighbor_id].neighbors[lc].clone();
208                    let mut scored: Vec<(f32, u64)> = neighbor_ids
209                        .iter()
210                        .map(|&nid| {
211                            let dist = self.metric.distance(&nv, &self.nodes[&nid].vector);
212                            (dist, nid)
213                        })
214                        .collect();
215                    scored
216                        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
217                    scored.truncate(m_max);
218                    if let Some(neighbor) = self.nodes.get_mut(&neighbor_id) {
219                        neighbor.neighbors[lc] = scored.into_iter().map(|(_, nid)| nid).collect();
220                    }
221                }
222            }
223
224            ep_set = candidates.iter().map(|&(_, nid)| nid).collect();
225        }
226
227        if level > self.max_layer {
228            self.entry_point = Some(id);
229            self.max_layer = level;
230        }
231    }
232
233    /// Mark a vector as deleted (lazy deletion).
234    pub fn delete(&mut self, id: u64) {
235        if let Some(node) = self.nodes.get_mut(&id) {
236            node.deleted = true;
237        }
238
239        if self.entry_point == Some(id) {
240            self.entry_point = self
241                .nodes
242                .values()
243                .filter(|n| !n.deleted)
244                .max_by_key(|n| n.layer)
245                .map(|n| n.id);
246            if let Some(ep) = self.entry_point {
247                self.max_layer = self.nodes[&ep].layer;
248            } else {
249                self.max_layer = 0;
250            }
251        }
252    }
253
254    /// Search for the K nearest neighbors of the query vector.
255    ///
256    /// - `query`: the query vector
257    /// - `k`: number of results to return
258    /// - `ef`: search width (larger = better recall, slower; typical: 50-200)
259    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
260        assert_eq!(query.len(), self.dim);
261
262        let ep = match self.entry_point {
263            Some(ep) if !self.nodes[&ep].deleted || !self.is_empty() => ep,
264            _ => return vec![],
265        };
266
267        let mut current_ep = ep;
268        for lc in (1..=self.max_layer).rev() {
269            current_ep = self.greedy_closest(query, current_ep, lc);
270        }
271
272        let ef = ef.max(k);
273        let candidates = self.search_layer(query, &[current_ep], ef, 0);
274
275        candidates
276            .into_iter()
277            .filter(|&(_, id)| !self.nodes[&id].deleted)
278            .take(k)
279            .map(|(dist, id)| SearchResult { id, distance: dist })
280            .collect()
281    }
282
283    /// Check if a vector with the given ID exists (and is not deleted).
284    pub fn contains(&self, id: u64) -> bool {
285        self.nodes.get(&id).is_some_and(|n| !n.deleted)
286    }
287
288    fn random_level(&mut self) -> usize {
289        self.rng_state ^= self.rng_state << 13;
290        self.rng_state ^= self.rng_state >> 7;
291        self.rng_state ^= self.rng_state << 17;
292
293        let r = (self.rng_state as f64) / (u64::MAX as f64);
294        (-r.ln() * self.ml) as usize
295    }
296
297    fn greedy_closest(&self, query: &[f32], mut ep: u64, layer: usize) -> u64 {
298        let mut best_dist = self.metric.distance(query, &self.nodes[&ep].vector);
299
300        loop {
301            let mut changed = false;
302            let node = &self.nodes[&ep];
303            if layer < node.neighbors.len() {
304                for &neighbor_id in &node.neighbors[layer] {
305                    if let Some(neighbor) = self.nodes.get(&neighbor_id) {
306                        let dist = self.metric.distance(query, &neighbor.vector);
307                        if dist < best_dist {
308                            best_dist = dist;
309                            ep = neighbor_id;
310                            changed = true;
311                        }
312                    }
313                }
314            }
315            if !changed {
316                break;
317            }
318        }
319        ep
320    }
321
322    fn search_layer(
323        &self,
324        query: &[f32],
325        entry_points: &[u64],
326        ef: usize,
327        layer: usize,
328    ) -> Vec<(f32, u64)> {
329        let mut visited = HashSet::new();
330        let mut candidates: BinaryHeap<MinItem> = BinaryHeap::new();
331        let mut results: BinaryHeap<MaxItem> = BinaryHeap::new();
332
333        for &ep in entry_points {
334            if !self.nodes.contains_key(&ep) {
335                continue;
336            }
337            let dist = self.metric.distance(query, &self.nodes[&ep].vector);
338            visited.insert(ep);
339            candidates.push(MinItem(OrderedFloat(dist), ep));
340            results.push(MaxItem(OrderedFloat(dist), ep));
341        }
342
343        while let Some(MinItem(c_dist, c_id)) = candidates.pop() {
344            let f_dist = results
345                .peek()
346                .map(|r| r.0)
347                .unwrap_or(OrderedFloat(f32::MAX));
348            if c_dist > f_dist {
349                break;
350            }
351
352            let node = match self.nodes.get(&c_id) {
353                Some(n) => n,
354                None => continue,
355            };
356
357            if layer < node.neighbors.len() {
358                for &neighbor_id in &node.neighbors[layer] {
359                    if !visited.insert(neighbor_id) {
360                        continue;
361                    }
362                    if let Some(neighbor) = self.nodes.get(&neighbor_id) {
363                        let dist = self.metric.distance(query, &neighbor.vector);
364                        let f_dist = results
365                            .peek()
366                            .map(|r| r.0)
367                            .unwrap_or(OrderedFloat(f32::MAX));
368
369                        if dist < f_dist.0 || results.len() < ef {
370                            candidates.push(MinItem(OrderedFloat(dist), neighbor_id));
371                            results.push(MaxItem(OrderedFloat(dist), neighbor_id));
372                            if results.len() > ef {
373                                results.pop();
374                            }
375                        }
376                    }
377                }
378            }
379        }
380
381        let mut result: Vec<(f32, u64)> = results
382            .into_iter()
383            .map(|MaxItem(d, id)| (d.0, id))
384            .collect();
385        result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
386        result
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    fn make_index(n: usize, dim: usize) -> (HnswIndex, Vec<Vec<f32>>) {
395        let mut index = HnswIndex::new(dim, DistanceMetric::L2, 16, 200);
396        let mut vectors = Vec::new();
397        for i in 0..n {
398            let v: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect();
399            vectors.push(v.clone());
400            index.insert(i as u64, &v);
401        }
402        (index, vectors)
403    }
404
405    #[test]
406    fn test_insert_and_search() {
407        let (index, vectors) = make_index(100, 8);
408        assert_eq!(index.len(), 100);
409
410        // Search for vector 42; it should be its own nearest neighbor
411        let results = index.search(&vectors[42], 5, 50);
412        assert!(!results.is_empty());
413        assert_eq!(results[0].id, 42);
414        assert!(results[0].distance < 1e-6);
415    }
416
417    #[test]
418    fn test_search_empty_index() {
419        let index = HnswIndex::new(4, DistanceMetric::L2, 16, 200);
420        let results = index.search(&[1.0, 2.0, 3.0, 4.0], 5, 50);
421        assert!(results.is_empty());
422    }
423
424    #[test]
425    fn test_single_vector() {
426        let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
427        index.insert(1, &[1.0, 2.0, 3.0]);
428
429        let results = index.search(&[1.0, 2.0, 3.0], 5, 50);
430        assert_eq!(results.len(), 1);
431        assert_eq!(results[0].id, 1);
432    }
433
434    #[test]
435    fn test_delete() {
436        let (mut index, vectors) = make_index(50, 4);
437        assert_eq!(index.len(), 50);
438
439        index.delete(25);
440        assert_eq!(index.len(), 49);
441        assert!(!index.contains(25));
442
443        // Search should not return deleted vector
444        let results = index.search(&vectors[25], 5, 50);
445        assert!(results.iter().all(|r| r.id != 25));
446    }
447
448    #[test]
449    fn test_cosine_metric() {
450        let mut index = HnswIndex::new(3, DistanceMetric::Cosine, 16, 200);
451        // Two parallel vectors should have distance ~0
452        index.insert(1, &[1.0, 0.0, 0.0]);
453        index.insert(2, &[2.0, 0.0, 0.0]); // same direction, different magnitude
454        index.insert(3, &[0.0, 1.0, 0.0]); // orthogonal
455
456        let results = index.search(&[3.0, 0.0, 0.0], 3, 50);
457        // Both 1 and 2 should be closer than 3
458        assert!(results.len() >= 2);
459        let ids: Vec<u64> = results.iter().map(|r| r.id).collect();
460        // Vector 3 (orthogonal) should be last
461        assert!(ids[0] == 1 || ids[0] == 2);
462    }
463
464    #[test]
465    fn test_inner_product() {
466        let mut index = HnswIndex::new(2, DistanceMetric::InnerProduct, 16, 200);
467        index.insert(1, &[1.0, 0.0]);
468        index.insert(2, &[0.0, 1.0]);
469        index.insert(3, &[10.0, 0.0]); // highest inner product with [1,0]
470
471        let results = index.search(&[1.0, 0.0], 3, 50);
472        // id=3 has highest inner product (10), so lowest negative_ip distance
473        assert_eq!(results[0].id, 3);
474    }
475
476    #[test]
477    fn test_recall_quality() {
478        // Insert 500 random-ish vectors and verify recall > 80% for k=10
479        let n = 500;
480        let dim = 16;
481        let (index, vectors) = make_index(n, dim);
482
483        let query = &vectors[0];
484        let k = 10;
485
486        // Brute-force ground truth
487        let mut dists: Vec<(f32, u64)> = vectors
488            .iter()
489            .enumerate()
490            .map(|(i, v)| (DistanceMetric::L2.distance(query, v), i as u64))
491            .collect();
492        dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
493        let ground_truth: HashSet<u64> = dists.iter().take(k).map(|&(_, id)| id).collect();
494
495        let results = index.search(query, k, 100);
496        let found: HashSet<u64> = results.iter().map(|r| r.id).collect();
497
498        let recall = ground_truth.intersection(&found).count() as f32 / k as f32;
499        assert!(
500            recall >= 0.8,
501            "Recall too low: {:.2} (expected >= 0.80)",
502            recall
503        );
504    }
505
506    #[test]
507    fn test_duplicate_insert() {
508        let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
509        index.insert(1, &[1.0, 2.0, 3.0]);
510        index.insert(1, &[4.0, 5.0, 6.0]); // replace
511
512        assert_eq!(index.len(), 1);
513        let results = index.search(&[4.0, 5.0, 6.0], 1, 50);
514        assert_eq!(results[0].id, 1);
515        assert!(results[0].distance < 1e-6);
516    }
517
518    #[test]
519    fn test_k_larger_than_index() {
520        let (index, _) = make_index(5, 4);
521        let results = index.search(&[0.0; 4], 100, 200);
522        assert_eq!(results.len(), 5); // can't return more than exist
523    }
524
525    #[test]
526    fn test_contains() {
527        let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
528        assert!(!index.contains(1));
529        index.insert(1, &[1.0, 2.0, 3.0]);
530        assert!(index.contains(1));
531        index.delete(1);
532        assert!(!index.contains(1));
533    }
534}