use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Citation {
pub id: String,
pub source: String,
pub text: String,
pub page: Option<u32>,
pub score: Option<f32>,
pub metadata: HashMap<String, String>,
}
impl Citation {
pub fn new(id: impl Into<String>, source: impl Into<String>, text: impl Into<String>) -> Self {
Self {
id: id.into(),
source: source.into(),
text: text.into(),
page: None,
score: None,
metadata: HashMap::new(),
}
}
pub fn page(mut self, page: u32) -> Self {
self.page = Some(page);
self
}
pub fn score(mut self, score: f32) -> Self {
self.score = Some(score);
self
}
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextPack {
pub chunks: Vec<ContextChunk>,
pub total_tokens: usize,
pub query: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextChunk {
pub content: String,
pub source: String,
pub score: f32,
pub index: usize,
pub metadata: HashMap<String, String>,
}
impl ContextChunk {
pub fn new(content: impl Into<String>, source: impl Into<String>, score: f32) -> Self {
Self {
content: content.into(),
source: source.into(),
score,
index: 0,
metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RAGResult {
pub answer: String,
pub citations: Vec<Citation>,
pub context: ContextPack,
pub tokens_used: usize,
pub processing_time_ms: u64,
}
impl RAGResult {
pub fn new(answer: impl Into<String>, context: ContextPack) -> Self {
Self {
answer: answer.into(),
citations: Vec::new(),
context,
tokens_used: 0,
processing_time_ms: 0,
}
}
pub fn add_citation(&mut self, citation: Citation) {
self.citations.push(citation);
}
pub fn citation_count(&self) -> usize {
self.citations.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RetrievalStrategy {
#[default]
Similarity,
Hybrid,
MultiQuery,
Hierarchical,
Compression,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CitationsMode {
#[default]
Inline,
Footnote,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RAGConfig {
pub top_k: usize,
pub score_threshold: f32,
pub max_context_tokens: usize,
pub strategy: RetrievalStrategy,
pub citations_mode: CitationsMode,
pub rerank: bool,
pub compress: bool,
pub chunk_overlap: usize,
pub chunk_size: usize,
}
impl Default for RAGConfig {
fn default() -> Self {
Self {
top_k: 5,
score_threshold: 0.7,
max_context_tokens: 4096,
strategy: RetrievalStrategy::default(),
citations_mode: CitationsMode::default(),
rerank: false,
compress: false,
chunk_overlap: 50,
chunk_size: 500,
}
}
}
impl RAGConfig {
pub fn new() -> Self {
Self::default()
}
pub fn top_k(mut self, k: usize) -> Self {
self.top_k = k;
self
}
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.score_threshold = threshold;
self
}
pub fn max_context_tokens(mut self, tokens: usize) -> Self {
self.max_context_tokens = tokens;
self
}
pub fn strategy(mut self, strategy: RetrievalStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn citations_mode(mut self, mode: CitationsMode) -> Self {
self.citations_mode = mode;
self
}
pub fn rerank(mut self, enable: bool) -> Self {
self.rerank = enable;
self
}
pub fn compress(mut self, enable: bool) -> Self {
self.compress = enable;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalConfig {
pub enabled: bool,
pub rag: RAGConfig,
pub sources: Vec<String>,
pub auto_retrieve: bool,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
enabled: false,
rag: RAGConfig::default(),
sources: Vec::new(),
auto_retrieve: true,
}
}
}
impl RetrievalConfig {
pub fn new() -> Self {
Self::default()
}
pub fn enable(mut self) -> Self {
self.enabled = true;
self
}
pub fn source(mut self, source: impl Into<String>) -> Self {
self.sources.push(source.into());
self
}
pub fn rag(mut self, config: RAGConfig) -> Self {
self.rag = config;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenBudget {
pub total: usize,
pub system: usize,
pub context: usize,
pub response: usize,
pub reserved: usize,
}
impl Default for TokenBudget {
fn default() -> Self {
Self {
total: 8192,
system: 500,
context: 4096,
response: 2048,
reserved: 500,
}
}
}
impl TokenBudget {
pub fn new(total: usize) -> Self {
let context = total / 2;
let response = total / 4;
let system = 500.min(total / 10);
let reserved = total - context - response - system;
Self {
total,
system,
context,
response,
reserved,
}
}
pub fn available_context(&self) -> usize {
self.context
}
pub fn can_add_context(&self, tokens: usize) -> bool {
tokens <= self.context
}
}
pub fn get_model_context_window(model: &str) -> usize {
match model {
m if m.contains("gpt-4o") => 128000,
m if m.contains("gpt-4-turbo") => 128000,
m if m.contains("gpt-4") => 8192,
m if m.contains("gpt-3.5") => 16385,
m if m.contains("claude-3") => 200000,
m if m.contains("claude-2") => 100000,
m if m.contains("gemini-1.5") => 1000000,
m if m.contains("gemini-pro") => 32768,
_ => 8192, }
}
pub fn estimate_tokens(text: &str) -> usize {
(text.len() + 3) / 4
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalResult {
pub chunks: Vec<ContextChunk>,
pub query: String,
pub strategy: RetrievalStrategy,
pub total_searched: usize,
}
impl RetrievalResult {
pub fn new(query: impl Into<String>, strategy: RetrievalStrategy) -> Self {
Self {
chunks: Vec::new(),
query: query.into(),
strategy,
total_searched: 0,
}
}
pub fn add_chunk(&mut self, chunk: ContextChunk) {
self.chunks.push(chunk);
}
pub fn top_chunks(&self, n: usize) -> Vec<&ContextChunk> {
let mut sorted: Vec<_> = self.chunks.iter().collect();
sorted.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
sorted.into_iter().take(n).collect()
}
}
#[derive(Debug, Clone)]
pub struct RAG {
pub config: RAGConfig,
pub model: String,
pub sources: Vec<String>,
}
impl Default for RAG {
fn default() -> Self {
Self {
config: RAGConfig::default(),
model: "gpt-4o-mini".to_string(),
sources: Vec::new(),
}
}
}
impl RAG {
pub fn new() -> RAGBuilder {
RAGBuilder::default()
}
pub fn query(&self, question: &str) -> Result<RAGResult> {
let context = ContextPack {
chunks: vec![ContextChunk::new(
"Sample retrieved content for the query.",
"knowledge_base",
0.95,
)],
total_tokens: 50,
query: question.to_string(),
};
let mut result = RAGResult::new(
format!("Answer to: {} (based on retrieved context)", question),
context,
);
result.add_citation(Citation::new(
"[1]",
"knowledge_base",
"Sample retrieved content",
));
Ok(result)
}
pub fn add_source(&mut self, source: impl Into<String>) {
self.sources.push(source.into());
}
pub fn build_context(&self, chunks: &[ContextChunk]) -> String {
chunks
.iter()
.enumerate()
.map(|(i, chunk)| format!("[{}] {}", i + 1, chunk.content))
.collect::<Vec<_>>()
.join("\n\n")
}
pub fn truncate_context(&self, context: &str, max_tokens: usize) -> String {
let estimated = estimate_tokens(context);
if estimated <= max_tokens {
return context.to_string();
}
let char_limit = max_tokens * 4;
if context.len() <= char_limit {
return context.to_string();
}
format!("{}...", &context[..char_limit])
}
}
#[derive(Debug, Default)]
pub struct RAGBuilder {
config: RAGConfig,
model: Option<String>,
sources: Vec<String>,
}
impl RAGBuilder {
pub fn config(mut self, config: RAGConfig) -> Self {
self.config = config;
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn source(mut self, source: impl Into<String>) -> Self {
self.sources.push(source.into());
self
}
pub fn build(self) -> Result<RAG> {
Ok(RAG {
config: self.config,
model: self.model.unwrap_or_else(|| "gpt-4o-mini".to_string()),
sources: self.sources,
})
}
}
pub fn build_context(chunks: &[ContextChunk]) -> String {
chunks
.iter()
.enumerate()
.map(|(i, chunk)| format!("[{}] {}", i + 1, chunk.content))
.collect::<Vec<_>>()
.join("\n\n")
}
pub fn truncate_context(context: &str, max_tokens: usize) -> String {
let estimated = estimate_tokens(context);
if estimated <= max_tokens {
return context.to_string();
}
let char_limit = max_tokens * 4;
if context.len() <= char_limit {
return context.to_string();
}
format!("{}...", &context[..char_limit])
}
pub fn deduplicate_chunks(chunks: Vec<ContextChunk>, _threshold: f32) -> Vec<ContextChunk> {
let mut result = Vec::new();
for chunk in chunks {
let is_duplicate = result.iter().any(|existing: &ContextChunk| {
existing.content == chunk.content
});
if !is_duplicate {
result.push(chunk);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_citation_creation() {
let citation = Citation::new("[1]", "document.pdf", "Sample text")
.page(5)
.score(0.95);
assert_eq!(citation.id, "[1]");
assert_eq!(citation.source, "document.pdf");
assert_eq!(citation.page, Some(5));
assert_eq!(citation.score, Some(0.95));
}
#[test]
fn test_context_chunk() {
let chunk = ContextChunk::new("Content here", "source.txt", 0.85);
assert_eq!(chunk.content, "Content here");
assert_eq!(chunk.score, 0.85);
}
#[test]
fn test_rag_config_defaults() {
let config = RAGConfig::default();
assert_eq!(config.top_k, 5);
assert_eq!(config.score_threshold, 0.7);
assert_eq!(config.strategy, RetrievalStrategy::Similarity);
}
#[test]
fn test_rag_config_builder() {
let config = RAGConfig::new()
.top_k(10)
.score_threshold(0.8)
.strategy(RetrievalStrategy::Hybrid)
.rerank(true);
assert_eq!(config.top_k, 10);
assert_eq!(config.score_threshold, 0.8);
assert_eq!(config.strategy, RetrievalStrategy::Hybrid);
assert!(config.rerank);
}
#[test]
fn test_retrieval_config() {
let config = RetrievalConfig::new()
.enable()
.source("docs/")
.source("knowledge/");
assert!(config.enabled);
assert_eq!(config.sources.len(), 2);
}
#[test]
fn test_token_budget() {
let budget = TokenBudget::new(16000);
assert_eq!(budget.total, 16000);
assert!(budget.can_add_context(4000));
}
#[test]
fn test_model_context_window() {
assert_eq!(get_model_context_window("gpt-4o"), 128000);
assert_eq!(get_model_context_window("claude-3-opus"), 200000);
assert_eq!(get_model_context_window("unknown-model"), 8192);
}
#[test]
fn test_estimate_tokens() {
let text = "Hello world";
let tokens = estimate_tokens(text);
assert!(tokens > 0);
assert!(tokens < text.len());
}
#[test]
fn test_rag_builder() {
let rag = RAG::new()
.model("gpt-4o")
.source("docs/")
.config(RAGConfig::new().top_k(10))
.build()
.unwrap();
assert_eq!(rag.model, "gpt-4o");
assert_eq!(rag.sources.len(), 1);
assert_eq!(rag.config.top_k, 10);
}
#[test]
fn test_rag_query() {
let rag = RAG::new().build().unwrap();
let result = rag.query("What is the answer?").unwrap();
assert!(!result.answer.is_empty());
assert!(!result.citations.is_empty());
}
#[test]
fn test_build_context() {
let chunks = vec![
ContextChunk::new("First chunk", "doc1", 0.9),
ContextChunk::new("Second chunk", "doc2", 0.8),
];
let context = build_context(&chunks);
assert!(context.contains("[1]"));
assert!(context.contains("[2]"));
assert!(context.contains("First chunk"));
}
#[test]
fn test_truncate_context() {
let long_text = "a".repeat(10000);
let truncated = truncate_context(&long_text, 100);
assert!(truncated.len() < long_text.len());
assert!(truncated.ends_with("..."));
}
#[test]
fn test_deduplicate_chunks() {
let chunks = vec![
ContextChunk::new("Same content", "doc1", 0.9),
ContextChunk::new("Same content", "doc2", 0.8),
ContextChunk::new("Different content", "doc3", 0.7),
];
let deduped = deduplicate_chunks(chunks, 0.9);
assert_eq!(deduped.len(), 2);
}
#[test]
fn test_retrieval_result() {
let mut result = RetrievalResult::new("test query", RetrievalStrategy::Similarity);
result.add_chunk(ContextChunk::new("High score", "doc1", 0.95));
result.add_chunk(ContextChunk::new("Low score", "doc2", 0.5));
let top = result.top_chunks(1);
assert_eq!(top.len(), 1);
assert_eq!(top[0].score, 0.95);
}
#[test]
fn test_rag_result() {
let context = ContextPack {
chunks: vec![],
total_tokens: 0,
query: "test".to_string(),
};
let mut result = RAGResult::new("Answer", context);
result.add_citation(Citation::new("[1]", "source", "text"));
assert_eq!(result.citation_count(), 1);
}
}