use crate::error::{Error, Result};
use bytes::{BufMut, Bytes, BytesMut};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum FrameType {
Data = 0,
Control = 1,
Heartbeat = 2,
FragmentStart = 3,
FragmentContinuation = 4,
FragmentEnd = 5,
}
impl TryFrom<u8> for FrameType {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
0 => Ok(FrameType::Data),
1 => Ok(FrameType::Control),
2 => Ok(FrameType::Heartbeat),
3 => Ok(FrameType::FragmentStart),
4 => Ok(FrameType::FragmentContinuation),
5 => Ok(FrameType::FragmentEnd),
_ => Err(Error::Protocol(format!("Invalid frame type: {}", value))),
}
}
}
#[derive(Debug, Clone)]
pub struct FrameHeader {
pub frame_type: FrameType,
pub version: u8,
pub compressed: bool,
pub fragmented: bool,
pub payload_length: u32,
}
impl FrameHeader {
pub const SIZE: usize = 8;
pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload_length: u32) -> Self {
Self {
frame_type,
version,
compressed,
fragmented: false,
payload_length,
}
}
pub fn encode(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0] = ((self.frame_type as u8) << 4) | (self.version & 0x0F);
let mut flags = 0u8;
if self.compressed {
flags |= 0x80; }
if self.fragmented {
flags |= 0x40; }
buf[1] = flags;
buf[2..6].copy_from_slice(&self.payload_length.to_be_bytes());
buf
}
pub fn decode(data: &[u8]) -> Result<Self> {
if data.len() < Self::SIZE {
return Err(Error::Protocol(format!(
"Insufficient data for frame header: expected {}, got {}",
Self::SIZE,
data.len()
)));
}
let frame_type = FrameType::try_from(data[0] >> 4)?;
let version = data[0] & 0x0F;
let compressed = (data[1] & 0x80) != 0;
let fragmented = (data[1] & 0x40) != 0;
let payload_length = u32::from_be_bytes([data[2], data[3], data[4], data[5]]);
Ok(Self {
frame_type,
version,
compressed,
fragmented,
payload_length,
})
}
pub fn total_size(&self) -> usize {
Self::SIZE + self.payload_length as usize
}
}
#[derive(Debug, Clone)]
pub struct Frame {
pub header: FrameHeader,
pub payload: Bytes,
}
impl Frame {
pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload: Bytes) -> Self {
let header = FrameHeader::new(frame_type, version, compressed, payload.len() as u32);
Self { header, payload }
}
pub fn data(version: u8, compressed: bool, payload: Bytes) -> Self {
Self::new(FrameType::Data, version, compressed, payload)
}
pub fn control(version: u8, payload: Bytes) -> Self {
Self::new(FrameType::Control, version, false, payload)
}
pub fn heartbeat(version: u8) -> Self {
Self::new(FrameType::Heartbeat, version, false, Bytes::new())
}
pub fn size(&self) -> usize {
self.header.total_size()
}
}
pub struct FrameCodec {
max_payload_size: u32,
}
impl FrameCodec {
pub fn new() -> Self {
Self {
max_payload_size: 16 * 1024 * 1024, }
}
pub fn with_max_payload_size(max_payload_size: u32) -> Self {
Self { max_payload_size }
}
pub fn encode(&self, frame: &Frame) -> Result<Bytes> {
if frame.header.payload_length > self.max_payload_size {
return Err(Error::Protocol(format!(
"Payload size {} exceeds maximum {}",
frame.header.payload_length, self.max_payload_size
)));
}
let mut buf = BytesMut::with_capacity(frame.size());
buf.put_slice(&frame.header.encode());
buf.put_slice(&frame.payload);
Ok(buf.freeze())
}
pub fn decode(&self, data: &[u8]) -> Result<Frame> {
let header = FrameHeader::decode(data)?;
if header.payload_length > self.max_payload_size {
return Err(Error::Protocol(format!(
"Payload size {} exceeds maximum {}",
header.payload_length, self.max_payload_size
)));
}
let total_size = header.total_size();
if data.len() < total_size {
return Err(Error::Protocol(format!(
"Insufficient data for frame: expected {}, got {}",
total_size,
data.len()
)));
}
let payload = Bytes::copy_from_slice(&data[FrameHeader::SIZE..total_size]);
Ok(Frame { header, payload })
}
pub fn decode_all(&self, data: &[u8]) -> Result<Vec<Frame>> {
let mut frames = Vec::new();
let mut offset = 0;
while offset < data.len() {
if data.len() - offset < FrameHeader::SIZE {
break;
}
let header = FrameHeader::decode(&data[offset..])?;
let total_size = header.total_size();
if data.len() - offset < total_size {
break;
}
let frame = self.decode(&data[offset..])?;
frames.push(frame);
offset += total_size;
}
Ok(frames)
}
}
impl Default for FrameCodec {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_header_encode_decode() -> Result<()> {
let header = FrameHeader::new(FrameType::Data, 1, true, 1024);
let encoded = header.encode();
let decoded = FrameHeader::decode(&encoded)?;
assert_eq!(header.frame_type as u8, decoded.frame_type as u8);
assert_eq!(header.version, decoded.version);
assert_eq!(header.compressed, decoded.compressed);
assert_eq!(header.payload_length, decoded.payload_length);
Ok(())
}
#[test]
fn test_frame_encode_decode() -> Result<()> {
let codec = FrameCodec::new();
let payload = Bytes::from(vec![1, 2, 3, 4, 5]);
let frame = Frame::data(1, false, payload.clone());
let encoded = codec.encode(&frame)?;
let decoded = codec.decode(&encoded)?;
assert_eq!(
frame.header.frame_type as u8,
decoded.header.frame_type as u8
);
assert_eq!(frame.payload, decoded.payload);
Ok(())
}
#[test]
fn test_frame_codec_decode_all() -> Result<()> {
let codec = FrameCodec::new();
let frame1 = Frame::data(1, false, Bytes::from(vec![1, 2, 3]));
let frame2 = Frame::data(1, false, Bytes::from(vec![4, 5, 6]));
let mut buf = BytesMut::new();
buf.put_slice(&codec.encode(&frame1)?);
buf.put_slice(&codec.encode(&frame2)?);
let frames = codec.decode_all(&buf)?;
assert_eq!(frames.len(), 2);
assert_eq!(frames[0].payload, Bytes::from(vec![1, 2, 3]));
assert_eq!(frames[1].payload, Bytes::from(vec![4, 5, 6]));
Ok(())
}
#[test]
fn test_frame_heartbeat() -> Result<()> {
let codec = FrameCodec::new();
let frame = Frame::heartbeat(1);
let encoded = codec.encode(&frame)?;
let decoded = codec.decode(&encoded)?;
assert_eq!(decoded.header.frame_type as u8, FrameType::Heartbeat as u8);
assert!(decoded.payload.is_empty());
Ok(())
}
#[test]
fn test_frame_max_size() {
let codec = FrameCodec::with_max_payload_size(100);
let payload = Bytes::from(vec![0; 200]);
let frame = Frame::data(1, false, payload);
let result = codec.encode(&frame);
assert!(result.is_err());
}
}