mnemo_core/index/
usearch.rs1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::RwLock;
4
5use crate::error::{Error, Result};
6use crate::index::VectorIndex;
7use uuid::Uuid;
8
9pub struct UsearchIndex {
10 index: RwLock<usearch::Index>,
11 uuid_to_key: RwLock<HashMap<Uuid, u64>>,
12 key_to_uuid: RwLock<HashMap<u64, Uuid>>,
13 next_key: RwLock<u64>,
14 dimensions: usize,
15}
16
17impl UsearchIndex {
18 pub fn new(dimensions: usize) -> Result<Self> {
19 let opts = usearch::IndexOptions {
20 dimensions,
21 metric: usearch::MetricKind::Cos,
22 quantization: usearch::ScalarKind::F32,
23 ..Default::default()
24 };
25 let index = usearch::Index::new(&opts).map_err(|e| Error::Index(e.to_string()))?;
26 index
27 .reserve(10_000)
28 .map_err(|e| Error::Index(e.to_string()))?;
29
30 Ok(Self {
31 index: RwLock::new(index),
32 uuid_to_key: RwLock::new(HashMap::new()),
33 key_to_uuid: RwLock::new(HashMap::new()),
34 next_key: RwLock::new(0),
35 dimensions,
36 })
37 }
38
39 fn allocate_key(&self, id: Uuid) -> u64 {
40 let mut next = self.next_key.write().unwrap_or_else(|e| e.into_inner());
41 let key = *next;
42 *next += 1;
43 self.uuid_to_key
44 .write()
45 .unwrap_or_else(|e| e.into_inner())
46 .insert(id, key);
47 self.key_to_uuid
48 .write()
49 .unwrap_or_else(|e| e.into_inner())
50 .insert(key, id);
51 key
52 }
53
54 fn rollback_key(&self, id: Uuid, key: u64) {
55 self.uuid_to_key
56 .write()
57 .unwrap_or_else(|e| e.into_inner())
58 .remove(&id);
59 self.key_to_uuid
60 .write()
61 .unwrap_or_else(|e| e.into_inner())
62 .remove(&key);
63 }
64}
65
66impl VectorIndex for UsearchIndex {
67 fn add(&self, id: Uuid, vector: &[f32]) -> Result<()> {
68 if vector.len() != self.dimensions {
69 return Err(Error::Validation(format!(
70 "expected {} dimensions, got {}",
71 self.dimensions,
72 vector.len()
73 )));
74 }
75
76 if self
78 .uuid_to_key
79 .read()
80 .unwrap_or_else(|e| e.into_inner())
81 .contains_key(&id)
82 {
83 self.remove(id)?;
84 }
85
86 let key = self.allocate_key(id);
87 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
88
89 if index.size() >= index.capacity() {
91 index
92 .reserve(index.capacity() + 10_000)
93 .map_err(|e| Error::Index(e.to_string()))?;
94 }
95
96 if let Err(e) = index.add(key, vector) {
97 drop(index);
99 self.rollback_key(id, key);
100 return Err(Error::Index(e.to_string()));
101 }
102 Ok(())
103 }
104
105 fn remove(&self, id: Uuid) -> Result<()> {
106 let key = {
107 let map = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
108 match map.get(&id) {
109 Some(&k) => k,
110 None => return Ok(()),
111 }
112 };
113
114 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
115 index.remove(key).map_err(|e| Error::Index(e.to_string()))?;
116
117 self.uuid_to_key
118 .write()
119 .unwrap_or_else(|e| e.into_inner())
120 .remove(&id);
121 self.key_to_uuid
122 .write()
123 .unwrap_or_else(|e| e.into_inner())
124 .remove(&key);
125 Ok(())
126 }
127
128 fn search(&self, query: &[f32], limit: usize) -> Result<Vec<(Uuid, f32)>> {
129 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
130 let results = index
131 .search(query, limit)
132 .map_err(|e| Error::Index(e.to_string()))?;
133
134 let key_map = self.key_to_uuid.read().unwrap_or_else(|e| e.into_inner());
135 let mut output = Vec::new();
136 for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
137 if let Some(&uuid) = key_map.get(key) {
138 output.push((uuid, *distance));
139 }
140 }
141 Ok(output)
142 }
143
144 fn filtered_search(
145 &self,
146 query: &[f32],
147 limit: usize,
148 filter: &dyn Fn(Uuid) -> bool,
149 ) -> Result<Vec<(Uuid, f32)>> {
150 let index_size = self.len();
151 if index_size == 0 {
152 return Ok(Vec::new());
153 }
154 let mut oversample = (limit * 3).max(1);
156 loop {
157 let results = self.search(query, oversample.min(index_size))?;
158 let filtered: Vec<(Uuid, f32)> = results
159 .into_iter()
160 .filter(|(uuid, _)| filter(*uuid))
161 .take(limit)
162 .collect();
163 if filtered.len() >= limit || oversample >= index_size {
164 return Ok(filtered);
165 }
166 oversample = (oversample * 2).min(index_size);
167 }
168 }
169
170 fn save(&self, path: &Path) -> Result<()> {
171 let path_str = path
172 .to_str()
173 .ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
174 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
175 index
176 .save(path_str)
177 .map_err(|e| Error::Index(e.to_string()))?;
178
179 let mappings_path = path.with_extension("mappings.json");
181 let uuid_to_key = self.uuid_to_key.read().unwrap_or_else(|e| e.into_inner());
182 let next_key = *self.next_key.read().unwrap_or_else(|e| e.into_inner());
183 let data = serde_json::json!({
184 "uuid_to_key": uuid_to_key.iter().map(|(k, v)| (k.to_string(), v)).collect::<HashMap<String, &u64>>(),
185 "next_key": next_key,
186 });
187 let json_str = serde_json::to_string(&data).map_err(|e| Error::Index(e.to_string()))?;
188 std::fs::write(&mappings_path, json_str).map_err(|e| Error::Index(e.to_string()))?;
189 Ok(())
190 }
191
192 fn load(&self, path: &Path) -> Result<()> {
193 let path_str = path
194 .to_str()
195 .ok_or_else(|| Error::Index("non-UTF-8 index path".to_string()))?;
196 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
197 index
198 .load(path_str)
199 .map_err(|e| Error::Index(e.to_string()))?;
200
201 let mappings_path = path.with_extension("mappings.json");
203 if mappings_path.exists() {
204 let data =
205 std::fs::read_to_string(&mappings_path).map_err(|e| Error::Index(e.to_string()))?;
206 let parsed: serde_json::Value =
207 serde_json::from_str(&data).map_err(|e| Error::Index(e.to_string()))?;
208
209 let mut uuid_to_key = self.uuid_to_key.write().unwrap_or_else(|e| e.into_inner());
210 let mut key_to_uuid = self.key_to_uuid.write().unwrap_or_else(|e| e.into_inner());
211 let mut next_key = self.next_key.write().unwrap_or_else(|e| e.into_inner());
212
213 uuid_to_key.clear();
214 key_to_uuid.clear();
215
216 if let Some(map) = parsed["uuid_to_key"].as_object() {
217 for (uuid_str, key_val) in map {
218 let uuid =
219 Uuid::parse_str(uuid_str).map_err(|e| Error::Index(e.to_string()))?;
220 let key = key_val.as_u64().ok_or_else(|| {
221 Error::Index(format!("invalid key value for UUID {uuid_str}"))
222 })?;
223 uuid_to_key.insert(uuid, key);
224 key_to_uuid.insert(key, uuid);
225 }
226 }
227
228 if let Some(nk) = parsed["next_key"].as_u64() {
229 *next_key = nk;
230 }
231 }
232 Ok(())
233 }
234
235 fn len(&self) -> usize {
236 let index = self.index.read().unwrap_or_else(|e| e.into_inner());
237 index.size()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn random_vector(dims: usize, seed: u64) -> Vec<f32> {
246 let mut v = Vec::with_capacity(dims);
248 let mut x = seed;
249 for _ in 0..dims {
250 x = x.wrapping_mul(6364136223846793005).wrapping_add(1);
251 v.push((x as f32) / (u64::MAX as f32));
252 }
253 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
255 if norm > 0.0 {
256 for x in &mut v {
257 *x /= norm;
258 }
259 }
260 v
261 }
262
263 #[test]
264 fn test_add_and_search() {
265 let index = UsearchIndex::new(128).unwrap();
266
267 let mut ids = Vec::new();
268 let mut vectors = Vec::new();
269 for i in 0..100 {
270 let id = Uuid::now_v7();
271 let vec = random_vector(128, i);
272 index.add(id, &vec).unwrap();
273 ids.push(id);
274 vectors.push(vec);
275 }
276
277 assert_eq!(index.len(), 100);
278
279 let results = index.search(&vectors[0], 5).unwrap();
281 assert!(!results.is_empty());
282 assert_eq!(results[0].0, ids[0]);
283 }
284
285 #[test]
286 fn test_remove() {
287 let index = UsearchIndex::new(128).unwrap();
288 let id = Uuid::now_v7();
289 let vec = random_vector(128, 42);
290
291 index.add(id, &vec).unwrap();
292 assert_eq!(index.len(), 1);
293
294 index.remove(id).unwrap();
295 assert_eq!(index.len(), 0);
296 }
297
298 #[test]
299 fn test_filtered_search() {
300 let index = UsearchIndex::new(128).unwrap();
301
302 let mut ids = Vec::new();
303 for i in 0..50 {
304 let id = Uuid::now_v7();
305 let vec = random_vector(128, i);
306 index.add(id, &vec).unwrap();
307 ids.push(id);
308 }
309
310 let excluded: std::collections::HashSet<Uuid> = ids.iter().step_by(2).copied().collect();
312 let query = random_vector(128, 0);
313 let results = index
314 .filtered_search(&query, 10, &|id| !excluded.contains(&id))
315 .unwrap();
316
317 for (id, _) in &results {
319 assert!(!excluded.contains(id));
320 }
321 }
322
323 #[test]
324 fn test_save_and_load() {
325 let dir = std::env::temp_dir().join(format!("usearch_test_{}", Uuid::now_v7()));
326 std::fs::create_dir_all(&dir).unwrap();
327 let index_path = dir.join("test.usearch");
328
329 let index = UsearchIndex::new(128).unwrap();
330 let id1 = Uuid::now_v7();
331 let id2 = Uuid::now_v7();
332 index.add(id1, &random_vector(128, 1)).unwrap();
333 index.add(id2, &random_vector(128, 2)).unwrap();
334
335 index.save(&index_path).unwrap();
336
337 let index2 = UsearchIndex::new(128).unwrap();
339 index2.load(&index_path).unwrap();
340 assert_eq!(index2.len(), 2);
341
342 let results = index2.search(&random_vector(128, 1), 1).unwrap();
344 assert_eq!(results[0].0, id1);
345
346 std::fs::remove_dir_all(&dir).ok();
348 }
349
350 #[test]
351 fn test_dimension_mismatch() {
352 let index = UsearchIndex::new(128).unwrap();
353 let result = index.add(Uuid::now_v7(), &vec![0.1; 64]);
354 assert!(result.is_err());
355 }
356}