Skip to main content

engine/
hnsw.rs

1//! HNSW (Hierarchical Navigable Small World) Index Implementation
2//!
3//! HNSW is a graph-based approximate nearest neighbor search algorithm that provides
4//! excellent query performance with high recall. It builds a multi-layer graph where:
5//! - Each layer contains a subset of nodes from the layer below
6//! - Top layers enable fast coarse navigation
7//! - Bottom layers provide fine-grained search
8//!
9//! Key parameters:
10//! - M: Maximum number of connections per node (default: 16)
11//! - ef_construction: Search width during index building (default: 200)
12//! - ef_search: Search width during query (default: 50)
13
14use common::types::{DistanceMetric, VectorId};
15use parking_lot::RwLock;
16use rand::Rng;
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, HashMap, HashSet};
19
20use crate::distance::calculate_distance;
21
22/// Convert similarity score to distance (lower = closer)
23/// calculate_distance returns similarity: higher = more similar
24/// We need distance: lower = closer for HNSW graph traversal
25#[inline]
26fn similarity_to_distance(similarity: f32, metric: DistanceMetric) -> f32 {
27    match metric {
28        // Cosine: similarity in [-1, 1], distance = 1 - similarity, so 0 = identical
29        DistanceMetric::Cosine => 1.0 - similarity,
30        // Euclidean: returns negative distance, negate to get positive distance
31        DistanceMetric::Euclidean => -similarity,
32        // Dot product: higher = more similar, negate for distance
33        DistanceMetric::DotProduct => -similarity,
34    }
35}
36
37/// HNSW index configuration
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct HnswConfig {
40    /// Maximum number of connections per node at each layer
41    pub m: usize,
42    /// Maximum connections at layer 0 (typically 2 * M)
43    pub m_max0: usize,
44    /// Search width during construction
45    pub ef_construction: usize,
46    /// Default search width during queries
47    pub ef_search: usize,
48    /// Level generation multiplier (1/ln(M))
49    pub level_multiplier: f64,
50    /// Distance metric to use
51    pub distance_metric: DistanceMetric,
52}
53
54impl Default for HnswConfig {
55    fn default() -> Self {
56        let m = 16;
57        Self {
58            m,
59            m_max0: m * 2,
60            ef_construction: 200,
61            ef_search: 50,
62            level_multiplier: 1.0 / (m as f64).ln(),
63            distance_metric: DistanceMetric::Cosine,
64        }
65    }
66}
67
68impl HnswConfig {
69    pub fn new(m: usize, ef_construction: usize, ef_search: usize) -> Self {
70        Self {
71            m,
72            m_max0: m * 2,
73            ef_construction,
74            ef_search,
75            level_multiplier: 1.0 / (m as f64).ln(),
76            distance_metric: DistanceMetric::Cosine,
77        }
78    }
79
80    pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
81        self.distance_metric = metric;
82        self
83    }
84}
85
86/// A node in the HNSW graph
87#[derive(Debug)]
88struct HnswNode {
89    /// The vector ID
90    id: VectorId,
91    /// The vector data
92    vector: Vec<f32>,
93    /// Connections at each layer (layer -> neighbor IDs)
94    connections: Vec<Vec<usize>>,
95    /// Maximum layer this node exists in
96    max_layer: usize,
97}
98
99/// Candidate for search with distance ordering
100#[derive(Debug, Clone)]
101struct Candidate {
102    node_idx: usize,
103    distance: f32,
104}
105
106impl PartialEq for Candidate {
107    fn eq(&self, other: &Self) -> bool {
108        self.distance == other.distance
109    }
110}
111
112impl Eq for Candidate {}
113
114impl PartialOrd for Candidate {
115    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
116        Some(self.cmp(other))
117    }
118}
119
120impl Ord for Candidate {
121    fn cmp(&self, other: &Self) -> Ordering {
122        // Min-heap: smaller distance = higher priority
123        other
124            .distance
125            .partial_cmp(&self.distance)
126            .unwrap_or(Ordering::Equal)
127    }
128}
129
130/// Candidate for max-heap (furthest first)
131#[derive(Debug, Clone)]
132struct FurthestCandidate {
133    node_idx: usize,
134    distance: f32,
135}
136
137impl PartialEq for FurthestCandidate {
138    fn eq(&self, other: &Self) -> bool {
139        self.distance == other.distance
140    }
141}
142
143impl Eq for FurthestCandidate {}
144
145impl PartialOrd for FurthestCandidate {
146    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
147        Some(self.cmp(other))
148    }
149}
150
151impl Ord for FurthestCandidate {
152    fn cmp(&self, other: &Self) -> Ordering {
153        // Max-heap: larger distance = higher priority
154        self.distance
155            .partial_cmp(&other.distance)
156            .unwrap_or(Ordering::Equal)
157    }
158}
159
160/// HNSW Index for approximate nearest neighbor search
161pub struct HnswIndex {
162    config: HnswConfig,
163    /// All nodes in the graph
164    nodes: RwLock<Vec<HnswNode>>,
165    /// Entry point (node index) for searches
166    entry_point: RwLock<Option<usize>>,
167    /// Current maximum layer in the graph
168    max_level: RwLock<usize>,
169    /// Map from vector ID to node index
170    id_to_idx: RwLock<HashMap<VectorId, usize>>,
171    /// Vector dimension (set on first insert)
172    dimension: RwLock<Option<usize>>,
173}
174
175impl HnswIndex {
176    /// Create a new HNSW index with default configuration
177    pub fn new() -> Self {
178        Self::with_config(HnswConfig::default())
179    }
180
181    /// Create a new HNSW index with custom configuration
182    pub fn with_config(config: HnswConfig) -> Self {
183        Self {
184            config,
185            nodes: RwLock::new(Vec::new()),
186            entry_point: RwLock::new(None),
187            max_level: RwLock::new(0),
188            id_to_idx: RwLock::new(HashMap::new()),
189            dimension: RwLock::new(None),
190        }
191    }
192
193    /// Generate a random level for a new node using exponential distribution
194    fn random_level(&self) -> usize {
195        let mut rng = rand::thread_rng();
196        let uniform: f64 = rng.gen();
197
198        (-uniform.ln() * self.config.level_multiplier).floor() as usize
199    }
200
201    /// Compute distance between a query vector and a node
202    /// Converts similarity scores to distance (lower = closer)
203    fn distance(&self, query: &[f32], node_idx: usize, nodes: &[HnswNode]) -> f32 {
204        similarity_to_distance(
205            calculate_distance(query, &nodes[node_idx].vector, self.config.distance_metric),
206            self.config.distance_metric,
207        )
208    }
209
210    /// Search for nearest neighbors at a specific layer
211    fn search_layer(
212        &self,
213        query: &[f32],
214        entry_points: Vec<usize>,
215        ef: usize,
216        layer: usize,
217        nodes: &[HnswNode],
218    ) -> Vec<Candidate> {
219        let mut visited: HashSet<usize> = HashSet::new();
220        let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
221        let mut results: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
222
223        // Initialize with entry points
224        for &ep in &entry_points {
225            visited.insert(ep);
226            let dist = self.distance(query, ep, nodes);
227            candidates.push(Candidate {
228                node_idx: ep,
229                distance: dist,
230            });
231            results.push(FurthestCandidate {
232                node_idx: ep,
233                distance: dist,
234            });
235        }
236
237        while let Some(candidate) = candidates.pop() {
238            // Get furthest result
239            let furthest_dist = results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
240
241            // Stop if current candidate is further than the furthest result
242            if candidate.distance > furthest_dist && results.len() >= ef {
243                break;
244            }
245
246            // Explore neighbors at this layer
247            let node = &nodes[candidate.node_idx];
248            if layer < node.connections.len() {
249                for &neighbor_idx in &node.connections[layer] {
250                    if visited.insert(neighbor_idx) {
251                        let dist = self.distance(query, neighbor_idx, nodes);
252
253                        let should_add = results.len() < ef
254                            || dist < results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
255
256                        if should_add {
257                            candidates.push(Candidate {
258                                node_idx: neighbor_idx,
259                                distance: dist,
260                            });
261                            results.push(FurthestCandidate {
262                                node_idx: neighbor_idx,
263                                distance: dist,
264                            });
265
266                            // Keep only top ef results
267                            while results.len() > ef {
268                                results.pop();
269                            }
270                        }
271                    }
272                }
273            }
274        }
275
276        // Convert results to sorted candidates
277        let mut final_results: Vec<Candidate> = results
278            .into_iter()
279            .map(|fc| Candidate {
280                node_idx: fc.node_idx,
281                distance: fc.distance,
282            })
283            .collect();
284        final_results.sort_by(|a, b| {
285            a.distance
286                .partial_cmp(&b.distance)
287                .unwrap_or(Ordering::Equal)
288        });
289        final_results
290    }
291
292    /// Select neighbors using simple heuristic (keep M closest)
293    fn select_neighbors_simple(&self, candidates: &[Candidate], m: usize) -> Vec<usize> {
294        candidates.iter().take(m).map(|c| c.node_idx).collect()
295    }
296
297    /// Select neighbors using heuristic that considers diversity
298    fn select_neighbors_heuristic(
299        &self,
300        query: &[f32],
301        candidates: &[Candidate],
302        m: usize,
303        nodes: &[HnswNode],
304        extend_candidates: bool,
305    ) -> Vec<usize> {
306        let mut working_candidates = candidates.to_vec();
307
308        // Optionally extend with neighbors of candidates
309        if extend_candidates {
310            let mut extended: HashSet<usize> =
311                working_candidates.iter().map(|c| c.node_idx).collect();
312            for candidate in candidates.iter().take(m) {
313                let node = &nodes[candidate.node_idx];
314                for layer_connections in &node.connections {
315                    for &neighbor in layer_connections {
316                        if extended.insert(neighbor) {
317                            let dist = self.distance(query, neighbor, nodes);
318                            working_candidates.push(Candidate {
319                                node_idx: neighbor,
320                                distance: dist,
321                            });
322                        }
323                    }
324                }
325            }
326            working_candidates.sort_by(|a, b| {
327                a.distance
328                    .partial_cmp(&b.distance)
329                    .unwrap_or(Ordering::Equal)
330            });
331        }
332
333        // Use heuristic: prefer diverse neighbors
334        let mut selected: Vec<usize> = Vec::with_capacity(m);
335
336        for candidate in &working_candidates {
337            if selected.len() >= m {
338                break;
339            }
340
341            // Check if this candidate is closer to query than to any selected neighbor
342            let mut is_good = true;
343            for &sel_idx in &selected {
344                let dist_to_selected = calculate_distance(
345                    &nodes[candidate.node_idx].vector,
346                    &nodes[sel_idx].vector,
347                    self.config.distance_metric,
348                );
349                if dist_to_selected < candidate.distance {
350                    is_good = false;
351                    break;
352                }
353            }
354
355            if is_good {
356                selected.push(candidate.node_idx);
357            }
358        }
359
360        // Fill remaining slots with closest candidates if needed
361        if selected.len() < m {
362            for candidate in &working_candidates {
363                if selected.len() >= m {
364                    break;
365                }
366                if !selected.contains(&candidate.node_idx) {
367                    selected.push(candidate.node_idx);
368                }
369            }
370        }
371
372        selected
373    }
374
375    /// Add a bidirectional connection between two nodes at a specific layer
376    fn add_connection(&self, from_idx: usize, to_idx: usize, layer: usize, nodes: &mut [HnswNode]) {
377        let m_max = if layer == 0 {
378            self.config.m_max0
379        } else {
380            self.config.m
381        };
382
383        // Add connection from -> to
384        if layer < nodes[from_idx].connections.len()
385            && !nodes[from_idx].connections[layer].contains(&to_idx)
386        {
387            nodes[from_idx].connections[layer].push(to_idx);
388
389            // Prune if necessary
390            if nodes[from_idx].connections[layer].len() > m_max {
391                let conn_indices: Vec<usize> = nodes[from_idx].connections[layer].clone();
392                let mut sorted_candidates: Vec<Candidate> = conn_indices
393                    .iter()
394                    .map(|&idx| Candidate {
395                        node_idx: idx,
396                        distance: self.distance(&nodes[from_idx].vector, idx, nodes),
397                    })
398                    .collect();
399                sorted_candidates.sort_by(|a, b| {
400                    a.distance
401                        .partial_cmp(&b.distance)
402                        .unwrap_or(Ordering::Equal)
403                });
404                nodes[from_idx].connections[layer] =
405                    self.select_neighbors_simple(&sorted_candidates, m_max);
406            }
407        }
408
409        // Add connection to -> from
410        if layer < nodes[to_idx].connections.len()
411            && !nodes[to_idx].connections[layer].contains(&from_idx)
412        {
413            nodes[to_idx].connections[layer].push(from_idx);
414
415            // Prune if necessary
416            if nodes[to_idx].connections[layer].len() > m_max {
417                let conn_indices: Vec<usize> = nodes[to_idx].connections[layer].clone();
418                let mut sorted_candidates: Vec<Candidate> = conn_indices
419                    .iter()
420                    .map(|&idx| Candidate {
421                        node_idx: idx,
422                        distance: self.distance(&nodes[to_idx].vector, idx, nodes),
423                    })
424                    .collect();
425                sorted_candidates.sort_by(|a, b| {
426                    a.distance
427                        .partial_cmp(&b.distance)
428                        .unwrap_or(Ordering::Equal)
429                });
430                nodes[to_idx].connections[layer] =
431                    self.select_neighbors_simple(&sorted_candidates, m_max);
432            }
433        }
434    }
435
436    /// Insert a vector into the index
437    pub fn insert(&self, id: VectorId, vector: Vec<f32>) {
438        let vector_dim = vector.len();
439
440        // Check/set dimension
441        {
442            let mut dim = self.dimension.write();
443            if let Some(d) = *dim {
444                if d != vector_dim {
445                    tracing::error!("Dimension mismatch: expected {}, got {}", d, vector_dim);
446                    return;
447                }
448            } else {
449                *dim = Some(vector_dim);
450            }
451        }
452
453        let new_level = self.random_level();
454
455        // Create the new node
456        let new_node = HnswNode {
457            id: id.clone(),
458            vector: vector.clone(),
459            connections: (0..=new_level).map(|_| Vec::new()).collect(),
460            max_layer: new_level,
461        };
462
463        let mut nodes = self.nodes.write();
464        let new_idx = nodes.len();
465        nodes.push(new_node);
466
467        // Update ID mapping
468        self.id_to_idx.write().insert(id, new_idx);
469
470        // Handle first node
471        let entry = *self.entry_point.read();
472        let entry_idx = match entry {
473            None => {
474                *self.entry_point.write() = Some(new_idx);
475                *self.max_level.write() = new_level;
476                return;
477            }
478            Some(idx) => idx,
479        };
480        let current_max_level = *self.max_level.read();
481
482        // Find entry point at the top layer
483        let mut current_entry = vec![entry_idx];
484
485        // Descend from top layer to the layer above new node's max layer
486        for layer in (new_level + 1..=current_max_level).rev() {
487            let nearest = self.search_layer(&vector, current_entry.clone(), 1, layer, &nodes);
488            if !nearest.is_empty() {
489                current_entry = vec![nearest[0].node_idx];
490            }
491        }
492
493        // For each layer from new_level down to 0, find and connect to neighbors
494        for layer in (0..=new_level.min(current_max_level)).rev() {
495            let candidates = self.search_layer(
496                &vector,
497                current_entry.clone(),
498                self.config.ef_construction,
499                layer,
500                &nodes,
501            );
502
503            let m = if layer == 0 {
504                self.config.m_max0
505            } else {
506                self.config.m
507            };
508
509            let neighbors = self.select_neighbors_heuristic(&vector, &candidates, m, &nodes, false);
510
511            // Connect new node to selected neighbors
512            for &neighbor_idx in &neighbors {
513                self.add_connection(new_idx, neighbor_idx, layer, &mut nodes);
514            }
515
516            // Update entry points for next layer
517            if !candidates.is_empty() {
518                current_entry = candidates.iter().take(1).map(|c| c.node_idx).collect();
519            }
520        }
521
522        // Update entry point if new node has higher level
523        if new_level > current_max_level {
524            *self.entry_point.write() = Some(new_idx);
525            *self.max_level.write() = new_level;
526        }
527    }
528
529    /// Search for k nearest neighbors
530    pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
531        self.search_with_ef(query, k, self.config.ef_search)
532    }
533
534    /// Search with custom ef parameter
535    pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(VectorId, f32)> {
536        let nodes = self.nodes.read();
537
538        if nodes.is_empty() {
539            return Vec::new();
540        }
541
542        let entry = *self.entry_point.read();
543        let entry_idx = match entry {
544            None => return Vec::new(),
545            Some(idx) => idx,
546        };
547        let max_level = *self.max_level.read();
548
549        // Start at entry point
550        let mut current_entry = vec![entry_idx];
551
552        // Descend through layers greedily (ef=1) until layer 0
553        for layer in (1..=max_level).rev() {
554            let nearest = self.search_layer(query, current_entry.clone(), 1, layer, &nodes);
555            if !nearest.is_empty() {
556                current_entry = vec![nearest[0].node_idx];
557            }
558        }
559
560        // Search at layer 0 with full ef
561        let candidates = self.search_layer(query, current_entry, ef.max(k), 0, &nodes);
562
563        // Return top k results
564        candidates
565            .into_iter()
566            .take(k)
567            .map(|c| (nodes[c.node_idx].id.clone(), c.distance))
568            .collect()
569    }
570
571    /// Delete a vector from the index
572    pub fn delete(&self, id: &VectorId) -> bool {
573        let idx = {
574            let id_map = self.id_to_idx.read();
575            match id_map.get(id) {
576                Some(&idx) => idx,
577                None => return false,
578            }
579        };
580
581        let mut nodes = self.nodes.write();
582        let mut id_map = self.id_to_idx.write();
583
584        // Remove all connections to this node
585        for layer in 0..nodes[idx].connections.len() {
586            let neighbors: Vec<usize> = nodes[idx].connections[layer].clone();
587            for neighbor_idx in neighbors {
588                if neighbor_idx < nodes.len() && layer < nodes[neighbor_idx].connections.len() {
589                    nodes[neighbor_idx].connections[layer].retain(|&n| n != idx);
590                }
591            }
592        }
593
594        // Mark node as deleted (we don't actually remove to preserve indices)
595        nodes[idx].connections.clear();
596        nodes[idx].vector.clear();
597        id_map.remove(id);
598
599        // Update entry point if necessary
600        let entry = *self.entry_point.read();
601        if entry == Some(idx) {
602            // Find a new entry point
603            let new_entry = nodes
604                .iter()
605                .enumerate()
606                .filter(|(_, n)| !n.vector.is_empty())
607                .max_by_key(|(_, n)| n.max_layer)
608                .map(|(i, _)| i);
609            *self.entry_point.write() = new_entry;
610        }
611
612        true
613    }
614
615    /// Get the number of vectors in the index
616    pub fn len(&self) -> usize {
617        self.id_to_idx.read().len()
618    }
619
620    /// Check if the index is empty
621    pub fn is_empty(&self) -> bool {
622        self.len() == 0
623    }
624
625    /// Get index statistics
626    pub fn stats(&self) -> HnswStats {
627        let nodes = self.nodes.read();
628        let max_level = *self.max_level.read();
629
630        let mut level_counts = vec![0usize; max_level + 1];
631        let mut total_connections = 0usize;
632
633        for node in nodes.iter() {
634            if !node.vector.is_empty() {
635                for (layer, connections) in node.connections.iter().enumerate() {
636                    if layer <= max_level {
637                        level_counts[layer] += 1;
638                        total_connections += connections.len();
639                    }
640                }
641            }
642        }
643
644        HnswStats {
645            num_vectors: self.len(),
646            max_level,
647            level_counts,
648            total_connections,
649            avg_connections: if !self.is_empty() {
650                total_connections as f64 / self.len() as f64
651            } else {
652                0.0
653            },
654        }
655    }
656
657    /// Get configuration
658    pub fn config(&self) -> &HnswConfig {
659        &self.config
660    }
661
662    /// Get dimension
663    pub fn dimension(&self) -> Option<usize> {
664        *self.dimension.read()
665    }
666
667    /// Get entry point node index
668    pub fn entry_point(&self) -> Option<usize> {
669        *self.entry_point.read()
670    }
671
672    /// Get maximum level in the graph
673    pub fn max_level(&self) -> usize {
674        *self.max_level.read()
675    }
676
677    /// Get read access to nodes for persistence
678    /// Returns a vector of tuples: (id, vector, connections, max_layer)
679    pub(crate) fn nodes_read(&self) -> Vec<NodeSnapshot> {
680        self.nodes
681            .read()
682            .iter()
683            .map(|node| NodeSnapshot {
684                id: node.id.clone(),
685                vector: node.vector.clone(),
686                connections: node.connections.clone(),
687                max_layer: node.max_layer,
688            })
689            .collect()
690    }
691
692    /// Restore HNSW index from a full snapshot
693    pub fn from_snapshot(snapshot: crate::persistence::HnswFullSnapshot) -> Result<Self, String> {
694        use std::collections::HashMap;
695
696        let mut nodes = Vec::with_capacity(snapshot.nodes.len());
697        let mut id_to_idx = HashMap::with_capacity(snapshot.nodes.len());
698
699        for (idx, snode) in snapshot.nodes.into_iter().enumerate() {
700            id_to_idx.insert(snode.id.clone(), idx);
701            nodes.push(HnswNode {
702                id: snode.id,
703                vector: snode.vector,
704                connections: snode.connections,
705                max_layer: snode.max_layer,
706            });
707        }
708
709        let dimension = if nodes.is_empty() {
710            None
711        } else {
712            Some(snapshot.dimension)
713        };
714
715        Ok(Self {
716            config: snapshot.config,
717            nodes: RwLock::new(nodes),
718            entry_point: RwLock::new(snapshot.entry_point),
719            max_level: RwLock::new(snapshot.max_level),
720            id_to_idx: RwLock::new(id_to_idx),
721            dimension: RwLock::new(dimension),
722        })
723    }
724}
725
726/// Snapshot of a node for persistence
727#[derive(Debug, Clone)]
728pub(crate) struct NodeSnapshot {
729    pub id: String,
730    pub vector: Vec<f32>,
731    pub connections: Vec<Vec<usize>>,
732    pub max_layer: usize,
733}
734
735impl Default for HnswIndex {
736    fn default() -> Self {
737        Self::new()
738    }
739}
740
741/// Statistics about the HNSW index
742#[derive(Debug, Clone)]
743pub struct HnswStats {
744    pub num_vectors: usize,
745    pub max_level: usize,
746    pub level_counts: Vec<usize>,
747    pub total_connections: usize,
748    pub avg_connections: f64,
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754
755    fn random_vector(dim: usize) -> Vec<f32> {
756        let mut rng = rand::thread_rng();
757        (0..dim).map(|_| rng.gen::<f32>()).collect()
758    }
759
760    fn normalize(v: &mut Vec<f32>) {
761        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
762        if norm > 0.0 {
763            for x in v.iter_mut() {
764                *x /= norm;
765            }
766        }
767    }
768
769    #[test]
770    fn test_hnsw_basic_operations() {
771        let index = HnswIndex::new();
772
773        // Insert vectors
774        for i in 0..100 {
775            let mut vec = random_vector(128);
776            normalize(&mut vec);
777            index.insert(format!("vec_{}", i), vec);
778        }
779
780        assert_eq!(index.len(), 100);
781        assert!(!index.is_empty());
782
783        // Search
784        let mut query = random_vector(128);
785        normalize(&mut query);
786        let results = index.search(&query, 10);
787
788        assert_eq!(results.len(), 10);
789
790        // Results should be sorted by distance
791        for i in 1..results.len() {
792            assert!(results[i - 1].1 <= results[i].1);
793        }
794    }
795
796    #[test]
797    fn test_hnsw_delete() {
798        let index = HnswIndex::new();
799
800        for i in 0..10 {
801            let mut vec = random_vector(64);
802            normalize(&mut vec);
803            index.insert(format!("vec_{}", i), vec);
804        }
805
806        assert_eq!(index.len(), 10);
807
808        // Delete a vector
809        assert!(index.delete(&"vec_5".to_string()));
810        assert_eq!(index.len(), 9);
811
812        // Delete non-existent vector
813        assert!(!index.delete(&"vec_999".to_string()));
814    }
815
816    #[test]
817    fn test_hnsw_recall() {
818        let dim = 128;
819        let n_vectors = 1000;
820        let index = HnswIndex::with_config(HnswConfig::new(16, 200, 100));
821
822        // Insert vectors
823        let mut vectors: Vec<(VectorId, Vec<f32>)> = Vec::new();
824        for i in 0..n_vectors {
825            let mut vec = random_vector(dim);
826            normalize(&mut vec);
827            let id: VectorId = format!("vec_{}", i);
828            vectors.push((id.clone(), vec.clone()));
829            index.insert(id, vec);
830        }
831
832        // Test recall with random queries
833        let n_queries = 10;
834        let k = 10;
835        let mut total_recall = 0.0;
836
837        for _ in 0..n_queries {
838            let mut query = random_vector(dim);
839            normalize(&mut query);
840
841            // Get HNSW results
842            let hnsw_results: HashSet<String> = index
843                .search(&query, k)
844                .into_iter()
845                .map(|(id, _)| id)
846                .collect();
847
848            // Compute exact nearest neighbors (using distance, lower = closer)
849            let mut exact: Vec<(String, f32)> = vectors
850                .iter()
851                .map(|(id, vec)| {
852                    let sim = calculate_distance(&query, vec, DistanceMetric::Cosine);
853                    (
854                        id.clone(),
855                        similarity_to_distance(sim, DistanceMetric::Cosine),
856                    )
857                })
858                .collect();
859            exact.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
860            let exact_results: HashSet<String> =
861                exact.into_iter().take(k).map(|(id, _)| id).collect();
862
863            // Compute recall
864            let overlap = hnsw_results.intersection(&exact_results).count();
865            total_recall += overlap as f64 / k as f64;
866        }
867
868        let avg_recall = total_recall / n_queries as f64;
869        println!("Average recall@{}: {:.2}%", k, avg_recall * 100.0);
870
871        // HNSW should achieve at least 80% recall with default parameters
872        assert!(
873            avg_recall >= 0.80,
874            "Recall too low: {:.2}%",
875            avg_recall * 100.0
876        );
877    }
878
879    #[test]
880    fn test_hnsw_stats() {
881        let index = HnswIndex::new();
882
883        for i in 0..50 {
884            let mut vec = random_vector(64);
885            normalize(&mut vec);
886            index.insert(format!("vec_{}", i), vec);
887        }
888
889        let stats = index.stats();
890        assert_eq!(stats.num_vectors, 50);
891        // max_level is usize, always >= 0, just verify it's accessible
892        let _ = stats.max_level;
893        assert!(stats.avg_connections > 0.0);
894
895        println!("HNSW Stats: {:?}", stats);
896    }
897
898    #[test]
899    fn test_hnsw_custom_ef() {
900        let index = HnswIndex::new();
901
902        for i in 0..100 {
903            let mut vec = random_vector(64);
904            normalize(&mut vec);
905            index.insert(format!("vec_{}", i), vec);
906        }
907
908        let mut query = random_vector(64);
909        normalize(&mut query);
910
911        // Search with different ef values
912        let results_low_ef = index.search_with_ef(&query, 10, 10);
913        let results_high_ef = index.search_with_ef(&query, 10, 200);
914
915        assert_eq!(results_low_ef.len(), 10);
916        assert_eq!(results_high_ef.len(), 10);
917
918        // Higher ef might find better results (lower distances)
919        // At minimum, both should return valid results
920    }
921
922    #[test]
923    fn test_hnsw_empty_search() {
924        let index = HnswIndex::new();
925        let query = random_vector(64);
926        let results = index.search(&query, 10);
927        assert!(results.is_empty());
928    }
929
930    #[test]
931    fn test_hnsw_single_vector() {
932        let index = HnswIndex::new();
933
934        let mut vec = random_vector(64);
935        normalize(&mut vec);
936        index.insert("single".to_string(), vec.clone());
937
938        let results = index.search(&vec, 5);
939        assert_eq!(results.len(), 1);
940        assert_eq!(results[0].0, "single".to_string());
941        // Distance to self should be very small (cosine distance = 1 - similarity)
942        assert!(
943            results[0].1.abs() < 0.1,
944            "Distance to self was {}",
945            results[0].1
946        );
947    }
948}