Skip to main content

mermaid_cli/utils/
retry.rs

1use anyhow::Result;
2use std::time::Duration;
3use tracing::debug;
4
5/// Retry configuration
6pub struct RetryConfig {
7    pub max_attempts: usize,
8    pub initial_delay_ms: u64,
9    pub max_delay_ms: u64,
10    pub backoff_multiplier: f64,
11}
12
13impl Default for RetryConfig {
14    fn default() -> Self {
15        Self {
16            max_attempts: 3,
17            initial_delay_ms: 100,
18            max_delay_ms: 10_000,
19            backoff_multiplier: 2.0,
20        }
21    }
22}
23
24/// Retry an async operation with exponential backoff
25pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
26where
27    F: Fn() -> Fut,
28    Fut: std::future::Future<Output = Result<T>>,
29{
30    let mut attempt = 0;
31    let mut delay_ms = config.initial_delay_ms;
32
33    loop {
34        attempt += 1;
35
36        match operation().await {
37            Ok(result) => return Ok(result),
38            Err(e) if attempt >= config.max_attempts => {
39                return Err(anyhow::anyhow!(
40                    "Operation failed after {} attempts: {}",
41                    config.max_attempts,
42                    e
43                ));
44            },
45            Err(e) => {
46                debug!(
47                    attempt = attempt,
48                    max_attempts = config.max_attempts,
49                    delay_ms = delay_ms,
50                    "Retry attempt failed: {}",
51                    e
52                );
53
54                // Sleep with exponential backoff
55                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
56
57                // Calculate next delay
58                delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
59                delay_ms = delay_ms.min(config.max_delay_ms);
60            },
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use std::sync::atomic::{AtomicUsize, Ordering};
69    use std::sync::Arc;
70
71    #[tokio::test]
72    async fn test_retry_async_success_on_first_try() {
73        let config = RetryConfig::default();
74        let call_count = Arc::new(AtomicUsize::new(0));
75        let call_count_clone = Arc::clone(&call_count);
76
77        let result = retry_async(
78            move || {
79                let count = Arc::clone(&call_count_clone);
80                async move {
81                    count.fetch_add(1, Ordering::SeqCst);
82                    Ok::<_, anyhow::Error>(42)
83                }
84            },
85            &config,
86        )
87        .await;
88
89        assert!(result.is_ok());
90        assert_eq!(result.unwrap(), 42);
91        assert_eq!(call_count.load(Ordering::SeqCst), 1);
92    }
93
94    #[tokio::test]
95    async fn test_retry_async_success_on_second_try() {
96        let config = RetryConfig {
97            max_attempts: 3,
98            initial_delay_ms: 10,
99            ..Default::default()
100        };
101        let call_count = Arc::new(AtomicUsize::new(0));
102        let call_count_clone = Arc::clone(&call_count);
103
104        let result = retry_async(
105            move || {
106                let count = Arc::clone(&call_count_clone);
107                async move {
108                    let current = count.fetch_add(1, Ordering::SeqCst) + 1;
109                    if current < 2 {
110                        Err(anyhow::anyhow!("Temporary error"))
111                    } else {
112                        Ok(42)
113                    }
114                }
115            },
116            &config,
117        )
118        .await;
119
120        assert!(result.is_ok());
121        assert_eq!(result.unwrap(), 42);
122        assert_eq!(call_count.load(Ordering::SeqCst), 2);
123    }
124
125    #[tokio::test]
126    async fn test_retry_async_fails_after_max_attempts() {
127        let config = RetryConfig {
128            max_attempts: 3,
129            initial_delay_ms: 10,
130            ..Default::default()
131        };
132        let call_count = Arc::new(AtomicUsize::new(0));
133        let call_count_clone = Arc::clone(&call_count);
134
135        let result = retry_async(
136            move || {
137                let count = Arc::clone(&call_count_clone);
138                async move {
139                    count.fetch_add(1, Ordering::SeqCst);
140                    Err::<i32, _>(anyhow::anyhow!("Persistent error"))
141                }
142            },
143            &config,
144        )
145        .await;
146
147        assert!(result.is_err());
148        assert_eq!(call_count.load(Ordering::SeqCst), 3);
149    }
150}