1use std::time::Duration;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum RetryAction {
53 ShouldRetry(Duration),
55 NoRetry,
57 PermanentFailure,
59}
60
61pub trait RetryPolicy: Send + Sync {
65 fn should_retry(&self, attempt: u32) -> Option<Duration>;
76
77 fn should_retry_with_error(&self, attempt: u32, _error: &str) -> Option<Duration> {
82 self.should_retry(attempt)
83 }
84
85 fn max_retries(&self) -> u32 {
87 u32::MAX
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct ExponentialBackoff {
94 max_retries: u32,
95 initial_delay: Duration,
96 max_delay: Duration,
97 jitter: bool,
98}
99
100impl ExponentialBackoff {
101 pub fn new(max_retries: u32, initial_delay: Duration) -> Self {
108 Self {
109 max_retries,
110 initial_delay,
111 max_delay: Duration::from_secs(30),
112 jitter: false,
113 }
114 }
115
116 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
118 self.max_delay = max_delay;
119 self
120 }
121
122 pub fn with_jitter(mut self) -> Self {
126 self.jitter = true;
127 self
128 }
129
130 fn calculate_delay(&self, attempt: u32) -> f64 {
131 let base_delay_ms = self.initial_delay.as_millis() as f64;
132 let max_delay_ms = self.max_delay.as_millis() as f64;
133 let delay = base_delay_ms * 2f64.powi(attempt as i32);
134
135 let delay = if delay > max_delay_ms {
136 max_delay_ms
137 } else {
138 delay
139 };
140
141 if self.jitter {
142 use std::time::Instant;
143 let now = Instant::now();
144 let nanos = now.elapsed().as_nanos() as f64;
145 let jitter_range = delay * 0.2;
146 let jitter = nanos as f64 % jitter_range;
147 delay - jitter_range / 2.0 + jitter
148 } else {
149 delay
150 }
151 }
152}
153
154impl RetryPolicy for ExponentialBackoff {
155 fn should_retry(&self, attempt: u32) -> Option<Duration> {
156 if attempt >= self.max_retries {
157 return None;
158 }
159 let delay_ms = self.calculate_delay(attempt);
160 Some(Duration::from_millis(delay_ms as u64))
161 }
162
163 fn max_retries(&self) -> u32 {
164 self.max_retries
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct FixedDelay {
171 max_retries: u32,
172 delay: Duration,
173}
174
175impl FixedDelay {
176 pub fn new(max_retries: u32, delay: Duration) -> Self {
178 Self { max_retries, delay }
179 }
180}
181
182impl RetryPolicy for FixedDelay {
183 fn should_retry(&self, attempt: u32) -> Option<Duration> {
184 if attempt >= self.max_retries {
185 return None;
186 }
187 Some(self.delay)
188 }
189
190 fn max_retries(&self) -> u32 {
191 self.max_retries
192 }
193}
194
195#[derive(Debug, Clone, Copy, Default)]
197pub struct NoRetry;
198
199impl RetryPolicy for NoRetry {
200 fn should_retry(&self, _attempt: u32) -> Option<Duration> {
201 None
202 }
203
204 fn max_retries(&self) -> u32 {
205 0
206 }
207}
208
209pub struct TransientFilter<P> {
213 inner: P,
214 predicate: Box<dyn Fn(&str) -> bool + Send + Sync>,
215}
216
217impl<P: RetryPolicy> TransientFilter<P> {
218 pub fn new(policy: P, predicate: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
220 Self {
221 inner: policy,
222 predicate: Box::new(predicate),
223 }
224 }
225}
226
227impl<P: RetryPolicy> RetryPolicy for TransientFilter<P> {
228 fn should_retry(&self, attempt: u32) -> Option<Duration> {
229 self.inner.should_retry(attempt)
230 }
231
232 fn should_retry_with_error(&self, attempt: u32, error: &str) -> Option<Duration> {
233 if (self.predicate)(error) {
234 self.inner.should_retry_with_error(attempt, error)
235 } else {
236 None
237 }
238 }
239
240 fn max_retries(&self) -> u32 {
241 self.inner.max_retries()
242 }
243}
244
245pub trait RetryPolicyExt: RetryPolicy + Sized {
247 fn delays(&self) -> DelayIterator<'_, Self> {
249 DelayIterator {
250 policy: self,
251 attempt: 0,
252 }
253 }
254}
255
256impl<T: RetryPolicy + Sized> RetryPolicyExt for T {}
257
258#[derive(Debug)]
260pub struct DelayIterator<'a, P: RetryPolicy> {
261 policy: &'a P,
262 attempt: u32,
263}
264
265impl<P: RetryPolicy> Iterator for DelayIterator<'_, P> {
266 type Item = Duration;
267
268 fn next(&mut self) -> Option<Self::Item> {
269 let delay = self.policy.should_retry(self.attempt);
270 self.attempt += 1;
271 delay
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_exponential_backoff() {
281 let policy = ExponentialBackoff::new(3, Duration::from_millis(100));
282
283 assert_eq!(policy.should_retry(0), Some(Duration::from_millis(100)));
284 assert_eq!(policy.should_retry(1), Some(Duration::from_millis(200)));
285 assert_eq!(policy.should_retry(2), Some(Duration::from_millis(400)));
286 assert_eq!(policy.should_retry(3), None);
287 }
288
289 #[test]
290 fn test_exponential_backoff_with_max_delay() {
291 let policy = ExponentialBackoff::new(10, Duration::from_millis(100))
292 .with_max_delay(Duration::from_millis(500));
293
294 assert_eq!(policy.should_retry(0), Some(Duration::from_millis(100)));
295 assert_eq!(policy.should_retry(1), Some(Duration::from_millis(200)));
296 assert_eq!(policy.should_retry(2), Some(Duration::from_millis(400)));
297 assert_eq!(policy.should_retry(3), Some(Duration::from_millis(500)));
298 assert_eq!(policy.should_retry(4), Some(Duration::from_millis(500)));
299 }
300
301 #[test]
302 fn test_fixed_delay() {
303 let policy = FixedDelay::new(3, Duration::from_secs(1));
304
305 assert_eq!(policy.should_retry(0), Some(Duration::from_secs(1)));
306 assert_eq!(policy.should_retry(1), Some(Duration::from_secs(1)));
307 assert_eq!(policy.should_retry(2), Some(Duration::from_secs(1)));
308 assert_eq!(policy.should_retry(3), None);
309 }
310
311 #[test]
312 fn test_no_retry() {
313 let policy = NoRetry;
314
315 assert_eq!(policy.should_retry(0), None);
316 assert_eq!(policy.should_retry(1), None);
317 }
318
319 #[test]
320 fn test_delay_iterator() {
321 let policy = ExponentialBackoff::new(3, Duration::from_millis(100));
322
323 let delays: Vec<_> = policy.delays().collect();
324 assert_eq!(delays.len(), 3);
325 assert_eq!(delays[0], Duration::from_millis(100));
326 assert_eq!(delays[1], Duration::from_millis(200));
327 assert_eq!(delays[2], Duration::from_millis(400));
328 }
329}