use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FrameType {
Amqp,
Sasl,
Reserved(u8),
}
impl FrameType {
#[must_use]
pub const fn to_u8(self) -> u8 {
match self {
Self::Amqp => 0x00,
Self::Sasl => 0x01,
Self::Reserved(v) => v,
}
}
#[must_use]
pub const fn from_u8(v: u8) -> Self {
match v {
0x00 => Self::Amqp,
0x01 => Self::Sasl,
other => Self::Reserved(other),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
pub size: u32,
pub doff: u8,
pub frame_type: FrameType,
pub channel: u16,
}
impl FrameHeader {
#[must_use]
pub const fn new_amqp(size_total: u32, doff_words: u8, channel: u16) -> Self {
Self {
size: size_total,
doff: doff_words,
frame_type: FrameType::Amqp,
channel,
}
}
#[must_use]
pub const fn body_offset(self) -> usize {
(self.doff as usize) * 4
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameError {
HeaderTooShort,
InvalidDataOffset(u8),
SizeBelowMinimum(u32),
BodyOffsetExceedsSize,
}
impl fmt::Display for FrameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HeaderTooShort => f.write_str("frame header < 8 bytes"),
Self::InvalidDataOffset(d) => write!(f, "invalid DOFF {d} (< 2)"),
Self::SizeBelowMinimum(s) => write!(f, "frame SIZE {s} < 8 minimum"),
Self::BodyOffsetExceedsSize => f.write_str("body offset exceeds SIZE"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for FrameError {}
#[must_use]
pub fn encode_frame_header(h: FrameHeader) -> [u8; 8] {
let mut out = [0u8; 8];
out[0..4].copy_from_slice(&h.size.to_be_bytes());
out[4] = h.doff;
out[5] = h.frame_type.to_u8();
out[6..8].copy_from_slice(&h.channel.to_be_bytes());
out
}
pub fn decode_frame_header(bytes: &[u8]) -> Result<FrameHeader, FrameError> {
if bytes.len() < 8 {
return Err(FrameError::HeaderTooShort);
}
let size = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
if size < 8 {
return Err(FrameError::SizeBelowMinimum(size));
}
let doff = bytes[4];
if doff < 2 {
return Err(FrameError::InvalidDataOffset(doff));
}
let body_offset = u32::from(doff) * 4;
if body_offset > size {
return Err(FrameError::BodyOffsetExceedsSize);
}
let frame_type = FrameType::from_u8(bytes[5]);
let channel = u16::from_be_bytes([bytes[6], bytes[7]]);
Ok(FrameHeader {
size,
doff,
frame_type,
channel,
})
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn frame_type_round_trip() {
for ft in [FrameType::Amqp, FrameType::Sasl, FrameType::Reserved(0xFE)] {
assert_eq!(FrameType::from_u8(ft.to_u8()), ft);
}
}
#[test]
fn header_round_trip_minimum_size() {
let h = FrameHeader::new_amqp(8, 2, 0);
let bytes = encode_frame_header(h);
assert_eq!(bytes, [0, 0, 0, 8, 2, 0, 0, 0]);
let parsed = decode_frame_header(&bytes).expect("decode");
assert_eq!(parsed, h);
}
#[test]
fn header_with_channel_42_and_size_1024() {
let h = FrameHeader {
size: 1024,
doff: 2,
frame_type: FrameType::Amqp,
channel: 42,
};
let bytes = encode_frame_header(h);
let parsed = decode_frame_header(&bytes).expect("decode");
assert_eq!(parsed, h);
assert_eq!(parsed.channel, 42);
assert_eq!(parsed.size, 1024);
}
#[test]
fn body_offset_is_doff_times_4() {
let h = FrameHeader::new_amqp(20, 3, 0); assert_eq!(h.body_offset(), 12);
let h2 = FrameHeader::new_amqp(8, 2, 0); assert_eq!(h2.body_offset(), 8);
}
#[test]
fn header_too_short_decode_fails() {
assert_eq!(decode_frame_header(&[]), Err(FrameError::HeaderTooShort));
assert_eq!(
decode_frame_header(&[0; 7]),
Err(FrameError::HeaderTooShort)
);
}
#[test]
fn doff_below_2_rejected() {
let bytes = [0u8, 0, 0, 8, 1, 0, 0, 0];
assert_eq!(
decode_frame_header(&bytes),
Err(FrameError::InvalidDataOffset(1))
);
}
#[test]
fn size_below_8_rejected() {
let bytes = [0u8, 0, 0, 4, 2, 0, 0, 0];
assert_eq!(
decode_frame_header(&bytes),
Err(FrameError::SizeBelowMinimum(4))
);
}
#[test]
fn body_offset_exceeding_size_rejected() {
let bytes = [0u8, 0, 0, 8, 4, 0, 0, 0];
assert_eq!(
decode_frame_header(&bytes),
Err(FrameError::BodyOffsetExceedsSize)
);
}
#[test]
fn sasl_frame_type_byte_is_1() {
let h = FrameHeader {
size: 8,
doff: 2,
frame_type: FrameType::Sasl,
channel: 0,
};
let bytes = encode_frame_header(h);
assert_eq!(bytes[5], 0x01);
}
}