1use crate::{Chunk, ChunkId, Error, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6
7const DEFAULT_EMBEDDING_DIM: usize = 384;
9
10pub trait SparseIndex: Send + Sync {
12 fn add(&mut self, chunk: &Chunk);
14
15 fn add_batch(&mut self, chunks: &[Chunk]);
17
18 fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)>;
20
21 fn remove(&mut self, chunk_id: ChunkId);
23
24 fn len(&self) -> usize;
26
27 fn is_empty(&self) -> bool {
29 self.len() == 0
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BM25Index {
36 inverted_index: HashMap<String, Vec<(ChunkId, u32)>>,
38 doc_freqs: HashMap<String, u32>,
40 doc_lengths: HashMap<ChunkId, u32>,
42 avg_doc_length: f32,
44 doc_count: u32,
46 k1: f32,
48 b: f32,
50 lowercase: bool,
53 stopwords: HashSet<String>,
55 #[serde(skip, default)]
67 custom_tokenizer: Option<std::sync::Arc<dyn crate::tokenizer::Tokenizer>>,
68}
69
70impl Default for BM25Index {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl BM25Index {
77 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 inverted_index: HashMap::new(),
82 doc_freqs: HashMap::new(),
83 doc_lengths: HashMap::new(),
84 avg_doc_length: 0.0,
85 doc_count: 0,
86 k1: 1.2,
87 b: 0.75,
88 lowercase: true,
89 stopwords: Self::default_stopwords(),
90 custom_tokenizer: None,
91 }
92 }
93
94 #[must_use]
96 pub fn with_params(k1: f32, b: f32) -> Self {
97 Self { k1, b, ..Self::new() }
98 }
99
100 #[must_use]
114 pub fn with_tokenizer(
115 mut self,
116 tokenizer: std::sync::Arc<dyn crate::tokenizer::Tokenizer>,
117 ) -> Self {
118 self.custom_tokenizer = Some(tokenizer);
119 self
120 }
121
122 #[must_use]
126 pub fn has_custom_tokenizer(&self) -> bool {
127 self.custom_tokenizer.is_some()
128 }
129
130 #[must_use]
136 pub fn indexed_terms(&self) -> Vec<&str> {
137 self.inverted_index.keys().map(String::as_str).collect()
138 }
139
140 #[must_use]
142 pub fn with_stopwords(mut self, stopwords: HashSet<String>) -> Self {
143 self.stopwords = stopwords;
144 self
145 }
146
147 fn default_stopwords() -> HashSet<String> {
148 [
149 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
150 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
151 "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
152 "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
153 "below", "between", "under", "again", "further", "then", "once", "here", "there",
154 "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
155 "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
156 "and", "but", "if", "or", "because", "until", "while", "this", "that", "these",
157 "those", "it", "its",
158 ]
159 .iter()
160 .map(|s| (*s).to_string())
161 .collect()
162 }
163
164 pub fn tokenize(&self, text: &str) -> Vec<String> {
169 if let Some(tok) = self.custom_tokenizer.as_ref() {
170 return tok.tokenize(text);
171 }
172 text.split(|c: char| !c.is_alphanumeric())
173 .filter(|s| !s.is_empty())
174 .map(|s| if self.lowercase { s.to_lowercase() } else { s.to_string() })
175 .filter(|s| !self.stopwords.contains(s))
176 .filter(|s| s.len() >= 2) .collect()
178 }
179
180 fn term_frequency(&self, term: &str, chunk_id: ChunkId) -> u32 {
182 self.inverted_index
183 .get(term)
184 .and_then(|postings| postings.iter().find(|(id, _)| *id == chunk_id))
185 .map(|(_, freq)| *freq)
186 .unwrap_or(0)
187 }
188
189 fn score_term(&self, term: &str, chunk_id: ChunkId) -> f32 {
191 let tf = self.term_frequency(term, chunk_id) as f32;
192 if tf == 0.0 {
193 return 0.0;
194 }
195
196 let df = self.doc_freqs.get(term).copied().unwrap_or(0) as f32;
197 let n = self.doc_count as f32;
198 let doc_len = self.doc_lengths.get(&chunk_id).copied().unwrap_or(0) as f32;
199
200 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).max(f32::EPSILON).ln();
202
203 let tf_norm = (tf * (self.k1 + 1.0))
205 / (tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length));
206
207 idf * tf_norm
208 }
209
210 fn update_avg_doc_length(&mut self) {
212 if self.doc_count == 0 {
213 self.avg_doc_length = 0.0;
214 } else {
215 let total: u32 = self.doc_lengths.values().sum();
216 self.avg_doc_length = total as f32 / self.doc_count as f32;
217 }
218 }
219
220 fn get_chunks_for_term(&self, term: &str) -> Vec<ChunkId> {
222 self.inverted_index
223 .get(term)
224 .map(|postings| postings.iter().map(|(id, _)| *id).collect())
225 .unwrap_or_default()
226 }
227}
228
229impl SparseIndex for BM25Index {
230 fn add(&mut self, chunk: &Chunk) {
231 let tokens = self.tokenize(&chunk.content);
232 let doc_len = tokens.len() as u32;
233
234 self.doc_lengths.insert(chunk.id, doc_len);
236 self.doc_count += 1;
237
238 let mut term_freqs: HashMap<String, u32> = HashMap::new();
240 for token in &tokens {
241 *term_freqs.entry(token.clone()).or_insert(0) += 1;
242 }
243
244 let mut seen_terms: HashSet<String> = HashSet::new();
246 for (term, freq) in term_freqs {
247 self.inverted_index.entry(term.clone()).or_default().push((chunk.id, freq));
248
249 if seen_terms.insert(term.clone()) {
250 *self.doc_freqs.entry(term).or_insert(0) += 1;
251 }
252 }
253
254 self.update_avg_doc_length();
255 }
256
257 fn add_batch(&mut self, chunks: &[Chunk]) {
258 for chunk in chunks {
259 self.add(chunk);
260 }
261 }
262
263 fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)> {
264 let query_terms = self.tokenize(query);
265 if query_terms.is_empty() {
266 return Vec::new();
267 }
268
269 let mut candidates: HashSet<ChunkId> = HashSet::new();
271 for term in &query_terms {
272 for chunk_id in self.get_chunks_for_term(term) {
273 candidates.insert(chunk_id);
274 }
275 }
276
277 let mut scores: Vec<(ChunkId, f32)> = candidates
279 .into_iter()
280 .map(|chunk_id| {
281 let score: f32 =
282 query_terms.iter().map(|term| self.score_term(term, chunk_id)).sum();
283 (chunk_id, score)
284 })
285 .filter(|(_, score)| *score > 0.0)
286 .collect();
287
288 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
290 scores.truncate(k);
291 scores
292 }
293
294 fn remove(&mut self, chunk_id: ChunkId) {
295 if self.doc_lengths.remove(&chunk_id).is_some() {
297 self.doc_count = self.doc_count.saturating_sub(1);
298 }
299
300 let mut terms_to_remove: Vec<String> = Vec::new();
302 for (term, postings) in &mut self.inverted_index {
303 let original_len = postings.len();
304 postings.retain(|(id, _)| *id != chunk_id);
305
306 if postings.len() < original_len {
307 if let Some(df) = self.doc_freqs.get_mut(term) {
309 *df = df.saturating_sub(1);
310 if *df == 0 {
311 terms_to_remove.push(term.clone());
312 }
313 }
314 }
315 }
316
317 for term in terms_to_remove {
319 self.inverted_index.remove(&term);
320 self.doc_freqs.remove(&term);
321 }
322
323 self.update_avg_doc_length();
324 }
325
326 fn len(&self) -> usize {
327 self.doc_count as usize
328 }
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct VectorStoreConfig {
334 pub dimension: usize,
336 pub metric: DistanceMetric,
338 pub hnsw_m: usize,
340 pub hnsw_ef_construction: usize,
342 pub hnsw_ef_search: usize,
344}
345
346impl Default for VectorStoreConfig {
347 fn default() -> Self {
348 Self {
349 dimension: DEFAULT_EMBEDDING_DIM,
350 metric: DistanceMetric::Cosine,
351 hnsw_m: 16,
352 hnsw_ef_construction: 100,
353 hnsw_ef_search: 50,
354 }
355 }
356}
357
358#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
360pub enum DistanceMetric {
361 #[default]
363 Cosine,
364 Euclidean,
366 DotProduct,
368}
369
370#[derive(Debug, Clone)]
372pub struct VectorStore {
373 config: VectorStoreConfig,
375 vectors: HashMap<ChunkId, Vec<f32>>,
377 chunks: HashMap<ChunkId, Chunk>,
379}
380
381impl VectorStore {
382 #[must_use]
384 pub fn new(config: VectorStoreConfig) -> Self {
385 Self { config, vectors: HashMap::new(), chunks: HashMap::new() }
386 }
387
388 #[must_use]
390 pub fn with_dimension(dimension: usize) -> Self {
391 Self::new(VectorStoreConfig { dimension, ..Default::default() })
392 }
393
394 #[must_use]
396 pub fn config(&self) -> &VectorStoreConfig {
397 &self.config
398 }
399
400 pub fn insert(&mut self, chunk: Chunk) -> Result<()> {
402 let embedding = chunk
403 .embedding
404 .as_ref()
405 .ok_or_else(|| Error::InvalidConfig("chunk must have embedding".to_string()))?;
406
407 if embedding.len() != self.config.dimension {
408 return Err(Error::DimensionMismatch {
409 expected: self.config.dimension,
410 actual: embedding.len(),
411 });
412 }
413
414 self.vectors.insert(chunk.id, embedding.clone());
415 self.chunks.insert(chunk.id, chunk);
416 Ok(())
417 }
418
419 pub fn insert_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
421 for chunk in chunks {
422 self.insert(chunk)?;
423 }
424 Ok(())
425 }
426
427 pub fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<(ChunkId, f32)>> {
429 if query_vector.len() != self.config.dimension {
430 return Err(Error::DimensionMismatch {
431 expected: self.config.dimension,
432 actual: query_vector.len(),
433 });
434 }
435
436 let mut scores: Vec<(ChunkId, f32)> = self
437 .vectors
438 .iter()
439 .map(|(id, vec)| {
440 let score = match self.config.metric {
441 DistanceMetric::Cosine => cosine_similarity(query_vector, vec),
442 DistanceMetric::Euclidean => -euclidean_distance(query_vector, vec),
443 DistanceMetric::DotProduct => dot_product(query_vector, vec),
444 };
445 (*id, score)
446 })
447 .collect();
448
449 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
451 scores.truncate(k);
452
453 Ok(scores)
454 }
455
456 #[must_use]
458 pub fn get(&self, chunk_id: ChunkId) -> Option<&Chunk> {
459 self.chunks.get(&chunk_id)
460 }
461
462 pub fn remove(&mut self, chunk_id: ChunkId) -> Option<Chunk> {
464 self.vectors.remove(&chunk_id);
465 self.chunks.remove(&chunk_id)
466 }
467
468 #[must_use]
470 pub fn len(&self) -> usize {
471 self.vectors.len()
472 }
473
474 #[must_use]
476 pub fn is_empty(&self) -> bool {
477 self.vectors.is_empty()
478 }
479}
480
481fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
483 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
484 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
485 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
486
487 if norm_a == 0.0 || norm_b == 0.0 {
488 0.0
489 } else {
490 dot / (norm_a * norm_b)
491 }
492}
493
494fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
495 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
496}
497
498fn dot_product(a: &[f32], b: &[f32]) -> f32 {
499 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::DocumentId;
506
507 fn create_test_chunk(content: &str) -> Chunk {
508 Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
509 }
510
511 fn create_test_chunk_with_embedding(content: &str, embedding: Vec<f32>) -> Chunk {
512 let mut chunk = create_test_chunk(content);
513 chunk.set_embedding(embedding);
514 chunk
515 }
516
517 #[test]
520 fn test_bm25_index_new() {
521 let index = BM25Index::new();
522 assert_eq!(index.len(), 0);
523 assert!(index.is_empty());
524 assert!((index.k1 - 1.2).abs() < 0.01);
525 assert!((index.b - 0.75).abs() < 0.01);
526 }
527
528 #[test]
529 fn test_bm25_index_with_params() {
530 let index = BM25Index::with_params(1.5, 0.5);
531 assert!((index.k1 - 1.5).abs() < 0.01);
532 assert!((index.b - 0.5).abs() < 0.01);
533 }
534
535 #[test]
536 fn test_bm25_tokenize() {
537 let index = BM25Index::new();
538 let tokens = index.tokenize("Hello World! This is a test.");
539
540 assert!(tokens.contains(&"hello".to_string()));
541 assert!(tokens.contains(&"world".to_string()));
542 assert!(tokens.contains(&"test".to_string()));
543 assert!(!tokens.contains(&"this".to_string()));
545 assert!(!tokens.contains(&"is".to_string()));
546 assert!(!tokens.contains(&"a".to_string()));
547 }
548
549 #[test]
550 fn test_bm25_tokenize_lowercase() {
551 let index = BM25Index::new();
552 let tokens = index.tokenize("HELLO World");
553 assert!(tokens.contains(&"hello".to_string()));
554 assert!(tokens.contains(&"world".to_string()));
555 }
556
557 #[test]
558 fn test_bm25_add_chunk() {
559 let mut index = BM25Index::new();
560 let chunk = create_test_chunk("Machine learning is fascinating");
561
562 index.add(&chunk);
563
564 assert_eq!(index.len(), 1);
565 assert!(!index.is_empty());
566 assert!(index.inverted_index.contains_key("machine"));
567 assert!(index.inverted_index.contains_key("learning"));
568 }
569
570 #[test]
571 fn test_bm25_add_batch() {
572 let mut index = BM25Index::new();
573 let chunks = vec![
574 create_test_chunk("First document about AI"),
575 create_test_chunk("Second document about ML"),
576 create_test_chunk("Third document about deep learning"),
577 ];
578
579 index.add_batch(&chunks);
580
581 assert_eq!(index.len(), 3);
582 }
583
584 #[test]
585 fn test_bm25_search_basic() {
586 let mut index = BM25Index::new();
587 let chunk1 = create_test_chunk("Machine learning algorithms");
588 let chunk2 = create_test_chunk("Deep learning neural networks");
589 let chunk3 = create_test_chunk("Natural language processing");
590
591 index.add(&chunk1);
592 index.add(&chunk2);
593 index.add(&chunk3);
594
595 let results = index.search("machine learning", 10);
596
597 assert!(!results.is_empty());
598 assert!(results.iter().any(|(id, _)| *id == chunk1.id));
600 }
601
602 #[test]
603 fn test_bm25_search_empty_query() {
604 let mut index = BM25Index::new();
605 index.add(&create_test_chunk("Test document"));
606
607 let results = index.search("", 10);
608 assert!(results.is_empty());
609 }
610
611 #[test]
612 fn test_bm25_search_stopwords_only() {
613 let mut index = BM25Index::new();
614 index.add(&create_test_chunk("Test document"));
615
616 let results = index.search("the a an", 10);
617 assert!(results.is_empty());
618 }
619
620 #[test]
621 fn test_bm25_search_no_match() {
622 let mut index = BM25Index::new();
623 index.add(&create_test_chunk("Cats and dogs"));
624
625 let results = index.search("quantum physics", 10);
626 assert!(results.is_empty());
627 }
628
629 #[test]
630 fn test_bm25_search_ranking() {
631 let mut index = BM25Index::new();
632
633 let chunk1 = create_test_chunk("python programming language");
635 let chunk2 = create_test_chunk("python python python programming");
636
637 index.add(&chunk1);
638 index.add(&chunk2);
639
640 let results = index.search("python programming", 10);
641
642 assert_eq!(results.len(), 2);
643 assert_eq!(results[0].0, chunk2.id);
645 }
646
647 #[test]
648 fn test_bm25_search_top_k() {
649 let mut index = BM25Index::new();
650 for i in 0..10 {
651 index.add(&create_test_chunk(&format!("document {i} about rust")));
652 }
653
654 let results = index.search("rust", 3);
655 assert_eq!(results.len(), 3);
656 }
657
658 #[test]
659 fn test_bm25_remove() {
660 let mut index = BM25Index::new();
661 let chunk = create_test_chunk("Test document");
662 let chunk_id = chunk.id;
663
664 index.add(&chunk);
665 assert_eq!(index.len(), 1);
666
667 index.remove(chunk_id);
668 assert_eq!(index.len(), 0);
669
670 let results = index.search("test", 10);
671 assert!(results.is_empty());
672 }
673
674 #[test]
675 fn test_bm25_avg_doc_length() {
676 let mut index = BM25Index::new();
677
678 index.add(&create_test_chunk("short text")); index.add(&create_test_chunk("this is a longer piece of text about programming")); assert!(index.avg_doc_length > 0.0);
682 }
683
684 #[test]
685 fn test_bm25_idf_calculation() {
686 let mut index = BM25Index::new();
687
688 index.add(&create_test_chunk("common rare"));
690 index.add(&create_test_chunk("common word"));
691 index.add(&create_test_chunk("common term"));
692
693 let rare_results = index.search("rare", 10);
695 let common_results = index.search("common", 10);
696
697 assert!(!rare_results.is_empty());
700 assert!(!common_results.is_empty());
701 }
702
703 #[test]
706 fn test_vector_store_new() {
707 let store = VectorStore::with_dimension(384);
708 assert_eq!(store.config().dimension, 384);
709 assert!(store.is_empty());
710 }
711
712 #[test]
713 fn test_vector_store_config() {
714 let config = VectorStoreConfig {
715 dimension: 768,
716 metric: DistanceMetric::DotProduct,
717 hnsw_m: 32,
718 hnsw_ef_construction: 200,
719 hnsw_ef_search: 100,
720 };
721 let store = VectorStore::new(config.clone());
722
723 assert_eq!(store.config().dimension, 768);
724 assert_eq!(store.config().metric, DistanceMetric::DotProduct);
725 }
726
727 #[test]
728 fn test_vector_store_insert() {
729 let mut store = VectorStore::with_dimension(3);
730 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
731
732 store.insert(chunk.clone()).unwrap();
733
734 assert_eq!(store.len(), 1);
735 assert!(!store.is_empty());
736 assert!(store.get(chunk.id).is_some());
737 }
738
739 #[test]
740 fn test_vector_store_insert_no_embedding() {
741 let mut store = VectorStore::with_dimension(3);
742 let chunk = create_test_chunk("no embedding");
743
744 let result = store.insert(chunk);
745 assert!(result.is_err());
746 }
747
748 #[test]
749 fn test_vector_store_insert_wrong_dimension() {
750 let mut store = VectorStore::with_dimension(3);
751 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0]); let result = store.insert(chunk);
754 assert!(result.is_err());
755 match result {
756 Err(Error::DimensionMismatch { expected, actual }) => {
757 assert_eq!(expected, 3);
758 assert_eq!(actual, 2);
759 }
760 _ => panic!("Expected DimensionMismatch error"),
761 }
762 }
763
764 #[test]
765 fn test_vector_store_insert_batch() {
766 let mut store = VectorStore::with_dimension(3);
767 let chunks = vec![
768 create_test_chunk_with_embedding("a", vec![1.0, 0.0, 0.0]),
769 create_test_chunk_with_embedding("b", vec![0.0, 1.0, 0.0]),
770 create_test_chunk_with_embedding("c", vec![0.0, 0.0, 1.0]),
771 ];
772
773 store.insert_batch(chunks).unwrap();
774 assert_eq!(store.len(), 3);
775 }
776
777 #[test]
778 fn test_vector_store_search_cosine() {
779 let mut store = VectorStore::with_dimension(3);
780
781 let chunk1 = create_test_chunk_with_embedding("north", vec![1.0, 0.0, 0.0]);
782 let chunk2 = create_test_chunk_with_embedding("east", vec![0.0, 1.0, 0.0]);
783 let chunk3 = create_test_chunk_with_embedding(
784 "diagonal",
785 vec![std::f32::consts::FRAC_1_SQRT_2, std::f32::consts::FRAC_1_SQRT_2, 0.0],
786 );
787
788 let id1 = chunk1.id;
789 let id3 = chunk3.id;
790
791 store.insert(chunk1).unwrap();
792 store.insert(chunk2).unwrap();
793 store.insert(chunk3).unwrap();
794
795 let query = vec![0.9, 0.1, 0.0];
797 let results = store.search(&query, 10).unwrap();
798
799 assert_eq!(results.len(), 3);
800 assert_eq!(results[0].0, id1);
802 assert_eq!(results[1].0, id3);
804 }
805
806 #[test]
807 fn test_vector_store_search_top_k() {
808 let mut store = VectorStore::with_dimension(3);
809
810 for i in 0..10 {
811 let embedding = vec![i as f32, 0.0, 0.0];
812 store
813 .insert(create_test_chunk_with_embedding(&format!("chunk {i}"), embedding))
814 .unwrap();
815 }
816
817 let results = store.search(&[9.0, 0.0, 0.0], 3).unwrap();
818 assert_eq!(results.len(), 3);
819 }
820
821 #[test]
822 fn test_vector_store_search_wrong_dimension() {
823 let store = VectorStore::with_dimension(3);
824 let result = store.search(&[1.0, 0.0], 10);
825 assert!(result.is_err());
826 }
827
828 #[test]
829 fn test_vector_store_remove() {
830 let mut store = VectorStore::with_dimension(3);
831 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
832 let chunk_id = chunk.id;
833
834 store.insert(chunk).unwrap();
835 assert_eq!(store.len(), 1);
836
837 let removed = store.remove(chunk_id);
838 assert!(removed.is_some());
839 assert_eq!(store.len(), 0);
840 assert!(store.get(chunk_id).is_none());
841 }
842
843 #[test]
844 fn test_vector_store_remove_nonexistent() {
845 let mut store = VectorStore::with_dimension(3);
846 let removed = store.remove(ChunkId::new());
847 assert!(removed.is_none());
848 }
849
850 #[test]
851 fn test_distance_metric_euclidean() {
852 let config = VectorStoreConfig {
853 dimension: 2,
854 metric: DistanceMetric::Euclidean,
855 ..Default::default()
856 };
857 let mut store = VectorStore::new(config);
858
859 let chunk1 = create_test_chunk_with_embedding("origin", vec![0.0, 0.0]);
860 let chunk2 = create_test_chunk_with_embedding("near", vec![1.0, 0.0]);
861 let chunk3 = create_test_chunk_with_embedding("far", vec![10.0, 0.0]);
862
863 let id2 = chunk2.id;
864 let id1 = chunk1.id;
865
866 store.insert(chunk1).unwrap();
867 store.insert(chunk2).unwrap();
868 store.insert(chunk3).unwrap();
869
870 let results = store.search(&[0.0, 0.0], 10).unwrap();
872 assert_eq!(results[0].0, id1); assert_eq!(results[1].0, id2); }
875
876 #[test]
877 fn test_distance_metric_dot_product() {
878 let config = VectorStoreConfig {
879 dimension: 2,
880 metric: DistanceMetric::DotProduct,
881 ..Default::default()
882 };
883 let mut store = VectorStore::new(config);
884
885 let chunk1 = create_test_chunk_with_embedding("small", vec![1.0, 0.0]);
886 let chunk2 = create_test_chunk_with_embedding("large", vec![10.0, 0.0]);
887
888 let id2 = chunk2.id;
889
890 store.insert(chunk1).unwrap();
891 store.insert(chunk2).unwrap();
892
893 let results = store.search(&[1.0, 0.0], 10).unwrap();
895 assert_eq!(results[0].0, id2);
896 }
897
898 use proptest::prelude::*;
901
902 proptest! {
903 #[test]
904 fn prop_bm25_add_increases_count(content in "[a-zA-Z ]{10,100}") {
905 let mut index = BM25Index::new();
906 let initial = index.len();
907 index.add(&create_test_chunk(&content));
908 prop_assert_eq!(index.len(), initial + 1);
909 }
910
911 #[test]
912 fn prop_bm25_search_results_within_k(
913 content in prop::collection::vec("[a-zA-Z]{3,10}", 5..20),
914 k in 1usize..10
915 ) {
916 let mut index = BM25Index::new();
917 for c in &content {
918 index.add(&create_test_chunk(c));
919 }
920
921 let results = index.search("test", k);
922 prop_assert!(results.len() <= k);
923 }
924
925 #[test]
926 fn prop_bm25_scores_non_negative(
927 docs in prop::collection::vec("[a-zA-Z ]{5,50}", 3..10),
928 query in "[a-zA-Z]{3,10}"
929 ) {
930 let mut index = BM25Index::new();
931 for doc in &docs {
932 index.add(&create_test_chunk(doc));
933 }
934
935 let results = index.search(&query, 100);
936 for (_, score) in results {
937 prop_assert!(score >= 0.0);
938 }
939 }
940
941 #[test]
942 fn prop_vector_store_search_returns_stored(
943 dim in 2usize..10,
944 n_chunks in 1usize..20
945 ) {
946 let mut store = VectorStore::with_dimension(dim);
947 let mut ids = Vec::new();
948
949 for i in 0..n_chunks {
950 let mut embedding = vec![0.0f32; dim];
951 embedding[i % dim] = 1.0;
952 let chunk = create_test_chunk_with_embedding(&format!("chunk {i}"), embedding);
953 ids.push(chunk.id);
954 store.insert(chunk).unwrap();
955 }
956
957 let query = vec![1.0f32; dim];
958 let results = store.search(&query, n_chunks).unwrap();
959
960 for (id, _) in results {
962 prop_assert!(ids.contains(&id));
963 }
964 }
965 }
966}