1use crate::filter::{Filter, Metadata};
9use crate::simd;
10use crate::types::{DistanceMetric, IndexStats, SearchConfig, SearchResult};
11use anyhow::{anyhow, Result};
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use tracing::{debug, info};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct VectorSearchIndex {
20 config: SearchConfig,
21 embeddings: HashMap<String, Vec<f32>>,
22 entity_ids: Vec<String>,
23 embedding_matrix: Option<Vec<Vec<f32>>>,
24 dimensions: usize,
25 is_built: bool,
26 metadata: HashMap<String, Metadata>,
28}
29
30impl VectorSearchIndex {
31 pub fn new(config: SearchConfig) -> Self {
33 info!(
34 "Initialized vector search index: metric={:?}, parallel={}",
35 config.metric, config.parallel
36 );
37
38 Self {
39 config,
40 embeddings: HashMap::new(),
41 entity_ids: Vec::new(),
42 embedding_matrix: None,
43 dimensions: 0,
44 is_built: false,
45 metadata: HashMap::new(),
46 }
47 }
48
49 pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
51 if embeddings.is_empty() {
52 return Err(anyhow!("Cannot build index from empty embeddings"));
53 }
54
55 info!(
56 "Building vector search index for {} entities",
57 embeddings.len()
58 );
59
60 self.embeddings = embeddings.clone();
62 self.entity_ids = embeddings.keys().cloned().collect();
63 self.dimensions = embeddings.values().next().unwrap().len();
64
65 let mut matrix = Vec::new();
67 for entity_id in &self.entity_ids {
68 let mut emb = self.embeddings[entity_id].clone();
69
70 if self.config.normalize {
72 Self::normalize_vector(&mut emb);
73 }
74
75 matrix.push(emb);
76 }
77 self.embedding_matrix = Some(matrix);
78
79 self.is_built = true;
80
81 info!("Vector search index built successfully");
82 Ok(())
83 }
84
85 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
87 if !self.is_built {
88 return Err(anyhow!("Index not built. Call build() first"));
89 }
90
91 if query.len() != self.dimensions {
92 return Err(anyhow!(
93 "Query dimension {} doesn't match index dimension {}",
94 query.len(),
95 self.dimensions
96 ));
97 }
98
99 let mut normalized_query = query.to_vec();
101 if self.config.normalize {
102 Self::normalize_vector(&mut normalized_query);
103 }
104
105 debug!("Searching for {} nearest neighbors", k);
106 self.exact_search(&normalized_query, k)
107 }
108
109 fn exact_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
111 let matrix = self.embedding_matrix.as_ref().unwrap();
112
113 let scores: Vec<(usize, f32)> = if self.config.parallel {
115 (0..self.entity_ids.len())
116 .into_par_iter()
117 .map(|i| {
118 let score = self.compute_similarity(query, &matrix[i]);
119 (i, score)
120 })
121 .collect()
122 } else {
123 (0..self.entity_ids.len())
124 .map(|i| {
125 let score = self.compute_similarity(query, &matrix[i]);
126 (i, score)
127 })
128 .collect()
129 };
130
131 let mut sorted_scores = scores;
133 sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
134
135 let results: Vec<SearchResult> = sorted_scores
137 .iter()
138 .take(k.min(self.entity_ids.len()))
139 .enumerate()
140 .map(|(rank, &(idx, score))| SearchResult {
141 entity_id: self.entity_ids[idx].clone(),
142 score,
143 distance: self.score_to_distance(score),
144 rank: rank + 1,
145 })
146 .collect();
147
148 debug!("Found {} results", results.len());
149 Ok(results)
150 }
151
152 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
154 if !self.is_built {
155 return Err(anyhow!("Index not built. Call build() first"));
156 }
157
158 info!("Batch searching for {} queries", queries.len());
159
160 let results: Vec<Vec<SearchResult>> = if self.config.parallel {
161 queries
162 .par_iter()
163 .map(|query| self.search(query, k).unwrap_or_default())
164 .collect()
165 } else {
166 queries
167 .iter()
168 .map(|query| self.search(query, k).unwrap_or_default())
169 .collect()
170 };
171
172 Ok(results)
173 }
174
175 #[inline]
179 fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
180 simd::compute_distance_simd(self.config.metric, a, b)
182 }
183
184 #[inline]
186 fn score_to_distance(&self, score: f32) -> f32 {
187 match self.config.metric {
188 DistanceMetric::Cosine => 1.0 - score, DistanceMetric::Euclidean | DistanceMetric::Manhattan => -score, DistanceMetric::DotProduct => -score,
191 }
192 }
193
194 #[inline]
196 fn normalize_vector(vec: &mut [f32]) {
197 simd::normalize_vector_simd(vec);
198 }
199
200 pub fn get_stats(&self) -> IndexStats {
202 IndexStats {
203 num_entities: self.entity_ids.len(),
204 dimensions: self.dimensions,
205 is_built: self.is_built,
206 metric: self.config.metric,
207 }
208 }
209
210 pub fn radius_search(&self, query: &[f32], radius: f32) -> Result<Vec<SearchResult>> {
212 if !self.is_built {
213 return Err(anyhow!("Index not built. Call build() first"));
214 }
215
216 let all_results = self.search(query, self.entity_ids.len())?;
217
218 Ok(all_results
219 .into_iter()
220 .filter(|r| r.distance <= radius)
221 .collect())
222 }
223
224 pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
226 self.metadata.insert(entity_id.to_string(), metadata);
227 }
228
229 pub fn set_metadata_batch(&mut self, metadata_map: HashMap<String, Metadata>) {
231 self.metadata.extend(metadata_map);
232 }
233
234 pub fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
236 self.metadata.get(entity_id)
237 }
238
239 pub fn filtered_search(
244 &self,
245 query: &[f32],
246 k: usize,
247 filter: &Filter,
248 ) -> Result<Vec<SearchResult>> {
249 if !self.is_built {
250 return Err(anyhow!("Index not built. Call build() first"));
251 }
252
253 if filter.is_empty() {
254 return self.search(query, k);
255 }
256
257 debug!(
258 "Filtered search: k={}, filter conditions={}",
259 k,
260 filter.conditions().len()
261 );
262
263 let all_results = self.search(query, self.entity_ids.len())?;
266
267 let filtered: Vec<SearchResult> = all_results
268 .into_iter()
269 .filter(|r| {
270 self.metadata
271 .get(&r.entity_id)
272 .is_some_and(|m| filter.matches(m))
273 })
274 .take(k)
275 .enumerate()
276 .map(|(i, mut r)| {
277 r.rank = i + 1; r
279 })
280 .collect();
281
282 debug!("Filtered search returned {} results", filtered.len());
283 Ok(filtered)
284 }
285
286 pub fn prefiltered_search(
291 &self,
292 query: &[f32],
293 k: usize,
294 filter: &Filter,
295 ) -> Result<Vec<SearchResult>> {
296 if !self.is_built {
297 return Err(anyhow!("Index not built. Call build() first"));
298 }
299
300 if query.len() != self.dimensions {
301 return Err(anyhow!(
302 "Query dimension {} doesn't match index dimension {}",
303 query.len(),
304 self.dimensions
305 ));
306 }
307
308 if filter.is_empty() {
309 return self.search(query, k);
310 }
311
312 debug!("Pre-filtered search: k={}", k);
313
314 let mut normalized_query = query.to_vec();
316 if self.config.normalize {
317 Self::normalize_vector(&mut normalized_query);
318 }
319
320 let matrix = self.embedding_matrix.as_ref().unwrap();
321
322 let matching_indices: Vec<usize> = (0..self.entity_ids.len())
324 .filter(|&i| {
325 self.metadata
326 .get(&self.entity_ids[i])
327 .is_some_and(|m| filter.matches(m))
328 })
329 .collect();
330
331 if matching_indices.is_empty() {
332 return Ok(Vec::new());
333 }
334
335 let scores: Vec<(usize, f32)> = if self.config.parallel {
337 matching_indices
338 .par_iter()
339 .map(|&i| {
340 let score = self.compute_similarity(&normalized_query, &matrix[i]);
341 (i, score)
342 })
343 .collect()
344 } else {
345 matching_indices
346 .iter()
347 .map(|&i| {
348 let score = self.compute_similarity(&normalized_query, &matrix[i]);
349 (i, score)
350 })
351 .collect()
352 };
353
354 let mut sorted_scores = scores;
356 sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
357
358 let results: Vec<SearchResult> = sorted_scores
360 .iter()
361 .take(k)
362 .enumerate()
363 .map(|(rank, &(idx, score))| SearchResult {
364 entity_id: self.entity_ids[idx].clone(),
365 score,
366 distance: self.score_to_distance(score),
367 rank: rank + 1,
368 })
369 .collect();
370
371 debug!("Pre-filtered search returned {} results", results.len());
372 Ok(results)
373 }
374
375 pub fn add_vector(&mut self, entity_id: String, mut embedding: Vec<f32>) -> Result<()> {
380 if self.is_built && embedding.len() != self.dimensions {
381 return Err(anyhow!(
382 "Vector dimension {} doesn't match index dimension {}",
383 embedding.len(),
384 self.dimensions
385 ));
386 }
387
388 if self.embeddings.contains_key(&entity_id) {
390 return Err(anyhow!("Entity '{}' already exists in index", entity_id));
391 }
392
393 if !self.is_built {
395 self.dimensions = embedding.len();
396 }
397
398 if self.config.normalize {
400 Self::normalize_vector(&mut embedding);
401 }
402
403 self.embeddings.insert(entity_id.clone(), embedding.clone());
405
406 self.entity_ids.push(entity_id);
408
409 if let Some(ref mut matrix) = self.embedding_matrix {
411 matrix.push(embedding);
412 } else {
413 self.embedding_matrix = Some(vec![embedding]);
414 }
415
416 self.is_built = true;
417 debug!("Added vector to index (total: {})", self.entity_ids.len());
418 Ok(())
419 }
420
421 pub fn add_vectors(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
425 if embeddings.is_empty() {
426 return Ok(());
427 }
428
429 info!("Adding {} vectors to index", embeddings.len());
430
431 for (entity_id, embedding) in embeddings {
433 if self.is_built && embedding.len() != self.dimensions {
434 return Err(anyhow!(
435 "Vector dimension {} doesn't match index dimension {}",
436 embedding.len(),
437 self.dimensions
438 ));
439 }
440 if self.embeddings.contains_key(entity_id) {
441 return Err(anyhow!("Entity '{}' already exists in index", entity_id));
442 }
443 }
444
445 if !self.is_built {
447 self.dimensions = embeddings.values().next().unwrap().len();
448 }
449
450 for (entity_id, embedding) in embeddings {
452 let mut emb = embedding.clone();
453 if self.config.normalize {
454 Self::normalize_vector(&mut emb);
455 }
456
457 self.embeddings.insert(entity_id.clone(), emb.clone());
458 self.entity_ids.push(entity_id.clone());
459
460 if let Some(ref mut matrix) = self.embedding_matrix {
461 matrix.push(emb);
462 } else {
463 self.embedding_matrix = Some(vec![emb]);
464 }
465 }
466
467 self.is_built = true;
468 info!(
469 "Added vectors successfully (total: {})",
470 self.entity_ids.len()
471 );
472 Ok(())
473 }
474
475 pub fn remove_vector(&mut self, entity_id: &str) -> Result<()> {
477 if !self.embeddings.contains_key(entity_id) {
478 return Err(anyhow!("Entity '{}' not found in index", entity_id));
479 }
480
481 let idx = self
483 .entity_ids
484 .iter()
485 .position(|id| id == entity_id)
486 .ok_or_else(|| anyhow!("Entity '{}' not found in entity_ids", entity_id))?;
487
488 self.embeddings.remove(entity_id);
490
491 self.entity_ids.remove(idx);
493
494 if let Some(ref mut matrix) = self.embedding_matrix {
496 matrix.remove(idx);
497 }
498
499 self.metadata.remove(entity_id);
501
502 if self.embeddings.is_empty() {
504 self.is_built = false;
505 self.dimensions = 0;
506 }
507
508 debug!(
509 "Removed vector from index (remaining: {})",
510 self.entity_ids.len()
511 );
512 Ok(())
513 }
514
515 pub fn remove_vectors(&mut self, entity_ids: &[&str]) -> Result<()> {
517 info!("Removing {} vectors from index", entity_ids.len());
518
519 for entity_id in entity_ids {
520 self.remove_vector(entity_id)?;
521 }
522
523 info!(
524 "Removed vectors successfully (remaining: {})",
525 self.entity_ids.len()
526 );
527 Ok(())
528 }
529
530 pub fn update_vector(&mut self, entity_id: &str, mut new_embedding: Vec<f32>) -> Result<()> {
532 if !self.embeddings.contains_key(entity_id) {
533 return Err(anyhow!("Entity '{}' not found in index", entity_id));
534 }
535
536 if new_embedding.len() != self.dimensions {
537 return Err(anyhow!(
538 "Vector dimension {} doesn't match index dimension {}",
539 new_embedding.len(),
540 self.dimensions
541 ));
542 }
543
544 if self.config.normalize {
546 Self::normalize_vector(&mut new_embedding);
547 }
548
549 let idx = self
551 .entity_ids
552 .iter()
553 .position(|id| id == entity_id)
554 .ok_or_else(|| anyhow!("Entity '{}' not found in entity_ids", entity_id))?;
555
556 self.embeddings
558 .insert(entity_id.to_string(), new_embedding.clone());
559
560 if let Some(ref mut matrix) = self.embedding_matrix {
562 matrix[idx] = new_embedding;
563 }
564
565 debug!("Updated vector in index: {}", entity_id);
566 Ok(())
567 }
568
569 pub fn clear(&mut self) {
571 self.embeddings.clear();
572 self.entity_ids.clear();
573 self.embedding_matrix = None;
574 self.metadata.clear();
575 self.dimensions = 0;
576 self.is_built = false;
577 info!("Index cleared");
578 }
579
580 #[inline]
582 pub fn len(&self) -> usize {
583 self.entity_ids.len()
584 }
585
586 #[inline]
588 pub fn is_empty(&self) -> bool {
589 self.entity_ids.is_empty()
590 }
591
592 #[inline]
594 pub fn contains(&self, entity_id: &str) -> bool {
595 self.embeddings.contains_key(entity_id)
596 }
597
598 #[inline]
600 pub fn get_vector(&self, entity_id: &str) -> Option<&Vec<f32>> {
601 self.embeddings.get(entity_id)
602 }
603
604 pub fn merge(&mut self, other: &VectorSearchIndex, overwrite_duplicates: bool) -> Result<()> {
609 if !other.is_built {
610 return Err(anyhow!("Cannot merge from an unbuilt index"));
611 }
612
613 if !self.is_built && other.is_built {
615 self.dimensions = other.dimensions;
616 }
617
618 if self.is_built && other.is_built && self.dimensions != other.dimensions {
620 return Err(anyhow!(
621 "Cannot merge indexes with different dimensions: {} vs {}",
622 self.dimensions,
623 other.dimensions
624 ));
625 }
626
627 info!("Merging index with {} vectors", other.entity_ids.len());
628
629 let mut added = 0;
630 let mut updated = 0;
631 let mut skipped = 0;
632
633 for entity_id in &other.entity_ids {
634 let embedding = &other.embeddings[entity_id];
635
636 if self.embeddings.contains_key(entity_id) {
637 if overwrite_duplicates {
638 self.update_vector(entity_id, embedding.clone())?;
640 updated += 1;
641 } else {
642 skipped += 1;
644 }
645 } else {
646 self.add_vector(entity_id.clone(), embedding.clone())?;
648 added += 1;
649 }
650
651 if let Some(metadata) = other.metadata.get(entity_id) {
653 self.metadata.insert(entity_id.clone(), metadata.clone());
654 }
655 }
656
657 info!(
658 "Merge complete: added={}, updated={}, skipped={}",
659 added, updated, skipped
660 );
661 Ok(())
662 }
663
664 pub fn merge_multiple(indexes: &[&VectorSearchIndex]) -> Result<VectorSearchIndex> {
668 if indexes.is_empty() {
669 return Err(anyhow!("Cannot merge zero indexes"));
670 }
671
672 let first_built = indexes
674 .iter()
675 .find(|idx| idx.is_built)
676 .ok_or_else(|| anyhow!("At least one index must be built"))?;
677
678 let dimensions = first_built.dimensions;
679 let config = first_built.config.clone();
680
681 for (i, index) in indexes.iter().enumerate() {
683 if index.is_built && index.dimensions != dimensions {
684 return Err(anyhow!(
685 "Index {} has incompatible dimensions: {} vs {}",
686 i,
687 index.dimensions,
688 dimensions
689 ));
690 }
691 }
692
693 info!("Merging {} indexes into one", indexes.len());
694
695 let mut all_embeddings = HashMap::new();
697 let mut all_metadata = HashMap::new();
698
699 for index in indexes {
700 for (entity_id, embedding) in &index.embeddings {
701 all_embeddings.insert(entity_id.clone(), embedding.clone());
703 }
704
705 for (entity_id, metadata) in &index.metadata {
706 all_metadata.insert(entity_id.clone(), metadata.clone());
707 }
708 }
709
710 let mut merged = VectorSearchIndex::new(config);
712 merged.build(&all_embeddings)?;
713 merged.metadata = all_metadata;
714
715 info!("Merged index contains {} vectors", merged.len());
716 Ok(merged)
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723 use crate::filter::FilterValue;
724
725 fn create_test_embeddings() -> HashMap<String, Vec<f32>> {
726 let mut embeddings = HashMap::new();
727
728 embeddings.insert("entity1".to_string(), vec![1.0, 0.0, 0.0]);
730 embeddings.insert("entity2".to_string(), vec![0.9, 0.1, 0.0]);
731 embeddings.insert("entity3".to_string(), vec![0.0, 1.0, 0.0]);
732 embeddings.insert("entity4".to_string(), vec![0.0, 0.0, 1.0]);
733 embeddings.insert("entity5".to_string(), vec![0.7, 0.7, 0.0]);
734
735 embeddings
736 }
737
738 fn create_test_metadata() -> HashMap<String, Metadata> {
739 let mut metadata = HashMap::new();
740
741 let mut m1 = HashMap::new();
742 m1.insert(
743 "type".to_string(),
744 FilterValue::String("article".to_string()),
745 );
746 m1.insert("year".to_string(), FilterValue::Int(2023));
747 metadata.insert("entity1".to_string(), m1);
748
749 let mut m2 = HashMap::new();
750 m2.insert(
751 "type".to_string(),
752 FilterValue::String("article".to_string()),
753 );
754 m2.insert("year".to_string(), FilterValue::Int(2022));
755 metadata.insert("entity2".to_string(), m2);
756
757 let mut m3 = HashMap::new();
758 m3.insert("type".to_string(), FilterValue::String("book".to_string()));
759 m3.insert("year".to_string(), FilterValue::Int(2023));
760 metadata.insert("entity3".to_string(), m3);
761
762 let mut m4 = HashMap::new();
763 m4.insert("type".to_string(), FilterValue::String("book".to_string()));
764 m4.insert("year".to_string(), FilterValue::Int(2021));
765 metadata.insert("entity4".to_string(), m4);
766
767 let mut m5 = HashMap::new();
768 m5.insert(
769 "type".to_string(),
770 FilterValue::String("article".to_string()),
771 );
772 m5.insert("year".to_string(), FilterValue::Int(2024));
773 metadata.insert("entity5".to_string(), m5);
774
775 metadata
776 }
777
778 #[test]
779 fn test_index_creation() {
780 let config = SearchConfig::default();
781 let index = VectorSearchIndex::new(config);
782
783 assert!(!index.is_built);
784 assert_eq!(index.dimensions, 0);
785 }
786
787 #[test]
788 fn test_index_building() {
789 let embeddings = create_test_embeddings();
790 let mut index = VectorSearchIndex::new(SearchConfig::default());
791
792 assert!(index.build(&embeddings).is_ok());
793 assert!(index.is_built);
794 assert_eq!(index.dimensions, 3);
795 assert_eq!(index.entity_ids.len(), 5);
796 }
797
798 #[test]
799 fn test_search() {
800 let embeddings = create_test_embeddings();
801 let mut index = VectorSearchIndex::new(SearchConfig::default());
802 index.build(&embeddings).unwrap();
803
804 let query = vec![1.0, 0.0, 0.0];
806 let results = index.search(&query, 3).unwrap();
807
808 assert_eq!(results.len(), 3);
809 assert!(results[0].entity_id == "entity1" || results[0].entity_id == "entity2");
811 }
812
813 #[test]
814 fn test_batch_search() {
815 let embeddings = create_test_embeddings();
816 let mut index = VectorSearchIndex::new(SearchConfig::default());
817 index.build(&embeddings).unwrap();
818
819 let queries = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
820 let results = index.batch_search(&queries, 2).unwrap();
821
822 assert_eq!(results.len(), 2);
823 assert_eq!(results[0].len(), 2);
824 assert_eq!(results[1].len(), 2);
825 }
826
827 #[test]
828 fn test_get_stats() {
829 let embeddings = create_test_embeddings();
830 let mut index = VectorSearchIndex::new(SearchConfig::default());
831 index.build(&embeddings).unwrap();
832
833 let stats = index.get_stats();
834 assert_eq!(stats.num_entities, 5);
835 assert_eq!(stats.dimensions, 3);
836 assert!(stats.is_built);
837 assert_eq!(stats.metric, DistanceMetric::Cosine);
838 }
839
840 #[test]
841 fn test_set_and_get_metadata() {
842 let embeddings = create_test_embeddings();
843 let mut index = VectorSearchIndex::new(SearchConfig::default());
844 index.build(&embeddings).unwrap();
845
846 let mut metadata = HashMap::new();
847 metadata.insert(
848 "type".to_string(),
849 FilterValue::String("article".to_string()),
850 );
851
852 index.set_metadata("entity1", metadata.clone());
853
854 let retrieved = index.get_metadata("entity1");
855 assert!(retrieved.is_some());
856 assert_eq!(
857 retrieved.unwrap().get("type"),
858 Some(&FilterValue::String("article".to_string()))
859 );
860 }
861
862 #[test]
863 fn test_filtered_search() {
864 let embeddings = create_test_embeddings();
865 let metadata = create_test_metadata();
866 let mut index = VectorSearchIndex::new(SearchConfig::default());
867 index.build(&embeddings).unwrap();
868 index.set_metadata_batch(metadata);
869
870 let filter = Filter::new().eq("type", "article");
872 let query = vec![1.0, 0.0, 0.0];
873 let results = index.filtered_search(&query, 5, &filter).unwrap();
874
875 assert_eq!(results.len(), 3);
877 for result in &results {
878 let meta = index.get_metadata(&result.entity_id).unwrap();
879 assert_eq!(
880 meta.get("type"),
881 Some(&FilterValue::String("article".to_string()))
882 );
883 }
884 }
885
886 #[test]
887 fn test_filtered_search_with_year() {
888 let embeddings = create_test_embeddings();
889 let metadata = create_test_metadata();
890 let mut index = VectorSearchIndex::new(SearchConfig::default());
891 index.build(&embeddings).unwrap();
892 index.set_metadata_batch(metadata);
893
894 let filter = Filter::new().gte("year", 2023i64);
896 let query = vec![1.0, 0.0, 0.0];
897 let results = index.filtered_search(&query, 5, &filter).unwrap();
898
899 assert_eq!(results.len(), 3);
901 }
902
903 #[test]
904 fn test_prefiltered_search() {
905 let embeddings = create_test_embeddings();
906 let metadata = create_test_metadata();
907 let mut index = VectorSearchIndex::new(SearchConfig::default());
908 index.build(&embeddings).unwrap();
909 index.set_metadata_batch(metadata);
910
911 let filter = Filter::new().eq("type", "book");
913 let query = vec![0.0, 1.0, 0.0]; let results = index.prefiltered_search(&query, 5, &filter).unwrap();
915
916 assert_eq!(results.len(), 2);
918 for result in &results {
919 let meta = index.get_metadata(&result.entity_id).unwrap();
920 assert_eq!(
921 meta.get("type"),
922 Some(&FilterValue::String("book".to_string()))
923 );
924 }
925 }
926
927 #[test]
928 fn test_filtered_search_empty_filter() {
929 let embeddings = create_test_embeddings();
930 let mut index = VectorSearchIndex::new(SearchConfig::default());
931 index.build(&embeddings).unwrap();
932
933 let filter = Filter::new();
935 let query = vec![1.0, 0.0, 0.0];
936 let results = index.filtered_search(&query, 3, &filter).unwrap();
937
938 assert_eq!(results.len(), 3);
939 }
940
941 #[test]
942 fn test_filtered_search_no_matches() {
943 let embeddings = create_test_embeddings();
944 let metadata = create_test_metadata();
945 let mut index = VectorSearchIndex::new(SearchConfig::default());
946 index.build(&embeddings).unwrap();
947 index.set_metadata_batch(metadata);
948
949 let filter = Filter::new().eq("type", "journal");
951 let query = vec![1.0, 0.0, 0.0];
952 let results = index.filtered_search(&query, 5, &filter).unwrap();
953
954 assert_eq!(results.len(), 0);
955 }
956
957 #[test]
958 fn test_add_vector() {
959 let embeddings = create_test_embeddings();
960 let mut index = VectorSearchIndex::new(SearchConfig::default());
961 index.build(&embeddings).unwrap();
962
963 let initial_len = index.len();
964
965 let result = index.add_vector("entity6".to_string(), vec![0.5, 0.5, 0.5]);
967 assert!(result.is_ok());
968 assert_eq!(index.len(), initial_len + 1);
969 assert!(index.contains("entity6"));
970
971 let query = vec![0.5, 0.5, 0.5];
973 let results = index.search(&query, 1).unwrap();
974 assert_eq!(results[0].entity_id, "entity6");
975 }
976
977 #[test]
978 fn test_add_vector_duplicate() {
979 let embeddings = create_test_embeddings();
980 let mut index = VectorSearchIndex::new(SearchConfig::default());
981 index.build(&embeddings).unwrap();
982
983 let result = index.add_vector("entity1".to_string(), vec![0.5, 0.5, 0.5]);
985 assert!(result.is_err());
986 }
987
988 #[test]
989 fn test_add_vector_dimension_mismatch() {
990 let embeddings = create_test_embeddings();
991 let mut index = VectorSearchIndex::new(SearchConfig::default());
992 index.build(&embeddings).unwrap();
993
994 let result = index.add_vector("entity6".to_string(), vec![0.5, 0.5]); assert!(result.is_err());
997 }
998
999 #[test]
1000 fn test_add_vectors() {
1001 let embeddings = create_test_embeddings();
1002 let mut index = VectorSearchIndex::new(SearchConfig::default());
1003 index.build(&embeddings).unwrap();
1004
1005 let initial_len = index.len();
1006
1007 let mut new_embeddings = HashMap::new();
1009 new_embeddings.insert("entity6".to_string(), vec![0.5, 0.5, 0.5]);
1010 new_embeddings.insert("entity7".to_string(), vec![0.6, 0.6, 0.6]);
1011
1012 let result = index.add_vectors(&new_embeddings);
1013 assert!(result.is_ok());
1014 assert_eq!(index.len(), initial_len + 2);
1015 assert!(index.contains("entity6"));
1016 assert!(index.contains("entity7"));
1017 }
1018
1019 #[test]
1020 fn test_remove_vector() {
1021 let embeddings = create_test_embeddings();
1022 let mut index = VectorSearchIndex::new(SearchConfig::default());
1023 index.build(&embeddings).unwrap();
1024
1025 let initial_len = index.len();
1026
1027 let result = index.remove_vector("entity1");
1029 assert!(result.is_ok());
1030 assert_eq!(index.len(), initial_len - 1);
1031 assert!(!index.contains("entity1"));
1032
1033 let query = vec![1.0, 0.0, 0.0];
1035 let results = index.search(&query, 5).unwrap();
1036 assert!(!results.iter().any(|r| r.entity_id == "entity1"));
1037 }
1038
1039 #[test]
1040 fn test_remove_vector_not_found() {
1041 let embeddings = create_test_embeddings();
1042 let mut index = VectorSearchIndex::new(SearchConfig::default());
1043 index.build(&embeddings).unwrap();
1044
1045 let result = index.remove_vector("nonexistent");
1047 assert!(result.is_err());
1048 }
1049
1050 #[test]
1051 fn test_remove_vectors() {
1052 let embeddings = create_test_embeddings();
1053 let mut index = VectorSearchIndex::new(SearchConfig::default());
1054 index.build(&embeddings).unwrap();
1055
1056 let initial_len = index.len();
1057
1058 let result = index.remove_vectors(&["entity1", "entity2"]);
1060 assert!(result.is_ok());
1061 assert_eq!(index.len(), initial_len - 2);
1062 assert!(!index.contains("entity1"));
1063 assert!(!index.contains("entity2"));
1064 }
1065
1066 #[test]
1067 fn test_update_vector() {
1068 let embeddings = create_test_embeddings();
1069 let mut index = VectorSearchIndex::new(SearchConfig::default());
1070 index.build(&embeddings).unwrap();
1071
1072 let new_embedding = vec![0.9, 0.9, 0.9];
1074 let result = index.update_vector("entity1", new_embedding.clone());
1075 assert!(result.is_ok());
1076
1077 let retrieved = index.get_vector("entity1").unwrap();
1079 let mut expected = new_embedding.clone();
1081 VectorSearchIndex::normalize_vector(&mut expected);
1082 assert_eq!(retrieved.len(), expected.len());
1083 for (a, b) in retrieved.iter().zip(expected.iter()) {
1084 assert!((a - b).abs() < 1e-6);
1085 }
1086 }
1087
1088 #[test]
1089 fn test_update_vector_not_found() {
1090 let embeddings = create_test_embeddings();
1091 let mut index = VectorSearchIndex::new(SearchConfig::default());
1092 index.build(&embeddings).unwrap();
1093
1094 let result = index.update_vector("nonexistent", vec![0.5, 0.5, 0.5]);
1096 assert!(result.is_err());
1097 }
1098
1099 #[test]
1100 fn test_update_vector_dimension_mismatch() {
1101 let embeddings = create_test_embeddings();
1102 let mut index = VectorSearchIndex::new(SearchConfig::default());
1103 index.build(&embeddings).unwrap();
1104
1105 let result = index.update_vector("entity1", vec![0.5, 0.5]); assert!(result.is_err());
1108 }
1109
1110 #[test]
1111 fn test_clear() {
1112 let embeddings = create_test_embeddings();
1113 let mut index = VectorSearchIndex::new(SearchConfig::default());
1114 index.build(&embeddings).unwrap();
1115
1116 assert!(!index.is_empty());
1117 assert!(index.is_built);
1118
1119 index.clear();
1120
1121 assert_eq!(index.len(), 0);
1122 assert!(index.is_empty());
1123 assert!(!index.is_built);
1124 assert_eq!(index.dimensions, 0);
1125 }
1126
1127 #[test]
1128 fn test_get_vector() {
1129 let embeddings = create_test_embeddings();
1130 let mut index = VectorSearchIndex::new(SearchConfig::default());
1131 index.build(&embeddings).unwrap();
1132
1133 let vector = index.get_vector("entity1");
1135 assert!(vector.is_some());
1136
1137 let vector = index.get_vector("nonexistent");
1139 assert!(vector.is_none());
1140 }
1141
1142 #[test]
1143 fn test_incremental_build() {
1144 let mut index = VectorSearchIndex::new(SearchConfig::default());
1145
1146 index
1148 .add_vector("entity1".to_string(), vec![1.0, 0.0, 0.0])
1149 .unwrap();
1150 index
1151 .add_vector("entity2".to_string(), vec![0.0, 1.0, 0.0])
1152 .unwrap();
1153 index
1154 .add_vector("entity3".to_string(), vec![0.0, 0.0, 1.0])
1155 .unwrap();
1156
1157 assert_eq!(index.len(), 3);
1158 assert!(index.is_built);
1159 assert_eq!(index.dimensions, 3);
1160
1161 let query = vec![1.0, 0.0, 0.0];
1163 let results = index.search(&query, 1).unwrap();
1164 assert_eq!(results[0].entity_id, "entity1");
1165 }
1166
1167 #[test]
1168 fn test_merge_indexes() {
1169 let embeddings1 = create_test_embeddings();
1170 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
1171 index1.build(&embeddings1).unwrap();
1172
1173 let mut embeddings2 = HashMap::new();
1175 embeddings2.insert("entity6".to_string(), vec![0.6, 0.6, 0.0]);
1176 embeddings2.insert("entity7".to_string(), vec![0.7, 0.7, 0.0]);
1177
1178 let mut index2 = VectorSearchIndex::new(SearchConfig::default());
1179 index2.build(&embeddings2).unwrap();
1180
1181 let initial_len = index1.len();
1182
1183 let result = index1.merge(&index2, false);
1185 assert!(result.is_ok());
1186 assert_eq!(index1.len(), initial_len + 2);
1187 assert!(index1.contains("entity6"));
1188 assert!(index1.contains("entity7"));
1189 }
1190
1191 #[test]
1192 fn test_merge_with_duplicates_skip() {
1193 let embeddings1 = create_test_embeddings();
1194 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
1195 index1.build(&embeddings1).unwrap();
1196
1197 let mut embeddings2 = HashMap::new();
1199 embeddings2.insert("entity1".to_string(), vec![0.9, 0.9, 0.9]); embeddings2.insert("entity6".to_string(), vec![0.6, 0.6, 0.0]); let mut index2 = VectorSearchIndex::new(SearchConfig::default());
1203 index2.build(&embeddings2).unwrap();
1204
1205 let initial_len = index1.len();
1206
1207 let result = index1.merge(&index2, false);
1209 assert!(result.is_ok());
1210 assert_eq!(index1.len(), initial_len + 1); }
1212
1213 #[test]
1214 fn test_merge_with_duplicates_overwrite() {
1215 let embeddings1 = create_test_embeddings();
1216 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
1217 index1.build(&embeddings1).unwrap();
1218
1219 let mut embeddings2 = HashMap::new();
1221 embeddings2.insert("entity1".to_string(), vec![0.9, 0.9, 0.9]); let mut index2 = VectorSearchIndex::new(SearchConfig::default());
1224 index2.build(&embeddings2).unwrap();
1225
1226 let initial_len = index1.len();
1227
1228 let result = index1.merge(&index2, true);
1230 assert!(result.is_ok());
1231 assert_eq!(index1.len(), initial_len); let vector = index1.get_vector("entity1").unwrap();
1235 let mut expected = vec![0.9, 0.9, 0.9];
1237 VectorSearchIndex::normalize_vector(&mut expected);
1238 for (a, b) in vector.iter().zip(expected.iter()) {
1239 assert!((a - b).abs() < 1e-6);
1240 }
1241 }
1242
1243 #[test]
1244 fn test_merge_dimension_mismatch() {
1245 let embeddings1 = create_test_embeddings(); let mut index1 = VectorSearchIndex::new(SearchConfig::default());
1247 index1.build(&embeddings1).unwrap();
1248
1249 let mut embeddings2 = HashMap::new();
1251 embeddings2.insert("entity6".to_string(), vec![0.6, 0.6]); let mut index2 = VectorSearchIndex::new(SearchConfig::default());
1254 index2.build(&embeddings2).unwrap();
1255
1256 let result = index1.merge(&index2, false);
1258 assert!(result.is_err());
1259 }
1260
1261 #[test]
1262 fn test_merge_multiple_indexes() {
1263 let mut index1 = VectorSearchIndex::new(SearchConfig::default());
1265 let mut embeddings1 = HashMap::new();
1266 embeddings1.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
1267 embeddings1.insert("doc2".to_string(), vec![0.9, 0.1, 0.0]);
1268 index1.build(&embeddings1).unwrap();
1269
1270 let mut index2 = VectorSearchIndex::new(SearchConfig::default());
1271 let mut embeddings2 = HashMap::new();
1272 embeddings2.insert("doc3".to_string(), vec![0.0, 1.0, 0.0]);
1273 embeddings2.insert("doc4".to_string(), vec![0.1, 0.9, 0.0]);
1274 index2.build(&embeddings2).unwrap();
1275
1276 let mut index3 = VectorSearchIndex::new(SearchConfig::default());
1277 let mut embeddings3 = HashMap::new();
1278 embeddings3.insert("doc5".to_string(), vec![0.0, 0.0, 1.0]);
1279 index3.build(&embeddings3).unwrap();
1280
1281 let merged = VectorSearchIndex::merge_multiple(&[&index1, &index2, &index3]);
1283 assert!(merged.is_ok());
1284
1285 let merged = merged.unwrap();
1286 assert_eq!(merged.len(), 5);
1287 assert!(merged.contains("doc1"));
1288 assert!(merged.contains("doc3"));
1289 assert!(merged.contains("doc5"));
1290
1291 let query = vec![1.0, 0.0, 0.0];
1293 let results = merged.search(&query, 2).unwrap();
1294 assert_eq!(results.len(), 2);
1295 }
1296
1297 #[test]
1298 fn test_merge_multiple_empty() {
1299 let result = VectorSearchIndex::merge_multiple(&[]);
1300 assert!(result.is_err());
1301 }
1302
1303 #[cfg(test)]
1305 mod proptest_tests {
1306 use super::*;
1307 use proptest::prelude::*;
1308
1309 fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
1311 proptest::collection::vec(-1.0f32..1.0f32, dim..=dim)
1312 }
1313
1314 fn embeddings_strategy(
1316 count: usize,
1317 dim: usize,
1318 ) -> impl Strategy<Value = HashMap<String, Vec<f32>>> {
1319 proptest::collection::vec(
1320 (
1321 proptest::string::string_regex("[a-z0-9]{5,10}").unwrap(),
1322 vector_strategy(dim),
1323 ),
1324 count..=count,
1325 )
1326 .prop_map(|pairs| pairs.into_iter().collect())
1327 }
1328
1329 proptest! {
1330 #[test]
1332 fn prop_search_never_panics(
1333 embeddings in embeddings_strategy(10, 128),
1334 query in vector_strategy(128),
1335 k in 1usize..10
1336 ) {
1337 let mut index = VectorSearchIndex::new(SearchConfig::default());
1338 index.build(&embeddings).unwrap();
1339 let _ = index.search(&query, k);
1340 }
1341
1342 #[test]
1344 fn prop_search_respects_k(
1345 embeddings in embeddings_strategy(20, 64),
1346 query in vector_strategy(64),
1347 k in 1usize..15
1348 ) {
1349 let mut index = VectorSearchIndex::new(SearchConfig::default());
1350 index.build(&embeddings).unwrap();
1351 let results = index.search(&query, k).unwrap();
1352 prop_assert!(results.len() <= k);
1353 prop_assert!(results.len() <= embeddings.len());
1354 }
1355
1356 #[test]
1358 fn prop_search_results_sorted(
1359 embeddings in embeddings_strategy(15, 32),
1360 query in vector_strategy(32),
1361 k in 2usize..10
1362 ) {
1363 let mut index = VectorSearchIndex::new(SearchConfig::default());
1364 index.build(&embeddings).unwrap();
1365 let results = index.search(&query, k).unwrap();
1366
1367 for i in 1..results.len() {
1368 prop_assert!(results[i-1].score >= results[i].score,
1369 "Results not sorted: {} < {}", results[i-1].score, results[i].score);
1370 }
1371 }
1372
1373 #[test]
1375 fn prop_search_ranks_consecutive(
1376 embeddings in embeddings_strategy(10, 16),
1377 query in vector_strategy(16),
1378 k in 1usize..8
1379 ) {
1380 let mut index = VectorSearchIndex::new(SearchConfig::default());
1381 index.build(&embeddings).unwrap();
1382 let results = index.search(&query, k).unwrap();
1383
1384 for (i, result) in results.iter().enumerate() {
1385 prop_assert_eq!(result.rank, i + 1);
1386 }
1387 }
1388
1389 #[test]
1391 fn prop_search_deterministic(
1392 embeddings in embeddings_strategy(12, 48),
1393 query in vector_strategy(48),
1394 k in 1usize..10
1395 ) {
1396 let mut index = VectorSearchIndex::new(SearchConfig::default());
1397 index.build(&embeddings).unwrap();
1398
1399 let results1 = index.search(&query, k).unwrap();
1400 let results2 = index.search(&query, k).unwrap();
1401
1402 prop_assert_eq!(results1.len(), results2.len());
1403 for (r1, r2) in results1.iter().zip(results2.iter()) {
1404 prop_assert_eq!(&r1.entity_id, &r2.entity_id);
1405 prop_assert!((r1.score - r2.score).abs() < 1e-6);
1406 }
1407 }
1408
1409 #[test]
1411 fn prop_batch_search_count(
1412 embeddings in embeddings_strategy(10, 32),
1413 num_queries in 1usize..5,
1414 k in 1usize..8
1415 ) {
1416 let mut index = VectorSearchIndex::new(SearchConfig::default());
1417 index.build(&embeddings).unwrap();
1418
1419 let queries: Vec<Vec<f32>> = (0..num_queries)
1420 .map(|i| vec![i as f32; 32])
1421 .collect();
1422
1423 let results = index.batch_search(&queries, k).unwrap();
1424 prop_assert_eq!(results.len(), num_queries);
1425
1426 for result_set in results {
1427 prop_assert!(result_set.len() <= k);
1428 }
1429 }
1430
1431 #[test]
1433 fn prop_distance_non_negative(
1434 embeddings in embeddings_strategy(8, 24),
1435 query in vector_strategy(24),
1436 k in 1usize..6
1437 ) {
1438 let mut index = VectorSearchIndex::new(SearchConfig::default());
1439 index.build(&embeddings).unwrap();
1440 let results = index.search(&query, k).unwrap();
1441
1442 for result in results {
1443 prop_assert!(result.distance >= 0.0,
1444 "Negative distance: {}", result.distance);
1445 }
1446 }
1447
1448 #[test]
1450 fn prop_empty_embeddings_fail(_dim in 1usize..128) {
1451 let embeddings: HashMap<String, Vec<f32>> = HashMap::new();
1452 let mut index = VectorSearchIndex::new(SearchConfig::default());
1453 prop_assert!(index.build(&embeddings).is_err());
1454 }
1455
1456 #[test]
1458 fn prop_dimension_mismatch_fail(
1459 embeddings in embeddings_strategy(5, 64),
1460 wrong_dim in 1usize..128
1461 ) {
1462 prop_assume!(wrong_dim != 64); let mut index = VectorSearchIndex::new(SearchConfig::default());
1465 index.build(&embeddings).unwrap();
1466
1467 let query = vec![0.0; wrong_dim];
1468 let result = index.search(&query, 3);
1469 prop_assert!(result.is_err());
1470 }
1471
1472 #[test]
1474 fn prop_normalize_unit_norm(mut vec in vector_strategy(128)) {
1475 VectorSearchIndex::normalize_vector(&mut vec);
1476 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1477 prop_assert!((norm - 1.0).abs() < 1e-5, "Norm: {}", norm);
1478 }
1479 }
1480 }
1481}