use core::marker::PhantomData;
use heapless::Vec;
use crate::{
buffer::BufferProvider,
eio::Read,
fmt::{trace, verbose},
header::{FixedHeader, PacketType},
io::{
read::{BodyReader, Readable},
write::{Writable, wlen},
},
packet::{Packet, RxError, RxPacket},
types::{PacketIdentifier, ReasonCode, VarByteInt},
v5::{
packet::subacks::types::{Suback, SubackPacketType, Unsuback},
property::PropertyType,
},
};
mod types;
pub type SubackPacket<'p, const MAX_TOPIC_FILTERS: usize> =
GenericSubackPacket<'p, Suback, MAX_TOPIC_FILTERS>;
pub type UnsubackPacket<'p, const MAX_TOPIC_FILTERS: usize> =
GenericSubackPacket<'p, Unsuback, MAX_TOPIC_FILTERS>;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct GenericSubackPacket<'p, T: SubackPacketType, const MAX_TOPIC_FILTERS: usize> {
pub packet_identifier: PacketIdentifier,
pub reason_codes: Vec<ReasonCode, MAX_TOPIC_FILTERS>,
_phantom_data: PhantomData<&'p T>,
}
impl<T: SubackPacketType, const MAX_TOPIC_FILTERS: usize> Packet
for GenericSubackPacket<'_, T, MAX_TOPIC_FILTERS>
{
const PACKET_TYPE: PacketType = T::PACKET_TYPE;
}
impl<'p, T: SubackPacketType, const MAX_TOPIC_FILTERS: usize> RxPacket<'p>
for GenericSubackPacket<'static, T, MAX_TOPIC_FILTERS>
{
async fn receive<R: Read, B: BufferProvider<'p>>(
header: &FixedHeader,
mut reader: BodyReader<'_, 'p, R, B>,
) -> Result<Self, RxError<R::Error, B::ProvisionError>> {
trace!("decoding {:?} packet", T::PACKET_TYPE);
if header.flags() != 0 {
trace!(
"invalid {:?} fixed header flags: {}",
T::PACKET_TYPE,
header.flags()
);
return Err(RxError::MalformedPacket);
}
let r = &mut reader;
verbose!("reading packet identifier field");
let packet_identifier = PacketIdentifier::read(r).await?;
verbose!("reading property length field");
let mut properties_length = VarByteInt::read(r).await?.size();
verbose!("property length: {} bytes", properties_length);
if properties_length > r.remaining_len() {
trace!(
"invalid {:?} property length for remaining packet length",
T::PACKET_TYPE
);
return Err(RxError::MalformedPacket);
}
if properties_length == r.remaining_len() {
trace!("{:?} packet does not contain a reason code", T::PACKET_TYPE);
return Err(RxError::ProtocolError);
}
while properties_length > 0 {
verbose!(
"reading property identifier (remaining length: {} bytes)",
r.remaining_len()
);
let property_type = PropertyType::read(r).await?;
properties_length -= property_type.written_len();
verbose!(
"reading {:?} property body (remaining length: {} bytes)",
property_type,
r.remaining_len()
);
let mut seen_reason_string = false;
#[rustfmt::skip]
match property_type {
PropertyType::ReasonString if seen_reason_string => return Err(RxError::ProtocolError),
PropertyType::ReasonString => {
seen_reason_string = true;
let len = u16::read(r).await? as usize;
verbose!("skipping reason string ({} bytes)", len);
r.skip(len).await?;
properties_length = properties_length.checked_sub(wlen!(u16) + len).ok_or(RxError::MalformedPacket)?;
},
PropertyType::UserProperty => {
let len = u16::read(r).await? as usize;
verbose!("skipping user property name ({} bytes)", len);
r.skip(len).await?;
properties_length = properties_length.checked_sub(wlen!(u16) + len).ok_or(RxError::MalformedPacket)?;
let len = u16::read(r).await? as usize;
verbose!("skipping user property value ({} bytes)", len);
r.skip(len).await?;
properties_length = properties_length.checked_sub(wlen!(u16) + len).ok_or(RxError::MalformedPacket)?;
},
p => {
trace!("invalid {:?} property: {:?}", T::PACKET_TYPE, p);
return Err(RxError::MalformedPacket)
},
};
let _ = seen_reason_string;
}
let mut reason_codes = Vec::new();
while r.remaining_len() > 0 {
verbose!("reading reason code field");
let reason_code = ReasonCode::read(r).await?;
if !T::reason_code_allowed(reason_code) {
trace!(
"invalid {:?} reason code: {:?}",
T::PACKET_TYPE,
reason_code
);
return Err(RxError::ProtocolError);
}
reason_codes
.push(reason_code)
.map_err(|_| RxError::ProtocolError)?;
}
let packet = Self {
packet_identifier,
reason_codes,
_phantom_data: PhantomData,
};
Ok(packet)
}
}
#[cfg(test)]
mod unit {
mod suback {
use core::num::NonZero;
use heapless::Vec;
use crate::{
test::rx::decode,
types::{PacketIdentifier, ReasonCode},
v5::packet::SubackPacket,
};
#[tokio::test]
#[test_log::test]
async fn decode_payload() {
#[rustfmt::skip]
let packet: SubackPacket<'_, 12> = decode!(
SubackPacket<12>,
15,
[
0x90,
0x0F,
0x17, 0x89,
0x00,
0x00, 0xA2, 0x01, 0xA1, 0x02, 0x9E, 0x80, 0x97, 0x83, 0x91, 0x87, 0x8F,
]
);
assert_eq!(
packet.packet_identifier,
PacketIdentifier::new(NonZero::new(6025).unwrap())
);
let mut reason_codes: Vec<_, 12> = Vec::new();
reason_codes.push(ReasonCode::Success).unwrap();
reason_codes
.push(ReasonCode::WildcardSubscriptionsNotSupported)
.unwrap();
reason_codes.push(ReasonCode::GrantedQoS1).unwrap();
reason_codes
.push(ReasonCode::SubscriptionIdentifiersNotSupported)
.unwrap();
reason_codes.push(ReasonCode::GrantedQoS2).unwrap();
reason_codes
.push(ReasonCode::SharedSubscriptionsNotSupported)
.unwrap();
reason_codes.push(ReasonCode::UnspecifiedError).unwrap();
reason_codes.push(ReasonCode::QuotaExceeded).unwrap();
reason_codes
.push(ReasonCode::ImplementationSpecificError)
.unwrap();
reason_codes
.push(ReasonCode::PacketIdentifierInUse)
.unwrap();
reason_codes.push(ReasonCode::NotAuthorized).unwrap();
reason_codes.push(ReasonCode::TopicFilterInvalid).unwrap();
assert_eq!(packet.reason_codes, reason_codes);
}
#[tokio::test]
#[test_log::test]
async fn decode_properties() {
#[rustfmt::skip]
let packet: SubackPacket<'_, 1> = decode!(
SubackPacket<1>,
61,
[
0x90,
0x3D,
0x15, 0xF4,
0x39,
0x1F, 0x00, 0x0C, b'c', b'r', b'a', b'z', b'y', b' ', b't', b'h', b'i', b'n', b'g', b's',
0x26, 0x00, 0x09, b's', b'o', b'm', b'e', b' ', b'n', b'a', b'm', b'e',
0x00, 0x09, b'a', b'n', b'y', b' ', b'v', b'a', b'l', b'u', b'e',
0x26, 0x00, 0x07, b'a', b'n', b'y', b' ', b'k', b'e', b'y',
0x00, 0x07, b'a', b' ', b'v', b'a', b'l', b'u', b'e',
0x00,
]
);
assert_eq!(
packet.packet_identifier,
PacketIdentifier::new(NonZero::new(5620).unwrap())
);
let mut reason_codes: Vec<_, 1> = Vec::new();
reason_codes.push(ReasonCode::Success).unwrap();
assert_eq!(packet.reason_codes, reason_codes);
}
}
mod unsuback {
use core::num::NonZero;
use heapless::Vec;
use crate::{
test::rx::decode,
types::{PacketIdentifier, ReasonCode},
v5::packet::UnsubackPacket,
};
#[tokio::test]
#[test_log::test]
async fn decode_payload() {
#[rustfmt::skip]
let packet: UnsubackPacket<'_, 7> = decode!(
UnsubackPacket<7>,
10,
[
0xB0,
0x0A,
0xA3, 0xF4, 0x00,
0x00, 0x91, 0x11, 0x8F, 0x80, 0x87, 0x83,
]
);
assert_eq!(
packet.packet_identifier,
PacketIdentifier::new(NonZero::new(41972).unwrap())
);
let mut reason_codes: Vec<_, 7> = Vec::new();
reason_codes.push(ReasonCode::Success).unwrap();
reason_codes
.push(ReasonCode::PacketIdentifierInUse)
.unwrap();
reason_codes
.push(ReasonCode::NoSubscriptionExisted)
.unwrap();
reason_codes.push(ReasonCode::TopicFilterInvalid).unwrap();
reason_codes.push(ReasonCode::UnspecifiedError).unwrap();
reason_codes.push(ReasonCode::NotAuthorized).unwrap();
reason_codes
.push(ReasonCode::ImplementationSpecificError)
.unwrap();
assert_eq!(packet.reason_codes, reason_codes);
}
#[tokio::test]
#[test_log::test]
async fn decode_properties() {
#[rustfmt::skip]
let packet: UnsubackPacket<'_, 1> = decode!(
UnsubackPacket<1>,
78,
[
0xB0,
0x4E,
0x26, 0x1C,
0x4A,
0x1F, 0x00, 0x0E, b'g', b'e', b't', b' ', b'o', b'u', b't', b't', b'a', b' ', b'h', b'e', b'r', b'e',
0x26, 0x00, 0x07, b'i', b'm', b'a', b'g', b'i', b'n', b'e',
0x00, 0x0E, b'a', b'l', b'l', b' ', b't', b'h', b'e', b' ', b'p', b'e', b'o', b'p', b'l', b'e',
0x26, 0x00, 0x05, b'p', b'r', b'i', b'd', b'e',
0x00, 0x15, b'(', b'i', b'n', b' ', b't', b'h', b'e', b' ', b'n', b'a', b'm', b'e', b' ', b'o', b'f', b' ', b'l', b'o', b'v', b'e', b')',
0x00,
]
);
assert_eq!(
packet.packet_identifier,
PacketIdentifier::new(NonZero::new(9756).unwrap())
);
let mut reason_codes: Vec<_, 1> = Vec::new();
reason_codes.push(ReasonCode::Success).unwrap();
assert_eq!(packet.reason_codes, reason_codes);
}
}
}