use bytes::{Buf, BufMut, Bytes, BytesMut};
use core::fmt;
use std::slice::Iter;
mod topic;
pub mod v4;
pub use topic::*;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Expected Connect, received: {0:?}")]
NotConnect(PacketType),
#[error("Unexpected Connect")]
UnexpectedConnect,
#[error("Invalid Connect return code: {0}")]
InvalidConnectReturnCode(u8),
#[error("Invalid protocol")]
InvalidProtocol,
#[error("Invalid protocol level: {0}")]
InvalidProtocolLevel(u8),
#[error("Incorrect packet format")]
IncorrectPacketFormat,
#[error("Invalid packet type: {0}")]
InvalidPacketType(u8),
#[error("Invalid property type: {0}")]
InvalidPropertyType(u8),
#[error("Invalid QoS level: {0}")]
InvalidQoS(u8),
#[error("Invalid subscribe reason code: {0}")]
InvalidSubscribeReasonCode(u8),
#[error("Packet id Zero")]
PacketIdZero,
#[error("Payload size is incorrect")]
PayloadSizeIncorrect,
#[error("payload is too long")]
PayloadTooLong,
#[error("payload size limit exceeded: {0}")]
PayloadSizeLimitExceeded(usize),
#[error("Payload required")]
PayloadRequired,
#[error("Topic is not UTF-8")]
TopicNotUtf8,
#[error("Promised boundary crossed: {0}")]
BoundaryCrossed(usize),
#[error("Malformed packet")]
MalformedPacket,
#[error("Malformed remaining length")]
MalformedRemainingLength,
#[error("A Subscribe packet must contain atleast one filter")]
EmptySubscription,
#[error("At least {0} more bytes required to frame packet")]
InsufficientBytes(usize),
#[error("IO: {0}")]
Io(#[from] std::io::Error),
#[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")]
OutgoingPacketTooLarge { pkt_size: usize, max: usize },
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketType {
Connect = 1,
ConnAck,
Publish,
PubAck,
PubRec,
PubRel,
PubComp,
Subscribe,
SubAck,
Unsubscribe,
UnsubAck,
PingReq,
PingResp,
Disconnect,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
V4,
V5,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum QoS {
AtMostOnce = 0,
AtLeastOnce = 1,
ExactlyOnce = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
pub struct FixedHeader {
byte1: u8,
fixed_header_len: usize,
remaining_len: usize,
}
impl FixedHeader {
pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
FixedHeader {
byte1,
fixed_header_len: remaining_len_len + 1,
remaining_len,
}
}
pub fn packet_type(&self) -> Result<PacketType, Error> {
let num = self.byte1 >> 4;
match num {
1 => Ok(PacketType::Connect),
2 => Ok(PacketType::ConnAck),
3 => Ok(PacketType::Publish),
4 => Ok(PacketType::PubAck),
5 => Ok(PacketType::PubRec),
6 => Ok(PacketType::PubRel),
7 => Ok(PacketType::PubComp),
8 => Ok(PacketType::Subscribe),
9 => Ok(PacketType::SubAck),
10 => Ok(PacketType::Unsubscribe),
11 => Ok(PacketType::UnsubAck),
12 => Ok(PacketType::PingReq),
13 => Ok(PacketType::PingResp),
14 => Ok(PacketType::Disconnect),
_ => Err(Error::InvalidPacketType(num)),
}
}
pub fn frame_length(&self) -> usize {
self.fixed_header_len + self.remaining_len
}
}
pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
let stream_len = stream.len();
let fixed_header = parse_fixed_header(stream)?;
if fixed_header.remaining_len > max_packet_size {
return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
}
let frame_length = fixed_header.frame_length();
if stream_len < frame_length {
return Err(Error::InsufficientBytes(frame_length - stream_len));
}
Ok(fixed_header)
}
fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
let stream_len = stream.len();
if stream_len < 2 {
return Err(Error::InsufficientBytes(2 - stream_len));
}
let byte1 = stream.next().unwrap();
let (len_len, len) = length(stream)?;
Ok(FixedHeader::new(*byte1, len_len, len))
}
fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
let mut len: usize = 0;
let mut len_len = 0;
let mut done = false;
let mut shift = 0;
for byte in stream {
len_len += 1;
let byte = *byte as usize;
len += (byte & 0x7F) << shift;
done = (byte & 0x80) == 0;
if done {
break;
}
shift += 7;
if shift > 21 {
return Err(Error::MalformedRemainingLength);
}
}
if !done {
return Err(Error::InsufficientBytes(1));
}
Ok((len_len, len))
}
fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
let len = read_u16(stream)? as usize;
if len > stream.len() {
return Err(Error::BoundaryCrossed(len));
}
Ok(stream.split_to(len))
}
fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
let s = read_mqtt_bytes(stream)?;
match String::from_utf8(s.to_vec()) {
Ok(v) => Ok(v),
Err(_e) => Err(Error::TopicNotUtf8),
}
}
fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
stream.put_u16(bytes.len() as u16);
stream.extend_from_slice(bytes);
}
fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
write_mqtt_bytes(stream, string.as_bytes());
}
fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
if len > 268_435_455 {
return Err(Error::PayloadTooLong);
}
let mut done = false;
let mut x = len;
let mut count = 0;
while !done {
let mut byte = (x % 128) as u8;
x /= 128;
if x > 0 {
byte |= 128;
}
stream.put_u8(byte);
count += 1;
done = x == 0;
}
Ok(count)
}
pub fn qos(num: u8) -> Result<QoS, Error> {
match num {
0 => Ok(QoS::AtMostOnce),
1 => Ok(QoS::AtLeastOnce),
2 => Ok(QoS::ExactlyOnce),
qos => Err(Error::InvalidQoS(qos)),
}
}
fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
if stream.len() < 2 {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u16())
}
fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
if stream.is_empty() {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u8())
}