use crate::core::{
base_types::*,
collections::UserProperties,
error::{
CodecError, InvalidPacketHeader, InvalidPacketSize, InvalidPropertyLength,
MandatoryPropertyMissing, UnexpectedProperty,
},
properties::*,
utils::{ByteLen, Decoder, Encode, Encoder, PacketID, SizedPacket, TryDecode},
};
use bytes::{Bytes, BytesMut};
use core::mem;
use derive_builder::Builder;
#[derive(Builder)]
#[builder(build_fn(error = "CodecError", validate = "Self::validate"))]
pub(crate) struct PublishRx {
#[builder(default)]
pub(crate) dup: bool,
#[builder(default)]
pub(crate) retain: bool,
#[builder(default)]
pub(crate) qos: QoS,
pub(crate) topic_name: UTF8String,
#[builder(setter(strip_option), default)]
pub(crate) packet_identifier: Option<NonZero<u16>>,
#[builder(setter(strip_option), default)]
pub(crate) payload_format_indicator: Option<PayloadFormatIndicator>,
#[builder(setter(strip_option), default)]
pub(crate) topic_alias: Option<TopicAlias>,
#[builder(setter(strip_option), default)]
pub(crate) message_expiry_interval: Option<MessageExpiryInterval>,
#[builder(setter(strip_option), default)]
pub(crate) subscription_identifier: Option<SubscriptionIdentifier>,
#[builder(setter(strip_option), default)]
pub(crate) correlation_data: Option<CorrelationData>,
#[builder(setter(strip_option), default)]
pub(crate) response_topic: Option<ResponseTopic>,
#[builder(setter(strip_option), default)]
pub(crate) content_type: Option<ContentType>,
#[builder(setter(custom), default)]
pub(crate) user_property: UserProperties,
#[builder(default)]
pub(crate) payload: Payload,
}
impl PublishRxBuilder {
fn validate(&self) -> Result<(), CodecError> {
match self.qos.unwrap_or_default() {
QoS::AtMostOnce => Ok(()),
_ => match self.packet_identifier {
Some(_) => Ok(()),
None => Err(MandatoryPropertyMissing.into()),
},
}
}
fn user_property(&mut self, value: UserProperty) {
match self.user_property.as_mut() {
Some(user_property) => {
user_property.push(value);
}
None => {
self.user_property = Some(UserProperties::new());
self.user_property.as_mut().unwrap().push(value);
}
}
}
}
impl PacketID for PublishRx {
const PACKET_ID: u8 = 3;
}
impl TryDecode for PublishRx {
type Error = CodecError;
fn try_decode(bytes: Bytes) -> Result<Self, Self::Error> {
let mut builder = PublishRxBuilder::default();
let mut decoder = Decoder::from(bytes);
let fixed_hdr = decoder.try_decode::<u8>()?;
if fixed_hdr >> 4 != Self::PACKET_ID {
return Err(InvalidPacketHeader.into());
}
let qos = QoS::try_from((fixed_hdr >> 1) & 0x03)?;
builder
.dup(fixed_hdr & (1 << 3) != 0)
.retain(fixed_hdr & 1 != 0)
.qos(qos);
let remaining_len = decoder.try_decode::<VarSizeInt>()?;
if remaining_len > decoder.remaining() {
return Err(InvalidPacketSize.into());
}
let topic_name = decoder.try_decode::<UTF8String>()?;
builder.topic_name(topic_name);
if qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce {
let packet_id = decoder.try_decode::<NonZero<u16>>()?;
builder.packet_identifier(packet_id);
}
let property_len = decoder.try_decode::<VarSizeInt>()?;
if property_len > decoder.remaining() {
return Err(InvalidPropertyLength.into());
}
let property_iterator =
Decoder::from(decoder.get_buf().split_to(property_len.value() as usize))
.iter::<Property>();
for property in property_iterator {
if let Err(err) = property {
return Err(err.into());
}
match property.unwrap() {
Property::PayloadFormatIndicator(val) => {
builder.payload_format_indicator(val);
}
Property::TopicAlias(val) => {
builder.topic_alias(val);
}
Property::MessageExpiryInterval(val) => {
builder.message_expiry_interval(val);
}
Property::SubscriptionIdentifier(val) => {
builder.subscription_identifier(val);
}
Property::CorrelationData(val) => {
builder.correlation_data(val);
}
Property::ResponseTopic(val) => {
builder.response_topic(val);
}
Property::ContentType(val) => {
builder.content_type(val);
}
Property::UserProperty(val) => {
builder.user_property(val);
}
_ => {
return Err(UnexpectedProperty.into());
}
}
}
decoder.advance_by(usize::from(property_len));
builder.payload(decoder.try_decode::<Payload>()?);
builder.build()
}
}
#[derive(Builder)]
#[builder(build_fn(error = "CodecError", validate = "Self::validate"))]
pub(crate) struct PublishTx<'a> {
#[builder(default)]
pub(crate) dup: bool,
#[builder(default)]
pub(crate) retain: bool,
#[builder(default)]
pub(crate) qos: QoS,
pub(crate) topic_name: UTF8StringRef<'a>,
#[builder(setter(strip_option), default)]
pub(crate) packet_identifier: Option<NonZero<u16>>,
#[builder(setter(strip_option), default)]
pub(crate) payload_format_indicator: Option<PayloadFormatIndicator>,
#[builder(setter(strip_option), default)]
pub(crate) topic_alias: Option<TopicAlias>,
#[builder(setter(strip_option), default)]
pub(crate) message_expiry_interval: Option<MessageExpiryInterval>,
#[builder(setter(strip_option), default)]
pub(crate) correlation_data: Option<CorrelationDataRef<'a>>,
#[builder(setter(strip_option), default)]
pub(crate) response_topic: Option<ResponseTopicRef<'a>>,
#[builder(setter(strip_option), default)]
pub(crate) content_type: Option<ContentTypeRef<'a>>,
#[builder(setter(custom), default)]
pub(crate) user_property: Vec<UserPropertyRef<'a>>,
#[builder(setter(strip_option), default)]
pub(crate) payload: Option<PayloadRef<'a>>,
}
impl<'a> PublishTxBuilder<'a> {
fn validate(&self) -> Result<(), CodecError> {
match self.qos.unwrap_or_default() {
QoS::AtMostOnce => Ok(()),
_ => match self.packet_identifier {
Some(_) => Ok(()),
None => Err(MandatoryPropertyMissing.into()),
},
}
}
pub(crate) fn user_property(&mut self, value: UserPropertyRef<'a>) {
match self.user_property.as_mut() {
Some(user_property) => {
user_property.push(value);
}
None => {
self.user_property = Some(Vec::new());
self.user_property.as_mut().unwrap().push(value);
}
}
}
}
impl<'a> PublishTx<'a> {
fn fixed_hdr(&self) -> u8 {
(Self::PACKET_ID << 4)
| ((self.dup as u8) << 3)
| ((self.qos as u8) << 1)
| (self.retain as u8)
}
fn property_len(&self) -> VarSizeInt {
VarSizeInt::try_from(
self.payload_format_indicator
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.topic_alias
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.message_expiry_interval
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.correlation_data
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.response_topic
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.content_type
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ self
.user_property
.iter()
.map(|val| val.byte_len())
.sum::<usize>(),
)
.unwrap()
}
fn remaining_len(&self) -> VarSizeInt {
let property_len = self.property_len();
VarSizeInt::try_from(
self.topic_name.byte_len()
+ self
.packet_identifier
.as_ref()
.map(|val| val.byte_len())
.unwrap_or(0)
+ property_len.len()
+ property_len.value() as usize
+ self.payload.as_ref().map(|val| val.byte_len()).unwrap_or(0),
)
.unwrap()
}
}
impl<'a> PacketID for PublishTx<'a> {
const PACKET_ID: u8 = 3;
}
impl<'a> SizedPacket for PublishTx<'a> {
fn packet_len(&self) -> usize {
let remaining_len = self.remaining_len();
mem::size_of::<u8>() + remaining_len.len()
+ remaining_len.value() as usize
}
}
impl<'a> Encode for PublishTx<'a> {
fn encode(&self, buf: &mut BytesMut) {
let mut encoder = Encoder::from(buf);
encoder.encode(self.fixed_hdr());
let remaining_len = self.remaining_len();
encoder.encode(remaining_len);
encoder.encode(self.topic_name);
if let Some(val) = self.packet_identifier {
encoder.encode(val);
}
encoder.encode(self.property_len());
if let Some(val) = self.payload_format_indicator {
encoder.encode(val);
}
if let Some(val) = self.topic_alias {
encoder.encode(val);
}
if let Some(val) = self.message_expiry_interval {
encoder.encode(val);
}
if let Some(val) = self.correlation_data {
encoder.encode(val);
}
if let Some(val) = self.response_topic {
encoder.encode(val);
}
if let Some(val) = self.content_type {
encoder.encode(val);
}
for val in self.user_property.iter().copied() {
encoder.encode(val);
}
if let Some(payload) = self.payload {
encoder.encode(payload);
}
}
}
#[cfg(test)]
mod test {
use super::*;
const FIXED_HDR: u8 = (PublishRx::PACKET_ID << 4) | 0x0b; const PACKET: [u8; 15] = [
FIXED_HDR, 13, 0, 4, b't', b'e', b's', b't', 0, 13, 0, b't', b'e', b's', b't',
];
#[test]
fn from_bytes_0() {
let packet = PublishRx::try_decode(Bytes::from_static(&PACKET)).unwrap();
assert!(packet.dup);
assert!(packet.retain);
assert_eq!(packet.qos, QoS::AtLeastOnce);
assert_eq!(packet.packet_identifier.unwrap(), 13);
assert_eq!(
packet.payload,
Payload(Bytes::from_static("test".as_bytes()))
);
}
#[test]
fn to_bytes_0() {
let mut builder = PublishTxBuilder::default();
builder.dup(true);
builder.qos(QoS::AtLeastOnce);
builder.retain(true);
builder.packet_identifier(NonZero::try_from(13).unwrap());
builder.topic_name(UTF8StringRef("test"));
builder.payload(PayloadRef(&[b't', b'e', b's', b't']));
let packet = builder.build().unwrap();
let mut buf = BytesMut::new();
packet.encode(&mut buf);
assert_eq!(&buf.split().freeze()[..], &PACKET);
}
}