msg_socket/connection/backoff.rs
1use futures::{FutureExt, Stream};
2use std::{
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7use tokio::time::sleep;
8
9use crate::ConnOptions;
10
11/// Helper trait alias for backoff streams.
12/// We define any stream that yields `Duration`s as a backoff
13pub trait Backoff: Stream<Item = Duration> + Unpin {}
14
15// Blanket implementation of `Backoff` for any stream that yields `Duration`s.
16impl<T> Backoff for T where T: Stream<Item = Duration> + Unpin {}
17
18/// A stream that yields exponentially increasing backoff durations.
19pub struct ExponentialBackoff {
20 /// Current number of retries.
21 retry_count: usize,
22 /// Maximum number of retries before closing the stream.
23 /// If `None`, the stream will retry indefinitely.
24 max_retries: Option<usize>,
25 /// The current backoff duration.
26 backoff: Duration,
27 /// The current backoff timeout, if any.
28 /// We need the timeout to be pinned (`Sleep` is not `Unpin`)
29 timeout: Option<Pin<Box<tokio::time::Sleep>>>,
30}
31
32impl ExponentialBackoff {
33 /// Creates a new exponential backoff stream with the given initial duration and max retries.
34 pub fn new(initial: Duration, max_retries: Option<usize>) -> Self {
35 Self { retry_count: 0, max_retries, backoff: initial, timeout: None }
36 }
37
38 /// (Re)-set the timeout to the current backoff duration.
39 fn reset_timeout(&mut self) {
40 self.timeout = Some(Box::pin(sleep(self.backoff)));
41 }
42}
43
44impl From<&ConnOptions> for ExponentialBackoff {
45 fn from(options: &ConnOptions) -> Self {
46 Self::new(options.backoff_duration, options.retry_attempts)
47 }
48}
49
50impl Stream for ExponentialBackoff {
51 type Item = Duration;
52
53 /// Polls the exponential backoff stream. Returns `Poll::Ready` with the current backoff
54 /// duration if the backoff timeout has elapsed, otherwise returns `Poll::Pending`.
55 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 let this = self.get_mut();
57
58 loop {
59 let Some(ref mut timeout) = this.timeout else {
60 // Set the initial timeout
61 this.reset_timeout();
62 continue;
63 };
64
65 if timeout.poll_unpin(cx).is_ready() {
66 // Timeout has elapsed, so reset the timeout and double the backoff
67 this.backoff *= 2;
68 this.retry_count += 1;
69
70 // Close the stream
71 if let Some(max_retries) = this.max_retries &&
72 this.retry_count >= max_retries
73 {
74 return Poll::Ready(None);
75 }
76
77 this.reset_timeout();
78
79 // Wake up the task to poll the timeout again
80 cx.waker().wake_by_ref();
81
82 // Return the current backoff duration
83 return Poll::Ready(Some(this.backoff));
84 } else {
85 // Timeout has not elapsed, so return pending
86 return Poll::Pending;
87 }
88 }
89 }
90}