1use crate::{
4 embed::Embedder,
5 fusion::FusionStrategy,
6 index::{BM25Index, SparseIndex, VectorStore},
7 Chunk, ChunkId, Result,
8};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RetrievalResult {
14 pub chunk: Chunk,
16 pub dense_score: Option<f32>,
18 pub sparse_score: Option<f32>,
20 #[cfg(feature = "multivector")]
22 pub multivector_score: Option<f32>,
23 pub fused_score: Option<f32>,
25 pub rerank_score: Option<f32>,
27}
28
29impl RetrievalResult {
30 #[must_use]
32 pub fn new(chunk: Chunk) -> Self {
33 Self {
34 chunk,
35 dense_score: None,
36 sparse_score: None,
37 #[cfg(feature = "multivector")]
38 multivector_score: None,
39 fused_score: None,
40 rerank_score: None,
41 }
42 }
43
44 #[must_use]
46 pub fn with_dense_score(mut self, score: f32) -> Self {
47 self.dense_score = Some(score);
48 self
49 }
50
51 #[must_use]
53 pub fn with_sparse_score(mut self, score: f32) -> Self {
54 self.sparse_score = Some(score);
55 self
56 }
57
58 #[must_use]
60 pub fn with_fused_score(mut self, score: f32) -> Self {
61 self.fused_score = Some(score);
62 self
63 }
64
65 #[must_use]
67 pub fn with_rerank_score(mut self, score: f32) -> Self {
68 self.rerank_score = Some(score);
69 self
70 }
71
72 #[cfg(feature = "multivector")]
74 #[must_use]
75 pub fn with_multivector_score(mut self, score: f32) -> Self {
76 self.multivector_score = Some(score);
77 self
78 }
79
80 #[must_use]
82 pub fn best_score(&self) -> f32 {
83 self.rerank_score
84 .or(self.fused_score)
85 .or(self.dense_score)
86 .or(self.sparse_score)
87 .unwrap_or(0.0)
88 }
89
90 #[cfg(feature = "multivector")]
92 #[must_use]
93 pub fn best_score_with_multivector(&self) -> f32 {
94 self.rerank_score
95 .or(self.fused_score)
96 .or(self.multivector_score)
97 .or(self.dense_score)
98 .or(self.sparse_score)
99 .unwrap_or(0.0)
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct HybridRetrieverConfig {
106 pub candidates_per_source: usize,
108 pub fusion: FusionStrategy,
110 pub use_dense: bool,
112 pub use_sparse: bool,
114}
115
116impl Default for HybridRetrieverConfig {
117 fn default() -> Self {
118 Self {
119 candidates_per_source: 50,
120 fusion: FusionStrategy::default(),
121 use_dense: true,
122 use_sparse: true,
123 }
124 }
125}
126
127pub struct HybridRetriever<E: Embedder> {
129 dense: VectorStore,
131 sparse: BM25Index,
133 embedder: E,
135 config: HybridRetrieverConfig,
137}
138
139impl<E: Embedder> HybridRetriever<E> {
140 #[must_use]
142 pub fn new(dense: VectorStore, sparse: BM25Index, embedder: E) -> Self {
143 Self { dense, sparse, embedder, config: HybridRetrieverConfig::default() }
144 }
145
146 #[must_use]
148 pub fn with_config(mut self, config: HybridRetrieverConfig) -> Self {
149 self.config = config;
150 self
151 }
152
153 #[must_use]
155 pub fn dense_store(&self) -> &VectorStore {
156 &self.dense
157 }
158
159 pub fn dense_store_mut(&mut self) -> &mut VectorStore {
161 &mut self.dense
162 }
163
164 #[must_use]
166 pub fn sparse_index(&self) -> &BM25Index {
167 &self.sparse
168 }
169
170 pub fn sparse_index_mut(&mut self) -> &mut BM25Index {
172 &mut self.sparse
173 }
174
175 pub fn index(&mut self, chunk: Chunk) -> Result<()> {
177 self.sparse.add(&chunk);
179
180 self.dense.insert(chunk)?;
182
183 Ok(())
184 }
185
186 pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
188 for chunk in chunks {
189 self.index(chunk)?;
190 }
191 Ok(())
192 }
193
194 pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
196 let candidates = self.config.candidates_per_source;
197
198 let dense_results = if self.config.use_dense {
200 let query_embedding = self.embedder.embed_query(query)?;
201 self.dense.search(&query_embedding, candidates)?
202 } else {
203 Vec::new()
204 };
205
206 let sparse_results =
208 if self.config.use_sparse { self.sparse.search(query, candidates) } else { Vec::new() };
209
210 let fused = self.config.fusion.fuse(&dense_results, &sparse_results);
212
213 let dense_scores: std::collections::HashMap<ChunkId, f32> =
215 dense_results.into_iter().collect();
216 let sparse_scores: std::collections::HashMap<ChunkId, f32> =
217 sparse_results.into_iter().collect();
218
219 let mut results = Vec::with_capacity(k.min(fused.len()));
221 for (chunk_id, fused_score) in fused.into_iter().take(k) {
222 if let Some(chunk) = self.dense.get(chunk_id) {
223 let mut result = RetrievalResult::new(chunk.clone()).with_fused_score(fused_score);
224
225 if let Some(&score) = dense_scores.get(&chunk_id) {
226 result = result.with_dense_score(score);
227 }
228 if let Some(&score) = sparse_scores.get(&chunk_id) {
229 result = result.with_sparse_score(score);
230 }
231
232 results.push(result);
233 }
234 }
235
236 Ok(results)
237 }
238
239 pub fn retrieve_dense(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
241 let query_embedding = self.embedder.embed_query(query)?;
242 let results = self.dense.search(&query_embedding, k)?;
243
244 let mut retrieval_results = Vec::with_capacity(results.len());
245 for (chunk_id, score) in results {
246 if let Some(chunk) = self.dense.get(chunk_id) {
247 retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
248 }
249 }
250
251 Ok(retrieval_results)
252 }
253
254 pub fn retrieve_sparse(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
256 let results = self.sparse.search(query, k);
257
258 let mut retrieval_results = Vec::with_capacity(results.len());
259 for (chunk_id, score) in results {
260 if let Some(chunk) = self.dense.get(chunk_id) {
261 retrieval_results
262 .push(RetrievalResult::new(chunk.clone()).with_sparse_score(score));
263 }
264 }
265
266 Ok(retrieval_results)
267 }
268
269 #[must_use]
271 pub fn len(&self) -> usize {
272 self.dense.len()
273 }
274
275 #[must_use]
277 pub fn is_empty(&self) -> bool {
278 self.dense.is_empty()
279 }
280}
281
282pub struct DenseRetriever<E: Embedder> {
284 store: VectorStore,
285 embedder: E,
286}
287
288impl<E: Embedder> DenseRetriever<E> {
289 #[must_use]
291 pub fn new(store: VectorStore, embedder: E) -> Self {
292 Self { store, embedder }
293 }
294
295 pub fn index(&mut self, chunk: Chunk) -> Result<()> {
297 self.store.insert(chunk)
298 }
299
300 pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
302 let query_embedding = self.embedder.embed_query(query)?;
303 let results = self.store.search(&query_embedding, k)?;
304
305 let mut retrieval_results = Vec::with_capacity(results.len());
306 for (chunk_id, score) in results {
307 if let Some(chunk) = self.store.get(chunk_id) {
308 retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
309 }
310 }
311
312 Ok(retrieval_results)
313 }
314}
315
316pub struct SparseRetriever {
318 index: BM25Index,
319 chunks: std::collections::HashMap<ChunkId, Chunk>,
320}
321
322impl SparseRetriever {
323 #[must_use]
325 pub fn new() -> Self {
326 Self { index: BM25Index::new(), chunks: std::collections::HashMap::new() }
327 }
328
329 pub fn index(&mut self, chunk: Chunk) {
331 self.index.add(&chunk);
332 self.chunks.insert(chunk.id, chunk);
333 }
334
335 #[must_use]
337 pub fn retrieve(&self, query: &str, k: usize) -> Vec<RetrievalResult> {
338 let results = self.index.search(query, k);
339
340 results
341 .into_iter()
342 .filter_map(|(chunk_id, score)| {
343 self.chunks
344 .get(&chunk_id)
345 .map(|chunk| RetrievalResult::new(chunk.clone()).with_sparse_score(score))
346 })
347 .collect()
348 }
349}
350
351impl Default for SparseRetriever {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357#[cfg(feature = "multivector")]
390pub struct MultiVectorRetriever<E: crate::multivector::MultiVectorEmbedder> {
391 index: crate::multivector::WarpIndex,
393 embedder: E,
395 search_config: crate::multivector::WarpSearchConfig,
397}
398
399#[cfg(feature = "multivector")]
400impl<E: crate::multivector::MultiVectorEmbedder> MultiVectorRetriever<E> {
401 #[must_use]
408 pub fn new(config: crate::multivector::WarpIndexConfig, embedder: E) -> Self {
409 Self {
410 index: crate::multivector::WarpIndex::new(config),
411 embedder,
412 search_config: crate::multivector::WarpSearchConfig::default(),
413 }
414 }
415
416 #[must_use]
418 pub fn with_search_config(mut self, config: crate::multivector::WarpSearchConfig) -> Self {
419 self.search_config = config;
420 self
421 }
422
423 pub fn train(&mut self, sample_chunks: &[Chunk]) -> Result<()> {
432 let texts: Vec<&str> = sample_chunks.iter().map(|c| c.content.as_str()).collect();
433 let embeddings = self.embedder.embed_tokens_batch(&texts)?;
434 self.index.train(&embeddings)?;
435 Ok(())
436 }
437
438 pub fn index(&mut self, chunk: Chunk) -> Result<()> {
443 let embedding = self.embedder.embed_tokens(&chunk.content)?;
444 self.index.insert(chunk, embedding)?;
445 Ok(())
446 }
447
448 pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
450 for chunk in chunks {
451 self.index(chunk)?;
452 }
453 Ok(())
454 }
455
456 pub fn build(&mut self) -> Result<()> {
461 self.index.build()
462 }
463
464 pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
471 let query_embedding = self.embedder.embed_tokens(query)?;
472 let search_config = crate::multivector::WarpSearchConfig::with_k(k)
473 .nprobe(self.search_config.nprobe)
474 .bound(self.search_config.bound)
475 .centroid_score_threshold(self.search_config.centroid_score_threshold);
476 let results = self.index.search(&query_embedding, &search_config)?;
477
478 let mut retrieval_results = Vec::with_capacity(results.len());
479 for (chunk_id, score) in results {
480 if let Some(chunk) = self.index.get_chunk(&chunk_id) {
481 retrieval_results
482 .push(RetrievalResult::new(chunk.clone()).with_multivector_score(score));
483 }
484 }
485
486 Ok(retrieval_results)
487 }
488
489 #[must_use]
491 pub fn len(&self) -> usize {
492 self.index.num_chunks()
493 }
494
495 #[must_use]
497 pub fn is_empty(&self) -> bool {
498 self.len() == 0
499 }
500
501 #[must_use]
503 pub fn warp_index(&self) -> &crate::multivector::WarpIndex {
504 &self.index
505 }
506
507 #[must_use]
509 pub fn embedder(&self) -> &E {
510 &self.embedder
511 }
512
513 #[must_use]
515 pub fn memory_usage(&self) -> usize {
516 self.index.memory_usage()
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::{embed::MockEmbedder, DocumentId};
524
525 fn create_test_chunk(content: &str, embedding: Vec<f32>) -> Chunk {
526 let mut chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
527 chunk.set_embedding(embedding);
528 chunk
529 }
530
531 #[test]
534 fn test_retrieval_result_new() {
535 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
536 let result = RetrievalResult::new(chunk);
537
538 assert!(result.dense_score.is_none());
539 assert!(result.sparse_score.is_none());
540 assert!(result.fused_score.is_none());
541 assert!(result.rerank_score.is_none());
542 }
543
544 #[test]
545 fn test_retrieval_result_with_scores() {
546 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
547 let result = RetrievalResult::new(chunk)
548 .with_dense_score(0.9)
549 .with_sparse_score(0.8)
550 .with_fused_score(0.85)
551 .with_rerank_score(0.95);
552
553 assert_eq!(result.dense_score, Some(0.9));
554 assert_eq!(result.sparse_score, Some(0.8));
555 assert_eq!(result.fused_score, Some(0.85));
556 assert_eq!(result.rerank_score, Some(0.95));
557 }
558
559 #[test]
560 fn test_retrieval_result_best_score_priority() {
561 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
562
563 let result =
565 RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_rerank_score(0.9);
566 assert!((result.best_score() - 0.9).abs() < 0.001);
567
568 let result =
570 RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_fused_score(0.7);
571 assert!((result.best_score() - 0.7).abs() < 0.001);
572
573 let result = RetrievalResult::new(chunk).with_dense_score(0.5);
575 assert!((result.best_score() - 0.5).abs() < 0.001);
576 }
577
578 #[test]
579 fn test_retrieval_result_best_score_default() {
580 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
581 let result = RetrievalResult::new(chunk);
582 assert!((result.best_score() - 0.0).abs() < 0.001);
583 }
584
585 #[test]
588 fn test_hybrid_config_default() {
589 let config = HybridRetrieverConfig::default();
590 assert_eq!(config.candidates_per_source, 50);
591 assert!(config.use_dense);
592 assert!(config.use_sparse);
593 }
594
595 #[test]
598 fn test_hybrid_retriever_new() {
599 let embedder = MockEmbedder::new(64);
600 let dense = VectorStore::with_dimension(64);
601 let sparse = BM25Index::new();
602
603 let retriever = HybridRetriever::new(dense, sparse, embedder);
604 assert!(retriever.is_empty());
605 }
606
607 #[test]
608 fn test_hybrid_retriever_index() {
609 let embedder = MockEmbedder::new(64);
610 let dense = VectorStore::with_dimension(64);
611 let sparse = BM25Index::new();
612
613 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
614
615 let chunk = create_test_chunk("machine learning is great", vec![0.0; 64]);
616 retriever.index(chunk).unwrap();
617
618 assert_eq!(retriever.len(), 1);
619 }
620
621 #[test]
622 fn test_hybrid_retriever_index_batch() {
623 let embedder = MockEmbedder::new(64);
624 let dense = VectorStore::with_dimension(64);
625 let sparse = BM25Index::new();
626
627 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
628
629 let chunks = vec![
630 create_test_chunk("first document", vec![1.0; 64]),
631 create_test_chunk("second document", vec![0.5; 64]),
632 ];
633 retriever.index_batch(chunks).unwrap();
634
635 assert_eq!(retriever.len(), 2);
636 }
637
638 #[test]
639 fn test_hybrid_retriever_retrieve() {
640 let embedder = MockEmbedder::new(3);
641 let dense = VectorStore::with_dimension(3);
642 let sparse = BM25Index::new();
643
644 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
645
646 retriever
648 .index(create_test_chunk("machine learning algorithms", vec![1.0, 0.0, 0.0]))
649 .unwrap();
650 retriever
651 .index(create_test_chunk("deep learning neural networks", vec![0.9, 0.1, 0.0]))
652 .unwrap();
653 retriever.index(create_test_chunk("cooking recipes", vec![0.0, 0.0, 1.0])).unwrap();
654
655 let results = retriever.retrieve("machine learning", 2).unwrap();
656
657 assert!(!results.is_empty());
658 assert!(results.len() <= 2);
659 }
660
661 #[test]
662 fn test_hybrid_retriever_retrieve_dense_only() {
663 let embedder = MockEmbedder::new(3);
664 let dense = VectorStore::with_dimension(3);
665 let sparse = BM25Index::new();
666
667 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
668
669 retriever.index(create_test_chunk("test doc", vec![1.0, 0.0, 0.0])).unwrap();
670
671 let results = retriever.retrieve_dense("test", 10).unwrap();
672 assert!(!results.is_empty());
673 assert!(results[0].dense_score.is_some());
674 assert!(results[0].sparse_score.is_none());
675 }
676
677 #[test]
678 fn test_hybrid_retriever_retrieve_sparse_only() {
679 let embedder = MockEmbedder::new(3);
680 let dense = VectorStore::with_dimension(3);
681 let sparse = BM25Index::new();
682
683 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
684
685 retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
686
687 let results = retriever.retrieve_sparse("machine", 10).unwrap();
688 assert!(!results.is_empty());
689 assert!(results[0].sparse_score.is_some());
690 assert!(results[0].dense_score.is_none());
691 }
692
693 #[test]
694 fn test_hybrid_retriever_config() {
695 let embedder = MockEmbedder::new(3);
696 let dense = VectorStore::with_dimension(3);
697 let sparse = BM25Index::new();
698
699 let config = HybridRetrieverConfig {
700 candidates_per_source: 100,
701 fusion: FusionStrategy::Linear { dense_weight: 0.7 },
702 use_dense: true,
703 use_sparse: true,
704 };
705
706 let retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
707
708 assert_eq!(retriever.config.candidates_per_source, 100);
709 }
710
711 #[test]
714 fn test_dense_retriever() {
715 let embedder = MockEmbedder::new(3);
716 let store = VectorStore::with_dimension(3);
717 let mut retriever = DenseRetriever::new(store, embedder);
718
719 retriever.index(create_test_chunk("test document", vec![1.0, 0.0, 0.0])).unwrap();
720
721 let results = retriever.retrieve("test", 10).unwrap();
722 assert_eq!(results.len(), 1);
723 assert!(results[0].dense_score.is_some());
724 }
725
726 #[test]
729 fn test_sparse_retriever_new() {
730 let retriever = SparseRetriever::new();
731 let results = retriever.retrieve("test", 10);
732 assert!(results.is_empty());
733 }
734
735 #[test]
736 fn test_sparse_retriever_index() {
737 let mut retriever = SparseRetriever::new();
738 let chunk = Chunk::new(DocumentId::new(), "machine learning test".to_string(), 0, 20);
739
740 retriever.index(chunk);
741 let results = retriever.retrieve("machine", 10);
742
743 assert_eq!(results.len(), 1);
744 assert!(results[0].sparse_score.is_some());
745 }
746
747 #[test]
748 fn test_sparse_retriever_multiple() {
749 let mut retriever = SparseRetriever::new();
750
751 retriever.index(Chunk::new(
752 DocumentId::new(),
753 "rust programming language".to_string(),
754 0,
755 24,
756 ));
757 retriever.index(Chunk::new(
758 DocumentId::new(),
759 "python programming language".to_string(),
760 0,
761 26,
762 ));
763
764 let results = retriever.retrieve("programming", 10);
765 assert_eq!(results.len(), 2);
766 }
767
768 #[test]
771 fn test_hybrid_retriever_store_accessors() {
772 let embedder = MockEmbedder::new(64);
773 let dense = VectorStore::with_dimension(64);
774 let sparse = BM25Index::new();
775
776 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
777
778 let _dense_store = retriever.dense_store();
780 let _sparse_index = retriever.sparse_index();
781
782 let dense_mut = retriever.dense_store_mut();
784 assert!(dense_mut.is_empty());
785
786 let sparse_mut = retriever.sparse_index_mut();
787 let _ = sparse_mut; }
789
790 #[test]
791 fn test_hybrid_retriever_is_empty() {
792 let embedder = MockEmbedder::new(64);
793 let dense = VectorStore::with_dimension(64);
794 let sparse = BM25Index::new();
795
796 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
797 assert!(retriever.is_empty());
798
799 retriever.index(create_test_chunk("test", vec![0.0; 64])).unwrap();
800 assert!(!retriever.is_empty());
801 }
802
803 #[test]
804 fn test_sparse_retriever_default() {
805 let retriever = SparseRetriever::default();
806 let results = retriever.retrieve("test", 10);
807 assert!(results.is_empty());
808 }
809
810 #[test]
811 fn test_retrieval_result_best_score_sparse_fallback() {
812 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
813
814 let result = RetrievalResult::new(chunk).with_sparse_score(0.75);
816 assert!((result.best_score() - 0.75).abs() < 0.001);
817 }
818
819 #[test]
820 fn test_hybrid_retriever_with_dense_disabled() {
821 let embedder = MockEmbedder::new(3);
822 let dense = VectorStore::with_dimension(3);
823 let sparse = BM25Index::new();
824
825 let config = HybridRetrieverConfig {
826 candidates_per_source: 50,
827 fusion: FusionStrategy::default(),
828 use_dense: false,
829 use_sparse: true,
830 };
831
832 let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
833
834 retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
835
836 let results = retriever.retrieve("machine", 10).unwrap();
838 assert!(results.len() <= 10);
840 }
841
842 #[test]
843 fn test_hybrid_retriever_with_sparse_disabled() {
844 let embedder = MockEmbedder::new(3);
845 let dense = VectorStore::with_dimension(3);
846 let sparse = BM25Index::new();
847
848 let config = HybridRetrieverConfig {
849 candidates_per_source: 50,
850 fusion: FusionStrategy::default(),
851 use_dense: true,
852 use_sparse: false,
853 };
854
855 let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
856
857 retriever.index(create_test_chunk("test content", vec![1.0, 0.0, 0.0])).unwrap();
858
859 let results = retriever.retrieve("test", 10).unwrap();
861 assert!(results.len() <= 10);
862 }
863
864 #[test]
865 fn test_hybrid_retriever_config_serialization() {
866 let config = HybridRetrieverConfig {
867 candidates_per_source: 100,
868 fusion: FusionStrategy::RRF { k: 60.0 },
869 use_dense: true,
870 use_sparse: false,
871 };
872
873 let json = serde_json::to_string(&config).unwrap();
874 let deserialized: HybridRetrieverConfig = serde_json::from_str(&json).unwrap();
875
876 assert_eq!(config.candidates_per_source, deserialized.candidates_per_source);
877 assert_eq!(config.use_dense, deserialized.use_dense);
878 assert_eq!(config.use_sparse, deserialized.use_sparse);
879 }
880
881 #[test]
882 fn test_retrieval_result_serialization() {
883 let chunk = Chunk::new(DocumentId::new(), "test content".to_string(), 0, 12);
884 let result = RetrievalResult::new(chunk)
885 .with_dense_score(0.9)
886 .with_sparse_score(0.8)
887 .with_fused_score(0.85)
888 .with_rerank_score(0.95);
889
890 let json = serde_json::to_string(&result).unwrap();
891 let deserialized: RetrievalResult = serde_json::from_str(&json).unwrap();
892
893 assert_eq!(result.dense_score, deserialized.dense_score);
894 assert_eq!(result.sparse_score, deserialized.sparse_score);
895 assert_eq!(result.fused_score, deserialized.fused_score);
896 assert_eq!(result.rerank_score, deserialized.rerank_score);
897 }
898
899 use proptest::prelude::*;
902
903 proptest! {
904 #[test]
905 fn prop_retrieval_result_scores_preserved(
906 dense in 0.0f32..1.0,
907 sparse in 0.0f32..1.0,
908 fused in 0.0f32..1.0
909 ) {
910 let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
911 let result = RetrievalResult::new(chunk)
912 .with_dense_score(dense)
913 .with_sparse_score(sparse)
914 .with_fused_score(fused);
915
916 prop_assert!((result.dense_score.unwrap() - dense).abs() < 0.0001);
917 prop_assert!((result.sparse_score.unwrap() - sparse).abs() < 0.0001);
918 prop_assert!((result.fused_score.unwrap() - fused).abs() < 0.0001);
919 }
920
921 #[test]
922 fn prop_hybrid_retriever_respects_k(k in 1usize..10) {
923 let embedder = MockEmbedder::new(3);
924 let dense = VectorStore::with_dimension(3);
925 let sparse = BM25Index::new();
926
927 let mut retriever = HybridRetriever::new(dense, sparse, embedder);
928
929 for i in 0..20 {
931 let mut emb = vec![0.0; 3];
932 emb[i % 3] = 1.0;
933 retriever.index(create_test_chunk(
934 &format!("document number {i} about testing"),
935 emb,
936 )).unwrap();
937 }
938
939 let results = retriever.retrieve("testing", k).unwrap();
940 prop_assert!(results.len() <= k);
941 }
942 }
943}