Skip to main content

content_index/
ann.rs

1//! Approximate Nearest Neighbor (ANN) search using HNSW algorithm.
2//!
3//! This module provides high-performance vector similarity search using
4//! Hierarchical Navigable Small World (HNSW) graphs. It offers sub-linear
5//! search time complexity (~O(log n)) compared to brute force O(n).
6//!
7//! ## Trade-offs
8//!
9//! - **Speed**: ~100-1000x faster than linear scan for large datasets
10//! - **Recall**: Typically 95-99% (some false negatives possible)
11//! - **Memory**: Higher memory usage than linear scan
12//! - **Build time**: Index construction takes longer than insertion
13//!
14//! ## When to Use
15//!
16//! - Dataset size > 10,000 vectors
17//! - Query latency requirements < 100ms
18//! - Acceptable to miss ~1-5% of true nearest neighbors
19//!
20//! ## When NOT to Use
21//!
22//! - Dataset size < 1,000 vectors (linear scan is fine)
23//! - Need 100% recall (use exact search)
24//! - Memory constrained environment
25
26use hnsw_rs::prelude::*;
27use std::collections::HashMap;
28
29/// Configuration for ANN index construction.
30#[derive(Debug, Clone, Copy)]
31pub struct AnnConfig {
32    /// Number of neighbors per node (higher = better recall, slower build).
33    /// Default: 16
34    pub m: usize,
35    /// Size of dynamic candidate list during construction (higher = better recall, slower build).
36    /// Default: 200
37    pub ef_construction: usize,
38    /// Size of dynamic candidate list during search (higher = better recall, slower search).
39    /// Default: 50
40    pub ef_search: usize,
41    /// Maximum number of results to return from ANN search.
42    /// Default: 100
43    pub max_results: usize,
44    /// Whether to use ANN or fall back to linear scan.
45    /// Default: true (use ANN when beneficial)
46    pub enabled: bool,
47    /// Minimum number of vectors before ANN is used.
48    /// Below this threshold, linear scan is used even if enabled=true.
49    /// Default: 1000
50    pub min_vectors_for_ann: usize,
51}
52
53impl Default for AnnConfig {
54    fn default() -> Self {
55        Self {
56            m: 16,
57            ef_construction: 200,
58            ef_search: 50,
59            max_results: 100,
60            enabled: true,
61            min_vectors_for_ann: 1000,
62        }
63    }
64}
65
66impl AnnConfig {
67    pub fn with_m(mut self, m: usize) -> Self {
68        self.m = m;
69        self
70    }
71
72    pub fn with_ef_construction(mut self, ef: usize) -> Self {
73        self.ef_construction = ef;
74        self
75    }
76
77    pub fn with_ef_search(mut self, ef: usize) -> Self {
78        self.ef_search = ef;
79        self
80    }
81
82    pub fn with_max_results(mut self, max: usize) -> Self {
83        self.max_results = max;
84        self
85    }
86
87    pub fn with_enabled(mut self, enabled: bool) -> Self {
88        self.enabled = enabled;
89        self
90    }
91
92    pub fn with_min_vectors_for_ann(mut self, min: usize) -> Self {
93        self.min_vectors_for_ann = min;
94        self
95    }
96
97    /// Check if ANN should be used given the current dataset size.
98    pub fn should_use_ann(&self, num_vectors: usize) -> bool {
99        self.enabled && num_vectors >= self.min_vectors_for_ann
100    }
101}
102
103/// Result from ANN search.
104#[derive(Debug, Clone)]
105pub struct AnnResult {
106    /// Index of the vector in the original dataset.
107    pub index: usize,
108    /// Distance to query vector (lower = closer).
109    pub distance: f32,
110}
111
112/// ANN index interface (HNSW implementation).
113pub struct AnnIndex {
114    config: AnnConfig,
115    dimension: usize,
116    hnsw: Option<Hnsw<'static, f32, DistCosine>>,
117    id_to_index: HashMap<String, usize>,
118    index_to_id: HashMap<usize, String>,
119    vectors: Vec<Vec<f32>>,
120    built: bool,
121}
122
123impl AnnIndex {
124    /// Create a new empty ANN index.
125    pub fn new(dimension: usize, config: AnnConfig) -> Self {
126        Self {
127            config,
128            dimension,
129            hnsw: None,
130            id_to_index: HashMap::new(),
131            index_to_id: HashMap::new(),
132            vectors: Vec::new(),
133            built: false,
134        }
135    }
136
137    /// Insert a vector with associated ID.
138    pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<(), AnnError> {
139        if vector.len() != self.dimension {
140            return Err(AnnError::DimensionMismatch {
141                expected: self.dimension,
142                got: vector.len(),
143            });
144        }
145
146        let index = self.vectors.len();
147        self.vectors.push(vector);
148        self.id_to_index.insert(id.clone(), index);
149        self.index_to_id.insert(index, id);
150
151        // Mark as needing rebuild
152        self.built = false;
153
154        Ok(())
155    }
156
157    /// Search for nearest neighbors.
158    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
159        if query.len() != self.dimension {
160            return Err(AnnError::DimensionMismatch {
161                expected: self.dimension,
162                got: query.len(),
163            });
164        }
165
166        let k = k.min(self.config.max_results);
167
168        // Decide whether to use ANN or linear scan
169        if self.built && self.config.should_use_ann(self.vectors.len()) && self.hnsw.is_some() {
170            // Use HNSW for approximate search
171            self.hnsw_search(query, k)
172        } else {
173            // Fall back to linear scan
174            self.linear_search(query, k)
175        }
176    }
177
178    /// HNSW-based approximate search.
179    fn hnsw_search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
180        if let Some(ref hnsw) = self.hnsw {
181            let ef = self.config.ef_search;
182            let results: Vec<Neighbour> = hnsw.search(query, k, ef);
183
184            Ok(results
185                .into_iter()
186                .map(|neighbour| AnnResult {
187                    index: neighbour.get_origin_id(),
188                    distance: neighbour.distance,
189                })
190                .collect())
191        } else {
192            Err(AnnError::NotBuilt)
193        }
194    }
195
196    /// Linear search (exact, slow but accurate).
197    fn linear_search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
198        if self.vectors.is_empty() {
199            return Ok(Vec::new());
200        }
201
202        // Calculate distances for all vectors
203        let mut distances: Vec<(usize, f32)> = self
204            .vectors
205            .iter()
206            .enumerate()
207            .map(|(idx, vec)| (idx, cosine_distance(query, vec)))
208            .collect();
209
210        // Sort by distance (ascending - lower is closer)
211        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
212
213        // Take top k
214        let results = distances
215            .into_iter()
216            .take(k)
217            .map(|(idx, dist)| AnnResult {
218                index: idx,
219                distance: dist,
220            })
221            .collect();
222
223        Ok(results)
224    }
225
226    /// Get ID by index.
227    pub fn get_id(&self, index: usize) -> Option<&String> {
228        self.index_to_id.get(&index)
229    }
230
231    /// Get index by ID.
232    pub fn get_index(&self, id: &str) -> Option<usize> {
233        self.id_to_index.get(id).copied()
234    }
235
236    /// Number of vectors in index.
237    pub fn len(&self) -> usize {
238        self.vectors.len()
239    }
240
241    pub fn is_empty(&self) -> bool {
242        self.vectors.is_empty()
243    }
244
245    /// Check if HNSW index is built.
246    pub fn is_built(&self) -> bool {
247        self.built
248    }
249
250    /// Build HNSW index (required before using ANN search).
251    /// Only builds if there are enough vectors for HNSW to work properly (minimum 10).
252    pub fn build(&mut self) {
253        if self.vectors.is_empty() {
254            return;
255        }
256
257        // HNSW requires a minimum number of vectors to work properly
258        // Below this threshold, we'll just use linear search
259        let nb_elem = self.vectors.len();
260        if nb_elem < 10 {
261            // Not enough vectors for HNSW, mark as built but use linear search
262            self.built = true;
263            return;
264        }
265
266        // Calculate parameters for HNSW
267        let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize);
268
269        // Create HNSW index with 5 parameters
270        let hnsw = Hnsw::<f32, DistCosine>::new(
271            self.config.m,
272            nb_elem,
273            nb_layer,
274            self.config.ef_construction,
275            DistCosine {},
276        );
277
278        // Insert all vectors using parallel_insert
279        // The API expects &[(&Vec<f32>, usize)] so we pass references to the stored vectors
280        let data_for_insertion: Vec<(&Vec<f32>, usize)> = self
281            .vectors
282            .iter()
283            .enumerate()
284            .map(|(idx, vec)| (vec, idx))
285            .collect();
286        hnsw.parallel_insert(&data_for_insertion);
287
288        self.hnsw = Some(hnsw);
289        self.built = true;
290    }
291
292    /// Rebuild the index (useful after batch insertions).
293    pub fn rebuild(&mut self) {
294        self.built = false;
295        self.build();
296    }
297
298    /// Get current configuration.
299    pub fn config(&self) -> &AnnConfig {
300        &self.config
301    }
302
303    /// Update configuration and mark as needing rebuild if needed.
304    pub fn update_config(&mut self, config: AnnConfig) {
305        let needs_rebuild =
306            config.m != self.config.m || config.ef_construction != self.config.ef_construction;
307
308        self.config = config;
309
310        if needs_rebuild {
311            self.built = false;
312        }
313    }
314}
315
316/// Error type for ANN operations.
317#[derive(Debug, thiserror::Error)]
318pub enum AnnError {
319    #[error("Dimension mismatch: expected {expected}, got {got}")]
320    DimensionMismatch { expected: usize, got: usize },
321    #[error("Index not built")]
322    NotBuilt,
323    #[error("HNSW error: {0}")]
324    HnswError(String),
325}
326
327/// Calculate cosine distance (1 - cosine similarity).
328/// Lower values mean vectors are more similar.
329fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
330    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
331    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
332    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
333
334    if norm_a == 0.0 || norm_b == 0.0 {
335        return 1.0; // Maximum distance
336    }
337
338    let similarity = dot / (norm_a * norm_b);
339    // Convert to distance: 1 - similarity
340    1.0 - similarity.clamp(-1.0, 1.0)
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_ann_config_defaults() {
349        let config = AnnConfig::default();
350        assert_eq!(config.m, 16);
351        assert_eq!(config.ef_construction, 200);
352        assert_eq!(config.ef_search, 50);
353        assert!(config.enabled);
354        assert_eq!(config.min_vectors_for_ann, 1000);
355    }
356
357    #[test]
358    fn test_ann_config_builder() {
359        let config = AnnConfig::default()
360            .with_m(32)
361            .with_ef_construction(400)
362            .with_ef_search(100)
363            .with_enabled(false)
364            .with_min_vectors_for_ann(500);
365
366        assert_eq!(config.m, 32);
367        assert_eq!(config.ef_construction, 400);
368        assert_eq!(config.ef_search, 100);
369        assert!(!config.enabled);
370        assert_eq!(config.min_vectors_for_ann, 500);
371    }
372
373    #[test]
374    fn test_should_use_ann() {
375        let config = AnnConfig::default();
376
377        // Above threshold with enabled=true
378        assert!(config.should_use_ann(1000));
379        assert!(config.should_use_ann(10000));
380
381        // Below threshold
382        assert!(!config.should_use_ann(999));
383        assert!(!config.should_use_ann(100));
384
385        // When disabled
386        let disabled_config = AnnConfig::default().with_enabled(false);
387        assert!(!disabled_config.should_use_ann(10000));
388    }
389
390    #[test]
391    fn test_ann_index_insert_and_linear_search() {
392        let mut index = AnnIndex::new(3, AnnConfig::default());
393
394        // Insert vectors
395        index
396            .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
397            .unwrap();
398        index
399            .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
400            .unwrap();
401        index
402            .insert("doc3".to_string(), vec![0.0, 0.0, 1.0])
403            .unwrap();
404
405        // Search (should use linear since < 1000 vectors)
406        let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
407        assert_eq!(results.len(), 2);
408        assert_eq!(results[0].index, 0); // doc1 is closest
409    }
410
411    #[test]
412    fn test_ann_index_dimension_mismatch() {
413        let mut index = AnnIndex::new(3, AnnConfig::default());
414
415        // Wrong dimension on insert
416        let result = index.insert("doc1".to_string(), vec![1.0, 0.0]);
417        assert!(matches!(result, Err(AnnError::DimensionMismatch { .. })));
418
419        // Wrong dimension on search
420        index
421            .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
422            .unwrap();
423        let result = index.search(&[1.0, 0.0], 1);
424        assert!(matches!(result, Err(AnnError::DimensionMismatch { .. })));
425    }
426
427    #[test]
428    fn test_ann_index_empty_search() {
429        let index = AnnIndex::new(3, AnnConfig::default());
430        let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap();
431        assert!(results.is_empty());
432    }
433
434    #[test]
435    fn test_id_index_mapping() {
436        let mut index = AnnIndex::new(3, AnnConfig::default());
437
438        index
439            .insert("doc-a".to_string(), vec![1.0, 0.0, 0.0])
440            .unwrap();
441        index
442            .insert("doc-b".to_string(), vec![0.0, 1.0, 0.0])
443            .unwrap();
444
445        assert_eq!(index.get_index("doc-a"), Some(0));
446        assert_eq!(index.get_index("doc-b"), Some(1));
447        assert_eq!(index.get_id(0), Some(&"doc-a".to_string()));
448        assert_eq!(index.get_id(1), Some(&"doc-b".to_string()));
449    }
450
451    #[test]
452    fn test_cosine_distance() {
453        // Same vector - distance should be 0
454        let d = cosine_distance(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]);
455        assert!(d.abs() < 0.001);
456
457        // Orthogonal vectors - distance should be 1 (max)
458        let d = cosine_distance(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]);
459        assert!((d - 1.0).abs() < 0.001);
460
461        // Opposite vectors - distance should be 2 (beyond max, clamped to 1)
462        let d = cosine_distance(&[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0]);
463        assert!((d - 2.0).abs() < 0.001);
464    }
465
466    #[test]
467    fn test_ann_index_search_respects_k() {
468        let mut index = AnnIndex::new(3, AnnConfig::default());
469
470        // Insert 5 vectors
471        for i in 0..5 {
472            index
473                .insert(format!("doc{i}"), vec![i as f32, 0.0, 0.0])
474                .unwrap();
475        }
476
477        // Search for k=2
478        let results = index.search(&[0.0, 0.0, 0.0], 2).unwrap();
479        assert_eq!(results.len(), 2);
480
481        // Search for k=10 (more than available)
482        let results = index.search(&[0.0, 0.0, 0.0], 10).unwrap();
483        assert_eq!(results.len(), 5); // Returns all available
484    }
485
486    #[test]
487    fn test_ann_index_build_and_search() {
488        let mut index = AnnIndex::new(
489            3,
490            AnnConfig::default().with_min_vectors_for_ann(1), // Enable ANN for small test
491        );
492
493        // Insert vectors
494        index
495            .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
496            .unwrap();
497        index
498            .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
499            .unwrap();
500        index
501            .insert("doc3".to_string(), vec![0.0, 0.0, 1.0])
502            .unwrap();
503
504        // Not built yet
505        assert!(!index.is_built());
506
507        // Build HNSW
508        index.build();
509        assert!(index.is_built());
510
511        // Search should use HNSW now
512        let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
513        assert_eq!(results.len(), 2);
514    }
515
516    #[test]
517    fn test_ann_index_rebuild() {
518        let mut index = AnnIndex::new(3, AnnConfig::default().with_min_vectors_for_ann(1));
519
520        index
521            .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
522            .unwrap();
523        index.build();
524        assert!(index.is_built());
525
526        // Insert more without rebuilding
527        index
528            .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
529            .unwrap();
530        assert!(!index.is_built()); // Should be marked as not built
531
532        // Rebuild
533        index.rebuild();
534        assert!(index.is_built());
535    }
536
537    #[test]
538    fn test_update_config_triggers_rebuild() {
539        let mut index = AnnIndex::new(3, AnnConfig::default().with_min_vectors_for_ann(1));
540
541        index
542            .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
543            .unwrap();
544        index.build();
545        assert!(index.is_built());
546
547        // Update config with different M - should invalidate
548        let new_config = AnnConfig::default().with_min_vectors_for_ann(1).with_m(32);
549        index.update_config(new_config);
550
551        // Should need rebuild since M changed
552        assert!(!index.is_built());
553    }
554}