1use std::time::Duration;
29use tokio::time::sleep;
30use tracing::{warn, info};
31use std::future::Future;
32
33#[derive(Debug, Clone)]
40pub struct RetryConfig {
41 pub initial_delay: Duration,
42 pub max_delay: Duration,
43 pub factor: f64,
44 pub max_retries: Option<usize>,
45}
46
47impl Default for RetryConfig {
48 fn default() -> Self {
49 Self::daemon()
50 }
51}
52
53impl RetryConfig {
54 #[must_use]
58 pub fn startup() -> Self {
59 Self {
60 max_retries: Some(5),
61 initial_delay: Duration::from_millis(200),
62 max_delay: Duration::from_secs(2),
63 factor: 2.0,
64 }
65 }
66
67 #[must_use]
71 pub fn daemon() -> Self {
72 Self {
73 max_retries: None, initial_delay: Duration::from_secs(1),
75 max_delay: Duration::from_secs(300), factor: 2.0,
77 }
78 }
79
80 #[must_use]
83 pub fn query() -> Self {
84 Self {
85 max_retries: Some(3),
86 initial_delay: Duration::from_millis(100),
87 max_delay: Duration::from_secs(2),
88 factor: 2.0,
89 }
90 }
91
92 #[cfg(test)]
94 pub fn test() -> Self {
95 Self {
96 max_retries: Some(3),
97 initial_delay: Duration::from_millis(1),
98 max_delay: Duration::from_millis(10),
99 factor: 2.0,
100 }
101 }
102}
103
104pub async fn retry<F, Fut, T, E>(
105 operation_name: &str,
106 config: &RetryConfig,
107 mut operation: F,
108) -> Result<T, E>
109where
110 F: FnMut() -> Fut,
111 Fut: Future<Output = Result<T, E>>,
112 E: std::fmt::Display,
113{
114 let mut delay = config.initial_delay;
115 let mut attempts = 0;
116
117 loop {
118 match operation().await {
119 Ok(val) => {
120 if attempts > 0 {
121 info!("Operation '{}' succeeded after {} retries", operation_name, attempts);
122 }
123 return Ok(val);
124 }
125 Err(err) => {
126 attempts += 1;
127
128 if let Some(max) = config.max_retries {
129 if attempts >= max {
130 return Err(err);
131 }
132 }
133
134 if config.max_retries.is_none() {
135 warn!(
137 "Operation '{}' failed (attempt {}, will retry forever): {}. Next retry in {:?}...",
138 operation_name, attempts, err, delay
139 );
140 } else {
141 warn!(
142 "Operation '{}' failed (attempt {}/{}): {}. Retrying in {:?}...",
143 operation_name, attempts, config.max_retries.unwrap(), err, delay
144 );
145 }
146
147 sleep(delay).await;
148 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
149 }
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::sync::atomic::{AtomicUsize, Ordering};
158 use std::sync::Arc;
159
160 #[derive(Debug)]
161 struct TestError(String);
162
163 impl std::fmt::Display for TestError {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 write!(f, "{}", self.0)
166 }
167 }
168
169 #[tokio::test]
170 async fn test_retry_succeeds_first_try() {
171 let result: Result<i32, TestError> = retry(
172 "test_op",
173 &RetryConfig::test(),
174 || async { Ok(42) },
175 ).await;
176
177 assert_eq!(result.unwrap(), 42);
178 }
179
180 #[tokio::test]
181 async fn test_retry_succeeds_after_failures() {
182 let attempts = Arc::new(AtomicUsize::new(0));
183 let attempts_clone = attempts.clone();
184
185 let result: Result<i32, TestError> = retry(
186 "test_op",
187 &RetryConfig::test(),
188 || {
189 let a = attempts_clone.clone();
190 async move {
191 let count = a.fetch_add(1, Ordering::SeqCst) + 1;
192 if count < 3 {
193 Err(TestError(format!("fail {}", count)))
194 } else {
195 Ok(42)
196 }
197 }
198 },
199 ).await;
200
201 assert_eq!(result.unwrap(), 42);
202 assert_eq!(attempts.load(Ordering::SeqCst), 3);
203 }
204
205 #[tokio::test]
206 async fn test_retry_exhausts_retries() {
207 let attempts = Arc::new(AtomicUsize::new(0));
208 let attempts_clone = attempts.clone();
209
210 let config = RetryConfig {
211 max_retries: Some(3),
212 initial_delay: Duration::from_millis(1),
213 max_delay: Duration::from_millis(10),
214 factor: 2.0,
215 };
216
217 let result: Result<i32, TestError> = retry(
218 "test_op",
219 &config,
220 || {
221 let a = attempts_clone.clone();
222 async move {
223 a.fetch_add(1, Ordering::SeqCst);
224 Err(TestError("always fail".to_string()))
225 }
226 },
227 ).await;
228
229 assert!(result.is_err());
230 assert!(result.unwrap_err().0.contains("always fail"));
231 assert_eq!(attempts.load(Ordering::SeqCst), 3);
232 }
233
234 #[test]
235 fn test_retry_config_presets() {
236 let startup = RetryConfig::startup();
238 assert!(startup.max_retries.is_some());
239 assert_eq!(startup.max_retries.unwrap(), 5);
240
241 let daemon = RetryConfig::daemon();
243 assert!(daemon.max_retries.is_none());
244
245 let query = RetryConfig::query();
247 assert!(query.max_retries.is_some());
248 assert_eq!(query.max_retries.unwrap(), 3);
249 }
250
251 #[test]
252 fn test_delay_exponential_backoff() {
253 let config = RetryConfig {
254 initial_delay: Duration::from_millis(100),
255 max_delay: Duration::from_secs(10),
256 factor: 2.0,
257 max_retries: Some(5),
258 };
259
260 let mut delay = config.initial_delay;
261
262 assert_eq!(delay, Duration::from_millis(100));
264
265 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
267 assert_eq!(delay, Duration::from_millis(200));
268
269 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
271 assert_eq!(delay, Duration::from_millis(400));
272 }
273
274 #[test]
275 fn test_delay_caps_at_max() {
276 let config = RetryConfig {
277 initial_delay: Duration::from_secs(1),
278 max_delay: Duration::from_secs(5),
279 factor: 10.0, max_retries: Some(5),
281 };
282
283 let mut delay = config.initial_delay;
284 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
285
286 assert_eq!(delay, Duration::from_secs(5));
288 }
289}