use crate::error::WireError;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
const MAGIC: u16 = 0x5652;
const HEADER_SIZE: usize = 44;
pub const FLAG_MAC_PRESENT: u16 = 0x0001;
pub const FLAG_COMPRESSED: u16 = 0x0002;
pub const FLAG_RAW_BINARY: u16 = 0x0010;
pub const FLAG_FRAGMENTED: u16 = 0x0004;
pub const COMPRESS_THRESHOLD: usize = 65_536;
pub const FRAG_HEADER_SIZE: usize = 10;
#[derive(Debug, Clone, Copy)]
pub struct FragmentHeader {
pub fragment_id: u16,
pub sequence: u16,
pub total: u16,
pub stream_id: u32,
}
pub fn parse_frag_header(payload: &[u8]) -> Option<FragmentHeader> {
if payload.len() < FRAG_HEADER_SIZE {
return None;
}
Some(FragmentHeader {
fragment_id: u16::from_be_bytes([payload[0], payload[1]]),
sequence: u16::from_be_bytes([payload[2], payload[3]]),
total: u16::from_be_bytes([payload[4], payload[5]]),
stream_id: u32::from_be_bytes([payload[6], payload[7], payload[8], payload[9]]),
})
}
const FRAME_READ_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct Frame {
pub magic: u16,
pub flags: u16,
pub length: u32,
pub target: [u8; 32],
pub crc32: u32,
pub payload: Arc<[u8]>,
pub mac: Option<[u8; 32]>,
}
pub fn serialize_header(frame: &Frame) -> [u8; HEADER_SIZE] {
let mut header = [0u8; HEADER_SIZE];
header[0..2].copy_from_slice(&frame.magic.to_be_bytes());
header[2..4].copy_from_slice(&frame.flags.to_be_bytes());
header[4..8].copy_from_slice(&frame.length.to_be_bytes());
header[8..40].copy_from_slice(&frame.target);
header[40..44].copy_from_slice(&frame.crc32.to_be_bytes());
header
}
pub async fn write_frame<W>(
stream: &mut W,
target: &str,
flags: u16,
payload: &[u8],
) -> Result<(), WireError>
where
W: AsyncWrite + Unpin,
{
if payload.len() > MAX_PAYLOAD_SIZE {
return Err(WireError::PayloadTooLarge(payload.len()));
}
let mut header = [0u8; HEADER_SIZE];
header[0..2].copy_from_slice(&MAGIC.to_be_bytes());
header[2..4].copy_from_slice(&flags.to_be_bytes());
header[4..8].copy_from_slice(&(payload.len() as u32).to_be_bytes());
let target_bytes = target.as_bytes();
let copy_len = target_bytes.len().min(32);
header[8..8 + copy_len].copy_from_slice(&target_bytes[..copy_len]);
let checksum = crc32fast::hash(payload);
header[40..44].copy_from_slice(&checksum.to_be_bytes());
stream.write_all(&header).await?;
stream.write_all(payload).await?;
Ok(())
}
pub async fn write_frame_raw<W>(stream: &mut W, frame: &Frame) -> Result<(), WireError>
where
W: AsyncWrite + Unpin,
{
if frame.payload.len() > MAX_PAYLOAD_SIZE {
return Err(WireError::PayloadTooLarge(frame.payload.len()));
}
let (wire_payload, wire_flags): (Arc<[u8]>, u16) = if frame.payload.len() >= COMPRESS_THRESHOLD
&& frame.flags & FLAG_COMPRESSED == 0
&& frame.flags & FLAG_RAW_BINARY == 0
{
match zstd::bulk::compress(&frame.payload, 3) {
Ok(c) if c.len() < frame.payload.len() => (Arc::from(c), frame.flags | FLAG_COMPRESSED),
_ => (frame.payload.clone(), frame.flags),
}
} else {
(frame.payload.clone(), frame.flags)
};
let wire_crc = crc32fast::hash(&wire_payload);
let wire_frame = Frame {
magic: frame.magic,
flags: wire_flags,
length: wire_payload.len() as u32,
target: frame.target,
crc32: wire_crc,
payload: wire_payload,
mac: frame.mac,
};
let header = serialize_header(&wire_frame);
stream.write_all(&header).await?;
stream.write_all(&wire_frame.payload).await?;
if let Some(tag) = &wire_frame.mac {
stream.write_all(tag).await?;
}
Ok(())
}
pub async fn read_frame<R>(stream: &mut R) -> Result<Frame, WireError>
where
R: AsyncRead + Unpin,
{
read_frame_with_timeout(stream, FRAME_READ_TIMEOUT).await
}
pub async fn read_frame_with_timeout<R>(
stream: &mut R,
frame_timeout: Duration,
) -> Result<Frame, WireError>
where
R: AsyncRead + Unpin,
{
let mut first = [0u8; 1];
stream.read_exact(&mut first).await?;
match tokio::time::timeout(frame_timeout, read_frame_body(stream, first[0])).await {
Ok(result) => result,
Err(_) => Err(WireError::FrameReadTimeout),
}
}
async fn read_frame_body<R>(stream: &mut R, first_byte: u8) -> Result<Frame, WireError>
where
R: AsyncRead + Unpin,
{
let mut header = [0u8; HEADER_SIZE];
header[0] = first_byte;
stream.read_exact(&mut header[1..]).await?;
let magic = u16::from_be_bytes([header[0], header[1]]);
if magic != MAGIC {
return Err(WireError::FrameMagicMismatch);
}
let flags = u16::from_be_bytes([header[2], header[3]]);
let length = u32::from_be_bytes([header[4], header[5], header[6], header[7]]);
if length as usize > MAX_PAYLOAD_SIZE {
return Err(WireError::PayloadTooLarge(length as usize));
}
let mut target = [0u8; 32];
target.copy_from_slice(&header[8..40]);
let crc32 = u32::from_be_bytes([header[40], header[41], header[42], header[43]]);
let mut payload = vec![0u8; length as usize];
if length > 0 {
stream.read_exact(&mut payload).await?;
}
let computed = crc32fast::hash(&payload);
if computed != crc32 {
return Err(WireError::FrameCrcMismatch);
}
let (payload, flags, length, crc32) = if flags & FLAG_COMPRESSED != 0 {
let decompressed = zstd::bulk::decompress(&payload, MAX_PAYLOAD_SIZE)
.map_err(|e| WireError::Internal(format!("decompress frame: {e}")))?;
let plain_len = decompressed.len() as u32;
let plain_crc = crc32fast::hash(&decompressed);
(decompressed, flags & !FLAG_COMPRESSED, plain_len, plain_crc)
} else {
(payload, flags, length, crc32)
};
let mac = if flags & FLAG_MAC_PRESENT != 0 {
let mut tag = [0u8; 32];
stream.read_exact(&mut tag).await?;
Some(tag)
} else {
None
};
Ok(Frame {
magic,
flags,
length,
target,
crc32,
payload: payload.into(),
mac,
})
}
pub fn target_as_str(frame: &Frame) -> Option<&str> {
let end = frame.target.iter().position(|&b| b == 0).unwrap_or(32);
std::str::from_utf8(&frame.target[..end]).ok()
}