distant_net/client/
reconnect.rs1use std::io;
2use std::time::Duration;
3
4use log::*;
5use strum::Display;
6use tokio::sync::watch;
7use tokio::task::JoinHandle;
8
9use super::Reconnectable;
10
11#[derive(Clone)]
13pub struct ConnectionWatcher(pub(super) watch::Receiver<ConnectionState>);
14
15impl ConnectionWatcher {
16 pub async fn next(&mut self) -> Option<ConnectionState> {
19 self.0.changed().await.ok()?;
20 Some(self.last())
21 }
22
23 pub fn has_changed(&self) -> bool {
25 self.0.has_changed().ok().unwrap_or(false)
26 }
27
28 pub fn last(&self) -> ConnectionState {
30 *self.0.borrow()
31 }
32
33 pub fn on_change<F>(&self, mut f: F) -> JoinHandle<()>
36 where
37 F: FnMut(ConnectionState) + Send + 'static,
38 {
39 let rx = self.0.clone();
40 tokio::spawn(async move {
41 let mut watcher = Self(rx);
42 while let Some(state) = watcher.next().await {
43 f(state);
44 }
45 })
46 }
47}
48
49#[derive(Copy, Clone, Debug, Display, PartialEq, Eq)]
51#[strum(serialize_all = "snake_case")]
52pub enum ConnectionState {
53 Reconnecting,
55
56 Connected,
58
59 Disconnected,
61}
62
63impl ConnectionState {
64 pub fn is_reconnecting(&self) -> bool {
66 matches!(self, Self::Reconnecting)
67 }
68
69 pub fn is_connected(&self) -> bool {
71 matches!(self, Self::Connected)
72 }
73
74 pub fn is_disconnected(&self) -> bool {
76 matches!(self, Self::Disconnected)
77 }
78}
79
80#[derive(Clone, Debug)]
82pub enum ReconnectStrategy {
83 Fail,
85
86 ExponentialBackoff {
88 base: Duration,
90
91 factor: f64,
93
94 max_duration: Option<Duration>,
96
97 max_retries: Option<usize>,
99
100 timeout: Option<Duration>,
102 },
103
104 FibonacciBackoff {
106 base: Duration,
108
109 max_duration: Option<Duration>,
111
112 max_retries: Option<usize>,
114
115 timeout: Option<Duration>,
117 },
118
119 FixedInterval {
121 interval: Duration,
123
124 max_retries: Option<usize>,
126
127 timeout: Option<Duration>,
129 },
130}
131
132impl Default for ReconnectStrategy {
133 fn default() -> Self {
135 Self::Fail
136 }
137}
138
139impl ReconnectStrategy {
140 pub async fn reconnect<T: Reconnectable>(&mut self, reconnectable: &mut T) -> io::Result<()> {
141 if self.is_fail() {
143 return Err(io::Error::from(io::ErrorKind::ConnectionAborted));
144 }
145
146 let mut previous_sleep = None;
148 let mut current_sleep = self.initial_sleep_duration();
149
150 let mut retries_remaining = self.max_retries();
152
153 let timeout = self.timeout();
155
156 let max_duration = self.max_duration();
158
159 let mut result = Ok(());
162
163 while retries_remaining.is_none() || retries_remaining > Some(0) {
164 result = match timeout {
166 Some(timeout) => {
167 match tokio::time::timeout(timeout, reconnectable.reconnect()).await {
168 Ok(x) => x,
169 Err(x) => Err(x.into()),
170 }
171 }
172 None => reconnectable.reconnect().await,
173 };
174
175 match &result {
177 Ok(()) => return Ok(()),
178 Err(x) => {
179 error!("Failed to reconnect: {x}");
180 }
181 }
182
183 if let Some(remaining) = retries_remaining.as_mut() {
185 if *remaining > 0 {
186 *remaining -= 1;
187 }
188 }
189
190 tokio::time::sleep(current_sleep).await;
192
193 let next_sleep = self.adjust_sleep(previous_sleep, current_sleep);
195 previous_sleep = Some(current_sleep);
196 current_sleep = if let Some(duration) = max_duration {
197 std::cmp::min(next_sleep, duration)
198 } else {
199 next_sleep
200 };
201 }
202
203 result
204 }
205
206 pub fn is_fail(&self) -> bool {
208 matches!(self, Self::Fail)
209 }
210
211 pub fn is_exponential_backoff(&self) -> bool {
213 matches!(self, Self::ExponentialBackoff { .. })
214 }
215
216 pub fn is_fibonacci_backoff(&self) -> bool {
218 matches!(self, Self::FibonacciBackoff { .. })
219 }
220
221 pub fn is_fixed_interval(&self) -> bool {
223 matches!(self, Self::FixedInterval { .. })
224 }
225
226 pub fn max_duration(&self) -> Option<Duration> {
228 match self {
229 ReconnectStrategy::Fail => None,
230 ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration,
231 ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration,
232 ReconnectStrategy::FixedInterval { .. } => None,
233 }
234 }
235
236 pub fn max_retries(&self) -> Option<usize> {
239 match self {
240 ReconnectStrategy::Fail => None,
241 ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries,
242 ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries,
243 ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries,
244 }
245 }
246
247 pub fn timeout(&self) -> Option<Duration> {
249 match self {
250 ReconnectStrategy::Fail => None,
251 ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout,
252 ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout,
253 ReconnectStrategy::FixedInterval { timeout, .. } => *timeout,
254 }
255 }
256
257 fn initial_sleep_duration(&self) -> Duration {
259 match self {
260 ReconnectStrategy::Fail => Duration::new(0, 0),
261 ReconnectStrategy::ExponentialBackoff { base, .. } => *base,
262 ReconnectStrategy::FibonacciBackoff { base, .. } => *base,
263 ReconnectStrategy::FixedInterval { interval, .. } => *interval,
264 }
265 }
266
267 fn adjust_sleep(&self, prev: Option<Duration>, curr: Duration) -> Duration {
269 match self {
270 ReconnectStrategy::Fail => Duration::new(0, 0),
271 ReconnectStrategy::ExponentialBackoff { factor, .. } => {
272 let next_millis = (curr.as_millis() as f64) * factor;
273 Duration::from_millis(if next_millis > (std::u64::MAX as f64) {
274 std::u64::MAX
275 } else {
276 next_millis as u64
277 })
278 }
279 ReconnectStrategy::FibonacciBackoff { .. } => {
280 let prev = prev.unwrap_or_else(|| Duration::new(0, 0));
281 prev.checked_add(curr).unwrap_or(Duration::MAX)
282 }
283 ReconnectStrategy::FixedInterval { .. } => curr,
284 }
285 }
286}