mermaid_cli/utils/
retry.rs1use anyhow::Result;
2use std::time::Duration;
3use tracing::debug;
4
5pub struct RetryConfig {
7 pub max_attempts: usize,
8 pub initial_delay_ms: u64,
9 pub max_delay_ms: u64,
10 pub backoff_multiplier: f64,
11}
12
13impl Default for RetryConfig {
14 fn default() -> Self {
15 Self {
16 max_attempts: 3,
17 initial_delay_ms: 100,
18 max_delay_ms: 10_000,
19 backoff_multiplier: 2.0,
20 }
21 }
22}
23
24pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
26where
27 F: Fn() -> Fut,
28 Fut: std::future::Future<Output = Result<T>>,
29{
30 let mut attempt = 0;
31 let mut delay_ms = config.initial_delay_ms;
32
33 loop {
34 attempt += 1;
35
36 match operation().await {
37 Ok(result) => return Ok(result),
38 Err(e) if attempt >= config.max_attempts => {
39 return Err(anyhow::anyhow!(
40 "Operation failed after {} attempts: {}",
41 config.max_attempts,
42 e
43 ));
44 },
45 Err(e) => {
46 debug!(
47 attempt = attempt,
48 max_attempts = config.max_attempts,
49 delay_ms = delay_ms,
50 "Retry attempt failed: {}",
51 e
52 );
53
54 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
56
57 delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
59 delay_ms = delay_ms.min(config.max_delay_ms);
60 },
61 }
62 }
63}
64
65pub fn retry_sync<F, T>(operation: F, config: &RetryConfig) -> Result<T>
67where
68 F: Fn() -> Result<T>,
69{
70 let mut attempt = 0;
71 let mut delay_ms = config.initial_delay_ms;
72
73 loop {
74 attempt += 1;
75
76 match operation() {
77 Ok(result) => return Ok(result),
78 Err(e) if attempt >= config.max_attempts => {
79 return Err(anyhow::anyhow!(
80 "Operation failed after {} attempts: {}",
81 config.max_attempts,
82 e
83 ));
84 },
85 Err(e) => {
86 debug!(
87 attempt = attempt,
88 max_attempts = config.max_attempts,
89 delay_ms = delay_ms,
90 "Retry attempt failed: {}",
91 e
92 );
93
94 std::thread::sleep(Duration::from_millis(delay_ms));
96
97 delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
99 delay_ms = delay_ms.min(config.max_delay_ms);
100 },
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use std::sync::atomic::{AtomicUsize, Ordering};
109 use std::sync::Arc;
110
111 #[tokio::test]
112 async fn test_retry_async_success_on_first_try() {
113 let config = RetryConfig::default();
114 let call_count = Arc::new(AtomicUsize::new(0));
115 let call_count_clone = Arc::clone(&call_count);
116
117 let result = retry_async(
118 move || {
119 let count = Arc::clone(&call_count_clone);
120 async move {
121 count.fetch_add(1, Ordering::SeqCst);
122 Ok::<_, anyhow::Error>(42)
123 }
124 },
125 &config,
126 )
127 .await;
128
129 assert!(result.is_ok());
130 assert_eq!(result.unwrap(), 42);
131 assert_eq!(call_count.load(Ordering::SeqCst), 1);
132 }
133
134 #[tokio::test]
135 async fn test_retry_async_success_on_second_try() {
136 let config = RetryConfig {
137 max_attempts: 3,
138 initial_delay_ms: 10,
139 ..Default::default()
140 };
141 let call_count = Arc::new(AtomicUsize::new(0));
142 let call_count_clone = Arc::clone(&call_count);
143
144 let result = retry_async(
145 move || {
146 let count = Arc::clone(&call_count_clone);
147 async move {
148 let current = count.fetch_add(1, Ordering::SeqCst) + 1;
149 if current < 2 {
150 Err(anyhow::anyhow!("Temporary error"))
151 } else {
152 Ok(42)
153 }
154 }
155 },
156 &config,
157 )
158 .await;
159
160 assert!(result.is_ok());
161 assert_eq!(result.unwrap(), 42);
162 assert_eq!(call_count.load(Ordering::SeqCst), 2);
163 }
164
165 #[tokio::test]
166 async fn test_retry_async_fails_after_max_attempts() {
167 let config = RetryConfig {
168 max_attempts: 3,
169 initial_delay_ms: 10,
170 ..Default::default()
171 };
172 let call_count = Arc::new(AtomicUsize::new(0));
173 let call_count_clone = Arc::clone(&call_count);
174
175 let result = retry_async(
176 move || {
177 let count = Arc::clone(&call_count_clone);
178 async move {
179 count.fetch_add(1, Ordering::SeqCst);
180 Err::<i32, _>(anyhow::anyhow!("Persistent error"))
181 }
182 },
183 &config,
184 )
185 .await;
186
187 assert!(result.is_err());
188 assert_eq!(call_count.load(Ordering::SeqCst), 3);
189 }
190}