use std::io::{self, Read, Write};
pub use vmette_proto::{Action, ResponseHeader, ScrollDirection};
const MAX_HEADER_LEN: u32 = 1 << 20;
const MAX_PAYLOAD_LEN: u32 = 64 << 20;
pub fn write_frame<W: Write>(
w: &mut W,
req_id: u32,
header: &[u8],
payload: &[u8],
) -> io::Result<()> {
let len = u32::try_from(header.len())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "header too large"))?;
w.write_all(&req_id.to_le_bytes())?;
w.write_all(&len.to_le_bytes())?;
w.write_all(header)?;
if !payload.is_empty() {
w.write_all(payload)?;
}
w.flush()
}
pub fn read_header<R: Read>(r: &mut R) -> io::Result<(u32, Vec<u8>)> {
let mut id_buf = [0u8; 4];
r.read_exact(&mut id_buf)?;
let req_id = u32::from_le_bytes(id_buf);
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf);
if len > MAX_HEADER_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("header length {len} exceeds cap {MAX_HEADER_LEN}"),
));
}
let mut buf = vec![0u8; len as usize];
r.read_exact(&mut buf)?;
Ok((req_id, buf))
}
pub fn read_payload<R: Read>(r: &mut R, len: u32) -> io::Result<Vec<u8>> {
if len > MAX_PAYLOAD_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("payload length {len} exceeds cap {MAX_PAYLOAD_LEN}"),
));
}
let mut buf = vec![0u8; len as usize];
r.read_exact(&mut buf)?;
Ok(buf)
}
pub fn send_action<W: Write>(w: &mut W, req_id: u32, action: &Action) -> io::Result<()> {
let header = serde_json::to_vec(action)?;
write_frame(w, req_id, &header, &[])
}
pub fn read_response<R: Read>(r: &mut R) -> io::Result<(u32, ResponseHeader, Vec<u8>)> {
let (req_id, header_bytes) = read_header(r)?;
let header: ResponseHeader = serde_json::from_slice(&header_bytes)?;
let payload = if header.payload_len > 0 {
read_payload(r, header.payload_len)?
} else {
Vec::new()
};
Ok((req_id, header, payload))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn frame_round_trip_header_only() {
let mut buf = Vec::new();
write_frame(&mut buf, 7, b"hello", &[]).unwrap();
assert_eq!(&buf[..4], &7u32.to_le_bytes());
assert_eq!(&buf[4..8], &5u32.to_le_bytes());
assert_eq!(&buf[8..], b"hello");
let mut cur = std::io::Cursor::new(buf);
let (id, header) = read_header(&mut cur).unwrap();
assert_eq!(id, 7);
assert_eq!(header, b"hello");
}
#[test]
fn frame_round_trip_with_payload() {
let header = br#"{"ok":true,"payload_len":4}"#;
let payload = [0xDE, 0xAD, 0xBE, 0xEF];
let mut buf = Vec::new();
write_frame(&mut buf, 42, header, &payload).unwrap();
let mut cur = std::io::Cursor::new(buf);
let (id, h, p) = read_response(&mut cur).unwrap();
assert_eq!(id, 42);
assert!(h.ok);
assert_eq!(h.payload_len, 4);
assert_eq!(p, payload);
}
#[test]
fn send_action_then_read_back_as_frame() {
let mut buf = Vec::new();
send_action(&mut buf, 3, &Action::LeftClick).unwrap();
let mut cur = std::io::Cursor::new(buf);
let (id, header) = read_header(&mut cur).unwrap();
assert_eq!(id, 3);
let a: Action = serde_json::from_slice(&header).unwrap();
assert_eq!(a, Action::LeftClick);
}
#[test]
fn oversized_header_length_is_rejected() {
let mut buf = Vec::new();
buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&(MAX_HEADER_LEN + 1).to_le_bytes());
let mut cur = std::io::Cursor::new(buf);
let err = read_header(&mut cur).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test]
fn responses_demultiplex_by_req_id_out_of_order() {
let mut buf = Vec::new();
let h20 = br#"{"ok":true,"x":2,"y":0,"payload_len":0}"#;
let h10 = br#"{"ok":true,"x":1,"y":0,"payload_len":3}"#;
write_frame(&mut buf, 20, h20, &[]).unwrap();
write_frame(&mut buf, 10, h10, &[0xAA, 0xBB, 0xCC]).unwrap();
let mut cur = std::io::Cursor::new(buf);
let (id_a, ha, pa) = read_response(&mut cur).unwrap();
assert_eq!(id_a, 20);
assert_eq!(ha.x, Some(2));
assert!(pa.is_empty());
let (id_b, hb, pb) = read_response(&mut cur).unwrap();
assert_eq!(id_b, 10);
assert_eq!(hb.x, Some(1));
assert_eq!(pb, [0xAA, 0xBB, 0xCC]);
}
}