pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Retry middleware — exponential backoff with jitter for transient LLM errors.
//!
//! Only retries errors where [`PeError::is_transient()`] returns `true`.
//! Permanent errors propagate immediately.
//!
//! # Example
//!
//! ```ignore
//! use std::time::Duration;
//! use pe_core::retry_middleware::RetryMiddleware;
//!
//! let retry = RetryMiddleware::new(3, Duration::from_millis(100));
//! let stack = MiddlewareStack::new(provider).with(retry);
//! ```

use std::time::Duration;

use async_trait::async_trait;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;

/// Retries transient LLM errors with exponential backoff and jitter.
///
/// - `max_attempts`: total attempts including the first call (minimum 1).
/// - `initial_interval`: base delay before the first retry; doubles each attempt.
/// - Jitter: up to 25% of the current interval, added randomly.
pub struct RetryMiddleware {
    max_attempts: u32,
    initial_interval: Duration,
}

impl RetryMiddleware {
    /// Create a retry middleware.
    ///
    /// `max_attempts` is clamped to at least 1. `initial_interval` is the
    /// delay before the first retry; subsequent retries double the interval.
    pub fn new(max_attempts: u32, initial_interval: Duration) -> Self {
        Self {
            max_attempts: max_attempts.max(1),
            initial_interval,
        }
    }

    /// Compute delay for attempt `n` (0-indexed retry number).
    /// Exponential backoff capped at 32 doublings (~49 days) to prevent overflow.
    /// Adds a fixed 25% to the base delay for spread.
    fn delay_for_attempt(&self, n: u32) -> Duration {
        let shift = n.min(32);
        let base = (self.initial_interval.as_millis() as u64).saturating_mul(1u64 << shift);
        let spread = base / 4;
        Duration::from_millis(base.saturating_add(spread))
    }
}

#[async_trait]
impl ProviderMiddleware for RetryMiddleware {
    async fn wrap_complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
        next: &dyn LlmProvider,
    ) -> Result<LlmResponse, PeError> {
        let mut last_err = None;

        for attempt in 0..self.max_attempts {
            match next.complete(messages, tools).await {
                Ok(resp) => return Ok(resp),
                Err(e) if e.is_retryable() && attempt + 1 < self.max_attempts => {
                    let delay = self.delay_for_attempt(attempt);
                    tokio::time::sleep(delay).await;
                    last_err = Some(e);
                }
                Err(e) => return Err(e),
            }
        }

        Err(last_err.unwrap_or(PeError::Internal {
            details: "retry loop exited without result".into(),
        }))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;

    #[tokio::test]
    async fn test_retry_succeeds_on_first_attempt() {
        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
        let provider = MockProvider::new().respond_with("ok");

        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("ok"));
    }

    #[tokio::test]
    async fn test_retry_succeeds_after_transient_failure() {
        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
        let provider = MockProvider::new()
            .respond_with_error(PeError::LlmProvider {
                details: "503".into(),
            })
            .respond_with("recovered");

        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("recovered"));
    }

    #[tokio::test]
    async fn test_retry_exhausts_attempts_on_persistent_transient() {
        let retry = RetryMiddleware::new(2, Duration::from_millis(1));
        let provider = MockProvider::new()
            .respond_with_error(PeError::LlmProvider {
                details: "503".into(),
            })
            .respond_with_error(PeError::LlmProvider {
                details: "503".into(),
            });

        let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
        assert!(matches!(err, PeError::LlmProvider { .. }));
    }

    #[tokio::test]
    async fn test_retry_does_not_retry_permanent_errors() {
        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
        let provider = MockProvider::new()
            .respond_with_error(PeError::PermissionDenied {
                action: "write".into(),
            })
            .respond_with("should not reach");

        let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
        assert!(matches!(err, PeError::PermissionDenied { .. }));
        // Second response untouched — no retry happened
        assert_eq!(provider.remaining(), 1);
    }

    #[tokio::test]
    async fn test_retry_max_attempts_clamped_to_one() {
        let retry = RetryMiddleware::new(0, Duration::from_millis(1));
        let provider = MockProvider::new().respond_with("ok");

        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("ok"));
    }

    #[tokio::test]
    async fn test_delay_increases_exponentially() {
        let retry = RetryMiddleware::new(5, Duration::from_millis(100));

        let d0 = retry.delay_for_attempt(0); // 100 + 25 = 125
        let d1 = retry.delay_for_attempt(1); // 200 + 50 = 250
        let d2 = retry.delay_for_attempt(2); // 400 + 100 = 500

        assert_eq!(d0.as_millis(), 125);
        assert_eq!(d1.as_millis(), 250);
        assert_eq!(d2.as_millis(), 500);
    }
}