1use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery};
11use crate::error::ContextResult;
12use crate::storage::ContextStore;
13use crate::temporal::{TemporalQuery, TemporalStats};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagConfig {
18 pub max_results: usize,
20 pub min_relevance: f64,
22 pub parallel: bool,
24 pub num_threads: usize,
26 pub temporal_decay: bool,
28 pub safe_only: bool,
30 pub chunk_size: usize,
32}
33
34impl Default for RagConfig {
35 fn default() -> Self {
36 Self {
37 max_results: 10,
38 min_relevance: 0.1,
39 parallel: true,
40 num_threads: 0, temporal_decay: true,
42 safe_only: true,
43 chunk_size: 1000,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ScoredContext {
51 pub context: Context,
53 pub score: f64,
55 pub score_breakdown: ScoreBreakdown,
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ScoreBreakdown {
62 pub temporal: f64,
64 pub importance: f64,
66 pub domain_match: f64,
68 pub tag_match: f64,
70 pub similarity: Option<f64>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RetrievalResult {
77 pub contexts: Vec<ScoredContext>,
79 pub query_summary: String,
81 pub processing_time_ms: u64,
83 pub candidates_considered: usize,
85 pub temporal_stats: TemporalStats,
87}
88
89pub struct RagProcessor {
91 config: RagConfig,
92 store: Arc<ContextStore>,
93}
94
95impl RagProcessor {
96 pub fn new(store: Arc<ContextStore>, config: RagConfig) -> Self {
98 if config.num_threads > 0 {
100 rayon::ThreadPoolBuilder::new()
101 .num_threads(config.num_threads)
102 .build_global()
103 .ok();
104 }
105
106 Self { config, store }
107 }
108
109 pub fn with_defaults(store: Arc<ContextStore>) -> Self {
111 Self::new(store, RagConfig::default())
112 }
113
114 pub async fn retrieve(&self, query: &RetrievalQuery) -> ContextResult<RetrievalResult> {
116 let start = std::time::Instant::now();
117
118 let mut ctx_query = ContextQuery::new();
120
121 if let Some(domain) = &query.domain {
122 ctx_query = ctx_query.with_domain(domain.clone());
123 }
124
125 for tag in &query.tags {
126 ctx_query = ctx_query.with_tag(tag.clone());
127 }
128
129 if let Some(min_importance) = query.min_importance {
130 ctx_query = ctx_query.with_min_importance(min_importance);
131 }
132
133 let candidates: Vec<Context> = self.store.query(&ctx_query).await?;
135 let candidates_count = candidates.len();
136
137 let temporal_query = query.temporal.clone().unwrap_or_default();
139 let filtered: Vec<Context> = candidates
140 .into_iter()
141 .filter(|c| temporal_query.matches(c))
142 .filter(|c| !self.config.safe_only || c.is_safe())
143 .collect();
144
145 let scored = if self.config.parallel && filtered.len() > self.config.chunk_size {
147 self.score_parallel(&filtered, query, &temporal_query)
148 } else {
149 self.score_sequential(&filtered, query, &temporal_query)
150 };
151
152 let mut results: Vec<ScoredContext> = scored
154 .into_iter()
155 .filter(|s| s.score >= self.config.min_relevance)
156 .collect();
157
158 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
159 results.truncate(self.config.max_results);
160
161 let temporal_stats = TemporalStats::from_contexts(
162 &results.iter().map(|s| s.context.clone()).collect::<Vec<_>>()
163 );
164
165 Ok(RetrievalResult {
166 contexts: results,
167 query_summary: query.to_string(),
168 processing_time_ms: start.elapsed().as_millis() as u64,
169 candidates_considered: candidates_count,
170 temporal_stats,
171 })
172 }
173
174 fn score_parallel(
176 &self,
177 contexts: &[Context],
178 query: &RetrievalQuery,
179 temporal: &TemporalQuery,
180 ) -> Vec<ScoredContext> {
181 contexts
182 .par_iter()
183 .map(|ctx| self.score_context(ctx, query, temporal))
184 .collect()
185 }
186
187 fn score_sequential(
189 &self,
190 contexts: &[Context],
191 query: &RetrievalQuery,
192 temporal: &TemporalQuery,
193 ) -> Vec<ScoredContext> {
194 contexts
195 .iter()
196 .map(|ctx| self.score_context(ctx, query, temporal))
197 .collect()
198 }
199
200 fn score_context(
202 &self,
203 ctx: &Context,
204 query: &RetrievalQuery,
205 temporal: &TemporalQuery,
206 ) -> ScoredContext {
207 let mut breakdown = ScoreBreakdown::default();
208
209 breakdown.temporal = if self.config.temporal_decay {
211 temporal.relevance_score(ctx)
212 } else {
213 1.0
214 };
215
216 breakdown.importance = ctx.metadata.importance as f64;
218
219 breakdown.domain_match = if query.domain.as_ref() == Some(&ctx.domain) {
221 1.0
222 } else if query.domain.is_none() {
223 0.5 } else {
225 0.2 };
227
228 if !query.tags.is_empty() {
230 let matching_tags = query
231 .tags
232 .iter()
233 .filter(|t| ctx.metadata.tags.contains(*t))
234 .count();
235 breakdown.tag_match = matching_tags as f64 / query.tags.len() as f64;
236 } else {
237 breakdown.tag_match = 0.5; }
239
240 breakdown.similarity = None;
244
245 let score = 0.25 * breakdown.temporal
247 + 0.25 * breakdown.importance
248 + 0.25 * breakdown.domain_match
249 + 0.25 * breakdown.tag_match;
250
251 ScoredContext {
252 context: ctx.clone(),
253 score,
254 score_breakdown: breakdown,
255 }
256 }
257
258 pub async fn retrieve_by_text(&self, text: &str) -> ContextResult<RetrievalResult> {
260 let query = RetrievalQuery::from_text(text);
261 self.retrieve(&query).await
262 }
263
264 pub fn config(&self) -> &RagConfig {
266 &self.config
267 }
268}
269
270#[derive(Debug, Clone, Default, Serialize, Deserialize)]
272pub struct RetrievalQuery {
273 pub text: Option<String>,
275 pub domain: Option<ContextDomain>,
277 pub tags: Vec<String>,
279 pub min_importance: Option<f32>,
281 pub temporal: Option<TemporalQuery>,
283 pub max_results: Option<usize>,
285}
286
287impl RetrievalQuery {
288 pub fn new() -> Self {
290 Self::default()
291 }
292
293 pub fn from_text(text: &str) -> Self {
295 Self {
296 text: Some(text.to_string()),
297 ..Default::default()
298 }
299 }
300
301 pub fn with_domain(mut self, domain: ContextDomain) -> Self {
303 self.domain = Some(domain);
304 self
305 }
306
307 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
309 self.tags.push(tag.into());
310 self
311 }
312
313 pub fn with_min_importance(mut self, importance: f32) -> Self {
315 self.min_importance = Some(importance);
316 self
317 }
318
319 pub fn with_temporal(mut self, temporal: TemporalQuery) -> Self {
321 self.temporal = Some(temporal);
322 self
323 }
324
325 pub fn recent(hours: i64) -> Self {
327 Self::new().with_temporal(TemporalQuery::recent(hours))
328 }
329}
330
331impl std::fmt::Display for RetrievalQuery {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 let mut parts = Vec::new();
334
335 if let Some(text) = &self.text {
336 parts.push(format!("text: '{}'", text));
337 }
338 if let Some(domain) = &self.domain {
339 parts.push(format!("domain: {:?}", domain));
340 }
341 if !self.tags.is_empty() {
342 parts.push(format!("tags: {:?}", self.tags));
343 }
344 if let Some(importance) = self.min_importance {
345 parts.push(format!("min_importance: {}", importance));
346 }
347
348 if parts.is_empty() {
349 write!(f, "all contexts")
350 } else {
351 write!(f, "{}", parts.join(", "))
352 }
353 }
354}
355
356pub struct BatchProcessor {
358 processor: Arc<RagProcessor>,
359}
360
361impl BatchProcessor {
362 pub fn new(processor: Arc<RagProcessor>) -> Self {
364 Self { processor }
365 }
366
367 pub async fn process_batch(&self, queries: Vec<RetrievalQuery>) -> Vec<ContextResult<RetrievalResult>> {
369 let mut results = Vec::with_capacity(queries.len());
370 for query in queries {
371 results.push(self.processor.retrieve(&query).await);
372 }
373 results
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::storage::StorageConfig;
381 use tempfile::TempDir;
382
383 fn create_test_store() -> (Arc<ContextStore>, TempDir) {
384 let temp_dir = TempDir::new().unwrap();
385 let config = StorageConfig {
386 persist_path: Some(temp_dir.path().to_path_buf()),
387 enable_persistence: true,
388 ..Default::default()
389 };
390 let store = ContextStore::new(config).unwrap();
391 (Arc::new(store), temp_dir)
392 }
393
394 #[test]
395 fn test_retrieval_query() {
396 let query = RetrievalQuery::from_text("test query")
397 .with_domain(ContextDomain::Code)
398 .with_tag("rust");
399
400 assert_eq!(query.text, Some("test query".to_string()));
401 assert_eq!(query.domain, Some(ContextDomain::Code));
402 assert!(query.tags.contains(&"rust".to_string()));
403 }
404
405 #[tokio::test]
406 async fn test_rag_processor() {
407 let (store, _temp) = create_test_store();
408 let processor = RagProcessor::with_defaults(store.clone());
409
410 let ctx = Context::new("Test content", ContextDomain::Code);
412 store.store(ctx).await.unwrap();
413
414 let result = processor.retrieve(&RetrievalQuery::new()).await.unwrap();
416 assert_eq!(result.candidates_considered, 1);
417 }
418}