rexis_rag/
retrieval_core.rs

1//! # RRAG Retrieval System
2//!
3//! High-performance, async-first retrieval system with pluggable similarity search,
4//! advanced ranking algorithms, and comprehensive filtering capabilities. Built for
5//! production workloads with sub-millisecond response times and horizontal scaling.
6//!
7//! ## Features
8//!
9//! - **Multiple Search Algorithms**: Cosine similarity, dot product, Euclidean distance
10//! - **Advanced Filtering**: Metadata-based filtering with complex queries
11//! - **Ranking & Scoring**: Configurable scoring and re-ranking strategies
12//! - **Async Operations**: Full async/await support for high concurrency
13//! - **Memory Efficient**: Optimized data structures and minimal allocations
14//! - **Pluggable Backends**: Support for multiple storage backends
15//! - **Real-time Updates**: Live index updates without downtime
16//!
17//! ## Quick Start
18//!
19//! ### Basic Similarity Search
20//!
21//! ```rust
22//! use rrag::prelude::*;
23//! use std::sync::Arc;
24//!
25//! # #[tokio::main]
26//! # async fn main() -> RragResult<()> {
27//! // Create a retrieval service
28//! let storage = Arc::new(InMemoryStorage::new());
29//! let retriever = InMemoryRetriever::new()
30//!     .with_storage(storage)
31//!     .with_similarity_threshold(0.8);
32//!
33//! // Add documents to the index
34//! let documents = vec![
35//!     Document::new("Rust is a systems programming language"),
36//!     Document::new("Python is great for data science"),
37//!     Document::new("JavaScript runs in web browsers"),
38//! ];
39//!
40//! for doc in documents {
41//!     retriever.index_document(&doc).await?;
42//! }
43//!
44//! // Search for similar content
45//! let query = SearchQuery::new("programming languages")
46//!     .with_limit(5)
47//!     .with_min_score(0.7);
48//!
49//! let results = retriever.search(query).await?;
50//! for result in results {
51//!     tracing::debug!("Score: {:.3} - {}", result.score, result.content);
52//! }
53//! # Ok(())
54//! # }
55//! ```
56//!
57//! ### Advanced Search with Filters
58//!
59//! ```rust
60//! use rrag::prelude::*;
61//!
62//! # #[tokio::main]
63//! # async fn main() -> RragResult<()> {
64//! # let retriever = InMemoryRetriever::new();
65//! // Search with metadata filters
66//! let query = SearchQuery::new("machine learning")
67//!     .with_filter("category", "technical".into())
68//!     .with_filter("language", "english".into())
69//!     .with_date_range("created_after", "2023-01-01")
70//!     .with_config(SearchConfig {
71//!         algorithm: SearchAlgorithm::CosineSimilarity,
72//!         enable_reranking: true,
73//!         include_embeddings: false,
74//!         ..Default::default()
75//!     });
76//!
77//! let results = retriever.search(query).await?;
78//! tracing::debug!("Found {} filtered results", results.len());
79//! # Ok(())
80//! # }
81//! ```
82//!
83//! ### Custom Retrieval Implementation
84//!
85//! ```rust
86//! use rrag::prelude::*;
87//! use async_trait::async_trait;
88//!
89//! struct CustomRetriever {
90//!     // Your custom fields
91//! }
92//!
93//! #[async_trait]
94//! impl Retriever for CustomRetriever {
95//!     async fn search(&self, query: SearchQuery) -> RragResult<Vec<SearchResult>> {
96//!         // Your custom search logic
97//!         # Ok(Vec::new())
98//!     }
99//!
100//!     async fn index_document(&self, document: &Document) -> RragResult<()> {
101//!         // Your custom indexing logic
102//!         Ok(())
103//!     }
104//!
105//!     async fn delete_document(&self, id: &str) -> RragResult<bool> {
106//!         // Your custom deletion logic
107//!         Ok(true)
108//!     }
109//! }
110//! ```
111//!
112//! ## Search Algorithms
113//!
114//! RRAG supports multiple similarity algorithms:
115//!
116//! ```rust
117//! use rrag::prelude::*;
118//!
119//! // Cosine similarity (default, best for most use cases)
120//! let config = SearchConfig {
121//!     algorithm: SearchAlgorithm::CosineSimilarity,
122//!     ..Default::default()
123//! };
124//!
125//! // Dot product (faster, good for normalized embeddings)
126//! let config = SearchConfig {
127//!     algorithm: SearchAlgorithm::DotProduct,
128//!     ..Default::default()
129//! };
130//!
131//! // Euclidean distance (good for spatial data)
132//! let config = SearchConfig {
133//!     algorithm: SearchAlgorithm::EuclideanDistance,
134//!     ..Default::default()
135//! };
136//! ```
137//!
138//! ## Performance Optimization
139//!
140//! - **Batch Operations**: Index multiple documents at once
141//! - **Parallel Search**: Concurrent query processing
142//! - **Memory Optimization**: Efficient vector storage and computation
143//! - **Caching**: Optional result caching for repeated queries
144//! - **Lazy Loading**: Load embeddings on demand
145//!
146//! ## Error Handling
147//!
148//! ```rust
149//! use rrag::prelude::*;
150//!
151//! # #[tokio::main]
152//! # async fn main() {
153//! match retriever.search(query).await {
154//!     Ok(results) => {
155//!         tracing::debug!("Found {} results", results.len());
156//!         for result in results {
157//!             tracing::debug!("  {}: {:.3}", result.content, result.score);
158//!         }
159//!     }
160//!     Err(RragError::Retrieval { query, .. }) => {
161//!         tracing::debug!("Search failed for query: {}", query);
162//!     }
163//!     Err(e) => {
164//!         tracing::debug!("Retrieval error: {}", e);
165//!     }
166//! }
167//! # }
168//! ```
169
170use crate::{Document, DocumentChunk, Embedding, RragError, RragResult};
171use async_trait::async_trait;
172use serde::{Deserialize, Serialize};
173use std::collections::HashMap;
174use std::sync::Arc;
175
176/// A search result containing content, similarity score, and metadata
177///
178/// Represents a single result from a similarity search operation, including
179/// the matched content, relevance score, ranking information, and associated
180/// metadata. Results are typically returned in descending order of relevance.
181///
182/// # Example
183///
184/// ```rust
185/// use rrag::prelude::*;
186///
187/// let result = SearchResult::new(
188///     "doc-123",
189///     "This document discusses machine learning algorithms",
190///     0.87, // 87% similarity
191///     0     // First result
192/// )
193/// .with_metadata("category", "technical".into())
194/// .with_metadata("author", "Dr. Smith".into())
195/// .with_embedding(embedding); // Optional embedding
196///
197/// tracing::debug!("Result: {} (score: {:.3})", result.content, result.score);
198/// ```
199///
200/// # Scoring
201///
202/// Scores are normalized to 0.0-1.0 range where:
203/// - 1.0 = Perfect match (identical content)
204/// - 0.8+ = Very relevant
205/// - 0.6-0.8 = Somewhat relevant  
206/// - <0.6 = Low relevance
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct SearchResult {
209    /// Document or chunk ID
210    pub id: String,
211
212    /// Content that matched the query
213    pub content: String,
214
215    /// Similarity score (0.0 to 1.0, higher is more similar)
216    pub score: f32,
217
218    /// Ranking position in results (0-indexed)
219    pub rank: usize,
220
221    /// Associated metadata
222    pub metadata: HashMap<String, serde_json::Value>,
223
224    /// Embedding used for the match (optional)
225    pub embedding: Option<Embedding>,
226}
227
228impl SearchResult {
229    /// Create a new search result with the specified parameters
230    pub fn new(id: impl Into<String>, content: impl Into<String>, score: f32, rank: usize) -> Self {
231        Self {
232            id: id.into(),
233            content: content.into(),
234            score,
235            rank,
236            metadata: HashMap::new(),
237            embedding: None,
238        }
239    }
240
241    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
242        self.metadata.insert(key.into(), value);
243        self
244    }
245
246    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
247        self.embedding = Some(embedding);
248        self
249    }
250}
251
252/// A search query with comprehensive configuration options
253///
254/// Encapsulates all parameters for a search operation including the query itself,
255/// result limits, filtering criteria, and algorithm configuration. Supports both
256/// text queries (that will be embedded) and pre-computed embedding queries.
257///
258/// # Example
259///
260/// ```rust
261/// use rrag::prelude::*;
262///
263/// // Simple text query
264/// let query = SearchQuery::new("machine learning algorithms")
265///     .with_limit(10)
266///     .with_min_score(0.7);
267///
268/// // Advanced query with filters
269/// let advanced_query = SearchQuery::new("neural networks")
270///     .with_limit(20)
271///     .with_min_score(0.6)
272///     .with_filter("category", "research".into())
273///     .with_filter("year", 2023.into())
274///     .with_config(SearchConfig {
275///         algorithm: SearchAlgorithm::CosineSimilarity,
276///         enable_reranking: true,
277///         include_embeddings: true,
278///         ..Default::default()
279///     });
280///
281/// // Query with pre-computed embedding
282/// let embedding_query = SearchQuery::from_embedding(embedding)
283///     .with_limit(5);
284/// ```
285///
286/// # Filter Types
287///
288/// Filters support various data types:
289/// - **Strings**: Exact match or pattern matching
290/// - **Numbers**: Range queries and exact values
291/// - **Dates**: Date range filtering
292/// - **Arrays**: "Contains" operations
293/// - **Booleans**: Exact boolean matching
294#[derive(Debug, Clone)]
295pub struct SearchQuery {
296    /// Query text or embedding
297    pub query: QueryType,
298
299    /// Maximum number of results to return
300    pub limit: usize,
301
302    /// Minimum similarity threshold
303    pub min_score: f32,
304
305    /// Metadata filters
306    pub filters: HashMap<String, serde_json::Value>,
307
308    /// Search configuration
309    pub config: SearchConfig,
310}
311
312/// Query type - text or pre-computed embedding
313#[derive(Debug, Clone)]
314pub enum QueryType {
315    /// Text query that needs to be embedded
316    Text(String),
317
318    /// Pre-computed embedding vector
319    Embedding(Embedding),
320}
321
322/// Search configuration
323#[derive(Debug, Clone)]
324pub struct SearchConfig {
325    /// Whether to include embeddings in results
326    pub include_embeddings: bool,
327
328    /// Whether to apply re-ranking
329    pub enable_reranking: bool,
330
331    /// Search algorithm to use
332    pub algorithm: SearchAlgorithm,
333
334    /// Custom scoring weights
335    pub scoring_weights: ScoringWeights,
336}
337
338/// Search algorithms available
339#[derive(Debug, Clone)]
340pub enum SearchAlgorithm {
341    /// Cosine similarity search
342    Cosine,
343
344    /// Euclidean distance search
345    Euclidean,
346
347    /// Dot product search
348    DotProduct,
349
350    /// Hybrid search (combine multiple methods)
351    Hybrid {
352        methods: Vec<SearchAlgorithm>,
353        weights: Vec<f32>,
354    },
355}
356
357/// Scoring weights for different factors
358#[derive(Debug, Clone)]
359pub struct ScoringWeights {
360    /// Weight for semantic similarity
361    pub semantic: f32,
362
363    /// Weight for metadata matches
364    pub metadata: f32,
365
366    /// Weight for recency (if timestamps available)
367    pub recency: f32,
368
369    /// Weight for content length/quality
370    pub quality: f32,
371}
372
373impl Default for SearchConfig {
374    fn default() -> Self {
375        Self {
376            include_embeddings: false,
377            enable_reranking: true,
378            algorithm: SearchAlgorithm::Cosine,
379            scoring_weights: ScoringWeights::default(),
380        }
381    }
382}
383
384impl Default for ScoringWeights {
385    fn default() -> Self {
386        Self {
387            semantic: 1.0,
388            metadata: 0.1,
389            recency: 0.05,
390            quality: 0.1,
391        }
392    }
393}
394
395impl SearchQuery {
396    /// Create a text-based search query
397    pub fn text(query: impl Into<String>) -> Self {
398        Self {
399            query: QueryType::Text(query.into()),
400            limit: 10,
401            min_score: 0.0,
402            filters: HashMap::new(),
403            config: SearchConfig::default(),
404        }
405    }
406
407    /// Create an embedding-based search query
408    pub fn embedding(embedding: Embedding) -> Self {
409        Self {
410            query: QueryType::Embedding(embedding),
411            limit: 10,
412            min_score: 0.0,
413            filters: HashMap::new(),
414            config: SearchConfig::default(),
415        }
416    }
417
418    /// Set result limit
419    pub fn with_limit(mut self, limit: usize) -> Self {
420        self.limit = limit;
421        self
422    }
423
424    /// Set minimum score threshold
425    pub fn with_min_score(mut self, min_score: f32) -> Self {
426        self.min_score = min_score;
427        self
428    }
429
430    /// Add metadata filter
431    pub fn with_filter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
432        self.filters.insert(key.into(), value);
433        self
434    }
435
436    /// Set search configuration
437    pub fn with_config(mut self, config: SearchConfig) -> Self {
438        self.config = config;
439        self
440    }
441}
442
443/// Core retrieval trait for different storage backends
444#[async_trait]
445pub trait Retriever: Send + Sync {
446    /// Retriever name/type
447    fn name(&self) -> &str;
448
449    /// Search for similar documents/chunks
450    async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>>;
451
452    /// Add documents to the retrieval index
453    async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()>;
454
455    /// Add document chunks to the retrieval index
456    async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()>;
457
458    /// Remove documents from the index
459    async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()>;
460
461    /// Clear all documents from the index
462    async fn clear(&self) -> RragResult<()>;
463
464    /// Get index statistics
465    async fn stats(&self) -> RragResult<IndexStats>;
466
467    /// Health check
468    async fn health_check(&self) -> RragResult<bool>;
469}
470
471/// Index statistics
472#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct IndexStats {
474    /// Total number of documents/chunks indexed
475    pub total_items: usize,
476
477    /// Index size in bytes (estimate)
478    pub size_bytes: usize,
479
480    /// Number of dimensions
481    pub dimensions: usize,
482
483    /// Index type/implementation
484    pub index_type: String,
485
486    /// Last update timestamp
487    pub last_updated: chrono::DateTime<chrono::Utc>,
488}
489
490/// In-memory retriever for small datasets and testing
491pub struct InMemoryRetriever {
492    /// Stored documents with embeddings
493    documents: Arc<tokio::sync::RwLock<HashMap<String, (Document, Embedding)>>>,
494
495    /// Stored chunks with embeddings
496    chunks: Arc<tokio::sync::RwLock<HashMap<String, (DocumentChunk, Embedding)>>>,
497
498    /// Retriever configuration
499    config: RetrieverConfig,
500}
501
502/// Retriever configuration
503#[derive(Debug, Clone)]
504pub struct RetrieverConfig {
505    /// Whether to store documents, chunks, or both
506    pub storage_mode: StorageMode,
507
508    /// Default similarity threshold
509    pub default_threshold: f32,
510
511    /// Maximum results to return
512    pub max_results: usize,
513}
514
515#[derive(Debug, Clone)]
516pub enum StorageMode {
517    DocumentsOnly,
518    ChunksOnly,
519    Both,
520}
521
522impl Default for RetrieverConfig {
523    fn default() -> Self {
524        Self {
525            storage_mode: StorageMode::Both,
526            default_threshold: 0.0,
527            max_results: 1000,
528        }
529    }
530}
531
532impl InMemoryRetriever {
533    /// Create new in-memory retriever
534    pub fn new() -> Self {
535        Self {
536            documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
537            chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
538            config: RetrieverConfig::default(),
539        }
540    }
541
542    /// Create with custom configuration
543    pub fn with_config(config: RetrieverConfig) -> Self {
544        Self {
545            documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
546            chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
547            config,
548        }
549    }
550
551    /// Calculate similarity between embeddings
552    fn calculate_similarity(
553        &self,
554        embedding1: &Embedding,
555        embedding2: &Embedding,
556        algorithm: &SearchAlgorithm,
557    ) -> RragResult<f32> {
558        match algorithm {
559            SearchAlgorithm::Cosine => embedding1.cosine_similarity(embedding2),
560            SearchAlgorithm::Euclidean => {
561                let distance = embedding1.euclidean_distance(embedding2)?;
562                // Convert distance to similarity score (0-1)
563                Ok(1.0 / (1.0 + distance))
564            }
565            SearchAlgorithm::DotProduct => {
566                if embedding1.dimensions != embedding2.dimensions {
567                    return Err(RragError::retrieval(format!(
568                        "Dimension mismatch: {} vs {}",
569                        embedding1.dimensions, embedding2.dimensions
570                    )));
571                }
572                let dot_product: f32 = embedding1
573                    .vector
574                    .iter()
575                    .zip(embedding2.vector.iter())
576                    .map(|(a, b)| a * b)
577                    .sum();
578                Ok(dot_product.max(0.0).min(1.0)) // Clamp to [0, 1]
579            }
580            SearchAlgorithm::Hybrid { methods, weights } => {
581                let mut total_score = 0.0;
582                let mut total_weight = 0.0;
583
584                for (method, weight) in methods.iter().zip(weights.iter()) {
585                    let score = self.calculate_similarity(embedding1, embedding2, method)?;
586                    total_score += score * weight;
587                    total_weight += weight;
588                }
589
590                if total_weight > 0.0 {
591                    Ok(total_score / total_weight)
592                } else {
593                    Ok(0.0)
594                }
595            }
596        }
597    }
598
599    /// Apply metadata filters to a result
600    fn apply_filters(
601        &self,
602        metadata: &HashMap<String, serde_json::Value>,
603        filters: &HashMap<String, serde_json::Value>,
604    ) -> bool {
605        for (key, expected_value) in filters {
606            match metadata.get(key) {
607                Some(actual_value) if actual_value == expected_value => continue,
608                _ => return false,
609            }
610        }
611        true
612    }
613
614    /// Apply re-ranking with custom scoring
615    fn rerank_results(
616        &self,
617        mut results: Vec<SearchResult>,
618        weights: &ScoringWeights,
619    ) -> Vec<SearchResult> {
620        // Calculate enhanced scores
621        for result in &mut results {
622            let mut enhanced_score = result.score * weights.semantic;
623
624            // Add metadata matching bonus
625            if !result.metadata.is_empty() {
626                enhanced_score += 0.1 * weights.metadata;
627            }
628
629            // Add recency bonus if timestamp is available
630            if let Some(timestamp_value) = result.metadata.get("created_at") {
631                if let Some(timestamp_str) = timestamp_value.as_str() {
632                    if let Ok(timestamp) = chrono::DateTime::parse_from_rfc3339(timestamp_str) {
633                        let age_days =
634                            (chrono::Utc::now() - timestamp.with_timezone(&chrono::Utc)).num_days();
635                        let recency_bonus = (-age_days as f32 / 30.0).exp() * weights.recency;
636                        enhanced_score += recency_bonus;
637                    }
638                }
639            }
640
641            // Add quality bonus based on content length
642            let content_length = result.content.len();
643            if content_length > 100 && content_length < 2000 {
644                enhanced_score += 0.05 * weights.quality;
645            }
646
647            result.score = enhanced_score.min(1.0);
648        }
649
650        // Re-sort by enhanced scores
651        results.sort_by(|a, b| {
652            b.score
653                .partial_cmp(&a.score)
654                .unwrap_or(std::cmp::Ordering::Equal)
655        });
656
657        // Update ranks
658        for (i, result) in results.iter_mut().enumerate() {
659            result.rank = i;
660        }
661
662        results
663    }
664}
665
666impl Default for InMemoryRetriever {
667    fn default() -> Self {
668        Self::new()
669    }
670}
671
672#[async_trait]
673impl Retriever for InMemoryRetriever {
674    fn name(&self) -> &str {
675        "in_memory"
676    }
677
678    async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>> {
679        let query_embedding = match &query.query {
680            QueryType::Text(_) => {
681                return Err(RragError::retrieval(
682                    "Text queries require pre-computed embeddings for in-memory retriever"
683                        .to_string(),
684                ));
685            }
686            QueryType::Embedding(emb) => emb,
687        };
688
689        let mut results = Vec::new();
690
691        // Search documents if enabled
692        if matches!(
693            self.config.storage_mode,
694            StorageMode::DocumentsOnly | StorageMode::Both
695        ) {
696            let documents = self.documents.read().await;
697            for (doc_id, (document, embedding)) in documents.iter() {
698                // Apply metadata filters
699                if !self.apply_filters(&document.metadata, &query.filters) {
700                    continue;
701                }
702
703                let similarity =
704                    self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
705
706                if similarity >= query.min_score {
707                    let mut result = SearchResult::new(
708                        doc_id,
709                        document.content_str(),
710                        similarity,
711                        0, // Will be updated after sorting
712                    )
713                    .with_metadata("type", serde_json::Value::String("document".to_string()));
714
715                    // Add document metadata
716                    for (key, value) in &document.metadata {
717                        result = result.with_metadata(key, value.clone());
718                    }
719
720                    if query.config.include_embeddings {
721                        result = result.with_embedding(embedding.clone());
722                    }
723
724                    results.push(result);
725                }
726            }
727        }
728
729        // Search chunks if enabled
730        if matches!(
731            self.config.storage_mode,
732            StorageMode::ChunksOnly | StorageMode::Both
733        ) {
734            let chunks = self.chunks.read().await;
735            for (chunk_id, (chunk, embedding)) in chunks.iter() {
736                // Apply metadata filters
737                if !self.apply_filters(&chunk.metadata, &query.filters) {
738                    continue;
739                }
740
741                let similarity =
742                    self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
743
744                if similarity >= query.min_score {
745                    let mut result = SearchResult::new(
746                        chunk_id,
747                        &chunk.content,
748                        similarity,
749                        0, // Will be updated after sorting
750                    )
751                    .with_metadata("type", serde_json::Value::String("chunk".to_string()))
752                    .with_metadata(
753                        "document_id",
754                        serde_json::Value::String(chunk.document_id.clone()),
755                    )
756                    .with_metadata(
757                        "chunk_index",
758                        serde_json::Value::Number(chunk.chunk_index.into()),
759                    );
760
761                    // Add chunk metadata
762                    for (key, value) in &chunk.metadata {
763                        result = result.with_metadata(key, value.clone());
764                    }
765
766                    if query.config.include_embeddings {
767                        result = result.with_embedding(embedding.clone());
768                    }
769
770                    results.push(result);
771                }
772            }
773        }
774
775        // Sort by similarity score (descending)
776        results.sort_by(|a, b| {
777            b.score
778                .partial_cmp(&a.score)
779                .unwrap_or(std::cmp::Ordering::Equal)
780        });
781
782        // Apply re-ranking if enabled
783        if query.config.enable_reranking {
784            results = self.rerank_results(results, &query.config.scoring_weights);
785        }
786
787        // Update ranks after sorting
788        for (i, result) in results.iter_mut().enumerate() {
789            result.rank = i;
790        }
791
792        // Limit results
793        results.truncate(query.limit.min(self.config.max_results));
794
795        Ok(results)
796    }
797
798    async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()> {
799        let mut docs = self.documents.write().await;
800        for (document, embedding) in documents {
801            docs.insert(document.id.clone(), (document.clone(), embedding.clone()));
802        }
803        Ok(())
804    }
805
806    async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()> {
807        let mut chunk_store = self.chunks.write().await;
808        for (chunk, embedding) in chunks {
809            let chunk_id = format!("{}_{}", chunk.document_id, chunk.chunk_index);
810            chunk_store.insert(chunk_id, (chunk.clone(), embedding.clone()));
811        }
812        Ok(())
813    }
814
815    async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()> {
816        let mut docs = self.documents.write().await;
817        for doc_id in document_ids {
818            docs.remove(doc_id);
819        }
820
821        // Also remove associated chunks
822        let mut chunk_store = self.chunks.write().await;
823        let chunk_ids_to_remove: Vec<String> = chunk_store
824            .iter()
825            .filter(|(_, (chunk, _))| document_ids.contains(&chunk.document_id))
826            .map(|(id, _)| id.clone())
827            .collect();
828
829        for chunk_id in chunk_ids_to_remove {
830            chunk_store.remove(&chunk_id);
831        }
832
833        Ok(())
834    }
835
836    async fn clear(&self) -> RragResult<()> {
837        self.documents.write().await.clear();
838        self.chunks.write().await.clear();
839        Ok(())
840    }
841
842    async fn stats(&self) -> RragResult<IndexStats> {
843        let doc_count = self.documents.read().await.len();
844        let chunk_count = self.chunks.read().await.len();
845
846        // Get embedding dimensions from first item
847        let dimensions = if doc_count > 0 {
848            self.documents
849                .read()
850                .await
851                .values()
852                .next()
853                .map(|(_, emb)| emb.dimensions)
854                .unwrap_or(0)
855        } else if chunk_count > 0 {
856            self.chunks
857                .read()
858                .await
859                .values()
860                .next()
861                .map(|(_, emb)| emb.dimensions)
862                .unwrap_or(0)
863        } else {
864            0
865        };
866
867        Ok(IndexStats {
868            total_items: doc_count + chunk_count,
869            size_bytes: (doc_count + chunk_count) * dimensions * 4, // Rough estimate
870            dimensions,
871            index_type: "in_memory".to_string(),
872            last_updated: chrono::Utc::now(),
873        })
874    }
875
876    async fn health_check(&self) -> RragResult<bool> {
877        Ok(true)
878    }
879}
880
881/// High-level retrieval service
882pub struct RetrievalService {
883    /// Active retriever
884    retriever: Arc<dyn Retriever>,
885
886    /// Service configuration
887    config: RetrievalServiceConfig,
888}
889
890/// Configuration for retrieval service
891#[derive(Debug, Clone)]
892pub struct RetrievalServiceConfig {
893    /// Default search configuration
894    pub default_search_config: SearchConfig,
895
896    /// Cache query results
897    pub enable_caching: bool,
898
899    /// Cache TTL in seconds
900    pub cache_ttl_seconds: u64,
901}
902
903impl Default for RetrievalServiceConfig {
904    fn default() -> Self {
905        Self {
906            default_search_config: SearchConfig::default(),
907            enable_caching: false,
908            cache_ttl_seconds: 300, // 5 minutes
909        }
910    }
911}
912
913impl RetrievalService {
914    /// Create retrieval service
915    pub fn new(retriever: Arc<dyn Retriever>) -> Self {
916        Self {
917            retriever,
918            config: RetrievalServiceConfig::default(),
919        }
920    }
921
922    /// Create with configuration
923    pub fn with_config(retriever: Arc<dyn Retriever>, config: RetrievalServiceConfig) -> Self {
924        Self { retriever, config }
925    }
926
927    /// Search with text query (requires embedding service)
928    pub async fn search_text(
929        &self,
930        _query: &str,
931        _limit: Option<usize>,
932    ) -> RragResult<Vec<SearchResult>> {
933        // This would typically involve embedding the query text first
934        // For now, return an error indicating the limitation
935        Err(RragError::retrieval(
936            "Text search requires embedding service integration".to_string(),
937        ))
938    }
939
940    /// Search with pre-computed embedding
941    pub async fn search_embedding(
942        &self,
943        embedding: Embedding,
944        limit: Option<usize>,
945    ) -> RragResult<Vec<SearchResult>> {
946        let query = SearchQuery::embedding(embedding)
947            .with_limit(limit.unwrap_or(10))
948            .with_config(self.config.default_search_config.clone());
949
950        self.retriever.search(&query).await
951    }
952
953    /// Advanced search with full query configuration
954    pub async fn search(&self, query: SearchQuery) -> RragResult<Vec<SearchResult>> {
955        self.retriever.search(&query).await
956    }
957
958    /// Add documents to the index
959    pub async fn index_documents(
960        &self,
961        documents_with_embeddings: &[(Document, Embedding)],
962    ) -> RragResult<()> {
963        self.retriever
964            .add_documents(documents_with_embeddings)
965            .await
966    }
967
968    /// Add chunks to the index
969    pub async fn index_chunks(
970        &self,
971        chunks_with_embeddings: &[(DocumentChunk, Embedding)],
972    ) -> RragResult<()> {
973        self.retriever.add_chunks(chunks_with_embeddings).await
974    }
975
976    /// Get retriever statistics
977    pub async fn get_stats(&self) -> RragResult<IndexStats> {
978        self.retriever.stats().await
979    }
980
981    /// Health check
982    pub async fn health_check(&self) -> RragResult<bool> {
983        self.retriever.health_check().await
984    }
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use crate::Document;
991
992    #[tokio::test]
993    async fn test_in_memory_retriever() {
994        let retriever = InMemoryRetriever::new();
995
996        // Create test documents with embeddings
997        let doc1 = Document::new("First test document");
998        let emb1 = Embedding::new(vec![1.0, 0.0, 0.0], "test-model", &doc1.id);
999
1000        let doc2 = Document::new("Second test document");
1001        let emb2 = Embedding::new(vec![0.0, 1.0, 0.0], "test-model", &doc2.id);
1002
1003        // Add documents
1004        retriever
1005            .add_documents(&[(doc1.clone(), emb1.clone()), (doc2, emb2)])
1006            .await
1007            .unwrap();
1008
1009        // Create query
1010        let query_embedding = Embedding::new(vec![0.8, 0.2, 0.0], "test-model", "query");
1011        let query = SearchQuery::embedding(query_embedding).with_limit(5);
1012
1013        // Search
1014        let results = retriever.search(&query).await.unwrap();
1015
1016        assert!(!results.is_empty());
1017        assert_eq!(results[0].id, doc1.id); // Should be most similar
1018    }
1019
1020    #[tokio::test]
1021    async fn test_search_filters() {
1022        let retriever = InMemoryRetriever::new();
1023
1024        let doc1 = Document::new("Test document")
1025            .with_metadata("category", serde_json::Value::String("tech".to_string()));
1026        let emb1 = Embedding::new(vec![1.0, 0.0], "test-model", &doc1.id);
1027
1028        let doc2 = Document::new("Another document")
1029            .with_metadata("category", serde_json::Value::String("science".to_string()));
1030        let emb2 = Embedding::new(vec![0.9, 0.1], "test-model", &doc2.id);
1031
1032        retriever
1033            .add_documents(&[(doc1.clone(), emb1), (doc2, emb2)])
1034            .await
1035            .unwrap();
1036
1037        // Search with filter
1038        let query_embedding = Embedding::new(vec![1.0, 0.0], "test-model", "query");
1039        let query = SearchQuery::embedding(query_embedding)
1040            .with_filter("category", serde_json::Value::String("tech".to_string()));
1041
1042        let results = retriever.search(&query).await.unwrap();
1043
1044        assert_eq!(results.len(), 1);
1045        assert_eq!(results[0].id, doc1.id);
1046    }
1047
1048    #[test]
1049    fn test_search_query_builder() {
1050        let query = SearchQuery::text("test query")
1051            .with_limit(20)
1052            .with_min_score(0.5)
1053            .with_filter("type", serde_json::Value::String("article".to_string()));
1054
1055        assert_eq!(query.limit, 20);
1056        assert_eq!(query.min_score, 0.5);
1057        assert_eq!(query.filters.len(), 1);
1058    }
1059
1060    #[tokio::test]
1061    async fn test_retrieval_service() {
1062        let retriever = Arc::new(InMemoryRetriever::new());
1063        let service = RetrievalService::new(retriever);
1064
1065        let stats = service.get_stats().await.unwrap();
1066        assert_eq!(stats.total_items, 0);
1067
1068        assert!(service.health_check().await.unwrap());
1069    }
1070}