1use crate::retriever_engine::{Document, EmbeddingModel, RetrievalResult};
13use crate::types::Layer3Result;
14use crate::vector_store::{MetadataFilter, VectorStore};
15use async_trait::async_trait;
16use parking_lot::RwLock;
17use sh_layer2::generate_short_id;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use tracing::instrument;
21
22pub struct BM25Index {
30 documents: Arc<RwLock<HashMap<String, String>>>,
32 term_frequencies: Arc<RwLock<HashMap<String, HashMap<String, usize>>>>,
34 idf_cache: Arc<RwLock<HashMap<String, f64>>>,
36 avg_doc_length: Arc<RwLock<f64>>,
38 doc_count: Arc<RwLock<usize>>,
40 k1: f64,
42 b: f64,
44}
45
46impl BM25Index {
47 pub fn new() -> Self {
49 Self {
50 documents: Arc::new(RwLock::new(HashMap::new())),
51 term_frequencies: Arc::new(RwLock::new(HashMap::new())),
52 idf_cache: Arc::new(RwLock::new(HashMap::new())),
53 avg_doc_length: Arc::new(RwLock::new(0.0)),
54 doc_count: Arc::new(RwLock::new(0)),
55 k1: 1.2,
56 b: 0.75,
57 }
58 }
59
60 pub fn with_params(k1: f64, b: f64) -> Self {
62 Self {
63 documents: Arc::new(RwLock::new(HashMap::new())),
64 term_frequencies: Arc::new(RwLock::new(HashMap::new())),
65 idf_cache: Arc::new(RwLock::new(HashMap::new())),
66 avg_doc_length: Arc::new(RwLock::new(0.0)),
67 doc_count: Arc::new(RwLock::new(0)),
68 k1,
69 b,
70 }
71 }
72
73 pub fn add_document(&self, doc_id: String, content: &str) {
75 let tokens = self.tokenize(content);
76 let mut tf: HashMap<String, usize> = HashMap::new();
77
78 for token in tokens {
79 *tf.entry(token).or_insert(0) += 1;
80 }
81
82 let doc_length = content.split_whitespace().count();
83
84 {
85 let mut documents = self.documents.write();
86 documents.insert(doc_id.clone(), content.to_lowercase());
87 }
88
89 {
90 let mut term_frequencies = self.term_frequencies.write();
91 term_frequencies.insert(doc_id, tf);
92 }
93
94 {
96 let mut avg_len = self.avg_doc_length.write();
97 let mut count = self.doc_count.write();
98
99 let old_count = *count;
100 let old_avg = *avg_len;
101 let new_count = old_count + 1;
102 *avg_len = (old_avg * old_count as f64 + doc_length as f64) / new_count as f64;
103 *count = new_count;
104 }
105
106 self.idf_cache.write().clear();
108 }
109
110 pub fn add_documents(&self, docs: Vec<(String, String)>) {
112 for (doc_id, content) in docs {
113 self.add_document(doc_id, &content);
114 }
115 }
116
117 pub fn remove_document(&self, doc_id: &str) -> bool {
119 let removed = {
120 let mut documents = self.documents.write();
121 documents.remove(doc_id).is_some()
122 };
123
124 if removed {
125 let mut term_frequencies = self.term_frequencies.write();
126 term_frequencies.remove(doc_id);
127
128 {
130 let mut count = self.doc_count.write();
131 if *count > 0 {
132 *count -= 1;
133 }
134 }
135
136 self.idf_cache.write().clear();
138 }
139
140 removed
141 }
142
143 pub fn clear(&self) {
145 self.documents.write().clear();
146 self.term_frequencies.write().clear();
147 self.idf_cache.write().clear();
148 *self.avg_doc_length.write() = 0.0;
149 *self.doc_count.write() = 0;
150 }
151
152 pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
154 let query_tokens = self.tokenize(query);
155
156 if query_tokens.is_empty() {
157 return Vec::new();
158 }
159
160 let documents = self.documents.read();
161 let term_frequencies = self.term_frequencies.read();
162 let avg_doc_length = *self.avg_doc_length.read();
163 let doc_count = *self.doc_count.read();
164
165 if doc_count == 0 {
166 return Vec::new();
167 }
168
169 let mut scores: Vec<(String, f64)> = documents
170 .keys()
171 .filter_map(|doc_id| {
172 let score = self.compute_bm25_score(
173 doc_id,
174 &query_tokens,
175 &term_frequencies,
176 avg_doc_length,
177 doc_count,
178 );
179 if score > 0.0 {
180 Some((doc_id.clone(), score))
181 } else {
182 None
183 }
184 })
185 .collect();
186
187 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188 scores.truncate(top_k);
189
190 scores
191 }
192
193 fn compute_bm25_score(
195 &self,
196 doc_id: &str,
197 query_tokens: &[String],
198 term_frequencies: &HashMap<String, HashMap<String, usize>>,
199 avg_doc_length: f64,
200 doc_count: usize,
201 ) -> f64 {
202 let doc_tf = match term_frequencies.get(doc_id) {
203 Some(tf) => tf,
204 None => return 0.0,
205 };
206
207 let documents = self.documents.read();
208 let doc_content = match documents.get(doc_id) {
209 Some(content) => content,
210 None => return 0.0,
211 };
212
213 let doc_length = doc_content.split_whitespace().count() as f64;
214 let mut idf_cache = self.idf_cache.write();
215
216 let mut score = 0.0;
217
218 for token in query_tokens {
219 let tf = *doc_tf.get(token).unwrap_or(&0) as f64;
220
221 if tf == 0.0 {
222 continue;
223 }
224
225 let idf = *idf_cache.entry(token.clone()).or_insert_with(|| {
227 let df = self.compute_document_frequency(token);
228 let n = doc_count as f64;
229 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
230 });
231
232 let numerator = tf * (self.k1 + 1.0);
234 let denominator =
235 tf + self.k1 * (1.0 - self.b + self.b * (doc_length / avg_doc_length));
236
237 score += idf * (numerator / denominator);
238 }
239
240 score
241 }
242
243 fn compute_document_frequency(&self, term: &str) -> f64 {
245 let term_frequencies = self.term_frequencies.read();
246 term_frequencies
247 .values()
248 .filter(|tf| tf.contains_key(term))
249 .count() as f64
250 }
251
252 fn tokenize(&self, text: &str) -> Vec<String> {
254 let stop_words: HashSet<&str> = [
255 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
256 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
257 "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
258 "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
259 "below", "between", "under", "again", "further", "then", "once", "here", "there",
260 "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
261 "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s",
262 "t", "just", "and", "but", "if", "or", "because", "until", "while", "although",
263 ]
264 .iter()
265 .cloned()
266 .collect();
267
268 text.to_lowercase()
269 .split_whitespace()
270 .filter(|w| !stop_words.contains(*w) && w.len() > 1)
271 .map(|s| s.to_string())
272 .collect()
273 }
274
275 pub fn doc_count(&self) -> usize {
277 *self.doc_count.read()
278 }
279}
280
281impl Default for BM25Index {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287pub struct ReciprocalRankFusion {
295 k: f64,
297}
298
299impl ReciprocalRankFusion {
300 pub fn new(k: f64) -> Self {
302 Self { k }
303 }
304
305 pub fn default_fusion() -> Self {
307 Self::new(60.0)
308 }
309
310 pub fn fuse(&self, result_lists: &[Vec<(String, f64)>], top_k: usize) -> Vec<(String, f64)> {
319 let mut rrf_scores: HashMap<String, f64> = HashMap::new();
320
321 for results in result_lists {
322 for (rank, (doc_id, _original_score)) in results.iter().enumerate() {
323 let rrf_score = 1.0 / (self.k + (rank + 1) as f64);
324 *rrf_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
325 }
326 }
327
328 let mut fused: Vec<(String, f64)> = rrf_scores.into_iter().collect();
329 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
330 fused.truncate(top_k);
331
332 fused
333 }
334
335 pub fn fuse_with_weights(
342 &self,
343 result_lists: &[Vec<(String, f64)>],
344 weights: &[f64],
345 top_k: usize,
346 ) -> Vec<(String, f64)> {
347 if result_lists.len() != weights.len() {
348 panic!("Result lists and weights must have the same length");
349 }
350
351 let mut combined_scores: HashMap<String, f64> = HashMap::new();
352
353 for (results, weight) in result_lists.iter().zip(weights.iter()) {
354 for (rank, (doc_id, original_score)) in results.iter().enumerate() {
355 let rrf_score = 1.0 / (self.k + (rank + 1) as f64);
356 let weighted_score = (rrf_score + original_score * 0.1) * weight;
357 *combined_scores.entry(doc_id.clone()).or_insert(0.0) += weighted_score;
358 }
359 }
360
361 let mut fused: Vec<(String, f64)> = combined_scores.into_iter().collect();
362 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363 fused.truncate(top_k);
364
365 fused
366 }
367}
368
369impl Default for ReciprocalRankFusion {
370 fn default() -> Self {
371 Self::default_fusion()
372 }
373}
374
375#[derive(Debug, Clone)]
381pub struct HybridRetrieverConfig {
382 pub vector_weight: f64,
384 pub bm25_weight: f64,
386 pub rrf_k: f64,
388 pub use_rrf: bool,
390 pub candidate_multiplier: usize,
392 pub min_score_threshold: f64,
394}
395
396impl HybridRetrieverConfig {
397 pub fn new() -> Self {
399 Self {
400 vector_weight: 0.7,
401 bm25_weight: 0.3,
402 rrf_k: 60.0,
403 use_rrf: true,
404 candidate_multiplier: 2,
405 min_score_threshold: 0.0,
406 }
407 }
408
409 pub fn vector_only() -> Self {
411 Self {
412 vector_weight: 1.0,
413 bm25_weight: 0.0,
414 ..Self::new()
415 }
416 }
417
418 pub fn bm25_only() -> Self {
420 Self {
421 vector_weight: 0.0,
422 bm25_weight: 1.0,
423 ..Self::new()
424 }
425 }
426
427 pub fn balanced() -> Self {
429 Self {
430 vector_weight: 0.5,
431 bm25_weight: 0.5,
432 ..Self::new()
433 }
434 }
435
436 pub fn with_weights(mut self, vector: f64, bm25: f64) -> Self {
438 let total = vector + bm25;
439 self.vector_weight = vector / total;
440 self.bm25_weight = bm25 / total;
441 self
442 }
443
444 pub fn with_rrf(mut self, enabled: bool, k: f64) -> Self {
446 self.use_rrf = enabled;
447 self.rrf_k = k;
448 self
449 }
450
451 pub fn with_candidate_multiplier(mut self, multiplier: usize) -> Self {
453 self.candidate_multiplier = multiplier;
454 self
455 }
456
457 pub fn with_min_score(mut self, threshold: f64) -> Self {
459 self.min_score_threshold = threshold;
460 self
461 }
462
463 pub fn normalize_weights(&mut self) {
465 let total = self.vector_weight + self.bm25_weight;
466 if total > 0.0 {
467 self.vector_weight /= total;
468 self.bm25_weight /= total;
469 }
470 }
471}
472
473impl Default for HybridRetrieverConfig {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479#[async_trait]
485pub trait HybridRetriever: Send + Sync {
486 async fn index_documents(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>>;
488
489 async fn retrieve(
491 &self,
492 query: &str,
493 top_k: usize,
494 config: Option<&HybridRetrieverConfig>,
495 ) -> Layer3Result<Vec<RetrievalResult>>;
496
497 async fn retrieve_with_filter(
499 &self,
500 query: &str,
501 top_k: usize,
502 filter: Option<MetadataFilter>,
503 config: Option<&HybridRetrieverConfig>,
504 ) -> Layer3Result<Vec<RetrievalResult>>;
505
506 async fn delete_documents(&self, doc_ids: &[String]) -> Layer3Result<bool>;
508
509 async fn clear(&self) -> Layer3Result<bool>;
511
512 async fn count(&self) -> Layer3Result<usize>;
514}
515
516type DocCacheEntry = (String, HashMap<String, serde_json::Value>);
522
523pub struct DefaultHybridRetriever<VS, EM>
527where
528 VS: VectorStore,
529 EM: EmbeddingModel,
530{
531 vector_store: VS,
533 embedding_model: EM,
535 bm25_index: BM25Index,
537 doc_cache: Arc<RwLock<HashMap<String, DocCacheEntry>>>,
539 default_config: HybridRetrieverConfig,
541}
542
543impl<VS, EM> DefaultHybridRetriever<VS, EM>
544where
545 VS: VectorStore,
546 EM: EmbeddingModel,
547{
548 pub fn new(vector_store: VS, embedding_model: EM) -> Self {
550 Self {
551 vector_store,
552 embedding_model,
553 bm25_index: BM25Index::new(),
554 doc_cache: Arc::new(RwLock::new(HashMap::new())),
555 default_config: HybridRetrieverConfig::new(),
556 }
557 }
558
559 pub fn with_config(
561 vector_store: VS,
562 embedding_model: EM,
563 config: HybridRetrieverConfig,
564 ) -> Self {
565 Self {
566 vector_store,
567 embedding_model,
568 bm25_index: BM25Index::new(),
569 doc_cache: Arc::new(RwLock::new(HashMap::new())),
570 default_config: config,
571 }
572 }
573
574 #[instrument(skip(self))]
576 async fn vector_search(&self, query: &str, top_k: usize) -> Layer3Result<Vec<(String, f64)>> {
577 let query_embedding = self.embedding_model.embed(query).await?;
578 let results = self.vector_store.query(query_embedding, top_k).await?;
579
580 Ok(results
581 .into_iter()
582 .map(|r| (r.doc_id, r.score as f64))
583 .collect())
584 }
585
586 #[instrument(skip(self))]
588 fn bm25_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
589 self.bm25_index.search(query, top_k)
590 }
591
592 fn get_document_content(
594 &self,
595 doc_id: &str,
596 ) -> Option<(String, HashMap<String, serde_json::Value>)> {
597 self.doc_cache.read().get(doc_id).cloned()
598 }
599
600 #[allow(dead_code)]
602 fn apply_threshold(&self, results: Vec<(String, f64)>, threshold: f64) -> Vec<(String, f64)> {
603 results
604 .into_iter()
605 .filter(|(_, score)| *score >= threshold)
606 .collect()
607 }
608}
609
610#[async_trait]
611impl<VS, EM> HybridRetriever for DefaultHybridRetriever<VS, EM>
612where
613 VS: VectorStore,
614 EM: EmbeddingModel,
615{
616 #[instrument(skip(self, documents))]
617 async fn index_documents(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>> {
618 use crate::vector_store::VectorItem;
619
620 let mut doc_ids = Vec::new();
621 let mut vector_items = Vec::new();
622 let mut bm25_docs = Vec::new();
623
624 for doc in documents {
625 let doc_id = doc.id.unwrap_or_else(generate_short_id);
626
627 {
629 let mut cache = self.doc_cache.write();
630 cache.insert(doc_id.clone(), (doc.content.clone(), doc.metadata.clone()));
631 }
632
633 bm25_docs.push((doc_id.clone(), doc.content.clone()));
635
636 let embedding = self.embedding_model.embed(&doc.content).await?;
638
639 let mut metadata = doc.metadata.clone();
640 if let Some(source) = doc.source {
641 metadata.insert("source".to_string(), serde_json::json!(source));
642 }
643
644 vector_items.push(VectorItem {
645 id: doc_id.clone(),
646 vector: embedding,
647 metadata,
648 content: Some(doc.content),
649 });
650
651 doc_ids.push(doc_id);
652 }
653
654 self.bm25_index.add_documents(bm25_docs);
656
657 self.vector_store.add_batch(vector_items).await?;
659
660 Ok(doc_ids)
661 }
662
663 #[instrument(skip(self))]
664 async fn retrieve(
665 &self,
666 query: &str,
667 top_k: usize,
668 config: Option<&HybridRetrieverConfig>,
669 ) -> Layer3Result<Vec<RetrievalResult>> {
670 let config = config.unwrap_or(&self.default_config);
671
672 let candidates = top_k * config.candidate_multiplier;
674
675 let mut result_lists: Vec<Vec<(String, f64)>> = Vec::new();
677 let mut weights: Vec<f64> = Vec::new();
678
679 if config.vector_weight > 0.0 {
681 let vector_results = self.vector_search(query, candidates).await?;
682 result_lists.push(vector_results);
683 weights.push(config.vector_weight);
684 }
685
686 if config.bm25_weight > 0.0 {
688 let bm25_results = self.bm25_search(query, candidates);
689 result_lists.push(bm25_results);
690 weights.push(config.bm25_weight);
691 }
692
693 if result_lists.len() == 1 {
695 let results = result_lists.remove(0);
696 let final_results: Vec<RetrievalResult> = results
697 .into_iter()
698 .take(top_k)
699 .filter_map(|(doc_id, score)| {
700 let (content, metadata) = self.get_document_content(&doc_id)?;
701 let source = metadata
702 .get("source")
703 .and_then(|v| v.as_str())
704 .map(String::from);
705 Some(RetrievalResult {
706 doc_id,
707 content,
708 score: score as f32,
709 metadata,
710 source,
711 })
712 })
713 .collect();
714
715 return Ok(final_results);
716 }
717
718 let fused_results = if config.use_rrf {
720 let rrf = ReciprocalRankFusion::new(config.rrf_k);
721 rrf.fuse_with_weights(&result_lists, &weights, top_k)
722 } else {
723 let mut combined: HashMap<String, f64> = HashMap::new();
725 for (results, weight) in result_lists.iter().zip(weights.iter()) {
726 for (doc_id, score) in results {
727 *combined.entry(doc_id.clone()).or_insert(0.0) += score * weight;
728 }
729 }
730 let mut fused: Vec<(String, f64)> = combined.into_iter().collect();
731 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
732 fused.truncate(top_k);
733 fused
734 };
735
736 let final_results: Vec<RetrievalResult> = fused_results
738 .into_iter()
739 .filter_map(|(doc_id, score)| {
740 let (content, metadata) = self.get_document_content(&doc_id)?;
741 let source = metadata
742 .get("source")
743 .and_then(|v| v.as_str())
744 .map(String::from);
745 Some(RetrievalResult {
746 doc_id,
747 content,
748 score: score as f32,
749 metadata,
750 source,
751 })
752 })
753 .collect();
754
755 Ok(final_results)
756 }
757
758 async fn retrieve_with_filter(
759 &self,
760 query: &str,
761 top_k: usize,
762 filter: Option<MetadataFilter>,
763 config: Option<&HybridRetrieverConfig>,
764 ) -> Layer3Result<Vec<RetrievalResult>> {
765 let config = config.unwrap_or(&self.default_config);
766 let candidates = top_k * config.candidate_multiplier * 2;
767
768 let mut results = self.retrieve(query, candidates, Some(config)).await?;
770
771 if let Some(f) = filter {
773 results.retain(|r| {
774 f.must
776 .iter()
777 .all(|(key, value)| r.metadata.get(key) == Some(value))
778 });
779 }
780
781 if config.min_score_threshold > 0.0 {
783 results.retain(|r| r.score >= config.min_score_threshold as f32);
784 }
785
786 results.truncate(top_k);
787 Ok(results)
788 }
789
790 async fn delete_documents(&self, doc_ids: &[String]) -> Layer3Result<bool> {
791 self.vector_store.delete_batch(doc_ids).await?;
793
794 for doc_id in doc_ids {
796 self.bm25_index.remove_document(doc_id);
797 }
798
799 {
801 let mut cache = self.doc_cache.write();
802 for doc_id in doc_ids {
803 cache.remove(doc_id);
804 }
805 }
806
807 Ok(true)
808 }
809
810 async fn clear(&self) -> Layer3Result<bool> {
811 self.vector_store.clear().await?;
812 self.bm25_index.clear();
813 self.doc_cache.write().clear();
814 Ok(true)
815 }
816
817 async fn count(&self) -> Layer3Result<usize> {
818 Ok(self.bm25_index.doc_count())
819 }
820}
821
822#[cfg(test)]
827mod tests {
828 use super::*;
829 use crate::retriever_engine::Layer1EmbeddingAdapter;
830 use crate::vector_store::InMemoryVectorStore;
831
832 fn create_mock_embedding_model(dimension: usize) -> Layer1EmbeddingAdapter {
835 Layer1EmbeddingAdapter::new(Box::new(sh_layer1::MockEmbeddingModel::new(dimension)))
836 }
837
838 #[test]
839 fn test_bm25_index_basic() {
840 let index = BM25Index::new();
841
842 index.add_document("doc1".to_string(), "Rust is a systems programming language");
843 index.add_document("doc2".to_string(), "Python is used for data science");
844 index.add_document("doc3".to_string(), "JavaScript runs in the browser");
845
846 let results = index.search("Rust programming", 5);
847 assert!(!results.is_empty());
848 assert_eq!(results[0].0, "doc1");
849 }
850
851 #[test]
852 fn test_bm25_index_scoring() {
853 let index = BM25Index::new();
854
855 index.add_document("doc1".to_string(), "machine learning algorithms");
856 index.add_document("doc2".to_string(), "deep learning neural networks");
857 index.add_document("doc3".to_string(), "database systems");
858
859 let results = index.search("machine learning", 3);
860 assert!(!results.is_empty());
861
862 assert!(results.iter().any(|(id, _)| id == "doc1"));
864 }
865
866 #[test]
867 fn test_bm25_remove_document() {
868 let index = BM25Index::new();
869
870 index.add_document("doc1".to_string(), "test document");
871 assert_eq!(index.doc_count(), 1);
872
873 let removed = index.remove_document("doc1");
874 assert!(removed);
875 assert_eq!(index.doc_count(), 0);
876
877 let removed = index.remove_document("nonexistent");
878 assert!(!removed);
879 }
880
881 #[test]
882 fn test_rrf_fusion() {
883 let rrf = ReciprocalRankFusion::default_fusion();
884
885 let list1 = vec![
886 ("doc1".to_string(), 0.9),
887 ("doc2".to_string(), 0.8),
888 ("doc3".to_string(), 0.7),
889 ];
890
891 let list2 = vec![
892 ("doc3".to_string(), 0.95),
893 ("doc1".to_string(), 0.85),
894 ("doc4".to_string(), 0.75),
895 ];
896
897 let fused = rrf.fuse(&[list1, list2], 5);
898
899 assert!(!fused.is_empty());
900 assert!(fused
902 .iter()
903 .take(2)
904 .any(|(id, _)| id == "doc1" || id == "doc3"));
905 }
906
907 #[test]
908 fn test_rrf_with_weights() {
909 let rrf = ReciprocalRankFusion::new(60.0);
910
911 let list1 = vec![("doc1".to_string(), 0.9)];
912 let list2 = vec![("doc2".to_string(), 0.9)];
913
914 let fused = rrf.fuse_with_weights(&[list1, list2], &[0.7, 0.3], 5);
915 assert!(!fused.is_empty());
916 }
917
918 #[test]
919 fn test_hybrid_retriever_config() {
920 let config = HybridRetrieverConfig::new();
921 assert_eq!(config.vector_weight, 0.7);
922 assert_eq!(config.bm25_weight, 0.3);
923 assert!(config.use_rrf);
924
925 let vector_only = HybridRetrieverConfig::vector_only();
926 assert_eq!(vector_only.vector_weight, 1.0);
927 assert_eq!(vector_only.bm25_weight, 0.0);
928
929 let balanced = HybridRetrieverConfig::balanced();
930 assert_eq!(balanced.vector_weight, 0.5);
931 assert_eq!(balanced.bm25_weight, 0.5);
932
933 let custom = HybridRetrieverConfig::new().with_weights(0.8, 0.2);
934 assert!((custom.vector_weight - 0.8).abs() < 0.001);
935 assert!((custom.bm25_weight - 0.2).abs() < 0.001);
936 }
937
938 #[tokio::test]
939 async fn test_hybrid_retriever_index_and_search() {
940 let vector_store = InMemoryVectorStore::in_memory();
941 let embedding_model = create_mock_embedding_model(128);
942
943 let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
944
945 let docs = vec![
946 Document::new("Rust is a systems programming language"),
947 Document::new("Python is widely used for data science"),
948 Document::new("JavaScript runs in the browser"),
949 ];
950
951 let doc_ids = retriever.index_documents(docs).await.unwrap();
952 assert_eq!(doc_ids.len(), 3);
953
954 let results = retriever
955 .retrieve("Rust programming", 5, None)
956 .await
957 .unwrap();
958 assert!(!results.is_empty());
959 }
960
961 #[tokio::test]
962 async fn test_hybrid_retriever_with_config() {
963 let vector_store = InMemoryVectorStore::in_memory();
964 let embedding_model = create_mock_embedding_model(128);
965
966 let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
967
968 retriever
969 .index_documents(vec![
970 Document::new("Machine learning algorithms use neural networks"),
971 Document::new("Database stores data for applications"),
972 ])
973 .await
974 .unwrap();
975
976 let config = HybridRetrieverConfig::vector_only();
978 let results = retriever
979 .retrieve("neural networks", 5, Some(&config))
980 .await
981 .unwrap();
982 assert!(!results.is_empty());
983
984 let config = HybridRetrieverConfig::bm25_only();
986 let results = retriever
987 .retrieve("machine learning", 5, Some(&config))
988 .await
989 .unwrap();
990 assert!(!results.is_empty());
991
992 let config = HybridRetrieverConfig::balanced().with_rrf(true, 60.0);
994 let results = retriever
995 .retrieve("database", 5, Some(&config))
996 .await
997 .unwrap();
998 assert!(!results.is_empty());
999 }
1000
1001 #[tokio::test]
1002 async fn test_hybrid_retriever_delete_and_count() {
1003 let vector_store = InMemoryVectorStore::in_memory();
1004 let embedding_model = create_mock_embedding_model(128);
1005
1006 let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
1007
1008 let doc_ids = retriever
1009 .index_documents(vec![Document::new("Test document")])
1010 .await
1011 .unwrap();
1012
1013 assert_eq!(retriever.count().await.unwrap(), 1);
1014
1015 retriever.delete_documents(&doc_ids).await.unwrap();
1016 assert_eq!(retriever.count().await.unwrap(), 0);
1017 }
1018
1019 #[tokio::test]
1020 async fn test_hybrid_retriever_clear() {
1021 let vector_store = InMemoryVectorStore::in_memory();
1022 let embedding_model = create_mock_embedding_model(128);
1023
1024 let retriever = DefaultHybridRetriever::new(vector_store, embedding_model);
1025
1026 retriever
1027 .index_documents(vec![Document::new("Doc 1"), Document::new("Doc 2")])
1028 .await
1029 .unwrap();
1030
1031 assert_eq!(retriever.count().await.unwrap(), 2);
1032
1033 retriever.clear().await.unwrap();
1034 assert_eq!(retriever.count().await.unwrap(), 0);
1035 }
1036}