Skip to main content

mnemo_core/index/
usearch.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::RwLock;
4
5use crate::error::{Error, Result};
6use crate::index::VectorIndex;
7use uuid::Uuid;
8
9pub struct UsearchIndex {
10    index: RwLock<usearch::Index>,
11    uuid_to_key: RwLock<HashMap<Uuid, u64>>,
12    key_to_uuid: RwLock<HashMap<u64, Uuid>>,
13    next_key: RwLock<u64>,
14    dimensions: usize,
15}
16
17impl UsearchIndex {
18    pub fn new(dimensions: usize) -> Result<Self> {
19        let opts = usearch::IndexOptions {
20            dimensions,
21            metric: usearch::MetricKind::Cos,
22            quantization: usearch::ScalarKind::F32,
23            ..Default::default()
24        };
25        let index = usearch::Index::new(&opts).map_err(|e| Error::Index(e.to_string()))?;
26        index
27            .reserve(10_000)
28            .map_err(|e| Error::Index(e.to_string()))?;
29
30        Ok(Self {
31            index: RwLock::new(index),
32            uuid_to_key: RwLock::new(HashMap::new()),
33            key_to_uuid: RwLock::new(HashMap::new()),
34            next_key: RwLock::new(0),
35            dimensions,
36        })
37    }
38
39    fn allocate_key(&self, id: Uuid) -> u64 {
40        let mut next = self.next_key.write().unwrap_or_else(|e| e.into_inner());
41        let key = *next;
42        *next += 1;
43        self.uuid_to_key
44            .write()
45            .unwrap_or_else(|e| e.into_inner())
46            .insert(id, key);
47        self.key_to_uuid
48            .write()
49            .unwrap_or_else(|e| e.into_inner())
50            .insert(key, id);
51        key
52    }
53
54    fn rollback_key(&self, id: Uuid, key: u64) {
55        self.uuid_to_key
56            .write()
57            .unwrap_or_else(|e| e.into_inner())
58            .remove(&id);
59        self.key_to_uuid
60            .write()
61            .unwrap_or_else(|e| e.into_inner())
62            .remove(&key);
63    }
64}
65
66impl VectorIndex for UsearchIndex {
67    fn add(&self, id: Uuid, vector: &[f32]) -> Result<()> {
68        if vector.len() != self.dimensions {
69            return Err(Error::Validation(format!(
70                "expected {} dimensions, got {}",
71                self.dimensions,
72                vector.len()
73            )));
74        }
75
76        // If this UUID already exists, remove it first
77        if self
78            .uuid_to_key
79            .read()
80            .unwrap_or_else(|e| e.into_inner())
81            .contains_key(&id)
82        {
83            self.remove(id)?;
84        }
85
86        let key = self.allocate_key(id);
87        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
88
89        // Grow capacity if needed
90        if index.size() >= index.capacity() {
91            index
92                .reserve(index.capacity() + 10_000)
93                .map_err(|e| Error::Index(e.to_string()))?;
94        }
95
96        if let Err(e) = index.add(key, vector) {
97            // Rollback orphaned mappings on add failure
98            drop(index);
99            self.rollback_key(id, key);
100            return Err(Error::Index(e.to_string()));
101        }
102        Ok(())
103    }
104
105    fn remove(&self, id: Uuid) -> Result<()> {
106        let key = {
107            let map = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
108            match map.get(&id) {
109                Some(&k) => k,
110                None => return Ok(()),
111            }
112        };
113
114        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
115        index.remove(key).map_err(|e| Error::Index(e.to_string()))?;
116
117        self.uuid_to_key
118            .write()
119            .unwrap_or_else(|e| e.into_inner())
120            .remove(&id);
121        self.key_to_uuid
122            .write()
123            .unwrap_or_else(|e| e.into_inner())
124            .remove(&key);
125        Ok(())
126    }
127
128    fn search(&self, query: &[f32], limit: usize) -> Result<Vec<(Uuid, f32)>> {
129        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
130        let results = index
131            .search(query, limit)
132            .map_err(|e| Error::Index(e.to_string()))?;
133
134        let key_map = self.key_to_uuid.read().unwrap_or_else(|e| e.into_inner());
135        let mut output = Vec::new();
136        for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
137            if let Some(&uuid) = key_map.get(key) {
138                output.push((uuid, *distance));
139            }
140        }
141        Ok(output)
142    }
143
144    fn filtered_search(
145        &self,
146        query: &[f32],
147        limit: usize,
148        filter: &dyn Fn(Uuid) -> bool,
149    ) -> Result<Vec<(Uuid, f32)>> {
150        let index_size = self.len();
151        if index_size == 0 {
152            return Ok(Vec::new());
153        }
154        // Iterative oversample: start at 3x, double until we have enough or hit index size
155        let mut oversample = (limit * 3).max(1);
156        loop {
157            let results = self.search(query, oversample.min(index_size))?;
158            let filtered: Vec<(Uuid, f32)> = results
159                .into_iter()
160                .filter(|(uuid, _)| filter(*uuid))
161                .take(limit)
162                .collect();
163            if filtered.len() >= limit || oversample >= index_size {
164                return Ok(filtered);
165            }
166            oversample = (oversample * 2).min(index_size);
167        }
168    }
169
170    fn save(&self, path: &Path) -> Result<()> {
171        let path_str = path
172            .to_str()
173            .ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
174        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
175        index
176            .save(path_str)
177            .map_err(|e| Error::Index(e.to_string()))?;
178
179        // Save mappings alongside
180        let mappings_path = path.with_extension("mappings.json");
181        let uuid_to_key = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
182        let next_key = *self.next_key.read().unwrap_or_else(|e| e.into_inner());
183        let data = serde_json::json!({
184            "uuid_to_key": uuid_to_key.iter().map(|(k, v)| (k.to_string(), v)).collect::<HashMap<String, &u64>>(),
185            "next_key": next_key,
186        });
187        let json_str = serde_json::to_string(&data).map_err(|e| Error::Index(e.to_string()))?;
188        std::fs::write(&mappings_path, json_str).map_err(|e| Error::Index(e.to_string()))?;
189        Ok(())
190    }
191
192    fn load(&self, path: &Path) -> Result<()> {
193        let path_str = path
194            .to_str()
195            .ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
196        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
197        index
198            .load(path_str)
199            .map_err(|e| Error::Index(e.to_string()))?;
200
201        // Load mappings
202        let mappings_path = path.with_extension("mappings.json");
203        if mappings_path.exists() {
204            let data =
205                std::fs::read_to_string(&mappings_path).map_err(|e| Error::Index(e.to_string()))?;
206            let parsed: serde_json::Value =
207                serde_json::from_str(&data).map_err(|e| Error::Index(e.to_string()))?;
208
209            let mut uuid_to_key = self.uuid_to_key.write().unwrap_or_else(|e| e.into_inner());
210            let mut key_to_uuid = self.key_to_uuid.write().unwrap_or_else(|e| e.into_inner());
211            let mut next_key = self.next_key.write().unwrap_or_else(|e| e.into_inner());
212
213            uuid_to_key.clear();
214            key_to_uuid.clear();
215
216            if let Some(map) = parsed["uuid_to_key"].as_object() {
217                for (uuid_str, key_val) in map {
218                    let uuid =
219                        Uuid::parse_str(uuid_str).map_err(|e| Error::Index(e.to_string()))?;
220                    let key = key_val.as_u64().ok_or_else(|| {
221                        Error::Index(format!("invalid key value for UUID {uuid_str}"))
222                    })?;
223                    uuid_to_key.insert(uuid, key);
224                    key_to_uuid.insert(key, uuid);
225                }
226            }
227
228            if let Some(nk) = parsed["next_key"].as_u64() {
229                *next_key = nk;
230            }
231        }
232        Ok(())
233    }
234
235    fn len(&self) -> usize {
236        let index = self.index.read().unwrap_or_else(|e| e.into_inner());
237        index.size()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    fn random_vector(dims: usize, seed: u64) -> Vec<f32> {
246        // Simple deterministic pseudo-random
247        let mut v = Vec::with_capacity(dims);
248        let mut x = seed;
249        for _ in 0..dims {
250            x = x.wrapping_mul(6364136223846793005).wrapping_add(1);
251            v.push((x as f32) / (u64::MAX as f32));
252        }
253        // Normalize
254        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
255        if norm > 0.0 {
256            for x in &mut v {
257                *x /= norm;
258            }
259        }
260        v
261    }
262
263    #[test]
264    fn test_add_and_search() {
265        let index = UsearchIndex::new(128).unwrap();
266
267        let mut ids = Vec::new();
268        let mut vectors = Vec::new();
269        for i in 0..100 {
270            let id = Uuid::now_v7();
271            let vec = random_vector(128, i);
272            index.add(id, &vec).unwrap();
273            ids.push(id);
274            vectors.push(vec);
275        }
276
277        assert_eq!(index.len(), 100);
278
279        // Search with the first vector should return itself as nearest
280        let results = index.search(&vectors[0], 5).unwrap();
281        assert!(!results.is_empty());
282        assert_eq!(results[0].0, ids[0]);
283    }
284
285    #[test]
286    fn test_remove() {
287        let index = UsearchIndex::new(128).unwrap();
288        let id = Uuid::now_v7();
289        let vec = random_vector(128, 42);
290
291        index.add(id, &vec).unwrap();
292        assert_eq!(index.len(), 1);
293
294        index.remove(id).unwrap();
295        assert_eq!(index.len(), 0);
296    }
297
298    #[test]
299    fn test_filtered_search() {
300        let index = UsearchIndex::new(128).unwrap();
301
302        let mut ids = Vec::new();
303        for i in 0..50 {
304            let id = Uuid::now_v7();
305            let vec = random_vector(128, i);
306            index.add(id, &vec).unwrap();
307            ids.push(id);
308        }
309
310        // Filter out all even-indexed IDs
311        let excluded: std::collections::HashSet<Uuid> = ids.iter().step_by(2).copied().collect();
312        let query = random_vector(128, 0);
313        let results = index
314            .filtered_search(&query, 10, &|id| !excluded.contains(&id))
315            .unwrap();
316
317        // All results should be odd-indexed
318        for (id, _) in &results {
319            assert!(!excluded.contains(id));
320        }
321    }
322
323    #[test]
324    fn test_save_and_load() {
325        let dir = std::env::temp_dir().join(format!("usearch_test_{}", Uuid::now_v7()));
326        std::fs::create_dir_all(&dir).unwrap();
327        let index_path = dir.join("test.usearch");
328
329        let index = UsearchIndex::new(128).unwrap();
330        let id1 = Uuid::now_v7();
331        let id2 = Uuid::now_v7();
332        index.add(id1, &random_vector(128, 1)).unwrap();
333        index.add(id2, &random_vector(128, 2)).unwrap();
334
335        index.save(&index_path).unwrap();
336
337        // Load into a new index
338        let index2 = UsearchIndex::new(128).unwrap();
339        index2.load(&index_path).unwrap();
340        assert_eq!(index2.len(), 2);
341
342        // Search should still work
343        let results = index2.search(&random_vector(128, 1), 1).unwrap();
344        assert_eq!(results[0].0, id1);
345
346        // Cleanup
347        std::fs::remove_dir_all(&dir).ok();
348    }
349
350    #[test]
351    fn test_dimension_mismatch() {
352        let index = UsearchIndex::new(128).unwrap();
353        let result = index.add(Uuid::now_v7(), &vec![0.1; 64]);
354        assert!(result.is_err());
355    }
356}