Skip to main content

mermaid_cli/utils/
retry.rs

1use anyhow::Result;
2use std::time::Duration;
3use tracing::debug;
4
5/// Retry configuration
6pub struct RetryConfig {
7    pub max_attempts: usize,
8    pub initial_delay_ms: u64,
9    pub max_delay_ms: u64,
10    pub backoff_multiplier: f64,
11}
12
13impl Default for RetryConfig {
14    fn default() -> Self {
15        Self {
16            max_attempts: 3,
17            initial_delay_ms: 100,
18            max_delay_ms: 10_000,
19            backoff_multiplier: 2.0,
20        }
21    }
22}
23
24/// Retry an async operation with exponential backoff
25pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
26where
27    F: Fn() -> Fut,
28    Fut: std::future::Future<Output = Result<T>>,
29{
30    let mut attempt = 0;
31    let mut delay_ms = config.initial_delay_ms;
32
33    loop {
34        attempt += 1;
35
36        match operation().await {
37            Ok(result) => return Ok(result),
38            Err(e) if attempt >= config.max_attempts => {
39                return Err(anyhow::anyhow!(
40                    "Operation failed after {} attempts: {}",
41                    config.max_attempts,
42                    e
43                ));
44            },
45            Err(e) => {
46                debug!(
47                    attempt = attempt,
48                    max_attempts = config.max_attempts,
49                    delay_ms = delay_ms,
50                    "Retry attempt failed: {}",
51                    e
52                );
53
54                // Sleep with exponential backoff
55                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
56
57                // Calculate next delay
58                delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
59                delay_ms = delay_ms.min(config.max_delay_ms);
60            },
61        }
62    }
63}
64
65/// Retry a synchronous operation with exponential backoff
66pub fn retry_sync<F, T>(operation: F, config: &RetryConfig) -> Result<T>
67where
68    F: Fn() -> Result<T>,
69{
70    let mut attempt = 0;
71    let mut delay_ms = config.initial_delay_ms;
72
73    loop {
74        attempt += 1;
75
76        match operation() {
77            Ok(result) => return Ok(result),
78            Err(e) if attempt >= config.max_attempts => {
79                return Err(anyhow::anyhow!(
80                    "Operation failed after {} attempts: {}",
81                    config.max_attempts,
82                    e
83                ));
84            },
85            Err(e) => {
86                debug!(
87                    attempt = attempt,
88                    max_attempts = config.max_attempts,
89                    delay_ms = delay_ms,
90                    "Retry attempt failed: {}",
91                    e
92                );
93
94                // Sleep with exponential backoff
95                std::thread::sleep(Duration::from_millis(delay_ms));
96
97                // Calculate next delay
98                delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
99                delay_ms = delay_ms.min(config.max_delay_ms);
100            },
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use std::sync::atomic::{AtomicUsize, Ordering};
109    use std::sync::Arc;
110
111    #[tokio::test]
112    async fn test_retry_async_success_on_first_try() {
113        let config = RetryConfig::default();
114        let call_count = Arc::new(AtomicUsize::new(0));
115        let call_count_clone = Arc::clone(&call_count);
116
117        let result = retry_async(
118            move || {
119                let count = Arc::clone(&call_count_clone);
120                async move {
121                    count.fetch_add(1, Ordering::SeqCst);
122                    Ok::<_, anyhow::Error>(42)
123                }
124            },
125            &config,
126        )
127        .await;
128
129        assert!(result.is_ok());
130        assert_eq!(result.unwrap(), 42);
131        assert_eq!(call_count.load(Ordering::SeqCst), 1);
132    }
133
134    #[tokio::test]
135    async fn test_retry_async_success_on_second_try() {
136        let config = RetryConfig {
137            max_attempts: 3,
138            initial_delay_ms: 10,
139            ..Default::default()
140        };
141        let call_count = Arc::new(AtomicUsize::new(0));
142        let call_count_clone = Arc::clone(&call_count);
143
144        let result = retry_async(
145            move || {
146                let count = Arc::clone(&call_count_clone);
147                async move {
148                    let current = count.fetch_add(1, Ordering::SeqCst) + 1;
149                    if current < 2 {
150                        Err(anyhow::anyhow!("Temporary error"))
151                    } else {
152                        Ok(42)
153                    }
154                }
155            },
156            &config,
157        )
158        .await;
159
160        assert!(result.is_ok());
161        assert_eq!(result.unwrap(), 42);
162        assert_eq!(call_count.load(Ordering::SeqCst), 2);
163    }
164
165    #[tokio::test]
166    async fn test_retry_async_fails_after_max_attempts() {
167        let config = RetryConfig {
168            max_attempts: 3,
169            initial_delay_ms: 10,
170            ..Default::default()
171        };
172        let call_count = Arc::new(AtomicUsize::new(0));
173        let call_count_clone = Arc::clone(&call_count);
174
175        let result = retry_async(
176            move || {
177                let count = Arc::clone(&call_count_clone);
178                async move {
179                    count.fetch_add(1, Ordering::SeqCst);
180                    Err::<i32, _>(anyhow::anyhow!("Persistent error"))
181                }
182            },
183            &config,
184        )
185        .await;
186
187        assert!(result.is_err());
188        assert_eq!(call_count.load(Ordering::SeqCst), 3);
189    }
190}