1#![deny(unused_must_use)]
2#![deny(clippy::all)]
3#![deny(clippy::pedantic)]
4#![deny(clippy::recursive_format_impl)]
5#![allow(clippy::missing_errors_doc)]
6#![allow(clippy::module_name_repetitions)]
7
8pub use self::{
9 connack::{ConnAck, ConnAckProperties, ConnectReturnCode},
10 connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login},
11 disconnect::{Disconnect, DisconnectProperties, DisconnectReasonCode},
12 ping::{PingReq, PingResp},
13 puback::{PubAck, PubAckProperties, PubAckReason},
14 pubcomp::{PubComp, PubCompProperties, PubCompReason},
15 publish::{Publish, PublishProperties},
16 pubrec::{PubRec, PubRecProperties, PubRecReason},
17 pubrel::{PubRel, PubRelProperties, PubRelReason},
18 suback::{SubAck, SubAckProperties, SubscribeReasonCode},
19 subscribe::{Filter, RetainForwardRule, Subscribe, SubscribeProperties},
20 unsuback::{UnsubAck, UnsubAckProperties, UnsubAckReason},
21 unsubscribe::{Unsubscribe, UnsubscribeProperties},
22};
23use bytes::{Buf, BufMut, Bytes, BytesMut};
24#[cfg(feature = "cow_string")]
25use std::borrow::Cow;
26use std::{fmt::Debug, slice::Iter};
27use std::{str::Utf8Error, vec};
28
29mod connack;
30mod connect;
31mod disconnect;
32mod ping;
33mod puback;
34mod pubcomp;
35mod publish;
36mod pubrec;
37mod pubrel;
38mod suback;
39mod subscribe;
40mod unsuback;
41mod unsubscribe;
42
43#[cfg(all(feature = "boxed_string", feature = "binary_string"))]
44compile_error!(
45 "feature \"boxed_string\" and feature \"binary_string\" cannot be enabled at the same time"
46);
47#[cfg(all(feature = "boxed_string", feature = "cow_string"))]
48compile_error!(
49 "feature \"boxed_string\" and feature \"cow_string\" cannot be enabled at the same time"
50);
51#[cfg(all(feature = "binary_string", feature = "cow_string"))]
52compile_error!(
53 "feature \"binary_string\" and feature \"cow_string\" cannot be enabled at the same time"
54);
55
56#[cfg(feature = "boxed_string")]
57type MqttString = Box<str>;
58
59#[cfg(feature = "binary_string")]
60type MqttString = Bytes;
61
62#[cfg(feature = "cow_string")]
63type MqttString = Cow<'static, str>;
64
65#[cfg(all(
66 not(feature = "boxed_string"),
67 not(feature = "binary_string"),
68 not(feature = "cow_string")
69))]
70type MqttString = String;
71
72#[cfg(all(
73 not(feature = "boxed_string"),
74 not(feature = "binary_string"),
75 not(feature = "cow_string")
76))]
77#[inline]
78fn mqtt_string_eq(m: &MqttString, str: &str) -> bool {
79 m == str
80}
81
82#[cfg(any(feature = "boxed_string", feature = "cow_string"))]
83#[inline]
84fn mqtt_string_eq(m: &MqttString, str: &str) -> bool {
85 m.as_ref().eq(str)
86}
87
88#[cfg(feature = "binary_string")]
89#[inline]
90fn mqtt_string_eq(m: &Bytes, str: &str) -> bool {
91 m.eq(str.as_bytes())
92}
93
94#[cfg(all(
95 not(feature = "boxed_string"),
96 not(feature = "binary_string"),
97 not(feature = "cow_string")
98))]
99#[inline]
100#[must_use]
101pub fn mqtt_string_new(str: &'static str) -> MqttString {
102 str.to_string()
103}
104
105#[cfg(feature = "boxed_string")]
106#[inline]
107#[must_use]
108pub fn mqtt_string_new(str: &'static str) -> MqttString {
109 str.into()
110}
111
112#[cfg(feature = "binary_string")]
113#[inline]
114#[must_use]
115pub fn mqtt_string_new(str: &str) -> MqttString {
116 Bytes::copy_from_slice(str.as_bytes())
117}
118
119#[cfg(feature = "cow_string")]
120#[inline]
121#[must_use]
122pub fn mqtt_string_new(str: &'static str) -> MqttString {
123 Cow::Borrowed(str)
124}
125
126#[derive(Clone, Debug, PartialEq, Eq)]
127pub enum Packet {
128 Connect(Connect, Option<LastWill>, Option<Login>),
129 ConnAck(ConnAck),
130 Publish(Publish),
131 PubAck(PubAck),
132 PingReq(PingReq),
133 PingResp(PingResp),
134 Subscribe(Subscribe),
135 SubAck(SubAck),
136 PubRec(PubRec),
137 PubRel(PubRel),
138 PubComp(PubComp),
139 Unsubscribe(Unsubscribe),
140 UnsubAck(UnsubAck),
141 Disconnect(Disconnect),
142}
143
144impl Packet {
145 pub fn read(stream: &mut BytesMut, max_size: Option<usize>) -> Result<Packet, Error> {
147 let fixed_header = check(stream.iter(), max_size)?;
148
149 let packet = stream.split_to(fixed_header.frame_length());
151 let packet_type = fixed_header.packet_type()?;
152
153 if fixed_header.remaining_len == 0 && packet_type != PacketType::Disconnect {
154 return match packet_type {
155 PacketType::PingReq => Ok(Packet::PingReq(PingReq)),
156 PacketType::PingResp => Ok(Packet::PingResp(PingResp)),
157 _ => Err(Error::PayloadRequired),
158 };
159 }
160
161 let packet = packet.freeze();
162 let packet = match packet_type {
163 PacketType::Connect => {
164 let (connect, will, login) = Connect::read(fixed_header, packet)?;
165 Packet::Connect(connect, will, login)
166 }
167 PacketType::Publish => {
168 let publish = Publish::read(fixed_header, packet)?;
169 Packet::Publish(publish)
170 }
171 PacketType::Subscribe => {
172 let subscribe = Subscribe::read(fixed_header, packet)?;
173 Packet::Subscribe(subscribe)
174 }
175 PacketType::Unsubscribe => {
176 let unsubscribe = Unsubscribe::read(fixed_header, packet)?;
177 Packet::Unsubscribe(unsubscribe)
178 }
179 PacketType::ConnAck => {
180 let connack = ConnAck::read(fixed_header, packet)?;
181 Packet::ConnAck(connack)
182 }
183 PacketType::PubAck => {
184 let puback = PubAck::read(fixed_header, packet)?;
185 Packet::PubAck(puback)
186 }
187 PacketType::PubRec => {
188 let pubrec = PubRec::read(fixed_header, packet)?;
189 Packet::PubRec(pubrec)
190 }
191 PacketType::PubRel => {
192 let pubrel = PubRel::read(fixed_header, packet)?;
193 Packet::PubRel(pubrel)
194 }
195 PacketType::PubComp => {
196 let pubcomp = PubComp::read(fixed_header, packet)?;
197 Packet::PubComp(pubcomp)
198 }
199 PacketType::SubAck => {
200 let suback = SubAck::read(fixed_header, packet)?;
201 Packet::SubAck(suback)
202 }
203 PacketType::UnsubAck => {
204 let unsuback = UnsubAck::read(fixed_header, packet)?;
205 Packet::UnsubAck(unsuback)
206 }
207 PacketType::PingReq => Packet::PingReq(PingReq),
208 PacketType::PingResp => Packet::PingResp(PingResp),
209 PacketType::Disconnect => {
210 let disconnect = Disconnect::read(fixed_header, packet)?;
211 Packet::Disconnect(disconnect)
212 }
213 };
214
215 Ok(packet)
216 }
217
218 pub fn write(&self, write: &mut BytesMut) -> Result<usize, Error> {
219 match self {
220 Self::Publish(publish) => publish.write(write),
221 Self::Subscribe(subscription) => subscription.write(write),
222 Self::Unsubscribe(unsubscribe) => unsubscribe.write(write),
223 Self::ConnAck(ack) => ack.write(write),
224 Self::PubAck(ack) => ack.write(write),
225 Self::SubAck(ack) => ack.write(write),
226 Self::UnsubAck(unsuback) => unsuback.write(write),
227 Self::PubRec(pubrec) => pubrec.write(write),
228 Self::PubRel(pubrel) => pubrel.write(write),
229 Self::PubComp(pubcomp) => pubcomp.write(write),
230 Self::Connect(connect, will, login) => connect.write(will, login, write),
231 Self::PingReq(_) => PingReq::write(write),
232 Self::PingResp(_) => PingResp::write(write),
233 Self::Disconnect(disconnect) => disconnect.write(write),
234 }
235 }
236}
237
238#[repr(u8)]
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
241pub enum PacketType {
242 Connect = 1,
243 ConnAck,
244 Publish,
245 PubAck,
246 PubRec,
247 PubRel,
248 PubComp,
249 Subscribe,
250 SubAck,
251 Unsubscribe,
252 UnsubAck,
253 PingReq,
254 PingResp,
255 Disconnect,
256}
257
258#[repr(u8)]
259#[derive(Debug, Clone, Copy, PartialEq, Eq)]
260enum PropertyType {
261 PayloadFormatIndicator = 1,
262 MessageExpiryInterval = 2,
263 ContentType = 3,
264 ResponseTopic = 8,
265 CorrelationData = 9,
266 SubscriptionIdentifier = 11,
267 SessionExpiryInterval = 17,
268 AssignedClientIdentifier = 18,
269 ServerKeepAlive = 19,
270 AuthenticationMethod = 21,
271 AuthenticationData = 22,
272 RequestProblemInformation = 23,
273 WillDelayInterval = 24,
274 RequestResponseInformation = 25,
275 ResponseInformation = 26,
276 ServerReference = 28,
277 ReasonString = 31,
278 ReceiveMaximum = 33,
279 TopicAliasMaximum = 34,
280 TopicAlias = 35,
281 MaximumQos = 36,
282 RetainAvailable = 37,
283 UserProperty = 38,
284 MaximumPacketSize = 39,
285 WildcardSubscriptionAvailable = 40,
286 SubscriptionIdentifierAvailable = 41,
287 SharedSubscriptionAvailable = 42,
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
303pub struct FixedHeader {
304 byte1: u8,
307 fixed_header_len: usize,
311 remaining_len: usize,
314}
315
316impl FixedHeader {
317 #[must_use]
318 pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
319 FixedHeader {
320 byte1,
321 fixed_header_len: remaining_len_len + 1,
322 remaining_len,
323 }
324 }
325
326 pub fn packet_type(&self) -> Result<PacketType, Error> {
327 let num = self.byte1 >> 4;
328 match num {
329 1 => Ok(PacketType::Connect),
330 2 => Ok(PacketType::ConnAck),
331 3 => Ok(PacketType::Publish),
332 4 => Ok(PacketType::PubAck),
333 5 => Ok(PacketType::PubRec),
334 6 => Ok(PacketType::PubRel),
335 7 => Ok(PacketType::PubComp),
336 8 => Ok(PacketType::Subscribe),
337 9 => Ok(PacketType::SubAck),
338 10 => Ok(PacketType::Unsubscribe),
339 11 => Ok(PacketType::UnsubAck),
340 12 => Ok(PacketType::PingReq),
341 13 => Ok(PacketType::PingResp),
342 14 => Ok(PacketType::Disconnect),
343 _ => Err(Error::InvalidPacketType(num)),
344 }
345 }
346
347 #[must_use]
350 pub fn frame_length(&self) -> usize {
351 self.fixed_header_len + self.remaining_len
352 }
353}
354
355fn property(num: u8) -> Result<PropertyType, Error> {
356 let property = match num {
357 1 => PropertyType::PayloadFormatIndicator,
358 2 => PropertyType::MessageExpiryInterval,
359 3 => PropertyType::ContentType,
360 8 => PropertyType::ResponseTopic,
361 9 => PropertyType::CorrelationData,
362 11 => PropertyType::SubscriptionIdentifier,
363 17 => PropertyType::SessionExpiryInterval,
364 18 => PropertyType::AssignedClientIdentifier,
365 19 => PropertyType::ServerKeepAlive,
366 21 => PropertyType::AuthenticationMethod,
367 22 => PropertyType::AuthenticationData,
368 23 => PropertyType::RequestProblemInformation,
369 24 => PropertyType::WillDelayInterval,
370 25 => PropertyType::RequestResponseInformation,
371 26 => PropertyType::ResponseInformation,
372 28 => PropertyType::ServerReference,
373 31 => PropertyType::ReasonString,
374 33 => PropertyType::ReceiveMaximum,
375 34 => PropertyType::TopicAliasMaximum,
376 35 => PropertyType::TopicAlias,
377 36 => PropertyType::MaximumQos,
378 37 => PropertyType::RetainAvailable,
379 38 => PropertyType::UserProperty,
380 39 => PropertyType::MaximumPacketSize,
381 40 => PropertyType::WildcardSubscriptionAvailable,
382 41 => PropertyType::SubscriptionIdentifierAvailable,
383 42 => PropertyType::SharedSubscriptionAvailable,
384 num => return Err(Error::InvalidPropertyType(num)),
385 };
386
387 Ok(property)
388}
389
390pub fn check(stream: Iter<u8>, max_packet_size: Option<usize>) -> Result<FixedHeader, Error> {
396 let stream_len = stream.len();
399 let fixed_header = parse_fixed_header(stream)?;
400
401 if let Some(max_size) = max_packet_size {
404 if fixed_header.remaining_len > max_size {
405 return Err(Error::PayloadSizeLimitExceeded {
406 pkt_size: fixed_header.remaining_len,
407 max: max_size,
408 });
409 }
410 }
411
412 let frame_length = fixed_header.frame_length();
415 if stream_len < frame_length {
416 return Err(Error::InsufficientBytes(frame_length - stream_len));
417 }
418
419 Ok(fixed_header)
420}
421
422pub(crate) fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
424 let stream_len = stream.len();
426 if stream_len < 2 {
427 return Err(Error::InsufficientBytes(2 - stream_len));
428 }
429
430 let byte1 = stream.next().unwrap();
431 let (len_len, len) = length(stream)?;
432
433 Ok(FixedHeader::new(*byte1, len_len, len))
434}
435
436fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
440 let mut len: usize = 0;
441 let mut len_len = 0;
442 let mut done = false;
443 let mut shift = 0;
444
445 for byte in stream {
450 len_len += 1;
451 let byte = *byte as usize;
452 len += (byte & 0x7F) << shift;
453
454 done = (byte & 0x80) == 0;
456 if done {
457 break;
458 }
459
460 shift += 7;
461
462 if shift > 21 {
465 return Err(Error::MalformedRemainingLength);
466 }
467 }
468
469 if !done {
472 return Err(Error::InsufficientBytes(1));
473 }
474
475 Ok((len_len, len))
476}
477
478#[inline]
480fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
481 let len = read_u16(stream)? as usize;
482
483 if len > stream.len() {
488 return Err(Error::BoundaryCrossed(len));
489 }
490
491 Ok(stream.split_to(len))
492}
493
494#[inline]
496#[cfg(all(not(feature = "binary_string"), not(feature = "cow_string"),))]
497fn read_mqtt_string(stream: &mut Bytes) -> Result<MqttString, Error> {
498 let bytes = read_mqtt_bytes(stream)?;
499 match std::str::from_utf8(&bytes) {
500 Ok(v) => Ok(v.into()),
501 Err(_) => Err(Error::TopicNotUtf8),
502 }
503}
504
505#[inline]
506#[cfg(feature = "cow_string")]
507fn read_mqtt_string(stream: &mut Bytes) -> Result<Cow<'static, str>, Error> {
508 let bytes = read_mqtt_bytes(stream)?;
509 match std::str::from_utf8(&bytes) {
510 Ok(v) => Ok(Cow::Owned(v.to_string())),
511 Err(_) => Err(Error::TopicNotUtf8),
512 }
513}
514
515#[inline]
516#[cfg(feature = "binary_string")]
517fn read_mqtt_string(stream: &mut Bytes) -> Result<MqttString, Error> {
518 read_mqtt_bytes(stream)
519}
520
521#[inline]
523fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) -> Result<(), Error> {
524 let Ok(len) = u16::try_from(bytes.len()) else {
525 return Err(Error::BinaryDataTooLong);
526 };
527 stream.put_u16(len);
528 stream.extend_from_slice(bytes);
529 Ok(())
530}
531
532#[inline]
534#[cfg(not(feature = "binary_string"))]
535fn write_mqtt_string(stream: &mut BytesMut, string: &MqttString) -> Result<(), Error> {
536 write_mqtt_bytes(stream, string.as_bytes())
537}
538
539#[cfg(feature = "binary_string")]
540fn write_mqtt_string(stream: &mut BytesMut, string: &MqttString) -> Result<(), Error> {
541 write_mqtt_bytes(stream, string)
542}
543
544fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
546 if len > 268_435_455 {
547 return Err(Error::PayloadTooLong);
548 }
549
550 let mut done = false;
551 let mut x = len;
552 let mut count = 0;
553
554 while !done {
555 #[allow(clippy::cast_possible_truncation)]
556 let mut byte = (x % 128) as u8;
557 x /= 128;
558 if x > 0 {
559 byte |= 128;
560 }
561
562 stream.put_u8(byte);
563 count += 1;
564 done = x == 0;
565 }
566
567 Ok(count)
568}
569
570#[inline]
572fn len_len(len: usize) -> usize {
573 if len >= 2_097_152 {
574 4
575 } else if len >= 16_384 {
576 3
577 } else if len >= 128 {
578 2
579 } else {
580 1
581 }
582}
583
584#[inline]
590fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
591 if stream.len() < 2 {
592 return Err(Error::MalformedPacket);
593 }
594
595 Ok(stream.get_u16())
596}
597
598#[inline]
599fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
600 if stream.is_empty() {
601 return Err(Error::MalformedPacket);
602 }
603
604 Ok(stream.get_u8())
605}
606
607#[inline]
608fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
609 if stream.len() < 4 {
610 return Err(Error::MalformedPacket);
611 }
612
613 Ok(stream.get_u32())
614}
615
616#[repr(u8)]
618#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
619#[allow(clippy::enum_variant_names)]
620pub enum QoS {
621 AtMostOnce = 0,
622 AtLeastOnce = 1,
623 ExactlyOnce = 2,
624}
625
626impl Default for QoS {
627 fn default() -> Self {
628 Self::AtMostOnce
629 }
630}
631
632#[must_use]
634pub fn qos(num: u8) -> Option<QoS> {
635 match num {
636 0 => Some(QoS::AtMostOnce),
637 1 => Some(QoS::AtLeastOnce),
638 2 => Some(QoS::ExactlyOnce),
639 _ => None,
640 }
641}
642
643#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
645pub enum Error {
646 #[error("Invalid return code received as response for connect = {0}")]
647 InvalidConnectReturnCode(u8),
648 #[error("Invalid reason = {0}")]
649 InvalidReason(u8),
650 #[error("Invalid remaining length = {0}")]
651 InvalidRemainingLength(usize),
652 #[error("Invalid protocol used")]
653 InvalidProtocol,
654 #[error("Invalid protocol level")]
655 InvalidProtocolLevel(u8),
656 #[error("Invalid packet format")]
657 IncorrectPacketFormat,
658 #[error("Invalid packet type = {0}")]
659 InvalidPacketType(u8),
660 #[error("Invalid retain forward rule = {0}")]
661 InvalidRetainForwardRule(u8),
662 #[error("Invalid QoS level = {0}")]
663 InvalidQoS(u8),
664 #[error("Invalid subscribe reason code = {0}")]
665 InvalidSubscribeReasonCode(u8),
666 #[error("Packet received has id Zero")]
667 PacketIdZero,
668 #[error("Empty Subscription")]
669 EmptySubscription,
670 #[error("Subscription had id Zero")]
671 SubscriptionIdZero,
672 #[error("Payload size is incorrect")]
673 PayloadSizeIncorrect,
674 #[error("Payload is too long")]
675 PayloadTooLong,
676 #[error("Binary data is too long")]
677 BinaryDataTooLong,
678 #[error("Max Payload size of {max:?} has been exceeded by packet of {pkt_size:?} bytes")]
679 PayloadSizeLimitExceeded { pkt_size: usize, max: usize },
680 #[error("Payload is required")]
681 PayloadRequired,
682 #[error("Payload is required = {0}")]
683 PayloadNotUtf8(#[from] Utf8Error),
684 #[error("Topic not utf-8")]
685 TopicNotUtf8,
686 #[error("Promised boundary crossed, contains {0} bytes")]
687 BoundaryCrossed(usize),
688 #[error("Packet is malformed")]
689 MalformedPacket,
690 #[error("Remaining length is malformed")]
691 MalformedRemainingLength,
692 #[error("Invalid property type = {0}")]
693 InvalidPropertyType(u8),
694 #[error("Insufficient number of bytes to frame packet, {0} more bytes required")]
698 InsufficientBytes(usize),
699}
700
701mod test {
702 use bytes::BytesMut;
703
704 use crate::Packet;
705
706 #[allow(dead_code)]
708 pub const USER_PROP_KEY: &str = "property";
709 #[allow(dead_code)]
710 pub const USER_PROP_VAL: &str = "a value thats really long............................................................................................................";
711
712 #[allow(dead_code)]
713 pub fn read_write_packets(packets: Vec<Packet>) {
714 for out in packets {
715 let mut buf = BytesMut::new();
716 out.write(&mut buf).unwrap();
717 let incoming = Packet::read(&mut buf, None).unwrap();
718 assert_eq!(incoming, out);
719 assert_eq!(buf.len(), 0);
720 }
721 }
722}