mqtt_v5/
decoder.rs

1use crate::types::{
2    properties::*, AuthenticatePacket, AuthenticateReason, ConnectAckPacket, ConnectPacket,
3    ConnectReason, DecodeError, DisconnectPacket, DisconnectReason, FinalWill, Packet, PacketType,
4    ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket,
5    PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason,
6    PublishReleasePacket, PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket,
7    SubscribeAckReason, SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket,
8    UnsubscribeAckReason, UnsubscribePacket, VariableByteInt,
9};
10use bytes::{Buf, Bytes, BytesMut};
11use std::{convert::TryFrom, io::Cursor};
12
13macro_rules! return_if_none {
14    ($x: expr) => {{
15        let string_opt = $x;
16        if string_opt.is_none() {
17            return Ok(None);
18        }
19
20        string_opt.unwrap()
21    }};
22}
23
24macro_rules! require_length {
25    ($bytes: expr, $len: expr) => {{
26        if $bytes.remaining() < $len {
27            return Ok(None);
28        }
29    }};
30}
31
32macro_rules! read_u8 {
33    ($bytes: expr) => {{
34        if !$bytes.has_remaining() {
35            return Ok(None);
36        }
37
38        $bytes.get_u8()
39    }};
40}
41
42macro_rules! read_u16 {
43    ($bytes: expr) => {{
44        if $bytes.remaining() < 2 {
45            return Ok(None);
46        }
47
48        $bytes.get_u16()
49    }};
50}
51
52macro_rules! read_u32 {
53    ($bytes: expr) => {{
54        if $bytes.remaining() < 4 {
55            return Ok(None);
56        }
57
58        $bytes.get_u32()
59    }};
60}
61
62macro_rules! read_variable_int {
63    ($bytes: expr) => {{
64        return_if_none!(decode_variable_int($bytes)?)
65    }};
66}
67
68macro_rules! read_string {
69    ($bytes: expr) => {{
70        return_if_none!(decode_string($bytes)?)
71    }};
72}
73
74macro_rules! read_binary_data {
75    ($bytes: expr) => {{
76        return_if_none!(decode_binary_data($bytes)?)
77    }};
78}
79
80macro_rules! read_string_pair {
81    ($bytes: expr) => {{
82        let string_key = read_string!($bytes);
83        let string_value = read_string!($bytes);
84
85        (string_key, string_value)
86    }};
87}
88
89macro_rules! read_property {
90    ($bytes: expr) => {{
91        let property_id = read_variable_int!($bytes);
92        return_if_none!(decode_property(property_id, $bytes)?)
93    }};
94}
95
96fn decode_variable_int(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<u32>, DecodeError> {
97    let mut multiplier = 1;
98    let mut value: u32 = 0;
99
100    loop {
101        let encoded_byte = read_u8!(bytes);
102
103        value += ((encoded_byte & 0b0111_1111) as u32) * multiplier;
104
105        if multiplier > (128 * 128 * 128) {
106            return Err(DecodeError::InvalidRemainingLength);
107        }
108
109        multiplier *= 128;
110
111        if encoded_byte & 0b1000_0000 == 0b0000_0000 {
112            break;
113        }
114    }
115
116    Ok(Some(value))
117}
118
119fn decode_string(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<String>, DecodeError> {
120    let str_size_bytes = read_u16!(bytes) as usize;
121
122    require_length!(bytes, str_size_bytes);
123
124    let position = bytes.position() as usize;
125
126    // TODO - Use Cow<str> and from_utf8_lossy later for less copying
127    match String::from_utf8(bytes.get_ref()[position..(position + str_size_bytes)].into()) {
128        Ok(string) => {
129            bytes.advance(str_size_bytes);
130            Ok(Some(string))
131        },
132        Err(_) => Err(DecodeError::InvalidUtf8),
133    }
134}
135
136fn decode_binary_data(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Bytes>, DecodeError> {
137    let data_size_bytes = read_u16!(bytes) as usize;
138    require_length!(bytes, data_size_bytes);
139
140    let position = bytes.position() as usize;
141
142    let payload_bytes =
143        BytesMut::from(&bytes.get_ref()[position..(position + data_size_bytes)]).freeze();
144    let result = Ok(Some(payload_bytes));
145    bytes.advance(data_size_bytes);
146
147    result
148}
149
150fn decode_binary_data_with_size(
151    bytes: &mut Cursor<&mut BytesMut>,
152    size: usize,
153) -> Result<Option<Bytes>, DecodeError> {
154    require_length!(bytes, size);
155
156    let position = bytes.position() as usize;
157    let payload_bytes = BytesMut::from(&bytes.get_ref()[position..(position + size)]).freeze();
158    let result = Ok(Some(payload_bytes));
159    bytes.advance(size);
160
161    result
162}
163
164fn decode_property(
165    property_id: u32,
166    bytes: &mut Cursor<&mut BytesMut>,
167) -> Result<Option<Property>, DecodeError> {
168    let property_type =
169        PropertyType::try_from(property_id).map_err(|_| DecodeError::InvalidPropertyId)?;
170
171    match property_type {
172        PropertyType::PayloadFormatIndicator => {
173            let format_indicator = read_u8!(bytes);
174            Ok(Some(Property::PayloadFormatIndicator(PayloadFormatIndicator(format_indicator))))
175        },
176        PropertyType::MessageExpiryInterval => {
177            let message_expiry_interval = read_u32!(bytes);
178            Ok(Some(Property::MessageExpiryInterval(MessageExpiryInterval(
179                message_expiry_interval,
180            ))))
181        },
182        PropertyType::ContentType => {
183            let content_type = read_string!(bytes);
184            Ok(Some(Property::ContentType(ContentType(content_type))))
185        },
186        PropertyType::ResponseTopic => {
187            let response_topic = read_string!(bytes);
188            Ok(Some(Property::ResponseTopic(ResponseTopic(response_topic))))
189        },
190        PropertyType::CorrelationData => {
191            let correlation_data = read_binary_data!(bytes);
192            Ok(Some(Property::CorrelationData(CorrelationData(correlation_data))))
193        },
194        PropertyType::SubscriptionIdentifier => {
195            let subscription_identifier = read_u32!(bytes);
196            Ok(Some(Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(
197                subscription_identifier,
198            )))))
199        },
200        PropertyType::SessionExpiryInterval => {
201            let session_expiry_interval = read_u32!(bytes);
202            Ok(Some(Property::SessionExpiryInterval(SessionExpiryInterval(
203                session_expiry_interval,
204            ))))
205        },
206        PropertyType::AssignedClientIdentifier => {
207            let assigned_client_identifier = read_string!(bytes);
208            Ok(Some(Property::AssignedClientIdentifier(AssignedClientIdentifier(
209                assigned_client_identifier,
210            ))))
211        },
212        PropertyType::ServerKeepAlive => {
213            let server_keep_alive = read_u16!(bytes);
214            Ok(Some(Property::ServerKeepAlive(ServerKeepAlive(server_keep_alive))))
215        },
216        PropertyType::AuthenticationMethod => {
217            let authentication_method = read_string!(bytes);
218            Ok(Some(Property::AuthenticationMethod(AuthenticationMethod(authentication_method))))
219        },
220        PropertyType::AuthenticationData => {
221            let authentication_data = read_binary_data!(bytes);
222            Ok(Some(Property::AuthenticationData(AuthenticationData(authentication_data))))
223        },
224        PropertyType::RequestProblemInformation => {
225            let request_problem_information = read_u8!(bytes);
226            Ok(Some(Property::RequestProblemInformation(RequestProblemInformation(
227                request_problem_information,
228            ))))
229        },
230        PropertyType::WillDelayInterval => {
231            let will_delay_interval = read_u32!(bytes);
232            Ok(Some(Property::WillDelayInterval(WillDelayInterval(will_delay_interval))))
233        },
234        PropertyType::RequestResponseInformation => {
235            let request_response_information = read_u8!(bytes);
236            Ok(Some(Property::RequestResponseInformation(RequestResponseInformation(
237                request_response_information,
238            ))))
239        },
240        PropertyType::ResponseInformation => {
241            let response_information = read_string!(bytes);
242            Ok(Some(Property::ResponseInformation(ResponseInformation(response_information))))
243        },
244        PropertyType::ServerReference => {
245            let server_reference = read_string!(bytes);
246            Ok(Some(Property::ServerReference(ServerReference(server_reference))))
247        },
248        PropertyType::ReasonString => {
249            let reason_string = read_string!(bytes);
250            Ok(Some(Property::ReasonString(ReasonString(reason_string))))
251        },
252        PropertyType::ReceiveMaximum => {
253            let receive_maximum = read_u16!(bytes);
254            Ok(Some(Property::ReceiveMaximum(ReceiveMaximum(receive_maximum))))
255        },
256        PropertyType::TopicAliasMaximum => {
257            let topic_alias_maximum = read_u16!(bytes);
258            Ok(Some(Property::TopicAliasMaximum(TopicAliasMaximum(topic_alias_maximum))))
259        },
260        PropertyType::TopicAlias => {
261            let topic_alias = read_u16!(bytes);
262            Ok(Some(Property::TopicAlias(TopicAlias(topic_alias))))
263        },
264        PropertyType::MaximumQos => {
265            let qos_byte = read_u8!(bytes);
266            let qos = QoS::try_from(qos_byte).map_err(|_| DecodeError::InvalidQoS)?;
267
268            Ok(Some(Property::MaximumQos(MaximumQos(qos))))
269        },
270        PropertyType::RetainAvailable => {
271            let retain_available = read_u8!(bytes);
272            Ok(Some(Property::RetainAvailable(RetainAvailable(retain_available))))
273        },
274        PropertyType::UserProperty => {
275            let (key, value) = read_string_pair!(bytes);
276            Ok(Some(Property::UserProperty(UserProperty(key, value))))
277        },
278        PropertyType::MaximumPacketSize => {
279            let maximum_packet_size = read_u32!(bytes);
280            Ok(Some(Property::MaximumPacketSize(MaximumPacketSize(maximum_packet_size))))
281        },
282        PropertyType::WildcardSubscriptionAvailable => {
283            let wildcard_subscription_available = read_u8!(bytes);
284            Ok(Some(Property::WildcardSubscriptionAvailable(WildcardSubscriptionAvailable(
285                wildcard_subscription_available,
286            ))))
287        },
288        PropertyType::SubscriptionIdentifierAvailable => {
289            let subscription_identifier_available = read_u8!(bytes);
290            Ok(Some(Property::SubscriptionIdentifierAvailable(SubscriptionIdentifierAvailable(
291                subscription_identifier_available,
292            ))))
293        },
294        PropertyType::SharedSubscriptionAvailable => {
295            let shared_subscription_available = read_u8!(bytes);
296            Ok(Some(Property::SharedSubscriptionAvailable(SharedSubscriptionAvailable(
297                shared_subscription_available,
298            ))))
299        },
300    }
301}
302
303fn decode_properties<F: FnMut(Property)>(
304    bytes: &mut Cursor<&mut BytesMut>,
305    mut closure: F,
306) -> Result<Option<()>, DecodeError> {
307    let property_length = read_variable_int!(bytes);
308
309    if property_length == 0 {
310        return Ok(Some(()));
311    }
312
313    require_length!(bytes, property_length as usize);
314
315    let start_cursor_pos = bytes.position();
316
317    loop {
318        let cursor_pos = bytes.position();
319
320        if cursor_pos - start_cursor_pos >= property_length as u64 {
321            break;
322        }
323
324        let property = read_property!(bytes);
325        closure(property);
326    }
327
328    Ok(Some(()))
329}
330
331fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, DecodeError> {
332    let protocol_name = read_string!(bytes);
333    let protocol_level = read_u8!(bytes);
334    let connect_flags = read_u8!(bytes);
335    let keep_alive = read_u16!(bytes);
336
337    let protocol_version = ProtocolVersion::try_from(protocol_level)
338        .map_err(|_| DecodeError::InvalidProtocolVersion)?;
339
340    let mut session_expiry_interval = None;
341    let mut receive_maximum = None;
342    let mut maximum_packet_size = None;
343    let mut topic_alias_maximum = None;
344    let mut request_response_information = None;
345    let mut request_problem_information = None;
346    let mut user_properties = vec![];
347    let mut authentication_method = None;
348    let mut authentication_data = None;
349
350    if protocol_version == ProtocolVersion::V500 {
351        return_if_none!(decode_properties(bytes, |property| {
352            match property {
353                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
354                Property::ReceiveMaximum(p) => receive_maximum = Some(p),
355                Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
356                Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
357                Property::RequestResponseInformation(p) => request_response_information = Some(p),
358                Property::RequestProblemInformation(p) => request_problem_information = Some(p),
359                Property::UserProperty(p) => user_properties.push(p),
360                Property::AuthenticationMethod(p) => authentication_method = Some(p),
361                Property::AuthenticationData(p) => authentication_data = Some(p),
362                _ => {}, // Invalid property for packet
363            }
364        })?);
365    }
366
367    // Start payload
368    let clean_start = connect_flags & 0b0000_0010 == 0b0000_0010;
369    let has_will = connect_flags & 0b0000_0100 == 0b0000_0100;
370    let will_qos_val = (connect_flags & 0b0001_1000) >> 3;
371    let will_qos = QoS::try_from(will_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
372    let retain_will = connect_flags & 0b0010_0000 == 0b0010_0000;
373    let has_password = connect_flags & 0b0100_0000 == 0b0100_0000;
374    let has_user_name = connect_flags & 0b1000_0000 == 0b1000_0000;
375
376    let client_id = read_string!(bytes);
377
378    let will = if has_will {
379        let mut will_delay_interval = None;
380        let mut payload_format_indicator = None;
381        let mut message_expiry_interval = None;
382        let mut content_type = None;
383        let mut response_topic = None;
384        let mut correlation_data = None;
385        let mut user_properties = vec![];
386
387        if protocol_version == ProtocolVersion::V500 {
388            return_if_none!(decode_properties(bytes, |property| {
389                match property {
390                    Property::WillDelayInterval(p) => will_delay_interval = Some(p),
391                    Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
392                    Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
393                    Property::ContentType(p) => content_type = Some(p),
394                    Property::ResponseTopic(p) => response_topic = Some(p),
395                    Property::CorrelationData(p) => correlation_data = Some(p),
396                    Property::UserProperty(p) => user_properties.push(p),
397                    _ => {}, // Invalid property for packet
398                }
399            })?);
400        }
401
402        let topic = read_string!(bytes);
403        let payload = read_binary_data!(bytes);
404
405        Some(FinalWill {
406            topic,
407            payload,
408            qos: will_qos,
409            should_retain: retain_will,
410            will_delay_interval,
411            payload_format_indicator,
412            message_expiry_interval,
413            content_type,
414            response_topic,
415            correlation_data,
416            user_properties,
417        })
418    } else {
419        None
420    };
421
422    let mut user_name = None;
423    let mut password = None;
424
425    if has_user_name {
426        user_name = Some(read_string!(bytes));
427    }
428
429    if has_password {
430        password = Some(read_string!(bytes));
431    }
432
433    let packet = ConnectPacket {
434        protocol_name,
435        protocol_version,
436        clean_start,
437        keep_alive,
438        session_expiry_interval,
439        receive_maximum,
440        maximum_packet_size,
441        topic_alias_maximum,
442        request_response_information,
443        request_problem_information,
444        user_properties,
445        authentication_method,
446        authentication_data,
447        client_id,
448        will,
449        user_name,
450        password,
451    };
452
453    Ok(Some(Packet::Connect(packet)))
454}
455
456fn decode_connect_ack(
457    bytes: &mut Cursor<&mut BytesMut>,
458    protocol_version: ProtocolVersion,
459) -> Result<Option<Packet>, DecodeError> {
460    let flags = read_u8!(bytes);
461    let session_present = (flags & 0b0000_0001) == 0b0000_0001;
462
463    let reason_code_byte = read_u8!(bytes);
464    let reason_code =
465        ConnectReason::try_from(reason_code_byte).map_err(|_| DecodeError::InvalidConnectReason)?;
466
467    let mut session_expiry_interval = None;
468    let mut receive_maximum = None;
469    let mut maximum_qos = None;
470    let mut retain_available = None;
471    let mut maximum_packet_size = None;
472    let mut assigned_client_identifier = None;
473    let mut topic_alias_maximum = None;
474    let mut reason_string = None;
475    let mut user_properties = vec![];
476    let mut wildcard_subscription_available = None;
477    let mut subscription_identifiers_available = None;
478    let mut shared_subscription_available = None;
479    let mut server_keep_alive = None;
480    let mut response_information = None;
481    let mut server_reference = None;
482    let mut authentication_method = None;
483    let mut authentication_data = None;
484
485    if protocol_version == ProtocolVersion::V500 {
486        return_if_none!(decode_properties(bytes, |property| {
487            match property {
488                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
489                Property::ReceiveMaximum(p) => receive_maximum = Some(p),
490                Property::MaximumQos(p) => maximum_qos = Some(p),
491                Property::RetainAvailable(p) => retain_available = Some(p),
492                Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
493                Property::AssignedClientIdentifier(p) => assigned_client_identifier = Some(p),
494                Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
495                Property::ReasonString(p) => reason_string = Some(p),
496                Property::UserProperty(p) => user_properties.push(p),
497                Property::WildcardSubscriptionAvailable(p) => {
498                    wildcard_subscription_available = Some(p)
499                },
500                Property::SubscriptionIdentifierAvailable(p) => {
501                    subscription_identifiers_available = Some(p)
502                },
503                Property::SharedSubscriptionAvailable(p) => shared_subscription_available = Some(p),
504                Property::ServerKeepAlive(p) => server_keep_alive = Some(p),
505                Property::ResponseInformation(p) => response_information = Some(p),
506                Property::ServerReference(p) => server_reference = Some(p),
507                Property::AuthenticationMethod(p) => authentication_method = Some(p),
508                Property::AuthenticationData(p) => authentication_data = Some(p),
509                _ => {}, // Invalid property for packet
510            }
511        })?);
512    }
513
514    let packet = ConnectAckPacket {
515        session_present,
516        reason_code,
517        session_expiry_interval,
518        receive_maximum,
519        maximum_qos,
520        retain_available,
521        maximum_packet_size,
522        assigned_client_identifier,
523        topic_alias_maximum,
524        reason_string,
525        user_properties,
526        wildcard_subscription_available,
527        subscription_identifiers_available,
528        shared_subscription_available,
529        server_keep_alive,
530        response_information,
531        server_reference,
532        authentication_method,
533        authentication_data,
534    };
535
536    Ok(Some(Packet::ConnectAck(packet)))
537}
538
539fn decode_publish(
540    bytes: &mut Cursor<&mut BytesMut>,
541    first_byte: u8,
542    remaining_packet_length: u32,
543    protocol_version: ProtocolVersion,
544) -> Result<Option<Packet>, DecodeError> {
545    let is_duplicate = (first_byte & 0b0000_1000) == 0b0000_1000;
546    let qos_val = (first_byte & 0b0000_0110) >> 1;
547    let qos = QoS::try_from(qos_val).map_err(|_| DecodeError::InvalidQoS)?;
548    let retain = (first_byte & 0b0000_0001) == 0b0000_0001;
549
550    // Variable header start
551    let start_cursor_pos = bytes.position();
552
553    let topic_str = read_string!(bytes);
554    let topic = topic_str.parse().map_err(DecodeError::InvalidTopic)?;
555
556    let packet_id = match qos {
557        QoS::AtMostOnce => None,
558        QoS::AtLeastOnce | QoS::ExactlyOnce => Some(read_u16!(bytes)),
559    };
560
561    let mut payload_format_indicator = None;
562    let mut message_expiry_interval = None;
563    let mut topic_alias = None;
564    let mut response_topic = None;
565    let mut correlation_data = None;
566    let mut user_properties = vec![];
567    let mut subscription_identifier = None;
568    let mut content_type = None;
569
570    if protocol_version == ProtocolVersion::V500 {
571        return_if_none!(decode_properties(bytes, |property| {
572            match property {
573                Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
574                Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
575                Property::TopicAlias(p) => topic_alias = Some(p),
576                Property::ResponseTopic(p) => response_topic = Some(p),
577                Property::CorrelationData(p) => correlation_data = Some(p),
578                Property::UserProperty(p) => user_properties.push(p),
579                Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
580                Property::ContentType(p) => content_type = Some(p),
581                _ => {}, // Invalid property for packet
582            }
583        })?);
584    }
585
586    let end_cursor_pos = bytes.position();
587    let variable_header_size = (end_cursor_pos - start_cursor_pos) as u32;
588    // Variable header end
589
590    if remaining_packet_length < variable_header_size {
591        return Err(DecodeError::InvalidRemainingLength);
592    }
593    let payload_size = remaining_packet_length - variable_header_size;
594    let payload = return_if_none!(decode_binary_data_with_size(bytes, payload_size as usize)?);
595
596    let packet = PublishPacket {
597        is_duplicate,
598        qos,
599        retain,
600
601        topic,
602        packet_id,
603
604        payload_format_indicator,
605        message_expiry_interval,
606        topic_alias,
607        response_topic,
608        correlation_data,
609        user_properties,
610        subscription_identifier,
611        content_type,
612
613        payload,
614    };
615
616    Ok(Some(Packet::Publish(packet)))
617}
618
619fn decode_publish_ack(
620    bytes: &mut Cursor<&mut BytesMut>,
621    remaining_packet_length: u32,
622    protocol_version: ProtocolVersion,
623) -> Result<Option<Packet>, DecodeError> {
624    let packet_id = read_u16!(bytes);
625
626    if remaining_packet_length == 2 {
627        return Ok(Some(Packet::PublishAck(PublishAckPacket {
628            packet_id,
629            reason_code: PublishAckReason::Success,
630            reason_string: None,
631            user_properties: vec![],
632        })));
633    }
634
635    let reason_code_byte = read_u8!(bytes);
636    let reason_code = PublishAckReason::try_from(reason_code_byte)
637        .map_err(|_| DecodeError::InvalidPublishAckReason)?;
638
639    let mut reason_string = None;
640    let mut user_properties = vec![];
641
642    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
643        return_if_none!(decode_properties(bytes, |property| {
644            match property {
645                Property::ReasonString(p) => reason_string = Some(p),
646                Property::UserProperty(p) => user_properties.push(p),
647                _ => {}, // Invalid property for packet
648            }
649        })?);
650    }
651
652    let packet = PublishAckPacket { packet_id, reason_code, reason_string, user_properties };
653
654    Ok(Some(Packet::PublishAck(packet)))
655}
656
657fn decode_publish_received(
658    bytes: &mut Cursor<&mut BytesMut>,
659    remaining_packet_length: u32,
660    protocol_version: ProtocolVersion,
661) -> Result<Option<Packet>, DecodeError> {
662    let packet_id = read_u16!(bytes);
663
664    if remaining_packet_length == 2 {
665        return Ok(Some(Packet::PublishReceived(PublishReceivedPacket {
666            packet_id,
667            reason_code: PublishReceivedReason::Success,
668            reason_string: None,
669            user_properties: vec![],
670        })));
671    }
672
673    let reason_code_byte = read_u8!(bytes);
674    let reason_code = PublishReceivedReason::try_from(reason_code_byte)
675        .map_err(|_| DecodeError::InvalidPublishReceivedReason)?;
676
677    let mut reason_string = None;
678    let mut user_properties = vec![];
679
680    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
681        return_if_none!(decode_properties(bytes, |property| {
682            match property {
683                Property::ReasonString(p) => reason_string = Some(p),
684                Property::UserProperty(p) => user_properties.push(p),
685                _ => {}, // Invalid property for packet
686            }
687        })?);
688    }
689
690    let packet = PublishReceivedPacket { packet_id, reason_code, reason_string, user_properties };
691
692    Ok(Some(Packet::PublishReceived(packet)))
693}
694
695fn decode_publish_release(
696    bytes: &mut Cursor<&mut BytesMut>,
697    remaining_packet_length: u32,
698    protocol_version: ProtocolVersion,
699) -> Result<Option<Packet>, DecodeError> {
700    let packet_id = read_u16!(bytes);
701
702    if remaining_packet_length == 2 {
703        return Ok(Some(Packet::PublishRelease(PublishReleasePacket {
704            packet_id,
705            reason_code: PublishReleaseReason::Success,
706            reason_string: None,
707            user_properties: vec![],
708        })));
709    }
710
711    let reason_code_byte = read_u8!(bytes);
712    let reason_code = PublishReleaseReason::try_from(reason_code_byte)
713        .map_err(|_| DecodeError::InvalidPublishReleaseReason)?;
714
715    let mut reason_string = None;
716    let mut user_properties = vec![];
717
718    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
719        return_if_none!(decode_properties(bytes, |property| {
720            match property {
721                Property::ReasonString(p) => reason_string = Some(p),
722                Property::UserProperty(p) => user_properties.push(p),
723                _ => {}, // Invalid property for packet
724            }
725        })?);
726    }
727
728    let packet = PublishReleasePacket { packet_id, reason_code, reason_string, user_properties };
729
730    Ok(Some(Packet::PublishRelease(packet)))
731}
732
733fn decode_publish_complete(
734    bytes: &mut Cursor<&mut BytesMut>,
735    remaining_packet_length: u32,
736    protocol_version: ProtocolVersion,
737) -> Result<Option<Packet>, DecodeError> {
738    let packet_id = read_u16!(bytes);
739
740    if remaining_packet_length == 2 {
741        return Ok(Some(Packet::PublishComplete(PublishCompletePacket {
742            packet_id,
743            reason_code: PublishCompleteReason::Success,
744            reason_string: None,
745            user_properties: vec![],
746        })));
747    }
748
749    let reason_code_byte = read_u8!(bytes);
750    let reason_code = PublishCompleteReason::try_from(reason_code_byte)
751        .map_err(|_| DecodeError::InvalidPublishCompleteReason)?;
752
753    let mut reason_string = None;
754    let mut user_properties = vec![];
755
756    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
757        return_if_none!(decode_properties(bytes, |property| {
758            match property {
759                Property::ReasonString(p) => reason_string = Some(p),
760                Property::UserProperty(p) => user_properties.push(p),
761                _ => {}, // Invalid property for packet
762            }
763        })?);
764    }
765
766    let packet = PublishCompletePacket { packet_id, reason_code, reason_string, user_properties };
767
768    Ok(Some(Packet::PublishComplete(packet)))
769}
770
771fn decode_subscribe(
772    bytes: &mut Cursor<&mut BytesMut>,
773    remaining_packet_length: u32,
774    protocol_version: ProtocolVersion,
775) -> Result<Option<Packet>, DecodeError> {
776    let start_cursor_pos = bytes.position();
777
778    let packet_id = read_u16!(bytes);
779
780    let mut subscription_identifier = None;
781    let mut user_properties = vec![];
782
783    if protocol_version == ProtocolVersion::V500 {
784        return_if_none!(decode_properties(bytes, |property| {
785            match property {
786                Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
787                Property::UserProperty(p) => user_properties.push(p),
788                _ => {}, // Invalid property for packet
789            }
790        })?);
791    }
792
793    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
794    if remaining_packet_length < variable_header_size {
795        return Err(DecodeError::InvalidRemainingLength);
796    }
797    let payload_size = remaining_packet_length - variable_header_size;
798
799    let mut subscription_topics = vec![];
800    let mut bytes_read: usize = 0;
801
802    loop {
803        if bytes_read >= payload_size as usize {
804            break;
805        }
806
807        let start_cursor_pos = bytes.position();
808
809        let topic_filter_str = read_string!(bytes);
810        let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
811
812        let options_byte = read_u8!(bytes);
813
814        let maximum_qos_val = options_byte & 0b0000_0011;
815        let maximum_qos = QoS::try_from(maximum_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
816
817        let retain_handling_val = (options_byte & 0b0011_0000) >> 4;
818        let retain_handling = RetainHandling::try_from(retain_handling_val)
819            .map_err(|_| DecodeError::InvalidRetainHandling)?;
820
821        let retain_as_published = (options_byte & 0b0000_1000) == 0b0000_1000;
822        let no_local = (options_byte & 0b0000_0100) == 0b0000_0100;
823
824        let subscription_topic = SubscriptionTopic {
825            topic_filter,
826            maximum_qos,
827            no_local,
828            retain_as_published,
829            retain_handling,
830        };
831
832        subscription_topics.push(subscription_topic);
833
834        let end_cursor_pos = bytes.position();
835        bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
836    }
837
838    let packet = SubscribePacket {
839        packet_id,
840        subscription_identifier,
841        user_properties,
842        subscription_topics,
843    };
844
845    Ok(Some(Packet::Subscribe(packet)))
846}
847
848fn decode_subscribe_ack(
849    bytes: &mut Cursor<&mut BytesMut>,
850    remaining_packet_length: u32,
851    protocol_version: ProtocolVersion,
852) -> Result<Option<Packet>, DecodeError> {
853    let start_cursor_pos = bytes.position();
854
855    let packet_id = read_u16!(bytes);
856
857    let mut reason_string = None;
858    let mut user_properties = vec![];
859
860    if protocol_version == ProtocolVersion::V500 {
861        return_if_none!(decode_properties(bytes, |property| {
862            match property {
863                Property::ReasonString(p) => reason_string = Some(p),
864                Property::UserProperty(p) => user_properties.push(p),
865                _ => {}, // Invalid property for packet
866            }
867        })?);
868    }
869
870    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
871    if remaining_packet_length < variable_header_size {
872        return Err(DecodeError::InvalidRemainingLength);
873    }
874    let payload_size = remaining_packet_length - variable_header_size;
875
876    let mut reason_codes = vec![];
877    for _ in 0..payload_size {
878        let next_byte = read_u8!(bytes);
879        let reason_code = SubscribeAckReason::try_from(next_byte)
880            .map_err(|_| DecodeError::InvalidSubscribeAckReason)?;
881        reason_codes.push(reason_code);
882    }
883
884    let packet = SubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
885
886    Ok(Some(Packet::SubscribeAck(packet)))
887}
888
889fn decode_unsubscribe(
890    bytes: &mut Cursor<&mut BytesMut>,
891    remaining_packet_length: u32,
892    protocol_version: ProtocolVersion,
893) -> Result<Option<Packet>, DecodeError> {
894    let start_cursor_pos = bytes.position();
895
896    let packet_id = read_u16!(bytes);
897
898    let mut user_properties = vec![];
899
900    if protocol_version == ProtocolVersion::V500 {
901        return_if_none!(decode_properties(bytes, |property| {
902            if let Property::UserProperty(p) = property {
903                user_properties.push(p);
904            }
905        })?);
906    }
907
908    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
909    if remaining_packet_length < variable_header_size {
910        return Err(DecodeError::InvalidRemainingLength);
911    }
912    let payload_size = remaining_packet_length - variable_header_size;
913
914    let mut topic_filters = vec![];
915    let mut bytes_read: usize = 0;
916
917    loop {
918        if bytes_read >= payload_size as usize {
919            break;
920        }
921
922        let start_cursor_pos = bytes.position();
923
924        let topic_filter_str = read_string!(bytes);
925        let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
926        topic_filters.push(topic_filter);
927
928        let end_cursor_pos = bytes.position();
929        bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
930    }
931
932    let packet = UnsubscribePacket { packet_id, user_properties, topic_filters };
933
934    Ok(Some(Packet::Unsubscribe(packet)))
935}
936
937fn decode_unsubscribe_ack(
938    bytes: &mut Cursor<&mut BytesMut>,
939    remaining_packet_length: u32,
940    protocol_version: ProtocolVersion,
941) -> Result<Option<Packet>, DecodeError> {
942    let start_cursor_pos = bytes.position();
943
944    let packet_id = read_u16!(bytes);
945
946    let mut reason_string = None;
947    let mut user_properties = vec![];
948
949    if protocol_version == ProtocolVersion::V500 {
950        return_if_none!(decode_properties(bytes, |property| {
951            match property {
952                Property::ReasonString(p) => reason_string = Some(p),
953                Property::UserProperty(p) => user_properties.push(p),
954                _ => {}, // Invalid property for packet
955            }
956        })?);
957    }
958
959    let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
960    if remaining_packet_length < variable_header_size {
961        return Err(DecodeError::InvalidRemainingLength);
962    }
963    let payload_size = remaining_packet_length - variable_header_size;
964
965    let mut reason_codes = vec![];
966    for _ in 0..payload_size {
967        let next_byte = read_u8!(bytes);
968        let reason_code = UnsubscribeAckReason::try_from(next_byte)
969            .map_err(|_| DecodeError::InvalidUnsubscribeAckReason)?;
970        reason_codes.push(reason_code);
971    }
972
973    let packet = UnsubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
974
975    Ok(Some(Packet::UnsubscribeAck(packet)))
976}
977
978fn decode_disconnect(
979    bytes: &mut Cursor<&mut BytesMut>,
980    remaining_packet_length: u32,
981    protocol_version: ProtocolVersion,
982) -> Result<Option<Packet>, DecodeError> {
983    if remaining_packet_length == 0 {
984        return Ok(Some(Packet::Disconnect(DisconnectPacket {
985            reason_code: DisconnectReason::NormalDisconnection,
986            session_expiry_interval: None,
987            reason_string: None,
988            user_properties: vec![],
989            server_reference: None,
990        })));
991    }
992
993    let reason_code_byte = read_u8!(bytes);
994    let reason_code = DisconnectReason::try_from(reason_code_byte)
995        .map_err(|_| DecodeError::InvalidDisconnectReason)?;
996
997    let mut session_expiry_interval = None;
998    let mut reason_string = None;
999    let mut user_properties = vec![];
1000    let mut server_reference = None;
1001
1002    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1003        return_if_none!(decode_properties(bytes, |property| {
1004            match property {
1005                Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
1006                Property::ReasonString(p) => reason_string = Some(p),
1007                Property::UserProperty(p) => user_properties.push(p),
1008                Property::ServerReference(p) => server_reference = Some(p),
1009                _ => {}, // Invalid property for packet
1010            }
1011        })?);
1012    }
1013
1014    let packet = DisconnectPacket {
1015        reason_code,
1016        session_expiry_interval,
1017        reason_string,
1018        user_properties,
1019        server_reference,
1020    };
1021
1022    Ok(Some(Packet::Disconnect(packet)))
1023}
1024
1025fn decode_authenticate(
1026    bytes: &mut Cursor<&mut BytesMut>,
1027    remaining_packet_length: u32,
1028    protocol_version: ProtocolVersion,
1029) -> Result<Option<Packet>, DecodeError> {
1030    if remaining_packet_length == 0 {
1031        return Ok(Some(Packet::Authenticate(AuthenticatePacket {
1032            reason_code: AuthenticateReason::Success,
1033            authentication_method: None,
1034            authentication_data: None,
1035            reason_string: None,
1036            user_properties: vec![],
1037        })));
1038    }
1039
1040    let reason_code_byte = read_u8!(bytes);
1041    let reason_code = AuthenticateReason::try_from(reason_code_byte)
1042        .map_err(|_| DecodeError::InvalidAuthenticateReason)?;
1043
1044    let mut authentication_method = None;
1045    let mut authentication_data = None;
1046    let mut reason_string = None;
1047    let mut user_properties = vec![];
1048
1049    if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1050        return_if_none!(decode_properties(bytes, |property| {
1051            match property {
1052                Property::AuthenticationMethod(p) => authentication_method = Some(p),
1053                Property::AuthenticationData(p) => authentication_data = Some(p),
1054                Property::ReasonString(p) => reason_string = Some(p),
1055                Property::UserProperty(p) => user_properties.push(p),
1056                _ => {}, // Invalid property for packet
1057            }
1058        })?);
1059    }
1060
1061    let packet = AuthenticatePacket {
1062        reason_code,
1063        authentication_method,
1064        authentication_data,
1065        reason_string,
1066        user_properties,
1067    };
1068
1069    Ok(Some(Packet::Authenticate(packet)))
1070}
1071
1072fn decode_packet(
1073    protocol_version: ProtocolVersion,
1074    packet_type: &PacketType,
1075    bytes: &mut Cursor<&mut BytesMut>,
1076    remaining_packet_length: u32,
1077    first_byte: u8,
1078) -> Result<Option<Packet>, DecodeError> {
1079    match packet_type {
1080        PacketType::Connect => decode_connect(bytes),
1081        PacketType::ConnectAck => decode_connect_ack(bytes, protocol_version),
1082        PacketType::Publish => {
1083            decode_publish(bytes, first_byte, remaining_packet_length, protocol_version)
1084        },
1085        PacketType::PublishAck => {
1086            decode_publish_ack(bytes, remaining_packet_length, protocol_version)
1087        },
1088        PacketType::PublishReceived => {
1089            decode_publish_received(bytes, remaining_packet_length, protocol_version)
1090        },
1091        PacketType::PublishRelease => {
1092            decode_publish_release(bytes, remaining_packet_length, protocol_version)
1093        },
1094        PacketType::PublishComplete => {
1095            decode_publish_complete(bytes, remaining_packet_length, protocol_version)
1096        },
1097        PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length, protocol_version),
1098        PacketType::SubscribeAck => {
1099            decode_subscribe_ack(bytes, remaining_packet_length, protocol_version)
1100        },
1101        PacketType::Unsubscribe => {
1102            decode_unsubscribe(bytes, remaining_packet_length, protocol_version)
1103        },
1104        PacketType::UnsubscribeAck => {
1105            decode_unsubscribe_ack(bytes, remaining_packet_length, protocol_version)
1106        },
1107        PacketType::PingRequest => Ok(Some(Packet::PingRequest)),
1108        PacketType::PingResponse => Ok(Some(Packet::PingResponse)),
1109        PacketType::Disconnect => {
1110            decode_disconnect(bytes, remaining_packet_length, protocol_version)
1111        },
1112        PacketType::Authenticate => {
1113            decode_authenticate(bytes, remaining_packet_length, protocol_version)
1114        },
1115    }
1116}
1117
1118pub fn decode_mqtt(
1119    bytes: &mut BytesMut,
1120    protocol_version: ProtocolVersion,
1121) -> Result<Option<Packet>, DecodeError> {
1122    let mut bytes = Cursor::new(bytes);
1123    let first_byte = read_u8!(bytes);
1124
1125    let first_byte_val = (first_byte & 0b1111_0000) >> 4;
1126    let packet_type =
1127        PacketType::try_from(first_byte_val).map_err(|_| DecodeError::InvalidPacketType)?;
1128    let remaining_packet_length = read_variable_int!(&mut bytes);
1129
1130    let cursor_pos = bytes.position() as usize;
1131    let remaining_buffer_amount = bytes.get_ref().len() - cursor_pos;
1132
1133    if remaining_buffer_amount < remaining_packet_length as usize {
1134        // If we don't have the full payload, just bail
1135        return Ok(None);
1136    }
1137
1138    let packet = return_if_none!(decode_packet(
1139        protocol_version,
1140        &packet_type,
1141        &mut bytes,
1142        remaining_packet_length,
1143        first_byte
1144    )?);
1145
1146    let cursor_pos = bytes.position() as usize;
1147    let bytes = bytes.into_inner();
1148
1149    let _rest = bytes.split_to(cursor_pos);
1150
1151    Ok(Some(packet))
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156    use crate::{decoder::*, types::*};
1157    use bytes::BytesMut;
1158
1159    #[test]
1160    fn test_invalid_remaining_length() {
1161        let mut bytes = BytesMut::new();
1162        bytes.extend_from_slice(&[136, 1, 0, 36, 0, 0]); // Discovered from fuzz test
1163
1164        let _ = decode_mqtt(&mut bytes, ProtocolVersion::V500);
1165    }
1166
1167    #[test]
1168    fn test_decode_variable_int() {
1169        // TODO - Maybe it would be better to add an abnormal system test.
1170
1171        fn normal_test(encoded_variable_int: &[u8], expected_variable_int: u32) {
1172            let bytes = &mut BytesMut::new();
1173            bytes.extend_from_slice(encoded_variable_int);
1174            match decode_variable_int(&mut Cursor::new(bytes)) {
1175                Ok(val) => match val {
1176                    Some(get_variable_int) => assert_eq!(get_variable_int, expected_variable_int),
1177                    None => panic!("variable_int is None"),
1178                },
1179                Err(err) => panic!(err),
1180            }
1181        }
1182
1183        // Digits 1
1184        normal_test(&[0x00], 0);
1185        normal_test(&[0x7F], 127);
1186
1187        // Digits 2
1188        normal_test(&[0x80, 0x01], 128);
1189        normal_test(&[0xFF, 0x7F], 16383);
1190
1191        // Digits 3
1192        normal_test(&[0x80, 0x80, 0x01], 16384);
1193        normal_test(&[0xFF, 0xFF, 0x7F], 2097151);
1194
1195        // Digits 4
1196        normal_test(&[0x80, 0x80, 0x80, 0x01], 2097152);
1197        normal_test(&[0xFF, 0xFF, 0xFF, 0x7F], 268435455);
1198    }
1199}