1use crate::error::MqttError;
2use crate::numeric::u128_to_u64_saturating;
3use crate::prelude::*;
4use crate::time::Duration;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7pub enum ConnectionState {
8 #[default]
9 Disconnected,
10 Connecting,
11 Connected,
12 Reconnecting {
13 attempt: u32,
14 },
15}
16
17impl ConnectionState {
18 #[must_use]
19 pub fn is_connected(&self) -> bool {
20 matches!(self, Self::Connected)
21 }
22
23 #[must_use]
24 pub fn is_disconnected(&self) -> bool {
25 matches!(self, Self::Disconnected)
26 }
27
28 #[must_use]
29 pub fn is_reconnecting(&self) -> bool {
30 matches!(self, Self::Reconnecting { .. })
31 }
32
33 #[must_use]
34 pub fn reconnect_attempt(&self) -> Option<u32> {
35 match self {
36 Self::Reconnecting { attempt } => Some(*attempt),
37 _ => None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum DisconnectReason {
44 ClientInitiated,
45 ServerClosed,
46 NetworkError(String),
47 ProtocolError(String),
48 KeepAliveTimeout,
49 AuthFailure,
50}
51
52#[derive(Debug, Clone)]
53pub enum ConnectionEvent {
54 Connecting,
55 Connected { session_present: bool },
56 Disconnected { reason: DisconnectReason },
57 Reconnecting { attempt: u32 },
58 ReconnectFailed { error: MqttError },
59}
60
61#[derive(Debug, Clone, Default)]
62pub struct ConnectionInfo {
63 pub session_present: bool,
64 pub assigned_client_id: Option<String>,
65 pub server_keep_alive: Option<u16>,
66}
67
68#[derive(Debug, Clone)]
69pub struct ReconnectConfig {
70 pub enabled: bool,
71 pub initial_delay: Duration,
72 pub max_delay: Duration,
73 pub backoff_factor_tenths: u32,
74 pub max_attempts: Option<u32>,
75}
76
77impl Default for ReconnectConfig {
78 fn default() -> Self {
79 Self {
80 enabled: true,
81 initial_delay: Duration::from_secs(1),
82 max_delay: Duration::from_secs(60),
83 backoff_factor_tenths: 20,
84 max_attempts: None,
85 }
86 }
87}
88
89impl ReconnectConfig {
90 #[must_use]
91 pub fn disabled() -> Self {
92 Self {
93 enabled: false,
94 ..Default::default()
95 }
96 }
97
98 #[must_use]
99 pub fn backoff_factor(&self) -> f64 {
100 f64::from(self.backoff_factor_tenths) / 10.0
101 }
102
103 pub fn set_backoff_factor(&mut self, factor: f64) {
104 self.backoff_factor_tenths = if factor < 0.0 {
105 0
106 } else if factor >= f64::from(u32::MAX) / 10.0 {
107 u32::MAX
108 } else {
109 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
110 let result = (factor * 10.0) as u32;
111 result
112 };
113 }
114
115 #[must_use]
116 pub fn calculate_delay(&self, attempt: u32) -> Duration {
117 if attempt == 0 {
118 return self.initial_delay;
119 }
120
121 let initial_ms = u128_to_u64_saturating(self.initial_delay.as_millis());
122 let max_ms = u128_to_u64_saturating(self.max_delay.as_millis());
123
124 let factor_tenths = u64::from(self.backoff_factor_tenths);
125 let mut delay_tenths = initial_ms.saturating_mul(10);
126
127 for _ in 0..attempt {
128 delay_tenths = delay_tenths.saturating_mul(factor_tenths) / 10;
129 if delay_tenths / 10 >= max_ms {
130 return self.max_delay;
131 }
132 }
133
134 Duration::from_millis((delay_tenths / 10).min(max_ms))
135 }
136
137 #[must_use]
138 pub fn should_retry(&self, attempt: u32) -> bool {
139 if !self.enabled {
140 return false;
141 }
142 match self.max_attempts {
143 Some(max) => attempt < max,
144 None => true,
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
150pub struct ConnectionStateMachine {
151 state: ConnectionState,
152 info: ConnectionInfo,
153 reconnect_config: ReconnectConfig,
154}
155
156impl Default for ConnectionStateMachine {
157 fn default() -> Self {
158 Self {
159 state: ConnectionState::Disconnected,
160 info: ConnectionInfo::default(),
161 reconnect_config: ReconnectConfig::default(),
162 }
163 }
164}
165
166impl ConnectionStateMachine {
167 #[must_use]
168 pub fn new(reconnect_config: ReconnectConfig) -> Self {
169 Self {
170 state: ConnectionState::Disconnected,
171 info: ConnectionInfo::default(),
172 reconnect_config,
173 }
174 }
175
176 #[must_use]
177 pub fn state(&self) -> ConnectionState {
178 self.state
179 }
180
181 #[must_use]
182 pub fn info(&self) -> &ConnectionInfo {
183 &self.info
184 }
185
186 #[must_use]
187 pub fn reconnect_config(&self) -> &ReconnectConfig {
188 &self.reconnect_config
189 }
190
191 pub fn set_reconnect_config(&mut self, config: ReconnectConfig) {
192 self.reconnect_config = config;
193 }
194
195 pub fn transition(&mut self, event: &ConnectionEvent) -> ConnectionState {
196 match event {
197 ConnectionEvent::Connecting => {
198 self.state = ConnectionState::Connecting;
199 }
200 ConnectionEvent::Connected { session_present } => {
201 self.state = ConnectionState::Connected;
202 self.info.session_present = *session_present;
203 }
204 ConnectionEvent::Disconnected { .. } | ConnectionEvent::ReconnectFailed { .. } => {
205 self.state = ConnectionState::Disconnected;
206 self.info = ConnectionInfo::default();
207 }
208 ConnectionEvent::Reconnecting { attempt } => {
209 self.state = ConnectionState::Reconnecting { attempt: *attempt };
210 }
211 }
212 self.state
213 }
214
215 pub fn set_connection_info(&mut self, info: ConnectionInfo) {
216 self.info = info;
217 }
218
219 #[must_use]
220 pub fn is_connected(&self) -> bool {
221 self.state.is_connected()
222 }
223
224 #[must_use]
225 pub fn should_reconnect(&self) -> bool {
226 match self.state {
227 ConnectionState::Disconnected => self.reconnect_config.enabled,
228 ConnectionState::Reconnecting { attempt } => {
229 self.reconnect_config.should_retry(attempt + 1)
230 }
231 _ => false,
232 }
233 }
234
235 #[must_use]
236 pub fn next_reconnect_delay(&self) -> Option<Duration> {
237 match self.state {
238 ConnectionState::Disconnected => {
239 if self.reconnect_config.enabled {
240 Some(self.reconnect_config.calculate_delay(0))
241 } else {
242 None
243 }
244 }
245 ConnectionState::Reconnecting { attempt } => {
246 if self.reconnect_config.should_retry(attempt + 1) {
247 Some(self.reconnect_config.calculate_delay(attempt))
248 } else {
249 None
250 }
251 }
252 _ => None,
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_connection_state_default() {
263 let state = ConnectionState::default();
264 assert!(state.is_disconnected());
265 }
266
267 #[test]
268 fn test_state_machine_transitions() {
269 let mut sm = ConnectionStateMachine::default();
270
271 assert!(sm.state().is_disconnected());
272
273 sm.transition(&ConnectionEvent::Connecting);
274 assert_eq!(sm.state(), ConnectionState::Connecting);
275
276 sm.transition(&ConnectionEvent::Connected {
277 session_present: true,
278 });
279 assert!(sm.is_connected());
280 assert!(sm.info().session_present);
281
282 sm.transition(&ConnectionEvent::Disconnected {
283 reason: DisconnectReason::NetworkError("timeout".into()),
284 });
285 assert!(sm.state().is_disconnected());
286 assert!(!sm.info().session_present);
287 }
288
289 #[test]
290 fn test_reconnect_delay_calculation() {
291 let config = ReconnectConfig {
292 enabled: true,
293 initial_delay: Duration::from_secs(1),
294 max_delay: Duration::from_secs(30),
295 backoff_factor_tenths: 20,
296 max_attempts: Some(5),
297 };
298
299 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
300 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
301 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
302 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
303 assert_eq!(config.calculate_delay(4), Duration::from_secs(16));
304 assert_eq!(config.calculate_delay(5), Duration::from_secs(30));
305 }
306
307 #[test]
308 fn test_should_retry() {
309 let config = ReconnectConfig {
310 enabled: true,
311 max_attempts: Some(3),
312 ..Default::default()
313 };
314
315 assert!(config.should_retry(0));
316 assert!(config.should_retry(1));
317 assert!(config.should_retry(2));
318 assert!(!config.should_retry(3));
319 assert!(!config.should_retry(4));
320 }
321
322 #[test]
323 fn test_disabled_reconnect() {
324 let config = ReconnectConfig::disabled();
325 assert!(!config.should_retry(0));
326 }
327
328 #[test]
329 fn test_reconnect_flow() {
330 let mut sm = ConnectionStateMachine::new(ReconnectConfig {
331 enabled: true,
332 initial_delay: Duration::from_millis(100),
333 max_delay: Duration::from_secs(10),
334 backoff_factor_tenths: 20,
335 max_attempts: Some(3),
336 });
337
338 sm.transition(&ConnectionEvent::Connecting);
339 sm.transition(&ConnectionEvent::Connected {
340 session_present: false,
341 });
342 assert!(sm.is_connected());
343
344 sm.transition(&ConnectionEvent::Disconnected {
345 reason: DisconnectReason::NetworkError("connection lost".into()),
346 });
347 assert!(sm.should_reconnect());
348
349 sm.transition(&ConnectionEvent::Reconnecting { attempt: 0 });
350 assert!(sm.state().is_reconnecting());
351 assert_eq!(sm.state().reconnect_attempt(), Some(0));
352 assert!(sm.should_reconnect());
353
354 sm.transition(&ConnectionEvent::Reconnecting { attempt: 1 });
355 assert!(sm.should_reconnect());
356
357 sm.transition(&ConnectionEvent::Reconnecting { attempt: 2 });
358 assert!(!sm.should_reconnect());
359 }
360}