use std::convert::{TryFrom, TryInto};
use crate::protocol::{len_len, property, PropertyType};
use bytes::{BufMut, Bytes, BytesMut};
use super::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Disconnect {
V4,
V5 {
reason_code: DisconnectReasonCode,
properties: Option<DisconnectProperties>,
},
}
impl Disconnect {
pub fn new(protocol: Protocol) -> Self {
match protocol {
Protocol::V4 => Self::V4,
Protocol::V5 => Self::V5 {
reason_code: DisconnectReasonCode::NormalDisconnection,
properties: None,
},
}
}
fn len(&self) -> usize {
match self {
Disconnect::V4 => 2,
Disconnect::V5 {
reason_code,
properties,
} => {
if *reason_code == DisconnectReasonCode::NormalDisconnection && properties.is_none()
{
return 2; }
let mut length = 0;
if let Some(properties) = properties {
length += 1; let properties_len = properties.len();
let properties_len_len = len_len(properties_len);
length += properties_len_len + properties_len;
} else {
length += 1;
}
length
}
}
}
pub fn read(
fixed_header: FixedHeader,
mut bytes: Bytes,
protocol: Protocol,
) -> Result<Self, PacketParseError> {
let packet_type = fixed_header.byte1 >> 4;
let flags = fixed_header.byte1 & 0b0000_1111;
bytes.advance(fixed_header.fixed_header_len);
if packet_type != PacketType::Disconnect as u8 {
return Err(PacketParseError::InvalidPacketType(packet_type));
};
if flags != 0x00 {
return Err(PacketParseError::MalformedPacket);
};
let disconnect = match protocol {
Protocol::V4 => Self::V4,
Protocol::V5 => {
if fixed_header.remaining_len == 0 {
Self::V5 {
reason_code: DisconnectReasonCode::NormalDisconnection,
properties: None,
}
} else {
let reason_code = read_u8(&mut bytes)?;
Self::V5 {
reason_code: reason_code.try_into()?,
properties: DisconnectProperties::extract(&mut bytes)?,
}
}
}
};
Ok(disconnect)
}
pub fn data(&self) -> Bytes {
let mut buffer = BytesMut::new();
self.write(&mut buffer);
buffer.freeze()
}
pub fn write(&self, buffer: &mut BytesMut) -> usize {
match self {
Disconnect::V4 => {
buffer.put_u8(0xE0);
buffer.put_u8(0x00);
2
}
Disconnect::V5 {
reason_code,
properties,
} => {
buffer.put_u8(0xE0);
let length = self.len();
if length == 2 {
buffer.put_u8(0x00);
return length;
}
let len_len = write_remaining_length(buffer, length);
buffer.put_u8(*reason_code as u8);
if let Some(properties) = properties {
properties.write(buffer);
} else {
write_remaining_length(buffer, 0);
}
1 + len_len + length
}
}
}
}
impl TryFrom<u8> for DisconnectReasonCode {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
let rc = match value {
0x00 => Self::NormalDisconnection,
0x04 => Self::DisconnectWithWillMessage,
0x80 => Self::UnspecifiedError,
0x81 => Self::MalformedPacket,
0x82 => Self::ProtocolError,
0x83 => Self::ImplementationSpecificError,
0x87 => Self::NotAuthorized,
0x89 => Self::ServerBusy,
0x8B => Self::ServerShuttingDown,
0x8D => Self::KeepAliveTimeout,
0x8E => Self::SessionTakenOver,
0x8F => Self::TopicFilterInvalid,
0x90 => Self::TopicNameInvalid,
0x93 => Self::ReceiveMaximumExceeded,
0x94 => Self::TopicAliasInvalid,
0x95 => Self::PacketTooLarge,
0x96 => Self::MessageRateTooHigh,
0x97 => Self::QuotaExceeded,
0x98 => Self::AdministrativeAction,
0x99 => Self::PayloadFormatInvalid,
0x9A => Self::RetainNotSupported,
0x9B => Self::QoSNotSupported,
0x9C => Self::UseAnotherServer,
0x9D => Self::ServerMoved,
0x9E => Self::SharedSubscriptionNotSupported,
0x9F => Self::ConnectionRateExceeded,
0xA0 => Self::MaximumConnectTime,
0xA1 => Self::SubscriptionIdentifiersNotSupported,
0xA2 => Self::WildcardSubscriptionsNotSupported,
other => return Err(PacketParseError::InvalidConnectReturnCode(other)),
};
Ok(rc)
}
}
impl DisconnectProperties {
fn len(&self) -> usize {
let mut length = 0;
if self.session_expiry_interval.is_some() {
length += 1 + 4;
}
if let Some(reason) = &self.reason_string {
length += 1 + 2 + reason.len();
}
for (key, value) in self.user_properties.iter() {
length += 1 + 2 + key.len() + 2 + value.len();
}
if let Some(server_reference) = &self.server_reference {
length += 1 + 2 + server_reference.len();
}
length
}
pub fn extract(bytes: &mut Bytes) -> Result<Option<Self>, PacketParseError> {
let (properties_len_len, properties_len) = length(bytes.iter())?;
bytes.advance(properties_len_len);
if properties_len == 0 {
return Ok(None);
}
let mut session_expiry_interval = None;
let mut reason_string = None;
let mut user_properties = Vec::new();
let mut server_reference = None;
let mut cursor = 0;
while cursor < properties_len {
let prop = read_u8(bytes)?;
cursor += 1;
match property(prop)? {
PropertyType::SessionExpiryInterval => {
session_expiry_interval = Some(read_u32(bytes)?);
cursor += 4;
}
PropertyType::ReasonString => {
let reason = read_mqtt_string(bytes)?;
cursor += 2 + reason.len();
reason_string = Some(reason);
}
PropertyType::UserProperty => {
let key = read_mqtt_string(bytes)?;
let value = read_mqtt_string(bytes)?;
cursor += 2 + key.len() + 2 + value.len();
user_properties.push((key, value));
}
PropertyType::ServerReference => {
let reference = read_mqtt_string(bytes)?;
cursor += 2 + reference.len();
server_reference = Some(reference);
}
_ => return Err(PacketParseError::InvalidPropertyType(prop)),
}
}
let properties = Self {
session_expiry_interval,
reason_string,
user_properties,
server_reference,
};
Ok(Some(properties))
}
fn write(&self, buffer: &mut BytesMut) {
let length = self.len();
write_remaining_length(buffer, length);
if let Some(session_expiry_interval) = self.session_expiry_interval {
buffer.put_u8(PropertyType::SessionExpiryInterval as u8);
buffer.put_u32(session_expiry_interval);
}
if let Some(reason) = &self.reason_string {
buffer.put_u8(PropertyType::ReasonString as u8);
write_mqtt_string(buffer, reason);
}
for (key, value) in self.user_properties.iter() {
buffer.put_u8(PropertyType::UserProperty as u8);
write_mqtt_string(buffer, key);
write_mqtt_string(buffer, value);
}
if let Some(reference) = &self.server_reference {
buffer.put_u8(PropertyType::ServerReference as u8);
write_mqtt_string(buffer, reference);
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DisconnectProperties {
pub session_expiry_interval: Option<u32>,
pub reason_string: Option<String>,
pub user_properties: Vec<(String, String)>,
pub server_reference: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum DisconnectReasonCode {
NormalDisconnection = 0x00,
DisconnectWithWillMessage = 0x04,
UnspecifiedError = 0x80,
MalformedPacket = 0x81,
ProtocolError = 0x82,
ImplementationSpecificError = 0x83,
NotAuthorized = 0x87,
ServerBusy = 0x89,
ServerShuttingDown = 0x8B,
KeepAliveTimeout = 0x8D,
SessionTakenOver = 0x8E,
TopicFilterInvalid = 0x8F,
TopicNameInvalid = 0x90,
ReceiveMaximumExceeded = 0x93,
TopicAliasInvalid = 0x94,
PacketTooLarge = 0x95,
MessageRateTooHigh = 0x96,
QuotaExceeded = 0x97,
AdministrativeAction = 0x98,
PayloadFormatInvalid = 0x99,
RetainNotSupported = 0x9A,
QoSNotSupported = 0x9B,
UseAnotherServer = 0x9C,
ServerMoved = 0x9D,
SharedSubscriptionNotSupported = 0x9E,
ConnectionRateExceeded = 0x9F,
MaximumConnectTime = 0xA0,
SubscriptionIdentifiersNotSupported = 0xA1,
WildcardSubscriptionsNotSupported = 0xA2,
}