use crate::{
error::{MemError, MemResult},
storage::Storage,
Document, RetrievalConfig, SearchResult,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
pub retrieval: RetrievalConfig,
pub max_context_tokens: usize,
pub include_citations: bool,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
retrieval: RetrievalConfig::default(),
max_context_tokens: 4096,
include_citations: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagContext {
pub context: String,
pub sources: Vec<SearchResult>,
pub token_count: usize,
}
pub struct RagPipeline {
config: RagConfig,
}
impl RagPipeline {
pub fn new(config: RagConfig) -> Self {
Self { config }
}
pub fn assemble_context(&self, results: Vec<SearchResult>) -> RagContext {
let mut context = String::new();
let mut token_count = 0;
for (i, result) in results.iter().enumerate() {
let chunk_text = &result.chunk.text;
let chunk_tokens = chunk_text.split_whitespace().count();
if token_count + chunk_tokens > self.config.max_context_tokens {
break;
}
if self.config.include_citations {
context.push_str(&format!("[{}] ", i + 1));
}
context.push_str(chunk_text);
context.push_str("\n\n");
token_count += chunk_tokens;
}
RagContext {
context,
sources: results,
token_count,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rag_config_default() {
let config = RagConfig::default();
assert_eq!(config.max_context_tokens, 4096);
assert!(config.include_citations);
}
#[test]
fn test_rag_pipeline_creation() {
let pipeline = RagPipeline::new(RagConfig::default());
assert_eq!(pipeline.config.max_context_tokens, 4096);
}
}