use bytes::{Buf, BufMut, Bytes, BytesMut};
use crc32fast::Hasher as Crc32;
use serde::{Deserialize, Serialize};
use std::fmt;
use tokio_util::codec::{Decoder, Encoder};
pub const MAGIC: u32 = 0x5743_4C50; pub const VERSION: u8 = 1;
pub const HEADER_LEN: usize = 28;
pub const FLAG_SIZE_EXCEEDED: u16 = 0x0001;
pub const FLAG_TTL_EXPIRED: u16 = 0x0002;
pub const FLAG_TOKEN_INVALID: u16 = 0x0004;
#[repr(u8)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum MsgKind {
Hello = 0x01,
HelloAck = 0x02,
Data = 0x03,
Ack = 0x04,
Error = 0x05,
KeepAlive = 0x06,
Close = 0x07,
SessionInit = 0x08,
AckBitmap = 0x09,
SessionDone = 0x0A,
SessionAbort = 0x0B,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct FrameHeader {
pub magic: u32,
pub version: u8,
pub kind: MsgKind,
pub flags: u16,
pub session_id: u64,
pub ttl_ms: u32,
pub payload_len: u32,
pub crc32: u32,
}
impl FrameHeader {
pub fn new(kind: MsgKind, flags: u16, session_id: u64, ttl_ms: u32, payload_len: u32) -> Self {
Self {
magic: MAGIC,
version: VERSION,
kind,
flags,
session_id,
ttl_ms,
payload_len,
crc32: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
pub header: FrameHeader,
pub payload: Bytes,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
pub struct SessionInitPayload {
pub total_len: u64,
pub chunk_bytes: u32,
pub ttl_ms: u32,
pub idle_ms: u32,
#[serde(default = "default_content_type")]
pub content_type: String,
#[serde(default)]
pub meta: Vec<u8>,
#[serde(default)]
pub token: Option<String>,
}
fn default_content_type() -> String {
"text/plain".to_string()
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct AckBitmapPayload {
pub base_chunk: u32,
pub bitmap: Vec<u8>,
}
#[derive(Debug, thiserror::Error)]
pub enum BridgeError {
#[error("incomplete frame")]
Incomplete,
#[error("bad magic: {0:#x}")]
BadMagic(u32),
#[error("unsupported version: {0}")]
BadVersion(u8),
#[error("payload too large: {got} > {max}")]
TooLarge { max: u64, got: u64 },
#[error("crc mismatch: expected {expected:#x} got {got:#x}")]
CrcMismatch { expected: u32, got: u32 },
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("decode error: {0}")]
Decode(&'static str),
}
impl From<BridgeError> for std::io::Error {
fn from(err: BridgeError) -> Self {
match err {
BridgeError::Io(e) => e,
other => std::io::Error::new(std::io::ErrorKind::InvalidData, other),
}
}
}
pub struct BridgeCodec {
pub max_payload_bytes: u64,
}
impl BridgeCodec {
pub fn new(max_payload_bytes: u64) -> Self {
Self { max_payload_bytes }
}
fn write_header(buf: &mut BytesMut, header: &FrameHeader) {
buf.put_u32(header.magic);
buf.put_u8(header.version);
buf.put_u8(header.kind as u8);
buf.put_u16(header.flags);
buf.put_u64(header.session_id);
buf.put_u32(header.ttl_ms);
buf.put_u32(header.payload_len);
buf.put_u32(header.crc32);
}
fn parse_header(buf: &mut BytesMut) -> Result<FrameHeader, BridgeError> {
if buf.len() < HEADER_LEN {
return Err(BridgeError::Incomplete);
}
let magic = buf.get_u32();
let version = buf.get_u8();
let kind_raw = buf.get_u8();
let flags = buf.get_u16();
let session_id = buf.get_u64();
let ttl_ms = buf.get_u32();
let payload_len = buf.get_u32();
let crc32 = buf.get_u32();
if magic != MAGIC {
return Err(BridgeError::BadMagic(magic));
}
if version != VERSION {
return Err(BridgeError::BadVersion(version));
}
let kind = match kind_raw {
0x01 => MsgKind::Hello,
0x02 => MsgKind::HelloAck,
0x03 => MsgKind::Data,
0x04 => MsgKind::Ack,
0x05 => MsgKind::Error,
0x06 => MsgKind::KeepAlive,
0x07 => MsgKind::Close,
0x08 => MsgKind::SessionInit,
0x09 => MsgKind::AckBitmap,
0x0A => MsgKind::SessionDone,
0x0B => MsgKind::SessionAbort,
_ => return Err(BridgeError::Decode("unknown message kind")),
};
Ok(FrameHeader {
magic,
version,
kind,
flags,
session_id,
ttl_ms,
payload_len,
crc32,
})
}
}
impl Encoder<Frame> for BridgeCodec {
type Error = std::io::Error;
fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
if (item.header.payload_len as u64) > self.max_payload_bytes {
return Err(BridgeError::TooLarge {
max: self.max_payload_bytes,
got: item.header.payload_len as u64,
}
.into());
}
let mut header = item.header;
let mut hasher = Crc32::new();
let mut header_bytes = BytesMut::with_capacity(HEADER_LEN);
header.crc32 = 0;
Self::write_header(&mut header_bytes, &header);
hasher.update(&header_bytes);
hasher.update(&item.payload);
header.crc32 = hasher.finalize();
let frame_len = HEADER_LEN + item.payload.len();
dst.reserve(4 + frame_len);
dst.put_u32(frame_len as u32);
let mut header_buf = BytesMut::with_capacity(HEADER_LEN);
Self::write_header(&mut header_buf, &header);
dst.put_slice(&header_buf);
dst.put_slice(&item.payload);
Ok(())
}
}
impl Decoder for BridgeCodec {
type Item = Frame;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let mut len_prefix = src.clone();
let frame_len = len_prefix.get_u32() as usize;
if frame_len > self.max_payload_bytes as usize + HEADER_LEN {
return Err(BridgeError::TooLarge {
max: self.max_payload_bytes + HEADER_LEN as u64,
got: frame_len as u64,
}
.into());
}
if src.len() < 4 + frame_len {
return Ok(None);
}
src.advance(4);
let mut frame_bytes = src.split_to(frame_len);
let mut header_buf = frame_bytes.split_to(HEADER_LEN);
let header = Self::parse_header(&mut header_buf).map_err(|e| std::io::Error::from(e))?;
if header.payload_len as usize != frame_bytes.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"payload length mismatch",
));
}
if (header.payload_len as u64) > self.max_payload_bytes {
return Err(BridgeError::TooLarge {
max: self.max_payload_bytes,
got: header.payload_len as u64,
}
.into());
}
let mut hasher = Crc32::new();
let mut header_zeroed = header;
header_zeroed.crc32 = 0;
let mut hb = BytesMut::with_capacity(HEADER_LEN);
Self::write_header(&mut hb, &header_zeroed);
hasher.update(&hb);
hasher.update(&frame_bytes);
let crc = hasher.finalize();
if crc != header.crc32 {
return Err(BridgeError::CrcMismatch {
expected: header.crc32,
got: crc,
}
.into());
}
Ok(Some(Frame {
header,
payload: frame_bytes.freeze(),
}))
}
}
impl fmt::Debug for BridgeCodec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BridgeCodec")
.field("max_payload_bytes", &self.max_payload_bytes)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_frame() {
let mut codec = BridgeCodec::new(1024);
let payload = Bytes::from_static(b"hello");
let header = FrameHeader::new(MsgKind::Data, 0, 42, 0, payload.len() as u32);
let frame = Frame { header, payload };
let mut buf = BytesMut::new();
codec.encode(frame.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.header.kind, MsgKind::Data);
assert_eq!(decoded.header.session_id, 42);
assert_eq!(decoded.payload, Bytes::from_static(b"hello"));
}
#[test]
fn rejects_crc_mismatch() {
let mut codec = BridgeCodec::new(1024);
let payload = Bytes::from_static(b"hello");
let header = FrameHeader::new(MsgKind::Data, 0, 1, 0, payload.len() as u32);
let frame = Frame { header, payload };
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
let idx = buf.len() - 1;
buf[idx] ^= 0xFF;
let err = codec.decode(&mut buf).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn enforces_max_payload() {
let mut codec = BridgeCodec::new(4);
let payload = Bytes::from_static(b"hello");
let header = FrameHeader::new(MsgKind::Data, 0, 0, 0, payload.len() as u32);
let frame = Frame { header, payload };
let mut buf = BytesMut::new();
let err = codec.encode(frame, &mut buf).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
}