1use crate::error::TaskError;
50use chrono::NaiveDateTime;
51use rand::Rng;
52use serde::{Deserialize, Serialize};
53use std::time::Duration;
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
61pub struct RetryPolicy {
62 pub max_attempts: i32,
64
65 pub backoff_strategy: BackoffStrategy,
67
68 pub initial_delay: Duration,
70
71 pub max_delay: Duration,
73
74 pub jitter: bool,
76
77 pub retry_conditions: Vec<RetryCondition>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86#[serde(tag = "type")]
87pub enum BackoffStrategy {
88 Fixed,
90
91 Linear {
94 multiplier: f64,
96 },
97
98 Exponential {
101 base: f64,
103 multiplier: f64,
105 },
106
107 Custom {
109 function_name: String,
111 },
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119#[serde(tag = "type")]
120pub enum RetryCondition {
121 AllErrors,
123
124 Never,
126
127 TransientOnly,
129
130 ErrorPattern { patterns: Vec<String> },
132}
133
134impl Default for RetryPolicy {
135 fn default() -> Self {
145 Self {
146 max_attempts: 3,
147 backoff_strategy: BackoffStrategy::Exponential {
148 base: 2.0,
149 multiplier: 1.0,
150 },
151 initial_delay: Duration::from_secs(1),
152 max_delay: Duration::from_secs(60),
153 jitter: true,
154 retry_conditions: vec![RetryCondition::AllErrors],
155 }
156 }
157}
158
159impl RetryPolicy {
160 pub fn builder() -> RetryPolicyBuilder {
162 RetryPolicyBuilder::new()
163 }
164
165 pub fn calculate_delay(&self, attempt: i32) -> Duration {
175 let base_delay = match &self.backoff_strategy {
176 BackoffStrategy::Fixed => self.initial_delay,
177
178 BackoffStrategy::Linear { multiplier } => {
179 let millis = self.initial_delay.as_millis() as f64 * attempt as f64 * multiplier;
180 Duration::from_millis(millis as u64)
181 }
182
183 BackoffStrategy::Exponential { base, multiplier } => {
184 let millis =
185 self.initial_delay.as_millis() as f64 * multiplier * base.powi(attempt - 1);
186 Duration::from_millis(millis as u64)
187 }
188
189 BackoffStrategy::Custom { .. } => {
190 let millis = self.initial_delay.as_millis() as f64 * 2.0_f64.powi(attempt - 1);
192 Duration::from_millis(millis as u64)
193 }
194 };
195
196 let capped_delay = std::cmp::min(base_delay, self.max_delay);
198
199 if self.jitter {
201 self.add_jitter(capped_delay)
202 } else {
203 capped_delay
204 }
205 }
206
207 pub fn should_retry(&self, error: &TaskError, attempt: i32) -> bool {
218 if attempt >= self.max_attempts {
220 return false;
221 }
222
223 self.retry_conditions
225 .iter()
226 .any(|condition| match condition {
227 RetryCondition::AllErrors => true,
228 RetryCondition::Never => false,
229 RetryCondition::TransientOnly => self.is_transient_error(error),
230 RetryCondition::ErrorPattern { patterns } => {
231 let error_msg = error.to_string().to_lowercase();
232 patterns
233 .iter()
234 .any(|pattern| error_msg.contains(&pattern.to_lowercase()))
235 }
236 })
237 }
238
239 pub fn calculate_retry_at(&self, attempt: i32, now: NaiveDateTime) -> NaiveDateTime {
250 let delay = self.calculate_delay(attempt);
251 now + chrono::Duration::from_std(delay).unwrap_or_default()
252 }
253
254 fn add_jitter(&self, delay: Duration) -> Duration {
258 let mut rng = rand::thread_rng();
259 let jitter_factor = rng.gen_range(0.75..=1.25); let jittered_millis = (delay.as_millis() as f64 * jitter_factor) as u64;
261 Duration::from_millis(jittered_millis)
262 }
263
264 fn is_transient_error(&self, error: &TaskError) -> bool {
266 match error {
267 TaskError::Timeout { .. } => true,
268 TaskError::ExecutionFailed { message, .. } | TaskError::Unknown { message, .. } => {
269 Self::message_matches_transient_patterns(message)
270 }
271 _ => false,
272 }
273 }
274
275 fn message_matches_transient_patterns(message: &str) -> bool {
277 const TRANSIENT_PATTERNS: &[&str] = &[
278 "connection",
279 "network",
280 "timeout",
281 "temporary",
282 "unavailable",
283 "busy",
284 "overloaded",
285 "rate limit",
286 ];
287 let error_msg = message.to_lowercase();
288 TRANSIENT_PATTERNS
289 .iter()
290 .any(|pattern| error_msg.contains(pattern))
291 }
292}
293
294#[derive(Debug)]
296pub struct RetryPolicyBuilder {
297 policy: RetryPolicy,
298}
299
300impl RetryPolicyBuilder {
301 pub fn new() -> Self {
303 Self {
304 policy: RetryPolicy::default(),
305 }
306 }
307
308 pub fn max_attempts(mut self, max_attempts: i32) -> Self {
310 self.policy.max_attempts = max_attempts;
311 self
312 }
313
314 pub fn backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
316 self.policy.backoff_strategy = strategy;
317 self
318 }
319
320 pub fn initial_delay(mut self, delay: Duration) -> Self {
322 self.policy.initial_delay = delay;
323 self
324 }
325
326 pub fn max_delay(mut self, delay: Duration) -> Self {
328 self.policy.max_delay = delay;
329 self
330 }
331
332 pub fn with_jitter(mut self, jitter: bool) -> Self {
334 self.policy.jitter = jitter;
335 self
336 }
337
338 pub fn retry_condition(mut self, condition: RetryCondition) -> Self {
340 self.policy.retry_conditions = vec![condition];
341 self
342 }
343
344 pub fn retry_conditions(mut self, conditions: Vec<RetryCondition>) -> Self {
346 self.policy.retry_conditions = conditions;
347 self
348 }
349
350 pub fn build(self) -> RetryPolicy {
352 self.policy
353 }
354}
355
356impl Default for RetryPolicyBuilder {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_default_retry_policy() {
368 let policy = RetryPolicy::default();
369 assert_eq!(policy.max_attempts, 3);
370 assert_eq!(policy.initial_delay, Duration::from_secs(1));
371 assert_eq!(policy.max_delay, Duration::from_secs(60));
372 assert!(policy.jitter);
373 assert!(matches!(
374 policy.backoff_strategy,
375 BackoffStrategy::Exponential { .. }
376 ));
377 }
378
379 #[test]
380 fn test_retry_policy_builder() {
381 let policy = RetryPolicy::builder()
382 .max_attempts(5)
383 .initial_delay(Duration::from_millis(500))
384 .max_delay(Duration::from_secs(30))
385 .with_jitter(false)
386 .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.5 })
387 .retry_condition(RetryCondition::TransientOnly)
388 .build();
389
390 assert_eq!(policy.max_attempts, 5);
391 assert_eq!(policy.initial_delay, Duration::from_millis(500));
392 assert_eq!(policy.max_delay, Duration::from_secs(30));
393 assert!(!policy.jitter);
394 assert_eq!(policy.retry_conditions, vec![RetryCondition::TransientOnly]);
395 }
396
397 #[test]
398 fn test_fixed_backoff_calculation() {
399 let policy = RetryPolicy::builder()
400 .backoff_strategy(BackoffStrategy::Fixed)
401 .initial_delay(Duration::from_secs(2))
402 .with_jitter(false)
403 .build();
404
405 assert_eq!(policy.calculate_delay(1), Duration::from_secs(2));
406 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
407 assert_eq!(policy.calculate_delay(3), Duration::from_secs(2));
408 }
409
410 #[test]
411 fn test_linear_backoff_calculation() {
412 let policy = RetryPolicy::builder()
413 .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.0 })
414 .initial_delay(Duration::from_secs(1))
415 .with_jitter(false)
416 .build();
417
418 assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
419 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
420 assert_eq!(policy.calculate_delay(3), Duration::from_secs(3));
421 }
422
423 #[test]
424 fn test_exponential_backoff_calculation() {
425 let policy = RetryPolicy::builder()
426 .backoff_strategy(BackoffStrategy::Exponential {
427 base: 2.0,
428 multiplier: 1.0,
429 })
430 .initial_delay(Duration::from_secs(1))
431 .with_jitter(false)
432 .build();
433
434 assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
435 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
436 assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
437 assert_eq!(policy.calculate_delay(4), Duration::from_secs(8));
438 }
439
440 #[test]
441 fn test_max_delay_capping() {
442 let policy = RetryPolicy::builder()
443 .backoff_strategy(BackoffStrategy::Exponential {
444 base: 2.0,
445 multiplier: 1.0,
446 })
447 .initial_delay(Duration::from_secs(10))
448 .max_delay(Duration::from_secs(15))
449 .with_jitter(false)
450 .build();
451
452 assert_eq!(policy.calculate_delay(1), Duration::from_secs(10));
453 assert_eq!(policy.calculate_delay(2), Duration::from_secs(15)); assert_eq!(policy.calculate_delay(3), Duration::from_secs(15)); }
456
457 fn make_execution_error(msg: &str) -> TaskError {
460 TaskError::ExecutionFailed {
461 message: msg.to_string(),
462 task_id: "test".to_string(),
463 timestamp: chrono::Utc::now(),
464 }
465 }
466
467 fn make_unknown_error(msg: &str) -> TaskError {
468 TaskError::Unknown {
469 task_id: "test".to_string(),
470 message: msg.to_string(),
471 }
472 }
473
474 #[test]
475 fn test_timeout_is_transient() {
476 let policy = RetryPolicy::default();
477 let error = TaskError::Timeout {
478 task_id: "test".to_string(),
479 timeout_seconds: 30,
480 };
481 assert!(policy.is_transient_error(&error));
482 }
483
484 #[test]
485 fn test_connection_error_is_transient() {
486 let policy = RetryPolicy::default();
487 assert!(policy.is_transient_error(&make_execution_error("Connection refused")));
488 assert!(policy.is_transient_error(&make_execution_error("network unreachable")));
489 assert!(policy.is_transient_error(&make_execution_error("service temporarily unavailable")));
490 assert!(policy.is_transient_error(&make_execution_error("server busy")));
491 assert!(policy.is_transient_error(&make_execution_error("overloaded")));
492 assert!(policy.is_transient_error(&make_execution_error("rate limit exceeded")));
493 }
494
495 #[test]
496 fn test_unknown_error_with_transient_message_is_transient() {
497 let policy = RetryPolicy::default();
498 assert!(policy.is_transient_error(&make_unknown_error("Connection reset by peer")));
499 assert!(policy.is_transient_error(&make_unknown_error("TIMEOUT waiting for response")));
500 }
501
502 #[test]
503 fn test_permanent_errors_are_not_transient() {
504 let policy = RetryPolicy::default();
505 assert!(!policy.is_transient_error(&make_execution_error("invalid input format")));
506 assert!(!policy.is_transient_error(&make_execution_error("permission denied")));
507 assert!(!policy.is_transient_error(&make_unknown_error("null pointer")));
508 }
509
510 #[test]
511 fn test_non_retryable_error_variants_are_not_transient() {
512 let policy = RetryPolicy::default();
513 assert!(!policy.is_transient_error(&TaskError::ContextError {
514 task_id: "t".to_string(),
515 error: crate::error::ContextError::KeyNotFound("k".to_string()),
516 }));
517 assert!(
518 !policy.is_transient_error(&TaskError::DependencyNotSatisfied {
519 dependency: "dep".to_string(),
520 task_id: "t".to_string(),
521 })
522 );
523 assert!(!policy.is_transient_error(&TaskError::ValidationFailed {
524 message: "bad".to_string(),
525 }));
526 assert!(
527 !policy.is_transient_error(&TaskError::ReadinessCheckFailed {
528 task_id: "t".to_string(),
529 })
530 );
531 assert!(!policy.is_transient_error(&TaskError::TriggerRuleFailed {
532 task_id: "t".to_string(),
533 }));
534 }
535
536 #[test]
537 fn test_transient_pattern_matching_is_case_insensitive() {
538 let policy = RetryPolicy::default();
539 assert!(policy.is_transient_error(&make_execution_error("CONNECTION REFUSED")));
540 assert!(policy.is_transient_error(&make_execution_error("Network Error")));
541 assert!(policy.is_transient_error(&make_execution_error("TIMEOUT")));
542 }
543
544 #[test]
547 fn test_should_retry_all_errors_within_limit() {
548 let policy = RetryPolicy::builder()
549 .max_attempts(3)
550 .retry_condition(RetryCondition::AllErrors)
551 .build();
552
553 let error = make_execution_error("anything");
554 assert!(policy.should_retry(&error, 1));
555 assert!(policy.should_retry(&error, 2));
556 assert!(!policy.should_retry(&error, 3)); assert!(!policy.should_retry(&error, 4)); }
559
560 #[test]
561 fn test_should_retry_never_condition() {
562 let policy = RetryPolicy::builder()
563 .max_attempts(10)
564 .retry_condition(RetryCondition::Never)
565 .build();
566
567 assert!(!policy.should_retry(&make_execution_error("anything"), 1));
568 }
569
570 #[test]
571 fn test_should_retry_transient_only() {
572 let policy = RetryPolicy::builder()
573 .max_attempts(3)
574 .retry_condition(RetryCondition::TransientOnly)
575 .build();
576
577 assert!(policy.should_retry(&make_execution_error("connection refused"), 1));
578 assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
579 }
580
581 #[test]
582 fn test_should_retry_error_pattern() {
583 let policy = RetryPolicy::builder()
584 .max_attempts(3)
585 .retry_condition(RetryCondition::ErrorPattern {
586 patterns: vec!["deadlock".to_string(), "lock timeout".to_string()],
587 })
588 .build();
589
590 assert!(policy.should_retry(&make_execution_error("deadlock detected"), 1));
591 assert!(policy.should_retry(&make_execution_error("Lock Timeout on table"), 1));
592 assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
593 }
594
595 #[test]
596 fn test_should_retry_zero_max_attempts() {
597 let policy = RetryPolicy::builder()
598 .max_attempts(0)
599 .retry_condition(RetryCondition::AllErrors)
600 .build();
601
602 assert!(!policy.should_retry(&make_execution_error("anything"), 0));
603 }
604
605 #[test]
606 fn test_custom_backoff_falls_back_to_exponential() {
607 let policy = RetryPolicy::builder()
608 .backoff_strategy(BackoffStrategy::Custom {
609 function_name: "my_func".to_string(),
610 })
611 .initial_delay(Duration::from_secs(1))
612 .with_jitter(false)
613 .build();
614
615 assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
616 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
617 assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
618 }
619
620 #[test]
621 fn test_jitter_stays_within_bounds() {
622 let policy = RetryPolicy::builder()
623 .backoff_strategy(BackoffStrategy::Fixed)
624 .initial_delay(Duration::from_secs(10))
625 .with_jitter(true)
626 .build();
627
628 for _ in 0..100 {
630 let delay = policy.calculate_delay(1);
631 let millis = delay.as_millis();
632 assert!(millis >= 7500, "jitter too low: {}ms", millis);
633 assert!(millis <= 12500, "jitter too high: {}ms", millis);
634 }
635 }
636
637 #[test]
638 fn test_message_matches_transient_patterns_directly() {
639 assert!(RetryPolicy::message_matches_transient_patterns(
640 "connection reset"
641 ));
642 assert!(RetryPolicy::message_matches_transient_patterns(
643 "NETWORK error"
644 ));
645 assert!(!RetryPolicy::message_matches_transient_patterns(
646 "invalid input"
647 ));
648 assert!(!RetryPolicy::message_matches_transient_patterns(""));
649 }
650}