1use crate::{Document, DocumentChunk, Embedding, RragError, RragResult};
171use async_trait::async_trait;
172use serde::{Deserialize, Serialize};
173use std::collections::HashMap;
174use std::sync::Arc;
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct SearchResult {
209 pub id: String,
211
212 pub content: String,
214
215 pub score: f32,
217
218 pub rank: usize,
220
221 pub metadata: HashMap<String, serde_json::Value>,
223
224 pub embedding: Option<Embedding>,
226}
227
228impl SearchResult {
229 pub fn new(id: impl Into<String>, content: impl Into<String>, score: f32, rank: usize) -> Self {
231 Self {
232 id: id.into(),
233 content: content.into(),
234 score,
235 rank,
236 metadata: HashMap::new(),
237 embedding: None,
238 }
239 }
240
241 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
242 self.metadata.insert(key.into(), value);
243 self
244 }
245
246 pub fn with_embedding(mut self, embedding: Embedding) -> Self {
247 self.embedding = Some(embedding);
248 self
249 }
250}
251
252#[derive(Debug, Clone)]
295pub struct SearchQuery {
296 pub query: QueryType,
298
299 pub limit: usize,
301
302 pub min_score: f32,
304
305 pub filters: HashMap<String, serde_json::Value>,
307
308 pub config: SearchConfig,
310}
311
312#[derive(Debug, Clone)]
314pub enum QueryType {
315 Text(String),
317
318 Embedding(Embedding),
320}
321
322#[derive(Debug, Clone)]
324pub struct SearchConfig {
325 pub include_embeddings: bool,
327
328 pub enable_reranking: bool,
330
331 pub algorithm: SearchAlgorithm,
333
334 pub scoring_weights: ScoringWeights,
336}
337
338#[derive(Debug, Clone)]
340pub enum SearchAlgorithm {
341 Cosine,
343
344 Euclidean,
346
347 DotProduct,
349
350 Hybrid {
352 methods: Vec<SearchAlgorithm>,
353 weights: Vec<f32>,
354 },
355}
356
357#[derive(Debug, Clone)]
359pub struct ScoringWeights {
360 pub semantic: f32,
362
363 pub metadata: f32,
365
366 pub recency: f32,
368
369 pub quality: f32,
371}
372
373impl Default for SearchConfig {
374 fn default() -> Self {
375 Self {
376 include_embeddings: false,
377 enable_reranking: true,
378 algorithm: SearchAlgorithm::Cosine,
379 scoring_weights: ScoringWeights::default(),
380 }
381 }
382}
383
384impl Default for ScoringWeights {
385 fn default() -> Self {
386 Self {
387 semantic: 1.0,
388 metadata: 0.1,
389 recency: 0.05,
390 quality: 0.1,
391 }
392 }
393}
394
395impl SearchQuery {
396 pub fn text(query: impl Into<String>) -> Self {
398 Self {
399 query: QueryType::Text(query.into()),
400 limit: 10,
401 min_score: 0.0,
402 filters: HashMap::new(),
403 config: SearchConfig::default(),
404 }
405 }
406
407 pub fn embedding(embedding: Embedding) -> Self {
409 Self {
410 query: QueryType::Embedding(embedding),
411 limit: 10,
412 min_score: 0.0,
413 filters: HashMap::new(),
414 config: SearchConfig::default(),
415 }
416 }
417
418 pub fn with_limit(mut self, limit: usize) -> Self {
420 self.limit = limit;
421 self
422 }
423
424 pub fn with_min_score(mut self, min_score: f32) -> Self {
426 self.min_score = min_score;
427 self
428 }
429
430 pub fn with_filter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
432 self.filters.insert(key.into(), value);
433 self
434 }
435
436 pub fn with_config(mut self, config: SearchConfig) -> Self {
438 self.config = config;
439 self
440 }
441}
442
443#[async_trait]
445pub trait Retriever: Send + Sync {
446 fn name(&self) -> &str;
448
449 async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>>;
451
452 async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()>;
454
455 async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()>;
457
458 async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()>;
460
461 async fn clear(&self) -> RragResult<()>;
463
464 async fn stats(&self) -> RragResult<IndexStats>;
466
467 async fn health_check(&self) -> RragResult<bool>;
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct IndexStats {
474 pub total_items: usize,
476
477 pub size_bytes: usize,
479
480 pub dimensions: usize,
482
483 pub index_type: String,
485
486 pub last_updated: chrono::DateTime<chrono::Utc>,
488}
489
490pub struct InMemoryRetriever {
492 documents: Arc<tokio::sync::RwLock<HashMap<String, (Document, Embedding)>>>,
494
495 chunks: Arc<tokio::sync::RwLock<HashMap<String, (DocumentChunk, Embedding)>>>,
497
498 config: RetrieverConfig,
500}
501
502#[derive(Debug, Clone)]
504pub struct RetrieverConfig {
505 pub storage_mode: StorageMode,
507
508 pub default_threshold: f32,
510
511 pub max_results: usize,
513}
514
515#[derive(Debug, Clone)]
516pub enum StorageMode {
517 DocumentsOnly,
518 ChunksOnly,
519 Both,
520}
521
522impl Default for RetrieverConfig {
523 fn default() -> Self {
524 Self {
525 storage_mode: StorageMode::Both,
526 default_threshold: 0.0,
527 max_results: 1000,
528 }
529 }
530}
531
532impl InMemoryRetriever {
533 pub fn new() -> Self {
535 Self {
536 documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
537 chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
538 config: RetrieverConfig::default(),
539 }
540 }
541
542 pub fn with_config(config: RetrieverConfig) -> Self {
544 Self {
545 documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
546 chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
547 config,
548 }
549 }
550
551 fn calculate_similarity(
553 &self,
554 embedding1: &Embedding,
555 embedding2: &Embedding,
556 algorithm: &SearchAlgorithm,
557 ) -> RragResult<f32> {
558 match algorithm {
559 SearchAlgorithm::Cosine => embedding1.cosine_similarity(embedding2),
560 SearchAlgorithm::Euclidean => {
561 let distance = embedding1.euclidean_distance(embedding2)?;
562 Ok(1.0 / (1.0 + distance))
564 }
565 SearchAlgorithm::DotProduct => {
566 if embedding1.dimensions != embedding2.dimensions {
567 return Err(RragError::retrieval(format!(
568 "Dimension mismatch: {} vs {}",
569 embedding1.dimensions, embedding2.dimensions
570 )));
571 }
572 let dot_product: f32 = embedding1
573 .vector
574 .iter()
575 .zip(embedding2.vector.iter())
576 .map(|(a, b)| a * b)
577 .sum();
578 Ok(dot_product.max(0.0).min(1.0)) }
580 SearchAlgorithm::Hybrid { methods, weights } => {
581 let mut total_score = 0.0;
582 let mut total_weight = 0.0;
583
584 for (method, weight) in methods.iter().zip(weights.iter()) {
585 let score = self.calculate_similarity(embedding1, embedding2, method)?;
586 total_score += score * weight;
587 total_weight += weight;
588 }
589
590 if total_weight > 0.0 {
591 Ok(total_score / total_weight)
592 } else {
593 Ok(0.0)
594 }
595 }
596 }
597 }
598
599 fn apply_filters(
601 &self,
602 metadata: &HashMap<String, serde_json::Value>,
603 filters: &HashMap<String, serde_json::Value>,
604 ) -> bool {
605 for (key, expected_value) in filters {
606 match metadata.get(key) {
607 Some(actual_value) if actual_value == expected_value => continue,
608 _ => return false,
609 }
610 }
611 true
612 }
613
614 fn rerank_results(
616 &self,
617 mut results: Vec<SearchResult>,
618 weights: &ScoringWeights,
619 ) -> Vec<SearchResult> {
620 for result in &mut results {
622 let mut enhanced_score = result.score * weights.semantic;
623
624 if !result.metadata.is_empty() {
626 enhanced_score += 0.1 * weights.metadata;
627 }
628
629 if let Some(timestamp_value) = result.metadata.get("created_at") {
631 if let Some(timestamp_str) = timestamp_value.as_str() {
632 if let Ok(timestamp) = chrono::DateTime::parse_from_rfc3339(timestamp_str) {
633 let age_days =
634 (chrono::Utc::now() - timestamp.with_timezone(&chrono::Utc)).num_days();
635 let recency_bonus = (-age_days as f32 / 30.0).exp() * weights.recency;
636 enhanced_score += recency_bonus;
637 }
638 }
639 }
640
641 let content_length = result.content.len();
643 if content_length > 100 && content_length < 2000 {
644 enhanced_score += 0.05 * weights.quality;
645 }
646
647 result.score = enhanced_score.min(1.0);
648 }
649
650 results.sort_by(|a, b| {
652 b.score
653 .partial_cmp(&a.score)
654 .unwrap_or(std::cmp::Ordering::Equal)
655 });
656
657 for (i, result) in results.iter_mut().enumerate() {
659 result.rank = i;
660 }
661
662 results
663 }
664}
665
666impl Default for InMemoryRetriever {
667 fn default() -> Self {
668 Self::new()
669 }
670}
671
672#[async_trait]
673impl Retriever for InMemoryRetriever {
674 fn name(&self) -> &str {
675 "in_memory"
676 }
677
678 async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>> {
679 let query_embedding = match &query.query {
680 QueryType::Text(_) => {
681 return Err(RragError::retrieval(
682 "Text queries require pre-computed embeddings for in-memory retriever"
683 .to_string(),
684 ));
685 }
686 QueryType::Embedding(emb) => emb,
687 };
688
689 let mut results = Vec::new();
690
691 if matches!(
693 self.config.storage_mode,
694 StorageMode::DocumentsOnly | StorageMode::Both
695 ) {
696 let documents = self.documents.read().await;
697 for (doc_id, (document, embedding)) in documents.iter() {
698 if !self.apply_filters(&document.metadata, &query.filters) {
700 continue;
701 }
702
703 let similarity =
704 self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
705
706 if similarity >= query.min_score {
707 let mut result = SearchResult::new(
708 doc_id,
709 document.content_str(),
710 similarity,
711 0, )
713 .with_metadata("type", serde_json::Value::String("document".to_string()));
714
715 for (key, value) in &document.metadata {
717 result = result.with_metadata(key, value.clone());
718 }
719
720 if query.config.include_embeddings {
721 result = result.with_embedding(embedding.clone());
722 }
723
724 results.push(result);
725 }
726 }
727 }
728
729 if matches!(
731 self.config.storage_mode,
732 StorageMode::ChunksOnly | StorageMode::Both
733 ) {
734 let chunks = self.chunks.read().await;
735 for (chunk_id, (chunk, embedding)) in chunks.iter() {
736 if !self.apply_filters(&chunk.metadata, &query.filters) {
738 continue;
739 }
740
741 let similarity =
742 self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
743
744 if similarity >= query.min_score {
745 let mut result = SearchResult::new(
746 chunk_id,
747 &chunk.content,
748 similarity,
749 0, )
751 .with_metadata("type", serde_json::Value::String("chunk".to_string()))
752 .with_metadata(
753 "document_id",
754 serde_json::Value::String(chunk.document_id.clone()),
755 )
756 .with_metadata(
757 "chunk_index",
758 serde_json::Value::Number(chunk.chunk_index.into()),
759 );
760
761 for (key, value) in &chunk.metadata {
763 result = result.with_metadata(key, value.clone());
764 }
765
766 if query.config.include_embeddings {
767 result = result.with_embedding(embedding.clone());
768 }
769
770 results.push(result);
771 }
772 }
773 }
774
775 results.sort_by(|a, b| {
777 b.score
778 .partial_cmp(&a.score)
779 .unwrap_or(std::cmp::Ordering::Equal)
780 });
781
782 if query.config.enable_reranking {
784 results = self.rerank_results(results, &query.config.scoring_weights);
785 }
786
787 for (i, result) in results.iter_mut().enumerate() {
789 result.rank = i;
790 }
791
792 results.truncate(query.limit.min(self.config.max_results));
794
795 Ok(results)
796 }
797
798 async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()> {
799 let mut docs = self.documents.write().await;
800 for (document, embedding) in documents {
801 docs.insert(document.id.clone(), (document.clone(), embedding.clone()));
802 }
803 Ok(())
804 }
805
806 async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()> {
807 let mut chunk_store = self.chunks.write().await;
808 for (chunk, embedding) in chunks {
809 let chunk_id = format!("{}_{}", chunk.document_id, chunk.chunk_index);
810 chunk_store.insert(chunk_id, (chunk.clone(), embedding.clone()));
811 }
812 Ok(())
813 }
814
815 async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()> {
816 let mut docs = self.documents.write().await;
817 for doc_id in document_ids {
818 docs.remove(doc_id);
819 }
820
821 let mut chunk_store = self.chunks.write().await;
823 let chunk_ids_to_remove: Vec<String> = chunk_store
824 .iter()
825 .filter(|(_, (chunk, _))| document_ids.contains(&chunk.document_id))
826 .map(|(id, _)| id.clone())
827 .collect();
828
829 for chunk_id in chunk_ids_to_remove {
830 chunk_store.remove(&chunk_id);
831 }
832
833 Ok(())
834 }
835
836 async fn clear(&self) -> RragResult<()> {
837 self.documents.write().await.clear();
838 self.chunks.write().await.clear();
839 Ok(())
840 }
841
842 async fn stats(&self) -> RragResult<IndexStats> {
843 let doc_count = self.documents.read().await.len();
844 let chunk_count = self.chunks.read().await.len();
845
846 let dimensions = if doc_count > 0 {
848 self.documents
849 .read()
850 .await
851 .values()
852 .next()
853 .map(|(_, emb)| emb.dimensions)
854 .unwrap_or(0)
855 } else if chunk_count > 0 {
856 self.chunks
857 .read()
858 .await
859 .values()
860 .next()
861 .map(|(_, emb)| emb.dimensions)
862 .unwrap_or(0)
863 } else {
864 0
865 };
866
867 Ok(IndexStats {
868 total_items: doc_count + chunk_count,
869 size_bytes: (doc_count + chunk_count) * dimensions * 4, dimensions,
871 index_type: "in_memory".to_string(),
872 last_updated: chrono::Utc::now(),
873 })
874 }
875
876 async fn health_check(&self) -> RragResult<bool> {
877 Ok(true)
878 }
879}
880
881pub struct RetrievalService {
883 retriever: Arc<dyn Retriever>,
885
886 config: RetrievalServiceConfig,
888}
889
890#[derive(Debug, Clone)]
892pub struct RetrievalServiceConfig {
893 pub default_search_config: SearchConfig,
895
896 pub enable_caching: bool,
898
899 pub cache_ttl_seconds: u64,
901}
902
903impl Default for RetrievalServiceConfig {
904 fn default() -> Self {
905 Self {
906 default_search_config: SearchConfig::default(),
907 enable_caching: false,
908 cache_ttl_seconds: 300, }
910 }
911}
912
913impl RetrievalService {
914 pub fn new(retriever: Arc<dyn Retriever>) -> Self {
916 Self {
917 retriever,
918 config: RetrievalServiceConfig::default(),
919 }
920 }
921
922 pub fn with_config(retriever: Arc<dyn Retriever>, config: RetrievalServiceConfig) -> Self {
924 Self { retriever, config }
925 }
926
927 pub async fn search_text(
929 &self,
930 _query: &str,
931 _limit: Option<usize>,
932 ) -> RragResult<Vec<SearchResult>> {
933 Err(RragError::retrieval(
936 "Text search requires embedding service integration".to_string(),
937 ))
938 }
939
940 pub async fn search_embedding(
942 &self,
943 embedding: Embedding,
944 limit: Option<usize>,
945 ) -> RragResult<Vec<SearchResult>> {
946 let query = SearchQuery::embedding(embedding)
947 .with_limit(limit.unwrap_or(10))
948 .with_config(self.config.default_search_config.clone());
949
950 self.retriever.search(&query).await
951 }
952
953 pub async fn search(&self, query: SearchQuery) -> RragResult<Vec<SearchResult>> {
955 self.retriever.search(&query).await
956 }
957
958 pub async fn index_documents(
960 &self,
961 documents_with_embeddings: &[(Document, Embedding)],
962 ) -> RragResult<()> {
963 self.retriever
964 .add_documents(documents_with_embeddings)
965 .await
966 }
967
968 pub async fn index_chunks(
970 &self,
971 chunks_with_embeddings: &[(DocumentChunk, Embedding)],
972 ) -> RragResult<()> {
973 self.retriever.add_chunks(chunks_with_embeddings).await
974 }
975
976 pub async fn get_stats(&self) -> RragResult<IndexStats> {
978 self.retriever.stats().await
979 }
980
981 pub async fn health_check(&self) -> RragResult<bool> {
983 self.retriever.health_check().await
984 }
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use crate::Document;
991
992 #[tokio::test]
993 async fn test_in_memory_retriever() {
994 let retriever = InMemoryRetriever::new();
995
996 let doc1 = Document::new("First test document");
998 let emb1 = Embedding::new(vec![1.0, 0.0, 0.0], "test-model", &doc1.id);
999
1000 let doc2 = Document::new("Second test document");
1001 let emb2 = Embedding::new(vec![0.0, 1.0, 0.0], "test-model", &doc2.id);
1002
1003 retriever
1005 .add_documents(&[(doc1.clone(), emb1.clone()), (doc2, emb2)])
1006 .await
1007 .unwrap();
1008
1009 let query_embedding = Embedding::new(vec![0.8, 0.2, 0.0], "test-model", "query");
1011 let query = SearchQuery::embedding(query_embedding).with_limit(5);
1012
1013 let results = retriever.search(&query).await.unwrap();
1015
1016 assert!(!results.is_empty());
1017 assert_eq!(results[0].id, doc1.id); }
1019
1020 #[tokio::test]
1021 async fn test_search_filters() {
1022 let retriever = InMemoryRetriever::new();
1023
1024 let doc1 = Document::new("Test document")
1025 .with_metadata("category", serde_json::Value::String("tech".to_string()));
1026 let emb1 = Embedding::new(vec![1.0, 0.0], "test-model", &doc1.id);
1027
1028 let doc2 = Document::new("Another document")
1029 .with_metadata("category", serde_json::Value::String("science".to_string()));
1030 let emb2 = Embedding::new(vec![0.9, 0.1], "test-model", &doc2.id);
1031
1032 retriever
1033 .add_documents(&[(doc1.clone(), emb1), (doc2, emb2)])
1034 .await
1035 .unwrap();
1036
1037 let query_embedding = Embedding::new(vec![1.0, 0.0], "test-model", "query");
1039 let query = SearchQuery::embedding(query_embedding)
1040 .with_filter("category", serde_json::Value::String("tech".to_string()));
1041
1042 let results = retriever.search(&query).await.unwrap();
1043
1044 assert_eq!(results.len(), 1);
1045 assert_eq!(results[0].id, doc1.id);
1046 }
1047
1048 #[test]
1049 fn test_search_query_builder() {
1050 let query = SearchQuery::text("test query")
1051 .with_limit(20)
1052 .with_min_score(0.5)
1053 .with_filter("type", serde_json::Value::String("article".to_string()));
1054
1055 assert_eq!(query.limit, 20);
1056 assert_eq!(query.min_score, 0.5);
1057 assert_eq!(query.filters.len(), 1);
1058 }
1059
1060 #[tokio::test]
1061 async fn test_retrieval_service() {
1062 let retriever = Arc::new(InMemoryRetriever::new());
1063 let service = RetrievalService::new(retriever);
1064
1065 let stats = service.get_stats().await.unwrap();
1066 assert_eq!(stats.total_items, 0);
1067
1068 assert!(service.health_check().await.unwrap());
1069 }
1070}