Skip to main content

lean_ctx/core/
hnsw.rs

1//! Lightweight HNSW (Hierarchical Navigable Small World) index for approximate nearest neighbors.
2//!
3//! Scientific basis: Malkov & Yashunin, "Efficient and Robust Approximate Nearest Neighbor
4//! using Hierarchical Navigable Small World Graphs" (IEEE TPAMI 2018).
5//!
6//! This is a minimal implementation optimized for lean-ctx's embedding dimensions (384-d).
7//! For indices under BRUTE_FORCE_THRESHOLD chunks, falls back to exact linear scan
8//! with binary-heap top-k selection (O(n log k) instead of O(n log n)).
9
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13const BRUTE_FORCE_THRESHOLD: usize = 1000;
14const M: usize = 16; // max connections per node per layer
15const EF_CONSTRUCTION: usize = 200; // search width during build
16const EF_SEARCH: usize = 64; // search width during query
17                             // ML = 1/ln(M) = 1/ln(16) ≈ 0.3607
18const ML: f64 = 0.360_674_0;
19
20/// A scored item for the min-heap (lowest similarity first for top-k pruning).
21#[derive(Clone, PartialEq)]
22struct Candidate {
23    idx: usize,
24    sim: f32,
25}
26
27impl Eq for Candidate {}
28
29impl PartialOrd for Candidate {
30    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
31        Some(self.cmp(other))
32    }
33}
34
35impl Ord for Candidate {
36    fn cmp(&self, other: &Self) -> Ordering {
37        // Min-heap: lower similarity should be popped first
38        other.sim.partial_cmp(&self.sim).unwrap_or(Ordering::Equal)
39    }
40}
41
42/// Max-heap variant for HNSW traversal.
43#[derive(Clone, PartialEq)]
44struct MaxCandidate {
45    idx: usize,
46    sim: f32,
47}
48
49impl Eq for MaxCandidate {}
50
51impl PartialOrd for MaxCandidate {
52    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
53        Some(self.cmp(other))
54    }
55}
56
57impl Ord for MaxCandidate {
58    fn cmp(&self, other: &Self) -> Ordering {
59        self.sim.partial_cmp(&other.sim).unwrap_or(Ordering::Equal)
60    }
61}
62
63/// HNSW index node.
64struct Node {
65    connections: Vec<Vec<usize>>, // connections[layer] = list of neighbor indices
66}
67
68/// Approximate nearest neighbor index using HNSW for large datasets,
69/// with brute-force fallback for small ones.
70pub struct AnnIndex {
71    vectors: Vec<Vec<f32>>,
72    nodes: Vec<Node>,
73    entry_point: usize,
74    max_level: usize,
75}
76
77impl AnnIndex {
78    /// Build the index from a set of vectors.
79    pub fn build(vectors: Vec<Vec<f32>>) -> Self {
80        let n = vectors.len();
81        if n == 0 {
82            return Self {
83                vectors,
84                nodes: Vec::new(),
85                entry_point: 0,
86                max_level: 0,
87            };
88        }
89
90        if n < BRUTE_FORCE_THRESHOLD {
91            return Self {
92                vectors,
93                nodes: Vec::new(),
94                entry_point: 0,
95                max_level: 0,
96            };
97        }
98
99        let mut index = Self {
100            vectors: Vec::with_capacity(n),
101            nodes: Vec::with_capacity(n),
102            entry_point: 0,
103            max_level: 0,
104        };
105
106        for vec in vectors {
107            index.insert(vec);
108        }
109
110        index
111    }
112
113    fn insert(&mut self, vec: Vec<f32>) {
114        let level = Self::random_level();
115        let new_id = self.vectors.len();
116
117        self.vectors.push(vec);
118        self.nodes.push(Node {
119            connections: vec![Vec::new(); level + 1],
120        });
121
122        if self.nodes.len() == 1 {
123            self.entry_point = 0;
124            self.max_level = level;
125            return;
126        }
127
128        let mut ep = self.entry_point;
129
130        // Traverse from top layer down to level+1 (greedy)
131        for lc in (level + 1..=self.max_level).rev() {
132            ep = self.search_layer_single(&self.vectors[new_id], ep, lc);
133        }
134
135        // Insert into layers [min(level, max_level) .. 0]
136        let insert_levels = level.min(self.max_level);
137        for lc in (0..=insert_levels).rev() {
138            let neighbors = self.search_layer(&self.vectors[new_id], ep, EF_CONSTRUCTION, lc);
139            let selected = Self::select_neighbors(&neighbors, M);
140
141            if lc < self.nodes[new_id].connections.len() {
142                self.nodes[new_id].connections[lc].clone_from(&selected);
143            }
144
145            for &neighbor in &selected {
146                if lc < self.nodes[neighbor].connections.len() {
147                    self.nodes[neighbor].connections[lc].push(new_id);
148                    if self.nodes[neighbor].connections[lc].len() > M * 2 {
149                        let nv = &self.vectors[neighbor];
150                        let mut scored: Vec<(usize, f32)> = self.nodes[neighbor].connections[lc]
151                            .iter()
152                            .map(|&n| (n, cosine_sim(nv, &self.vectors[n])))
153                            .collect();
154                        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
155                        scored.truncate(M);
156                        self.nodes[neighbor].connections[lc] =
157                            scored.into_iter().map(|(id, _)| id).collect();
158                    }
159                }
160            }
161
162            if !neighbors.is_empty() {
163                ep = neighbors[0].0;
164            }
165        }
166
167        if level > self.max_level {
168            self.max_level = level;
169            self.entry_point = new_id;
170        }
171    }
172
173    fn search_layer_single(&self, query: &[f32], ep: usize, _layer: usize) -> usize {
174        let mut current = ep;
175        let mut best_sim = cosine_sim(query, &self.vectors[ep]);
176
177        loop {
178            let mut improved = false;
179            let conns = &self.nodes[current].connections;
180            let layer_conns = if _layer < conns.len() {
181                &conns[_layer]
182            } else {
183                break;
184            };
185
186            for &neighbor in layer_conns {
187                let sim = cosine_sim(query, &self.vectors[neighbor]);
188                if sim > best_sim {
189                    best_sim = sim;
190                    current = neighbor;
191                    improved = true;
192                }
193            }
194            if !improved {
195                break;
196            }
197        }
198        current
199    }
200
201    fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<(usize, f32)> {
202        let mut visited = vec![false; self.vectors.len()];
203        let mut candidates = BinaryHeap::<MaxCandidate>::new();
204        let mut results = BinaryHeap::<Candidate>::new();
205
206        let sim = cosine_sim(query, &self.vectors[ep]);
207        visited[ep] = true;
208        candidates.push(MaxCandidate { idx: ep, sim });
209        results.push(Candidate { idx: ep, sim });
210
211        while let Some(MaxCandidate { idx: c, sim: _ }) = candidates.pop() {
212            let worst_result = results.peek().map_or(f32::MIN, |r| r.sim);
213            if cosine_sim(query, &self.vectors[c]) < worst_result && results.len() >= ef {
214                break;
215            }
216
217            let conns = &self.nodes[c].connections;
218            let layer_conns = if layer < conns.len() {
219                &conns[layer]
220            } else {
221                continue;
222            };
223
224            for &neighbor in layer_conns {
225                if visited[neighbor] {
226                    continue;
227                }
228                visited[neighbor] = true;
229
230                let n_sim = cosine_sim(query, &self.vectors[neighbor]);
231                let worst = results.peek().map_or(f32::MIN, |r| r.sim);
232
233                if results.len() < ef || n_sim > worst {
234                    candidates.push(MaxCandidate {
235                        idx: neighbor,
236                        sim: n_sim,
237                    });
238                    results.push(Candidate {
239                        idx: neighbor,
240                        sim: n_sim,
241                    });
242                    if results.len() > ef {
243                        results.pop();
244                    }
245                }
246            }
247        }
248
249        let mut out: Vec<(usize, f32)> = results.into_iter().map(|c| (c.idx, c.sim)).collect();
250        out.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
251        out
252    }
253
254    fn select_neighbors(candidates: &[(usize, f32)], max_count: usize) -> Vec<usize> {
255        candidates
256            .iter()
257            .take(max_count)
258            .map(|&(idx, _)| idx)
259            .collect()
260    }
261
262    fn random_level() -> usize {
263        let mut buf = [0u8; 4];
264        let _ = getrandom::fill(&mut buf);
265        let r = f64::from(u32::from_le_bytes(buf)) / f64::from(u32::MAX);
266        (-r.ln() * ML).floor() as usize
267    }
268
269    /// Search for the top-k nearest neighbors of a query vector.
270    /// Returns (index, similarity) pairs sorted by descending similarity.
271    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
272        if self.vectors.is_empty() {
273            return Vec::new();
274        }
275
276        // Brute-force for small indices (faster due to no graph overhead)
277        if self.nodes.is_empty() || self.vectors.len() < BRUTE_FORCE_THRESHOLD {
278            return brute_force_topk(&self.vectors, query, top_k);
279        }
280
281        // HNSW search
282        let mut ep = self.entry_point;
283        for lc in (1..=self.max_level).rev() {
284            ep = self.search_layer_single(query, ep, lc);
285        }
286
287        let mut results = self.search_layer(query, ep, EF_SEARCH.max(top_k), 0);
288        results.truncate(top_k);
289        results
290    }
291}
292
293/// O(n log k) brute-force top-k selection using a min-heap.
294pub fn brute_force_topk(vectors: &[Vec<f32>], query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
295    let mut heap = BinaryHeap::<Candidate>::with_capacity(top_k + 1);
296
297    for (i, vec) in vectors.iter().enumerate() {
298        let sim = cosine_sim(query, vec);
299        if heap.len() < top_k {
300            heap.push(Candidate { idx: i, sim });
301        } else if let Some(worst) = heap.peek() {
302            if sim > worst.sim {
303                heap.pop();
304                heap.push(Candidate { idx: i, sim });
305            }
306        }
307    }
308
309    let mut results: Vec<(usize, f32)> = heap.into_iter().map(|c| (c.idx, c.sim)).collect();
310    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
311    results
312}
313
314#[inline]
315fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
316    if a.len() != b.len() {
317        return 0.0;
318    }
319    let mut dot = 0.0f32;
320    let mut norm_a = 0.0f32;
321    let mut norm_b = 0.0f32;
322    for i in 0..a.len() {
323        dot += a[i] * b[i];
324        norm_a += a[i] * a[i];
325        norm_b += b[i] * b[i];
326    }
327    let denom = (norm_a * norm_b).sqrt();
328    if denom < 1e-10 {
329        0.0
330    } else {
331        dot / denom
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
340        let mut v = Vec::with_capacity(dim);
341        let mut s = seed;
342        for _ in 0..dim {
343            s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
344            v.push((s as f32 / u64::MAX as f32) * 2.0 - 1.0);
345        }
346        v
347    }
348
349    #[test]
350    fn brute_force_topk_correctness() {
351        let vectors: Vec<Vec<f32>> = (0..100).map(|i| random_vec(16, i)).collect();
352        let query = random_vec(16, 999);
353
354        let results = brute_force_topk(&vectors, &query, 5);
355        assert_eq!(results.len(), 5);
356
357        // Results should be in descending similarity order
358        for w in results.windows(2) {
359            assert!(w[0].1 >= w[1].1);
360        }
361    }
362
363    #[test]
364    fn brute_force_topk_matches_exhaustive() {
365        let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(8, i + 42)).collect();
366        let query = random_vec(8, 123);
367
368        let top5 = brute_force_topk(&vectors, &query, 5);
369
370        // Exhaustive comparison
371        let mut all: Vec<(usize, f32)> = vectors
372            .iter()
373            .enumerate()
374            .map(|(i, v)| (i, cosine_sim(&query, v)))
375            .collect();
376        all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377        all.truncate(5);
378
379        for (heap_r, exact_r) in top5.iter().zip(all.iter()) {
380            assert_eq!(heap_r.0, exact_r.0);
381            assert!((heap_r.1 - exact_r.1).abs() < 1e-6);
382        }
383    }
384
385    #[test]
386    fn empty_index_returns_empty() {
387        let index = AnnIndex::build(Vec::new());
388        assert!(index.search(&[1.0, 0.0], 5).is_empty());
389    }
390
391    #[test]
392    fn small_index_uses_brute_force() {
393        let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(4, i)).collect();
394        let index = AnnIndex::build(vectors);
395        assert!(index.nodes.is_empty()); // no HNSW graph built
396        let results = index.search(&random_vec(4, 999), 3);
397        assert_eq!(results.len(), 3);
398    }
399}