context_mcp/
rag.rs

1//! CPU-optimized RAG processing for context retrieval
2//!
3//! Provides parallel processing capabilities for efficient
4//! retrieval-augmented generation operations on screened safe inputs.
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery};
11use crate::error::ContextResult;
12use crate::storage::ContextStore;
13use crate::temporal::{TemporalQuery, TemporalStats};
14
15/// RAG processor configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagConfig {
18    /// Maximum results per query
19    pub max_results: usize,
20    /// Minimum relevance threshold (0.0 to 1.0)
21    pub min_relevance: f64,
22    /// Enable parallel processing
23    pub parallel: bool,
24    /// Number of threads (0 = auto)
25    pub num_threads: usize,
26    /// Apply temporal decay to scoring
27    pub temporal_decay: bool,
28    /// Only retrieve screened-safe contexts
29    pub safe_only: bool,
30    /// Chunk size for parallel processing
31    pub chunk_size: usize,
32}
33
34impl Default for RagConfig {
35    fn default() -> Self {
36        Self {
37            max_results: 10,
38            min_relevance: 0.1,
39            parallel: true,
40            num_threads: 0, // Auto-detect
41            temporal_decay: true,
42            safe_only: true,
43            chunk_size: 1000,
44        }
45    }
46}
47
48/// Result from RAG retrieval with scoring
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ScoredContext {
51    /// The context
52    pub context: Context,
53    /// Relevance score (0.0 to 1.0)
54    pub score: f64,
55    /// Contributing score components
56    pub score_breakdown: ScoreBreakdown,
57}
58
59/// Breakdown of score components
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ScoreBreakdown {
62    /// Temporal relevance
63    pub temporal: f64,
64    /// Importance score
65    pub importance: f64,
66    /// Domain match score
67    pub domain_match: f64,
68    /// Tag match score
69    pub tag_match: f64,
70    /// Content similarity (if embedding available)
71    pub similarity: Option<f64>,
72}
73
74/// RAG retrieval results
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RetrievalResult {
77    /// Scored contexts
78    pub contexts: Vec<ScoredContext>,
79    /// Query used
80    pub query_summary: String,
81    /// Processing time in ms
82    pub processing_time_ms: u64,
83    /// Total candidates considered
84    pub candidates_considered: usize,
85    /// Temporal statistics
86    pub temporal_stats: TemporalStats,
87}
88
89/// CPU-optimized RAG processor
90pub struct RagProcessor {
91    config: RagConfig,
92    store: Arc<ContextStore>,
93}
94
95impl RagProcessor {
96    /// Create a new RAG processor
97    pub fn new(store: Arc<ContextStore>, config: RagConfig) -> Self {
98        // Configure thread pool if specified
99        if config.num_threads > 0 {
100            rayon::ThreadPoolBuilder::new()
101                .num_threads(config.num_threads)
102                .build_global()
103                .ok();
104        }
105
106        Self { config, store }
107    }
108
109    /// Create with default configuration
110    pub fn with_defaults(store: Arc<ContextStore>) -> Self {
111        Self::new(store, RagConfig::default())
112    }
113
114    /// Retrieve contexts using a query
115    pub async fn retrieve(&self, query: &RetrievalQuery) -> ContextResult<RetrievalResult> {
116        let start = std::time::Instant::now();
117
118        // Build context query
119        let mut ctx_query = ContextQuery::new();
120        
121        if let Some(domain) = &query.domain {
122            ctx_query = ctx_query.with_domain(domain.clone());
123        }
124        
125        for tag in &query.tags {
126            ctx_query = ctx_query.with_tag(tag.clone());
127        }
128
129        if let Some(min_importance) = query.min_importance {
130            ctx_query = ctx_query.with_min_importance(min_importance);
131        }
132
133        // Get candidates from storage
134        let candidates: Vec<Context> = self.store.query(&ctx_query).await?;
135        let candidates_count = candidates.len();
136
137        // Apply temporal filtering
138        let temporal_query = query.temporal.clone().unwrap_or_default();
139        let filtered: Vec<Context> = candidates
140            .into_iter()
141            .filter(|c| temporal_query.matches(c))
142            .filter(|c| !self.config.safe_only || c.is_safe())
143            .collect();
144
145        // Score contexts (parallel or sequential)
146        let scored = if self.config.parallel && filtered.len() > self.config.chunk_size {
147            self.score_parallel(&filtered, query, &temporal_query)
148        } else {
149            self.score_sequential(&filtered, query, &temporal_query)
150        };
151
152        // Filter by minimum relevance and sort
153        let mut results: Vec<ScoredContext> = scored
154            .into_iter()
155            .filter(|s| s.score >= self.config.min_relevance)
156            .collect();
157
158        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
159        results.truncate(self.config.max_results);
160
161        let temporal_stats = TemporalStats::from_contexts(
162            &results.iter().map(|s| s.context.clone()).collect::<Vec<_>>()
163        );
164
165        Ok(RetrievalResult {
166            contexts: results,
167            query_summary: query.to_string(),
168            processing_time_ms: start.elapsed().as_millis() as u64,
169            candidates_considered: candidates_count,
170            temporal_stats,
171        })
172    }
173
174    /// Score contexts in parallel using rayon
175    fn score_parallel(
176        &self,
177        contexts: &[Context],
178        query: &RetrievalQuery,
179        temporal: &TemporalQuery,
180    ) -> Vec<ScoredContext> {
181        contexts
182            .par_iter()
183            .map(|ctx| self.score_context(ctx, query, temporal))
184            .collect()
185    }
186
187    /// Score contexts sequentially
188    fn score_sequential(
189        &self,
190        contexts: &[Context],
191        query: &RetrievalQuery,
192        temporal: &TemporalQuery,
193    ) -> Vec<ScoredContext> {
194        contexts
195            .iter()
196            .map(|ctx| self.score_context(ctx, query, temporal))
197            .collect()
198    }
199
200    /// Score a single context
201    fn score_context(
202        &self,
203        ctx: &Context,
204        query: &RetrievalQuery,
205        temporal: &TemporalQuery,
206    ) -> ScoredContext {
207        let mut breakdown = ScoreBreakdown::default();
208
209        // Temporal score
210        breakdown.temporal = if self.config.temporal_decay {
211            temporal.relevance_score(ctx)
212        } else {
213            1.0
214        };
215
216        // Importance score
217        breakdown.importance = ctx.metadata.importance as f64;
218
219        // Domain match score
220        breakdown.domain_match = if query.domain.as_ref() == Some(&ctx.domain) {
221            1.0
222        } else if query.domain.is_none() {
223            0.5 // Neutral if no domain specified
224        } else {
225            0.2 // Partial credit for different domains
226        };
227
228        // Tag match score
229        if !query.tags.is_empty() {
230            let matching_tags = query
231                .tags
232                .iter()
233                .filter(|t| ctx.metadata.tags.contains(*t))
234                .count();
235            breakdown.tag_match = matching_tags as f64 / query.tags.len() as f64;
236        } else {
237            breakdown.tag_match = 0.5; // Neutral
238        }
239
240        // Content similarity (placeholder for embedding-based scoring)
241        // In a full implementation, this would compute cosine similarity
242        // between query embedding and context embedding
243        breakdown.similarity = None;
244
245        // Weighted final score
246        let score = 0.25 * breakdown.temporal
247            + 0.25 * breakdown.importance
248            + 0.25 * breakdown.domain_match
249            + 0.25 * breakdown.tag_match;
250
251        ScoredContext {
252            context: ctx.clone(),
253            score,
254            score_breakdown: breakdown,
255        }
256    }
257
258    /// Retrieve by text query with simple keyword matching
259    pub async fn retrieve_by_text(&self, text: &str) -> ContextResult<RetrievalResult> {
260        let query = RetrievalQuery::from_text(text);
261        self.retrieve(&query).await
262    }
263
264    /// Get configuration
265    pub fn config(&self) -> &RagConfig {
266        &self.config
267    }
268}
269
270/// Query for RAG retrieval
271#[derive(Debug, Clone, Default, Serialize, Deserialize)]
272pub struct RetrievalQuery {
273    /// Text query (for keyword/semantic matching)
274    pub text: Option<String>,
275    /// Domain filter
276    pub domain: Option<ContextDomain>,
277    /// Tag filters
278    pub tags: Vec<String>,
279    /// Minimum importance
280    pub min_importance: Option<f32>,
281    /// Temporal query parameters
282    pub temporal: Option<TemporalQuery>,
283    /// Maximum results
284    pub max_results: Option<usize>,
285}
286
287impl RetrievalQuery {
288    /// Create a new retrieval query
289    pub fn new() -> Self {
290        Self::default()
291    }
292
293    /// Create from text
294    pub fn from_text(text: &str) -> Self {
295        Self {
296            text: Some(text.to_string()),
297            ..Default::default()
298        }
299    }
300
301    /// Set domain filter
302    pub fn with_domain(mut self, domain: ContextDomain) -> Self {
303        self.domain = Some(domain);
304        self
305    }
306
307    /// Add tag filter
308    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
309        self.tags.push(tag.into());
310        self
311    }
312
313    /// Set minimum importance
314    pub fn with_min_importance(mut self, importance: f32) -> Self {
315        self.min_importance = Some(importance);
316        self
317    }
318
319    /// Set temporal parameters
320    pub fn with_temporal(mut self, temporal: TemporalQuery) -> Self {
321        self.temporal = Some(temporal);
322        self
323    }
324
325    /// Query for recent contexts
326    pub fn recent(hours: i64) -> Self {
327        Self::new().with_temporal(TemporalQuery::recent(hours))
328    }
329}
330
331impl std::fmt::Display for RetrievalQuery {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        let mut parts = Vec::new();
334        
335        if let Some(text) = &self.text {
336            parts.push(format!("text: '{}'", text));
337        }
338        if let Some(domain) = &self.domain {
339            parts.push(format!("domain: {:?}", domain));
340        }
341        if !self.tags.is_empty() {
342            parts.push(format!("tags: {:?}", self.tags));
343        }
344        if let Some(importance) = self.min_importance {
345            parts.push(format!("min_importance: {}", importance));
346        }
347        
348        if parts.is_empty() {
349            write!(f, "all contexts")
350        } else {
351            write!(f, "{}", parts.join(", "))
352        }
353    }
354}
355
356/// Batch processing for multiple queries
357pub struct BatchProcessor {
358    processor: Arc<RagProcessor>,
359}
360
361impl BatchProcessor {
362    /// Create a new batch processor
363    pub fn new(processor: Arc<RagProcessor>) -> Self {
364        Self { processor }
365    }
366
367    /// Process multiple queries (sequential for async compatibility)
368    pub async fn process_batch(&self, queries: Vec<RetrievalQuery>) -> Vec<ContextResult<RetrievalResult>> {
369        let mut results = Vec::with_capacity(queries.len());
370        for query in queries {
371            results.push(self.processor.retrieve(&query).await);
372        }
373        results
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::storage::StorageConfig;
381    use tempfile::TempDir;
382
383    fn create_test_store() -> (Arc<ContextStore>, TempDir) {
384        let temp_dir = TempDir::new().unwrap();
385        let config = StorageConfig {
386            persist_path: Some(temp_dir.path().to_path_buf()),
387            enable_persistence: true,
388            ..Default::default()
389        };
390        let store = ContextStore::new(config).unwrap();
391        (Arc::new(store), temp_dir)
392    }
393
394    #[test]
395    fn test_retrieval_query() {
396        let query = RetrievalQuery::from_text("test query")
397            .with_domain(ContextDomain::Code)
398            .with_tag("rust");
399
400        assert_eq!(query.text, Some("test query".to_string()));
401        assert_eq!(query.domain, Some(ContextDomain::Code));
402        assert!(query.tags.contains(&"rust".to_string()));
403    }
404
405    #[tokio::test]
406    async fn test_rag_processor() {
407        let (store, _temp) = create_test_store();
408        let processor = RagProcessor::with_defaults(store.clone());
409
410        // Add test context
411        let ctx = Context::new("Test content", ContextDomain::Code);
412        store.store(ctx).await.unwrap();
413
414        // Retrieve
415        let result = processor.retrieve(&RetrievalQuery::new()).await.unwrap();
416        assert_eq!(result.candidates_considered, 1);
417    }
418}