mqtt5_protocol/
error_classification.rs

1use crate::error::MqttError;
2use crate::protocol::v5::reason_codes::ReasonCode;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum RecoverableError {
6    NetworkError,
7    ServerUnavailable,
8    QuotaExceeded,
9    PacketIdExhausted,
10    FlowControlLimited,
11    SessionTakenOver,
12    ServerShuttingDown,
13    MqoqFlowRecoverable,
14}
15
16impl RecoverableError {
17    #[must_use]
18    pub fn base_delay_multiplier(&self) -> u32 {
19        match self {
20            Self::QuotaExceeded => 10,
21            Self::MqoqFlowRecoverable => 3,
22            Self::FlowControlLimited => 2,
23            _ => 1,
24        }
25    }
26
27    #[must_use]
28    pub fn default_set() -> [Self; 6] {
29        [
30            Self::NetworkError,
31            Self::ServerUnavailable,
32            Self::QuotaExceeded,
33            Self::PacketIdExhausted,
34            Self::FlowControlLimited,
35            Self::MqoqFlowRecoverable,
36        ]
37    }
38}
39
40impl MqttError {
41    #[must_use]
42    pub fn classify(&self) -> Option<RecoverableError> {
43        match self {
44            Self::ConnectionError(msg) => classify_connection_error(msg),
45            Self::ConnectionRefused(reason) => classify_connection_refused(*reason),
46            Self::PacketIdExhausted => Some(RecoverableError::PacketIdExhausted),
47            Self::FlowControlExceeded => Some(RecoverableError::FlowControlLimited),
48            Self::Timeout => Some(RecoverableError::NetworkError),
49            Self::ServerUnavailable | Self::ServerBusy => Some(RecoverableError::ServerUnavailable),
50            Self::QuotaExceeded => Some(RecoverableError::QuotaExceeded),
51            Self::ServerShuttingDown => Some(RecoverableError::ServerShuttingDown),
52            Self::SessionExpired => Some(RecoverableError::SessionTakenOver),
53            _ => None,
54        }
55    }
56
57    #[must_use]
58    pub fn is_aws_iot_connection_limit(&self) -> bool {
59        match self {
60            Self::ConnectionError(msg) => is_aws_iot_limit_error(msg),
61            _ => false,
62        }
63    }
64}
65
66fn classify_connection_error(msg: &str) -> Option<RecoverableError> {
67    if is_aws_iot_limit_error(msg) {
68        return None;
69    }
70
71    if msg.contains("temporarily unavailable")
72        || msg.contains("Connection refused")
73        || msg.contains("Network is unreachable")
74        || msg.contains("connection reset")
75        || msg.contains("broken pipe")
76        || msg.contains("timed out")
77    {
78        return Some(RecoverableError::NetworkError);
79    }
80
81    None
82}
83
84fn is_aws_iot_limit_error(msg: &str) -> bool {
85    msg.contains("Connection reset by peer")
86        || msg.contains("RST")
87        || msg.contains("TCP RST")
88        || msg.contains("reset by peer")
89        || msg.contains("connection limit")
90        || msg.contains("client limit")
91}
92
93fn classify_connection_refused(reason: ReasonCode) -> Option<RecoverableError> {
94    match reason {
95        ReasonCode::ServerUnavailable | ReasonCode::ServerBusy => {
96            Some(RecoverableError::ServerUnavailable)
97        }
98        ReasonCode::QuotaExceeded => Some(RecoverableError::QuotaExceeded),
99        ReasonCode::SessionTakenOver => Some(RecoverableError::SessionTakenOver),
100        ReasonCode::ServerShuttingDown => Some(RecoverableError::ServerShuttingDown),
101        ReasonCode::MqoqIncompletePacket
102        | ReasonCode::MqoqFlowOpenIdle
103        | ReasonCode::MqoqFlowCancelled => Some(RecoverableError::MqoqFlowRecoverable),
104        _ => None,
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_connection_error_classification() {
114        let error = MqttError::ConnectionError("Connection refused".to_string());
115        assert_eq!(error.classify(), Some(RecoverableError::NetworkError));
116
117        let error = MqttError::ConnectionError("Network is unreachable".to_string());
118        assert_eq!(error.classify(), Some(RecoverableError::NetworkError));
119
120        let error = MqttError::ConnectionError("temporarily unavailable".to_string());
121        assert_eq!(error.classify(), Some(RecoverableError::NetworkError));
122    }
123
124    #[test]
125    fn test_aws_iot_limit_not_recoverable() {
126        let error = MqttError::ConnectionError("Connection reset by peer".to_string());
127        assert_eq!(error.classify(), None);
128        assert!(error.is_aws_iot_connection_limit());
129
130        let error = MqttError::ConnectionError("TCP RST received".to_string());
131        assert_eq!(error.classify(), None);
132        assert!(error.is_aws_iot_connection_limit());
133
134        let error = MqttError::ConnectionError("client limit exceeded".to_string());
135        assert_eq!(error.classify(), None);
136        assert!(error.is_aws_iot_connection_limit());
137    }
138
139    #[test]
140    fn test_connection_refused_classification() {
141        let error = MqttError::ConnectionRefused(ReasonCode::ServerUnavailable);
142        assert_eq!(error.classify(), Some(RecoverableError::ServerUnavailable));
143
144        let error = MqttError::ConnectionRefused(ReasonCode::QuotaExceeded);
145        assert_eq!(error.classify(), Some(RecoverableError::QuotaExceeded));
146
147        let error = MqttError::ConnectionRefused(ReasonCode::SessionTakenOver);
148        assert_eq!(error.classify(), Some(RecoverableError::SessionTakenOver));
149
150        let error = MqttError::ConnectionRefused(ReasonCode::BadUsernameOrPassword);
151        assert_eq!(error.classify(), None);
152    }
153
154    #[test]
155    fn test_mqoq_classification() {
156        let error = MqttError::ConnectionRefused(ReasonCode::MqoqIncompletePacket);
157        assert_eq!(
158            error.classify(),
159            Some(RecoverableError::MqoqFlowRecoverable)
160        );
161
162        let error = MqttError::ConnectionRefused(ReasonCode::MqoqFlowOpenIdle);
163        assert_eq!(
164            error.classify(),
165            Some(RecoverableError::MqoqFlowRecoverable)
166        );
167
168        let error = MqttError::ConnectionRefused(ReasonCode::MqoqFlowCancelled);
169        assert_eq!(
170            error.classify(),
171            Some(RecoverableError::MqoqFlowRecoverable)
172        );
173
174        let error = MqttError::ConnectionRefused(ReasonCode::MqoqNoFlowState);
175        assert_eq!(error.classify(), None);
176    }
177
178    #[test]
179    fn test_direct_error_classification() {
180        assert_eq!(
181            MqttError::PacketIdExhausted.classify(),
182            Some(RecoverableError::PacketIdExhausted)
183        );
184        assert_eq!(
185            MqttError::FlowControlExceeded.classify(),
186            Some(RecoverableError::FlowControlLimited)
187        );
188        assert_eq!(
189            MqttError::Timeout.classify(),
190            Some(RecoverableError::NetworkError)
191        );
192        assert_eq!(
193            MqttError::ServerUnavailable.classify(),
194            Some(RecoverableError::ServerUnavailable)
195        );
196        assert_eq!(
197            MqttError::ServerBusy.classify(),
198            Some(RecoverableError::ServerUnavailable)
199        );
200        assert_eq!(
201            MqttError::QuotaExceeded.classify(),
202            Some(RecoverableError::QuotaExceeded)
203        );
204        assert_eq!(
205            MqttError::ServerShuttingDown.classify(),
206            Some(RecoverableError::ServerShuttingDown)
207        );
208    }
209
210    #[test]
211    fn test_non_recoverable_errors() {
212        assert_eq!(MqttError::NotConnected.classify(), None);
213        assert_eq!(MqttError::AlreadyConnected.classify(), None);
214        assert_eq!(MqttError::AuthenticationFailed.classify(), None);
215        assert_eq!(MqttError::NotAuthorized.classify(), None);
216        assert_eq!(MqttError::BadUsernameOrPassword.classify(), None);
217        assert_eq!(
218            MqttError::ProtocolError("test".to_string()).classify(),
219            None
220        );
221    }
222
223    #[test]
224    fn test_base_delay_multiplier() {
225        assert_eq!(RecoverableError::NetworkError.base_delay_multiplier(), 1);
226        assert_eq!(
227            RecoverableError::ServerUnavailable.base_delay_multiplier(),
228            1
229        );
230        assert_eq!(RecoverableError::QuotaExceeded.base_delay_multiplier(), 10);
231        assert_eq!(
232            RecoverableError::FlowControlLimited.base_delay_multiplier(),
233            2
234        );
235        assert_eq!(
236            RecoverableError::MqoqFlowRecoverable.base_delay_multiplier(),
237            3
238        );
239    }
240
241    #[test]
242    fn test_default_set() {
243        let defaults = RecoverableError::default_set();
244        assert_eq!(defaults.len(), 6);
245        assert!(defaults.contains(&RecoverableError::NetworkError));
246        assert!(defaults.contains(&RecoverableError::ServerUnavailable));
247        assert!(defaults.contains(&RecoverableError::QuotaExceeded));
248        assert!(defaults.contains(&RecoverableError::PacketIdExhausted));
249        assert!(defaults.contains(&RecoverableError::FlowControlLimited));
250        assert!(defaults.contains(&RecoverableError::MqoqFlowRecoverable));
251        assert!(!defaults.contains(&RecoverableError::SessionTakenOver));
252        assert!(!defaults.contains(&RecoverableError::ServerShuttingDown));
253    }
254}