mqtt_v5_fork/
decoder.rs

1use crate::{
2    topic::{Topic, TopicParseError},
3    types::{
4        properties::*, AuthenticatePacket, AuthenticateReason, ConnectAckPacket, ConnectPacket,
5        ConnectReason, DecodeError, DisconnectPacket, DisconnectReason, FinalWill, Packet,
6        PacketType, ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket,
7        PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason,
8        PublishReleasePacket, PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket,
9        SubscribeAckReason, SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket,
10        UnsubscribeAckReason, UnsubscribePacket, VariableByteInt,
11    },
12};
13use bytes::{Buf, Bytes, BytesMut};
14use std::{convert::TryFrom, io::Cursor, str::FromStr};
15
16macro_rules! return_if_none {
17    ($x: expr) => {{
18        let string_opt = $x;
19        if string_opt.is_none() {
20            return Ok(None);
21        }
22
23        string_opt.unwrap()
24    }};
25}
26
27macro_rules! require_length {
28    ($bytes: expr, $len: expr) => {{
29        if $bytes.remaining() < $len {
30            return Ok(None);
31        }
32    }};
33}
34
35macro_rules! read_u8 {
36    ($bytes: expr) => {{
37        if !$bytes.has_remaining() {
38            return Ok(None);
39        }
40
41        $bytes.get_u8()
42    }};
43}
44
45macro_rules! read_u16 {
46    ($bytes: expr) => {{
47        if $bytes.remaining() < 2 {
48            return Ok(None);
49        }
50
51        $bytes.get_u16()
52    }};
53}
54
55macro_rules! read_u32 {
56    ($bytes: expr) => {{
57        if $bytes.remaining() < 4 {
58            return Ok(None);
59        }
60
61        $bytes.get_u32()
62    }};
63}
64
65macro_rules! read_variable_int {
66    ($bytes: expr) => {{
67        return_if_none!(decode_variable_int($bytes)?)
68    }};
69}
70
71macro_rules! read_string {
72    ($bytes: expr) => {{
73        return_if_none!(decode_string($bytes)?)
74    }};
75}
76
77macro_rules! read_binary_data {
78    ($bytes: expr) => {{
79        return_if_none!(decode_binary_data($bytes)?)
80    }};
81}
82
83macro_rules! read_string_pair {
84    ($bytes: expr) => {{
85        let string_key = read_string!($bytes);
86        let string_value = read_string!($bytes);
87
88        (string_key, string_value)
89    }};
90}
91
92macro_rules! read_property {
93    ($bytes: expr) => {{
94        let property_id = read_variable_int!($bytes);
95        return_if_none!(decode_property(property_id, $bytes)?)
96    }};
97}
98
99fn decode_variable_int(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<u32>, DecodeError> {
100    let mut multiplier: u32 = 1;
101    let mut value: u32 = 0;
102
103    for _ in 0..4 {
104        let encoded_byte = read_u8!(bytes);
105
106        value += ((encoded_byte & 0b0111_1111) as u32) * multiplier;
107
108        multiplier *= 128;
109
110        if encoded_byte & 0b1000_0000 == 0b0000_0000 {
111            return Ok(Some(value));
112        }
113    }
114
115    Err(DecodeError::InvalidRemainingLength)
116}
117
118fn decode_string(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<String>, DecodeError> {
119    let str_size_bytes = read_u16!(bytes) as usize;
120
121    require_length!(bytes, str_size_bytes);
122
123    let position = bytes.position() as usize;
124
125    // TODO - Use Cow<str> and from_utf8_lossy later for less copying
126    match String::from_utf8(bytes.get_ref()[position..(position + str_size_bytes)].into()) {
127        Ok(string) => {
128            bytes.advance(str_size_bytes);
129            Ok(Some(string))
130        },
131        Err(_) => Err(DecodeError::InvalidUtf8),
132    }
133}
134
135fn decode_binary_data(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Bytes>, DecodeError> {
136    let data_size_bytes = read_u16!(bytes) as usize;
137    require_length!(bytes, data_size_bytes);
138
139    let position = bytes.position() as usize;
140
141    let payload_bytes =
142        BytesMut::from(&bytes.get_ref()[position..(position + data_size_bytes)]).freeze();
143    let result = Ok(Some(payload_bytes));
144    bytes.advance(data_size_bytes);
145
146    result
147}
148
149fn decode_binary_data_with_size(
150    bytes: &mut Cursor<&mut BytesMut>,
151    size: usize,
152) -> Result<Option<Bytes>, DecodeError> {
153    require_length!(bytes, size);
154
155    let position = bytes.position() as usize;
156    let payload_bytes = BytesMut::from(&bytes.get_ref()[position..(position + size)]).freeze();
157    let result = Ok(Some(payload_bytes));
158    bytes.advance(size);
159
160    result
161}
162
163fn decode_property(
164    property_id: u32,
165    bytes: &mut Cursor<&mut BytesMut>,
166) -> Result<Option<Property>, DecodeError> {
167    let property_type =
168        PropertyType::try_from(property_id).map_err(|_| DecodeError::InvalidPropertyId)?;
169
170    match property_type {
171        PropertyType::PayloadFormatIndicator => {
172            let format_indicator = read_u8!(bytes);
173            Ok(Some(Property::PayloadFormatIndicator(PayloadFormatIndicator(format_indicator))))
174        },
175        PropertyType::MessageExpiryInterval => {
176            let message_expiry_interval = read_u32!(bytes);
177            Ok(Some(Property::MessageExpiryInterval(MessageExpiryInterval(
178                message_expiry_interval,
179            ))))
180        },
181        PropertyType::ContentType => {
182            let content_type = read_string!(bytes);
183            Ok(Some(Property::ContentType(ContentType(content_type))))
184        },
185        PropertyType::ResponseTopic => {
186            let response_topic = read_string!(bytes);
187            Ok(Some(Property::ResponseTopic(ResponseTopic(response_topic))))
188        },
189        PropertyType::CorrelationData => {
190            let correlation_data = read_binary_data!(bytes);
191            Ok(Some(Property::CorrelationData(CorrelationData(correlation_data))))
192        },
193        PropertyType::SubscriptionIdentifier => {
194            let subscription_identifier = read_variable_int!(bytes);
195            Ok(Some(Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(
196                subscription_identifier,
197            )))))
198        },
199        PropertyType::SessionExpiryInterval => {
200            let session_expiry_interval = read_u32!(bytes);
201            Ok(Some(Property::SessionExpiryInterval(SessionExpiryInterval(
202                session_expiry_interval,
203            ))))
204        },
205        PropertyType::AssignedClientIdentifier => {
206            let assigned_client_identifier = read_string!(bytes);
207            Ok(Some(Property::AssignedClientIdentifier(AssignedClientIdentifier(
208                assigned_client_identifier,
209            ))))
210        },
211        PropertyType::ServerKeepAlive => {
212            let server_keep_alive = read_u16!(bytes);
213            Ok(Some(Property::ServerKeepAlive(ServerKeepAlive(server_keep_alive))))
214        },
215        PropertyType::AuthenticationMethod => {
216            let authentication_method = read_string!(bytes);
217            Ok(Some(Property::AuthenticationMethod(AuthenticationMethod(authentication_method))))
218        },
219        PropertyType::AuthenticationData => {
220            let authentication_data = read_binary_data!(bytes);
221            Ok(Some(Property::AuthenticationData(AuthenticationData(authentication_data))))
222        },
223        PropertyType::RequestProblemInformation => {
224            let request_problem_information = read_u8!(bytes);
225            Ok(Some(Property::RequestProblemInformation(RequestProblemInformation(
226                request_problem_information,
227            ))))
228        },
229        PropertyType::WillDelayInterval => {
230            let will_delay_interval = read_u32!(bytes);
231            Ok(Some(Property::WillDelayInterval(WillDelayInterval(will_delay_interval))))
232        },
233        PropertyType::RequestResponseInformation => {
234            let request_response_information = read_u8!(bytes);
235            Ok(Some(Property::RequestResponseInformation(RequestResponseInformation(
236                request_response_information,
237            ))))
238        },
239        PropertyType::ResponseInformation => {
240            let response_information = read_string!(bytes);
241            Ok(Some(Property::ResponseInformation(ResponseInformation(response_information))))
242        },
243        PropertyType::ServerReference => {
244            let server_reference = read_string!(bytes);
245            Ok(Some(Property::ServerReference(ServerReference(server_reference))))
246        },
247        PropertyType::ReasonString => {
248            let reason_string = read_string!(bytes);
249            Ok(Some(Property::ReasonString(ReasonString(reason_string))))
250        },
251        PropertyType::ReceiveMaximum => {
252            let receive_maximum = read_u16!(bytes);
253            Ok(Some(Property::ReceiveMaximum(ReceiveMaximum(receive_maximum))))
254        },
255        PropertyType::TopicAliasMaximum => {
256            let topic_alias_maximum = read_u16!(bytes);
257            Ok(Some(Property::TopicAliasMaximum(TopicAliasMaximum(topic_alias_maximum))))
258        },
259        PropertyType::TopicAlias => {
260            let topic_alias = read_u16!(bytes);
261            Ok(Some(Property::TopicAlias(TopicAlias(topic_alias))))
262        },
263        PropertyType::MaximumQos => {
264            let qos_byte = read_u8!(bytes);
265            let qos = QoS::try_from(qos_byte).map_err(|_| DecodeError::InvalidQoS)?;
266
267            Ok(Some(Property::MaximumQos(MaximumQos(qos))))
268        },
269        PropertyType::RetainAvailable => {
270            let retain_available = read_u8!(bytes);
271            Ok(Some(Property::RetainAvailable(RetainAvailable(retain_available))))
272        },
273        PropertyType::UserProperty => {
274            let (key, value) = read_string_pair!(bytes);
275            Ok(Some(Property::UserProperty(UserProperty(key, value))))
276        },
277        PropertyType::MaximumPacketSize => {
278            let maximum_packet_size = read_u32!(bytes);
279            Ok(Some(Property::MaximumPacketSize(MaximumPacketSize(maximum_packet_size))))
280        },
281        PropertyType::WildcardSubscriptionAvailable => {
282            let wildcard_subscription_available = read_u8!(bytes);
283            Ok(Some(Property::WildcardSubscriptionAvailable(WildcardSubscriptionAvailable(
284                wildcard_subscription_available,
285            ))))
286        },
287        PropertyType::SubscriptionIdentifierAvailable => {
288            let subscription_identifier_available = read_u8!(bytes);
289            Ok(Some(Property::SubscriptionIdentifierAvailable(SubscriptionIdentifierAvailable(
290                subscription_identifier_available,
291            ))))
292        },
293        PropertyType::SharedSubscriptionAvailable => {
294            let shared_subscription_available = read_u8!(bytes);
295            Ok(Some(Property::SharedSubscriptionAvailable(SharedSubscriptionAvailable(
296                shared_subscription_available,
297            ))))
298        },
299    }
300}
301
302fn decode_properties<F: FnMut(Property)>(
303    bytes: &mut Cursor<&mut BytesMut>,
304    mut closure: F,
305) -> Result<Option<()>, DecodeError> {
306    try_decode_properties(bytes, |property| {
307        closure(property);
308        Ok(())
309    })
310}
311
312fn try_decode_properties<F: FnMut(Property) -> Result<(), DecodeError>>(
313    bytes: &mut Cursor<&mut BytesMut>,
314    mut closure: F,
315) -> Result<Option<()>, DecodeError> {
316    let property_length = read_variable_int!(bytes);
317
318    if property_length == 0 {
319        return Ok(Some(()));
320    }
321
322    require_length!(bytes, property_length as usize);
323
324    let start_cursor_pos = bytes.position();
325
326    loop {
327        let cursor_pos = bytes.position();
328
329        if cursor_pos - start_cursor_pos >= property_length as u64 {
330            break;
331        }
332
333        let property = read_property!(bytes);
334        closure(property)?;
335    }
336
337    Ok(Some(()))
338}
339
340fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, DecodeError> {
341    let protocol_name = read_string!(bytes);
342    let protocol_level = read_u8!(bytes);
343    let connect_flags = read_u8!(bytes);
344    let keep_alive = read_u16!(bytes);
345
346    let protocol_version = ProtocolVersion::try_from(protocol_level)
347        .map_err(|_| DecodeError::InvalidProtocolVersion)?;
348
349    let mut session_expiry_interval = None;
350    let mut receive_maximum = None;
351    let mut maximum_packet_size = None;
352    let mut topic_alias_maximum = None;
353    let mut request_response_information = None;
354    let mut request_problem_information = None;
355    let mut user_properties = vec![];
356    let mut authentication_method = None;
357    let mut authentication_data = None;
358
359    if protocol_version == ProtocolVersion::V500 {
360        return_if_none!(decode_properties(bytes, |property| {
361            match property {
362                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
363                Property::ReceiveMaximum(p) => receive_maximum = Some(p),
364                Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
365                Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
366                Property::RequestResponseInformation(p) => request_response_information = Some(p),
367                Property::RequestProblemInformation(p) => request_problem_information = Some(p),
368                Property::UserProperty(p) => user_properties.push(p),
369                Property::AuthenticationMethod(p) => authentication_method = Some(p),
370                Property::AuthenticationData(p) => authentication_data = Some(p),
371                _ => {}, // Invalid property for packet
372            }
373        })?);
374    }
375
376    // Start payload
377    let clean_start = connect_flags & 0b0000_0010 == 0b0000_0010;
378    let has_will = connect_flags & 0b0000_0100 == 0b0000_0100;
379    let will_qos_val = (connect_flags & 0b0001_1000) >> 3;
380    let will_qos = QoS::try_from(will_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
381    let retain_will = connect_flags & 0b0010_0000 == 0b0010_0000;
382    let has_password = connect_flags & 0b0100_0000 == 0b0100_0000;
383    let has_user_name = connect_flags & 0b1000_0000 == 0b1000_0000;
384
385    let client_id = read_string!(bytes);
386
387    let will = if has_will {
388        let mut will_delay_interval = None;
389        let mut payload_format_indicator = None;
390        let mut message_expiry_interval = None;
391        let mut content_type = None;
392        let mut response_topic = None;
393        let mut correlation_data = None;
394        let mut user_properties = vec![];
395
396        if protocol_version == ProtocolVersion::V500 {
397            return_if_none!(decode_properties(bytes, |property| {
398                match property {
399                    Property::WillDelayInterval(p) => will_delay_interval = Some(p),
400                    Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
401                    Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
402                    Property::ContentType(p) => content_type = Some(p),
403                    Property::ResponseTopic(p) => response_topic = Some(p),
404                    Property::CorrelationData(p) => correlation_data = Some(p),
405                    Property::UserProperty(p) => user_properties.push(p),
406                    _ => {}, // Invalid property for packet
407                }
408            })?);
409        }
410
411        let topic = Topic::from_str(read_string!(bytes).as_str())
412            .map_err(|_| DecodeError::InvalidTopicFilter(TopicParseError::TopicTooLong))?;
413        let payload = read_binary_data!(bytes);
414
415        Some(FinalWill {
416            topic,
417            payload,
418            qos: will_qos,
419            should_retain: retain_will,
420            will_delay_interval,
421            payload_format_indicator,
422            message_expiry_interval,
423            content_type,
424            response_topic,
425            correlation_data,
426            user_properties,
427        })
428    } else {
429        None
430    };
431
432    let mut user_name = None;
433    let mut password = None;
434
435    if has_user_name {
436        user_name = Some(read_string!(bytes));
437    }
438
439    if has_password {
440        password = Some(read_string!(bytes));
441    }
442
443    let packet = ConnectPacket {
444        protocol_name,
445        protocol_version,
446        clean_start,
447        keep_alive,
448        session_expiry_interval,
449        receive_maximum,
450        maximum_packet_size,
451        topic_alias_maximum,
452        request_response_information,
453        request_problem_information,
454        user_properties,
455        authentication_method,
456        authentication_data,
457        client_id,
458        will,
459        user_name,
460        password,
461    };
462
463    Ok(Some(Packet::Connect(packet)))
464}
465
466fn decode_connect_ack(
467    bytes: &mut Cursor<&mut BytesMut>,
468    protocol_version: ProtocolVersion,
469) -> Result<Option<Packet>, DecodeError> {
470    let flags = read_u8!(bytes);
471    let session_present = (flags & 0b0000_0001) == 0b0000_0001;
472
473    let reason_code_byte = read_u8!(bytes);
474    let reason_code =
475        ConnectReason::try_from(reason_code_byte).map_err(|_| DecodeError::InvalidConnectReason)?;
476
477    let mut session_expiry_interval = None;
478    let mut receive_maximum = None;
479    let mut maximum_qos = None;
480    let mut retain_available = None;
481    let mut maximum_packet_size = None;
482    let mut assigned_client_identifier = None;
483    let mut topic_alias_maximum = None;
484    let mut reason_string = None;
485    let mut user_properties = vec![];
486    let mut wildcard_subscription_available = None;
487    let mut subscription_identifiers_available = None;
488    let mut shared_subscription_available = None;
489    let mut server_keep_alive = None;
490    let mut response_information = None;
491    let mut server_reference = None;
492    let mut authentication_method = None;
493    let mut authentication_data = None;
494
495    if protocol_version == ProtocolVersion::V500 {
496        return_if_none!(decode_properties(bytes, |property| {
497            match property {
498                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
499                Property::ReceiveMaximum(p) => receive_maximum = Some(p),
500                Property::MaximumQos(p) => maximum_qos = Some(p),
501                Property::RetainAvailable(p) => retain_available = Some(p),
502                Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
503                Property::AssignedClientIdentifier(p) => assigned_client_identifier = Some(p),
504                Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
505                Property::ReasonString(p) => reason_string = Some(p),
506                Property::UserProperty(p) => user_properties.push(p),
507                Property::WildcardSubscriptionAvailable(p) => {
508                    wildcard_subscription_available = Some(p)
509                },
510                Property::SubscriptionIdentifierAvailable(p) => {
511                    subscription_identifiers_available = Some(p)
512                },
513                Property::SharedSubscriptionAvailable(p) => shared_subscription_available = Some(p),
514                Property::ServerKeepAlive(p) => server_keep_alive = Some(p),
515                Property::ResponseInformation(p) => response_information = Some(p),
516                Property::ServerReference(p) => server_reference = Some(p),
517                Property::AuthenticationMethod(p) => authentication_method = Some(p),
518                Property::AuthenticationData(p) => authentication_data = Some(p),
519                _ => {}, // Invalid property for packet
520            }
521        })?);
522    }
523
524    let packet = ConnectAckPacket {
525        session_present,
526        reason_code,
527        session_expiry_interval,
528        receive_maximum,
529        maximum_qos,
530        retain_available,
531        maximum_packet_size,
532        assigned_client_identifier,
533        topic_alias_maximum,
534        reason_string,
535        user_properties,
536        wildcard_subscription_available,
537        subscription_identifiers_available,
538        shared_subscription_available,
539        server_keep_alive,
540        response_information,
541        server_reference,
542        authentication_method,
543        authentication_data,
544    };
545
546    Ok(Some(Packet::ConnectAck(packet)))
547}
548
549fn decode_publish(
550    bytes: &mut Cursor<&mut BytesMut>,
551    first_byte: u8,
552    remaining_packet_length: u32,
553    protocol_version: ProtocolVersion,
554) -> Result<Option<Packet>, DecodeError> {
555    let is_duplicate = (first_byte & 0b0000_1000) == 0b0000_1000;
556    let qos_val = (first_byte & 0b0000_0110) >> 1;
557    let qos = QoS::try_from(qos_val).map_err(|_| DecodeError::InvalidQoS)?;
558    let retain = (first_byte & 0b0000_0001) == 0b0000_0001;
559
560    // Variable header start
561    let start_cursor_pos = bytes.position();
562
563    let topic_str = read_string!(bytes);
564    let topic = topic_str.parse().map_err(DecodeError::InvalidTopic)?;
565
566    let packet_id = match qos {
567        QoS::AtMostOnce => None,
568        QoS::AtLeastOnce | QoS::ExactlyOnce => Some(read_u16!(bytes)),
569    };
570
571    let mut payload_format_indicator = None;
572    let mut message_expiry_interval = None;
573    let mut topic_alias = None;
574    let mut response_topic = None;
575    let mut correlation_data = None;
576    let mut user_properties = vec![];
577    let mut subscription_identifiers = None;
578    let mut content_type = None;
579
580    if protocol_version == ProtocolVersion::V500 {
581        try_decode_properties(bytes, |property| match property {
582            Property::PayloadFormatIndicator(p) => {
583                payload_format_indicator = Some(p);
584                Ok(())
585            },
586            Property::MessageExpiryInterval(p) => {
587                message_expiry_interval = Some(p);
588                Ok(())
589            },
590            Property::TopicAlias(p) => {
591                topic_alias = Some(p);
592                Ok(())
593            },
594            Property::ResponseTopic(p) => {
595                response_topic = Some(p);
596                Ok(())
597            },
598            Property::CorrelationData(p) => {
599                correlation_data = Some(p);
600                Ok(())
601            },
602            Property::UserProperty(p) => {
603                user_properties.push(p);
604                Ok(())
605            },
606            Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => {
607                Err(DecodeError::InvalidSubscriptionIdentifier)
608            },
609            Property::SubscriptionIdentifier(p) => {
610                subscription_identifiers.get_or_insert(Vec::new()).push(p);
611                Ok(())
612            },
613            Property::ContentType(p) => {
614                content_type = Some(p);
615                Ok(())
616            },
617            _ => Err(DecodeError::InvalidPropertyForPacket),
618        })?;
619    }
620
621    let end_cursor_pos = bytes.position();
622    let variable_header_size = (end_cursor_pos - start_cursor_pos) as u32;
623    // Variable header end
624
625    if remaining_packet_length < variable_header_size {
626        return Err(DecodeError::InvalidRemainingLength);
627    }
628    let payload_size = remaining_packet_length - variable_header_size;
629    let payload = return_if_none!(decode_binary_data_with_size(bytes, payload_size as usize)?);
630    let subscription_identifiers =
631        subscription_identifiers.unwrap_or_else(|| Vec::with_capacity(0));
632
633    let packet = PublishPacket {
634        is_duplicate,
635        qos,
636        retain,
637
638        topic,
639        packet_id,
640
641        payload_format_indicator,
642        message_expiry_interval,
643        topic_alias,
644        response_topic,
645        correlation_data,
646        user_properties,
647        subscription_identifiers,
648        content_type,
649
650        payload,
651    };
652
653    Ok(Some(Packet::Publish(packet)))
654}
655
656fn decode_publish_ack(
657    bytes: &mut Cursor<&mut BytesMut>,
658    remaining_packet_length: u32,
659    protocol_version: ProtocolVersion,
660) -> Result<Option<Packet>, DecodeError> {
661    let packet_id = read_u16!(bytes);
662
663    if remaining_packet_length == 2 {
664        return Ok(Some(Packet::PublishAck(PublishAckPacket {
665            packet_id,
666            reason_code: PublishAckReason::Success,
667            reason_string: None,
668            user_properties: vec![],
669        })));
670    }
671
672    let reason_code_byte = read_u8!(bytes);
673    let reason_code = PublishAckReason::try_from(reason_code_byte)
674        .map_err(|_| DecodeError::InvalidPublishAckReason)?;
675
676    let mut reason_string = None;
677    let mut user_properties = vec![];
678
679    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
680        return_if_none!(decode_properties(bytes, |property| {
681            match property {
682                Property::ReasonString(p) => reason_string = Some(p),
683                Property::UserProperty(p) => user_properties.push(p),
684                _ => {}, // Invalid property for packet
685            }
686        })?);
687    }
688
689    let packet = PublishAckPacket { packet_id, reason_code, reason_string, user_properties };
690
691    Ok(Some(Packet::PublishAck(packet)))
692}
693
694fn decode_publish_received(
695    bytes: &mut Cursor<&mut BytesMut>,
696    remaining_packet_length: u32,
697    protocol_version: ProtocolVersion,
698) -> Result<Option<Packet>, DecodeError> {
699    let packet_id = read_u16!(bytes);
700
701    if remaining_packet_length == 2 {
702        return Ok(Some(Packet::PublishReceived(PublishReceivedPacket {
703            packet_id,
704            reason_code: PublishReceivedReason::Success,
705            reason_string: None,
706            user_properties: vec![],
707        })));
708    }
709
710    let reason_code_byte = read_u8!(bytes);
711    let reason_code = PublishReceivedReason::try_from(reason_code_byte)
712        .map_err(|_| DecodeError::InvalidPublishReceivedReason)?;
713
714    let mut reason_string = None;
715    let mut user_properties = vec![];
716
717    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
718        return_if_none!(decode_properties(bytes, |property| {
719            match property {
720                Property::ReasonString(p) => reason_string = Some(p),
721                Property::UserProperty(p) => user_properties.push(p),
722                _ => {}, // Invalid property for packet
723            }
724        })?);
725    }
726
727    let packet = PublishReceivedPacket { packet_id, reason_code, reason_string, user_properties };
728
729    Ok(Some(Packet::PublishReceived(packet)))
730}
731
732fn decode_publish_release(
733    bytes: &mut Cursor<&mut BytesMut>,
734    remaining_packet_length: u32,
735    protocol_version: ProtocolVersion,
736) -> Result<Option<Packet>, DecodeError> {
737    let packet_id = read_u16!(bytes);
738
739    if remaining_packet_length == 2 {
740        return Ok(Some(Packet::PublishRelease(PublishReleasePacket {
741            packet_id,
742            reason_code: PublishReleaseReason::Success,
743            reason_string: None,
744            user_properties: vec![],
745        })));
746    }
747
748    let reason_code_byte = read_u8!(bytes);
749    let reason_code = PublishReleaseReason::try_from(reason_code_byte)
750        .map_err(|_| DecodeError::InvalidPublishReleaseReason)?;
751
752    let mut reason_string = None;
753    let mut user_properties = vec![];
754
755    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
756        return_if_none!(decode_properties(bytes, |property| {
757            match property {
758                Property::ReasonString(p) => reason_string = Some(p),
759                Property::UserProperty(p) => user_properties.push(p),
760                _ => {}, // Invalid property for packet
761            }
762        })?);
763    }
764
765    let packet = PublishReleasePacket { packet_id, reason_code, reason_string, user_properties };
766
767    Ok(Some(Packet::PublishRelease(packet)))
768}
769
770fn decode_publish_complete(
771    bytes: &mut Cursor<&mut BytesMut>,
772    remaining_packet_length: u32,
773    protocol_version: ProtocolVersion,
774) -> Result<Option<Packet>, DecodeError> {
775    let packet_id = read_u16!(bytes);
776
777    if remaining_packet_length == 2 {
778        return Ok(Some(Packet::PublishComplete(PublishCompletePacket {
779            packet_id,
780            reason_code: PublishCompleteReason::Success,
781            reason_string: None,
782            user_properties: vec![],
783        })));
784    }
785
786    let reason_code_byte = read_u8!(bytes);
787    let reason_code = PublishCompleteReason::try_from(reason_code_byte)
788        .map_err(|_| DecodeError::InvalidPublishCompleteReason)?;
789
790    let mut reason_string = None;
791    let mut user_properties = vec![];
792
793    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
794        return_if_none!(decode_properties(bytes, |property| {
795            match property {
796                Property::ReasonString(p) => reason_string = Some(p),
797                Property::UserProperty(p) => user_properties.push(p),
798                _ => {}, // Invalid property for packet
799            }
800        })?);
801    }
802
803    let packet = PublishCompletePacket { packet_id, reason_code, reason_string, user_properties };
804
805    Ok(Some(Packet::PublishComplete(packet)))
806}
807
808fn decode_subscribe(
809    bytes: &mut Cursor<&mut BytesMut>,
810    remaining_packet_length: u32,
811    protocol_version: ProtocolVersion,
812) -> Result<Option<Packet>, DecodeError> {
813    let start_cursor_pos = bytes.position();
814
815    let packet_id = read_u16!(bytes);
816
817    let mut subscription_identifier = None;
818    let mut user_properties = vec![];
819
820    if protocol_version == ProtocolVersion::V500 {
821        try_decode_properties(bytes, |property| {
822            match property {
823                // [MQTT-3.8.2.1.2] The subscription identifier is allowed exactly once
824                Property::SubscriptionIdentifier(_) if subscription_identifier.is_some() => {
825                    Err(DecodeError::InvalidSubscriptionIdentifier)
826                },
827                // [MQTT-3.8.2.1.2] The subscription identifier must not be 0
828                Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => {
829                    Err(DecodeError::InvalidSubscriptionIdentifier)
830                },
831                Property::SubscriptionIdentifier(p) => {
832                    subscription_identifier = Some(p);
833                    Ok(())
834                },
835                Property::UserProperty(p) => {
836                    user_properties.push(p);
837                    Ok(())
838                },
839                _ => Err(DecodeError::InvalidPropertyForPacket),
840            }
841        })?;
842    }
843
844    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
845    if remaining_packet_length < variable_header_size {
846        return Err(DecodeError::InvalidRemainingLength);
847    }
848    let payload_size = remaining_packet_length - variable_header_size;
849
850    let mut subscription_topics = vec![];
851    let mut bytes_read: usize = 0;
852
853    loop {
854        if bytes_read >= payload_size as usize {
855            break;
856        }
857
858        let start_cursor_pos = bytes.position();
859
860        let topic_filter_str = read_string!(bytes);
861        let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
862
863        let options_byte = read_u8!(bytes);
864
865        let maximum_qos_val = options_byte & 0b0000_0011;
866        let maximum_qos = QoS::try_from(maximum_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
867
868        let retain_handling_val = (options_byte & 0b0011_0000) >> 4;
869        let retain_handling = RetainHandling::try_from(retain_handling_val)
870            .map_err(|_| DecodeError::InvalidRetainHandling)?;
871
872        let retain_as_published = (options_byte & 0b0000_1000) == 0b0000_1000;
873        let no_local = (options_byte & 0b0000_0100) == 0b0000_0100;
874
875        let subscription_topic = SubscriptionTopic {
876            topic_filter,
877            maximum_qos,
878            no_local,
879            retain_as_published,
880            retain_handling,
881        };
882
883        subscription_topics.push(subscription_topic);
884
885        let end_cursor_pos = bytes.position();
886        bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
887    }
888
889    let packet = SubscribePacket {
890        packet_id,
891        subscription_identifier,
892        user_properties,
893        subscription_topics,
894    };
895
896    Ok(Some(Packet::Subscribe(packet)))
897}
898
899fn decode_subscribe_ack(
900    bytes: &mut Cursor<&mut BytesMut>,
901    remaining_packet_length: u32,
902    protocol_version: ProtocolVersion,
903) -> Result<Option<Packet>, DecodeError> {
904    let start_cursor_pos = bytes.position();
905
906    let packet_id = read_u16!(bytes);
907
908    let mut reason_string = None;
909    let mut user_properties = vec![];
910
911    if protocol_version == ProtocolVersion::V500 {
912        return_if_none!(decode_properties(bytes, |property| {
913            match property {
914                Property::ReasonString(p) => reason_string = Some(p),
915                Property::UserProperty(p) => user_properties.push(p),
916                _ => {}, // Invalid property for packet
917            }
918        })?);
919    }
920
921    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
922    if remaining_packet_length < variable_header_size {
923        return Err(DecodeError::InvalidRemainingLength);
924    }
925    let payload_size = remaining_packet_length - variable_header_size;
926
927    let mut reason_codes = vec![];
928    for _ in 0..payload_size {
929        let next_byte = read_u8!(bytes);
930        let reason_code = SubscribeAckReason::try_from(next_byte)
931            .map_err(|_| DecodeError::InvalidSubscribeAckReason)?;
932        reason_codes.push(reason_code);
933    }
934
935    let packet = SubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
936
937    Ok(Some(Packet::SubscribeAck(packet)))
938}
939
940fn decode_unsubscribe(
941    bytes: &mut Cursor<&mut BytesMut>,
942    remaining_packet_length: u32,
943    protocol_version: ProtocolVersion,
944) -> Result<Option<Packet>, DecodeError> {
945    let start_cursor_pos = bytes.position();
946
947    let packet_id = read_u16!(bytes);
948
949    let mut user_properties = vec![];
950
951    if protocol_version == ProtocolVersion::V500 {
952        return_if_none!(decode_properties(bytes, |property| {
953            if let Property::UserProperty(p) = property {
954                user_properties.push(p);
955            }
956        })?);
957    }
958
959    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
960    if remaining_packet_length < variable_header_size {
961        return Err(DecodeError::InvalidRemainingLength);
962    }
963    let payload_size = remaining_packet_length - variable_header_size;
964
965    let mut topic_filters = vec![];
966    let mut bytes_read: usize = 0;
967
968    loop {
969        if bytes_read >= payload_size as usize {
970            break;
971        }
972
973        let start_cursor_pos = bytes.position();
974
975        let topic_filter_str = read_string!(bytes);
976        let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
977        topic_filters.push(topic_filter);
978
979        let end_cursor_pos = bytes.position();
980        bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
981    }
982
983    let packet = UnsubscribePacket { packet_id, user_properties, topic_filters };
984
985    Ok(Some(Packet::Unsubscribe(packet)))
986}
987
988fn decode_unsubscribe_ack(
989    bytes: &mut Cursor<&mut BytesMut>,
990    remaining_packet_length: u32,
991    protocol_version: ProtocolVersion,
992) -> Result<Option<Packet>, DecodeError> {
993    let start_cursor_pos = bytes.position();
994
995    let packet_id = read_u16!(bytes);
996
997    let mut reason_string = None;
998    let mut user_properties = vec![];
999
1000    if protocol_version == ProtocolVersion::V500 {
1001        return_if_none!(decode_properties(bytes, |property| {
1002            match property {
1003                Property::ReasonString(p) => reason_string = Some(p),
1004                Property::UserProperty(p) => user_properties.push(p),
1005                _ => {}, // Invalid property for packet
1006            }
1007        })?);
1008    }
1009
1010    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
1011    if remaining_packet_length < variable_header_size {
1012        return Err(DecodeError::InvalidRemainingLength);
1013    }
1014    let payload_size = remaining_packet_length - variable_header_size;
1015
1016    let mut reason_codes = vec![];
1017    for _ in 0..payload_size {
1018        let next_byte = read_u8!(bytes);
1019        let reason_code = UnsubscribeAckReason::try_from(next_byte)
1020            .map_err(|_| DecodeError::InvalidUnsubscribeAckReason)?;
1021        reason_codes.push(reason_code);
1022    }
1023
1024    let packet = UnsubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
1025
1026    Ok(Some(Packet::UnsubscribeAck(packet)))
1027}
1028
1029fn decode_disconnect(
1030    bytes: &mut Cursor<&mut BytesMut>,
1031    remaining_packet_length: u32,
1032    protocol_version: ProtocolVersion,
1033) -> Result<Option<Packet>, DecodeError> {
1034    if remaining_packet_length == 0 {
1035        return Ok(Some(Packet::Disconnect(DisconnectPacket {
1036            reason_code: DisconnectReason::NormalDisconnection,
1037            session_expiry_interval: None,
1038            reason_string: None,
1039            user_properties: vec![],
1040            server_reference: None,
1041        })));
1042    }
1043
1044    let reason_code_byte = read_u8!(bytes);
1045    let reason_code = DisconnectReason::try_from(reason_code_byte)
1046        .map_err(|_| DecodeError::InvalidDisconnectReason)?;
1047
1048    let mut session_expiry_interval = None;
1049    let mut reason_string = None;
1050    let mut user_properties = vec![];
1051    let mut server_reference = None;
1052
1053    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1054        return_if_none!(decode_properties(bytes, |property| {
1055            match property {
1056                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
1057                Property::ReasonString(p) => reason_string = Some(p),
1058                Property::UserProperty(p) => user_properties.push(p),
1059                Property::ServerReference(p) => server_reference = Some(p),
1060                _ => {}, // Invalid property for packet
1061            }
1062        })?);
1063    }
1064
1065    let packet = DisconnectPacket {
1066        reason_code,
1067        session_expiry_interval,
1068        reason_string,
1069        user_properties,
1070        server_reference,
1071    };
1072
1073    Ok(Some(Packet::Disconnect(packet)))
1074}
1075
1076fn decode_authenticate(
1077    bytes: &mut Cursor<&mut BytesMut>,
1078    remaining_packet_length: u32,
1079    protocol_version: ProtocolVersion,
1080) -> Result<Option<Packet>, DecodeError> {
1081    if remaining_packet_length == 0 {
1082        return Ok(Some(Packet::Authenticate(AuthenticatePacket {
1083            reason_code: AuthenticateReason::Success,
1084            authentication_method: None,
1085            authentication_data: None,
1086            reason_string: None,
1087            user_properties: vec![],
1088        })));
1089    }
1090
1091    let reason_code_byte = read_u8!(bytes);
1092    let reason_code = AuthenticateReason::try_from(reason_code_byte)
1093        .map_err(|_| DecodeError::InvalidAuthenticateReason)?;
1094
1095    let mut authentication_method = None;
1096    let mut authentication_data = None;
1097    let mut reason_string = None;
1098    let mut user_properties = vec![];
1099
1100    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1101        return_if_none!(decode_properties(bytes, |property| {
1102            match property {
1103                Property::AuthenticationMethod(p) => authentication_method = Some(p),
1104                Property::AuthenticationData(p) => authentication_data = Some(p),
1105                Property::ReasonString(p) => reason_string = Some(p),
1106                Property::UserProperty(p) => user_properties.push(p),
1107                _ => {}, // Invalid property for packet
1108            }
1109        })?);
1110    }
1111
1112    let packet = AuthenticatePacket {
1113        reason_code,
1114        authentication_method,
1115        authentication_data,
1116        reason_string,
1117        user_properties,
1118    };
1119
1120    Ok(Some(Packet::Authenticate(packet)))
1121}
1122
1123fn decode_packet(
1124    protocol_version: ProtocolVersion,
1125    packet_type: &PacketType,
1126    bytes: &mut Cursor<&mut BytesMut>,
1127    remaining_packet_length: u32,
1128    first_byte: u8,
1129) -> Result<Option<Packet>, DecodeError> {
1130    match packet_type {
1131        PacketType::Connect => decode_connect(bytes),
1132        PacketType::ConnectAck => decode_connect_ack(bytes, protocol_version),
1133        PacketType::Publish => {
1134            decode_publish(bytes, first_byte, remaining_packet_length, protocol_version)
1135        },
1136        PacketType::PublishAck => {
1137            decode_publish_ack(bytes, remaining_packet_length, protocol_version)
1138        },
1139        PacketType::PublishReceived => {
1140            decode_publish_received(bytes, remaining_packet_length, protocol_version)
1141        },
1142        PacketType::PublishRelease => {
1143            decode_publish_release(bytes, remaining_packet_length, protocol_version)
1144        },
1145        PacketType::PublishComplete => {
1146            decode_publish_complete(bytes, remaining_packet_length, protocol_version)
1147        },
1148        PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length, protocol_version),
1149        PacketType::SubscribeAck => {
1150            decode_subscribe_ack(bytes, remaining_packet_length, protocol_version)
1151        },
1152        PacketType::Unsubscribe => {
1153            decode_unsubscribe(bytes, remaining_packet_length, protocol_version)
1154        },
1155        PacketType::UnsubscribeAck => {
1156            decode_unsubscribe_ack(bytes, remaining_packet_length, protocol_version)
1157        },
1158        PacketType::PingRequest => Ok(Some(Packet::PingRequest)),
1159        PacketType::PingResponse => Ok(Some(Packet::PingResponse)),
1160        PacketType::Disconnect => {
1161            decode_disconnect(bytes, remaining_packet_length, protocol_version)
1162        },
1163        PacketType::Authenticate => {
1164            decode_authenticate(bytes, remaining_packet_length, protocol_version)
1165        },
1166    }
1167}
1168
1169pub fn decode_mqtt(
1170    bytes: &mut BytesMut,
1171    protocol_version: ProtocolVersion,
1172) -> Result<Option<Packet>, DecodeError> {
1173    let mut bytes = Cursor::new(bytes);
1174    let first_byte = read_u8!(bytes);
1175
1176    let first_byte_val = (first_byte & 0b1111_0000) >> 4;
1177    let packet_type =
1178        PacketType::try_from(first_byte_val).map_err(|_| DecodeError::InvalidPacketType)?;
1179    let remaining_packet_length = read_variable_int!(&mut bytes);
1180
1181    let cursor_pos = bytes.position() as usize;
1182    let remaining_buffer_amount = bytes.get_ref().len() - cursor_pos;
1183
1184    if remaining_buffer_amount < remaining_packet_length as usize {
1185        // If we don't have the full payload, just bail
1186        return Ok(None);
1187    }
1188
1189    let packet = return_if_none!(decode_packet(
1190        protocol_version,
1191        &packet_type,
1192        &mut bytes,
1193        remaining_packet_length,
1194        first_byte
1195    )?);
1196
1197    let cursor_pos = bytes.position() as usize;
1198    let bytes = bytes.into_inner();
1199
1200    let _rest = bytes.split_to(cursor_pos);
1201
1202    Ok(Some(packet))
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207    use crate::{decoder::*, topic::TopicFilter, types::*};
1208    use bytes::BytesMut;
1209
1210    #[test]
1211    fn test_invalid_remaining_length() {
1212        let mut bytes = BytesMut::new();
1213        bytes.extend_from_slice(&[136, 1, 0, 36, 0, 0]); // Discovered from fuzz test
1214
1215        let _ = decode_mqtt(&mut bytes, ProtocolVersion::V500);
1216    }
1217
1218    #[test]
1219    fn test_decode_variable_int() {
1220        // TODO - Maybe it would be better to add an abnormal system test.
1221
1222        fn normal_test(encoded_variable_int: &[u8], expected_variable_int: u32) {
1223            let bytes = &mut BytesMut::new();
1224            bytes.extend_from_slice(encoded_variable_int);
1225            match decode_variable_int(&mut Cursor::new(bytes)) {
1226                Ok(val) => match val {
1227                    Some(get_variable_int) => assert_eq!(get_variable_int, expected_variable_int),
1228                    None => panic!("variable_int is None"),
1229                },
1230                Err(err) => panic!("Error decoding variable int: {:?}", err),
1231            }
1232        }
1233
1234        // Digits 1
1235        normal_test(&[0x00], 0);
1236        normal_test(&[0x7F], 127);
1237
1238        // Digits 2
1239        normal_test(&[0x80, 0x01], 128);
1240        normal_test(&[0xFF, 0x7F], 16383);
1241
1242        // Digits 3
1243        normal_test(&[0x80, 0x80, 0x01], 16384);
1244        normal_test(&[0xFF, 0xFF, 0x7F], 2097151);
1245
1246        // Digits 4
1247        normal_test(&[0x80, 0x80, 0x80, 0x01], 2097152);
1248        normal_test(&[0xFF, 0xFF, 0xFF, 0x7F], 268435455);
1249    }
1250
1251    #[test]
1252    fn test_decode_subscribe() {
1253        // Subscribe packet *without* Subscription Identifier
1254        let mut without_subscription_identifier = BytesMut::from(
1255            [0x82, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00].as_slice(),
1256        );
1257        let without_subscription_identifier_expected = Packet::Subscribe(SubscribePacket {
1258            packet_id: 1,
1259            subscription_identifier: None,
1260            user_properties: vec![],
1261            subscription_topics: vec![SubscriptionTopic {
1262                topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 },
1263                maximum_qos: QoS::AtMostOnce,
1264                no_local: false,
1265                retain_as_published: false,
1266                retain_handling: RetainHandling::SendAtSubscribeTime,
1267            }],
1268        });
1269        let decoded = decode_mqtt(&mut without_subscription_identifier, ProtocolVersion::V500)
1270            .unwrap()
1271            .unwrap();
1272        assert_eq!(without_subscription_identifier_expected, decoded);
1273
1274        // Subscribe packet with Subscription Identifier
1275        let mut packet = BytesMut::from(
1276            [0x82, 0x0c, 0xff, 0xf6, 0x02, 0x0b, 0x01, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x02]
1277                .as_slice(),
1278        );
1279        let decoded = decode_mqtt(&mut packet, ProtocolVersion::V500).unwrap().unwrap();
1280        let with_subscription_identifier_expected = Packet::Subscribe(SubscribePacket {
1281            packet_id: 65526,
1282            subscription_identifier: Some(SubscriptionIdentifier(VariableByteInt(1))),
1283            user_properties: vec![],
1284            subscription_topics: vec![SubscriptionTopic {
1285                topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 },
1286                maximum_qos: QoS::ExactlyOnce,
1287                no_local: false,
1288                retain_as_published: false,
1289                retain_handling: RetainHandling::SendAtSubscribeTime,
1290            }],
1291        });
1292        assert_eq!(with_subscription_identifier_expected, decoded);
1293    }
1294    #[test]
1295    fn test_decode_variable_int_crash() {
1296        let number: u32 = u32::MAX;
1297        let result = decode_variable_int(&mut Cursor::new(&mut BytesMut::from(
1298            number.to_be_bytes().as_slice(),
1299        )));
1300
1301        assert!(matches!(result, Err(DecodeError::InvalidRemainingLength)));
1302    }
1303}