pe_core/
retry_middleware.rs1use std::time::Duration;
17
18use async_trait::async_trait;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25pub struct RetryMiddleware {
31 max_attempts: u32,
32 initial_interval: Duration,
33}
34
35impl RetryMiddleware {
36 pub fn new(max_attempts: u32, initial_interval: Duration) -> Self {
41 Self {
42 max_attempts: max_attempts.max(1),
43 initial_interval,
44 }
45 }
46
47 fn delay_for_attempt(&self, n: u32) -> Duration {
51 let shift = n.min(32);
52 let base = (self.initial_interval.as_millis() as u64).saturating_mul(1u64 << shift);
53 let spread = base / 4;
54 Duration::from_millis(base.saturating_add(spread))
55 }
56}
57
58#[async_trait]
59impl ProviderMiddleware for RetryMiddleware {
60 async fn wrap_complete(
61 &self,
62 messages: &[Message],
63 tools: &[ToolSchema],
64 next: &dyn LlmProvider,
65 ) -> Result<LlmResponse, PeError> {
66 let mut last_err = None;
67
68 for attempt in 0..self.max_attempts {
69 match next.complete(messages, tools).await {
70 Ok(resp) => return Ok(resp),
71 Err(e) if e.is_retryable() && attempt + 1 < self.max_attempts => {
72 let delay = self.delay_for_attempt(attempt);
73 tokio::time::sleep(delay).await;
74 last_err = Some(e);
75 }
76 Err(e) => return Err(e),
77 }
78 }
79
80 Err(last_err.unwrap_or(PeError::Internal {
81 details: "retry loop exited without result".into(),
82 }))
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::mock_provider::MockProvider;
90
91 #[tokio::test]
92 async fn test_retry_succeeds_on_first_attempt() {
93 let retry = RetryMiddleware::new(3, Duration::from_millis(1));
94 let provider = MockProvider::new().respond_with("ok");
95
96 let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
97 assert_eq!(resp.message.content.as_text(), Some("ok"));
98 }
99
100 #[tokio::test]
101 async fn test_retry_succeeds_after_transient_failure() {
102 let retry = RetryMiddleware::new(3, Duration::from_millis(1));
103 let provider = MockProvider::new()
104 .respond_with_error(PeError::LlmProvider {
105 details: "503".into(),
106 })
107 .respond_with("recovered");
108
109 let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
110 assert_eq!(resp.message.content.as_text(), Some("recovered"));
111 }
112
113 #[tokio::test]
114 async fn test_retry_exhausts_attempts_on_persistent_transient() {
115 let retry = RetryMiddleware::new(2, Duration::from_millis(1));
116 let provider = MockProvider::new()
117 .respond_with_error(PeError::LlmProvider {
118 details: "503".into(),
119 })
120 .respond_with_error(PeError::LlmProvider {
121 details: "503".into(),
122 });
123
124 let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
125 assert!(matches!(err, PeError::LlmProvider { .. }));
126 }
127
128 #[tokio::test]
129 async fn test_retry_does_not_retry_permanent_errors() {
130 let retry = RetryMiddleware::new(3, Duration::from_millis(1));
131 let provider = MockProvider::new()
132 .respond_with_error(PeError::PermissionDenied {
133 action: "write".into(),
134 })
135 .respond_with("should not reach");
136
137 let err = retry.wrap_complete(&[], &[], &provider).await.unwrap_err();
138 assert!(matches!(err, PeError::PermissionDenied { .. }));
139 assert_eq!(provider.remaining(), 1);
141 }
142
143 #[tokio::test]
144 async fn test_retry_max_attempts_clamped_to_one() {
145 let retry = RetryMiddleware::new(0, Duration::from_millis(1));
146 let provider = MockProvider::new().respond_with("ok");
147
148 let resp = retry.wrap_complete(&[], &[], &provider).await.unwrap();
149 assert_eq!(resp.message.content.as_text(), Some("ok"));
150 }
151
152 #[tokio::test]
153 async fn test_delay_increases_exponentially() {
154 let retry = RetryMiddleware::new(5, Duration::from_millis(100));
155
156 let d0 = retry.delay_for_attempt(0); let d1 = retry.delay_for_attempt(1); let d2 = retry.delay_for_attempt(2); assert_eq!(d0.as_millis(), 125);
161 assert_eq!(d1.as_millis(), 250);
162 assert_eq!(d2.as_millis(), 500);
163 }
164}