mod ack_common;
pub mod auth;
pub mod connack;
pub mod connect;
pub mod disconnect;
pub mod pingreq;
pub mod pingresp;
pub mod puback;
pub mod pubcomp;
pub mod publish;
pub mod pubrec;
pub mod pubrel;
pub mod suback;
pub mod subscribe;
pub mod subscribe_options;
pub mod unsuback;
pub mod unsubscribe;
pub use ack_common::{is_valid_publish_ack_reason_code, is_valid_pubrel_reason_code};
#[cfg(test)]
mod property_tests;
#[cfg(test)]
mod bebytes_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_mqtt_type_and_flags_round_trip(
message_type in 1u8..=15,
dup in 0u8..=1,
qos in 0u8..=3,
retain in 0u8..=1
) {
let original = MqttTypeAndFlags {
message_type,
dup,
qos,
retain,
};
let bytes = original.to_be_bytes();
let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
prop_assert_eq!(original, decoded);
}
#[test]
fn prop_packet_type_round_trip(packet_type in 1u8..=15) {
if let Some(pt) = PacketType::from_u8(packet_type) {
let type_and_flags = MqttTypeAndFlags::for_packet_type(pt);
let bytes = type_and_flags.to_be_bytes();
let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
prop_assert_eq!(type_and_flags, decoded);
prop_assert_eq!(decoded.packet_type(), Some(pt));
}
}
#[test]
fn prop_publish_flags_round_trip(
qos in 0u8..=3,
dup: bool,
retain: bool
) {
let type_and_flags = MqttTypeAndFlags::for_publish(qos, dup, retain);
let bytes = type_and_flags.to_be_bytes();
let (decoded, _) = MqttTypeAndFlags::try_from_be_bytes(&bytes).unwrap();
prop_assert_eq!(type_and_flags, decoded);
prop_assert_eq!(decoded.packet_type(), Some(PacketType::Publish));
prop_assert_eq!(decoded.qos, qos);
prop_assert_eq!(decoded.is_dup(), dup);
prop_assert_eq!(decoded.is_retain(), retain);
}
}
}
use crate::encoding::{decode_variable_int, encode_variable_int};
use crate::error::{MqttError, Result};
use crate::prelude::{format, Box, ToString, Vec};
use bebytes::BeBytes;
use bytes::{Buf, BufMut};
#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
pub struct AckPacketHeader {
pub packet_id: u16,
pub reason_code: u8,
}
impl AckPacketHeader {
#[must_use]
pub fn create(packet_id: u16, reason_code: crate::types::ReasonCode) -> Self {
Self {
packet_id,
reason_code: u8::from(reason_code),
}
}
#[must_use]
pub fn get_reason_code(&self) -> Option<crate::types::ReasonCode> {
crate::types::ReasonCode::from_u8(self.reason_code)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
pub struct MqttTypeAndFlags {
#[bits(4)]
pub message_type: u8,
#[bits(1)]
pub dup: u8,
#[bits(2)]
pub qos: u8,
#[bits(1)]
pub retain: u8,
}
impl MqttTypeAndFlags {
#[must_use]
pub fn for_packet_type(packet_type: PacketType) -> Self {
Self {
message_type: packet_type as u8,
dup: 0,
qos: 0,
retain: 0,
}
}
#[must_use]
pub fn for_publish(qos: u8, dup: bool, retain: bool) -> Self {
Self {
message_type: PacketType::Publish as u8,
dup: u8::from(dup),
qos,
retain: u8::from(retain),
}
}
#[must_use]
pub fn packet_type(&self) -> Option<PacketType> {
PacketType::from_u8(self.message_type)
}
#[must_use]
pub fn is_dup(&self) -> bool {
self.dup != 0
}
#[must_use]
pub fn is_retain(&self) -> bool {
self.retain != 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, BeBytes)]
pub enum PacketType {
Connect = 1,
ConnAck = 2,
Publish = 3,
PubAck = 4,
PubRec = 5,
PubRel = 6,
PubComp = 7,
Subscribe = 8,
SubAck = 9,
Unsubscribe = 10,
UnsubAck = 11,
PingReq = 12,
PingResp = 13,
Disconnect = 14,
Auth = 15,
}
impl PacketType {
#[must_use]
pub fn from_u8(value: u8) -> Option<Self> {
Self::try_from(value).ok()
}
}
impl From<PacketType> for u8 {
fn from(packet_type: PacketType) -> Self {
packet_type as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FixedHeader {
pub packet_type: PacketType,
pub flags: u8,
pub remaining_length: u32,
}
impl FixedHeader {
#[must_use]
pub fn new(packet_type: PacketType, flags: u8, remaining_length: u32) -> Self {
Self {
packet_type,
flags,
remaining_length,
}
}
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
let byte1 =
(u8::from(self.packet_type) << 4) | (self.flags & crate::constants::masks::FLAGS);
buf.put_u8(byte1);
encode_variable_int(buf, self.remaining_length)?;
Ok(())
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self> {
if !buf.has_remaining() {
return Err(MqttError::MalformedPacket(
"No data for fixed header".to_string(),
));
}
let byte1 = buf.get_u8();
let packet_type_val = (byte1 >> 4) & crate::constants::masks::FLAGS;
let flags = byte1 & crate::constants::masks::FLAGS;
let packet_type = PacketType::from_u8(packet_type_val)
.ok_or(MqttError::InvalidPacketType(packet_type_val))?;
let remaining_length = decode_variable_int(buf)?;
Ok(Self {
packet_type,
flags,
remaining_length,
})
}
#[must_use]
pub fn validate_flags(&self) -> bool {
match self.packet_type {
PacketType::Publish => true, PacketType::PubRel | PacketType::Subscribe | PacketType::Unsubscribe => {
self.flags == 0x02 }
_ => self.flags == 0,
}
}
#[must_use]
pub fn encoded_len(&self) -> usize {
1 + crate::encoding::encoded_variable_int_len(self.remaining_length)
}
}
#[derive(Debug, Clone)]
pub enum Packet {
Connect(Box<connect::ConnectPacket>),
ConnAck(connack::ConnAckPacket),
Publish(publish::PublishPacket),
PubAck(puback::PubAckPacket),
PubRec(pubrec::PubRecPacket),
PubRel(pubrel::PubRelPacket),
PubComp(pubcomp::PubCompPacket),
Subscribe(subscribe::SubscribePacket),
SubAck(suback::SubAckPacket),
Unsubscribe(unsubscribe::UnsubscribePacket),
UnsubAck(unsuback::UnsubAckPacket),
PingReq,
PingResp,
Disconnect(disconnect::DisconnectPacket),
Auth(auth::AuthPacket),
}
impl Packet {
#[must_use]
pub fn packet_type_name(&self) -> &'static str {
match self {
Self::Connect(_) => "CONNECT",
Self::ConnAck(_) => "CONNACK",
Self::Publish(_) => "PUBLISH",
Self::PubAck(_) => "PUBACK",
Self::PubRec(_) => "PUBREC",
Self::PubRel(_) => "PUBREL",
Self::PubComp(_) => "PUBCOMP",
Self::Subscribe(_) => "SUBSCRIBE",
Self::SubAck(_) => "SUBACK",
Self::Unsubscribe(_) => "UNSUBSCRIBE",
Self::UnsubAck(_) => "UNSUBACK",
Self::PingReq => "PINGREQ",
Self::PingResp => "PINGRESP",
Self::Disconnect(_) => "DISCONNECT",
Self::Auth(_) => "AUTH",
}
}
pub fn decode_from_body<B: Buf>(
packet_type: PacketType,
fixed_header: &FixedHeader,
buf: &mut B,
) -> Result<Self> {
if !fixed_header.validate_flags() {
return Err(MqttError::MalformedPacket(format!(
"Invalid fixed header flags 0x{:02X} for {:?}",
fixed_header.flags, packet_type
)));
}
match packet_type {
PacketType::Connect => {
let packet = connect::ConnectPacket::decode_body(buf, fixed_header)?;
Ok(Packet::Connect(Box::new(packet)))
}
PacketType::ConnAck => {
let packet = connack::ConnAckPacket::decode_body(buf, fixed_header)?;
Ok(Packet::ConnAck(packet))
}
PacketType::Publish => {
let packet = publish::PublishPacket::decode_body(buf, fixed_header)?;
Ok(Packet::Publish(packet))
}
PacketType::PubAck => {
let packet = puback::PubAckPacket::decode_body(buf, fixed_header)?;
Ok(Packet::PubAck(packet))
}
PacketType::PubRec => {
let packet = pubrec::PubRecPacket::decode_body(buf, fixed_header)?;
Ok(Packet::PubRec(packet))
}
PacketType::PubRel => {
let packet = pubrel::PubRelPacket::decode_body(buf, fixed_header)?;
Ok(Packet::PubRel(packet))
}
PacketType::PubComp => {
let packet = pubcomp::PubCompPacket::decode_body(buf, fixed_header)?;
Ok(Packet::PubComp(packet))
}
PacketType::Subscribe => {
let packet = subscribe::SubscribePacket::decode_body(buf, fixed_header)?;
Ok(Packet::Subscribe(packet))
}
PacketType::SubAck => {
let packet = suback::SubAckPacket::decode_body(buf, fixed_header)?;
Ok(Packet::SubAck(packet))
}
PacketType::Unsubscribe => {
let packet = unsubscribe::UnsubscribePacket::decode_body(buf, fixed_header)?;
Ok(Packet::Unsubscribe(packet))
}
PacketType::UnsubAck => {
let packet = unsuback::UnsubAckPacket::decode_body(buf, fixed_header)?;
Ok(Packet::UnsubAck(packet))
}
PacketType::PingReq => Ok(Packet::PingReq),
PacketType::PingResp => Ok(Packet::PingResp),
PacketType::Disconnect => {
let packet = disconnect::DisconnectPacket::decode_body(buf, fixed_header)?;
Ok(Packet::Disconnect(packet))
}
PacketType::Auth => {
let packet = auth::AuthPacket::decode_body(buf, fixed_header)?;
Ok(Packet::Auth(packet))
}
}
}
pub fn decode_from_body_with_version<B: Buf>(
packet_type: PacketType,
fixed_header: &FixedHeader,
buf: &mut B,
protocol_version: u8,
) -> Result<Self> {
match packet_type {
PacketType::Publish => {
let packet = publish::PublishPacket::decode_body_with_version(
buf,
fixed_header,
protocol_version,
)?;
Ok(Packet::Publish(packet))
}
PacketType::Subscribe => {
let packet = subscribe::SubscribePacket::decode_body_with_version(
buf,
fixed_header,
protocol_version,
)?;
Ok(Packet::Subscribe(packet))
}
PacketType::SubAck => {
let packet = suback::SubAckPacket::decode_body_with_version(
buf,
fixed_header,
protocol_version,
)?;
Ok(Packet::SubAck(packet))
}
PacketType::Unsubscribe => {
let packet = unsubscribe::UnsubscribePacket::decode_body_with_version(
buf,
fixed_header,
protocol_version,
)?;
Ok(Packet::Unsubscribe(packet))
}
_ => Self::decode_from_body(packet_type, fixed_header, buf),
}
}
}
pub trait MqttPacket: Sized {
fn packet_type(&self) -> PacketType;
fn flags(&self) -> u8 {
0
}
fn encode_body<B: BufMut>(&self, buf: &mut B) -> Result<()>;
fn decode_body<B: Buf>(buf: &mut B, fixed_header: &FixedHeader) -> Result<Self>;
fn encode<B: BufMut>(&self, buf: &mut B) -> Result<()> {
let mut body = Vec::new();
self.encode_body(&mut body)?;
let fixed_header = FixedHeader::new(
self.packet_type(),
self.flags(),
body.len().try_into().unwrap_or(u32::MAX),
);
fixed_header.encode(buf)?;
buf.put_slice(&body);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_packet_type_from_u8() {
assert_eq!(PacketType::from_u8(1), Some(PacketType::Connect));
assert_eq!(PacketType::from_u8(2), Some(PacketType::ConnAck));
assert_eq!(PacketType::from_u8(15), Some(PacketType::Auth));
assert_eq!(PacketType::from_u8(0), None);
assert_eq!(PacketType::from_u8(16), None);
}
#[test]
fn test_fixed_header_encode_decode() {
let mut buf = BytesMut::new();
let header = FixedHeader::new(PacketType::Connect, 0, 100);
header.encode(&mut buf).unwrap();
let decoded = FixedHeader::decode(&mut buf).unwrap();
assert_eq!(decoded.packet_type, PacketType::Connect);
assert_eq!(decoded.flags, 0);
assert_eq!(decoded.remaining_length, 100);
}
#[test]
fn test_fixed_header_with_flags() {
let mut buf = BytesMut::new();
let header = FixedHeader::new(PacketType::Publish, 0x0D, 50);
header.encode(&mut buf).unwrap();
let decoded = FixedHeader::decode(&mut buf).unwrap();
assert_eq!(decoded.packet_type, PacketType::Publish);
assert_eq!(decoded.flags, 0x0D);
assert_eq!(decoded.remaining_length, 50);
}
#[test]
fn test_validate_flags() {
let header = FixedHeader::new(PacketType::Connect, 0, 0);
assert!(header.validate_flags());
let header = FixedHeader::new(PacketType::Connect, 1, 0);
assert!(!header.validate_flags());
let header = FixedHeader::new(PacketType::Subscribe, 0x02, 0);
assert!(header.validate_flags());
let header = FixedHeader::new(PacketType::Subscribe, 0x00, 0);
assert!(!header.validate_flags());
let header = FixedHeader::new(PacketType::Publish, 0x0F, 0);
assert!(header.validate_flags());
}
#[test]
fn test_decode_insufficient_data() {
let mut buf = BytesMut::new();
let result = FixedHeader::decode(&mut buf);
assert!(result.is_err());
}
#[test]
fn test_decode_invalid_packet_type() {
let mut buf = BytesMut::new();
buf.put_u8(0x00); buf.put_u8(0x00);
let result = FixedHeader::decode(&mut buf);
assert!(result.is_err());
}
#[test]
fn test_packet_type_bebytes_serialization() {
let packet_type = PacketType::Publish;
let bytes = packet_type.to_be_bytes();
assert_eq!(bytes, vec![3]);
let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
assert_eq!(decoded, PacketType::Publish);
assert_eq!(consumed, 1);
let packet_type = PacketType::Connect;
let bytes = packet_type.to_be_bytes();
assert_eq!(bytes, vec![1]);
let (decoded, consumed) = PacketType::try_from_be_bytes(&bytes).unwrap();
assert_eq!(decoded, PacketType::Connect);
assert_eq!(consumed, 1);
}
}