mqtt5_protocol/
packet.rs

1pub mod auth;
2pub mod connack;
3pub mod connect;
4pub mod disconnect;
5pub mod pingreq;
6pub mod pingresp;
7pub mod puback;
8pub mod pubcomp;
9pub mod publish;
10pub mod pubrec;
11pub mod pubrel;
12pub mod suback;
13pub mod subscribe;
14pub mod unsuback;
15pub mod unsubscribe;
16
17#[cfg(test)]
18mod property_tests;
19
20#[cfg(test)]
21mod bebytes_tests {
22    use super::*;
23    use proptest::prelude::*;
24
25    proptest! {
26        #[test]
27        fn prop_mqtt_type_and_flags_round_trip(
28            message_type in 1u8..=15,
29            dup in 0u8..=1,
30            qos in 0u8..=3,
31            retain in 0u8..=1
32        ) {
33            let original = MqttTypeAndFlags {
34                message_type,
35                dup,
36                qos,
37                retain,
38            };
39
40            let bytes = original.to_be_bytes();
41            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
42
43            prop_assert_eq!(original, decoded);
44        }
45
46        #[test]
47        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
48            if let Some(pt) = PacketType::from_u8(packet_type) {
49                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
50                let bytes = type_and_flags.to_be_bytes();
51                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
52
53                prop_assert_eq!(type_and_flags, decoded);
54                prop_assert_eq!(decoded.packet_type(), Some(pt));
55            }
56        }
57
58        #[test]
59        fn prop_publish_flags_round_trip(
60            qos in 0u8..=3,
61            dup: bool,
62            retain: bool
63        ) {
64            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
65            let bytes = type_and_flags.to_be_bytes();
66            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
67
68            prop_assert_eq!(type_and_flags, decoded);
69            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
70            prop_assert_eq!(decoded.qos, qos);
71            prop_assert_eq!(decoded.is_dup(), dup);
72            prop_assert_eq!(decoded.is_retain(), retain);
73        }
74    }
75}
76
77use crate::encoding::{decode_variable_int, encode_variable_int};
78use crate::error::{MqttError, Result};
79use bebytes::BeBytes;
80use bytes::{Buf, BufMut};
81
82/// MQTT acknowledgment packet variable header using bebytes
83/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
84#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
85pub struct AckPacketHeader {
86    /// Packet identifier
87    pub packet_id: u16,
88    /// Reason code (single byte)
89    pub reason_code: u8,
90}
91
92impl AckPacketHeader {
93    /// Creates a new acknowledgment packet header
94    #[must_use]
95    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
96        Self {
97            packet_id,
98            reason_code: u8::from(reason_code),
99        }
100    }
101
102    /// Gets the reason code as a `ReasonCode` enum
103    #[must_use]
104    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
105        crate::types::ReasonCode::from_u8(self.reason_code)
106    }
107}
108
109/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
110#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
111pub struct MqttTypeAndFlags {
112    /// Message type (bits 7-4)
113    #[bits(4)]
114    pub message_type: u8,
115    /// DUP flag (bit 3) - for PUBLISH packets
116    #[bits(1)]
117    pub dup: u8,
118    /// `QoS` level (bits 2-1) - for PUBLISH packets
119    #[bits(2)]
120    pub qos: u8,
121    /// RETAIN flag (bit 0) - for PUBLISH packets  
122    #[bits(1)]
123    pub retain: u8,
124}
125
126impl MqttTypeAndFlags {
127    /// Creates a new `MqttTypeAndFlags` for a given packet type
128    #[must_use]
129    pub fn for_packet_type(packet_type: PacketType) -> Self {
130        Self {
131            message_type: packet_type as u8,
132            dup: 0,
133            qos: 0,
134            retain: 0,
135        }
136    }
137
138    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
139    #[must_use]
140    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
141        Self {
142            message_type: PacketType::Publish as u8,
143            dup: u8::from(dup),
144            qos,
145            retain: u8::from(retain),
146        }
147    }
148
149    /// Returns the packet type
150    #[must_use]
151    pub fn packet_type(&self) -> Option<PacketType> {
152        PacketType::from_u8(self.message_type)
153    }
154
155    /// Returns true if the DUP flag is set
156    #[must_use]
157    pub fn is_dup(&self) -> bool {
158        self.dup != 0
159    }
160
161    /// Returns true if the RETAIN flag is set
162    #[must_use]
163    pub fn is_retain(&self) -> bool {
164        self.retain != 0
165    }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
169pub enum PacketType {
170    Connect = 1,
171    ConnAck = 2,
172    Publish = 3,
173    PubAck = 4,
174    PubRec = 5,
175    PubRel = 6,
176    PubComp = 7,
177    Subscribe = 8,
178    SubAck = 9,
179    Unsubscribe = 10,
180    UnsubAck = 11,
181    PingReq = 12,
182    PingResp = 13,
183    Disconnect = 14,
184    Auth = 15,
185}
186
187impl PacketType {
188    /// Converts a u8 to `PacketType`
189    #[must_use]
190    pub fn from_u8(value: u8) -> Option<Self> {
191        // Use the TryFrom implementation generated by BeBytes
192        Self::try_from(value).ok()
193    }
194}
195
196impl From<PacketType> for u8 {
197    fn from(packet_type: PacketType) -> Self {
198        packet_type as u8
199    }
200}
201
202/// MQTT packet fixed header
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub struct FixedHeader {
205    pub packet_type: PacketType,
206    pub flags: u8,
207    pub remaining_length: u32,
208}
209
210impl FixedHeader {
211    /// Creates a new fixed header
212    #[must_use]
213    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
214        Self {
215            packet_type,
216            flags,
217            remaining_length,
218        }
219    }
220
221    /// Encodes the fixed header
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if the remaining length is too large
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the operation fails
230    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
231        let byte1 =
232            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
233        buf.put_u8(byte1);
234        encode_variable_int(buf, self.remaining_length)?;
235        Ok(())
236    }
237
238    /// Decodes a fixed header from the buffer
239    ///
240    /// # Errors
241    ///
242    /// Returns an error if:
243    /// - Insufficient bytes in buffer
244    /// - Invalid packet type
245    /// - Invalid remaining length
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if the operation fails
250    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
251        if !buf.has_remaining() {
252            return Err(MqttError::MalformedPacket(
253                "No data for fixed header".to_string(),
254            ));
255        }
256
257        let byte1 = buf.get_u8();
258        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
259        let flags = byte1 & crate::constants::masks::FLAGS;
260
261        let packet_type = PacketType::from_u8(packet_type_val)
262            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
263
264        let remaining_length = decode_variable_int(buf)?;
265
266        Ok(Self {
267            packet_type,
268            flags,
269            remaining_length,
270        })
271    }
272
273    /// Validates the flags for the packet type
274    #[must_use]
275    pub fn validate_flags(&self) -> bool {
276        match self.packet_type {
277            PacketType::Publish => true, // Publish has variable flags
278            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
279                self.flags == 0x02 // Required flags for these packet types
280            }
281            _ => self.flags == 0,
282        }
283    }
284
285    /// Returns the encoded length of the fixed header
286    #[must_use]
287    pub fn encoded_len(&self) -> usize {
288        // 1 byte for packet type + flags, plus variable length encoding of remaining length
289        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
290    }
291}
292
293/// Enum representing all MQTT packet types
294#[derive(Debug, Clone)]
295pub enum Packet {
296    Connect(Box<connect::ConnectPacket>),
297    ConnAck(connack::ConnAckPacket),
298    Publish(publish::PublishPacket),
299    PubAck(puback::PubAckPacket),
300    PubRec(pubrec::PubRecPacket),
301    PubRel(pubrel::PubRelPacket),
302    PubComp(pubcomp::PubCompPacket),
303    Subscribe(subscribe::SubscribePacket),
304    SubAck(suback::SubAckPacket),
305    Unsubscribe(unsubscribe::UnsubscribePacket),
306    UnsubAck(unsuback::UnsubAckPacket),
307    PingReq,
308    PingResp,
309    Disconnect(disconnect::DisconnectPacket),
310    Auth(auth::AuthPacket),
311}
312
313impl Packet {
314    /// Decode a packet body based on the packet type
315    ///
316    /// # Errors
317    ///
318    /// Returns an error if decoding fails
319    pub fn decode_from_body<B: Buf>(
320        packet_type: PacketType,
321        fixed_header: &FixedHeader,
322        buf: &mut B,
323    ) -> Result<Self> {
324        match packet_type {
325            PacketType::Connect => {
326                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
327                Ok(Packet::Connect(Box::new(packet)))
328            }
329            PacketType::ConnAck => {
330                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
331                Ok(Packet::ConnAck(packet))
332            }
333            PacketType::Publish => {
334                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
335                Ok(Packet::Publish(packet))
336            }
337            PacketType::PubAck => {
338                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
339                Ok(Packet::PubAck(packet))
340            }
341            PacketType::PubRec => {
342                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
343                Ok(Packet::PubRec(packet))
344            }
345            PacketType::PubRel => {
346                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
347                Ok(Packet::PubRel(packet))
348            }
349            PacketType::PubComp => {
350                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
351                Ok(Packet::PubComp(packet))
352            }
353            PacketType::Subscribe => {
354                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
355                Ok(Packet::Subscribe(packet))
356            }
357            PacketType::SubAck => {
358                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
359                Ok(Packet::SubAck(packet))
360            }
361            PacketType::Unsubscribe => {
362                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
363                Ok(Packet::Unsubscribe(packet))
364            }
365            PacketType::UnsubAck => {
366                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
367                Ok(Packet::UnsubAck(packet))
368            }
369            PacketType::PingReq => Ok(Packet::PingReq),
370            PacketType::PingResp => Ok(Packet::PingResp),
371            PacketType::Disconnect => {
372                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
373                Ok(Packet::Disconnect(packet))
374            }
375            PacketType::Auth => {
376                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
377                Ok(Packet::Auth(packet))
378            }
379        }
380    }
381}
382
383/// Trait for MQTT packets
384pub trait MqttPacket: Sized {
385    /// Returns the packet type
386    fn packet_type(&self) -> PacketType;
387
388    /// Returns the fixed header flags
389    fn flags(&self) -> u8 {
390        0
391    }
392
393    /// Encodes the packet body (without fixed header)
394    ///
395    /// # Errors
396    ///
397    /// Returns an error if encoding fails
398    ///
399    /// # Errors
400    ///
401    /// Returns an error if the operation fails
402    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
403
404    /// Decodes the packet body (without fixed header)
405    ///
406    /// # Errors
407    ///
408    /// Returns an error if decoding fails
409    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
410
411    /// Encodes the complete packet (with fixed header)
412    ///
413    /// # Errors
414    ///
415    /// Returns an error if encoding fails
416    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
417        // First encode to temporary buffer to get remaining length
418        let mut body = Vec::new();
419        self.encode_body(&mut body)?;
420
421        let fixed_header = FixedHeader::new(
422            self.packet_type(),
423            self.flags(),
424            body.len().try_into().unwrap_or(u32::MAX),
425        );
426
427        fixed_header.encode(buf)?;
428        buf.put_slice(&body);
429        Ok(())
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use bytes::BytesMut;
437
438    #[test]
439    fn test_packet_type_from_u8() {
440        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
441        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
442        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
443        assert_eq!(PacketType::from_u8(0), None);
444        assert_eq!(PacketType::from_u8(16), None);
445    }
446
447    #[test]
448    fn test_fixed_header_encode_decode() {
449        let mut buf = BytesMut::new();
450
451        let header = FixedHeader::new(PacketType::Connect, 0, 100);
452        header.encode(&mut buf).unwrap();
453
454        let decoded = FixedHeader::decode(&mut buf).unwrap();
455        assert_eq!(decoded.packet_type, PacketType::Connect);
456        assert_eq!(decoded.flags, 0);
457        assert_eq!(decoded.remaining_length, 100);
458    }
459
460    #[test]
461    fn test_fixed_header_with_flags() {
462        let mut buf = BytesMut::new();
463
464        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
465        header.encode(&mut buf).unwrap();
466
467        let decoded = FixedHeader::decode(&mut buf).unwrap();
468        assert_eq!(decoded.packet_type, PacketType::Publish);
469        assert_eq!(decoded.flags, 0x0D);
470        assert_eq!(decoded.remaining_length, 50);
471    }
472
473    #[test]
474    fn test_validate_flags() {
475        let header = FixedHeader::new(PacketType::Connect, 0, 0);
476        assert!(header.validate_flags());
477
478        let header = FixedHeader::new(PacketType::Connect, 1, 0);
479        assert!(!header.validate_flags());
480
481        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
482        assert!(header.validate_flags());
483
484        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
485        assert!(!header.validate_flags());
486
487        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
488        assert!(header.validate_flags());
489    }
490
491    #[test]
492    fn test_decode_insufficient_data() {
493        let mut buf = BytesMut::new();
494        let result = FixedHeader::decode(&mut buf);
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_decode_invalid_packet_type() {
500        let mut buf = BytesMut::new();
501        buf.put_u8(0x00); // Invalid packet type 0
502        buf.put_u8(0x00); // Remaining length
503
504        let result = FixedHeader::decode(&mut buf);
505        assert!(result.is_err());
506    }
507
508    #[test]
509    fn test_packet_type_bebytes_serialization() {
510        // Test BeBytes to_be_bytes and try_from_be_bytes
511        let packet_type = PacketType::Publish;
512        let bytes = packet_type.to_be_bytes();
513        assert_eq!(bytes, vec![3]);
514
515        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
516        assert_eq!(decoded, PacketType::Publish);
517        assert_eq!(consumed, 1);
518
519        // Test other packet types
520        let packet_type = PacketType::Connect;
521        let bytes = packet_type.to_be_bytes();
522        assert_eq!(bytes, vec![1]);
523
524        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
525        assert_eq!(decoded, PacketType::Connect);
526        assert_eq!(consumed, 1);
527    }
528}