Skip to main content

cloacina_workflow/
retry.rs

1/*
2 *  Copyright 2025-2026 Colliery Software
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *  Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17//! # Retry Policy System
18//!
19//! This module provides a comprehensive retry policy system for Cloacina tasks,
20//! including configurable backoff strategies, jitter, and conditional retry logic.
21//!
22//! ## Overview
23//!
24//! The retry system allows tasks to define sophisticated retry behavior:
25//! - **Configurable retry limits** with per-task policies
26//! - **Multiple backoff strategies** including exponential, linear, and custom
27//! - **Jitter support** to prevent thundering herd problems
28//! - **Conditional retries** based on error types and conditions
29//!
30//! ## Usage
31//!
32//! ```rust
33//! use cloacina_workflow::retry::{RetryPolicy, BackoffStrategy, RetryCondition};
34//! use std::time::Duration;
35//!
36//! let policy = RetryPolicy::builder()
37//!     .max_attempts(5)
38//!     .backoff_strategy(BackoffStrategy::Exponential {
39//!         base: 2.0,
40//!         multiplier: 1.0
41//!     })
42//!     .initial_delay(Duration::from_millis(100))
43//!     .max_delay(Duration::from_secs(30))
44//!     .with_jitter(true)
45//!     .retry_condition(RetryCondition::AllErrors)
46//!     .build();
47//! ```
48
49use crate::error::TaskError;
50use chrono::NaiveDateTime;
51use rand::Rng;
52use serde::{Deserialize, Serialize};
53use std::time::Duration;
54
55/// Comprehensive retry policy configuration for tasks.
56///
57/// This struct defines how a task should behave when it fails, including
58/// the number of retry attempts, backoff strategy, delays, and conditions
59/// under which retries should be attempted.
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
61pub struct RetryPolicy {
62    /// Maximum number of retry attempts (not including the initial attempt)
63    pub max_attempts: i32,
64
65    /// The backoff strategy to use for calculating delays between retries
66    pub backoff_strategy: BackoffStrategy,
67
68    /// Initial delay before the first retry attempt
69    pub initial_delay: Duration,
70
71    /// Maximum delay between retry attempts (caps exponential growth)
72    pub max_delay: Duration,
73
74    /// Whether to add random jitter to delays to prevent thundering herd
75    pub jitter: bool,
76
77    /// Conditions that determine whether a retry should be attempted
78    pub retry_conditions: Vec<RetryCondition>,
79}
80
81/// Different backoff strategies for calculating retry delays.
82///
83/// Each strategy defines how the delay between retry attempts should increase.
84/// The actual delay is calculated based on the attempt number and the strategy's parameters.
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86#[serde(tag = "type")]
87pub enum BackoffStrategy {
88    /// Fixed delay - same delay for every retry attempt
89    Fixed,
90
91    /// Linear backoff - delay increases linearly with each attempt
92    /// delay = initial_delay * attempt * multiplier
93    Linear {
94        /// Multiplier for linear growth (default: 1.0)
95        multiplier: f64,
96    },
97
98    /// Exponential backoff - delay increases exponentially with each attempt
99    /// delay = initial_delay * multiplier * (base ^ attempt)
100    Exponential {
101        /// Base for exponential growth (default: 2.0)
102        base: f64,
103        /// Multiplier for the exponential function (default: 1.0)
104        multiplier: f64,
105    },
106
107    /// Custom backoff function (reserved for future extensibility)
108    Custom {
109        /// Name of the custom function to use
110        function_name: String,
111    },
112}
113
114/// Conditions that determine whether a failed task should be retried.
115///
116/// These conditions are used to evaluate whether a task should be retried
117/// based on the type of error or specific error patterns.
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119#[serde(tag = "type")]
120pub enum RetryCondition {
121    /// Retry on all errors (default behavior)
122    AllErrors,
123
124    /// Never retry (equivalent to max_attempts = 0)
125    Never,
126
127    /// Retry only for transient errors (network, timeout, etc.)
128    TransientOnly,
129
130    /// Retry only if error message contains any of the specified patterns
131    ErrorPattern { patterns: Vec<String> },
132}
133
134impl Default for RetryPolicy {
135    /// Creates a default retry policy with reasonable production settings.
136    ///
137    /// Default configuration:
138    /// - 3 retry attempts
139    /// - Exponential backoff (base 2.0, multiplier 1.0)
140    /// - 1 second initial delay
141    /// - 60 seconds maximum delay
142    /// - Jitter enabled
143    /// - Retry on all errors
144    fn default() -> Self {
145        Self {
146            max_attempts: 3,
147            backoff_strategy: BackoffStrategy::Exponential {
148                base: 2.0,
149                multiplier: 1.0,
150            },
151            initial_delay: Duration::from_secs(1),
152            max_delay: Duration::from_secs(60),
153            jitter: true,
154            retry_conditions: vec![RetryCondition::AllErrors],
155        }
156    }
157}
158
159impl RetryPolicy {
160    /// Creates a new RetryPolicyBuilder for fluent configuration.
161    pub fn builder() -> RetryPolicyBuilder {
162        RetryPolicyBuilder::new()
163    }
164
165    /// Calculates the delay before the next retry attempt.
166    ///
167    /// # Arguments
168    ///
169    /// * `attempt` - The current attempt number (1-based)
170    ///
171    /// # Returns
172    ///
173    /// The duration to wait before the next retry attempt.
174    pub fn calculate_delay(&self, attempt: i32) -> Duration {
175        let base_delay = match &self.backoff_strategy {
176            BackoffStrategy::Fixed => self.initial_delay,
177
178            BackoffStrategy::Linear { multiplier } => {
179                let millis = self.initial_delay.as_millis() as f64 * attempt as f64 * multiplier;
180                Duration::from_millis(millis as u64)
181            }
182
183            BackoffStrategy::Exponential { base, multiplier } => {
184                let millis =
185                    self.initial_delay.as_millis() as f64 * multiplier * base.powi(attempt - 1);
186                Duration::from_millis(millis as u64)
187            }
188
189            BackoffStrategy::Custom { .. } => {
190                // For now, fall back to exponential backoff for custom functions
191                let millis = self.initial_delay.as_millis() as f64 * 2.0_f64.powi(attempt - 1);
192                Duration::from_millis(millis as u64)
193            }
194        };
195
196        // Cap the delay at max_delay
197        let capped_delay = std::cmp::min(base_delay, self.max_delay);
198
199        // Add jitter if enabled
200        if self.jitter {
201            self.add_jitter(capped_delay)
202        } else {
203            capped_delay
204        }
205    }
206
207    /// Determines whether a retry should be attempted based on the error and retry conditions.
208    ///
209    /// # Arguments
210    ///
211    /// * `error` - The error that caused the task to fail
212    /// * `attempt` - The current attempt number
213    ///
214    /// # Returns
215    ///
216    /// `true` if the task should be retried, `false` otherwise.
217    pub fn should_retry(&self, error: &TaskError, attempt: i32) -> bool {
218        // Check if we've exceeded the maximum number of attempts
219        if attempt >= self.max_attempts {
220            return false;
221        }
222
223        // Check retry conditions
224        self.retry_conditions
225            .iter()
226            .any(|condition| match condition {
227                RetryCondition::AllErrors => true,
228                RetryCondition::Never => false,
229                RetryCondition::TransientOnly => self.is_transient_error(error),
230                RetryCondition::ErrorPattern { patterns } => {
231                    let error_msg = error.to_string().to_lowercase();
232                    patterns
233                        .iter()
234                        .any(|pattern| error_msg.contains(&pattern.to_lowercase()))
235                }
236            })
237    }
238
239    /// Calculates the absolute timestamp when the next retry should occur.
240    ///
241    /// # Arguments
242    ///
243    /// * `attempt` - The current attempt number
244    /// * `now` - The current timestamp
245    ///
246    /// # Returns
247    ///
248    /// A NaiveDateTime representing when the retry should be attempted.
249    pub fn calculate_retry_at(&self, attempt: i32, now: NaiveDateTime) -> NaiveDateTime {
250        let delay = self.calculate_delay(attempt);
251        now + chrono::Duration::from_std(delay).unwrap_or_default()
252    }
253
254    /// Adds random jitter to a delay to prevent thundering herd problems.
255    ///
256    /// Uses +/-25% jitter by default.
257    fn add_jitter(&self, delay: Duration) -> Duration {
258        let mut rng = rand::thread_rng();
259        let jitter_factor = rng.gen_range(0.75..=1.25); // +/-25% jitter
260        let jittered_millis = (delay.as_millis() as f64 * jitter_factor) as u64;
261        Duration::from_millis(jittered_millis)
262    }
263
264    /// Determines if an error is transient (network, timeout, temporary failures).
265    fn is_transient_error(&self, error: &TaskError) -> bool {
266        match error {
267            TaskError::Timeout { .. } => true,
268            TaskError::ExecutionFailed { message, .. } | TaskError::Unknown { message, .. } => {
269                Self::message_matches_transient_patterns(message)
270            }
271            _ => false,
272        }
273    }
274
275    /// Checks whether an error message contains any known transient error patterns.
276    fn message_matches_transient_patterns(message: &str) -> bool {
277        const TRANSIENT_PATTERNS: &[&str] = &[
278            "connection",
279            "network",
280            "timeout",
281            "temporary",
282            "unavailable",
283            "busy",
284            "overloaded",
285            "rate limit",
286        ];
287        let error_msg = message.to_lowercase();
288        TRANSIENT_PATTERNS
289            .iter()
290            .any(|pattern| error_msg.contains(pattern))
291    }
292}
293
294/// Builder for creating RetryPolicy instances with a fluent API.
295#[derive(Debug)]
296pub struct RetryPolicyBuilder {
297    policy: RetryPolicy,
298}
299
300impl RetryPolicyBuilder {
301    /// Creates a new RetryPolicyBuilder with default values.
302    pub fn new() -> Self {
303        Self {
304            policy: RetryPolicy::default(),
305        }
306    }
307
308    /// Sets the maximum number of retry attempts.
309    pub fn max_attempts(mut self, max_attempts: i32) -> Self {
310        self.policy.max_attempts = max_attempts;
311        self
312    }
313
314    /// Sets the backoff strategy.
315    pub fn backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
316        self.policy.backoff_strategy = strategy;
317        self
318    }
319
320    /// Sets the initial delay before the first retry.
321    pub fn initial_delay(mut self, delay: Duration) -> Self {
322        self.policy.initial_delay = delay;
323        self
324    }
325
326    /// Sets the maximum delay between retries.
327    pub fn max_delay(mut self, delay: Duration) -> Self {
328        self.policy.max_delay = delay;
329        self
330    }
331
332    /// Enables or disables jitter.
333    pub fn with_jitter(mut self, jitter: bool) -> Self {
334        self.policy.jitter = jitter;
335        self
336    }
337
338    /// Adds a retry condition.
339    pub fn retry_condition(mut self, condition: RetryCondition) -> Self {
340        self.policy.retry_conditions = vec![condition];
341        self
342    }
343
344    /// Adds multiple retry conditions.
345    pub fn retry_conditions(mut self, conditions: Vec<RetryCondition>) -> Self {
346        self.policy.retry_conditions = conditions;
347        self
348    }
349
350    /// Builds the RetryPolicy.
351    pub fn build(self) -> RetryPolicy {
352        self.policy
353    }
354}
355
356impl Default for RetryPolicyBuilder {
357    fn default() -> Self {
358        Self::new()
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_default_retry_policy() {
368        let policy = RetryPolicy::default();
369        assert_eq!(policy.max_attempts, 3);
370        assert_eq!(policy.initial_delay, Duration::from_secs(1));
371        assert_eq!(policy.max_delay, Duration::from_secs(60));
372        assert!(policy.jitter);
373        assert!(matches!(
374            policy.backoff_strategy,
375            BackoffStrategy::Exponential { .. }
376        ));
377    }
378
379    #[test]
380    fn test_retry_policy_builder() {
381        let policy = RetryPolicy::builder()
382            .max_attempts(5)
383            .initial_delay(Duration::from_millis(500))
384            .max_delay(Duration::from_secs(30))
385            .with_jitter(false)
386            .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.5 })
387            .retry_condition(RetryCondition::TransientOnly)
388            .build();
389
390        assert_eq!(policy.max_attempts, 5);
391        assert_eq!(policy.initial_delay, Duration::from_millis(500));
392        assert_eq!(policy.max_delay, Duration::from_secs(30));
393        assert!(!policy.jitter);
394        assert_eq!(policy.retry_conditions, vec![RetryCondition::TransientOnly]);
395    }
396
397    #[test]
398    fn test_fixed_backoff_calculation() {
399        let policy = RetryPolicy::builder()
400            .backoff_strategy(BackoffStrategy::Fixed)
401            .initial_delay(Duration::from_secs(2))
402            .with_jitter(false)
403            .build();
404
405        assert_eq!(policy.calculate_delay(1), Duration::from_secs(2));
406        assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
407        assert_eq!(policy.calculate_delay(3), Duration::from_secs(2));
408    }
409
410    #[test]
411    fn test_linear_backoff_calculation() {
412        let policy = RetryPolicy::builder()
413            .backoff_strategy(BackoffStrategy::Linear { multiplier: 1.0 })
414            .initial_delay(Duration::from_secs(1))
415            .with_jitter(false)
416            .build();
417
418        assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
419        assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
420        assert_eq!(policy.calculate_delay(3), Duration::from_secs(3));
421    }
422
423    #[test]
424    fn test_exponential_backoff_calculation() {
425        let policy = RetryPolicy::builder()
426            .backoff_strategy(BackoffStrategy::Exponential {
427                base: 2.0,
428                multiplier: 1.0,
429            })
430            .initial_delay(Duration::from_secs(1))
431            .with_jitter(false)
432            .build();
433
434        assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
435        assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
436        assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
437        assert_eq!(policy.calculate_delay(4), Duration::from_secs(8));
438    }
439
440    #[test]
441    fn test_max_delay_capping() {
442        let policy = RetryPolicy::builder()
443            .backoff_strategy(BackoffStrategy::Exponential {
444                base: 2.0,
445                multiplier: 1.0,
446            })
447            .initial_delay(Duration::from_secs(10))
448            .max_delay(Duration::from_secs(15))
449            .with_jitter(false)
450            .build();
451
452        assert_eq!(policy.calculate_delay(1), Duration::from_secs(10));
453        assert_eq!(policy.calculate_delay(2), Duration::from_secs(15)); // Capped
454        assert_eq!(policy.calculate_delay(3), Duration::from_secs(15)); // Capped
455    }
456
457    // --- is_transient_error tests ---
458
459    fn make_execution_error(msg: &str) -> TaskError {
460        TaskError::ExecutionFailed {
461            message: msg.to_string(),
462            task_id: "test".to_string(),
463            timestamp: chrono::Utc::now(),
464        }
465    }
466
467    fn make_unknown_error(msg: &str) -> TaskError {
468        TaskError::Unknown {
469            task_id: "test".to_string(),
470            message: msg.to_string(),
471        }
472    }
473
474    #[test]
475    fn test_timeout_is_transient() {
476        let policy = RetryPolicy::default();
477        let error = TaskError::Timeout {
478            task_id: "test".to_string(),
479            timeout_seconds: 30,
480        };
481        assert!(policy.is_transient_error(&error));
482    }
483
484    #[test]
485    fn test_connection_error_is_transient() {
486        let policy = RetryPolicy::default();
487        assert!(policy.is_transient_error(&make_execution_error("Connection refused")));
488        assert!(policy.is_transient_error(&make_execution_error("network unreachable")));
489        assert!(policy.is_transient_error(&make_execution_error("service temporarily unavailable")));
490        assert!(policy.is_transient_error(&make_execution_error("server busy")));
491        assert!(policy.is_transient_error(&make_execution_error("overloaded")));
492        assert!(policy.is_transient_error(&make_execution_error("rate limit exceeded")));
493    }
494
495    #[test]
496    fn test_unknown_error_with_transient_message_is_transient() {
497        let policy = RetryPolicy::default();
498        assert!(policy.is_transient_error(&make_unknown_error("Connection reset by peer")));
499        assert!(policy.is_transient_error(&make_unknown_error("TIMEOUT waiting for response")));
500    }
501
502    #[test]
503    fn test_permanent_errors_are_not_transient() {
504        let policy = RetryPolicy::default();
505        assert!(!policy.is_transient_error(&make_execution_error("invalid input format")));
506        assert!(!policy.is_transient_error(&make_execution_error("permission denied")));
507        assert!(!policy.is_transient_error(&make_unknown_error("null pointer")));
508    }
509
510    #[test]
511    fn test_non_retryable_error_variants_are_not_transient() {
512        let policy = RetryPolicy::default();
513        assert!(!policy.is_transient_error(&TaskError::ContextError {
514            task_id: "t".to_string(),
515            error: crate::error::ContextError::KeyNotFound("k".to_string()),
516        }));
517        assert!(
518            !policy.is_transient_error(&TaskError::DependencyNotSatisfied {
519                dependency: "dep".to_string(),
520                task_id: "t".to_string(),
521            })
522        );
523        assert!(!policy.is_transient_error(&TaskError::ValidationFailed {
524            message: "bad".to_string(),
525        }));
526        assert!(
527            !policy.is_transient_error(&TaskError::ReadinessCheckFailed {
528                task_id: "t".to_string(),
529            })
530        );
531        assert!(!policy.is_transient_error(&TaskError::TriggerRuleFailed {
532            task_id: "t".to_string(),
533        }));
534    }
535
536    #[test]
537    fn test_transient_pattern_matching_is_case_insensitive() {
538        let policy = RetryPolicy::default();
539        assert!(policy.is_transient_error(&make_execution_error("CONNECTION REFUSED")));
540        assert!(policy.is_transient_error(&make_execution_error("Network Error")));
541        assert!(policy.is_transient_error(&make_execution_error("TIMEOUT")));
542    }
543
544    // --- should_retry tests ---
545
546    #[test]
547    fn test_should_retry_all_errors_within_limit() {
548        let policy = RetryPolicy::builder()
549            .max_attempts(3)
550            .retry_condition(RetryCondition::AllErrors)
551            .build();
552
553        let error = make_execution_error("anything");
554        assert!(policy.should_retry(&error, 1));
555        assert!(policy.should_retry(&error, 2));
556        assert!(!policy.should_retry(&error, 3)); // at max
557        assert!(!policy.should_retry(&error, 4)); // over max
558    }
559
560    #[test]
561    fn test_should_retry_never_condition() {
562        let policy = RetryPolicy::builder()
563            .max_attempts(10)
564            .retry_condition(RetryCondition::Never)
565            .build();
566
567        assert!(!policy.should_retry(&make_execution_error("anything"), 1));
568    }
569
570    #[test]
571    fn test_should_retry_transient_only() {
572        let policy = RetryPolicy::builder()
573            .max_attempts(3)
574            .retry_condition(RetryCondition::TransientOnly)
575            .build();
576
577        assert!(policy.should_retry(&make_execution_error("connection refused"), 1));
578        assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
579    }
580
581    #[test]
582    fn test_should_retry_error_pattern() {
583        let policy = RetryPolicy::builder()
584            .max_attempts(3)
585            .retry_condition(RetryCondition::ErrorPattern {
586                patterns: vec!["deadlock".to_string(), "lock timeout".to_string()],
587            })
588            .build();
589
590        assert!(policy.should_retry(&make_execution_error("deadlock detected"), 1));
591        assert!(policy.should_retry(&make_execution_error("Lock Timeout on table"), 1));
592        assert!(!policy.should_retry(&make_execution_error("invalid input"), 1));
593    }
594
595    #[test]
596    fn test_should_retry_zero_max_attempts() {
597        let policy = RetryPolicy::builder()
598            .max_attempts(0)
599            .retry_condition(RetryCondition::AllErrors)
600            .build();
601
602        assert!(!policy.should_retry(&make_execution_error("anything"), 0));
603    }
604
605    #[test]
606    fn test_custom_backoff_falls_back_to_exponential() {
607        let policy = RetryPolicy::builder()
608            .backoff_strategy(BackoffStrategy::Custom {
609                function_name: "my_func".to_string(),
610            })
611            .initial_delay(Duration::from_secs(1))
612            .with_jitter(false)
613            .build();
614
615        assert_eq!(policy.calculate_delay(1), Duration::from_secs(1));
616        assert_eq!(policy.calculate_delay(2), Duration::from_secs(2));
617        assert_eq!(policy.calculate_delay(3), Duration::from_secs(4));
618    }
619
620    #[test]
621    fn test_jitter_stays_within_bounds() {
622        let policy = RetryPolicy::builder()
623            .backoff_strategy(BackoffStrategy::Fixed)
624            .initial_delay(Duration::from_secs(10))
625            .with_jitter(true)
626            .build();
627
628        // Run multiple times to check jitter range (+/-25%)
629        for _ in 0..100 {
630            let delay = policy.calculate_delay(1);
631            let millis = delay.as_millis();
632            assert!(millis >= 7500, "jitter too low: {}ms", millis);
633            assert!(millis <= 12500, "jitter too high: {}ms", millis);
634        }
635    }
636
637    #[test]
638    fn test_message_matches_transient_patterns_directly() {
639        assert!(RetryPolicy::message_matches_transient_patterns(
640            "connection reset"
641        ));
642        assert!(RetryPolicy::message_matches_transient_patterns(
643            "NETWORK error"
644        ));
645        assert!(!RetryPolicy::message_matches_transient_patterns(
646            "invalid input"
647        ));
648        assert!(!RetryPolicy::message_matches_transient_patterns(""));
649    }
650}