use bincode::Options;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tokio::io::AsyncReadExt;
#[derive(Error, Debug)]
pub enum ProtocolError {
#[error("frame too large: {size} bytes (max {max})")]
FrameTooLarge { size: usize, max: usize },
#[error("message too large to encode: {size} bytes exceeds u32 max")]
EncodeTooLarge { size: usize },
#[error("deserialization failed: {0}")]
Deserialize(#[from] bincode::Error),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
pub const READ_BUF_SIZE: usize = 65536;
pub fn bincode_config() -> impl Options + Copy {
bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(MAX_FRAME_SIZE as u64)
}
pub fn encode(msg: &impl Serialize) -> Result<Vec<u8>, ProtocolError> {
let data = bincode_config().serialize(msg)?;
let len = u32::try_from(data.len())
.map_err(|_| ProtocolError::EncodeTooLarge { size: data.len() })?;
let mut buf = Vec::with_capacity(4 + data.len());
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&data);
Ok(buf)
}
pub fn decode<T: DeserializeOwned>(data: &[u8]) -> Result<T, ProtocolError> {
Ok(bincode_config().deserialize(data)?)
}
pub fn decode_frame(buf: &[u8]) -> Result<Option<(&[u8], usize)>, ProtocolError> {
if buf.len() < 4 {
return Ok(None);
}
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if len > MAX_FRAME_SIZE {
return Err(ProtocolError::FrameTooLarge {
size: len,
max: MAX_FRAME_SIZE,
});
}
if buf.len() < 4 + len {
return Ok(None);
}
Ok(Some((&buf[4..4 + len], 4 + len)))
}
pub async fn read_one_message<T: DeserializeOwned>(
reader: &mut (impl AsyncReadExt + Unpin),
) -> Result<T, ProtocolError> {
let mut frames = FrameReader::new();
loop {
if !frames.fill_from(reader).await? {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed",
)
.into());
}
if let Some(msg) = frames.decode_next()? {
debug_assert!(
frames.into_leftover().is_empty(),
"read_one_message dropped buffered bytes after the first frame; \
it must not be used on streams carrying multiple messages",
);
return Ok(msg);
}
}
}
pub struct FrameReader {
read_buf: Vec<u8>,
offset: usize,
tmp_buf: Vec<u8>,
}
impl FrameReader {
pub fn new() -> Self {
Self {
read_buf: Vec::new(),
offset: 0,
tmp_buf: vec![0u8; READ_BUF_SIZE],
}
}
pub fn with_leftover(leftover: Vec<u8>) -> Self {
Self {
read_buf: leftover,
offset: 0,
tmp_buf: vec![0u8; READ_BUF_SIZE],
}
}
pub async fn fill_from<R: AsyncReadExt + Unpin>(
&mut self,
reader: &mut R,
) -> Result<bool, ProtocolError> {
if self.offset > 0 {
self.read_buf.drain(..self.offset);
self.offset = 0;
}
let n = reader.read(&mut self.tmp_buf).await?;
if n == 0 {
return Ok(false);
}
self.read_buf.extend_from_slice(&self.tmp_buf[..n]);
if self.read_buf.len() > MAX_FRAME_SIZE + 4 + READ_BUF_SIZE {
return Err(ProtocolError::FrameTooLarge {
size: self.read_buf.len(),
max: MAX_FRAME_SIZE,
});
}
Ok(true)
}
pub fn decode_next<T: DeserializeOwned>(&mut self) -> Result<Option<T>, ProtocolError> {
match decode_frame(&self.read_buf[self.offset..])? {
Some((data, consumed)) => {
let msg: T = decode(data)?;
self.offset += consumed;
Ok(Some(msg))
}
None => Ok(None),
}
}
pub fn into_leftover(self) -> Vec<u8> {
self.read_buf[self.offset..].to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::messages::{
ClientMsg, ConnectMode, ServerMsg, SessionInfo, PROTOCOL_VERSION,
};
#[test]
fn encode_decode_round_trip() {
let msg = ClientMsg::Connect {
version: PROTOCOL_VERSION,
name: "test".into(),
history: 1000,
cols: 80,
rows: 24,
mode: ConnectMode::CreateOrAttach,
};
let encoded = encode(&msg).unwrap();
let (data, consumed) = decode_frame(&encoded).unwrap().unwrap();
assert_eq!(consumed, encoded.len());
let decoded: ClientMsg = decode(data).unwrap();
match decoded {
ClientMsg::Connect {
version,
name,
history,
cols,
rows,
mode,
} => {
assert_eq!(version, PROTOCOL_VERSION);
assert_eq!(name, "test");
assert_eq!(history, 1000);
assert_eq!(cols, 80);
assert_eq!(rows, 24);
assert_eq!(mode, ConnectMode::CreateOrAttach);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn encode_decode_server_msg() {
let msg = ServerMsg::SessionList(vec![SessionInfo {
name: "s1".into(),
pid: 123,
cols: 80,
rows: 24,
}]);
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ServerMsg = decode(data).unwrap();
match decoded {
ServerMsg::SessionList(list) => {
assert_eq!(list.len(), 1);
assert_eq!(list[0].name, "s1");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn decode_incomplete_frame() {
let msg = ClientMsg::Detach;
let encoded = encode(&msg).unwrap();
let result = decode_frame(&encoded[..3]).unwrap();
assert!(result.is_none());
let result = decode_frame(&encoded[..encoded.len() - 1]).unwrap();
assert!(result.is_none());
}
#[test]
fn decode_rejects_oversized_frame() {
let len_bytes = ((MAX_FRAME_SIZE + 1) as u32).to_be_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&len_bytes);
buf.extend_from_slice(&[0u8; 100]);
let result = decode_frame(&buf);
assert!(result.is_err());
match result.unwrap_err() {
ProtocolError::FrameTooLarge { size, max } => {
assert_eq!(size, MAX_FRAME_SIZE + 1);
assert_eq!(max, MAX_FRAME_SIZE);
}
other => panic!("expected FrameTooLarge, got {:?}", other),
}
}
#[test]
fn decode_accepts_max_size_frame() {
let len_bytes = (MAX_FRAME_SIZE as u32).to_be_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&len_bytes);
let result = decode_frame(&buf).unwrap();
assert!(result.is_none());
}
#[test]
fn encode_multiple_decode_sequential() {
let msg1 = ClientMsg::Detach;
let msg2 = ClientMsg::ListSessions {
version: PROTOCOL_VERSION,
};
let mut buf = encode(&msg1).unwrap();
buf.extend_from_slice(&encode(&msg2).unwrap());
let (data1, consumed1) = decode_frame(&buf).unwrap().unwrap();
let _: ClientMsg = decode(data1).unwrap();
let (data2, _) = decode_frame(&buf[consumed1..]).unwrap().unwrap();
let _: ClientMsg = decode(data2).unwrap();
}
#[tokio::test]
async fn read_one_message_success() {
let msg = ClientMsg::Detach;
let encoded = encode(&msg).unwrap();
let (mut write_half, mut read_half) = tokio::io::duplex(65536);
use tokio::io::AsyncWriteExt;
write_half.write_all(&encoded).await.unwrap();
drop(write_half); let result: ClientMsg = read_one_message(&mut read_half).await.unwrap();
match result {
ClientMsg::Detach => {} other => panic!("expected Detach, got {:?}", other),
}
}
#[tokio::test]
async fn read_one_message_connection_closed() {
let (write_half, mut read_half) = tokio::io::duplex(65536);
drop(write_half);
let result: Result<ClientMsg, _> = read_one_message(&mut read_half).await;
assert!(result.is_err(), "expected error on empty stream");
match result.unwrap_err() {
ProtocolError::Io(e) => {
assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
}
other => panic!("expected Io error, got {:?}", other),
}
}
#[tokio::test]
async fn read_one_message_server_msg() {
let msg = ServerMsg::Connected {
name: "my-session".into(),
new_session: true,
};
let encoded = encode(&msg).unwrap();
let (mut write_half, mut read_half) = tokio::io::duplex(65536);
use tokio::io::AsyncWriteExt;
write_half.write_all(&encoded).await.unwrap();
drop(write_half);
let result: ServerMsg = read_one_message(&mut read_half).await.unwrap();
match result {
ServerMsg::Connected { name, new_session } => {
assert_eq!(name, "my-session");
assert!(new_session);
}
other => panic!("expected Connected, got {:?}", other),
}
}
#[tokio::test]
async fn read_one_message_rejects_buffer_overflow() {
let (mut write_half, mut read_half) = tokio::io::duplex(65536);
use tokio::io::AsyncWriteExt;
let len_bytes = (MAX_FRAME_SIZE as u32).to_be_bytes();
write_half.write_all(&len_bytes).await.unwrap();
let junk = vec![0u8; MAX_FRAME_SIZE + 2 * READ_BUF_SIZE];
tokio::spawn(async move {
let _ = write_half.write_all(&junk).await;
});
let result: Result<ClientMsg, _> = read_one_message(&mut read_half).await;
assert!(
result.is_err(),
"should reject oversized buffer accumulation"
);
match result.unwrap_err() {
ProtocolError::FrameTooLarge { .. } | ProtocolError::Deserialize(_) => {}
other => panic!("expected FrameTooLarge or Deserialize, got {:?}", other),
}
}
#[tokio::test]
async fn fill_from_trips_tightened_overflow_bound() {
let (mut write_half, mut read_half) = tokio::io::duplex(READ_BUF_SIZE);
use tokio::io::AsyncWriteExt;
let len_bytes = (MAX_FRAME_SIZE as u32).to_be_bytes();
tokio::spawn(async move {
let _ = write_half.write_all(&len_bytes).await;
let chunk = vec![0u8; READ_BUF_SIZE];
for _ in 0..((MAX_FRAME_SIZE / READ_BUF_SIZE) + 4) {
if write_half.write_all(&chunk).await.is_err() {
break;
}
}
});
let mut reader = FrameReader::new();
let mut tripped = false;
loop {
match reader.fill_from(&mut read_half).await {
Ok(true) => {
if reader.read_buf.len() > MAX_FRAME_SIZE + 4 + READ_BUF_SIZE {
panic!("buffer exceeded bound without tripping the guard");
}
}
Ok(false) => break,
Err(ProtocolError::FrameTooLarge { size, max }) => {
assert!(size > MAX_FRAME_SIZE + 4 + READ_BUF_SIZE);
assert_eq!(max, MAX_FRAME_SIZE);
tripped = true;
break;
}
Err(other) => panic!("unexpected error: {:?}", other),
}
}
assert!(tripped, "overflow guard should have tripped");
}
#[tokio::test]
async fn fill_from_accepts_up_to_tightened_bound() {
let (mut write_half, mut read_half) = tokio::io::duplex(READ_BUF_SIZE);
use tokio::io::AsyncWriteExt;
let payload = vec![0u8; MAX_FRAME_SIZE];
let len_bytes = (MAX_FRAME_SIZE as u32).to_be_bytes();
tokio::spawn(async move {
let _ = write_half.write_all(&len_bytes).await;
let _ = write_half.write_all(&payload).await;
});
let mut reader = FrameReader::new();
loop {
if !reader.fill_from(&mut read_half).await.unwrap() {
break;
}
if reader.read_buf.len() - reader.offset >= 4 + MAX_FRAME_SIZE {
break;
}
}
let (data, consumed) = decode_frame(&reader.read_buf[reader.offset..])
.unwrap()
.unwrap();
assert_eq!(data.len(), MAX_FRAME_SIZE);
assert_eq!(consumed, 4 + MAX_FRAME_SIZE);
}
#[test]
fn decode_frame_zero_length() {
let mut buf = Vec::new();
buf.extend_from_slice(&0u32.to_be_bytes());
let result = decode_frame(&buf).unwrap();
assert!(result.is_some());
let (data, consumed) = result.unwrap();
assert_eq!(data.len(), 0);
assert_eq!(consumed, 4);
}
#[test]
fn decode_frame_empty_buffer() {
let result = decode_frame(&[]).unwrap();
assert!(result.is_none());
}
#[test]
fn decode_frame_1_byte_buffer() {
let result = decode_frame(&[0x00]).unwrap();
assert!(result.is_none());
}
#[test]
fn decode_frame_3_byte_buffer() {
let result = decode_frame(&[0x00, 0x00, 0x00]).unwrap();
assert!(result.is_none());
}
#[test]
fn frame_reader_with_leftover() {
let msg = ClientMsg::Detach;
let encoded = encode(&msg).unwrap();
let mut reader = FrameReader::with_leftover(encoded);
let result: Option<ClientMsg> = reader.decode_next().unwrap();
assert!(result.is_some());
match result.unwrap() {
ClientMsg::Detach => {}
other => panic!("expected Detach, got {:?}", other),
}
}
#[test]
fn frame_reader_multiple_messages_in_buffer() {
let msg1 = ClientMsg::Detach;
let msg2 = ClientMsg::ListSessions {
version: PROTOCOL_VERSION,
};
let msg3 = ClientMsg::RefreshScreen;
let mut buf = encode(&msg1).unwrap();
buf.extend_from_slice(&encode(&msg2).unwrap());
buf.extend_from_slice(&encode(&msg3).unwrap());
let mut reader = FrameReader::with_leftover(buf);
let r1: Option<ClientMsg> = reader.decode_next().unwrap();
assert!(matches!(r1, Some(ClientMsg::Detach)));
let r2: Option<ClientMsg> = reader.decode_next().unwrap();
assert!(matches!(r2, Some(ClientMsg::ListSessions { .. })));
let r3: Option<ClientMsg> = reader.decode_next().unwrap();
assert!(matches!(r3, Some(ClientMsg::RefreshScreen)));
let r4: Option<ClientMsg> = reader.decode_next().unwrap();
assert!(r4.is_none());
}
#[test]
fn frame_reader_leftover_after_decode() {
let msg = ClientMsg::Detach;
let mut buf = encode(&msg).unwrap();
buf.extend_from_slice(&[0xDE, 0xAD]);
let mut reader = FrameReader::with_leftover(buf);
let _: ClientMsg = reader.decode_next().unwrap().unwrap();
let leftover = reader.into_leftover();
assert_eq!(leftover, &[0xDE, 0xAD]);
}
#[test]
fn encode_all_client_msg_variants() {
let messages: Vec<ClientMsg> = vec![
ClientMsg::Input(vec![0x61, 0x62]),
ClientMsg::Resize {
cols: 120,
rows: 40,
},
ClientMsg::Detach,
ClientMsg::ListSessions {
version: PROTOCOL_VERSION,
},
ClientMsg::Connect {
version: PROTOCOL_VERSION,
name: "test".into(),
history: 500,
cols: 80,
rows: 24,
mode: ConnectMode::CreateOrAttach,
},
ClientMsg::KillSession {
version: PROTOCOL_VERSION,
name: "kill-me".into(),
},
ClientMsg::RefreshScreen,
];
for msg in &messages {
let encoded = encode(msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let _decoded: ClientMsg = decode(data).unwrap();
}
}
#[test]
fn encode_all_server_msg_variants() {
let messages: Vec<ServerMsg> = vec![
ServerMsg::ScreenUpdate(vec![0x1b, 0x5b, 0x48]),
ServerMsg::History(vec![vec![0x41], vec![0x42]]),
ServerMsg::SessionList(vec![SessionInfo {
name: "s1".into(),
pid: 42,
cols: 80,
rows: 24,
}]),
ServerMsg::SessionEnded { exit_code: Some(0) },
ServerMsg::SessionEnded { exit_code: None },
ServerMsg::Error("test error".into()),
ServerMsg::Connected {
name: "test".into(),
new_session: true,
},
ServerMsg::SessionKilled {
name: "dead".into(),
},
ServerMsg::Passthrough(vec![0x07]),
];
for msg in &messages {
let encoded = encode(msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let _decoded: ServerMsg = decode(data).unwrap();
}
}
#[test]
fn encode_empty_collections() {
let msg = ServerMsg::History(vec![]);
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ServerMsg = decode(data).unwrap();
match decoded {
ServerMsg::History(lines) => assert!(lines.is_empty()),
other => panic!("expected History, got {:?}", other),
}
let msg = ServerMsg::SessionList(vec![]);
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ServerMsg = decode(data).unwrap();
match decoded {
ServerMsg::SessionList(list) => assert!(list.is_empty()),
other => panic!("expected SessionList, got {:?}", other),
}
}
#[test]
fn encode_large_input() {
let data = vec![0x41u8; 65536];
let msg = ClientMsg::Input(data.clone());
let encoded = encode(&msg).unwrap();
let (frame_data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ClientMsg = decode(frame_data).unwrap();
match decoded {
ClientMsg::Input(d) => assert_eq!(d.len(), 65536),
other => panic!("expected Input, got {:?}", other),
}
}
#[test]
fn encode_decode_connect_mode_create_only() {
let msg = ClientMsg::Connect {
version: PROTOCOL_VERSION,
name: "new-session".into(),
history: 500,
cols: 120,
rows: 40,
mode: ConnectMode::CreateOnly,
};
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ClientMsg = decode(data).unwrap();
match decoded {
ClientMsg::Connect { mode, .. } => assert_eq!(mode, ConnectMode::CreateOnly),
_ => panic!("wrong variant"),
}
}
#[test]
fn encode_decode_connect_mode_attach_only() {
let msg = ClientMsg::Connect {
version: PROTOCOL_VERSION,
name: "existing".into(),
history: 0,
cols: 80,
rows: 24,
mode: ConnectMode::AttachOnly,
};
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ClientMsg = decode(data).unwrap();
match decoded {
ClientMsg::Connect { mode, .. } => assert_eq!(mode, ConnectMode::AttachOnly),
_ => panic!("wrong variant"),
}
}
#[test]
fn encode_decode_session_ended_with_code() {
for code in [Some(0), Some(42), Some(-1), None] {
let msg = ServerMsg::SessionEnded { exit_code: code };
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ServerMsg = decode(data).unwrap();
match decoded {
ServerMsg::SessionEnded { exit_code } => assert_eq!(exit_code, code),
other => panic!("expected SessionEnded, got {:?}", other),
}
}
}
#[test]
fn encode_decode_carries_protocol_version() {
let msg = ClientMsg::ListSessions {
version: PROTOCOL_VERSION,
};
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ClientMsg = decode(data).unwrap();
match decoded {
ClientMsg::ListSessions { version } => assert_eq!(version, PROTOCOL_VERSION),
other => panic!("expected ListSessions, got {:?}", other),
}
let msg = ClientMsg::KillSession {
version: PROTOCOL_VERSION + 1,
name: "s".into(),
};
let encoded = encode(&msg).unwrap();
let (data, _) = decode_frame(&encoded).unwrap().unwrap();
let decoded: ClientMsg = decode(data).unwrap();
match decoded {
ClientMsg::KillSession { version, .. } => {
assert_eq!(version, PROTOCOL_VERSION + 1);
}
other => panic!("expected KillSession, got {:?}", other),
}
}
#[test]
fn decode_rejects_corrupted_payload() {
let mut buf = Vec::new();
buf.extend_from_slice(&10u32.to_be_bytes());
buf.extend_from_slice(&[0xFF; 10]);
let (data, _) = decode_frame(&buf).unwrap().unwrap();
let result: Result<ClientMsg, _> = decode(data);
assert!(
result.is_err(),
"corrupted payload should fail deserialization"
);
match result.unwrap_err() {
ProtocolError::Deserialize(_) => {} other => panic!("expected Deserialize error, got {:?}", other),
}
}
}