mqtt5_protocol/
packet.rs

1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod unsuback;
16pub mod unsubscribe;
17
18pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
19
20#[cfg(test)]
21mod property_tests;
22
23#[cfg(test)]
24mod bebytes_tests {
25    use super::*;
26    use proptest::prelude::*;
27
28    proptest! {
29        #[test]
30        fn prop_mqtt_type_and_flags_round_trip(
31            message_type in 1u8..=15,
32            dup in 0u8..=1,
33            qos in 0u8..=3,
34            retain in 0u8..=1
35        ) {
36            let original = MqttTypeAndFlags {
37                message_type,
38                dup,
39                qos,
40                retain,
41            };
42
43            let bytes = original.to_be_bytes();
44            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
45
46            prop_assert_eq!(original, decoded);
47        }
48
49        #[test]
50        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
51            if let Some(pt) = PacketType::from_u8(packet_type) {
52                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
53                let bytes = type_and_flags.to_be_bytes();
54                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
55
56                prop_assert_eq!(type_and_flags, decoded);
57                prop_assert_eq!(decoded.packet_type(), Some(pt));
58            }
59        }
60
61        #[test]
62        fn prop_publish_flags_round_trip(
63            qos in 0u8..=3,
64            dup: bool,
65            retain: bool
66        ) {
67            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
68            let bytes = type_and_flags.to_be_bytes();
69            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
70
71            prop_assert_eq!(type_and_flags, decoded);
72            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
73            prop_assert_eq!(decoded.qos, qos);
74            prop_assert_eq!(decoded.is_dup(), dup);
75            prop_assert_eq!(decoded.is_retain(), retain);
76        }
77    }
78}
79
80use crate::encoding::{decode_variable_int, encode_variable_int};
81use crate::error::{MqttError, Result};
82use bebytes::BeBytes;
83use bytes::{Buf, BufMut};
84
85/// MQTT acknowledgment packet variable header using bebytes
86/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
87#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
88pub struct AckPacketHeader {
89    /// Packet identifier
90    pub packet_id: u16,
91    /// Reason code (single byte)
92    pub reason_code: u8,
93}
94
95impl AckPacketHeader {
96    /// Creates a new acknowledgment packet header
97    #[must_use]
98    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
99        Self {
100            packet_id,
101            reason_code: u8::from(reason_code),
102        }
103    }
104
105    /// Gets the reason code as a `ReasonCode` enum
106    #[must_use]
107    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
108        crate::types::ReasonCode::from_u8(self.reason_code)
109    }
110}
111
112/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
113#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
114pub struct MqttTypeAndFlags {
115    /// Message type (bits 7-4)
116    #[bits(4)]
117    pub message_type: u8,
118    /// DUP flag (bit 3) - for PUBLISH packets
119    #[bits(1)]
120    pub dup: u8,
121    /// `QoS` level (bits 2-1) - for PUBLISH packets
122    #[bits(2)]
123    pub qos: u8,
124    /// RETAIN flag (bit 0) - for PUBLISH packets  
125    #[bits(1)]
126    pub retain: u8,
127}
128
129impl MqttTypeAndFlags {
130    /// Creates a new `MqttTypeAndFlags` for a given packet type
131    #[must_use]
132    pub fn for_packet_type(packet_type: PacketType) -> Self {
133        Self {
134            message_type: packet_type as u8,
135            dup: 0,
136            qos: 0,
137            retain: 0,
138        }
139    }
140
141    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
142    #[must_use]
143    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
144        Self {
145            message_type: PacketType::Publish as u8,
146            dup: u8::from(dup),
147            qos,
148            retain: u8::from(retain),
149        }
150    }
151
152    /// Returns the packet type
153    #[must_use]
154    pub fn packet_type(&self) -> Option<PacketType> {
155        PacketType::from_u8(self.message_type)
156    }
157
158    /// Returns true if the DUP flag is set
159    #[must_use]
160    pub fn is_dup(&self) -> bool {
161        self.dup != 0
162    }
163
164    /// Returns true if the RETAIN flag is set
165    #[must_use]
166    pub fn is_retain(&self) -> bool {
167        self.retain != 0
168    }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
172pub enum PacketType {
173    Connect = 1,
174    ConnAck = 2,
175    Publish = 3,
176    PubAck = 4,
177    PubRec = 5,
178    PubRel = 6,
179    PubComp = 7,
180    Subscribe = 8,
181    SubAck = 9,
182    Unsubscribe = 10,
183    UnsubAck = 11,
184    PingReq = 12,
185    PingResp = 13,
186    Disconnect = 14,
187    Auth = 15,
188}
189
190impl PacketType {
191    /// Converts a u8 to `PacketType`
192    #[must_use]
193    pub fn from_u8(value: u8) -> Option<Self> {
194        // Use the TryFrom implementation generated by BeBytes
195        Self::try_from(value).ok()
196    }
197}
198
199impl From<PacketType> for u8 {
200    fn from(packet_type: PacketType) -> Self {
201        packet_type as u8
202    }
203}
204
205/// MQTT packet fixed header
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub struct FixedHeader {
208    pub packet_type: PacketType,
209    pub flags: u8,
210    pub remaining_length: u32,
211}
212
213impl FixedHeader {
214    /// Creates a new fixed header
215    #[must_use]
216    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
217        Self {
218            packet_type,
219            flags,
220            remaining_length,
221        }
222    }
223
224    /// Encodes the fixed header
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the remaining length is too large
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the operation fails
233    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
234        let byte1 =
235            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
236        buf.put_u8(byte1);
237        encode_variable_int(buf, self.remaining_length)?;
238        Ok(())
239    }
240
241    /// Decodes a fixed header from the buffer
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if:
246    /// - Insufficient bytes in buffer
247    /// - Invalid packet type
248    /// - Invalid remaining length
249    ///
250    /// # Errors
251    ///
252    /// Returns an error if the operation fails
253    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
254        if !buf.has_remaining() {
255            return Err(MqttError::MalformedPacket(
256                "No data for fixed header".to_string(),
257            ));
258        }
259
260        let byte1 = buf.get_u8();
261        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
262        let flags = byte1 & crate::constants::masks::FLAGS;
263
264        let packet_type = PacketType::from_u8(packet_type_val)
265            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
266
267        let remaining_length = decode_variable_int(buf)?;
268
269        Ok(Self {
270            packet_type,
271            flags,
272            remaining_length,
273        })
274    }
275
276    /// Validates the flags for the packet type
277    #[must_use]
278    pub fn validate_flags(&self) -> bool {
279        match self.packet_type {
280            PacketType::Publish => true, // Publish has variable flags
281            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
282                self.flags == 0x02 // Required flags for these packet types
283            }
284            _ => self.flags == 0,
285        }
286    }
287
288    /// Returns the encoded length of the fixed header
289    #[must_use]
290    pub fn encoded_len(&self) -> usize {
291        // 1 byte for packet type + flags, plus variable length encoding of remaining length
292        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
293    }
294}
295
296/// Enum representing all MQTT packet types
297#[derive(Debug, Clone)]
298pub enum Packet {
299    Connect(Box<connect::ConnectPacket>),
300    ConnAck(connack::ConnAckPacket),
301    Publish(publish::PublishPacket),
302    PubAck(puback::PubAckPacket),
303    PubRec(pubrec::PubRecPacket),
304    PubRel(pubrel::PubRelPacket),
305    PubComp(pubcomp::PubCompPacket),
306    Subscribe(subscribe::SubscribePacket),
307    SubAck(suback::SubAckPacket),
308    Unsubscribe(unsubscribe::UnsubscribePacket),
309    UnsubAck(unsuback::UnsubAckPacket),
310    PingReq,
311    PingResp,
312    Disconnect(disconnect::DisconnectPacket),
313    Auth(auth::AuthPacket),
314}
315
316impl Packet {
317    /// Decode a packet body based on the packet type
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if decoding fails
322    pub fn decode_from_body<B: Buf>(
323        packet_type: PacketType,
324        fixed_header: &FixedHeader,
325        buf: &mut B,
326    ) -> Result<Self> {
327        match packet_type {
328            PacketType::Connect => {
329                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
330                Ok(Packet::Connect(Box::new(packet)))
331            }
332            PacketType::ConnAck => {
333                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
334                Ok(Packet::ConnAck(packet))
335            }
336            PacketType::Publish => {
337                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
338                Ok(Packet::Publish(packet))
339            }
340            PacketType::PubAck => {
341                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
342                Ok(Packet::PubAck(packet))
343            }
344            PacketType::PubRec => {
345                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
346                Ok(Packet::PubRec(packet))
347            }
348            PacketType::PubRel => {
349                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
350                Ok(Packet::PubRel(packet))
351            }
352            PacketType::PubComp => {
353                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
354                Ok(Packet::PubComp(packet))
355            }
356            PacketType::Subscribe => {
357                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
358                Ok(Packet::Subscribe(packet))
359            }
360            PacketType::SubAck => {
361                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
362                Ok(Packet::SubAck(packet))
363            }
364            PacketType::Unsubscribe => {
365                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
366                Ok(Packet::Unsubscribe(packet))
367            }
368            PacketType::UnsubAck => {
369                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
370                Ok(Packet::UnsubAck(packet))
371            }
372            PacketType::PingReq => Ok(Packet::PingReq),
373            PacketType::PingResp => Ok(Packet::PingResp),
374            PacketType::Disconnect => {
375                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
376                Ok(Packet::Disconnect(packet))
377            }
378            PacketType::Auth => {
379                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
380                Ok(Packet::Auth(packet))
381            }
382        }
383    }
384
385    /// Decode a packet body based on the packet type with protocol version
386    ///
387    /// # Errors
388    ///
389    /// Returns an error if decoding fails
390    pub fn decode_from_body_with_version<B: Buf>(
391        packet_type: PacketType,
392        fixed_header: &FixedHeader,
393        buf: &mut B,
394        protocol_version: u8,
395    ) -> Result<Self> {
396        match packet_type {
397            PacketType::Publish => {
398                let packet = publish::PublishPacket::decode_body_with_version(
399                    buf,
400                    fixed_header,
401                    protocol_version,
402                )?;
403                Ok(Packet::Publish(packet))
404            }
405            PacketType::Subscribe => {
406                let packet = subscribe::SubscribePacket::decode_body_with_version(
407                    buf,
408                    fixed_header,
409                    protocol_version,
410                )?;
411                Ok(Packet::Subscribe(packet))
412            }
413            PacketType::SubAck => {
414                let packet = suback::SubAckPacket::decode_body_with_version(
415                    buf,
416                    fixed_header,
417                    protocol_version,
418                )?;
419                Ok(Packet::SubAck(packet))
420            }
421            PacketType::Unsubscribe => {
422                let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
423                    buf,
424                    fixed_header,
425                    protocol_version,
426                )?;
427                Ok(Packet::Unsubscribe(packet))
428            }
429            _ => Self::decode_from_body(packet_type, fixed_header, buf),
430        }
431    }
432}
433
434/// Trait for MQTT packets
435pub trait MqttPacket: Sized {
436    /// Returns the packet type
437    fn packet_type(&self) -> PacketType;
438
439    /// Returns the fixed header flags
440    fn flags(&self) -> u8 {
441        0
442    }
443
444    /// Encodes the packet body (without fixed header)
445    ///
446    /// # Errors
447    ///
448    /// Returns an error if encoding fails
449    ///
450    /// # Errors
451    ///
452    /// Returns an error if the operation fails
453    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
454
455    /// Decodes the packet body (without fixed header)
456    ///
457    /// # Errors
458    ///
459    /// Returns an error if decoding fails
460    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
461
462    /// Encodes the complete packet (with fixed header)
463    ///
464    /// # Errors
465    ///
466    /// Returns an error if encoding fails
467    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
468        // First encode to temporary buffer to get remaining length
469        let mut body = Vec::new();
470        self.encode_body(&mut body)?;
471
472        let fixed_header = FixedHeader::new(
473            self.packet_type(),
474            self.flags(),
475            body.len().try_into().unwrap_or(u32::MAX),
476        );
477
478        fixed_header.encode(buf)?;
479        buf.put_slice(&body);
480        Ok(())
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use bytes::BytesMut;
488
489    #[test]
490    fn test_packet_type_from_u8() {
491        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
492        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
493        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
494        assert_eq!(PacketType::from_u8(0), None);
495        assert_eq!(PacketType::from_u8(16), None);
496    }
497
498    #[test]
499    fn test_fixed_header_encode_decode() {
500        let mut buf = BytesMut::new();
501
502        let header = FixedHeader::new(PacketType::Connect, 0, 100);
503        header.encode(&mut buf).unwrap();
504
505        let decoded = FixedHeader::decode(&mut buf).unwrap();
506        assert_eq!(decoded.packet_type, PacketType::Connect);
507        assert_eq!(decoded.flags, 0);
508        assert_eq!(decoded.remaining_length, 100);
509    }
510
511    #[test]
512    fn test_fixed_header_with_flags() {
513        let mut buf = BytesMut::new();
514
515        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
516        header.encode(&mut buf).unwrap();
517
518        let decoded = FixedHeader::decode(&mut buf).unwrap();
519        assert_eq!(decoded.packet_type, PacketType::Publish);
520        assert_eq!(decoded.flags, 0x0D);
521        assert_eq!(decoded.remaining_length, 50);
522    }
523
524    #[test]
525    fn test_validate_flags() {
526        let header = FixedHeader::new(PacketType::Connect, 0, 0);
527        assert!(header.validate_flags());
528
529        let header = FixedHeader::new(PacketType::Connect, 1, 0);
530        assert!(!header.validate_flags());
531
532        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
533        assert!(header.validate_flags());
534
535        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
536        assert!(!header.validate_flags());
537
538        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
539        assert!(header.validate_flags());
540    }
541
542    #[test]
543    fn test_decode_insufficient_data() {
544        let mut buf = BytesMut::new();
545        let result = FixedHeader::decode(&mut buf);
546        assert!(result.is_err());
547    }
548
549    #[test]
550    fn test_decode_invalid_packet_type() {
551        let mut buf = BytesMut::new();
552        buf.put_u8(0x00); // Invalid packet type 0
553        buf.put_u8(0x00); // Remaining length
554
555        let result = FixedHeader::decode(&mut buf);
556        assert!(result.is_err());
557    }
558
559    #[test]
560    fn test_packet_type_bebytes_serialization() {
561        // Test BeBytes to_be_bytes and try_from_be_bytes
562        let packet_type = PacketType::Publish;
563        let bytes = packet_type.to_be_bytes();
564        assert_eq!(bytes, vec![3]);
565
566        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
567        assert_eq!(decoded, PacketType::Publish);
568        assert_eq!(consumed, 1);
569
570        // Test other packet types
571        let packet_type = PacketType::Connect;
572        let bytes = packet_type.to_be_bytes();
573        assert_eq!(bytes, vec![1]);
574
575        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
576        assert_eq!(decoded, PacketType::Connect);
577        assert_eq!(consumed, 1);
578    }
579}