1use crate::prelude::*;
7use std::fmt;
8use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12pub type RetryCallback = Arc<dyn Fn(usize, &LimitlessError, Duration) + Send + Sync>;
14
15#[derive(Clone)]
20pub struct RetryConfig {
21 pub status_codes: Vec<u16>,
23 pub max_retries: usize,
25 pub exponential_base: f64,
27 pub max_delay: Duration,
29 pub initial_delay: Duration,
31 pub on_retry: Option<RetryCallback>,
33}
34
35impl fmt::Debug for RetryConfig {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("RetryConfig")
38 .field("status_codes", &self.status_codes)
39 .field("max_retries", &self.max_retries)
40 .field("exponential_base", &self.exponential_base)
41 .field("max_delay", &self.max_delay)
42 .field("initial_delay", &self.initial_delay)
43 .field("has_on_retry", &self.on_retry.is_some())
44 .finish()
45 }
46}
47
48impl Default for RetryConfig {
49 fn default() -> Self {
50 Self {
51 status_codes: vec![429, 500, 502, 503, 504],
52 max_retries: 3,
53 exponential_base: 2.0,
54 max_delay: Duration::from_secs(60),
55 initial_delay: Duration::from_secs(1),
56 on_retry: None,
57 }
58 }
59}
60
61impl RetryConfig {
62 pub fn none() -> Self {
64 Self {
65 max_retries: 0,
66 ..Default::default()
67 }
68 }
69
70 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
72 let base = if self.exponential_base.is_finite() && self.exponential_base > 0.0 {
73 self.exponential_base
74 } else {
75 2.0
76 };
77
78 let exponent = attempt.min(63) as u32;
79 let seconds = self.initial_delay.as_secs_f64() * base.powi(exponent as i32);
80 let capped = seconds.min(self.max_delay.as_secs_f64());
81 Duration::from_secs_f64(if capped <= 0.0 { 0.001 } else { capped })
82 }
83
84 pub fn should_retry(&self, error: &LimitlessError) -> bool {
86 match error {
87 LimitlessError::RateLimited => self.status_codes.contains(&429),
88 LimitlessError::InternalServerError => self.status_codes.contains(&500),
89 LimitlessError::ServiceUnavailable => self.status_codes.contains(&503),
90 LimitlessError::StatusCode(code) => self.status_codes.contains(code),
91 LimitlessError::ReqError(err) => err.is_connect() || err.is_timeout(),
92 _ => false,
93 }
94 }
95
96 #[must_use]
98 pub fn with_on_retry<F>(mut self, callback: F) -> Self
99 where
100 F: Fn(usize, &LimitlessError, Duration) + Send + Sync + 'static,
101 {
102 self.on_retry = Some(Arc::new(callback));
103 self
104 }
105}
106
107pub async fn with_retry<T, F, Fut>(
120 config: RetryConfig,
121 mut operation: F,
122) -> Result<T, LimitlessError>
123where
124 F: FnMut() -> Fut,
125 Fut: Future<Output = Result<T, LimitlessError>>,
126{
127 let mut last_error = None;
128
129 for attempt in 0..=config.max_retries {
130 match operation().await {
131 Ok(value) => return Ok(value),
132 Err(err) => {
133 let retryable = config.should_retry(&err);
134 last_error = Some(err);
135
136 if !retryable || attempt == config.max_retries {
137 break;
138 }
139
140 let delay = config.delay_for_attempt(attempt);
141 if let Some(callback) = &config.on_retry {
142 if let Some(ref err) = last_error {
143 callback(attempt, err, delay);
144 }
145 }
146
147 log::warn!(
148 "Retrying request after failure (attempt {} of {})",
149 attempt + 1,
150 config.max_retries
151 );
152 tokio::time::sleep(delay).await;
153 }
154 }
155 }
156
157 Err(last_error.expect("retry loop always stores the last error"))
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use std::sync::atomic::{AtomicUsize, Ordering};
164
165 #[test]
166 fn delay_grows_exponentially() {
167 let config = RetryConfig::default();
168 let d0 = config.delay_for_attempt(0);
169 let d2 = config.delay_for_attempt(2);
170 assert!(d2 > d0);
171 }
172
173 #[test]
174 fn delay_clamps_to_max() {
175 let config = RetryConfig {
176 max_delay: Duration::from_secs(5),
177 ..Default::default()
178 };
179 assert_eq!(config.delay_for_attempt(100), Duration::from_secs(5));
180 }
181
182 #[tokio::test]
183 async fn retries_and_eventually_succeeds() {
184 let attempts = Arc::new(AtomicUsize::new(0));
185 let a = attempts.clone();
186
187 let result = with_retry(
188 RetryConfig {
189 max_retries: 3,
190 initial_delay: Duration::from_millis(1),
191 ..Default::default()
192 },
193 move || {
194 let a = a.clone();
195 async move {
196 let attempt = a.fetch_add(1, Ordering::SeqCst);
197 if attempt < 2 {
198 Err(LimitlessError::RateLimited)
199 } else {
200 Ok("ok")
201 }
202 }
203 },
204 )
205 .await
206 .unwrap();
207
208 assert_eq!(result, "ok");
209 assert_eq!(attempts.load(Ordering::SeqCst), 3);
210 }
211
212 #[tokio::test]
213 async fn does_not_retry_non_retryable_errors() {
214 let attempts = Arc::new(AtomicUsize::new(0));
215 let a = attempts.clone();
216
217 let err = with_retry(RetryConfig::default(), move || {
218 let a = a.clone();
219 async move {
220 a.fetch_add(1, Ordering::SeqCst);
221 Err::<(), _>(LimitlessError::ValidationError("boom".into()))
222 }
223 })
224 .await
225 .unwrap_err();
226
227 assert!(matches!(err, LimitlessError::ValidationError(_)));
228 assert_eq!(attempts.load(Ordering::SeqCst), 1);
229 }
230}