1use async_trait::async_trait;
2use futures_lite::FutureExt as lite_ext;
3use futures_util::FutureExt;
4use std::error::Error;
5use std::fmt::{Debug, Display, Formatter};
6use std::future::Future;
7use std::time::Duration;
8use tracing::warn;
9
10use crate::timer::sleep;
11pub use delay::ExponentialBackoff;
12pub use delay::FibonacciBackoff;
13pub use delay::FixedDelay;
14
15#[async_trait]
17pub trait RetryExt: Future {
18 async fn timeout(self, timeout: Duration) -> Result<Self::Output, TimeoutError>;
33}
34
35#[async_trait]
36impl<F: Future + Send> RetryExt for F {
37 async fn timeout(self, timeout: Duration) -> Result<Self::Output, TimeoutError> {
38 self.map(Ok)
39 .or(async move {
40 let _ = sleep(timeout).await;
41 Err(TimeoutError)
42 })
43 .await
44 }
45}
46
47#[derive(Debug, Clone, Eq, PartialEq)]
48pub struct TimeoutError;
49
50impl Display for TimeoutError {
51 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{self:?}")
53 }
54}
55
56impl Error for TimeoutError {}
57
58macro_rules! poll_err {
61 ($function:ident, $condition:ident) => {{
62 match $function().await {
63 Ok(output) => return Ok(output),
64 Err(err) if !$condition(&err) => return Err(err),
65 Err(err) => err,
66 }
67 }};
68}
69
70pub fn retry<I, O, F, E, A>(retries: I, factory: A) -> impl Future<Output = Result<O, E>>
72where
73 I: IntoIterator<Item = Duration>,
74 A: FnMut() -> F,
75 F: Future<Output = Result<O, E>>,
76 E: Debug,
77{
78 retry_if(retries, factory, |_| true)
79}
80
81pub async fn retry_if<I, O, F, E, A, P>(retries: I, mut factory: A, condition: P) -> Result<O, E>
110where
111 I: IntoIterator<Item = Duration>,
112 A: FnMut() -> F,
113 F: Future<Output = Result<O, E>>,
114 P: Fn(&E) -> bool,
115 E: Debug,
116{
117 let mut err = poll_err!(factory, condition);
118 for delay_duration in retries.into_iter() {
119 cfg_if::cfg_if! {
120 if #[cfg(target_arch = "wasm32")] {
121 sleep(delay_duration).await.unwrap();
122 } else {
123 sleep(delay_duration).await;
124 }
125 }
126 warn!(?err, "retrying");
127 err = poll_err!(factory, condition);
128 }
129 Err(err)
130}
131
132mod delay {
133 use std::time::Duration;
134
135 #[derive(Default, Clone, Debug, Eq, PartialEq)]
146 pub struct FixedDelay {
147 delay: Duration,
148 }
149
150 impl FixedDelay {
151 pub fn new(delay: Duration) -> Self {
152 Self { delay }
153 }
154
155 pub fn from_millis(millis: u64) -> Self {
156 Self::new(Duration::from_millis(millis))
157 }
158
159 pub fn from_secs(secs: u64) -> Self {
160 Self::new(Duration::from_secs(secs))
161 }
162 }
163
164 impl Iterator for FixedDelay {
165 type Item = Duration;
166
167 fn next(&mut self) -> Option<Duration> {
168 Some(self.delay)
169 }
170 }
171
172 #[derive(Default, Clone, Debug, Eq, PartialEq)]
183 pub struct FibonacciBackoff {
184 current: Duration,
185 next: Duration,
186 max_delay: Option<Duration>,
187 }
188
189 impl FibonacciBackoff {
190 pub fn new(initial_delay: Duration) -> Self {
191 Self {
192 current: initial_delay,
193 next: initial_delay,
194 max_delay: None,
195 }
196 }
197
198 pub fn from_millis(millis: u64) -> Self {
199 Self::new(Duration::from_millis(millis))
200 }
201
202 pub fn from_secs(secs: u64) -> Self {
203 Self::new(Duration::from_secs(secs))
204 }
205
206 pub fn max_delay(mut self, max_delay: Duration) -> Self {
207 self.max_delay = Some(max_delay);
208 self
209 }
210 }
211
212 impl Iterator for FibonacciBackoff {
213 type Item = Duration;
214
215 fn next(&mut self) -> Option<Self::Item> {
216 let duration = self.current;
217 if let Some(ref max_delay) = self.max_delay {
218 if duration > *max_delay {
219 return Some(*max_delay);
220 }
221 };
222 if let Some(next_next) = self.current.checked_add(self.next) {
223 self.current = self.next;
224 self.next = next_next;
225 } else {
226 self.current = self.next;
227 self.next = Duration::MAX;
228 }
229 Some(duration)
230 }
231 }
232
233 #[derive(Default, Clone, Debug, Eq, PartialEq)]
244 pub struct ExponentialBackoff {
245 base_millis: u64,
246 current_millis: u64,
247 max_delay: Option<Duration>,
248 }
249
250 impl ExponentialBackoff {
251 pub fn from_millis(millis: u64) -> Self {
252 Self {
253 base_millis: millis,
254 current_millis: millis,
255 max_delay: None,
256 }
257 }
258
259 pub fn max_delay(mut self, max_delay: Duration) -> Self {
260 self.max_delay = Some(max_delay);
261 self
262 }
263 }
264
265 impl Iterator for ExponentialBackoff {
266 type Item = Duration;
267
268 fn next(&mut self) -> Option<Self::Item> {
269 let duration = Duration::from_millis(self.current_millis);
270 if let Some(ref max_delay) = self.max_delay {
271 if duration > *max_delay {
272 return Some(*max_delay);
273 }
274 };
275 if let Some(next) = self.current_millis.checked_mul(self.base_millis) {
276 self.current_millis = next;
277 } else {
278 self.current_millis = u64::MAX;
279 }
280 Some(duration)
281 }
282 }
283
284 #[cfg(test)]
285 mod test {
286 use super::*;
287
288 #[test]
289 fn test_fibonacci_series_starting_at_10() {
290 let mut iter = FibonacciBackoff::from_millis(10);
291 assert_eq!(iter.next(), Some(Duration::from_millis(10)));
292 assert_eq!(iter.next(), Some(Duration::from_millis(10)));
293 assert_eq!(iter.next(), Some(Duration::from_millis(20)));
294 assert_eq!(iter.next(), Some(Duration::from_millis(30)));
295 assert_eq!(iter.next(), Some(Duration::from_millis(50)));
296 assert_eq!(iter.next(), Some(Duration::from_millis(80)));
297 }
298
299 #[test]
300 fn test_fibonacci_saturates_at_maximum_value() {
301 let mut iter = FibonacciBackoff::from_millis(u64::MAX);
302 assert_eq!(iter.next(), Some(Duration::from_millis(u64::MAX)));
303 assert_eq!(iter.next(), Some(Duration::from_millis(u64::MAX)));
304 }
305
306 #[test]
307 fn test_fibonacci_stops_increasing_at_max_delay() {
308 let mut iter = FibonacciBackoff::from_millis(10).max_delay(Duration::from_millis(50));
309 assert_eq!(iter.next(), Some(Duration::from_millis(10)));
310 assert_eq!(iter.next(), Some(Duration::from_millis(10)));
311 assert_eq!(iter.next(), Some(Duration::from_millis(20)));
312 assert_eq!(iter.next(), Some(Duration::from_millis(30)));
313 assert_eq!(iter.next(), Some(Duration::from_millis(50)));
314 assert_eq!(iter.next(), Some(Duration::from_millis(50)));
315 }
316
317 #[test]
318 fn test_fibonacci_returns_max_when_max_less_than_base() {
319 let mut iter = FibonacciBackoff::from_secs(20).max_delay(Duration::from_secs(10));
320
321 assert_eq!(iter.next(), Some(Duration::from_secs(10)));
322 assert_eq!(iter.next(), Some(Duration::from_secs(10)));
323 }
324
325 #[test]
326 fn test_exponential_some_exponential_base_10() {
327 let mut s = ExponentialBackoff::from_millis(10);
328
329 assert_eq!(s.next(), Some(Duration::from_millis(10)));
330 assert_eq!(s.next(), Some(Duration::from_millis(100)));
331 assert_eq!(s.next(), Some(Duration::from_millis(1000)));
332 }
333
334 #[test]
335 fn test_exponential_some_exponential_base_2() {
336 let mut s = ExponentialBackoff::from_millis(2);
337
338 assert_eq!(s.next(), Some(Duration::from_millis(2)));
339 assert_eq!(s.next(), Some(Duration::from_millis(4)));
340 assert_eq!(s.next(), Some(Duration::from_millis(8)));
341 }
342
343 #[test]
344 fn test_exponential_saturates_at_maximum_value() {
345 let mut s = ExponentialBackoff::from_millis(u64::MAX - 1);
346
347 assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX - 1)));
348 assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX)));
349 assert_eq!(s.next(), Some(Duration::from_millis(u64::MAX)));
350 }
351
352 #[test]
353 fn test_exponential_stops_increasing_at_max_delay() {
354 let mut s = ExponentialBackoff::from_millis(2).max_delay(Duration::from_millis(4));
355
356 assert_eq!(s.next(), Some(Duration::from_millis(2)));
357 assert_eq!(s.next(), Some(Duration::from_millis(4)));
358 assert_eq!(s.next(), Some(Duration::from_millis(4)));
359 }
360
361 #[test]
362 fn test_exponential_max_when_max_less_than_base() {
363 let mut s = ExponentialBackoff::from_millis(20).max_delay(Duration::from_millis(10));
364
365 assert_eq!(s.next(), Some(Duration::from_millis(10)));
366 assert_eq!(s.next(), Some(Duration::from_millis(10)));
367 }
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use super::*;
374 use std::io::ErrorKind;
375 use std::ops::AddAssign;
376 use std::time::Duration;
377 use tracing::debug;
378
379 #[fluvio_future::test]
380 async fn test_fixed_retries_no_delay() {
381 let mut executed_retries = 0u8;
382 let operation = || {
383 let i = executed_retries;
384 executed_retries.add_assign(1);
385 async move {
386 debug!("called retry#{}", i);
387
388 Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
389 }
390 };
391 let retry_result = retry(FixedDelay::default().take(2), operation).await;
392 assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::NotFound));
393 assert_eq!(executed_retries, 3);
394 }
395
396 #[fluvio_future::test]
397 async fn test_fixed_retries_timeout() {
398 let mut executed_retries = 0u8;
399 let operation = || {
400 let i = executed_retries;
401 executed_retries.add_assign(1);
402 async move {
403 debug!("called retry#{}", i);
404 Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
405 }
406 };
407 let retry_result = retry(FixedDelay::from_millis(100).take(10), operation)
408 .timeout(Duration::from_millis(300))
409 .await;
410
411 assert!(retry_result.is_err());
412 assert!(executed_retries < 10);
413 }
414
415 #[fluvio_future::test]
416 async fn test_fixed_retries_not_retryable() {
417 let mut executed_retries = 0u8;
418 let operation = || {
419 let i = executed_retries;
420 executed_retries.add_assign(1);
421 async move {
422 debug!("called retry#{}", i);
423 Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
424 }
425 };
426 let retry_result =
427 retry_if(FixedDelay::from_millis(100).take(10), operation, |_| false).await;
428
429 assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::NotFound));
430 assert_eq!(executed_retries, 1);
431 }
432
433 #[fluvio_future::test]
434 async fn test_conditional_retry() {
435 let mut executed_retries = 0u8;
436 let operation = || {
437 executed_retries.add_assign(1);
438 let i = executed_retries;
439 async move {
440 debug!("called retry#{}", i);
441 if i < 2 {
442 Result::<usize, std::io::Error>::Err(std::io::Error::from(ErrorKind::NotFound))
443 } else {
444 Result::<usize, std::io::Error>::Err(std::io::Error::from(
445 ErrorKind::AddrNotAvailable,
446 ))
447 }
448 }
449 };
450 let condition = |err: &std::io::Error| err.kind() == ErrorKind::NotFound;
451 let retry_result = retry_if(FixedDelay::default().take(10), operation, condition).await;
452
453 assert!(matches!(retry_result, Err(err) if err.kind() == ErrorKind::AddrNotAvailable));
454 assert_eq!(executed_retries, 2);
455 }
456}