ntex-mqtt 0.10.1

Client and Server framework for MQTT v5 and v3.1.1 protocols
Documentation
use std::convert::TryInto;

use ntex::codec::{Decoder, Encoder};
use ntex::util::BytesMut;

use crate::error::{DecodeError, EncodeError};
use crate::types::{packet_type, MQTT, MQTT_LEVEL_3, MQTT_LEVEL_5};
use crate::utils;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(super) enum ProtocolVersion {
    MQTT3,
    MQTT5,
}

#[derive(Debug)]
pub(super) struct VersionCodec;

impl Decoder for VersionCodec {
    type Item = ProtocolVersion;
    type Error = DecodeError;

    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
        let len = src.len();
        if len < 2 {
            return Ok(None);
        }

        let src_slice = src.as_ref();
        let first_byte = src_slice[0];
        match utils::decode_variable_length(&src_slice[1..])? {
            Some((_, mut consumed)) => {
                consumed += 1;

                if first_byte == packet_type::CONNECT {
                    if len <= consumed + 6 {
                        return Ok(None);
                    }

                    let len =
                        u16::from_be_bytes(src[consumed..consumed + 2].try_into().unwrap());
                    ensure!(
                        len == 4 && &src[consumed + 2..consumed + 6] == MQTT,
                        DecodeError::InvalidProtocol
                    );

                    match src[consumed + 6] {
                        MQTT_LEVEL_3 => Ok(Some(ProtocolVersion::MQTT3)),
                        MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
                        _ => Err(DecodeError::InvalidProtocol),
                    }
                } else {
                    Err(DecodeError::UnsupportedPacketType)
                }
            }
            None => Ok(None),
        }
    }
}

impl Encoder for VersionCodec {
    type Item = ProtocolVersion;
    type Error = EncodeError;

    fn encode(&self, _: Self::Item, _: &mut BytesMut) -> Result<(), EncodeError> {
        Err(EncodeError::UnsupportedVersion)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_decode_connect_packets() {
        let mut buf = BytesMut::from(
            b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass"
                .as_ref(),
        );
        assert_eq!(Err(DecodeError::InvalidProtocol), VersionCodec.decode(&mut buf));

        let mut buf =
            BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
        assert_eq!(ProtocolVersion::MQTT3, VersionCodec.decode(&mut buf).unwrap().unwrap());

        let mut buf =
            BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
        assert_eq!(ProtocolVersion::MQTT5, VersionCodec.decode(&mut buf).unwrap().unwrap());

        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05".as_ref());
        assert_eq!(ProtocolVersion::MQTT5, VersionCodec.decode(&mut buf).unwrap().unwrap());

        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04".as_ref());
        assert_eq!(None, VersionCodec.decode(&mut buf).unwrap());

        let mut buf = BytesMut::from(b"\x10\x98\x02\0\x04MQTT".as_ref());
        assert_eq!(None, VersionCodec.decode(&mut buf).unwrap());
    }
}