use alloc::vec::Vec;
use core::fmt;
use derive_builder::Builder;
#[cfg(feature = "std")]
use std::io::IoSlice;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use getset::{CopyGetters, Getters};
use crate::mqtt::packet::packet_type::{FixedHeader, PacketType};
use crate::mqtt::packet::property::PropertiesToContinuousBuffer;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
#[cfg(feature = "std")]
use crate::mqtt::packet::PropertiesToBuffers;
use crate::mqtt::packet::{Properties, PropertiesParse, PropertiesSize, Property};
use crate::mqtt::result_code::DisconnectReasonCode;
use crate::mqtt::result_code::MqttError;
#[derive(PartialEq, Eq, Builder, Clone, Getters, CopyGetters)]
#[builder(no_std, derive(Debug), pattern = "owned", setter(into), build_fn(skip))]
pub struct Disconnect {
#[builder(private)]
fixed_header: [u8; 1],
#[builder(private)]
remaining_length: VariableByteInteger,
#[builder(private)]
reason_code_buf: Option<[u8; 1]>,
#[builder(private)]
property_length: Option<VariableByteInteger>,
#[builder(setter(into, strip_option))]
#[getset(get = "pub")]
pub props: Option<Properties>,
}
impl Disconnect {
pub fn builder() -> DisconnectBuilder {
DisconnectBuilder::default()
}
pub fn packet_type() -> PacketType {
PacketType::Disconnect
}
pub fn reason_code(&self) -> Option<DisconnectReasonCode> {
self.reason_code_buf
.as_ref()
.and_then(|buf| DisconnectReasonCode::try_from(buf[0]).ok())
}
pub fn size(&self) -> usize {
1 + self.remaining_length.size() + self.remaining_length.to_u32() as usize
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
let mut bufs = Vec::new();
bufs.push(IoSlice::new(&self.fixed_header));
bufs.push(IoSlice::new(self.remaining_length.as_bytes()));
if let Some(buf) = &self.reason_code_buf {
bufs.push(IoSlice::new(buf));
}
if let Some(pl) = &self.property_length {
bufs.push(IoSlice::new(pl.as_bytes()));
}
if let Some(ref props) = self.props {
bufs.append(&mut props.to_buffers());
}
bufs
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.fixed_header);
buf.extend_from_slice(self.remaining_length.as_bytes());
if let Some(rc_buf) = &self.reason_code_buf {
buf.extend_from_slice(rc_buf);
}
if let Some(pl) = &self.property_length {
buf.extend_from_slice(pl.as_bytes());
}
if let Some(ref props) = self.props {
buf.append(&mut props.to_continuous_buffer());
}
buf
}
pub fn parse(data: &[u8]) -> Result<(Self, usize), MqttError> {
let mut cursor = 0;
let reason_code_buf = if cursor < data.len() {
let rc = data[cursor];
let _ = DisconnectReasonCode::try_from(rc).map_err(|_| MqttError::MalformedPacket)?;
cursor += 1;
Some([rc])
} else {
None
};
let (property_length, props) = if reason_code_buf.is_some() && cursor < data.len() {
let (props, consumed) = Properties::parse(&data[cursor..])?;
cursor += consumed;
validate_disconnect_properties(&props)?;
let prop_len = VariableByteInteger::from_u32(props.size() as u32).unwrap();
(Some(prop_len), Some(props))
} else {
(None, None)
};
let remaining_size = reason_code_buf.as_ref().map_or(0, |_| 1)
+ property_length.as_ref().map_or(0, |pl| pl.size())
+ props.as_ref().map_or(0, |ps| ps.size());
let disconnect = Disconnect {
fixed_header: [FixedHeader::Disconnect.as_u8()],
remaining_length: VariableByteInteger::from_u32(remaining_size as u32).unwrap(),
reason_code_buf,
property_length,
props,
};
Ok((disconnect, cursor))
}
}
impl DisconnectBuilder {
pub fn reason_code(mut self, rc: DisconnectReasonCode) -> Self {
self.reason_code_buf = Some(Some([rc as u8]));
self
}
fn validate(&self) -> Result<(), MqttError> {
if self.reason_code_buf.is_none() && self.props.is_some() {
return Err(MqttError::MalformedPacket);
}
if let Some(inner_option) = &self.props {
let props = inner_option
.as_ref()
.expect("INTERNAL ERRORS: props was set with None value, this should never happen");
validate_disconnect_properties(props)?;
}
Ok(())
}
pub fn build(self) -> Result<Disconnect, MqttError> {
self.validate()?;
let reason_code_buf = self.reason_code_buf.flatten();
let props = self.props.flatten();
let props_size: usize = props.as_ref().map_or(0, |p| p.size());
let property_length = if props.is_some() {
Some(VariableByteInteger::from_u32(props_size as u32).unwrap())
} else {
None
};
let mut remaining = 0;
if reason_code_buf.is_some() {
remaining += 1;
}
if let Some(ref pl) = property_length {
remaining += pl.size() + props_size;
}
let remaining_length = VariableByteInteger::from_u32(remaining as u32).unwrap();
Ok(Disconnect {
fixed_header: [FixedHeader::Disconnect.as_u8()],
remaining_length,
reason_code_buf,
property_length,
props,
})
}
}
impl Serialize for Disconnect {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 1;
if self.reason_code_buf.is_some() {
field_count += 1; }
if self.props.is_some() {
field_count += 1; }
let mut state = serializer.serialize_struct("disconnect", field_count)?;
state.serialize_field("type", PacketType::Disconnect.as_str())?;
if self.reason_code_buf.is_some() {
state.serialize_field("reason_code", &self.reason_code())?;
}
if let Some(props) = &self.props {
state.serialize_field("props", props)?;
}
state.end()
}
}
impl fmt::Display for Disconnect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{json}"),
Err(e) => write!(f, "{{\"error\": \"{e}\"}}"),
}
}
}
impl fmt::Debug for Disconnect {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl GenericPacketTrait for Disconnect {
fn size(&self) -> usize {
self.size()
}
#[cfg(feature = "std")]
fn to_buffers(&self) -> Vec<IoSlice<'_>> {
self.to_buffers()
}
fn to_continuous_buffer(&self) -> Vec<u8> {
self.to_continuous_buffer()
}
}
impl GenericPacketDisplay for Disconnect {
fn fmt_debug(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(self, f)
}
fn fmt_display(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}
fn validate_disconnect_properties(props: &[Property]) -> Result<(), MqttError> {
let mut count_session_expiry_interval = 0;
let mut count_reason_string = 0;
let mut count_server_reference = 0;
for prop in props {
match prop {
Property::SessionExpiryInterval(_) => count_session_expiry_interval += 1,
Property::ReasonString(_) => count_reason_string += 1,
Property::UserProperty(_) => {}
Property::ServerReference(_) => count_server_reference += 1,
_ => return Err(MqttError::ProtocolError),
}
}
if count_session_expiry_interval > 1 {
return Err(MqttError::ProtocolError);
}
if count_reason_string > 1 {
return Err(MqttError::ProtocolError);
}
if count_server_reference > 1 {
return Err(MqttError::ProtocolError);
}
Ok(())
}