mqtt5_protocol/
packet.rs

1mod ack_common;
2pub mod auth;
3pub mod connack;
4pub mod connect;
5pub mod disconnect;
6pub mod pingreq;
7pub mod pingresp;
8pub mod puback;
9pub mod pubcomp;
10pub mod publish;
11pub mod pubrec;
12pub mod pubrel;
13pub mod suback;
14pub mod subscribe;
15pub mod unsuback;
16pub mod unsubscribe;
17
18pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
19
20#[cfg(test)]
21mod property_tests;
22
23#[cfg(test)]
24mod bebytes_tests {
25    use super::*;
26    use proptest::prelude::*;
27
28    proptest! {
29        #[test]
30        fn prop_mqtt_type_and_flags_round_trip(
31            message_type in 1u8..=15,
32            dup in 0u8..=1,
33            qos in 0u8..=3,
34            retain in 0u8..=1
35        ) {
36            let original = MqttTypeAndFlags {
37                message_type,
38                dup,
39                qos,
40                retain,
41            };
42
43            let bytes = original.to_be_bytes();
44            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
45
46            prop_assert_eq!(original, decoded);
47        }
48
49        #[test]
50        fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
51            if let Some(pt) = PacketType::from_u8(packet_type) {
52                let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
53                let bytes = type_and_flags.to_be_bytes();
54                let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
55
56                prop_assert_eq!(type_and_flags, decoded);
57                prop_assert_eq!(decoded.packet_type(), Some(pt));
58            }
59        }
60
61        #[test]
62        fn prop_publish_flags_round_trip(
63            qos in 0u8..=3,
64            dup: bool,
65            retain: bool
66        ) {
67            let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
68            let bytes = type_and_flags.to_be_bytes();
69            let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
70
71            prop_assert_eq!(type_and_flags, decoded);
72            prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
73            prop_assert_eq!(decoded.qos, qos);
74            prop_assert_eq!(decoded.is_dup(), dup);
75            prop_assert_eq!(decoded.is_retain(), retain);
76        }
77    }
78}
79
80use crate::encoding::{decode_variable_int, encode_variable_int};
81use crate::error::{MqttError, Result};
82use crate::prelude::{Box, ToString, Vec};
83use bebytes::BeBytes;
84use bytes::{Buf, BufMut};
85
86/// MQTT acknowledgment packet variable header using bebytes
87/// Used by `PubAck`, `PubRec`, `PubRel`, and `PubComp` packets
88#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
89pub struct AckPacketHeader {
90    /// Packet identifier
91    pub packet_id: u16,
92    /// Reason code (single byte)
93    pub reason_code: u8,
94}
95
96impl AckPacketHeader {
97    /// Creates a new acknowledgment packet header
98    #[must_use]
99    pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
100        Self {
101            packet_id,
102            reason_code: u8::from(reason_code),
103        }
104    }
105
106    /// Gets the reason code as a `ReasonCode` enum
107    #[must_use]
108    pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
109        crate::types::ReasonCode::from_u8(self.reason_code)
110    }
111}
112
113/// MQTT Fixed Header Type and Flags byte using bebytes for bit field operations
114#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
115pub struct MqttTypeAndFlags {
116    /// Message type (bits 7-4)
117    #[bits(4)]
118    pub message_type: u8,
119    /// DUP flag (bit 3) - for PUBLISH packets
120    #[bits(1)]
121    pub dup: u8,
122    /// `QoS` level (bits 2-1) - for PUBLISH packets
123    #[bits(2)]
124    pub qos: u8,
125    /// RETAIN flag (bit 0) - for PUBLISH packets  
126    #[bits(1)]
127    pub retain: u8,
128}
129
130impl MqttTypeAndFlags {
131    /// Creates a new `MqttTypeAndFlags` for a given packet type
132    #[must_use]
133    pub fn for_packet_type(packet_type: PacketType) -> Self {
134        Self {
135            message_type: packet_type as u8,
136            dup: 0,
137            qos: 0,
138            retain: 0,
139        }
140    }
141
142    /// Creates a new `MqttTypeAndFlags` for PUBLISH packets with `QoS` and flags
143    #[must_use]
144    pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
145        Self {
146            message_type: PacketType::Publish as u8,
147            dup: u8::from(dup),
148            qos,
149            retain: u8::from(retain),
150        }
151    }
152
153    /// Returns the packet type
154    #[must_use]
155    pub fn packet_type(&self) -> Option<PacketType> {
156        PacketType::from_u8(self.message_type)
157    }
158
159    /// Returns true if the DUP flag is set
160    #[must_use]
161    pub fn is_dup(&self) -> bool {
162        self.dup != 0
163    }
164
165    /// Returns true if the RETAIN flag is set
166    #[must_use]
167    pub fn is_retain(&self) -> bool {
168        self.retain != 0
169    }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
173pub enum PacketType {
174    Connect = 1,
175    ConnAck = 2,
176    Publish = 3,
177    PubAck = 4,
178    PubRec = 5,
179    PubRel = 6,
180    PubComp = 7,
181    Subscribe = 8,
182    SubAck = 9,
183    Unsubscribe = 10,
184    UnsubAck = 11,
185    PingReq = 12,
186    PingResp = 13,
187    Disconnect = 14,
188    Auth = 15,
189}
190
191impl PacketType {
192    /// Converts a u8 to `PacketType`
193    #[must_use]
194    pub fn from_u8(value: u8) -> Option<Self> {
195        // Use the TryFrom implementation generated by BeBytes
196        Self::try_from(value).ok()
197    }
198}
199
200impl From<PacketType> for u8 {
201    fn from(packet_type: PacketType) -> Self {
202        packet_type as u8
203    }
204}
205
206/// MQTT packet fixed header
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub struct FixedHeader {
209    pub packet_type: PacketType,
210    pub flags: u8,
211    pub remaining_length: u32,
212}
213
214impl FixedHeader {
215    /// Creates a new fixed header
216    #[must_use]
217    pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
218        Self {
219            packet_type,
220            flags,
221            remaining_length,
222        }
223    }
224
225    /// Encodes the fixed header
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the remaining length is too large to encode
230    pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
231        let byte1 =
232            (u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
233        buf.put_u8(byte1);
234        encode_variable_int(buf, self.remaining_length)?;
235        Ok(())
236    }
237
238    /// Decodes a fixed header from the buffer
239    ///
240    /// # Errors
241    ///
242    /// Returns an error if:
243    /// - Insufficient bytes in buffer
244    /// - Invalid packet type
245    /// - Invalid remaining length encoding
246    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
247        if !buf.has_remaining() {
248            return Err(MqttError::MalformedPacket(
249                "No data for fixed header".to_string(),
250            ));
251        }
252
253        let byte1 = buf.get_u8();
254        let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
255        let flags = byte1 & crate::constants::masks::FLAGS;
256
257        let packet_type = PacketType::from_u8(packet_type_val)
258            .ok_or(MqttError::InvalidPacketType(packet_type_val))?;
259
260        let remaining_length = decode_variable_int(buf)?;
261
262        Ok(Self {
263            packet_type,
264            flags,
265            remaining_length,
266        })
267    }
268
269    /// Validates the flags for the packet type
270    #[must_use]
271    pub fn validate_flags(&self) -> bool {
272        match self.packet_type {
273            PacketType::Publish => true, // Publish has variable flags
274            PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
275                self.flags == 0x02 // Required flags for these packet types
276            }
277            _ => self.flags == 0,
278        }
279    }
280
281    /// Returns the encoded length of the fixed header
282    #[must_use]
283    pub fn encoded_len(&self) -> usize {
284        // 1 byte for packet type + flags, plus variable length encoding of remaining length
285        1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
286    }
287}
288
289/// Enum representing all MQTT packet types
290#[derive(Debug, Clone)]
291pub enum Packet {
292    Connect(Box<connect::ConnectPacket>),
293    ConnAck(connack::ConnAckPacket),
294    Publish(publish::PublishPacket),
295    PubAck(puback::PubAckPacket),
296    PubRec(pubrec::PubRecPacket),
297    PubRel(pubrel::PubRelPacket),
298    PubComp(pubcomp::PubCompPacket),
299    Subscribe(subscribe::SubscribePacket),
300    SubAck(suback::SubAckPacket),
301    Unsubscribe(unsubscribe::UnsubscribePacket),
302    UnsubAck(unsuback::UnsubAckPacket),
303    PingReq,
304    PingResp,
305    Disconnect(disconnect::DisconnectPacket),
306    Auth(auth::AuthPacket),
307}
308
309impl Packet {
310    /// Decode a packet body based on the packet type
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if decoding fails
315    pub fn decode_from_body<B: Buf>(
316        packet_type: PacketType,
317        fixed_header: &FixedHeader,
318        buf: &mut B,
319    ) -> Result<Self> {
320        match packet_type {
321            PacketType::Connect => {
322                let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
323                Ok(Packet::Connect(Box::new(packet)))
324            }
325            PacketType::ConnAck => {
326                let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
327                Ok(Packet::ConnAck(packet))
328            }
329            PacketType::Publish => {
330                let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
331                Ok(Packet::Publish(packet))
332            }
333            PacketType::PubAck => {
334                let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
335                Ok(Packet::PubAck(packet))
336            }
337            PacketType::PubRec => {
338                let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
339                Ok(Packet::PubRec(packet))
340            }
341            PacketType::PubRel => {
342                let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
343                Ok(Packet::PubRel(packet))
344            }
345            PacketType::PubComp => {
346                let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
347                Ok(Packet::PubComp(packet))
348            }
349            PacketType::Subscribe => {
350                let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
351                Ok(Packet::Subscribe(packet))
352            }
353            PacketType::SubAck => {
354                let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
355                Ok(Packet::SubAck(packet))
356            }
357            PacketType::Unsubscribe => {
358                let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
359                Ok(Packet::Unsubscribe(packet))
360            }
361            PacketType::UnsubAck => {
362                let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
363                Ok(Packet::UnsubAck(packet))
364            }
365            PacketType::PingReq => Ok(Packet::PingReq),
366            PacketType::PingResp => Ok(Packet::PingResp),
367            PacketType::Disconnect => {
368                let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
369                Ok(Packet::Disconnect(packet))
370            }
371            PacketType::Auth => {
372                let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
373                Ok(Packet::Auth(packet))
374            }
375        }
376    }
377
378    /// Decode a packet body based on the packet type with protocol version
379    ///
380    /// # Errors
381    ///
382    /// Returns an error if decoding fails
383    pub fn decode_from_body_with_version<B: Buf>(
384        packet_type: PacketType,
385        fixed_header: &FixedHeader,
386        buf: &mut B,
387        protocol_version: u8,
388    ) -> Result<Self> {
389        match packet_type {
390            PacketType::Publish => {
391                let packet = publish::PublishPacket::decode_body_with_version(
392                    buf,
393                    fixed_header,
394                    protocol_version,
395                )?;
396                Ok(Packet::Publish(packet))
397            }
398            PacketType::Subscribe => {
399                let packet = subscribe::SubscribePacket::decode_body_with_version(
400                    buf,
401                    fixed_header,
402                    protocol_version,
403                )?;
404                Ok(Packet::Subscribe(packet))
405            }
406            PacketType::SubAck => {
407                let packet = suback::SubAckPacket::decode_body_with_version(
408                    buf,
409                    fixed_header,
410                    protocol_version,
411                )?;
412                Ok(Packet::SubAck(packet))
413            }
414            PacketType::Unsubscribe => {
415                let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
416                    buf,
417                    fixed_header,
418                    protocol_version,
419                )?;
420                Ok(Packet::Unsubscribe(packet))
421            }
422            _ => Self::decode_from_body(packet_type, fixed_header, buf),
423        }
424    }
425}
426
427/// Trait for MQTT packets
428pub trait MqttPacket: Sized {
429    /// Returns the packet type
430    fn packet_type(&self) -> PacketType;
431
432    /// Returns the fixed header flags
433    fn flags(&self) -> u8 {
434        0
435    }
436
437    /// Encodes the packet body (without fixed header)
438    ///
439    /// # Errors
440    ///
441    /// Returns an error if encoding fails
442    fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
443
444    /// Decodes the packet body (without fixed header)
445    ///
446    /// # Errors
447    ///
448    /// Returns an error if decoding fails
449    fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
450
451    /// Encodes the complete packet (with fixed header)
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if encoding fails
456    fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
457        // First encode to temporary buffer to get remaining length
458        let mut body = Vec::new();
459        self.encode_body(&mut body)?;
460
461        let fixed_header = FixedHeader::new(
462            self.packet_type(),
463            self.flags(),
464            body.len().try_into().unwrap_or(u32::MAX),
465        );
466
467        fixed_header.encode(buf)?;
468        buf.put_slice(&body);
469        Ok(())
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use bytes::BytesMut;
477
478    #[test]
479    fn test_packet_type_from_u8() {
480        assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
481        assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
482        assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
483        assert_eq!(PacketType::from_u8(0), None);
484        assert_eq!(PacketType::from_u8(16), None);
485    }
486
487    #[test]
488    fn test_fixed_header_encode_decode() {
489        let mut buf = BytesMut::new();
490
491        let header = FixedHeader::new(PacketType::Connect, 0, 100);
492        header.encode(&mut buf).unwrap();
493
494        let decoded = FixedHeader::decode(&mut buf).unwrap();
495        assert_eq!(decoded.packet_type, PacketType::Connect);
496        assert_eq!(decoded.flags, 0);
497        assert_eq!(decoded.remaining_length, 100);
498    }
499
500    #[test]
501    fn test_fixed_header_with_flags() {
502        let mut buf = BytesMut::new();
503
504        let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
505        header.encode(&mut buf).unwrap();
506
507        let decoded = FixedHeader::decode(&mut buf).unwrap();
508        assert_eq!(decoded.packet_type, PacketType::Publish);
509        assert_eq!(decoded.flags, 0x0D);
510        assert_eq!(decoded.remaining_length, 50);
511    }
512
513    #[test]
514    fn test_validate_flags() {
515        let header = FixedHeader::new(PacketType::Connect, 0, 0);
516        assert!(header.validate_flags());
517
518        let header = FixedHeader::new(PacketType::Connect, 1, 0);
519        assert!(!header.validate_flags());
520
521        let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
522        assert!(header.validate_flags());
523
524        let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
525        assert!(!header.validate_flags());
526
527        let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
528        assert!(header.validate_flags());
529    }
530
531    #[test]
532    fn test_decode_insufficient_data() {
533        let mut buf = BytesMut::new();
534        let result = FixedHeader::decode(&mut buf);
535        assert!(result.is_err());
536    }
537
538    #[test]
539    fn test_decode_invalid_packet_type() {
540        let mut buf = BytesMut::new();
541        buf.put_u8(0x00); // Invalid packet type 0
542        buf.put_u8(0x00); // Remaining length
543
544        let result = FixedHeader::decode(&mut buf);
545        assert!(result.is_err());
546    }
547
548    #[test]
549    fn test_packet_type_bebytes_serialization() {
550        // Test BeBytes to_be_bytes and try_from_be_bytes
551        let packet_type = PacketType::Publish;
552        let bytes = packet_type.to_be_bytes();
553        assert_eq!(bytes, vec![3]);
554
555        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
556        assert_eq!(decoded, PacketType::Publish);
557        assert_eq!(consumed, 1);
558
559        // Test other packet types
560        let packet_type = PacketType::Connect;
561        let bytes = packet_type.to_be_bytes();
562        assert_eq!(bytes, vec![1]);
563
564        let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
565        assert_eq!(decoded, PacketType::Connect);
566        assert_eq!(consumed, 1);
567    }
568}