use super::{Agent, AgentError, Payload};
use crate::retrieval::Document;
use async_trait::async_trait;
pub struct RetrievalAwareAgent<R, I>
where
R: Agent<Output = Vec<Document>>,
I: Agent,
{
retriever: R,
inner_agent: I,
}
impl<R, I> RetrievalAwareAgent<R, I>
where
R: Agent<Output = Vec<Document>>,
I: Agent,
{
pub fn new(retriever: R, inner_agent: I) -> Self {
Self {
retriever,
inner_agent,
}
}
pub fn retriever(&self) -> &R {
&self.retriever
}
pub fn inner_agent(&self) -> &I {
&self.inner_agent
}
}
#[async_trait]
impl<R, I> Agent for RetrievalAwareAgent<R, I>
where
R: Agent<Output = Vec<Document>> + Send + Sync,
I: Agent + Send + Sync,
I::Output: Send,
{
type Output = I::Output;
type Expertise = I::Expertise;
fn expertise(&self) -> &I::Expertise {
self.inner_agent.expertise()
}
fn capabilities(&self) -> Option<Vec<super::Capability>> {
self.inner_agent.capabilities()
}
#[crate::tracing::instrument(
name = "retrieval_aware_agent.execute",
skip(self, payload),
fields(
retriever.description = self.retriever.description(),
inner_agent.description = self.inner_agent.description(),
)
)]
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
crate::tracing::debug!(
target: "llm_toolkit::agent::retrieval",
"Executing retriever agent"
);
let documents = self.retriever.execute(payload.clone()).await?;
crate::tracing::debug!(
target: "llm_toolkit::agent::retrieval",
document_count = documents.len(),
"Retrieved documents from retriever agent"
);
let augmented_payload = payload.with_documents(documents);
crate::tracing::trace!(
target: "llm_toolkit::agent::retrieval",
"Augmented payload with retrieved documents"
);
crate::tracing::debug!(
target: "llm_toolkit::agent::retrieval",
"Executing inner agent with augmented payload"
);
self.inner_agent.execute(augmented_payload).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::{Agent, AgentError, Payload};
use crate::retrieval::Document;
use async_trait::async_trait;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct MockRetriever {
documents: Vec<Document>,
calls: Arc<Mutex<Vec<Payload>>>,
}
impl MockRetriever {
fn new(documents: Vec<Document>) -> Self {
Self {
documents,
calls: Arc::new(Mutex::new(Vec::new())),
}
}
async fn get_calls(&self) -> Vec<Payload> {
self.calls.lock().await.clone()
}
}
#[async_trait]
impl Agent for MockRetriever {
type Output = Vec<Document>;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Mock retriever for testing";
&EXPERTISE
}
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
self.calls.lock().await.push(payload);
Ok(self.documents.clone())
}
}
#[derive(Clone)]
struct MockInnerAgent<T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static> {
response: T,
calls: Arc<Mutex<Vec<Payload>>>,
}
impl<T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static> MockInnerAgent<T> {
fn new(response: T) -> Self {
Self {
response,
calls: Arc::new(Mutex::new(Vec::new())),
}
}
async fn get_calls(&self) -> Vec<Payload> {
self.calls.lock().await.clone()
}
}
#[async_trait]
impl<T> Agent for MockInnerAgent<T>
where
T: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
type Output = T;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Mock inner agent for testing";
&EXPERTISE
}
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
self.calls.lock().await.push(payload);
Ok(self.response.clone())
}
}
#[tokio::test]
async fn test_retrieval_aware_agent_augments_payload() {
let documents = vec![
Document::new("Rust is a systems programming language.")
.with_source("rust_intro.md")
.with_score(0.92),
Document::new("Rust has ownership and borrowing.")
.with_source("rust_memory.md")
.with_score(0.88),
];
let retriever = MockRetriever::new(documents.clone());
let inner_agent = MockInnerAgent::new("Response".to_string());
let rag_agent = RetrievalAwareAgent::new(retriever.clone(), inner_agent.clone());
let payload = Payload::text("What is Rust?");
let result = rag_agent.execute(payload.clone()).await.unwrap();
assert_eq!(result, "Response");
let retriever_calls = retriever.get_calls().await;
assert_eq!(retriever_calls.len(), 1);
assert_eq!(retriever_calls[0].to_text(), "What is Rust?");
let inner_calls = inner_agent.get_calls().await;
assert_eq!(inner_calls.len(), 1);
let received_docs = inner_calls[0].documents();
assert_eq!(received_docs.len(), 2);
assert_eq!(
received_docs[0].content,
"Rust is a systems programming language."
);
assert_eq!(
received_docs[1].content,
"Rust has ownership and borrowing."
);
}
#[tokio::test]
async fn test_retrieval_aware_agent_propagates_retriever_error() {
#[derive(Clone)]
struct FailingRetriever;
#[async_trait]
impl Agent for FailingRetriever {
type Output = Vec<Document>;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Failing retriever";
&EXPERTISE
}
async fn execute(&self, _payload: Payload) -> Result<Self::Output, AgentError> {
Err(AgentError::ExecutionFailed("Retrieval failed".to_string()))
}
}
let retriever = FailingRetriever;
let inner_agent = MockInnerAgent::new("Should not be reached".to_string());
let rag_agent = RetrievalAwareAgent::new(retriever, inner_agent.clone());
let result = rag_agent.execute(Payload::text("Query")).await;
assert!(result.is_err());
let inner_calls = inner_agent.get_calls().await;
assert_eq!(inner_calls.len(), 0);
}
#[tokio::test]
async fn test_retrieval_aware_agent_with_empty_results() {
let retriever = MockRetriever::new(vec![]);
let inner_agent = MockInnerAgent::new("No context".to_string());
let rag_agent = RetrievalAwareAgent::new(retriever, inner_agent.clone());
let result = rag_agent.execute(Payload::text("Query")).await.unwrap();
assert_eq!(result, "No context");
let inner_calls = inner_agent.get_calls().await;
assert_eq!(inner_calls.len(), 1);
assert_eq!(inner_calls[0].documents().len(), 0);
}
#[tokio::test]
async fn test_expertise_delegation() {
let retriever = MockRetriever::new(vec![]);
let inner_agent = MockInnerAgent::new("Response".to_string());
let rag_agent = RetrievalAwareAgent::new(retriever, inner_agent);
assert_eq!(rag_agent.description(), "Mock inner agent for testing");
}
#[tokio::test]
async fn test_retrieval_aware_agent_preserves_attachments() {
use crate::attachment::Attachment;
let retriever = MockRetriever::new(vec![Document::new("Doc content")]);
let inner_agent = MockInnerAgent::new("ok".to_string());
let rag_agent = RetrievalAwareAgent::new(retriever, inner_agent.clone());
let attachment = Attachment::in_memory(vec![1, 2, 3]);
let payload = Payload::text("Query").with_attachment(attachment.clone());
let _ = rag_agent.execute(payload).await.unwrap();
let inner_calls = inner_agent.get_calls().await;
assert!(inner_calls[0].has_attachments());
assert_eq!(inner_calls[0].attachments().len(), 1);
assert_eq!(inner_calls[0].attachments()[0], &attachment);
}
}