Skip to main content

mqtt5_protocol/
packet.rs

1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod subscribe_options;
16pub mod unsuback;
17pub mod unsubscribe;
18
19pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
20
21#[cfg(test)]
22mod property_tests;
23
24#[cfg(test)]
25mod bebytes_tests {
26    use super::*;
27    use proptest::prelude::*;
28
29    proptest! {
30        #[test]
31        fn prop_mqtt_type_and_flags_round_trip(
32            message_type in 1u8..=15,
33            dup in 0u8..=1,
34            qos in 0u8..=3,
35            retain in 0u8..=1
36        ) {
37            let original = MqttTypeAndFlags {
38                message_type,
39                dup,
40                qos,
41                retain,
42            };
43
44            let bytes = original.to_be_bytes();
45            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
46
47            prop_assert_eq!(original, decoded);
48        }
49
50        #[test]
51        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
52            if let Some(pt) = PacketType::from_u8(packet_type) {
53                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
54                let bytes = type_and_flags.to_be_bytes();
55                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
56
57                prop_assert_eq!(type_and_flags, decoded);
58                prop_assert_eq!(decoded.packet_type(), Some(pt));
59            }
60        }
61
62        #[test]
63        fn prop_publish_flags_round_trip(
64            qos in 0u8..=3,
65            dup: bool,
66            retain: bool
67        ) {
68            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
69            let bytes = type_and_flags.to_be_bytes();
70            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
71
72            prop_assert_eq!(type_and_flags, decoded);
73            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
74            prop_assert_eq!(decoded.qos, qos);
75            prop_assert_eq!(decoded.is_dup(), dup);
76            prop_assert_eq!(decoded.is_retain(), retain);
77        }
78    }
79}
80
81use crate::encoding::{decode_variable_int, encode_variable_int};
82use crate::error::{MqttError, Result};
83use crate::prelude::{Box, ToString, Vec};
84use bebytes::BeBytes;
85use bytes::{Buf, BufMut};
86
87/// MQTT acknowledgment packet variable header using bebytes
88/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
89#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
90pub struct AckPacketHeader {
91    /// Packet identifier
92    pub packet_id: u16,
93    /// Reason code (single byte)
94    pub reason_code: u8,
95}
96
97impl AckPacketHeader {
98    /// Creates a new acknowledgment packet header
99    #[must_use]
100    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
101        Self {
102            packet_id,
103            reason_code: u8::from(reason_code),
104        }
105    }
106
107    /// Gets the reason code as a `ReasonCode` enum
108    #[must_use]
109    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
110        crate::types::ReasonCode::from_u8(self.reason_code)
111    }
112}
113
114/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
115#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
116pub struct MqttTypeAndFlags {
117    /// Message type (bits 7-4)
118    #[bits(4)]
119    pub message_type: u8,
120    /// DUP flag (bit 3) - for PUBLISH packets
121    #[bits(1)]
122    pub dup: u8,
123    /// `QoS` level (bits 2-1) - for PUBLISH packets
124    #[bits(2)]
125    pub qos: u8,
126    /// RETAIN flag (bit 0) - for PUBLISH packets  
127    #[bits(1)]
128    pub retain: u8,
129}
130
131impl MqttTypeAndFlags {
132    /// Creates a new `MqttTypeAndFlags` for a given packet type
133    #[must_use]
134    pub fn for_packet_type(packet_type: PacketType) -> Self {
135        Self {
136            message_type: packet_type as u8,
137            dup: 0,
138            qos: 0,
139            retain: 0,
140        }
141    }
142
143    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
144    #[must_use]
145    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
146        Self {
147            message_type: PacketType::Publish as u8,
148            dup: u8::from(dup),
149            qos,
150            retain: u8::from(retain),
151        }
152    }
153
154    /// Returns the packet type
155    #[must_use]
156    pub fn packet_type(&self) -> Option<PacketType> {
157        PacketType::from_u8(self.message_type)
158    }
159
160    /// Returns true if the DUP flag is set
161    #[must_use]
162    pub fn is_dup(&self) -> bool {
163        self.dup != 0
164    }
165
166    /// Returns true if the RETAIN flag is set
167    #[must_use]
168    pub fn is_retain(&self) -> bool {
169        self.retain != 0
170    }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
174pub enum PacketType {
175    Connect = 1,
176    ConnAck = 2,
177    Publish = 3,
178    PubAck = 4,
179    PubRec = 5,
180    PubRel = 6,
181    PubComp = 7,
182    Subscribe = 8,
183    SubAck = 9,
184    Unsubscribe = 10,
185    UnsubAck = 11,
186    PingReq = 12,
187    PingResp = 13,
188    Disconnect = 14,
189    Auth = 15,
190}
191
192impl PacketType {
193    /// Converts a u8 to `PacketType`
194    #[must_use]
195    pub fn from_u8(value: u8) -> Option<Self> {
196        // Use the TryFrom implementation generated by BeBytes
197        Self::try_from(value).ok()
198    }
199}
200
201impl From<PacketType> for u8 {
202    fn from(packet_type: PacketType) -> Self {
203        packet_type as u8
204    }
205}
206
207/// MQTT packet fixed header
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub struct FixedHeader {
210    pub packet_type: PacketType,
211    pub flags: u8,
212    pub remaining_length: u32,
213}
214
215impl FixedHeader {
216    /// Creates a new fixed header
217    #[must_use]
218    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
219        Self {
220            packet_type,
221            flags,
222            remaining_length,
223        }
224    }
225
226    /// Encodes the fixed header
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the remaining length is too large to encode
231    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
232        let byte1 =
233            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
234        buf.put_u8(byte1);
235        encode_variable_int(buf, self.remaining_length)?;
236        Ok(())
237    }
238
239    /// Decodes a fixed header from the buffer
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if:
244    /// - Insufficient bytes in buffer
245    /// - Invalid packet type
246    /// - Invalid remaining length encoding
247    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
248        if !buf.has_remaining() {
249            return Err(MqttError::MalformedPacket(
250                "No data for fixed header".to_string(),
251            ));
252        }
253
254        let byte1 = buf.get_u8();
255        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
256        let flags = byte1 & crate::constants::masks::FLAGS;
257
258        let packet_type = PacketType::from_u8(packet_type_val)
259            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
260
261        let remaining_length = decode_variable_int(buf)?;
262
263        Ok(Self {
264            packet_type,
265            flags,
266            remaining_length,
267        })
268    }
269
270    /// Validates the flags for the packet type
271    #[must_use]
272    pub fn validate_flags(&self) -> bool {
273        match self.packet_type {
274            PacketType::Publish => true, // Publish has variable flags
275            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
276                self.flags == 0x02 // Required flags for these packet types
277            }
278            _ => self.flags == 0,
279        }
280    }
281
282    /// Returns the encoded length of the fixed header
283    #[must_use]
284    pub fn encoded_len(&self) -> usize {
285        // 1 byte for packet type + flags, plus variable length encoding of remaining length
286        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
287    }
288}
289
290/// Enum representing all MQTT packet types
291#[derive(Debug, Clone)]
292pub enum Packet {
293    Connect(Box<connect::ConnectPacket>),
294    ConnAck(connack::ConnAckPacket),
295    Publish(publish::PublishPacket),
296    PubAck(puback::PubAckPacket),
297    PubRec(pubrec::PubRecPacket),
298    PubRel(pubrel::PubRelPacket),
299    PubComp(pubcomp::PubCompPacket),
300    Subscribe(subscribe::SubscribePacket),
301    SubAck(suback::SubAckPacket),
302    Unsubscribe(unsubscribe::UnsubscribePacket),
303    UnsubAck(unsuback::UnsubAckPacket),
304    PingReq,
305    PingResp,
306    Disconnect(disconnect::DisconnectPacket),
307    Auth(auth::AuthPacket),
308}
309
310impl Packet {
311    /// Decode a packet body based on the packet type
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if decoding fails
316    pub fn decode_from_body<B: Buf>(
317        packet_type: PacketType,
318        fixed_header: &FixedHeader,
319        buf: &mut B,
320    ) -> Result<Self> {
321        match packet_type {
322            PacketType::Connect => {
323                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
324                Ok(Packet::Connect(Box::new(packet)))
325            }
326            PacketType::ConnAck => {
327                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
328                Ok(Packet::ConnAck(packet))
329            }
330            PacketType::Publish => {
331                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
332                Ok(Packet::Publish(packet))
333            }
334            PacketType::PubAck => {
335                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
336                Ok(Packet::PubAck(packet))
337            }
338            PacketType::PubRec => {
339                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
340                Ok(Packet::PubRec(packet))
341            }
342            PacketType::PubRel => {
343                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
344                Ok(Packet::PubRel(packet))
345            }
346            PacketType::PubComp => {
347                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
348                Ok(Packet::PubComp(packet))
349            }
350            PacketType::Subscribe => {
351                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
352                Ok(Packet::Subscribe(packet))
353            }
354            PacketType::SubAck => {
355                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
356                Ok(Packet::SubAck(packet))
357            }
358            PacketType::Unsubscribe => {
359                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
360                Ok(Packet::Unsubscribe(packet))
361            }
362            PacketType::UnsubAck => {
363                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
364                Ok(Packet::UnsubAck(packet))
365            }
366            PacketType::PingReq => Ok(Packet::PingReq),
367            PacketType::PingResp => Ok(Packet::PingResp),
368            PacketType::Disconnect => {
369                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
370                Ok(Packet::Disconnect(packet))
371            }
372            PacketType::Auth => {
373                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
374                Ok(Packet::Auth(packet))
375            }
376        }
377    }
378
379    /// Decode a packet body based on the packet type with protocol version
380    ///
381    /// # Errors
382    ///
383    /// Returns an error if decoding fails
384    pub fn decode_from_body_with_version<B: Buf>(
385        packet_type: PacketType,
386        fixed_header: &FixedHeader,
387        buf: &mut B,
388        protocol_version: u8,
389    ) -> Result<Self> {
390        match packet_type {
391            PacketType::Publish => {
392                let packet = publish::PublishPacket::decode_body_with_version(
393                    buf,
394                    fixed_header,
395                    protocol_version,
396                )?;
397                Ok(Packet::Publish(packet))
398            }
399            PacketType::Subscribe => {
400                let packet = subscribe::SubscribePacket::decode_body_with_version(
401                    buf,
402                    fixed_header,
403                    protocol_version,
404                )?;
405                Ok(Packet::Subscribe(packet))
406            }
407            PacketType::SubAck => {
408                let packet = suback::SubAckPacket::decode_body_with_version(
409                    buf,
410                    fixed_header,
411                    protocol_version,
412                )?;
413                Ok(Packet::SubAck(packet))
414            }
415            PacketType::Unsubscribe => {
416                let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
417                    buf,
418                    fixed_header,
419                    protocol_version,
420                )?;
421                Ok(Packet::Unsubscribe(packet))
422            }
423            _ => Self::decode_from_body(packet_type, fixed_header, buf),
424        }
425    }
426}
427
428/// Trait for MQTT packets
429pub trait MqttPacket: Sized {
430    /// Returns the packet type
431    fn packet_type(&self) -> PacketType;
432
433    /// Returns the fixed header flags
434    fn flags(&self) -> u8 {
435        0
436    }
437
438    /// Encodes the packet body (without fixed header)
439    ///
440    /// # Errors
441    ///
442    /// Returns an error if encoding fails
443    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
444
445    /// Decodes the packet body (without fixed header)
446    ///
447    /// # Errors
448    ///
449    /// Returns an error if decoding fails
450    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
451
452    /// Encodes the complete packet (with fixed header)
453    ///
454    /// # Errors
455    ///
456    /// Returns an error if encoding fails
457    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
458        // First encode to temporary buffer to get remaining length
459        let mut body = Vec::new();
460        self.encode_body(&mut body)?;
461
462        let fixed_header = FixedHeader::new(
463            self.packet_type(),
464            self.flags(),
465            body.len().try_into().unwrap_or(u32::MAX),
466        );
467
468        fixed_header.encode(buf)?;
469        buf.put_slice(&body);
470        Ok(())
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use bytes::BytesMut;
478
479    #[test]
480    fn test_packet_type_from_u8() {
481        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
482        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
483        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
484        assert_eq!(PacketType::from_u8(0), None);
485        assert_eq!(PacketType::from_u8(16), None);
486    }
487
488    #[test]
489    fn test_fixed_header_encode_decode() {
490        let mut buf = BytesMut::new();
491
492        let header = FixedHeader::new(PacketType::Connect, 0, 100);
493        header.encode(&mut buf).unwrap();
494
495        let decoded = FixedHeader::decode(&mut buf).unwrap();
496        assert_eq!(decoded.packet_type, PacketType::Connect);
497        assert_eq!(decoded.flags, 0);
498        assert_eq!(decoded.remaining_length, 100);
499    }
500
501    #[test]
502    fn test_fixed_header_with_flags() {
503        let mut buf = BytesMut::new();
504
505        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
506        header.encode(&mut buf).unwrap();
507
508        let decoded = FixedHeader::decode(&mut buf).unwrap();
509        assert_eq!(decoded.packet_type, PacketType::Publish);
510        assert_eq!(decoded.flags, 0x0D);
511        assert_eq!(decoded.remaining_length, 50);
512    }
513
514    #[test]
515    fn test_validate_flags() {
516        let header = FixedHeader::new(PacketType::Connect, 0, 0);
517        assert!(header.validate_flags());
518
519        let header = FixedHeader::new(PacketType::Connect, 1, 0);
520        assert!(!header.validate_flags());
521
522        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
523        assert!(header.validate_flags());
524
525        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
526        assert!(!header.validate_flags());
527
528        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
529        assert!(header.validate_flags());
530    }
531
532    #[test]
533    fn test_decode_insufficient_data() {
534        let mut buf = BytesMut::new();
535        let result = FixedHeader::decode(&mut buf);
536        assert!(result.is_err());
537    }
538
539    #[test]
540    fn test_decode_invalid_packet_type() {
541        let mut buf = BytesMut::new();
542        buf.put_u8(0x00); // Invalid packet type 0
543        buf.put_u8(0x00); // Remaining length
544
545        let result = FixedHeader::decode(&mut buf);
546        assert!(result.is_err());
547    }
548
549    #[test]
550    fn test_packet_type_bebytes_serialization() {
551        // Test BeBytes to_be_bytes and try_from_be_bytes
552        let packet_type = PacketType::Publish;
553        let bytes = packet_type.to_be_bytes();
554        assert_eq!(bytes, vec![3]);
555
556        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
557        assert_eq!(decoded, PacketType::Publish);
558        assert_eq!(consumed, 1);
559
560        // Test other packet types
561        let packet_type = PacketType::Connect;
562        let bytes = packet_type.to_be_bytes();
563        assert_eq!(bytes, vec![1]);
564
565        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
566        assert_eq!(decoded, PacketType::Connect);
567        assert_eq!(consumed, 1);
568    }
569}