use bytes::Bytes;
use crate::error::{FrameError, Result, SrxError};
pub const WIRE_HEADER_LEN: usize = 8 + 4 + 2;
pub const WIRE_MAC_LEN: usize = 16;
pub const MAX_WIRE_PAYLOAD: usize = u16::MAX as usize;
#[derive(Debug, Clone)]
pub struct Frame {
pub frame_id: u64,
pub routing_mask: u32,
pub payload: Bytes,
pub mac: [u8; WIRE_MAC_LEN],
pub fragment: Option<FragmentInfo>,
}
#[derive(Debug, Clone, Copy)]
pub struct FragmentInfo {
pub sequence: u64,
pub index: u16,
pub total: u16,
}
pub struct FrameCodec {
_private: (),
}
impl FrameCodec {
pub fn new() -> Self {
Self { _private: () }
}
pub fn encode(&self, frame: &Frame) -> Result<Vec<u8>> {
let plen = frame.payload.len();
if plen > MAX_WIRE_PAYLOAD {
return Err(SrxError::Frame(FrameError::FrameTooLarge {
size: plen,
max: MAX_WIRE_PAYLOAD,
}));
}
let mut out = Vec::with_capacity(WIRE_HEADER_LEN + plen + WIRE_MAC_LEN);
out.extend_from_slice(&frame.frame_id.to_be_bytes());
out.extend_from_slice(&frame.routing_mask.to_be_bytes());
out.extend_from_slice(&(plen as u16).to_be_bytes());
out.extend_from_slice(&frame.payload);
out.extend_from_slice(&frame.mac);
Ok(out)
}
pub fn decode(&self, data: &[u8]) -> Result<Frame> {
if data.len() < WIRE_HEADER_LEN + WIRE_MAC_LEN {
return Err(SrxError::Frame(FrameError::Corrupted(
"truncated frame".into(),
)));
}
let frame_id = read_u64(data, 0)?;
let routing_mask = read_u32(data, 8)?;
let payload_len = u16::from_be_bytes([data[12], data[13]]) as usize;
let expected = WIRE_HEADER_LEN
.checked_add(payload_len)
.and_then(|n| n.checked_add(WIRE_MAC_LEN))
.ok_or_else(|| {
SrxError::Frame(FrameError::Corrupted("frame length overflow".into()))
})?;
if data.len() != expected {
return Err(SrxError::Frame(FrameError::Corrupted(format!(
"length mismatch: got {} bytes, expected {}",
data.len(),
expected
))));
}
let payload_start = WIRE_HEADER_LEN;
let payload_end = payload_start + payload_len;
let mac_start = payload_end;
let payload = Bytes::copy_from_slice(&data[payload_start..payload_end]);
let mac: [u8; WIRE_MAC_LEN] = data[mac_start..mac_start + WIRE_MAC_LEN]
.try_into()
.map_err(|_| SrxError::Frame(FrameError::Corrupted("mac field".into())))?;
Ok(Frame {
frame_id,
routing_mask,
payload,
mac,
fragment: None,
})
}
}
impl Default for FrameCodec {
fn default() -> Self {
Self::new()
}
}
fn read_u64(data: &[u8], offset: usize) -> Result<u64> {
let slice = data
.get(offset..offset + 8)
.ok_or_else(|| SrxError::Frame(FrameError::Corrupted("frame_id".into())))?;
let arr: [u8; 8] = slice
.try_into()
.map_err(|_| SrxError::Frame(FrameError::Corrupted("frame_id".into())))?;
Ok(u64::from_be_bytes(arr))
}
fn read_u32(data: &[u8], offset: usize) -> Result<u32> {
let slice = data
.get(offset..offset + 4)
.ok_or_else(|| SrxError::Frame(FrameError::Corrupted("routing_mask".into())))?;
let arr: [u8; 4] = slice
.try_into()
.map_err(|_| SrxError::Frame(FrameError::Corrupted("routing_mask".into())))?;
Ok(u32::from_be_bytes(arr))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_empty_payload() {
let codec = FrameCodec::new();
let frame = Frame {
frame_id: 0xdead_beef_cafe_babe,
routing_mask: 0x1122_3344,
payload: Bytes::new(),
mac: [0x5Au8; WIRE_MAC_LEN],
fragment: None,
};
let wire = codec.encode(&frame).unwrap();
let out = codec.decode(&wire).unwrap();
assert_eq!(out.frame_id, frame.frame_id);
assert_eq!(out.routing_mask, frame.routing_mask);
assert!(out.payload.is_empty());
assert_eq!(out.mac, frame.mac);
}
#[test]
fn roundtrip_with_payload() {
let codec = FrameCodec::new();
let frame = Frame {
frame_id: 1,
routing_mask: 2,
payload: Bytes::from_static(b"cipher-bytes"),
mac: [7u8; WIRE_MAC_LEN],
fragment: None,
};
let wire = codec.encode(&frame).unwrap();
let out = codec.decode(&wire).unwrap();
assert_eq!(out.payload.as_ref(), b"cipher-bytes");
}
#[test]
fn rejects_oversized_payload() {
let codec = FrameCodec::new();
let big = vec![0u8; MAX_WIRE_PAYLOAD + 1];
let frame = Frame {
frame_id: 0,
routing_mask: 0,
payload: Bytes::from(big),
mac: [0u8; WIRE_MAC_LEN],
fragment: None,
};
assert!(codec.encode(&frame).is_err());
}
}