Skip to main content

mqtt5_protocol/
packet.rs

1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod subscribe_options;
16pub mod unsuback;
17pub mod unsubscribe;
18
19pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
20
21#[cfg(test)]
22mod property_tests;
23
24#[cfg(test)]
25mod bebytes_tests {
26    use super::*;
27    use proptest::prelude::*;
28
29    proptest! {
30        #[test]
31        fn prop_mqtt_type_and_flags_round_trip(
32            message_type in 1u8..=15,
33            dup in 0u8..=1,
34            qos in 0u8..=3,
35            retain in 0u8..=1
36        ) {
37            let original = MqttTypeAndFlags {
38                message_type,
39                dup,
40                qos,
41                retain,
42            };
43
44            let bytes = original.to_be_bytes();
45            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
46
47            prop_assert_eq!(original, decoded);
48        }
49
50        #[test]
51        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
52            if let Some(pt) = PacketType::from_u8(packet_type) {
53                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
54                let bytes = type_and_flags.to_be_bytes();
55                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
56
57                prop_assert_eq!(type_and_flags, decoded);
58                prop_assert_eq!(decoded.packet_type(), Some(pt));
59            }
60        }
61
62        #[test]
63        fn prop_publish_flags_round_trip(
64            qos in 0u8..=3,
65            dup: bool,
66            retain: bool
67        ) {
68            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
69            let bytes = type_and_flags.to_be_bytes();
70            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
71
72            prop_assert_eq!(type_and_flags, decoded);
73            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
74            prop_assert_eq!(decoded.qos, qos);
75            prop_assert_eq!(decoded.is_dup(), dup);
76            prop_assert_eq!(decoded.is_retain(), retain);
77        }
78    }
79}
80
81use crate::encoding::{decode_variable_int, encode_variable_int};
82use crate::error::{MqttError, Result};
83use crate::prelude::{format, Box, ToString, Vec};
84use bebytes::BeBytes;
85use bytes::{Buf, BufMut};
86
87/// MQTT acknowledgment packet variable header using bebytes
88/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
89#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
90pub struct AckPacketHeader {
91    /// Packet identifier
92    pub packet_id: u16,
93    /// Reason code (single byte)
94    pub reason_code: u8,
95}
96
97impl AckPacketHeader {
98    /// Creates a new acknowledgment packet header
99    #[must_use]
100    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
101        Self {
102            packet_id,
103            reason_code: u8::from(reason_code),
104        }
105    }
106
107    /// Gets the reason code as a `ReasonCode` enum
108    #[must_use]
109    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
110        crate::types::ReasonCode::from_u8(self.reason_code)
111    }
112}
113
114/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
115#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
116pub struct MqttTypeAndFlags {
117    /// Message type (bits 7-4)
118    #[bits(4)]
119    pub message_type: u8,
120    /// DUP flag (bit 3) - for PUBLISH packets
121    #[bits(1)]
122    pub dup: u8,
123    /// `QoS` level (bits 2-1) - for PUBLISH packets
124    #[bits(2)]
125    pub qos: u8,
126    /// RETAIN flag (bit 0) - for PUBLISH packets  
127    #[bits(1)]
128    pub retain: u8,
129}
130
131impl MqttTypeAndFlags {
132    /// Creates a new `MqttTypeAndFlags` for a given packet type
133    #[must_use]
134    pub fn for_packet_type(packet_type: PacketType) -> Self {
135        Self {
136            message_type: packet_type as u8,
137            dup: 0,
138            qos: 0,
139            retain: 0,
140        }
141    }
142
143    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
144    #[must_use]
145    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
146        Self {
147            message_type: PacketType::Publish as u8,
148            dup: u8::from(dup),
149            qos,
150            retain: u8::from(retain),
151        }
152    }
153
154    /// Returns the packet type
155    #[must_use]
156    pub fn packet_type(&self) -> Option<PacketType> {
157        PacketType::from_u8(self.message_type)
158    }
159
160    /// Returns true if the DUP flag is set
161    #[must_use]
162    pub fn is_dup(&self) -> bool {
163        self.dup != 0
164    }
165
166    /// Returns true if the RETAIN flag is set
167    #[must_use]
168    pub fn is_retain(&self) -> bool {
169        self.retain != 0
170    }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
174pub enum PacketType {
175    Connect = 1,
176    ConnAck = 2,
177    Publish = 3,
178    PubAck = 4,
179    PubRec = 5,
180    PubRel = 6,
181    PubComp = 7,
182    Subscribe = 8,
183    SubAck = 9,
184    Unsubscribe = 10,
185    UnsubAck = 11,
186    PingReq = 12,
187    PingResp = 13,
188    Disconnect = 14,
189    Auth = 15,
190}
191
192impl PacketType {
193    /// Converts a u8 to `PacketType`
194    #[must_use]
195    pub fn from_u8(value: u8) -> Option<Self> {
196        // Use the TryFrom implementation generated by BeBytes
197        Self::try_from(value).ok()
198    }
199}
200
201impl From<PacketType> for u8 {
202    fn from(packet_type: PacketType) -> Self {
203        packet_type as u8
204    }
205}
206
207/// MQTT packet fixed header
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub struct FixedHeader {
210    pub packet_type: PacketType,
211    pub flags: u8,
212    pub remaining_length: u32,
213}
214
215impl FixedHeader {
216    /// Creates a new fixed header
217    #[must_use]
218    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
219        Self {
220            packet_type,
221            flags,
222            remaining_length,
223        }
224    }
225
226    /// Encodes the fixed header
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the remaining length is too large to encode
231    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
232        let byte1 =
233            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
234        buf.put_u8(byte1);
235        encode_variable_int(buf, self.remaining_length)?;
236        Ok(())
237    }
238
239    /// Decodes a fixed header from the buffer
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if:
244    /// - Insufficient bytes in buffer
245    /// - Invalid packet type
246    /// - Invalid remaining length encoding
247    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
248        if !buf.has_remaining() {
249            return Err(MqttError::MalformedPacket(
250                "No data for fixed header".to_string(),
251            ));
252        }
253
254        let byte1 = buf.get_u8();
255        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
256        let flags = byte1 & crate::constants::masks::FLAGS;
257
258        let packet_type = PacketType::from_u8(packet_type_val)
259            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
260
261        let remaining_length = decode_variable_int(buf)?;
262
263        Ok(Self {
264            packet_type,
265            flags,
266            remaining_length,
267        })
268    }
269
270    /// Validates the flags for the packet type
271    #[must_use]
272    pub fn validate_flags(&self) -> bool {
273        match self.packet_type {
274            PacketType::Publish => true, // Publish has variable flags
275            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
276                self.flags == 0x02 // Required flags for these packet types
277            }
278            _ => self.flags == 0,
279        }
280    }
281
282    /// Returns the encoded length of the fixed header
283    #[must_use]
284    pub fn encoded_len(&self) -> usize {
285        // 1 byte for packet type + flags, plus variable length encoding of remaining length
286        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
287    }
288}
289
290/// Enum representing all MQTT packet types
291#[derive(Debug, Clone)]
292pub enum Packet {
293    Connect(Box<connect::ConnectPacket>),
294    ConnAck(connack::ConnAckPacket),
295    Publish(publish::PublishPacket),
296    PubAck(puback::PubAckPacket),
297    PubRec(pubrec::PubRecPacket),
298    PubRel(pubrel::PubRelPacket),
299    PubComp(pubcomp::PubCompPacket),
300    Subscribe(subscribe::SubscribePacket),
301    SubAck(suback::SubAckPacket),
302    Unsubscribe(unsubscribe::UnsubscribePacket),
303    UnsubAck(unsuback::UnsubAckPacket),
304    PingReq,
305    PingResp,
306    Disconnect(disconnect::DisconnectPacket),
307    Auth(auth::AuthPacket),
308}
309
310impl Packet {
311    #[must_use]
312    pub fn packet_type_name(&self) -> &'static str {
313        match self {
314            Self::Connect(_) => "CONNECT",
315            Self::ConnAck(_) => "CONNACK",
316            Self::Publish(_) => "PUBLISH",
317            Self::PubAck(_) => "PUBACK",
318            Self::PubRec(_) => "PUBREC",
319            Self::PubRel(_) => "PUBREL",
320            Self::PubComp(_) => "PUBCOMP",
321            Self::Subscribe(_) => "SUBSCRIBE",
322            Self::SubAck(_) => "SUBACK",
323            Self::Unsubscribe(_) => "UNSUBSCRIBE",
324            Self::UnsubAck(_) => "UNSUBACK",
325            Self::PingReq => "PINGREQ",
326            Self::PingResp => "PINGRESP",
327            Self::Disconnect(_) => "DISCONNECT",
328            Self::Auth(_) => "AUTH",
329        }
330    }
331
332    /// # Errors
333    ///
334    /// Returns an error if decoding fails
335    pub fn decode_from_body<B: Buf>(
336        packet_type: PacketType,
337        fixed_header: &FixedHeader,
338        buf: &mut B,
339    ) -> Result<Self> {
340        if !fixed_header.validate_flags() {
341            return Err(MqttError::MalformedPacket(format!(
342                "Invalid fixed header flags 0x{:02X} for {:?}",
343                fixed_header.flags, packet_type
344            )));
345        }
346
347        match packet_type {
348            PacketType::Connect => {
349                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
350                Ok(Packet::Connect(Box::new(packet)))
351            }
352            PacketType::ConnAck => {
353                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
354                Ok(Packet::ConnAck(packet))
355            }
356            PacketType::Publish => {
357                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
358                Ok(Packet::Publish(packet))
359            }
360            PacketType::PubAck => {
361                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
362                Ok(Packet::PubAck(packet))
363            }
364            PacketType::PubRec => {
365                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
366                Ok(Packet::PubRec(packet))
367            }
368            PacketType::PubRel => {
369                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
370                Ok(Packet::PubRel(packet))
371            }
372            PacketType::PubComp => {
373                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
374                Ok(Packet::PubComp(packet))
375            }
376            PacketType::Subscribe => {
377                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
378                Ok(Packet::Subscribe(packet))
379            }
380            PacketType::SubAck => {
381                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
382                Ok(Packet::SubAck(packet))
383            }
384            PacketType::Unsubscribe => {
385                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
386                Ok(Packet::Unsubscribe(packet))
387            }
388            PacketType::UnsubAck => {
389                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
390                Ok(Packet::UnsubAck(packet))
391            }
392            PacketType::PingReq => Ok(Packet::PingReq),
393            PacketType::PingResp => Ok(Packet::PingResp),
394            PacketType::Disconnect => {
395                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
396                Ok(Packet::Disconnect(packet))
397            }
398            PacketType::Auth => {
399                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
400                Ok(Packet::Auth(packet))
401            }
402        }
403    }
404
405    /// Decode a packet body based on the packet type with protocol version
406    ///
407    /// # Errors
408    ///
409    /// Returns an error if decoding fails
410    pub fn decode_from_body_with_version<B: Buf>(
411        packet_type: PacketType,
412        fixed_header: &FixedHeader,
413        buf: &mut B,
414        protocol_version: u8,
415    ) -> Result<Self> {
416        match packet_type {
417            PacketType::Publish => {
418                let packet = publish::PublishPacket::decode_body_with_version(
419                    buf,
420                    fixed_header,
421                    protocol_version,
422                )?;
423                Ok(Packet::Publish(packet))
424            }
425            PacketType::Subscribe => {
426                let packet = subscribe::SubscribePacket::decode_body_with_version(
427                    buf,
428                    fixed_header,
429                    protocol_version,
430                )?;
431                Ok(Packet::Subscribe(packet))
432            }
433            PacketType::SubAck => {
434                let packet = suback::SubAckPacket::decode_body_with_version(
435                    buf,
436                    fixed_header,
437                    protocol_version,
438                )?;
439                Ok(Packet::SubAck(packet))
440            }
441            PacketType::Unsubscribe => {
442                let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
443                    buf,
444                    fixed_header,
445                    protocol_version,
446                )?;
447                Ok(Packet::Unsubscribe(packet))
448            }
449            _ => Self::decode_from_body(packet_type, fixed_header, buf),
450        }
451    }
452}
453
454/// Trait for MQTT packets
455pub trait MqttPacket: Sized {
456    /// Returns the packet type
457    fn packet_type(&self) -> PacketType;
458
459    /// Returns the fixed header flags
460    fn flags(&self) -> u8 {
461        0
462    }
463
464    /// Encodes the packet body (without fixed header)
465    ///
466    /// # Errors
467    ///
468    /// Returns an error if encoding fails
469    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
470
471    /// Decodes the packet body (without fixed header)
472    ///
473    /// # Errors
474    ///
475    /// Returns an error if decoding fails
476    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
477
478    /// Encodes the complete packet (with fixed header)
479    ///
480    /// # Errors
481    ///
482    /// Returns an error if encoding fails
483    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
484        // First encode to temporary buffer to get remaining length
485        let mut body = Vec::new();
486        self.encode_body(&mut body)?;
487
488        let fixed_header = FixedHeader::new(
489            self.packet_type(),
490            self.flags(),
491            body.len().try_into().unwrap_or(u32::MAX),
492        );
493
494        fixed_header.encode(buf)?;
495        buf.put_slice(&body);
496        Ok(())
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use bytes::BytesMut;
504
505    #[test]
506    fn test_packet_type_from_u8() {
507        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
508        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
509        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
510        assert_eq!(PacketType::from_u8(0), None);
511        assert_eq!(PacketType::from_u8(16), None);
512    }
513
514    #[test]
515    fn test_fixed_header_encode_decode() {
516        let mut buf = BytesMut::new();
517
518        let header = FixedHeader::new(PacketType::Connect, 0, 100);
519        header.encode(&mut buf).unwrap();
520
521        let decoded = FixedHeader::decode(&mut buf).unwrap();
522        assert_eq!(decoded.packet_type, PacketType::Connect);
523        assert_eq!(decoded.flags, 0);
524        assert_eq!(decoded.remaining_length, 100);
525    }
526
527    #[test]
528    fn test_fixed_header_with_flags() {
529        let mut buf = BytesMut::new();
530
531        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
532        header.encode(&mut buf).unwrap();
533
534        let decoded = FixedHeader::decode(&mut buf).unwrap();
535        assert_eq!(decoded.packet_type, PacketType::Publish);
536        assert_eq!(decoded.flags, 0x0D);
537        assert_eq!(decoded.remaining_length, 50);
538    }
539
540    #[test]
541    fn test_validate_flags() {
542        let header = FixedHeader::new(PacketType::Connect, 0, 0);
543        assert!(header.validate_flags());
544
545        let header = FixedHeader::new(PacketType::Connect, 1, 0);
546        assert!(!header.validate_flags());
547
548        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
549        assert!(header.validate_flags());
550
551        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
552        assert!(!header.validate_flags());
553
554        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
555        assert!(header.validate_flags());
556    }
557
558    #[test]
559    fn test_decode_insufficient_data() {
560        let mut buf = BytesMut::new();
561        let result = FixedHeader::decode(&mut buf);
562        assert!(result.is_err());
563    }
564
565    #[test]
566    fn test_decode_invalid_packet_type() {
567        let mut buf = BytesMut::new();
568        buf.put_u8(0x00); // Invalid packet type 0
569        buf.put_u8(0x00); // Remaining length
570
571        let result = FixedHeader::decode(&mut buf);
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn test_packet_type_bebytes_serialization() {
577        // Test BeBytes to_be_bytes and try_from_be_bytes
578        let packet_type = PacketType::Publish;
579        let bytes = packet_type.to_be_bytes();
580        assert_eq!(bytes, vec![3]);
581
582        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
583        assert_eq!(decoded, PacketType::Publish);
584        assert_eq!(consumed, 1);
585
586        // Test other packet types
587        let packet_type = PacketType::Connect;
588        let bytes = packet_type.to_be_bytes();
589        assert_eq!(bytes, vec![1]);
590
591        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
592        assert_eq!(decoded, PacketType::Connect);
593        assert_eq!(consumed, 1);
594    }
595}