rmqtt_codec/
version.rs

1use bytes::BytesMut;
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::{DecodeError, EncodeError};
5use crate::types::{packet_type, MQISDP, MQTT, MQTT_LEVEL_31, MQTT_LEVEL_311, MQTT_LEVEL_5};
6use crate::utils;
7
8/// Represents supported MQTT protocol versions
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub enum ProtocolVersion {
11    /// MQTT version 3.1 or 3.1.1
12    MQTT3,
13    /// MQTT version 5.0
14    MQTT5,
15}
16
17/// Codec for detecting MQTT protocol version from initial handshake
18///
19/// This codec is specifically designed to handle the initial CONNECT packet
20/// and determine the protocol version before switching to version-specific codecs
21#[derive(Debug)]
22pub struct VersionCodec;
23
24impl Decoder for VersionCodec {
25    type Item = ProtocolVersion;
26    type Error = DecodeError;
27
28    /// Decodes the protocol version from the initial CONNECT packet
29    ///
30    /// # Process
31    /// 1. Checks for minimum packet length
32    /// 2. Verifies CONNECT packet type
33    /// 3. Reads variable length header
34    /// 4. Validates protocol name (MQTT/MQIsdp)
35    /// 5. Extracts protocol level byte
36    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
37        let len = src.len();
38        if len < 2 {
39            return Ok(None);
40        }
41
42        let src_slice = src.as_ref();
43        let first_byte = src_slice[0];
44        match utils::decode_variable_length(&src_slice[1..])? {
45            Some((_, mut consumed)) => {
46                consumed += 1;
47
48                if first_byte == packet_type::CONNECT {
49                    if len <= consumed + 6 {
50                        return Ok(None);
51                    }
52
53                    let protocol_len = u16::from_be_bytes(
54                        src[consumed..consumed + 2].try_into().map_err(|_| DecodeError::InvalidProtocol)?,
55                    );
56
57                    // Validate protocol name matches MQTT spec
58                    ensure!(
59                        (protocol_len == 4 && &src[consumed + 2..consumed + 6] == MQTT)
60                            || (protocol_len == 6 && &src[consumed + 2..consumed + 8] == MQISDP),
61                        DecodeError::InvalidProtocol
62                    );
63
64                    // Extract protocol level byte (position after protocol name)
65                    match src[consumed + 2 + protocol_len as usize] {
66                        MQTT_LEVEL_31 | MQTT_LEVEL_311 => Ok(Some(ProtocolVersion::MQTT3)),
67                        MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
68                        _ => Err(DecodeError::InvalidProtocol),
69                    }
70                } else {
71                    Err(DecodeError::UnsupportedPacketType)
72                }
73            }
74            None => Ok(None),
75        }
76    }
77}
78
79impl Encoder<ProtocolVersion> for VersionCodec {
80    type Error = EncodeError;
81
82    /// Encoding not supported for version detection codec
83    ///
84    /// This codec is only used for initial protocol detection,
85    /// actual packet encoding should be handled by version-specific codecs
86    fn encode(&mut self, _: ProtocolVersion, _: &mut BytesMut) -> Result<(), Self::Error> {
87        Err(EncodeError::UnsupportedVersion)
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use bytes::BytesMut;
95
96    /// Test invalid protocol format detection
97    #[test]
98    fn test_invalid_protocol() {
99        let mut buf = BytesMut::from(
100            b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass".as_ref(),
101        );
102        assert!(matches!(VersionCodec.decode(&mut buf), Err(DecodeError::InvalidProtocol)));
103    }
104
105    /// Test valid MQTT 3.1.1 protocol detection
106    #[test]
107    fn test_mqtt3_protocol_detection() {
108        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
109        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT3));
110    }
111
112    /// Test valid MQTT 5.0 protocol detection
113    #[test]
114    fn test_mqtt5_protocol_detection() {
115        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
116        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
117    }
118
119    /// Test partial packet handling
120    #[test]
121    fn test_partial_packet_handling() {
122        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05".as_ref());
123        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
124
125        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04".as_ref());
126        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), None);
127    }
128}