1use std::time::Duration;
42
43use crate::error::{AgentError, OperationError};
44
45const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
47
48const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
50
51const DEFAULT_MULTIPLIER: f64 = 2.0;
53
54#[derive(Debug, Clone)]
81pub struct RetryPolicy {
82 pub(crate) max_retries: u32,
83 pub(crate) initial_backoff: Duration,
84 pub(crate) max_backoff: Duration,
85 pub(crate) multiplier: f64,
86}
87
88impl RetryPolicy {
89 pub fn new(max_retries: u32) -> Self {
107 assert!(max_retries > 0, "max_retries must be greater than 0");
108 Self {
109 max_retries,
110 initial_backoff: DEFAULT_INITIAL_BACKOFF,
111 max_backoff: DEFAULT_MAX_BACKOFF,
112 multiplier: DEFAULT_MULTIPLIER,
113 }
114 }
115
116 pub fn backoff(mut self, duration: Duration) -> Self {
124 assert!(!duration.is_zero(), "initial backoff must not be zero");
125 self.initial_backoff = duration;
126 self
127 }
128
129 pub fn max_backoff(mut self, duration: Duration) -> Self {
138 assert!(!duration.is_zero(), "max backoff must not be zero");
139 self.max_backoff = duration;
140 self
141 }
142
143 pub fn multiplier(mut self, multiplier: f64) -> Self {
154 assert!(
155 multiplier >= 1.0 && multiplier.is_finite(),
156 "multiplier must be >= 1.0 and finite, got {multiplier}"
157 );
158 self.multiplier = multiplier;
159 self
160 }
161
162 pub fn max_retries(&self) -> u32 {
164 self.max_retries
165 }
166
167 pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
171 let delay = self.initial_backoff.as_secs_f64() * self.multiplier.powi(attempt as i32);
172 let capped = delay.min(self.max_backoff.as_secs_f64());
173 Duration::from_secs_f64(capped)
174 }
175}
176
177pub fn is_retryable(error: &OperationError) -> bool {
193 match error {
194 OperationError::Http { status, .. } => match status {
195 None => true,
196 Some(code) => *code >= 500 || *code == 429,
197 },
198 OperationError::Agent(agent_err) => matches!(
199 agent_err,
200 AgentError::ProcessFailed { .. } | AgentError::Timeout { .. }
201 ),
202 OperationError::Timeout { .. } => true,
203 OperationError::Shell { .. } | OperationError::Deserialize { .. } => false,
204 }
205}
206
207pub(crate) fn is_retryable_status(status: u16) -> bool {
210 status >= 500 || status == 429
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::time::Duration;
217
218 #[test]
221 fn new_creates_policy_with_defaults() {
222 let policy = RetryPolicy::new(3);
223 assert_eq!(policy.max_retries, 3);
224 assert_eq!(policy.initial_backoff, DEFAULT_INITIAL_BACKOFF);
225 assert_eq!(policy.max_backoff, DEFAULT_MAX_BACKOFF);
226 assert!((policy.multiplier - DEFAULT_MULTIPLIER).abs() < f64::EPSILON);
227 }
228
229 #[test]
230 #[should_panic(expected = "max_retries must be greater than 0")]
231 fn new_zero_retries_panics() {
232 let _ = RetryPolicy::new(0);
233 }
234
235 #[test]
236 fn backoff_sets_initial_backoff() {
237 let policy = RetryPolicy::new(1).backoff(Duration::from_secs(1));
238 assert_eq!(policy.initial_backoff, Duration::from_secs(1));
239 }
240
241 #[test]
242 #[should_panic(expected = "initial backoff must not be zero")]
243 fn backoff_zero_panics() {
244 let _ = RetryPolicy::new(1).backoff(Duration::ZERO);
245 }
246
247 #[test]
248 fn max_backoff_sets_cap() {
249 let policy = RetryPolicy::new(1).max_backoff(Duration::from_secs(60));
250 assert_eq!(policy.max_backoff, Duration::from_secs(60));
251 }
252
253 #[test]
254 #[should_panic(expected = "max backoff must not be zero")]
255 fn max_backoff_zero_panics() {
256 let _ = RetryPolicy::new(1).max_backoff(Duration::ZERO);
257 }
258
259 #[test]
260 fn multiplier_sets_value() {
261 let policy = RetryPolicy::new(1).multiplier(3.0);
262 assert!((policy.multiplier - 3.0).abs() < f64::EPSILON);
263 }
264
265 #[test]
266 #[should_panic(expected = "multiplier must be >= 1.0")]
267 fn multiplier_below_one_panics() {
268 let _ = RetryPolicy::new(1).multiplier(0.5);
269 }
270
271 #[test]
272 #[should_panic(expected = "multiplier must be >= 1.0 and finite")]
273 fn multiplier_nan_panics() {
274 let _ = RetryPolicy::new(1).multiplier(f64::NAN);
275 }
276
277 #[test]
278 #[should_panic(expected = "multiplier must be >= 1.0 and finite")]
279 fn multiplier_infinity_panics() {
280 let _ = RetryPolicy::new(1).multiplier(f64::INFINITY);
281 }
282
283 #[test]
284 fn max_retries_accessor() {
285 assert_eq!(RetryPolicy::new(5).max_retries(), 5);
286 }
287
288 #[test]
291 fn delay_for_attempt_zero_is_initial_backoff() {
292 let policy = RetryPolicy::new(3).backoff(Duration::from_millis(100));
293 let delay = policy.delay_for_attempt(0);
294 assert_eq!(delay, Duration::from_millis(100));
295 }
296
297 #[test]
298 fn delay_for_attempt_grows_exponentially() {
299 let policy = RetryPolicy::new(5)
300 .backoff(Duration::from_millis(100))
301 .multiplier(2.0);
302
303 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
304 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
305 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
306 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
307 }
308
309 #[test]
310 fn delay_for_attempt_capped_at_max_backoff() {
311 let policy = RetryPolicy::new(10)
312 .backoff(Duration::from_secs(1))
313 .max_backoff(Duration::from_secs(5))
314 .multiplier(10.0);
315
316 assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
318 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(5));
319 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
320 }
321
322 #[test]
323 fn delay_for_attempt_with_multiplier_one_is_constant() {
324 let policy = RetryPolicy::new(3)
325 .backoff(Duration::from_millis(500))
326 .multiplier(1.0);
327
328 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(500));
329 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(500));
330 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(500));
331 }
332
333 #[test]
336 fn http_transport_error_is_retryable() {
337 let err = OperationError::Http {
338 status: None,
339 message: "connection refused".to_string(),
340 };
341 assert!(is_retryable(&err));
342 }
343
344 #[test]
345 fn http_500_is_retryable() {
346 let err = OperationError::Http {
347 status: Some(500),
348 message: "internal server error".to_string(),
349 };
350 assert!(is_retryable(&err));
351 }
352
353 #[test]
354 fn http_502_is_retryable() {
355 let err = OperationError::Http {
356 status: Some(502),
357 message: "bad gateway".to_string(),
358 };
359 assert!(is_retryable(&err));
360 }
361
362 #[test]
363 fn http_503_is_retryable() {
364 let err = OperationError::Http {
365 status: Some(503),
366 message: "service unavailable".to_string(),
367 };
368 assert!(is_retryable(&err));
369 }
370
371 #[test]
372 fn http_429_is_retryable() {
373 let err = OperationError::Http {
374 status: Some(429),
375 message: "too many requests".to_string(),
376 };
377 assert!(is_retryable(&err));
378 }
379
380 #[test]
381 fn http_400_is_not_retryable() {
382 let err = OperationError::Http {
383 status: Some(400),
384 message: "bad request".to_string(),
385 };
386 assert!(!is_retryable(&err));
387 }
388
389 #[test]
390 fn http_404_is_not_retryable() {
391 let err = OperationError::Http {
392 status: Some(404),
393 message: "not found".to_string(),
394 };
395 assert!(!is_retryable(&err));
396 }
397
398 #[test]
399 fn agent_process_failed_is_retryable() {
400 let err = OperationError::Agent(AgentError::ProcessFailed {
401 exit_code: 1,
402 stderr: "crash".to_string(),
403 });
404 assert!(is_retryable(&err));
405 }
406
407 #[test]
408 fn agent_timeout_is_retryable() {
409 let err = OperationError::Agent(AgentError::Timeout {
410 limit: Duration::from_secs(60),
411 });
412 assert!(is_retryable(&err));
413 }
414
415 #[test]
416 fn agent_prompt_too_large_is_not_retryable() {
417 let err = OperationError::Agent(AgentError::PromptTooLarge {
418 chars: 1_000_000,
419 estimated_tokens: 250_000,
420 model_limit: 200_000,
421 });
422 assert!(!is_retryable(&err));
423 }
424
425 #[test]
426 fn agent_schema_validation_is_not_retryable() {
427 let err = OperationError::Agent(AgentError::SchemaValidation {
428 expected: "object".to_string(),
429 got: "string".to_string(),
430 debug_messages: Vec::new(),
431 partial_usage: Box::default(),
432 raw_response: None,
433 });
434 assert!(!is_retryable(&err));
435 }
436
437 #[test]
438 fn operation_timeout_is_retryable() {
439 let err = OperationError::Timeout {
440 step: "fetch".to_string(),
441 limit: Duration::from_secs(30),
442 };
443 assert!(is_retryable(&err));
444 }
445
446 #[test]
447 fn shell_error_is_not_retryable() {
448 let err = OperationError::Shell {
449 exit_code: 1,
450 stderr: "fail".to_string(),
451 };
452 assert!(!is_retryable(&err));
453 }
454
455 #[test]
456 fn deserialize_error_is_not_retryable() {
457 let err = OperationError::Deserialize {
458 target_type: "MyStruct".to_string(),
459 reason: "missing field".to_string(),
460 };
461 assert!(!is_retryable(&err));
462 }
463
464 #[test]
467 fn retryable_status_codes() {
468 assert!(is_retryable_status(500));
469 assert!(is_retryable_status(502));
470 assert!(is_retryable_status(503));
471 assert!(is_retryable_status(504));
472 assert!(is_retryable_status(429));
473 }
474
475 #[test]
476 fn non_retryable_status_codes() {
477 assert!(!is_retryable_status(200));
478 assert!(!is_retryable_status(201));
479 assert!(!is_retryable_status(301));
480 assert!(!is_retryable_status(400));
481 assert!(!is_retryable_status(401));
482 assert!(!is_retryable_status(403));
483 assert!(!is_retryable_status(404));
484 assert!(!is_retryable_status(422));
485 assert!(!is_retryable_status(428));
486 }
487
488 #[test]
491 fn builder_chain_all_methods() {
492 let policy = RetryPolicy::new(5)
493 .backoff(Duration::from_millis(100))
494 .max_backoff(Duration::from_secs(10))
495 .multiplier(3.0);
496
497 assert_eq!(policy.max_retries, 5);
498 assert_eq!(policy.initial_backoff, Duration::from_millis(100));
499 assert_eq!(policy.max_backoff, Duration::from_secs(10));
500 assert!((policy.multiplier - 3.0).abs() < f64::EPSILON);
501 }
502
503 #[test]
504 fn clone_produces_independent_copy() {
505 let policy = RetryPolicy::new(3).backoff(Duration::from_millis(100));
506 let cloned = policy.clone();
507 assert_eq!(policy.max_retries, cloned.max_retries);
508 assert_eq!(policy.initial_backoff, cloned.initial_backoff);
509 }
510
511 #[test]
512 fn debug_does_not_panic() {
513 let policy = RetryPolicy::new(1);
514 let debug = format!("{:?}", policy);
515 assert!(debug.contains("RetryPolicy"));
516 }
517}