Skip to main content

mockforge_mqtt/
protocol.rs

1//! MQTT 3.1.1 Protocol packet parsing and serialization
2//!
3//! This module implements the MQTT 3.1.1 protocol specification for packet encoding
4//! and decoding. It supports all control packet types including CONNECT, PUBLISH,
5//! SUBSCRIBE, and their acknowledgments.
6
7use std::io::{self, Cursor, Read};
8use thiserror::Error;
9
10/// MQTT Protocol Error types
11#[derive(Debug, Error)]
12pub enum ProtocolError {
13    #[error("Invalid packet type: {0}")]
14    InvalidPacketType(u8),
15
16    #[error("Invalid remaining length encoding")]
17    InvalidRemainingLength,
18
19    #[error("Invalid protocol name: {0}")]
20    InvalidProtocolName(String),
21
22    #[error("Invalid protocol version: {0}")]
23    InvalidProtocolVersion(u8),
24
25    #[error("Invalid QoS level: {0}")]
26    InvalidQoS(u8),
27
28    #[error("Invalid UTF-8 string")]
29    InvalidUtf8,
30
31    #[error("Packet too large: {0} bytes")]
32    PacketTooLarge(usize),
33
34    #[error("Incomplete packet: expected {expected} bytes, got {got}")]
35    IncompletePacket { expected: usize, got: usize },
36
37    #[error("IO error: {0}")]
38    Io(#[from] io::Error),
39
40    #[error("Invalid connect flags")]
41    InvalidConnectFlags,
42
43    #[error("Malformed packet")]
44    MalformedPacket,
45}
46
47/// Result type for protocol operations
48pub type ProtocolResult<T> = Result<T, ProtocolError>;
49
50/// MQTT Control Packet Types (4-bit identifier in first byte)
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52#[repr(u8)]
53pub enum PacketType {
54    Connect = 1,
55    Connack = 2,
56    Publish = 3,
57    Puback = 4,
58    Pubrec = 5,
59    Pubrel = 6,
60    Pubcomp = 7,
61    Subscribe = 8,
62    Suback = 9,
63    Unsubscribe = 10,
64    Unsuback = 11,
65    Pingreq = 12,
66    Pingresp = 13,
67    Disconnect = 14,
68}
69
70impl TryFrom<u8> for PacketType {
71    type Error = ProtocolError;
72
73    fn try_from(value: u8) -> ProtocolResult<Self> {
74        match value {
75            1 => Ok(PacketType::Connect),
76            2 => Ok(PacketType::Connack),
77            3 => Ok(PacketType::Publish),
78            4 => Ok(PacketType::Puback),
79            5 => Ok(PacketType::Pubrec),
80            6 => Ok(PacketType::Pubrel),
81            7 => Ok(PacketType::Pubcomp),
82            8 => Ok(PacketType::Subscribe),
83            9 => Ok(PacketType::Suback),
84            10 => Ok(PacketType::Unsubscribe),
85            11 => Ok(PacketType::Unsuback),
86            12 => Ok(PacketType::Pingreq),
87            13 => Ok(PacketType::Pingresp),
88            14 => Ok(PacketType::Disconnect),
89            _ => Err(ProtocolError::InvalidPacketType(value)),
90        }
91    }
92}
93
94/// Quality of Service levels
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96#[repr(u8)]
97pub enum QoS {
98    #[default]
99    AtMostOnce = 0,
100    AtLeastOnce = 1,
101    ExactlyOnce = 2,
102}
103
104impl TryFrom<u8> for QoS {
105    type Error = ProtocolError;
106
107    fn try_from(value: u8) -> ProtocolResult<Self> {
108        match value {
109            0 => Ok(QoS::AtMostOnce),
110            1 => Ok(QoS::AtLeastOnce),
111            2 => Ok(QoS::ExactlyOnce),
112            _ => Err(ProtocolError::InvalidQoS(value)),
113        }
114    }
115}
116
117/// CONNACK return codes
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119#[repr(u8)]
120pub enum ConnackCode {
121    Accepted = 0,
122    UnacceptableProtocolVersion = 1,
123    IdentifierRejected = 2,
124    ServerUnavailable = 3,
125    BadUsernamePassword = 4,
126    NotAuthorized = 5,
127}
128
129impl From<ConnackCode> for u8 {
130    fn from(code: ConnackCode) -> u8 {
131        code as u8
132    }
133}
134
135/// MQTT Control Packet representation
136#[derive(Debug, Clone)]
137pub enum Packet {
138    Connect(ConnectPacket),
139    Connack(ConnackPacket),
140    Publish(PublishPacket),
141    Puback(PubackPacket),
142    Pubrec(PubrecPacket),
143    Pubrel(PubrelPacket),
144    Pubcomp(PubcompPacket),
145    Subscribe(SubscribePacket),
146    Suback(SubackPacket),
147    Unsubscribe(UnsubscribePacket),
148    Unsuback(UnsubackPacket),
149    Pingreq,
150    Pingresp,
151    Disconnect,
152}
153
154/// CONNECT packet from client
155#[derive(Debug, Clone)]
156pub struct ConnectPacket {
157    pub protocol_name: String,
158    pub protocol_level: u8,
159    pub clean_session: bool,
160    pub keep_alive: u16,
161    pub client_id: String,
162    pub will: Option<Will>,
163    pub username: Option<String>,
164    pub password: Option<Vec<u8>>,
165}
166
167/// Last Will and Testament
168#[derive(Debug, Clone)]
169pub struct Will {
170    pub topic: String,
171    pub message: Vec<u8>,
172    pub qos: QoS,
173    pub retain: bool,
174}
175
176/// CONNACK packet to client
177#[derive(Debug, Clone)]
178pub struct ConnackPacket {
179    pub session_present: bool,
180    pub return_code: ConnackCode,
181}
182
183/// PUBLISH packet
184#[derive(Debug, Clone)]
185pub struct PublishPacket {
186    pub dup: bool,
187    pub qos: QoS,
188    pub retain: bool,
189    pub topic: String,
190    pub packet_id: Option<u16>,
191    pub payload: Vec<u8>,
192}
193
194/// PUBACK packet (QoS 1 acknowledgment)
195#[derive(Debug, Clone)]
196pub struct PubackPacket {
197    pub packet_id: u16,
198}
199
200/// PUBREC packet (QoS 2 step 1)
201#[derive(Debug, Clone)]
202pub struct PubrecPacket {
203    pub packet_id: u16,
204}
205
206/// PUBREL packet (QoS 2 step 2)
207#[derive(Debug, Clone)]
208pub struct PubrelPacket {
209    pub packet_id: u16,
210}
211
212/// PUBCOMP packet (QoS 2 step 3)
213#[derive(Debug, Clone)]
214pub struct PubcompPacket {
215    pub packet_id: u16,
216}
217
218/// SUBSCRIBE packet
219#[derive(Debug, Clone)]
220pub struct SubscribePacket {
221    pub packet_id: u16,
222    pub subscriptions: Vec<(String, QoS)>,
223}
224
225/// SUBACK packet
226#[derive(Debug, Clone)]
227pub struct SubackPacket {
228    pub packet_id: u16,
229    pub return_codes: Vec<SubackReturnCode>,
230}
231
232/// SUBACK return codes
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub enum SubackReturnCode {
235    SuccessQoS0,
236    SuccessQoS1,
237    SuccessQoS2,
238    Failure,
239}
240
241impl From<SubackReturnCode> for u8 {
242    fn from(code: SubackReturnCode) -> u8 {
243        match code {
244            SubackReturnCode::SuccessQoS0 => 0x00,
245            SubackReturnCode::SuccessQoS1 => 0x01,
246            SubackReturnCode::SuccessQoS2 => 0x02,
247            SubackReturnCode::Failure => 0x80,
248        }
249    }
250}
251
252impl SubackReturnCode {
253    /// Create a success return code for the given QoS
254    pub fn success(qos: QoS) -> Self {
255        match qos {
256            QoS::AtMostOnce => SubackReturnCode::SuccessQoS0,
257            QoS::AtLeastOnce => SubackReturnCode::SuccessQoS1,
258            QoS::ExactlyOnce => SubackReturnCode::SuccessQoS2,
259        }
260    }
261}
262
263/// UNSUBSCRIBE packet
264#[derive(Debug, Clone)]
265pub struct UnsubscribePacket {
266    pub packet_id: u16,
267    pub topics: Vec<String>,
268}
269
270/// UNSUBACK packet
271#[derive(Debug, Clone)]
272pub struct UnsubackPacket {
273    pub packet_id: u16,
274}
275
276/// Packet decoder for parsing MQTT packets from bytes
277pub struct PacketDecoder;
278
279impl PacketDecoder {
280    /// Decode a single MQTT packet from a byte buffer
281    ///
282    /// Returns the parsed packet and number of bytes consumed
283    pub fn decode(buffer: &[u8]) -> ProtocolResult<Option<(Packet, usize)>> {
284        if buffer.is_empty() {
285            return Ok(None);
286        }
287
288        // Read fixed header
289        let first_byte = buffer[0];
290        let packet_type = PacketType::try_from(first_byte >> 4)?;
291        let flags = first_byte & 0x0F;
292
293        // Read remaining length (variable length encoding)
294        let (remaining_length, header_len) = Self::decode_remaining_length(&buffer[1..])?;
295
296        let total_len = 1 + header_len + remaining_length;
297        if buffer.len() < total_len {
298            return Ok(None); // Need more data
299        }
300
301        let payload = &buffer[1 + header_len..total_len];
302
303        let packet = match packet_type {
304            PacketType::Connect => Packet::Connect(Self::decode_connect(payload)?),
305            PacketType::Connack => Packet::Connack(Self::decode_connack(payload)?),
306            PacketType::Publish => Packet::Publish(Self::decode_publish(flags, payload)?),
307            PacketType::Puback => Packet::Puback(Self::decode_puback(payload)?),
308            PacketType::Pubrec => Packet::Pubrec(Self::decode_pubrec(payload)?),
309            PacketType::Pubrel => Packet::Pubrel(Self::decode_pubrel(payload)?),
310            PacketType::Pubcomp => Packet::Pubcomp(Self::decode_pubcomp(payload)?),
311            PacketType::Subscribe => Packet::Subscribe(Self::decode_subscribe(payload)?),
312            PacketType::Suback => Packet::Suback(Self::decode_suback(payload)?),
313            PacketType::Unsubscribe => Packet::Unsubscribe(Self::decode_unsubscribe(payload)?),
314            PacketType::Unsuback => Packet::Unsuback(Self::decode_unsuback(payload)?),
315            PacketType::Pingreq => Packet::Pingreq,
316            PacketType::Pingresp => Packet::Pingresp,
317            PacketType::Disconnect => Packet::Disconnect,
318        };
319
320        Ok(Some((packet, total_len)))
321    }
322
323    /// Decode variable-length remaining length field
324    fn decode_remaining_length(buffer: &[u8]) -> ProtocolResult<(usize, usize)> {
325        let mut multiplier = 1usize;
326        let mut value = 0usize;
327        let mut pos = 0;
328
329        loop {
330            if pos >= buffer.len() {
331                return Err(ProtocolError::IncompletePacket {
332                    expected: pos + 1,
333                    got: buffer.len(),
334                });
335            }
336
337            let byte = buffer[pos];
338            value += (byte & 0x7F) as usize * multiplier;
339
340            if multiplier > 128 * 128 * 128 {
341                return Err(ProtocolError::InvalidRemainingLength);
342            }
343
344            multiplier *= 128;
345            pos += 1;
346
347            if byte & 0x80 == 0 {
348                break;
349            }
350        }
351
352        Ok((value, pos))
353    }
354
355    /// Decode a UTF-8 string with length prefix
356    fn decode_string(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<String> {
357        let mut len_buf = [0u8; 2];
358        cursor.read_exact(&mut len_buf)?;
359        let len = u16::from_be_bytes(len_buf) as usize;
360
361        let mut str_buf = vec![0u8; len];
362        cursor.read_exact(&mut str_buf)?;
363
364        String::from_utf8(str_buf).map_err(|_| ProtocolError::InvalidUtf8)
365    }
366
367    /// Decode binary data with length prefix
368    fn decode_binary(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<Vec<u8>> {
369        let mut len_buf = [0u8; 2];
370        cursor.read_exact(&mut len_buf)?;
371        let len = u16::from_be_bytes(len_buf) as usize;
372
373        let mut data = vec![0u8; len];
374        cursor.read_exact(&mut data)?;
375        Ok(data)
376    }
377
378    /// Decode u16 from cursor
379    fn decode_u16(cursor: &mut Cursor<&[u8]>) -> ProtocolResult<u16> {
380        let mut buf = [0u8; 2];
381        cursor.read_exact(&mut buf)?;
382        Ok(u16::from_be_bytes(buf))
383    }
384
385    fn decode_connect(payload: &[u8]) -> ProtocolResult<ConnectPacket> {
386        let mut cursor = Cursor::new(payload);
387
388        // Protocol name
389        let protocol_name = Self::decode_string(&mut cursor)?;
390        if protocol_name != "MQTT" && protocol_name != "MQIsdp" {
391            return Err(ProtocolError::InvalidProtocolName(protocol_name));
392        }
393
394        // Protocol level
395        let mut level_buf = [0u8; 1];
396        cursor.read_exact(&mut level_buf)?;
397        let protocol_level = level_buf[0];
398
399        // For MQTT 3.1.1, protocol level should be 4
400        // For MQTT 3.1, protocol level should be 3
401        if protocol_level != 4 && protocol_level != 3 {
402            return Err(ProtocolError::InvalidProtocolVersion(protocol_level));
403        }
404
405        // Connect flags
406        let mut flags_buf = [0u8; 1];
407        cursor.read_exact(&mut flags_buf)?;
408        let flags = flags_buf[0];
409
410        let clean_session = (flags & 0x02) != 0;
411        let will_flag = (flags & 0x04) != 0;
412        let will_qos = QoS::try_from((flags >> 3) & 0x03)?;
413        let will_retain = (flags & 0x20) != 0;
414        let password_flag = (flags & 0x40) != 0;
415        let username_flag = (flags & 0x80) != 0;
416
417        // Reserved bit must be 0
418        if flags & 0x01 != 0 {
419            return Err(ProtocolError::InvalidConnectFlags);
420        }
421
422        // Keep alive
423        let keep_alive = Self::decode_u16(&mut cursor)?;
424
425        // Client ID
426        let client_id = Self::decode_string(&mut cursor)?;
427
428        // Will
429        let will = if will_flag {
430            let topic = Self::decode_string(&mut cursor)?;
431            let message = Self::decode_binary(&mut cursor)?;
432            Some(Will {
433                topic,
434                message,
435                qos: will_qos,
436                retain: will_retain,
437            })
438        } else {
439            None
440        };
441
442        // Username
443        let username = if username_flag {
444            Some(Self::decode_string(&mut cursor)?)
445        } else {
446            None
447        };
448
449        // Password
450        let password = if password_flag {
451            Some(Self::decode_binary(&mut cursor)?)
452        } else {
453            None
454        };
455
456        Ok(ConnectPacket {
457            protocol_name,
458            protocol_level,
459            clean_session,
460            keep_alive,
461            client_id,
462            will,
463            username,
464            password,
465        })
466    }
467
468    fn decode_connack(payload: &[u8]) -> ProtocolResult<ConnackPacket> {
469        if payload.len() < 2 {
470            return Err(ProtocolError::MalformedPacket);
471        }
472
473        let session_present = (payload[0] & 0x01) != 0;
474        let return_code = match payload[1] {
475            0 => ConnackCode::Accepted,
476            1 => ConnackCode::UnacceptableProtocolVersion,
477            2 => ConnackCode::IdentifierRejected,
478            3 => ConnackCode::ServerUnavailable,
479            4 => ConnackCode::BadUsernamePassword,
480            5 => ConnackCode::NotAuthorized,
481            _ => return Err(ProtocolError::MalformedPacket),
482        };
483
484        Ok(ConnackPacket {
485            session_present,
486            return_code,
487        })
488    }
489
490    fn decode_publish(flags: u8, payload: &[u8]) -> ProtocolResult<PublishPacket> {
491        let dup = (flags & 0x08) != 0;
492        let qos = QoS::try_from((flags >> 1) & 0x03)?;
493        let retain = (flags & 0x01) != 0;
494
495        let mut cursor = Cursor::new(payload);
496
497        // Topic name
498        let topic = Self::decode_string(&mut cursor)?;
499
500        // Packet ID (only for QoS > 0)
501        let packet_id = if qos != QoS::AtMostOnce {
502            Some(Self::decode_u16(&mut cursor)?)
503        } else {
504            None
505        };
506
507        // Remaining bytes are the payload
508        let pos = cursor.position() as usize;
509        let message_payload = payload[pos..].to_vec();
510
511        Ok(PublishPacket {
512            dup,
513            qos,
514            retain,
515            topic,
516            packet_id,
517            payload: message_payload,
518        })
519    }
520
521    fn decode_puback(payload: &[u8]) -> ProtocolResult<PubackPacket> {
522        if payload.len() < 2 {
523            return Err(ProtocolError::MalformedPacket);
524        }
525        Ok(PubackPacket {
526            packet_id: u16::from_be_bytes([payload[0], payload[1]]),
527        })
528    }
529
530    fn decode_pubrec(payload: &[u8]) -> ProtocolResult<PubrecPacket> {
531        if payload.len() < 2 {
532            return Err(ProtocolError::MalformedPacket);
533        }
534        Ok(PubrecPacket {
535            packet_id: u16::from_be_bytes([payload[0], payload[1]]),
536        })
537    }
538
539    fn decode_pubrel(payload: &[u8]) -> ProtocolResult<PubrelPacket> {
540        if payload.len() < 2 {
541            return Err(ProtocolError::MalformedPacket);
542        }
543        Ok(PubrelPacket {
544            packet_id: u16::from_be_bytes([payload[0], payload[1]]),
545        })
546    }
547
548    fn decode_pubcomp(payload: &[u8]) -> ProtocolResult<PubcompPacket> {
549        if payload.len() < 2 {
550            return Err(ProtocolError::MalformedPacket);
551        }
552        Ok(PubcompPacket {
553            packet_id: u16::from_be_bytes([payload[0], payload[1]]),
554        })
555    }
556
557    fn decode_subscribe(payload: &[u8]) -> ProtocolResult<SubscribePacket> {
558        let mut cursor = Cursor::new(payload);
559
560        let packet_id = Self::decode_u16(&mut cursor)?;
561        let mut subscriptions = Vec::new();
562
563        while (cursor.position() as usize) < payload.len() {
564            let topic = Self::decode_string(&mut cursor)?;
565            let mut qos_buf = [0u8; 1];
566            cursor.read_exact(&mut qos_buf)?;
567            let qos = QoS::try_from(qos_buf[0] & 0x03)?;
568            subscriptions.push((topic, qos));
569        }
570
571        if subscriptions.is_empty() {
572            return Err(ProtocolError::MalformedPacket);
573        }
574
575        Ok(SubscribePacket {
576            packet_id,
577            subscriptions,
578        })
579    }
580
581    fn decode_suback(payload: &[u8]) -> ProtocolResult<SubackPacket> {
582        if payload.len() < 3 {
583            return Err(ProtocolError::MalformedPacket);
584        }
585
586        let packet_id = u16::from_be_bytes([payload[0], payload[1]]);
587        let mut return_codes = Vec::new();
588
589        for &byte in &payload[2..] {
590            let code = match byte {
591                0x00 => SubackReturnCode::SuccessQoS0,
592                0x01 => SubackReturnCode::SuccessQoS1,
593                0x02 => SubackReturnCode::SuccessQoS2,
594                0x80 => SubackReturnCode::Failure,
595                _ => return Err(ProtocolError::MalformedPacket),
596            };
597            return_codes.push(code);
598        }
599
600        Ok(SubackPacket {
601            packet_id,
602            return_codes,
603        })
604    }
605
606    fn decode_unsubscribe(payload: &[u8]) -> ProtocolResult<UnsubscribePacket> {
607        let mut cursor = Cursor::new(payload);
608
609        let packet_id = Self::decode_u16(&mut cursor)?;
610        let mut topics = Vec::new();
611
612        while (cursor.position() as usize) < payload.len() {
613            let topic = Self::decode_string(&mut cursor)?;
614            topics.push(topic);
615        }
616
617        if topics.is_empty() {
618            return Err(ProtocolError::MalformedPacket);
619        }
620
621        Ok(UnsubscribePacket { packet_id, topics })
622    }
623
624    fn decode_unsuback(payload: &[u8]) -> ProtocolResult<UnsubackPacket> {
625        if payload.len() < 2 {
626            return Err(ProtocolError::MalformedPacket);
627        }
628        Ok(UnsubackPacket {
629            packet_id: u16::from_be_bytes([payload[0], payload[1]]),
630        })
631    }
632}
633
634/// Packet encoder for serializing MQTT packets to bytes
635pub struct PacketEncoder;
636
637impl PacketEncoder {
638    /// Encode a packet to bytes
639    pub fn encode(packet: &Packet) -> ProtocolResult<Vec<u8>> {
640        match packet {
641            Packet::Connect(p) => Self::encode_connect(p),
642            Packet::Connack(p) => Self::encode_connack(p),
643            Packet::Publish(p) => Self::encode_publish(p),
644            Packet::Puback(p) => Self::encode_puback(p),
645            Packet::Pubrec(p) => Self::encode_pubrec(p),
646            Packet::Pubrel(p) => Self::encode_pubrel(p),
647            Packet::Pubcomp(p) => Self::encode_pubcomp(p),
648            Packet::Subscribe(p) => Self::encode_subscribe(p),
649            Packet::Suback(p) => Self::encode_suback(p),
650            Packet::Unsubscribe(p) => Self::encode_unsubscribe(p),
651            Packet::Unsuback(p) => Self::encode_unsuback(p),
652            Packet::Pingreq => Self::encode_pingreq(),
653            Packet::Pingresp => Self::encode_pingresp(),
654            Packet::Disconnect => Self::encode_disconnect(),
655        }
656    }
657
658    /// Encode remaining length as variable-length integer
659    fn encode_remaining_length(length: usize) -> Vec<u8> {
660        let mut result = Vec::new();
661        let mut x = length;
662
663        loop {
664            let mut byte = (x % 128) as u8;
665            x /= 128;
666            if x > 0 {
667                byte |= 0x80;
668            }
669            result.push(byte);
670            if x == 0 {
671                break;
672            }
673        }
674
675        result
676    }
677
678    /// Encode a UTF-8 string with length prefix
679    fn encode_string(s: &str) -> Vec<u8> {
680        let bytes = s.as_bytes();
681        let len = bytes.len() as u16;
682        let mut result = Vec::with_capacity(2 + bytes.len());
683        result.extend_from_slice(&len.to_be_bytes());
684        result.extend_from_slice(bytes);
685        result
686    }
687
688    /// Encode binary data with length prefix
689    fn encode_binary(data: &[u8]) -> Vec<u8> {
690        let len = data.len() as u16;
691        let mut result = Vec::with_capacity(2 + data.len());
692        result.extend_from_slice(&len.to_be_bytes());
693        result.extend_from_slice(data);
694        result
695    }
696
697    fn encode_connect(packet: &ConnectPacket) -> ProtocolResult<Vec<u8>> {
698        let mut payload = Vec::new();
699
700        // Protocol name
701        payload.extend(Self::encode_string(&packet.protocol_name));
702
703        // Protocol level
704        payload.push(packet.protocol_level);
705
706        // Connect flags
707        let mut flags = 0u8;
708        if packet.clean_session {
709            flags |= 0x02;
710        }
711        if let Some(ref will) = packet.will {
712            flags |= 0x04; // Will flag
713            flags |= (will.qos as u8) << 3;
714            if will.retain {
715                flags |= 0x20;
716            }
717        }
718        if packet.password.is_some() {
719            flags |= 0x40;
720        }
721        if packet.username.is_some() {
722            flags |= 0x80;
723        }
724        payload.push(flags);
725
726        // Keep alive
727        payload.extend_from_slice(&packet.keep_alive.to_be_bytes());
728
729        // Client ID
730        payload.extend(Self::encode_string(&packet.client_id));
731
732        // Will
733        if let Some(ref will) = packet.will {
734            payload.extend(Self::encode_string(&will.topic));
735            payload.extend(Self::encode_binary(&will.message));
736        }
737
738        // Username
739        if let Some(ref username) = packet.username {
740            payload.extend(Self::encode_string(username));
741        }
742
743        // Password
744        if let Some(ref password) = packet.password {
745            payload.extend(Self::encode_binary(password));
746        }
747
748        // Build final packet
749        let mut result = Vec::new();
750        result.push(0x10); // CONNECT packet type
751        result.extend(Self::encode_remaining_length(payload.len()));
752        result.extend(payload);
753
754        Ok(result)
755    }
756
757    fn encode_connack(packet: &ConnackPacket) -> ProtocolResult<Vec<u8>> {
758        let mut result = Vec::new();
759        result.push(0x20); // CONNACK packet type
760        result.push(0x02); // Remaining length
761
762        let ack_flags = if packet.session_present { 0x01 } else { 0x00 };
763        result.push(ack_flags);
764        result.push(packet.return_code as u8);
765
766        Ok(result)
767    }
768
769    fn encode_publish(packet: &PublishPacket) -> ProtocolResult<Vec<u8>> {
770        let mut payload = Vec::new();
771
772        // Topic
773        payload.extend(Self::encode_string(&packet.topic));
774
775        // Packet ID (only for QoS > 0)
776        if let Some(packet_id) = packet.packet_id {
777            payload.extend_from_slice(&packet_id.to_be_bytes());
778        }
779
780        // Payload
781        payload.extend_from_slice(&packet.payload);
782
783        // Build fixed header
784        let mut first_byte = 0x30u8; // PUBLISH packet type
785        if packet.dup {
786            first_byte |= 0x08;
787        }
788        first_byte |= (packet.qos as u8) << 1;
789        if packet.retain {
790            first_byte |= 0x01;
791        }
792
793        let mut result = Vec::new();
794        result.push(first_byte);
795        result.extend(Self::encode_remaining_length(payload.len()));
796        result.extend(payload);
797
798        Ok(result)
799    }
800
801    fn encode_puback(packet: &PubackPacket) -> ProtocolResult<Vec<u8>> {
802        let mut result = Vec::new();
803        result.push(0x40); // PUBACK packet type
804        result.push(0x02); // Remaining length
805        result.extend_from_slice(&packet.packet_id.to_be_bytes());
806        Ok(result)
807    }
808
809    fn encode_pubrec(packet: &PubrecPacket) -> ProtocolResult<Vec<u8>> {
810        let mut result = Vec::new();
811        result.push(0x50); // PUBREC packet type
812        result.push(0x02); // Remaining length
813        result.extend_from_slice(&packet.packet_id.to_be_bytes());
814        Ok(result)
815    }
816
817    fn encode_pubrel(packet: &PubrelPacket) -> ProtocolResult<Vec<u8>> {
818        let mut result = Vec::new();
819        result.push(0x62); // PUBREL packet type (with required flags)
820        result.push(0x02); // Remaining length
821        result.extend_from_slice(&packet.packet_id.to_be_bytes());
822        Ok(result)
823    }
824
825    fn encode_pubcomp(packet: &PubcompPacket) -> ProtocolResult<Vec<u8>> {
826        let mut result = Vec::new();
827        result.push(0x70); // PUBCOMP packet type
828        result.push(0x02); // Remaining length
829        result.extend_from_slice(&packet.packet_id.to_be_bytes());
830        Ok(result)
831    }
832
833    fn encode_subscribe(packet: &SubscribePacket) -> ProtocolResult<Vec<u8>> {
834        let mut payload = Vec::new();
835
836        // Packet ID
837        payload.extend_from_slice(&packet.packet_id.to_be_bytes());
838
839        // Topic filters
840        for (topic, qos) in &packet.subscriptions {
841            payload.extend(Self::encode_string(topic));
842            payload.push(*qos as u8);
843        }
844
845        let mut result = Vec::new();
846        result.push(0x82); // SUBSCRIBE packet type (with required flags)
847        result.extend(Self::encode_remaining_length(payload.len()));
848        result.extend(payload);
849
850        Ok(result)
851    }
852
853    fn encode_suback(packet: &SubackPacket) -> ProtocolResult<Vec<u8>> {
854        let mut payload = Vec::new();
855
856        // Packet ID
857        payload.extend_from_slice(&packet.packet_id.to_be_bytes());
858
859        // Return codes
860        for code in &packet.return_codes {
861            payload.push((*code).into());
862        }
863
864        let mut result = Vec::new();
865        result.push(0x90); // SUBACK packet type
866        result.extend(Self::encode_remaining_length(payload.len()));
867        result.extend(payload);
868
869        Ok(result)
870    }
871
872    fn encode_unsubscribe(packet: &UnsubscribePacket) -> ProtocolResult<Vec<u8>> {
873        let mut payload = Vec::new();
874
875        // Packet ID
876        payload.extend_from_slice(&packet.packet_id.to_be_bytes());
877
878        // Topic filters
879        for topic in &packet.topics {
880            payload.extend(Self::encode_string(topic));
881        }
882
883        let mut result = Vec::new();
884        result.push(0xA2); // UNSUBSCRIBE packet type (with required flags)
885        result.extend(Self::encode_remaining_length(payload.len()));
886        result.extend(payload);
887
888        Ok(result)
889    }
890
891    fn encode_unsuback(packet: &UnsubackPacket) -> ProtocolResult<Vec<u8>> {
892        let mut result = Vec::new();
893        result.push(0xB0); // UNSUBACK packet type
894        result.push(0x02); // Remaining length
895        result.extend_from_slice(&packet.packet_id.to_be_bytes());
896        Ok(result)
897    }
898
899    fn encode_pingreq() -> ProtocolResult<Vec<u8>> {
900        Ok(vec![0xC0, 0x00]) // PINGREQ packet
901    }
902
903    fn encode_pingresp() -> ProtocolResult<Vec<u8>> {
904        Ok(vec![0xD0, 0x00]) // PINGRESP packet
905    }
906
907    fn encode_disconnect() -> ProtocolResult<Vec<u8>> {
908        Ok(vec![0xE0, 0x00]) // DISCONNECT packet
909    }
910}
911
912#[cfg(test)]
913mod tests {
914    use super::*;
915
916    #[test]
917    fn test_packet_type_from_u8() {
918        assert_eq!(PacketType::try_from(1).unwrap(), PacketType::Connect);
919        assert_eq!(PacketType::try_from(2).unwrap(), PacketType::Connack);
920        assert_eq!(PacketType::try_from(3).unwrap(), PacketType::Publish);
921        assert!(PacketType::try_from(0).is_err());
922        assert!(PacketType::try_from(15).is_err());
923    }
924
925    #[test]
926    fn test_qos_from_u8() {
927        assert_eq!(QoS::try_from(0).unwrap(), QoS::AtMostOnce);
928        assert_eq!(QoS::try_from(1).unwrap(), QoS::AtLeastOnce);
929        assert_eq!(QoS::try_from(2).unwrap(), QoS::ExactlyOnce);
930        assert!(QoS::try_from(3).is_err());
931    }
932
933    #[test]
934    fn test_encode_decode_connect() {
935        let connect = ConnectPacket {
936            protocol_name: "MQTT".to_string(),
937            protocol_level: 4,
938            clean_session: true,
939            keep_alive: 60,
940            client_id: "test-client".to_string(),
941            will: None,
942            username: None,
943            password: None,
944        };
945
946        let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
947        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
948
949        if let Packet::Connect(decoded_connect) = decoded {
950            assert_eq!(decoded_connect.protocol_name, connect.protocol_name);
951            assert_eq!(decoded_connect.protocol_level, connect.protocol_level);
952            assert_eq!(decoded_connect.clean_session, connect.clean_session);
953            assert_eq!(decoded_connect.keep_alive, connect.keep_alive);
954            assert_eq!(decoded_connect.client_id, connect.client_id);
955        } else {
956            panic!("Expected Connect packet");
957        }
958    }
959
960    #[test]
961    fn test_encode_decode_connack() {
962        let connack = ConnackPacket {
963            session_present: false,
964            return_code: ConnackCode::Accepted,
965        };
966
967        let encoded = PacketEncoder::encode(&Packet::Connack(connack.clone())).unwrap();
968        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
969
970        if let Packet::Connack(decoded_connack) = decoded {
971            assert_eq!(decoded_connack.session_present, connack.session_present);
972            assert_eq!(decoded_connack.return_code, connack.return_code);
973        } else {
974            panic!("Expected Connack packet");
975        }
976    }
977
978    #[test]
979    fn test_encode_decode_publish_qos0() {
980        let publish = PublishPacket {
981            dup: false,
982            qos: QoS::AtMostOnce,
983            retain: false,
984            topic: "test/topic".to_string(),
985            packet_id: None,
986            payload: b"Hello, MQTT!".to_vec(),
987        };
988
989        let encoded = PacketEncoder::encode(&Packet::Publish(publish.clone())).unwrap();
990        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
991
992        if let Packet::Publish(decoded_publish) = decoded {
993            assert_eq!(decoded_publish.topic, publish.topic);
994            assert_eq!(decoded_publish.payload, publish.payload);
995            assert_eq!(decoded_publish.qos, publish.qos);
996        } else {
997            panic!("Expected Publish packet");
998        }
999    }
1000
1001    #[test]
1002    fn test_encode_decode_publish_qos1() {
1003        let publish = PublishPacket {
1004            dup: false,
1005            qos: QoS::AtLeastOnce,
1006            retain: true,
1007            topic: "sensor/temp".to_string(),
1008            packet_id: Some(1234),
1009            payload: b"25.5".to_vec(),
1010        };
1011
1012        let encoded = PacketEncoder::encode(&Packet::Publish(publish.clone())).unwrap();
1013        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1014
1015        if let Packet::Publish(decoded_publish) = decoded {
1016            assert_eq!(decoded_publish.topic, publish.topic);
1017            assert_eq!(decoded_publish.packet_id, publish.packet_id);
1018            assert_eq!(decoded_publish.retain, publish.retain);
1019        } else {
1020            panic!("Expected Publish packet");
1021        }
1022    }
1023
1024    #[test]
1025    fn test_encode_decode_subscribe() {
1026        let subscribe = SubscribePacket {
1027            packet_id: 100,
1028            subscriptions: vec![
1029                ("topic/a".to_string(), QoS::AtMostOnce),
1030                ("topic/b/#".to_string(), QoS::AtLeastOnce),
1031            ],
1032        };
1033
1034        let encoded = PacketEncoder::encode(&Packet::Subscribe(subscribe.clone())).unwrap();
1035        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1036
1037        if let Packet::Subscribe(decoded_sub) = decoded {
1038            assert_eq!(decoded_sub.packet_id, subscribe.packet_id);
1039            assert_eq!(decoded_sub.subscriptions.len(), 2);
1040            assert_eq!(decoded_sub.subscriptions[0].0, "topic/a");
1041            assert_eq!(decoded_sub.subscriptions[1].1, QoS::AtLeastOnce);
1042        } else {
1043            panic!("Expected Subscribe packet");
1044        }
1045    }
1046
1047    #[test]
1048    fn test_encode_decode_suback() {
1049        let suback = SubackPacket {
1050            packet_id: 100,
1051            return_codes: vec![SubackReturnCode::SuccessQoS0, SubackReturnCode::SuccessQoS1],
1052        };
1053
1054        let encoded = PacketEncoder::encode(&Packet::Suback(suback.clone())).unwrap();
1055        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1056
1057        if let Packet::Suback(decoded_suback) = decoded {
1058            assert_eq!(decoded_suback.packet_id, suback.packet_id);
1059            assert_eq!(decoded_suback.return_codes.len(), 2);
1060        } else {
1061            panic!("Expected Suback packet");
1062        }
1063    }
1064
1065    #[test]
1066    fn test_encode_decode_unsubscribe() {
1067        let unsubscribe = UnsubscribePacket {
1068            packet_id: 200,
1069            topics: vec!["topic/a".to_string(), "topic/b".to_string()],
1070        };
1071
1072        let encoded = PacketEncoder::encode(&Packet::Unsubscribe(unsubscribe.clone())).unwrap();
1073        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1074
1075        if let Packet::Unsubscribe(decoded_unsub) = decoded {
1076            assert_eq!(decoded_unsub.packet_id, unsubscribe.packet_id);
1077            assert_eq!(decoded_unsub.topics.len(), 2);
1078        } else {
1079            panic!("Expected Unsubscribe packet");
1080        }
1081    }
1082
1083    #[test]
1084    fn test_encode_decode_pingreq() {
1085        let encoded = PacketEncoder::encode(&Packet::Pingreq).unwrap();
1086        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1087        assert!(matches!(decoded, Packet::Pingreq));
1088    }
1089
1090    #[test]
1091    fn test_encode_decode_pingresp() {
1092        let encoded = PacketEncoder::encode(&Packet::Pingresp).unwrap();
1093        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1094        assert!(matches!(decoded, Packet::Pingresp));
1095    }
1096
1097    #[test]
1098    fn test_encode_decode_disconnect() {
1099        let encoded = PacketEncoder::encode(&Packet::Disconnect).unwrap();
1100        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1101        assert!(matches!(decoded, Packet::Disconnect));
1102    }
1103
1104    #[test]
1105    fn test_incomplete_packet() {
1106        let partial = vec![0x10, 0x0A]; // CONNECT header, but no payload
1107        let result = PacketDecoder::decode(&partial);
1108        assert!(result.unwrap().is_none());
1109    }
1110
1111    #[test]
1112    fn test_remaining_length_encoding() {
1113        // Test various lengths
1114        assert_eq!(PacketEncoder::encode_remaining_length(0), vec![0x00]);
1115        assert_eq!(PacketEncoder::encode_remaining_length(127), vec![0x7F]);
1116        assert_eq!(PacketEncoder::encode_remaining_length(128), vec![0x80, 0x01]);
1117        assert_eq!(PacketEncoder::encode_remaining_length(16383), vec![0xFF, 0x7F]);
1118        assert_eq!(PacketEncoder::encode_remaining_length(16384), vec![0x80, 0x80, 0x01]);
1119    }
1120
1121    #[test]
1122    fn test_connect_with_credentials() {
1123        let connect = ConnectPacket {
1124            protocol_name: "MQTT".to_string(),
1125            protocol_level: 4,
1126            clean_session: false,
1127            keep_alive: 120,
1128            client_id: "secure-client".to_string(),
1129            will: None,
1130            username: Some("user".to_string()),
1131            password: Some(b"pass".to_vec()),
1132        };
1133
1134        let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
1135        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1136
1137        if let Packet::Connect(decoded_connect) = decoded {
1138            assert_eq!(decoded_connect.username, Some("user".to_string()));
1139            assert_eq!(decoded_connect.password, Some(b"pass".to_vec()));
1140            assert!(!decoded_connect.clean_session);
1141        } else {
1142            panic!("Expected Connect packet");
1143        }
1144    }
1145
1146    #[test]
1147    fn test_connect_with_will() {
1148        let connect = ConnectPacket {
1149            protocol_name: "MQTT".to_string(),
1150            protocol_level: 4,
1151            clean_session: true,
1152            keep_alive: 60,
1153            client_id: "will-client".to_string(),
1154            will: Some(Will {
1155                topic: "last/will".to_string(),
1156                message: b"goodbye".to_vec(),
1157                qos: QoS::AtLeastOnce,
1158                retain: true,
1159            }),
1160            username: None,
1161            password: None,
1162        };
1163
1164        let encoded = PacketEncoder::encode(&Packet::Connect(connect.clone())).unwrap();
1165        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1166
1167        if let Packet::Connect(decoded_connect) = decoded {
1168            let will = decoded_connect.will.unwrap();
1169            assert_eq!(will.topic, "last/will");
1170            assert_eq!(will.message, b"goodbye");
1171            assert_eq!(will.qos, QoS::AtLeastOnce);
1172            assert!(will.retain);
1173        } else {
1174            panic!("Expected Connect packet");
1175        }
1176    }
1177
1178    #[test]
1179    fn test_puback_roundtrip() {
1180        let puback = PubackPacket { packet_id: 12345 };
1181        let encoded = PacketEncoder::encode(&Packet::Puback(puback.clone())).unwrap();
1182        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1183
1184        if let Packet::Puback(decoded_puback) = decoded {
1185            assert_eq!(decoded_puback.packet_id, puback.packet_id);
1186        } else {
1187            panic!("Expected Puback packet");
1188        }
1189    }
1190
1191    #[test]
1192    fn test_qos2_handshake_packets() {
1193        // PUBREC
1194        let pubrec = PubrecPacket { packet_id: 1000 };
1195        let encoded = PacketEncoder::encode(&Packet::Pubrec(pubrec.clone())).unwrap();
1196        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1197        if let Packet::Pubrec(d) = decoded {
1198            assert_eq!(d.packet_id, 1000);
1199        } else {
1200            panic!("Expected Pubrec");
1201        }
1202
1203        // PUBREL
1204        let pubrel = PubrelPacket { packet_id: 1000 };
1205        let encoded = PacketEncoder::encode(&Packet::Pubrel(pubrel.clone())).unwrap();
1206        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1207        if let Packet::Pubrel(d) = decoded {
1208            assert_eq!(d.packet_id, 1000);
1209        } else {
1210            panic!("Expected Pubrel");
1211        }
1212
1213        // PUBCOMP
1214        let pubcomp = PubcompPacket { packet_id: 1000 };
1215        let encoded = PacketEncoder::encode(&Packet::Pubcomp(pubcomp.clone())).unwrap();
1216        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1217        if let Packet::Pubcomp(d) = decoded {
1218            assert_eq!(d.packet_id, 1000);
1219        } else {
1220            panic!("Expected Pubcomp");
1221        }
1222    }
1223
1224    #[test]
1225    fn test_unsuback_roundtrip() {
1226        let unsuback = UnsubackPacket { packet_id: 999 };
1227        let encoded = PacketEncoder::encode(&Packet::Unsuback(unsuback.clone())).unwrap();
1228        let (decoded, _) = PacketDecoder::decode(&encoded).unwrap().unwrap();
1229
1230        if let Packet::Unsuback(decoded_unsuback) = decoded {
1231            assert_eq!(decoded_unsuback.packet_id, unsuback.packet_id);
1232        } else {
1233            panic!("Expected Unsuback packet");
1234        }
1235    }
1236}