use crate::encoding::{decode_string, encode_string};
use crate::error::{MqttError, Result};
use crate::flags::PublishFlags;
use crate::packet::{FixedHeader, MqttPacket, PacketType};
use crate::prelude::{String, ToString, Vec};
use crate::protocol::v5::properties::{Properties, PropertyId, PropertyValue};
use crate::types::ProtocolVersion;
use crate::QoS;
use bytes::{Buf, BufMut, Bytes};
#[derive(Debug, Clone)]
pub struct PublishPacket {
pub topic_name: String,
pub packet_id: Option<u16>,
pub payload: Bytes,
pub qos: QoS,
pub retain: bool,
pub dup: bool,
pub properties: Properties,
pub protocol_version: u8,
pub stream_id: Option<u64>,
}
impl PublishPacket {
#[must_use]
pub fn new(topic_name: impl Into<String>, payload: impl Into<Bytes>, qos: QoS) -> Self {
Self {
topic_name: topic_name.into(),
packet_id: None,
payload: payload.into(),
qos,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 5,
stream_id: None,
}
}
#[must_use]
pub fn new_v311(topic_name: impl Into<String>, payload: impl Into<Bytes>, qos: QoS) -> Self {
Self {
topic_name: topic_name.into(),
packet_id: None,
payload: payload.into(),
qos,
retain: false,
dup: false,
properties: Properties::default(),
protocol_version: 4,
stream_id: None,
}
}
#[must_use]
pub fn with_packet_id(mut self, id: u16) -> Self {
if self.qos != QoS::AtMostOnce {
self.packet_id = Some(id);
}
self
}
#[must_use]
pub fn with_retain(mut self, retain: bool) -> Self {
self.retain = retain;
self
}
#[must_use]
pub fn with_dup(mut self, dup: bool) -> Self {
self.dup = dup;
self
}
#[must_use]
pub fn with_payload_format_indicator(mut self, is_utf8: bool) -> Self {
self.properties.set_payload_format_indicator(is_utf8);
self
}
#[must_use]
pub fn with_message_expiry_interval(mut self, seconds: u32) -> Self {
self.properties.set_message_expiry_interval(seconds);
self
}
#[must_use]
pub fn with_topic_alias(mut self, alias: u16) -> Self {
self.properties.set_topic_alias(alias);
self
}
#[must_use]
pub fn with_response_topic(mut self, topic: String) -> Self {
self.properties.set_response_topic(topic);
self
}
#[must_use]
pub fn with_correlation_data(mut self, data: Vec<u8>) -> Self {
self.properties.set_correlation_data(data.into());
self
}
#[must_use]
pub fn with_user_property(mut self, key: String, value: String) -> Self {
self.properties.add_user_property(key, value);
self
}
#[must_use]
pub fn with_subscription_identifier(mut self, id: u32) -> Self {
self.properties.set_subscription_identifier(id);
self
}
#[must_use]
pub fn with_content_type(mut self, content_type: String) -> Self {
self.properties.set_content_type(content_type);
self
}
#[must_use]
pub fn topic_alias(&self) -> Option<u16> {
self.properties
.get(PropertyId::TopicAlias)
.and_then(|prop| {
if let PropertyValue::TwoByteInteger(alias) = prop {
Some(*alias)
} else {
None
}
})
}
#[must_use]
pub fn message_expiry_interval(&self) -> Option<u32> {
self.properties
.get(PropertyId::MessageExpiryInterval)
.and_then(|prop| {
if let PropertyValue::FourByteInteger(interval) = prop {
Some(*interval)
} else {
None
}
})
}
#[must_use]
pub fn body_encoded_size(&self) -> usize {
let mut size = 2 + self.topic_name.len();
if self.qos != QoS::AtMostOnce {
size += 2;
}
if self.protocol_version == 5 {
size += self.properties.encoded_len();
}
size += self.payload.len();
size
}
pub fn encode_body_direct<B: BufMut>(&self, buf: &mut B) -> Result<()> {
encode_string(buf, &self.topic_name)?;
if self.qos != QoS::AtMostOnce {
let packet_id = self.packet_id.ok_or_else(|| {
MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
})?;
buf.put_u16(packet_id);
}
if self.protocol_version == 5 {
self.properties.encode_direct(buf)?;
}
buf.put_slice(&self.payload);
Ok(())
}
}
impl MqttPacket for PublishPacket {
fn packet_type(&self) -> PacketType {
PacketType::Publish
}
fn flags(&self) -> u8 {
let mut flags = 0u8;
if self.dup {
flags |= PublishFlags::Dup as u8;
}
flags = PublishFlags::with_qos(flags, self.qos as u8);
if self.retain {
flags |= PublishFlags::Retain as u8;
}
flags
}
fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()> {
encode_string(buf, &self.topic_name)?;
if self.qos != QoS::AtMostOnce {
let packet_id = self.packet_id.ok_or_else(|| {
MqttError::MalformedPacket("Packet ID required for QoS > 0".to_string())
})?;
buf.put_u16(packet_id);
}
if self.protocol_version == 5 {
self.properties.encode(buf)?;
}
buf.put_slice(&self.payload);
Ok(())
}
fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self> {
Self::decode_body_with_version(buf, fixed_header, 5)
}
}
impl PublishPacket {
pub fn decode_body_with_version<B: Buf>(
buf: &mut B,
fixed_header: &FixedHeader,
protocol_version: u8,
) -> Result<Self> {
ProtocolVersion::try_from(protocol_version)
.map_err(|()| MqttError::UnsupportedProtocolVersion)?;
let flags = PublishFlags::decompose(fixed_header.flags);
let dup = flags.contains(&PublishFlags::Dup);
let qos_val = PublishFlags::extract_qos(fixed_header.flags);
let retain = flags.contains(&PublishFlags::Retain);
let qos = match qos_val {
0 => QoS::AtMostOnce,
1 => QoS::AtLeastOnce,
2 => QoS::ExactlyOnce,
_ => {
return Err(MqttError::InvalidQoS(qos_val));
}
};
if dup && qos == QoS::AtMostOnce {
return Err(MqttError::MalformedPacket(
"DUP flag must be 0 when QoS is 0 [MQTT-3.3.1-2]".to_string(),
));
}
let topic_name = decode_string(buf)?;
let packet_id = if qos == QoS::AtMostOnce {
None
} else {
if buf.remaining() < 2 {
return Err(MqttError::MalformedPacket(
"Missing packet identifier".to_string(),
));
}
Some(buf.get_u16())
};
let properties = if protocol_version == 5 {
Properties::decode(buf)?
} else {
Properties::default()
};
let payload = buf.copy_to_bytes(buf.remaining());
Ok(Self {
topic_name,
packet_id,
payload,
qos,
retain,
dup,
properties,
protocol_version,
stream_id: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_publish_packet_qos0() {
let packet = PublishPacket::new("test/topic", &b"Hello, MQTT!"[..], QoS::AtMostOnce);
assert_eq!(packet.topic_name, "test/topic");
assert_eq!(&packet.payload[..], b"Hello, MQTT!");
assert_eq!(packet.qos, QoS::AtMostOnce);
assert!(packet.packet_id.is_none());
assert!(!packet.retain);
assert!(!packet.dup);
}
#[test]
fn test_publish_packet_qos1() {
let packet =
PublishPacket::new("test/topic", &b"Hello"[..], QoS::AtLeastOnce).with_packet_id(123);
assert_eq!(packet.qos, QoS::AtLeastOnce);
assert_eq!(packet.packet_id, Some(123));
}
#[test]
fn test_publish_packet_with_properties() {
let packet = PublishPacket::new("test/topic", &b"data"[..], QoS::AtMostOnce)
.with_retain(true)
.with_payload_format_indicator(true)
.with_message_expiry_interval(3600)
.with_response_topic("response/topic".to_string())
.with_user_property("key".to_string(), "value".to_string());
assert!(packet.retain);
assert!(packet
.properties
.contains(PropertyId::PayloadFormatIndicator));
assert!(packet
.properties
.contains(PropertyId::MessageExpiryInterval));
assert!(packet.properties.contains(PropertyId::ResponseTopic));
assert!(packet.properties.contains(PropertyId::UserProperty));
}
#[test]
fn test_publish_flags() {
let packet = PublishPacket::new("topic", &b"data"[..], QoS::AtMostOnce);
assert_eq!(packet.flags(), 0x00);
let packet = PublishPacket::new("topic", &b"data"[..], QoS::AtLeastOnce).with_retain(true);
assert_eq!(packet.flags(), 0x03);
let packet = PublishPacket::new("topic", &b"data"[..], QoS::ExactlyOnce).with_dup(true);
assert_eq!(packet.flags(), 0x0C);
let packet = PublishPacket::new("topic", &b"data"[..], QoS::ExactlyOnce)
.with_dup(true)
.with_retain(true);
assert_eq!(packet.flags(), 0x0D);
}
#[test]
fn test_publish_encode_decode_qos0() {
let packet = PublishPacket::new("sensor/temperature", &b"23.5"[..], QoS::AtMostOnce)
.with_retain(true);
let mut buf = BytesMut::new();
packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
assert_eq!(fixed_header.packet_type, PacketType::Publish);
assert_eq!(
fixed_header.flags & crate::flags::PublishFlags::Retain as u8,
crate::flags::PublishFlags::Retain as u8
);
let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.topic_name, "sensor/temperature");
assert_eq!(&decoded.payload[..], b"23.5");
assert_eq!(decoded.qos, QoS::AtMostOnce);
assert!(decoded.retain);
assert!(decoded.packet_id.is_none());
}
#[test]
fn test_publish_encode_decode_qos1() {
let packet = PublishPacket::new("test/qos1", &b"QoS 1 message"[..], QoS::AtLeastOnce)
.with_packet_id(456);
let mut buf = BytesMut::new();
packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.topic_name, "test/qos1");
assert_eq!(&decoded.payload[..], b"QoS 1 message");
assert_eq!(decoded.qos, QoS::AtLeastOnce);
assert_eq!(decoded.packet_id, Some(456));
}
#[test]
fn test_publish_encode_decode_with_properties() {
let packet = PublishPacket::new("test/props", &b"data"[..], QoS::ExactlyOnce)
.with_packet_id(789)
.with_message_expiry_interval(7200)
.with_content_type("application/json".to_string());
let mut buf = BytesMut::new();
packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
let decoded = PublishPacket::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded.qos, QoS::ExactlyOnce);
assert_eq!(decoded.packet_id, Some(789));
let expiry = decoded.properties.get(PropertyId::MessageExpiryInterval);
assert!(expiry.is_some());
if let Some(PropertyValue::FourByteInteger(val)) = expiry {
assert_eq!(*val, 7200);
}
let content_type = decoded.properties.get(PropertyId::ContentType);
assert!(content_type.is_some());
if let Some(PropertyValue::Utf8String(val)) = content_type {
assert_eq!(val, "application/json");
}
}
#[test]
fn test_publish_missing_packet_id() {
let mut buf = BytesMut::new();
encode_string(&mut buf, "topic").unwrap();
let fixed_header =
FixedHeader::new(PacketType::Publish, 0x02, u32::try_from(buf.len()).unwrap()); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
assert!(result.is_err());
}
#[test]
fn test_publish_invalid_qos() {
let mut buf = BytesMut::new();
encode_string(&mut buf, "topic").unwrap();
let fixed_header = FixedHeader::new(PacketType::Publish, 0x06, 0); let result = PublishPacket::decode_body(&mut buf, &fixed_header);
assert!(result.is_err());
}
}