use crate::error::TaskError;
use chrono::NaiveDateTime;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RetryPolicy {
pub max_attempts: i32,
pub backoff_strategy: BackoffStrategy,
pub initial_delay: Duration,
pub max_delay: Duration,
pub jitter: bool,
pub retry_conditions: Vec<RetryCondition>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum BackoffStrategy {
Fixed,
Linear {
multiplier: f64,
},
Exponential {
base: f64,
multiplier: f64,
},
Custom {
function_name: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum RetryCondition {
AllErrors,
Never,
TransientOnly,
ErrorPattern { patterns: Vec<String> },
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
backoff_strategy: BackoffStrategy::Exponential {
base: 2.0,
multiplier: 1.0,
},
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
jitter: true,
retry_conditions: vec![RetryCondition::AllErrors],
}
}
}
impl RetryPolicy {
pub fn builder() -> RetryPolicyBuilder {
RetryPolicyBuilder::new()
}
pub fn calculate_delay(&self, attempt: i32) -> Duration {
let base_delay = match &self.backoff_strategy {
BackoffStrategy::Fixed => self.initial_delay,
BackoffStrategy::Linear { multiplier } => {
let millis = self.initial_delay.as_millis() as f64 * attempt as f64 * multiplier;
Duration::from_millis(millis as u64)
}
BackoffStrategy::Exponential { base, multiplier } => {
let millis =
self.initial_delay.as_millis() as f64 * multiplier * base.powi(attempt - 1);
Duration::from_millis(millis as u64)
}
BackoffStrategy::Custom { .. } => {
let millis = self.initial_delay.as_millis() as f64 * 2.0_f64.powi(attempt - 1);
Duration::from_millis(millis as u64)
}
};
let capped_delay = std::cmp::min(base_delay, self.max_delay);
if self.jitter {
self.add_jitter(capped_delay)
} else {
capped_delay
}
}
pub fn should_retry(&self, error: &TaskError, attempt: i32) -> bool {
if attempt >= self.max_attempts {
return false;
}
self.retry_conditions
.iter()
.any(|condition| match condition {
RetryCondition::AllErrors => true,
RetryCondition::Never => false,
RetryCondition::TransientOnly => self.is_transient_error(error),
RetryCondition::ErrorPattern { patterns } => {
let error_msg = error.to_string().to_lowercase();
patterns
.iter()
.any(|pattern| error_msg.contains(&pattern.to_lowercase()))
}
})
}
pub fn calculate_retry_at(&self, attempt: i32, now: NaiveDateTime) -> NaiveDateTime {
let delay = self.calculate_delay(attempt);
now + chrono::Duration::from_std(delay).unwrap_or_default()
}
fn add_jitter(&self, delay: Duration) -> Duration {
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.75..=1.25); let jittered_millis = (delay.as_millis() as f64 * jitter_factor) as u64;
Duration::from_millis(jittered_millis)
}
fn is_transient_error(&self, error: &TaskError) -> bool {
match error {
TaskError::Timeout { .. } => true,
TaskError::ExecutionFailed { message, .. } | TaskError::Unknown { message, .. } => {
Self::message_matches_transient_patterns(message)
}
_ => false,
}
}
fn message_matches_transient_patterns(message: &str) -> bool {
const TRANSIENT_PATTERNS: &[&str] = &[
"connection",
"network",
"timeout",
"temporary",
"unavailable",
"busy",
"overloaded",
"rate limit",
];
let error_msg = message.to_lowercase();
TRANSIENT_PATTERNS
.iter()
.any(|pattern| error_msg.contains(pattern))
}
}
#[derive(Debug)]
pub struct RetryPolicyBuilder {
policy: RetryPolicy,
}
impl RetryPolicyBuilder {
pub fn new() -> Self {
Self {
policy: RetryPolicy::default(),
}
}
pub fn max_attempts(mut self, max_attempts: i32) -> Self {
self.policy.max_attempts = max_attempts;
self
}
pub fn backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
self.policy.backoff_strategy = strategy;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.policy.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.policy.max_delay = delay;
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.policy.jitter = jitter;
self
}
pub fn retry_condition(mut self, condition: RetryCondition) -> Self {
self.policy.retry_conditions = vec![condition];
self
}
pub fn retry_conditions(mut self, conditions: Vec<RetryCondition>) -> Self {
self.policy.retry_conditions = conditions;
self
}
pub fn build(self) -> RetryPolicy {
self.policy
}
}
impl Default for RetryPolicyBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_retry_policy() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_delay, Duration::from_secs(1));
assert_eq!(policy.max_delay, Duration::from_secs(60));
assert!(policy.jitter);
assert!(matches!(
policy.backoff_strategy,
BackoffStrategy::Exponential { .. }
));
}
#[test]
fn test_retry_policy_builder() {
let policy = RetryPolicy::builder()
.max_attempts(5)
.initial_delay(Duration::from_millis(500))
.max_delay(Duration::from_secs(30))
.with_jitter(false)
.backoff_strategy(BackoffStrategy::Linear { multiplier: 1.5 })
.retry_condition(RetryCondition::TransientOnly)
.build();
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.initial_delay, Duration::from_millis(500));
assert_eq!(policy.max_delay, Duration::from_secs(30));
assert!(!policy.jitter);
assert_eq!(policy.retry_conditions, vec![RetryCondition::TransientOnly]);
}
#[test]
fn test_fixed_backoff_calculation() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Fixed)
.initial_delay(Duration::from_secs(2))
.with_jitter(false)
.build();
assert_eq!(policy.calculate_delay(1), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(3), Duration::from_secs(2));
}
#[test]
fn test_linear_backoff_calculation() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Linear { multiplier: 1.0 })
.initial_delay(Duration::from_secs(1))
.with_jitter(false)
.build();
assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(3), Duration::from_secs(3));
}
#[test]
fn test_exponential_backoff_calculation() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Exponential {
base: 2.0,
multiplier: 1.0,
})
.initial_delay(Duration::from_secs(1))
.with_jitter(false)
.build();
assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
assert_eq!(policy.calculate_delay(4), Duration::from_secs(8));
}
#[test]
fn test_max_delay_capping() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Exponential {
base: 2.0,
multiplier: 1.0,
})
.initial_delay(Duration::from_secs(10))
.max_delay(Duration::from_secs(15))
.with_jitter(false)
.build();
assert_eq!(policy.calculate_delay(1), Duration::from_secs(10));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(15)); assert_eq!(policy.calculate_delay(3), Duration::from_secs(15)); }
fn make_execution_error(msg: &str) -> TaskError {
TaskError::ExecutionFailed {
message: msg.to_string(),
task_id: "test".to_string(),
timestamp: chrono::Utc::now(),
}
}
fn make_unknown_error(msg: &str) -> TaskError {
TaskError::Unknown {
task_id: "test".to_string(),
message: msg.to_string(),
}
}
#[test]
fn test_timeout_is_transient() {
let policy = RetryPolicy::default();
let error = TaskError::Timeout {
task_id: "test".to_string(),
timeout_seconds: 30,
};
assert!(policy.is_transient_error(&error));
}
#[test]
fn test_connection_error_is_transient() {
let policy = RetryPolicy::default();
assert!(policy.is_transient_error(&make_execution_error("Connection refused")));
assert!(policy.is_transient_error(&make_execution_error("network unreachable")));
assert!(policy.is_transient_error(&make_execution_error("service temporarily unavailable")));
assert!(policy.is_transient_error(&make_execution_error("server busy")));
assert!(policy.is_transient_error(&make_execution_error("overloaded")));
assert!(policy.is_transient_error(&make_execution_error("rate limit exceeded")));
}
#[test]
fn test_unknown_error_with_transient_message_is_transient() {
let policy = RetryPolicy::default();
assert!(policy.is_transient_error(&make_unknown_error("Connection reset by peer")));
assert!(policy.is_transient_error(&make_unknown_error("TIMEOUT waiting for response")));
}
#[test]
fn test_permanent_errors_are_not_transient() {
let policy = RetryPolicy::default();
assert!(!policy.is_transient_error(&make_execution_error("invalid input format")));
assert!(!policy.is_transient_error(&make_execution_error("permission denied")));
assert!(!policy.is_transient_error(&make_unknown_error("null pointer")));
}
#[test]
fn test_non_retryable_error_variants_are_not_transient() {
let policy = RetryPolicy::default();
assert!(!policy.is_transient_error(&TaskError::ContextError {
task_id: "t".to_string(),
error: crate::error::ContextError::KeyNotFound("k".to_string()),
}));
assert!(
!policy.is_transient_error(&TaskError::DependencyNotSatisfied {
dependency: "dep".to_string(),
task_id: "t".to_string(),
})
);
assert!(!policy.is_transient_error(&TaskError::ValidationFailed {
message: "bad".to_string(),
}));
assert!(
!policy.is_transient_error(&TaskError::ReadinessCheckFailed {
task_id: "t".to_string(),
})
);
assert!(!policy.is_transient_error(&TaskError::TriggerRuleFailed {
task_id: "t".to_string(),
}));
}
#[test]
fn test_transient_pattern_matching_is_case_insensitive() {
let policy = RetryPolicy::default();
assert!(policy.is_transient_error(&make_execution_error("CONNECTION REFUSED")));
assert!(policy.is_transient_error(&make_execution_error("Network Error")));
assert!(policy.is_transient_error(&make_execution_error("TIMEOUT")));
}
#[test]
fn test_should_retry_all_errors_within_limit() {
let policy = RetryPolicy::builder()
.max_attempts(3)
.retry_condition(RetryCondition::AllErrors)
.build();
let error = make_execution_error("anything");
assert!(policy.should_retry(&error, 1));
assert!(policy.should_retry(&error, 2));
assert!(!policy.should_retry(&error, 3)); assert!(!policy.should_retry(&error, 4)); }
#[test]
fn test_should_retry_never_condition() {
let policy = RetryPolicy::builder()
.max_attempts(10)
.retry_condition(RetryCondition::Never)
.build();
assert!(!policy.should_retry(&make_execution_error("anything"), 1));
}
#[test]
fn test_should_retry_transient_only() {
let policy = RetryPolicy::builder()
.max_attempts(3)
.retry_condition(RetryCondition::TransientOnly)
.build();
assert!(policy.should_retry(&make_execution_error("connection refused"), 1));
assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
}
#[test]
fn test_should_retry_error_pattern() {
let policy = RetryPolicy::builder()
.max_attempts(3)
.retry_condition(RetryCondition::ErrorPattern {
patterns: vec!["deadlock".to_string(), "lock timeout".to_string()],
})
.build();
assert!(policy.should_retry(&make_execution_error("deadlock detected"), 1));
assert!(policy.should_retry(&make_execution_error("Lock Timeout on table"), 1));
assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
}
#[test]
fn test_should_retry_zero_max_attempts() {
let policy = RetryPolicy::builder()
.max_attempts(0)
.retry_condition(RetryCondition::AllErrors)
.build();
assert!(!policy.should_retry(&make_execution_error("anything"), 0));
}
#[test]
fn test_custom_backoff_falls_back_to_exponential() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Custom {
function_name: "my_func".to_string(),
})
.initial_delay(Duration::from_secs(1))
.with_jitter(false)
.build();
assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
}
#[test]
fn test_jitter_stays_within_bounds() {
let policy = RetryPolicy::builder()
.backoff_strategy(BackoffStrategy::Fixed)
.initial_delay(Duration::from_secs(10))
.with_jitter(true)
.build();
for _ in 0..100 {
let delay = policy.calculate_delay(1);
let millis = delay.as_millis();
assert!(millis >= 7500, "jitter too low: {}ms", millis);
assert!(millis <= 12500, "jitter too high: {}ms", millis);
}
}
#[test]
fn test_message_matches_transient_patterns_directly() {
assert!(RetryPolicy::message_matches_transient_patterns(
"connection reset"
));
assert!(RetryPolicy::message_matches_transient_patterns(
"NETWORK error"
));
assert!(!RetryPolicy::message_matches_transient_patterns(
"invalid input"
));
assert!(!RetryPolicy::message_matches_transient_patterns(""));
}
}