1use std::collections::{HashMap, HashSet};
15use std::fs;
16use std::path::Path;
17use std::sync::RwLock;
18
19use hnsw_rs::prelude::*;
20
21use crate::config::HnswConfig;
22use crate::error::{PulseDBError, Result};
23use crate::types::ExperienceId;
24
25use super::VectorIndex;
26
27const BRUTE_FORCE_THRESHOLD: usize = 128;
33
34struct FilterBridge<'a>(&'a (dyn Fn(&usize) -> bool + Sync));
42
43impl FilterT for FilterBridge<'_> {
44 fn hnsw_filter(&self, id: &DataId) -> bool {
45 (self.0)(id)
46 }
47}
48
49pub struct HnswIndex {
62 hnsw: Hnsw<'static, f32, DistCosine>,
65
66 state: RwLock<IndexState>,
68
69 #[allow(dead_code)]
71 config: HnswConfig,
72
73 dimension: usize,
75}
76
77#[derive(Debug)]
79struct IndexState {
80 id_to_internal: HashMap<ExperienceId, usize>,
82
83 internal_to_id: Vec<ExperienceId>,
86
87 deleted: HashSet<usize>,
89
90 next_id: usize,
92}
93
94#[derive(serde::Serialize, serde::Deserialize)]
96pub(crate) struct IndexMetadata {
97 pub(crate) dimension: usize,
98 pub(crate) next_id: usize,
99 pub(crate) id_map: Vec<(String, usize)>,
101 pub(crate) deleted: Vec<String>,
107}
108
109impl HnswIndex {
110 pub fn new(dimension: usize, config: &HnswConfig) -> Self {
117 let hnsw = Hnsw::new(
118 config.max_nb_connection,
119 config.max_elements,
120 config.max_layer,
121 config.ef_construction,
122 DistCosine,
123 );
124
125 Self {
126 hnsw,
127 state: RwLock::new(IndexState {
128 id_to_internal: HashMap::new(),
129 internal_to_id: Vec::new(),
130 deleted: HashSet::new(),
131 next_id: 0,
132 }),
133 config: config.clone(),
134 dimension,
135 }
136 }
137
138 pub fn insert_experience(&self, exp_id: ExperienceId, embedding: &[f32]) -> Result<()> {
143 if embedding.len() != self.dimension {
144 return Err(PulseDBError::vector(format!(
145 "Embedding dimension mismatch: expected {}, got {}",
146 self.dimension,
147 embedding.len()
148 )));
149 }
150
151 let mut state = self
152 .state
153 .write()
154 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
155
156 if state.id_to_internal.contains_key(&exp_id) {
158 return Ok(());
159 }
160
161 let internal_id = state.next_id;
163 state.next_id += 1;
164
165 state.id_to_internal.insert(exp_id, internal_id);
167 state.internal_to_id.push(exp_id);
168
169 drop(state);
171
172 self.hnsw.insert((embedding, internal_id));
174
175 Ok(())
176 }
177
178 pub fn delete_experience(&self, exp_id: ExperienceId) -> Result<()> {
184 let mut state = self
185 .state
186 .write()
187 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
188
189 if let Some(&internal_id) = state.id_to_internal.get(&exp_id) {
190 state.deleted.insert(internal_id);
191 }
192
193 Ok(())
194 }
195
196 pub fn search_experiences(
202 &self,
203 query: &[f32],
204 k: usize,
205 ef_search: usize,
206 ) -> Result<Vec<(ExperienceId, f32)>> {
207 if query.len() != self.dimension {
208 return Err(PulseDBError::vector(format!(
209 "Query dimension mismatch: expected {}, got {}",
210 self.dimension,
211 query.len()
212 )));
213 }
214
215 let state = self
216 .state
217 .read()
218 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
219
220 let active_count = state.next_id - state.deleted.len();
221 if active_count == 0 {
222 return Ok(vec![]);
223 }
224 let effective_k = k.min(active_count);
225
226 if active_count <= BRUTE_FORCE_THRESHOLD {
227 let dist_fn = DistCosine;
231 let mut all_distances: Vec<(ExperienceId, f32)> = Vec::with_capacity(active_count);
232
233 for point in self.hnsw.get_point_indexation().into_iter() {
234 let origin_id = point.get_origin_id();
235 if state.deleted.contains(&origin_id) {
236 continue;
237 }
238 let distance = dist_fn.eval(query, point.get_v());
239 if let Some(&exp_id) = state.internal_to_id.get(origin_id) {
240 all_distances.push((exp_id, distance));
241 }
242 }
243
244 all_distances
245 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
246 all_distances.truncate(effective_k);
247 return Ok(all_distances);
248 }
249
250 let effective_ef = ef_search.max(effective_k);
252 let deleted_ref = &state.deleted;
253 let filter_fn = |id: &usize| -> bool { !deleted_ref.contains(id) };
254 let results = if state.deleted.is_empty() {
255 self.hnsw.search(query, effective_k, effective_ef)
256 } else {
257 self.hnsw
258 .search_filter(query, effective_k, effective_ef, Some(&filter_fn))
259 };
260
261 let mapped: Vec<(ExperienceId, f32)> = results
263 .into_iter()
264 .filter_map(|n| {
265 state
266 .internal_to_id
267 .get(n.d_id)
268 .map(|&exp_id| (exp_id, n.distance))
269 })
270 .collect();
271
272 Ok(mapped)
273 }
274
275 pub fn contains(&self, exp_id: ExperienceId) -> bool {
277 let state = self.state.read().ok();
278 state.is_some_and(|s| {
279 s.id_to_internal
280 .get(&exp_id)
281 .is_some_and(|id| !s.deleted.contains(id))
282 })
283 }
284
285 pub fn active_count(&self) -> usize {
287 let state = self.state.read().ok();
288 state.map_or(0, |s| s.id_to_internal.len() - s.deleted.len())
289 }
290
291 pub fn total_count(&self) -> usize {
293 self.hnsw.get_nb_point()
294 }
295
296 pub fn restore_deleted_set(&self, deleted_exp_ids: &[String]) -> Result<()> {
303 let mut state = self
304 .state
305 .write()
306 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
307 for exp_id_str in deleted_exp_ids {
308 let uuid = uuid::Uuid::parse_str(exp_id_str)
310 .map_err(|e| PulseDBError::vector(format!("Invalid UUID in deleted set: {}", e)))?;
311 let exp_id = ExperienceId::from_bytes(*uuid.as_bytes());
312 if let Some(&internal_id) = state.id_to_internal.get(&exp_id) {
315 state.deleted.insert(internal_id);
316 }
317 }
318 Ok(())
319 }
320
321 pub fn save_to_dir(&self, dir: &Path, name: &str) -> Result<()> {
328 fs::create_dir_all(dir)
329 .map_err(|e| PulseDBError::vector(format!("Failed to create HNSW directory: {}", e)))?;
330
331 let state = self
332 .state
333 .read()
334 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
335
336 let metadata = IndexMetadata {
338 dimension: self.dimension,
339 next_id: state.next_id,
340 id_map: state
341 .id_to_internal
342 .iter()
343 .map(|(exp_id, &internal_id)| (exp_id.to_string(), internal_id))
344 .collect(),
345 deleted: state
346 .deleted
347 .iter()
348 .filter_map(|&internal_id| {
349 state
350 .internal_to_id
351 .get(internal_id)
352 .map(|exp_id| exp_id.to_string())
353 })
354 .collect(),
355 };
356
357 let meta_path = dir.join(format!("{}.hnsw.meta", name));
359 let json = serde_json::to_string_pretty(&metadata).map_err(|e| {
360 PulseDBError::vector(format!("Failed to serialize HNSW metadata: {}", e))
361 })?;
362 fs::write(&meta_path, json)
363 .map_err(|e| PulseDBError::vector(format!("Failed to write HNSW metadata: {}", e)))?;
364
365 if state.id_to_internal.is_empty() {
367 return Ok(());
368 }
369 drop(state);
370
371 if let Err(e) = self.hnsw.file_dump(dir, name) {
372 tracing::warn!(error = %e, "Failed to dump HNSW graph (non-fatal, will rebuild on next open)");
373 }
374
375 Ok(())
376 }
377
378 #[allow(dead_code)] pub(crate) fn load_metadata(dir: &Path, name: &str) -> Result<Option<IndexMetadata>> {
385 let meta_path = dir.join(format!("{}.hnsw.meta", name));
386 if !meta_path.exists() {
387 return Ok(None);
388 }
389
390 let json = fs::read_to_string(&meta_path)
391 .map_err(|e| PulseDBError::vector(format!("Failed to read HNSW metadata: {}", e)))?;
392 let metadata: IndexMetadata = serde_json::from_str(&json)
393 .map_err(|e| PulseDBError::vector(format!("Failed to parse HNSW metadata: {}", e)))?;
394
395 Ok(Some(metadata))
396 }
397
398 pub fn rebuild_from_embeddings(
403 dimension: usize,
404 config: &HnswConfig,
405 embeddings: Vec<(ExperienceId, Vec<f32>)>,
406 ) -> Result<Self> {
407 let index = Self::new(dimension, config);
408
409 if embeddings.is_empty() {
410 return Ok(index);
411 }
412
413 let mut state = index
415 .state
416 .write()
417 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
418
419 let mut batch: Vec<(&Vec<f32>, usize)> = Vec::with_capacity(embeddings.len());
420
421 for (exp_id, embedding) in &embeddings {
422 let internal_id = state.next_id;
423 state.next_id += 1;
424 state.id_to_internal.insert(*exp_id, internal_id);
425 state.internal_to_id.push(*exp_id);
426 batch.push((embedding, internal_id));
427 }
428
429 drop(state);
430
431 index.hnsw.parallel_insert(&batch);
433
434 Ok(index)
435 }
436
437 pub fn remove_files(dir: &Path, name: &str) -> Result<()> {
439 let meta_path = dir.join(format!("{}.hnsw.meta", name));
441 if meta_path.exists() {
442 fs::remove_file(&meta_path).map_err(|e| {
443 PulseDBError::vector(format!("Failed to remove HNSW metadata: {}", e))
444 })?;
445 }
446
447 if let Ok(entries) = fs::read_dir(dir) {
449 for entry in entries.flatten() {
450 let file_name = entry.file_name();
451 let file_str = file_name.to_string_lossy();
452 if file_str.starts_with(name) && file_str.contains("hnswdump") {
453 let _ = fs::remove_file(entry.path());
454 }
455 }
456 }
457
458 Ok(())
459 }
460}
461
462impl VectorIndex for HnswIndex {
467 fn insert(&self, id: usize, embedding: &[f32]) -> Result<()> {
468 if embedding.len() != self.dimension {
469 return Err(PulseDBError::vector(format!(
470 "Embedding dimension mismatch: expected {}, got {}",
471 self.dimension,
472 embedding.len()
473 )));
474 }
475 self.hnsw.insert((embedding, id));
476 Ok(())
477 }
478
479 fn insert_batch(&self, items: &[(&Vec<f32>, usize)]) -> Result<()> {
480 self.hnsw.parallel_insert(items);
481 Ok(())
482 }
483
484 fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Result<Vec<(usize, f32)>> {
485 let results = self.hnsw.search(query, k, ef_search);
486 Ok(results.into_iter().map(|n| (n.d_id, n.distance)).collect())
487 }
488
489 fn search_filtered(
490 &self,
491 query: &[f32],
492 k: usize,
493 ef_search: usize,
494 filter: &(dyn Fn(&usize) -> bool + Sync),
495 ) -> Result<Vec<(usize, f32)>> {
496 let bridge = FilterBridge(filter);
499 let results = self.hnsw.search_filter(query, k, ef_search, Some(&bridge));
500 Ok(results.into_iter().map(|n| (n.d_id, n.distance)).collect())
501 }
502
503 fn delete(&self, id: usize) -> Result<()> {
504 let mut state = self
505 .state
506 .write()
507 .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
508 state.deleted.insert(id);
509 Ok(())
510 }
511
512 fn is_deleted(&self, id: usize) -> bool {
513 self.state
514 .read()
515 .ok()
516 .is_some_and(|s| s.deleted.contains(&id))
517 }
518
519 fn len(&self) -> usize {
520 self.active_count()
521 }
522
523 fn save(&self, dir: &Path, name: &str) -> Result<()> {
524 self.save_to_dir(dir, name)
525 }
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535 use crate::config::HnswConfig;
536
537 fn test_config() -> HnswConfig {
538 HnswConfig {
539 max_nb_connection: 16,
540 ef_construction: 100,
541 ef_search: 50,
542 max_layer: 8,
543 max_elements: 1000,
544 }
545 }
546
547 fn make_embedding(seed: u64, dim: usize) -> Vec<f32> {
550 (0..dim)
551 .map(|i| (seed as f32 * 0.1 + i as f32 * 0.01).sin())
552 .collect()
553 }
554
555 #[test]
556 fn test_new_index_is_empty() {
557 let index = HnswIndex::new(384, &test_config());
558 assert_eq!(index.active_count(), 0);
559 assert_eq!(index.total_count(), 0);
560 assert!(index.is_empty());
561 }
562
563 #[test]
564 fn test_insert_and_search() {
565 let dim = 8;
566 let config = test_config();
567 let index = HnswIndex::new(dim, &config);
568
569 for i in 0..10u64 {
571 let exp_id = ExperienceId::new();
572 let embedding = make_embedding(i, dim);
573 index.insert_experience(exp_id, &embedding).unwrap();
574 }
575
576 assert_eq!(index.active_count(), 10);
577
578 let query = make_embedding(5, dim);
580 let results = index.search_experiences(&query, 3, 50).unwrap();
581
582 assert!(!results.is_empty());
583 assert!(results.len() <= 3);
584 for w in results.windows(2) {
586 assert!(w[0].1 <= w[1].1, "Results not sorted by distance");
587 }
588 }
589
590 #[test]
591 fn test_insert_idempotent() {
592 let dim = 4;
593 let index = HnswIndex::new(dim, &test_config());
594
595 let exp_id = ExperienceId::new();
596 let embedding = make_embedding(1, dim);
597
598 index.insert_experience(exp_id, &embedding).unwrap();
599 index.insert_experience(exp_id, &embedding).unwrap(); assert_eq!(index.active_count(), 1);
602 }
603
604 #[test]
605 fn test_dimension_mismatch_rejected() {
606 let index = HnswIndex::new(384, &test_config());
607
608 let exp_id = ExperienceId::new();
609 let wrong_dim = vec![1.0f32; 128]; let result = index.insert_experience(exp_id, &wrong_dim);
612 assert!(result.is_err());
613 assert!(result.unwrap_err().is_vector());
614 }
615
616 #[test]
617 fn test_delete_excludes_from_search() {
618 let dim = 8;
619 let index = HnswIndex::new(dim, &test_config());
620
621 let mut ids = Vec::new();
623 for i in 0..5u64 {
624 let exp_id = ExperienceId::new();
625 index
626 .insert_experience(exp_id, &make_embedding(i, dim))
627 .unwrap();
628 ids.push(exp_id);
629 }
630
631 assert_eq!(index.active_count(), 5);
632
633 index.delete_experience(ids[0]).unwrap();
635 assert_eq!(index.active_count(), 4);
636 assert!(!index.contains(ids[0]));
637 assert!(index.contains(ids[1]));
638
639 let query = make_embedding(0, dim); let results = index.search_experiences(&query, 10, 50).unwrap();
642 let result_ids: Vec<ExperienceId> = results.iter().map(|r| r.0).collect();
643 assert!(!result_ids.contains(&ids[0]));
644 }
645
646 #[test]
647 fn test_search_k_larger_than_index() {
648 let dim = 4;
649 let index = HnswIndex::new(dim, &test_config());
650
651 let exp_id = ExperienceId::new();
652 index
653 .insert_experience(exp_id, &make_embedding(1, dim))
654 .unwrap();
655
656 let results = index
658 .search_experiences(&make_embedding(1, dim), 100, 50)
659 .unwrap();
660 assert_eq!(results.len(), 1);
661 }
662
663 #[test]
664 fn test_search_empty_index() {
665 let dim = 4;
666 let index = HnswIndex::new(dim, &test_config());
667
668 let results = index
669 .search_experiences(&make_embedding(1, dim), 10, 50)
670 .unwrap();
671 assert!(results.is_empty());
672 }
673
674 #[test]
675 fn test_rebuild_from_embeddings() {
676 let dim = 8;
677 let config = test_config();
678
679 let embeddings: Vec<(ExperienceId, Vec<f32>)> = (0..20u64)
681 .map(|i| (ExperienceId::new(), make_embedding(i, dim)))
682 .collect();
683
684 let index = HnswIndex::rebuild_from_embeddings(dim, &config, embeddings.clone()).unwrap();
685
686 assert_eq!(index.active_count(), 20);
687
688 let query = make_embedding(10, dim);
690 let results = index.search_experiences(&query, 5, 50).unwrap();
691 assert!(!results.is_empty());
692 }
693
694 #[test]
695 fn test_rebuild_empty() {
696 let dim = 384;
697 let config = test_config();
698 let index = HnswIndex::rebuild_from_embeddings(dim, &config, vec![]).unwrap();
699 assert!(index.is_empty());
700 }
701
702 #[test]
703 fn test_save_and_load_metadata_roundtrip() {
704 let dim = 4;
705 let index = HnswIndex::new(dim, &test_config());
706
707 let mut exp_ids = Vec::new();
708 for i in 0..5u64 {
709 let exp_id = ExperienceId::new();
710 index
711 .insert_experience(exp_id, &make_embedding(i, dim))
712 .unwrap();
713 exp_ids.push(exp_id);
714 }
715 index.delete_experience(exp_ids[2]).unwrap();
716
717 let dir = tempfile::tempdir().unwrap();
719 index.save_to_dir(dir.path(), "test_collective").unwrap();
720
721 let metadata = HnswIndex::load_metadata(dir.path(), "test_collective")
723 .unwrap()
724 .expect("Metadata should exist");
725
726 assert_eq!(metadata.dimension, dim);
727 assert_eq!(metadata.next_id, 5);
728 assert_eq!(metadata.id_map.len(), 5);
729 assert_eq!(metadata.deleted.len(), 1);
730 assert_eq!(metadata.deleted[0], exp_ids[2].to_string());
732 }
733
734 #[test]
735 fn test_remove_files() {
736 let dim = 4;
737 let index = HnswIndex::new(dim, &test_config());
738 index
739 .insert_experience(ExperienceId::new(), &make_embedding(1, dim))
740 .unwrap();
741
742 let dir = tempfile::tempdir().unwrap();
743 index.save_to_dir(dir.path(), "test_coll").unwrap();
744
745 let meta_path = dir.path().join("test_coll.hnsw.meta");
747 assert!(meta_path.exists());
748
749 HnswIndex::remove_files(dir.path(), "test_coll").unwrap();
751 assert!(!meta_path.exists());
752 }
753
754 #[test]
755 fn test_brute_force_search_returns_all_items() {
756 let dim = 8;
757 let config = test_config();
758 let index = HnswIndex::new(dim, &config);
759
760 let mut ids = Vec::new();
762 for i in 0..20u64 {
763 let exp_id = ExperienceId::new();
764 index
765 .insert_experience(exp_id, &make_embedding(i, dim))
766 .unwrap();
767 ids.push(exp_id);
768 }
769
770 let query = make_embedding(10, dim);
772 let results = index.search_experiences(&query, 20, 50).unwrap();
773 assert_eq!(results.len(), 20, "Brute-force must return all 20 items");
774
775 for w in results.windows(2) {
777 assert!(
778 w[0].1 <= w[1].1,
779 "Brute-force results not sorted: {} > {}",
780 w[0].1,
781 w[1].1
782 );
783 }
784
785 assert_eq!(results[0].0, ids[10]);
787 assert!(
788 results[0].1 < 0.001,
789 "Expected near-zero distance for exact match, got {}",
790 results[0].1
791 );
792 }
793
794 #[test]
795 fn test_brute_force_excludes_deleted() {
796 let dim = 8;
797 let index = HnswIndex::new(dim, &test_config());
798
799 let mut ids = Vec::new();
800 for i in 0..5u64 {
801 let exp_id = ExperienceId::new();
802 index
803 .insert_experience(exp_id, &make_embedding(i, dim))
804 .unwrap();
805 ids.push(exp_id);
806 }
807
808 index.delete_experience(ids[2]).unwrap();
810
811 let query = make_embedding(2, dim);
812 let results = index.search_experiences(&query, 10, 50).unwrap();
813 assert_eq!(results.len(), 4, "Should return 4 after deleting 1 of 5");
814 let result_ids: Vec<ExperienceId> = results.iter().map(|r| r.0).collect();
815 assert!(
816 !result_ids.contains(&ids[2]),
817 "Deleted item must be excluded"
818 );
819 }
820
821 #[test]
822 fn test_cosine_distance_identical_vectors() {
823 let dim = 8;
824 let index = HnswIndex::new(dim, &test_config());
825
826 let embedding = make_embedding(42, dim);
827 let exp_id = ExperienceId::new();
828 index.insert_experience(exp_id, &embedding).unwrap();
829
830 let results = index.search_experiences(&embedding, 1, 50).unwrap();
832 assert_eq!(results.len(), 1);
833 assert_eq!(results[0].0, exp_id);
834 assert!(
836 results[0].1 < 0.001,
837 "Expected near-zero distance for identical vectors, got {}",
838 results[0].1
839 );
840 }
841}