context_mcp/
rag.rs

1//! CPU-optimized text-based context retrieval with scoring
2//!
3//! Provides parallel processing capabilities using rayon for efficient
4//! text matching and relevance scoring of stored contexts.
5
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery};
11use crate::error::ContextResult;
12use crate::storage::ContextStore;
13use crate::temporal::{TemporalQuery, TemporalStats};
14
15/// RAG processor configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagConfig {
18    /// Maximum results per query
19    pub max_results: usize,
20    /// Minimum relevance threshold (0.0 to 1.0)
21    pub min_relevance: f64,
22    /// Enable parallel processing
23    pub parallel: bool,
24    /// Number of threads (0 = auto)
25    pub num_threads: usize,
26    /// Apply temporal decay to scoring
27    pub temporal_decay: bool,
28    /// Only retrieve screened-safe contexts
29    pub safe_only: bool,
30    /// Chunk size for parallel processing
31    pub chunk_size: usize,
32}
33
34impl Default for RagConfig {
35    fn default() -> Self {
36        Self {
37            max_results: 10,
38            min_relevance: 0.1,
39            parallel: true,
40            num_threads: 0, // Auto-detect
41            temporal_decay: true,
42            safe_only: true,
43            chunk_size: 1000,
44        }
45    }
46}
47
48/// Result from RAG retrieval with scoring
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ScoredContext {
51    /// The context
52    pub context: Context,
53    /// Relevance score (0.0 to 1.0)
54    pub score: f64,
55    /// Contributing score components
56    pub score_breakdown: ScoreBreakdown,
57}
58
59/// Breakdown of score components
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ScoreBreakdown {
62    /// Temporal relevance
63    pub temporal: f64,
64    /// Importance score
65    pub importance: f64,
66    /// Domain match score
67    pub domain_match: f64,
68    /// Tag match score
69    pub tag_match: f64,
70    /// Content similarity (if embedding available)
71    pub similarity: Option<f64>,
72}
73
74/// RAG retrieval results
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RetrievalResult {
77    /// Scored contexts
78    pub contexts: Vec<ScoredContext>,
79    /// Query used
80    pub query_summary: String,
81    /// Processing time in ms
82    pub processing_time_ms: u64,
83    /// Total candidates considered
84    pub candidates_considered: usize,
85    /// Temporal statistics
86    pub temporal_stats: TemporalStats,
87}
88
89/// CPU-optimized RAG processor
90pub struct RagProcessor {
91    config: RagConfig,
92    store: Arc<ContextStore>,
93}
94
95impl RagProcessor {
96    /// Create a new RAG processor
97    pub fn new(store: Arc<ContextStore>, config: RagConfig) -> Self {
98        // Configure thread pool if specified
99        if config.num_threads > 0 {
100            rayon::ThreadPoolBuilder::new()
101                .num_threads(config.num_threads)
102                .build_global()
103                .ok();
104        }
105
106        Self { config, store }
107    }
108
109    /// Create with default configuration
110    pub fn with_defaults(store: Arc<ContextStore>) -> Self {
111        Self::new(store, RagConfig::default())
112    }
113
114    /// Retrieve contexts using a query
115    pub async fn retrieve(&self, query: &RetrievalQuery) -> ContextResult<RetrievalResult> {
116        let start = std::time::Instant::now();
117
118        // Build context query
119        let mut ctx_query = ContextQuery::new();
120
121        if let Some(domain) = &query.domain {
122            ctx_query = ctx_query.with_domain(domain.clone());
123        }
124
125        for tag in &query.tags {
126            ctx_query = ctx_query.with_tag(tag.clone());
127        }
128
129        if let Some(min_importance) = query.min_importance {
130            ctx_query = ctx_query.with_min_importance(min_importance);
131        }
132
133        // Get candidates from storage
134        let candidates: Vec<Context> = self.store.query(&ctx_query).await?;
135        let candidates_count = candidates.len();
136
137        // Apply temporal filtering
138        let temporal_query = query.temporal.clone().unwrap_or_default();
139        let filtered: Vec<Context> = candidates
140            .into_iter()
141            .filter(|c| temporal_query.matches(c))
142            .filter(|c| !self.config.safe_only || c.is_safe())
143            .collect();
144
145        // Score contexts (parallel or sequential)
146        let scored = if self.config.parallel && filtered.len() > self.config.chunk_size {
147            self.score_parallel(&filtered, query, &temporal_query)
148        } else {
149            self.score_sequential(&filtered, query, &temporal_query)
150        };
151
152        // Filter by minimum relevance and sort
153        let mut results: Vec<ScoredContext> = scored
154            .into_iter()
155            .filter(|s| s.score >= self.config.min_relevance)
156            .collect();
157
158        results.sort_by(|a, b| {
159            b.score
160                .partial_cmp(&a.score)
161                .unwrap_or(std::cmp::Ordering::Equal)
162        });
163        results.truncate(self.config.max_results);
164
165        let temporal_stats = TemporalStats::from_contexts(
166            &results
167                .iter()
168                .map(|s| s.context.clone())
169                .collect::<Vec<_>>(),
170        );
171
172        Ok(RetrievalResult {
173            contexts: results,
174            query_summary: query.to_string(),
175            processing_time_ms: start.elapsed().as_millis() as u64,
176            candidates_considered: candidates_count,
177            temporal_stats,
178        })
179    }
180
181    /// Score contexts in parallel using rayon
182    fn score_parallel(
183        &self,
184        contexts: &[Context],
185        query: &RetrievalQuery,
186        temporal: &TemporalQuery,
187    ) -> Vec<ScoredContext> {
188        contexts
189            .par_iter()
190            .map(|ctx| self.score_context(ctx, query, temporal))
191            .collect()
192    }
193
194    /// Score contexts sequentially
195    fn score_sequential(
196        &self,
197        contexts: &[Context],
198        query: &RetrievalQuery,
199        temporal: &TemporalQuery,
200    ) -> Vec<ScoredContext> {
201        contexts
202            .iter()
203            .map(|ctx| self.score_context(ctx, query, temporal))
204            .collect()
205    }
206
207    /// Score a single context
208    fn score_context(
209        &self,
210        ctx: &Context,
211        query: &RetrievalQuery,
212        temporal: &TemporalQuery,
213    ) -> ScoredContext {
214        let temporal_score = if self.config.temporal_decay {
215            temporal.relevance_score(ctx)
216        } else {
217            1.0
218        };
219
220        let importance_score = ctx.metadata.importance as f64;
221
222        let domain_match_score = if query.domain.as_ref() == Some(&ctx.domain) {
223            1.0
224        } else if query.domain.is_none() {
225            0.5 // Neutral if no domain specified
226        } else {
227            0.2 // Partial credit for different domains
228        };
229
230        let tag_match_score = if !query.tags.is_empty() {
231            let matching_tags = query
232                .tags
233                .iter()
234                .filter(|t| ctx.metadata.tags.contains(*t))
235                .count();
236            matching_tags as f64 / query.tags.len() as f64
237        } else {
238            0.5 // Neutral
239        };
240
241        let breakdown = ScoreBreakdown {
242            temporal: temporal_score,
243            importance: importance_score,
244            domain_match: domain_match_score,
245            tag_match: tag_match_score,
246            similarity: None, // Placeholder for embedding-based scoring
247        };
248
249        // Weighted final score
250        let score = 0.25 * breakdown.temporal
251            + 0.25 * breakdown.importance
252            + 0.25 * breakdown.domain_match
253            + 0.25 * breakdown.tag_match;
254
255        ScoredContext {
256            context: ctx.clone(),
257            score,
258            score_breakdown: breakdown,
259        }
260    }
261
262    /// Retrieve by text query with simple keyword matching
263    pub async fn retrieve_by_text(&self, text: &str) -> ContextResult<RetrievalResult> {
264        let query = RetrievalQuery::from_text(text);
265        self.retrieve(&query).await
266    }
267
268    /// Get configuration
269    pub fn config(&self) -> &RagConfig {
270        &self.config
271    }
272}
273
274/// Query for RAG retrieval
275#[derive(Debug, Clone, Default, Serialize, Deserialize)]
276pub struct RetrievalQuery {
277    /// Text query (for keyword/semantic matching)
278    pub text: Option<String>,
279    /// Domain filter
280    pub domain: Option<ContextDomain>,
281    /// Tag filters
282    pub tags: Vec<String>,
283    /// Minimum importance
284    pub min_importance: Option<f32>,
285    /// Temporal query parameters
286    pub temporal: Option<TemporalQuery>,
287    /// Maximum results
288    pub max_results: Option<usize>,
289}
290
291impl RetrievalQuery {
292    /// Create a new retrieval query
293    pub fn new() -> Self {
294        Self::default()
295    }
296
297    /// Create from text
298    pub fn from_text(text: &str) -> Self {
299        Self {
300            text: Some(text.to_string()),
301            ..Default::default()
302        }
303    }
304
305    /// Set domain filter
306    pub fn with_domain(mut self, domain: ContextDomain) -> Self {
307        self.domain = Some(domain);
308        self
309    }
310
311    /// Add tag filter
312    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
313        self.tags.push(tag.into());
314        self
315    }
316
317    /// Set minimum importance
318    pub fn with_min_importance(mut self, importance: f32) -> Self {
319        self.min_importance = Some(importance);
320        self
321    }
322
323    /// Set temporal parameters
324    pub fn with_temporal(mut self, temporal: TemporalQuery) -> Self {
325        self.temporal = Some(temporal);
326        self
327    }
328
329    /// Query for recent contexts
330    pub fn recent(hours: i64) -> Self {
331        Self::new().with_temporal(TemporalQuery::recent(hours))
332    }
333}
334
335impl std::fmt::Display for RetrievalQuery {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        let mut parts = Vec::new();
338
339        if let Some(text) = &self.text {
340            parts.push(format!("text: '{}'", text));
341        }
342        if let Some(domain) = &self.domain {
343            parts.push(format!("domain: {:?}", domain));
344        }
345        if !self.tags.is_empty() {
346            parts.push(format!("tags: {:?}", self.tags));
347        }
348        if let Some(importance) = self.min_importance {
349            parts.push(format!("min_importance: {}", importance));
350        }
351
352        if parts.is_empty() {
353            write!(f, "all contexts")
354        } else {
355            write!(f, "{}", parts.join(", "))
356        }
357    }
358}
359
360/// Batch processing for multiple queries
361pub struct BatchProcessor {
362    processor: Arc<RagProcessor>,
363}
364
365impl BatchProcessor {
366    /// Create a new batch processor
367    pub fn new(processor: Arc<RagProcessor>) -> Self {
368        Self { processor }
369    }
370
371    /// Process multiple queries (sequential for async compatibility)
372    pub async fn process_batch(
373        &self,
374        queries: Vec<RetrievalQuery>,
375    ) -> Vec<ContextResult<RetrievalResult>> {
376        let mut results = Vec::with_capacity(queries.len());
377        for query in queries {
378            results.push(self.processor.retrieve(&query).await);
379        }
380        results
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::storage::StorageConfig;
388    use tempfile::TempDir;
389
390    fn create_test_store() -> (Arc<ContextStore>, TempDir) {
391        let temp_dir = TempDir::new().unwrap();
392        let config = StorageConfig {
393            persist_path: Some(temp_dir.path().to_path_buf()),
394            enable_persistence: true,
395            ..Default::default()
396        };
397        let store = ContextStore::new(config).unwrap();
398        (Arc::new(store), temp_dir)
399    }
400
401    #[test]
402    fn test_retrieval_query() {
403        let query = RetrievalQuery::from_text("test query")
404            .with_domain(ContextDomain::Code)
405            .with_tag("rust");
406
407        assert_eq!(query.text, Some("test query".to_string()));
408        assert_eq!(query.domain, Some(ContextDomain::Code));
409        assert!(query.tags.contains(&"rust".to_string()));
410    }
411
412    #[tokio::test]
413    async fn test_rag_processor() {
414        let (store, _temp) = create_test_store();
415        let processor = RagProcessor::with_defaults(store.clone());
416
417        // Add test context
418        let ctx = Context::new("Test content", ContextDomain::Code);
419        store.store(ctx).await.unwrap();
420
421        // Retrieve
422        let result = processor.retrieve(&RetrievalQuery::new()).await.unwrap();
423        assert_eq!(result.candidates_considered, 1);
424    }
425}