Skip to main content

pulsedb/vector/
hnsw.rs

1//! HNSW vector index implementation using hnsw_rs.
2//!
3//! Wraps `hnsw_rs::Hnsw<f32, DistCosine>` with:
4//! - Bidirectional `ExperienceId` ↔ `usize` ID mapping
5//! - Soft-delete via `HashSet` + filtered search
6//! - JSON metadata persistence (`.hnsw.meta`)
7//!
8//! # Thread Safety
9//!
10//! The `hnsw_rs::Hnsw` graph uses `parking_lot::RwLock` internally,
11//! so `insert()` takes `&self`. Our metadata (`IndexState`) is
12//! protected by `std::sync::RwLock`.
13
14use std::collections::{HashMap, HashSet};
15use std::fs;
16use std::path::Path;
17use std::sync::RwLock;
18
19use hnsw_rs::prelude::*;
20
21use crate::config::HnswConfig;
22use crate::error::{PulseDBError, Result};
23use crate::types::ExperienceId;
24
25use super::VectorIndex;
26
27/// Below this threshold, search uses brute-force linear scan instead of HNSW
28/// graph traversal. hnsw_rs stores each point only in its assigned layer, so
29/// points placed in upper layers are unreachable during layer-0 search. For
30/// small collections this causes missed results. Linear scan is both more
31/// reliable (100% recall) and faster (no graph overhead) at this scale.
32const BRUTE_FORCE_THRESHOLD: usize = 128;
33
34/// Newtype wrapper that bridges `&dyn Fn(&usize) -> bool` to `FilterT`.
35///
36/// Rust's blanket impl `impl<F: Fn(&DataId) -> bool> FilterT for F` only
37/// works for concrete types. When we have a `&dyn Fn` trait object (from the
38/// `VectorIndex` trait's `search_filtered` method), we can't coerce it to
39/// `&dyn FilterT` directly. This wrapper implements `FilterT` by delegating
40/// to the wrapped closure trait object.
41struct FilterBridge<'a>(&'a (dyn Fn(&usize) -> bool + Sync));
42
43impl FilterT for FilterBridge<'_> {
44    fn hnsw_filter(&self, id: &DataId) -> bool {
45        (self.0)(id)
46    }
47}
48
49/// HNSW vector index backed by `hnsw_rs`.
50///
51/// Each collective gets its own `HnswIndex` instance, providing
52/// complete data isolation between collectives.
53///
54/// # Persistence Strategy
55///
56/// Metadata (ID mappings, deleted set) is persisted to a JSON `.hnsw.meta`
57/// file. The graph itself is rebuilt from redb embeddings on open, because
58/// `hnsw_rs::HnswIo::load_hnsw` has lifetime constraints that create
59/// self-referential struct issues. The graph dump files (via `file_dump`)
60/// are saved for future optimization but not currently loaded.
61pub struct HnswIndex {
62    /// The underlying HNSW graph. Uses `'static` lifetime because
63    /// all data is heap-owned (not memory-mapped).
64    hnsw: Hnsw<'static, f32, DistCosine>,
65
66    /// Mutable metadata protected by RwLock.
67    state: RwLock<IndexState>,
68
69    /// Immutable configuration (used during save/rebuild lifecycle).
70    #[allow(dead_code)]
71    config: HnswConfig,
72
73    /// Embedding dimension (must match all inserted vectors).
74    dimension: usize,
75}
76
77/// Internal mutable state for ID mapping and soft-deletion.
78#[derive(Debug)]
79struct IndexState {
80    /// Forward map: ExperienceId → internal usize ID.
81    id_to_internal: HashMap<ExperienceId, usize>,
82
83    /// Reverse map: internal usize ID → ExperienceId.
84    /// Uses Vec for O(1) lookup by index.
85    internal_to_id: Vec<ExperienceId>,
86
87    /// Set of soft-deleted internal IDs (excluded from search).
88    deleted: HashSet<usize>,
89
90    /// Next internal ID to assign (monotonically increasing).
91    next_id: usize,
92}
93
94/// Serializable metadata for persistence.
95#[derive(serde::Serialize, serde::Deserialize)]
96pub(crate) struct IndexMetadata {
97    pub(crate) dimension: usize,
98    pub(crate) next_id: usize,
99    /// Vec of (ExperienceId UUID string, internal usize ID) pairs.
100    pub(crate) id_map: Vec<(String, usize)>,
101    /// Deleted ExperienceId UUID strings (not internal IDs).
102    ///
103    /// We store UUIDs instead of internal usize IDs because internal IDs
104    /// are reassigned sequentially on rebuild. Using UUIDs ensures the
105    /// correct experiences are marked as deleted after rebuild.
106    pub(crate) deleted: Vec<String>,
107}
108
109impl HnswIndex {
110    /// Creates a new empty HNSW index.
111    ///
112    /// # Arguments
113    ///
114    /// * `dimension` - Expected embedding dimension (validated on insert)
115    /// * `config` - HNSW tuning parameters
116    pub fn new(dimension: usize, config: &HnswConfig) -> Self {
117        let hnsw = Hnsw::new(
118            config.max_nb_connection,
119            config.max_elements,
120            config.max_layer,
121            config.ef_construction,
122            DistCosine,
123        );
124
125        Self {
126            hnsw,
127            state: RwLock::new(IndexState {
128                id_to_internal: HashMap::new(),
129                internal_to_id: Vec::new(),
130                deleted: HashSet::new(),
131                next_id: 0,
132            }),
133            config: config.clone(),
134            dimension,
135        }
136    }
137
138    /// Inserts an experience embedding into the index.
139    ///
140    /// Assigns a new internal usize ID and records the mapping.
141    /// If the ExperienceId is already present, this is a no-op.
142    pub fn insert_experience(&self, exp_id: ExperienceId, embedding: &[f32]) -> Result<()> {
143        if embedding.len() != self.dimension {
144            return Err(PulseDBError::vector(format!(
145                "Embedding dimension mismatch: expected {}, got {}",
146                self.dimension,
147                embedding.len()
148            )));
149        }
150
151        let mut state = self
152            .state
153            .write()
154            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
155
156        // Skip if already inserted (idempotent)
157        if state.id_to_internal.contains_key(&exp_id) {
158            return Ok(());
159        }
160
161        // Assign next sequential internal ID
162        let internal_id = state.next_id;
163        state.next_id += 1;
164
165        // Record bidirectional mapping
166        state.id_to_internal.insert(exp_id, internal_id);
167        state.internal_to_id.push(exp_id);
168
169        // Drop the lock before calling hnsw insert (which acquires its own lock)
170        drop(state);
171
172        // Insert into HNSW graph (uses interior mutability via parking_lot::RwLock)
173        self.hnsw.insert((embedding, internal_id));
174
175        Ok(())
176    }
177
178    /// Marks an experience as deleted in the index.
179    ///
180    /// The vector remains in the graph but is excluded from search
181    /// results via filtered search. Returns Ok even if the experience
182    /// is not in the index (idempotent).
183    pub fn delete_experience(&self, exp_id: ExperienceId) -> Result<()> {
184        let mut state = self
185            .state
186            .write()
187            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
188
189        if let Some(&internal_id) = state.id_to_internal.get(&exp_id) {
190            state.deleted.insert(internal_id);
191        }
192
193        Ok(())
194    }
195
196    /// Searches for the k nearest experiences, excluding deleted ones.
197    ///
198    /// Returns `(ExperienceId, distance)` pairs sorted by distance
199    /// ascending (closest first). Distance is cosine distance:
200    /// 0.0 = identical, 2.0 = opposite.
201    pub fn search_experiences(
202        &self,
203        query: &[f32],
204        k: usize,
205        ef_search: usize,
206    ) -> Result<Vec<(ExperienceId, f32)>> {
207        if query.len() != self.dimension {
208            return Err(PulseDBError::vector(format!(
209                "Query dimension mismatch: expected {}, got {}",
210                self.dimension,
211                query.len()
212            )));
213        }
214
215        let state = self
216            .state
217            .read()
218            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
219
220        let active_count = state.next_id - state.deleted.len();
221        if active_count == 0 {
222            return Ok(vec![]);
223        }
224        let effective_k = k.min(active_count);
225
226        if active_count <= BRUTE_FORCE_THRESHOLD {
227            // Linear scan: iterate all stored vectors and compute exact distances.
228            // Guarantees 100% recall for small collections where HNSW's layer
229            // fragmentation causes missed results.
230            let dist_fn = DistCosine;
231            let mut all_distances: Vec<(ExperienceId, f32)> = Vec::with_capacity(active_count);
232
233            for point in self.hnsw.get_point_indexation().into_iter() {
234                let origin_id = point.get_origin_id();
235                if state.deleted.contains(&origin_id) {
236                    continue;
237                }
238                let distance = dist_fn.eval(query, point.get_v());
239                if let Some(&exp_id) = state.internal_to_id.get(origin_id) {
240                    all_distances.push((exp_id, distance));
241                }
242            }
243
244            all_distances
245                .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
246            all_distances.truncate(effective_k);
247            return Ok(all_distances);
248        }
249
250        // HNSW graph search for larger collections
251        let effective_ef = ef_search.max(effective_k);
252        let deleted_ref = &state.deleted;
253        let filter_fn = |id: &usize| -> bool { !deleted_ref.contains(id) };
254        let results = if state.deleted.is_empty() {
255            self.hnsw.search(query, effective_k, effective_ef)
256        } else {
257            self.hnsw
258                .search_filter(query, effective_k, effective_ef, Some(&filter_fn))
259        };
260
261        // Map internal IDs back to ExperienceIds
262        let mapped: Vec<(ExperienceId, f32)> = results
263            .into_iter()
264            .filter_map(|n| {
265                state
266                    .internal_to_id
267                    .get(n.d_id)
268                    .map(|&exp_id| (exp_id, n.distance))
269            })
270            .collect();
271
272        Ok(mapped)
273    }
274
275    /// Returns true if the given experience is in the index (and not deleted).
276    pub fn contains(&self, exp_id: ExperienceId) -> bool {
277        let state = self.state.read().ok();
278        state.is_some_and(|s| {
279            s.id_to_internal
280                .get(&exp_id)
281                .is_some_and(|id| !s.deleted.contains(id))
282        })
283    }
284
285    /// Returns the number of active (non-deleted) vectors.
286    pub fn active_count(&self) -> usize {
287        let state = self.state.read().ok();
288        state.map_or(0, |s| s.id_to_internal.len() - s.deleted.len())
289    }
290
291    /// Returns the total number of vectors (including deleted).
292    pub fn total_count(&self) -> usize {
293        self.hnsw.get_nb_point()
294    }
295
296    /// Restores the deleted set from persisted metadata.
297    ///
298    /// Called during `PulseDB::open()` after rebuilding the graph from redb.
299    /// Accepts ExperienceId UUID strings and maps them to the current
300    /// internal IDs (which may differ from the previous session's IDs
301    /// after a rebuild).
302    pub fn restore_deleted_set(&self, deleted_exp_ids: &[String]) -> Result<()> {
303        let mut state = self
304            .state
305            .write()
306            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
307        for exp_id_str in deleted_exp_ids {
308            // Parse UUID string back to ExperienceId
309            let uuid = uuid::Uuid::parse_str(exp_id_str)
310                .map_err(|e| PulseDBError::vector(format!("Invalid UUID in deleted set: {}", e)))?;
311            let exp_id = ExperienceId::from_bytes(*uuid.as_bytes());
312            // Map to current internal ID (skip if not found — experience
313            // may have been hard-deleted from redb since last save)
314            if let Some(&internal_id) = state.id_to_internal.get(&exp_id) {
315                state.deleted.insert(internal_id);
316            }
317        }
318        Ok(())
319    }
320
321    /// Saves index metadata to a JSON file.
322    ///
323    /// Creates `{dir}/{name}.hnsw.meta` with ID mappings and deleted set.
324    /// Also attempts to save the HNSW graph via `file_dump` for future
325    /// optimization (graph loading is not yet implemented due to lifetime
326    /// constraints in hnsw_rs).
327    pub fn save_to_dir(&self, dir: &Path, name: &str) -> Result<()> {
328        fs::create_dir_all(dir)
329            .map_err(|e| PulseDBError::vector(format!("Failed to create HNSW directory: {}", e)))?;
330
331        let state = self
332            .state
333            .read()
334            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
335
336        // Build metadata
337        let metadata = IndexMetadata {
338            dimension: self.dimension,
339            next_id: state.next_id,
340            id_map: state
341                .id_to_internal
342                .iter()
343                .map(|(exp_id, &internal_id)| (exp_id.to_string(), internal_id))
344                .collect(),
345            deleted: state
346                .deleted
347                .iter()
348                .filter_map(|&internal_id| {
349                    state
350                        .internal_to_id
351                        .get(internal_id)
352                        .map(|exp_id| exp_id.to_string())
353                })
354                .collect(),
355        };
356
357        // Write metadata as JSON
358        let meta_path = dir.join(format!("{}.hnsw.meta", name));
359        let json = serde_json::to_string_pretty(&metadata).map_err(|e| {
360            PulseDBError::vector(format!("Failed to serialize HNSW metadata: {}", e))
361        })?;
362        fs::write(&meta_path, json)
363            .map_err(|e| PulseDBError::vector(format!("Failed to write HNSW metadata: {}", e)))?;
364
365        // Also dump the HNSW graph (for future direct-load optimization)
366        if state.id_to_internal.is_empty() {
367            return Ok(());
368        }
369        drop(state);
370
371        if let Err(e) = self.hnsw.file_dump(dir, name) {
372            tracing::warn!(error = %e, "Failed to dump HNSW graph (non-fatal, will rebuild on next open)");
373        }
374
375        Ok(())
376    }
377
378    /// Loads index metadata from a JSON file.
379    ///
380    /// Returns the metadata needed to rebuild the graph. The caller must
381    /// create a new `HnswIndex` and re-insert embeddings using the
382    /// stored ID mappings.
383    #[allow(dead_code)] // Used in Step 4 (db.rs open/close lifecycle)
384    pub(crate) fn load_metadata(dir: &Path, name: &str) -> Result<Option<IndexMetadata>> {
385        let meta_path = dir.join(format!("{}.hnsw.meta", name));
386        if !meta_path.exists() {
387            return Ok(None);
388        }
389
390        let json = fs::read_to_string(&meta_path)
391            .map_err(|e| PulseDBError::vector(format!("Failed to read HNSW metadata: {}", e)))?;
392        let metadata: IndexMetadata = serde_json::from_str(&json)
393            .map_err(|e| PulseDBError::vector(format!("Failed to parse HNSW metadata: {}", e)))?;
394
395        Ok(Some(metadata))
396    }
397
398    /// Rebuilds an index from a set of embeddings.
399    ///
400    /// Used during `PulseDB::open()` to reconstruct the HNSW graph
401    /// from embeddings stored in redb (the source of truth).
402    pub fn rebuild_from_embeddings(
403        dimension: usize,
404        config: &HnswConfig,
405        embeddings: Vec<(ExperienceId, Vec<f32>)>,
406    ) -> Result<Self> {
407        let index = Self::new(dimension, config);
408
409        if embeddings.is_empty() {
410            return Ok(index);
411        }
412
413        // Prepare batch data for parallel insertion
414        let mut state = index
415            .state
416            .write()
417            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
418
419        let mut batch: Vec<(&Vec<f32>, usize)> = Vec::with_capacity(embeddings.len());
420
421        for (exp_id, embedding) in &embeddings {
422            let internal_id = state.next_id;
423            state.next_id += 1;
424            state.id_to_internal.insert(*exp_id, internal_id);
425            state.internal_to_id.push(*exp_id);
426            batch.push((embedding, internal_id));
427        }
428
429        drop(state);
430
431        // Parallel bulk insert (uses rayon internally)
432        index.hnsw.parallel_insert(&batch);
433
434        Ok(index)
435    }
436
437    /// Removes HNSW files for a collective from disk.
438    pub fn remove_files(dir: &Path, name: &str) -> Result<()> {
439        // Remove metadata file
440        let meta_path = dir.join(format!("{}.hnsw.meta", name));
441        if meta_path.exists() {
442            fs::remove_file(&meta_path).map_err(|e| {
443                PulseDBError::vector(format!("Failed to remove HNSW metadata: {}", e))
444            })?;
445        }
446
447        // Remove graph dump files (hnsw_rs creates files with the name as prefix)
448        if let Ok(entries) = fs::read_dir(dir) {
449            for entry in entries.flatten() {
450                let file_name = entry.file_name();
451                let file_str = file_name.to_string_lossy();
452                if file_str.starts_with(name) && file_str.contains("hnswdump") {
453                    let _ = fs::remove_file(entry.path());
454                }
455            }
456        }
457
458        Ok(())
459    }
460}
461
462// ==========================================================================
463// VectorIndex trait implementation
464// ==========================================================================
465
466impl VectorIndex for HnswIndex {
467    fn insert(&self, id: usize, embedding: &[f32]) -> Result<()> {
468        if embedding.len() != self.dimension {
469            return Err(PulseDBError::vector(format!(
470                "Embedding dimension mismatch: expected {}, got {}",
471                self.dimension,
472                embedding.len()
473            )));
474        }
475        self.hnsw.insert((embedding, id));
476        Ok(())
477    }
478
479    fn insert_batch(&self, items: &[(&Vec<f32>, usize)]) -> Result<()> {
480        self.hnsw.parallel_insert(items);
481        Ok(())
482    }
483
484    fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Result<Vec<(usize, f32)>> {
485        let results = self.hnsw.search(query, k, ef_search);
486        Ok(results.into_iter().map(|n| (n.d_id, n.distance)).collect())
487    }
488
489    fn search_filtered(
490        &self,
491        query: &[f32],
492        k: usize,
493        ef_search: usize,
494        filter: &(dyn Fn(&usize) -> bool + Sync),
495    ) -> Result<Vec<(usize, f32)>> {
496        // Wrap the dyn Fn trait object in FilterBridge to satisfy hnsw_rs's
497        // FilterT requirement (trait objects can't auto-coerce between traits)
498        let bridge = FilterBridge(filter);
499        let results = self.hnsw.search_filter(query, k, ef_search, Some(&bridge));
500        Ok(results.into_iter().map(|n| (n.d_id, n.distance)).collect())
501    }
502
503    fn delete(&self, id: usize) -> Result<()> {
504        let mut state = self
505            .state
506            .write()
507            .map_err(|_| PulseDBError::vector("Index state lock poisoned"))?;
508        state.deleted.insert(id);
509        Ok(())
510    }
511
512    fn is_deleted(&self, id: usize) -> bool {
513        self.state
514            .read()
515            .ok()
516            .is_some_and(|s| s.deleted.contains(&id))
517    }
518
519    fn len(&self) -> usize {
520        self.active_count()
521    }
522
523    fn save(&self, dir: &Path, name: &str) -> Result<()> {
524        self.save_to_dir(dir, name)
525    }
526}
527
528// ==========================================================================
529// Tests
530// ==========================================================================
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use crate::config::HnswConfig;
536
537    fn test_config() -> HnswConfig {
538        HnswConfig {
539            max_nb_connection: 16,
540            ef_construction: 100,
541            ef_search: 50,
542            max_layer: 8,
543            max_elements: 1000,
544        }
545    }
546
547    /// Generates a deterministic embedding from a seed.
548    /// Vectors with close seeds produce similar embeddings.
549    fn make_embedding(seed: u64, dim: usize) -> Vec<f32> {
550        (0..dim)
551            .map(|i| (seed as f32 * 0.1 + i as f32 * 0.01).sin())
552            .collect()
553    }
554
555    #[test]
556    fn test_new_index_is_empty() {
557        let index = HnswIndex::new(384, &test_config());
558        assert_eq!(index.active_count(), 0);
559        assert_eq!(index.total_count(), 0);
560        assert!(index.is_empty());
561    }
562
563    #[test]
564    fn test_insert_and_search() {
565        let dim = 8;
566        let config = test_config();
567        let index = HnswIndex::new(dim, &config);
568
569        // Insert 10 embeddings
570        for i in 0..10u64 {
571            let exp_id = ExperienceId::new();
572            let embedding = make_embedding(i, dim);
573            index.insert_experience(exp_id, &embedding).unwrap();
574        }
575
576        assert_eq!(index.active_count(), 10);
577
578        // Search for something similar to embedding 5
579        let query = make_embedding(5, dim);
580        let results = index.search_experiences(&query, 3, 50).unwrap();
581
582        assert!(!results.is_empty());
583        assert!(results.len() <= 3);
584        // Results should be sorted by distance ascending
585        for w in results.windows(2) {
586            assert!(w[0].1 <= w[1].1, "Results not sorted by distance");
587        }
588    }
589
590    #[test]
591    fn test_insert_idempotent() {
592        let dim = 4;
593        let index = HnswIndex::new(dim, &test_config());
594
595        let exp_id = ExperienceId::new();
596        let embedding = make_embedding(1, dim);
597
598        index.insert_experience(exp_id, &embedding).unwrap();
599        index.insert_experience(exp_id, &embedding).unwrap(); // duplicate
600
601        assert_eq!(index.active_count(), 1);
602    }
603
604    #[test]
605    fn test_dimension_mismatch_rejected() {
606        let index = HnswIndex::new(384, &test_config());
607
608        let exp_id = ExperienceId::new();
609        let wrong_dim = vec![1.0f32; 128]; // wrong dimension
610
611        let result = index.insert_experience(exp_id, &wrong_dim);
612        assert!(result.is_err());
613        assert!(result.unwrap_err().is_vector());
614    }
615
616    #[test]
617    fn test_delete_excludes_from_search() {
618        let dim = 8;
619        let index = HnswIndex::new(dim, &test_config());
620
621        // Insert 5 embeddings, remembering IDs
622        let mut ids = Vec::new();
623        for i in 0..5u64 {
624            let exp_id = ExperienceId::new();
625            index
626                .insert_experience(exp_id, &make_embedding(i, dim))
627                .unwrap();
628            ids.push(exp_id);
629        }
630
631        assert_eq!(index.active_count(), 5);
632
633        // Delete the first one
634        index.delete_experience(ids[0]).unwrap();
635        assert_eq!(index.active_count(), 4);
636        assert!(!index.contains(ids[0]));
637        assert!(index.contains(ids[1]));
638
639        // Search should not return the deleted ID
640        let query = make_embedding(0, dim); // similar to deleted entry
641        let results = index.search_experiences(&query, 10, 50).unwrap();
642        let result_ids: Vec<ExperienceId> = results.iter().map(|r| r.0).collect();
643        assert!(!result_ids.contains(&ids[0]));
644    }
645
646    #[test]
647    fn test_search_k_larger_than_index() {
648        let dim = 4;
649        let index = HnswIndex::new(dim, &test_config());
650
651        let exp_id = ExperienceId::new();
652        index
653            .insert_experience(exp_id, &make_embedding(1, dim))
654            .unwrap();
655
656        // Ask for more results than exist
657        let results = index
658            .search_experiences(&make_embedding(1, dim), 100, 50)
659            .unwrap();
660        assert_eq!(results.len(), 1);
661    }
662
663    #[test]
664    fn test_search_empty_index() {
665        let dim = 4;
666        let index = HnswIndex::new(dim, &test_config());
667
668        let results = index
669            .search_experiences(&make_embedding(1, dim), 10, 50)
670            .unwrap();
671        assert!(results.is_empty());
672    }
673
674    #[test]
675    fn test_rebuild_from_embeddings() {
676        let dim = 8;
677        let config = test_config();
678
679        // Prepare embeddings
680        let embeddings: Vec<(ExperienceId, Vec<f32>)> = (0..20u64)
681            .map(|i| (ExperienceId::new(), make_embedding(i, dim)))
682            .collect();
683
684        let index = HnswIndex::rebuild_from_embeddings(dim, &config, embeddings.clone()).unwrap();
685
686        assert_eq!(index.active_count(), 20);
687
688        // Verify all IDs are searchable
689        let query = make_embedding(10, dim);
690        let results = index.search_experiences(&query, 5, 50).unwrap();
691        assert!(!results.is_empty());
692    }
693
694    #[test]
695    fn test_rebuild_empty() {
696        let dim = 384;
697        let config = test_config();
698        let index = HnswIndex::rebuild_from_embeddings(dim, &config, vec![]).unwrap();
699        assert!(index.is_empty());
700    }
701
702    #[test]
703    fn test_save_and_load_metadata_roundtrip() {
704        let dim = 4;
705        let index = HnswIndex::new(dim, &test_config());
706
707        let mut exp_ids = Vec::new();
708        for i in 0..5u64 {
709            let exp_id = ExperienceId::new();
710            index
711                .insert_experience(exp_id, &make_embedding(i, dim))
712                .unwrap();
713            exp_ids.push(exp_id);
714        }
715        index.delete_experience(exp_ids[2]).unwrap();
716
717        // Save to temp directory
718        let dir = tempfile::tempdir().unwrap();
719        index.save_to_dir(dir.path(), "test_collective").unwrap();
720
721        // Load metadata
722        let metadata = HnswIndex::load_metadata(dir.path(), "test_collective")
723            .unwrap()
724            .expect("Metadata should exist");
725
726        assert_eq!(metadata.dimension, dim);
727        assert_eq!(metadata.next_id, 5);
728        assert_eq!(metadata.id_map.len(), 5);
729        assert_eq!(metadata.deleted.len(), 1);
730        // Deleted set stores ExperienceId UUIDs, not internal IDs
731        assert_eq!(metadata.deleted[0], exp_ids[2].to_string());
732    }
733
734    #[test]
735    fn test_remove_files() {
736        let dim = 4;
737        let index = HnswIndex::new(dim, &test_config());
738        index
739            .insert_experience(ExperienceId::new(), &make_embedding(1, dim))
740            .unwrap();
741
742        let dir = tempfile::tempdir().unwrap();
743        index.save_to_dir(dir.path(), "test_coll").unwrap();
744
745        // Verify files exist
746        let meta_path = dir.path().join("test_coll.hnsw.meta");
747        assert!(meta_path.exists());
748
749        // Remove files
750        HnswIndex::remove_files(dir.path(), "test_coll").unwrap();
751        assert!(!meta_path.exists());
752    }
753
754    #[test]
755    fn test_brute_force_search_returns_all_items() {
756        let dim = 8;
757        let config = test_config();
758        let index = HnswIndex::new(dim, &config);
759
760        // Insert 20 items (well below BRUTE_FORCE_THRESHOLD of 128)
761        let mut ids = Vec::new();
762        for i in 0..20u64 {
763            let exp_id = ExperienceId::new();
764            index
765                .insert_experience(exp_id, &make_embedding(i, dim))
766                .unwrap();
767            ids.push(exp_id);
768        }
769
770        // Search for all 20 — brute-force path must return every one
771        let query = make_embedding(10, dim);
772        let results = index.search_experiences(&query, 20, 50).unwrap();
773        assert_eq!(results.len(), 20, "Brute-force must return all 20 items");
774
775        // Results sorted by distance ascending
776        for w in results.windows(2) {
777            assert!(
778                w[0].1 <= w[1].1,
779                "Brute-force results not sorted: {} > {}",
780                w[0].1,
781                w[1].1
782            );
783        }
784
785        // The exact query match (seed=10) should be first with distance ≈ 0
786        assert_eq!(results[0].0, ids[10]);
787        assert!(
788            results[0].1 < 0.001,
789            "Expected near-zero distance for exact match, got {}",
790            results[0].1
791        );
792    }
793
794    #[test]
795    fn test_brute_force_excludes_deleted() {
796        let dim = 8;
797        let index = HnswIndex::new(dim, &test_config());
798
799        let mut ids = Vec::new();
800        for i in 0..5u64 {
801            let exp_id = ExperienceId::new();
802            index
803                .insert_experience(exp_id, &make_embedding(i, dim))
804                .unwrap();
805            ids.push(exp_id);
806        }
807
808        // Delete one
809        index.delete_experience(ids[2]).unwrap();
810
811        let query = make_embedding(2, dim);
812        let results = index.search_experiences(&query, 10, 50).unwrap();
813        assert_eq!(results.len(), 4, "Should return 4 after deleting 1 of 5");
814        let result_ids: Vec<ExperienceId> = results.iter().map(|r| r.0).collect();
815        assert!(
816            !result_ids.contains(&ids[2]),
817            "Deleted item must be excluded"
818        );
819    }
820
821    #[test]
822    fn test_cosine_distance_identical_vectors() {
823        let dim = 8;
824        let index = HnswIndex::new(dim, &test_config());
825
826        let embedding = make_embedding(42, dim);
827        let exp_id = ExperienceId::new();
828        index.insert_experience(exp_id, &embedding).unwrap();
829
830        // Search with the same vector
831        let results = index.search_experiences(&embedding, 1, 50).unwrap();
832        assert_eq!(results.len(), 1);
833        assert_eq!(results[0].0, exp_id);
834        // Distance should be ~0 for identical vectors
835        assert!(
836            results[0].1 < 0.001,
837            "Expected near-zero distance for identical vectors, got {}",
838            results[0].1
839        );
840    }
841}