Skip to main content

astrid_core/
retry.rs

1//! Retry utilities with exponential backoff.
2//!
3//! This module provides configurable retry logic for transient failures,
4//! commonly used for network operations and external service calls.
5
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10/// Configuration for retry behavior with exponential backoff.
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct RetryConfig {
13    /// Maximum number of retry attempts (0 = no retries, just the initial attempt).
14    pub max_attempts: u32,
15    /// Initial delay before the first retry.
16    pub initial_delay: Duration,
17    /// Maximum delay between retries (caps the exponential growth).
18    pub max_delay: Duration,
19    /// Base for exponential backoff (typically 2.0).
20    pub exponential_base: f64,
21    /// Optional jitter factor (0.0 to 1.0) to randomize delays.
22    #[serde(default)]
23    pub jitter_factor: f64,
24}
25
26impl RetryConfig {
27    /// Creates a new retry configuration.
28    #[must_use]
29    pub fn new(
30        max_attempts: u32,
31        initial_delay: Duration,
32        max_delay: Duration,
33        exponential_base: f64,
34    ) -> Self {
35        Self {
36            max_attempts,
37            initial_delay,
38            max_delay,
39            exponential_base,
40            jitter_factor: 0.0,
41        }
42    }
43
44    /// Creates a configuration with no retries.
45    #[must_use]
46    #[allow(dead_code)]
47    pub(crate) const fn no_retry() -> Self {
48        Self {
49            max_attempts: 0,
50            initial_delay: Duration::ZERO,
51            max_delay: Duration::ZERO,
52            exponential_base: 2.0,
53            jitter_factor: 0.0,
54        }
55    }
56
57    /// Creates a configuration suitable for quick local operations.
58    #[must_use]
59    pub fn fast() -> Self {
60        Self::new(
61            3,
62            Duration::from_millis(10),
63            Duration::from_millis(100),
64            2.0,
65        )
66    }
67
68    /// Creates a configuration suitable for network operations.
69    #[must_use]
70    pub fn network() -> Self {
71        Self {
72            max_attempts: 5,
73            initial_delay: Duration::from_millis(100),
74            max_delay: Duration::from_secs(10),
75            exponential_base: 2.0,
76            jitter_factor: 0.1,
77        }
78    }
79
80    /// Creates a configuration suitable for external API calls.
81    #[must_use]
82    pub fn api() -> Self {
83        Self {
84            max_attempts: 3,
85            initial_delay: Duration::from_secs(1),
86            max_delay: Duration::from_secs(30),
87            exponential_base: 2.0,
88            jitter_factor: 0.2,
89        }
90    }
91
92    /// Calculates the delay for a given attempt number (0-indexed).
93    ///
94    /// Returns `Duration::ZERO` for attempt 0, then exponentially increasing
95    /// delays for subsequent attempts, capped at `max_delay`.
96    #[must_use]
97    #[expect(
98        clippy::cast_precision_loss,
99        clippy::cast_possible_truncation,
100        clippy::cast_sign_loss
101    )]
102    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
103        if attempt == 0 {
104            return Duration::ZERO;
105        }
106
107        // Calculate base delay with exponential backoff.
108        // Precision loss is acceptable for delay calculations.
109        let exponent = i32::try_from(attempt.saturating_sub(1)).unwrap_or(i32::MAX);
110        let base_delay_ms =
111            self.initial_delay.as_millis() as f64 * self.exponential_base.powi(exponent);
112
113        let capped_delay_ms = base_delay_ms.min(self.max_delay.as_millis() as f64);
114
115        // Safe: delays are always positive and within reasonable bounds
116        Duration::from_millis(capped_delay_ms.max(0.0) as u64)
117    }
118
119    /// Calculates the delay for a given attempt with jitter applied.
120    ///
121    /// Jitter helps prevent thundering herd problems when many clients
122    /// retry simultaneously.
123    #[must_use]
124    #[expect(
125        clippy::cast_precision_loss,
126        clippy::cast_possible_truncation,
127        clippy::cast_sign_loss
128    )]
129    #[allow(dead_code)]
130    pub(crate) fn delay_for_attempt_with_jitter(
131        &self,
132        attempt: u32,
133        random_factor: f64,
134    ) -> Duration {
135        let base_delay = self.delay_for_attempt(attempt);
136
137        if self.jitter_factor <= 0.0 {
138            return base_delay;
139        }
140
141        // random_factor should be between 0.0 and 1.0
142        let random_factor = random_factor.clamp(0.0, 1.0);
143
144        // Apply jitter: delay * (1 - jitter_factor + 2 * jitter_factor * random)
145        // This gives a range of [delay * (1 - jitter), delay * (1 + jitter)]
146        let jitter_multiplier =
147            1.0 - self.jitter_factor + (2.0 * self.jitter_factor * random_factor);
148
149        let jittered_ms = base_delay.as_millis() as f64 * jitter_multiplier;
150
151        // Safe: jittered delays are always positive
152        Duration::from_millis(jittered_ms.max(0.0) as u64)
153    }
154
155    /// Returns true if more attempts are allowed given the current attempt count.
156    #[must_use]
157    #[allow(dead_code)]
158    pub(crate) fn should_retry(&self, current_attempt: u32) -> bool {
159        current_attempt < self.max_attempts
160    }
161
162    /// Returns an iterator over the delays for all retry attempts.
163    #[allow(dead_code)]
164    pub(crate) fn delays(&self) -> impl Iterator<Item = Duration> + '_ {
165        (1..=self.max_attempts).map(|attempt| self.delay_for_attempt(attempt))
166    }
167}
168
169impl Default for RetryConfig {
170    fn default() -> Self {
171        Self::network()
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn delay_calculation() {
181        let config = RetryConfig::new(5, Duration::from_millis(100), Duration::from_secs(10), 2.0);
182
183        assert_eq!(config.delay_for_attempt(0), Duration::ZERO);
184        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(100));
185        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(200));
186        assert_eq!(config.delay_for_attempt(3), Duration::from_millis(400));
187        assert_eq!(config.delay_for_attempt(4), Duration::from_millis(800));
188    }
189
190    #[test]
191    fn delay_caps_at_max() {
192        let config = RetryConfig::new(
193            10,
194            Duration::from_millis(100),
195            Duration::from_millis(500),
196            2.0,
197        );
198
199        // Should cap at 500ms
200        assert_eq!(config.delay_for_attempt(5), Duration::from_millis(500));
201        assert_eq!(config.delay_for_attempt(10), Duration::from_millis(500));
202    }
203
204    #[test]
205    fn should_retry_logic() {
206        let config = RetryConfig::new(3, Duration::from_millis(100), Duration::from_secs(1), 2.0);
207
208        assert!(config.should_retry(0));
209        assert!(config.should_retry(1));
210        assert!(config.should_retry(2));
211        assert!(!config.should_retry(3));
212        assert!(!config.should_retry(4));
213    }
214
215    #[test]
216    fn no_retry_config() {
217        let config = RetryConfig::no_retry();
218
219        assert!(!config.should_retry(0));
220        assert_eq!(config.delays().count(), 0);
221    }
222
223    #[test]
224    fn jitter_application() {
225        let config = RetryConfig::network();
226
227        let base_delay = config.delay_for_attempt(1);
228        let jittered_low = config.delay_for_attempt_with_jitter(1, 0.0);
229        let jittered_high = config.delay_for_attempt_with_jitter(1, 1.0);
230
231        // With jitter_factor of 0.1:
232        // low = base * 0.9, high = base * 1.1
233        assert!(jittered_low < base_delay);
234        assert!(jittered_high > base_delay);
235    }
236}