use crate::mqtt::common::{Arc, Cursor};
use crate::mqtt::result_code::MqttError;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub enum PacketData {
Normal(Vec<u8>),
Publish(Arc<[u8]>),
}
impl PacketData {
pub fn as_slice(&self) -> &[u8] {
match self {
PacketData::Normal(vec) => vec.as_slice(),
PacketData::Publish(arc) => arc.as_ref(),
}
}
pub fn len(&self) -> u32 {
match self {
PacketData::Normal(vec) => vec.len().try_into().unwrap(),
PacketData::Publish(arc) => arc.len().try_into().unwrap(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct RawPacket {
fixed_header: u8,
pub data: PacketData,
}
impl RawPacket {
pub fn data_as_slice(&self) -> &[u8] {
self.data.as_slice()
}
pub fn packet_type(&self) -> u8 {
self.fixed_header >> 4
}
pub fn flags(&self) -> u8 {
self.fixed_header & 0x0F
}
pub fn is_publish(&self) -> bool {
self.packet_type() == 3
}
pub fn remaining_length(&self) -> u32 {
self.data.len()
}
}
#[derive(Debug)]
pub enum PacketBuildResult {
Complete(RawPacket),
Incomplete,
Error(MqttError),
}
pub struct PacketBuilder {
state: ReadState,
header_buf: Vec<u8>,
remaining_length: usize,
multiplier: u32,
raw_buf: Option<Vec<u8>>,
raw_buf_offset: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ReadState {
FixedHeader,
RemainingLength,
Payload,
}
impl PacketBuilder {
pub fn new() -> Self {
Self {
state: ReadState::FixedHeader,
header_buf: Vec::with_capacity(5),
remaining_length: 0,
multiplier: 1,
raw_buf: None,
raw_buf_offset: 0,
}
}
pub fn reset(&mut self) {
self.state = ReadState::FixedHeader;
self.header_buf.clear();
self.remaining_length = 0;
self.multiplier = 1;
self.raw_buf = None;
self.raw_buf_offset = 0;
}
fn get_packet_type(&self) -> u8 {
if !self.header_buf.is_empty() {
self.header_buf[0]
} else {
0
}
}
fn is_publish_packet(&self) -> bool {
(self.get_packet_type() & 0xF0) == 0x30
}
pub fn feed(&mut self, data: &mut Cursor<&[u8]>) -> PacketBuildResult {
let available = data.get_ref().len() as u64 - data.position();
if available == 0 {
return PacketBuildResult::Incomplete;
}
let mut byte = [0u8; 1];
loop {
match self.state {
ReadState::FixedHeader => {
if data.read_exact(&mut byte).is_err() {
return PacketBuildResult::Incomplete;
}
self.header_buf.push(byte[0]);
self.state = ReadState::RemainingLength;
}
ReadState::RemainingLength => {
if data.read_exact(&mut byte).is_err() {
return PacketBuildResult::Incomplete;
}
self.header_buf.push(byte[0]);
let encoded_byte = byte[0];
self.remaining_length +=
((encoded_byte & 0x7F) as usize) * (self.multiplier as usize);
self.multiplier *= 128;
if self.multiplier > 128 * 128 * 128 {
self.reset();
return PacketBuildResult::Error(MqttError::MalformedPacket);
}
if (encoded_byte & 0x80) == 0 {
if self.remaining_length == 0 {
let fixed_header = self.header_buf[0];
let packet_data = if self.is_publish_packet() {
PacketData::Publish(Arc::from([]))
} else {
PacketData::Normal(Vec::new())
};
let packet = RawPacket {
fixed_header,
data: packet_data,
};
self.reset();
return PacketBuildResult::Complete(packet);
} else {
self.raw_buf = Some(Vec::with_capacity(self.remaining_length));
self.raw_buf_offset = 0;
self.state = ReadState::Payload;
}
}
}
ReadState::Payload => {
let raw_buf = self.raw_buf.as_mut().unwrap();
let bytes_remaining = self.remaining_length;
let position = data.position();
let available = data.get_ref().len() as u64 - position;
let bytes_to_read = bytes_remaining.min(available as usize);
if bytes_to_read == 0 {
return PacketBuildResult::Incomplete;
}
raw_buf.resize(self.raw_buf_offset + bytes_to_read, 0);
let read_slice =
&mut raw_buf[self.raw_buf_offset..self.raw_buf_offset + bytes_to_read];
let bytes_read = data.read(read_slice).unwrap();
self.raw_buf_offset += bytes_read;
self.remaining_length -= bytes_read;
if self.remaining_length == 0 {
let raw_buf = self.raw_buf.take().unwrap();
let fixed_header = self.header_buf[0];
let packet_data = if self.is_publish_packet() {
PacketData::Publish(Arc::from(raw_buf.into_boxed_slice()))
} else {
PacketData::Normal(raw_buf)
};
let packet = RawPacket {
fixed_header,
data: packet_data,
};
self.reset();
return PacketBuildResult::Complete(packet);
}
return PacketBuildResult::Incomplete;
}
}
}
}
}