1use bytes::BytesMut;
2use tokio_util::codec::{Decoder, Encoder};
3
4use crate::error::{DecodeError, EncodeError};
5use crate::types::{packet_type, MQISDP, MQTT, MQTT_LEVEL_31, MQTT_LEVEL_311, MQTT_LEVEL_5};
6use crate::utils;
7
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10pub enum ProtocolVersion {
11 MQTT3,
13 MQTT5,
15}
16
17#[derive(Debug)]
22pub struct VersionCodec;
23
24impl Decoder for VersionCodec {
25 type Item = ProtocolVersion;
26 type Error = DecodeError;
27
28 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
37 let len = src.len();
38 if len < 2 {
39 return Ok(None);
40 }
41
42 let src_slice = src.as_ref();
43 let first_byte = src_slice[0];
44 match utils::decode_variable_length(&src_slice[1..])? {
45 Some((_, mut consumed)) => {
46 consumed += 1;
47
48 if first_byte == packet_type::CONNECT {
49 if len <= consumed + 6 {
50 return Ok(None);
51 }
52
53 let protocol_len = u16::from_be_bytes(
54 src[consumed..consumed + 2].try_into().map_err(|_| DecodeError::InvalidProtocol)?,
55 );
56
57 ensure!(
59 (protocol_len == 4 && &src[consumed + 2..consumed + 6] == MQTT)
60 || (protocol_len == 6 && &src[consumed + 2..consumed + 8] == MQISDP),
61 DecodeError::InvalidProtocol
62 );
63
64 match src[consumed + 2 + protocol_len as usize] {
66 MQTT_LEVEL_31 | MQTT_LEVEL_311 => Ok(Some(ProtocolVersion::MQTT3)),
67 MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
68 _ => Err(DecodeError::InvalidProtocol),
69 }
70 } else {
71 Err(DecodeError::UnsupportedPacketType)
72 }
73 }
74 None => Ok(None),
75 }
76 }
77}
78
79impl Encoder<ProtocolVersion> for VersionCodec {
80 type Error = EncodeError;
81
82 fn encode(&mut self, _: ProtocolVersion, _: &mut BytesMut) -> Result<(), Self::Error> {
87 Err(EncodeError::UnsupportedVersion)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use bytes::BytesMut;
95
96 #[test]
98 fn test_invalid_protocol() {
99 let mut buf = BytesMut::from(
100 b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass".as_ref(),
101 );
102 assert!(matches!(VersionCodec.decode(&mut buf), Err(DecodeError::InvalidProtocol)));
103 }
104
105 #[test]
107 fn test_mqtt3_protocol_detection() {
108 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
109 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT3));
110 }
111
112 #[test]
114 fn test_mqtt5_protocol_detection() {
115 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
116 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
117 }
118
119 #[test]
121 fn test_partial_packet_handling() {
122 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05".as_ref());
123 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), Some(ProtocolVersion::MQTT5));
124
125 let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04".as_ref());
126 assert_eq!(VersionCodec.decode(&mut buf).unwrap(), None);
127 }
128}