oxify_vector/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) Index
2//!
3//! State-of-the-art approximate nearest neighbor (ANN) algorithm.
4//! Provides sub-linear search time with high recall.
5//!
6//! ## Algorithm Overview
7//!
8//! HNSW builds a multi-layer graph where:
9//! - Layer 0 contains all vectors (densest)
10//! - Higher layers contain progressively fewer vectors (sparser)
11//! - Search starts from the top layer and navigates down
12//!
13//! ## Parameters
14//!
15//! - `M`: Maximum connections per node (default: 16)
16//! - `ef_construction`: Build-time quality (default: 200)
17//! - `ef_search`: Search-time quality vs speed trade-off (default: 50)
18//!
19//! ## Example
20//!
21//! ```rust
22//! use oxify_vector::hnsw::{HnswIndex, HnswConfig};
23//! use std::collections::HashMap;
24//!
25//! # fn example() -> anyhow::Result<()> {
26//! // Create embeddings
27//! let mut embeddings = HashMap::new();
28//! embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
29//! embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
30//! embeddings.insert("doc3".to_string(), vec![0.9, 0.8, 0.7]);
31//!
32//! // Build HNSW index
33//! let config = HnswConfig::default();
34//! let mut index = HnswIndex::new(config);
35//! index.build(&embeddings)?;
36//!
37//! // Search for similar documents
38//! let query = vec![0.15, 0.25, 0.35];
39//! let results = index.search(&query, 2)?;
40//!
41//! for result in results {
42//!     println!("{}: score = {:.4}", result.entity_id, result.score);
43//! }
44//! # Ok(())
45//! # }
46//! ```
47
48use crate::filter::{Filter, Metadata};
49use crate::simd;
50use crate::types::{DistanceMetric, SearchResult};
51use anyhow::{anyhow, Result};
52use rand::Rng;
53use serde::{Deserialize, Serialize};
54use std::cmp::Ordering;
55use std::collections::{BinaryHeap, HashMap, HashSet};
56use tracing::{debug, info};
57
58/// HNSW configuration
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HnswConfig {
61    /// Distance metric
62    pub metric: DistanceMetric,
63    /// Maximum number of connections per node
64    pub m: usize,
65    /// Maximum connections at layer 0 (typically 2 * m)
66    pub m0: usize,
67    /// Size of dynamic candidate list during construction
68    pub ef_construction: usize,
69    /// Size of dynamic candidate list during search
70    pub ef_search: usize,
71    /// Normalization factor for level generation (1/ln(m))
72    pub ml: f64,
73    /// Normalize vectors before indexing
74    pub normalize: bool,
75}
76
77impl Default for HnswConfig {
78    fn default() -> Self {
79        let m = 16;
80        Self {
81            metric: DistanceMetric::Cosine,
82            m,
83            m0: m * 2,
84            ef_construction: 200,
85            ef_search: 50,
86            ml: 1.0 / (m as f64).ln(),
87            normalize: true,
88        }
89    }
90}
91
92impl HnswConfig {
93    /// Create config optimized for high recall
94    pub fn high_recall() -> Self {
95        let m = 32;
96        Self {
97            metric: DistanceMetric::Cosine,
98            m,
99            m0: m * 2,
100            ef_construction: 400,
101            ef_search: 100,
102            ml: 1.0 / (m as f64).ln(),
103            normalize: true,
104        }
105    }
106
107    /// Create config optimized for speed
108    pub fn fast() -> Self {
109        let m = 12;
110        Self {
111            metric: DistanceMetric::Cosine,
112            m,
113            m0: m * 2,
114            ef_construction: 100,
115            ef_search: 30,
116            ml: 1.0 / (m as f64).ln(),
117            normalize: true,
118        }
119    }
120}
121
122/// Node in the HNSW graph
123#[allow(dead_code)]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct HnswNode {
126    /// Node index
127    id: usize,
128    /// Maximum layer this node exists on
129    level: usize,
130    /// Neighbors at each layer (layer -> neighbor indices)
131    neighbors: Vec<Vec<usize>>,
132}
133
134impl HnswNode {
135    fn new(id: usize, level: usize) -> Self {
136        Self {
137            id,
138            level,
139            neighbors: vec![Vec::new(); level + 1],
140        }
141    }
142}
143
144/// Candidate for nearest neighbor search (min-heap by distance)
145#[derive(Debug, Clone, Copy)]
146struct Candidate {
147    id: usize,
148    distance: f32,
149}
150
151impl PartialEq for Candidate {
152    fn eq(&self, other: &Self) -> bool {
153        self.distance == other.distance
154    }
155}
156
157impl Eq for Candidate {}
158
159impl PartialOrd for Candidate {
160    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161        Some(self.cmp(other))
162    }
163}
164
165impl Ord for Candidate {
166    fn cmp(&self, other: &Self) -> Ordering {
167        // Reversed for min-heap behavior (smaller distance = higher priority)
168        other
169            .distance
170            .partial_cmp(&self.distance)
171            .unwrap_or(Ordering::Equal)
172    }
173}
174
175/// Max-heap candidate (for maintaining furthest neighbors)
176#[derive(Debug, Clone, Copy)]
177struct MaxCandidate {
178    id: usize,
179    distance: f32,
180}
181
182impl PartialEq for MaxCandidate {
183    fn eq(&self, other: &Self) -> bool {
184        self.distance == other.distance
185    }
186}
187
188impl Eq for MaxCandidate {}
189
190impl PartialOrd for MaxCandidate {
191    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
192        Some(self.cmp(other))
193    }
194}
195
196impl Ord for MaxCandidate {
197    fn cmp(&self, other: &Self) -> Ordering {
198        // Normal ordering for max-heap (larger distance = higher priority)
199        self.distance
200            .partial_cmp(&other.distance)
201            .unwrap_or(Ordering::Equal)
202    }
203}
204
205/// HNSW Index for approximate nearest neighbor search
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct HnswIndex {
208    config: HnswConfig,
209    /// Stored vectors (normalized if config.normalize)
210    vectors: Vec<Vec<f32>>,
211    /// Entity IDs corresponding to vectors
212    entity_ids: Vec<String>,
213    /// Graph nodes
214    nodes: Vec<HnswNode>,
215    /// Entry point (top-level node)
216    entry_point: Option<usize>,
217    /// Maximum level in the graph
218    max_level: usize,
219    /// Vector dimensions
220    dimensions: usize,
221    /// Whether index is built
222    is_built: bool,
223    /// Metadata storage for filtered search
224    metadata: HashMap<String, Metadata>,
225    /// Tombstones for lazy deletion (deleted entity IDs)
226    deleted: HashSet<String>,
227}
228
229impl HnswIndex {
230    /// Create a new HNSW index
231    pub fn new(config: HnswConfig) -> Self {
232        info!(
233            "Initialized HNSW index: m={}, ef_construction={}, ef_search={}",
234            config.m, config.ef_construction, config.ef_search
235        );
236
237        Self {
238            config,
239            vectors: Vec::new(),
240            entity_ids: Vec::new(),
241            nodes: Vec::new(),
242            entry_point: None,
243            max_level: 0,
244            dimensions: 0,
245            is_built: false,
246            metadata: HashMap::new(),
247            deleted: HashSet::new(),
248        }
249    }
250
251    /// Build HNSW index from embeddings
252    pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
253        if embeddings.is_empty() {
254            return Err(anyhow!("Cannot build index from empty embeddings"));
255        }
256
257        info!(
258            "Building HNSW index for {} entities (m={}, ef_construction={})",
259            embeddings.len(),
260            self.config.m,
261            self.config.ef_construction
262        );
263
264        // Reset state
265        self.vectors.clear();
266        self.entity_ids.clear();
267        self.nodes.clear();
268        self.entry_point = None;
269        self.max_level = 0;
270
271        // Store vectors
272        self.dimensions = embeddings.values().next().unwrap().len();
273
274        for (entity_id, vec) in embeddings {
275            let mut v = vec.clone();
276            if self.config.normalize {
277                Self::normalize_vector(&mut v);
278            }
279            self.vectors.push(v);
280            self.entity_ids.push(entity_id.clone());
281        }
282
283        // Insert vectors one by one
284        for i in 0..self.vectors.len() {
285            self.insert_node(i)?;
286        }
287
288        self.is_built = true;
289        info!(
290            "HNSW index built: {} vectors, max_level={}",
291            self.vectors.len(),
292            self.max_level
293        );
294
295        Ok(())
296    }
297
298    /// Insert a single node into the graph
299    fn insert_node(&mut self, id: usize) -> Result<()> {
300        let level = self.random_level();
301        let node = HnswNode::new(id, level);
302        self.nodes.push(node);
303
304        // If this is the first node, just set it as entry point
305        if self.entry_point.is_none() {
306            self.entry_point = Some(id);
307            self.max_level = level;
308            return Ok(());
309        }
310
311        let entry_point = self.entry_point.unwrap();
312
313        // Search from top layer down to node's level + 1
314        let mut current_nearest = entry_point;
315
316        for layer in (level + 1..=self.max_level).rev() {
317            current_nearest = self.greedy_search(id, current_nearest, layer);
318        }
319
320        // Insert at each layer from node's level down to 0
321        for layer in (0..=level.min(self.max_level)).rev() {
322            // Find ef_construction nearest neighbors at this layer
323            let neighbors =
324                self.search_layer(id, current_nearest, self.config.ef_construction, layer);
325
326            // Select M best neighbors
327            let m = if layer == 0 {
328                self.config.m0
329            } else {
330                self.config.m
331            };
332
333            let selected = self.select_neighbors(&neighbors, m);
334
335            // Connect the new node to selected neighbors
336            self.nodes[id].neighbors[layer] = selected.clone();
337
338            // Connect neighbors back to the new node (bidirectional)
339            for &neighbor_id in &selected {
340                self.nodes[neighbor_id].neighbors[layer].push(id);
341
342                // Prune if too many connections
343                let max_connections = if layer == 0 {
344                    self.config.m0
345                } else {
346                    self.config.m
347                };
348
349                if self.nodes[neighbor_id].neighbors[layer].len() > max_connections {
350                    self.prune_connections(neighbor_id, layer, max_connections);
351                }
352            }
353
354            if !selected.is_empty() {
355                current_nearest = selected[0];
356            }
357        }
358
359        // Update entry point if new node has higher level
360        if level > self.max_level {
361            self.entry_point = Some(id);
362            self.max_level = level;
363        }
364
365        Ok(())
366    }
367
368    /// Generate random level for a new node
369    fn random_level(&self) -> usize {
370        let mut rng = rand::rng();
371        let mut level = 0;
372        let uniform: f64 = rng.random();
373
374        // Level follows exponential distribution
375        while uniform < (-((level + 1) as f64) * self.config.ml).exp() && level < 32 {
376            level += 1;
377        }
378
379        level
380    }
381
382    /// Greedy search to find nearest node at a given layer
383    fn greedy_search(&self, query_id: usize, start: usize, layer: usize) -> usize {
384        let query = &self.vectors[query_id];
385        let mut current = start;
386        let mut current_dist = self.compute_distance(query, &self.vectors[current]);
387
388        loop {
389            let mut changed = false;
390
391            for &neighbor in &self.nodes[current].neighbors[layer] {
392                let dist = self.compute_distance(query, &self.vectors[neighbor]);
393                if dist < current_dist {
394                    current = neighbor;
395                    current_dist = dist;
396                    changed = true;
397                }
398            }
399
400            if !changed {
401                break;
402            }
403        }
404
405        current
406    }
407
408    /// Search layer for nearest neighbors
409    fn search_layer(
410        &self,
411        query_id: usize,
412        entry_point: usize,
413        ef: usize,
414        layer: usize,
415    ) -> Vec<(usize, f32)> {
416        let query = &self.vectors[query_id];
417        self.search_layer_by_vector(query, entry_point, ef, layer)
418    }
419
420    /// Search layer for nearest neighbors by vector
421    fn search_layer_by_vector(
422        &self,
423        query: &[f32],
424        entry_point: usize,
425        ef: usize,
426        layer: usize,
427    ) -> Vec<(usize, f32)> {
428        let mut visited = HashSet::new();
429        let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
430        let mut results: BinaryHeap<MaxCandidate> = BinaryHeap::new();
431
432        let entry_dist = self.compute_distance(query, &self.vectors[entry_point]);
433
434        visited.insert(entry_point);
435        candidates.push(Candidate {
436            id: entry_point,
437            distance: entry_dist,
438        });
439        results.push(MaxCandidate {
440            id: entry_point,
441            distance: entry_dist,
442        });
443
444        while let Some(Candidate { id: current, .. }) = candidates.pop() {
445            let furthest_result = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
446
447            // If current candidate is further than worst result, we're done
448            if self.compute_distance(query, &self.vectors[current]) > furthest_result {
449                break;
450            }
451
452            // Check neighbors
453            if layer < self.nodes[current].neighbors.len() {
454                for &neighbor in &self.nodes[current].neighbors[layer] {
455                    if visited.contains(&neighbor) {
456                        continue;
457                    }
458                    visited.insert(neighbor);
459
460                    let dist = self.compute_distance(query, &self.vectors[neighbor]);
461                    let furthest = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
462
463                    if dist < furthest || results.len() < ef {
464                        candidates.push(Candidate {
465                            id: neighbor,
466                            distance: dist,
467                        });
468                        results.push(MaxCandidate {
469                            id: neighbor,
470                            distance: dist,
471                        });
472
473                        // Keep only ef best results
474                        while results.len() > ef {
475                            results.pop();
476                        }
477                    }
478                }
479            }
480        }
481
482        // Convert to sorted vec
483        let mut result_vec: Vec<(usize, f32)> =
484            results.into_iter().map(|c| (c.id, c.distance)).collect();
485        result_vec.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
486        result_vec
487    }
488
489    /// Select best neighbors using simple heuristic
490    fn select_neighbors(&self, candidates: &[(usize, f32)], m: usize) -> Vec<usize> {
491        candidates.iter().take(m).map(|(id, _)| *id).collect()
492    }
493
494    /// Prune connections to maintain max_connections limit
495    fn prune_connections(&mut self, node_id: usize, layer: usize, max_connections: usize) {
496        let node_vec = self.vectors[node_id].clone();
497
498        // Calculate distances to all neighbors
499        let mut neighbor_dists: Vec<(usize, f32)> = self.nodes[node_id].neighbors[layer]
500            .iter()
501            .map(|&n| (n, self.compute_distance(&node_vec, &self.vectors[n])))
502            .collect();
503
504        // Sort by distance
505        neighbor_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
506
507        // Keep only the closest
508        self.nodes[node_id].neighbors[layer] = neighbor_dists
509            .into_iter()
510            .take(max_connections)
511            .map(|(id, _)| id)
512            .collect();
513    }
514
515    /// Compute distance between two vectors
516    ///
517    /// Uses SIMD-optimized distance calculations for better performance.
518    #[inline]
519    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
520        // Use SIMD-optimized implementations for hot path performance
521        simd::compute_distance_lower_is_better_simd(self.config.metric, a, b)
522    }
523
524    /// Normalize vector in-place
525    #[inline]
526    fn normalize_vector(vec: &mut [f32]) {
527        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
528        if norm > 1e-10 {
529            for x in vec.iter_mut() {
530                *x /= norm;
531            }
532        }
533    }
534
535    /// Search for K nearest neighbors
536    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
537        if !self.is_built {
538            return Err(anyhow!("Index not built. Call build() first"));
539        }
540
541        if query.len() != self.dimensions {
542            return Err(anyhow!(
543                "Query dimension {} doesn't match index dimension {}",
544                query.len(),
545                self.dimensions
546            ));
547        }
548
549        // Normalize query if needed
550        let mut normalized_query = query.to_vec();
551        if self.config.normalize {
552            Self::normalize_vector(&mut normalized_query);
553        }
554
555        debug!("HNSW search: k={}, ef_search={}", k, self.config.ef_search);
556
557        let entry_point = self.entry_point.ok_or_else(|| anyhow!("Empty index"))?;
558
559        // Navigate from top layer to layer 1
560        let mut current = entry_point;
561        for layer in (1..=self.max_level).rev() {
562            current = self.greedy_search_by_vector(&normalized_query, current, layer);
563        }
564
565        // Search at layer 0 with ef_search candidates
566        let candidates =
567            self.search_layer_by_vector(&normalized_query, current, self.config.ef_search, 0);
568
569        // Return top-k results, filtering out deleted vectors
570        let results: Vec<SearchResult> = candidates
571            .into_iter()
572            .filter(|(id, _)| !self.deleted.contains(&self.entity_ids[*id]))
573            .take(k)
574            .enumerate()
575            .map(|(rank, (id, distance))| SearchResult {
576                entity_id: self.entity_ids[id].clone(),
577                score: self.distance_to_score(distance),
578                distance,
579                rank: rank + 1,
580            })
581            .collect();
582
583        debug!("Found {} results", results.len());
584        Ok(results)
585    }
586
587    /// Greedy search by vector
588    fn greedy_search_by_vector(&self, query: &[f32], start: usize, layer: usize) -> usize {
589        let mut current = start;
590        let mut current_dist = self.compute_distance(query, &self.vectors[current]);
591
592        loop {
593            let mut changed = false;
594
595            if layer < self.nodes[current].neighbors.len() {
596                for &neighbor in &self.nodes[current].neighbors[layer] {
597                    let dist = self.compute_distance(query, &self.vectors[neighbor]);
598                    if dist < current_dist {
599                        current = neighbor;
600                        current_dist = dist;
601                        changed = true;
602                    }
603                }
604            }
605
606            if !changed {
607                break;
608            }
609        }
610
611        current
612    }
613
614    /// Convert distance to similarity score
615    fn distance_to_score(&self, distance: f32) -> f32 {
616        match self.config.metric {
617            DistanceMetric::Cosine => 1.0 - distance,
618            DistanceMetric::Euclidean | DistanceMetric::Manhattan => -distance,
619            DistanceMetric::DotProduct => -distance,
620        }
621    }
622
623    /// Batch search for multiple queries
624    pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
625        if !self.is_built {
626            return Err(anyhow!("Index not built. Call build() first"));
627        }
628
629        info!("HNSW batch search: {} queries", queries.len());
630
631        let results: Vec<Vec<SearchResult>> = queries
632            .iter()
633            .map(|query| self.search(query, k).unwrap_or_default())
634            .collect();
635
636        Ok(results)
637    }
638
639    /// Add a new vector to the index (incremental update)
640    pub fn add(&mut self, entity_id: &str, vector: &[f32]) -> Result<()> {
641        if !self.is_built {
642            return Err(anyhow!(
643                "Index not built. Call build() first or use build() with initial data"
644            ));
645        }
646
647        if vector.len() != self.dimensions {
648            return Err(anyhow!(
649                "Vector dimension {} doesn't match index dimension {}",
650                vector.len(),
651                self.dimensions
652            ));
653        }
654
655        // Add vector
656        let mut v = vector.to_vec();
657        if self.config.normalize {
658            Self::normalize_vector(&mut v);
659        }
660
661        let id = self.vectors.len();
662        self.vectors.push(v);
663        self.entity_ids.push(entity_id.to_string());
664
665        // Insert into graph
666        self.insert_node(id)?;
667
668        debug!("Added vector '{}' to HNSW index", entity_id);
669        Ok(())
670    }
671
672    /// Get index statistics
673    pub fn get_stats(&self) -> HnswStats {
674        let total_connections: usize = self
675            .nodes
676            .iter()
677            .flat_map(|n| n.neighbors.iter())
678            .map(|neighbors| neighbors.len())
679            .sum();
680
681        let avg_connections = if !self.nodes.is_empty() {
682            total_connections as f64 / self.nodes.len() as f64
683        } else {
684            0.0
685        };
686
687        HnswStats {
688            num_vectors: self.vectors.len(),
689            active_vectors: self.active_count(),
690            deleted_vectors: self.deleted_count(),
691            dimensions: self.dimensions,
692            max_level: self.max_level,
693            avg_connections,
694            m: self.config.m,
695            ef_construction: self.config.ef_construction,
696            ef_search: self.config.ef_search,
697            is_built: self.is_built,
698        }
699    }
700
701    /// Set ef_search parameter (trade-off between speed and recall)
702    pub fn set_ef_search(&mut self, ef: usize) {
703        self.config.ef_search = ef;
704    }
705
706    /// Remove a vector from the index (lazy deletion with tombstone)
707    ///
708    /// The vector is marked as deleted and will be excluded from search results.
709    /// The actual data is not removed until `compact()` is called.
710    pub fn remove(&mut self, entity_id: &str) -> bool {
711        if self.entity_ids.iter().any(|e| e == entity_id) {
712            self.deleted.insert(entity_id.to_string());
713            self.metadata.remove(entity_id);
714            debug!("Marked '{}' as deleted (tombstone)", entity_id);
715            true
716        } else {
717            false
718        }
719    }
720
721    /// Check if a vector is deleted
722    pub fn is_deleted(&self, entity_id: &str) -> bool {
723        self.deleted.contains(entity_id)
724    }
725
726    /// Get the number of deleted vectors (tombstones)
727    pub fn deleted_count(&self) -> usize {
728        self.deleted.len()
729    }
730
731    /// Get the number of active (non-deleted) vectors
732    pub fn active_count(&self) -> usize {
733        self.vectors.len() - self.deleted.len()
734    }
735
736    /// Set metadata for an entity
737    pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
738        self.metadata.insert(entity_id.to_string(), metadata);
739    }
740
741    /// Set metadata for multiple entities
742    pub fn set_metadata_batch(&mut self, metadata_map: HashMap<String, Metadata>) {
743        self.metadata.extend(metadata_map);
744    }
745
746    /// Get metadata for an entity
747    #[inline]
748    pub fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
749        self.metadata.get(entity_id)
750    }
751
752    /// Search with metadata filtering (post-filtering)
753    ///
754    /// Searches with HNSW, then filters results by metadata.
755    /// Efficient when most results pass the filter.
756    pub fn filtered_search(
757        &self,
758        query: &[f32],
759        k: usize,
760        filter: &Filter,
761    ) -> Result<Vec<SearchResult>> {
762        if !self.is_built {
763            return Err(anyhow!("Index not built. Call build() first"));
764        }
765
766        if filter.is_empty() {
767            return self.search(query, k);
768        }
769
770        // Increase ef_search temporarily to get more candidates for filtering
771        let expanded_k = (k * 10).min(self.vectors.len());
772
773        debug!(
774            "HNSW filtered search: k={}, expanded_k={}, filter conditions={}",
775            k,
776            expanded_k,
777            filter.conditions().len()
778        );
779
780        // Get more candidates than needed
781        let all_results = self.search(query, expanded_k)?;
782
783        // Filter and take top-k
784        let filtered: Vec<SearchResult> = all_results
785            .into_iter()
786            .filter(|r| {
787                self.metadata
788                    .get(&r.entity_id)
789                    .is_some_and(|m| filter.matches(m))
790            })
791            .take(k)
792            .enumerate()
793            .map(|(i, mut r)| {
794                r.rank = i + 1; // Re-rank after filtering
795                r
796            })
797            .collect();
798
799        debug!("HNSW filtered search returned {} results", filtered.len());
800        Ok(filtered)
801    }
802
803    /// Search with pre-filtering (filter candidates during search)
804    ///
805    /// More accurate when filters are very selective, but may be slower.
806    /// Uses brute-force search on filtered candidates.
807    pub fn prefiltered_search(
808        &self,
809        query: &[f32],
810        k: usize,
811        filter: &Filter,
812    ) -> Result<Vec<SearchResult>> {
813        if !self.is_built {
814            return Err(anyhow!("Index not built. Call build() first"));
815        }
816
817        if query.len() != self.dimensions {
818            return Err(anyhow!(
819                "Query dimension {} doesn't match index dimension {}",
820                query.len(),
821                self.dimensions
822            ));
823        }
824
825        if filter.is_empty() {
826            return self.search(query, k);
827        }
828
829        debug!("HNSW pre-filtered search: k={}", k);
830
831        // Normalize query if needed
832        let mut normalized_query = query.to_vec();
833        if self.config.normalize {
834            Self::normalize_vector(&mut normalized_query);
835        }
836
837        // Find all indices that match the filter
838        let matching_indices: Vec<usize> = (0..self.entity_ids.len())
839            .filter(|&i| {
840                self.metadata
841                    .get(&self.entity_ids[i])
842                    .is_some_and(|m| filter.matches(m))
843            })
844            .collect();
845
846        if matching_indices.is_empty() {
847            return Ok(Vec::new());
848        }
849
850        // Compute distances only for matching entities (brute force on filtered set)
851        let mut scores: Vec<(usize, f32)> = matching_indices
852            .iter()
853            .map(|&i| {
854                let dist = self.compute_distance(&normalized_query, &self.vectors[i]);
855                (i, dist)
856            })
857            .collect();
858
859        // Sort by distance ascending
860        scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
861
862        // Return top-K results
863        let results: Vec<SearchResult> = scores
864            .iter()
865            .take(k)
866            .enumerate()
867            .map(|(rank, &(idx, distance))| SearchResult {
868                entity_id: self.entity_ids[idx].clone(),
869                score: self.distance_to_score(distance),
870                distance,
871                rank: rank + 1,
872            })
873            .collect();
874
875        debug!(
876            "HNSW pre-filtered search returned {} results",
877            results.len()
878        );
879        Ok(results)
880    }
881
882    /// Optimize the HNSW graph structure
883    ///
884    /// This method performs periodic maintenance on the graph to improve search quality:
885    /// - Removes references to deleted nodes from neighbor lists
886    /// - Trims neighbor lists to respect the M parameter
887    /// - Rebuilds connections for isolated nodes
888    ///
889    /// Call this method periodically after many deletions or additions to maintain
890    /// optimal search performance.
891    pub fn optimize_graph(&mut self) -> Result<()> {
892        if !self.is_built {
893            return Err(anyhow!("Index not built. Call build() first"));
894        }
895
896        info!("Optimizing HNSW graph structure...");
897
898        let deleted_indices: HashSet<usize> = self
899            .entity_ids
900            .iter()
901            .enumerate()
902            .filter(|(_, id)| self.deleted.contains(*id))
903            .map(|(idx, _)| idx)
904            .collect();
905
906        let mut optimized_count = 0;
907
908        // Process each node
909        for node_idx in 0..self.nodes.len() {
910            let node_level = self.nodes[node_idx].level;
911
912            for layer in 0..=node_level {
913                let original_len = self.nodes[node_idx].neighbors[layer].len();
914
915                // Remove deleted neighbors
916                self.nodes[node_idx].neighbors[layer]
917                    .retain(|&neighbor_id| !deleted_indices.contains(&neighbor_id));
918
919                // Trim to max connections (M for layer > 0, M0 for layer 0)
920                let max_connections = if layer == 0 {
921                    self.config.m0
922                } else {
923                    self.config.m
924                };
925
926                if self.nodes[node_idx].neighbors[layer].len() > max_connections {
927                    // Collect neighbors and compute distances
928                    let node_vec = self.vectors[node_idx].clone();
929                    let mut neighbor_distances: Vec<(usize, f32)> = self.nodes[node_idx].neighbors
930                        [layer]
931                        .iter()
932                        .map(|&neighbor_id| {
933                            let dist = self.compute_distance(&node_vec, &self.vectors[neighbor_id]);
934                            (neighbor_id, dist)
935                        })
936                        .collect();
937
938                    // Sort by distance (ascending)
939                    neighbor_distances
940                        .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
941
942                    // Keep only top-M closest neighbors
943                    self.nodes[node_idx].neighbors[layer] = neighbor_distances
944                        .iter()
945                        .take(max_connections)
946                        .map(|(id, _)| *id)
947                        .collect();
948                }
949
950                if self.nodes[node_idx].neighbors[layer].len() != original_len {
951                    optimized_count += 1;
952                }
953            }
954        }
955
956        info!(
957            "HNSW graph optimization complete. {} node connections updated.",
958            optimized_count
959        );
960
961        Ok(())
962    }
963
964    /// Compact the index by removing tombstones
965    ///
966    /// This method rebuilds the index without deleted vectors, freeing memory
967    /// and improving cache efficiency. Use this after many deletions.
968    ///
969    /// WARNING: This operation is expensive as it rebuilds the entire index.
970    pub fn compact(&mut self) -> Result<()> {
971        if !self.is_built {
972            return Err(anyhow!("Index not built. Call build() first"));
973        }
974
975        if self.deleted.is_empty() {
976            info!("No deleted vectors to compact");
977            return Ok(());
978        }
979
980        info!(
981            "Compacting HNSW index: removing {} deleted vectors out of {}",
982            self.deleted.len(),
983            self.vectors.len()
984        );
985
986        // Collect non-deleted vectors and their entity IDs
987        let mut new_embeddings = HashMap::new();
988        let mut new_metadata = HashMap::new();
989
990        for (i, entity_id) in self.entity_ids.iter().enumerate() {
991            if !self.deleted.contains(entity_id) {
992                new_embeddings.insert(entity_id.clone(), self.vectors[i].clone());
993
994                if let Some(metadata) = self.metadata.get(entity_id) {
995                    new_metadata.insert(entity_id.clone(), metadata.clone());
996                }
997            }
998        }
999
1000        // Rebuild the index
1001        self.build(&new_embeddings)?;
1002
1003        // Restore metadata
1004        self.set_metadata_batch(new_metadata);
1005
1006        info!("HNSW index compaction complete");
1007
1008        Ok(())
1009    }
1010}
1011
1012/// HNSW index statistics
1013#[derive(Debug, Clone, Serialize, Deserialize)]
1014pub struct HnswStats {
1015    /// Total number of vectors in the index (including deleted)
1016    pub num_vectors: usize,
1017    /// Number of active (non-deleted) vectors
1018    pub active_vectors: usize,
1019    /// Number of deleted vectors (tombstones)
1020    pub deleted_vectors: usize,
1021    /// Vector dimensions
1022    pub dimensions: usize,
1023    /// Maximum level in the graph
1024    pub max_level: usize,
1025    /// Average connections per node
1026    pub avg_connections: f64,
1027    /// M parameter
1028    pub m: usize,
1029    /// ef_construction parameter
1030    pub ef_construction: usize,
1031    /// ef_search parameter
1032    pub ef_search: usize,
1033    /// Whether index is built
1034    pub is_built: bool,
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    fn create_test_embeddings() -> HashMap<String, Vec<f32>> {
1042        let mut embeddings = HashMap::new();
1043
1044        embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
1045        embeddings.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
1046        embeddings.insert("doc3".to_string(), vec![0.0, 1.0, 0.0]);
1047        embeddings.insert("doc4".to_string(), vec![0.0, 0.0, 1.0]);
1048        embeddings.insert("doc5".to_string(), vec![0.7, 0.7, 0.0]);
1049
1050        embeddings
1051    }
1052
1053    #[test]
1054    fn test_hnsw_config_default() {
1055        let config = HnswConfig::default();
1056        assert_eq!(config.m, 16);
1057        assert_eq!(config.m0, 32);
1058        assert_eq!(config.ef_construction, 200);
1059        assert_eq!(config.ef_search, 50);
1060    }
1061
1062    #[test]
1063    fn test_hnsw_build() {
1064        let embeddings = create_test_embeddings();
1065        let mut index = HnswIndex::new(HnswConfig::default());
1066
1067        assert!(index.build(&embeddings).is_ok());
1068        assert!(index.is_built);
1069
1070        let stats = index.get_stats();
1071        assert_eq!(stats.num_vectors, 5);
1072        assert_eq!(stats.dimensions, 3);
1073    }
1074
1075    #[test]
1076    fn test_hnsw_search() {
1077        let embeddings = create_test_embeddings();
1078        let mut index = HnswIndex::new(HnswConfig::default());
1079        index.build(&embeddings).unwrap();
1080
1081        // Search for vector similar to doc1
1082        let query = vec![1.0, 0.0, 0.0];
1083        let results = index.search(&query, 3).unwrap();
1084
1085        assert_eq!(results.len(), 3);
1086        // doc1 or doc2 should be the closest
1087        assert!(results[0].entity_id == "doc1" || results[0].entity_id == "doc2");
1088    }
1089
1090    #[test]
1091    fn test_hnsw_batch_search() {
1092        let embeddings = create_test_embeddings();
1093        let mut index = HnswIndex::new(HnswConfig::default());
1094        index.build(&embeddings).unwrap();
1095
1096        let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
1097
1098        let results = index.batch_search(&queries, 2).unwrap();
1099        assert_eq!(results.len(), 2);
1100        assert_eq!(results[0].len(), 2);
1101        assert_eq!(results[1].len(), 2);
1102    }
1103
1104    #[test]
1105    fn test_hnsw_incremental_add() {
1106        let embeddings = create_test_embeddings();
1107        let mut index = HnswIndex::new(HnswConfig::default());
1108        index.build(&embeddings).unwrap();
1109
1110        // Add a new vector
1111        index.add("doc6", &[0.5, 0.5, 0.5]).unwrap();
1112
1113        let stats = index.get_stats();
1114        assert_eq!(stats.num_vectors, 6);
1115
1116        // Search should find the new vector
1117        let query = vec![0.5, 0.5, 0.5];
1118        let results = index.search(&query, 1).unwrap();
1119        assert_eq!(results[0].entity_id, "doc6");
1120    }
1121
1122    #[test]
1123    fn test_hnsw_search_accuracy() {
1124        // Create a larger test set
1125        let mut embeddings = HashMap::new();
1126        for i in 0..100 {
1127            let angle = (i as f32) * 2.0 * std::f32::consts::PI / 100.0;
1128            embeddings.insert(format!("doc{}", i), vec![angle.cos(), angle.sin(), 0.0]);
1129        }
1130
1131        let mut index = HnswIndex::new(HnswConfig::default());
1132        index.build(&embeddings).unwrap();
1133
1134        // Search for a specific angle
1135        let query_angle = 0.5_f32;
1136        let query = vec![query_angle.cos(), query_angle.sin(), 0.0];
1137        let results = index.search(&query, 5).unwrap();
1138
1139        // Should find nearby angles
1140        assert_eq!(results.len(), 5);
1141        // Top result should have high similarity
1142        assert!(results[0].score > 0.95);
1143    }
1144
1145    #[test]
1146    fn test_hnsw_empty_error() {
1147        let embeddings: HashMap<String, Vec<f32>> = HashMap::new();
1148        let mut index = HnswIndex::new(HnswConfig::default());
1149
1150        assert!(index.build(&embeddings).is_err());
1151    }
1152
1153    #[test]
1154    fn test_hnsw_dimension_mismatch() {
1155        let embeddings = create_test_embeddings();
1156        let mut index = HnswIndex::new(HnswConfig::default());
1157        index.build(&embeddings).unwrap();
1158
1159        // Query with wrong dimensions
1160        let query = vec![1.0, 0.0]; // 2D instead of 3D
1161        assert!(index.search(&query, 1).is_err());
1162    }
1163
1164    #[test]
1165    fn test_hnsw_stats() {
1166        let embeddings = create_test_embeddings();
1167        let mut index = HnswIndex::new(HnswConfig::default());
1168        index.build(&embeddings).unwrap();
1169
1170        let stats = index.get_stats();
1171        assert_eq!(stats.num_vectors, 5);
1172        assert_eq!(stats.dimensions, 3);
1173        assert_eq!(stats.m, 16);
1174        assert_eq!(stats.ef_construction, 200);
1175        assert!(stats.is_built);
1176    }
1177
1178    #[test]
1179    fn test_ef_search_adjustment() {
1180        let embeddings = create_test_embeddings();
1181        let mut index = HnswIndex::new(HnswConfig::default());
1182        index.build(&embeddings).unwrap();
1183
1184        index.set_ef_search(100);
1185        let stats = index.get_stats();
1186        assert_eq!(stats.ef_search, 100);
1187    }
1188
1189    fn create_test_metadata() -> HashMap<String, Metadata> {
1190        use crate::filter::FilterValue;
1191
1192        let mut metadata = HashMap::new();
1193
1194        let mut m1 = HashMap::new();
1195        m1.insert(
1196            "type".to_string(),
1197            FilterValue::String("article".to_string()),
1198        );
1199        m1.insert("year".to_string(), FilterValue::Int(2023));
1200        metadata.insert("doc1".to_string(), m1);
1201
1202        let mut m2 = HashMap::new();
1203        m2.insert(
1204            "type".to_string(),
1205            FilterValue::String("article".to_string()),
1206        );
1207        m2.insert("year".to_string(), FilterValue::Int(2022));
1208        metadata.insert("doc2".to_string(), m2);
1209
1210        let mut m3 = HashMap::new();
1211        m3.insert("type".to_string(), FilterValue::String("book".to_string()));
1212        m3.insert("year".to_string(), FilterValue::Int(2023));
1213        metadata.insert("doc3".to_string(), m3);
1214
1215        let mut m4 = HashMap::new();
1216        m4.insert("type".to_string(), FilterValue::String("book".to_string()));
1217        m4.insert("year".to_string(), FilterValue::Int(2021));
1218        metadata.insert("doc4".to_string(), m4);
1219
1220        let mut m5 = HashMap::new();
1221        m5.insert(
1222            "type".to_string(),
1223            FilterValue::String("article".to_string()),
1224        );
1225        m5.insert("year".to_string(), FilterValue::Int(2024));
1226        metadata.insert("doc5".to_string(), m5);
1227
1228        metadata
1229    }
1230
1231    #[test]
1232    fn test_hnsw_set_and_get_metadata() {
1233        use crate::filter::FilterValue;
1234
1235        let embeddings = create_test_embeddings();
1236        let mut index = HnswIndex::new(HnswConfig::default());
1237        index.build(&embeddings).unwrap();
1238
1239        let mut metadata = HashMap::new();
1240        metadata.insert(
1241            "type".to_string(),
1242            FilterValue::String("article".to_string()),
1243        );
1244
1245        index.set_metadata("doc1", metadata.clone());
1246
1247        let retrieved = index.get_metadata("doc1");
1248        assert!(retrieved.is_some());
1249        assert_eq!(
1250            retrieved.unwrap().get("type"),
1251            Some(&FilterValue::String("article".to_string()))
1252        );
1253    }
1254
1255    #[test]
1256    fn test_hnsw_filtered_search() {
1257        use crate::filter::FilterValue;
1258
1259        let embeddings = create_test_embeddings();
1260        let metadata = create_test_metadata();
1261        let mut index = HnswIndex::new(HnswConfig::default());
1262        index.build(&embeddings).unwrap();
1263        index.set_metadata_batch(metadata);
1264
1265        // Filter for articles only
1266        let filter = Filter::new().eq("type", "article");
1267        let query = vec![1.0, 0.0, 0.0];
1268        let results = index.filtered_search(&query, 5, &filter).unwrap();
1269
1270        // Should only return articles (doc1, doc2, doc5)
1271        assert_eq!(results.len(), 3);
1272        for result in &results {
1273            let meta = index.get_metadata(&result.entity_id).unwrap();
1274            assert_eq!(
1275                meta.get("type"),
1276                Some(&FilterValue::String("article".to_string()))
1277            );
1278        }
1279    }
1280
1281    #[test]
1282    fn test_hnsw_filtered_search_with_year() {
1283        let embeddings = create_test_embeddings();
1284        let metadata = create_test_metadata();
1285        let mut index = HnswIndex::new(HnswConfig::default());
1286        index.build(&embeddings).unwrap();
1287        index.set_metadata_batch(metadata);
1288
1289        // Filter for year >= 2023
1290        let filter = Filter::new().gte("year", 2023i64);
1291        let query = vec![1.0, 0.0, 0.0];
1292        let results = index.filtered_search(&query, 5, &filter).unwrap();
1293
1294        // Should return doc1 (2023), doc3 (2023), doc5 (2024)
1295        assert_eq!(results.len(), 3);
1296    }
1297
1298    #[test]
1299    fn test_hnsw_prefiltered_search() {
1300        use crate::filter::FilterValue;
1301
1302        let embeddings = create_test_embeddings();
1303        let metadata = create_test_metadata();
1304        let mut index = HnswIndex::new(HnswConfig::default());
1305        index.build(&embeddings).unwrap();
1306        index.set_metadata_batch(metadata);
1307
1308        // Filter for books only
1309        let filter = Filter::new().eq("type", "book");
1310        let query = vec![0.0, 1.0, 0.0]; // Similar to doc3 (book)
1311        let results = index.prefiltered_search(&query, 5, &filter).unwrap();
1312
1313        // Should only return books (doc3, doc4)
1314        assert_eq!(results.len(), 2);
1315        for result in &results {
1316            let meta = index.get_metadata(&result.entity_id).unwrap();
1317            assert_eq!(
1318                meta.get("type"),
1319                Some(&FilterValue::String("book".to_string()))
1320            );
1321        }
1322    }
1323
1324    #[test]
1325    fn test_hnsw_filtered_search_empty_filter() {
1326        let embeddings = create_test_embeddings();
1327        let mut index = HnswIndex::new(HnswConfig::default());
1328        index.build(&embeddings).unwrap();
1329
1330        // Empty filter should return all results
1331        let filter = Filter::new();
1332        let query = vec![1.0, 0.0, 0.0];
1333        let results = index.filtered_search(&query, 3, &filter).unwrap();
1334
1335        assert_eq!(results.len(), 3);
1336    }
1337
1338    #[test]
1339    fn test_hnsw_filtered_search_no_matches() {
1340        let embeddings = create_test_embeddings();
1341        let metadata = create_test_metadata();
1342        let mut index = HnswIndex::new(HnswConfig::default());
1343        index.build(&embeddings).unwrap();
1344        index.set_metadata_batch(metadata);
1345
1346        // Filter for non-existent type
1347        let filter = Filter::new().eq("type", "journal");
1348        let query = vec![1.0, 0.0, 0.0];
1349        let results = index.filtered_search(&query, 5, &filter).unwrap();
1350
1351        assert_eq!(results.len(), 0);
1352    }
1353
1354    #[test]
1355    fn test_hnsw_lazy_delete() {
1356        let embeddings = create_test_embeddings();
1357        let mut index = HnswIndex::new(HnswConfig::default());
1358        index.build(&embeddings).unwrap();
1359
1360        let stats_before = index.get_stats();
1361        assert_eq!(stats_before.num_vectors, 5);
1362        assert_eq!(stats_before.active_vectors, 5);
1363        assert_eq!(stats_before.deleted_vectors, 0);
1364
1365        // Delete a vector
1366        assert!(index.remove("doc1"));
1367        assert!(index.is_deleted("doc1"));
1368
1369        let stats_after = index.get_stats();
1370        assert_eq!(stats_after.num_vectors, 5); // Still in index
1371        assert_eq!(stats_after.active_vectors, 4);
1372        assert_eq!(stats_after.deleted_vectors, 1);
1373
1374        // Deleted vector should not appear in search results
1375        let query = vec![1.0, 0.0, 0.0]; // doc1's vector
1376        let results = index.search(&query, 5).unwrap();
1377
1378        // doc1 should not be in results
1379        for result in &results {
1380            assert_ne!(result.entity_id, "doc1");
1381        }
1382        assert_eq!(results.len(), 4);
1383    }
1384
1385    #[test]
1386    fn test_hnsw_delete_nonexistent() {
1387        let embeddings = create_test_embeddings();
1388        let mut index = HnswIndex::new(HnswConfig::default());
1389        index.build(&embeddings).unwrap();
1390
1391        // Try to delete non-existent vector
1392        assert!(!index.remove("nonexistent"));
1393        assert!(!index.is_deleted("nonexistent"));
1394    }
1395
1396    #[test]
1397    fn test_hnsw_delete_multiple() {
1398        let embeddings = create_test_embeddings();
1399        let mut index = HnswIndex::new(HnswConfig::default());
1400        index.build(&embeddings).unwrap();
1401
1402        // Delete multiple vectors
1403        index.remove("doc1");
1404        index.remove("doc2");
1405        index.remove("doc3");
1406
1407        let stats = index.get_stats();
1408        assert_eq!(stats.active_vectors, 2);
1409        assert_eq!(stats.deleted_vectors, 3);
1410
1411        // Search should only return non-deleted vectors
1412        let query = vec![0.5, 0.5, 0.5];
1413        let results = index.search(&query, 10).unwrap();
1414        assert_eq!(results.len(), 2);
1415    }
1416
1417    #[test]
1418    fn test_hnsw_delete_and_active_count() {
1419        let embeddings = create_test_embeddings();
1420        let mut index = HnswIndex::new(HnswConfig::default());
1421        index.build(&embeddings).unwrap();
1422
1423        assert_eq!(index.active_count(), 5);
1424        assert_eq!(index.deleted_count(), 0);
1425
1426        index.remove("doc1");
1427        assert_eq!(index.active_count(), 4);
1428        assert_eq!(index.deleted_count(), 1);
1429
1430        index.remove("doc2");
1431        assert_eq!(index.active_count(), 3);
1432        assert_eq!(index.deleted_count(), 2);
1433    }
1434}