mermaid_cli/utils/
retry.rs

1use anyhow::Result;
2use std::time::Duration;
3
4/// Retry configuration
5pub struct RetryConfig {
6    pub max_attempts: usize,
7    pub initial_delay_ms: u64,
8    pub max_delay_ms: u64,
9    pub backoff_multiplier: f64,
10}
11
12impl Default for RetryConfig {
13    fn default() -> Self {
14        Self {
15            max_attempts: 3,
16            initial_delay_ms: 100,
17            max_delay_ms: 10_000,
18            backoff_multiplier: 2.0,
19        }
20    }
21}
22
23/// Retry an async operation with exponential backoff
24pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
25where
26    F: Fn() -> Fut,
27    Fut: std::future::Future<Output = Result<T>>,
28{
29    let mut attempt = 0;
30    let mut delay_ms = config.initial_delay_ms;
31
32    loop {
33        attempt += 1;
34
35        match operation().await {
36            Ok(result) => return Ok(result),
37            Err(e) if attempt >= config.max_attempts => {
38                return Err(anyhow::anyhow!(
39                    "Operation failed after {} attempts: {}",
40                    config.max_attempts,
41                    e
42                ));
43            },
44            Err(e) => {
45                eprintln!(
46                    "[RETRY] Attempt {}/{} failed: {}. Retrying in {}ms...",
47                    attempt, config.max_attempts, e, delay_ms
48                );
49
50                // Sleep with exponential backoff
51                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
52
53                // Calculate next delay
54                delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
55                delay_ms = delay_ms.min(config.max_delay_ms);
56            },
57        }
58    }
59}
60
61/// Retry a synchronous operation with exponential backoff
62pub fn retry_sync<F, T>(operation: F, config: &RetryConfig) -> Result<T>
63where
64    F: Fn() -> Result<T>,
65{
66    let mut attempt = 0;
67    let mut delay_ms = config.initial_delay_ms;
68
69    loop {
70        attempt += 1;
71
72        match operation() {
73            Ok(result) => return Ok(result),
74            Err(e) if attempt >= config.max_attempts => {
75                return Err(anyhow::anyhow!(
76                    "Operation failed after {} attempts: {}",
77                    config.max_attempts,
78                    e
79                ));
80            },
81            Err(e) => {
82                eprintln!(
83                    "[RETRY] Attempt {}/{} failed: {}. Retrying in {}ms...",
84                    attempt, config.max_attempts, e, delay_ms
85                );
86
87                // Sleep with exponential backoff
88                std::thread::sleep(Duration::from_millis(delay_ms));
89
90                // Calculate next delay
91                delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
92                delay_ms = delay_ms.min(config.max_delay_ms);
93            },
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use std::sync::atomic::{AtomicUsize, Ordering};
102    use std::sync::Arc;
103
104    #[tokio::test]
105    async fn test_retry_async_success_on_first_try() {
106        let config = RetryConfig::default();
107        let call_count = Arc::new(AtomicUsize::new(0));
108        let call_count_clone = Arc::clone(&call_count);
109
110        let result = retry_async(
111            move || {
112                let count = Arc::clone(&call_count_clone);
113                async move {
114                    count.fetch_add(1, Ordering::SeqCst);
115                    Ok::<_, anyhow::Error>(42)
116                }
117            },
118            &config,
119        )
120        .await;
121
122        assert!(result.is_ok());
123        assert_eq!(result.unwrap(), 42);
124        assert_eq!(call_count.load(Ordering::SeqCst), 1);
125    }
126
127    #[tokio::test]
128    async fn test_retry_async_success_on_second_try() {
129        let config = RetryConfig {
130            max_attempts: 3,
131            initial_delay_ms: 10,
132            ..Default::default()
133        };
134        let call_count = Arc::new(AtomicUsize::new(0));
135        let call_count_clone = Arc::clone(&call_count);
136
137        let result = retry_async(
138            move || {
139                let count = Arc::clone(&call_count_clone);
140                async move {
141                    let current = count.fetch_add(1, Ordering::SeqCst) + 1;
142                    if current < 2 {
143                        Err(anyhow::anyhow!("Temporary error"))
144                    } else {
145                        Ok(42)
146                    }
147                }
148            },
149            &config,
150        )
151        .await;
152
153        assert!(result.is_ok());
154        assert_eq!(result.unwrap(), 42);
155        assert_eq!(call_count.load(Ordering::SeqCst), 2);
156    }
157
158    #[tokio::test]
159    async fn test_retry_async_fails_after_max_attempts() {
160        let config = RetryConfig {
161            max_attempts: 3,
162            initial_delay_ms: 10,
163            ..Default::default()
164        };
165        let call_count = Arc::new(AtomicUsize::new(0));
166        let call_count_clone = Arc::clone(&call_count);
167
168        let result = retry_async(
169            move || {
170                let count = Arc::clone(&call_count_clone);
171                async move {
172                    count.fetch_add(1, Ordering::SeqCst);
173                    Err::<i32, _>(anyhow::anyhow!("Persistent error"))
174                }
175            },
176            &config,
177        )
178        .await;
179
180        assert!(result.is_err());
181        assert_eq!(call_count.load(Ordering::SeqCst), 3);
182    }
183}