mqtt5/
packet.rs

1pub mod auth;
2pub mod connack;
3pub mod connect;
4pub mod disconnect;
5pub mod pingreq;
6pub mod pingresp;
7pub mod puback;
8pub mod pubcomp;
9pub mod publish;
10pub mod pubrec;
11pub mod pubrel;
12pub mod suback;
13pub mod subscribe;
14pub mod unsuback;
15pub mod unsubscribe;
16
17#[cfg(test)]
18mod property_tests;
19
20#[cfg(test)]
21mod bebytes_tests {
22    use super::*;
23    use proptest::prelude::*;
24
25    proptest! {
26        #[test]
27        fn prop_mqtt_type_and_flags_round_trip(
28            message_type in 1u8..=15,
29            dup in 0u8..=1,
30            qos in 0u8..=3,
31            retain in 0u8..=1
32        ) {
33            let original = MqttTypeAndFlags {
34                message_type,
35                dup,
36                qos,
37                retain,
38            };
39
40            let bytes = original.to_be_bytes();
41            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
42
43            prop_assert_eq!(original, decoded);
44        }
45
46        #[test]
47        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
48            if let Some(pt) = PacketType::from_u8(packet_type) {
49                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
50                let bytes = type_and_flags.to_be_bytes();
51                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
52
53                prop_assert_eq!(type_and_flags, decoded);
54                prop_assert_eq!(decoded.packet_type(), Some(pt));
55            }
56        }
57
58        #[test]
59        fn prop_publish_flags_round_trip(
60            qos in 0u8..=3,
61            dup: bool,
62            retain: bool
63        ) {
64            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
65            let bytes = type_and_flags.to_be_bytes();
66            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
67
68            prop_assert_eq!(type_and_flags, decoded);
69            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
70            prop_assert_eq!(decoded.qos, qos);
71            prop_assert_eq!(decoded.is_dup(), dup);
72            prop_assert_eq!(decoded.is_retain(), retain);
73        }
74    }
75}
76
77use crate::encoding::{decode_variable_int, encode_variable_int};
78use crate::error::{MqttError, Result};
79use bebytes::BeBytes;
80use bytes::{Buf, BufMut};
81
82/// MQTT acknowledgment packet variable header using bebytes
83/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
84#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
85pub struct AckPacketHeader {
86    /// Packet identifier (big-endian u16)
87    #[bebytes(big_endian)]
88    pub packet_id: u16,
89    /// Reason code (single byte)
90    pub reason_code: u8,
91}
92
93impl AckPacketHeader {
94    /// Creates a new acknowledgment packet header
95    #[must_use]
96    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
97        Self {
98            packet_id,
99            reason_code: u8::from(reason_code),
100        }
101    }
102
103    /// Gets the reason code as a `ReasonCode` enum
104    #[must_use]
105    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
106        crate::types::ReasonCode::from_u8(self.reason_code)
107    }
108}
109
110/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
111#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
112pub struct MqttTypeAndFlags {
113    /// Message type (bits 7-4)
114    #[bits(4)]
115    pub message_type: u8,
116    /// DUP flag (bit 3) - for PUBLISH packets
117    #[bits(1)]
118    pub dup: u8,
119    /// `QoS` level (bits 2-1) - for PUBLISH packets
120    #[bits(2)]
121    pub qos: u8,
122    /// RETAIN flag (bit 0) - for PUBLISH packets  
123    #[bits(1)]
124    pub retain: u8,
125}
126
127impl MqttTypeAndFlags {
128    /// Creates a new `MqttTypeAndFlags` for a given packet type
129    #[must_use]
130    pub fn for_packet_type(packet_type: PacketType) -> Self {
131        Self {
132            message_type: packet_type as u8,
133            dup: 0,
134            qos: 0,
135            retain: 0,
136        }
137    }
138
139    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
140    #[must_use]
141    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
142        Self {
143            message_type: PacketType::Publish as u8,
144            dup: u8::from(dup),
145            qos,
146            retain: u8::from(retain),
147        }
148    }
149
150    /// Returns the packet type
151    #[must_use]
152    pub fn packet_type(&self) -> Option<PacketType> {
153        PacketType::from_u8(self.message_type)
154    }
155
156    /// Returns true if the DUP flag is set
157    #[must_use]
158    pub fn is_dup(&self) -> bool {
159        self.dup != 0
160    }
161
162    /// Returns true if the RETAIN flag is set
163    #[must_use]
164    pub fn is_retain(&self) -> bool {
165        self.retain != 0
166    }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
170pub enum PacketType {
171    Connect = 1,
172    ConnAck = 2,
173    Publish = 3,
174    PubAck = 4,
175    PubRec = 5,
176    PubRel = 6,
177    PubComp = 7,
178    Subscribe = 8,
179    SubAck = 9,
180    Unsubscribe = 10,
181    UnsubAck = 11,
182    PingReq = 12,
183    PingResp = 13,
184    Disconnect = 14,
185    Auth = 15,
186}
187
188impl PacketType {
189    /// Converts a u8 to `PacketType`
190    #[must_use]
191    pub fn from_u8(value: u8) -> Option<Self> {
192        // Use the TryFrom implementation generated by BeBytes
193        Self::try_from(value).ok()
194    }
195}
196
197impl From<PacketType> for u8 {
198    fn from(packet_type: PacketType) -> Self {
199        packet_type as u8
200    }
201}
202
203/// MQTT packet fixed header
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub struct FixedHeader {
206    pub packet_type: PacketType,
207    pub flags: u8,
208    pub remaining_length: u32,
209}
210
211impl FixedHeader {
212    /// Creates a new fixed header
213    #[must_use]
214    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
215        Self {
216            packet_type,
217            flags,
218            remaining_length,
219        }
220    }
221
222    /// Encodes the fixed header
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if the remaining length is too large
227    ///
228    /// # Errors
229    ///
230    /// Returns an error if the operation fails
231    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
232        let byte1 =
233            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
234        buf.put_u8(byte1);
235        encode_variable_int(buf, self.remaining_length)?;
236        Ok(())
237    }
238
239    /// Decodes a fixed header from the buffer
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if:
244    /// - Insufficient bytes in buffer
245    /// - Invalid packet type
246    /// - Invalid remaining length
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if the operation fails
251    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
252        if !buf.has_remaining() {
253            return Err(MqttError::MalformedPacket(
254                "No data for fixed header".to_string(),
255            ));
256        }
257
258        let byte1 = buf.get_u8();
259        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
260        let flags = byte1 & crate::constants::masks::FLAGS;
261
262        let packet_type = PacketType::from_u8(packet_type_val)
263            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
264
265        let remaining_length = decode_variable_int(buf)?;
266
267        Ok(Self {
268            packet_type,
269            flags,
270            remaining_length,
271        })
272    }
273
274    /// Validates the flags for the packet type
275    #[must_use]
276    pub fn validate_flags(&self) -> bool {
277        match self.packet_type {
278            PacketType::Publish => true, // Publish has variable flags
279            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
280                self.flags == 0x02 // Required flags for these packet types
281            }
282            _ => self.flags == 0,
283        }
284    }
285
286    /// Returns the encoded length of the fixed header
287    #[must_use]
288    pub fn encoded_len(&self) -> usize {
289        // 1 byte for packet type + flags, plus variable length encoding of remaining length
290        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
291    }
292}
293
294/// Enum representing all MQTT packet types
295#[derive(Debug, Clone)]
296pub enum Packet {
297    Connect(Box<connect::ConnectPacket>),
298    ConnAck(connack::ConnAckPacket),
299    Publish(publish::PublishPacket),
300    PubAck(puback::PubAckPacket),
301    PubRec(pubrec::PubRecPacket),
302    PubRel(pubrel::PubRelPacket),
303    PubComp(pubcomp::PubCompPacket),
304    Subscribe(subscribe::SubscribePacket),
305    SubAck(suback::SubAckPacket),
306    Unsubscribe(unsubscribe::UnsubscribePacket),
307    UnsubAck(unsuback::UnsubAckPacket),
308    PingReq,
309    PingResp,
310    Disconnect(disconnect::DisconnectPacket),
311    Auth(auth::AuthPacket),
312}
313
314impl Packet {
315    /// Decode a packet body based on the packet type
316    ///
317    /// # Errors
318    ///
319    /// Returns an error if decoding fails
320    pub fn decode_from_body<B: Buf>(
321        packet_type: PacketType,
322        fixed_header: &FixedHeader,
323        buf: &mut B,
324    ) -> Result<Self> {
325        match packet_type {
326            PacketType::Connect => {
327                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
328                Ok(Packet::Connect(Box::new(packet)))
329            }
330            PacketType::ConnAck => {
331                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
332                Ok(Packet::ConnAck(packet))
333            }
334            PacketType::Publish => {
335                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
336                Ok(Packet::Publish(packet))
337            }
338            PacketType::PubAck => {
339                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
340                Ok(Packet::PubAck(packet))
341            }
342            PacketType::PubRec => {
343                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
344                Ok(Packet::PubRec(packet))
345            }
346            PacketType::PubRel => {
347                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
348                Ok(Packet::PubRel(packet))
349            }
350            PacketType::PubComp => {
351                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
352                Ok(Packet::PubComp(packet))
353            }
354            PacketType::Subscribe => {
355                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
356                Ok(Packet::Subscribe(packet))
357            }
358            PacketType::SubAck => {
359                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
360                Ok(Packet::SubAck(packet))
361            }
362            PacketType::Unsubscribe => {
363                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
364                Ok(Packet::Unsubscribe(packet))
365            }
366            PacketType::UnsubAck => {
367                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
368                Ok(Packet::UnsubAck(packet))
369            }
370            PacketType::PingReq => Ok(Packet::PingReq),
371            PacketType::PingResp => Ok(Packet::PingResp),
372            PacketType::Disconnect => {
373                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
374                Ok(Packet::Disconnect(packet))
375            }
376            PacketType::Auth => {
377                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
378                Ok(Packet::Auth(packet))
379            }
380        }
381    }
382}
383
384/// Trait for MQTT packets
385pub trait MqttPacket: Sized {
386    /// Returns the packet type
387    fn packet_type(&self) -> PacketType;
388
389    /// Returns the fixed header flags
390    fn flags(&self) -> u8 {
391        0
392    }
393
394    /// Encodes the packet body (without fixed header)
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if encoding fails
399    ///
400    /// # Errors
401    ///
402    /// Returns an error if the operation fails
403    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
404
405    /// Decodes the packet body (without fixed header)
406    ///
407    /// # Errors
408    ///
409    /// Returns an error if decoding fails
410    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
411
412    /// Encodes the complete packet (with fixed header)
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if encoding fails
417    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
418        // First encode to temporary buffer to get remaining length
419        let mut body = Vec::new();
420        self.encode_body(&mut body)?;
421
422        let fixed_header = FixedHeader::new(
423            self.packet_type(),
424            self.flags(),
425            body.len().try_into().unwrap_or(u32::MAX),
426        );
427
428        fixed_header.encode(buf)?;
429        buf.put_slice(&body);
430        Ok(())
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use bytes::BytesMut;
438
439    #[test]
440    fn test_packet_type_from_u8() {
441        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
442        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
443        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
444        assert_eq!(PacketType::from_u8(0), None);
445        assert_eq!(PacketType::from_u8(16), None);
446    }
447
448    #[test]
449    fn test_fixed_header_encode_decode() {
450        let mut buf = BytesMut::new();
451
452        let header = FixedHeader::new(PacketType::Connect, 0, 100);
453        header.encode(&mut buf).unwrap();
454
455        let decoded = FixedHeader::decode(&mut buf).unwrap();
456        assert_eq!(decoded.packet_type, PacketType::Connect);
457        assert_eq!(decoded.flags, 0);
458        assert_eq!(decoded.remaining_length, 100);
459    }
460
461    #[test]
462    fn test_fixed_header_with_flags() {
463        let mut buf = BytesMut::new();
464
465        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
466        header.encode(&mut buf).unwrap();
467
468        let decoded = FixedHeader::decode(&mut buf).unwrap();
469        assert_eq!(decoded.packet_type, PacketType::Publish);
470        assert_eq!(decoded.flags, 0x0D);
471        assert_eq!(decoded.remaining_length, 50);
472    }
473
474    #[test]
475    fn test_validate_flags() {
476        let header = FixedHeader::new(PacketType::Connect, 0, 0);
477        assert!(header.validate_flags());
478
479        let header = FixedHeader::new(PacketType::Connect, 1, 0);
480        assert!(!header.validate_flags());
481
482        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
483        assert!(header.validate_flags());
484
485        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
486        assert!(!header.validate_flags());
487
488        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
489        assert!(header.validate_flags());
490    }
491
492    #[test]
493    fn test_decode_insufficient_data() {
494        let mut buf = BytesMut::new();
495        let result = FixedHeader::decode(&mut buf);
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn test_decode_invalid_packet_type() {
501        let mut buf = BytesMut::new();
502        buf.put_u8(0x00); // Invalid packet type 0
503        buf.put_u8(0x00); // Remaining length
504
505        let result = FixedHeader::decode(&mut buf);
506        assert!(result.is_err());
507    }
508
509    #[test]
510    fn test_packet_type_bebytes_serialization() {
511        // Test BeBytes to_be_bytes and try_from_be_bytes
512        let packet_type = PacketType::Publish;
513        let bytes = packet_type.to_be_bytes();
514        assert_eq!(bytes, vec![3]);
515
516        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
517        assert_eq!(decoded, PacketType::Publish);
518        assert_eq!(consumed, 1);
519
520        // Test other packet types
521        let packet_type = PacketType::Connect;
522        let bytes = packet_type.to_be_bytes();
523        assert_eq!(bytes, vec![1]);
524
525        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
526        assert_eq!(decoded, PacketType::Connect);
527        assert_eq!(consumed, 1);
528    }
529}