distant_net/client/
reconnect.rs

1use std::io;
2use std::time::Duration;
3
4use log::*;
5use strum::Display;
6use tokio::sync::watch;
7use tokio::task::JoinHandle;
8
9use super::Reconnectable;
10
11/// Represents a watcher over a [`ConnectionState`].
12#[derive(Clone)]
13pub struct ConnectionWatcher(pub(super) watch::Receiver<ConnectionState>);
14
15impl ConnectionWatcher {
16    /// Returns next [`ConnectionState`] after a change is detected, or `None` if no more changes
17    /// will be detected.
18    pub async fn next(&mut self) -> Option<ConnectionState> {
19        self.0.changed().await.ok()?;
20        Some(self.last())
21    }
22
23    /// Returns true if the connection state has changed.
24    pub fn has_changed(&self) -> bool {
25        self.0.has_changed().ok().unwrap_or(false)
26    }
27
28    /// Returns the last [`ConnectionState`] observed.
29    pub fn last(&self) -> ConnectionState {
30        *self.0.borrow()
31    }
32
33    /// Spawns a new task that continually monitors for connection state changes and invokes the
34    /// function `f` whenever a new change is detected.
35    pub fn on_change<F>(&self, mut f: F) -> JoinHandle<()>
36    where
37        F: FnMut(ConnectionState) + Send + 'static,
38    {
39        let rx = self.0.clone();
40        tokio::spawn(async move {
41            let mut watcher = Self(rx);
42            while let Some(state) = watcher.next().await {
43                f(state);
44            }
45        })
46    }
47}
48
49/// Represents the state of a connection.
50#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)]
51#[strum(serialize_all = "snake_case")]
52pub enum ConnectionState {
53    /// Connection is not active, but currently going through reconnection process.
54    Reconnecting,
55
56    /// Connection is active.
57    Connected,
58
59    /// Connection is not active.
60    Disconnected,
61}
62
63impl ConnectionState {
64    /// Returns true if reconnecting.
65    pub fn is_reconnecting(&self) -> bool {
66        matches!(self, Self::Reconnecting)
67    }
68
69    /// Returns true if connected.
70    pub fn is_connected(&self) -> bool {
71        matches!(self, Self::Connected)
72    }
73
74    /// Returns true if disconnected.
75    pub fn is_disconnected(&self) -> bool {
76        matches!(self, Self::Disconnected)
77    }
78}
79
80/// Represents the strategy to apply when attempting to reconnect the client to the server.
81#[derive(Clone, Debug)]
82pub enum ReconnectStrategy {
83    /// A retry strategy that will fail immediately if a reconnect is attempted.
84    Fail,
85
86    /// A retry strategy driven by exponential back-off.
87    ExponentialBackoff {
88        /// Represents the initial time to wait between reconnect attempts.
89        base: Duration,
90
91        /// Factor to use when modifying the retry time, used as a multiplier.
92        factor: f64,
93
94        /// Represents the maximum duration to wait between attempts. None indicates no limit.
95        max_duration: Option<Duration>,
96
97        /// Represents the maximum attempts to retry before failing. None indicates no limit.
98        max_retries: Option<usize>,
99
100        /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
101        timeout: Option<Duration>,
102    },
103
104    /// A retry strategy driven by the fibonacci series.
105    FibonacciBackoff {
106        /// Represents the initial time to wait between reconnect attempts.
107        base: Duration,
108
109        /// Represents the maximum duration to wait between attempts. None indicates no limit.
110        max_duration: Option<Duration>,
111
112        /// Represents the maximum attempts to retry before failing. None indicates no limit.
113        max_retries: Option<usize>,
114
115        /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
116        timeout: Option<Duration>,
117    },
118
119    /// A retry strategy driven by a fixed interval.
120    FixedInterval {
121        /// Represents the time between reconnect attempts.
122        interval: Duration,
123
124        /// Represents the maximum attempts to retry before failing. None indicates no limit.
125        max_retries: Option<usize>,
126
127        /// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
128        timeout: Option<Duration>,
129    },
130}
131
132impl Default for ReconnectStrategy {
133    /// Creates a reconnect strategy that will immediately fail.
134    fn default() -> Self {
135        Self::Fail
136    }
137}
138
139impl ReconnectStrategy {
140    pub async fn reconnect<T: Reconnectable>(&mut self, reconnectable: &mut T) -> io::Result<()> {
141        // If our strategy is to immediately fail, do so
142        if self.is_fail() {
143            return Err(io::Error::from(io::ErrorKind::ConnectionAborted));
144        }
145
146        // Keep track of last sleep length for use in adjustment
147        let mut previous_sleep = None;
148        let mut current_sleep = self.initial_sleep_duration();
149
150        // Keep track of remaining retries
151        let mut retries_remaining = self.max_retries();
152
153        // Get timeout if strategy will employ one
154        let timeout = self.timeout();
155
156        // Get maximum allowed duration between attempts
157        let max_duration = self.max_duration();
158
159        // Continue trying to reconnect while we have more tries remaining, otherwise
160        // we will return the last error encountered
161        let mut result = Ok(());
162
163        while retries_remaining.is_none() || retries_remaining > Some(0) {
164            // Perform reconnect attempt
165            result = match timeout {
166                Some(timeout) => {
167                    match tokio::time::timeout(timeout, reconnectable.reconnect()).await {
168                        Ok(x) => x,
169                        Err(x) => Err(x.into()),
170                    }
171                }
172                None => reconnectable.reconnect().await,
173            };
174
175            // If reconnect was successful, we're done and we can exit
176            match &result {
177                Ok(()) => return Ok(()),
178                Err(x) => {
179                    error!("Failed to reconnect: {x}");
180                }
181            }
182
183            // Decrement remaining retries if we have a limit
184            if let Some(remaining) = retries_remaining.as_mut() {
185                if *remaining > 0 {
186                    *remaining -= 1;
187                }
188            }
189
190            // Sleep before making next attempt
191            tokio::time::sleep(current_sleep).await;
192
193            // Update our sleep duration
194            let next_sleep = self.adjust_sleep(previous_sleep, current_sleep);
195            previous_sleep = Some(current_sleep);
196            current_sleep = if let Some(duration) = max_duration {
197                std::cmp::min(next_sleep, duration)
198            } else {
199                next_sleep
200            };
201        }
202
203        result
204    }
205
206    /// Returns true if this strategy is the fail variant.
207    pub fn is_fail(&self) -> bool {
208        matches!(self, Self::Fail)
209    }
210
211    /// Returns true if this strategy is the exponential backoff variant.
212    pub fn is_exponential_backoff(&self) -> bool {
213        matches!(self, Self::ExponentialBackoff { .. })
214    }
215
216    /// Returns true if this strategy is the fibonacci backoff variant.
217    pub fn is_fibonacci_backoff(&self) -> bool {
218        matches!(self, Self::FibonacciBackoff { .. })
219    }
220
221    /// Returns true if this strategy is the fixed interval variant.
222    pub fn is_fixed_interval(&self) -> bool {
223        matches!(self, Self::FixedInterval { .. })
224    }
225
226    /// Returns the maximum duration between reconnect attempts, or None if there is no limit.
227    pub fn max_duration(&self) -> Option<Duration> {
228        match self {
229            ReconnectStrategy::Fail => None,
230            ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration,
231            ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration,
232            ReconnectStrategy::FixedInterval { .. } => None,
233        }
234    }
235
236    /// Returns the maximum reconnect attempts the strategy will perform, or None if will attempt
237    /// forever.
238    pub fn max_retries(&self) -> Option<usize> {
239        match self {
240            ReconnectStrategy::Fail => None,
241            ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries,
242            ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries,
243            ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries,
244        }
245    }
246
247    /// Returns the timeout per reconnect attempt that is associated with the strategy.
248    pub fn timeout(&self) -> Option<Duration> {
249        match self {
250            ReconnectStrategy::Fail => None,
251            ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout,
252            ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout,
253            ReconnectStrategy::FixedInterval { timeout, .. } => *timeout,
254        }
255    }
256
257    /// Returns the initial duration to sleep.
258    fn initial_sleep_duration(&self) -> Duration {
259        match self {
260            ReconnectStrategy::Fail => Duration::new(0, 0),
261            ReconnectStrategy::ExponentialBackoff { base, .. } => *base,
262            ReconnectStrategy::FibonacciBackoff { base, .. } => *base,
263            ReconnectStrategy::FixedInterval { interval, .. } => *interval,
264        }
265    }
266
267    /// Adjusts next sleep duration based on the strategy.
268    fn adjust_sleep(&self, prev: Option<Duration>, curr: Duration) -> Duration {
269        match self {
270            ReconnectStrategy::Fail => Duration::new(0, 0),
271            ReconnectStrategy::ExponentialBackoff { factor, .. } => {
272                let next_millis = (curr.as_millis() as f64) * factor;
273                Duration::from_millis(if next_millis > (std::u64::MAX as f64) {
274                    std::u64::MAX
275                } else {
276                    next_millis as u64
277                })
278            }
279            ReconnectStrategy::FibonacciBackoff { .. } => {
280                let prev = prev.unwrap_or_else(|| Duration::new(0, 0));
281                prev.checked_add(curr).unwrap_or(Duration::MAX)
282            }
283            ReconnectStrategy::FixedInterval { .. } => curr,
284        }
285    }
286}