use crate::client::MqttVersion;
use crate::error::{MqttError, ProtocolError};
use crate::transport;
use crate::util::{
self, read_utf8_string, read_variable_byte_integer, write_utf8_string,
};
use core::marker::PhantomData;
use heapless::Vec;
#[cfg(feature = "v5")]
use crate::util::{read_properties, write_properties};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum QoS {
AtMostOnce = 0,
AtLeastOnce = 1,
ExactlyOnce = 2,
}
pub trait EncodePacket {
fn encode(
&self,
buf: &mut [u8],
version: MqttVersion,
) -> Result<usize, MqttError<transport::ErrorPlaceHolder>>;
}
pub trait DecodePacket<'a>: Sized {
fn decode(
buf: &'a [u8],
version: MqttVersion,
) -> Result<Self, MqttError<transport::ErrorPlaceHolder>>;
}
#[derive(Debug)]
pub enum MqttPacket<'a> {
Connect(Connect<'a>),
ConnAck(ConnAck<'a>),
Publish(Publish<'a>),
PubAck(PubAck<'a>),
Subscribe(Subscribe<'a>),
SubAck(SubAck<'a>),
PingReq,
PingResp,
Disconnect(Disconnect<'a>),
}
pub fn decode<'a, T>(
buf: &'a [u8],
version: MqttVersion,
) -> Result<Option<MqttPacket<'a>>, MqttError<T>>
where
T: transport::TransportError,
{
if buf.is_empty() { return Ok(None); }
let packet_type = buf[0] >> 4;
let packet = match packet_type {
1 => MqttPacket::Connect(Connect::decode(buf, version).map_err(MqttError::cast_transport_error)?),
2 => MqttPacket::ConnAck(ConnAck::decode(buf, version).map_err(MqttError::cast_transport_error)?),
3 => MqttPacket::Publish(Publish::decode(buf, version).map_err(MqttError::cast_transport_error)?),
4 => MqttPacket::PubAck(PubAck::decode(buf, version).map_err(MqttError::cast_transport_error)?),
8 => MqttPacket::Subscribe(Subscribe::decode(buf, version).map_err(MqttError::cast_transport_error)?),
9 => MqttPacket::SubAck(SubAck::decode(buf, version).map_err(MqttError::cast_transport_error)?),
12 => MqttPacket::PingReq,
13 => MqttPacket::PingResp,
14 => MqttPacket::Disconnect(Disconnect::decode(buf, version).map_err(MqttError::cast_transport_error)?),
_ => return Err(MqttError::Protocol(ProtocolError::InvalidPacketType(packet_type))),
};
Ok(Some(packet))
}
#[cfg(feature = "v5")]
#[derive(Debug)]
pub struct Property<'a> {
pub id: u8,
pub data: &'a [u8],
}
#[derive(Debug)]
pub struct Connect<'a> {
pub clean_session: bool,
pub keep_alive: u16,
pub client_id: &'a str,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
#[cfg(not(feature = "v5"))]
_phantom: PhantomData<&'a ()>,
}
impl<'a> Connect<'a> {
pub fn new(client_id: &'a str, keep_alive: u16, clean_session: bool) -> Self {
Self { client_id, keep_alive, clean_session, #[cfg(feature = "v5")] properties: Vec::new(), #[cfg(not(feature = "v5"))] _phantom: PhantomData }
}
}
impl<'a> EncodePacket for Connect<'a> {
fn encode(&self, buf: &mut [u8], version: MqttVersion) -> Result<usize, MqttError<transport::ErrorPlaceHolder>> {
let mut cursor = 0;
buf[cursor] = 0x10; cursor += 1;
let remaining_len_pos = cursor;
cursor += 4;
let content_start = cursor;
let protocol_name = if version == MqttVersion::V5 { "MQTT" } else { "MQIsdp" };
cursor += write_utf8_string(&mut buf[cursor..], protocol_name)?;
buf[cursor] = if version == MqttVersion::V5 { 5 } else { 3 }; cursor += 1;
let mut flags = 0;
if self.clean_session { flags |= 0x02; }
buf[cursor] = flags; cursor += 1;
buf[cursor..cursor + 2].copy_from_slice(&self.keep_alive.to_be_bytes()); cursor += 2;
#[cfg(feature = "v5")]
if version == MqttVersion::V5 { write_properties(&mut cursor, buf, &self.properties)?; }
cursor += write_utf8_string(&mut buf[cursor..], self.client_id)?;
let remaining_len = cursor - content_start;
let len_bytes = util::write_variable_byte_integer_len(&mut buf[remaining_len_pos..], remaining_len)?;
let header_len = 1 + len_bytes;
buf.copy_within(content_start..cursor, header_len);
Ok(header_len + remaining_len)
}
}
impl<'a> DecodePacket<'a> for Connect<'a> {
fn decode(buf: &'a [u8], version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
let mut cursor = 2; cursor += 6;
let connect_flags = buf[cursor];
let clean_session = (connect_flags & 0x02) != 0; cursor += 1;
let keep_alive = u16::from_be_bytes([buf[cursor], buf[cursor + 1]]); cursor += 2;
#[cfg(feature = "v5")]
let properties = if version == MqttVersion::V5 { read_properties(&mut cursor, buf)? } else { Vec::new() };
let client_id = read_utf8_string(&mut cursor, buf)?;
Ok(Self { clean_session, keep_alive, client_id, #[cfg(feature = "v5")] properties, #[cfg(not(feature = "v5"))] _phantom: PhantomData })
}
}
#[derive(Debug)]
pub struct ConnAck<'a> {
pub session_present: bool,
pub reason_code: u8,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
#[cfg(not(feature = "v5"))]
_phantom: PhantomData<&'a ()>,
}
impl<'a> DecodePacket<'a> for ConnAck<'a> {
fn decode(buf: &'a [u8], version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
let mut cursor = 2;
let session_present = (buf[cursor] & 0x01) != 0; cursor += 1;
let reason_code = buf[cursor]; cursor += 1;
#[cfg(feature = "v5")]
let properties = if version == MqttVersion::V5 { read_properties(&mut cursor, buf)? } else { Vec::new() };
Ok(Self { session_present, reason_code, #[cfg(feature = "v5")] properties, #[cfg(not(feature = "v5"))] _phantom: PhantomData })
}
}
#[derive(Debug)]
pub struct Publish<'a> {
pub topic: &'a str,
pub qos: QoS,
pub payload: &'a [u8],
pub packet_id: Option<u16>,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
}
impl<'a> DecodePacket<'a> for Publish<'a> {
fn decode(_buf: &'a [u8], _version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
Ok(Publish { topic: "", qos: QoS::AtMostOnce, payload: &[], packet_id: None, #[cfg(feature = "v5")] properties: Vec::new() })
}
}
impl<'a> EncodePacket for Publish<'a> {
fn encode(&self, _buf: &mut [u8], _version: MqttVersion) -> Result<usize, MqttError<transport::ErrorPlaceHolder>> {
Ok(0) }
}
#[derive(Debug)]
pub struct PubAck<'a> {
pub packet_id: u16,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
#[cfg(not(feature = "v5"))]
_phantom: PhantomData<&'a ()>,
}
impl<'a> DecodePacket<'a> for PubAck<'a> {
fn decode(_buf: &'a [u8], _version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
Ok(PubAck { packet_id: 0, #[cfg(feature = "v5")] properties: Vec::new(), #[cfg(not(feature = "v5"))] _phantom: PhantomData })
}
}
#[derive(Debug)]
pub struct Subscribe<'a> {
pub packet_id: u16,
pub topics: Vec<(&'a str, QoS), 8>,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
}
impl<'a> DecodePacket<'a> for Subscribe<'a> {
fn decode(_buf: &'a [u8], _version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
Ok(Subscribe { packet_id: 0, topics: Vec::new(), #[cfg(feature = "v5")] properties: Vec::new() })
}
}
impl<'a> EncodePacket for Subscribe<'a> {
fn encode(&self, _buf: &mut [u8], _version: MqttVersion) -> Result<usize, MqttError<transport::ErrorPlaceHolder>> {
Ok(0) }
}
#[derive(Debug)]
pub struct SubAck<'a> {
pub packet_id: u16,
pub reason_codes: Vec<u8, 8>,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
#[cfg(not(feature = "v5"))]
_phantom: PhantomData<&'a ()>,
}
impl<'a> DecodePacket<'a> for SubAck<'a> {
fn decode(_buf: &'a [u8], _version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
Ok(SubAck { packet_id: 0, reason_codes: Vec::new(), #[cfg(feature = "v5")] properties: Vec::new(), #[cfg(not(feature = "v5"))] _phantom: PhantomData })
}
}
#[derive(Debug)]
pub struct PingReq;
impl EncodePacket for PingReq {
fn encode(&self, buf: &mut [u8], _version: MqttVersion) -> Result<usize, MqttError<transport::ErrorPlaceHolder>> {
if buf.len() < 2 { return Err(MqttError::BufferTooSmall); }
buf[0] = 0xC0;
buf[1] = 0x00;
Ok(2)
}
}
#[derive(Debug)]
pub struct PingResp;
#[derive(Debug)]
pub struct Disconnect<'a> {
#[cfg(feature = "v5")]
pub reason_code: u8,
#[cfg(feature = "v5")]
pub properties: Vec<Property<'a>, 8>,
#[cfg(not(feature = "v5"))]
pub _phantom: PhantomData<&'a ()>,
}
impl<'a> DecodePacket<'a> for Disconnect<'a> {
fn decode(_buf: &'a [u8], _version: MqttVersion) -> Result<Self, MqttError<transport::ErrorPlaceHolder>> {
Ok(Disconnect { #[cfg(feature = "v5")] reason_code: 0, #[cfg(feature = "v5")] properties: Vec::new(), #[cfg(not(feature = "v5"))] _phantom: PhantomData })
}
}
impl<'a> EncodePacket for Disconnect<'a> {
fn encode(&self, buf: &mut [u8], _version: MqttVersion) -> Result<usize, MqttError<transport::ErrorPlaceHolder>> {
if buf.len() < 2 { return Err(MqttError::BufferTooSmall); }
buf[0] = 0xE0;
buf[1] = 0x00;
Ok(2)
}
}