Skip to main content

mqtt5_protocol/
connection.rs

1use crate::error::MqttError;
2use crate::numeric::u128_to_u64_saturating;
3use crate::prelude::*;
4use crate::time::Duration;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum ConnectionState {
8    #[default]
9    Disconnected,
10    Connecting,
11    Connected,
12    Reconnecting {
13        attempt: u32,
14    },
15}
16
17impl ConnectionState {
18    #[must_use]
19    pub fn is_connected(&self) -> bool {
20        matches!(self, Self::Connected)
21    }
22
23    #[must_use]
24    pub fn is_disconnected(&self) -> bool {
25        matches!(self, Self::Disconnected)
26    }
27
28    #[must_use]
29    pub fn is_reconnecting(&self) -> bool {
30        matches!(self, Self::Reconnecting { .. })
31    }
32
33    #[must_use]
34    pub fn reconnect_attempt(&self) -> Option<u32> {
35        match self {
36            Self::Reconnecting { attempt } => Some(*attempt),
37            _ => None,
38        }
39    }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum DisconnectReason {
44    ClientInitiated,
45    ServerClosed,
46    NetworkError(String),
47    ProtocolError(String),
48    KeepAliveTimeout,
49    AuthFailure,
50}
51
52#[derive(Debug, Clone)]
53pub enum ConnectionEvent {
54    Connecting,
55    Connected {
56        session_present: bool,
57        /// Effective keep-alive interval after MQTT v5 `ServerKeepAlive` negotiation.
58        ///
59        /// Equals `Duration::ZERO` if keep-alive is disabled.
60        keep_alive: Duration,
61    },
62    Disconnected {
63        reason: DisconnectReason,
64    },
65    Reconnecting {
66        attempt: u32,
67    },
68    ReconnectFailed {
69        error: MqttError,
70    },
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct ConnectionInfo {
75    pub session_present: bool,
76    pub assigned_client_id: Option<String>,
77    pub server_keep_alive: Option<u16>,
78}
79
80#[derive(Debug, Clone)]
81pub struct ReconnectConfig {
82    pub enabled: bool,
83    pub initial_delay: Duration,
84    pub max_delay: Duration,
85    pub backoff_factor_tenths: u32,
86    pub max_attempts: Option<u32>,
87}
88
89impl Default for ReconnectConfig {
90    fn default() -> Self {
91        Self {
92            enabled: true,
93            initial_delay: Duration::from_secs(1),
94            max_delay: Duration::from_secs(60),
95            backoff_factor_tenths: 20,
96            max_attempts: None,
97        }
98    }
99}
100
101impl ReconnectConfig {
102    #[must_use]
103    pub fn disabled() -> Self {
104        Self {
105            enabled: false,
106            ..Default::default()
107        }
108    }
109
110    #[must_use]
111    pub fn backoff_factor(&self) -> f64 {
112        f64::from(self.backoff_factor_tenths) / 10.0
113    }
114
115    pub fn set_backoff_factor(&mut self, factor: f64) {
116        self.backoff_factor_tenths = if factor < 0.0 {
117            0
118        } else if factor >= f64::from(u32::MAX) / 10.0 {
119            u32::MAX
120        } else {
121            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
122            let result = (factor * 10.0) as u32;
123            result
124        };
125    }
126
127    #[must_use]
128    pub fn calculate_delay(&self, attempt: u32) -> Duration {
129        if attempt == 0 {
130            return self.initial_delay;
131        }
132
133        let initial_ms = u128_to_u64_saturating(self.initial_delay.as_millis());
134        let max_ms = u128_to_u64_saturating(self.max_delay.as_millis());
135
136        let factor_tenths = u64::from(self.backoff_factor_tenths);
137        let mut delay_tenths = initial_ms.saturating_mul(10);
138
139        for _ in 0..attempt {
140            delay_tenths = delay_tenths.saturating_mul(factor_tenths) / 10;
141            if delay_tenths / 10 >= max_ms {
142                return self.max_delay;
143            }
144        }
145
146        Duration::from_millis((delay_tenths / 10).min(max_ms))
147    }
148
149    #[must_use]
150    pub fn should_retry(&self, attempt: u32) -> bool {
151        if !self.enabled {
152            return false;
153        }
154        match self.max_attempts {
155            Some(max) => attempt < max,
156            None => true,
157        }
158    }
159}
160
161#[derive(Debug, Clone)]
162pub struct ConnectionStateMachine {
163    state: ConnectionState,
164    info: ConnectionInfo,
165    reconnect_config: ReconnectConfig,
166}
167
168impl Default for ConnectionStateMachine {
169    fn default() -> Self {
170        Self {
171            state: ConnectionState::Disconnected,
172            info: ConnectionInfo::default(),
173            reconnect_config: ReconnectConfig::default(),
174        }
175    }
176}
177
178impl ConnectionStateMachine {
179    #[must_use]
180    pub fn new(reconnect_config: ReconnectConfig) -> Self {
181        Self {
182            state: ConnectionState::Disconnected,
183            info: ConnectionInfo::default(),
184            reconnect_config,
185        }
186    }
187
188    #[must_use]
189    pub fn state(&self) -> ConnectionState {
190        self.state
191    }
192
193    #[must_use]
194    pub fn info(&self) -> &ConnectionInfo {
195        &self.info
196    }
197
198    #[must_use]
199    pub fn reconnect_config(&self) -> &ReconnectConfig {
200        &self.reconnect_config
201    }
202
203    pub fn set_reconnect_config(&mut self, config: ReconnectConfig) {
204        self.reconnect_config = config;
205    }
206
207    pub fn transition(&mut self, event: &ConnectionEvent) -> ConnectionState {
208        match event {
209            ConnectionEvent::Connecting => {
210                self.state = ConnectionState::Connecting;
211            }
212            ConnectionEvent::Connected {
213                session_present,
214                keep_alive,
215            } => {
216                self.state = ConnectionState::Connected;
217                self.info.session_present = *session_present;
218                self.info.server_keep_alive = u16::try_from(keep_alive.as_secs()).ok();
219            }
220            ConnectionEvent::Disconnected { .. } | ConnectionEvent::ReconnectFailed { .. } => {
221                self.state = ConnectionState::Disconnected;
222                self.info = ConnectionInfo::default();
223            }
224            ConnectionEvent::Reconnecting { attempt } => {
225                self.state = ConnectionState::Reconnecting { attempt: *attempt };
226            }
227        }
228        self.state
229    }
230
231    pub fn set_connection_info(&mut self, info: ConnectionInfo) {
232        self.info = info;
233    }
234
235    #[must_use]
236    pub fn is_connected(&self) -> bool {
237        self.state.is_connected()
238    }
239
240    #[must_use]
241    pub fn should_reconnect(&self) -> bool {
242        match self.state {
243            ConnectionState::Disconnected => self.reconnect_config.enabled,
244            ConnectionState::Reconnecting { attempt } => {
245                self.reconnect_config.should_retry(attempt + 1)
246            }
247            _ => false,
248        }
249    }
250
251    #[must_use]
252    pub fn next_reconnect_delay(&self) -> Option<Duration> {
253        match self.state {
254            ConnectionState::Disconnected => {
255                if self.reconnect_config.enabled {
256                    Some(self.reconnect_config.calculate_delay(0))
257                } else {
258                    None
259                }
260            }
261            ConnectionState::Reconnecting { attempt } => {
262                if self.reconnect_config.should_retry(attempt + 1) {
263                    Some(self.reconnect_config.calculate_delay(attempt))
264                } else {
265                    None
266                }
267            }
268            _ => None,
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_connection_state_default() {
279        let state = ConnectionState::default();
280        assert!(state.is_disconnected());
281    }
282
283    #[test]
284    fn test_state_machine_transitions() {
285        let mut sm = ConnectionStateMachine::default();
286
287        assert!(sm.state().is_disconnected());
288
289        sm.transition(&ConnectionEvent::Connecting);
290        assert_eq!(sm.state(), ConnectionState::Connecting);
291
292        sm.transition(&ConnectionEvent::Connected {
293            session_present: true,
294            keep_alive: Duration::from_secs(60),
295        });
296        assert!(sm.is_connected());
297        assert!(sm.info().session_present);
298        assert_eq!(sm.info().server_keep_alive, Some(60));
299
300        sm.transition(&ConnectionEvent::Disconnected {
301            reason: DisconnectReason::NetworkError("timeout".into()),
302        });
303        assert!(sm.state().is_disconnected());
304        assert!(!sm.info().session_present);
305    }
306
307    #[test]
308    fn test_reconnect_delay_calculation() {
309        let config = ReconnectConfig {
310            enabled: true,
311            initial_delay: Duration::from_secs(1),
312            max_delay: Duration::from_secs(30),
313            backoff_factor_tenths: 20,
314            max_attempts: Some(5),
315        };
316
317        assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
318        assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
319        assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
320        assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
321        assert_eq!(config.calculate_delay(4), Duration::from_secs(16));
322        assert_eq!(config.calculate_delay(5), Duration::from_secs(30));
323    }
324
325    #[test]
326    fn test_should_retry() {
327        let config = ReconnectConfig {
328            enabled: true,
329            max_attempts: Some(3),
330            ..Default::default()
331        };
332
333        assert!(config.should_retry(0));
334        assert!(config.should_retry(1));
335        assert!(config.should_retry(2));
336        assert!(!config.should_retry(3));
337        assert!(!config.should_retry(4));
338    }
339
340    #[test]
341    fn test_disabled_reconnect() {
342        let config = ReconnectConfig::disabled();
343        assert!(!config.should_retry(0));
344    }
345
346    #[test]
347    fn test_reconnect_flow() {
348        let mut sm = ConnectionStateMachine::new(ReconnectConfig {
349            enabled: true,
350            initial_delay: Duration::from_millis(100),
351            max_delay: Duration::from_secs(10),
352            backoff_factor_tenths: 20,
353            max_attempts: Some(3),
354        });
355
356        sm.transition(&ConnectionEvent::Connecting);
357        sm.transition(&ConnectionEvent::Connected {
358            session_present: false,
359            keep_alive: Duration::from_secs(60),
360        });
361        assert!(sm.is_connected());
362
363        sm.transition(&ConnectionEvent::Disconnected {
364            reason: DisconnectReason::NetworkError("connection lost".into()),
365        });
366        assert!(sm.should_reconnect());
367
368        sm.transition(&ConnectionEvent::Reconnecting { attempt: 0 });
369        assert!(sm.state().is_reconnecting());
370        assert_eq!(sm.state().reconnect_attempt(), Some(0));
371        assert!(sm.should_reconnect());
372
373        sm.transition(&ConnectionEvent::Reconnecting { attempt: 1 });
374        assert!(sm.should_reconnect());
375
376        sm.transition(&ConnectionEvent::Reconnecting { attempt: 2 });
377        assert!(!sm.should_reconnect());
378    }
379}