use heapless::Vec;
use crate::{
eio::Write,
header::{FixedHeader, PacketType},
io::write::Writable,
packet::{Packet, TxError, TxPacket},
types::{PacketIdentifier, SubscriptionFilter, TooLargeToEncode, VarByteInt},
v5::property::SubscriptionIdentifier,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SubscribePacket<'p, const MAX_TOPIC_FILTERS: usize> {
packet_identifier: PacketIdentifier,
subscription_identifier: Option<SubscriptionIdentifier>,
subscribe_filters: Vec<SubscriptionFilter<'p>, MAX_TOPIC_FILTERS>,
}
impl<const MAX_TOPIC_FILTERS: usize> Packet for SubscribePacket<'_, MAX_TOPIC_FILTERS> {
const PACKET_TYPE: PacketType = PacketType::Subscribe;
}
impl<const MAX_TOPIC_FILTERS: usize> TxPacket for SubscribePacket<'_, MAX_TOPIC_FILTERS> {
fn remaining_len(&self) -> VarByteInt {
unsafe { self.remaining_len_raw().unwrap_unchecked() }
}
async fn send<W: Write>(&self, write: &mut W) -> Result<(), TxError<W::Error>> {
FixedHeader::new(Self::PACKET_TYPE, 0x02, self.remaining_len())
.write(write)
.await?;
self.packet_identifier.write(write).await?;
self.properties_length().write(write).await?;
self.subscription_identifier.write(write).await?;
self.subscribe_filters.write(write).await?;
Ok(())
}
}
impl<'p, const MAX_TOPIC_FILTERS: usize> SubscribePacket<'p, MAX_TOPIC_FILTERS> {
pub fn new(
packet_identifier: PacketIdentifier,
subscription_identifier: Option<VarByteInt>,
subscribe_filters: Vec<SubscriptionFilter<'p>, MAX_TOPIC_FILTERS>,
) -> Result<Self, TooLargeToEncode> {
let p = Self {
packet_identifier,
subscription_identifier: subscription_identifier.map(Into::into),
subscribe_filters,
};
const GUARANTEED_ENCODABLE_TOPIC_FILTERS: usize = 4095;
if MAX_TOPIC_FILTERS > GUARANTEED_ENCODABLE_TOPIC_FILTERS {
p.remaining_len_raw().map(|_| p)
} else {
Ok(p)
}
}
fn remaining_len_raw(&self) -> Result<VarByteInt, TooLargeToEncode> {
let variable_header_length = self.packet_identifier.written_len();
let properties_length = self.properties_length();
let total_properties_length = properties_length.size() + properties_length.written_len();
let body_length = self.subscribe_filters.written_len();
let total_length = variable_header_length + total_properties_length + body_length;
VarByteInt::try_from(total_length as u32)
}
pub fn properties_length(&self) -> VarByteInt {
let len = self.subscription_identifier.written_len();
VarByteInt::new_unchecked(len as u32)
}
}
#[cfg(test)]
mod unit {
use core::num::NonZero;
use heapless::Vec;
use crate::{
client::options::{RetainHandling, SubscriptionOptions},
test::tx::encode,
types::{MqttString, PacketIdentifier, QoS, SubscriptionFilter, TopicFilter, VarByteInt},
v5::packet::SubscribePacket,
};
#[tokio::test]
#[test_log::test]
async fn encode_payload() {
let mut topics = Vec::new();
topics
.push(SubscriptionFilter::new(
TopicFilter::new(MqttString::try_from("test/hello").unwrap()).unwrap(),
&SubscriptionOptions {
retain_handling: RetainHandling::AlwaysSend,
retain_as_published: false,
no_local: true,
qos: QoS::AtMostOnce,
subscription_identifier: None,
},
))
.unwrap();
topics
.push(SubscriptionFilter::new(
TopicFilter::new(MqttString::try_from("asdfjklo/#").unwrap()).unwrap(),
&SubscriptionOptions {
retain_handling: RetainHandling::NeverSend,
retain_as_published: true,
no_local: false,
qos: QoS::ExactlyOnce,
subscription_identifier: None,
},
))
.unwrap();
let packet: SubscribePacket<'_, 2> = SubscribePacket::new(
PacketIdentifier::new(NonZero::new(23197).unwrap()),
None,
topics,
)
.unwrap();
#[rustfmt::skip]
encode!(packet, [
0x82, 0x1D, 0x5A, 0x9D, 0x00, 0x00, 0x0A, b't', b'e', b's', b't', b'/', b'h', b'e', b'l', b'l', b'o', 0x04, 0x00, 0x0A, b'a', b's', b'd', b'f', b'j', b'k', b'l', b'o', b'/', b'#', 0x2A, ]
);
}
#[tokio::test]
#[test_log::test]
async fn encode_properties() {
let mut topics = Vec::new();
topics
.push(SubscriptionFilter::new(
TopicFilter::new(MqttString::try_from("abc/+/y").unwrap()).unwrap(),
&SubscriptionOptions {
retain_handling: RetainHandling::SendIfNotSubscribedBefore,
retain_as_published: true,
no_local: false,
qos: QoS::AtMostOnce,
subscription_identifier: Some(VarByteInt::from(23459u16)),
},
))
.unwrap();
let packet: SubscribePacket<'_, 10> = SubscribePacket::new(
PacketIdentifier::new(NonZero::new(23197).unwrap()),
Some(VarByteInt::new(87986078u32).unwrap()),
topics,
)
.unwrap();
#[rustfmt::skip]
encode!(packet, [
0x82, 0x12, 0x5A, 0x9D, 0x05, 0x0B, 0x9E, 0x9F, 0xFA, 0x29, 0x00, 0x07, b'a', b'b', b'c', b'/', b'+', b'/', b'y', 0x18,
]
);
}
}