use alloc::vec::Vec;
use core::fmt;
use core::mem;
use derive_builder::Builder;
#[cfg(feature = "std")]
use std::io::IoSlice;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use getset::{CopyGetters, Getters};
use crate::mqtt::packet::packet_type::{FixedHeader, PacketType};
use crate::mqtt::packet::property::PropertiesToContinuousBuffer;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
use crate::mqtt::packet::IsPacketId;
#[cfg(feature = "std")]
use crate::mqtt::packet::PropertiesToBuffers;
use crate::mqtt::packet::{Properties, PropertiesParse, PropertiesSize, Property};
use crate::mqtt::result_code::MqttError;
use crate::mqtt::result_code::SubackReasonCode;
#[derive(PartialEq, Eq, Builder, Clone, Getters, CopyGetters)]
#[builder(no_std, derive(Debug), pattern = "owned", setter(into), build_fn(skip))]
pub struct GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId,
{
#[builder(private)]
fixed_header: [u8; 1],
#[builder(private)]
remaining_length: VariableByteInteger,
#[builder(private)]
packet_id_buf: PacketIdType::Buffer,
#[builder(private)]
property_length: VariableByteInteger,
#[builder(setter(into, strip_option))]
#[getset(get = "pub")]
pub props: Properties,
#[builder(private)]
reason_codes_buf: Vec<u8>,
}
pub type Suback = GenericSuback<u16>;
impl<PacketIdType> GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId,
{
pub fn builder() -> GenericSubackBuilder<PacketIdType> {
GenericSubackBuilder::<PacketIdType>::default()
}
pub fn packet_type() -> PacketType {
PacketType::Suback
}
pub fn packet_id(&self) -> PacketIdType {
PacketIdType::from_buffer(self.packet_id_buf.as_ref())
}
pub fn reason_codes(&self) -> Vec<SubackReasonCode> {
self.reason_codes_buf
.iter()
.map(|&byte| {
SubackReasonCode::try_from(byte).unwrap_or(SubackReasonCode::UnspecifiedError)
})
.collect()
}
pub fn parse(data: &[u8]) -> Result<(Self, usize), MqttError> {
let mut cursor = 0;
let buffer_size = mem::size_of::<<PacketIdType as IsPacketId>::Buffer>();
if data.len() < buffer_size {
return Err(MqttError::MalformedPacket);
}
let packet_id = PacketIdType::from_buffer(&data[0..buffer_size]);
let packet_id_buf = packet_id.to_buffer();
cursor += buffer_size;
let (props, property_length) = Properties::parse(&data[cursor..])?;
cursor += property_length;
validate_suback_properties(&props)?;
let prop_len = VariableByteInteger::from_u32(props.size() as u32).unwrap();
let mut reason_codes_buf = Vec::new();
while cursor < data.len() {
let _reason_code =
SubackReasonCode::try_from(data[cursor]).map_err(|_| MqttError::MalformedPacket)?;
reason_codes_buf.push(data[cursor]);
cursor += 1;
}
if reason_codes_buf.is_empty() {
return Err(MqttError::ProtocolError);
}
let remaining_size = buffer_size + property_length + reason_codes_buf.len();
let remaining_length = VariableByteInteger::from_u32(remaining_size as u32).unwrap();
let suback = GenericSuback {
fixed_header: [FixedHeader::Suback as u8],
remaining_length,
packet_id_buf,
property_length: prop_len,
props,
reason_codes_buf,
};
Ok((suback, cursor))
}
pub fn size(&self) -> usize {
1 + self.remaining_length.size() + self.remaining_length.to_u32() as usize
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
let mut bufs = Vec::new();
bufs.push(IoSlice::new(&self.fixed_header));
bufs.push(IoSlice::new(self.remaining_length.as_bytes()));
bufs.push(IoSlice::new(self.packet_id_buf.as_ref()));
bufs.push(IoSlice::new(self.property_length.as_bytes()));
bufs.extend(self.props.to_buffers());
if !self.reason_codes_buf.is_empty() {
bufs.push(IoSlice::new(&self.reason_codes_buf));
}
bufs
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.fixed_header);
buf.extend_from_slice(self.remaining_length.as_bytes());
buf.extend_from_slice(self.packet_id_buf.as_ref());
buf.extend_from_slice(self.property_length.as_bytes());
buf.append(&mut self.props.to_continuous_buffer());
if !self.reason_codes_buf.is_empty() {
buf.extend_from_slice(&self.reason_codes_buf);
}
buf
}
}
impl<PacketIdType> GenericSubackBuilder<PacketIdType>
where
PacketIdType: IsPacketId,
{
pub fn packet_id(mut self, id: PacketIdType) -> Self {
self.packet_id_buf = Some(id.to_buffer());
self
}
pub fn reason_codes(mut self, codes: Vec<SubackReasonCode>) -> Self {
let reason_codes_buf: Vec<u8> = codes.iter().map(|&rc| rc as u8).collect();
self.reason_codes_buf = Some(reason_codes_buf);
self
}
fn validate(&self) -> Result<(), MqttError> {
if self.packet_id_buf.is_none() {
return Err(MqttError::MalformedPacket);
}
let packet_id_bytes = self.packet_id_buf.as_ref().unwrap().as_ref();
let all_zeros = packet_id_bytes.iter().all(|&b| b == 0);
if all_zeros {
return Err(MqttError::MalformedPacket);
}
if self
.reason_codes_buf
.as_ref()
.map_or(true, |r| r.is_empty())
{
return Err(MqttError::ProtocolError);
}
if let Some(ref props) = self.props {
validate_suback_properties(props)?;
}
Ok(())
}
pub fn build(self) -> Result<GenericSuback<PacketIdType>, MqttError> {
self.validate()?;
let packet_id_buf = self.packet_id_buf.unwrap();
let reason_codes_buf = self.reason_codes_buf.unwrap_or_default();
let props = self.props.unwrap_or_else(Properties::new);
let props_size = props.size();
let property_length = VariableByteInteger::from_u32(props_size as u32).unwrap();
let packet_id_size = mem::size_of::<<PacketIdType as IsPacketId>::Buffer>();
let prop_len_size = property_length.size();
let reason_codes_size = reason_codes_buf.len();
let remaining = packet_id_size + prop_len_size + props_size + reason_codes_size;
let remaining_length = VariableByteInteger::from_u32(remaining as u32).unwrap();
Ok(GenericSuback {
fixed_header: [FixedHeader::Suback as u8],
remaining_length,
packet_id_buf,
property_length,
props,
reason_codes_buf,
})
}
}
impl<PacketIdType> fmt::Display for GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{json}"),
Err(e) => write!(f, "{{\"error\": \"{e}\"}}"),
}
}
}
impl<PacketIdType> fmt::Debug for GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<PacketIdType> Serialize for GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 2;
if !self.props.is_empty() {
field_count += 1;
}
if !self.reason_codes_buf.is_empty() {
field_count += 1;
}
let mut state = serializer.serialize_struct("Suback", field_count)?;
state.serialize_field("type", "suback")?;
state.serialize_field("packet_id", &self.packet_id())?;
if !self.props.is_empty() {
state.serialize_field("props", &self.props)?;
}
if !self.reason_codes_buf.is_empty() {
state.serialize_field("reason_codes", &self.reason_codes())?;
}
state.end()
}
}
impl<PacketIdType> GenericPacketTrait for GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId,
{
fn size(&self) -> usize {
self.size()
}
#[cfg(feature = "std")]
fn to_buffers(&self) -> Vec<IoSlice<'_>> {
self.to_buffers()
}
fn to_continuous_buffer(&self) -> Vec<u8> {
self.to_continuous_buffer()
}
}
impl<PacketIdType> GenericPacketDisplay for GenericSuback<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt_debug(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(self, f)
}
fn fmt_display(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}
fn validate_suback_properties(props: &Properties) -> Result<(), MqttError> {
let mut count_reason_string = 0;
for prop in props {
match prop {
Property::ReasonString(_) => count_reason_string += 1,
Property::UserProperty(_) => {}
_ => return Err(MqttError::ProtocolError),
}
}
if count_reason_string > 1 {
return Err(MqttError::ProtocolError);
}
Ok(())
}