use bytes::BytesMut;
use std::io::Write;
use crate::{
ByteArray, DecodeError, DecodePacket, EncodeError, EncodePacket, FixedHeader, Packet, PacketId,
PacketType, PubTopic, QoS, VarIntError,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct PublishPacket {
dup: bool,
qos: QoS,
retain: bool,
topic: PubTopic,
packet_id: PacketId,
msg: BytesMut,
}
impl PublishPacket {
pub fn new(topic: &str, qos: QoS, msg: &[u8]) -> Result<Self, EncodeError> {
let topic = PubTopic::new(topic)?;
Ok(Self {
qos,
dup: false,
retain: false,
topic,
packet_id: PacketId::new(0),
msg: BytesMut::from(msg),
})
}
pub fn append(&mut self, msg_parts: &[u8]) {
self.msg.extend_from_slice(msg_parts);
}
pub fn set_retain(&mut self, retain: bool) -> &mut Self {
self.retain = retain;
self
}
#[must_use]
pub const fn retain(&self) -> bool {
self.retain
}
pub fn set_dup(&mut self, dup: bool) -> Result<&mut Self, EncodeError> {
if dup && self.qos == QoS::AtMostOnce {
return Err(EncodeError::InvalidPacketType);
}
self.dup = dup;
Ok(self)
}
#[must_use]
pub const fn dup(&self) -> bool {
self.dup
}
pub fn set_qos(&mut self, qos: QoS) -> &mut Self {
if qos == QoS::AtMostOnce {
self.packet_id = PacketId::new(0);
}
self.qos = qos;
self
}
#[must_use]
pub const fn qos(&self) -> QoS {
self.qos
}
pub fn set_packet_id(&mut self, packet_id: PacketId) -> &mut Self {
self.packet_id = packet_id;
self
}
#[must_use]
pub const fn packet_id(&self) -> PacketId {
self.packet_id
}
pub fn set_topic(&mut self, topic: &str) -> Result<&mut Self, EncodeError> {
self.topic = PubTopic::new(topic)?;
Ok(self)
}
#[must_use]
pub fn topic(&self) -> &str {
self.topic.as_ref()
}
#[must_use]
pub fn message(&self) -> &[u8] {
&self.msg
}
fn get_fixed_header(&self) -> Result<FixedHeader, VarIntError> {
let mut remaining_length = self.topic.bytes() + self.msg.len();
if self.qos != QoS::AtMostOnce {
remaining_length += PacketId::bytes();
}
let packet_type = PacketType::Publish {
dup: self.dup,
retain: self.retain,
qos: self.qos,
};
FixedHeader::new(packet_type, remaining_length)
}
}
impl DecodePacket for PublishPacket {
fn decode(ba: &mut ByteArray) -> Result<Self, DecodeError> {
let fixed_header = FixedHeader::decode(ba)?;
let PacketType::Publish { dup, qos, retain } = fixed_header.packet_type() else {
return Err(DecodeError::InvalidPacketType);
};
if dup && qos == QoS::AtMostOnce {
return Err(DecodeError::InvalidPacketFlags);
}
if dup && qos == QoS::AtLeastOnce {
return Err(DecodeError::InvalidPacketFlags);
}
let topic = PubTopic::decode(ba)?;
log::info!("topic: {:?}", &topic);
let packet_id = if qos == QoS::AtMostOnce {
PacketId::new(0)
} else {
let packet_id = PacketId::decode(ba)?;
if packet_id.value() == 0 {
return Err(DecodeError::InvalidPacketId);
}
packet_id
};
if fixed_header.remaining_length() < topic.bytes() {
log::info!(
"remaining length: {}, topic bytes: {}",
fixed_header.remaining_length(),
topic.bytes()
);
return Err(DecodeError::InvalidRemainingLength);
}
let mut msg_len = fixed_header.remaining_length() - topic.bytes();
if qos != QoS::AtMostOnce {
if msg_len < PacketId::bytes() {
return Err(DecodeError::InvalidRemainingLength);
}
msg_len -= PacketId::bytes();
}
let msg = BytesMut::from(ba.read_bytes(msg_len)?);
Ok(Self {
dup,
qos,
retain,
topic,
packet_id,
msg,
})
}
}
impl EncodePacket for PublishPacket {
fn encode(&self, v: &mut Vec<u8>) -> Result<usize, EncodeError> {
let old_len = v.len();
let fixed_header = self.get_fixed_header()?;
fixed_header.encode(v)?;
self.topic.encode(v)?;
if self.qos() != QoS::AtMostOnce {
self.packet_id.encode(v)?;
}
v.write_all(&self.msg)?;
Ok(v.len() - old_len)
}
}
impl Packet for PublishPacket {
fn packet_type(&self) -> PacketType {
PacketType::Publish {
dup: self.dup,
retain: self.retain,
qos: self.qos,
}
}
fn bytes(&self) -> Result<usize, VarIntError> {
let fixed_header = self.get_fixed_header()?;
Ok(fixed_header.bytes() + fixed_header.remaining_length())
}
}