rmqtt_codec/
version.rs

1use bytes::BytesMut;
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::{DecodeError, EncodeError};
5use crate::types::{packet_type, MQISDP, MQTT, MQTT_LEVEL_31, MQTT_LEVEL_311, MQTT_LEVEL_5};
6use crate::utils;
7
8/// Represents supported MQTT protocol versions
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub enum ProtocolVersion {
11    /// MQTT version 3.1 or 3.1.1
12    MQTT3,
13    /// MQTT version 5.0
14    MQTT5,
15}
16
17/// Codec for detecting MQTT protocol version from initial handshake
18///
19/// This codec is specifically designed to handle the initial CONNECT packet
20/// and determine the protocol version before switching to version-specific codecs
21#[derive(Debug)]
22pub struct VersionCodec;
23
24impl Decoder for VersionCodec {
25    type Item = ProtocolVersion;
26    type Error = DecodeError;
27
28    /// Decodes the protocol version from the initial CONNECT packet
29    ///
30    /// # Process
31    /// 1. Checks for minimum packet length
32    /// 2. Verifies CONNECT packet type
33    /// 3. Reads variable length header
34    /// 4. Validates protocol name (MQTT/MQIsdp)
35    /// 5. Extracts protocol level byte
36    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
37        let len = src.len();
38        if len < 2 {
39            return Ok(None);
40        }
41
42        let src_slice = src.as_ref();
43        let first_byte = src_slice[0];
44        match utils::decode_variable_length(&src_slice[1..])? {
45            Some((_, mut consumed)) => {
46                consumed += 1;
47
48                if first_byte == packet_type::CONNECT {
49                    if len <= consumed + 6 {
50                        return Ok(None);
51                    }
52
53                    let protocol_len = u16::from_be_bytes(
54                        src[consumed..consumed + 2].try_into().map_err(|_| DecodeError::InvalidProtocol)?,
55                    );
56
57                    // Validate protocol name matches MQTT spec
58                    if protocol_len == 4 {
59                        //for mqtt 3.1.1 or 5.0
60                        if &src[consumed + 2..consumed + 6] != MQTT {
61                            return Err(DecodeError::InvalidProtocol);
62                        }
63                    } else if protocol_len == 6 {
64                        //for mqtt 3.1
65                        if len <= consumed + 8 {
66                            return Ok(None);
67                        }
68                        if &src[consumed + 2..consumed + 8] != MQISDP {
69                            return Err(DecodeError::InvalidProtocol);
70                        }
71                    } else {
72                        return Err(DecodeError::InvalidProtocol);
73                    }
74
75                    // Extract protocol level byte (position after protocol name)
76                    match src[consumed + 2 + protocol_len as usize] {
77                        MQTT_LEVEL_31 | MQTT_LEVEL_311 => Ok(Some(ProtocolVersion::MQTT3)),
78                        MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
79                        _ => Err(DecodeError::InvalidProtocol),
80                    }
81                } else {
82                    Err(DecodeError::UnsupportedPacketType)
83                }
84            }
85            None => Ok(None),
86        }
87    }
88}
89
90impl Encoder<ProtocolVersion> for VersionCodec {
91    type Error = EncodeError;
92
93    /// Encoding not supported for version detection codec
94    ///
95    /// This codec is only used for initial protocol detection,
96    /// actual packet encoding should be handled by version-specific codecs
97    fn encode(&mut self, _: ProtocolVersion, _: &mut BytesMut) -> Result<(), Self::Error> {
98        Err(EncodeError::UnsupportedVersion)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use bytes::BytesMut;
106
107    /// Test invalid protocol format detection
108    #[test]
109    fn test_invalid_protocol() {
110        let mut buf = BytesMut::from(
111            b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass".as_ref(),
112        );
113        assert!(matches!(VersionCodec.decode(&mut buf), Err(DecodeError::InvalidProtocol)));
114    }
115
116    /// Test valid MQTT 3.1.1 protocol detection
117    #[test]
118    fn test_mqtt3_protocol_detection() {
119        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
120        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT3));
121    }
122
123    /// Test valid MQTT 5.0 protocol detection
124    #[test]
125    fn test_mqtt5_protocol_detection() {
126        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
127        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
128    }
129
130    /// Test partial packet handling
131    #[test]
132    fn test_partial_packet_handling() {
133        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05".as_ref());
134        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
135
136        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04".as_ref());
137        assert_eq!(VersionCodec.decode(&mut buf).unwrap(), None);
138    }
139}