use bytes::{BufMut, Bytes, BytesMut};
use core::fmt;
use mqttbytes_core::primitives::{self as core_primitives, Error as PrimitiveError};
use std::slice::Iter;
pub mod v4;
pub use mqttbytes_core::{QoS, has_wildcards, matches, valid_filter, valid_topic};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Expected Connect, received: {0:?}")]
NotConnect(PacketType),
#[error("Unexpected Connect")]
UnexpectedConnect,
#[error("Invalid Connect return code: {0}")]
InvalidConnectReturnCode(u8),
#[error("Invalid protocol")]
InvalidProtocol,
#[error("Invalid protocol level: {0}")]
InvalidProtocolLevel(u8),
#[error("Incorrect packet format")]
IncorrectPacketFormat,
#[error("Invalid packet type: {0}")]
InvalidPacketType(u8),
#[error("Invalid property type: {0}")]
InvalidPropertyType(u8),
#[error("Invalid QoS level: {0}")]
InvalidQoS(u8),
#[error("Invalid subscribe reason code: {0}")]
InvalidSubscribeReasonCode(u8),
#[error("Packet id Zero")]
PacketIdZero,
#[error("Payload size is incorrect")]
PayloadSizeIncorrect,
#[error("payload is too long")]
PayloadTooLong,
#[error("payload size limit exceeded: {0}")]
PayloadSizeLimitExceeded(usize),
#[error("Payload required")]
PayloadRequired,
#[error("Topic is not UTF-8")]
TopicNotUtf8,
#[error("Promised boundary crossed: {0}")]
BoundaryCrossed(usize),
#[error("Malformed packet")]
MalformedPacket,
#[error("Malformed remaining length")]
MalformedRemainingLength,
#[error("A Subscribe packet must contain atleast one filter")]
EmptySubscription,
#[error("At least {0} more bytes required to frame packet")]
InsufficientBytes(usize),
#[error("IO: {0}")]
Io(#[from] std::io::Error),
#[error(
"Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'"
)]
OutgoingPacketTooLarge { pkt_size: usize, max: usize },
}
impl From<PrimitiveError> for Error {
fn from(error: PrimitiveError) -> Self {
match error {
PrimitiveError::PayloadTooLong => Self::PayloadTooLong,
PrimitiveError::BoundaryCrossed(len) => Self::BoundaryCrossed(len),
PrimitiveError::MalformedPacket => Self::MalformedPacket,
PrimitiveError::MalformedRemainingLength => Self::MalformedRemainingLength,
PrimitiveError::TopicNotUtf8 => Self::TopicNotUtf8,
PrimitiveError::InsufficientBytes(required) => Self::InsufficientBytes(required),
}
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketType {
Connect = 1,
ConnAck,
Publish,
PubAck,
PubRec,
PubRel,
PubComp,
Subscribe,
SubAck,
Unsubscribe,
UnsubAck,
PingReq,
PingResp,
Disconnect,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
V4,
V5,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
pub struct FixedHeader {
byte1: u8,
header_len: usize,
remaining_len: usize,
}
impl FixedHeader {
#[must_use]
pub const fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> Self {
Self {
byte1,
header_len: remaining_len_len + 1,
remaining_len,
}
}
pub const fn packet_type(&self) -> Result<PacketType, Error> {
let num = self.byte1 >> 4;
match num {
1 => Ok(PacketType::Connect),
2 => Ok(PacketType::ConnAck),
3 => Ok(PacketType::Publish),
4 => Ok(PacketType::PubAck),
5 => Ok(PacketType::PubRec),
6 => Ok(PacketType::PubRel),
7 => Ok(PacketType::PubComp),
8 => Ok(PacketType::Subscribe),
9 => Ok(PacketType::SubAck),
10 => Ok(PacketType::Unsubscribe),
11 => Ok(PacketType::UnsubAck),
12 => Ok(PacketType::PingReq),
13 => Ok(PacketType::PingResp),
14 => Ok(PacketType::Disconnect),
_ => Err(Error::InvalidPacketType(num)),
}
}
#[must_use]
pub const fn frame_length(&self) -> usize {
self.header_len + self.remaining_len
}
}
pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
let stream_len = stream.len();
let fixed_header = parse_fixed_header(stream)?;
if fixed_header.remaining_len > max_packet_size {
return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
}
let frame_length = fixed_header.frame_length();
if stream_len < frame_length {
return Err(Error::InsufficientBytes(frame_length - stream_len));
}
Ok(fixed_header)
}
fn parse_fixed_header(stream: Iter<u8>) -> Result<FixedHeader, Error> {
let fixed_header = core_primitives::parse_fixed_header(stream).map_err(Error::from)?;
Ok(FixedHeader::new(
fixed_header.byte1,
fixed_header.remaining_len_len,
fixed_header.remaining_len,
))
}
fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
core_primitives::read_mqtt_bytes(stream).map_err(Error::from)
}
fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
core_primitives::read_mqtt_string(stream).map_err(Error::from)
}
fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
core_primitives::write_mqtt_bytes(stream, bytes);
}
fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
core_primitives::write_mqtt_string(stream, string);
}
fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
core_primitives::write_remaining_length(stream, len).map_err(Error::from)
}
pub fn qos(num: u8) -> Result<QoS, Error> {
mqttbytes_core::qos(num).ok_or(Error::InvalidQoS(num))
}
fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
core_primitives::read_u16(stream).map_err(Error::from)
}
fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
core_primitives::read_u8(stream).map_err(Error::from)
}
#[cfg(test)]
mod tests {
use super::{Error, check};
#[test]
fn check_rejects_oversized_packet_on_partial_frame() {
let stream = [0x30, 0x14];
let result = check(stream.iter(), 10);
assert!(matches!(result, Err(Error::PayloadSizeLimitExceeded(20))));
}
}