ntex-mqtt 0.2.0-beta.2

MQTT Client/Server framework for v5 and v3.1.1 protocols
Documentation
use std::convert::TryInto;
use std::io;

use bytes::BytesMut;
use ntex_codec::{Decoder, Encoder};

use crate::error::{DecodeError, EncodeError};
use crate::types::{packet_type, MQTT, MQTT_LEVEL_3, MQTT_LEVEL_5};
use crate::utils;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ProtocolVersion {
    MQTT3,
    MQTT5,
}

#[derive(Debug)]
pub struct VersionCodec;

impl Decoder for VersionCodec {
    type Item = ProtocolVersion;
    type Error = DecodeError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
        let len = src.len();
        if len < 2 {
            return Ok(None);
        }

        let src_slice = src.as_ref();
        let first_byte = src_slice[0];
        match utils::decode_variable_length(&src_slice[1..])? {
            Some((_, mut consumed)) => {
                consumed += 1;

                if first_byte == packet_type::CONNECT {
                    if len <= consumed + 5 {
                        return Ok(None);
                    }

                    let len =
                        u16::from_be_bytes(src[consumed..consumed + 2].try_into().unwrap());
                    ensure!(
                        len == 4 && &src[consumed + 2..consumed + 6] == MQTT,
                        DecodeError::InvalidProtocol
                    );

                    match src[consumed + 6] {
                        MQTT_LEVEL_3 => Ok(Some(ProtocolVersion::MQTT3)),
                        MQTT_LEVEL_5 => Ok(Some(ProtocolVersion::MQTT5)),
                        _ => Err(DecodeError::InvalidProtocol),
                    }
                } else {
                    Err(DecodeError::UnsupportedPacketType)
                }
            }
            None => Ok(None),
        }
    }
}

impl Encoder for VersionCodec {
    type Item = ProtocolVersion;
    type Error = EncodeError;

    fn encode(&mut self, _: Self::Item, _: &mut BytesMut) -> Result<(), EncodeError> {
        Err(EncodeError::Io(io::Error::new(
            io::ErrorKind::Other,
            "Unsupported",
        )))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::BytesMut;

    #[test]
    fn test_decode_connect_packets() {
        let mut buf = BytesMut::from(
            b"\x10\x7f\x7f\x00\x04MQTT\x06\xC0\x00\x3C\x00\x0512345\x00\x04user\x00\x04pass"
                .as_ref(),
        );
        assert_eq!(
            Err(DecodeError::InvalidProtocol),
            VersionCodec.decode(&mut buf)
        );

        let mut buf =
            BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x04\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
        assert_eq!(
            ProtocolVersion::MQTT3,
            VersionCodec.decode(&mut buf).unwrap().unwrap()
        );

        let mut buf =
            BytesMut::from(b"\x10\x98\x02\0\x04MQTT\x05\xc0\0\x0f\0\x02d1\0|testhub.".as_ref());
        assert_eq!(
            ProtocolVersion::MQTT5,
            VersionCodec.decode(&mut buf).unwrap().unwrap()
        );
    }
}