1use crate::{
2 topic::{Topic, TopicParseError},
3 types::{
4 properties::*, AuthenticatePacket, AuthenticateReason, ConnectAckPacket, ConnectPacket,
5 ConnectReason, DecodeError, DisconnectPacket, DisconnectReason, FinalWill, Packet,
6 PacketType, ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket,
7 PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason,
8 PublishReleasePacket, PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket,
9 SubscribeAckReason, SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket,
10 UnsubscribeAckReason, UnsubscribePacket, VariableByteInt,
11 },
12};
13use bytes::{Buf, Bytes, BytesMut};
14use std::{convert::TryFrom, io::Cursor, str::FromStr};
15
16macro_rules! return_if_none {
17 ($x: expr) => {{
18 let string_opt = $x;
19 if string_opt.is_none() {
20 return Ok(None);
21 }
22
23 string_opt.unwrap()
24 }};
25}
26
27macro_rules! require_length {
28 ($bytes: expr, $len: expr) => {{
29 if $bytes.remaining() < $len {
30 return Ok(None);
31 }
32 }};
33}
34
35macro_rules! read_u8 {
36 ($bytes: expr) => {{
37 if !$bytes.has_remaining() {
38 return Ok(None);
39 }
40
41 $bytes.get_u8()
42 }};
43}
44
45macro_rules! read_u16 {
46 ($bytes: expr) => {{
47 if $bytes.remaining() < 2 {
48 return Ok(None);
49 }
50
51 $bytes.get_u16()
52 }};
53}
54
55macro_rules! read_u32 {
56 ($bytes: expr) => {{
57 if $bytes.remaining() < 4 {
58 return Ok(None);
59 }
60
61 $bytes.get_u32()
62 }};
63}
64
65macro_rules! read_variable_int {
66 ($bytes: expr) => {{
67 return_if_none!(decode_variable_int($bytes)?)
68 }};
69}
70
71macro_rules! read_string {
72 ($bytes: expr) => {{
73 return_if_none!(decode_string($bytes)?)
74 }};
75}
76
77macro_rules! read_binary_data {
78 ($bytes: expr) => {{
79 return_if_none!(decode_binary_data($bytes)?)
80 }};
81}
82
83macro_rules! read_string_pair {
84 ($bytes: expr) => {{
85 let string_key = read_string!($bytes);
86 let string_value = read_string!($bytes);
87
88 (string_key, string_value)
89 }};
90}
91
92macro_rules! read_property {
93 ($bytes: expr) => {{
94 let property_id = read_variable_int!($bytes);
95 return_if_none!(decode_property(property_id, $bytes)?)
96 }};
97}
98
99fn decode_variable_int(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<u32>, DecodeError> {
100 let mut multiplier: u32 = 1;
101 let mut value: u32 = 0;
102
103 for _ in 0..4 {
104 let encoded_byte = read_u8!(bytes);
105
106 value += ((encoded_byte & 0b0111_1111) as u32) * multiplier;
107
108 multiplier *= 128;
109
110 if encoded_byte & 0b1000_0000 == 0b0000_0000 {
111 return Ok(Some(value));
112 }
113 }
114
115 Err(DecodeError::InvalidRemainingLength)
116}
117
118fn decode_string(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<String>, DecodeError> {
119 let str_size_bytes = read_u16!(bytes) as usize;
120
121 require_length!(bytes, str_size_bytes);
122
123 let position = bytes.position() as usize;
124
125 match String::from_utf8(bytes.get_ref()[position..(position + str_size_bytes)].into()) {
127 Ok(string) => {
128 bytes.advance(str_size_bytes);
129 Ok(Some(string))
130 },
131 Err(_) => Err(DecodeError::InvalidUtf8),
132 }
133}
134
135fn decode_binary_data(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Bytes>, DecodeError> {
136 let data_size_bytes = read_u16!(bytes) as usize;
137 require_length!(bytes, data_size_bytes);
138
139 let position = bytes.position() as usize;
140
141 let payload_bytes =
142 BytesMut::from(&bytes.get_ref()[position..(position + data_size_bytes)]).freeze();
143 let result = Ok(Some(payload_bytes));
144 bytes.advance(data_size_bytes);
145
146 result
147}
148
149fn decode_binary_data_with_size(
150 bytes: &mut Cursor<&mut BytesMut>,
151 size: usize,
152) -> Result<Option<Bytes>, DecodeError> {
153 require_length!(bytes, size);
154
155 let position = bytes.position() as usize;
156 let payload_bytes = BytesMut::from(&bytes.get_ref()[position..(position + size)]).freeze();
157 let result = Ok(Some(payload_bytes));
158 bytes.advance(size);
159
160 result
161}
162
163fn decode_property(
164 property_id: u32,
165 bytes: &mut Cursor<&mut BytesMut>,
166) -> Result<Option<Property>, DecodeError> {
167 let property_type =
168 PropertyType::try_from(property_id).map_err(|_| DecodeError::InvalidPropertyId)?;
169
170 match property_type {
171 PropertyType::PayloadFormatIndicator => {
172 let format_indicator = read_u8!(bytes);
173 Ok(Some(Property::PayloadFormatIndicator(PayloadFormatIndicator(format_indicator))))
174 },
175 PropertyType::MessageExpiryInterval => {
176 let message_expiry_interval = read_u32!(bytes);
177 Ok(Some(Property::MessageExpiryInterval(MessageExpiryInterval(
178 message_expiry_interval,
179 ))))
180 },
181 PropertyType::ContentType => {
182 let content_type = read_string!(bytes);
183 Ok(Some(Property::ContentType(ContentType(content_type))))
184 },
185 PropertyType::ResponseTopic => {
186 let response_topic = read_string!(bytes);
187 Ok(Some(Property::ResponseTopic(ResponseTopic(response_topic))))
188 },
189 PropertyType::CorrelationData => {
190 let correlation_data = read_binary_data!(bytes);
191 Ok(Some(Property::CorrelationData(CorrelationData(correlation_data))))
192 },
193 PropertyType::SubscriptionIdentifier => {
194 let subscription_identifier = read_variable_int!(bytes);
195 Ok(Some(Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(
196 subscription_identifier,
197 )))))
198 },
199 PropertyType::SessionExpiryInterval => {
200 let session_expiry_interval = read_u32!(bytes);
201 Ok(Some(Property::SessionExpiryInterval(SessionExpiryInterval(
202 session_expiry_interval,
203 ))))
204 },
205 PropertyType::AssignedClientIdentifier => {
206 let assigned_client_identifier = read_string!(bytes);
207 Ok(Some(Property::AssignedClientIdentifier(AssignedClientIdentifier(
208 assigned_client_identifier,
209 ))))
210 },
211 PropertyType::ServerKeepAlive => {
212 let server_keep_alive = read_u16!(bytes);
213 Ok(Some(Property::ServerKeepAlive(ServerKeepAlive(server_keep_alive))))
214 },
215 PropertyType::AuthenticationMethod => {
216 let authentication_method = read_string!(bytes);
217 Ok(Some(Property::AuthenticationMethod(AuthenticationMethod(authentication_method))))
218 },
219 PropertyType::AuthenticationData => {
220 let authentication_data = read_binary_data!(bytes);
221 Ok(Some(Property::AuthenticationData(AuthenticationData(authentication_data))))
222 },
223 PropertyType::RequestProblemInformation => {
224 let request_problem_information = read_u8!(bytes);
225 Ok(Some(Property::RequestProblemInformation(RequestProblemInformation(
226 request_problem_information,
227 ))))
228 },
229 PropertyType::WillDelayInterval => {
230 let will_delay_interval = read_u32!(bytes);
231 Ok(Some(Property::WillDelayInterval(WillDelayInterval(will_delay_interval))))
232 },
233 PropertyType::RequestResponseInformation => {
234 let request_response_information = read_u8!(bytes);
235 Ok(Some(Property::RequestResponseInformation(RequestResponseInformation(
236 request_response_information,
237 ))))
238 },
239 PropertyType::ResponseInformation => {
240 let response_information = read_string!(bytes);
241 Ok(Some(Property::ResponseInformation(ResponseInformation(response_information))))
242 },
243 PropertyType::ServerReference => {
244 let server_reference = read_string!(bytes);
245 Ok(Some(Property::ServerReference(ServerReference(server_reference))))
246 },
247 PropertyType::ReasonString => {
248 let reason_string = read_string!(bytes);
249 Ok(Some(Property::ReasonString(ReasonString(reason_string))))
250 },
251 PropertyType::ReceiveMaximum => {
252 let receive_maximum = read_u16!(bytes);
253 Ok(Some(Property::ReceiveMaximum(ReceiveMaximum(receive_maximum))))
254 },
255 PropertyType::TopicAliasMaximum => {
256 let topic_alias_maximum = read_u16!(bytes);
257 Ok(Some(Property::TopicAliasMaximum(TopicAliasMaximum(topic_alias_maximum))))
258 },
259 PropertyType::TopicAlias => {
260 let topic_alias = read_u16!(bytes);
261 Ok(Some(Property::TopicAlias(TopicAlias(topic_alias))))
262 },
263 PropertyType::MaximumQos => {
264 let qos_byte = read_u8!(bytes);
265 let qos = QoS::try_from(qos_byte).map_err(|_| DecodeError::InvalidQoS)?;
266
267 Ok(Some(Property::MaximumQos(MaximumQos(qos))))
268 },
269 PropertyType::RetainAvailable => {
270 let retain_available = read_u8!(bytes);
271 Ok(Some(Property::RetainAvailable(RetainAvailable(retain_available))))
272 },
273 PropertyType::UserProperty => {
274 let (key, value) = read_string_pair!(bytes);
275 Ok(Some(Property::UserProperty(UserProperty(key, value))))
276 },
277 PropertyType::MaximumPacketSize => {
278 let maximum_packet_size = read_u32!(bytes);
279 Ok(Some(Property::MaximumPacketSize(MaximumPacketSize(maximum_packet_size))))
280 },
281 PropertyType::WildcardSubscriptionAvailable => {
282 let wildcard_subscription_available = read_u8!(bytes);
283 Ok(Some(Property::WildcardSubscriptionAvailable(WildcardSubscriptionAvailable(
284 wildcard_subscription_available,
285 ))))
286 },
287 PropertyType::SubscriptionIdentifierAvailable => {
288 let subscription_identifier_available = read_u8!(bytes);
289 Ok(Some(Property::SubscriptionIdentifierAvailable(SubscriptionIdentifierAvailable(
290 subscription_identifier_available,
291 ))))
292 },
293 PropertyType::SharedSubscriptionAvailable => {
294 let shared_subscription_available = read_u8!(bytes);
295 Ok(Some(Property::SharedSubscriptionAvailable(SharedSubscriptionAvailable(
296 shared_subscription_available,
297 ))))
298 },
299 }
300}
301
302fn decode_properties<F: FnMut(Property)>(
303 bytes: &mut Cursor<&mut BytesMut>,
304 mut closure: F,
305) -> Result<Option<()>, DecodeError> {
306 try_decode_properties(bytes, |property| {
307 closure(property);
308 Ok(())
309 })
310}
311
312fn try_decode_properties<F: FnMut(Property) -> Result<(), DecodeError>>(
313 bytes: &mut Cursor<&mut BytesMut>,
314 mut closure: F,
315) -> Result<Option<()>, DecodeError> {
316 let property_length = read_variable_int!(bytes);
317
318 if property_length == 0 {
319 return Ok(Some(()));
320 }
321
322 require_length!(bytes, property_length as usize);
323
324 let start_cursor_pos = bytes.position();
325
326 loop {
327 let cursor_pos = bytes.position();
328
329 if cursor_pos - start_cursor_pos >= property_length as u64 {
330 break;
331 }
332
333 let property = read_property!(bytes);
334 closure(property)?;
335 }
336
337 Ok(Some(()))
338}
339
340fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, DecodeError> {
341 let protocol_name = read_string!(bytes);
342 let protocol_level = read_u8!(bytes);
343 let connect_flags = read_u8!(bytes);
344 let keep_alive = read_u16!(bytes);
345
346 let protocol_version = ProtocolVersion::try_from(protocol_level)
347 .map_err(|_| DecodeError::InvalidProtocolVersion)?;
348
349 let mut session_expiry_interval = None;
350 let mut receive_maximum = None;
351 let mut maximum_packet_size = None;
352 let mut topic_alias_maximum = None;
353 let mut request_response_information = None;
354 let mut request_problem_information = None;
355 let mut user_properties = vec![];
356 let mut authentication_method = None;
357 let mut authentication_data = None;
358
359 if protocol_version == ProtocolVersion::V500 {
360 return_if_none!(decode_properties(bytes, |property| {
361 match property {
362 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
363 Property::ReceiveMaximum(p) => receive_maximum = Some(p),
364 Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
365 Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
366 Property::RequestResponseInformation(p) => request_response_information = Some(p),
367 Property::RequestProblemInformation(p) => request_problem_information = Some(p),
368 Property::UserProperty(p) => user_properties.push(p),
369 Property::AuthenticationMethod(p) => authentication_method = Some(p),
370 Property::AuthenticationData(p) => authentication_data = Some(p),
371 _ => {}, }
373 })?);
374 }
375
376 let clean_start = connect_flags & 0b0000_0010 == 0b0000_0010;
378 let has_will = connect_flags & 0b0000_0100 == 0b0000_0100;
379 let will_qos_val = (connect_flags & 0b0001_1000) >> 3;
380 let will_qos = QoS::try_from(will_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
381 let retain_will = connect_flags & 0b0010_0000 == 0b0010_0000;
382 let has_password = connect_flags & 0b0100_0000 == 0b0100_0000;
383 let has_user_name = connect_flags & 0b1000_0000 == 0b1000_0000;
384
385 let client_id = read_string!(bytes);
386
387 let will = if has_will {
388 let mut will_delay_interval = None;
389 let mut payload_format_indicator = None;
390 let mut message_expiry_interval = None;
391 let mut content_type = None;
392 let mut response_topic = None;
393 let mut correlation_data = None;
394 let mut user_properties = vec![];
395
396 if protocol_version == ProtocolVersion::V500 {
397 return_if_none!(decode_properties(bytes, |property| {
398 match property {
399 Property::WillDelayInterval(p) => will_delay_interval = Some(p),
400 Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
401 Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
402 Property::ContentType(p) => content_type = Some(p),
403 Property::ResponseTopic(p) => response_topic = Some(p),
404 Property::CorrelationData(p) => correlation_data = Some(p),
405 Property::UserProperty(p) => user_properties.push(p),
406 _ => {}, }
408 })?);
409 }
410
411 let topic = Topic::from_str(read_string!(bytes).as_str())
412 .map_err(|_| DecodeError::InvalidTopicFilter(TopicParseError::TopicTooLong))?;
413 let payload = read_binary_data!(bytes);
414
415 Some(FinalWill {
416 topic,
417 payload,
418 qos: will_qos,
419 should_retain: retain_will,
420 will_delay_interval,
421 payload_format_indicator,
422 message_expiry_interval,
423 content_type,
424 response_topic,
425 correlation_data,
426 user_properties,
427 })
428 } else {
429 None
430 };
431
432 let mut user_name = None;
433 let mut password = None;
434
435 if has_user_name {
436 user_name = Some(read_string!(bytes));
437 }
438
439 if has_password {
440 password = Some(read_string!(bytes));
441 }
442
443 let packet = ConnectPacket {
444 protocol_name,
445 protocol_version,
446 clean_start,
447 keep_alive,
448 session_expiry_interval,
449 receive_maximum,
450 maximum_packet_size,
451 topic_alias_maximum,
452 request_response_information,
453 request_problem_information,
454 user_properties,
455 authentication_method,
456 authentication_data,
457 client_id,
458 will,
459 user_name,
460 password,
461 };
462
463 Ok(Some(Packet::Connect(packet)))
464}
465
466fn decode_connect_ack(
467 bytes: &mut Cursor<&mut BytesMut>,
468 protocol_version: ProtocolVersion,
469) -> Result<Option<Packet>, DecodeError> {
470 let flags = read_u8!(bytes);
471 let session_present = (flags & 0b0000_0001) == 0b0000_0001;
472
473 let reason_code_byte = read_u8!(bytes);
474 let reason_code =
475 ConnectReason::try_from(reason_code_byte).map_err(|_| DecodeError::InvalidConnectReason)?;
476
477 let mut session_expiry_interval = None;
478 let mut receive_maximum = None;
479 let mut maximum_qos = None;
480 let mut retain_available = None;
481 let mut maximum_packet_size = None;
482 let mut assigned_client_identifier = None;
483 let mut topic_alias_maximum = None;
484 let mut reason_string = None;
485 let mut user_properties = vec![];
486 let mut wildcard_subscription_available = None;
487 let mut subscription_identifiers_available = None;
488 let mut shared_subscription_available = None;
489 let mut server_keep_alive = None;
490 let mut response_information = None;
491 let mut server_reference = None;
492 let mut authentication_method = None;
493 let mut authentication_data = None;
494
495 if protocol_version == ProtocolVersion::V500 {
496 return_if_none!(decode_properties(bytes, |property| {
497 match property {
498 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
499 Property::ReceiveMaximum(p) => receive_maximum = Some(p),
500 Property::MaximumQos(p) => maximum_qos = Some(p),
501 Property::RetainAvailable(p) => retain_available = Some(p),
502 Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
503 Property::AssignedClientIdentifier(p) => assigned_client_identifier = Some(p),
504 Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
505 Property::ReasonString(p) => reason_string = Some(p),
506 Property::UserProperty(p) => user_properties.push(p),
507 Property::WildcardSubscriptionAvailable(p) => {
508 wildcard_subscription_available = Some(p)
509 },
510 Property::SubscriptionIdentifierAvailable(p) => {
511 subscription_identifiers_available = Some(p)
512 },
513 Property::SharedSubscriptionAvailable(p) => shared_subscription_available = Some(p),
514 Property::ServerKeepAlive(p) => server_keep_alive = Some(p),
515 Property::ResponseInformation(p) => response_information = Some(p),
516 Property::ServerReference(p) => server_reference = Some(p),
517 Property::AuthenticationMethod(p) => authentication_method = Some(p),
518 Property::AuthenticationData(p) => authentication_data = Some(p),
519 _ => {}, }
521 })?);
522 }
523
524 let packet = ConnectAckPacket {
525 session_present,
526 reason_code,
527 session_expiry_interval,
528 receive_maximum,
529 maximum_qos,
530 retain_available,
531 maximum_packet_size,
532 assigned_client_identifier,
533 topic_alias_maximum,
534 reason_string,
535 user_properties,
536 wildcard_subscription_available,
537 subscription_identifiers_available,
538 shared_subscription_available,
539 server_keep_alive,
540 response_information,
541 server_reference,
542 authentication_method,
543 authentication_data,
544 };
545
546 Ok(Some(Packet::ConnectAck(packet)))
547}
548
549fn decode_publish(
550 bytes: &mut Cursor<&mut BytesMut>,
551 first_byte: u8,
552 remaining_packet_length: u32,
553 protocol_version: ProtocolVersion,
554) -> Result<Option<Packet>, DecodeError> {
555 let is_duplicate = (first_byte & 0b0000_1000) == 0b0000_1000;
556 let qos_val = (first_byte & 0b0000_0110) >> 1;
557 let qos = QoS::try_from(qos_val).map_err(|_| DecodeError::InvalidQoS)?;
558 let retain = (first_byte & 0b0000_0001) == 0b0000_0001;
559
560 let start_cursor_pos = bytes.position();
562
563 let topic_str = read_string!(bytes);
564 let topic = topic_str.parse().map_err(DecodeError::InvalidTopic)?;
565
566 let packet_id = match qos {
567 QoS::AtMostOnce => None,
568 QoS::AtLeastOnce | QoS::ExactlyOnce => Some(read_u16!(bytes)),
569 };
570
571 let mut payload_format_indicator = None;
572 let mut message_expiry_interval = None;
573 let mut topic_alias = None;
574 let mut response_topic = None;
575 let mut correlation_data = None;
576 let mut user_properties = vec![];
577 let mut subscription_identifiers = None;
578 let mut content_type = None;
579
580 if protocol_version == ProtocolVersion::V500 {
581 try_decode_properties(bytes, |property| match property {
582 Property::PayloadFormatIndicator(p) => {
583 payload_format_indicator = Some(p);
584 Ok(())
585 },
586 Property::MessageExpiryInterval(p) => {
587 message_expiry_interval = Some(p);
588 Ok(())
589 },
590 Property::TopicAlias(p) => {
591 topic_alias = Some(p);
592 Ok(())
593 },
594 Property::ResponseTopic(p) => {
595 response_topic = Some(p);
596 Ok(())
597 },
598 Property::CorrelationData(p) => {
599 correlation_data = Some(p);
600 Ok(())
601 },
602 Property::UserProperty(p) => {
603 user_properties.push(p);
604 Ok(())
605 },
606 Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => {
607 Err(DecodeError::InvalidSubscriptionIdentifier)
608 },
609 Property::SubscriptionIdentifier(p) => {
610 subscription_identifiers.get_or_insert(Vec::new()).push(p);
611 Ok(())
612 },
613 Property::ContentType(p) => {
614 content_type = Some(p);
615 Ok(())
616 },
617 _ => Err(DecodeError::InvalidPropertyForPacket),
618 })?;
619 }
620
621 let end_cursor_pos = bytes.position();
622 let variable_header_size = (end_cursor_pos - start_cursor_pos) as u32;
623 if remaining_packet_length < variable_header_size {
626 return Err(DecodeError::InvalidRemainingLength);
627 }
628 let payload_size = remaining_packet_length - variable_header_size;
629 let payload = return_if_none!(decode_binary_data_with_size(bytes, payload_size as usize)?);
630 let subscription_identifiers =
631 subscription_identifiers.unwrap_or_else(|| Vec::with_capacity(0));
632
633 let packet = PublishPacket {
634 is_duplicate,
635 qos,
636 retain,
637
638 topic,
639 packet_id,
640
641 payload_format_indicator,
642 message_expiry_interval,
643 topic_alias,
644 response_topic,
645 correlation_data,
646 user_properties,
647 subscription_identifiers,
648 content_type,
649
650 payload,
651 };
652
653 Ok(Some(Packet::Publish(packet)))
654}
655
656fn decode_publish_ack(
657 bytes: &mut Cursor<&mut BytesMut>,
658 remaining_packet_length: u32,
659 protocol_version: ProtocolVersion,
660) -> Result<Option<Packet>, DecodeError> {
661 let packet_id = read_u16!(bytes);
662
663 if remaining_packet_length == 2 {
664 return Ok(Some(Packet::PublishAck(PublishAckPacket {
665 packet_id,
666 reason_code: PublishAckReason::Success,
667 reason_string: None,
668 user_properties: vec![],
669 })));
670 }
671
672 let reason_code_byte = read_u8!(bytes);
673 let reason_code = PublishAckReason::try_from(reason_code_byte)
674 .map_err(|_| DecodeError::InvalidPublishAckReason)?;
675
676 let mut reason_string = None;
677 let mut user_properties = vec![];
678
679 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
680 return_if_none!(decode_properties(bytes, |property| {
681 match property {
682 Property::ReasonString(p) => reason_string = Some(p),
683 Property::UserProperty(p) => user_properties.push(p),
684 _ => {}, }
686 })?);
687 }
688
689 let packet = PublishAckPacket { packet_id, reason_code, reason_string, user_properties };
690
691 Ok(Some(Packet::PublishAck(packet)))
692}
693
694fn decode_publish_received(
695 bytes: &mut Cursor<&mut BytesMut>,
696 remaining_packet_length: u32,
697 protocol_version: ProtocolVersion,
698) -> Result<Option<Packet>, DecodeError> {
699 let packet_id = read_u16!(bytes);
700
701 if remaining_packet_length == 2 {
702 return Ok(Some(Packet::PublishReceived(PublishReceivedPacket {
703 packet_id,
704 reason_code: PublishReceivedReason::Success,
705 reason_string: None,
706 user_properties: vec![],
707 })));
708 }
709
710 let reason_code_byte = read_u8!(bytes);
711 let reason_code = PublishReceivedReason::try_from(reason_code_byte)
712 .map_err(|_| DecodeError::InvalidPublishReceivedReason)?;
713
714 let mut reason_string = None;
715 let mut user_properties = vec![];
716
717 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
718 return_if_none!(decode_properties(bytes, |property| {
719 match property {
720 Property::ReasonString(p) => reason_string = Some(p),
721 Property::UserProperty(p) => user_properties.push(p),
722 _ => {}, }
724 })?);
725 }
726
727 let packet = PublishReceivedPacket { packet_id, reason_code, reason_string, user_properties };
728
729 Ok(Some(Packet::PublishReceived(packet)))
730}
731
732fn decode_publish_release(
733 bytes: &mut Cursor<&mut BytesMut>,
734 remaining_packet_length: u32,
735 protocol_version: ProtocolVersion,
736) -> Result<Option<Packet>, DecodeError> {
737 let packet_id = read_u16!(bytes);
738
739 if remaining_packet_length == 2 {
740 return Ok(Some(Packet::PublishRelease(PublishReleasePacket {
741 packet_id,
742 reason_code: PublishReleaseReason::Success,
743 reason_string: None,
744 user_properties: vec![],
745 })));
746 }
747
748 let reason_code_byte = read_u8!(bytes);
749 let reason_code = PublishReleaseReason::try_from(reason_code_byte)
750 .map_err(|_| DecodeError::InvalidPublishReleaseReason)?;
751
752 let mut reason_string = None;
753 let mut user_properties = vec![];
754
755 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
756 return_if_none!(decode_properties(bytes, |property| {
757 match property {
758 Property::ReasonString(p) => reason_string = Some(p),
759 Property::UserProperty(p) => user_properties.push(p),
760 _ => {}, }
762 })?);
763 }
764
765 let packet = PublishReleasePacket { packet_id, reason_code, reason_string, user_properties };
766
767 Ok(Some(Packet::PublishRelease(packet)))
768}
769
770fn decode_publish_complete(
771 bytes: &mut Cursor<&mut BytesMut>,
772 remaining_packet_length: u32,
773 protocol_version: ProtocolVersion,
774) -> Result<Option<Packet>, DecodeError> {
775 let packet_id = read_u16!(bytes);
776
777 if remaining_packet_length == 2 {
778 return Ok(Some(Packet::PublishComplete(PublishCompletePacket {
779 packet_id,
780 reason_code: PublishCompleteReason::Success,
781 reason_string: None,
782 user_properties: vec![],
783 })));
784 }
785
786 let reason_code_byte = read_u8!(bytes);
787 let reason_code = PublishCompleteReason::try_from(reason_code_byte)
788 .map_err(|_| DecodeError::InvalidPublishCompleteReason)?;
789
790 let mut reason_string = None;
791 let mut user_properties = vec![];
792
793 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
794 return_if_none!(decode_properties(bytes, |property| {
795 match property {
796 Property::ReasonString(p) => reason_string = Some(p),
797 Property::UserProperty(p) => user_properties.push(p),
798 _ => {}, }
800 })?);
801 }
802
803 let packet = PublishCompletePacket { packet_id, reason_code, reason_string, user_properties };
804
805 Ok(Some(Packet::PublishComplete(packet)))
806}
807
808fn decode_subscribe(
809 bytes: &mut Cursor<&mut BytesMut>,
810 remaining_packet_length: u32,
811 protocol_version: ProtocolVersion,
812) -> Result<Option<Packet>, DecodeError> {
813 let start_cursor_pos = bytes.position();
814
815 let packet_id = read_u16!(bytes);
816
817 let mut subscription_identifier = None;
818 let mut user_properties = vec![];
819
820 if protocol_version == ProtocolVersion::V500 {
821 try_decode_properties(bytes, |property| {
822 match property {
823 Property::SubscriptionIdentifier(_) if subscription_identifier.is_some() => {
825 Err(DecodeError::InvalidSubscriptionIdentifier)
826 },
827 Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(0))) => {
829 Err(DecodeError::InvalidSubscriptionIdentifier)
830 },
831 Property::SubscriptionIdentifier(p) => {
832 subscription_identifier = Some(p);
833 Ok(())
834 },
835 Property::UserProperty(p) => {
836 user_properties.push(p);
837 Ok(())
838 },
839 _ => Err(DecodeError::InvalidPropertyForPacket),
840 }
841 })?;
842 }
843
844 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
845 if remaining_packet_length < variable_header_size {
846 return Err(DecodeError::InvalidRemainingLength);
847 }
848 let payload_size = remaining_packet_length - variable_header_size;
849
850 let mut subscription_topics = vec![];
851 let mut bytes_read: usize = 0;
852
853 loop {
854 if bytes_read >= payload_size as usize {
855 break;
856 }
857
858 let start_cursor_pos = bytes.position();
859
860 let topic_filter_str = read_string!(bytes);
861 let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
862
863 let options_byte = read_u8!(bytes);
864
865 let maximum_qos_val = options_byte & 0b0000_0011;
866 let maximum_qos = QoS::try_from(maximum_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
867
868 let retain_handling_val = (options_byte & 0b0011_0000) >> 4;
869 let retain_handling = RetainHandling::try_from(retain_handling_val)
870 .map_err(|_| DecodeError::InvalidRetainHandling)?;
871
872 let retain_as_published = (options_byte & 0b0000_1000) == 0b0000_1000;
873 let no_local = (options_byte & 0b0000_0100) == 0b0000_0100;
874
875 let subscription_topic = SubscriptionTopic {
876 topic_filter,
877 maximum_qos,
878 no_local,
879 retain_as_published,
880 retain_handling,
881 };
882
883 subscription_topics.push(subscription_topic);
884
885 let end_cursor_pos = bytes.position();
886 bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
887 }
888
889 let packet = SubscribePacket {
890 packet_id,
891 subscription_identifier,
892 user_properties,
893 subscription_topics,
894 };
895
896 Ok(Some(Packet::Subscribe(packet)))
897}
898
899fn decode_subscribe_ack(
900 bytes: &mut Cursor<&mut BytesMut>,
901 remaining_packet_length: u32,
902 protocol_version: ProtocolVersion,
903) -> Result<Option<Packet>, DecodeError> {
904 let start_cursor_pos = bytes.position();
905
906 let packet_id = read_u16!(bytes);
907
908 let mut reason_string = None;
909 let mut user_properties = vec![];
910
911 if protocol_version == ProtocolVersion::V500 {
912 return_if_none!(decode_properties(bytes, |property| {
913 match property {
914 Property::ReasonString(p) => reason_string = Some(p),
915 Property::UserProperty(p) => user_properties.push(p),
916 _ => {}, }
918 })?);
919 }
920
921 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
922 if remaining_packet_length < variable_header_size {
923 return Err(DecodeError::InvalidRemainingLength);
924 }
925 let payload_size = remaining_packet_length - variable_header_size;
926
927 let mut reason_codes = vec![];
928 for _ in 0..payload_size {
929 let next_byte = read_u8!(bytes);
930 let reason_code = SubscribeAckReason::try_from(next_byte)
931 .map_err(|_| DecodeError::InvalidSubscribeAckReason)?;
932 reason_codes.push(reason_code);
933 }
934
935 let packet = SubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
936
937 Ok(Some(Packet::SubscribeAck(packet)))
938}
939
940fn decode_unsubscribe(
941 bytes: &mut Cursor<&mut BytesMut>,
942 remaining_packet_length: u32,
943 protocol_version: ProtocolVersion,
944) -> Result<Option<Packet>, DecodeError> {
945 let start_cursor_pos = bytes.position();
946
947 let packet_id = read_u16!(bytes);
948
949 let mut user_properties = vec![];
950
951 if protocol_version == ProtocolVersion::V500 {
952 return_if_none!(decode_properties(bytes, |property| {
953 if let Property::UserProperty(p) = property {
954 user_properties.push(p);
955 }
956 })?);
957 }
958
959 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
960 if remaining_packet_length < variable_header_size {
961 return Err(DecodeError::InvalidRemainingLength);
962 }
963 let payload_size = remaining_packet_length - variable_header_size;
964
965 let mut topic_filters = vec![];
966 let mut bytes_read: usize = 0;
967
968 loop {
969 if bytes_read >= payload_size as usize {
970 break;
971 }
972
973 let start_cursor_pos = bytes.position();
974
975 let topic_filter_str = read_string!(bytes);
976 let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
977 topic_filters.push(topic_filter);
978
979 let end_cursor_pos = bytes.position();
980 bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
981 }
982
983 let packet = UnsubscribePacket { packet_id, user_properties, topic_filters };
984
985 Ok(Some(Packet::Unsubscribe(packet)))
986}
987
988fn decode_unsubscribe_ack(
989 bytes: &mut Cursor<&mut BytesMut>,
990 remaining_packet_length: u32,
991 protocol_version: ProtocolVersion,
992) -> Result<Option<Packet>, DecodeError> {
993 let start_cursor_pos = bytes.position();
994
995 let packet_id = read_u16!(bytes);
996
997 let mut reason_string = None;
998 let mut user_properties = vec![];
999
1000 if protocol_version == ProtocolVersion::V500 {
1001 return_if_none!(decode_properties(bytes, |property| {
1002 match property {
1003 Property::ReasonString(p) => reason_string = Some(p),
1004 Property::UserProperty(p) => user_properties.push(p),
1005 _ => {}, }
1007 })?);
1008 }
1009
1010 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
1011 if remaining_packet_length < variable_header_size {
1012 return Err(DecodeError::InvalidRemainingLength);
1013 }
1014 let payload_size = remaining_packet_length - variable_header_size;
1015
1016 let mut reason_codes = vec![];
1017 for _ in 0..payload_size {
1018 let next_byte = read_u8!(bytes);
1019 let reason_code = UnsubscribeAckReason::try_from(next_byte)
1020 .map_err(|_| DecodeError::InvalidUnsubscribeAckReason)?;
1021 reason_codes.push(reason_code);
1022 }
1023
1024 let packet = UnsubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
1025
1026 Ok(Some(Packet::UnsubscribeAck(packet)))
1027}
1028
1029fn decode_disconnect(
1030 bytes: &mut Cursor<&mut BytesMut>,
1031 remaining_packet_length: u32,
1032 protocol_version: ProtocolVersion,
1033) -> Result<Option<Packet>, DecodeError> {
1034 if remaining_packet_length == 0 {
1035 return Ok(Some(Packet::Disconnect(DisconnectPacket {
1036 reason_code: DisconnectReason::NormalDisconnection,
1037 session_expiry_interval: None,
1038 reason_string: None,
1039 user_properties: vec![],
1040 server_reference: None,
1041 })));
1042 }
1043
1044 let reason_code_byte = read_u8!(bytes);
1045 let reason_code = DisconnectReason::try_from(reason_code_byte)
1046 .map_err(|_| DecodeError::InvalidDisconnectReason)?;
1047
1048 let mut session_expiry_interval = None;
1049 let mut reason_string = None;
1050 let mut user_properties = vec![];
1051 let mut server_reference = None;
1052
1053 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1054 return_if_none!(decode_properties(bytes, |property| {
1055 match property {
1056 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
1057 Property::ReasonString(p) => reason_string = Some(p),
1058 Property::UserProperty(p) => user_properties.push(p),
1059 Property::ServerReference(p) => server_reference = Some(p),
1060 _ => {}, }
1062 })?);
1063 }
1064
1065 let packet = DisconnectPacket {
1066 reason_code,
1067 session_expiry_interval,
1068 reason_string,
1069 user_properties,
1070 server_reference,
1071 };
1072
1073 Ok(Some(Packet::Disconnect(packet)))
1074}
1075
1076fn decode_authenticate(
1077 bytes: &mut Cursor<&mut BytesMut>,
1078 remaining_packet_length: u32,
1079 protocol_version: ProtocolVersion,
1080) -> Result<Option<Packet>, DecodeError> {
1081 if remaining_packet_length == 0 {
1082 return Ok(Some(Packet::Authenticate(AuthenticatePacket {
1083 reason_code: AuthenticateReason::Success,
1084 authentication_method: None,
1085 authentication_data: None,
1086 reason_string: None,
1087 user_properties: vec![],
1088 })));
1089 }
1090
1091 let reason_code_byte = read_u8!(bytes);
1092 let reason_code = AuthenticateReason::try_from(reason_code_byte)
1093 .map_err(|_| DecodeError::InvalidAuthenticateReason)?;
1094
1095 let mut authentication_method = None;
1096 let mut authentication_data = None;
1097 let mut reason_string = None;
1098 let mut user_properties = vec![];
1099
1100 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1101 return_if_none!(decode_properties(bytes, |property| {
1102 match property {
1103 Property::AuthenticationMethod(p) => authentication_method = Some(p),
1104 Property::AuthenticationData(p) => authentication_data = Some(p),
1105 Property::ReasonString(p) => reason_string = Some(p),
1106 Property::UserProperty(p) => user_properties.push(p),
1107 _ => {}, }
1109 })?);
1110 }
1111
1112 let packet = AuthenticatePacket {
1113 reason_code,
1114 authentication_method,
1115 authentication_data,
1116 reason_string,
1117 user_properties,
1118 };
1119
1120 Ok(Some(Packet::Authenticate(packet)))
1121}
1122
1123fn decode_packet(
1124 protocol_version: ProtocolVersion,
1125 packet_type: &PacketType,
1126 bytes: &mut Cursor<&mut BytesMut>,
1127 remaining_packet_length: u32,
1128 first_byte: u8,
1129) -> Result<Option<Packet>, DecodeError> {
1130 match packet_type {
1131 PacketType::Connect => decode_connect(bytes),
1132 PacketType::ConnectAck => decode_connect_ack(bytes, protocol_version),
1133 PacketType::Publish => {
1134 decode_publish(bytes, first_byte, remaining_packet_length, protocol_version)
1135 },
1136 PacketType::PublishAck => {
1137 decode_publish_ack(bytes, remaining_packet_length, protocol_version)
1138 },
1139 PacketType::PublishReceived => {
1140 decode_publish_received(bytes, remaining_packet_length, protocol_version)
1141 },
1142 PacketType::PublishRelease => {
1143 decode_publish_release(bytes, remaining_packet_length, protocol_version)
1144 },
1145 PacketType::PublishComplete => {
1146 decode_publish_complete(bytes, remaining_packet_length, protocol_version)
1147 },
1148 PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length, protocol_version),
1149 PacketType::SubscribeAck => {
1150 decode_subscribe_ack(bytes, remaining_packet_length, protocol_version)
1151 },
1152 PacketType::Unsubscribe => {
1153 decode_unsubscribe(bytes, remaining_packet_length, protocol_version)
1154 },
1155 PacketType::UnsubscribeAck => {
1156 decode_unsubscribe_ack(bytes, remaining_packet_length, protocol_version)
1157 },
1158 PacketType::PingRequest => Ok(Some(Packet::PingRequest)),
1159 PacketType::PingResponse => Ok(Some(Packet::PingResponse)),
1160 PacketType::Disconnect => {
1161 decode_disconnect(bytes, remaining_packet_length, protocol_version)
1162 },
1163 PacketType::Authenticate => {
1164 decode_authenticate(bytes, remaining_packet_length, protocol_version)
1165 },
1166 }
1167}
1168
1169pub fn decode_mqtt(
1170 bytes: &mut BytesMut,
1171 protocol_version: ProtocolVersion,
1172) -> Result<Option<Packet>, DecodeError> {
1173 let mut bytes = Cursor::new(bytes);
1174 let first_byte = read_u8!(bytes);
1175
1176 let first_byte_val = (first_byte & 0b1111_0000) >> 4;
1177 let packet_type =
1178 PacketType::try_from(first_byte_val).map_err(|_| DecodeError::InvalidPacketType)?;
1179 let remaining_packet_length = read_variable_int!(&mut bytes);
1180
1181 let cursor_pos = bytes.position() as usize;
1182 let remaining_buffer_amount = bytes.get_ref().len() - cursor_pos;
1183
1184 if remaining_buffer_amount < remaining_packet_length as usize {
1185 return Ok(None);
1187 }
1188
1189 let packet = return_if_none!(decode_packet(
1190 protocol_version,
1191 &packet_type,
1192 &mut bytes,
1193 remaining_packet_length,
1194 first_byte
1195 )?);
1196
1197 let cursor_pos = bytes.position() as usize;
1198 let bytes = bytes.into_inner();
1199
1200 let _rest = bytes.split_to(cursor_pos);
1201
1202 Ok(Some(packet))
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207 use crate::{decoder::*, topic::TopicFilter, types::*};
1208 use bytes::BytesMut;
1209
1210 #[test]
1211 fn test_invalid_remaining_length() {
1212 let mut bytes = BytesMut::new();
1213 bytes.extend_from_slice(&[136, 1, 0, 36, 0, 0]); let _ = decode_mqtt(&mut bytes, ProtocolVersion::V500);
1216 }
1217
1218 #[test]
1219 fn test_decode_variable_int() {
1220 fn normal_test(encoded_variable_int: &[u8], expected_variable_int: u32) {
1223 let bytes = &mut BytesMut::new();
1224 bytes.extend_from_slice(encoded_variable_int);
1225 match decode_variable_int(&mut Cursor::new(bytes)) {
1226 Ok(val) => match val {
1227 Some(get_variable_int) => assert_eq!(get_variable_int, expected_variable_int),
1228 None => panic!("variable_int is None"),
1229 },
1230 Err(err) => panic!("Error decoding variable int: {:?}", err),
1231 }
1232 }
1233
1234 normal_test(&[0x00], 0);
1236 normal_test(&[0x7F], 127);
1237
1238 normal_test(&[0x80, 0x01], 128);
1240 normal_test(&[0xFF, 0x7F], 16383);
1241
1242 normal_test(&[0x80, 0x80, 0x01], 16384);
1244 normal_test(&[0xFF, 0xFF, 0x7F], 2097151);
1245
1246 normal_test(&[0x80, 0x80, 0x80, 0x01], 2097152);
1248 normal_test(&[0xFF, 0xFF, 0xFF, 0x7F], 268435455);
1249 }
1250
1251 #[test]
1252 fn test_decode_subscribe() {
1253 let mut without_subscription_identifier = BytesMut::from(
1255 [0x82, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00].as_slice(),
1256 );
1257 let without_subscription_identifier_expected = Packet::Subscribe(SubscribePacket {
1258 packet_id: 1,
1259 subscription_identifier: None,
1260 user_properties: vec![],
1261 subscription_topics: vec![SubscriptionTopic {
1262 topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 },
1263 maximum_qos: QoS::AtMostOnce,
1264 no_local: false,
1265 retain_as_published: false,
1266 retain_handling: RetainHandling::SendAtSubscribeTime,
1267 }],
1268 });
1269 let decoded = decode_mqtt(&mut without_subscription_identifier, ProtocolVersion::V500)
1270 .unwrap()
1271 .unwrap();
1272 assert_eq!(without_subscription_identifier_expected, decoded);
1273
1274 let mut packet = BytesMut::from(
1276 [0x82, 0x0c, 0xff, 0xf6, 0x02, 0x0b, 0x01, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x02]
1277 .as_slice(),
1278 );
1279 let decoded = decode_mqtt(&mut packet, ProtocolVersion::V500).unwrap().unwrap();
1280 let with_subscription_identifier_expected = Packet::Subscribe(SubscribePacket {
1281 packet_id: 65526,
1282 subscription_identifier: Some(SubscriptionIdentifier(VariableByteInt(1))),
1283 user_properties: vec![],
1284 subscription_topics: vec![SubscriptionTopic {
1285 topic_filter: TopicFilter::Concrete { filter: "test".into(), level_count: 1 },
1286 maximum_qos: QoS::ExactlyOnce,
1287 no_local: false,
1288 retain_as_published: false,
1289 retain_handling: RetainHandling::SendAtSubscribeTime,
1290 }],
1291 });
1292 assert_eq!(with_subscription_identifier_expected, decoded);
1293 }
1294 #[test]
1295 fn test_decode_variable_int_crash() {
1296 let number: u32 = u32::MAX;
1297 let result = decode_variable_int(&mut Cursor::new(&mut BytesMut::from(
1298 number.to_be_bytes().as_slice(),
1299 )));
1300
1301 assert!(matches!(result, Err(DecodeError::InvalidRemainingLength)));
1302 }
1303}