1use crate::types::Layer3Result;
14use async_trait::async_trait;
15use parking_lot::RwLock;
16use sh_layer2::generate_short_id;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20#[async_trait]
24pub trait RetrieverEngine: Send + Sync {
25 async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>>;
27
28 async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>>;
30
31 async fn hybrid_retrieve(
33 &self,
34 query: &str,
35 top_k: usize,
36 ) -> Layer3Result<Vec<RetrievalResult>>;
37
38 async fn hybrid_retrieve_with_config(
40 &self,
41 query: &str,
42 top_k: usize,
43 config: &HybridSearchConfig,
44 ) -> Layer3Result<Vec<RetrievalResult>> {
45 let _ = config;
46 self.hybrid_retrieve(query, top_k).await
47 }
48
49 async fn retrieve_with_filter(
51 &self,
52 query: &str,
53 top_k: usize,
54 filter: Option<crate::vector_store::MetadataFilter>,
55 ) -> Layer3Result<Vec<RetrievalResult>> {
56 let _ = filter;
57 self.retrieve(query, top_k).await
58 }
59
60 async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool>;
62
63 async fn clear(&self) -> Layer3Result<bool>;
65
66 async fn count(&self) -> Layer3Result<usize>;
68}
69
70#[derive(Debug, Clone)]
72pub struct Document {
73 pub id: Option<String>,
75 pub content: String,
77 pub metadata: HashMap<String, serde_json::Value>,
79 pub source: Option<String>,
81}
82
83impl Document {
84 pub fn new(content: impl Into<String>) -> Self {
85 Self {
86 id: None,
87 content: content.into(),
88 metadata: HashMap::new(),
89 source: None,
90 }
91 }
92
93 pub fn with_source(mut self, source: impl Into<String>) -> Self {
94 self.source = Some(source.into());
95 self
96 }
97
98 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
99 self.metadata.insert(key.into(), value);
100 self
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct RetrievalResult {
107 pub doc_id: String,
109 pub content: String,
111 pub score: f32,
113 pub metadata: HashMap<String, serde_json::Value>,
115 pub source: Option<String>,
117}
118
119#[derive(Debug, Clone, Copy)]
125pub struct HybridWeights {
126 pub vector: f32,
128 pub keyword: f32,
130}
131
132impl HybridWeights {
133 pub fn new(vector: f32, keyword: f32) -> Self {
135 let total = vector + keyword;
136 Self {
137 vector: vector / total,
138 keyword: keyword / total,
139 }
140 }
141
142 pub fn default_weights() -> Self {
144 Self {
145 vector: 0.7,
146 keyword: 0.3,
147 }
148 }
149
150 pub fn vector_only() -> Self {
152 Self {
153 vector: 1.0,
154 keyword: 0.0,
155 }
156 }
157
158 pub fn keyword_only() -> Self {
160 Self {
161 vector: 0.0,
162 keyword: 1.0,
163 }
164 }
165
166 pub fn balanced() -> Self {
168 Self {
169 vector: 0.5,
170 keyword: 0.5,
171 }
172 }
173}
174
175impl Default for HybridWeights {
176 fn default() -> Self {
177 Self::default_weights()
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct HybridSearchConfig {
184 pub weights: HybridWeights,
186 pub phrase_matching: bool,
188 pub use_rrif: bool,
190 pub rrif_k: f32,
192 pub candidates_multiplier: usize,
194}
195
196impl HybridSearchConfig {
197 pub fn new() -> Self {
198 Self {
199 weights: HybridWeights::default(),
200 phrase_matching: true,
201 use_rrif: true,
202 rrif_k: 60.0,
203 candidates_multiplier: 2,
204 }
205 }
206
207 pub fn with_weights(mut self, weights: HybridWeights) -> Self {
208 self.weights = weights;
209 self
210 }
211
212 pub fn with_phrase_matching(mut self, enabled: bool) -> Self {
213 self.phrase_matching = enabled;
214 self
215 }
216
217 pub fn with_rrif(mut self, enabled: bool, k: f32) -> Self {
218 self.use_rrif = enabled;
219 self.rrif_k = k;
220 self
221 }
222}
223
224impl Default for HybridSearchConfig {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[async_trait]
232pub trait EmbeddingModel: Send + Sync {
233 async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>>;
235
236 async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>>;
238
239 fn dimension(&self) -> usize;
241
242 fn model_name(&self) -> &str;
244}
245
246pub trait ChunkingStrategy: Send + Sync {
248 fn chunk(&self, document: &Document) -> Vec<Chunk>;
250}
251
252#[derive(Debug, Clone)]
254pub struct Chunk {
255 pub id: String,
257 pub doc_id: String,
259 pub content: String,
261 pub position: ChunkPosition,
263 pub metadata: HashMap<String, serde_json::Value>,
265}
266
267#[derive(Debug, Clone, Copy)]
269pub struct ChunkPosition {
270 pub start: usize,
272 pub end: usize,
274 pub index: usize,
276 pub total: usize,
278}
279
280#[derive(Debug, Clone)]
282pub struct FixedSizeChunker {
283 pub chunk_size: usize,
285 pub overlap: usize,
287}
288
289impl FixedSizeChunker {
290 pub fn new(chunk_size: usize, overlap: usize) -> Self {
291 Self {
292 chunk_size,
293 overlap,
294 }
295 }
296}
297
298impl Default for FixedSizeChunker {
299 fn default() -> Self {
300 Self {
301 chunk_size: 500,
302 overlap: 50,
303 }
304 }
305}
306
307impl ChunkingStrategy for FixedSizeChunker {
308 fn chunk(&self, document: &Document) -> Vec<Chunk> {
309 let content = &document.content;
310 if content.len() <= self.chunk_size {
311 return vec![Chunk {
312 id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
313 doc_id: document.id.clone().unwrap_or_default(),
314 content: content.clone(),
315 position: ChunkPosition {
316 start: 0,
317 end: content.len(),
318 index: 0,
319 total: 1,
320 },
321 metadata: document.metadata.clone(),
322 }];
323 }
324
325 let mut chunks = Vec::new();
326 let mut start = 0;
327 let mut index = 0;
328
329 while start < content.len() {
330 let end = (start + self.chunk_size).min(content.len());
331 chunks.push(Chunk {
332 id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
333 doc_id: document.id.clone().unwrap_or_default(),
334 content: content[start..end].to_string(),
335 position: ChunkPosition {
336 start,
337 end,
338 index,
339 total: 0, },
341 metadata: document.metadata.clone(),
342 });
343 start = if end < content.len() {
345 end.saturating_sub(self.overlap)
346 } else {
347 end
348 };
349 index += 1;
350 }
351
352 let total = chunks.len();
353 for chunk in &mut chunks {
354 chunk.position.total = total;
355 }
356
357 chunks
358 }
359}
360
361#[derive(Debug, Clone)]
369pub struct ParagraphChunker {
370 max_chunk_size: usize,
371 min_chunk_size: usize,
372}
373
374impl ParagraphChunker {
375 pub fn new(max_chunk_size: usize, min_chunk_size: usize) -> Self {
376 Self {
377 max_chunk_size,
378 min_chunk_size,
379 }
380 }
381}
382
383impl Default for ParagraphChunker {
384 fn default() -> Self {
385 Self {
386 max_chunk_size: 1000,
387 min_chunk_size: 100,
388 }
389 }
390}
391
392impl ChunkingStrategy for ParagraphChunker {
393 fn chunk(&self, document: &Document) -> Vec<Chunk> {
394 let content = &document.content;
395 let paragraphs: Vec<&str> = content
396 .split('\n')
397 .filter(|p| !p.trim().is_empty())
398 .collect();
399
400 if paragraphs.is_empty() {
401 return vec![Chunk {
402 id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
403 doc_id: document.id.clone().unwrap_or_default(),
404 content: content.clone(),
405 position: ChunkPosition {
406 start: 0,
407 end: content.len(),
408 index: 0,
409 total: 1,
410 },
411 metadata: document.metadata.clone(),
412 }];
413 }
414
415 let mut chunks = Vec::new();
416 let mut current_chunk = String::new();
417 let mut start = 0;
418 let mut index = 0;
419
420 for paragraph in paragraphs {
421 if current_chunk.len() + paragraph.len() < self.max_chunk_size {
422 if !current_chunk.is_empty() {
423 current_chunk.push('\n');
424 }
425 current_chunk.push_str(paragraph);
426 } else {
427 if current_chunk.len() >= self.min_chunk_size {
428 let end = start + current_chunk.len();
429 chunks.push(Chunk {
430 id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
431 doc_id: document.id.clone().unwrap_or_default(),
432 content: current_chunk.clone(),
433 position: ChunkPosition {
434 start,
435 end,
436 index,
437 total: 0,
438 },
439 metadata: document.metadata.clone(),
440 });
441 start = end;
442 index += 1;
443 }
444 current_chunk = paragraph.to_string();
445 }
446 }
447
448 if current_chunk.len() >= self.min_chunk_size {
449 chunks.push(Chunk {
450 id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
451 doc_id: document.id.clone().unwrap_or_default(),
452 content: current_chunk,
453 position: ChunkPosition {
454 start,
455 end: content.len(),
456 index,
457 total: 0,
458 },
459 metadata: document.metadata.clone(),
460 });
461 }
462
463 let total = chunks.len().max(1);
464 for chunk in &mut chunks {
465 chunk.position.total = total;
466 }
467
468 if chunks.is_empty() {
469 vec![Chunk {
470 id: format!("{}-0", document.id.as_deref().unwrap_or("doc")),
471 doc_id: document.id.clone().unwrap_or_default(),
472 content: content.clone(),
473 position: ChunkPosition {
474 start: 0,
475 end: content.len(),
476 index: 0,
477 total: 1,
478 },
479 metadata: document.metadata.clone(),
480 }]
481 } else {
482 chunks
483 }
484 }
485}
486
487#[derive(Debug, Clone)]
495pub struct RecursiveChunker {
496 max_chunk_size: usize,
497 separators: Vec<String>,
498}
499
500impl RecursiveChunker {
501 pub fn new(max_chunk_size: usize) -> Self {
502 Self {
503 max_chunk_size,
504 separators: vec![
505 "\n\n\n".to_string(),
506 "\n\n".to_string(),
507 "\n".to_string(),
508 ". ".to_string(),
509 " ".to_string(),
510 "".to_string(),
511 ],
512 }
513 }
514}
515
516impl Default for RecursiveChunker {
517 fn default() -> Self {
518 Self::new(1000)
519 }
520}
521
522impl ChunkingStrategy for RecursiveChunker {
523 fn chunk(&self, document: &Document) -> Vec<Chunk> {
524 self._recursive_split(document, &document.content, 0, 0)
525 }
526}
527
528impl RecursiveChunker {
529 fn _recursive_split(
530 &self,
531 document: &Document,
532 text: &str,
533 start_offset: usize,
534 initial_index: usize,
535 ) -> Vec<Chunk> {
536 if text.len() <= self.max_chunk_size {
537 return vec![Chunk {
538 id: format!(
539 "{}-{}",
540 document.id.as_deref().unwrap_or("doc"),
541 initial_index
542 ),
543 doc_id: document.id.clone().unwrap_or_default(),
544 content: text.to_string(),
545 position: ChunkPosition {
546 start: start_offset,
547 end: start_offset + text.len(),
548 index: initial_index,
549 total: 1,
550 },
551 metadata: document.metadata.clone(),
552 }];
553 }
554
555 for separator in &self.separators {
556 if separator.is_empty() {
557 let mut chunks = Vec::new();
558 let mut start = 0;
559 let mut index = initial_index;
560
561 while start < text.len() {
562 let end = (start + self.max_chunk_size).min(text.len());
563 chunks.push(Chunk {
564 id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
565 doc_id: document.id.clone().unwrap_or_default(),
566 content: text[start..end].to_string(),
567 position: ChunkPosition {
568 start: start_offset + start,
569 end: start_offset + end,
570 index,
571 total: 0,
572 },
573 metadata: document.metadata.clone(),
574 });
575 start = end;
576 index += 1;
577 }
578
579 let total = chunks.len();
580 for chunk in &mut chunks {
581 chunk.position.total = total;
582 }
583 return chunks;
584 }
585
586 if text.contains(separator) {
587 let parts: Vec<&str> = text.split(separator).collect();
588 let mut chunks = Vec::new();
589 let mut current_chunk = String::new();
590 let mut current_start = start_offset;
591 let mut index = initial_index;
592
593 for (i, part) in parts.iter().enumerate() {
594 let part_with_sep = if i < parts.len() - 1 {
595 format!("{}{}", part, separator)
596 } else {
597 part.to_string()
598 };
599
600 if current_chunk.len() + part_with_sep.len() <= self.max_chunk_size {
601 current_chunk.push_str(&part_with_sep);
602 } else {
603 if !current_chunk.is_empty() {
604 chunks.push(Chunk {
605 id: format!(
606 "{}-{}",
607 document.id.as_deref().unwrap_or("doc"),
608 index
609 ),
610 doc_id: document.id.clone().unwrap_or_default(),
611 content: current_chunk.clone(),
612 position: ChunkPosition {
613 start: current_start,
614 end: current_start + current_chunk.len(),
615 index,
616 total: 0,
617 },
618 metadata: document.metadata.clone(),
619 });
620 current_start += current_chunk.len();
621 index += 1;
622 }
623
624 if part_with_sep.len() > self.max_chunk_size {
625 let sub_chunks = self._recursive_split(
626 document,
627 &part_with_sep,
628 current_start,
629 index,
630 );
631 for sub in sub_chunks {
632 current_start = sub.position.end;
633 index += 1;
634 chunks.push(sub);
635 }
636 } else {
637 current_chunk = part_with_sep;
638 }
639 }
640 }
641
642 if !current_chunk.is_empty() {
643 chunks.push(Chunk {
644 id: format!("{}-{}", document.id.as_deref().unwrap_or("doc"), index),
645 doc_id: document.id.clone().unwrap_or_default(),
646 content: current_chunk,
647 position: ChunkPosition {
648 start: current_start,
649 end: start_offset + text.len(),
650 index,
651 total: 0,
652 },
653 metadata: document.metadata.clone(),
654 });
655 }
656
657 let total = chunks.len().max(1);
658 for chunk in &mut chunks {
659 chunk.position.total = total;
660 }
661 return chunks;
662 }
663 }
664
665 vec![Chunk {
666 id: format!(
667 "{}-{}",
668 document.id.as_deref().unwrap_or("doc"),
669 initial_index
670 ),
671 doc_id: document.id.clone().unwrap_or_default(),
672 content: text.to_string(),
673 position: ChunkPosition {
674 start: start_offset,
675 end: start_offset + text.len(),
676 index: initial_index,
677 total: 1,
678 },
679 metadata: document.metadata.clone(),
680 }]
681 }
682}
683
684use crate::vector_store::{VectorItem, VectorStore};
689
690pub struct DefaultRetrieverEngine<VS, EM, CS>
694where
695 VS: VectorStore,
696 EM: EmbeddingModel,
697 CS: ChunkingStrategy,
698{
699 vector_store: VS,
701 embedding_model: EM,
703 chunking_strategy: CS,
705 doc_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
707 chunk_cache: Arc<RwLock<HashMap<String, String>>>,
709}
710
711impl<VS, EM, CS> DefaultRetrieverEngine<VS, EM, CS>
712where
713 VS: VectorStore,
714 EM: EmbeddingModel,
715 CS: ChunkingStrategy,
716{
717 pub fn new(vector_store: VS, embedding_model: EM, chunking_strategy: CS) -> Self {
719 Self {
720 vector_store,
721 embedding_model,
722 chunking_strategy,
723 doc_index: Arc::new(RwLock::new(HashMap::new())),
724 chunk_cache: Arc::new(RwLock::new(HashMap::new())),
725 }
726 }
727
728 fn extract_keywords(&self, query: &str) -> Vec<String> {
730 let words: Vec<String> = query
731 .to_lowercase()
732 .split_whitespace()
733 .map(|s| s.to_string())
734 .collect();
735
736 let stop_words = std::collections::HashSet::from([
737 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
738 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
739 "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
740 "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
741 "below", "between", "under", "again", "further", "then", "once", "here", "there",
742 "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
743 "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s",
744 "t", "just", "and", "but", "if", "or", "because", "until", "while", "although",
745 ]);
746
747 words
748 .into_iter()
749 .filter(|w| !stop_words.contains(w.as_str()) && w.len() > 1)
750 .collect()
751 }
752
753 fn compute_keyword_score(
755 &self,
756 query_keywords: &[String],
757 content: &str,
758 config: &HybridSearchConfig,
759 ) -> f32 {
760 if query_keywords.is_empty() {
761 return 0.0;
762 }
763
764 let content_lower = content.to_lowercase();
765
766 let mut phrase_bonus: f32 = 0.0;
768 if config.phrase_matching {
769 for keyword in query_keywords {
770 if content_lower.contains(keyword) {
771 phrase_bonus += 0.1;
772 }
773 }
774 phrase_bonus = phrase_bonus.min(0.3);
775 }
776
777 let matched_keywords = query_keywords
779 .iter()
780 .filter(|kw| content_lower.contains(kw.as_str()))
781 .count();
782
783 let k1 = 1.2;
785 let content_len = content.len() as f32;
786 let avg_len = 500.0;
787 let len_norm = 1.0 - 0.75 + 0.75 * (content_len / avg_len);
788
789 let bm25_score =
790 (matched_keywords as f32 * (k1 + 1.0)) / (matched_keywords as f32 + k1 * len_norm);
791
792 let normalized_score = bm25_score / (query_keywords.len() as f32 + k1);
794 let normalized_score = normalized_score.min(1.0);
795
796 normalized_score + phrase_bonus
797 }
798
799 async fn keyword_only_search(
801 &self,
802 query: &str,
803 candidates: Vec<RetrievalResult>,
804 top_k: usize,
805 config: &HybridSearchConfig,
806 ) -> Layer3Result<Vec<RetrievalResult>> {
807 let query_keywords = self.extract_keywords(query);
808
809 let mut scored_results: Vec<RetrievalResult> = candidates
810 .into_iter()
811 .map(|r| {
812 let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
813 RetrievalResult {
814 doc_id: r.doc_id,
815 content: r.content,
816 score: keyword_score,
817 metadata: r.metadata,
818 source: r.source,
819 }
820 })
821 .collect();
822
823 scored_results.sort_by(|a, b| {
824 b.score
825 .partial_cmp(&a.score)
826 .unwrap_or(std::cmp::Ordering::Equal)
827 });
828
829 scored_results.truncate(top_k);
830 Ok(scored_results)
831 }
832
833 fn apply_rrif(&self, results: Vec<RetrievalResult>, k: f32) -> Vec<RetrievalResult> {
835 if results.is_empty() {
836 return results;
837 }
838
839 results
840 .into_iter()
841 .enumerate()
842 .map(|(idx, mut r)| {
843 let rank = (idx + 1) as f32;
844 let rrif_score = 1.0 / (k + rank);
845 r.score = r.score * 0.5 + rrif_score * 0.5;
846 r
847 })
848 .collect()
849 }
850}
851
852#[async_trait]
853impl<VS, EM, CS> RetrieverEngine for DefaultRetrieverEngine<VS, EM, CS>
854where
855 VS: VectorStore,
856 EM: EmbeddingModel,
857 CS: ChunkingStrategy,
858{
859 async fn index(&self, documents: Vec<Document>) -> Layer3Result<Vec<String>> {
860 let mut doc_ids = Vec::new();
861
862 for doc in documents {
863 let doc_id = doc.id.clone().unwrap_or_else(generate_short_id);
865 let chunks = self.chunking_strategy.chunk(&Document {
866 id: Some(doc_id.clone()),
867 content: doc.content.clone(),
868 metadata: doc.metadata.clone(),
869 source: doc.source.clone(),
870 });
871
872 let chunk_ids: Vec<String> = chunks.iter().map(|c| c.id.clone()).collect();
874
875 let chunk_contents: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
876
877 let embeddings = self.embedding_model.embed_batch(&chunk_contents).await?;
879
880 let vector_items: Vec<VectorItem> = chunks
882 .into_iter()
883 .zip(embeddings)
884 .map(|(chunk, embedding)| {
885 let mut metadata = chunk.metadata.clone();
886 metadata.insert("doc_id".to_string(), serde_json::json!(chunk.doc_id));
887 metadata.insert(
888 "chunk_index".to_string(),
889 serde_json::json!(chunk.position.index),
890 );
891 if let Some(source) = doc.source.clone() {
892 metadata.insert("source".to_string(), serde_json::json!(source));
893 }
894
895 VectorItem {
896 id: chunk.id.clone(),
897 vector: embedding,
898 metadata,
899 content: Some(chunk.content.clone()),
900 }
901 })
902 .collect();
903
904 {
906 let mut cache = self.chunk_cache.write();
907 for item in &vector_items {
908 cache.insert(item.id.clone(), item.content.clone().unwrap_or_default());
909 }
910 }
911
912 self.vector_store.add_batch(vector_items).await?;
914
915 {
917 let mut index = self.doc_index.write();
918 index.insert(doc_id.clone(), chunk_ids);
919 }
920
921 doc_ids.push(doc_id);
922 }
923
924 Ok(doc_ids)
925 }
926
927 async fn retrieve(&self, query: &str, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
928 let query_embedding = self.embedding_model.embed(query).await?;
930
931 let results = self.vector_store.query(query_embedding, top_k).await?;
933
934 let cache = self.chunk_cache.read();
936 let enriched_results: Vec<RetrievalResult> = results
937 .into_iter()
938 .map(|r| {
939 let content = cache.get(&r.doc_id).cloned().unwrap_or(r.content);
940 RetrievalResult {
941 doc_id: r.doc_id,
942 content,
943 score: r.score,
944 metadata: r.metadata,
945 source: r.source,
946 }
947 })
948 .collect();
949
950 Ok(enriched_results)
951 }
952
953 async fn hybrid_retrieve(
954 &self,
955 query: &str,
956 top_k: usize,
957 ) -> Layer3Result<Vec<RetrievalResult>> {
958 self.hybrid_retrieve_with_config(query, top_k, &HybridSearchConfig::default())
959 .await
960 }
961
962 async fn hybrid_retrieve_with_config(
963 &self,
964 query: &str,
965 top_k: usize,
966 config: &HybridSearchConfig,
967 ) -> Layer3Result<Vec<RetrievalResult>> {
968 if config.weights.keyword == 0.0 {
970 return self.retrieve(query, top_k).await;
971 }
972
973 let candidates_count = top_k * config.candidates_multiplier;
975 let vector_results = self.retrieve(query, candidates_count).await?;
976
977 if config.weights.vector == 0.0 {
979 return self
980 .keyword_only_search(query, vector_results, top_k, config)
981 .await;
982 }
983
984 let query_keywords = self.extract_keywords(query);
986
987 let mut scored_results: Vec<RetrievalResult> = vector_results
989 .into_iter()
990 .map(|r| {
991 let keyword_score = self.compute_keyword_score(&query_keywords, &r.content, config);
992
993 let final_score =
995 r.score * config.weights.vector + keyword_score * config.weights.keyword;
996
997 RetrievalResult {
998 doc_id: r.doc_id,
999 content: r.content,
1000 score: final_score,
1001 metadata: r.metadata,
1002 source: r.source,
1003 }
1004 })
1005 .collect();
1006
1007 scored_results.sort_by(|a, b| {
1009 b.score
1010 .partial_cmp(&a.score)
1011 .unwrap_or(std::cmp::Ordering::Equal)
1012 });
1013
1014 if config.use_rrif {
1016 scored_results = self.apply_rrif(scored_results, config.rrif_k);
1017 }
1018
1019 scored_results.truncate(top_k);
1021 Ok(scored_results)
1022 }
1023
1024 async fn delete(&self, doc_ids: &[String]) -> Layer3Result<bool> {
1025 let all_chunk_ids: Vec<String> = {
1027 let mut index = self.doc_index.write();
1028 let mut cache = self.chunk_cache.write();
1029
1030 let mut ids_to_delete: Vec<String> = Vec::new();
1031 for doc_id in doc_ids {
1032 if let Some(chunk_ids) = index.remove(doc_id) {
1033 for chunk_id in &chunk_ids {
1034 cache.remove(chunk_id);
1035 }
1036 ids_to_delete.extend(chunk_ids);
1037 }
1038 }
1039 ids_to_delete
1040 };
1041
1042 if all_chunk_ids.is_empty() {
1043 return Ok(false);
1044 }
1045
1046 self.vector_store.delete_batch(&all_chunk_ids).await?;
1047 Ok(true)
1048 }
1049
1050 async fn clear(&self) -> Layer3Result<bool> {
1051 self.vector_store.clear().await?;
1052 let mut index = self.doc_index.write();
1053 index.clear();
1054 let mut cache = self.chunk_cache.write();
1055 cache.clear();
1056 Ok(true)
1057 }
1058
1059 async fn count(&self) -> Layer3Result<usize> {
1060 let index = self.doc_index.read();
1061 Ok(index.len())
1062 }
1063}
1064
1065pub struct Layer1EmbeddingAdapter {
1069 inner: Box<dyn sh_layer1::EmbeddingModel>,
1070}
1071
1072impl Layer1EmbeddingAdapter {
1073 pub fn new(model: Box<dyn sh_layer1::EmbeddingModel>) -> Self {
1075 Self { inner: model }
1076 }
1077}
1078
1079#[async_trait]
1080impl EmbeddingModel for Layer1EmbeddingAdapter {
1081 async fn embed(&self, text: &str) -> Layer3Result<Vec<f32>> {
1082 self.inner.embed(text).await
1083 }
1084
1085 async fn embed_batch(&self, texts: &[String]) -> Layer3Result<Vec<Vec<f32>>> {
1086 self.inner.embed_batch(texts).await
1087 }
1088
1089 fn dimension(&self) -> usize {
1090 self.inner.dimension()
1091 }
1092
1093 fn model_name(&self) -> &str {
1094 self.inner.model_name()
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use super::*;
1101 use crate::vector_store::InMemoryVectorStore;
1102
1103 fn create_mock_embedding_model(dimension: usize) -> Layer1EmbeddingAdapter {
1106 Layer1EmbeddingAdapter::new(Box::new(sh_layer1::MockEmbeddingModel::new(dimension)))
1107 }
1108
1109 #[test]
1110 fn test_document_builder() {
1111 let doc = Document::new("test content")
1112 .with_source("test.txt")
1113 .with_metadata("key", serde_json::json!("value"));
1114 assert_eq!(doc.source, Some("test.txt".to_string()));
1115 }
1116
1117 #[test]
1118 fn test_fixed_size_chunker() {
1119 let chunker = FixedSizeChunker::new(100, 20);
1120 let doc = Document::new("a".repeat(250));
1121 let chunks = chunker.chunk(&doc);
1122 assert!(!chunks.is_empty());
1123 }
1124
1125 #[tokio::test]
1126 async fn test_default_retriever_engine_index() {
1127 let vector_store = InMemoryVectorStore::in_memory();
1128 let embedding_model = create_mock_embedding_model(128);
1129 let chunker = FixedSizeChunker::default();
1130
1131 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1132
1133 let doc = Document::new("This is a test document for RAG.").with_source("test.txt");
1134
1135 let doc_ids = engine.index(vec![doc]).await.unwrap();
1136 assert_eq!(doc_ids.len(), 1);
1137 assert_eq!(engine.count().await.unwrap(), 1);
1138 }
1139
1140 #[tokio::test]
1141 async fn test_default_retriever_engine_retrieve() {
1142 let vector_store = InMemoryVectorStore::in_memory();
1143 let embedding_model = create_mock_embedding_model(128);
1144 let chunker = FixedSizeChunker::default();
1145
1146 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1147
1148 let docs = vec![
1150 Document::new("Rust is a systems programming language."),
1151 Document::new("Python is great for data science."),
1152 ];
1153 engine.index(docs).await.unwrap();
1154
1155 let results = engine.retrieve("Rust programming", 5).await.unwrap();
1157 assert!(!results.is_empty());
1158 }
1159
1160 #[tokio::test]
1161 async fn test_default_retriever_engine_delete() {
1162 let vector_store = InMemoryVectorStore::in_memory();
1163 let embedding_model = create_mock_embedding_model(128);
1164 let chunker = FixedSizeChunker::default();
1165
1166 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1167
1168 let doc = Document::new("Test document");
1169 let doc_ids = engine.index(vec![doc]).await.unwrap();
1170
1171 let deleted = engine.delete(&doc_ids).await.unwrap();
1172 assert!(deleted);
1173 assert_eq!(engine.count().await.unwrap(), 0);
1174 }
1175
1176 #[tokio::test]
1177 async fn test_mock_embedding_model() {
1178 let model = create_mock_embedding_model(64);
1179
1180 let embedding = model.embed("test").await.unwrap();
1181 assert_eq!(embedding.len(), 64);
1182 assert_eq!(model.dimension(), 64);
1183 assert_eq!(model.model_name(), "mock-embedding");
1184
1185 let embeddings = model
1186 .embed_batch(&["test1".to_string(), "test2".to_string()])
1187 .await
1188 .unwrap();
1189 assert_eq!(embeddings.len(), 2);
1190 }
1191
1192 #[tokio::test]
1193 async fn test_hybrid_retrieve() {
1194 let vector_store = InMemoryVectorStore::in_memory();
1195 let embedding_model = create_mock_embedding_model(128);
1196 let chunker = FixedSizeChunker::default();
1197
1198 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1199
1200 let docs = vec![
1202 Document::new("Rust is a systems programming language designed for performance."),
1203 Document::new("Python is widely used for data science and machine learning."),
1204 Document::new("JavaScript runs in the browser for web development."),
1205 ];
1206 engine.index(docs).await.unwrap();
1207
1208 let results = engine
1210 .hybrid_retrieve("Rust programming language", 5)
1211 .await
1212 .unwrap();
1213 assert!(!results.is_empty());
1214 assert!(results[0].content.contains("Rust"));
1216 }
1217
1218 #[tokio::test]
1219 async fn test_hybrid_retrieve_with_config() {
1220 let vector_store = InMemoryVectorStore::in_memory();
1221 let embedding_model = create_mock_embedding_model(128);
1222 let chunker = FixedSizeChunker::default();
1223
1224 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1225
1226 let docs = vec![
1228 Document::new("Machine learning algorithms use neural networks."),
1229 Document::new("The database stores data for the application."),
1230 ];
1231 engine.index(docs).await.unwrap();
1232
1233 let config_vector_only =
1235 HybridSearchConfig::new().with_weights(HybridWeights::vector_only());
1236 let results = engine
1237 .hybrid_retrieve_with_config("neural networks", 5, &config_vector_only)
1238 .await
1239 .unwrap();
1240 assert!(!results.is_empty());
1241
1242 let config_keyword_only =
1244 HybridSearchConfig::new().with_weights(HybridWeights::keyword_only());
1245 let results = engine
1246 .hybrid_retrieve_with_config("machine learning", 5, &config_keyword_only)
1247 .await
1248 .unwrap();
1249 assert!(!results.is_empty());
1250
1251 let config_balanced = HybridSearchConfig::new()
1253 .with_weights(HybridWeights::balanced())
1254 .with_rrif(true, 60.0);
1255 let results = engine
1256 .hybrid_retrieve_with_config("database", 5, &config_balanced)
1257 .await
1258 .unwrap();
1259 assert!(!results.is_empty());
1260 }
1261
1262 #[test]
1263 fn test_hybrid_weights() {
1264 let weights = HybridWeights::default_weights();
1265 assert_eq!(weights.vector, 0.7);
1266 assert_eq!(weights.keyword, 0.3);
1267
1268 let vector_only = HybridWeights::vector_only();
1269 assert_eq!(vector_only.vector, 1.0);
1270 assert_eq!(vector_only.keyword, 0.0);
1271
1272 let balanced = HybridWeights::balanced();
1273 assert_eq!(balanced.vector, 0.5);
1274 assert_eq!(balanced.keyword, 0.5);
1275 }
1276
1277 #[test]
1278 fn test_extract_keywords() {
1279 let vector_store = InMemoryVectorStore::in_memory();
1280 let embedding_model = create_mock_embedding_model(128);
1281 let chunker = FixedSizeChunker::default();
1282
1283 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1284
1285 let keywords = engine.extract_keywords("The Rust programming language");
1287 assert!(keywords.contains(&"rust".to_string()));
1288 assert!(keywords.contains(&"programming".to_string()));
1289 assert!(keywords.contains(&"language".to_string()));
1290 assert!(!keywords.contains(&"the".to_string()));
1292 }
1293
1294 #[test]
1295 fn test_bm25_keyword_score() {
1296 let vector_store = InMemoryVectorStore::in_memory();
1297 let embedding_model = create_mock_embedding_model(128);
1298 let chunker = FixedSizeChunker::default();
1299
1300 let engine = DefaultRetrieverEngine::new(vector_store, embedding_model, chunker);
1301 let config = HybridSearchConfig::new();
1302
1303 let keywords = vec!["rust".to_string(), "programming".to_string()];
1304
1305 let score_high = engine.compute_keyword_score(
1307 &keywords,
1308 "Rust programming language for systems",
1309 &config,
1310 );
1311
1312 let score_low =
1314 engine.compute_keyword_score(&keywords, "Python data science frameworks", &config);
1315
1316 assert!(score_high > score_low);
1317 }
1318}