use super::{Error, FixedHeader, len_len, read_u8, write_remaining_length};
use bytes::{Buf, BufMut, Bytes, BytesMut};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ConnectReturnCode {
Success = 0,
RefusedProtocolVersion,
BadClientId,
ServiceUnavailable,
BadUserNamePassword,
NotAuthorized,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConnAck {
pub session_present: bool,
pub code: ConnectReturnCode,
}
impl ConnAck {
#[must_use]
pub const fn new(code: ConnectReturnCode, session_present: bool) -> Self {
Self {
session_present,
code,
}
}
const fn len() -> usize {
1 + 1
}
pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
if fixed_header.remaining_len != Self::len() {
return Err(Error::PayloadSizeIncorrect);
}
let variable_header_index = fixed_header.header_len;
bytes.advance(variable_header_index);
let flags = read_u8(&mut bytes)?;
let return_code = read_u8(&mut bytes)?;
if (flags & 0xFE) != 0 {
return Err(Error::IncorrectPacketFormat);
}
let session_present = (flags & 0x01) == 1;
let code = connect_return(return_code)?;
if code != ConnectReturnCode::Success && session_present {
return Err(Error::IncorrectPacketFormat);
}
let connack = Self {
session_present,
code,
};
Ok(connack)
}
pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
let len = Self::len();
buffer.put_u8(0x20);
let count = write_remaining_length(buffer, len)?;
buffer.put_u8(u8::from(self.session_present));
buffer.put_u8(self.code as u8);
Ok(1 + count + len)
}
#[must_use]
pub const fn size(&self) -> usize {
let len = Self::len();
let remaining_len_size = len_len(len);
1 + remaining_len_size + len
}
}
const fn connect_return(num: u8) -> Result<ConnectReturnCode, Error> {
match num {
0 => Ok(ConnectReturnCode::Success),
1 => Ok(ConnectReturnCode::RefusedProtocolVersion),
2 => Ok(ConnectReturnCode::BadClientId),
3 => Ok(ConnectReturnCode::ServiceUnavailable),
4 => Ok(ConnectReturnCode::BadUserNamePassword),
5 => Ok(ConnectReturnCode::NotAuthorized),
num => Err(Error::InvalidConnectReturnCode(num)),
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::mqttbytes::parse_fixed_header;
use bytes::BytesMut;
use pretty_assertions::assert_eq;
#[test]
fn connack_parsing_works() {
let mut stream = bytes::BytesMut::new();
let packetstream = &[
0b0010_0000,
0x02, 0x01,
0x00, 0xDE,
0xAD,
0xBE,
0xEF, ];
stream.extend_from_slice(&packetstream[..]);
let fixed_header = parse_fixed_header(stream.iter()).unwrap();
let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
let connack = ConnAck::read(fixed_header, connack_bytes).unwrap();
assert_eq!(
connack,
ConnAck {
session_present: true,
code: ConnectReturnCode::Success,
}
);
}
#[test]
fn connack_encoding_works() {
let connack = ConnAck {
session_present: true,
code: ConnectReturnCode::Success,
};
let mut buf = BytesMut::new();
connack.write(&mut buf).unwrap();
assert_eq!(buf, vec![0b0010_0000, 0x02, 0x01, 0x00]);
}
#[test]
fn connack_parsing_rejects_invalid_remaining_len() {
let mut stream = bytes::BytesMut::new();
let packetstream = &[0b0010_0000, 0x03, 0x00, 0x00, 0x00];
stream.extend_from_slice(packetstream);
let fixed_header = parse_fixed_header(stream.iter()).unwrap();
let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
let connack = ConnAck::read(fixed_header, connack_bytes);
assert!(matches!(connack, Err(Error::PayloadSizeIncorrect)));
}
#[test]
fn connack_parsing_rejects_reserved_flag_bits() {
let mut stream = bytes::BytesMut::new();
let packetstream = &[0b0010_0000, 0x02, 0x02, 0x00];
stream.extend_from_slice(packetstream);
let fixed_header = parse_fixed_header(stream.iter()).unwrap();
let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
let connack = ConnAck::read(fixed_header, connack_bytes);
assert!(matches!(connack, Err(Error::IncorrectPacketFormat)));
}
#[test]
fn connack_parsing_rejects_session_present_on_error() {
let mut stream = bytes::BytesMut::new();
let packetstream = &[0b0010_0000, 0x02, 0x01, 0x01];
stream.extend_from_slice(packetstream);
let fixed_header = parse_fixed_header(stream.iter()).unwrap();
let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze();
let connack = ConnAck::read(fixed_header, connack_bytes);
assert!(matches!(connack, Err(Error::IncorrectPacketFormat)));
}
}