use bytes::{Bytes, BytesMut};
use po_transport::traits::{AsyncFrameTransport, TransportError};
use po_wire::{FrameHeader, WireError};
const DEFAULT_MAX_FRAME_SIZE: u64 = 10 * 1024 * 1024;
pub struct Framer {
read_buf: BytesMut,
max_frame_size: u64,
}
impl Default for Framer {
fn default() -> Self {
Self::new()
}
}
impl Framer {
pub fn new() -> Self {
Self {
read_buf: BytesMut::with_capacity(65536),
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
pub fn with_max_frame_size(mut self, max: u64) -> Self {
self.max_frame_size = max;
self
}
pub async fn write_frame(
&self,
transport: &mut dyn AsyncFrameTransport,
header: &FrameHeader,
payload: &[u8],
) -> Result<(), FramerError> {
let header_len = header.encoded_len();
let total_len = header_len + payload.len();
let mut combined = Vec::with_capacity(total_len);
combined.resize(header_len, 0u8);
header
.encode(&mut combined[..header_len])
.map_err(FramerError::Wire)?;
combined.extend_from_slice(payload);
transport
.write_all(&combined)
.await
.map_err(FramerError::Transport)?;
Ok(())
}
pub async fn read_frame(
&mut self,
transport: &mut dyn AsyncFrameTransport,
) -> Result<Option<(FrameHeader, Bytes)>, FramerError> {
loop {
if let Some((header, header_len)) = self.try_parse_header()? {
if header.payload_len > self.max_frame_size {
return Err(FramerError::Wire(WireError::PayloadTooLarge {
declared: header.payload_len,
max_allowed: self.max_frame_size,
}));
}
let total_needed = header_len + header.payload_len as usize;
if self.read_buf.len() >= total_needed {
let _ = self.read_buf.split_to(header_len);
let payload = self.read_buf.split_to(header.payload_len as usize).freeze();
return Ok(Some((header, payload)));
}
let still_needed = total_needed - self.read_buf.len();
if !self.fill_buffer(transport, still_needed).await? {
return Ok(None); }
continue;
}
if !self.fill_buffer(transport, 1).await? {
if self.read_buf.is_empty() {
return Ok(None); }
return Err(FramerError::Wire(WireError::Incomplete {
needed_min: 4,
available: self.read_buf.len(),
}));
}
}
}
fn try_parse_header(&self) -> Result<Option<(FrameHeader, usize)>, FramerError> {
if self.read_buf.is_empty() {
return Ok(None);
}
match FrameHeader::decode(&self.read_buf) {
Ok((header, len)) => Ok(Some((header, len))),
Err(WireError::Incomplete { .. }) => Ok(None), Err(e) => Err(FramerError::Wire(e)),
}
}
async fn fill_buffer(
&mut self,
transport: &mut dyn AsyncFrameTransport,
min_bytes: usize,
) -> Result<bool, FramerError> {
let mut total = 0;
let mut tmp = [0u8; 65536];
while total < min_bytes {
match transport.read(&mut tmp).await {
Ok(n) => {
self.read_buf.extend_from_slice(&tmp[..n]);
total += n;
}
Err(TransportError::ConnectionClosed) => {
return Ok(false);
}
Err(e) => return Err(FramerError::Transport(e)),
}
}
Ok(true)
}
pub fn buffered(&self) -> usize {
self.read_buf.len()
}
}
#[derive(Debug)]
pub enum FramerError {
Wire(WireError),
Transport(TransportError),
}
impl std::fmt::Display for FramerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wire(e) => write!(f, "wire: {e}"),
Self::Transport(e) => write!(f, "transport: {e}"),
}
}
}
impl std::error::Error for FramerError {}
#[cfg(test)]
mod tests {
use super::*;
use po_transport::MemoryTransport;
use po_wire::FrameType;
#[tokio::test]
async fn write_and_read_data_frame() {
let (mut a, mut b) = MemoryTransport::pair(64);
let framer_w = Framer::new();
let mut framer_r = Framer::new();
let payload = b"Hello Protocol Orzatty!";
let header = FrameHeader::data(0, payload.len() as u64);
framer_w
.write_frame(&mut a, &header, payload)
.await
.unwrap();
let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
assert_eq!(recv_header.frame_type, FrameType::Data);
assert_eq!(recv_payload.as_ref(), payload);
}
#[tokio::test]
async fn write_and_read_control_frame() {
let (mut a, mut b) = MemoryTransport::pair(64);
let framer_w = Framer::new();
let mut framer_r = Framer::new();
let header = FrameHeader::control(FrameType::Ping);
framer_w.write_frame(&mut a, &header, &[]).await.unwrap();
let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
assert_eq!(recv_header.frame_type, FrameType::Ping);
assert!(recv_header.flags.control);
assert!(recv_payload.is_empty());
}
#[tokio::test]
async fn multiple_frames_sequential() {
let (mut a, mut b) = MemoryTransport::pair(64);
let framer_w = Framer::new();
let mut framer_r = Framer::new();
for i in 0u8..10 {
let payload = vec![i; (i as usize + 1) * 10];
let header = FrameHeader::data(i as u32, payload.len() as u64);
framer_w
.write_frame(&mut a, &header, &payload)
.await
.unwrap();
}
for i in 0u8..10 {
let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
assert_eq!(h.channel_id, i as u32);
assert_eq!(p.len(), (i as usize + 1) * 10);
assert!(p.iter().all(|&b| b == i));
}
}
#[tokio::test]
async fn eof_returns_none() {
let (a, mut b) = MemoryTransport::pair(64);
let mut framer_r = Framer::new();
drop(a);
let result = framer_r.read_frame(&mut b).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn large_payload() {
let (mut a, mut b) = MemoryTransport::pair(256);
let framer_w = Framer::new();
let mut framer_r = Framer::new();
let payload = vec![0xAB; 100_000]; let header = FrameHeader::data(1, payload.len() as u64);
framer_w
.write_frame(&mut a, &header, &payload)
.await
.unwrap();
let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
assert_eq!(h.payload_len, 100_000);
assert_eq!(p.as_ref(), payload.as_slice());
}
}