1use hnsw_rs::prelude::*;
27use std::collections::HashMap;
28
29#[derive(Debug, Clone, Copy)]
31pub struct AnnConfig {
32 pub m: usize,
35 pub ef_construction: usize,
38 pub ef_search: usize,
41 pub max_results: usize,
44 pub enabled: bool,
47 pub min_vectors_for_ann: usize,
51}
52
53impl Default for AnnConfig {
54 fn default() -> Self {
55 Self {
56 m: 16,
57 ef_construction: 200,
58 ef_search: 50,
59 max_results: 100,
60 enabled: true,
61 min_vectors_for_ann: 1000,
62 }
63 }
64}
65
66impl AnnConfig {
67 pub fn with_m(mut self, m: usize) -> Self {
68 self.m = m;
69 self
70 }
71
72 pub fn with_ef_construction(mut self, ef: usize) -> Self {
73 self.ef_construction = ef;
74 self
75 }
76
77 pub fn with_ef_search(mut self, ef: usize) -> Self {
78 self.ef_search = ef;
79 self
80 }
81
82 pub fn with_max_results(mut self, max: usize) -> Self {
83 self.max_results = max;
84 self
85 }
86
87 pub fn with_enabled(mut self, enabled: bool) -> Self {
88 self.enabled = enabled;
89 self
90 }
91
92 pub fn with_min_vectors_for_ann(mut self, min: usize) -> Self {
93 self.min_vectors_for_ann = min;
94 self
95 }
96
97 pub fn should_use_ann(&self, num_vectors: usize) -> bool {
99 self.enabled && num_vectors >= self.min_vectors_for_ann
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct AnnResult {
106 pub index: usize,
108 pub distance: f32,
110}
111
112pub struct AnnIndex {
114 config: AnnConfig,
115 dimension: usize,
116 hnsw: Option<Hnsw<'static, f32, DistCosine>>,
117 id_to_index: HashMap<String, usize>,
118 index_to_id: HashMap<usize, String>,
119 vectors: Vec<Vec<f32>>,
120 built: bool,
121}
122
123impl AnnIndex {
124 pub fn new(dimension: usize, config: AnnConfig) -> Self {
126 Self {
127 config,
128 dimension,
129 hnsw: None,
130 id_to_index: HashMap::new(),
131 index_to_id: HashMap::new(),
132 vectors: Vec::new(),
133 built: false,
134 }
135 }
136
137 pub fn insert(&mut self, id: String, vector: Vec<f32>) -> Result<(), AnnError> {
139 if vector.len() != self.dimension {
140 return Err(AnnError::DimensionMismatch {
141 expected: self.dimension,
142 got: vector.len(),
143 });
144 }
145
146 let index = self.vectors.len();
147 self.vectors.push(vector);
148 self.id_to_index.insert(id.clone(), index);
149 self.index_to_id.insert(index, id);
150
151 self.built = false;
153
154 Ok(())
155 }
156
157 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
159 if query.len() != self.dimension {
160 return Err(AnnError::DimensionMismatch {
161 expected: self.dimension,
162 got: query.len(),
163 });
164 }
165
166 let k = k.min(self.config.max_results);
167
168 if self.built && self.config.should_use_ann(self.vectors.len()) && self.hnsw.is_some() {
170 self.hnsw_search(query, k)
172 } else {
173 self.linear_search(query, k)
175 }
176 }
177
178 fn hnsw_search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
180 if let Some(ref hnsw) = self.hnsw {
181 let ef = self.config.ef_search;
182 let results: Vec<Neighbour> = hnsw.search(query, k, ef);
183
184 Ok(results
185 .into_iter()
186 .map(|neighbour| AnnResult {
187 index: neighbour.get_origin_id(),
188 distance: neighbour.distance,
189 })
190 .collect())
191 } else {
192 Err(AnnError::NotBuilt)
193 }
194 }
195
196 fn linear_search(&self, query: &[f32], k: usize) -> Result<Vec<AnnResult>, AnnError> {
198 if self.vectors.is_empty() {
199 return Ok(Vec::new());
200 }
201
202 let mut distances: Vec<(usize, f32)> = self
204 .vectors
205 .iter()
206 .enumerate()
207 .map(|(idx, vec)| (idx, cosine_distance(query, vec)))
208 .collect();
209
210 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
212
213 let results = distances
215 .into_iter()
216 .take(k)
217 .map(|(idx, dist)| AnnResult {
218 index: idx,
219 distance: dist,
220 })
221 .collect();
222
223 Ok(results)
224 }
225
226 pub fn get_id(&self, index: usize) -> Option<&String> {
228 self.index_to_id.get(&index)
229 }
230
231 pub fn get_index(&self, id: &str) -> Option<usize> {
233 self.id_to_index.get(id).copied()
234 }
235
236 pub fn len(&self) -> usize {
238 self.vectors.len()
239 }
240
241 pub fn is_empty(&self) -> bool {
242 self.vectors.is_empty()
243 }
244
245 pub fn is_built(&self) -> bool {
247 self.built
248 }
249
250 pub fn build(&mut self) {
253 if self.vectors.is_empty() {
254 return;
255 }
256
257 let nb_elem = self.vectors.len();
260 if nb_elem < 10 {
261 self.built = true;
263 return;
264 }
265
266 let nb_layer = 16.min((nb_elem as f32).ln().trunc() as usize);
268
269 let hnsw = Hnsw::<f32, DistCosine>::new(
271 self.config.m,
272 nb_elem,
273 nb_layer,
274 self.config.ef_construction,
275 DistCosine {},
276 );
277
278 let data_for_insertion: Vec<(&Vec<f32>, usize)> = self
281 .vectors
282 .iter()
283 .enumerate()
284 .map(|(idx, vec)| (vec, idx))
285 .collect();
286 hnsw.parallel_insert(&data_for_insertion);
287
288 self.hnsw = Some(hnsw);
289 self.built = true;
290 }
291
292 pub fn rebuild(&mut self) {
294 self.built = false;
295 self.build();
296 }
297
298 pub fn config(&self) -> &AnnConfig {
300 &self.config
301 }
302
303 pub fn update_config(&mut self, config: AnnConfig) {
305 let needs_rebuild =
306 config.m != self.config.m || config.ef_construction != self.config.ef_construction;
307
308 self.config = config;
309
310 if needs_rebuild {
311 self.built = false;
312 }
313 }
314}
315
316#[derive(Debug, thiserror::Error)]
318pub enum AnnError {
319 #[error("Dimension mismatch: expected {expected}, got {got}")]
320 DimensionMismatch { expected: usize, got: usize },
321 #[error("Index not built")]
322 NotBuilt,
323 #[error("HNSW error: {0}")]
324 HnswError(String),
325}
326
327fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
330 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
331 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
332 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
333
334 if norm_a == 0.0 || norm_b == 0.0 {
335 return 1.0; }
337
338 let similarity = dot / (norm_a * norm_b);
339 1.0 - similarity.clamp(-1.0, 1.0)
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_ann_config_defaults() {
349 let config = AnnConfig::default();
350 assert_eq!(config.m, 16);
351 assert_eq!(config.ef_construction, 200);
352 assert_eq!(config.ef_search, 50);
353 assert!(config.enabled);
354 assert_eq!(config.min_vectors_for_ann, 1000);
355 }
356
357 #[test]
358 fn test_ann_config_builder() {
359 let config = AnnConfig::default()
360 .with_m(32)
361 .with_ef_construction(400)
362 .with_ef_search(100)
363 .with_enabled(false)
364 .with_min_vectors_for_ann(500);
365
366 assert_eq!(config.m, 32);
367 assert_eq!(config.ef_construction, 400);
368 assert_eq!(config.ef_search, 100);
369 assert!(!config.enabled);
370 assert_eq!(config.min_vectors_for_ann, 500);
371 }
372
373 #[test]
374 fn test_should_use_ann() {
375 let config = AnnConfig::default();
376
377 assert!(config.should_use_ann(1000));
379 assert!(config.should_use_ann(10000));
380
381 assert!(!config.should_use_ann(999));
383 assert!(!config.should_use_ann(100));
384
385 let disabled_config = AnnConfig::default().with_enabled(false);
387 assert!(!disabled_config.should_use_ann(10000));
388 }
389
390 #[test]
391 fn test_ann_index_insert_and_linear_search() {
392 let mut index = AnnIndex::new(3, AnnConfig::default());
393
394 index
396 .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
397 .unwrap();
398 index
399 .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
400 .unwrap();
401 index
402 .insert("doc3".to_string(), vec![0.0, 0.0, 1.0])
403 .unwrap();
404
405 let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
407 assert_eq!(results.len(), 2);
408 assert_eq!(results[0].index, 0); }
410
411 #[test]
412 fn test_ann_index_dimension_mismatch() {
413 let mut index = AnnIndex::new(3, AnnConfig::default());
414
415 let result = index.insert("doc1".to_string(), vec![1.0, 0.0]);
417 assert!(matches!(result, Err(AnnError::DimensionMismatch { .. })));
418
419 index
421 .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
422 .unwrap();
423 let result = index.search(&[1.0, 0.0], 1);
424 assert!(matches!(result, Err(AnnError::DimensionMismatch { .. })));
425 }
426
427 #[test]
428 fn test_ann_index_empty_search() {
429 let index = AnnIndex::new(3, AnnConfig::default());
430 let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap();
431 assert!(results.is_empty());
432 }
433
434 #[test]
435 fn test_id_index_mapping() {
436 let mut index = AnnIndex::new(3, AnnConfig::default());
437
438 index
439 .insert("doc-a".to_string(), vec![1.0, 0.0, 0.0])
440 .unwrap();
441 index
442 .insert("doc-b".to_string(), vec![0.0, 1.0, 0.0])
443 .unwrap();
444
445 assert_eq!(index.get_index("doc-a"), Some(0));
446 assert_eq!(index.get_index("doc-b"), Some(1));
447 assert_eq!(index.get_id(0), Some(&"doc-a".to_string()));
448 assert_eq!(index.get_id(1), Some(&"doc-b".to_string()));
449 }
450
451 #[test]
452 fn test_cosine_distance() {
453 let d = cosine_distance(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]);
455 assert!(d.abs() < 0.001);
456
457 let d = cosine_distance(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]);
459 assert!((d - 1.0).abs() < 0.001);
460
461 let d = cosine_distance(&[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0]);
463 assert!((d - 2.0).abs() < 0.001);
464 }
465
466 #[test]
467 fn test_ann_index_search_respects_k() {
468 let mut index = AnnIndex::new(3, AnnConfig::default());
469
470 for i in 0..5 {
472 index
473 .insert(format!("doc{i}"), vec![i as f32, 0.0, 0.0])
474 .unwrap();
475 }
476
477 let results = index.search(&[0.0, 0.0, 0.0], 2).unwrap();
479 assert_eq!(results.len(), 2);
480
481 let results = index.search(&[0.0, 0.0, 0.0], 10).unwrap();
483 assert_eq!(results.len(), 5); }
485
486 #[test]
487 fn test_ann_index_build_and_search() {
488 let mut index = AnnIndex::new(
489 3,
490 AnnConfig::default().with_min_vectors_for_ann(1), );
492
493 index
495 .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
496 .unwrap();
497 index
498 .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
499 .unwrap();
500 index
501 .insert("doc3".to_string(), vec![0.0, 0.0, 1.0])
502 .unwrap();
503
504 assert!(!index.is_built());
506
507 index.build();
509 assert!(index.is_built());
510
511 let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
513 assert_eq!(results.len(), 2);
514 }
515
516 #[test]
517 fn test_ann_index_rebuild() {
518 let mut index = AnnIndex::new(3, AnnConfig::default().with_min_vectors_for_ann(1));
519
520 index
521 .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
522 .unwrap();
523 index.build();
524 assert!(index.is_built());
525
526 index
528 .insert("doc2".to_string(), vec![0.0, 1.0, 0.0])
529 .unwrap();
530 assert!(!index.is_built()); index.rebuild();
534 assert!(index.is_built());
535 }
536
537 #[test]
538 fn test_update_config_triggers_rebuild() {
539 let mut index = AnnIndex::new(3, AnnConfig::default().with_min_vectors_for_ann(1));
540
541 index
542 .insert("doc1".to_string(), vec![1.0, 0.0, 0.0])
543 .unwrap();
544 index.build();
545 assert!(index.is_built());
546
547 let new_config = AnnConfig::default().with_min_vectors_for_ann(1).with_m(32);
549 index.update_config(new_config);
550
551 assert!(!index.is_built());
553 }
554}