Skip to main content

oxios_kernel/memory/
hnsw.rs

1//! HNSW-based approximate nearest neighbor index via `usearch`.
2//!
3//! Wraps the usearch library to provide a Rust-friendly HNSW index for
4//! high-dimensional dense vector search. Supports persistence (save/load),
5//! add/remove, and k-NN search.
6
7use std::path::Path;
8
9use anyhow::{Context, Result};
10use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
11use usearch::Index;
12
13/// Default vector dimensions (OpenAI text-embedding-3-small).
14const DEFAULT_DIMENSIONS: usize = 1536;
15
16/// Default connectivity (HNSW graph edges per node).
17const DEFAULT_CONNECTIVITY: usize = 16;
18
19/// Default expansion factor for search.
20const DEFAULT_EXPANSION_SEARCH: usize = 128;
21
22/// Default expansion factor for add.
23const DEFAULT_EXPANSION_ADD: usize = 128;
24
25// ---------------------------------------------------------------------------
26// HnswIndex
27// ---------------------------------------------------------------------------
28
29/// HNSW approximate nearest neighbor index.
30///
31/// Wraps `usearch::Index` and provides a type-safe, ergonomic interface
32/// for dense vector operations. The index is not thread-safe internally —
33/// callers must synchronize access (e.g., via `parking_lot::RwLock`).
34pub struct HnswIndex {
35    /// Underlying usearch index.
36    index: Index,
37    /// Vector dimensions.
38    dimensions: usize,
39}
40
41impl std::fmt::Debug for HnswIndex {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("HnswIndex")
44            .field("dimensions", &self.dimensions)
45            .field("size", &self.len())
46            .finish()
47    }
48}
49
50impl HnswIndex {
51    /// Create a new HNSW index.
52    ///
53    /// # Arguments
54    /// * `dimensions` — Dimensionality of vectors (e.g., 1536 for OpenAI).
55    /// * `capacity` — Initial capacity hint (pre-allocated slots).
56    pub fn new(dimensions: usize, capacity: usize) -> Result<Self> {
57        let options = IndexOptions {
58            dimensions,
59            metric: MetricKind::Cos,
60            quantization: ScalarKind::F32,
61            connectivity: DEFAULT_CONNECTIVITY,
62            expansion_add: DEFAULT_EXPANSION_ADD,
63            expansion_search: DEFAULT_EXPANSION_SEARCH,
64            multi: false,
65        };
66
67        let index = Index::new(&options).context("Failed to create HNSW index")?;
68        if capacity > 0 {
69            index
70                .reserve(capacity)
71                .map_err(|e| anyhow::anyhow!("Failed to reserve HNSW capacity: {}", e))?;
72        }
73
74        Ok(Self { index, dimensions })
75    }
76
77    /// Create with default dimensions (1536).
78    pub fn with_default_dims(capacity: usize) -> Result<Self> {
79        Self::new(DEFAULT_DIMENSIONS, capacity)
80    }
81
82    /// Add a vector to the index with the given key.
83    ///
84    /// The key is a u64 identifier. Callers should maintain a mapping
85    /// from u64 key to logical ID (e.g., via SQLite).
86    pub fn add(&self, key: u64, vector: &[f32]) -> Result<()> {
87        anyhow::ensure!(
88            vector.len() == self.dimensions,
89            "Vector dimension mismatch: expected {}, got {}",
90            self.dimensions,
91            vector.len()
92        );
93        self.index
94            .add(key, vector)
95            .map_err(|e| anyhow::anyhow!("HNSW add failed for key {}: {}", key, e))?;
96        Ok(())
97    }
98
99    /// Search for the k nearest neighbors of the query vector.
100    ///
101    /// Returns a sorted list of (key, distance) pairs.
102    /// Distance is cosine distance (0.0 = identical for normalized vectors).
103    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
104        anyhow::ensure!(
105            query.len() == self.dimensions,
106            "Query dimension mismatch: expected {}, got {}",
107            self.dimensions,
108            query.len()
109        );
110        if k == 0 {
111            return Ok(Vec::new());
112        }
113
114        let results = self
115            .index
116            .search(query, k)
117            .map_err(|e| anyhow::anyhow!("HNSW search failed: {}", e))?;
118
119        Ok(results
120            .keys
121            .into_iter()
122            .zip(results.distances)
123            .filter(|(k, _)| *k != 0)
124            .collect())
125    }
126
127    /// Remove a vector by key.
128    pub fn remove(&self, key: u64) -> Result<()> {
129        self.index
130            .remove(key)
131            .map(|_| ())
132            .map_err(|e| anyhow::anyhow!("HNSW remove failed for key {}: {}", key, e))
133    }
134
135    /// Check if a key exists in the index.
136    pub fn contains(&self, key: u64) -> bool {
137        self.index.contains(key)
138    }
139
140    /// Get the vector stored for a key.
141    pub fn get(&self, key: u64) -> Option<Vec<f32>> {
142        let mut buffer = vec![0.0f32; self.dimensions];
143        match self.index.get(key, &mut buffer) {
144            Ok(count) if count > 0 => Some(buffer),
145            _ => None,
146        }
147    }
148
149    /// Number of vectors currently in the index.
150    pub fn len(&self) -> usize {
151        self.index.size()
152    }
153
154    /// Whether the index is empty.
155    pub fn is_empty(&self) -> bool {
156        self.len() == 0
157    }
158
159    /// Vector dimensions.
160    pub fn dimensions(&self) -> usize {
161        self.dimensions
162    }
163
164    /// Save the index to a file.
165    pub fn save(&self, path: &Path) -> Result<()> {
166        let path_str = path.to_str().ok_or_else(|| {
167            anyhow::anyhow!("HNSW save path is not valid UTF-8: {}", path.display())
168        })?;
169        self.index
170            .save(path_str)
171            .map_err(|e| anyhow::anyhow!("HNSW save failed: {}", e))?;
172        Ok(())
173    }
174
175    /// Load (restore) an index from a file.
176    ///
177    /// Returns a new `HnswIndex` with the same dimensions as the saved index.
178    pub fn load(path: &Path) -> Result<Self> {
179        let path_str = path.to_str().ok_or_else(|| {
180            anyhow::anyhow!("HNSW load path is not valid UTF-8: {}", path.display())
181        })?;
182        let index =
183            Index::restore(path_str).map_err(|e| anyhow::anyhow!("HNSW load failed: {}", e))?;
184        let dimensions = index.dimensions();
185        Ok(Self { index, dimensions })
186    }
187
188    /// Reserve additional capacity.
189    pub fn reserve(&self, capacity: usize) -> Result<()> {
190        self.index
191            .reserve(capacity)
192            .map_err(|e| anyhow::anyhow!("HNSW reserve failed: {}", e))?;
193        Ok(())
194    }
195
196    /// Rename a key (reassign vector from old key to new key).
197    pub fn rename(&self, from: u64, to: u64) -> Result<()> {
198        self.index
199            .rename(from, to)
200            .map(|_| ())
201            .map_err(|e| anyhow::anyhow!("HNSW rename failed: {} -> {}: {}", from, to, e))
202    }
203}
204
205// ---------------------------------------------------------------------------
206// Tests
207// ---------------------------------------------------------------------------
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use tempfile::TempDir;
213
214    #[test]
215    fn test_hnsw_add_and_search() {
216        let index = HnswIndex::new(3, 100).unwrap();
217
218        let v1: Vec<f32> = vec![1.0, 0.0, 0.0];
219        let v2: Vec<f32> = vec![0.0, 1.0, 0.0];
220        let v3: Vec<f32> = vec![0.0, 0.0, 1.0];
221
222        index.add(1, &v1).unwrap();
223        index.add(2, &v2).unwrap();
224        index.add(3, &v3).unwrap();
225
226        assert_eq!(index.len(), 3);
227
228        // Search for nearest to v1
229        let results = index.search(&v1, 1).unwrap();
230        assert_eq!(results.len(), 1);
231        assert_eq!(results[0].0, 1);
232        // Cosine distance should be ~0 for identical vectors
233        assert!(
234            results[0].1 < 0.01,
235            "Distance should be ~0, got {}",
236            results[0].1
237        );
238    }
239
240    #[test]
241    fn test_hnsw_search_multiple() {
242        let index = HnswIndex::new(4, 100).unwrap();
243
244        // Two clusters
245        index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
246        index.add(2, &[0.9, 0.1, 0.0, 0.0]).unwrap();
247        index.add(3, &[0.0, 1.0, 0.0, 0.0]).unwrap();
248        index.add(4, &[0.0, 0.9, 0.1, 0.0]).unwrap();
249
250        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
251        assert_eq!(results.len(), 2);
252        // First result should be key 1 (exact match)
253        assert_eq!(results[0].0, 1);
254        // Second should be key 2 (nearest neighbor)
255        assert_eq!(results[1].0, 2);
256    }
257
258    #[test]
259    fn test_hnsw_dimension_mismatch() {
260        let index = HnswIndex::new(3, 10).unwrap();
261        let result = index.add(1, &[1.0, 0.0]); // wrong dim
262        assert!(result.is_err());
263    }
264
265    #[test]
266    fn test_hnsw_save_and_load() {
267        let dir = TempDir::new().unwrap();
268        let path = dir.path().join("test.usearch");
269
270        {
271            let index = HnswIndex::new(3, 100).unwrap();
272            index.add(1, &[1.0, 0.0, 0.0]).unwrap();
273            index.add(2, &[0.0, 1.0, 0.0]).unwrap();
274            index.save(&path).unwrap();
275        }
276
277        let loaded = HnswIndex::load(&path).unwrap();
278        assert_eq!(loaded.len(), 2);
279        assert_eq!(loaded.dimensions(), 3);
280
281        let results = loaded.search(&[1.0, 0.0, 0.0], 1).unwrap();
282        assert_eq!(results[0].0, 1);
283    }
284
285    #[test]
286    fn test_hnsw_contains() {
287        let index = HnswIndex::new(3, 10).unwrap();
288        assert!(!index.contains(1));
289
290        index.add(1, &[1.0, 0.0, 0.0]).unwrap();
291        assert!(index.contains(1));
292        assert!(!index.contains(2));
293    }
294
295    #[test]
296    fn test_hnsw_remove() {
297        let index = HnswIndex::new(3, 100).unwrap();
298        index.add(1, &[1.0, 0.0, 0.0]).unwrap();
299        assert_eq!(index.len(), 1);
300
301        index.remove(1).unwrap();
302        assert_eq!(index.len(), 0);
303    }
304
305    #[test]
306    fn test_hnsw_empty_search() {
307        let index = HnswIndex::new(3, 10).unwrap();
308        let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap();
309        assert!(results.is_empty());
310    }
311
312    #[test]
313    fn test_hnsw_with_default_dims() {
314        let index = HnswIndex::with_default_dims(100).unwrap();
315        assert_eq!(index.dimensions(), 1536);
316    }
317}