use alloc::vec::Vec;
use core::fmt;
use core::mem;
use derive_builder::Builder;
#[cfg(feature = "std")]
use std::io::IoSlice;
use serde::ser::{SerializeStruct, Serializer};
use serde::Serialize;
use getset::{CopyGetters, Getters};
use crate::mqtt::packet::json_bin_encode::escape_binary_json_string;
use crate::mqtt::packet::mqtt_string::MqttString;
use crate::mqtt::packet::packet_type::{FixedHeader, PacketType};
use crate::mqtt::packet::qos::Qos;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
use crate::mqtt::packet::{IntoPacketId, IsPacketId};
use crate::mqtt::result_code::MqttError;
use crate::mqtt::{Arc, ArcPayload, IntoPayload};
#[derive(PartialEq, Eq, Builder, Clone, Getters, CopyGetters)]
#[builder(no_std, derive(Debug), pattern = "owned", setter(into), build_fn(skip))]
pub struct GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId,
{
#[builder(private)]
fixed_header: [u8; 1],
#[builder(private)]
remaining_length: VariableByteInteger,
#[builder(private)]
topic_name_buf: MqttString,
#[builder(private)]
packet_id_buf: Option<PacketIdType::Buffer>,
#[builder(private)]
payload_buf: ArcPayload,
}
pub type Publish = GenericPublish<u16>;
impl<PacketIdType> GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId,
{
pub fn builder() -> GenericPublishBuilder<PacketIdType> {
GenericPublishBuilder::<PacketIdType>::default()
}
pub const fn packet_type() -> PacketType {
PacketType::Publish
}
pub fn packet_id(&self) -> Option<PacketIdType> {
self.packet_id_buf
.as_ref()
.map(|buf| PacketIdType::from_buffer(buf.as_ref()))
}
pub fn qos(&self) -> Qos {
let qos_value = (self.fixed_header[0] >> 1) & 0b0000_0011;
match qos_value {
0 => Qos::AtMostOnce,
1 => Qos::AtLeastOnce,
2 => Qos::ExactlyOnce,
_ => unreachable!("Invalid QoS value"),
}
}
pub fn dup(&self) -> bool {
(self.fixed_header[0] & 0b0000_1000) != 0
}
pub fn retain(&self) -> bool {
(self.fixed_header[0] & 0b0000_0001) != 0
}
pub fn set_dup(mut self, dup: bool) -> Self {
if dup {
self.fixed_header[0] |= 0b0000_1000;
} else {
self.fixed_header[0] &= !0b0000_1000;
}
self
}
pub fn topic_name(&self) -> &str {
self.topic_name_buf.as_str()
}
pub fn payload(&self) -> &ArcPayload {
&self.payload_buf
}
pub fn size(&self) -> usize {
1 + self.remaining_length.size() + self.remaining_length.to_u32() as usize
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
let mut bufs = Vec::new();
bufs.push(IoSlice::new(&self.fixed_header));
bufs.push(IoSlice::new(self.remaining_length.as_bytes()));
bufs.append(&mut self.topic_name_buf.to_buffers());
if let Some(buf) = &self.packet_id_buf {
bufs.push(IoSlice::new(buf.as_ref()));
}
if self.payload_buf.len() > 0 {
bufs.push(IoSlice::new(self.payload_buf.as_slice()));
}
bufs
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&self.fixed_header);
buf.extend_from_slice(self.remaining_length.as_bytes());
buf.append(&mut self.topic_name_buf.to_continuous_buffer());
if let Some(packet_id_buf) = &self.packet_id_buf {
buf.extend_from_slice(packet_id_buf.as_ref());
}
if self.payload_buf.len() > 0 {
buf.extend_from_slice(self.payload_buf.as_slice());
}
buf
}
pub fn parse(flags: u8, data_arc: Arc<[u8]>) -> Result<(Self, usize), MqttError> {
let fixed_header_byte = FixedHeader::Publish as u8 | (flags & 0b0000_1111);
let qos_value = (flags >> 1) & 0b0000_0011;
if qos_value == 3 {
return Err(MqttError::MalformedPacket);
}
let mut cursor = 0;
let (topic_name, consumed) = MqttString::decode(&data_arc[cursor..])?;
cursor += consumed;
let qos = match qos_value {
0 => Qos::AtMostOnce,
1 => Qos::AtLeastOnce,
2 => Qos::ExactlyOnce,
_ => unreachable!(),
};
let packet_id_buf = if qos != Qos::AtMostOnce {
let buffer_size = core::mem::size_of::<<PacketIdType as IsPacketId>::Buffer>();
if data_arc.len() < cursor + buffer_size {
return Err(MqttError::MalformedPacket);
}
let mut buf = PacketIdType::Buffer::default();
buf.as_mut()
.copy_from_slice(&data_arc[cursor..cursor + buffer_size]);
cursor += buffer_size;
Some(buf)
} else {
None
};
let payload_len = data_arc.len() - cursor;
let payload = if payload_len > 0 {
ArcPayload::new(data_arc.clone(), cursor, payload_len)
} else {
ArcPayload::default()
};
let remaining_size = topic_name.size()
+ packet_id_buf
.as_ref()
.map_or(0, |_| mem::size_of::<PacketIdType>())
+ payload_len;
let publish = GenericPublish {
fixed_header: [fixed_header_byte],
remaining_length: VariableByteInteger::from_u32(remaining_size as u32).unwrap(),
topic_name_buf: topic_name,
packet_id_buf,
payload_buf: payload,
};
Ok((publish, data_arc.len()))
}
}
impl<PacketIdType> GenericPublishBuilder<PacketIdType>
where
PacketIdType: IsPacketId,
{
pub fn topic_name<T>(mut self, topic: T) -> Result<Self, MqttError>
where
T: TryInto<MqttString, Error = MqttError>,
{
let mqtt_str = topic.try_into()?;
if mqtt_str.as_str().contains('#') || mqtt_str.as_str().contains('+') {
return Err(MqttError::MalformedPacket);
}
self.topic_name_buf = Some(mqtt_str);
Ok(self)
}
pub fn qos(mut self, qos: Qos) -> Self {
let mut header = self.fixed_header.unwrap_or([FixedHeader::Publish as u8]);
header[0] &= !0b0000_0110; header[0] |= (qos as u8) << 1;
self.fixed_header = Some(header);
self
}
pub fn dup(mut self, dup: bool) -> Self {
let mut header = self.fixed_header.unwrap_or([FixedHeader::Publish as u8]);
if dup {
header[0] |= 0b0000_1000;
} else {
header[0] &= !0b0000_1000;
}
self.fixed_header = Some(header);
self
}
pub fn retain(mut self, retain: bool) -> Self {
let mut header = self.fixed_header.unwrap_or([FixedHeader::Publish as u8]);
if retain {
header[0] |= 0b00000001;
} else {
header[0] &= !0b00000001;
}
self.fixed_header = Some(header);
self
}
pub fn packet_id<T>(mut self, id: T) -> Self
where
T: IntoPacketId<PacketIdType>,
{
self.packet_id_buf = Some(id.into_packet_id().map(|i| i.to_buffer()));
self
}
pub fn payload<T>(mut self, data: T) -> Self
where
T: IntoPayload,
{
self.payload_buf = Some(data.into_payload());
self
}
fn validate(&self) -> Result<(), MqttError> {
if self.topic_name_buf.is_none()
|| self.topic_name_buf.as_ref().unwrap().as_str().is_empty()
{
return Err(MqttError::MalformedPacket);
}
if let Some(header) = &self.fixed_header {
let qos_value = (header[0] >> 1) & 0b0000_0011;
let qos = match qos_value {
0 => Qos::AtMostOnce,
1 => Qos::AtLeastOnce,
2 => Qos::ExactlyOnce,
_ => return Err(MqttError::MalformedPacket),
};
if qos == Qos::AtMostOnce {
if self.packet_id_buf.is_some() && self.packet_id_buf.as_ref().unwrap().is_some() {
return Err(MqttError::MalformedPacket);
}
} else {
if self.packet_id_buf.is_none() || self.packet_id_buf.as_ref().unwrap().is_none() {
return Err(MqttError::MalformedPacket);
}
if let Some(Some(packet_id_buf)) = &self.packet_id_buf {
let packet_id = PacketIdType::from_buffer(packet_id_buf.as_ref());
if packet_id.is_zero() {
return Err(MqttError::MalformedPacket);
}
}
}
} else if self.packet_id_buf.is_some() && self.packet_id_buf.as_ref().unwrap().is_some() {
return Err(MqttError::MalformedPacket);
}
if let Some(payload) = &self.payload_buf {
if payload.len() > 268435455 {
return Err(MqttError::MalformedPacket);
}
}
Ok(())
}
pub fn build(self) -> Result<GenericPublish<PacketIdType>, MqttError> {
self.validate()?;
let topic_name_buf = self.topic_name_buf.unwrap();
let fixed_header = self.fixed_header.unwrap_or([FixedHeader::Publish as u8]);
let packet_id_buf = self.packet_id_buf.flatten();
let payload = self.payload_buf.unwrap_or_else(ArcPayload::default);
let mut remaining = topic_name_buf.size();
if (fixed_header[0] >> 1) & 0b0000_0011 != 0 && packet_id_buf.is_some() {
remaining += mem::size_of::<PacketIdType>();
}
remaining += payload.len();
let remaining_length = VariableByteInteger::from_u32(remaining as u32).unwrap();
Ok(GenericPublish {
fixed_header,
remaining_length,
topic_name_buf,
packet_id_buf,
payload_buf: payload,
})
}
}
impl<PacketIdType> Serialize for GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut field_count = 6;
field_count += 1;
let mut state = serializer.serialize_struct("publish", field_count)?;
state.serialize_field("type", PacketType::Publish.as_str())?;
state.serialize_field("topic_name", &self.topic_name_buf)?;
state.serialize_field("qos", &self.qos())?;
state.serialize_field("retain", &self.retain())?;
state.serialize_field("dup", &self.dup())?;
state.serialize_field("packet_id", &self.packet_id())?;
let payload_data = self.payload_buf.as_slice();
match escape_binary_json_string(payload_data) {
Some(escaped) => state.serialize_field("payload", &escaped)?,
None => state.serialize_field("payload", &payload_data)?,
}
state.end()
}
}
impl<PacketIdType> fmt::Display for GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match serde_json::to_string(self) {
Ok(json) => write!(f, "{json}"),
Err(e) => write!(f, "{{\"error\": \"{e}\"}}"),
}
}
}
impl<PacketIdType> fmt::Debug for GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<PacketIdType> GenericPacketTrait for GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId,
{
fn size(&self) -> usize {
self.size()
}
#[cfg(feature = "std")]
fn to_buffers(&self) -> Vec<IoSlice<'_>> {
self.to_buffers()
}
fn to_continuous_buffer(&self) -> Vec<u8> {
self.to_continuous_buffer()
}
}
impl<PacketIdType> GenericPacketDisplay for GenericPublish<PacketIdType>
where
PacketIdType: IsPacketId + Serialize,
{
fn fmt_debug(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(self, f)
}
fn fmt_display(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}