cognate_providers/
retry.rs1use std::time::Duration;
7use cognate_core::{Error, Result};
8use futures::Future;
9
10#[derive(Debug, Clone)]
12pub struct RetryConfig {
13 pub max_retries: u32,
15 pub min_delay: Duration,
17 pub max_delay: Duration,
19 pub factor: f64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 min_delay: Duration::from_millis(500),
28 max_delay: Duration::from_secs(30),
29 factor: 2.0,
30 }
31 }
32}
33
34pub async fn with_retry<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
36where
37 F: FnMut() -> Fut,
38 Fut: Future<Output = Result<T>>,
39{
40 let mut last_error = None;
41 let mut delay = config.min_delay;
42
43 for i in 0..=config.max_retries {
44 match f().await {
45 Ok(res) => return Ok(res),
46 Err(e) if e.is_retryable() && i < config.max_retries => {
47 let actual_delay = e.retry_after()
48 .map(Duration::from_secs)
49 .unwrap_or(delay);
50
51 tokio::time::sleep(actual_delay).await;
52
53 delay = Duration::from_secs_f64(
55 (delay.as_secs_f64() * config.factor).min(config.max_delay.as_secs_f64())
56 );
57 last_error = Some(e);
58 }
59 Err(e) => return Err(e),
60 }
61 }
62
63 Err(last_error.unwrap_or_else(|| Error::RetryExhausted(config.max_retries)))
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use std::sync::atomic::{AtomicU32, Ordering};
70 use std::sync::Arc;
71
72 #[tokio::test]
73 async fn test_retry_success() {
74 let config = RetryConfig::default();
75 let counter = Arc::new(AtomicU32::new(0));
76
77 let result = with_retry(&config, || {
78 let counter = counter.clone();
79 async move {
80 let val = counter.fetch_add(1, Ordering::SeqCst);
81 if val < 2 {
82 Err(Error::Timeout(1))
83 } else {
84 Ok("success")
85 }
86 }
87 }).await;
88
89 assert_eq!(result.unwrap(), "success");
90 assert_eq!(counter.load(Ordering::SeqCst), 3);
91 }
92
93 #[tokio::test]
94 async fn test_retry_failure() {
95 let config = RetryConfig {
96 max_retries: 2,
97 ..Default::default()
98 };
99 let counter = Arc::new(AtomicU32::new(0));
100
101 let result: Result<()> = with_retry(&config, || {
102 let counter = counter.clone();
103 async move {
104 counter.fetch_add(1, Ordering::SeqCst);
105 Err(Error::Timeout(1))
106 }
107 }).await;
108
109 assert!(result.is_err());
110 assert_eq!(counter.load(Ordering::SeqCst), 3);
111 }
112}