Skip to main content

ntex_mqtt/
error.rs

1use std::{fmt, io, num::NonZeroU16};
2
3use ntex_util::future::Either;
4
5use crate::v5::codec::DisconnectReasonCode;
6
7pub(crate) const ERR_PUB_NOT_SUP: &str = "Publish control message is not supported";
8pub(crate) const ERR_AUTH_NOT_SUP: &str = "Auth control message is not supported";
9
10/// Errors which can occur when attempting to handle mqtt connection.
11#[derive(Debug, thiserror::Error)]
12pub enum MqttError<E> {
13    /// Publish handler service error
14    #[error("Service error")]
15    Service(E),
16    /// Handshake error
17    #[error("Mqtt handshake error: {}", _0)]
18    Handshake(#[from] HandshakeError<E>),
19}
20
21/// Errors which can occur during mqtt connection handshake.
22#[derive(Debug, thiserror::Error)]
23pub enum HandshakeError<E> {
24    /// Handshake service error
25    #[error("Handshake service error")]
26    Service(E),
27    /// Protocol error
28    #[error("Mqtt protocol error: {}", _0)]
29    Protocol(#[from] ProtocolError),
30    /// Handshake timeout
31    #[error("Handshake timeout")]
32    Timeout,
33    /// Peer disconnect
34    #[error("Peer is disconnected, error: {:?}", _0)]
35    Disconnected(Option<io::Error>),
36}
37
38/// Errors related to protocol dispatcher
39#[derive(Debug, thiserror::Error)]
40pub enum DispatcherError<E> {
41    /// Publish handler service error
42    #[error("Service error")]
43    Service(E),
44    /// Protocol violations error
45    #[error("Protocol violations error: {}", _0)]
46    Protocol(#[from] ProtocolError),
47}
48
49impl<E> From<SpecViolation> for DispatcherError<E> {
50    fn from(spec: SpecViolation) -> Self {
51        DispatcherError::Protocol(ProtocolError::spec(spec))
52    }
53}
54
55/// Errors related to payload processing
56#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
57pub enum PayloadError {
58    /// Protocol error
59    #[error("{0}")]
60    Protocol(#[from] ProtocolError),
61    /// Service error
62    #[error("Service error")]
63    Service,
64    /// Payload is consumed
65    #[error("Payload is consumed")]
66    Consumed,
67    /// Peer is disconnected
68    #[error("Peer is disconnected")]
69    Disconnected,
70}
71
72/// Protocol level errors
73#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
74pub enum ProtocolError {
75    /// MQTT decoding error
76    #[error("Decoding error: {0:?}")]
77    Decode(#[from] DecodeError),
78    /// MQTT encoding error
79    #[error("Encoding error: {0:?}")]
80    Encode(#[from] EncodeError),
81    /// Peer violated MQTT protocol specification
82    #[error("Protocol violation: {0}")]
83    ProtocolViolation(#[from] ProtocolViolationError),
84    /// Keep alive timeout
85    #[error("Keep Alive timeout")]
86    KeepAliveTimeout,
87    /// Read frame timeout
88    #[error("Read frame timeout")]
89    ReadTimeout,
90}
91
92#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
93#[error(transparent)]
94pub struct ProtocolViolationError {
95    pub(crate) inner: ViolationInner,
96}
97
98#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
99pub(crate) enum ViolationInner {
100    #[error("{0}")]
101    Spec(SpecViolation),
102    #[error("{message}")]
103    Common { reason: DisconnectReasonCode, message: &'static str },
104    #[error("{message}; received packet with type `{packet_type:08b}`")]
105    UnexpectedPacket { packet_type: u8, message: &'static str },
106}
107
108#[allow(non_camel_case_types)]
109#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
110pub enum SpecViolation {
111    #[error("[MQTT-2.2.1-3] PUBLISH received with packet id that is already in use")]
112    PacketId_2_2_1_3_Pub,
113    #[error("[MQTT-2.2.1-3] SUBSCRIBE received with packet id that is already in use")]
114    PacketId_2_2_1_3_Sub,
115    #[error("[MQTT-2.2.1-3] UNSUBSCRIBE received with packet id that is already in use")]
116    PacketId_2_2_1_3_Unsub,
117    #[error("[MQTT-3.1.2-26] Topic alias is greater than max allowed")]
118    Connect_3_1_2_26,
119    #[error(
120        "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
121    )]
122    Connack_3_2_2_11,
123    #[error("[MQTT-3.2.2-14] RETAIN is not supported")]
124    Connack_3_2_2_14,
125    #[error("[MQTT-3.2.2-17] Topic alias is greater than max allowed")]
126    Connack_3_2_2_17,
127    #[error("[MQTT-3.2.2-3.12] Subscription Identifiers are not supported")]
128    Connack_3_2_2_3_12,
129    #[error("[MQTT-3.3.2-2] PUBLISH packet's topic name contains wildcard character")]
130    Pub_3_3_2_2,
131    #[error("[MQTT-3.3.4-7] Number of in-flight messages exceeds set maximum")]
132    Pub_3_3_4_7,
133    #[error("[MQTT-3.3.4-9] Number of in-flight messages exceeds set maximum")]
134    Pub_3_3_4_9,
135    #[error("[MQTT-4.7.1-*] Topic filter is malformed")]
136    Subs_4_7_1,
137    #[error(
138        "[MQTT-3.14.2-*] The Session Expiry Interval must not be set on DISCONNECT by Server"
139    )]
140    Disconnect_3_14_2_21,
141    #[error("[MQTT-3.14.2-*] Non-Zero Session Expiry Interval is set on DISCONNECT")]
142    Disconnect_3_14_2_22,
143}
144
145impl SpecViolation {
146    const fn reason(self) -> DisconnectReasonCode {
147        match self {
148            SpecViolation::Pub_3_3_4_7 | SpecViolation::Pub_3_3_4_9 => {
149                DisconnectReasonCode::ReceiveMaximumExceeded
150            }
151            SpecViolation::Connack_3_2_2_11 => DisconnectReasonCode::QosNotSupported,
152            SpecViolation::Connack_3_2_2_14 => DisconnectReasonCode::RetainNotSupported,
153            SpecViolation::Connack_3_2_2_3_12 => {
154                DisconnectReasonCode::SubscriptionIdentifiersNotSupported
155            }
156            SpecViolation::PacketId_2_2_1_3_Pub
157            | SpecViolation::PacketId_2_2_1_3_Sub
158            | SpecViolation::PacketId_2_2_1_3_Unsub
159            | SpecViolation::Connect_3_1_2_26
160            | SpecViolation::Pub_3_3_2_2
161            | SpecViolation::Subs_4_7_1
162            | SpecViolation::Connack_3_2_2_17
163            | SpecViolation::Disconnect_3_14_2_21
164            | SpecViolation::Disconnect_3_14_2_22 => DisconnectReasonCode::ProtocolError,
165        }
166    }
167
168    const fn as_str(self) -> &'static str {
169        match self {
170            SpecViolation::PacketId_2_2_1_3_Pub => {
171                "[MQTT-2.2.1-3] PUBLISH received with packet id that is already in use"
172            }
173            SpecViolation::PacketId_2_2_1_3_Sub => {
174                "[MQTT-2.2.1-3] SUBSCRIBE received with packet id that is already in use"
175            }
176            SpecViolation::PacketId_2_2_1_3_Unsub => {
177                "[MQTT-2.2.1-3] UNSUBSCRIBE received with packet id that is already in use"
178            }
179            SpecViolation::Connect_3_1_2_26 => {
180                "[MQTT-3.1.2-26] Topic alias is greater than max allowed"
181            }
182            SpecViolation::Connack_3_2_2_11 => {
183                "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
184            }
185            SpecViolation::Connack_3_2_2_14 => "[MQTT-3.2.2-14] RETAIN is not supported",
186            SpecViolation::Connack_3_2_2_17 => {
187                "[MQTT-3.2.2-17] Topic alias is greater than max allowed"
188            }
189            SpecViolation::Connack_3_2_2_3_12 => {
190                "[MQTT-3.2.2-3.12] Subscription Identifiers are not supported"
191            }
192            SpecViolation::Pub_3_3_2_2 => {
193                "[MQTT-3.3.2-2] PUBLISH packet's topic name contains wildcard character"
194            }
195            SpecViolation::Pub_3_3_4_7 => {
196                "[MQTT-3.3.4-7] Number of in-flight messages exceeds set maximum"
197            }
198            SpecViolation::Pub_3_3_4_9 => {
199                "[MQTT-3.3.4-9] Number of in-flight messages exceeds set maximum"
200            }
201            SpecViolation::Subs_4_7_1 => "[MQTT-4.7.1-*] Topic filter is malformed",
202            SpecViolation::Disconnect_3_14_2_21 => {
203                "[MQTT-3.14.2-*] The Session Expiry Interval must not be set on DISCONNECT by Server"
204            }
205            SpecViolation::Disconnect_3_14_2_22 => {
206                "[MQTT-3.14.2-*] Non-Zero Session Expiry Interval is set on DISCONNECT"
207            }
208        }
209    }
210}
211
212impl ProtocolViolationError {
213    /// Protocol violation reason code
214    pub const fn reason(&self) -> DisconnectReasonCode {
215        match self.inner {
216            ViolationInner::Spec(err) => err.reason(),
217            ViolationInner::Common { reason, .. } => reason,
218            ViolationInner::UnexpectedPacket { .. } => DisconnectReasonCode::ProtocolError,
219        }
220    }
221
222    /// Protocol violation reason message
223    pub const fn message(&self) -> &'static str {
224        match self.inner {
225            ViolationInner::Common { message, .. }
226            | ViolationInner::UnexpectedPacket { message, .. } => message,
227            ViolationInner::Spec(err) => err.as_str(),
228        }
229    }
230}
231
232impl ProtocolError {
233    pub(crate) fn violation(reason: DisconnectReasonCode, message: &'static str) -> Self {
234        Self::ProtocolViolation(ProtocolViolationError {
235            inner: ViolationInner::Common { reason, message },
236        })
237    }
238
239    pub fn spec(err: SpecViolation) -> Self {
240        Self::ProtocolViolation(ProtocolViolationError { inner: ViolationInner::Spec(err) })
241    }
242
243    pub fn generic_violation(message: &'static str) -> Self {
244        Self::violation(DisconnectReasonCode::ProtocolError, message)
245    }
246
247    pub(crate) fn unexpected_packet(packet_type: u8, message: &'static str) -> ProtocolError {
248        Self::ProtocolViolation(ProtocolViolationError {
249            inner: ViolationInner::UnexpectedPacket { packet_type, message },
250        })
251    }
252    pub(crate) fn packet_id_mismatch() -> Self {
253        Self::generic_violation(
254            "Packet id of PUBACK packet does not match expected next value according to sending order of PUBLISH packets [MQTT-4.6.0-2]",
255        )
256    }
257}
258
259impl<E> From<io::Error> for MqttError<E> {
260    fn from(err: io::Error) -> Self {
261        MqttError::Handshake(HandshakeError::Disconnected(Some(err)))
262    }
263}
264
265impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
266    fn from(err: Either<io::Error, io::Error>) -> Self {
267        MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))
268    }
269}
270
271impl<E> From<EncodeError> for MqttError<E> {
272    fn from(err: EncodeError) -> Self {
273        MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))
274    }
275}
276
277impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
278    fn from(err: Either<DecodeError, io::Error>) -> Self {
279        match err {
280            Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
281            Either::Right(err) => HandshakeError::Disconnected(Some(err)),
282        }
283    }
284}
285
286#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)]
287pub enum DecodeError {
288    #[error("Invalid protocol")]
289    InvalidProtocol,
290    #[error("Invalid length")]
291    InvalidLength,
292    #[error("Malformed packet")]
293    MalformedPacket,
294    #[error("Unsupported protocol level")]
295    UnsupportedProtocolLevel,
296    #[error("Connect frame's reserved flag is set")]
297    ConnectReservedFlagSet,
298    #[error("ConnectAck frame's reserved flag is set")]
299    ConnAckReservedFlagSet,
300    #[error("Invalid client id")]
301    InvalidClientId,
302    #[error("Unsupported packet type")]
303    UnsupportedPacketType,
304    // MQTT v3 only
305    #[error("Packet id is required")]
306    PacketIdRequired,
307    #[error("Max size exceeded size:{size} max-size:{max_size}")]
308    MaxSizeExceeded { size: u32, max_size: u32 },
309    #[error("utf8 error")]
310    Utf8Error,
311    #[error("Unexpected payload")]
312    UnexpectedPayload,
313}
314
315#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, thiserror::Error)]
316pub enum EncodeError {
317    #[error("Packet is bigger than peer's Maximum Packet Size")]
318    OverMaxPacketSize,
319    #[error("Streaming payload is bigger than Publish packet definition")]
320    OverPublishSize,
321    #[error("Streaming payload is incomplete")]
322    PublishIncomplete,
323    #[error("Invalid length")]
324    InvalidLength,
325    #[error("Malformed packet")]
326    MalformedPacket,
327    #[error("Packet id is required")]
328    PacketIdRequired,
329    #[error("Unexpected payload")]
330    UnexpectedPayload,
331    #[error("Publish packet is not completed, expect payload")]
332    ExpectPayload,
333    #[error("Unsupported version")]
334    UnsupportedVersion,
335}
336
337#[derive(Debug, PartialEq, Eq, Copy, Clone, thiserror::Error)]
338pub enum SendPacketError {
339    /// Encoder error
340    #[error("Encoding error {:?}", _0)]
341    Encode(#[from] EncodeError),
342    /// Provided packet id is in use
343    #[error("Provided packet id is in use")]
344    PacketIdInUse(NonZeroU16),
345    /// Unexpected release publish
346    #[error("Unexpected publish release")]
347    UnexpectedRelease,
348    /// Streaming has been cancelled
349    #[error("Streaming has been cancelled")]
350    StreamingCancelled,
351    /// Peer disconnected
352    #[error("Peer is disconnected")]
353    Disconnected,
354}
355
356/// Errors which can occur when attempting to handle mqtt client connection.
357#[derive(Debug, thiserror::Error)]
358pub enum ClientError<T: fmt::Debug> {
359    /// Connect negotiation failed
360    #[error("Connect ack failed: {:?}", _0)]
361    Ack(T),
362    /// Protocol error
363    #[error("Protocol error: {:?}", _0)]
364    Protocol(#[from] ProtocolError),
365    /// Handshake timeout
366    #[error("Handshake timeout")]
367    HandshakeTimeout,
368    /// Peer disconnected
369    #[error("Peer disconnected")]
370    Disconnected(Option<std::io::Error>),
371    /// Connect error
372    #[error("Connect error: {}", _0)]
373    Connect(#[from] ntex_net::connect::ConnectError),
374}
375
376impl<T: fmt::Debug> From<EncodeError> for ClientError<T> {
377    fn from(err: EncodeError) -> Self {
378        ClientError::Protocol(ProtocolError::Encode(err))
379    }
380}
381
382impl<T: fmt::Debug> From<Either<DecodeError, std::io::Error>> for ClientError<T> {
383    fn from(err: Either<DecodeError, std::io::Error>) -> Self {
384        match err {
385            Either::Left(err) => ClientError::Protocol(ProtocolError::Decode(err)),
386            Either::Right(err) => ClientError::Disconnected(Some(err)),
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use std::io;
394
395    use super::*;
396
397    #[test]
398    fn test_spec_violation_reason_and_message() {
399        let err = ProtocolError::spec(SpecViolation::Connack_3_2_2_11);
400        let ProtocolError::ProtocolViolation(violation) = err else {
401            panic!("expected protocol violation");
402        };
403
404        assert_eq!(violation.reason(), DisconnectReasonCode::QosNotSupported);
405        assert_eq!(
406            violation.message(),
407            "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
408        );
409    }
410
411    #[test]
412    fn test_generic_violation_reason_and_message() {
413        let err = ProtocolError::generic_violation("broken");
414        let ProtocolError::ProtocolViolation(violation) = err else {
415            panic!("expected protocol violation");
416        };
417
418        assert_eq!(violation.reason(), DisconnectReasonCode::ProtocolError);
419        assert_eq!(violation.message(), "broken");
420    }
421
422    #[test]
423    fn test_unexpected_packet_reason_and_message() {
424        let err = ProtocolError::unexpected_packet(0b0011_0000, "unexpected");
425        let ProtocolError::ProtocolViolation(violation) = err else {
426            panic!("expected protocol violation");
427        };
428
429        assert_eq!(violation.reason(), DisconnectReasonCode::ProtocolError);
430        assert_eq!(violation.message(), "unexpected");
431        assert_eq!(
432            err.to_string(),
433            "Protocol violation: unexpected; received packet with type `00110000`"
434        );
435    }
436
437    #[test]
438    fn test_mqtt_error_from_io_and_encode() {
439        let io_err = io::Error::other("io");
440        let err: MqttError<()> = io_err.into();
441        match err {
442            MqttError::Handshake(HandshakeError::Disconnected(Some(err))) => {
443                assert_eq!(err.kind(), io::ErrorKind::Other);
444            }
445            _ => panic!("expected disconnected handshake error"),
446        }
447
448        let err: MqttError<()> = EncodeError::MalformedPacket.into();
449        assert!(matches!(
450            err,
451            MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(
452                EncodeError::MalformedPacket
453            )))
454        ));
455    }
456
457    #[test]
458    fn test_handshake_error_from_decode_or_io() {
459        let err: HandshakeError<()> = Either::Left(DecodeError::MalformedPacket).into();
460        assert!(matches!(
461            err,
462            HandshakeError::Protocol(ProtocolError::Decode(DecodeError::MalformedPacket))
463        ));
464
465        let err: HandshakeError<()> = Either::Right(io::Error::other("peer")).into();
466        match err {
467            HandshakeError::Disconnected(Some(err)) => {
468                assert_eq!(err.kind(), io::ErrorKind::Other);
469            }
470            _ => panic!("expected disconnected handshake error"),
471        }
472    }
473
474    #[test]
475    fn test_client_error_from_decode_or_io_and_encode() {
476        let err: ClientError<()> = Either::Left(DecodeError::InvalidLength).into();
477        assert!(matches!(
478            err,
479            ClientError::Protocol(ProtocolError::Decode(DecodeError::InvalidLength))
480        ));
481
482        let err: ClientError<()> = Either::Right(io::Error::other("peer")).into();
483        match err {
484            ClientError::Disconnected(Some(err)) => {
485                assert_eq!(err.kind(), io::ErrorKind::Other);
486            }
487            _ => panic!("expected disconnected client error"),
488        }
489
490        let err: ClientError<()> = EncodeError::UnexpectedPayload.into();
491        assert!(matches!(
492            err,
493            ClientError::Protocol(ProtocolError::Encode(EncodeError::UnexpectedPayload))
494        ));
495    }
496}