Skip to main content

llm_core/
retry.rs

1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4
5fn default_max_retries() -> u32 {
6    3
7}
8fn default_base_delay_ms() -> u64 {
9    1000
10}
11fn default_max_delay_ms() -> u64 {
12    30_000
13}
14fn default_jitter() -> bool {
15    true
16}
17
18/// Configuration for retry with exponential backoff.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RetryConfig {
21    /// Maximum number of retries (default: 3).
22    #[serde(default = "default_max_retries")]
23    pub max_retries: u32,
24
25    /// Base delay in milliseconds (default: 1000).
26    #[serde(default = "default_base_delay_ms")]
27    pub base_delay_ms: u64,
28
29    /// Maximum delay in milliseconds (default: 30000).
30    #[serde(default = "default_max_delay_ms")]
31    pub max_delay_ms: u64,
32
33    /// Whether to add jitter to delays (default: true).
34    #[serde(default = "default_jitter")]
35    pub jitter: bool,
36}
37
38impl Default for RetryConfig {
39    fn default() -> Self {
40        Self {
41            max_retries: default_max_retries(),
42            base_delay_ms: default_base_delay_ms(),
43            max_delay_ms: default_max_delay_ms(),
44            jitter: default_jitter(),
45        }
46    }
47}
48
49impl RetryConfig {
50    /// Compute the delay for a given attempt (0-based).
51    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
52        let exp = self
53            .base_delay_ms
54            .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
55        let capped = exp.min(self.max_delay_ms);
56        if self.jitter {
57            Duration::from_millis(fastrand::u64(0..=capped))
58        } else {
59            Duration::from_millis(capped)
60        }
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn default_config() {
70        let config = RetryConfig::default();
71        assert_eq!(config.max_retries, 3);
72        assert_eq!(config.base_delay_ms, 1000);
73        assert_eq!(config.max_delay_ms, 30_000);
74        assert!(config.jitter);
75    }
76
77    #[test]
78    fn delay_exponential_no_jitter() {
79        let config = RetryConfig {
80            jitter: false,
81            ..Default::default()
82        };
83        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(1000));
84        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(2000));
85        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(4000));
86        assert_eq!(config.delay_for_attempt(3), Duration::from_millis(8000));
87    }
88
89    #[test]
90    fn delay_capped_at_max() {
91        let config = RetryConfig {
92            jitter: false,
93            max_delay_ms: 30_000,
94            ..Default::default()
95        };
96        // 2^10 * 1000 = 1_024_000, but capped at 30_000
97        assert_eq!(config.delay_for_attempt(10), Duration::from_millis(30_000));
98    }
99
100    #[test]
101    fn delay_with_jitter_in_bounds() {
102        let config = RetryConfig::default();
103        for _ in 0..100 {
104            let delay = config.delay_for_attempt(0);
105            assert!(delay <= Duration::from_millis(1000));
106        }
107    }
108
109    #[test]
110    fn delay_attempt_zero() {
111        let config = RetryConfig {
112            jitter: false,
113            ..Default::default()
114        };
115        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(1000));
116    }
117
118    #[test]
119    fn serde_roundtrip() {
120        let config = RetryConfig {
121            max_retries: 5,
122            base_delay_ms: 500,
123            max_delay_ms: 10_000,
124            jitter: false,
125        };
126        let toml_str = toml::to_string(&config).unwrap();
127        let parsed: RetryConfig = toml::from_str(&toml_str).unwrap();
128        assert_eq!(parsed.max_retries, 5);
129        assert_eq!(parsed.base_delay_ms, 500);
130        assert_eq!(parsed.max_delay_ms, 10_000);
131        assert!(!parsed.jitter);
132    }
133
134    #[test]
135    fn serde_defaults_from_empty() {
136        let parsed: RetryConfig = toml::from_str("").unwrap();
137        assert_eq!(parsed.max_retries, 3);
138        assert_eq!(parsed.base_delay_ms, 1000);
139        assert_eq!(parsed.max_delay_ms, 30_000);
140        assert!(parsed.jitter);
141    }
142}