1use crate::filter::{Filter, Metadata};
49use crate::simd;
50use crate::types::{DistanceMetric, SearchResult};
51use anyhow::{anyhow, Result};
52use rand::Rng;
53use serde::{Deserialize, Serialize};
54use std::cmp::Ordering;
55use std::collections::{BinaryHeap, HashMap, HashSet};
56use tracing::{debug, info};
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct HnswConfig {
61 pub metric: DistanceMetric,
63 pub m: usize,
65 pub m0: usize,
67 pub ef_construction: usize,
69 pub ef_search: usize,
71 pub ml: f64,
73 pub normalize: bool,
75}
76
77impl Default for HnswConfig {
78 fn default() -> Self {
79 let m = 16;
80 Self {
81 metric: DistanceMetric::Cosine,
82 m,
83 m0: m * 2,
84 ef_construction: 200,
85 ef_search: 50,
86 ml: 1.0 / (m as f64).ln(),
87 normalize: true,
88 }
89 }
90}
91
92impl HnswConfig {
93 pub fn high_recall() -> Self {
95 let m = 32;
96 Self {
97 metric: DistanceMetric::Cosine,
98 m,
99 m0: m * 2,
100 ef_construction: 400,
101 ef_search: 100,
102 ml: 1.0 / (m as f64).ln(),
103 normalize: true,
104 }
105 }
106
107 pub fn fast() -> Self {
109 let m = 12;
110 Self {
111 metric: DistanceMetric::Cosine,
112 m,
113 m0: m * 2,
114 ef_construction: 100,
115 ef_search: 30,
116 ml: 1.0 / (m as f64).ln(),
117 normalize: true,
118 }
119 }
120}
121
122#[allow(dead_code)]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct HnswNode {
126 id: usize,
128 level: usize,
130 neighbors: Vec<Vec<usize>>,
132}
133
134impl HnswNode {
135 fn new(id: usize, level: usize) -> Self {
136 Self {
137 id,
138 level,
139 neighbors: vec![Vec::new(); level + 1],
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy)]
146struct Candidate {
147 id: usize,
148 distance: f32,
149}
150
151impl PartialEq for Candidate {
152 fn eq(&self, other: &Self) -> bool {
153 self.distance == other.distance
154 }
155}
156
157impl Eq for Candidate {}
158
159impl PartialOrd for Candidate {
160 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161 Some(self.cmp(other))
162 }
163}
164
165impl Ord for Candidate {
166 fn cmp(&self, other: &Self) -> Ordering {
167 other
169 .distance
170 .partial_cmp(&self.distance)
171 .unwrap_or(Ordering::Equal)
172 }
173}
174
175#[derive(Debug, Clone, Copy)]
177struct MaxCandidate {
178 id: usize,
179 distance: f32,
180}
181
182impl PartialEq for MaxCandidate {
183 fn eq(&self, other: &Self) -> bool {
184 self.distance == other.distance
185 }
186}
187
188impl Eq for MaxCandidate {}
189
190impl PartialOrd for MaxCandidate {
191 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
192 Some(self.cmp(other))
193 }
194}
195
196impl Ord for MaxCandidate {
197 fn cmp(&self, other: &Self) -> Ordering {
198 self.distance
200 .partial_cmp(&other.distance)
201 .unwrap_or(Ordering::Equal)
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct HnswIndex {
208 config: HnswConfig,
209 vectors: Vec<Vec<f32>>,
211 entity_ids: Vec<String>,
213 nodes: Vec<HnswNode>,
215 entry_point: Option<usize>,
217 max_level: usize,
219 dimensions: usize,
221 is_built: bool,
223 metadata: HashMap<String, Metadata>,
225 deleted: HashSet<String>,
227}
228
229impl HnswIndex {
230 pub fn new(config: HnswConfig) -> Self {
232 info!(
233 "Initialized HNSW index: m={}, ef_construction={}, ef_search={}",
234 config.m, config.ef_construction, config.ef_search
235 );
236
237 Self {
238 config,
239 vectors: Vec::new(),
240 entity_ids: Vec::new(),
241 nodes: Vec::new(),
242 entry_point: None,
243 max_level: 0,
244 dimensions: 0,
245 is_built: false,
246 metadata: HashMap::new(),
247 deleted: HashSet::new(),
248 }
249 }
250
251 pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
253 if embeddings.is_empty() {
254 return Err(anyhow!("Cannot build index from empty embeddings"));
255 }
256
257 info!(
258 "Building HNSW index for {} entities (m={}, ef_construction={})",
259 embeddings.len(),
260 self.config.m,
261 self.config.ef_construction
262 );
263
264 self.vectors.clear();
266 self.entity_ids.clear();
267 self.nodes.clear();
268 self.entry_point = None;
269 self.max_level = 0;
270
271 self.dimensions = embeddings.values().next().unwrap().len();
273
274 for (entity_id, vec) in embeddings {
275 let mut v = vec.clone();
276 if self.config.normalize {
277 Self::normalize_vector(&mut v);
278 }
279 self.vectors.push(v);
280 self.entity_ids.push(entity_id.clone());
281 }
282
283 for i in 0..self.vectors.len() {
285 self.insert_node(i)?;
286 }
287
288 self.is_built = true;
289 info!(
290 "HNSW index built: {} vectors, max_level={}",
291 self.vectors.len(),
292 self.max_level
293 );
294
295 Ok(())
296 }
297
298 fn insert_node(&mut self, id: usize) -> Result<()> {
300 let level = self.random_level();
301 let node = HnswNode::new(id, level);
302 self.nodes.push(node);
303
304 if self.entry_point.is_none() {
306 self.entry_point = Some(id);
307 self.max_level = level;
308 return Ok(());
309 }
310
311 let entry_point = self.entry_point.unwrap();
312
313 let mut current_nearest = entry_point;
315
316 for layer in (level + 1..=self.max_level).rev() {
317 current_nearest = self.greedy_search(id, current_nearest, layer);
318 }
319
320 for layer in (0..=level.min(self.max_level)).rev() {
322 let neighbors =
324 self.search_layer(id, current_nearest, self.config.ef_construction, layer);
325
326 let m = if layer == 0 {
328 self.config.m0
329 } else {
330 self.config.m
331 };
332
333 let selected = self.select_neighbors(&neighbors, m);
334
335 self.nodes[id].neighbors[layer] = selected.clone();
337
338 for &neighbor_id in &selected {
340 self.nodes[neighbor_id].neighbors[layer].push(id);
341
342 let max_connections = if layer == 0 {
344 self.config.m0
345 } else {
346 self.config.m
347 };
348
349 if self.nodes[neighbor_id].neighbors[layer].len() > max_connections {
350 self.prune_connections(neighbor_id, layer, max_connections);
351 }
352 }
353
354 if !selected.is_empty() {
355 current_nearest = selected[0];
356 }
357 }
358
359 if level > self.max_level {
361 self.entry_point = Some(id);
362 self.max_level = level;
363 }
364
365 Ok(())
366 }
367
368 fn random_level(&self) -> usize {
370 let mut rng = rand::rng();
371 let mut level = 0;
372 let uniform: f64 = rng.random();
373
374 while uniform < (-((level + 1) as f64) * self.config.ml).exp() && level < 32 {
376 level += 1;
377 }
378
379 level
380 }
381
382 fn greedy_search(&self, query_id: usize, start: usize, layer: usize) -> usize {
384 let query = &self.vectors[query_id];
385 let mut current = start;
386 let mut current_dist = self.compute_distance(query, &self.vectors[current]);
387
388 loop {
389 let mut changed = false;
390
391 for &neighbor in &self.nodes[current].neighbors[layer] {
392 let dist = self.compute_distance(query, &self.vectors[neighbor]);
393 if dist < current_dist {
394 current = neighbor;
395 current_dist = dist;
396 changed = true;
397 }
398 }
399
400 if !changed {
401 break;
402 }
403 }
404
405 current
406 }
407
408 fn search_layer(
410 &self,
411 query_id: usize,
412 entry_point: usize,
413 ef: usize,
414 layer: usize,
415 ) -> Vec<(usize, f32)> {
416 let query = &self.vectors[query_id];
417 self.search_layer_by_vector(query, entry_point, ef, layer)
418 }
419
420 fn search_layer_by_vector(
422 &self,
423 query: &[f32],
424 entry_point: usize,
425 ef: usize,
426 layer: usize,
427 ) -> Vec<(usize, f32)> {
428 let mut visited = HashSet::new();
429 let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
430 let mut results: BinaryHeap<MaxCandidate> = BinaryHeap::new();
431
432 let entry_dist = self.compute_distance(query, &self.vectors[entry_point]);
433
434 visited.insert(entry_point);
435 candidates.push(Candidate {
436 id: entry_point,
437 distance: entry_dist,
438 });
439 results.push(MaxCandidate {
440 id: entry_point,
441 distance: entry_dist,
442 });
443
444 while let Some(Candidate { id: current, .. }) = candidates.pop() {
445 let furthest_result = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
446
447 if self.compute_distance(query, &self.vectors[current]) > furthest_result {
449 break;
450 }
451
452 if layer < self.nodes[current].neighbors.len() {
454 for &neighbor in &self.nodes[current].neighbors[layer] {
455 if visited.contains(&neighbor) {
456 continue;
457 }
458 visited.insert(neighbor);
459
460 let dist = self.compute_distance(query, &self.vectors[neighbor]);
461 let furthest = results.peek().map(|c| c.distance).unwrap_or(f32::MAX);
462
463 if dist < furthest || results.len() < ef {
464 candidates.push(Candidate {
465 id: neighbor,
466 distance: dist,
467 });
468 results.push(MaxCandidate {
469 id: neighbor,
470 distance: dist,
471 });
472
473 while results.len() > ef {
475 results.pop();
476 }
477 }
478 }
479 }
480 }
481
482 let mut result_vec: Vec<(usize, f32)> =
484 results.into_iter().map(|c| (c.id, c.distance)).collect();
485 result_vec.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
486 result_vec
487 }
488
489 fn select_neighbors(&self, candidates: &[(usize, f32)], m: usize) -> Vec<usize> {
491 candidates.iter().take(m).map(|(id, _)| *id).collect()
492 }
493
494 fn prune_connections(&mut self, node_id: usize, layer: usize, max_connections: usize) {
496 let node_vec = self.vectors[node_id].clone();
497
498 let mut neighbor_dists: Vec<(usize, f32)> = self.nodes[node_id].neighbors[layer]
500 .iter()
501 .map(|&n| (n, self.compute_distance(&node_vec, &self.vectors[n])))
502 .collect();
503
504 neighbor_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
506
507 self.nodes[node_id].neighbors[layer] = neighbor_dists
509 .into_iter()
510 .take(max_connections)
511 .map(|(id, _)| id)
512 .collect();
513 }
514
515 #[inline]
519 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
520 simd::compute_distance_lower_is_better_simd(self.config.metric, a, b)
522 }
523
524 #[inline]
526 fn normalize_vector(vec: &mut [f32]) {
527 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
528 if norm > 1e-10 {
529 for x in vec.iter_mut() {
530 *x /= norm;
531 }
532 }
533 }
534
535 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
537 if !self.is_built {
538 return Err(anyhow!("Index not built. Call build() first"));
539 }
540
541 if query.len() != self.dimensions {
542 return Err(anyhow!(
543 "Query dimension {} doesn't match index dimension {}",
544 query.len(),
545 self.dimensions
546 ));
547 }
548
549 let mut normalized_query = query.to_vec();
551 if self.config.normalize {
552 Self::normalize_vector(&mut normalized_query);
553 }
554
555 debug!("HNSW search: k={}, ef_search={}", k, self.config.ef_search);
556
557 let entry_point = self.entry_point.ok_or_else(|| anyhow!("Empty index"))?;
558
559 let mut current = entry_point;
561 for layer in (1..=self.max_level).rev() {
562 current = self.greedy_search_by_vector(&normalized_query, current, layer);
563 }
564
565 let candidates =
567 self.search_layer_by_vector(&normalized_query, current, self.config.ef_search, 0);
568
569 let results: Vec<SearchResult> = candidates
571 .into_iter()
572 .filter(|(id, _)| !self.deleted.contains(&self.entity_ids[*id]))
573 .take(k)
574 .enumerate()
575 .map(|(rank, (id, distance))| SearchResult {
576 entity_id: self.entity_ids[id].clone(),
577 score: self.distance_to_score(distance),
578 distance,
579 rank: rank + 1,
580 })
581 .collect();
582
583 debug!("Found {} results", results.len());
584 Ok(results)
585 }
586
587 fn greedy_search_by_vector(&self, query: &[f32], start: usize, layer: usize) -> usize {
589 let mut current = start;
590 let mut current_dist = self.compute_distance(query, &self.vectors[current]);
591
592 loop {
593 let mut changed = false;
594
595 if layer < self.nodes[current].neighbors.len() {
596 for &neighbor in &self.nodes[current].neighbors[layer] {
597 let dist = self.compute_distance(query, &self.vectors[neighbor]);
598 if dist < current_dist {
599 current = neighbor;
600 current_dist = dist;
601 changed = true;
602 }
603 }
604 }
605
606 if !changed {
607 break;
608 }
609 }
610
611 current
612 }
613
614 fn distance_to_score(&self, distance: f32) -> f32 {
616 match self.config.metric {
617 DistanceMetric::Cosine => 1.0 - distance,
618 DistanceMetric::Euclidean | DistanceMetric::Manhattan => -distance,
619 DistanceMetric::DotProduct => -distance,
620 }
621 }
622
623 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
625 if !self.is_built {
626 return Err(anyhow!("Index not built. Call build() first"));
627 }
628
629 info!("HNSW batch search: {} queries", queries.len());
630
631 let results: Vec<Vec<SearchResult>> = queries
632 .iter()
633 .map(|query| self.search(query, k).unwrap_or_default())
634 .collect();
635
636 Ok(results)
637 }
638
639 pub fn add(&mut self, entity_id: &str, vector: &[f32]) -> Result<()> {
641 if !self.is_built {
642 return Err(anyhow!(
643 "Index not built. Call build() first or use build() with initial data"
644 ));
645 }
646
647 if vector.len() != self.dimensions {
648 return Err(anyhow!(
649 "Vector dimension {} doesn't match index dimension {}",
650 vector.len(),
651 self.dimensions
652 ));
653 }
654
655 let mut v = vector.to_vec();
657 if self.config.normalize {
658 Self::normalize_vector(&mut v);
659 }
660
661 let id = self.vectors.len();
662 self.vectors.push(v);
663 self.entity_ids.push(entity_id.to_string());
664
665 self.insert_node(id)?;
667
668 debug!("Added vector '{}' to HNSW index", entity_id);
669 Ok(())
670 }
671
672 pub fn get_stats(&self) -> HnswStats {
674 let total_connections: usize = self
675 .nodes
676 .iter()
677 .flat_map(|n| n.neighbors.iter())
678 .map(|neighbors| neighbors.len())
679 .sum();
680
681 let avg_connections = if !self.nodes.is_empty() {
682 total_connections as f64 / self.nodes.len() as f64
683 } else {
684 0.0
685 };
686
687 HnswStats {
688 num_vectors: self.vectors.len(),
689 active_vectors: self.active_count(),
690 deleted_vectors: self.deleted_count(),
691 dimensions: self.dimensions,
692 max_level: self.max_level,
693 avg_connections,
694 m: self.config.m,
695 ef_construction: self.config.ef_construction,
696 ef_search: self.config.ef_search,
697 is_built: self.is_built,
698 }
699 }
700
701 pub fn set_ef_search(&mut self, ef: usize) {
703 self.config.ef_search = ef;
704 }
705
706 pub fn remove(&mut self, entity_id: &str) -> bool {
711 if self.entity_ids.iter().any(|e| e == entity_id) {
712 self.deleted.insert(entity_id.to_string());
713 self.metadata.remove(entity_id);
714 debug!("Marked '{}' as deleted (tombstone)", entity_id);
715 true
716 } else {
717 false
718 }
719 }
720
721 pub fn is_deleted(&self, entity_id: &str) -> bool {
723 self.deleted.contains(entity_id)
724 }
725
726 pub fn deleted_count(&self) -> usize {
728 self.deleted.len()
729 }
730
731 pub fn active_count(&self) -> usize {
733 self.vectors.len() - self.deleted.len()
734 }
735
736 pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
738 self.metadata.insert(entity_id.to_string(), metadata);
739 }
740
741 pub fn set_metadata_batch(&mut self, metadata_map: HashMap<String, Metadata>) {
743 self.metadata.extend(metadata_map);
744 }
745
746 #[inline]
748 pub fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
749 self.metadata.get(entity_id)
750 }
751
752 pub fn filtered_search(
757 &self,
758 query: &[f32],
759 k: usize,
760 filter: &Filter,
761 ) -> Result<Vec<SearchResult>> {
762 if !self.is_built {
763 return Err(anyhow!("Index not built. Call build() first"));
764 }
765
766 if filter.is_empty() {
767 return self.search(query, k);
768 }
769
770 let expanded_k = (k * 10).min(self.vectors.len());
772
773 debug!(
774 "HNSW filtered search: k={}, expanded_k={}, filter conditions={}",
775 k,
776 expanded_k,
777 filter.conditions().len()
778 );
779
780 let all_results = self.search(query, expanded_k)?;
782
783 let filtered: Vec<SearchResult> = all_results
785 .into_iter()
786 .filter(|r| {
787 self.metadata
788 .get(&r.entity_id)
789 .is_some_and(|m| filter.matches(m))
790 })
791 .take(k)
792 .enumerate()
793 .map(|(i, mut r)| {
794 r.rank = i + 1; r
796 })
797 .collect();
798
799 debug!("HNSW filtered search returned {} results", filtered.len());
800 Ok(filtered)
801 }
802
803 pub fn prefiltered_search(
808 &self,
809 query: &[f32],
810 k: usize,
811 filter: &Filter,
812 ) -> Result<Vec<SearchResult>> {
813 if !self.is_built {
814 return Err(anyhow!("Index not built. Call build() first"));
815 }
816
817 if query.len() != self.dimensions {
818 return Err(anyhow!(
819 "Query dimension {} doesn't match index dimension {}",
820 query.len(),
821 self.dimensions
822 ));
823 }
824
825 if filter.is_empty() {
826 return self.search(query, k);
827 }
828
829 debug!("HNSW pre-filtered search: k={}", k);
830
831 let mut normalized_query = query.to_vec();
833 if self.config.normalize {
834 Self::normalize_vector(&mut normalized_query);
835 }
836
837 let matching_indices: Vec<usize> = (0..self.entity_ids.len())
839 .filter(|&i| {
840 self.metadata
841 .get(&self.entity_ids[i])
842 .is_some_and(|m| filter.matches(m))
843 })
844 .collect();
845
846 if matching_indices.is_empty() {
847 return Ok(Vec::new());
848 }
849
850 let mut scores: Vec<(usize, f32)> = matching_indices
852 .iter()
853 .map(|&i| {
854 let dist = self.compute_distance(&normalized_query, &self.vectors[i]);
855 (i, dist)
856 })
857 .collect();
858
859 scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
861
862 let results: Vec<SearchResult> = scores
864 .iter()
865 .take(k)
866 .enumerate()
867 .map(|(rank, &(idx, distance))| SearchResult {
868 entity_id: self.entity_ids[idx].clone(),
869 score: self.distance_to_score(distance),
870 distance,
871 rank: rank + 1,
872 })
873 .collect();
874
875 debug!(
876 "HNSW pre-filtered search returned {} results",
877 results.len()
878 );
879 Ok(results)
880 }
881
882 pub fn optimize_graph(&mut self) -> Result<()> {
892 if !self.is_built {
893 return Err(anyhow!("Index not built. Call build() first"));
894 }
895
896 info!("Optimizing HNSW graph structure...");
897
898 let deleted_indices: HashSet<usize> = self
899 .entity_ids
900 .iter()
901 .enumerate()
902 .filter(|(_, id)| self.deleted.contains(*id))
903 .map(|(idx, _)| idx)
904 .collect();
905
906 let mut optimized_count = 0;
907
908 for node_idx in 0..self.nodes.len() {
910 let node_level = self.nodes[node_idx].level;
911
912 for layer in 0..=node_level {
913 let original_len = self.nodes[node_idx].neighbors[layer].len();
914
915 self.nodes[node_idx].neighbors[layer]
917 .retain(|&neighbor_id| !deleted_indices.contains(&neighbor_id));
918
919 let max_connections = if layer == 0 {
921 self.config.m0
922 } else {
923 self.config.m
924 };
925
926 if self.nodes[node_idx].neighbors[layer].len() > max_connections {
927 let node_vec = self.vectors[node_idx].clone();
929 let mut neighbor_distances: Vec<(usize, f32)> = self.nodes[node_idx].neighbors
930 [layer]
931 .iter()
932 .map(|&neighbor_id| {
933 let dist = self.compute_distance(&node_vec, &self.vectors[neighbor_id]);
934 (neighbor_id, dist)
935 })
936 .collect();
937
938 neighbor_distances
940 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
941
942 self.nodes[node_idx].neighbors[layer] = neighbor_distances
944 .iter()
945 .take(max_connections)
946 .map(|(id, _)| *id)
947 .collect();
948 }
949
950 if self.nodes[node_idx].neighbors[layer].len() != original_len {
951 optimized_count += 1;
952 }
953 }
954 }
955
956 info!(
957 "HNSW graph optimization complete. {} node connections updated.",
958 optimized_count
959 );
960
961 Ok(())
962 }
963
964 pub fn compact(&mut self) -> Result<()> {
971 if !self.is_built {
972 return Err(anyhow!("Index not built. Call build() first"));
973 }
974
975 if self.deleted.is_empty() {
976 info!("No deleted vectors to compact");
977 return Ok(());
978 }
979
980 info!(
981 "Compacting HNSW index: removing {} deleted vectors out of {}",
982 self.deleted.len(),
983 self.vectors.len()
984 );
985
986 let mut new_embeddings = HashMap::new();
988 let mut new_metadata = HashMap::new();
989
990 for (i, entity_id) in self.entity_ids.iter().enumerate() {
991 if !self.deleted.contains(entity_id) {
992 new_embeddings.insert(entity_id.clone(), self.vectors[i].clone());
993
994 if let Some(metadata) = self.metadata.get(entity_id) {
995 new_metadata.insert(entity_id.clone(), metadata.clone());
996 }
997 }
998 }
999
1000 self.build(&new_embeddings)?;
1002
1003 self.set_metadata_batch(new_metadata);
1005
1006 info!("HNSW index compaction complete");
1007
1008 Ok(())
1009 }
1010}
1011
1012#[derive(Debug, Clone, Serialize, Deserialize)]
1014pub struct HnswStats {
1015 pub num_vectors: usize,
1017 pub active_vectors: usize,
1019 pub deleted_vectors: usize,
1021 pub dimensions: usize,
1023 pub max_level: usize,
1025 pub avg_connections: f64,
1027 pub m: usize,
1029 pub ef_construction: usize,
1031 pub ef_search: usize,
1033 pub is_built: bool,
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040
1041 fn create_test_embeddings() -> HashMap<String, Vec<f32>> {
1042 let mut embeddings = HashMap::new();
1043
1044 embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
1045 embeddings.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
1046 embeddings.insert("doc3".to_string(), vec![0.0, 1.0, 0.0]);
1047 embeddings.insert("doc4".to_string(), vec![0.0, 0.0, 1.0]);
1048 embeddings.insert("doc5".to_string(), vec![0.7, 0.7, 0.0]);
1049
1050 embeddings
1051 }
1052
1053 #[test]
1054 fn test_hnsw_config_default() {
1055 let config = HnswConfig::default();
1056 assert_eq!(config.m, 16);
1057 assert_eq!(config.m0, 32);
1058 assert_eq!(config.ef_construction, 200);
1059 assert_eq!(config.ef_search, 50);
1060 }
1061
1062 #[test]
1063 fn test_hnsw_build() {
1064 let embeddings = create_test_embeddings();
1065 let mut index = HnswIndex::new(HnswConfig::default());
1066
1067 assert!(index.build(&embeddings).is_ok());
1068 assert!(index.is_built);
1069
1070 let stats = index.get_stats();
1071 assert_eq!(stats.num_vectors, 5);
1072 assert_eq!(stats.dimensions, 3);
1073 }
1074
1075 #[test]
1076 fn test_hnsw_search() {
1077 let embeddings = create_test_embeddings();
1078 let mut index = HnswIndex::new(HnswConfig::default());
1079 index.build(&embeddings).unwrap();
1080
1081 let query = vec![1.0, 0.0, 0.0];
1083 let results = index.search(&query, 3).unwrap();
1084
1085 assert_eq!(results.len(), 3);
1086 assert!(results[0].entity_id == "doc1" || results[0].entity_id == "doc2");
1088 }
1089
1090 #[test]
1091 fn test_hnsw_batch_search() {
1092 let embeddings = create_test_embeddings();
1093 let mut index = HnswIndex::new(HnswConfig::default());
1094 index.build(&embeddings).unwrap();
1095
1096 let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
1097
1098 let results = index.batch_search(&queries, 2).unwrap();
1099 assert_eq!(results.len(), 2);
1100 assert_eq!(results[0].len(), 2);
1101 assert_eq!(results[1].len(), 2);
1102 }
1103
1104 #[test]
1105 fn test_hnsw_incremental_add() {
1106 let embeddings = create_test_embeddings();
1107 let mut index = HnswIndex::new(HnswConfig::default());
1108 index.build(&embeddings).unwrap();
1109
1110 index.add("doc6", &[0.5, 0.5, 0.5]).unwrap();
1112
1113 let stats = index.get_stats();
1114 assert_eq!(stats.num_vectors, 6);
1115
1116 let query = vec![0.5, 0.5, 0.5];
1118 let results = index.search(&query, 1).unwrap();
1119 assert_eq!(results[0].entity_id, "doc6");
1120 }
1121
1122 #[test]
1123 fn test_hnsw_search_accuracy() {
1124 let mut embeddings = HashMap::new();
1126 for i in 0..100 {
1127 let angle = (i as f32) * 2.0 * std::f32::consts::PI / 100.0;
1128 embeddings.insert(format!("doc{}", i), vec![angle.cos(), angle.sin(), 0.0]);
1129 }
1130
1131 let mut index = HnswIndex::new(HnswConfig::default());
1132 index.build(&embeddings).unwrap();
1133
1134 let query_angle = 0.5_f32;
1136 let query = vec![query_angle.cos(), query_angle.sin(), 0.0];
1137 let results = index.search(&query, 5).unwrap();
1138
1139 assert_eq!(results.len(), 5);
1141 assert!(results[0].score > 0.95);
1143 }
1144
1145 #[test]
1146 fn test_hnsw_empty_error() {
1147 let embeddings: HashMap<String, Vec<f32>> = HashMap::new();
1148 let mut index = HnswIndex::new(HnswConfig::default());
1149
1150 assert!(index.build(&embeddings).is_err());
1151 }
1152
1153 #[test]
1154 fn test_hnsw_dimension_mismatch() {
1155 let embeddings = create_test_embeddings();
1156 let mut index = HnswIndex::new(HnswConfig::default());
1157 index.build(&embeddings).unwrap();
1158
1159 let query = vec![1.0, 0.0]; assert!(index.search(&query, 1).is_err());
1162 }
1163
1164 #[test]
1165 fn test_hnsw_stats() {
1166 let embeddings = create_test_embeddings();
1167 let mut index = HnswIndex::new(HnswConfig::default());
1168 index.build(&embeddings).unwrap();
1169
1170 let stats = index.get_stats();
1171 assert_eq!(stats.num_vectors, 5);
1172 assert_eq!(stats.dimensions, 3);
1173 assert_eq!(stats.m, 16);
1174 assert_eq!(stats.ef_construction, 200);
1175 assert!(stats.is_built);
1176 }
1177
1178 #[test]
1179 fn test_ef_search_adjustment() {
1180 let embeddings = create_test_embeddings();
1181 let mut index = HnswIndex::new(HnswConfig::default());
1182 index.build(&embeddings).unwrap();
1183
1184 index.set_ef_search(100);
1185 let stats = index.get_stats();
1186 assert_eq!(stats.ef_search, 100);
1187 }
1188
1189 fn create_test_metadata() -> HashMap<String, Metadata> {
1190 use crate::filter::FilterValue;
1191
1192 let mut metadata = HashMap::new();
1193
1194 let mut m1 = HashMap::new();
1195 m1.insert(
1196 "type".to_string(),
1197 FilterValue::String("article".to_string()),
1198 );
1199 m1.insert("year".to_string(), FilterValue::Int(2023));
1200 metadata.insert("doc1".to_string(), m1);
1201
1202 let mut m2 = HashMap::new();
1203 m2.insert(
1204 "type".to_string(),
1205 FilterValue::String("article".to_string()),
1206 );
1207 m2.insert("year".to_string(), FilterValue::Int(2022));
1208 metadata.insert("doc2".to_string(), m2);
1209
1210 let mut m3 = HashMap::new();
1211 m3.insert("type".to_string(), FilterValue::String("book".to_string()));
1212 m3.insert("year".to_string(), FilterValue::Int(2023));
1213 metadata.insert("doc3".to_string(), m3);
1214
1215 let mut m4 = HashMap::new();
1216 m4.insert("type".to_string(), FilterValue::String("book".to_string()));
1217 m4.insert("year".to_string(), FilterValue::Int(2021));
1218 metadata.insert("doc4".to_string(), m4);
1219
1220 let mut m5 = HashMap::new();
1221 m5.insert(
1222 "type".to_string(),
1223 FilterValue::String("article".to_string()),
1224 );
1225 m5.insert("year".to_string(), FilterValue::Int(2024));
1226 metadata.insert("doc5".to_string(), m5);
1227
1228 metadata
1229 }
1230
1231 #[test]
1232 fn test_hnsw_set_and_get_metadata() {
1233 use crate::filter::FilterValue;
1234
1235 let embeddings = create_test_embeddings();
1236 let mut index = HnswIndex::new(HnswConfig::default());
1237 index.build(&embeddings).unwrap();
1238
1239 let mut metadata = HashMap::new();
1240 metadata.insert(
1241 "type".to_string(),
1242 FilterValue::String("article".to_string()),
1243 );
1244
1245 index.set_metadata("doc1", metadata.clone());
1246
1247 let retrieved = index.get_metadata("doc1");
1248 assert!(retrieved.is_some());
1249 assert_eq!(
1250 retrieved.unwrap().get("type"),
1251 Some(&FilterValue::String("article".to_string()))
1252 );
1253 }
1254
1255 #[test]
1256 fn test_hnsw_filtered_search() {
1257 use crate::filter::FilterValue;
1258
1259 let embeddings = create_test_embeddings();
1260 let metadata = create_test_metadata();
1261 let mut index = HnswIndex::new(HnswConfig::default());
1262 index.build(&embeddings).unwrap();
1263 index.set_metadata_batch(metadata);
1264
1265 let filter = Filter::new().eq("type", "article");
1267 let query = vec![1.0, 0.0, 0.0];
1268 let results = index.filtered_search(&query, 5, &filter).unwrap();
1269
1270 assert_eq!(results.len(), 3);
1272 for result in &results {
1273 let meta = index.get_metadata(&result.entity_id).unwrap();
1274 assert_eq!(
1275 meta.get("type"),
1276 Some(&FilterValue::String("article".to_string()))
1277 );
1278 }
1279 }
1280
1281 #[test]
1282 fn test_hnsw_filtered_search_with_year() {
1283 let embeddings = create_test_embeddings();
1284 let metadata = create_test_metadata();
1285 let mut index = HnswIndex::new(HnswConfig::default());
1286 index.build(&embeddings).unwrap();
1287 index.set_metadata_batch(metadata);
1288
1289 let filter = Filter::new().gte("year", 2023i64);
1291 let query = vec![1.0, 0.0, 0.0];
1292 let results = index.filtered_search(&query, 5, &filter).unwrap();
1293
1294 assert_eq!(results.len(), 3);
1296 }
1297
1298 #[test]
1299 fn test_hnsw_prefiltered_search() {
1300 use crate::filter::FilterValue;
1301
1302 let embeddings = create_test_embeddings();
1303 let metadata = create_test_metadata();
1304 let mut index = HnswIndex::new(HnswConfig::default());
1305 index.build(&embeddings).unwrap();
1306 index.set_metadata_batch(metadata);
1307
1308 let filter = Filter::new().eq("type", "book");
1310 let query = vec![0.0, 1.0, 0.0]; let results = index.prefiltered_search(&query, 5, &filter).unwrap();
1312
1313 assert_eq!(results.len(), 2);
1315 for result in &results {
1316 let meta = index.get_metadata(&result.entity_id).unwrap();
1317 assert_eq!(
1318 meta.get("type"),
1319 Some(&FilterValue::String("book".to_string()))
1320 );
1321 }
1322 }
1323
1324 #[test]
1325 fn test_hnsw_filtered_search_empty_filter() {
1326 let embeddings = create_test_embeddings();
1327 let mut index = HnswIndex::new(HnswConfig::default());
1328 index.build(&embeddings).unwrap();
1329
1330 let filter = Filter::new();
1332 let query = vec![1.0, 0.0, 0.0];
1333 let results = index.filtered_search(&query, 3, &filter).unwrap();
1334
1335 assert_eq!(results.len(), 3);
1336 }
1337
1338 #[test]
1339 fn test_hnsw_filtered_search_no_matches() {
1340 let embeddings = create_test_embeddings();
1341 let metadata = create_test_metadata();
1342 let mut index = HnswIndex::new(HnswConfig::default());
1343 index.build(&embeddings).unwrap();
1344 index.set_metadata_batch(metadata);
1345
1346 let filter = Filter::new().eq("type", "journal");
1348 let query = vec![1.0, 0.0, 0.0];
1349 let results = index.filtered_search(&query, 5, &filter).unwrap();
1350
1351 assert_eq!(results.len(), 0);
1352 }
1353
1354 #[test]
1355 fn test_hnsw_lazy_delete() {
1356 let embeddings = create_test_embeddings();
1357 let mut index = HnswIndex::new(HnswConfig::default());
1358 index.build(&embeddings).unwrap();
1359
1360 let stats_before = index.get_stats();
1361 assert_eq!(stats_before.num_vectors, 5);
1362 assert_eq!(stats_before.active_vectors, 5);
1363 assert_eq!(stats_before.deleted_vectors, 0);
1364
1365 assert!(index.remove("doc1"));
1367 assert!(index.is_deleted("doc1"));
1368
1369 let stats_after = index.get_stats();
1370 assert_eq!(stats_after.num_vectors, 5); assert_eq!(stats_after.active_vectors, 4);
1372 assert_eq!(stats_after.deleted_vectors, 1);
1373
1374 let query = vec![1.0, 0.0, 0.0]; let results = index.search(&query, 5).unwrap();
1377
1378 for result in &results {
1380 assert_ne!(result.entity_id, "doc1");
1381 }
1382 assert_eq!(results.len(), 4);
1383 }
1384
1385 #[test]
1386 fn test_hnsw_delete_nonexistent() {
1387 let embeddings = create_test_embeddings();
1388 let mut index = HnswIndex::new(HnswConfig::default());
1389 index.build(&embeddings).unwrap();
1390
1391 assert!(!index.remove("nonexistent"));
1393 assert!(!index.is_deleted("nonexistent"));
1394 }
1395
1396 #[test]
1397 fn test_hnsw_delete_multiple() {
1398 let embeddings = create_test_embeddings();
1399 let mut index = HnswIndex::new(HnswConfig::default());
1400 index.build(&embeddings).unwrap();
1401
1402 index.remove("doc1");
1404 index.remove("doc2");
1405 index.remove("doc3");
1406
1407 let stats = index.get_stats();
1408 assert_eq!(stats.active_vectors, 2);
1409 assert_eq!(stats.deleted_vectors, 3);
1410
1411 let query = vec![0.5, 0.5, 0.5];
1413 let results = index.search(&query, 10).unwrap();
1414 assert_eq!(results.len(), 2);
1415 }
1416
1417 #[test]
1418 fn test_hnsw_delete_and_active_count() {
1419 let embeddings = create_test_embeddings();
1420 let mut index = HnswIndex::new(HnswConfig::default());
1421 index.build(&embeddings).unwrap();
1422
1423 assert_eq!(index.active_count(), 5);
1424 assert_eq!(index.deleted_count(), 0);
1425
1426 index.remove("doc1");
1427 assert_eq!(index.active_count(), 4);
1428 assert_eq!(index.deleted_count(), 1);
1429
1430 index.remove("doc2");
1431 assert_eq!(index.active_count(), 3);
1432 assert_eq!(index.deleted_count(), 2);
1433 }
1434}