use crate::{Error, QoS, Result, MAGIC_BYTE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
pub const HEADER_SIZE: usize = 4;
pub const HEADER_SIZE_WITH_TS: usize = 12;
pub const MAX_PAYLOAD_SIZE: usize = 65535;
#[derive(Debug, Clone, Copy, Default)]
pub struct FrameFlags {
pub qos: QoS,
pub has_timestamp: bool,
pub encrypted: bool,
pub compressed: bool,
pub version: u8,
}
impl FrameFlags {
pub fn to_byte(&self) -> u8 {
let mut flags = 0u8;
flags |= (self.qos as u8) << 6;
if self.has_timestamp {
flags |= 0x20;
}
if self.encrypted {
flags |= 0x10;
}
if self.compressed {
flags |= 0x08;
}
flags |= self.version & 0x07;
flags
}
pub fn from_byte(byte: u8) -> Self {
Self {
qos: QoS::from_u8((byte >> 6) & 0x03).unwrap_or(QoS::Fire),
has_timestamp: (byte & 0x20) != 0,
encrypted: (byte & 0x10) != 0,
compressed: (byte & 0x08) != 0,
version: byte & 0x07,
}
}
pub fn is_binary_encoding(&self) -> bool {
self.version >= 1
}
}
#[derive(Debug, Clone)]
pub struct Frame {
pub flags: FrameFlags,
pub timestamp: Option<u64>,
pub payload: Bytes,
}
impl Frame {
pub fn new(payload: impl Into<Bytes>) -> Self {
Self {
flags: FrameFlags::default(),
timestamp: None,
payload: payload.into(),
}
}
pub fn with_qos(mut self, qos: QoS) -> Self {
self.flags.qos = qos;
self
}
pub fn with_timestamp(mut self, timestamp: u64) -> Self {
self.timestamp = Some(timestamp);
self.flags.has_timestamp = true;
self
}
pub fn with_encrypted(mut self, encrypted: bool) -> Self {
self.flags.encrypted = encrypted;
self
}
pub fn with_compressed(mut self, compressed: bool) -> Self {
self.flags.compressed = compressed;
self
}
pub fn size(&self) -> usize {
let header = if self.flags.has_timestamp {
HEADER_SIZE_WITH_TS
} else {
HEADER_SIZE
};
header + self.payload.len()
}
pub fn encode(&self) -> Result<Bytes> {
if self.payload.len() > MAX_PAYLOAD_SIZE {
return Err(Error::PayloadTooLarge(self.payload.len()));
}
let mut buf = BytesMut::with_capacity(self.size());
buf.put_u8(MAGIC_BYTE);
buf.put_u8(self.flags.to_byte());
buf.put_u16(self.payload.len() as u16);
if let Some(ts) = self.timestamp {
buf.put_u64(ts);
}
buf.extend_from_slice(&self.payload);
Ok(buf.freeze())
}
pub fn decode(mut buf: impl Buf) -> Result<Self> {
if buf.remaining() < HEADER_SIZE {
return Err(Error::BufferTooSmall {
needed: HEADER_SIZE,
have: buf.remaining(),
});
}
let magic = buf.get_u8();
if magic != MAGIC_BYTE {
return Err(Error::InvalidMagic(magic));
}
let flags = FrameFlags::from_byte(buf.get_u8());
let payload_len = buf.get_u16() as usize;
let header_size = if flags.has_timestamp {
HEADER_SIZE_WITH_TS
} else {
HEADER_SIZE
};
let total_remaining = if flags.has_timestamp { 8 } else { 0 } + payload_len;
if buf.remaining() < total_remaining {
return Err(Error::BufferTooSmall {
needed: header_size + payload_len,
have: HEADER_SIZE + buf.remaining(),
});
}
let timestamp = if flags.has_timestamp {
Some(buf.get_u64())
} else {
None
};
let payload = buf.copy_to_bytes(payload_len);
Ok(Self {
flags,
timestamp,
payload,
})
}
pub fn check_complete(buf: &[u8]) -> Option<usize> {
if buf.len() < HEADER_SIZE {
return None;
}
if buf[0] != MAGIC_BYTE {
return None;
}
let flags = FrameFlags::from_byte(buf[1]);
let payload_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
let header_size = if flags.has_timestamp {
HEADER_SIZE_WITH_TS
} else {
HEADER_SIZE
};
let total_size = header_size + payload_len;
if buf.len() >= total_size {
Some(total_size)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_encode_decode() {
let payload = b"hello world";
let frame = Frame::new(payload.as_slice())
.with_qos(QoS::Confirm)
.with_timestamp(1234567890);
let encoded = frame.encode().unwrap();
let decoded = Frame::decode(&encoded[..]).unwrap();
assert_eq!(decoded.flags.qos, QoS::Confirm);
assert_eq!(decoded.timestamp, Some(1234567890));
assert_eq!(decoded.payload.as_ref(), payload);
}
#[test]
fn test_flags_roundtrip() {
let flags = FrameFlags {
qos: QoS::Commit,
has_timestamp: true,
encrypted: true,
compressed: false,
version: 1, };
let byte = flags.to_byte();
let decoded = FrameFlags::from_byte(byte);
assert_eq!(decoded.qos, QoS::Commit);
assert!(decoded.has_timestamp);
assert!(decoded.encrypted);
assert!(!decoded.compressed);
assert_eq!(decoded.version, 1);
assert!(decoded.is_binary_encoding());
}
#[test]
fn test_flags_version_bits() {
let v2_flags = FrameFlags {
version: 0,
..Default::default()
};
assert!(!v2_flags.is_binary_encoding());
let v3_flags = FrameFlags {
version: 1,
..Default::default()
};
assert!(v3_flags.is_binary_encoding());
}
#[test]
fn test_check_complete() {
let frame = Frame::new(b"test".as_slice());
let encoded = frame.encode().unwrap();
assert_eq!(Frame::check_complete(&encoded), Some(encoded.len()));
assert_eq!(Frame::check_complete(&encoded[..2]), None);
assert_eq!(Frame::check_complete(&encoded[..5]), None);
}
#[test]
fn test_frame_max_payload_size() {
let payload = vec![0u8; MAX_PAYLOAD_SIZE];
let frame = Frame::new(payload.clone())
.with_qos(QoS::Fire)
.with_encrypted(true);
let encoded = frame.encode().expect("encode max payload");
let decoded = Frame::decode(&encoded[..]).expect("decode max payload");
assert_eq!(decoded.payload.len(), MAX_PAYLOAD_SIZE);
assert_eq!(decoded.flags.qos, QoS::Fire);
assert!(decoded.flags.encrypted);
assert!(!decoded.flags.has_timestamp);
}
#[test]
fn test_frame_payload_too_large() {
let payload = vec![0u8; MAX_PAYLOAD_SIZE + 1];
let frame = Frame::new(payload);
let err = frame.encode().expect_err("expected PayloadTooLarge error");
match err {
Error::PayloadTooLarge(len) => assert_eq!(len, MAX_PAYLOAD_SIZE + 1),
other => panic!("unexpected error: {:?}", other),
}
}
#[test]
fn test_decode_invalid_magic() {
let frame = Frame::new(b"magic".as_slice());
let mut encoded_vec = frame.encode().unwrap().to_vec();
encoded_vec[0] = 0x00;
let err = Frame::decode(&encoded_vec[..]).expect_err("expected InvalidMagic error");
match err {
Error::InvalidMagic(byte) => assert_eq!(byte, 0x00),
other => panic!("unexpected error: {:?}", other),
}
}
#[test]
fn test_check_complete_with_timestamp() {
let frame = Frame::new(b"ts".as_slice()).with_timestamp(42);
let encoded = frame.encode().unwrap();
assert_eq!(Frame::check_complete(&encoded), Some(encoded.len()));
let truncated = &encoded[..encoded.len() - 1];
assert_eq!(Frame::check_complete(truncated), None);
}
}