use super::frame::MAX_FRAME_SIZE;
use std::io::{ErrorKind, Read};
use zamsync_core::{ZamError, ZamResult};
use zstd;
pub struct FrameBuffer {
buf: Vec<u8>,
}
impl Default for FrameBuffer {
fn default() -> Self {
Self::new()
}
}
impl FrameBuffer {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn try_read_frame(&mut self, stream: &mut impl Read) -> ZamResult<Option<Vec<u8>>> {
if let Some(frame) = self.try_consume_frame()? {
return Ok(Some(frame));
}
let mut tmp = [0u8; 8192];
let mut got_new_bytes = false;
loop {
match stream.read(&mut tmp) {
Ok(0) => {
if !got_new_bytes {
return Err(ZamError::Io(std::io::Error::new(
ErrorKind::UnexpectedEof,
"connection closed by peer",
)));
}
break;
}
Ok(n) => {
self.buf.extend_from_slice(&tmp[..n]);
got_new_bytes = true;
}
Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
break;
}
Err(e) => return Err(ZamError::Io(e)),
}
}
self.try_consume_frame()
}
fn try_consume_frame(&mut self) -> ZamResult<Option<Vec<u8>>> {
if self.buf.len() < 4 {
return Ok(None);
}
let total_len =
u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]) as usize;
if total_len == 0 {
self.buf.drain(..4);
return Ok(Some(vec![]));
}
if total_len as u64 > MAX_FRAME_SIZE as u64 {
return Err(ZamError::Protocol(format!(
"received frame too large: {} bytes (max {})",
total_len, MAX_FRAME_SIZE
)));
}
let frame_end = 4 + total_len;
if self.buf.len() < frame_end {
return Ok(None);
}
let flag = self.buf[4];
let body = self.buf[5..frame_end].to_vec();
self.buf.drain(..frame_end);
const FLAG_RAW: u8 = 0x00;
const FLAG_ZSTD: u8 = 0x01;
let payload = match flag {
FLAG_RAW => body,
FLAG_ZSTD => zstd::decode_all(body.as_slice())
.map_err(|e| ZamError::Protocol(format!("zstd decompress: {e}")))?,
other => {
return Err(ZamError::Protocol(format!(
"unknown frame flag: 0x{other:02x}"
)))
}
};
Ok(Some(payload))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::frame::write_frame;
use std::io::Cursor;
fn make_frame(payload: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
write_frame(&mut buf, payload).unwrap();
buf
}
#[test]
fn test_complete_frame_at_once() {
let payload = b"hello from bhutan";
let wire = make_frame(payload);
let mut fb = FrameBuffer::new();
let result = fb.try_read_frame(&mut Cursor::new(&wire)).unwrap();
assert_eq!(result, Some(payload.to_vec()));
assert!(fb.buf.is_empty());
}
#[test]
fn test_two_frames_back_to_back() {
let wire1 = make_frame(b"frame-one");
let wire2 = make_frame(b"frame-two");
let mut combined = wire1.clone();
combined.extend_from_slice(&wire2);
let mut fb = FrameBuffer::new();
let r1 = fb.try_read_frame(&mut Cursor::new(&combined)).unwrap();
assert_eq!(r1, Some(b"frame-one".to_vec()));
let r2 = fb.try_read_frame(&mut Cursor::new(&[])).unwrap();
assert_eq!(r2, Some(b"frame-two".to_vec()));
}
#[test]
fn test_partial_header_returns_none() {
let wire = make_frame(b"some data");
let partial = &wire[..2]; let mut fb = FrameBuffer::new();
let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
assert!(result.is_none());
assert_eq!(fb.buf.len(), 2);
}
#[test]
fn test_partial_body_returns_none() {
let wire = make_frame(b"some longer payload that has many bytes");
let partial = &wire[..wire.len() - 5]; let mut fb = FrameBuffer::new();
let result = fb.try_read_frame(&mut Cursor::new(partial)).unwrap();
assert!(result.is_none());
assert_eq!(fb.buf.len(), partial.len());
}
#[test]
fn test_split_delivery_reassembles_frame() {
let payload = b"patient-record-from-rural-bhutan";
let wire = make_frame(payload);
let mid = wire.len() / 2;
let mut fb = FrameBuffer::new();
let r1 = fb.try_read_frame(&mut Cursor::new(&wire[..mid])).unwrap();
assert!(r1.is_none());
let r2 = fb.try_read_frame(&mut Cursor::new(&wire[mid..])).unwrap();
assert_eq!(r2, Some(payload.to_vec()));
}
#[test]
fn test_empty_reader_on_empty_buffer_is_eof() {
let mut fb = FrameBuffer::new();
let result = fb.try_read_frame(&mut Cursor::new(&[]));
assert!(matches!(result, Err(ZamError::Io(_))));
}
}