use async_trait::async_trait;
use llm_toolkit::agent::retrieval::RetrievalAwareAgent;
use llm_toolkit::agent::{Agent, AgentError, Payload};
use llm_toolkit::retrieval::Document;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct VectorStoreAgent {
documents: Arc<Mutex<Vec<Document>>>,
}
impl VectorStoreAgent {
fn new() -> Self {
Self {
documents: Arc::new(Mutex::new(Vec::new())),
}
}
async fn seed(&self, docs: Vec<Document>) {
let mut store = self.documents.lock().await;
store.extend(docs);
}
}
#[async_trait]
impl Agent for VectorStoreAgent {
type Output = Vec<Document>;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Semantic search over programming language documentation";
&EXPERTISE
}
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
let query = payload.to_text();
let docs = self.documents.lock().await;
let mut results: Vec<_> = docs
.iter()
.filter(|doc| doc.content.to_lowercase().contains(&query.to_lowercase()))
.cloned()
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(3);
println!(
"\n[VectorStoreAgent] Retrieved {} documents for: \"{}\"",
results.len(),
query
);
for (i, doc) in results.iter().enumerate() {
println!(
" {}. {} (score: {:.2})",
i + 1,
doc.source.as_ref().unwrap_or(&"unknown".to_string()),
doc.score.unwrap_or(0.0)
);
}
Ok(results)
}
}
#[derive(Clone)]
struct MockLLMAgent;
#[async_trait]
impl Agent for MockLLMAgent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Answering questions about programming languages";
&EXPERTISE
}
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
let query = payload.to_text();
let documents = payload.documents();
println!("\n[MockLLMAgent] Processing query: \"{}\"", query);
println!(
"[MockLLMAgent] Received {} documents in context",
documents.len()
);
let response = if documents.is_empty() {
"I don't have enough context to answer this question accurately.".to_string()
} else {
let context_summary: Vec<_> = documents
.iter()
.map(|d| d.content.split('.').next().unwrap_or(&d.content))
.collect();
format!(
"Based on the retrieved documentation:\n\n{}\n\nThese documents provide relevant information to answer your question.",
context_summary.join("\n- ")
)
};
println!("\n[MockLLMAgent] Generated response:");
println!("{}", response);
Ok(response)
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("\n{}", "=".repeat(70));
println!("RetrievalAwareAgent Example: Agent-based RAG");
println!("{}\n", "=".repeat(70));
println!("Setting up vector store with programming language docs...\n");
let vector_store = VectorStoreAgent::new();
vector_store
.seed(vec![
Document::new(
"Rust is a systems programming language focused on safety and performance.",
)
.with_source("rust_intro.md")
.with_score(0.95),
Document::new(
"Rust's ownership system ensures memory safety without garbage collection.",
)
.with_source("rust_memory.md")
.with_score(0.92),
Document::new(
"Rust supports async programming through futures and the async/await syntax.",
)
.with_source("rust_async.md")
.with_score(0.88),
Document::new("Python is a high-level interpreted language known for its simplicity.")
.with_source("python_intro.md")
.with_score(0.90),
Document::new("Python has dynamic typing which allows for rapid prototyping.")
.with_source("python_typing.md")
.with_score(0.85),
Document::new("Go is a compiled language designed for building scalable systems.")
.with_source("go_intro.md")
.with_score(0.87),
])
.await;
println!("{}", "=".repeat(70));
println!("Scenario 1: Direct Retrieval");
println!("{}\n", "=".repeat(70));
let query1 = Payload::text("async programming");
let docs = vector_store.execute(query1).await?;
println!("\nDirect retrieval returned {} documents\n", docs.len());
println!("{}", "=".repeat(70));
println!("Scenario 2: RAG with RetrievalAwareAgent");
println!("{}\n", "=".repeat(70));
let llm_agent = MockLLMAgent;
let rag_agent = RetrievalAwareAgent::new(vector_store.clone(), llm_agent);
let query2 = Payload::text("How does Rust handle memory safety?");
println!("Query: \"{}\"", query2.to_text());
let response = rag_agent.execute(query2).await?;
println!("\n{}", "─".repeat(70));
println!("Final Answer:");
println!("{}", "─".repeat(70));
println!("{}\n", response);
println!("{}", "=".repeat(70));
println!("Scenario 3: Query with No Matching Documents");
println!("{}\n", "=".repeat(70));
let query3 = Payload::text("What is Java?");
println!("Query: \"{}\"", query3.to_text());
let response = rag_agent.execute(query3).await?;
println!("\n{}", "─".repeat(70));
println!("Final Answer:");
println!("{}", "─".repeat(70));
println!("{}\n", response);
println!("{}", "=".repeat(70));
println!("Summary: Key Takeaways");
println!("{}\n", "=".repeat(70));
println!("1. Retriever = Regular Agent with Output = Vec<Document>");
println!(" - No special Retriever trait needed");
println!(" - Works with all existing Agent infrastructure\n");
println!("2. RetrievalAwareAgent composes retriever + inner agent");
println!(" - Automatically retrieves documents");
println!(" - Augments payload with documents");
println!(" - Passes augmented payload to inner agent\n");
println!("3. Clean separation of concerns");
println!(" - Retrieval logic in VectorStoreAgent");
println!(" - LLM logic in MockLLMAgent");
println!(" - Composition in RetrievalAwareAgent\n");
println!("4. Easy to test and extend");
println!(" - Mock retrievers are just mock agents");
println!(" - Can compose with other agent wrappers (history, retry, etc.)");
println!(" - Type-safe and composable\n");
println!("{}\n", "=".repeat(70));
Ok(())
}