use alloc::vec::Vec;
use core::fmt;
use derive_builder::Builder;
#[cfg(feature = "std")]
use std::io::IoSlice;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use getset::{CopyGetters, Getters};
use crate::mqtt::packet::packet_type::{FixedHeader, PacketType};
use crate::mqtt::packet::property::PropertiesToContinuousBuffer;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
#[cfg(feature = "std")]
use crate::mqtt::packet::PropertiesToBuffers;
use crate::mqtt::packet::{Properties, PropertiesParse, PropertiesSize, Property};
use crate::mqtt::result_code::ConnectReasonCode;
use crate::mqtt::result_code::MqttError;
#[derive(PartialEq, Eq, Builder, Clone, Getters, CopyGetters)]
#[builder(no_std, derive(Debug), pattern = "owned", setter(into), build_fn(skip))]
pub struct Connack {
#[builder(private)]
fixed_header: [u8; 1],
#[builder(private)]
remaining_length: VariableByteInteger,
#[builder(private)]
ack_flags: [u8; 1],
#[builder(private)]
reason_code_buf: [u8; 1],
#[builder(private)]
property_length: VariableByteInteger,
#[builder(setter(into, strip_option))]
#[getset(get = "pub")]
pub props: Properties,
}
impl Connack {
pub fn builder() -> ConnackBuilder {
ConnackBuilder::default()
}
pub fn packet_type() -> PacketType {
PacketType::Connack
}
pub fn session_present(&self) -> bool {
(self.ack_flags[0] & 0b0000_0001) != 0
}
pub fn reason_code(&self) -> ConnectReasonCode {
ConnectReasonCode::try_from(self.reason_code_buf[0]).unwrap()
}
pub fn size(&self) -> usize {
1 + self.remaining_length.size() + self.remaining_length.to_u32() as usize
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
let mut bufs = Vec::new();
bufs.push(IoSlice::new(&self.fixed_header));
bufs.push(IoSlice::new(self.remaining_length.as_bytes()));
bufs.push(IoSlice::new(&self.ack_flags));
bufs.push(IoSlice::new(&self.reason_code_buf));
bufs.push(IoSlice::new(self.property_length.as_bytes()));
bufs.extend(self.props.to_buffers());
bufs
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.fixed_header);
buf.extend_from_slice(self.remaining_length.as_bytes());
buf.extend_from_slice(&self.ack_flags);
buf.extend_from_slice(&self.reason_code_buf);
buf.extend_from_slice(self.property_length.as_bytes());
buf.append(&mut self.props.to_continuous_buffer());
buf
}
pub fn parse(data: &[u8]) -> Result<(Self, usize), MqttError> {
let mut cursor = 0;
if data.len() < 3 {
return Err(MqttError::MalformedPacket);
}
let flags = data[cursor];
cursor += 1;
let _session = (flags & 0x01) != 0;
let code = data[cursor];
cursor += 1;
let _reason = ConnectReasonCode::try_from(code).map_err(|_| MqttError::MalformedPacket)?;
let (props, consumed) = Properties::parse(&data[cursor..])?;
cursor += consumed;
validate_connack_properties(&props)?;
let prop_len = VariableByteInteger::from_u32(props.size() as u32).unwrap();
let connack = Connack {
fixed_header: [FixedHeader::Connack.as_u8()],
remaining_length: VariableByteInteger::from_u32(cursor as u32).unwrap(),
ack_flags: [flags],
reason_code_buf: [code],
property_length: prop_len,
props,
};
Ok((connack, cursor))
}
}
impl ConnackBuilder {
pub fn session_present(mut self, v: bool) -> Self {
self.ack_flags = Some([v as u8]);
self
}
pub fn reason_code(mut self, rc: ConnectReasonCode) -> Self {
self.reason_code_buf = Some([rc as u8]);
self
}
fn validate(&self) -> Result<(), MqttError> {
if self.ack_flags.is_none() {
return Err(MqttError::MalformedPacket);
}
if self.reason_code_buf.is_none() {
return Err(MqttError::MalformedPacket);
}
if let Some(ref props) = self.props {
validate_connack_properties(props)?;
}
Ok(())
}
pub fn build(self) -> Result<Connack, MqttError> {
self.validate()?;
let ack_flags = self.ack_flags.unwrap_or([0]);
let reason_code_buf = self.reason_code_buf.unwrap_or([0]);
let props = self.props.unwrap_or_else(Properties::new);
let props_size: usize = props.size();
let property_length = VariableByteInteger::from_u32(props_size as u32).unwrap();
let remaining = 1 + 1 + property_length.size() + props_size;
let remaining_length = VariableByteInteger::from_u32(remaining as u32).unwrap();
Ok(Connack {
fixed_header: [FixedHeader::Connack.as_u8()],
remaining_length,
ack_flags,
reason_code_buf,
property_length,
props,
})
}
}
impl Serialize for Connack {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 3;
if !self.props.is_empty() {
field_count += 1;
}
let mut state = serializer.serialize_struct("Connack", field_count)?;
state.serialize_field("type", PacketType::Connack.as_str())?;
state.serialize_field("session_present", &self.session_present())?;
state.serialize_field("reason_code", &self.reason_code())?;
state.serialize_field("props", &self.props)?;
state.end()
}
}
impl fmt::Display for Connack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{json}"),
Err(e) => write!(f, "{{\"error\": \"{e}\"}}"),
}
}
}
impl fmt::Debug for Connack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl GenericPacketTrait for Connack {
fn size(&self) -> usize {
self.size()
}
#[cfg(feature = "std")]
fn to_buffers(&self) -> Vec<IoSlice<'_>> {
self.to_buffers()
}
fn to_continuous_buffer(&self) -> Vec<u8> {
self.to_continuous_buffer()
}
}
impl GenericPacketDisplay for Connack {
fn fmt_debug(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(self, f)
}
fn fmt_display(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}
fn validate_connack_properties(props: &[Property]) -> Result<(), MqttError> {
let mut count_session_expiry_interval = 0;
let mut count_receive_maximum = 0;
let mut count_maximum_qos = 0;
let mut count_retain_available = 0;
let mut count_maximum_packet_size = 0;
let mut count_assigned_client_identifier = 0;
let mut count_topic_alias_maximum = 0;
let mut count_reason_string = 0;
let mut count_wildcard_subscription_available = 0;
let mut count_subscription_identifier_available = 0;
let mut count_shared_subscription_available = 0;
let mut count_server_keep_alive = 0;
let mut count_response_information = 0;
let mut count_server_reference = 0;
let mut count_authentication_method = 0;
let mut count_authentication_data = 0;
for prop in props {
match prop {
Property::SessionExpiryInterval(_) => count_session_expiry_interval += 1,
Property::ReceiveMaximum(_) => count_receive_maximum += 1,
Property::MaximumQos(_) => count_maximum_qos += 1,
Property::RetainAvailable(_) => count_retain_available += 1,
Property::MaximumPacketSize(_) => count_maximum_packet_size += 1,
Property::AssignedClientIdentifier(_) => count_assigned_client_identifier += 1,
Property::TopicAliasMaximum(_) => count_topic_alias_maximum += 1,
Property::ReasonString(_) => count_reason_string += 1,
Property::WildcardSubscriptionAvailable(_) => {
count_wildcard_subscription_available += 1
}
Property::SubscriptionIdentifierAvailable(_) => {
count_subscription_identifier_available += 1
}
Property::SharedSubscriptionAvailable(_) => count_shared_subscription_available += 1,
Property::ServerKeepAlive(_) => count_server_keep_alive += 1,
Property::ResponseInformation(_) => count_response_information += 1,
Property::ServerReference(_) => count_server_reference += 1,
Property::AuthenticationMethod(_) => count_authentication_method += 1,
Property::AuthenticationData(_) => count_authentication_data += 1,
Property::UserProperty(_) => {}
_ => return Err(MqttError::ProtocolError),
}
}
if count_session_expiry_interval > 1
|| count_receive_maximum > 1
|| count_maximum_qos > 1
|| count_retain_available > 1
|| count_maximum_packet_size > 1
|| count_assigned_client_identifier > 1
|| count_topic_alias_maximum > 1
|| count_reason_string > 1
|| count_wildcard_subscription_available > 1
|| count_subscription_identifier_available > 1
|| count_shared_subscription_available > 1
|| count_server_keep_alive > 1
|| count_response_information > 1
|| count_server_reference > 1
|| count_authentication_method > 1
|| count_authentication_data > 1
{
return Err(MqttError::ProtocolError);
}
Ok(())
}