1use std::future::Future;
10use std::time::Duration;
11
12#[derive(Clone, Debug)]
14pub struct RetryConfig {
15 pub max_retries: u32,
17 pub initial_delay: Duration,
19 pub max_delay: Duration,
21 pub backoff_multiplier: f64,
23 pub jitter: bool,
25}
26
27impl Default for RetryConfig {
28 fn default() -> Self {
29 Self {
30 max_retries: 3,
31 initial_delay: Duration::from_millis(100),
32 max_delay: Duration::from_secs(10),
33 backoff_multiplier: 2.0,
34 jitter: true,
35 }
36 }
37}
38
39impl RetryConfig {
40 #[must_use]
42 pub fn new(max_retries: u32) -> Self {
43 Self {
44 max_retries,
45 ..Default::default()
46 }
47 }
48
49 #[must_use]
51 pub fn none() -> Self {
52 Self {
53 max_retries: 0,
54 ..Default::default()
55 }
56 }
57
58 #[must_use]
60 pub fn quick() -> Self {
61 Self {
62 max_retries: 2,
63 initial_delay: Duration::from_millis(50),
64 max_delay: Duration::from_secs(1),
65 backoff_multiplier: 2.0,
66 jitter: true,
67 }
68 }
69
70 #[must_use]
72 pub fn batch() -> Self {
73 Self {
74 max_retries: 5,
75 initial_delay: Duration::from_millis(200),
76 max_delay: Duration::from_secs(30),
77 backoff_multiplier: 2.0,
78 jitter: true,
79 }
80 }
81
82 #[must_use]
84 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
85 self.max_retries = max_retries;
86 self
87 }
88
89 #[must_use]
91 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
92 self.initial_delay = delay;
93 self
94 }
95
96 #[must_use]
98 pub fn with_max_delay(mut self, delay: Duration) -> Self {
99 self.max_delay = delay;
100 self
101 }
102
103 #[must_use]
105 pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
106 self.backoff_multiplier = multiplier;
107 self
108 }
109
110 #[must_use]
112 pub fn with_jitter(mut self, jitter: bool) -> Self {
113 self.jitter = jitter;
114 self
115 }
116
117 fn delay_for_attempt(&self, attempt: u32) -> Duration {
129 let base_delay =
132 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
133 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
134
135 let final_delay = if self.jitter {
136 let jitter_factor = 1.0 + (random_f64() * 0.25);
138 capped_delay * jitter_factor
139 } else {
140 capped_delay
141 };
142
143 Duration::from_millis(final_delay as u64)
144 }
145}
146
147fn random_f64() -> f64 {
149 fastrand::f64()
150}
151
152pub trait RetryableError {
154 fn is_retryable(&self) -> bool;
156
157 fn retry_after(&self) -> Option<Duration> {
159 None
160 }
161}
162
163#[derive(Debug)]
165pub struct RetryError<E> {
166 pub error: E,
168 pub attempts: u32,
170}
171
172impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(
175 f,
176 "{} (after {} attempt{})",
177 self.error,
178 self.attempts,
179 if self.attempts == 1 { "" } else { "s" }
180 )
181 }
182}
183
184impl<E: std::error::Error + 'static> std::error::Error for RetryError<E> {
185 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
186 Some(&self.error)
187 }
188}
189
190impl<E> RetryError<E> {
191 pub fn into_inner(self) -> E {
193 self.error
194 }
195}
196
197pub async fn with_retry<T, E, F, Fut>(
232 config: &RetryConfig,
233 mut operation: F,
234) -> Result<T, RetryError<E>>
235where
236 F: FnMut() -> Fut,
237 Fut: Future<Output = Result<T, E>>,
238 E: RetryableError,
239{
240 let mut attempts = 0;
241 let max_attempts = config.max_retries + 1;
242
243 loop {
244 attempts += 1;
245
246 match operation().await {
247 Ok(result) => return Ok(result),
248 Err(e) => {
249 if attempts >= max_attempts || !e.is_retryable() {
250 return Err(RetryError { error: e, attempts });
251 }
252
253 let delay = e
255 .retry_after()
256 .unwrap_or_else(|| config.delay_for_attempt(attempts - 1));
257 tokio::time::sleep(delay).await;
258 }
259 }
260 }
261}
262
263pub async fn with_simple_retry<T, E, F, Fut>(max_retries: u32, mut operation: F) -> Result<T, E>
282where
283 F: FnMut() -> Fut,
284 Fut: Future<Output = Result<T, E>>,
285{
286 let config = RetryConfig::new(max_retries);
287 let mut attempts = 0;
288 let max_attempts = config.max_retries + 1;
289
290 loop {
291 attempts += 1;
292
293 match operation().await {
294 Ok(result) => return Ok(result),
295 Err(e) => {
296 if attempts >= max_attempts {
297 return Err(e);
298 }
299
300 let delay = config.delay_for_attempt(attempts - 1);
301 tokio::time::sleep(delay).await;
302 }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_delay_calculation() {
313 let config = RetryConfig {
314 max_retries: 5,
315 initial_delay: Duration::from_millis(100),
316 max_delay: Duration::from_secs(10),
317 backoff_multiplier: 2.0,
318 jitter: false,
319 };
320
321 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
322 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
323 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
324 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
325 }
326
327 #[test]
328 fn test_delay_cap() {
329 let config = RetryConfig {
330 max_retries: 10,
331 initial_delay: Duration::from_secs(1),
332 max_delay: Duration::from_secs(5),
333 backoff_multiplier: 2.0,
334 jitter: false,
335 };
336
337 assert_eq!(config.delay_for_attempt(5), Duration::from_secs(5));
339 assert_eq!(config.delay_for_attempt(10), Duration::from_secs(5));
340 }
341
342 #[test]
343 fn test_presets() {
344 let quick = RetryConfig::quick();
345 assert_eq!(quick.max_retries, 2);
346 assert_eq!(quick.initial_delay, Duration::from_millis(50));
347
348 let batch = RetryConfig::batch();
349 assert_eq!(batch.max_retries, 5);
350 assert_eq!(batch.initial_delay, Duration::from_millis(200));
351
352 let none = RetryConfig::none();
353 assert_eq!(none.max_retries, 0);
354 }
355}