1use std::{fmt, io, num::NonZeroU16};
2
3use ntex_util::future::Either;
4
5use crate::v5::codec::DisconnectReasonCode;
6
7pub(crate) const ERR_PUB_NOT_SUP: &str = "Publish control message is not supported";
8pub(crate) const ERR_AUTH_NOT_SUP: &str = "Auth control message is not supported";
9
10#[derive(Debug, thiserror::Error)]
12pub enum MqttError<E> {
13 #[error("Service error")]
15 Service(E),
16 #[error("Mqtt handshake error: {}", _0)]
18 Handshake(#[from] HandshakeError<E>),
19}
20
21#[derive(Debug, thiserror::Error)]
23pub enum HandshakeError<E> {
24 #[error("Handshake service error")]
26 Service(E),
27 #[error("Mqtt protocol error: {}", _0)]
29 Protocol(#[from] ProtocolError),
30 #[error("Handshake timeout")]
32 Timeout,
33 #[error("Peer is disconnected, error: {:?}", _0)]
35 Disconnected(Option<io::Error>),
36}
37
38#[derive(Debug, thiserror::Error)]
40pub enum DispatcherError<E> {
41 #[error("Service error")]
43 Service(E),
44 #[error("Protocol violations error: {}", _0)]
46 Protocol(#[from] ProtocolError),
47}
48
49impl<E> From<SpecViolation> for DispatcherError<E> {
50 fn from(spec: SpecViolation) -> Self {
51 DispatcherError::Protocol(ProtocolError::spec(spec))
52 }
53}
54
55#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
57pub enum PayloadError {
58 #[error("{0}")]
60 Protocol(#[from] ProtocolError),
61 #[error("Service error")]
63 Service,
64 #[error("Payload is consumed")]
66 Consumed,
67 #[error("Peer is disconnected")]
69 Disconnected,
70}
71
72#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
74pub enum ProtocolError {
75 #[error("Decoding error: {0:?}")]
77 Decode(#[from] DecodeError),
78 #[error("Encoding error: {0:?}")]
80 Encode(#[from] EncodeError),
81 #[error("Protocol violation: {0}")]
83 ProtocolViolation(#[from] ProtocolViolationError),
84 #[error("Keep Alive timeout")]
86 KeepAliveTimeout,
87 #[error("Read frame timeout")]
89 ReadTimeout,
90}
91
92#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
93#[error(transparent)]
94pub struct ProtocolViolationError {
95 pub(crate) inner: ViolationInner,
96}
97
98#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
99pub(crate) enum ViolationInner {
100 #[error("{0}")]
101 Spec(SpecViolation),
102 #[error("{message}")]
103 Common { reason: DisconnectReasonCode, message: &'static str },
104 #[error("{message}; received packet with type `{packet_type:08b}`")]
105 UnexpectedPacket { packet_type: u8, message: &'static str },
106}
107
108#[allow(non_camel_case_types)]
109#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
110pub enum SpecViolation {
111 #[error("[MQTT-2.2.1-3] PUBLISH received with packet id that is already in use")]
112 PacketId_2_2_1_3_Pub,
113 #[error("[MQTT-2.2.1-3] SUBSCRIBE received with packet id that is already in use")]
114 PacketId_2_2_1_3_Sub,
115 #[error("[MQTT-2.2.1-3] UNSUBSCRIBE received with packet id that is already in use")]
116 PacketId_2_2_1_3_Unsub,
117 #[error("[MQTT-3.1.2-26] Topic alias is greater than max allowed")]
118 Connect_3_1_2_26,
119 #[error(
120 "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
121 )]
122 Connack_3_2_2_11,
123 #[error("[MQTT-3.2.2-14] RETAIN is not supported")]
124 Connack_3_2_2_14,
125 #[error("[MQTT-3.2.2-17] Topic alias is greater than max allowed")]
126 Connack_3_2_2_17,
127 #[error("[MQTT-3.2.2-3.12] Subscription Identifiers are not supported")]
128 Connack_3_2_2_3_12,
129 #[error("[MQTT-3.3.2-2] PUBLISH packet's topic name contains wildcard character")]
130 Pub_3_3_2_2,
131 #[error("[MQTT-3.3.4-7] Number of in-flight messages exceeds set maximum")]
132 Pub_3_3_4_7,
133 #[error("[MQTT-3.3.4-9] Number of in-flight messages exceeds set maximum")]
134 Pub_3_3_4_9,
135 #[error("[MQTT-4.7.1-*] Topic filter is malformed")]
136 Subs_4_7_1,
137 #[error(
138 "[MQTT-3.14.2-*] The Session Expiry Interval must not be set on DISCONNECT by Server"
139 )]
140 Disconnect_3_14_2_21,
141 #[error("[MQTT-3.14.2-*] Non-Zero Session Expiry Interval is set on DISCONNECT")]
142 Disconnect_3_14_2_22,
143}
144
145impl SpecViolation {
146 const fn reason(self) -> DisconnectReasonCode {
147 match self {
148 SpecViolation::Pub_3_3_4_7 | SpecViolation::Pub_3_3_4_9 => {
149 DisconnectReasonCode::ReceiveMaximumExceeded
150 }
151 SpecViolation::Connack_3_2_2_11 => DisconnectReasonCode::QosNotSupported,
152 SpecViolation::Connack_3_2_2_14 => DisconnectReasonCode::RetainNotSupported,
153 SpecViolation::Connack_3_2_2_3_12 => {
154 DisconnectReasonCode::SubscriptionIdentifiersNotSupported
155 }
156 SpecViolation::PacketId_2_2_1_3_Pub
157 | SpecViolation::PacketId_2_2_1_3_Sub
158 | SpecViolation::PacketId_2_2_1_3_Unsub
159 | SpecViolation::Connect_3_1_2_26
160 | SpecViolation::Pub_3_3_2_2
161 | SpecViolation::Subs_4_7_1
162 | SpecViolation::Connack_3_2_2_17
163 | SpecViolation::Disconnect_3_14_2_21
164 | SpecViolation::Disconnect_3_14_2_22 => DisconnectReasonCode::ProtocolError,
165 }
166 }
167
168 const fn as_str(self) -> &'static str {
169 match self {
170 SpecViolation::PacketId_2_2_1_3_Pub => {
171 "[MQTT-2.2.1-3] PUBLISH received with packet id that is already in use"
172 }
173 SpecViolation::PacketId_2_2_1_3_Sub => {
174 "[MQTT-2.2.1-3] SUBSCRIBE received with packet id that is already in use"
175 }
176 SpecViolation::PacketId_2_2_1_3_Unsub => {
177 "[MQTT-2.2.1-3] UNSUBSCRIBE received with packet id that is already in use"
178 }
179 SpecViolation::Connect_3_1_2_26 => {
180 "[MQTT-3.1.2-26] Topic alias is greater than max allowed"
181 }
182 SpecViolation::Connack_3_2_2_11 => {
183 "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
184 }
185 SpecViolation::Connack_3_2_2_14 => "[MQTT-3.2.2-14] RETAIN is not supported",
186 SpecViolation::Connack_3_2_2_17 => {
187 "[MQTT-3.2.2-17] Topic alias is greater than max allowed"
188 }
189 SpecViolation::Connack_3_2_2_3_12 => {
190 "[MQTT-3.2.2-3.12] Subscription Identifiers are not supported"
191 }
192 SpecViolation::Pub_3_3_2_2 => {
193 "[MQTT-3.3.2-2] PUBLISH packet's topic name contains wildcard character"
194 }
195 SpecViolation::Pub_3_3_4_7 => {
196 "[MQTT-3.3.4-7] Number of in-flight messages exceeds set maximum"
197 }
198 SpecViolation::Pub_3_3_4_9 => {
199 "[MQTT-3.3.4-9] Number of in-flight messages exceeds set maximum"
200 }
201 SpecViolation::Subs_4_7_1 => "[MQTT-4.7.1-*] Topic filter is malformed",
202 SpecViolation::Disconnect_3_14_2_21 => {
203 "[MQTT-3.14.2-*] The Session Expiry Interval must not be set on DISCONNECT by Server"
204 }
205 SpecViolation::Disconnect_3_14_2_22 => {
206 "[MQTT-3.14.2-*] Non-Zero Session Expiry Interval is set on DISCONNECT"
207 }
208 }
209 }
210}
211
212impl ProtocolViolationError {
213 pub const fn reason(&self) -> DisconnectReasonCode {
215 match self.inner {
216 ViolationInner::Spec(err) => err.reason(),
217 ViolationInner::Common { reason, .. } => reason,
218 ViolationInner::UnexpectedPacket { .. } => DisconnectReasonCode::ProtocolError,
219 }
220 }
221
222 pub const fn message(&self) -> &'static str {
224 match self.inner {
225 ViolationInner::Common { message, .. }
226 | ViolationInner::UnexpectedPacket { message, .. } => message,
227 ViolationInner::Spec(err) => err.as_str(),
228 }
229 }
230}
231
232impl ProtocolError {
233 pub(crate) fn violation(reason: DisconnectReasonCode, message: &'static str) -> Self {
234 Self::ProtocolViolation(ProtocolViolationError {
235 inner: ViolationInner::Common { reason, message },
236 })
237 }
238
239 pub fn spec(err: SpecViolation) -> Self {
240 Self::ProtocolViolation(ProtocolViolationError { inner: ViolationInner::Spec(err) })
241 }
242
243 pub fn generic_violation(message: &'static str) -> Self {
244 Self::violation(DisconnectReasonCode::ProtocolError, message)
245 }
246
247 pub(crate) fn unexpected_packet(packet_type: u8, message: &'static str) -> ProtocolError {
248 Self::ProtocolViolation(ProtocolViolationError {
249 inner: ViolationInner::UnexpectedPacket { packet_type, message },
250 })
251 }
252 pub(crate) fn packet_id_mismatch() -> Self {
253 Self::generic_violation(
254 "Packet id of PUBACK packet does not match expected next value according to sending order of PUBLISH packets [MQTT-4.6.0-2]",
255 )
256 }
257}
258
259impl<E> From<io::Error> for MqttError<E> {
260 fn from(err: io::Error) -> Self {
261 MqttError::Handshake(HandshakeError::Disconnected(Some(err)))
262 }
263}
264
265impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
266 fn from(err: Either<io::Error, io::Error>) -> Self {
267 MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))
268 }
269}
270
271impl<E> From<EncodeError> for MqttError<E> {
272 fn from(err: EncodeError) -> Self {
273 MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))
274 }
275}
276
277impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
278 fn from(err: Either<DecodeError, io::Error>) -> Self {
279 match err {
280 Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
281 Either::Right(err) => HandshakeError::Disconnected(Some(err)),
282 }
283 }
284}
285
286#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)]
287pub enum DecodeError {
288 #[error("Invalid protocol")]
289 InvalidProtocol,
290 #[error("Invalid length")]
291 InvalidLength,
292 #[error("Malformed packet")]
293 MalformedPacket,
294 #[error("Unsupported protocol level")]
295 UnsupportedProtocolLevel,
296 #[error("Connect frame's reserved flag is set")]
297 ConnectReservedFlagSet,
298 #[error("ConnectAck frame's reserved flag is set")]
299 ConnAckReservedFlagSet,
300 #[error("Invalid client id")]
301 InvalidClientId,
302 #[error("Unsupported packet type")]
303 UnsupportedPacketType,
304 #[error("Packet id is required")]
306 PacketIdRequired,
307 #[error("Max size exceeded size:{size} max-size:{max_size}")]
308 MaxSizeExceeded { size: u32, max_size: u32 },
309 #[error("utf8 error")]
310 Utf8Error,
311 #[error("Unexpected payload")]
312 UnexpectedPayload,
313}
314
315#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, thiserror::Error)]
316pub enum EncodeError {
317 #[error("Packet is bigger than peer's Maximum Packet Size")]
318 OverMaxPacketSize,
319 #[error("Streaming payload is bigger than Publish packet definition")]
320 OverPublishSize,
321 #[error("Streaming payload is incomplete")]
322 PublishIncomplete,
323 #[error("Invalid length")]
324 InvalidLength,
325 #[error("Malformed packet")]
326 MalformedPacket,
327 #[error("Packet id is required")]
328 PacketIdRequired,
329 #[error("Unexpected payload")]
330 UnexpectedPayload,
331 #[error("Publish packet is not completed, expect payload")]
332 ExpectPayload,
333 #[error("Unsupported version")]
334 UnsupportedVersion,
335}
336
337#[derive(Debug, PartialEq, Eq, Copy, Clone, thiserror::Error)]
338pub enum SendPacketError {
339 #[error("Encoding error {:?}", _0)]
341 Encode(#[from] EncodeError),
342 #[error("Provided packet id is in use")]
344 PacketIdInUse(NonZeroU16),
345 #[error("Unexpected publish release")]
347 UnexpectedRelease,
348 #[error("Streaming has been cancelled")]
350 StreamingCancelled,
351 #[error("Peer is disconnected")]
353 Disconnected,
354}
355
356#[derive(Debug, thiserror::Error)]
358pub enum ClientError<T: fmt::Debug> {
359 #[error("Connect ack failed: {:?}", _0)]
361 Ack(T),
362 #[error("Protocol error: {:?}", _0)]
364 Protocol(#[from] ProtocolError),
365 #[error("Handshake timeout")]
367 HandshakeTimeout,
368 #[error("Peer disconnected")]
370 Disconnected(Option<std::io::Error>),
371 #[error("Connect error: {}", _0)]
373 Connect(#[from] ntex_net::connect::ConnectError),
374}
375
376impl<T: fmt::Debug> From<EncodeError> for ClientError<T> {
377 fn from(err: EncodeError) -> Self {
378 ClientError::Protocol(ProtocolError::Encode(err))
379 }
380}
381
382impl<T: fmt::Debug> From<Either<DecodeError, std::io::Error>> for ClientError<T> {
383 fn from(err: Either<DecodeError, std::io::Error>) -> Self {
384 match err {
385 Either::Left(err) => ClientError::Protocol(ProtocolError::Decode(err)),
386 Either::Right(err) => ClientError::Disconnected(Some(err)),
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use std::io;
394
395 use super::*;
396
397 #[test]
398 fn test_spec_violation_reason_and_message() {
399 let err = ProtocolError::spec(SpecViolation::Connack_3_2_2_11);
400 let ProtocolError::ProtocolViolation(violation) = err else {
401 panic!("expected protocol violation");
402 };
403
404 assert_eq!(violation.reason(), DisconnectReasonCode::QosNotSupported);
405 assert_eq!(
406 violation.message(),
407 "[MQTT-3.2.2-11] PUBLISH packet at a QoS level exceeding the Maximum QoS level specified in CONNACK"
408 );
409 }
410
411 #[test]
412 fn test_generic_violation_reason_and_message() {
413 let err = ProtocolError::generic_violation("broken");
414 let ProtocolError::ProtocolViolation(violation) = err else {
415 panic!("expected protocol violation");
416 };
417
418 assert_eq!(violation.reason(), DisconnectReasonCode::ProtocolError);
419 assert_eq!(violation.message(), "broken");
420 }
421
422 #[test]
423 fn test_unexpected_packet_reason_and_message() {
424 let err = ProtocolError::unexpected_packet(0b0011_0000, "unexpected");
425 let ProtocolError::ProtocolViolation(violation) = err else {
426 panic!("expected protocol violation");
427 };
428
429 assert_eq!(violation.reason(), DisconnectReasonCode::ProtocolError);
430 assert_eq!(violation.message(), "unexpected");
431 assert_eq!(
432 err.to_string(),
433 "Protocol violation: unexpected; received packet with type `00110000`"
434 );
435 }
436
437 #[test]
438 fn test_mqtt_error_from_io_and_encode() {
439 let io_err = io::Error::other("io");
440 let err: MqttError<()> = io_err.into();
441 match err {
442 MqttError::Handshake(HandshakeError::Disconnected(Some(err))) => {
443 assert_eq!(err.kind(), io::ErrorKind::Other);
444 }
445 _ => panic!("expected disconnected handshake error"),
446 }
447
448 let err: MqttError<()> = EncodeError::MalformedPacket.into();
449 assert!(matches!(
450 err,
451 MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(
452 EncodeError::MalformedPacket
453 )))
454 ));
455 }
456
457 #[test]
458 fn test_handshake_error_from_decode_or_io() {
459 let err: HandshakeError<()> = Either::Left(DecodeError::MalformedPacket).into();
460 assert!(matches!(
461 err,
462 HandshakeError::Protocol(ProtocolError::Decode(DecodeError::MalformedPacket))
463 ));
464
465 let err: HandshakeError<()> = Either::Right(io::Error::other("peer")).into();
466 match err {
467 HandshakeError::Disconnected(Some(err)) => {
468 assert_eq!(err.kind(), io::ErrorKind::Other);
469 }
470 _ => panic!("expected disconnected handshake error"),
471 }
472 }
473
474 #[test]
475 fn test_client_error_from_decode_or_io_and_encode() {
476 let err: ClientError<()> = Either::Left(DecodeError::InvalidLength).into();
477 assert!(matches!(
478 err,
479 ClientError::Protocol(ProtocolError::Decode(DecodeError::InvalidLength))
480 ));
481
482 let err: ClientError<()> = Either::Right(io::Error::other("peer")).into();
483 match err {
484 ClientError::Disconnected(Some(err)) => {
485 assert_eq!(err.kind(), io::ErrorKind::Other);
486 }
487 _ => panic!("expected disconnected client error"),
488 }
489
490 let err: ClientError<()> = EncodeError::UnexpectedPayload.into();
491 assert!(matches!(
492 err,
493 ClientError::Protocol(ProtocolError::Encode(EncodeError::UnexpectedPayload))
494 ));
495 }
496}