use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::ai::providers::{
ChatMessage, LlmProvider, LlmRequest, MessageRole, ProviderResult,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
pub vector_stores: Vec<String>,
#[serde(default = "default_top_k")]
pub top_k: usize,
pub min_score: Option<f32>,
#[serde(default)]
pub rerank: bool,
pub alpha: Option<f32>,
#[serde(default = "default_true")]
pub include_sources: bool,
pub max_context_tokens: Option<usize>,
pub system_prompt: Option<String>,
pub chunk_strategy: Option<ChunkStrategy>,
}
fn default_top_k() -> usize {
5
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChunkStrategy {
Exact,
Expand { before: usize, after: usize },
Merge { max_length: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagRequest {
pub query: String,
pub config: RagConfig,
pub history: Option<Vec<ChatMessage>>,
pub context: Option<String>,
pub format: Option<OutputFormat>,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputFormat {
pub style: Option<String>,
pub language: Option<String>,
#[serde(default = "default_true")]
pub cite_sources: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagResponse {
pub response: String,
pub sources: Vec<RagSource>,
pub query_analysis: Option<QueryAnalysis>,
pub usage: Option<TokenUsage>,
pub confidence: f32,
pub follow_ups: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagSource {
pub id: String,
pub content: String,
pub score: f32,
pub metadata: HashMap<String, serde_json::Value>,
pub highlight: Option<String>,
pub location: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnalysis {
pub intent: String,
pub entities: Vec<ExtractedEntity>,
pub reformulated_query: Option<String>,
pub keywords: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub entity: String,
pub entity_type: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub context_tokens: usize,
}
pub struct RagPipeline {
llm: Arc<dyn LlmProvider>,
default_config: RagConfig,
retriever: Option<Arc<dyn Retriever>>,
reranker: Option<Arc<dyn Reranker>>,
}
#[async_trait::async_trait]
pub trait Retriever: Send + Sync {
async fn retrieve(
&self,
query: &str,
stores: &[String],
top_k: usize,
min_score: Option<f32>,
) -> ProviderResult<Vec<RagSource>>;
}
#[async_trait::async_trait]
pub trait Reranker: Send + Sync {
async fn rerank(
&self,
query: &str,
documents: Vec<RagSource>,
top_k: usize,
) -> ProviderResult<Vec<RagSource>>;
}
impl RagPipeline {
pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
Self {
llm,
default_config: RagConfig {
vector_stores: vec!["default".to_string()],
top_k: 5,
min_score: Some(0.7),
rerank: false,
alpha: Some(0.5),
include_sources: true,
max_context_tokens: Some(4000),
system_prompt: None,
chunk_strategy: None,
},
retriever: None,
reranker: None,
}
}
pub fn with_retriever(mut self, retriever: Arc<dyn Retriever>) -> Self {
self.retriever = Some(retriever);
self
}
pub fn with_reranker(mut self, reranker: Arc<dyn Reranker>) -> Self {
self.reranker = Some(reranker);
self
}
pub fn with_config(mut self, config: RagConfig) -> Self {
self.default_config = config;
self
}
pub async fn query(&self, request: RagRequest) -> ProviderResult<RagResponse> {
let config = request.config.clone();
let mut sources = self.retrieve(&request.query, &config).await?;
if config.rerank {
if let Some(ref reranker) = self.reranker {
sources = reranker.rerank(&request.query, sources, config.top_k).await?;
}
}
let context = self.build_context(&sources, config.max_context_tokens);
let response = self.generate(&request, &context, &sources).await?;
Ok(response)
}
async fn retrieve(&self, query: &str, config: &RagConfig) -> ProviderResult<Vec<RagSource>> {
if let Some(ref retriever) = self.retriever {
retriever.retrieve(
query,
&config.vector_stores,
config.top_k,
config.min_score,
).await
} else {
Ok(Vec::new())
}
}
fn build_context(&self, sources: &[RagSource], max_tokens: Option<usize>) -> String {
let max_chars = max_tokens.map(|t| t * 4).unwrap_or(16000);
let mut context = String::new();
let mut total_chars = 0;
for (i, source) in sources.iter().enumerate() {
let entry = format!(
"[Source {}]: {}\n\n",
i + 1,
source.content
);
if total_chars + entry.len() > max_chars {
break;
}
context.push_str(&entry);
total_chars += entry.len();
}
context
}
async fn generate(
&self,
request: &RagRequest,
context: &str,
sources: &[RagSource],
) -> ProviderResult<RagResponse> {
let system_prompt = request.config.system_prompt.clone().unwrap_or_else(|| {
format!(
r#"You are a helpful assistant that answers questions based on the provided context.
Guidelines:
1. Answer based primarily on the provided context
2. If the context doesn't contain relevant information, say so clearly
3. Cite sources using [Source N] format when making claims
4. Be concise but thorough
5. If asked for opinions, clarify that you're an AI and provide balanced perspectives
Context:
{}
Answer the user's question based on the above context."#,
context
)
});
let mut messages = vec![ChatMessage {
role: MessageRole::System,
content: system_prompt,
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
}];
if let Some(ref history) = request.history {
messages.extend(history.clone());
}
messages.push(ChatMessage {
role: MessageRole::User,
content: request.query.clone(),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
});
let llm_request = LlmRequest {
messages,
model: request.model.clone(),
max_tokens: request.max_tokens,
temperature: request.temperature,
..Default::default()
};
let response = self.llm.chat(llm_request).await?;
let response_content = response.message.content;
Ok(RagResponse {
follow_ups: self.generate_follow_ups(&request.query, &response_content),
response: response_content,
sources: if request.config.include_sources {
sources.to_vec()
} else {
Vec::new()
},
query_analysis: None, usage: response.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
context_tokens: context.len() / 4, }),
confidence: self.calculate_confidence(sources),
})
}
fn calculate_confidence(&self, sources: &[RagSource]) -> f32 {
if sources.is_empty() {
return 0.3;
}
let avg_score: f32 = sources.iter().map(|s| s.score).sum::<f32>() / sources.len() as f32;
let source_count_factor = (sources.len() as f32 / 5.0).min(1.0);
(avg_score * 0.7 + source_count_factor * 0.3).min(1.0)
}
fn generate_follow_ups(&self, query: &str, response: &str) -> Option<Vec<String>> {
let mut follow_ups = Vec::new();
if response.contains("however") || response.contains("but") {
follow_ups.push("Can you elaborate on the exceptions mentioned?".to_string());
}
if response.contains("example") || response.contains("such as") {
follow_ups.push("Can you provide more examples?".to_string());
}
if query.contains("how") {
follow_ups.push("What are the prerequisites for this?".to_string());
}
if query.contains("why") {
follow_ups.push("Are there any alternative explanations?".to_string());
}
if follow_ups.is_empty() {
None
} else {
Some(follow_ups)
}
}
}
impl Default for RagConfig {
fn default() -> Self {
Self {
vector_stores: vec!["default".to_string()],
top_k: 5,
min_score: Some(0.7),
rerank: false,
alpha: Some(0.5),
include_sources: true,
max_context_tokens: Some(4000),
system_prompt: None,
chunk_strategy: None,
}
}
}