use std::time::Duration;
use async_trait::async_trait;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;
pub struct RetryMiddleware {
max_attempts: u32,
initial_interval: Duration,
}
impl RetryMiddleware {
pub fn new(max_attempts: u32, initial_interval: Duration) -> Self {
Self {
max_attempts: max_attempts.max(1),
initial_interval,
}
}
fn delay_for_attempt(&self, n: u32) -> Duration {
let shift = n.min(32);
let base = (self.initial_interval.as_millis() as u64).saturating_mul(1u64 << shift);
let spread = base / 4;
Duration::from_millis(base.saturating_add(spread))
}
}
#[async_trait]
impl ProviderMiddleware for RetryMiddleware {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
let mut last_err = None;
for attempt in 0..self.max_attempts {
match next.complete(messages, tools).await {
Ok(resp) => return Ok(resp),
Err(e) if e.is_retryable() && attempt + 1 < self.max_attempts => {
let delay = self.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or(PeError::Internal {
details: "retry loop exited without result".into(),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
#[tokio::test]
async fn test_retry_succeeds_on_first_attempt() {
let retry = RetryMiddleware::new(3, Duration::from_millis(1));
let provider = MockProvider::new().respond_with("ok");
let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("ok"));
}
#[tokio::test]
async fn test_retry_succeeds_after_transient_failure() {
let retry = RetryMiddleware::new(3, Duration::from_millis(1));
let provider = MockProvider::new()
.respond_with_error(PeError::LlmProvider {
details: "503".into(),
})
.respond_with("recovered");
let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("recovered"));
}
#[tokio::test]
async fn test_retry_exhausts_attempts_on_persistent_transient() {
let retry = RetryMiddleware::new(2, Duration::from_millis(1));
let provider = MockProvider::new()
.respond_with_error(PeError::LlmProvider {
details: "503".into(),
})
.respond_with_error(PeError::LlmProvider {
details: "503".into(),
});
let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
assert!(matches!(err, PeError::LlmProvider { .. }));
}
#[tokio::test]
async fn test_retry_does_not_retry_permanent_errors() {
let retry = RetryMiddleware::new(3, Duration::from_millis(1));
let provider = MockProvider::new()
.respond_with_error(PeError::PermissionDenied {
action: "write".into(),
})
.respond_with("should not reach");
let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
assert!(matches!(err, PeError::PermissionDenied { .. }));
assert_eq!(provider.remaining(), 1);
}
#[tokio::test]
async fn test_retry_max_attempts_clamped_to_one() {
let retry = RetryMiddleware::new(0, Duration::from_millis(1));
let provider = MockProvider::new().respond_with("ok");
let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("ok"));
}
#[tokio::test]
async fn test_delay_increases_exponentially() {
let retry = RetryMiddleware::new(5, Duration::from_millis(100));
let d0 = retry.delay_for_attempt(0); let d1 = retry.delay_for_attempt(1); let d2 = retry.delay_for_attempt(2);
assert_eq!(d0.as_millis(), 125);
assert_eq!(d1.as_millis(), 250);
assert_eq!(d2.as_millis(), 500);
}
}