use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::slice::Iter;
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum Error {
#[error("Payload is too long")]
PayloadTooLong,
#[error("Promised boundary crossed, contains {0} bytes")]
BoundaryCrossed(usize),
#[error("Packet is malformed")]
MalformedPacket,
#[error("Remaining length is malformed")]
MalformedRemainingLength,
#[error("Topic not utf-8")]
TopicNotUtf8,
#[error("Insufficient number of bytes to frame packet, {0} more bytes required")]
InsufficientBytes(usize),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
pub struct ParsedFixedHeader {
pub byte1: u8,
pub remaining_len_len: usize,
pub remaining_len: usize,
}
impl ParsedFixedHeader {
#[must_use]
pub const fn frame_length(self) -> usize {
1 + self.remaining_len_len + self.remaining_len
}
}
pub fn check(stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
let stream_len = stream.len();
let fixed_header = parse_fixed_header(stream)?;
let frame_length = fixed_header.frame_length();
if stream_len < frame_length {
return Err(Error::InsufficientBytes(frame_length - stream_len));
}
Ok(fixed_header)
}
pub fn parse_fixed_header(mut stream: Iter<u8>) -> Result<ParsedFixedHeader, Error> {
let stream_len = stream.len();
if stream_len < 2 {
return Err(Error::InsufficientBytes(2 - stream_len));
}
let byte1 = *stream.next().unwrap();
let (remaining_len_len, remaining_len) = length(stream)?;
Ok(ParsedFixedHeader {
byte1,
remaining_len_len,
remaining_len,
})
}
pub fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
let mut len: usize = 0;
let mut len_len = 0;
let mut done = false;
let mut shift = 0;
for byte in stream {
len_len += 1;
let byte = *byte as usize;
len += (byte & 0x7F) << shift;
done = (byte & 0x80) == 0;
if done {
break;
}
shift += 7;
if shift > 21 {
return Err(Error::MalformedRemainingLength);
}
}
if !done {
return Err(Error::InsufficientBytes(1));
}
Ok((len_len, len))
}
pub fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
let len = read_u16(stream)? as usize;
if len > stream.len() {
return Err(Error::BoundaryCrossed(len));
}
Ok(stream.split_to(len))
}
pub fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
let s = read_mqtt_bytes(stream)?;
let s = std::str::from_utf8(&s).map_err(|_| Error::TopicNotUtf8)?;
Ok(s.to_owned())
}
pub fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
let len = u16::try_from(bytes.len()).expect("MQTT string/bytes length must fit in u16");
stream.put_u16(len);
stream.extend_from_slice(bytes);
}
pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
write_mqtt_bytes(stream, string.as_bytes());
}
pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
if len > 268_435_455 {
return Err(Error::PayloadTooLong);
}
let mut done = false;
let mut x = len;
let mut count = 0;
while !done {
let mut byte = u8::try_from(x % 128).expect("remainder in 0..=127 always fits in u8");
x /= 128;
if x > 0 {
byte |= 128;
}
stream.put_u8(byte);
count += 1;
done = x == 0;
}
Ok(count)
}
#[must_use]
pub const fn len_len(len: usize) -> usize {
if len >= 2_097_152 {
4
} else if len >= 16_384 {
3
} else if len >= 128 {
2
} else {
1
}
}
pub fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
if stream.len() < 2 {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u16())
}
pub fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
if stream.is_empty() {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u8())
}
pub fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
if stream.len() < 4 {
return Err(Error::MalformedPacket);
}
Ok(stream.get_u32())
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use super::*;
#[test]
fn len_len_matches_expected_thresholds() {
assert_eq!(len_len(0), 1);
assert_eq!(len_len(127), 1);
assert_eq!(len_len(128), 2);
assert_eq!(len_len(16_383), 2);
assert_eq!(len_len(16_384), 3);
assert_eq!(len_len(2_097_151), 3);
assert_eq!(len_len(2_097_152), 4);
}
#[test]
fn write_remaining_length_round_trip() {
for len in [0usize, 127, 128, 321, 16_384, 268_435_455] {
let mut b = BytesMut::new();
let count = write_remaining_length(&mut b, len).unwrap();
let (decoded_count, decoded) = length(b.iter()).unwrap();
assert_eq!(count, decoded_count);
assert_eq!(decoded, len);
}
}
#[test]
fn check_reports_missing_bytes() {
let b = [0x30u8, 0x05, 1, 2];
let result = check(b.iter());
assert_eq!(result, Err(Error::InsufficientBytes(3)));
}
}