Skip to main content

msg_socket/connection/
backoff.rs

1use futures::{FutureExt, Stream};
2use std::{
3    pin::Pin,
4    task::{Context, Poll},
5    time::Duration,
6};
7use tokio::time::sleep;
8
9use crate::ConnOptions;
10
11/// Helper trait alias for backoff streams.
12/// We define any stream that yields `Duration`s as a backoff
13pub trait Backoff: Stream<Item = Duration> + Unpin {}
14
15// Blanket implementation of `Backoff` for any stream that yields `Duration`s.
16impl<T> Backoff for T where T: Stream<Item = Duration> + Unpin {}
17
18/// A stream that yields exponentially increasing backoff durations.
19pub struct ExponentialBackoff {
20    /// Current number of retries.
21    retry_count: usize,
22    /// Maximum number of retries before closing the stream.
23    /// If `None`, the stream will retry indefinitely.
24    max_retries: Option<usize>,
25    /// The current backoff duration.
26    backoff: Duration,
27    /// The current backoff timeout, if any.
28    /// We need the timeout to be pinned (`Sleep` is not `Unpin`)
29    timeout: Option<Pin<Box<tokio::time::Sleep>>>,
30}
31
32impl ExponentialBackoff {
33    /// Creates a new exponential backoff stream with the given initial duration and max retries.
34    pub fn new(initial: Duration, max_retries: Option<usize>) -> Self {
35        Self { retry_count: 0, max_retries, backoff: initial, timeout: None }
36    }
37
38    /// (Re)-set the timeout to the current backoff duration.
39    fn reset_timeout(&mut self) {
40        self.timeout = Some(Box::pin(sleep(self.backoff)));
41    }
42}
43
44impl From<&ConnOptions> for ExponentialBackoff {
45    fn from(options: &ConnOptions) -> Self {
46        Self::new(options.backoff_duration, options.retry_attempts)
47    }
48}
49
50impl Stream for ExponentialBackoff {
51    type Item = Duration;
52
53    /// Polls the exponential backoff stream. Returns `Poll::Ready` with the current backoff
54    /// duration if the backoff timeout has elapsed, otherwise returns `Poll::Pending`.
55    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        let this = self.get_mut();
57
58        loop {
59            let Some(ref mut timeout) = this.timeout else {
60                // Set the initial timeout
61                this.reset_timeout();
62                continue;
63            };
64
65            if timeout.poll_unpin(cx).is_ready() {
66                // Timeout has elapsed, so reset the timeout and double the backoff
67                this.backoff *= 2;
68                this.retry_count += 1;
69
70                // Close the stream
71                if let Some(max_retries) = this.max_retries &&
72                    this.retry_count >= max_retries
73                {
74                    return Poll::Ready(None);
75                }
76
77                this.reset_timeout();
78
79                // Wake up the task to poll the timeout again
80                cx.waker().wake_by_ref();
81
82                // Return the current backoff duration
83                return Poll::Ready(Some(this.backoff));
84            } else {
85                // Timeout has not elapsed, so return pending
86                return Poll::Pending;
87            }
88        }
89    }
90}