1use instant_distance::{Builder, HnswMap, Point, Search};
13use std::sync::Mutex;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16const REBUILD_THRESHOLD: usize = 200;
18
19const MAX_ENTRIES: usize = 100_000;
21
22static INDEX_EVICTIONS_TOTAL: AtomicU64 = AtomicU64::new(0);
42static LAST_EVICTION_AT_NANOS: AtomicU64 = AtomicU64::new(0);
43
44#[must_use]
51pub fn index_evictions_total() -> u64 {
52 INDEX_EVICTIONS_TOTAL.load(Ordering::Relaxed)
53}
54
55#[must_use]
61pub fn evicted_recently(window_secs: u64) -> bool {
62 let last = LAST_EVICTION_AT_NANOS.load(Ordering::Relaxed);
63 if last == 0 {
64 return false;
65 }
66 let now_nanos = std::time::SystemTime::now()
67 .duration_since(std::time::UNIX_EPOCH)
68 .map(|d| d.as_nanos())
69 .unwrap_or(0);
70 let elapsed_nanos = u128::from(u64::MAX).min(now_nanos.saturating_sub(u128::from(last)));
72 elapsed_nanos < u128::from(window_secs).saturating_mul(1_000_000_000)
73}
74
75#[doc(hidden)]
81pub fn reset_eviction_counters_for_test() {
82 INDEX_EVICTIONS_TOTAL.store(0, Ordering::Relaxed);
83 LAST_EVICTION_AT_NANOS.store(0, Ordering::Relaxed);
84}
85
86#[derive(Clone, Debug)]
88pub struct EmbeddingPoint(pub Vec<f32>);
89
90impl instant_distance::Point for EmbeddingPoint {
91 fn distance(&self, other: &Self) -> f32 {
92 let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
95 1.0 - dot
96 }
97}
98
99pub struct VectorIndex {
101 inner: Mutex<IndexState>,
103}
104
105struct IndexState {
106 hnsw: Option<HnswMap<EmbeddingPoint, String>>,
107 overflow: Vec<(String, Vec<f32>)>,
109 all_entries: Vec<(String, Vec<f32>)>,
111}
112
113#[derive(Debug, Clone)]
115pub struct VectorHit {
116 pub id: String,
117 pub distance: f32,
118}
119
120impl VectorIndex {
121 pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
123 let hnsw = Self::build_hnsw(&entries);
124 VectorIndex {
125 inner: Mutex::new(IndexState {
126 hnsw,
127 overflow: Vec::new(),
128 all_entries: entries,
129 }),
130 }
131 }
132
133 pub fn empty() -> Self {
135 VectorIndex {
136 inner: Mutex::new(IndexState {
137 hnsw: None,
138 overflow: Vec::new(),
139 all_entries: Vec::new(),
140 }),
141 }
142 }
143
144 fn build_hnsw(entries: &[(String, Vec<f32>)]) -> Option<HnswMap<EmbeddingPoint, String>> {
145 if entries.is_empty() {
146 return None;
147 }
148 let points: Vec<EmbeddingPoint> = entries
149 .iter()
150 .map(|(_, emb)| EmbeddingPoint(emb.clone()))
151 .collect();
152 let values: Vec<String> = entries.iter().map(|(id, _)| id.clone()).collect();
153 Some(Builder::default().build(points, values))
154 }
155
156 pub fn insert(&self, id: String, embedding: Vec<f32>) {
158 let mut state = match self.inner.lock() {
159 Ok(s) => s,
160 Err(poisoned) => poisoned.into_inner(),
161 };
162 state.all_entries.push((id.clone(), embedding.clone()));
163 state.overflow.push((id, embedding));
164
165 if state.overflow.len() >= REBUILD_THRESHOLD {
167 state.hnsw = Self::build_hnsw(&state.all_entries);
168 state.overflow.clear();
169 }
170
171 if state.all_entries.len() > MAX_ENTRIES {
173 let excess = state.all_entries.len() - MAX_ENTRIES;
174 for (evicted_id, _) in state.all_entries.iter().take(excess) {
181 tracing::warn!(
182 target: "hnsw.eviction",
183 evicted_id = %evicted_id,
184 reason = "max_entries_reached",
185 max_entries = MAX_ENTRIES,
186 "hnsw index evicting oldest entry: cap reached"
187 );
188 }
189 #[allow(clippy::cast_possible_truncation)]
190 let evicted = excess as u64;
191 INDEX_EVICTIONS_TOTAL.fetch_add(evicted, Ordering::Relaxed);
192
193 state.all_entries.drain(..excess);
194 state.hnsw = Self::build_hnsw(&state.all_entries);
195 state.overflow.clear();
196
197 let now_nanos = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .map(|d| d.as_nanos())
206 .unwrap_or(0);
207 let now_nanos_u64 = u64::try_from(now_nanos).unwrap_or(u64::MAX);
208 LAST_EVICTION_AT_NANOS.store(now_nanos_u64, Ordering::Relaxed);
209 }
210 }
211
212 pub fn remove(&self, id: &str) {
214 let mut state = match self.inner.lock() {
215 Ok(s) => s,
216 Err(poisoned) => poisoned.into_inner(),
217 };
218 state.all_entries.retain(|(eid, _)| eid != id);
219 state.overflow.retain(|(eid, _)| eid != id);
220 }
223
224 pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorHit> {
229 let state = match self.inner.lock() {
230 Ok(s) => s,
231 Err(poisoned) => poisoned.into_inner(),
232 };
233 let query_point = EmbeddingPoint(query.to_vec());
234
235 let mut results: Vec<VectorHit> = Vec::with_capacity(k * 2);
236
237 let valid_ids: std::collections::HashSet<&str> = state
239 .all_entries
240 .iter()
241 .map(|(id, _)| id.as_str())
242 .collect();
243
244 if let Some(ref hnsw) = state.hnsw {
246 let mut search = Search::default();
247 for item in hnsw.search(&query_point, &mut search) {
248 if !valid_ids.contains(item.value.as_str()) {
249 continue; }
251 results.push(VectorHit {
252 id: item.value.clone(),
253 distance: item.distance,
254 });
255 if results.len() >= k * 2 {
256 break;
257 }
258 }
259 }
260
261 let mut overflow_hits: Vec<VectorHit> = state
263 .overflow
264 .iter()
265 .map(|(id, emb)| {
266 let point = EmbeddingPoint(emb.clone());
267 VectorHit {
268 id: id.clone(),
269 distance: query_point.distance(&point),
270 }
271 })
272 .collect();
273 overflow_hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
274
275 results.extend(overflow_hits);
276
277 let mut seen = std::collections::HashSet::new();
279 results.retain(|hit| seen.insert(hit.id.clone()));
280
281 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
283 results.truncate(k);
284 results
285 }
286
287 pub fn len(&self) -> usize {
289 let state = match self.inner.lock() {
290 Ok(s) => s,
291 Err(poisoned) => poisoned.into_inner(),
292 };
293 state.all_entries.len()
294 }
295
296 #[allow(dead_code)]
298 pub fn rebuild(&self) {
299 let mut state = match self.inner.lock() {
300 Ok(s) => s,
301 Err(poisoned) => poisoned.into_inner(),
302 };
303 state.hnsw = Self::build_hnsw(&state.all_entries);
304 state.overflow.clear();
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn make_embedding(values: &[f32]) -> Vec<f32> {
313 let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
315 values.iter().map(|v| v / norm).collect()
316 }
317
318 #[test]
319 fn empty_index_returns_empty() {
320 let idx = VectorIndex::empty();
321 let results = idx.search(&[1.0, 0.0, 0.0], 10);
322 assert!(results.is_empty());
323 }
324
325 #[test]
326 fn basic_search() {
327 let entries = vec![
328 ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
329 ("b".into(), make_embedding(&[0.0, 1.0, 0.0])),
330 ("c".into(), make_embedding(&[0.0, 0.0, 1.0])),
331 ];
332 let idx = VectorIndex::build(entries);
333 let results = idx.search(&make_embedding(&[1.0, 0.1, 0.0]), 2);
334 assert_eq!(results.len(), 2);
335 assert_eq!(results[0].id, "a"); }
337
338 #[test]
339 fn insert_and_search_overflow() {
340 let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
341 let idx = VectorIndex::build(entries);
342 idx.insert("b".into(), make_embedding(&[0.9, 0.1, 0.0]));
343 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 2);
344 assert_eq!(results.len(), 2);
345 assert_eq!(results[0].id, "a");
346 assert_eq!(results[1].id, "b");
347 }
348
349 #[test]
350 fn remove_excludes_from_results() {
351 let entries = vec![
352 ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
353 ("b".into(), make_embedding(&[0.9, 0.1, 0.0])),
354 ];
355 let idx = VectorIndex::build(entries);
356 idx.remove("a");
357 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 5);
358 assert!(results.iter().all(|h| h.id != "a"));
359 }
360
361 #[test]
366 fn test_rebuild_preserves_all_entries() {
367 let raw: Vec<(String, Vec<f32>)> = (0..12)
371 .map(|i| {
372 let mut v = vec![0.0_f32; 16];
373 #[allow(clippy::cast_precision_loss)]
374 let f = i as f32;
375 v[i % 16] = 1.0 + f * 0.01; (format!("id-{i}"), make_embedding(&v))
377 })
378 .collect();
379
380 let idx = VectorIndex::build(raw.clone());
381 idx.rebuild();
382 assert_eq!(idx.len(), raw.len());
383
384 let query = make_embedding(&[1.0; 16]);
386 let hits = idx.search(&query, raw.len() * 2);
387 let found: std::collections::HashSet<String> = hits.into_iter().map(|h| h.id).collect();
388 for (id, _) in &raw {
389 assert!(
390 found.contains(id),
391 "rebuild must preserve id {id}, found: {:?}",
392 found
393 );
394 }
395 }
396
397 #[test]
398 fn test_remove_then_search_excludes_id() {
399 let entries = vec![
400 ("alpha".into(), make_embedding(&[1.0, 0.0, 0.0, 0.0])),
401 ("beta".into(), make_embedding(&[0.9, 0.1, 0.0, 0.0])),
402 ("gamma".into(), make_embedding(&[0.8, 0.2, 0.0, 0.0])),
403 ];
404 let idx = VectorIndex::build(entries);
405 let pre = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
407 assert!(pre.iter().any(|h| h.id == "alpha"));
408
409 idx.remove("alpha");
410 for k in 1..=10 {
412 let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), k);
413 assert!(
414 hits.iter().all(|h| h.id != "alpha"),
415 "removed id `alpha` resurfaced with k={k}: {:?}",
416 hits.iter().map(|h| &h.id).collect::<Vec<_>>()
417 );
418 }
419
420 let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
422 let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
423 assert!(ids.contains(&"beta"));
424 assert!(ids.contains(&"gamma"));
425 }
426
427 #[test]
432 fn empty_index_len_is_zero() {
433 let idx = VectorIndex::empty();
434 assert_eq!(idx.len(), 0);
435 }
436
437 #[test]
438 fn build_with_empty_entries_search_empty() {
439 let idx = VectorIndex::build(Vec::new());
440 assert_eq!(idx.len(), 0);
441 let results = idx.search(&[1.0, 0.0, 0.0], 5);
442 assert!(results.is_empty());
443 }
444
445 #[test]
446 fn search_with_k_zero_returns_empty() {
447 let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
448 let idx = VectorIndex::build(entries);
449 let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 0);
450 assert!(results.is_empty());
451 }
452
453 #[test]
454 fn rebuild_on_empty_does_not_crash() {
455 let idx = VectorIndex::empty();
456 idx.rebuild();
457 assert_eq!(idx.len(), 0);
458 }
459
460 #[test]
461 fn insert_increases_len() {
462 let idx = VectorIndex::empty();
463 idx.insert("a".into(), make_embedding(&[1.0, 0.0, 0.0]));
464 idx.insert("b".into(), make_embedding(&[0.0, 1.0, 0.0]));
465 assert_eq!(idx.len(), 2);
466 }
467
468 #[test]
469 fn embedding_point_distance_orthogonal() {
470 let a = EmbeddingPoint(vec![1.0, 0.0, 0.0]);
471 let b = EmbeddingPoint(vec![0.0, 1.0, 0.0]);
472 assert!((a.distance(&b) - 1.0).abs() < 1e-6);
474 }
475
476 #[test]
477 fn embedding_point_distance_identical_is_zero() {
478 let a = EmbeddingPoint(make_embedding(&[1.0, 1.0, 1.0]));
479 assert!(a.distance(&a).abs() < 1e-6);
481 }
482
483 #[test]
484 fn remove_on_empty_index_is_noop() {
485 let idx = VectorIndex::empty();
486 idx.remove("nonexistent");
487 assert_eq!(idx.len(), 0);
488 }
489
490 #[test]
491 fn insert_triggers_auto_rebuild_at_threshold() {
492 let idx = VectorIndex::empty();
495 for i in 0..205_usize {
496 let mut v = vec![0.0_f32; 8];
497 #[allow(clippy::cast_precision_loss)]
498 let f = i as f32;
499 v[i % 8] = 1.0 + f * 0.001;
500 idx.insert(format!("id-{i}"), make_embedding(&v));
501 }
502 assert_eq!(idx.len(), 205);
503 let q = make_embedding(&[1.0_f32; 8]);
505 let hits = idx.search(&q, 5);
506 assert_eq!(hits.len(), 5);
507 }
508
509 #[test]
510 fn test_rebuild_after_batch_insert_settles() {
511 let idx = VectorIndex::empty();
515 let n = 25_usize;
516 for i in 0..n {
517 let mut v = vec![0.0_f32; 8];
518 #[allow(clippy::cast_precision_loss)]
519 let f = i as f32;
520 v[i % 8] = 1.0 + f * 0.001;
521 idx.insert(format!("id-{i}"), make_embedding(&v));
522 }
523 idx.rebuild();
525 assert_eq!(idx.len(), n);
526
527 let query = make_embedding(&[1.0; 8]);
528 let k = 5;
529 let hits = idx.search(&query, k);
530 assert_eq!(
531 hits.len(),
532 k,
533 "post-rebuild search top-{k} must return exactly {k} hits, got {:?}",
534 hits.iter().map(|h| &h.id).collect::<Vec<_>>()
535 );
536
537 for w in hits.windows(2) {
539 assert!(
540 w[0].distance <= w[1].distance,
541 "search results must be ascending by distance: {} > {}",
542 w[0].distance,
543 w[1].distance
544 );
545 }
546
547 let mut seen = std::collections::HashSet::new();
549 for h in &hits {
550 assert!(
551 seen.insert(h.id.clone()),
552 "duplicate id in search: {}",
553 h.id
554 );
555 }
556 }
557}