1use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10use crate::context::{Context, ContextDomain, ContextQuery};
11use crate::error::ContextResult;
12use crate::storage::ContextStore;
13use crate::temporal::{TemporalQuery, TemporalStats};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagConfig {
18 pub max_results: usize,
20 pub min_relevance: f64,
22 pub parallel: bool,
24 pub num_threads: usize,
26 pub temporal_decay: bool,
28 pub safe_only: bool,
30 pub chunk_size: usize,
32}
33
34impl Default for RagConfig {
35 fn default() -> Self {
36 Self {
37 max_results: 10,
38 min_relevance: 0.1,
39 parallel: true,
40 num_threads: 0, temporal_decay: true,
42 safe_only: true,
43 chunk_size: 1000,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ScoredContext {
51 pub context: Context,
53 pub score: f64,
55 pub score_breakdown: ScoreBreakdown,
57}
58
59#[derive(Debug, Clone, Default, Serialize, Deserialize)]
61pub struct ScoreBreakdown {
62 pub temporal: f64,
64 pub importance: f64,
66 pub domain_match: f64,
68 pub tag_match: f64,
70 pub similarity: Option<f64>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RetrievalResult {
77 pub contexts: Vec<ScoredContext>,
79 pub query_summary: String,
81 pub processing_time_ms: u64,
83 pub candidates_considered: usize,
85 pub temporal_stats: TemporalStats,
87}
88
89pub struct RagProcessor {
91 config: RagConfig,
92 store: Arc<ContextStore>,
93}
94
95impl RagProcessor {
96 pub fn new(store: Arc<ContextStore>, config: RagConfig) -> Self {
98 if config.num_threads > 0 {
100 rayon::ThreadPoolBuilder::new()
101 .num_threads(config.num_threads)
102 .build_global()
103 .ok();
104 }
105
106 Self { config, store }
107 }
108
109 pub fn with_defaults(store: Arc<ContextStore>) -> Self {
111 Self::new(store, RagConfig::default())
112 }
113
114 pub async fn retrieve(&self, query: &RetrievalQuery) -> ContextResult<RetrievalResult> {
116 let start = std::time::Instant::now();
117
118 let mut ctx_query = ContextQuery::new();
120
121 if let Some(domain) = &query.domain {
122 ctx_query = ctx_query.with_domain(domain.clone());
123 }
124
125 for tag in &query.tags {
126 ctx_query = ctx_query.with_tag(tag.clone());
127 }
128
129 if let Some(min_importance) = query.min_importance {
130 ctx_query = ctx_query.with_min_importance(min_importance);
131 }
132
133 let candidates: Vec<Context> = self.store.query(&ctx_query).await?;
135 let candidates_count = candidates.len();
136
137 let temporal_query = query.temporal.clone().unwrap_or_default();
139 let filtered: Vec<Context> = candidates
140 .into_iter()
141 .filter(|c| temporal_query.matches(c))
142 .filter(|c| !self.config.safe_only || c.is_safe())
143 .collect();
144
145 let scored = if self.config.parallel && filtered.len() > self.config.chunk_size {
147 self.score_parallel(&filtered, query, &temporal_query)
148 } else {
149 self.score_sequential(&filtered, query, &temporal_query)
150 };
151
152 let mut results: Vec<ScoredContext> = scored
154 .into_iter()
155 .filter(|s| s.score >= self.config.min_relevance)
156 .collect();
157
158 results.sort_by(|a, b| {
159 b.score
160 .partial_cmp(&a.score)
161 .unwrap_or(std::cmp::Ordering::Equal)
162 });
163 results.truncate(self.config.max_results);
164
165 let temporal_stats = TemporalStats::from_contexts(
166 &results
167 .iter()
168 .map(|s| s.context.clone())
169 .collect::<Vec<_>>(),
170 );
171
172 Ok(RetrievalResult {
173 contexts: results,
174 query_summary: query.to_string(),
175 processing_time_ms: start.elapsed().as_millis() as u64,
176 candidates_considered: candidates_count,
177 temporal_stats,
178 })
179 }
180
181 fn score_parallel(
183 &self,
184 contexts: &[Context],
185 query: &RetrievalQuery,
186 temporal: &TemporalQuery,
187 ) -> Vec<ScoredContext> {
188 contexts
189 .par_iter()
190 .map(|ctx| self.score_context(ctx, query, temporal))
191 .collect()
192 }
193
194 fn score_sequential(
196 &self,
197 contexts: &[Context],
198 query: &RetrievalQuery,
199 temporal: &TemporalQuery,
200 ) -> Vec<ScoredContext> {
201 contexts
202 .iter()
203 .map(|ctx| self.score_context(ctx, query, temporal))
204 .collect()
205 }
206
207 fn score_context(
209 &self,
210 ctx: &Context,
211 query: &RetrievalQuery,
212 temporal: &TemporalQuery,
213 ) -> ScoredContext {
214 let temporal_score = if self.config.temporal_decay {
215 temporal.relevance_score(ctx)
216 } else {
217 1.0
218 };
219
220 let importance_score = ctx.metadata.importance as f64;
221
222 let domain_match_score = if query.domain.as_ref() == Some(&ctx.domain) {
223 1.0
224 } else if query.domain.is_none() {
225 0.5 } else {
227 0.2 };
229
230 let tag_match_score = if !query.tags.is_empty() {
231 let matching_tags = query
232 .tags
233 .iter()
234 .filter(|t| ctx.metadata.tags.contains(*t))
235 .count();
236 matching_tags as f64 / query.tags.len() as f64
237 } else {
238 0.5 };
240
241 let breakdown = ScoreBreakdown {
242 temporal: temporal_score,
243 importance: importance_score,
244 domain_match: domain_match_score,
245 tag_match: tag_match_score,
246 similarity: None, };
248
249 let score = 0.25 * breakdown.temporal
251 + 0.25 * breakdown.importance
252 + 0.25 * breakdown.domain_match
253 + 0.25 * breakdown.tag_match;
254
255 ScoredContext {
256 context: ctx.clone(),
257 score,
258 score_breakdown: breakdown,
259 }
260 }
261
262 pub async fn retrieve_by_text(&self, text: &str) -> ContextResult<RetrievalResult> {
264 let query = RetrievalQuery::from_text(text);
265 self.retrieve(&query).await
266 }
267
268 pub fn config(&self) -> &RagConfig {
270 &self.config
271 }
272}
273
274#[derive(Debug, Clone, Default, Serialize, Deserialize)]
276pub struct RetrievalQuery {
277 pub text: Option<String>,
279 pub domain: Option<ContextDomain>,
281 pub tags: Vec<String>,
283 pub min_importance: Option<f32>,
285 pub temporal: Option<TemporalQuery>,
287 pub max_results: Option<usize>,
289}
290
291impl RetrievalQuery {
292 pub fn new() -> Self {
294 Self::default()
295 }
296
297 pub fn from_text(text: &str) -> Self {
299 Self {
300 text: Some(text.to_string()),
301 ..Default::default()
302 }
303 }
304
305 pub fn with_domain(mut self, domain: ContextDomain) -> Self {
307 self.domain = Some(domain);
308 self
309 }
310
311 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
313 self.tags.push(tag.into());
314 self
315 }
316
317 pub fn with_min_importance(mut self, importance: f32) -> Self {
319 self.min_importance = Some(importance);
320 self
321 }
322
323 pub fn with_temporal(mut self, temporal: TemporalQuery) -> Self {
325 self.temporal = Some(temporal);
326 self
327 }
328
329 pub fn recent(hours: i64) -> Self {
331 Self::new().with_temporal(TemporalQuery::recent(hours))
332 }
333}
334
335impl std::fmt::Display for RetrievalQuery {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 let mut parts = Vec::new();
338
339 if let Some(text) = &self.text {
340 parts.push(format!("text: '{}'", text));
341 }
342 if let Some(domain) = &self.domain {
343 parts.push(format!("domain: {:?}", domain));
344 }
345 if !self.tags.is_empty() {
346 parts.push(format!("tags: {:?}", self.tags));
347 }
348 if let Some(importance) = self.min_importance {
349 parts.push(format!("min_importance: {}", importance));
350 }
351
352 if parts.is_empty() {
353 write!(f, "all contexts")
354 } else {
355 write!(f, "{}", parts.join(", "))
356 }
357 }
358}
359
360pub struct BatchProcessor {
362 processor: Arc<RagProcessor>,
363}
364
365impl BatchProcessor {
366 pub fn new(processor: Arc<RagProcessor>) -> Self {
368 Self { processor }
369 }
370
371 pub async fn process_batch(
373 &self,
374 queries: Vec<RetrievalQuery>,
375 ) -> Vec<ContextResult<RetrievalResult>> {
376 let mut results = Vec::with_capacity(queries.len());
377 for query in queries {
378 results.push(self.processor.retrieve(&query).await);
379 }
380 results
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::storage::StorageConfig;
388 use tempfile::TempDir;
389
390 fn create_test_store() -> (Arc<ContextStore>, TempDir) {
391 let temp_dir = TempDir::new().unwrap();
392 let config = StorageConfig {
393 persist_path: Some(temp_dir.path().to_path_buf()),
394 enable_persistence: true,
395 ..Default::default()
396 };
397 let store = ContextStore::new(config).unwrap();
398 (Arc::new(store), temp_dir)
399 }
400
401 #[test]
402 fn test_retrieval_query() {
403 let query = RetrievalQuery::from_text("test query")
404 .with_domain(ContextDomain::Code)
405 .with_tag("rust");
406
407 assert_eq!(query.text, Some("test query".to_string()));
408 assert_eq!(query.domain, Some(ContextDomain::Code));
409 assert!(query.tags.contains(&"rust".to_string()));
410 }
411
412 #[tokio::test]
413 async fn test_rag_processor() {
414 let (store, _temp) = create_test_store();
415 let processor = RagProcessor::with_defaults(store.clone());
416
417 let ctx = Context::new("Test content", ContextDomain::Code);
419 store.store(ctx).await.unwrap();
420
421 let result = processor.retrieve(&RetrievalQuery::new()).await.unwrap();
423 assert_eq!(result.candidates_considered, 1);
424 }
425}