rag_plusplus_core/index/
hnsw.rs

1//! HNSW Index (Hierarchical Navigable Small World)
2//!
3//! Approximate nearest neighbor search with logarithmic query time.
4//! Production default for RAG++.
5//!
6//! # Performance Characteristics
7//!
8//! - Build: O(n log n)
9//! - Query: O(log n) with ef_search
10//! - Memory: O(n * M) where M is connectivity
11//!
12//! # Configuration
13//!
14//! - `m`: Connections per layer (higher = better recall, more memory)
15//! - `ef_construction`: Build-time search depth (higher = better quality)
16//! - `ef_search`: Query-time search depth (higher = better recall, slower)
17
18use ahash::{AHashMap, AHashSet};
19use ordered_float::OrderedFloat;
20use rand::Rng;
21use std::cmp::Reverse;
22use std::collections::BinaryHeap;
23
24use crate::error::{Error, Result};
25use super::traits::{DistanceType, IndexConfig, SearchResult, VectorIndex};
26
27/// HNSW configuration parameters.
28#[derive(Debug, Clone)]
29pub struct HNSWConfig {
30    /// Base index configuration
31    pub base: IndexConfig,
32    /// Number of connections per layer (default: 16)
33    pub m: usize,
34    /// Maximum connections for layer 0 (default: 2 * m)
35    pub m_max0: usize,
36    /// Construction-time search depth (default: 200)
37    pub ef_construction: usize,
38    /// Default query-time search depth (default: 128)
39    pub ef_search: usize,
40    /// Level multiplier (default: 1 / ln(m))
41    pub ml: f64,
42}
43
44impl HNSWConfig {
45    /// Create default HNSW config for given dimension.
46    #[must_use]
47    pub fn new(dimension: usize) -> Self {
48        let m = 16;
49        Self {
50            base: IndexConfig::new(dimension),
51            m,
52            m_max0: 2 * m,
53            ef_construction: 200,
54            ef_search: 128,
55            ml: 1.0 / (m as f64).ln(),
56        }
57    }
58
59    /// Set M parameter.
60    #[must_use]
61    pub fn with_m(mut self, m: usize) -> Self {
62        self.m = m;
63        self.m_max0 = 2 * m;
64        self.ml = 1.0 / (m as f64).ln();
65        self
66    }
67
68    /// Set ef_construction.
69    #[must_use]
70    pub const fn with_ef_construction(mut self, ef: usize) -> Self {
71        self.ef_construction = ef;
72        self
73    }
74
75    /// Set ef_search.
76    #[must_use]
77    pub const fn with_ef_search(mut self, ef: usize) -> Self {
78        self.ef_search = ef;
79        self
80    }
81
82    /// Set distance type.
83    #[must_use]
84    pub fn with_distance(mut self, distance_type: DistanceType) -> Self {
85        self.base.distance_type = distance_type;
86        self
87    }
88}
89
90/// Node in the HNSW graph.
91#[derive(Debug, Clone)]
92struct HNSWNode {
93    /// Node ID
94    id: String,
95    /// Vector data
96    vector: Vec<f32>,
97    /// Max level this node appears in (used for graph traversal)
98    #[allow(dead_code)]
99    level: usize,
100    /// Neighbors at each level (level -> set of neighbor indices)
101    neighbors: Vec<AHashSet<usize>>,
102}
103
104/// HNSW Index implementation.
105///
106/// Thread-safe approximate nearest neighbor index.
107#[derive(Debug)]
108pub struct HNSWIndex {
109    /// Configuration
110    config: HNSWConfig,
111    /// All nodes
112    nodes: Vec<HNSWNode>,
113    /// ID to index mapping
114    id_to_idx: AHashMap<String, usize>,
115    /// Entry point (highest level node index)
116    entry_point: Option<usize>,
117    /// Maximum level in the graph
118    max_level: usize,
119    /// RNG for level generation
120    rng: parking_lot::Mutex<rand::rngs::SmallRng>,
121}
122
123impl HNSWIndex {
124    /// Create a new HNSW index.
125    #[must_use]
126    pub fn new(config: HNSWConfig) -> Self {
127        use rand::SeedableRng;
128        Self {
129            config,
130            nodes: Vec::new(),
131            id_to_idx: AHashMap::new(),
132            entry_point: None,
133            max_level: 0,
134            rng: parking_lot::Mutex::new(rand::rngs::SmallRng::from_entropy()),
135        }
136    }
137
138    /// Generate random level for new node.
139    fn random_level(&self) -> usize {
140        let mut rng = self.rng.lock();
141        let mut level = 0;
142        while rng.gen::<f64>() < self.config.ml && level < 16 {
143            level += 1;
144        }
145        level
146    }
147
148    /// Compute distance between two vectors.
149    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
150        match self.config.base.distance_type {
151            DistanceType::L2 => {
152                a.iter()
153                    .zip(b.iter())
154                    .map(|(x, y)| (x - y).powi(2))
155                    .sum::<f32>()
156                    .sqrt()
157            }
158            DistanceType::InnerProduct => {
159                // Negative for min-heap compatibility
160                -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
161            }
162            DistanceType::Cosine => {
163                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
164                let norm_a: f32 = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
165                let norm_b: f32 = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
166                if norm_a == 0.0 || norm_b == 0.0 {
167                    1.0
168                } else {
169                    1.0 - (dot / (norm_a * norm_b))
170                }
171            }
172        }
173    }
174
175    /// Search layer for nearest neighbors.
176    fn search_layer(
177        &self,
178        query: &[f32],
179        entry_points: Vec<usize>,
180        ef: usize,
181        level: usize,
182    ) -> Vec<(f32, usize)> {
183        let mut visited: AHashSet<usize> = entry_points.iter().copied().collect();
184        
185        // Min-heap for candidates (distance, idx)
186        let mut candidates: BinaryHeap<Reverse<(OrderedFloat<f32>, usize)>> = BinaryHeap::new();
187        
188        // Max-heap for results (distance, idx)  
189        let mut results: BinaryHeap<(OrderedFloat<f32>, usize)> = BinaryHeap::new();
190
191        // Initialize with entry points
192        for &ep in &entry_points {
193            let dist = self.distance(query, &self.nodes[ep].vector);
194            candidates.push(Reverse((OrderedFloat(dist), ep)));
195            results.push((OrderedFloat(dist), ep));
196        }
197
198        while let Some(Reverse((OrderedFloat(c_dist), c_idx))) = candidates.pop() {
199            // Get furthest in results
200            let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
201            
202            if c_dist > f_dist && results.len() >= ef {
203                break;
204            }
205
206            // Explore neighbors
207            if level < self.nodes[c_idx].neighbors.len() {
208                for &neighbor_idx in &self.nodes[c_idx].neighbors[level] {
209                    if visited.insert(neighbor_idx) {
210                        let dist = self.distance(query, &self.nodes[neighbor_idx].vector);
211                        let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f32::INFINITY);
212                        
213                        if dist < f_dist || results.len() < ef {
214                            candidates.push(Reverse((OrderedFloat(dist), neighbor_idx)));
215                            results.push((OrderedFloat(dist), neighbor_idx));
216                            
217                            if results.len() > ef {
218                                results.pop();
219                            }
220                        }
221                    }
222                }
223            }
224        }
225
226        // Return sorted results
227        let mut result_vec: Vec<_> = results.into_iter().map(|(d, idx)| (d.0, idx)).collect();
228        result_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
229        result_vec
230    }
231
232    /// Select neighbors using simple heuristic.
233    fn select_neighbors(&self, candidates: &[(f32, usize)], m: usize) -> Vec<usize> {
234        candidates.iter().take(m).map(|(_, idx)| *idx).collect()
235    }
236
237    /// Get max connections for a level.
238    fn get_max_connections(&self, level: usize) -> usize {
239        if level == 0 {
240            self.config.m_max0
241        } else {
242            self.config.m
243        }
244    }
245}
246
247impl VectorIndex for HNSWIndex {
248    fn add(&mut self, id: String, vector: &[f32]) -> Result<()> {
249        if vector.len() != self.config.base.dimension {
250            return Err(Error::InvalidQuery {
251                reason: format!(
252                    "Dimension mismatch: expected {}, got {}",
253                    self.config.base.dimension,
254                    vector.len()
255                ),
256            });
257        }
258
259        if self.id_to_idx.contains_key(&id) {
260            return Err(Error::DuplicateRecord { record_id: id });
261        }
262
263        let level = self.random_level();
264        let new_idx = self.nodes.len();
265
266        // Create new node
267        let mut node = HNSWNode {
268            id: id.clone(),
269            vector: vector.to_vec(),
270            level,
271            neighbors: vec![AHashSet::new(); level + 1],
272        };
273
274        // First node
275        if self.entry_point.is_none() {
276            self.nodes.push(node);
277            self.id_to_idx.insert(id, new_idx);
278            self.entry_point = Some(new_idx);
279            self.max_level = level;
280            return Ok(());
281        }
282
283        let entry_point = self.entry_point.unwrap();
284        let mut curr_ep = vec![entry_point];
285
286        // Traverse from top to insertion level + 1
287        for lc in (level + 1..=self.max_level).rev() {
288            let nearest = self.search_layer(vector, curr_ep.clone(), 1, lc);
289            if !nearest.is_empty() {
290                curr_ep = vec![nearest[0].1];
291            }
292        }
293
294        // Insert at each level from level down to 0
295        for lc in (0..=level.min(self.max_level)).rev() {
296            let candidates = self.search_layer(
297                vector,
298                curr_ep.clone(),
299                self.config.ef_construction,
300                lc,
301            );
302
303            let m = self.get_max_connections(lc);
304            let neighbors = self.select_neighbors(&candidates, m);
305
306            // Add bidirectional connections
307            node.neighbors[lc] = neighbors.iter().copied().collect();
308
309            for &neighbor_idx in &neighbors {
310                if lc < self.nodes[neighbor_idx].neighbors.len() {
311                    self.nodes[neighbor_idx].neighbors[lc].insert(new_idx);
312
313                    // Prune if too many connections
314                    if self.nodes[neighbor_idx].neighbors[lc].len() > m {
315                        let neighbor_vec = &self.nodes[neighbor_idx].vector;
316                        // Exclude new_idx from distance calc since node not pushed yet
317                        let new_node_vec = vector;
318                        let mut scored: Vec<_> = self.nodes[neighbor_idx].neighbors[lc]
319                            .iter()
320                            .map(|&idx| {
321                                let dist = if idx == new_idx {
322                                    self.distance(neighbor_vec, new_node_vec)
323                                } else {
324                                    self.distance(neighbor_vec, &self.nodes[idx].vector)
325                                };
326                                (dist, idx)
327                            })
328                            .collect();
329                        scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
330                        self.nodes[neighbor_idx].neighbors[lc] =
331                            scored.into_iter().take(m).map(|(_, idx)| idx).collect();
332                    }
333                }
334            }
335
336            if !candidates.is_empty() {
337                curr_ep = vec![candidates[0].1];
338            }
339        }
340
341        self.nodes.push(node);
342        self.id_to_idx.insert(id, new_idx);
343
344        // Update entry point if new node has higher level
345        if level > self.max_level {
346            self.entry_point = Some(new_idx);
347            self.max_level = level;
348        }
349
350        Ok(())
351    }
352
353    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
354        if query.len() != self.config.base.dimension {
355            return Err(Error::InvalidQuery {
356                reason: format!(
357                    "Query dimension mismatch: expected {}, got {}",
358                    self.config.base.dimension,
359                    query.len()
360                ),
361            });
362        }
363
364        if self.nodes.is_empty() {
365            return Ok(vec![]);
366        }
367
368        let entry_point = self.entry_point.unwrap();
369        let mut curr_ep = vec![entry_point];
370
371        // Traverse from top to level 1
372        for lc in (1..=self.max_level).rev() {
373            let nearest = self.search_layer(query, curr_ep.clone(), 1, lc);
374            if !nearest.is_empty() {
375                curr_ep = vec![nearest[0].1];
376            }
377        }
378
379        // Search at level 0 with ef_search
380        let results = self.search_layer(query, curr_ep, self.config.ef_search, 0);
381
382        // Convert to SearchResult
383        let k = k.min(results.len());
384        Ok(results
385            .into_iter()
386            .take(k)
387            .map(|(dist, idx)| {
388                let actual_dist = match self.config.base.distance_type {
389                    DistanceType::InnerProduct => -dist,
390                    DistanceType::Cosine => 1.0 - dist,
391                    DistanceType::L2 => dist,
392                };
393                SearchResult::new(
394                    self.nodes[idx].id.clone(),
395                    actual_dist,
396                    self.config.base.distance_type,
397                )
398            })
399            .collect())
400    }
401
402    fn remove(&mut self, id: &str) -> Result<bool> {
403        // Note: Full removal in HNSW is complex. 
404        // For production, use soft-delete + periodic rebuild.
405        if let Some(&idx) = self.id_to_idx.get(id) {
406            // Remove from neighbor lists
407            for node in &mut self.nodes {
408                for neighbors in &mut node.neighbors {
409                    neighbors.remove(&idx);
410                }
411            }
412            self.id_to_idx.remove(id);
413            // Mark as deleted (don't actually remove to preserve indices)
414            self.nodes[idx].id = String::new();
415            self.nodes[idx].vector.clear();
416            Ok(true)
417        } else {
418            Ok(false)
419        }
420    }
421
422    fn contains(&self, id: &str) -> bool {
423        self.id_to_idx.contains_key(id)
424    }
425
426    fn len(&self) -> usize {
427        self.id_to_idx.len()
428    }
429
430    fn dimension(&self) -> usize {
431        self.config.base.dimension
432    }
433
434    fn distance_type(&self) -> DistanceType {
435        self.config.base.distance_type
436    }
437
438    fn clear(&mut self) {
439        self.nodes.clear();
440        self.id_to_idx.clear();
441        self.entry_point = None;
442        self.max_level = 0;
443    }
444
445    fn memory_usage(&self) -> usize {
446        let node_size = self.config.base.dimension * 4 + 64; // vector + overhead
447        let neighbor_size = self.config.m * 8 * 2; // avg neighbors * pointer size * 2
448        self.nodes.len() * (node_size + neighbor_size)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    fn create_test_index() -> HNSWIndex {
457        let config = HNSWConfig::new(4)
458            .with_m(4)
459            .with_ef_construction(16)
460            .with_ef_search(16);
461        let mut index = HNSWIndex::new(config);
462        
463        index.add("a".to_string(), &[1.0, 0.0, 0.0, 0.0]).unwrap();
464        index.add("b".to_string(), &[0.0, 1.0, 0.0, 0.0]).unwrap();
465        index.add("c".to_string(), &[0.0, 0.0, 1.0, 0.0]).unwrap();
466        index.add("d".to_string(), &[0.5, 0.5, 0.0, 0.0]).unwrap();
467        index.add("e".to_string(), &[0.9, 0.1, 0.0, 0.0]).unwrap();
468        
469        index
470    }
471
472    #[test]
473    fn test_add_and_search() {
474        let index = create_test_index();
475        
476        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 3).unwrap();
477        
478        assert!(!results.is_empty());
479        // "a" or "e" should be closest
480        assert!(results[0].id == "a" || results[0].id == "e");
481    }
482
483    #[test]
484    fn test_recall() {
485        let config = HNSWConfig::new(8).with_m(8).with_ef_search(32);
486        let mut index = HNSWIndex::new(config);
487        
488        // Add 100 random vectors
489        for i in 0..100 {
490            let vec: Vec<f32> = (0..8).map(|j| ((i * j) % 100) as f32 / 100.0).collect();
491            index.add(format!("v{}", i), &vec).unwrap();
492        }
493        
494        // Search should return results
495        let results = index.search(&[0.5; 8], 10).unwrap();
496        assert_eq!(results.len(), 10);
497    }
498
499    #[test]
500    fn test_duplicate_id() {
501        let mut index = create_test_index();
502        
503        let result = index.add("a".to_string(), &[0.0, 0.0, 0.0, 1.0]);
504        assert!(result.is_err());
505    }
506}