oxibonsai_runtime/embedding_index.rs
1//! Navigable Small World (NSW) approximate nearest-neighbor index.
2//!
3//! Implements a single-layer NSW graph — a simplified HNSW variant that is
4//! fast enough for caches up to ~100k entries while keeping the implementation
5//! self-contained and free of external dependencies.
6//!
7//! # Algorithm sketch
8//!
9//! - **Insert**: greedily traverse the graph from a random (deterministic)
10//! entry point, collecting the `ef_construct` nearest nodes. Connect the
11//! new node to at most `max_connections` of them. Prune neighbours that
12//! exceed `max_connections`.
13//! - **Search**: repeat the greedy traversal, expanding `ef_search` candidates,
14//! and return the top-k by cosine similarity.
15//!
16//! # Example
17//!
18//! ```rust
19//! use oxibonsai_runtime::embedding_index::{EmbeddingIndex, NswConfig};
20//!
21//! let mut index: EmbeddingIndex<&str> = EmbeddingIndex::new(4);
22//! let id = index.insert(vec![1.0, 0.0, 0.0, 0.0], "doc-a");
23//! let results = index.search(&[1.0, 0.0, 0.0, 0.0], 1);
24//! assert_eq!(results[0].1, &"doc-a");
25//! ```
26
27// ─────────────────────────────────────────────────────────────────────────────
28// Math helpers
29// ─────────────────────────────────────────────────────────────────────────────
30
31/// Cosine similarity between two equal-length unit vectors.
32///
33/// Both inputs are assumed to already be L2-normalised. Returns a value in
34/// `[-1.0, 1.0]`; returns `0.0` for empty or mismatched inputs.
35#[inline]
36fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
37 if a.len() != b.len() || a.is_empty() {
38 return 0.0;
39 }
40 a.iter()
41 .zip(b.iter())
42 .map(|(x, y)| x * y)
43 .sum::<f32>()
44 .clamp(-1.0, 1.0)
45}
46
47/// L2-normalise `v` in place. Leaves zero-vectors unchanged.
48#[inline]
49fn l2_normalize(v: &mut [f32]) {
50 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
51 if norm > 1e-10 {
52 for x in v.iter_mut() {
53 *x /= norm;
54 }
55 }
56}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// NswNode (internal)
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// A single node stored in the NSW graph.
63struct NswNode {
64 /// Unique numeric identifier (equals the node's position in `NswIndex::nodes`).
65 id: usize,
66 /// L2-normalised embedding vector.
67 vector: Vec<f32>,
68 /// Indices of connected neighbours in `NswIndex::nodes`.
69 neighbors: Vec<usize>,
70}
71
72// ─────────────────────────────────────────────────────────────────────────────
73// NswConfig
74// ─────────────────────────────────────────────────────────────────────────────
75
76/// Configuration for the NSW approximate nearest-neighbor graph.
77#[derive(Debug, Clone)]
78pub struct NswConfig {
79 /// Maximum number of bidirectional connections per node during construction
80 /// (default: 16). Higher values improve recall at the cost of memory and
81 /// insertion time.
82 pub max_connections: usize,
83 /// Number of candidates to explore during search (default: 64). Higher
84 /// values improve recall at the cost of query latency.
85 pub ef_search: usize,
86 /// Number of candidates to explore during insertion (default: 32). Higher
87 /// values improve graph quality at the cost of insertion latency.
88 pub ef_construct: usize,
89}
90
91impl Default for NswConfig {
92 fn default() -> Self {
93 Self {
94 max_connections: 16,
95 ef_search: 64,
96 ef_construct: 32,
97 }
98 }
99}
100
101// ─────────────────────────────────────────────────────────────────────────────
102// NswSearchResult
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// A single result from an NSW nearest-neighbor search.
106#[derive(Debug, Clone)]
107pub struct NswSearchResult {
108 /// The node's unique identifier (stable across insertions).
109 pub id: usize,
110 /// Cosine similarity score between the query and this node's vector.
111 pub score: f32,
112}
113
114// ─────────────────────────────────────────────────────────────────────────────
115// NswIndex
116// ─────────────────────────────────────────────────────────────────────────────
117
118/// Navigable Small World graph index for approximate nearest-neighbor search.
119///
120/// This is a single-layer NSW — the multi-layer hierarchical variant (HNSW) is
121/// outside scope. Performance is excellent for corpora up to ~100k entries.
122pub struct NswIndex {
123 nodes: Vec<NswNode>,
124 config: NswConfig,
125 dim: usize,
126 /// Simple deterministic counter used instead of a random entry point so
127 /// that behaviour is reproducible without the `rand` crate.
128 entry_counter: usize,
129}
130
131impl NswIndex {
132 /// Create an empty NSW index for `dim`-dimensional vectors.
133 pub fn new(dim: usize, config: NswConfig) -> Self {
134 Self {
135 nodes: Vec::new(),
136 config,
137 dim,
138 entry_counter: 0,
139 }
140 }
141
142 // ── Insertion ─────────────────────────────────────────────────────────────
143
144 /// Insert a normalised copy of `vector` with the given `id`.
145 ///
146 /// 1. Finds `ef_construct` nearest existing nodes via greedy search.
147 /// 2. Connects the new node to at most `max_connections` of them.
148 /// 3. Prunes the neighbours' connection lists if they exceed `max_connections`.
149 ///
150 /// Complexity: O(M × ef_construct) amortised where M = `max_connections`.
151 pub fn insert(&mut self, id: usize, vector: Vec<f32>) {
152 let mut v = vector;
153 // Pad or truncate to match declared dimensionality.
154 v.resize(self.dim, 0.0);
155 l2_normalize(&mut v);
156
157 let new_idx = self.nodes.len();
158
159 if self.nodes.is_empty() {
160 // First node — no edges to add yet.
161 self.nodes.push(NswNode {
162 id,
163 vector: v,
164 neighbors: Vec::new(),
165 });
166 self.entry_counter = 0;
167 return;
168 }
169
170 // Pick a deterministic entry point by rotating through existing nodes.
171 let entry = self.entry_counter % self.nodes.len();
172 self.entry_counter += 1;
173
174 // Find ef_construct nearest candidates.
175 let ef = self.config.ef_construct;
176 let candidates = self.greedy_search(&v, entry, ef);
177
178 // Keep at most max_connections neighbours.
179 let max_conn = self.config.max_connections;
180 let neighbor_indices: Vec<usize> = candidates
181 .iter()
182 .take(max_conn)
183 .map(|(node_idx, _)| *node_idx)
184 .collect();
185
186 // Add the new node.
187 self.nodes.push(NswNode {
188 id,
189 vector: v.clone(),
190 neighbors: neighbor_indices.clone(),
191 });
192
193 // Add back-edges and prune if needed.
194 for &nb_idx in &neighbor_indices {
195 self.nodes[nb_idx].neighbors.push(new_idx);
196 if self.nodes[nb_idx].neighbors.len() > max_conn {
197 self.prune_neighbors(nb_idx, max_conn);
198 }
199 }
200 }
201
202 // ── Search ────────────────────────────────────────────────────────────────
203
204 /// Return the top-`top_k` approximate nearest neighbors of `query`.
205 ///
206 /// Uses a greedy graph traversal starting from a deterministic entry point,
207 /// expanding at most `ef_search` candidates. Results are sorted by cosine
208 /// similarity in descending order.
209 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<NswSearchResult> {
210 if self.nodes.is_empty() || top_k == 0 {
211 return Vec::new();
212 }
213
214 // Normalise the query locally.
215 let mut q = query.to_vec();
216 q.resize(self.dim, 0.0);
217 l2_normalize(&mut q);
218
219 // Use node 0 as a stable entry point for search (read-only, no mutation).
220 let entry = 0;
221 let ef = self.config.ef_search;
222 let mut candidates = self.greedy_search(&q, entry, ef);
223
224 // Sort descending by score.
225 candidates
226 .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227 candidates.truncate(top_k);
228
229 candidates
230 .into_iter()
231 .map(|(node_idx, score)| NswSearchResult {
232 id: self.nodes[node_idx].id,
233 score,
234 })
235 .collect()
236 }
237
238 // ── Accessors ─────────────────────────────────────────────────────────────
239
240 /// Number of vectors stored in the index.
241 pub fn len(&self) -> usize {
242 self.nodes.len()
243 }
244
245 /// Returns `true` if the index contains no vectors.
246 pub fn is_empty(&self) -> bool {
247 self.nodes.is_empty()
248 }
249
250 /// The embedding dimensionality this index was constructed with.
251 pub fn dim(&self) -> usize {
252 self.dim
253 }
254
255 // ── Private helpers ───────────────────────────────────────────────────────
256
257 /// Greedy beam search from `entry` node, returning up to `ef` candidates
258 /// as `(node_index, cosine_similarity)` pairs.
259 ///
260 /// The implementation maintains two sets:
261 /// - `visited`: bit-set of already-explored node indices.
262 /// - `candidates`: max-heap of (score, node_idx) to explore next.
263 /// - `results`: the ef best nodes seen so far.
264 fn greedy_search(&self, query: &[f32], entry: usize, ef: usize) -> Vec<(usize, f32)> {
265 if self.nodes.is_empty() {
266 return Vec::new();
267 }
268
269 use std::cmp::Ordering;
270 use std::collections::{BinaryHeap, HashSet};
271
272 /// Wrapper to allow f32 in BinaryHeap (max-heap by score).
273 #[derive(PartialEq)]
274 struct Scored(f32, usize);
275
276 impl Eq for Scored {}
277
278 impl PartialOrd for Scored {
279 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
280 Some(self.cmp(other))
281 }
282 }
283
284 impl Ord for Scored {
285 fn cmp(&self, other: &Self) -> Ordering {
286 self.0
287 .partial_cmp(&other.0)
288 .unwrap_or(Ordering::Equal)
289 .then(self.1.cmp(&other.1))
290 }
291 }
292
293 let mut visited: HashSet<usize> = HashSet::new();
294 let entry_score = cosine_sim(query, &self.nodes[entry].vector);
295 visited.insert(entry);
296
297 // `frontier` is a max-heap of nodes to expand (best first).
298 let mut frontier: BinaryHeap<Scored> = BinaryHeap::new();
299 frontier.push(Scored(entry_score, entry));
300
301 // `results` keeps the best `ef` nodes found so far.
302 let mut results: Vec<(usize, f32)> = vec![(entry, entry_score)];
303
304 while let Some(Scored(_, node_idx)) = frontier.pop() {
305 // If results already has ef entries and the worst result in results
306 // is better than anything remaining in the frontier, we can stop.
307 if results.len() >= ef {
308 let worst_result = results
309 .iter()
310 .map(|(_, s)| *s)
311 .fold(f32::INFINITY, f32::min);
312 // All remaining frontier nodes are at most as good as `node_idx`
313 // (max-heap), so check against the worst we currently keep.
314 let node_score = results
315 .iter()
316 .find(|(i, _)| *i == node_idx)
317 .map(|(_, s)| *s)
318 .unwrap_or(f32::NEG_INFINITY);
319 if node_score < worst_result && frontier.is_empty() {
320 break;
321 }
322 }
323
324 // Expand neighbours.
325 for &nb_idx in &self.nodes[node_idx].neighbors {
326 if visited.contains(&nb_idx) {
327 continue;
328 }
329 visited.insert(nb_idx);
330
331 let nb_score = cosine_sim(query, &self.nodes[nb_idx].vector);
332 frontier.push(Scored(nb_score, nb_idx));
333 results.push((nb_idx, nb_score));
334
335 // Keep results bounded at ef (drop worst).
336 if results.len() > ef {
337 let worst_idx = results
338 .iter()
339 .enumerate()
340 .min_by(|a, b| {
341 a.1 .1
342 .partial_cmp(&b.1 .1)
343 .unwrap_or(std::cmp::Ordering::Equal)
344 })
345 .map(|(i, _)| i)
346 .expect("results is non-empty");
347 results.swap_remove(worst_idx);
348 }
349 }
350 }
351
352 results
353 }
354
355 /// Prune the neighbor list of node at `node_idx` to at most `max_conn`
356 /// connections, keeping the `max_conn` closest by cosine similarity.
357 fn prune_neighbors(&mut self, node_idx: usize, max_conn: usize) {
358 let v = self.nodes[node_idx].vector.clone();
359 let neighbors = &self.nodes[node_idx].neighbors;
360
361 // Score each current neighbour.
362 let mut scored: Vec<(usize, f32)> = neighbors
363 .iter()
364 .map(|&nb| {
365 let score = cosine_sim(&v, &self.nodes[nb].vector);
366 (nb, score)
367 })
368 .collect();
369
370 // Keep highest-scoring connections.
371 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
372 scored.truncate(max_conn);
373
374 self.nodes[node_idx].neighbors = scored.into_iter().map(|(nb, _)| nb).collect();
375 }
376}
377
378// ─────────────────────────────────────────────────────────────────────────────
379// EmbeddingIndex<T>
380// ─────────────────────────────────────────────────────────────────────────────
381
382/// Combined NSW graph index with per-entry metadata storage.
383///
384/// `T` is any cloneable metadata type — e.g. a `String` payload, a struct, or
385/// a raw identifier.
386///
387/// ```rust
388/// use oxibonsai_runtime::embedding_index::EmbeddingIndex;
389///
390/// let mut idx: EmbeddingIndex<String> = EmbeddingIndex::new(3);
391/// idx.insert(vec![1.0, 0.0, 0.0], "vec-a".to_string());
392/// idx.insert(vec![0.0, 1.0, 0.0], "vec-b".to_string());
393///
394/// let results = idx.search(&[1.0, 0.0, 0.0], 1);
395/// assert_eq!(results[0].1, &"vec-a".to_string());
396/// ```
397pub struct EmbeddingIndex<T: Clone> {
398 graph: NswIndex,
399 /// Parallel metadata store: `metadata[i] = (id, metadata_value)`.
400 metadata: Vec<(usize, T)>,
401 next_id: usize,
402}
403
404impl<T: Clone> EmbeddingIndex<T> {
405 /// Create a new index for `dim`-dimensional vectors with default NSW config.
406 pub fn new(dim: usize) -> Self {
407 Self::new_with_config(dim, NswConfig::default())
408 }
409
410 /// Create a new index with a custom [`NswConfig`].
411 pub fn new_with_config(dim: usize, config: NswConfig) -> Self {
412 Self {
413 graph: NswIndex::new(dim, config),
414 metadata: Vec::new(),
415 next_id: 0,
416 }
417 }
418
419 /// Insert a vector with associated metadata.
420 ///
421 /// Returns the stable numeric ID assigned to this entry.
422 pub fn insert(&mut self, vector: Vec<f32>, meta: T) -> usize {
423 let id = self.next_id;
424 self.next_id += 1;
425 self.graph.insert(id, vector);
426 self.metadata.push((id, meta));
427 id
428 }
429
430 /// Search for the top-`top_k` nearest neighbors of `query`.
431 ///
432 /// Returns a `Vec` of `(NswSearchResult, &T)` pairs sorted by descending
433 /// cosine similarity.
434 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(NswSearchResult, &T)> {
435 let results = self.graph.search(query, top_k);
436 results
437 .into_iter()
438 .filter_map(|r| {
439 // Look up metadata by id.
440 self.metadata
441 .iter()
442 .find(|(id, _)| *id == r.id)
443 .map(|(_, meta)| (r, meta))
444 })
445 .collect()
446 }
447
448 /// Number of entries in the index.
449 pub fn len(&self) -> usize {
450 self.graph.len()
451 }
452
453 /// Returns `true` if the index contains no entries.
454 pub fn is_empty(&self) -> bool {
455 self.graph.is_empty()
456 }
457}
458
459// ─────────────────────────────────────────────────────────────────────────────
460// Tests
461// ─────────────────────────────────────────────────────────────────────────────
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 fn unit_vec(values: &[f32]) -> Vec<f32> {
468 let mut v = values.to_vec();
469 l2_normalize(&mut v);
470 v
471 }
472
473 // ── NswIndex ──────────────────────────────────────────────────────────────
474
475 #[test]
476 fn test_nsw_index_empty() {
477 let idx = NswIndex::new(4, NswConfig::default());
478 assert!(idx.is_empty());
479 assert_eq!(idx.len(), 0);
480 assert_eq!(idx.dim(), 4);
481 let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
482 assert!(results.is_empty());
483 }
484
485 #[test]
486 fn test_nsw_index_single_insert() {
487 let mut idx = NswIndex::new(4, NswConfig::default());
488 idx.insert(0, vec![1.0, 0.0, 0.0, 0.0]);
489 assert_eq!(idx.len(), 1);
490 assert!(!idx.is_empty());
491 let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 1);
492 assert_eq!(results.len(), 1);
493 assert_eq!(results[0].id, 0);
494 assert!(
495 (results[0].score - 1.0).abs() < 1e-5,
496 "score={}",
497 results[0].score
498 );
499 }
500
501 #[test]
502 fn test_nsw_index_search_exact() {
503 let mut idx = NswIndex::new(3, NswConfig::default());
504 let v = unit_vec(&[1.0, 2.0, 3.0]);
505 idx.insert(42, v.clone());
506 let results = idx.search(&v, 1);
507 assert_eq!(results.len(), 1);
508 assert_eq!(results[0].id, 42);
509 assert!(
510 (results[0].score - 1.0).abs() < 1e-5,
511 "score={}",
512 results[0].score
513 );
514 }
515
516 #[test]
517 fn test_nsw_index_search_nearest() {
518 let mut idx = NswIndex::new(2, NswConfig::default());
519 // Insert three vectors; query is closest to id=1.
520 idx.insert(0, unit_vec(&[1.0, 0.0])); // along x-axis
521 idx.insert(1, unit_vec(&[0.0, 1.0])); // along y-axis
522 idx.insert(2, unit_vec(&[-1.0, 0.0])); // negative x-axis
523
524 let query = unit_vec(&[0.1, 0.9]); // close to y-axis
525 let results = idx.search(&query, 1);
526 assert_eq!(results.len(), 1);
527 assert_eq!(
528 results[0].id, 1,
529 "nearest should be y-axis vector, got id={}",
530 results[0].id
531 );
532 }
533
534 #[test]
535 fn test_nsw_index_many_vectors() {
536 let dim = 8;
537 let config = NswConfig {
538 max_connections: 8,
539 ef_search: 32,
540 ef_construct: 16,
541 };
542 let mut idx = NswIndex::new(dim, config);
543
544 // Insert 100 random-ish deterministic vectors.
545 for i in 0..100usize {
546 let mut v: Vec<f32> = (0..dim)
547 .map(|d| {
548 // deterministic pseudo-random using wrapping arithmetic
549 let x = (i as u64)
550 .wrapping_mul(6364136223846793005u64)
551 .wrapping_add((d as u64).wrapping_mul(1442695040888963407u64));
552 let x = x ^ (x >> 33);
553 let x = x.wrapping_mul(0xff51afd7ed558ccdu64);
554 let x = x ^ (x >> 33);
555 (x as i64) as f32 / i64::MAX as f32
556 })
557 .collect();
558 l2_normalize(&mut v);
559 idx.insert(i, v);
560 }
561
562 assert_eq!(idx.len(), 100);
563
564 // A known query: a unit vector along the first dimension.
565 let mut query = vec![0.0f32; dim];
566 query[0] = 1.0;
567 let results = idx.search(&query, 5);
568 assert!(!results.is_empty());
569 assert!(results.len() <= 5);
570 // Scores should be in descending order.
571 for w in results.windows(2) {
572 assert!(
573 w[0].score >= w[1].score - 1e-5,
574 "scores not sorted: {} < {}",
575 w[0].score,
576 w[1].score
577 );
578 }
579 }
580
581 // ── EmbeddingIndex ────────────────────────────────────────────────────────
582
583 #[test]
584 fn test_embedding_index_insert_and_search() {
585 let mut idx: EmbeddingIndex<u32> = EmbeddingIndex::new(4);
586 idx.insert(unit_vec(&[1.0, 0.0, 0.0, 0.0]), 100);
587 idx.insert(unit_vec(&[0.0, 1.0, 0.0, 0.0]), 200);
588 idx.insert(unit_vec(&[0.0, 0.0, 1.0, 0.0]), 300);
589
590 let results = idx.search(&unit_vec(&[1.0, 0.0, 0.0, 0.0]), 1);
591 assert_eq!(results.len(), 1);
592 assert_eq!(*results[0].1, 100u32);
593 }
594
595 #[test]
596 fn test_embedding_index_metadata_returned() {
597 let mut idx: EmbeddingIndex<String> = EmbeddingIndex::new(3);
598 let id = idx.insert(unit_vec(&[1.0, 1.0, 0.0]), "hello world".to_string());
599 assert_eq!(id, 0);
600 let results = idx.search(&unit_vec(&[1.0, 1.0, 0.0]), 1);
601 assert_eq!(results.len(), 1);
602 assert_eq!(results[0].1, &"hello world".to_string());
603 assert!((results[0].0.score - 1.0).abs() < 1e-5);
604 }
605
606 #[test]
607 fn test_nsw_config_defaults() {
608 let cfg = NswConfig::default();
609 assert_eq!(cfg.max_connections, 16);
610 assert_eq!(cfg.ef_search, 64);
611 assert_eq!(cfg.ef_construct, 32);
612 }
613}