Skip to main content

aprender_rag/
retrieve.rs

1//! Retrieval module for RAG pipelines
2
3use crate::{
4    embed::Embedder,
5    fusion::FusionStrategy,
6    index::{BM25Index, SparseIndex, VectorStore},
7    Chunk, ChunkId, Result,
8};
9use serde::{Deserialize, Serialize};
10
11/// Result of a retrieval operation
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RetrievalResult {
14    /// The retrieved chunk
15    pub chunk: Chunk,
16    /// Dense retrieval score (if applicable)
17    pub dense_score: Option<f32>,
18    /// Sparse retrieval score (if applicable)
19    pub sparse_score: Option<f32>,
20    /// Multi-vector retrieval score (if applicable, ColBERT-style MaxSim)
21    #[cfg(feature = "multivector")]
22    pub multivector_score: Option<f32>,
23    /// Fused score (if hybrid retrieval)
24    pub fused_score: Option<f32>,
25    /// Reranking score (if reranking applied)
26    pub rerank_score: Option<f32>,
27}
28
29impl RetrievalResult {
30    /// Create a new retrieval result from a chunk
31    #[must_use]
32    pub fn new(chunk: Chunk) -> Self {
33        Self {
34            chunk,
35            dense_score: None,
36            sparse_score: None,
37            #[cfg(feature = "multivector")]
38            multivector_score: None,
39            fused_score: None,
40            rerank_score: None,
41        }
42    }
43
44    /// Set the dense score
45    #[must_use]
46    pub fn with_dense_score(mut self, score: f32) -> Self {
47        self.dense_score = Some(score);
48        self
49    }
50
51    /// Set the sparse score
52    #[must_use]
53    pub fn with_sparse_score(mut self, score: f32) -> Self {
54        self.sparse_score = Some(score);
55        self
56    }
57
58    /// Set the fused score
59    #[must_use]
60    pub fn with_fused_score(mut self, score: f32) -> Self {
61        self.fused_score = Some(score);
62        self
63    }
64
65    /// Set the rerank score
66    #[must_use]
67    pub fn with_rerank_score(mut self, score: f32) -> Self {
68        self.rerank_score = Some(score);
69        self
70    }
71
72    /// Set the multi-vector (ColBERT-style) score
73    #[cfg(feature = "multivector")]
74    #[must_use]
75    pub fn with_multivector_score(mut self, score: f32) -> Self {
76        self.multivector_score = Some(score);
77        self
78    }
79
80    /// Get the best available score (rerank > fused > multivector > dense > sparse)
81    #[must_use]
82    pub fn best_score(&self) -> f32 {
83        self.rerank_score
84            .or(self.fused_score)
85            .or(self.dense_score)
86            .or(self.sparse_score)
87            .unwrap_or(0.0)
88    }
89
90    /// Get the best available score including multi-vector score
91    #[cfg(feature = "multivector")]
92    #[must_use]
93    pub fn best_score_with_multivector(&self) -> f32 {
94        self.rerank_score
95            .or(self.fused_score)
96            .or(self.multivector_score)
97            .or(self.dense_score)
98            .or(self.sparse_score)
99            .unwrap_or(0.0)
100    }
101}
102
103/// Configuration for hybrid retrieval
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct HybridRetrieverConfig {
106    /// Number of candidates to retrieve from each source
107    pub candidates_per_source: usize,
108    /// Fusion strategy
109    pub fusion: FusionStrategy,
110    /// Whether to use dense retrieval
111    pub use_dense: bool,
112    /// Whether to use sparse retrieval
113    pub use_sparse: bool,
114}
115
116impl Default for HybridRetrieverConfig {
117    fn default() -> Self {
118        Self {
119            candidates_per_source: 50,
120            fusion: FusionStrategy::default(),
121            use_dense: true,
122            use_sparse: true,
123        }
124    }
125}
126
127/// Hybrid retriever combining dense and sparse retrieval
128pub struct HybridRetriever<E: Embedder> {
129    /// Dense vector store
130    dense: VectorStore,
131    /// Sparse BM25 index
132    sparse: BM25Index,
133    /// Embedder for query embedding
134    embedder: E,
135    /// Configuration
136    config: HybridRetrieverConfig,
137}
138
139impl<E: Embedder> HybridRetriever<E> {
140    /// Create a new hybrid retriever
141    #[must_use]
142    pub fn new(dense: VectorStore, sparse: BM25Index, embedder: E) -> Self {
143        Self { dense, sparse, embedder, config: HybridRetrieverConfig::default() }
144    }
145
146    /// Set the configuration
147    #[must_use]
148    pub fn with_config(mut self, config: HybridRetrieverConfig) -> Self {
149        self.config = config;
150        self
151    }
152
153    /// Get the dense store
154    #[must_use]
155    pub fn dense_store(&self) -> &VectorStore {
156        &self.dense
157    }
158
159    /// Get the dense store mutably
160    pub fn dense_store_mut(&mut self) -> &mut VectorStore {
161        &mut self.dense
162    }
163
164    /// Get the sparse index
165    #[must_use]
166    pub fn sparse_index(&self) -> &BM25Index {
167        &self.sparse
168    }
169
170    /// Get the sparse index mutably
171    pub fn sparse_index_mut(&mut self) -> &mut BM25Index {
172        &mut self.sparse
173    }
174
175    /// Index a chunk (adds to both dense and sparse indices)
176    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
177        // Add to sparse index
178        self.sparse.add(&chunk);
179
180        // Add to dense index (requires embedding)
181        self.dense.insert(chunk)?;
182
183        Ok(())
184    }
185
186    /// Index multiple chunks
187    pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
188        for chunk in chunks {
189            self.index(chunk)?;
190        }
191        Ok(())
192    }
193
194    /// Retrieve relevant chunks for a query
195    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
196        // Contract: configuration-v1.yaml precondition (pv codegen)
197        contract_pre_configuration!(query.as_bytes());
198        let candidates = self.config.candidates_per_source;
199
200        // Dense retrieval
201        let dense_results = if self.config.use_dense {
202            let query_embedding = self.embedder.embed_query(query)?;
203            self.dense.search(&query_embedding, candidates)?
204        } else {
205            Vec::new()
206        };
207
208        // Sparse retrieval
209        let sparse_results =
210            if self.config.use_sparse { self.sparse.search(query, candidates) } else { Vec::new() };
211
212        // Fuse results
213        let fused = self.config.fusion.fuse(&dense_results, &sparse_results);
214
215        // Build score maps for lookup
216        let dense_scores: std::collections::HashMap<ChunkId, f32> =
217            dense_results.into_iter().collect();
218        let sparse_scores: std::collections::HashMap<ChunkId, f32> =
219            sparse_results.into_iter().collect();
220
221        // Build retrieval results
222        let mut results = Vec::with_capacity(k.min(fused.len()));
223        for (chunk_id, fused_score) in fused.into_iter().take(k) {
224            if let Some(chunk) = self.dense.get(chunk_id) {
225                let mut result = RetrievalResult::new(chunk.clone()).with_fused_score(fused_score);
226
227                if let Some(&score) = dense_scores.get(&chunk_id) {
228                    result = result.with_dense_score(score);
229                }
230                if let Some(&score) = sparse_scores.get(&chunk_id) {
231                    result = result.with_sparse_score(score);
232                }
233
234                results.push(result);
235            }
236        }
237
238        Ok(results)
239    }
240
241    /// Retrieve using only dense (vector) search
242    pub fn retrieve_dense(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
243        // Contract: configuration-v1.yaml precondition (pv codegen)
244        contract_pre_configuration!(query.as_bytes());
245        let query_embedding = self.embedder.embed_query(query)?;
246        let results = self.dense.search(&query_embedding, k)?;
247
248        let mut retrieval_results = Vec::with_capacity(results.len());
249        for (chunk_id, score) in results {
250            if let Some(chunk) = self.dense.get(chunk_id) {
251                retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
252            }
253        }
254
255        Ok(retrieval_results)
256    }
257
258    /// Retrieve using only sparse (BM25) search
259    pub fn retrieve_sparse(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
260        // Contract: configuration-v1.yaml precondition (pv codegen)
261        contract_pre_configuration!(query.as_bytes());
262        let results = self.sparse.search(query, k);
263
264        let mut retrieval_results = Vec::with_capacity(results.len());
265        for (chunk_id, score) in results {
266            if let Some(chunk) = self.dense.get(chunk_id) {
267                retrieval_results
268                    .push(RetrievalResult::new(chunk.clone()).with_sparse_score(score));
269            }
270        }
271
272        Ok(retrieval_results)
273    }
274
275    /// Get the number of indexed chunks
276    #[must_use]
277    pub fn len(&self) -> usize {
278        self.dense.len()
279    }
280
281    /// Check if the retriever is empty
282    #[must_use]
283    pub fn is_empty(&self) -> bool {
284        self.dense.is_empty()
285    }
286}
287
288/// Dense-only retriever (simpler API for vector-only search)
289pub struct DenseRetriever<E: Embedder> {
290    store: VectorStore,
291    embedder: E,
292}
293
294impl<E: Embedder> DenseRetriever<E> {
295    /// Create a new dense retriever
296    #[must_use]
297    pub fn new(store: VectorStore, embedder: E) -> Self {
298        Self { store, embedder }
299    }
300
301    /// Index a chunk
302    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
303        self.store.insert(chunk)
304    }
305
306    /// Retrieve relevant chunks
307    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
308        let query_embedding = self.embedder.embed_query(query)?;
309        let results = self.store.search(&query_embedding, k)?;
310
311        let mut retrieval_results = Vec::with_capacity(results.len());
312        for (chunk_id, score) in results {
313            if let Some(chunk) = self.store.get(chunk_id) {
314                retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
315            }
316        }
317
318        Ok(retrieval_results)
319    }
320}
321
322/// Sparse-only retriever (BM25)
323pub struct SparseRetriever {
324    index: BM25Index,
325    chunks: std::collections::HashMap<ChunkId, Chunk>,
326}
327
328impl SparseRetriever {
329    /// Create a new sparse retriever
330    #[must_use]
331    pub fn new() -> Self {
332        Self { index: BM25Index::new(), chunks: std::collections::HashMap::new() }
333    }
334
335    /// Index a chunk
336    pub fn index(&mut self, chunk: Chunk) {
337        self.index.add(&chunk);
338        self.chunks.insert(chunk.id, chunk);
339    }
340
341    /// Retrieve relevant chunks
342    #[must_use]
343    pub fn retrieve(&self, query: &str, k: usize) -> Vec<RetrievalResult> {
344        let results = self.index.search(query, k);
345
346        results
347            .into_iter()
348            .filter_map(|(chunk_id, score)| {
349                self.chunks
350                    .get(&chunk_id)
351                    .map(|chunk| RetrievalResult::new(chunk.clone()).with_sparse_score(score))
352            })
353            .collect()
354    }
355}
356
357impl Default for SparseRetriever {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363// ============ Multi-Vector Retriever (WARP) ============
364
365/// Multi-vector retriever using WARP index for ColBERT-style late interaction.
366///
367/// This retriever uses token-level embeddings and MaxSim scoring for fine-grained
368/// semantic matching. Unlike single-vector dense retrieval, multi-vector approaches
369/// represent documents and queries as multiple token embeddings.
370///
371/// # Example
372///
373/// ```ignore
374/// use aprender_rag::multivector::{
375///     WarpIndexConfig, WarpSearchConfig,
376///     MockMultiVectorEmbedder, MultiVectorRetriever,
377/// };
378///
379/// let config = WarpIndexConfig::new(2, 256, 128);
380/// let embedder = MockMultiVectorEmbedder::new(128, 512);
381/// let mut retriever = MultiVectorRetriever::new(config, embedder);
382///
383/// // Train on sample documents
384/// retriever.train(&sample_chunks)?;
385///
386/// // Index documents
387/// for chunk in chunks {
388///     retriever.index(chunk)?;
389/// }
390/// retriever.build()?;
391///
392/// // Search
393/// let results = retriever.retrieve("What is machine learning?", 10)?;
394/// ```
395#[cfg(feature = "multivector")]
396pub struct MultiVectorRetriever<E: crate::multivector::MultiVectorEmbedder> {
397    /// WARP index for compressed multi-vector storage and search
398    index: crate::multivector::WarpIndex,
399    /// Multi-vector embedder for token-level embeddings
400    embedder: E,
401    /// Search configuration
402    search_config: crate::multivector::WarpSearchConfig,
403}
404
405#[cfg(feature = "multivector")]
406impl<E: crate::multivector::MultiVectorEmbedder> MultiVectorRetriever<E> {
407    /// Create a new multi-vector retriever with the given configuration and embedder.
408    ///
409    /// # Arguments
410    ///
411    /// * `config` - WARP index configuration (nbits, num_centroids, token_dim)
412    /// * `embedder` - Multi-vector embedder for generating token embeddings
413    #[must_use]
414    pub fn new(config: crate::multivector::WarpIndexConfig, embedder: E) -> Self {
415        Self {
416            index: crate::multivector::WarpIndex::new(config),
417            embedder,
418            search_config: crate::multivector::WarpSearchConfig::default(),
419        }
420    }
421
422    /// Set the search configuration.
423    #[must_use]
424    pub fn with_search_config(mut self, config: crate::multivector::WarpSearchConfig) -> Self {
425        self.search_config = config;
426        self
427    }
428
429    /// Train the WARP index on sample chunks.
430    ///
431    /// This builds the residual quantization codec by learning centroids from
432    /// the provided sample embeddings. Should be called before indexing.
433    ///
434    /// # Arguments
435    ///
436    /// * `sample_chunks` - Representative chunks for training the codec
437    pub fn train(&mut self, sample_chunks: &[Chunk]) -> Result<()> {
438        let texts: Vec<&str> = sample_chunks.iter().map(|c| c.content.as_str()).collect();
439        let embeddings = self.embedder.embed_tokens_batch(&texts)?;
440        self.index.train(&embeddings)?;
441        Ok(())
442    }
443
444    /// Index a single chunk.
445    ///
446    /// The chunk is embedded and compressed using the trained codec.
447    /// Call `train()` before indexing.
448    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
449        let embedding = self.embedder.embed_tokens(&chunk.content)?;
450        self.index.insert(chunk, embedding)?;
451        Ok(())
452    }
453
454    /// Index multiple chunks.
455    pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
456        for chunk in chunks {
457            self.index(chunk)?;
458        }
459        Ok(())
460    }
461
462    /// Build the index for efficient search.
463    ///
464    /// This compacts the index by organizing embeddings by centroid (IVF structure).
465    /// Call after all chunks have been indexed.
466    pub fn build(&mut self) -> Result<()> {
467        self.index.build()
468    }
469
470    /// Retrieve relevant chunks for a query using multi-vector MaxSim scoring.
471    ///
472    /// # Arguments
473    ///
474    /// * `query` - Query text
475    /// * `k` - Number of results to return
476    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
477        let query_embedding = self.embedder.embed_tokens(query)?;
478        let search_config = crate::multivector::WarpSearchConfig::with_k(k)
479            .nprobe(self.search_config.nprobe)
480            .bound(self.search_config.bound)
481            .centroid_score_threshold(self.search_config.centroid_score_threshold);
482        let results = self.index.search(&query_embedding, &search_config)?;
483
484        let mut retrieval_results = Vec::with_capacity(results.len());
485        for (chunk_id, score) in results {
486            if let Some(chunk) = self.index.get_chunk(&chunk_id) {
487                retrieval_results
488                    .push(RetrievalResult::new(chunk.clone()).with_multivector_score(score));
489            }
490        }
491
492        Ok(retrieval_results)
493    }
494
495    /// Get the number of indexed chunks.
496    #[must_use]
497    pub fn len(&self) -> usize {
498        self.index.num_chunks()
499    }
500
501    /// Check if the retriever is empty.
502    #[must_use]
503    pub fn is_empty(&self) -> bool {
504        self.len() == 0
505    }
506
507    /// Get the underlying WARP index.
508    #[must_use]
509    pub fn warp_index(&self) -> &crate::multivector::WarpIndex {
510        &self.index
511    }
512
513    /// Get the embedder.
514    #[must_use]
515    pub fn embedder(&self) -> &E {
516        // Contract: embedding-algebra-v1.yaml precondition (pv codegen)
517        contract_pre_embedding_lookup!();
518        &self.embedder
519    }
520
521    /// Get memory usage statistics.
522    #[must_use]
523    pub fn memory_usage(&self) -> usize {
524        self.index.memory_usage()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::{embed::MockEmbedder, DocumentId};
532
533    fn create_test_chunk(content: &str, embedding: Vec<f32>) -> Chunk {
534        let mut chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
535        chunk.set_embedding(embedding);
536        chunk
537    }
538
539    // ============ RetrievalResult Tests ============
540
541    #[test]
542    fn test_retrieval_result_new() {
543        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
544        let result = RetrievalResult::new(chunk);
545
546        assert!(result.dense_score.is_none());
547        assert!(result.sparse_score.is_none());
548        assert!(result.fused_score.is_none());
549        assert!(result.rerank_score.is_none());
550    }
551
552    #[test]
553    fn test_retrieval_result_with_scores() {
554        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
555        let result = RetrievalResult::new(chunk)
556            .with_dense_score(0.9)
557            .with_sparse_score(0.8)
558            .with_fused_score(0.85)
559            .with_rerank_score(0.95);
560
561        assert_eq!(result.dense_score, Some(0.9));
562        assert_eq!(result.sparse_score, Some(0.8));
563        assert_eq!(result.fused_score, Some(0.85));
564        assert_eq!(result.rerank_score, Some(0.95));
565    }
566
567    #[test]
568    fn test_retrieval_result_best_score_priority() {
569        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
570
571        // Rerank takes priority
572        let result =
573            RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_rerank_score(0.9);
574        assert!((result.best_score() - 0.9).abs() < 0.001);
575
576        // Fused takes priority over dense/sparse
577        let result =
578            RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_fused_score(0.7);
579        assert!((result.best_score() - 0.7).abs() < 0.001);
580
581        // Dense used when nothing else available
582        let result = RetrievalResult::new(chunk).with_dense_score(0.5);
583        assert!((result.best_score() - 0.5).abs() < 0.001);
584    }
585
586    #[test]
587    fn test_retrieval_result_best_score_default() {
588        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
589        let result = RetrievalResult::new(chunk);
590        assert!((result.best_score() - 0.0).abs() < 0.001);
591    }
592
593    // ============ HybridRetrieverConfig Tests ============
594
595    #[test]
596    fn test_hybrid_config_default() {
597        let config = HybridRetrieverConfig::default();
598        assert_eq!(config.candidates_per_source, 50);
599        assert!(config.use_dense);
600        assert!(config.use_sparse);
601    }
602
603    // ============ HybridRetriever Tests ============
604
605    #[test]
606    fn test_hybrid_retriever_new() {
607        let embedder = MockEmbedder::new(64);
608        let dense = VectorStore::with_dimension(64);
609        let sparse = BM25Index::new();
610
611        let retriever = HybridRetriever::new(dense, sparse, embedder);
612        assert!(retriever.is_empty());
613    }
614
615    #[test]
616    fn test_hybrid_retriever_index() {
617        let embedder = MockEmbedder::new(64);
618        let dense = VectorStore::with_dimension(64);
619        let sparse = BM25Index::new();
620
621        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
622
623        let chunk = create_test_chunk("machine learning is great", vec![0.0; 64]);
624        retriever.index(chunk).unwrap();
625
626        assert_eq!(retriever.len(), 1);
627    }
628
629    #[test]
630    fn test_hybrid_retriever_index_batch() {
631        let embedder = MockEmbedder::new(64);
632        let dense = VectorStore::with_dimension(64);
633        let sparse = BM25Index::new();
634
635        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
636
637        let chunks = vec![
638            create_test_chunk("first document", vec![1.0; 64]),
639            create_test_chunk("second document", vec![0.5; 64]),
640        ];
641        retriever.index_batch(chunks).unwrap();
642
643        assert_eq!(retriever.len(), 2);
644    }
645
646    #[test]
647    fn test_hybrid_retriever_retrieve() {
648        let embedder = MockEmbedder::new(3);
649        let dense = VectorStore::with_dimension(3);
650        let sparse = BM25Index::new();
651
652        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
653
654        // Index some chunks
655        retriever
656            .index(create_test_chunk("machine learning algorithms", vec![1.0, 0.0, 0.0]))
657            .unwrap();
658        retriever
659            .index(create_test_chunk("deep learning neural networks", vec![0.9, 0.1, 0.0]))
660            .unwrap();
661        retriever.index(create_test_chunk("cooking recipes", vec![0.0, 0.0, 1.0])).unwrap();
662
663        let results = retriever.retrieve("machine learning", 2).unwrap();
664
665        assert!(!results.is_empty());
666        assert!(results.len() <= 2);
667    }
668
669    #[test]
670    fn test_hybrid_retriever_retrieve_dense_only() {
671        let embedder = MockEmbedder::new(3);
672        let dense = VectorStore::with_dimension(3);
673        let sparse = BM25Index::new();
674
675        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
676
677        retriever.index(create_test_chunk("test doc", vec![1.0, 0.0, 0.0])).unwrap();
678
679        let results = retriever.retrieve_dense("test", 10).unwrap();
680        assert!(!results.is_empty());
681        assert!(results[0].dense_score.is_some());
682        assert!(results[0].sparse_score.is_none());
683    }
684
685    #[test]
686    fn test_hybrid_retriever_retrieve_sparse_only() {
687        let embedder = MockEmbedder::new(3);
688        let dense = VectorStore::with_dimension(3);
689        let sparse = BM25Index::new();
690
691        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
692
693        retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
694
695        let results = retriever.retrieve_sparse("machine", 10).unwrap();
696        assert!(!results.is_empty());
697        assert!(results[0].sparse_score.is_some());
698        assert!(results[0].dense_score.is_none());
699    }
700
701    #[test]
702    fn test_hybrid_retriever_config() {
703        let embedder = MockEmbedder::new(3);
704        let dense = VectorStore::with_dimension(3);
705        let sparse = BM25Index::new();
706
707        let config = HybridRetrieverConfig {
708            candidates_per_source: 100,
709            fusion: FusionStrategy::Linear { dense_weight: 0.7 },
710            use_dense: true,
711            use_sparse: true,
712        };
713
714        let retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
715
716        assert_eq!(retriever.config.candidates_per_source, 100);
717    }
718
719    // ============ DenseRetriever Tests ============
720
721    #[test]
722    fn test_dense_retriever() {
723        let embedder = MockEmbedder::new(3);
724        let store = VectorStore::with_dimension(3);
725        let mut retriever = DenseRetriever::new(store, embedder);
726
727        retriever.index(create_test_chunk("test document", vec![1.0, 0.0, 0.0])).unwrap();
728
729        let results = retriever.retrieve("test", 10).unwrap();
730        assert_eq!(results.len(), 1);
731        assert!(results[0].dense_score.is_some());
732    }
733
734    // ============ SparseRetriever Tests ============
735
736    #[test]
737    fn test_sparse_retriever_new() {
738        let retriever = SparseRetriever::new();
739        let results = retriever.retrieve("test", 10);
740        assert!(results.is_empty());
741    }
742
743    #[test]
744    fn test_sparse_retriever_index() {
745        let mut retriever = SparseRetriever::new();
746        let chunk = Chunk::new(DocumentId::new(), "machine learning test".to_string(), 0, 20);
747
748        retriever.index(chunk);
749        let results = retriever.retrieve("machine", 10);
750
751        assert_eq!(results.len(), 1);
752        assert!(results[0].sparse_score.is_some());
753    }
754
755    #[test]
756    fn test_sparse_retriever_multiple() {
757        let mut retriever = SparseRetriever::new();
758
759        retriever.index(Chunk::new(
760            DocumentId::new(),
761            "rust programming language".to_string(),
762            0,
763            24,
764        ));
765        retriever.index(Chunk::new(
766            DocumentId::new(),
767            "python programming language".to_string(),
768            0,
769            26,
770        ));
771
772        let results = retriever.retrieve("programming", 10);
773        assert_eq!(results.len(), 2);
774    }
775
776    // ============ Additional Coverage Tests ============
777
778    #[test]
779    fn test_hybrid_retriever_store_accessors() {
780        let embedder = MockEmbedder::new(64);
781        let dense = VectorStore::with_dimension(64);
782        let sparse = BM25Index::new();
783
784        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
785
786        // Test immutable accessors
787        let _dense_store = retriever.dense_store();
788        let _sparse_index = retriever.sparse_index();
789
790        // Test mutable accessors
791        let dense_mut = retriever.dense_store_mut();
792        assert!(dense_mut.is_empty());
793
794        let sparse_mut = retriever.sparse_index_mut();
795        let _ = sparse_mut; // Just verify it compiles and works
796    }
797
798    #[test]
799    fn test_hybrid_retriever_is_empty() {
800        let embedder = MockEmbedder::new(64);
801        let dense = VectorStore::with_dimension(64);
802        let sparse = BM25Index::new();
803
804        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
805        assert!(retriever.is_empty());
806
807        retriever.index(create_test_chunk("test", vec![0.0; 64])).unwrap();
808        assert!(!retriever.is_empty());
809    }
810
811    #[test]
812    fn test_sparse_retriever_default() {
813        let retriever = SparseRetriever::default();
814        let results = retriever.retrieve("test", 10);
815        assert!(results.is_empty());
816    }
817
818    #[test]
819    fn test_retrieval_result_best_score_sparse_fallback() {
820        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
821
822        // Only sparse score available
823        let result = RetrievalResult::new(chunk).with_sparse_score(0.75);
824        assert!((result.best_score() - 0.75).abs() < 0.001);
825    }
826
827    #[test]
828    fn test_hybrid_retriever_with_dense_disabled() {
829        let embedder = MockEmbedder::new(3);
830        let dense = VectorStore::with_dimension(3);
831        let sparse = BM25Index::new();
832
833        let config = HybridRetrieverConfig {
834            candidates_per_source: 50,
835            fusion: FusionStrategy::default(),
836            use_dense: false,
837            use_sparse: true,
838        };
839
840        let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
841
842        retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
843
844        // Should still work, using only sparse
845        let results = retriever.retrieve("machine", 10).unwrap();
846        // Results depend on sparse-only fusion
847        assert!(results.len() <= 10);
848    }
849
850    #[test]
851    fn test_hybrid_retriever_with_sparse_disabled() {
852        let embedder = MockEmbedder::new(3);
853        let dense = VectorStore::with_dimension(3);
854        let sparse = BM25Index::new();
855
856        let config = HybridRetrieverConfig {
857            candidates_per_source: 50,
858            fusion: FusionStrategy::default(),
859            use_dense: true,
860            use_sparse: false,
861        };
862
863        let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
864
865        retriever.index(create_test_chunk("test content", vec![1.0, 0.0, 0.0])).unwrap();
866
867        // Should still work, using only dense
868        let results = retriever.retrieve("test", 10).unwrap();
869        assert!(results.len() <= 10);
870    }
871
872    #[test]
873    fn test_hybrid_retriever_config_serialization() {
874        let config = HybridRetrieverConfig {
875            candidates_per_source: 100,
876            fusion: FusionStrategy::RRF { k: 60.0 },
877            use_dense: true,
878            use_sparse: false,
879        };
880
881        let json = serde_json::to_string(&config).unwrap();
882        let deserialized: HybridRetrieverConfig = serde_json::from_str(&json).unwrap();
883
884        assert_eq!(config.candidates_per_source, deserialized.candidates_per_source);
885        assert_eq!(config.use_dense, deserialized.use_dense);
886        assert_eq!(config.use_sparse, deserialized.use_sparse);
887    }
888
889    #[test]
890    fn test_retrieval_result_serialization() {
891        let chunk = Chunk::new(DocumentId::new(), "test content".to_string(), 0, 12);
892        let result = RetrievalResult::new(chunk)
893            .with_dense_score(0.9)
894            .with_sparse_score(0.8)
895            .with_fused_score(0.85)
896            .with_rerank_score(0.95);
897
898        let json = serde_json::to_string(&result).unwrap();
899        let deserialized: RetrievalResult = serde_json::from_str(&json).unwrap();
900
901        assert_eq!(result.dense_score, deserialized.dense_score);
902        assert_eq!(result.sparse_score, deserialized.sparse_score);
903        assert_eq!(result.fused_score, deserialized.fused_score);
904        assert_eq!(result.rerank_score, deserialized.rerank_score);
905    }
906
907    // ============ Property-Based Tests ============
908
909    use proptest::prelude::*;
910
911    proptest! {
912        #[test]
913        fn prop_retrieval_result_scores_preserved(
914            dense in 0.0f32..1.0,
915            sparse in 0.0f32..1.0,
916            fused in 0.0f32..1.0
917        ) {
918            let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
919            let result = RetrievalResult::new(chunk)
920                .with_dense_score(dense)
921                .with_sparse_score(sparse)
922                .with_fused_score(fused);
923
924            prop_assert!((result.dense_score.unwrap() - dense).abs() < 0.0001);
925            prop_assert!((result.sparse_score.unwrap() - sparse).abs() < 0.0001);
926            prop_assert!((result.fused_score.unwrap() - fused).abs() < 0.0001);
927        }
928
929        #[test]
930        fn prop_hybrid_retriever_respects_k(k in 1usize..10) {
931            let embedder = MockEmbedder::new(3);
932            let dense = VectorStore::with_dimension(3);
933            let sparse = BM25Index::new();
934
935            let mut retriever = HybridRetriever::new(dense, sparse, embedder);
936
937            // Add more chunks than k
938            for i in 0..20 {
939                let mut emb = vec![0.0; 3];
940                emb[i % 3] = 1.0;
941                retriever.index(create_test_chunk(
942                    &format!("document number {i} about testing"),
943                    emb,
944                )).unwrap();
945            }
946
947            let results = retriever.retrieve("testing", k).unwrap();
948            prop_assert!(results.len() <= k);
949        }
950    }
951}