use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::error::ProtocolError;
use crate::opcodes::OpCode;
pub const PROTOCOL_VERSION: u8 = 0x01;
pub const JSON_MODE_FLAG: u8 = 0x80;
pub const ZSTD_FLAG: u8 = 0x40;
pub const VERSION_MASK: u8 = 0x3F;
pub const HEADER_SIZE: usize = 10;
pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct Frame {
pub version: u8,
pub opcode: OpCode,
pub stream_id: u32,
pub payload: Bytes,
}
impl Frame {
pub fn new(opcode: OpCode, stream_id: u32, payload: Bytes) -> Self {
Self {
version: PROTOCOL_VERSION,
opcode,
stream_id,
payload,
}
}
pub fn empty(opcode: OpCode, stream_id: u32) -> Self {
Self::new(opcode, stream_id, Bytes::new())
}
pub fn is_json_mode(&self) -> bool {
self.version & JSON_MODE_FLAG != 0
}
pub fn is_compressed(&self) -> bool {
self.version & ZSTD_FLAG != 0
}
pub fn with_compression(mut self) -> Self {
self.version |= ZSTD_FLAG;
self
}
pub fn wire_size(&self) -> usize {
4 + 1 + 1 + 4 + self.payload.len()
}
pub fn encode(&self, dst: &mut BytesMut) {
let body_len = 1 + 1 + 4 + self.payload.len();
dst.reserve(4 + body_len);
dst.put_u32(body_len as u32);
dst.put_u8(self.version);
dst.put_u8(self.opcode as u8);
dst.put_u32(self.stream_id);
dst.put_slice(&self.payload);
}
pub fn decode(src: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
if src.len() < 4 {
return Ok(None);
}
let body_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
if body_len < 6 {
return Err(ProtocolError::FrameTooSmall(body_len));
}
if body_len > MAX_PAYLOAD_SIZE + 6 {
return Err(ProtocolError::FrameTooLarge(body_len));
}
let total = 4 + body_len;
if src.len() < total {
src.reserve(total - src.len());
return Ok(None);
}
let mut frame_buf = src.split_to(total);
frame_buf.advance(4);
let version = frame_buf.get_u8();
let opcode_byte = frame_buf.get_u8();
let stream_id = frame_buf.get_u32();
let payload = frame_buf.freeze();
let opcode =
OpCode::from_u8(opcode_byte).ok_or(ProtocolError::UnknownOpCode(opcode_byte))?;
Ok(Some(Frame {
version,
opcode,
stream_id,
payload,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode_roundtrip() {
let frame = Frame::new(OpCode::Remember, 42, Bytes::from_static(b"hello world"));
let mut buf = BytesMut::new();
frame.encode(&mut buf);
assert_eq!(buf.len(), frame.wire_size());
let decoded = Frame::decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, OpCode::Remember);
assert_eq!(decoded.stream_id, 42);
assert_eq!(decoded.payload, Bytes::from_static(b"hello world"));
assert_eq!(decoded.version, PROTOCOL_VERSION);
}
#[test]
fn decode_partial() {
let frame = Frame::empty(OpCode::Ping, 0);
let mut buf = BytesMut::new();
frame.encode(&mut buf);
let half = buf.len() / 2;
let mut partial = buf.split_to(half);
assert!(Frame::decode(&mut partial).unwrap().is_none());
}
#[test]
fn decode_unknown_opcode() {
let mut buf = BytesMut::new();
buf.put_u32(6); buf.put_u8(PROTOCOL_VERSION);
buf.put_u8(0xFF); buf.put_u32(0);
let err = Frame::decode(&mut buf).unwrap_err();
assert!(matches!(err, ProtocolError::UnknownOpCode(0xFF)));
}
#[test]
fn empty_frame() {
let frame = Frame::empty(OpCode::Pong, 7);
let mut buf = BytesMut::new();
frame.encode(&mut buf);
let decoded = Frame::decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, OpCode::Pong);
assert_eq!(decoded.stream_id, 7);
assert!(decoded.payload.is_empty());
}
}