1use instant_distance::{Builder, HnswMap, Point, Search};
13use std::sync::Mutex;
14
15const REBUILD_THRESHOLD: usize = 200;
17
18const MAX_ENTRIES: usize = 100_000;
20
21#[derive(Clone, Debug)]
23pub struct EmbeddingPoint(pub Vec<f32>);
24
25impl instant_distance::Point for EmbeddingPoint {
26 fn distance(&self, other: &Self) -> f32 {
27 let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
30 1.0 - dot
31 }
32}
33
34pub struct VectorIndex {
36 inner: Mutex<IndexState>,
38}
39
40struct IndexState {
41 hnsw: Option<HnswMap<EmbeddingPoint, String>>,
42 overflow: Vec<(String, Vec<f32>)>,
44 all_entries: Vec<(String, Vec<f32>)>,
46}
47
48#[derive(Debug, Clone)]
50pub struct VectorHit {
51 pub id: String,
52 pub distance: f32,
53}
54
55impl VectorIndex {
56 pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
58 let hnsw = Self::build_hnsw(&entries);
59 VectorIndex {
60 inner: Mutex::new(IndexState {
61 hnsw,
62 overflow: Vec::new(),
63 all_entries: entries,
64 }),
65 }
66 }
67
68 pub fn empty() -> Self {
70 VectorIndex {
71 inner: Mutex::new(IndexState {
72 hnsw: None,
73 overflow: Vec::new(),
74 all_entries: Vec::new(),
75 }),
76 }
77 }
78
79 fn build_hnsw(entries: &[(String, Vec<f32>)]) -> Option<HnswMap<EmbeddingPoint, String>> {
80 if entries.is_empty() {
81 return None;
82 }
83 let points: Vec<EmbeddingPoint> = entries
84 .iter()
85 .map(|(_, emb)| EmbeddingPoint(emb.clone()))
86 .collect();
87 let values: Vec<String> = entries.iter().map(|(id, _)| id.clone()).collect();
88 Some(Builder::default().build(points, values))
89 }
90
91 pub fn insert(&self, id: String, embedding: Vec<f32>) {
93 let mut state = match self.inner.lock() {
94 Ok(s) => s,
95 Err(poisoned) => poisoned.into_inner(),
96 };
97 state.all_entries.push((id.clone(), embedding.clone()));
98 state.overflow.push((id, embedding));
99
100 if state.overflow.len() >= REBUILD_THRESHOLD {
102 state.hnsw = Self::build_hnsw(&state.all_entries);
103 state.overflow.clear();
104 }
105
106 if state.all_entries.len() > MAX_ENTRIES {
108 let excess = state.all_entries.len() - MAX_ENTRIES;
109 state.all_entries.drain(..excess);
110 state.hnsw = Self::build_hnsw(&state.all_entries);
111 state.overflow.clear();
112 }
113 }
114
115 pub fn remove(&self, id: &str) {
117 let mut state = match self.inner.lock() {
118 Ok(s) => s,
119 Err(poisoned) => poisoned.into_inner(),
120 };
121 state.all_entries.retain(|(eid, _)| eid != id);
122 state.overflow.retain(|(eid, _)| eid != id);
123 }
126
127 pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorHit> {
132 let state = match self.inner.lock() {
133 Ok(s) => s,
134 Err(poisoned) => poisoned.into_inner(),
135 };
136 let query_point = EmbeddingPoint(query.to_vec());
137
138 let mut results: Vec<VectorHit> = Vec::with_capacity(k * 2);
139
140 let valid_ids: std::collections::HashSet<&str> = state
142 .all_entries
143 .iter()
144 .map(|(id, _)| id.as_str())
145 .collect();
146
147 if let Some(ref hnsw) = state.hnsw {
149 let mut search = Search::default();
150 for item in hnsw.search(&query_point, &mut search) {
151 if !valid_ids.contains(item.value.as_str()) {
152 continue; }
154 results.push(VectorHit {
155 id: item.value.clone(),
156 distance: item.distance,
157 });
158 if results.len() >= k * 2 {
159 break;
160 }
161 }
162 }
163
164 let mut overflow_hits: Vec<VectorHit> = state
166 .overflow
167 .iter()
168 .map(|(id, emb)| {
169 let point = EmbeddingPoint(emb.clone());
170 VectorHit {
171 id: id.clone(),
172 distance: query_point.distance(&point),
173 }
174 })
175 .collect();
176 overflow_hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
177
178 results.extend(overflow_hits);
179
180 let mut seen = std::collections::HashSet::new();
182 results.retain(|hit| seen.insert(hit.id.clone()));
183
184 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
186 results.truncate(k);
187 results
188 }
189
190 pub fn len(&self) -> usize {
192 let state = match self.inner.lock() {
193 Ok(s) => s,
194 Err(poisoned) => poisoned.into_inner(),
195 };
196 state.all_entries.len()
197 }
198
199 #[allow(dead_code)]
201 pub fn rebuild(&self) {
202 let mut state = match self.inner.lock() {
203 Ok(s) => s,
204 Err(poisoned) => poisoned.into_inner(),
205 };
206 state.hnsw = Self::build_hnsw(&state.all_entries);
207 state.overflow.clear();
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn make_embedding(values: &[f32]) -> Vec<f32> {
216 let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
218 values.iter().map(|v| v / norm).collect()
219 }
220
221 #[test]
222 fn empty_index_returns_empty() {
223 let idx = VectorIndex::empty();
224 let results = idx.search(&[1.0, 0.0, 0.0], 10);
225 assert!(results.is_empty());
226 }
227
228 #[test]
229 fn basic_search() {
230 let entries = vec![
231 ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
232 ("b".into(), make_embedding(&[0.0, 1.0, 0.0])),
233 ("c".into(), make_embedding(&[0.0, 0.0, 1.0])),
234 ];
235 let idx = VectorIndex::build(entries);
236 let results = idx.search(&make_embedding(&[1.0, 0.1, 0.0]), 2);
237 assert_eq!(results.len(), 2);
238 assert_eq!(results[0].id, "a"); }
240
241 #[test]
242 fn insert_and_search_overflow() {
243 let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
244 let idx = VectorIndex::build(entries);
245 idx.insert("b".into(), make_embedding(&[0.9, 0.1, 0.0]));
246 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 2);
247 assert_eq!(results.len(), 2);
248 assert_eq!(results[0].id, "a");
249 assert_eq!(results[1].id, "b");
250 }
251
252 #[test]
253 fn remove_excludes_from_results() {
254 let entries = vec![
255 ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
256 ("b".into(), make_embedding(&[0.9, 0.1, 0.0])),
257 ];
258 let idx = VectorIndex::build(entries);
259 idx.remove("a");
260 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 5);
261 assert!(results.iter().all(|h| h.id != "a"));
262 }
263
264 #[test]
269 fn test_rebuild_preserves_all_entries() {
270 let raw: Vec<(String, Vec<f32>)> = (0..12)
274 .map(|i| {
275 let mut v = vec![0.0_f32; 16];
276 #[allow(clippy::cast_precision_loss)]
277 let f = i as f32;
278 v[i % 16] = 1.0 + f * 0.01; (format!("id-{i}"), make_embedding(&v))
280 })
281 .collect();
282
283 let idx = VectorIndex::build(raw.clone());
284 idx.rebuild();
285 assert_eq!(idx.len(), raw.len());
286
287 let query = make_embedding(&[1.0; 16]);
289 let hits = idx.search(&query, raw.len() * 2);
290 let found: std::collections::HashSet<String> = hits.into_iter().map(|h| h.id).collect();
291 for (id, _) in &raw {
292 assert!(
293 found.contains(id),
294 "rebuild must preserve id {id}, found: {:?}",
295 found
296 );
297 }
298 }
299
300 #[test]
301 fn test_remove_then_search_excludes_id() {
302 let entries = vec![
303 ("alpha".into(), make_embedding(&[1.0, 0.0, 0.0, 0.0])),
304 ("beta".into(), make_embedding(&[0.9, 0.1, 0.0, 0.0])),
305 ("gamma".into(), make_embedding(&[0.8, 0.2, 0.0, 0.0])),
306 ];
307 let idx = VectorIndex::build(entries);
308 let pre = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
310 assert!(pre.iter().any(|h| h.id == "alpha"));
311
312 idx.remove("alpha");
313 for k in 1..=10 {
315 let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), k);
316 assert!(
317 hits.iter().all(|h| h.id != "alpha"),
318 "removed id `alpha` resurfaced with k={k}: {:?}",
319 hits.iter().map(|h| &h.id).collect::<Vec<_>>()
320 );
321 }
322
323 let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
325 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
326 assert!(ids.contains(&"beta"));
327 assert!(ids.contains(&"gamma"));
328 }
329
330 #[test]
335 fn empty_index_len_is_zero() {
336 let idx = VectorIndex::empty();
337 assert_eq!(idx.len(), 0);
338 }
339
340 #[test]
341 fn build_with_empty_entries_search_empty() {
342 let idx = VectorIndex::build(Vec::new());
343 assert_eq!(idx.len(), 0);
344 let results = idx.search(&[1.0, 0.0, 0.0], 5);
345 assert!(results.is_empty());
346 }
347
348 #[test]
349 fn search_with_k_zero_returns_empty() {
350 let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
351 let idx = VectorIndex::build(entries);
352 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 0);
353 assert!(results.is_empty());
354 }
355
356 #[test]
357 fn rebuild_on_empty_does_not_crash() {
358 let idx = VectorIndex::empty();
359 idx.rebuild();
360 assert_eq!(idx.len(), 0);
361 }
362
363 #[test]
364 fn insert_increases_len() {
365 let idx = VectorIndex::empty();
366 idx.insert("a".into(), make_embedding(&[1.0, 0.0, 0.0]));
367 idx.insert("b".into(), make_embedding(&[0.0, 1.0, 0.0]));
368 assert_eq!(idx.len(), 2);
369 }
370
371 #[test]
372 fn embedding_point_distance_orthogonal() {
373 let a = EmbeddingPoint(vec![1.0, 0.0, 0.0]);
374 let b = EmbeddingPoint(vec![0.0, 1.0, 0.0]);
375 assert!((a.distance(&b) - 1.0).abs() < 1e-6);
377 }
378
379 #[test]
380 fn embedding_point_distance_identical_is_zero() {
381 let a = EmbeddingPoint(make_embedding(&[1.0, 1.0, 1.0]));
382 assert!(a.distance(&a).abs() < 1e-6);
384 }
385
386 #[test]
387 fn remove_on_empty_index_is_noop() {
388 let idx = VectorIndex::empty();
389 idx.remove("nonexistent");
390 assert_eq!(idx.len(), 0);
391 }
392
393 #[test]
394 fn insert_triggers_auto_rebuild_at_threshold() {
395 let idx = VectorIndex::empty();
398 for i in 0..205_usize {
399 let mut v = vec![0.0_f32; 8];
400 #[allow(clippy::cast_precision_loss)]
401 let f = i as f32;
402 v[i % 8] = 1.0 + f * 0.001;
403 idx.insert(format!("id-{i}"), make_embedding(&v));
404 }
405 assert_eq!(idx.len(), 205);
406 let q = make_embedding(&[1.0_f32; 8]);
408 let hits = idx.search(&q, 5);
409 assert_eq!(hits.len(), 5);
410 }
411
412 #[test]
413 fn test_rebuild_after_batch_insert_settles() {
414 let idx = VectorIndex::empty();
418 let n = 25_usize;
419 for i in 0..n {
420 let mut v = vec![0.0_f32; 8];
421 #[allow(clippy::cast_precision_loss)]
422 let f = i as f32;
423 v[i % 8] = 1.0 + f * 0.001;
424 idx.insert(format!("id-{i}"), make_embedding(&v));
425 }
426 idx.rebuild();
428 assert_eq!(idx.len(), n);
429
430 let query = make_embedding(&[1.0; 8]);
431 let k = 5;
432 let hits = idx.search(&query, k);
433 assert_eq!(
434 hits.len(),
435 k,
436 "post-rebuild search top-{k} must return exactly {k} hits, got {:?}",
437 hits.iter().map(|h| &h.id).collect::<Vec<_>>()
438 );
439
440 for w in hits.windows(2) {
442 assert!(
443 w[0].distance <= w[1].distance,
444 "search results must be ascending by distance: {} > {}",
445 w[0].distance,
446 w[1].distance
447 );
448 }
449
450 let mut seen = std::collections::HashSet::new();
452 for h in &hits {
453 assert!(
454 seen.insert(h.id.clone()),
455 "duplicate id in search: {}",
456 h.id
457 );
458 }
459 }
460}