mermaid_cli/utils/
retry.rs1use anyhow::Result;
2use std::time::Duration;
3use tracing::debug;
4
5pub struct RetryConfig {
7 pub max_attempts: usize,
8 pub initial_delay_ms: u64,
9 pub max_delay_ms: u64,
10 pub backoff_multiplier: f64,
11}
12
13impl Default for RetryConfig {
14 fn default() -> Self {
15 Self {
16 max_attempts: 3,
17 initial_delay_ms: 100,
18 max_delay_ms: 10_000,
19 backoff_multiplier: 2.0,
20 }
21 }
22}
23
24pub async fn retry_async<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
26where
27 F: Fn() -> Fut,
28 Fut: std::future::Future<Output = Result<T>>,
29{
30 let mut attempt = 0;
31 let mut delay_ms = config.initial_delay_ms;
32
33 loop {
34 attempt += 1;
35
36 match operation().await {
37 Ok(result) => return Ok(result),
38 Err(e) if attempt >= config.max_attempts => {
39 return Err(anyhow::anyhow!(
40 "Operation failed after {} attempts: {}",
41 config.max_attempts,
42 e
43 ));
44 },
45 Err(e) => {
46 debug!(
47 attempt = attempt,
48 max_attempts = config.max_attempts,
49 delay_ms = delay_ms,
50 "Retry attempt failed: {}",
51 e
52 );
53
54 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
56
57 delay_ms = ((delay_ms as f64) * config.backoff_multiplier) as u64;
59 delay_ms = delay_ms.min(config.max_delay_ms);
60 },
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use std::sync::Arc;
69 use std::sync::atomic::{AtomicUsize, Ordering};
70
71 #[tokio::test]
72 async fn test_retry_async_success_on_first_try() {
73 let config = RetryConfig::default();
74 let call_count = Arc::new(AtomicUsize::new(0));
75 let call_count_clone = Arc::clone(&call_count);
76
77 let result = retry_async(
78 move || {
79 let count = Arc::clone(&call_count_clone);
80 async move {
81 count.fetch_add(1, Ordering::SeqCst);
82 Ok::<_, anyhow::Error>(42)
83 }
84 },
85 &config,
86 )
87 .await;
88
89 assert!(result.is_ok());
90 assert_eq!(result.unwrap(), 42);
91 assert_eq!(call_count.load(Ordering::SeqCst), 1);
92 }
93
94 #[tokio::test]
95 async fn test_retry_async_success_on_second_try() {
96 let config = RetryConfig {
97 max_attempts: 3,
98 initial_delay_ms: 10,
99 ..Default::default()
100 };
101 let call_count = Arc::new(AtomicUsize::new(0));
102 let call_count_clone = Arc::clone(&call_count);
103
104 let result = retry_async(
105 move || {
106 let count = Arc::clone(&call_count_clone);
107 async move {
108 let current = count.fetch_add(1, Ordering::SeqCst) + 1;
109 if current < 2 {
110 Err(anyhow::anyhow!("Temporary error"))
111 } else {
112 Ok(42)
113 }
114 }
115 },
116 &config,
117 )
118 .await;
119
120 assert!(result.is_ok());
121 assert_eq!(result.unwrap(), 42);
122 assert_eq!(call_count.load(Ordering::SeqCst), 2);
123 }
124
125 #[tokio::test]
126 async fn test_retry_async_fails_after_max_attempts() {
127 let config = RetryConfig {
128 max_attempts: 3,
129 initial_delay_ms: 10,
130 ..Default::default()
131 };
132 let call_count = Arc::new(AtomicUsize::new(0));
133 let call_count_clone = Arc::clone(&call_count);
134
135 let result = retry_async(
136 move || {
137 let count = Arc::clone(&call_count_clone);
138 async move {
139 count.fetch_add(1, Ordering::SeqCst);
140 Err::<i32, _>(anyhow::anyhow!("Persistent error"))
141 }
142 },
143 &config,
144 )
145 .await;
146
147 assert!(result.is_err());
148 assert_eq!(call_count.load(Ordering::SeqCst), 3);
149 }
150}