use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::context::{Context, ContextDomain, ContextQuery};
use crate::error::ContextResult;
use crate::storage::ContextStore;
use crate::temporal::{TemporalQuery, TemporalStats};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
pub max_results: usize,
pub min_relevance: f64,
pub parallel: bool,
pub num_threads: usize,
pub temporal_decay: bool,
pub safe_only: bool,
pub chunk_size: usize,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
max_results: 10,
min_relevance: 0.1,
parallel: true,
num_threads: 0, temporal_decay: true,
safe_only: true,
chunk_size: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredContext {
pub context: Context,
pub score: f64,
pub score_breakdown: ScoreBreakdown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub temporal: f64,
pub importance: f64,
pub domain_match: f64,
pub tag_match: f64,
pub similarity: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalResult {
pub contexts: Vec<ScoredContext>,
pub query_summary: String,
pub processing_time_ms: u64,
pub candidates_considered: usize,
pub temporal_stats: TemporalStats,
}
pub struct RagProcessor {
config: RagConfig,
store: Arc<ContextStore>,
}
impl RagProcessor {
pub fn new(store: Arc<ContextStore>, config: RagConfig) -> Self {
if config.num_threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(config.num_threads)
.build_global()
.ok();
}
Self { config, store }
}
pub fn with_defaults(store: Arc<ContextStore>) -> Self {
Self::new(store, RagConfig::default())
}
pub async fn retrieve(&self, query: &RetrievalQuery) -> ContextResult<RetrievalResult> {
let start = std::time::Instant::now();
let mut ctx_query = ContextQuery::new();
if let Some(domain) = &query.domain {
ctx_query = ctx_query.with_domain(domain.clone());
}
for tag in &query.tags {
ctx_query = ctx_query.with_tag(tag.clone());
}
if let Some(min_importance) = query.min_importance {
ctx_query = ctx_query.with_min_importance(min_importance);
}
let candidates: Vec<Context> = self.store.query(&ctx_query).await?;
let candidates_count = candidates.len();
let temporal_query = query.temporal.clone().unwrap_or_default();
let filtered: Vec<Context> = candidates
.into_iter()
.filter(|c| temporal_query.matches(c))
.filter(|c| !self.config.safe_only || c.is_safe())
.collect();
let scored = if self.config.parallel && filtered.len() > self.config.chunk_size {
self.score_parallel(&filtered, query, &temporal_query)
} else {
self.score_sequential(&filtered, query, &temporal_query)
};
let mut results: Vec<ScoredContext> = scored
.into_iter()
.filter(|s| s.score >= self.config.min_relevance)
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.max_results);
let temporal_stats = TemporalStats::from_contexts(
&results
.iter()
.map(|s| s.context.clone())
.collect::<Vec<_>>(),
);
Ok(RetrievalResult {
contexts: results,
query_summary: query.to_string(),
processing_time_ms: start.elapsed().as_millis() as u64,
candidates_considered: candidates_count,
temporal_stats,
})
}
fn score_parallel(
&self,
contexts: &[Context],
query: &RetrievalQuery,
temporal: &TemporalQuery,
) -> Vec<ScoredContext> {
contexts
.par_iter()
.map(|ctx| self.score_context(ctx, query, temporal))
.collect()
}
fn score_sequential(
&self,
contexts: &[Context],
query: &RetrievalQuery,
temporal: &TemporalQuery,
) -> Vec<ScoredContext> {
contexts
.iter()
.map(|ctx| self.score_context(ctx, query, temporal))
.collect()
}
fn score_context(
&self,
ctx: &Context,
query: &RetrievalQuery,
temporal: &TemporalQuery,
) -> ScoredContext {
let temporal_score = if self.config.temporal_decay {
temporal.relevance_score(ctx)
} else {
1.0
};
let importance_score = ctx.metadata.importance as f64;
let domain_match_score = if query.domain.as_ref() == Some(&ctx.domain) {
1.0
} else if query.domain.is_none() {
0.5 } else {
0.2 };
let tag_match_score = if !query.tags.is_empty() {
let matching_tags = query
.tags
.iter()
.filter(|t| ctx.metadata.tags.contains(*t))
.count();
matching_tags as f64 / query.tags.len() as f64
} else {
0.5 };
let breakdown = ScoreBreakdown {
temporal: temporal_score,
importance: importance_score,
domain_match: domain_match_score,
tag_match: tag_match_score,
similarity: None, };
let score = 0.25 * breakdown.temporal
+ 0.25 * breakdown.importance
+ 0.25 * breakdown.domain_match
+ 0.25 * breakdown.tag_match;
ScoredContext {
context: ctx.clone(),
score,
score_breakdown: breakdown,
}
}
pub async fn retrieve_by_text(&self, text: &str) -> ContextResult<RetrievalResult> {
let query = RetrievalQuery::from_text(text);
self.retrieve(&query).await
}
pub fn config(&self) -> &RagConfig {
&self.config
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RetrievalQuery {
pub text: Option<String>,
pub domain: Option<ContextDomain>,
pub tags: Vec<String>,
pub min_importance: Option<f32>,
pub temporal: Option<TemporalQuery>,
pub max_results: Option<usize>,
}
impl RetrievalQuery {
pub fn new() -> Self {
Self::default()
}
pub fn from_text(text: &str) -> Self {
Self {
text: Some(text.to_string()),
..Default::default()
}
}
pub fn with_domain(mut self, domain: ContextDomain) -> Self {
self.domain = Some(domain);
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn with_min_importance(mut self, importance: f32) -> Self {
self.min_importance = Some(importance);
self
}
pub fn with_temporal(mut self, temporal: TemporalQuery) -> Self {
self.temporal = Some(temporal);
self
}
pub fn recent(hours: i64) -> Self {
Self::new().with_temporal(TemporalQuery::recent(hours))
}
}
impl std::fmt::Display for RetrievalQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut parts = Vec::new();
if let Some(text) = &self.text {
parts.push(format!("text: '{}'", text));
}
if let Some(domain) = &self.domain {
parts.push(format!("domain: {:?}", domain));
}
if !self.tags.is_empty() {
parts.push(format!("tags: {:?}", self.tags));
}
if let Some(importance) = self.min_importance {
parts.push(format!("min_importance: {}", importance));
}
if parts.is_empty() {
write!(f, "all contexts")
} else {
write!(f, "{}", parts.join(", "))
}
}
}
pub struct BatchProcessor {
processor: Arc<RagProcessor>,
}
impl BatchProcessor {
pub fn new(processor: Arc<RagProcessor>) -> Self {
Self { processor }
}
pub async fn process_batch(
&self,
queries: Vec<RetrievalQuery>,
) -> Vec<ContextResult<RetrievalResult>> {
let mut results = Vec::with_capacity(queries.len());
for query in queries {
results.push(self.processor.retrieve(&query).await);
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::StorageConfig;
use tempfile::TempDir;
fn create_test_store() -> (Arc<ContextStore>, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = StorageConfig {
persist_path: Some(temp_dir.path().to_path_buf()),
enable_persistence: true,
..Default::default()
};
let store = ContextStore::new(config).unwrap();
(Arc::new(store), temp_dir)
}
#[test]
fn test_retrieval_query() {
let query = RetrievalQuery::from_text("test query")
.with_domain(ContextDomain::Code)
.with_tag("rust");
assert_eq!(query.text, Some("test query".to_string()));
assert_eq!(query.domain, Some(ContextDomain::Code));
assert!(query.tags.contains(&"rust".to_string()));
}
#[tokio::test]
async fn test_rag_processor() {
let (store, _temp) = create_test_store();
let processor = RagProcessor::with_defaults(store.clone());
let ctx = Context::new("Test content", ContextDomain::Code);
store.store(ctx).await.unwrap();
let result = processor.retrieve(&RetrievalQuery::new()).await.unwrap();
assert_eq!(result.candidates_considered, 1);
}
}