mermaid_cli/utils/
retry.rs1use anyhow::Result;
2use std::time::Duration;
3
4pub struct RetryConfig {
6 pub max_attempts: usize,
7 pub initial_delay_ms: u64,
8 pub max_delay_ms: u64,
9 pub backoff_multiplier: f64,
10}
11
12impl Default for RetryConfig {
13 fn default() -> Self {
14 Self {
15 max_attempts: 3,
16 initial_delay_ms: 100,
17 max_delay_ms: 10_000,
18 backoff_multiplier: 2.0,
19 }
20 }
21}
22
23pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
25where
26 F: Fn() -> Fut,
27 Fut: std::future::Future<Output = Result<T>>,
28{
29 let mut attempt = 0;
30 let mut delay_ms = config.initial_delay_ms;
31
32 loop {
33 attempt += 1;
34
35 match operation().await {
36 Ok(result) => return Ok(result),
37 Err(e) if attempt >= config.max_attempts => {
38 return Err(anyhow::anyhow!(
39 "Operation failed after {} attempts: {}",
40 config.max_attempts,
41 e
42 ));
43 },
44 Err(e) => {
45 eprintln!(
46 "[RETRY] Attempt {}/{} failed: {}. Retrying in {}ms...",
47 attempt, config.max_attempts, e, delay_ms
48 );
49
50 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
52
53 delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
55 delay_ms = delay_ms.min(config.max_delay_ms);
56 },
57 }
58 }
59}
60
61pub fn retry_sync<F, T>(operation: F, config: &RetryConfig) -> Result<T>
63where
64 F: Fn() -> Result<T>,
65{
66 let mut attempt = 0;
67 let mut delay_ms = config.initial_delay_ms;
68
69 loop {
70 attempt += 1;
71
72 match operation() {
73 Ok(result) => return Ok(result),
74 Err(e) if attempt >= config.max_attempts => {
75 return Err(anyhow::anyhow!(
76 "Operation failed after {} attempts: {}",
77 config.max_attempts,
78 e
79 ));
80 },
81 Err(e) => {
82 eprintln!(
83 "[RETRY] Attempt {}/{} failed: {}. Retrying in {}ms...",
84 attempt, config.max_attempts, e, delay_ms
85 );
86
87 std::thread::sleep(Duration::from_millis(delay_ms));
89
90 delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
92 delay_ms = delay_ms.min(config.max_delay_ms);
93 },
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use std::sync::atomic::{AtomicUsize, Ordering};
102 use std::sync::Arc;
103
104 #[tokio::test]
105 async fn test_retry_async_success_on_first_try() {
106 let config = RetryConfig::default();
107 let call_count = Arc::new(AtomicUsize::new(0));
108 let call_count_clone = Arc::clone(&call_count);
109
110 let result = retry_async(
111 move || {
112 let count = Arc::clone(&call_count_clone);
113 async move {
114 count.fetch_add(1, Ordering::SeqCst);
115 Ok::<_, anyhow::Error>(42)
116 }
117 },
118 &config,
119 )
120 .await;
121
122 assert!(result.is_ok());
123 assert_eq!(result.unwrap(), 42);
124 assert_eq!(call_count.load(Ordering::SeqCst), 1);
125 }
126
127 #[tokio::test]
128 async fn test_retry_async_success_on_second_try() {
129 let config = RetryConfig {
130 max_attempts: 3,
131 initial_delay_ms: 10,
132 ..Default::default()
133 };
134 let call_count = Arc::new(AtomicUsize::new(0));
135 let call_count_clone = Arc::clone(&call_count);
136
137 let result = retry_async(
138 move || {
139 let count = Arc::clone(&call_count_clone);
140 async move {
141 let current = count.fetch_add(1, Ordering::SeqCst) + 1;
142 if current < 2 {
143 Err(anyhow::anyhow!("Temporary error"))
144 } else {
145 Ok(42)
146 }
147 }
148 },
149 &config,
150 )
151 .await;
152
153 assert!(result.is_ok());
154 assert_eq!(result.unwrap(), 42);
155 assert_eq!(call_count.load(Ordering::SeqCst), 2);
156 }
157
158 #[tokio::test]
159 async fn test_retry_async_fails_after_max_attempts() {
160 let config = RetryConfig {
161 max_attempts: 3,
162 initial_delay_ms: 10,
163 ..Default::default()
164 };
165 let call_count = Arc::new(AtomicUsize::new(0));
166 let call_count_clone = Arc::clone(&call_count);
167
168 let result = retry_async(
169 move || {
170 let count = Arc::clone(&call_count_clone);
171 async move {
172 count.fetch_add(1, Ordering::SeqCst);
173 Err::<i32, _>(anyhow::anyhow!("Persistent error"))
174 }
175 },
176 &config,
177 )
178 .await;
179
180 assert!(result.is_err());
181 assert_eq!(call_count.load(Ordering::SeqCst), 3);
182 }
183}