Skip to main content

pe_core/
retry_middleware.rs

1//! Retry middleware — exponential backoff with jitter for transient LLM errors.
2//!
3//! Only retries errors where [`PeError::is_transient()`] returns `true`.
4//! Permanent errors propagate immediately.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use std::time::Duration;
10//! use pe_core::retry_middleware::RetryMiddleware;
11//!
12//! let retry = RetryMiddleware::new(3, Duration::from_millis(100));
13//! let stack = MiddlewareStack::new(provider).with(retry);
14//! ```
15
16use std::time::Duration;
17
18use async_trait::async_trait;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25/// Retries transient LLM errors with exponential backoff and jitter.
26///
27/// - `max_attempts`: total attempts including the first call (minimum 1).
28/// - `initial_interval`: base delay before the first retry; doubles each attempt.
29/// - Jitter: up to 25% of the current interval, added randomly.
30pub struct RetryMiddleware {
31    max_attempts: u32,
32    initial_interval: Duration,
33}
34
35impl RetryMiddleware {
36    /// Create a retry middleware.
37    ///
38    /// `max_attempts` is clamped to at least 1. `initial_interval` is the
39    /// delay before the first retry; subsequent retries double the interval.
40    pub fn new(max_attempts: u32, initial_interval: Duration) -> Self {
41        Self {
42            max_attempts: max_attempts.max(1),
43            initial_interval,
44        }
45    }
46
47    /// Compute delay for attempt `n` (0-indexed retry number).
48    /// Exponential backoff capped at 32 doublings (~49 days) to prevent overflow.
49    /// Adds a fixed 25% to the base delay for spread.
50    fn delay_for_attempt(&self, n: u32) -> Duration {
51        let shift = n.min(32);
52        let base = (self.initial_interval.as_millis() as u64).saturating_mul(1u64 << shift);
53        let spread = base / 4;
54        Duration::from_millis(base.saturating_add(spread))
55    }
56}
57
58#[async_trait]
59impl ProviderMiddleware for RetryMiddleware {
60    async fn wrap_complete(
61        &self,
62        messages: &[Message],
63        tools: &[ToolSchema],
64        next: &dyn LlmProvider,
65    ) -> Result<LlmResponse, PeError> {
66        let mut last_err = None;
67
68        for attempt in 0..self.max_attempts {
69            match next.complete(messages, tools).await {
70                Ok(resp) => return Ok(resp),
71                Err(e) if e.is_retryable() && attempt + 1 < self.max_attempts => {
72                    let delay = self.delay_for_attempt(attempt);
73                    tokio::time::sleep(delay).await;
74                    last_err = Some(e);
75                }
76                Err(e) => return Err(e),
77            }
78        }
79
80        Err(last_err.unwrap_or(PeError::Internal {
81            details: "retry loop exited without result".into(),
82        }))
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::mock_provider::MockProvider;
90
91    #[tokio::test]
92    async fn test_retry_succeeds_on_first_attempt() {
93        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
94        let provider = MockProvider::new().respond_with("ok");
95
96        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
97        assert_eq!(resp.message.content.as_text(), Some("ok"));
98    }
99
100    #[tokio::test]
101    async fn test_retry_succeeds_after_transient_failure() {
102        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
103        let provider = MockProvider::new()
104            .respond_with_error(PeError::LlmProvider {
105                details: "503".into(),
106            })
107            .respond_with("recovered");
108
109        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
110        assert_eq!(resp.message.content.as_text(), Some("recovered"));
111    }
112
113    #[tokio::test]
114    async fn test_retry_exhausts_attempts_on_persistent_transient() {
115        let retry = RetryMiddleware::new(2, Duration::from_millis(1));
116        let provider = MockProvider::new()
117            .respond_with_error(PeError::LlmProvider {
118                details: "503".into(),
119            })
120            .respond_with_error(PeError::LlmProvider {
121                details: "503".into(),
122            });
123
124        let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
125        assert!(matches!(err, PeError::LlmProvider { .. }));
126    }
127
128    #[tokio::test]
129    async fn test_retry_does_not_retry_permanent_errors() {
130        let retry = RetryMiddleware::new(3, Duration::from_millis(1));
131        let provider = MockProvider::new()
132            .respond_with_error(PeError::PermissionDenied {
133                action: "write".into(),
134            })
135            .respond_with("should not reach");
136
137        let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
138        assert!(matches!(err, PeError::PermissionDenied { .. }));
139        // Second response untouched — no retry happened
140        assert_eq!(provider.remaining(), 1);
141    }
142
143    #[tokio::test]
144    async fn test_retry_max_attempts_clamped_to_one() {
145        let retry = RetryMiddleware::new(0, Duration::from_millis(1));
146        let provider = MockProvider::new().respond_with("ok");
147
148        let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
149        assert_eq!(resp.message.content.as_text(), Some("ok"));
150    }
151
152    #[tokio::test]
153    async fn test_delay_increases_exponentially() {
154        let retry = RetryMiddleware::new(5, Duration::from_millis(100));
155
156        let d0 = retry.delay_for_attempt(0); // 100 + 25 = 125
157        let d1 = retry.delay_for_attempt(1); // 200 + 50 = 250
158        let d2 = retry.delay_for_attempt(2); // 400 + 100 = 500
159
160        assert_eq!(d0.as_millis(), 125);
161        assert_eq!(d1.as_millis(), 250);
162        assert_eq!(d2.as_millis(), 500);
163    }
164}