1use crate::types::{
2 properties::*, AuthenticatePacket, AuthenticateReason, ConnectAckPacket, ConnectPacket,
3 ConnectReason, DecodeError, DisconnectPacket, DisconnectReason, FinalWill, Packet, PacketType,
4 ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket,
5 PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason,
6 PublishReleasePacket, PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket,
7 SubscribeAckReason, SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket,
8 UnsubscribeAckReason, UnsubscribePacket, VariableByteInt,
9};
10use bytes::{Buf, Bytes, BytesMut};
11use std::{convert::TryFrom, io::Cursor};
12
13macro_rules! return_if_none {
14 ($x: expr) => {{
15 let string_opt = $x;
16 if string_opt.is_none() {
17 return Ok(None);
18 }
19
20 string_opt.unwrap()
21 }};
22}
23
24macro_rules! require_length {
25 ($bytes: expr, $len: expr) => {{
26 if $bytes.remaining() < $len {
27 return Ok(None);
28 }
29 }};
30}
31
32macro_rules! read_u8 {
33 ($bytes: expr) => {{
34 if !$bytes.has_remaining() {
35 return Ok(None);
36 }
37
38 $bytes.get_u8()
39 }};
40}
41
42macro_rules! read_u16 {
43 ($bytes: expr) => {{
44 if $bytes.remaining() < 2 {
45 return Ok(None);
46 }
47
48 $bytes.get_u16()
49 }};
50}
51
52macro_rules! read_u32 {
53 ($bytes: expr) => {{
54 if $bytes.remaining() < 4 {
55 return Ok(None);
56 }
57
58 $bytes.get_u32()
59 }};
60}
61
62macro_rules! read_variable_int {
63 ($bytes: expr) => {{
64 return_if_none!(decode_variable_int($bytes)?)
65 }};
66}
67
68macro_rules! read_string {
69 ($bytes: expr) => {{
70 return_if_none!(decode_string($bytes)?)
71 }};
72}
73
74macro_rules! read_binary_data {
75 ($bytes: expr) => {{
76 return_if_none!(decode_binary_data($bytes)?)
77 }};
78}
79
80macro_rules! read_string_pair {
81 ($bytes: expr) => {{
82 let string_key = read_string!($bytes);
83 let string_value = read_string!($bytes);
84
85 (string_key, string_value)
86 }};
87}
88
89macro_rules! read_property {
90 ($bytes: expr) => {{
91 let property_id = read_variable_int!($bytes);
92 return_if_none!(decode_property(property_id, $bytes)?)
93 }};
94}
95
96fn decode_variable_int(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<u32>, DecodeError> {
97 let mut multiplier = 1;
98 let mut value: u32 = 0;
99
100 loop {
101 let encoded_byte = read_u8!(bytes);
102
103 value += ((encoded_byte & 0b0111_1111) as u32) * multiplier;
104
105 if multiplier > (128 * 128 * 128) {
106 return Err(DecodeError::InvalidRemainingLength);
107 }
108
109 multiplier *= 128;
110
111 if encoded_byte & 0b1000_0000 == 0b0000_0000 {
112 break;
113 }
114 }
115
116 Ok(Some(value))
117}
118
119fn decode_string(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<String>, DecodeError> {
120 let str_size_bytes = read_u16!(bytes) as usize;
121
122 require_length!(bytes, str_size_bytes);
123
124 let position = bytes.position() as usize;
125
126 match String::from_utf8(bytes.get_ref()[position..(position + str_size_bytes)].into()) {
128 Ok(string) => {
129 bytes.advance(str_size_bytes);
130 Ok(Some(string))
131 },
132 Err(_) => Err(DecodeError::InvalidUtf8),
133 }
134}
135
136fn decode_binary_data(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Bytes>, DecodeError> {
137 let data_size_bytes = read_u16!(bytes) as usize;
138 require_length!(bytes, data_size_bytes);
139
140 let position = bytes.position() as usize;
141
142 let payload_bytes =
143 BytesMut::from(&bytes.get_ref()[position..(position + data_size_bytes)]).freeze();
144 let result = Ok(Some(payload_bytes));
145 bytes.advance(data_size_bytes);
146
147 result
148}
149
150fn decode_binary_data_with_size(
151 bytes: &mut Cursor<&mut BytesMut>,
152 size: usize,
153) -> Result<Option<Bytes>, DecodeError> {
154 require_length!(bytes, size);
155
156 let position = bytes.position() as usize;
157 let payload_bytes = BytesMut::from(&bytes.get_ref()[position..(position + size)]).freeze();
158 let result = Ok(Some(payload_bytes));
159 bytes.advance(size);
160
161 result
162}
163
164fn decode_property(
165 property_id: u32,
166 bytes: &mut Cursor<&mut BytesMut>,
167) -> Result<Option<Property>, DecodeError> {
168 let property_type =
169 PropertyType::try_from(property_id).map_err(|_| DecodeError::InvalidPropertyId)?;
170
171 match property_type {
172 PropertyType::PayloadFormatIndicator => {
173 let format_indicator = read_u8!(bytes);
174 Ok(Some(Property::PayloadFormatIndicator(PayloadFormatIndicator(format_indicator))))
175 },
176 PropertyType::MessageExpiryInterval => {
177 let message_expiry_interval = read_u32!(bytes);
178 Ok(Some(Property::MessageExpiryInterval(MessageExpiryInterval(
179 message_expiry_interval,
180 ))))
181 },
182 PropertyType::ContentType => {
183 let content_type = read_string!(bytes);
184 Ok(Some(Property::ContentType(ContentType(content_type))))
185 },
186 PropertyType::ResponseTopic => {
187 let response_topic = read_string!(bytes);
188 Ok(Some(Property::ResponseTopic(ResponseTopic(response_topic))))
189 },
190 PropertyType::CorrelationData => {
191 let correlation_data = read_binary_data!(bytes);
192 Ok(Some(Property::CorrelationData(CorrelationData(correlation_data))))
193 },
194 PropertyType::SubscriptionIdentifier => {
195 let subscription_identifier = read_u32!(bytes);
196 Ok(Some(Property::SubscriptionIdentifier(SubscriptionIdentifier(VariableByteInt(
197 subscription_identifier,
198 )))))
199 },
200 PropertyType::SessionExpiryInterval => {
201 let session_expiry_interval = read_u32!(bytes);
202 Ok(Some(Property::SessionExpiryInterval(SessionExpiryInterval(
203 session_expiry_interval,
204 ))))
205 },
206 PropertyType::AssignedClientIdentifier => {
207 let assigned_client_identifier = read_string!(bytes);
208 Ok(Some(Property::AssignedClientIdentifier(AssignedClientIdentifier(
209 assigned_client_identifier,
210 ))))
211 },
212 PropertyType::ServerKeepAlive => {
213 let server_keep_alive = read_u16!(bytes);
214 Ok(Some(Property::ServerKeepAlive(ServerKeepAlive(server_keep_alive))))
215 },
216 PropertyType::AuthenticationMethod => {
217 let authentication_method = read_string!(bytes);
218 Ok(Some(Property::AuthenticationMethod(AuthenticationMethod(authentication_method))))
219 },
220 PropertyType::AuthenticationData => {
221 let authentication_data = read_binary_data!(bytes);
222 Ok(Some(Property::AuthenticationData(AuthenticationData(authentication_data))))
223 },
224 PropertyType::RequestProblemInformation => {
225 let request_problem_information = read_u8!(bytes);
226 Ok(Some(Property::RequestProblemInformation(RequestProblemInformation(
227 request_problem_information,
228 ))))
229 },
230 PropertyType::WillDelayInterval => {
231 let will_delay_interval = read_u32!(bytes);
232 Ok(Some(Property::WillDelayInterval(WillDelayInterval(will_delay_interval))))
233 },
234 PropertyType::RequestResponseInformation => {
235 let request_response_information = read_u8!(bytes);
236 Ok(Some(Property::RequestResponseInformation(RequestResponseInformation(
237 request_response_information,
238 ))))
239 },
240 PropertyType::ResponseInformation => {
241 let response_information = read_string!(bytes);
242 Ok(Some(Property::ResponseInformation(ResponseInformation(response_information))))
243 },
244 PropertyType::ServerReference => {
245 let server_reference = read_string!(bytes);
246 Ok(Some(Property::ServerReference(ServerReference(server_reference))))
247 },
248 PropertyType::ReasonString => {
249 let reason_string = read_string!(bytes);
250 Ok(Some(Property::ReasonString(ReasonString(reason_string))))
251 },
252 PropertyType::ReceiveMaximum => {
253 let receive_maximum = read_u16!(bytes);
254 Ok(Some(Property::ReceiveMaximum(ReceiveMaximum(receive_maximum))))
255 },
256 PropertyType::TopicAliasMaximum => {
257 let topic_alias_maximum = read_u16!(bytes);
258 Ok(Some(Property::TopicAliasMaximum(TopicAliasMaximum(topic_alias_maximum))))
259 },
260 PropertyType::TopicAlias => {
261 let topic_alias = read_u16!(bytes);
262 Ok(Some(Property::TopicAlias(TopicAlias(topic_alias))))
263 },
264 PropertyType::MaximumQos => {
265 let qos_byte = read_u8!(bytes);
266 let qos = QoS::try_from(qos_byte).map_err(|_| DecodeError::InvalidQoS)?;
267
268 Ok(Some(Property::MaximumQos(MaximumQos(qos))))
269 },
270 PropertyType::RetainAvailable => {
271 let retain_available = read_u8!(bytes);
272 Ok(Some(Property::RetainAvailable(RetainAvailable(retain_available))))
273 },
274 PropertyType::UserProperty => {
275 let (key, value) = read_string_pair!(bytes);
276 Ok(Some(Property::UserProperty(UserProperty(key, value))))
277 },
278 PropertyType::MaximumPacketSize => {
279 let maximum_packet_size = read_u32!(bytes);
280 Ok(Some(Property::MaximumPacketSize(MaximumPacketSize(maximum_packet_size))))
281 },
282 PropertyType::WildcardSubscriptionAvailable => {
283 let wildcard_subscription_available = read_u8!(bytes);
284 Ok(Some(Property::WildcardSubscriptionAvailable(WildcardSubscriptionAvailable(
285 wildcard_subscription_available,
286 ))))
287 },
288 PropertyType::SubscriptionIdentifierAvailable => {
289 let subscription_identifier_available = read_u8!(bytes);
290 Ok(Some(Property::SubscriptionIdentifierAvailable(SubscriptionIdentifierAvailable(
291 subscription_identifier_available,
292 ))))
293 },
294 PropertyType::SharedSubscriptionAvailable => {
295 let shared_subscription_available = read_u8!(bytes);
296 Ok(Some(Property::SharedSubscriptionAvailable(SharedSubscriptionAvailable(
297 shared_subscription_available,
298 ))))
299 },
300 }
301}
302
303fn decode_properties<F: FnMut(Property)>(
304 bytes: &mut Cursor<&mut BytesMut>,
305 mut closure: F,
306) -> Result<Option<()>, DecodeError> {
307 let property_length = read_variable_int!(bytes);
308
309 if property_length == 0 {
310 return Ok(Some(()));
311 }
312
313 require_length!(bytes, property_length as usize);
314
315 let start_cursor_pos = bytes.position();
316
317 loop {
318 let cursor_pos = bytes.position();
319
320 if cursor_pos - start_cursor_pos >= property_length as u64 {
321 break;
322 }
323
324 let property = read_property!(bytes);
325 closure(property);
326 }
327
328 Ok(Some(()))
329}
330
331fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, DecodeError> {
332 let protocol_name = read_string!(bytes);
333 let protocol_level = read_u8!(bytes);
334 let connect_flags = read_u8!(bytes);
335 let keep_alive = read_u16!(bytes);
336
337 let protocol_version = ProtocolVersion::try_from(protocol_level)
338 .map_err(|_| DecodeError::InvalidProtocolVersion)?;
339
340 let mut session_expiry_interval = None;
341 let mut receive_maximum = None;
342 let mut maximum_packet_size = None;
343 let mut topic_alias_maximum = None;
344 let mut request_response_information = None;
345 let mut request_problem_information = None;
346 let mut user_properties = vec![];
347 let mut authentication_method = None;
348 let mut authentication_data = None;
349
350 if protocol_version == ProtocolVersion::V500 {
351 return_if_none!(decode_properties(bytes, |property| {
352 match property {
353 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
354 Property::ReceiveMaximum(p) => receive_maximum = Some(p),
355 Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
356 Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
357 Property::RequestResponseInformation(p) => request_response_information = Some(p),
358 Property::RequestProblemInformation(p) => request_problem_information = Some(p),
359 Property::UserProperty(p) => user_properties.push(p),
360 Property::AuthenticationMethod(p) => authentication_method = Some(p),
361 Property::AuthenticationData(p) => authentication_data = Some(p),
362 _ => {}, }
364 })?);
365 }
366
367 let clean_start = connect_flags & 0b0000_0010 == 0b0000_0010;
369 let has_will = connect_flags & 0b0000_0100 == 0b0000_0100;
370 let will_qos_val = (connect_flags & 0b0001_1000) >> 3;
371 let will_qos = QoS::try_from(will_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
372 let retain_will = connect_flags & 0b0010_0000 == 0b0010_0000;
373 let has_password = connect_flags & 0b0100_0000 == 0b0100_0000;
374 let has_user_name = connect_flags & 0b1000_0000 == 0b1000_0000;
375
376 let client_id = read_string!(bytes);
377
378 let will = if has_will {
379 let mut will_delay_interval = None;
380 let mut payload_format_indicator = None;
381 let mut message_expiry_interval = None;
382 let mut content_type = None;
383 let mut response_topic = None;
384 let mut correlation_data = None;
385 let mut user_properties = vec![];
386
387 if protocol_version == ProtocolVersion::V500 {
388 return_if_none!(decode_properties(bytes, |property| {
389 match property {
390 Property::WillDelayInterval(p) => will_delay_interval = Some(p),
391 Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
392 Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
393 Property::ContentType(p) => content_type = Some(p),
394 Property::ResponseTopic(p) => response_topic = Some(p),
395 Property::CorrelationData(p) => correlation_data = Some(p),
396 Property::UserProperty(p) => user_properties.push(p),
397 _ => {}, }
399 })?);
400 }
401
402 let topic = read_string!(bytes);
403 let payload = read_binary_data!(bytes);
404
405 Some(FinalWill {
406 topic,
407 payload,
408 qos: will_qos,
409 should_retain: retain_will,
410 will_delay_interval,
411 payload_format_indicator,
412 message_expiry_interval,
413 content_type,
414 response_topic,
415 correlation_data,
416 user_properties,
417 })
418 } else {
419 None
420 };
421
422 let mut user_name = None;
423 let mut password = None;
424
425 if has_user_name {
426 user_name = Some(read_string!(bytes));
427 }
428
429 if has_password {
430 password = Some(read_string!(bytes));
431 }
432
433 let packet = ConnectPacket {
434 protocol_name,
435 protocol_version,
436 clean_start,
437 keep_alive,
438 session_expiry_interval,
439 receive_maximum,
440 maximum_packet_size,
441 topic_alias_maximum,
442 request_response_information,
443 request_problem_information,
444 user_properties,
445 authentication_method,
446 authentication_data,
447 client_id,
448 will,
449 user_name,
450 password,
451 };
452
453 Ok(Some(Packet::Connect(packet)))
454}
455
456fn decode_connect_ack(
457 bytes: &mut Cursor<&mut BytesMut>,
458 protocol_version: ProtocolVersion,
459) -> Result<Option<Packet>, DecodeError> {
460 let flags = read_u8!(bytes);
461 let session_present = (flags & 0b0000_0001) == 0b0000_0001;
462
463 let reason_code_byte = read_u8!(bytes);
464 let reason_code =
465 ConnectReason::try_from(reason_code_byte).map_err(|_| DecodeError::InvalidConnectReason)?;
466
467 let mut session_expiry_interval = None;
468 let mut receive_maximum = None;
469 let mut maximum_qos = None;
470 let mut retain_available = None;
471 let mut maximum_packet_size = None;
472 let mut assigned_client_identifier = None;
473 let mut topic_alias_maximum = None;
474 let mut reason_string = None;
475 let mut user_properties = vec![];
476 let mut wildcard_subscription_available = None;
477 let mut subscription_identifiers_available = None;
478 let mut shared_subscription_available = None;
479 let mut server_keep_alive = None;
480 let mut response_information = None;
481 let mut server_reference = None;
482 let mut authentication_method = None;
483 let mut authentication_data = None;
484
485 if protocol_version == ProtocolVersion::V500 {
486 return_if_none!(decode_properties(bytes, |property| {
487 match property {
488 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
489 Property::ReceiveMaximum(p) => receive_maximum = Some(p),
490 Property::MaximumQos(p) => maximum_qos = Some(p),
491 Property::RetainAvailable(p) => retain_available = Some(p),
492 Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
493 Property::AssignedClientIdentifier(p) => assigned_client_identifier = Some(p),
494 Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
495 Property::ReasonString(p) => reason_string = Some(p),
496 Property::UserProperty(p) => user_properties.push(p),
497 Property::WildcardSubscriptionAvailable(p) => {
498 wildcard_subscription_available = Some(p)
499 },
500 Property::SubscriptionIdentifierAvailable(p) => {
501 subscription_identifiers_available = Some(p)
502 },
503 Property::SharedSubscriptionAvailable(p) => shared_subscription_available = Some(p),
504 Property::ServerKeepAlive(p) => server_keep_alive = Some(p),
505 Property::ResponseInformation(p) => response_information = Some(p),
506 Property::ServerReference(p) => server_reference = Some(p),
507 Property::AuthenticationMethod(p) => authentication_method = Some(p),
508 Property::AuthenticationData(p) => authentication_data = Some(p),
509 _ => {}, }
511 })?);
512 }
513
514 let packet = ConnectAckPacket {
515 session_present,
516 reason_code,
517 session_expiry_interval,
518 receive_maximum,
519 maximum_qos,
520 retain_available,
521 maximum_packet_size,
522 assigned_client_identifier,
523 topic_alias_maximum,
524 reason_string,
525 user_properties,
526 wildcard_subscription_available,
527 subscription_identifiers_available,
528 shared_subscription_available,
529 server_keep_alive,
530 response_information,
531 server_reference,
532 authentication_method,
533 authentication_data,
534 };
535
536 Ok(Some(Packet::ConnectAck(packet)))
537}
538
539fn decode_publish(
540 bytes: &mut Cursor<&mut BytesMut>,
541 first_byte: u8,
542 remaining_packet_length: u32,
543 protocol_version: ProtocolVersion,
544) -> Result<Option<Packet>, DecodeError> {
545 let is_duplicate = (first_byte & 0b0000_1000) == 0b0000_1000;
546 let qos_val = (first_byte & 0b0000_0110) >> 1;
547 let qos = QoS::try_from(qos_val).map_err(|_| DecodeError::InvalidQoS)?;
548 let retain = (first_byte & 0b0000_0001) == 0b0000_0001;
549
550 let start_cursor_pos = bytes.position();
552
553 let topic_str = read_string!(bytes);
554 let topic = topic_str.parse().map_err(DecodeError::InvalidTopic)?;
555
556 let packet_id = match qos {
557 QoS::AtMostOnce => None,
558 QoS::AtLeastOnce | QoS::ExactlyOnce => Some(read_u16!(bytes)),
559 };
560
561 let mut payload_format_indicator = None;
562 let mut message_expiry_interval = None;
563 let mut topic_alias = None;
564 let mut response_topic = None;
565 let mut correlation_data = None;
566 let mut user_properties = vec![];
567 let mut subscription_identifier = None;
568 let mut content_type = None;
569
570 if protocol_version == ProtocolVersion::V500 {
571 return_if_none!(decode_properties(bytes, |property| {
572 match property {
573 Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
574 Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
575 Property::TopicAlias(p) => topic_alias = Some(p),
576 Property::ResponseTopic(p) => response_topic = Some(p),
577 Property::CorrelationData(p) => correlation_data = Some(p),
578 Property::UserProperty(p) => user_properties.push(p),
579 Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
580 Property::ContentType(p) => content_type = Some(p),
581 _ => {}, }
583 })?);
584 }
585
586 let end_cursor_pos = bytes.position();
587 let variable_header_size = (end_cursor_pos - start_cursor_pos) as u32;
588 if remaining_packet_length < variable_header_size {
591 return Err(DecodeError::InvalidRemainingLength);
592 }
593 let payload_size = remaining_packet_length - variable_header_size;
594 let payload = return_if_none!(decode_binary_data_with_size(bytes, payload_size as usize)?);
595
596 let packet = PublishPacket {
597 is_duplicate,
598 qos,
599 retain,
600
601 topic,
602 packet_id,
603
604 payload_format_indicator,
605 message_expiry_interval,
606 topic_alias,
607 response_topic,
608 correlation_data,
609 user_properties,
610 subscription_identifier,
611 content_type,
612
613 payload,
614 };
615
616 Ok(Some(Packet::Publish(packet)))
617}
618
619fn decode_publish_ack(
620 bytes: &mut Cursor<&mut BytesMut>,
621 remaining_packet_length: u32,
622 protocol_version: ProtocolVersion,
623) -> Result<Option<Packet>, DecodeError> {
624 let packet_id = read_u16!(bytes);
625
626 if remaining_packet_length == 2 {
627 return Ok(Some(Packet::PublishAck(PublishAckPacket {
628 packet_id,
629 reason_code: PublishAckReason::Success,
630 reason_string: None,
631 user_properties: vec![],
632 })));
633 }
634
635 let reason_code_byte = read_u8!(bytes);
636 let reason_code = PublishAckReason::try_from(reason_code_byte)
637 .map_err(|_| DecodeError::InvalidPublishAckReason)?;
638
639 let mut reason_string = None;
640 let mut user_properties = vec![];
641
642 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
643 return_if_none!(decode_properties(bytes, |property| {
644 match property {
645 Property::ReasonString(p) => reason_string = Some(p),
646 Property::UserProperty(p) => user_properties.push(p),
647 _ => {}, }
649 })?);
650 }
651
652 let packet = PublishAckPacket { packet_id, reason_code, reason_string, user_properties };
653
654 Ok(Some(Packet::PublishAck(packet)))
655}
656
657fn decode_publish_received(
658 bytes: &mut Cursor<&mut BytesMut>,
659 remaining_packet_length: u32,
660 protocol_version: ProtocolVersion,
661) -> Result<Option<Packet>, DecodeError> {
662 let packet_id = read_u16!(bytes);
663
664 if remaining_packet_length == 2 {
665 return Ok(Some(Packet::PublishReceived(PublishReceivedPacket {
666 packet_id,
667 reason_code: PublishReceivedReason::Success,
668 reason_string: None,
669 user_properties: vec![],
670 })));
671 }
672
673 let reason_code_byte = read_u8!(bytes);
674 let reason_code = PublishReceivedReason::try_from(reason_code_byte)
675 .map_err(|_| DecodeError::InvalidPublishReceivedReason)?;
676
677 let mut reason_string = None;
678 let mut user_properties = vec![];
679
680 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
681 return_if_none!(decode_properties(bytes, |property| {
682 match property {
683 Property::ReasonString(p) => reason_string = Some(p),
684 Property::UserProperty(p) => user_properties.push(p),
685 _ => {}, }
687 })?);
688 }
689
690 let packet = PublishReceivedPacket { packet_id, reason_code, reason_string, user_properties };
691
692 Ok(Some(Packet::PublishReceived(packet)))
693}
694
695fn decode_publish_release(
696 bytes: &mut Cursor<&mut BytesMut>,
697 remaining_packet_length: u32,
698 protocol_version: ProtocolVersion,
699) -> Result<Option<Packet>, DecodeError> {
700 let packet_id = read_u16!(bytes);
701
702 if remaining_packet_length == 2 {
703 return Ok(Some(Packet::PublishRelease(PublishReleasePacket {
704 packet_id,
705 reason_code: PublishReleaseReason::Success,
706 reason_string: None,
707 user_properties: vec![],
708 })));
709 }
710
711 let reason_code_byte = read_u8!(bytes);
712 let reason_code = PublishReleaseReason::try_from(reason_code_byte)
713 .map_err(|_| DecodeError::InvalidPublishReleaseReason)?;
714
715 let mut reason_string = None;
716 let mut user_properties = vec![];
717
718 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
719 return_if_none!(decode_properties(bytes, |property| {
720 match property {
721 Property::ReasonString(p) => reason_string = Some(p),
722 Property::UserProperty(p) => user_properties.push(p),
723 _ => {}, }
725 })?);
726 }
727
728 let packet = PublishReleasePacket { packet_id, reason_code, reason_string, user_properties };
729
730 Ok(Some(Packet::PublishRelease(packet)))
731}
732
733fn decode_publish_complete(
734 bytes: &mut Cursor<&mut BytesMut>,
735 remaining_packet_length: u32,
736 protocol_version: ProtocolVersion,
737) -> Result<Option<Packet>, DecodeError> {
738 let packet_id = read_u16!(bytes);
739
740 if remaining_packet_length == 2 {
741 return Ok(Some(Packet::PublishComplete(PublishCompletePacket {
742 packet_id,
743 reason_code: PublishCompleteReason::Success,
744 reason_string: None,
745 user_properties: vec![],
746 })));
747 }
748
749 let reason_code_byte = read_u8!(bytes);
750 let reason_code = PublishCompleteReason::try_from(reason_code_byte)
751 .map_err(|_| DecodeError::InvalidPublishCompleteReason)?;
752
753 let mut reason_string = None;
754 let mut user_properties = vec![];
755
756 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 4 {
757 return_if_none!(decode_properties(bytes, |property| {
758 match property {
759 Property::ReasonString(p) => reason_string = Some(p),
760 Property::UserProperty(p) => user_properties.push(p),
761 _ => {}, }
763 })?);
764 }
765
766 let packet = PublishCompletePacket { packet_id, reason_code, reason_string, user_properties };
767
768 Ok(Some(Packet::PublishComplete(packet)))
769}
770
771fn decode_subscribe(
772 bytes: &mut Cursor<&mut BytesMut>,
773 remaining_packet_length: u32,
774 protocol_version: ProtocolVersion,
775) -> Result<Option<Packet>, DecodeError> {
776 let start_cursor_pos = bytes.position();
777
778 let packet_id = read_u16!(bytes);
779
780 let mut subscription_identifier = None;
781 let mut user_properties = vec![];
782
783 if protocol_version == ProtocolVersion::V500 {
784 return_if_none!(decode_properties(bytes, |property| {
785 match property {
786 Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
787 Property::UserProperty(p) => user_properties.push(p),
788 _ => {}, }
790 })?);
791 }
792
793 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
794 if remaining_packet_length < variable_header_size {
795 return Err(DecodeError::InvalidRemainingLength);
796 }
797 let payload_size = remaining_packet_length - variable_header_size;
798
799 let mut subscription_topics = vec![];
800 let mut bytes_read: usize = 0;
801
802 loop {
803 if bytes_read >= payload_size as usize {
804 break;
805 }
806
807 let start_cursor_pos = bytes.position();
808
809 let topic_filter_str = read_string!(bytes);
810 let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
811
812 let options_byte = read_u8!(bytes);
813
814 let maximum_qos_val = options_byte & 0b0000_0011;
815 let maximum_qos = QoS::try_from(maximum_qos_val).map_err(|_| DecodeError::InvalidQoS)?;
816
817 let retain_handling_val = (options_byte & 0b0011_0000) >> 4;
818 let retain_handling = RetainHandling::try_from(retain_handling_val)
819 .map_err(|_| DecodeError::InvalidRetainHandling)?;
820
821 let retain_as_published = (options_byte & 0b0000_1000) == 0b0000_1000;
822 let no_local = (options_byte & 0b0000_0100) == 0b0000_0100;
823
824 let subscription_topic = SubscriptionTopic {
825 topic_filter,
826 maximum_qos,
827 no_local,
828 retain_as_published,
829 retain_handling,
830 };
831
832 subscription_topics.push(subscription_topic);
833
834 let end_cursor_pos = bytes.position();
835 bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
836 }
837
838 let packet = SubscribePacket {
839 packet_id,
840 subscription_identifier,
841 user_properties,
842 subscription_topics,
843 };
844
845 Ok(Some(Packet::Subscribe(packet)))
846}
847
848fn decode_subscribe_ack(
849 bytes: &mut Cursor<&mut BytesMut>,
850 remaining_packet_length: u32,
851 protocol_version: ProtocolVersion,
852) -> Result<Option<Packet>, DecodeError> {
853 let start_cursor_pos = bytes.position();
854
855 let packet_id = read_u16!(bytes);
856
857 let mut reason_string = None;
858 let mut user_properties = vec![];
859
860 if protocol_version == ProtocolVersion::V500 {
861 return_if_none!(decode_properties(bytes, |property| {
862 match property {
863 Property::ReasonString(p) => reason_string = Some(p),
864 Property::UserProperty(p) => user_properties.push(p),
865 _ => {}, }
867 })?);
868 }
869
870 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
871 if remaining_packet_length < variable_header_size {
872 return Err(DecodeError::InvalidRemainingLength);
873 }
874 let payload_size = remaining_packet_length - variable_header_size;
875
876 let mut reason_codes = vec![];
877 for _ in 0..payload_size {
878 let next_byte = read_u8!(bytes);
879 let reason_code = SubscribeAckReason::try_from(next_byte)
880 .map_err(|_| DecodeError::InvalidSubscribeAckReason)?;
881 reason_codes.push(reason_code);
882 }
883
884 let packet = SubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
885
886 Ok(Some(Packet::SubscribeAck(packet)))
887}
888
889fn decode_unsubscribe(
890 bytes: &mut Cursor<&mut BytesMut>,
891 remaining_packet_length: u32,
892 protocol_version: ProtocolVersion,
893) -> Result<Option<Packet>, DecodeError> {
894 let start_cursor_pos = bytes.position();
895
896 let packet_id = read_u16!(bytes);
897
898 let mut user_properties = vec![];
899
900 if protocol_version == ProtocolVersion::V500 {
901 return_if_none!(decode_properties(bytes, |property| {
902 if let Property::UserProperty(p) = property {
903 user_properties.push(p);
904 }
905 })?);
906 }
907
908 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
909 if remaining_packet_length < variable_header_size {
910 return Err(DecodeError::InvalidRemainingLength);
911 }
912 let payload_size = remaining_packet_length - variable_header_size;
913
914 let mut topic_filters = vec![];
915 let mut bytes_read: usize = 0;
916
917 loop {
918 if bytes_read >= payload_size as usize {
919 break;
920 }
921
922 let start_cursor_pos = bytes.position();
923
924 let topic_filter_str = read_string!(bytes);
925 let topic_filter = topic_filter_str.parse().map_err(DecodeError::InvalidTopicFilter)?;
926 topic_filters.push(topic_filter);
927
928 let end_cursor_pos = bytes.position();
929 bytes_read += (end_cursor_pos - start_cursor_pos) as usize;
930 }
931
932 let packet = UnsubscribePacket { packet_id, user_properties, topic_filters };
933
934 Ok(Some(Packet::Unsubscribe(packet)))
935}
936
937fn decode_unsubscribe_ack(
938 bytes: &mut Cursor<&mut BytesMut>,
939 remaining_packet_length: u32,
940 protocol_version: ProtocolVersion,
941) -> Result<Option<Packet>, DecodeError> {
942 let start_cursor_pos = bytes.position();
943
944 let packet_id = read_u16!(bytes);
945
946 let mut reason_string = None;
947 let mut user_properties = vec![];
948
949 if protocol_version == ProtocolVersion::V500 {
950 return_if_none!(decode_properties(bytes, |property| {
951 match property {
952 Property::ReasonString(p) => reason_string = Some(p),
953 Property::UserProperty(p) => user_properties.push(p),
954 _ => {}, }
956 })?);
957 }
958
959 let variable_header_size = (bytes.position() - start_cursor_pos) as u32;
960 if remaining_packet_length < variable_header_size {
961 return Err(DecodeError::InvalidRemainingLength);
962 }
963 let payload_size = remaining_packet_length - variable_header_size;
964
965 let mut reason_codes = vec![];
966 for _ in 0..payload_size {
967 let next_byte = read_u8!(bytes);
968 let reason_code = UnsubscribeAckReason::try_from(next_byte)
969 .map_err(|_| DecodeError::InvalidUnsubscribeAckReason)?;
970 reason_codes.push(reason_code);
971 }
972
973 let packet = UnsubscribeAckPacket { packet_id, reason_string, user_properties, reason_codes };
974
975 Ok(Some(Packet::UnsubscribeAck(packet)))
976}
977
978fn decode_disconnect(
979 bytes: &mut Cursor<&mut BytesMut>,
980 remaining_packet_length: u32,
981 protocol_version: ProtocolVersion,
982) -> Result<Option<Packet>, DecodeError> {
983 if remaining_packet_length == 0 {
984 return Ok(Some(Packet::Disconnect(DisconnectPacket {
985 reason_code: DisconnectReason::NormalDisconnection,
986 session_expiry_interval: None,
987 reason_string: None,
988 user_properties: vec![],
989 server_reference: None,
990 })));
991 }
992
993 let reason_code_byte = read_u8!(bytes);
994 let reason_code = DisconnectReason::try_from(reason_code_byte)
995 .map_err(|_| DecodeError::InvalidDisconnectReason)?;
996
997 let mut session_expiry_interval = None;
998 let mut reason_string = None;
999 let mut user_properties = vec![];
1000 let mut server_reference = None;
1001
1002 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1003 return_if_none!(decode_properties(bytes, |property| {
1004 match property {
1005 Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
1006 Property::ReasonString(p) => reason_string = Some(p),
1007 Property::UserProperty(p) => user_properties.push(p),
1008 Property::ServerReference(p) => server_reference = Some(p),
1009 _ => {}, }
1011 })?);
1012 }
1013
1014 let packet = DisconnectPacket {
1015 reason_code,
1016 session_expiry_interval,
1017 reason_string,
1018 user_properties,
1019 server_reference,
1020 };
1021
1022 Ok(Some(Packet::Disconnect(packet)))
1023}
1024
1025fn decode_authenticate(
1026 bytes: &mut Cursor<&mut BytesMut>,
1027 remaining_packet_length: u32,
1028 protocol_version: ProtocolVersion,
1029) -> Result<Option<Packet>, DecodeError> {
1030 if remaining_packet_length == 0 {
1031 return Ok(Some(Packet::Authenticate(AuthenticatePacket {
1032 reason_code: AuthenticateReason::Success,
1033 authentication_method: None,
1034 authentication_data: None,
1035 reason_string: None,
1036 user_properties: vec![],
1037 })));
1038 }
1039
1040 let reason_code_byte = read_u8!(bytes);
1041 let reason_code = AuthenticateReason::try_from(reason_code_byte)
1042 .map_err(|_| DecodeError::InvalidAuthenticateReason)?;
1043
1044 let mut authentication_method = None;
1045 let mut authentication_data = None;
1046 let mut reason_string = None;
1047 let mut user_properties = vec![];
1048
1049 if protocol_version == ProtocolVersion::V500 && remaining_packet_length >= 2 {
1050 return_if_none!(decode_properties(bytes, |property| {
1051 match property {
1052 Property::AuthenticationMethod(p) => authentication_method = Some(p),
1053 Property::AuthenticationData(p) => authentication_data = Some(p),
1054 Property::ReasonString(p) => reason_string = Some(p),
1055 Property::UserProperty(p) => user_properties.push(p),
1056 _ => {}, }
1058 })?);
1059 }
1060
1061 let packet = AuthenticatePacket {
1062 reason_code,
1063 authentication_method,
1064 authentication_data,
1065 reason_string,
1066 user_properties,
1067 };
1068
1069 Ok(Some(Packet::Authenticate(packet)))
1070}
1071
1072fn decode_packet(
1073 protocol_version: ProtocolVersion,
1074 packet_type: &PacketType,
1075 bytes: &mut Cursor<&mut BytesMut>,
1076 remaining_packet_length: u32,
1077 first_byte: u8,
1078) -> Result<Option<Packet>, DecodeError> {
1079 match packet_type {
1080 PacketType::Connect => decode_connect(bytes),
1081 PacketType::ConnectAck => decode_connect_ack(bytes, protocol_version),
1082 PacketType::Publish => {
1083 decode_publish(bytes, first_byte, remaining_packet_length, protocol_version)
1084 },
1085 PacketType::PublishAck => {
1086 decode_publish_ack(bytes, remaining_packet_length, protocol_version)
1087 },
1088 PacketType::PublishReceived => {
1089 decode_publish_received(bytes, remaining_packet_length, protocol_version)
1090 },
1091 PacketType::PublishRelease => {
1092 decode_publish_release(bytes, remaining_packet_length, protocol_version)
1093 },
1094 PacketType::PublishComplete => {
1095 decode_publish_complete(bytes, remaining_packet_length, protocol_version)
1096 },
1097 PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length, protocol_version),
1098 PacketType::SubscribeAck => {
1099 decode_subscribe_ack(bytes, remaining_packet_length, protocol_version)
1100 },
1101 PacketType::Unsubscribe => {
1102 decode_unsubscribe(bytes, remaining_packet_length, protocol_version)
1103 },
1104 PacketType::UnsubscribeAck => {
1105 decode_unsubscribe_ack(bytes, remaining_packet_length, protocol_version)
1106 },
1107 PacketType::PingRequest => Ok(Some(Packet::PingRequest)),
1108 PacketType::PingResponse => Ok(Some(Packet::PingResponse)),
1109 PacketType::Disconnect => {
1110 decode_disconnect(bytes, remaining_packet_length, protocol_version)
1111 },
1112 PacketType::Authenticate => {
1113 decode_authenticate(bytes, remaining_packet_length, protocol_version)
1114 },
1115 }
1116}
1117
1118pub fn decode_mqtt(
1119 bytes: &mut BytesMut,
1120 protocol_version: ProtocolVersion,
1121) -> Result<Option<Packet>, DecodeError> {
1122 let mut bytes = Cursor::new(bytes);
1123 let first_byte = read_u8!(bytes);
1124
1125 let first_byte_val = (first_byte & 0b1111_0000) >> 4;
1126 let packet_type =
1127 PacketType::try_from(first_byte_val).map_err(|_| DecodeError::InvalidPacketType)?;
1128 let remaining_packet_length = read_variable_int!(&mut bytes);
1129
1130 let cursor_pos = bytes.position() as usize;
1131 let remaining_buffer_amount = bytes.get_ref().len() - cursor_pos;
1132
1133 if remaining_buffer_amount < remaining_packet_length as usize {
1134 return Ok(None);
1136 }
1137
1138 let packet = return_if_none!(decode_packet(
1139 protocol_version,
1140 &packet_type,
1141 &mut bytes,
1142 remaining_packet_length,
1143 first_byte
1144 )?);
1145
1146 let cursor_pos = bytes.position() as usize;
1147 let bytes = bytes.into_inner();
1148
1149 let _rest = bytes.split_to(cursor_pos);
1150
1151 Ok(Some(packet))
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156 use crate::{decoder::*, types::*};
1157 use bytes::BytesMut;
1158
1159 #[test]
1160 fn test_invalid_remaining_length() {
1161 let mut bytes = BytesMut::new();
1162 bytes.extend_from_slice(&[136, 1, 0, 36, 0, 0]); let _ = decode_mqtt(&mut bytes, ProtocolVersion::V500);
1165 }
1166
1167 #[test]
1168 fn test_decode_variable_int() {
1169 fn normal_test(encoded_variable_int: &[u8], expected_variable_int: u32) {
1172 let bytes = &mut BytesMut::new();
1173 bytes.extend_from_slice(encoded_variable_int);
1174 match decode_variable_int(&mut Cursor::new(bytes)) {
1175 Ok(val) => match val {
1176 Some(get_variable_int) => assert_eq!(get_variable_int, expected_variable_int),
1177 None => panic!("variable_int is None"),
1178 },
1179 Err(err) => panic!(err),
1180 }
1181 }
1182
1183 normal_test(&[0x00], 0);
1185 normal_test(&[0x7F], 127);
1186
1187 normal_test(&[0x80, 0x01], 128);
1189 normal_test(&[0xFF, 0x7F], 16383);
1190
1191 normal_test(&[0x80, 0x80, 0x01], 16384);
1193 normal_test(&[0xFF, 0xFF, 0x7F], 2097151);
1194
1195 normal_test(&[0x80, 0x80, 0x80, 0x01], 2097152);
1197 normal_test(&[0xFF, 0xFF, 0xFF, 0x7F], 268435455);
1198 }
1199}