1use crate::error::RuntimeError;
2use rand::Rng;
3use std::time::Duration;
4
5#[derive(Debug, Clone)]
7pub struct RetryPolicy {
8 pub max_attempts: u32,
10
11 pub initial_delay: Duration,
13
14 pub max_delay: Duration,
16
17 pub backoff_multiplier: f64,
19
20 pub jitter_factor: f64,
23
24 pub max_total_duration: Option<Duration>,
26}
27
28impl Default for RetryPolicy {
29 fn default() -> Self {
30 Self {
31 max_attempts: 3,
32 initial_delay: Duration::from_millis(100),
33 max_delay: Duration::from_secs(30),
34 backoff_multiplier: 2.0,
35 jitter_factor: 0.1,
36 max_total_duration: Some(Duration::from_secs(60)),
37 }
38 }
39}
40
41impl RetryPolicy {
42 pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
44 Self {
45 max_attempts,
46 initial_delay,
47 ..Default::default()
48 }
49 }
50
51 pub fn no_retry() -> Self {
53 Self {
54 max_attempts: 0,
55 ..Default::default()
56 }
57 }
58
59 pub fn aggressive() -> Self {
61 Self {
62 max_attempts: 5,
63 initial_delay: Duration::from_millis(50),
64 max_delay: Duration::from_secs(10),
65 backoff_multiplier: 1.5,
66 jitter_factor: 0.2,
67 max_total_duration: Some(Duration::from_secs(30)),
68 }
69 }
70
71 pub fn conservative() -> Self {
73 Self {
74 max_attempts: 2,
75 initial_delay: Duration::from_secs(1),
76 max_delay: Duration::from_secs(60),
77 backoff_multiplier: 3.0,
78 jitter_factor: 0.1,
79 max_total_duration: Some(Duration::from_secs(120)),
80 }
81 }
82
83 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
85 let base_delay =
86 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
87
88 let clamped = base_delay.min(self.max_delay.as_millis() as f64);
89
90 let jittered = if self.jitter_factor > 0.0 {
92 let mut rng = rand::thread_rng();
93 let jitter = rng.gen::<f64>() * self.jitter_factor * clamped;
94 clamped + jitter
95 } else {
96 clamped
97 };
98
99 Duration::from_millis(jittered as u64)
100 }
101
102 pub async fn execute<F, Fut, T, E>(
122 &self,
123 operation_name: &str,
124 mut operation: F,
125 ) -> Result<T, RuntimeError>
126 where
127 F: FnMut() -> Fut,
128 Fut: std::future::Future<Output = Result<T, E>>,
129 E: Into<RuntimeError> + Clone,
130 {
131 let start = std::time::Instant::now();
132 let mut last_error = None;
133
134 for attempt in 0..=self.max_attempts {
135 if let Some(max_duration) = self.max_total_duration {
137 if start.elapsed() > max_duration {
138 break;
139 }
140 }
141
142 match operation().await {
144 Ok(result) => return Ok(result),
145 Err(e) => {
146 let runtime_error: RuntimeError = e.into();
147
148 let should_retry = match &runtime_error {
150 RuntimeError::Llm(llm_err) => llm_err.is_retryable(),
151 _ => false, };
153
154 last_error = Some(runtime_error.clone());
155
156 if attempt >= self.max_attempts || !should_retry {
160 break;
161 }
162
163 let delay = self.delay_for_attempt(attempt);
165 tokio::time::sleep(delay).await;
166 }
167 }
168 }
169
170 Err(RuntimeError::RetryExhausted {
172 operation: operation_name.to_string(),
173 attempts: self.max_attempts + 1,
174 last_error: Box::new(last_error.unwrap()),
175 })
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::LlmError;
183
184 #[test]
185 fn test_delay_calculation() {
186 let policy = RetryPolicy {
187 max_attempts: 3,
188 initial_delay: Duration::from_millis(100),
189 max_delay: Duration::from_secs(10),
190 backoff_multiplier: 2.0,
191 jitter_factor: 0.0, max_total_duration: None,
193 };
194
195 assert_eq!(policy.delay_for_attempt(0).as_millis(), 100);
196 assert_eq!(policy.delay_for_attempt(1).as_millis(), 200);
197 assert_eq!(policy.delay_for_attempt(2).as_millis(), 400);
198 }
199
200 #[test]
201 fn test_max_delay_clamp() {
202 let policy = RetryPolicy {
203 max_attempts: 10,
204 initial_delay: Duration::from_secs(1),
205 max_delay: Duration::from_secs(5),
206 backoff_multiplier: 2.0,
207 jitter_factor: 0.0,
208 max_total_duration: None,
209 };
210
211 let delay = policy.delay_for_attempt(10);
213 assert_eq!(delay, Duration::from_secs(5));
214 }
215
216 #[tokio::test]
217 async fn test_retry_success_on_second_attempt() {
218 let policy = RetryPolicy::default();
219 let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
220 let attempts_clone = attempts.clone();
221
222 let result: Result<&str, RuntimeError> = policy
223 .execute("test_op", move || {
224 let attempts = attempts_clone.clone();
225 async move {
226 let count = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
227 if count == 1 {
228 Err(LlmError::network("Network error"))
229 } else {
230 Ok("success")
231 }
232 }
233 })
234 .await;
235
236 assert!(result.is_ok());
237 assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 2);
238 }
239
240 #[tokio::test]
241 async fn test_retry_exhausted() {
242 let policy = RetryPolicy::new(2, Duration::from_millis(10));
243 let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
244 let attempts_clone = attempts.clone();
245
246 let result: Result<&str, RuntimeError> = policy
247 .execute("test_op", move || {
248 let attempts = attempts_clone.clone();
249 async move {
250 attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
251 Err(LlmError::network("Network error"))
252 }
253 })
254 .await;
255
256 assert!(result.is_err());
257 assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 3); match result.unwrap_err() {
260 RuntimeError::RetryExhausted {
261 attempts: retry_attempts,
262 ..
263 } => {
264 assert_eq!(retry_attempts, 3);
265 }
266 _ => panic!("Expected RetryExhausted error"),
267 }
268 }
269
270 #[tokio::test]
271 async fn test_no_retry_on_non_retryable_error() {
272 let policy = RetryPolicy::default();
273 let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
274 let attempts_clone = attempts.clone();
275
276 let result: Result<&str, RuntimeError> = policy
277 .execute("test_op", move || {
278 let attempts = attempts_clone.clone();
279 async move {
280 attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
281 Err(LlmError {
282 code: crate::error::LlmErrorCode::InvalidRequest,
283 message: "Bad request".to_string(),
284 provider: None,
285 model: None,
286 retryable: false,
287 })
288 }
289 })
290 .await;
291
292 assert!(result.is_err());
293 assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 1); }
295}