Skip to main content

oxibonsai_runtime/
embedding_index.rs

1//! Navigable Small World (NSW) approximate nearest-neighbor index.
2//!
3//! Implements a single-layer NSW graph — a simplified HNSW variant that is
4//! fast enough for caches up to ~100k entries while keeping the implementation
5//! self-contained and free of external dependencies.
6//!
7//! # Algorithm sketch
8//!
9//! - **Insert**: greedily traverse the graph from a random (deterministic)
10//!   entry point, collecting the `ef_construct` nearest nodes.  Connect the
11//!   new node to at most `max_connections` of them.  Prune neighbours that
12//!   exceed `max_connections`.
13//! - **Search**: repeat the greedy traversal, expanding `ef_search` candidates,
14//!   and return the top-k by cosine similarity.
15//!
16//! # Example
17//!
18//! ```rust
19//! use oxibonsai_runtime::embedding_index::{EmbeddingIndex, NswConfig};
20//!
21//! let mut index: EmbeddingIndex<&str> = EmbeddingIndex::new(4);
22//! let id = index.insert(vec![1.0, 0.0, 0.0, 0.0], "doc-a");
23//! let results = index.search(&[1.0, 0.0, 0.0, 0.0], 1);
24//! assert_eq!(results[0].1, &"doc-a");
25//! ```
26
27// ─────────────────────────────────────────────────────────────────────────────
28// Math helpers
29// ─────────────────────────────────────────────────────────────────────────────
30
31/// Cosine similarity between two equal-length unit vectors.
32///
33/// Both inputs are assumed to already be L2-normalised.  Returns a value in
34/// `[-1.0, 1.0]`; returns `0.0` for empty or mismatched inputs.
35#[inline]
36fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
37    if a.len() != b.len() || a.is_empty() {
38        return 0.0;
39    }
40    a.iter()
41        .zip(b.iter())
42        .map(|(x, y)| x * y)
43        .sum::<f32>()
44        .clamp(-1.0, 1.0)
45}
46
47/// L2-normalise `v` in place.  Leaves zero-vectors unchanged.
48#[inline]
49fn l2_normalize(v: &mut [f32]) {
50    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
51    if norm > 1e-10 {
52        for x in v.iter_mut() {
53            *x /= norm;
54        }
55    }
56}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// NswNode (internal)
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// A single node stored in the NSW graph.
63struct NswNode {
64    /// Unique numeric identifier (equals the node's position in `NswIndex::nodes`).
65    id: usize,
66    /// L2-normalised embedding vector.
67    vector: Vec<f32>,
68    /// Indices of connected neighbours in `NswIndex::nodes`.
69    neighbors: Vec<usize>,
70}
71
72// ─────────────────────────────────────────────────────────────────────────────
73// NswConfig
74// ─────────────────────────────────────────────────────────────────────────────
75
76/// Configuration for the NSW approximate nearest-neighbor graph.
77#[derive(Debug, Clone)]
78pub struct NswConfig {
79    /// Maximum number of bidirectional connections per node during construction
80    /// (default: 16).  Higher values improve recall at the cost of memory and
81    /// insertion time.
82    pub max_connections: usize,
83    /// Number of candidates to explore during search (default: 64).  Higher
84    /// values improve recall at the cost of query latency.
85    pub ef_search: usize,
86    /// Number of candidates to explore during insertion (default: 32).  Higher
87    /// values improve graph quality at the cost of insertion latency.
88    pub ef_construct: usize,
89}
90
91impl Default for NswConfig {
92    fn default() -> Self {
93        Self {
94            max_connections: 16,
95            ef_search: 64,
96            ef_construct: 32,
97        }
98    }
99}
100
101// ─────────────────────────────────────────────────────────────────────────────
102// NswSearchResult
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// A single result from an NSW nearest-neighbor search.
106#[derive(Debug, Clone)]
107pub struct NswSearchResult {
108    /// The node's unique identifier (stable across insertions).
109    pub id: usize,
110    /// Cosine similarity score between the query and this node's vector.
111    pub score: f32,
112}
113
114// ─────────────────────────────────────────────────────────────────────────────
115// NswIndex
116// ─────────────────────────────────────────────────────────────────────────────
117
118/// Navigable Small World graph index for approximate nearest-neighbor search.
119///
120/// This is a single-layer NSW — the multi-layer hierarchical variant (HNSW) is
121/// outside scope.  Performance is excellent for corpora up to ~100k entries.
122pub struct NswIndex {
123    nodes: Vec<NswNode>,
124    config: NswConfig,
125    dim: usize,
126    /// Simple deterministic counter used instead of a random entry point so
127    /// that behaviour is reproducible without the `rand` crate.
128    entry_counter: usize,
129}
130
131impl NswIndex {
132    /// Create an empty NSW index for `dim`-dimensional vectors.
133    pub fn new(dim: usize, config: NswConfig) -> Self {
134        Self {
135            nodes: Vec::new(),
136            config,
137            dim,
138            entry_counter: 0,
139        }
140    }
141
142    // ── Insertion ─────────────────────────────────────────────────────────────
143
144    /// Insert a normalised copy of `vector` with the given `id`.
145    ///
146    /// 1. Finds `ef_construct` nearest existing nodes via greedy search.
147    /// 2. Connects the new node to at most `max_connections` of them.
148    /// 3. Prunes the neighbours' connection lists if they exceed `max_connections`.
149    ///
150    /// Complexity: O(M × ef_construct) amortised where M = `max_connections`.
151    pub fn insert(&mut self, id: usize, vector: Vec<f32>) {
152        let mut v = vector;
153        // Pad or truncate to match declared dimensionality.
154        v.resize(self.dim, 0.0);
155        l2_normalize(&mut v);
156
157        let new_idx = self.nodes.len();
158
159        if self.nodes.is_empty() {
160            // First node — no edges to add yet.
161            self.nodes.push(NswNode {
162                id,
163                vector: v,
164                neighbors: Vec::new(),
165            });
166            self.entry_counter = 0;
167            return;
168        }
169
170        // Pick a deterministic entry point by rotating through existing nodes.
171        let entry = self.entry_counter % self.nodes.len();
172        self.entry_counter += 1;
173
174        // Find ef_construct nearest candidates.
175        let ef = self.config.ef_construct;
176        let candidates = self.greedy_search(&v, entry, ef);
177
178        // Keep at most max_connections neighbours.
179        let max_conn = self.config.max_connections;
180        let neighbor_indices: Vec<usize> = candidates
181            .iter()
182            .take(max_conn)
183            .map(|(node_idx, _)| *node_idx)
184            .collect();
185
186        // Add the new node.
187        self.nodes.push(NswNode {
188            id,
189            vector: v.clone(),
190            neighbors: neighbor_indices.clone(),
191        });
192
193        // Add back-edges and prune if needed.
194        for &nb_idx in &neighbor_indices {
195            self.nodes[nb_idx].neighbors.push(new_idx);
196            if self.nodes[nb_idx].neighbors.len() > max_conn {
197                self.prune_neighbors(nb_idx, max_conn);
198            }
199        }
200    }
201
202    // ── Search ────────────────────────────────────────────────────────────────
203
204    /// Return the top-`top_k` approximate nearest neighbors of `query`.
205    ///
206    /// Uses a greedy graph traversal starting from a deterministic entry point,
207    /// expanding at most `ef_search` candidates.  Results are sorted by cosine
208    /// similarity in descending order.
209    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<NswSearchResult> {
210        if self.nodes.is_empty() || top_k == 0 {
211            return Vec::new();
212        }
213
214        // Normalise the query locally.
215        let mut q = query.to_vec();
216        q.resize(self.dim, 0.0);
217        l2_normalize(&mut q);
218
219        // Use node 0 as a stable entry point for search (read-only, no mutation).
220        let entry = 0;
221        let ef = self.config.ef_search;
222        let mut candidates = self.greedy_search(&q, entry, ef);
223
224        // Sort descending by score.
225        candidates
226            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227        candidates.truncate(top_k);
228
229        candidates
230            .into_iter()
231            .map(|(node_idx, score)| NswSearchResult {
232                id: self.nodes[node_idx].id,
233                score,
234            })
235            .collect()
236    }
237
238    // ── Accessors ─────────────────────────────────────────────────────────────
239
240    /// Number of vectors stored in the index.
241    pub fn len(&self) -> usize {
242        self.nodes.len()
243    }
244
245    /// Returns `true` if the index contains no vectors.
246    pub fn is_empty(&self) -> bool {
247        self.nodes.is_empty()
248    }
249
250    /// The embedding dimensionality this index was constructed with.
251    pub fn dim(&self) -> usize {
252        self.dim
253    }
254
255    // ── Private helpers ───────────────────────────────────────────────────────
256
257    /// Greedy beam search from `entry` node, returning up to `ef` candidates
258    /// as `(node_index, cosine_similarity)` pairs.
259    ///
260    /// The implementation maintains two sets:
261    /// - `visited`: bit-set of already-explored node indices.
262    /// - `candidates`: max-heap of (score, node_idx) to explore next.
263    /// - `results`: the ef best nodes seen so far.
264    fn greedy_search(&self, query: &[f32], entry: usize, ef: usize) -> Vec<(usize, f32)> {
265        if self.nodes.is_empty() {
266            return Vec::new();
267        }
268
269        use std::cmp::Ordering;
270        use std::collections::{BinaryHeap, HashSet};
271
272        /// Wrapper to allow f32 in BinaryHeap (max-heap by score).
273        #[derive(PartialEq)]
274        struct Scored(f32, usize);
275
276        impl Eq for Scored {}
277
278        impl PartialOrd for Scored {
279            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
280                Some(self.cmp(other))
281            }
282        }
283
284        impl Ord for Scored {
285            fn cmp(&self, other: &Self) -> Ordering {
286                self.0
287                    .partial_cmp(&other.0)
288                    .unwrap_or(Ordering::Equal)
289                    .then(self.1.cmp(&other.1))
290            }
291        }
292
293        let mut visited: HashSet<usize> = HashSet::new();
294        let entry_score = cosine_sim(query, &self.nodes[entry].vector);
295        visited.insert(entry);
296
297        // `frontier` is a max-heap of nodes to expand (best first).
298        let mut frontier: BinaryHeap<Scored> = BinaryHeap::new();
299        frontier.push(Scored(entry_score, entry));
300
301        // `results` keeps the best `ef` nodes found so far.
302        let mut results: Vec<(usize, f32)> = vec![(entry, entry_score)];
303
304        while let Some(Scored(_, node_idx)) = frontier.pop() {
305            // If results already has ef entries and the worst result in results
306            // is better than anything remaining in the frontier, we can stop.
307            if results.len() >= ef {
308                let worst_result = results
309                    .iter()
310                    .map(|(_, s)| *s)
311                    .fold(f32::INFINITY, f32::min);
312                // All remaining frontier nodes are at most as good as `node_idx`
313                // (max-heap), so check against the worst we currently keep.
314                let node_score = results
315                    .iter()
316                    .find(|(i, _)| *i == node_idx)
317                    .map(|(_, s)| *s)
318                    .unwrap_or(f32::NEG_INFINITY);
319                if node_score < worst_result && frontier.is_empty() {
320                    break;
321                }
322            }
323
324            // Expand neighbours.
325            for &nb_idx in &self.nodes[node_idx].neighbors {
326                if visited.contains(&nb_idx) {
327                    continue;
328                }
329                visited.insert(nb_idx);
330
331                let nb_score = cosine_sim(query, &self.nodes[nb_idx].vector);
332                frontier.push(Scored(nb_score, nb_idx));
333                results.push((nb_idx, nb_score));
334
335                // Keep results bounded at ef (drop worst).
336                if results.len() > ef {
337                    let worst_idx = results
338                        .iter()
339                        .enumerate()
340                        .min_by(|a, b| {
341                            a.1 .1
342                                .partial_cmp(&b.1 .1)
343                                .unwrap_or(std::cmp::Ordering::Equal)
344                        })
345                        .map(|(i, _)| i)
346                        .expect("results is non-empty");
347                    results.swap_remove(worst_idx);
348                }
349            }
350        }
351
352        results
353    }
354
355    /// Prune the neighbor list of node at `node_idx` to at most `max_conn`
356    /// connections, keeping the `max_conn` closest by cosine similarity.
357    fn prune_neighbors(&mut self, node_idx: usize, max_conn: usize) {
358        let v = self.nodes[node_idx].vector.clone();
359        let neighbors = &self.nodes[node_idx].neighbors;
360
361        // Score each current neighbour.
362        let mut scored: Vec<(usize, f32)> = neighbors
363            .iter()
364            .map(|&nb| {
365                let score = cosine_sim(&v, &self.nodes[nb].vector);
366                (nb, score)
367            })
368            .collect();
369
370        // Keep highest-scoring connections.
371        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
372        scored.truncate(max_conn);
373
374        self.nodes[node_idx].neighbors = scored.into_iter().map(|(nb, _)| nb).collect();
375    }
376}
377
378// ─────────────────────────────────────────────────────────────────────────────
379// EmbeddingIndex<T>
380// ─────────────────────────────────────────────────────────────────────────────
381
382/// Combined NSW graph index with per-entry metadata storage.
383///
384/// `T` is any cloneable metadata type — e.g. a `String` payload, a struct, or
385/// a raw identifier.
386///
387/// ```rust
388/// use oxibonsai_runtime::embedding_index::EmbeddingIndex;
389///
390/// let mut idx: EmbeddingIndex<String> = EmbeddingIndex::new(3);
391/// idx.insert(vec![1.0, 0.0, 0.0], "vec-a".to_string());
392/// idx.insert(vec![0.0, 1.0, 0.0], "vec-b".to_string());
393///
394/// let results = idx.search(&[1.0, 0.0, 0.0], 1);
395/// assert_eq!(results[0].1, &"vec-a".to_string());
396/// ```
397pub struct EmbeddingIndex<T: Clone> {
398    graph: NswIndex,
399    /// Parallel metadata store: `metadata[i] = (id, metadata_value)`.
400    metadata: Vec<(usize, T)>,
401    next_id: usize,
402}
403
404impl<T: Clone> EmbeddingIndex<T> {
405    /// Create a new index for `dim`-dimensional vectors with default NSW config.
406    pub fn new(dim: usize) -> Self {
407        Self::new_with_config(dim, NswConfig::default())
408    }
409
410    /// Create a new index with a custom [`NswConfig`].
411    pub fn new_with_config(dim: usize, config: NswConfig) -> Self {
412        Self {
413            graph: NswIndex::new(dim, config),
414            metadata: Vec::new(),
415            next_id: 0,
416        }
417    }
418
419    /// Insert a vector with associated metadata.
420    ///
421    /// Returns the stable numeric ID assigned to this entry.
422    pub fn insert(&mut self, vector: Vec<f32>, meta: T) -> usize {
423        let id = self.next_id;
424        self.next_id += 1;
425        self.graph.insert(id, vector);
426        self.metadata.push((id, meta));
427        id
428    }
429
430    /// Search for the top-`top_k` nearest neighbors of `query`.
431    ///
432    /// Returns a `Vec` of `(NswSearchResult, &T)` pairs sorted by descending
433    /// cosine similarity.
434    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(NswSearchResult, &T)> {
435        let results = self.graph.search(query, top_k);
436        results
437            .into_iter()
438            .filter_map(|r| {
439                // Look up metadata by id.
440                self.metadata
441                    .iter()
442                    .find(|(id, _)| *id == r.id)
443                    .map(|(_, meta)| (r, meta))
444            })
445            .collect()
446    }
447
448    /// Number of entries in the index.
449    pub fn len(&self) -> usize {
450        self.graph.len()
451    }
452
453    /// Returns `true` if the index contains no entries.
454    pub fn is_empty(&self) -> bool {
455        self.graph.is_empty()
456    }
457}
458
459// ─────────────────────────────────────────────────────────────────────────────
460// Tests
461// ─────────────────────────────────────────────────────────────────────────────
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    fn unit_vec(values: &[f32]) -> Vec<f32> {
468        let mut v = values.to_vec();
469        l2_normalize(&mut v);
470        v
471    }
472
473    // ── NswIndex ──────────────────────────────────────────────────────────────
474
475    #[test]
476    fn test_nsw_index_empty() {
477        let idx = NswIndex::new(4, NswConfig::default());
478        assert!(idx.is_empty());
479        assert_eq!(idx.len(), 0);
480        assert_eq!(idx.dim(), 4);
481        let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
482        assert!(results.is_empty());
483    }
484
485    #[test]
486    fn test_nsw_index_single_insert() {
487        let mut idx = NswIndex::new(4, NswConfig::default());
488        idx.insert(0, vec![1.0, 0.0, 0.0, 0.0]);
489        assert_eq!(idx.len(), 1);
490        assert!(!idx.is_empty());
491        let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 1);
492        assert_eq!(results.len(), 1);
493        assert_eq!(results[0].id, 0);
494        assert!(
495            (results[0].score - 1.0).abs() < 1e-5,
496            "score={}",
497            results[0].score
498        );
499    }
500
501    #[test]
502    fn test_nsw_index_search_exact() {
503        let mut idx = NswIndex::new(3, NswConfig::default());
504        let v = unit_vec(&[1.0, 2.0, 3.0]);
505        idx.insert(42, v.clone());
506        let results = idx.search(&v, 1);
507        assert_eq!(results.len(), 1);
508        assert_eq!(results[0].id, 42);
509        assert!(
510            (results[0].score - 1.0).abs() < 1e-5,
511            "score={}",
512            results[0].score
513        );
514    }
515
516    #[test]
517    fn test_nsw_index_search_nearest() {
518        let mut idx = NswIndex::new(2, NswConfig::default());
519        // Insert three vectors; query is closest to id=1.
520        idx.insert(0, unit_vec(&[1.0, 0.0])); // along x-axis
521        idx.insert(1, unit_vec(&[0.0, 1.0])); // along y-axis
522        idx.insert(2, unit_vec(&[-1.0, 0.0])); // negative x-axis
523
524        let query = unit_vec(&[0.1, 0.9]); // close to y-axis
525        let results = idx.search(&query, 1);
526        assert_eq!(results.len(), 1);
527        assert_eq!(
528            results[0].id, 1,
529            "nearest should be y-axis vector, got id={}",
530            results[0].id
531        );
532    }
533
534    #[test]
535    fn test_nsw_index_many_vectors() {
536        let dim = 8;
537        let config = NswConfig {
538            max_connections: 8,
539            ef_search: 32,
540            ef_construct: 16,
541        };
542        let mut idx = NswIndex::new(dim, config);
543
544        // Insert 100 random-ish deterministic vectors.
545        for i in 0..100usize {
546            let mut v: Vec<f32> = (0..dim)
547                .map(|d| {
548                    // deterministic pseudo-random using wrapping arithmetic
549                    let x = (i as u64)
550                        .wrapping_mul(6364136223846793005u64)
551                        .wrapping_add((d as u64).wrapping_mul(1442695040888963407u64));
552                    let x = x ^ (x >> 33);
553                    let x = x.wrapping_mul(0xff51afd7ed558ccdu64);
554                    let x = x ^ (x >> 33);
555                    (x as i64) as f32 / i64::MAX as f32
556                })
557                .collect();
558            l2_normalize(&mut v);
559            idx.insert(i, v);
560        }
561
562        assert_eq!(idx.len(), 100);
563
564        // A known query: a unit vector along the first dimension.
565        let mut query = vec![0.0f32; dim];
566        query[0] = 1.0;
567        let results = idx.search(&query, 5);
568        assert!(!results.is_empty());
569        assert!(results.len() <= 5);
570        // Scores should be in descending order.
571        for w in results.windows(2) {
572            assert!(
573                w[0].score >= w[1].score - 1e-5,
574                "scores not sorted: {} < {}",
575                w[0].score,
576                w[1].score
577            );
578        }
579    }
580
581    // ── EmbeddingIndex ────────────────────────────────────────────────────────
582
583    #[test]
584    fn test_embedding_index_insert_and_search() {
585        let mut idx: EmbeddingIndex<u32> = EmbeddingIndex::new(4);
586        idx.insert(unit_vec(&[1.0, 0.0, 0.0, 0.0]), 100);
587        idx.insert(unit_vec(&[0.0, 1.0, 0.0, 0.0]), 200);
588        idx.insert(unit_vec(&[0.0, 0.0, 1.0, 0.0]), 300);
589
590        let results = idx.search(&unit_vec(&[1.0, 0.0, 0.0, 0.0]), 1);
591        assert_eq!(results.len(), 1);
592        assert_eq!(*results[0].1, 100u32);
593    }
594
595    #[test]
596    fn test_embedding_index_metadata_returned() {
597        let mut idx: EmbeddingIndex<String> = EmbeddingIndex::new(3);
598        let id = idx.insert(unit_vec(&[1.0, 1.0, 0.0]), "hello world".to_string());
599        assert_eq!(id, 0);
600        let results = idx.search(&unit_vec(&[1.0, 1.0, 0.0]), 1);
601        assert_eq!(results.len(), 1);
602        assert_eq!(results[0].1, &"hello world".to_string());
603        assert!((results[0].0.score - 1.0).abs() < 1e-5);
604    }
605
606    #[test]
607    fn test_nsw_config_defaults() {
608        let cfg = NswConfig::default();
609        assert_eq!(cfg.max_connections, 16);
610        assert_eq!(cfg.ef_search, 64);
611        assert_eq!(cfg.ef_construct, 32);
612    }
613}