1use crate::error::TaskError;
50use chrono::NaiveDateTime;
51use rand::Rng;
52use serde::{Deserialize, Serialize};
53use std::time::Duration;
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
61pub struct RetryPolicy {
62 pub max_attempts: i32,
64
65 pub backoff_strategy: BackoffStrategy,
67
68 pub initial_delay: Duration,
70
71 pub max_delay: Duration,
73
74 pub jitter: bool,
76
77 pub retry_conditions: Vec<RetryCondition>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86#[serde(tag = "type")]
87pub enum BackoffStrategy {
88 Fixed,
90
91 Linear {
94 multiplier: f64,
96 },
97
98 Exponential {
101 base: f64,
103 multiplier: f64,
105 },
106
107 Custom {
109 function_name: String,
111 },
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119#[serde(tag = "type")]
120pub enum RetryCondition {
121 AllErrors,
123
124 Never,
126
127 TransientOnly,
129
130 ErrorPattern { patterns: Vec<String> },
132}
133
134impl Default for RetryPolicy {
135 fn default() -> Self {
145 Self {
146 max_attempts: 3,
147 backoff_strategy: BackoffStrategy::Exponential {
148 base: 2.0,
149 multiplier: 1.0,
150 },
151 initial_delay: Duration::from_secs(1),
152 max_delay: Duration::from_secs(60),
153 jitter: true,
154 retry_conditions: vec![RetryCondition::AllErrors],
155 }
156 }
157}
158
159impl RetryPolicy {
160 pub fn builder() -> RetryPolicyBuilder {
162 RetryPolicyBuilder::new()
163 }
164
165 pub fn calculate_delay(&self, attempt: i32) -> Duration {
175 let base_delay = match &self.backoff_strategy {
176 BackoffStrategy::Fixed => self.initial_delay,
177
178 BackoffStrategy::Linear { multiplier } => {
179 let millis = self.initial_delay.as_millis() as f64 * attempt as f64 * multiplier;
180 Duration::from_millis(millis as u64)
181 }
182
183 BackoffStrategy::Exponential { base, multiplier } => {
184 let millis =
185 self.initial_delay.as_millis() as f64 * multiplier * base.powi(attempt - 1);
186 Duration::from_millis(millis as u64)
187 }
188
189 BackoffStrategy::Custom { .. } => {
190 let millis = self.initial_delay.as_millis() as f64 * 2.0_f64.powi(attempt - 1);
192 Duration::from_millis(millis as u64)
193 }
194 };
195
196 let capped_delay = std::cmp::min(base_delay, self.max_delay);
198
199 if self.jitter {
201 self.add_jitter(capped_delay)
202 } else {
203 capped_delay
204 }
205 }
206
207 pub fn should_retry(&self, error: &TaskError, attempt: i32) -> bool {
218 if attempt >= self.max_attempts {
220 return false;
221 }
222
223 self.retry_conditions
225 .iter()
226 .any(|condition| match condition {
227 RetryCondition::AllErrors => true,
228 RetryCondition::Never => false,
229 RetryCondition::TransientOnly => self.is_transient_error(error),
230 RetryCondition::ErrorPattern { patterns } => {
231 let error_msg = error.to_string().to_lowercase();
232 patterns
233 .iter()
234 .any(|pattern| error_msg.contains(&pattern.to_lowercase()))
235 }
236 })
237 }
238
239 pub fn calculate_retry_at(&self, attempt: i32, now: NaiveDateTime) -> NaiveDateTime {
250 let delay = self.calculate_delay(attempt);
251 now + chrono::Duration::from_std(delay).unwrap_or_default()
252 }
253
254 fn add_jitter(&self, delay: Duration) -> Duration {
258 let mut rng = rand::thread_rng();
259 let jitter_factor = rng.gen_range(0.75..=1.25); let jittered_millis = (delay.as_millis() as f64 * jitter_factor) as u64;
261 Duration::from_millis(jittered_millis)
262 }
263
264 fn is_transient_error(&self, error: &TaskError) -> bool {
266 match error {
267 TaskError::Timeout { .. } => true,
268 TaskError::ExecutionFailed { message, .. } => {
269 let error_msg = message.to_lowercase();
271 let transient_patterns = [
272 "connection",
273 "network",
274 "timeout",
275 "temporary",
276 "unavailable",
277 "busy",
278 "overloaded",
279 "rate limit",
280 ];
281 transient_patterns
282 .iter()
283 .any(|pattern| error_msg.contains(pattern))
284 }
285 TaskError::Unknown { message, .. } => {
286 let error_msg = message.to_lowercase();
288 let transient_patterns = [
289 "connection",
290 "network",
291 "timeout",
292 "temporary",
293 "unavailable",
294 "busy",
295 "overloaded",
296 "rate limit",
297 ];
298 transient_patterns
299 .iter()
300 .any(|pattern| error_msg.contains(pattern))
301 }
302 TaskError::ContextError { .. } => false,
303 TaskError::DependencyNotSatisfied { .. } => false,
304 TaskError::ValidationFailed { .. } => false,
305 TaskError::ReadinessCheckFailed { .. } => false,
306 TaskError::TriggerRuleFailed { .. } => false,
307 }
308 }
309}
310
311#[derive(Debug)]
313pub struct RetryPolicyBuilder {
314 policy: RetryPolicy,
315}
316
317impl RetryPolicyBuilder {
318 pub fn new() -> Self {
320 Self {
321 policy: RetryPolicy::default(),
322 }
323 }
324
325 pub fn max_attempts(mut self, max_attempts: i32) -> Self {
327 self.policy.max_attempts = max_attempts;
328 self
329 }
330
331 pub fn backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
333 self.policy.backoff_strategy = strategy;
334 self
335 }
336
337 pub fn initial_delay(mut self, delay: Duration) -> Self {
339 self.policy.initial_delay = delay;
340 self
341 }
342
343 pub fn max_delay(mut self, delay: Duration) -> Self {
345 self.policy.max_delay = delay;
346 self
347 }
348
349 pub fn with_jitter(mut self, jitter: bool) -> Self {
351 self.policy.jitter = jitter;
352 self
353 }
354
355 pub fn retry_condition(mut self, condition: RetryCondition) -> Self {
357 self.policy.retry_conditions = vec![condition];
358 self
359 }
360
361 pub fn retry_conditions(mut self, conditions: Vec<RetryCondition>) -> Self {
363 self.policy.retry_conditions = conditions;
364 self
365 }
366
367 pub fn build(self) -> RetryPolicy {
369 self.policy
370 }
371}
372
373impl Default for RetryPolicyBuilder {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_default_retry_policy() {
385 let policy = RetryPolicy::default();
386 assert_eq!(policy.max_attempts, 3);
387 assert_eq!(policy.initial_delay, Duration::from_secs(1));
388 assert_eq!(policy.max_delay, Duration::from_secs(60));
389 assert!(policy.jitter);
390 assert!(matches!(
391 policy.backoff_strategy,
392 BackoffStrategy::Exponential { .. }
393 ));
394 }
395
396 #[test]
397 fn test_retry_policy_builder() {
398 let policy = RetryPolicy::builder()
399 .max_attempts(5)
400 .initial_delay(Duration::from_millis(500))
401 .max_delay(Duration::from_secs(30))
402 .with_jitter(false)
403 .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.5 })
404 .retry_condition(RetryCondition::TransientOnly)
405 .build();
406
407 assert_eq!(policy.max_attempts, 5);
408 assert_eq!(policy.initial_delay, Duration::from_millis(500));
409 assert_eq!(policy.max_delay, Duration::from_secs(30));
410 assert!(!policy.jitter);
411 assert_eq!(policy.retry_conditions, vec![RetryCondition::TransientOnly]);
412 }
413
414 #[test]
415 fn test_fixed_backoff_calculation() {
416 let policy = RetryPolicy::builder()
417 .backoff_strategy(BackoffStrategy::Fixed)
418 .initial_delay(Duration::from_secs(2))
419 .with_jitter(false)
420 .build();
421
422 assert_eq!(policy.calculate_delay(1), Duration::from_secs(2));
423 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
424 assert_eq!(policy.calculate_delay(3), Duration::from_secs(2));
425 }
426
427 #[test]
428 fn test_linear_backoff_calculation() {
429 let policy = RetryPolicy::builder()
430 .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.0 })
431 .initial_delay(Duration::from_secs(1))
432 .with_jitter(false)
433 .build();
434
435 assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
436 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
437 assert_eq!(policy.calculate_delay(3), Duration::from_secs(3));
438 }
439
440 #[test]
441 fn test_exponential_backoff_calculation() {
442 let policy = RetryPolicy::builder()
443 .backoff_strategy(BackoffStrategy::Exponential {
444 base: 2.0,
445 multiplier: 1.0,
446 })
447 .initial_delay(Duration::from_secs(1))
448 .with_jitter(false)
449 .build();
450
451 assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
452 assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
453 assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
454 assert_eq!(policy.calculate_delay(4), Duration::from_secs(8));
455 }
456
457 #[test]
458 fn test_max_delay_capping() {
459 let policy = RetryPolicy::builder()
460 .backoff_strategy(BackoffStrategy::Exponential {
461 base: 2.0,
462 multiplier: 1.0,
463 })
464 .initial_delay(Duration::from_secs(10))
465 .max_delay(Duration::from_secs(15))
466 .with_jitter(false)
467 .build();
468
469 assert_eq!(policy.calculate_delay(1), Duration::from_secs(10));
470 assert_eq!(policy.calculate_delay(2), Duration::from_secs(15)); assert_eq!(policy.calculate_delay(3), Duration::from_secs(15)); }
473}