use crate::agent::{Agent, AgentError, Payload, retry::retry_execution};
use async_trait::async_trait;
pub struct RetryAgent<T: Agent> {
inner: T,
max_retries: u32,
}
impl<T: Agent> RetryAgent<T> {
pub fn new(inner: T, max_retries: u32) -> Self {
Self { inner, max_retries }
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
}
#[async_trait]
impl<T: Agent> Agent for RetryAgent<T>
where
T::Output: Send,
{
type Output = T::Output;
type Expertise = T::Expertise;
fn expertise(&self) -> &Self::Expertise {
self.inner.expertise()
}
fn name(&self) -> String {
self.inner.name()
}
async fn execute(&self, payload: Payload) -> Result<Self::Output, AgentError> {
let inner = &self.inner;
retry_execution(self.max_retries, &payload, move |p| {
let p = p.clone();
async move { inner.execute(p).await }
})
.await
}
async fn is_available(&self) -> Result<(), AgentError> {
self.inner.is_available().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::error::ParseErrorReason;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
struct FailingAgent {
fail_count: Arc<AtomicU32>,
total_calls: Arc<AtomicU32>,
}
impl FailingAgent {
fn new(fail_count: u32) -> Self {
Self {
fail_count: Arc::new(AtomicU32::new(fail_count)),
total_calls: Arc::new(AtomicU32::new(0)),
}
}
fn total_calls(&self) -> u32 {
self.total_calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Agent for FailingAgent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Test agent that fails a configurable number of times";
&EXPERTISE
}
async fn execute(&self, _payload: Payload) -> Result<String, AgentError> {
self.total_calls.fetch_add(1, Ordering::SeqCst);
let remaining = self.fail_count.load(Ordering::SeqCst);
if remaining > 0 {
self.fail_count.fetch_sub(1, Ordering::SeqCst);
Err(AgentError::ParseError {
message: "Simulated failure".to_string(),
reason: ParseErrorReason::MarkdownExtractionFailed,
})
} else {
Ok("success".to_string())
}
}
}
#[tokio::test]
async fn test_retry_agent_success_first_try() {
let base = FailingAgent::new(0); let retry_agent = RetryAgent::new(base, 3);
let result = retry_agent.execute(Payload::text("test")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(retry_agent.inner().total_calls(), 1);
}
#[tokio::test]
async fn test_retry_agent_success_after_retries() {
let base = FailingAgent::new(2); let retry_agent = RetryAgent::new(base, 3);
let result = retry_agent.execute(Payload::text("test")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(retry_agent.inner().total_calls(), 3); }
#[tokio::test]
async fn test_retry_agent_max_retries_exhausted() {
let base = FailingAgent::new(10); let retry_agent = RetryAgent::new(base, 2);
let result = retry_agent.execute(Payload::text("test")).await;
assert!(result.is_err());
assert_eq!(retry_agent.inner().total_calls(), 3); }
#[tokio::test]
async fn test_retry_agent_name() {
let base = FailingAgent::new(0);
let retry_agent = RetryAgent::new(base, 3);
let name = retry_agent.name();
assert_eq!(name, "FailingAgent");
}
#[tokio::test]
async fn test_retry_agent_expertise_delegation() {
let base = FailingAgent::new(0);
let retry_agent = RetryAgent::new(base, 3);
assert_eq!(
retry_agent.description(),
"Test agent that fails a configurable number of times"
);
}
struct RateLimitedAgent {
fail_count: Arc<AtomicU32>,
total_calls: Arc<AtomicU32>,
retry_after: std::time::Duration,
}
impl RateLimitedAgent {
fn new(fail_count: u32, retry_after: std::time::Duration) -> Self {
Self {
fail_count: Arc::new(AtomicU32::new(fail_count)),
total_calls: Arc::new(AtomicU32::new(0)),
retry_after,
}
}
fn total_calls(&self) -> u32 {
self.total_calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Agent for RateLimitedAgent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Test agent that simulates rate limiting with retry_after";
&EXPERTISE
}
async fn execute(&self, _payload: Payload) -> Result<String, AgentError> {
self.total_calls.fetch_add(1, Ordering::SeqCst);
let remaining = self.fail_count.load(Ordering::SeqCst);
if remaining > 0 {
self.fail_count.fetch_sub(1, Ordering::SeqCst);
Err(AgentError::ProcessError {
status_code: Some(429),
message: "Rate limited".to_string(),
is_retryable: true,
retry_after: Some(self.retry_after),
})
} else {
Ok("success".to_string())
}
}
}
struct RateLimited429Agent {
fail_count: Arc<AtomicU32>,
total_calls: Arc<AtomicU32>,
}
impl RateLimited429Agent {
fn new(fail_count: u32) -> Self {
Self {
fail_count: Arc::new(AtomicU32::new(fail_count)),
total_calls: Arc::new(AtomicU32::new(0)),
}
}
fn total_calls(&self) -> u32 {
self.total_calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Agent for RateLimited429Agent {
type Output = String;
type Expertise = &'static str;
fn expertise(&self) -> &&'static str {
const EXPERTISE: &str = "Test agent that simulates 429 without retry_after";
&EXPERTISE
}
async fn execute(&self, _payload: Payload) -> Result<String, AgentError> {
self.total_calls.fetch_add(1, Ordering::SeqCst);
let remaining = self.fail_count.load(Ordering::SeqCst);
if remaining > 0 {
self.fail_count.fetch_sub(1, Ordering::SeqCst);
Err(AgentError::ProcessError {
status_code: Some(429),
message: "Rate limited".to_string(),
is_retryable: true,
retry_after: None,
})
} else {
Ok("success".to_string())
}
}
}
#[tokio::test]
async fn test_retry_agent_with_429_retry_after() {
let base = RateLimitedAgent::new(2, std::time::Duration::from_millis(100));
let retry_agent = RetryAgent::new(base, 3);
let start = std::time::Instant::now();
let result = retry_agent.execute(Payload::text("test")).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(retry_agent.inner().total_calls(), 3);
assert!(
elapsed.as_millis() < 1000,
"Should complete within 1 second with 100ms retry_after"
);
}
#[tokio::test]
async fn test_retry_agent_with_429_without_retry_after() {
let base = RateLimited429Agent::new(1);
let retry_agent = RetryAgent::new(base, 3);
let start = std::time::Instant::now();
let result = retry_agent.execute(Payload::text("test")).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(retry_agent.inner().total_calls(), 2);
assert!(
elapsed.as_secs() < 2,
"Should complete within 2 seconds with exponential backoff"
);
}
#[tokio::test]
async fn test_retry_agent_respects_retry_after_duration() {
let base = RateLimitedAgent::new(1, std::time::Duration::from_secs(1));
let retry_agent = RetryAgent::new(base, 3);
let start = std::time::Instant::now();
let result = retry_agent.execute(Payload::text("test")).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(retry_agent.inner().total_calls(), 2);
assert!(
elapsed.as_millis() < 1500,
"Should complete within 1.5 seconds"
);
}
#[tokio::test]
async fn test_retry_agent_429_exhausts_retries() {
let base = RateLimitedAgent::new(10, std::time::Duration::from_millis(50));
let retry_agent = RetryAgent::new(base, 2);
let result = retry_agent.execute(Payload::text("test")).await;
assert!(result.is_err());
assert_eq!(retry_agent.inner().total_calls(), 3);
if let Err(AgentError::ProcessError {
status_code,
retry_after,
..
}) = result
{
assert_eq!(status_code, Some(429));
assert!(retry_after.is_some());
} else {
panic!("Expected ProcessError with 429");
}
}
}