use std::collections::HashMap;
use std::sync::Arc;
use crate::{
agent::UnifiedAgent,
chain::ChainError,
schemas::{messages::Message, Retriever},
tools::Tool,
};
use super::retriever_tool::RetrieverTool;
use crate::rag::RAGError;
pub struct RetrieverInfo {
pub retriever: Arc<dyn Retriever>,
pub name: String,
pub description: String,
pub max_docs: usize,
}
impl RetrieverInfo {
pub fn new(retriever: Arc<dyn Retriever>, name: String, description: String) -> Self {
Self {
retriever,
name,
description,
max_docs: 5,
}
}
pub fn with_max_docs(mut self, max_docs: usize) -> Self {
self.max_docs = max_docs;
self
}
}
pub struct AgenticRAG {
agent: Arc<UnifiedAgent>,
retriever_tools: HashMap<String, Arc<dyn Tool>>,
}
impl AgenticRAG {
pub fn new(agent: Arc<UnifiedAgent>, retriever_tools: Vec<Arc<dyn Tool>>) -> Self {
let mut tool_map = HashMap::new();
for tool in retriever_tools {
tool_map.insert(tool.name(), tool);
}
Self {
agent,
retriever_tools: tool_map,
}
}
pub async fn invoke_messages(&self, messages: Vec<Message>) -> Result<String, RAGError> {
self.agent
.invoke_messages(messages)
.await
.map_err(|e| RAGError::ChainError(e))
}
pub fn agent(&self) -> &Arc<UnifiedAgent> {
&self.agent
}
pub fn get_retriever_tool(&self, name: &str) -> Option<&Arc<dyn Tool>> {
self.retriever_tools.get(name)
}
}
pub struct AgenticRAGBuilder {
agent: Option<Arc<UnifiedAgent>>,
model: Option<String>,
system_prompt: Option<String>,
retrievers: Vec<RetrieverInfo>,
additional_tools: Vec<Arc<dyn Tool>>,
}
impl AgenticRAGBuilder {
pub fn new() -> Self {
Self {
agent: None,
model: None,
system_prompt: None,
retrievers: Vec::new(),
additional_tools: Vec::new(),
}
}
pub fn with_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.agent = Some(agent);
self
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.model = Some(model.into());
self
}
pub fn with_system_prompt<S: Into<String>>(mut self, system_prompt: S) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub fn with_retriever(mut self, retriever: RetrieverInfo) -> Self {
self.retrievers.push(retriever);
self
}
pub fn with_retrievers(mut self, retrievers: Vec<RetrieverInfo>) -> Self {
self.retrievers.extend(retrievers);
self
}
pub fn with_tools(mut self, tools: &[Arc<dyn Tool>]) -> Self {
self.additional_tools.extend_from_slice(tools);
self
}
pub fn build(self) -> Result<AgenticRAG, RAGError> {
let agent = if let Some(agent) = self.agent {
agent
} else {
let model = self.model.ok_or_else(|| {
RAGError::InvalidConfiguration("Either agent or model must be set".to_string())
})?;
let system_prompt = self.system_prompt.unwrap_or_else(|| {
"You are a helpful assistant. Use the retrieval tools when you need to find information from external knowledge sources.".to_string()
});
let mut all_tools: Vec<Arc<dyn Tool>> = self.additional_tools;
for retriever_info in &self.retrievers {
let retriever_tool = Arc::new(
RetrieverTool::new(
retriever_info.retriever.clone(),
retriever_info.name.clone(),
retriever_info.description.clone(),
)
.with_max_docs(retriever_info.max_docs),
);
all_tools.push(retriever_tool);
}
Arc::new(
crate::agent::create_agent(&model, &all_tools, Some(&system_prompt), None)
.map_err(|e| RAGError::ChainError(ChainError::AgentError(e.to_string())))?,
)
};
let mut retriever_tools: Vec<Arc<dyn Tool>> = Vec::new();
for retriever_info in self.retrievers {
let retriever_tool = Arc::new(
RetrieverTool::new(
retriever_info.retriever,
retriever_info.name.clone(),
retriever_info.description,
)
.with_max_docs(retriever_info.max_docs),
);
retriever_tools.push(retriever_tool.clone());
}
Ok(AgenticRAG::new(agent, retriever_tools))
}
}
impl Default for AgenticRAGBuilder {
fn default() -> Self {
Self::new()
}
}