Skip to main content

aprender_rag/
index.rs

1//! Indexing for RAG pipelines (BM25 sparse index and vector store)
2
3use crate::{Chunk, ChunkId, Error, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6
7/// Default embedding dimension (all-MiniLM-L6-v2 / BGE-small-en-v1.5)
8const DEFAULT_EMBEDDING_DIM: usize = 384;
9
10/// Sparse index trait for lexical retrieval
11pub trait SparseIndex: Send + Sync {
12    /// Index a chunk
13    fn add(&mut self, chunk: &Chunk);
14
15    /// Index multiple chunks
16    fn add_batch(&mut self, chunks: &[Chunk]);
17
18    /// Search for matching chunks
19    fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)>;
20
21    /// Remove a chunk from the index
22    fn remove(&mut self, chunk_id: ChunkId);
23
24    /// Get the number of indexed documents
25    fn len(&self) -> usize;
26
27    /// Check if the index is empty
28    fn is_empty(&self) -> bool {
29        self.len() == 0
30    }
31}
32
33/// BM25 index implementation
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BM25Index {
36    /// Inverted index: term -> [(chunk_id, term_freq)]
37    inverted_index: HashMap<String, Vec<(ChunkId, u32)>>,
38    /// Document frequencies: term -> doc count
39    doc_freqs: HashMap<String, u32>,
40    /// Document lengths: chunk_id -> length
41    doc_lengths: HashMap<ChunkId, u32>,
42    /// Average document length
43    avg_doc_length: f32,
44    /// Total document count
45    doc_count: u32,
46    /// BM25 k1 parameter (term frequency saturation)
47    k1: f32,
48    /// BM25 b parameter (length normalization)
49    b: f32,
50    /// Tokenizer settings (used by the built-in tokenizer when no
51    /// override is supplied via [`Self::with_tokenizer`])
52    lowercase: bool,
53    /// Stopwords
54    stopwords: HashSet<String>,
55    /// HELIX-IDEA-005 Phase 4 (FALSIFY-HYBRID-003): pluggable
56    /// tokenizer override. When `Some`, the built-in
57    /// `tokenize()` delegates to this trait object instead of the
58    /// internal lowercase/stopword/min-length logic — letting
59    /// callers (notably a future inference path) share a single
60    /// tokenizer implementation with BM25.
61    ///
62    /// Skipped on serde round-trip: a saved `BM25Index` reloads
63    /// with `None`, and the caller must re-attach the same
64    /// tokenizer via [`Self::with_tokenizer`] before resuming
65    /// indexing/search.
66    #[serde(skip, default)]
67    custom_tokenizer: Option<std::sync::Arc<dyn crate::tokenizer::Tokenizer>>,
68}
69
70impl Default for BM25Index {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl BM25Index {
77    /// Create a new BM25 index with default parameters
78    #[must_use]
79    pub fn new() -> Self {
80        Self {
81            inverted_index: HashMap::new(),
82            doc_freqs: HashMap::new(),
83            doc_lengths: HashMap::new(),
84            avg_doc_length: 0.0,
85            doc_count: 0,
86            k1: 1.2,
87            b: 0.75,
88            lowercase: true,
89            stopwords: Self::default_stopwords(),
90            custom_tokenizer: None,
91        }
92    }
93
94    /// Create with custom BM25 parameters
95    #[must_use]
96    pub fn with_params(k1: f32, b: f32) -> Self {
97        Self { k1, b, ..Self::new() }
98    }
99
100    /// HELIX-IDEA-005 Phase 4 (FALSIFY-HYBRID-003): plug a custom
101    /// tokenizer into the index so BM25's notion of "term" can be
102    /// shared with other consumers (e.g., an inference path that
103    /// uses the same lexicon).
104    ///
105    /// When `tokenizer` is `Some`, the index's internal `tokenize()`
106    /// delegates to it; the built-in `lowercase` / `stopwords` /
107    /// min-length rules are bypassed entirely. To revert to the
108    /// built-in path, construct a fresh `BM25Index::new()`.
109    ///
110    /// `Arc<dyn Tokenizer>` because the index is `Clone` and may
111    /// be shared across threads — `Box<dyn Tokenizer>` would force
112    /// each clone to deep-copy the tokenizer state.
113    #[must_use]
114    pub fn with_tokenizer(
115        mut self,
116        tokenizer: std::sync::Arc<dyn crate::tokenizer::Tokenizer>,
117    ) -> Self {
118        self.custom_tokenizer = Some(tokenizer);
119        self
120    }
121
122    /// True iff a custom tokenizer is plugged in (used by tests
123    /// and FALSIFY-HYBRID-003 to confirm the override path is
124    /// active).
125    #[must_use]
126    pub fn has_custom_tokenizer(&self) -> bool {
127        self.custom_tokenizer.is_some()
128    }
129
130    /// All terms currently indexed (i.e., the keys of the
131    /// inverted index). Used by FALSIFY-HYBRID-003 to verify the
132    /// indexer consulted the injected tokenizer during `add()` —
133    /// the built-in tokenizer and an injected one produce
134    /// observably different key sets on the same content.
135    #[must_use]
136    pub fn indexed_terms(&self) -> Vec<&str> {
137        self.inverted_index.keys().map(String::as_str).collect()
138    }
139
140    /// Set stopwords
141    #[must_use]
142    pub fn with_stopwords(mut self, stopwords: HashSet<String>) -> Self {
143        self.stopwords = stopwords;
144        self
145    }
146
147    fn default_stopwords() -> HashSet<String> {
148        [
149            "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
150            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
151            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
152            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
153            "below", "between", "under", "again", "further", "then", "once", "here", "there",
154            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
155            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
156            "and", "but", "if", "or", "because", "until", "while", "this", "that", "these",
157            "those", "it", "its",
158        ]
159        .iter()
160        .map(|s| (*s).to_string())
161        .collect()
162    }
163
164    /// Tokenize text. Consults the custom tokenizer override
165    /// (FALSIFY-HYBRID-003) if one is plugged in, otherwise falls
166    /// back to the built-in word-boundary + lowercase + stopwords
167    /// rule.
168    pub fn tokenize(&self, text: &str) -> Vec<String> {
169        if let Some(tok) = self.custom_tokenizer.as_ref() {
170            return tok.tokenize(text);
171        }
172        text.split(|c: char| !c.is_alphanumeric())
173            .filter(|s| !s.is_empty())
174            .map(|s| if self.lowercase { s.to_lowercase() } else { s.to_string() })
175            .filter(|s| !self.stopwords.contains(s))
176            .filter(|s| s.len() >= 2) // Filter very short tokens
177            .collect()
178    }
179
180    /// Compute term frequency in a document
181    fn term_frequency(&self, term: &str, chunk_id: ChunkId) -> u32 {
182        self.inverted_index
183            .get(term)
184            .and_then(|postings| postings.iter().find(|(id, _)| *id == chunk_id))
185            .map(|(_, freq)| *freq)
186            .unwrap_or(0)
187    }
188
189    /// Compute BM25 score for a single term
190    fn score_term(&self, term: &str, chunk_id: ChunkId) -> f32 {
191        let tf = self.term_frequency(term, chunk_id) as f32;
192        if tf == 0.0 {
193            return 0.0;
194        }
195
196        let df = self.doc_freqs.get(term).copied().unwrap_or(0) as f32;
197        let n = self.doc_count as f32;
198        let doc_len = self.doc_lengths.get(&chunk_id).copied().unwrap_or(0) as f32;
199
200        // IDF component: log((N - df + 0.5) / (df + 0.5) + 1)
201        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).max(f32::EPSILON).ln();
202
203        // TF component with length normalization
204        let tf_norm = (tf * (self.k1 + 1.0))
205            / (tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length));
206
207        idf * tf_norm
208    }
209
210    /// Update average document length
211    fn update_avg_doc_length(&mut self) {
212        if self.doc_count == 0 {
213            self.avg_doc_length = 0.0;
214        } else {
215            let total: u32 = self.doc_lengths.values().sum();
216            self.avg_doc_length = total as f32 / self.doc_count as f32;
217        }
218    }
219
220    /// Get chunks containing a term
221    fn get_chunks_for_term(&self, term: &str) -> Vec<ChunkId> {
222        self.inverted_index
223            .get(term)
224            .map(|postings| postings.iter().map(|(id, _)| *id).collect())
225            .unwrap_or_default()
226    }
227}
228
229impl SparseIndex for BM25Index {
230    fn add(&mut self, chunk: &Chunk) {
231        let tokens = self.tokenize(&chunk.content);
232        let doc_len = tokens.len() as u32;
233
234        // Update document length
235        self.doc_lengths.insert(chunk.id, doc_len);
236        self.doc_count += 1;
237
238        // Count term frequencies
239        let mut term_freqs: HashMap<String, u32> = HashMap::new();
240        for token in &tokens {
241            *term_freqs.entry(token.clone()).or_insert(0) += 1;
242        }
243
244        // Update inverted index and document frequencies
245        let mut seen_terms: HashSet<String> = HashSet::new();
246        for (term, freq) in term_freqs {
247            self.inverted_index.entry(term.clone()).or_default().push((chunk.id, freq));
248
249            if seen_terms.insert(term.clone()) {
250                *self.doc_freqs.entry(term).or_insert(0) += 1;
251            }
252        }
253
254        self.update_avg_doc_length();
255    }
256
257    fn add_batch(&mut self, chunks: &[Chunk]) {
258        for chunk in chunks {
259            self.add(chunk);
260        }
261    }
262
263    fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)> {
264        let query_terms = self.tokenize(query);
265        if query_terms.is_empty() {
266            return Vec::new();
267        }
268
269        // Collect candidate documents
270        let mut candidates: HashSet<ChunkId> = HashSet::new();
271        for term in &query_terms {
272            for chunk_id in self.get_chunks_for_term(term) {
273                candidates.insert(chunk_id);
274            }
275        }
276
277        // Score candidates
278        let mut scores: Vec<(ChunkId, f32)> = candidates
279            .into_iter()
280            .map(|chunk_id| {
281                let score: f32 =
282                    query_terms.iter().map(|term| self.score_term(term, chunk_id)).sum();
283                (chunk_id, score)
284            })
285            .filter(|(_, score)| *score > 0.0)
286            .collect();
287
288        // Sort by score descending
289        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
290        scores.truncate(k);
291        scores
292    }
293
294    fn remove(&mut self, chunk_id: ChunkId) {
295        // Remove from document lengths
296        if self.doc_lengths.remove(&chunk_id).is_some() {
297            self.doc_count = self.doc_count.saturating_sub(1);
298        }
299
300        // Remove from inverted index
301        let mut terms_to_remove: Vec<String> = Vec::new();
302        for (term, postings) in &mut self.inverted_index {
303            let original_len = postings.len();
304            postings.retain(|(id, _)| *id != chunk_id);
305
306            if postings.len() < original_len {
307                // Document contained this term
308                if let Some(df) = self.doc_freqs.get_mut(term) {
309                    *df = df.saturating_sub(1);
310                    if *df == 0 {
311                        terms_to_remove.push(term.clone());
312                    }
313                }
314            }
315        }
316
317        // Clean up empty terms
318        for term in terms_to_remove {
319            self.inverted_index.remove(&term);
320            self.doc_freqs.remove(&term);
321        }
322
323        self.update_avg_doc_length();
324    }
325
326    fn len(&self) -> usize {
327        self.doc_count as usize
328    }
329}
330
331/// Vector store configuration
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct VectorStoreConfig {
334    /// Embedding dimension
335    pub dimension: usize,
336    /// Distance metric
337    pub metric: DistanceMetric,
338    /// HNSW M parameter (connections per node)
339    pub hnsw_m: usize,
340    /// HNSW ef_construction parameter
341    pub hnsw_ef_construction: usize,
342    /// HNSW ef_search parameter
343    pub hnsw_ef_search: usize,
344}
345
346impl Default for VectorStoreConfig {
347    fn default() -> Self {
348        Self {
349            dimension: DEFAULT_EMBEDDING_DIM,
350            metric: DistanceMetric::Cosine,
351            hnsw_m: 16,
352            hnsw_ef_construction: 100,
353            hnsw_ef_search: 50,
354        }
355    }
356}
357
358/// Distance metric for vector search
359#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
360pub enum DistanceMetric {
361    /// Cosine similarity
362    #[default]
363    Cosine,
364    /// Euclidean distance
365    Euclidean,
366    /// Dot product
367    DotProduct,
368}
369
370/// Vector store for dense retrieval
371#[derive(Debug, Clone)]
372pub struct VectorStore {
373    /// Configuration
374    config: VectorStoreConfig,
375    /// Stored vectors: chunk_id -> embedding
376    vectors: HashMap<ChunkId, Vec<f32>>,
377    /// Chunk content cache: chunk_id -> content
378    chunks: HashMap<ChunkId, Chunk>,
379}
380
381impl VectorStore {
382    /// Create a new vector store
383    #[must_use]
384    pub fn new(config: VectorStoreConfig) -> Self {
385        Self { config, vectors: HashMap::new(), chunks: HashMap::new() }
386    }
387
388    /// Create with default configuration
389    #[must_use]
390    pub fn with_dimension(dimension: usize) -> Self {
391        Self::new(VectorStoreConfig { dimension, ..Default::default() })
392    }
393
394    /// Get the configuration
395    #[must_use]
396    pub fn config(&self) -> &VectorStoreConfig {
397        &self.config
398    }
399
400    /// Insert a chunk with its embedding
401    pub fn insert(&mut self, chunk: Chunk) -> Result<()> {
402        let embedding = chunk
403            .embedding
404            .as_ref()
405            .ok_or_else(|| Error::InvalidConfig("chunk must have embedding".to_string()))?;
406
407        if embedding.len() != self.config.dimension {
408            return Err(Error::DimensionMismatch {
409                expected: self.config.dimension,
410                actual: embedding.len(),
411            });
412        }
413
414        self.vectors.insert(chunk.id, embedding.clone());
415        self.chunks.insert(chunk.id, chunk);
416        Ok(())
417    }
418
419    /// Insert multiple chunks
420    pub fn insert_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
421        for chunk in chunks {
422            self.insert(chunk)?;
423        }
424        Ok(())
425    }
426
427    /// Search for similar vectors
428    pub fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<(ChunkId, f32)>> {
429        if query_vector.len() != self.config.dimension {
430            return Err(Error::DimensionMismatch {
431                expected: self.config.dimension,
432                actual: query_vector.len(),
433            });
434        }
435
436        let mut scores: Vec<(ChunkId, f32)> = self
437            .vectors
438            .iter()
439            .map(|(id, vec)| {
440                let score = match self.config.metric {
441                    DistanceMetric::Cosine => cosine_similarity(query_vector, vec),
442                    DistanceMetric::Euclidean => -euclidean_distance(query_vector, vec),
443                    DistanceMetric::DotProduct => dot_product(query_vector, vec),
444                };
445                (*id, score)
446            })
447            .collect();
448
449        // Sort by score descending (higher is better)
450        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
451        scores.truncate(k);
452
453        Ok(scores)
454    }
455
456    /// Get a chunk by ID
457    #[must_use]
458    pub fn get(&self, chunk_id: ChunkId) -> Option<&Chunk> {
459        self.chunks.get(&chunk_id)
460    }
461
462    /// Remove a chunk
463    pub fn remove(&mut self, chunk_id: ChunkId) -> Option<Chunk> {
464        self.vectors.remove(&chunk_id);
465        self.chunks.remove(&chunk_id)
466    }
467
468    /// Get the number of stored vectors
469    #[must_use]
470    pub fn len(&self) -> usize {
471        self.vectors.len()
472    }
473
474    /// Check if the store is empty
475    #[must_use]
476    pub fn is_empty(&self) -> bool {
477        self.vectors.is_empty()
478    }
479}
480
481// Helper functions from embed module
482fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
483    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
484    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
485    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
486
487    if norm_a == 0.0 || norm_b == 0.0 {
488        0.0
489    } else {
490        dot / (norm_a * norm_b)
491    }
492}
493
494fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
495    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
496}
497
498fn dot_product(a: &[f32], b: &[f32]) -> f32 {
499    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::DocumentId;
506
507    fn create_test_chunk(content: &str) -> Chunk {
508        Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
509    }
510
511    fn create_test_chunk_with_embedding(content: &str, embedding: Vec<f32>) -> Chunk {
512        let mut chunk = create_test_chunk(content);
513        chunk.set_embedding(embedding);
514        chunk
515    }
516
517    // ============ BM25Index Tests ============
518
519    #[test]
520    fn test_bm25_index_new() {
521        let index = BM25Index::new();
522        assert_eq!(index.len(), 0);
523        assert!(index.is_empty());
524        assert!((index.k1 - 1.2).abs() < 0.01);
525        assert!((index.b - 0.75).abs() < 0.01);
526    }
527
528    #[test]
529    fn test_bm25_index_with_params() {
530        let index = BM25Index::with_params(1.5, 0.5);
531        assert!((index.k1 - 1.5).abs() < 0.01);
532        assert!((index.b - 0.5).abs() < 0.01);
533    }
534
535    #[test]
536    fn test_bm25_tokenize() {
537        let index = BM25Index::new();
538        let tokens = index.tokenize("Hello World! This is a test.");
539
540        assert!(tokens.contains(&"hello".to_string()));
541        assert!(tokens.contains(&"world".to_string()));
542        assert!(tokens.contains(&"test".to_string()));
543        // Stopwords should be removed
544        assert!(!tokens.contains(&"this".to_string()));
545        assert!(!tokens.contains(&"is".to_string()));
546        assert!(!tokens.contains(&"a".to_string()));
547    }
548
549    #[test]
550    fn test_bm25_tokenize_lowercase() {
551        let index = BM25Index::new();
552        let tokens = index.tokenize("HELLO World");
553        assert!(tokens.contains(&"hello".to_string()));
554        assert!(tokens.contains(&"world".to_string()));
555    }
556
557    #[test]
558    fn test_bm25_add_chunk() {
559        let mut index = BM25Index::new();
560        let chunk = create_test_chunk("Machine learning is fascinating");
561
562        index.add(&chunk);
563
564        assert_eq!(index.len(), 1);
565        assert!(!index.is_empty());
566        assert!(index.inverted_index.contains_key("machine"));
567        assert!(index.inverted_index.contains_key("learning"));
568    }
569
570    #[test]
571    fn test_bm25_add_batch() {
572        let mut index = BM25Index::new();
573        let chunks = vec![
574            create_test_chunk("First document about AI"),
575            create_test_chunk("Second document about ML"),
576            create_test_chunk("Third document about deep learning"),
577        ];
578
579        index.add_batch(&chunks);
580
581        assert_eq!(index.len(), 3);
582    }
583
584    #[test]
585    fn test_bm25_search_basic() {
586        let mut index = BM25Index::new();
587        let chunk1 = create_test_chunk("Machine learning algorithms");
588        let chunk2 = create_test_chunk("Deep learning neural networks");
589        let chunk3 = create_test_chunk("Natural language processing");
590
591        index.add(&chunk1);
592        index.add(&chunk2);
593        index.add(&chunk3);
594
595        let results = index.search("machine learning", 10);
596
597        assert!(!results.is_empty());
598        // Chunk with "machine learning" should score highest
599        assert!(results.iter().any(|(id, _)| *id == chunk1.id));
600    }
601
602    #[test]
603    fn test_bm25_search_empty_query() {
604        let mut index = BM25Index::new();
605        index.add(&create_test_chunk("Test document"));
606
607        let results = index.search("", 10);
608        assert!(results.is_empty());
609    }
610
611    #[test]
612    fn test_bm25_search_stopwords_only() {
613        let mut index = BM25Index::new();
614        index.add(&create_test_chunk("Test document"));
615
616        let results = index.search("the a an", 10);
617        assert!(results.is_empty());
618    }
619
620    #[test]
621    fn test_bm25_search_no_match() {
622        let mut index = BM25Index::new();
623        index.add(&create_test_chunk("Cats and dogs"));
624
625        let results = index.search("quantum physics", 10);
626        assert!(results.is_empty());
627    }
628
629    #[test]
630    fn test_bm25_search_ranking() {
631        let mut index = BM25Index::new();
632
633        // Document with more term matches should rank higher
634        let chunk1 = create_test_chunk("python programming language");
635        let chunk2 = create_test_chunk("python python python programming");
636
637        index.add(&chunk1);
638        index.add(&chunk2);
639
640        let results = index.search("python programming", 10);
641
642        assert_eq!(results.len(), 2);
643        // Chunk2 should rank higher due to more "python" occurrences
644        assert_eq!(results[0].0, chunk2.id);
645    }
646
647    #[test]
648    fn test_bm25_search_top_k() {
649        let mut index = BM25Index::new();
650        for i in 0..10 {
651            index.add(&create_test_chunk(&format!("document {i} about rust")));
652        }
653
654        let results = index.search("rust", 3);
655        assert_eq!(results.len(), 3);
656    }
657
658    #[test]
659    fn test_bm25_remove() {
660        let mut index = BM25Index::new();
661        let chunk = create_test_chunk("Test document");
662        let chunk_id = chunk.id;
663
664        index.add(&chunk);
665        assert_eq!(index.len(), 1);
666
667        index.remove(chunk_id);
668        assert_eq!(index.len(), 0);
669
670        let results = index.search("test", 10);
671        assert!(results.is_empty());
672    }
673
674    #[test]
675    fn test_bm25_avg_doc_length() {
676        let mut index = BM25Index::new();
677
678        index.add(&create_test_chunk("short text")); // ~2 tokens
679        index.add(&create_test_chunk("this is a longer piece of text about programming")); // ~5 tokens
680
681        assert!(index.avg_doc_length > 0.0);
682    }
683
684    #[test]
685    fn test_bm25_idf_calculation() {
686        let mut index = BM25Index::new();
687
688        // Add documents with varying term frequencies
689        index.add(&create_test_chunk("common rare"));
690        index.add(&create_test_chunk("common word"));
691        index.add(&create_test_chunk("common term"));
692
693        // Search for rare term should give higher score
694        let rare_results = index.search("rare", 10);
695        let common_results = index.search("common", 10);
696
697        // "rare" appears in 1 doc, "common" in 3 docs
698        // IDF of "rare" should be higher
699        assert!(!rare_results.is_empty());
700        assert!(!common_results.is_empty());
701    }
702
703    // ============ VectorStore Tests ============
704
705    #[test]
706    fn test_vector_store_new() {
707        let store = VectorStore::with_dimension(384);
708        assert_eq!(store.config().dimension, 384);
709        assert!(store.is_empty());
710    }
711
712    #[test]
713    fn test_vector_store_config() {
714        let config = VectorStoreConfig {
715            dimension: 768,
716            metric: DistanceMetric::DotProduct,
717            hnsw_m: 32,
718            hnsw_ef_construction: 200,
719            hnsw_ef_search: 100,
720        };
721        let store = VectorStore::new(config.clone());
722
723        assert_eq!(store.config().dimension, 768);
724        assert_eq!(store.config().metric, DistanceMetric::DotProduct);
725    }
726
727    #[test]
728    fn test_vector_store_insert() {
729        let mut store = VectorStore::with_dimension(3);
730        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
731
732        store.insert(chunk.clone()).unwrap();
733
734        assert_eq!(store.len(), 1);
735        assert!(!store.is_empty());
736        assert!(store.get(chunk.id).is_some());
737    }
738
739    #[test]
740    fn test_vector_store_insert_no_embedding() {
741        let mut store = VectorStore::with_dimension(3);
742        let chunk = create_test_chunk("no embedding");
743
744        let result = store.insert(chunk);
745        assert!(result.is_err());
746    }
747
748    #[test]
749    fn test_vector_store_insert_wrong_dimension() {
750        let mut store = VectorStore::with_dimension(3);
751        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0]); // Wrong dimension
752
753        let result = store.insert(chunk);
754        assert!(result.is_err());
755        match result {
756            Err(Error::DimensionMismatch { expected, actual }) => {
757                assert_eq!(expected, 3);
758                assert_eq!(actual, 2);
759            }
760            _ => panic!("Expected DimensionMismatch error"),
761        }
762    }
763
764    #[test]
765    fn test_vector_store_insert_batch() {
766        let mut store = VectorStore::with_dimension(3);
767        let chunks = vec![
768            create_test_chunk_with_embedding("a", vec![1.0, 0.0, 0.0]),
769            create_test_chunk_with_embedding("b", vec![0.0, 1.0, 0.0]),
770            create_test_chunk_with_embedding("c", vec![0.0, 0.0, 1.0]),
771        ];
772
773        store.insert_batch(chunks).unwrap();
774        assert_eq!(store.len(), 3);
775    }
776
777    #[test]
778    fn test_vector_store_search_cosine() {
779        let mut store = VectorStore::with_dimension(3);
780
781        let chunk1 = create_test_chunk_with_embedding("north", vec![1.0, 0.0, 0.0]);
782        let chunk2 = create_test_chunk_with_embedding("east", vec![0.0, 1.0, 0.0]);
783        let chunk3 = create_test_chunk_with_embedding(
784            "diagonal",
785            vec![std::f32::consts::FRAC_1_SQRT_2, std::f32::consts::FRAC_1_SQRT_2, 0.0],
786        );
787
788        let id1 = chunk1.id;
789        let id3 = chunk3.id;
790
791        store.insert(chunk1).unwrap();
792        store.insert(chunk2).unwrap();
793        store.insert(chunk3).unwrap();
794
795        // Search for vector pointing mostly north
796        let query = vec![0.9, 0.1, 0.0];
797        let results = store.search(&query, 10).unwrap();
798
799        assert_eq!(results.len(), 3);
800        // chunk1 (north) should be most similar
801        assert_eq!(results[0].0, id1);
802        // chunk3 (diagonal) should be second
803        assert_eq!(results[1].0, id3);
804    }
805
806    #[test]
807    fn test_vector_store_search_top_k() {
808        let mut store = VectorStore::with_dimension(3);
809
810        for i in 0..10 {
811            let embedding = vec![i as f32, 0.0, 0.0];
812            store
813                .insert(create_test_chunk_with_embedding(&format!("chunk {i}"), embedding))
814                .unwrap();
815        }
816
817        let results = store.search(&[9.0, 0.0, 0.0], 3).unwrap();
818        assert_eq!(results.len(), 3);
819    }
820
821    #[test]
822    fn test_vector_store_search_wrong_dimension() {
823        let store = VectorStore::with_dimension(3);
824        let result = store.search(&[1.0, 0.0], 10);
825        assert!(result.is_err());
826    }
827
828    #[test]
829    fn test_vector_store_remove() {
830        let mut store = VectorStore::with_dimension(3);
831        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
832        let chunk_id = chunk.id;
833
834        store.insert(chunk).unwrap();
835        assert_eq!(store.len(), 1);
836
837        let removed = store.remove(chunk_id);
838        assert!(removed.is_some());
839        assert_eq!(store.len(), 0);
840        assert!(store.get(chunk_id).is_none());
841    }
842
843    #[test]
844    fn test_vector_store_remove_nonexistent() {
845        let mut store = VectorStore::with_dimension(3);
846        let removed = store.remove(ChunkId::new());
847        assert!(removed.is_none());
848    }
849
850    #[test]
851    fn test_distance_metric_euclidean() {
852        let config = VectorStoreConfig {
853            dimension: 2,
854            metric: DistanceMetric::Euclidean,
855            ..Default::default()
856        };
857        let mut store = VectorStore::new(config);
858
859        let chunk1 = create_test_chunk_with_embedding("origin", vec![0.0, 0.0]);
860        let chunk2 = create_test_chunk_with_embedding("near", vec![1.0, 0.0]);
861        let chunk3 = create_test_chunk_with_embedding("far", vec![10.0, 0.0]);
862
863        let id2 = chunk2.id;
864        let id1 = chunk1.id;
865
866        store.insert(chunk1).unwrap();
867        store.insert(chunk2).unwrap();
868        store.insert(chunk3).unwrap();
869
870        // Search from origin - near should be closest
871        let results = store.search(&[0.0, 0.0], 10).unwrap();
872        assert_eq!(results[0].0, id1); // Exact match
873        assert_eq!(results[1].0, id2); // Nearest neighbor
874    }
875
876    #[test]
877    fn test_distance_metric_dot_product() {
878        let config = VectorStoreConfig {
879            dimension: 2,
880            metric: DistanceMetric::DotProduct,
881            ..Default::default()
882        };
883        let mut store = VectorStore::new(config);
884
885        let chunk1 = create_test_chunk_with_embedding("small", vec![1.0, 0.0]);
886        let chunk2 = create_test_chunk_with_embedding("large", vec![10.0, 0.0]);
887
888        let id2 = chunk2.id;
889
890        store.insert(chunk1).unwrap();
891        store.insert(chunk2).unwrap();
892
893        // Dot product prefers larger magnitude vectors
894        let results = store.search(&[1.0, 0.0], 10).unwrap();
895        assert_eq!(results[0].0, id2);
896    }
897
898    // ============ Property-Based Tests ============
899
900    use proptest::prelude::*;
901
902    proptest! {
903        #[test]
904        fn prop_bm25_add_increases_count(content in "[a-zA-Z ]{10,100}") {
905            let mut index = BM25Index::new();
906            let initial = index.len();
907            index.add(&create_test_chunk(&content));
908            prop_assert_eq!(index.len(), initial + 1);
909        }
910
911        #[test]
912        fn prop_bm25_search_results_within_k(
913            content in prop::collection::vec("[a-zA-Z]{3,10}", 5..20),
914            k in 1usize..10
915        ) {
916            let mut index = BM25Index::new();
917            for c in &content {
918                index.add(&create_test_chunk(c));
919            }
920
921            let results = index.search("test", k);
922            prop_assert!(results.len() <= k);
923        }
924
925        #[test]
926        fn prop_bm25_scores_non_negative(
927            docs in prop::collection::vec("[a-zA-Z ]{5,50}", 3..10),
928            query in "[a-zA-Z]{3,10}"
929        ) {
930            let mut index = BM25Index::new();
931            for doc in &docs {
932                index.add(&create_test_chunk(doc));
933            }
934
935            let results = index.search(&query, 100);
936            for (_, score) in results {
937                prop_assert!(score >= 0.0);
938            }
939        }
940
941        #[test]
942        fn prop_vector_store_search_returns_stored(
943            dim in 2usize..10,
944            n_chunks in 1usize..20
945        ) {
946            let mut store = VectorStore::with_dimension(dim);
947            let mut ids = Vec::new();
948
949            for i in 0..n_chunks {
950                let mut embedding = vec![0.0f32; dim];
951                embedding[i % dim] = 1.0;
952                let chunk = create_test_chunk_with_embedding(&format!("chunk {i}"), embedding);
953                ids.push(chunk.id);
954                store.insert(chunk).unwrap();
955            }
956
957            let query = vec![1.0f32; dim];
958            let results = store.search(&query, n_chunks).unwrap();
959
960            // All results should be from stored chunks
961            for (id, _) in results {
962                prop_assert!(ids.contains(&id));
963            }
964        }
965    }
966}