1use bytes::BytesMut;
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::{DecodeError, EncodeError};
5use crate::types::{packet_type, MQISDP, MQTT, MQTT_LEVEL_31, MQTT_LEVEL_311, MQTT_LEVEL_5};
6use crate::utils;
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub enum ProtocolVersion {
11 MQTT3,
13 MQTT5,
15}
16
17#[derive(Debug)]
22pub struct VersionCodec;
23
24impl Decoder for VersionCodec {
25 type Item = ProtocolVersion;
26 type Error = DecodeError;
27
28 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
37 let len = src.len();
38 if len < 2 {
39 return Ok(None);
40 }
41
42 let src_slice = src.as_ref();
43 let first_byte = src_slice[0];
44 match utils::decode_variable_length(&src_slice[1..])? {
45 Some((_, mut consumed)) => {
46 consumed += 1;
47
48 if first_byte == packet_type::CONNECT {
49 if len <= consumed + 6 {
50 return Ok(None);
51 }
52
53 let protocol_len = u16::from_be_bytes(
54 src[consumed..consumed + 2].try_into().map_err(|_| DecodeError::InvalidProtocol)?,
55 );
56
57 if protocol_len == 4 {
59 if &src[consumed + 2..consumed + 6] != MQTT {
61 return Err(DecodeError::InvalidProtocol);
62 }
63 } else if protocol_len == 6 {
64 if len <= consumed + 8 {
66 return Ok(None);
67 }
68 if &src[consumed + 2..consumed + 8] != MQISDP {
69 return Err(DecodeError::InvalidProtocol);
70 }
71 } else {
72 return Err(DecodeError::InvalidProtocol);
73 }
74
75 match src[consumed + 2 + protocol_len as usize] {
77 MQTT_LEVEL_31 | MQTT_LEVEL_311 => Ok(Some(ProtocolVersion::MQTT3)),
78 MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
79 _ => Err(DecodeError::InvalidProtocol),
80 }
81 } else {
82 Err(DecodeError::UnsupportedPacketType)
83 }
84 }
85 None => Ok(None),
86 }
87 }
88}
89
90impl Encoder<ProtocolVersion> for VersionCodec {
91 type Error = EncodeError;
92
93 fn encode(&mut self, _: ProtocolVersion, _: &mut BytesMut) -> Result<(), Self::Error> {
98 Err(EncodeError::UnsupportedVersion)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use bytes::BytesMut;
106
107 #[test]
109 fn test_invalid_protocol() {
110 let mut buf = BytesMut::from(
111 b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass".as_ref(),
112 );
113 assert!(matches!(VersionCodec.decode(&mut buf), Err(DecodeError::InvalidProtocol)));
114 }
115
116 #[test]
118 fn test_mqtt3_protocol_detection() {
119 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
120 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT3));
121 }
122
123 #[test]
125 fn test_mqtt5_protocol_detection() {
126 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
127 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
128 }
129
130 #[test]
132 fn test_partial_packet_handling() {
133 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05".as_ref());
134 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
135
136 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04".as_ref());
137 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), None);
138 }
139}