Skip to main content

aimux_protocol/
lib.rs

1pub mod messages;
2pub mod types;
3
4pub use messages::{ClientInput, Request, Response, ServerPush};
5pub use types::{
6    CaptureFilter, CaptureResult, CursorPos, LayoutInfo, LayoutListEntry, LayoutSource, MarkInfo,
7    PaneEvent, PaneId, PaneInfo, PaneLayout, ProcessInfo, ScreenCell, ScreenCellAttrs, ScreenColor,
8    ScrollbackRange, ServerStatusInfo, SessionId, SessionInfo, WindowBarInfo, WindowId, WindowInfo,
9};
10
11use anyhow::Result;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13
14pub const PROTOCOL_VERSION: u8 = 3;
15const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; // 16MB safety limit
16
17pub async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, payload: &[u8]) -> Result<()> {
18    writer.write_all(&[PROTOCOL_VERSION]).await?;
19    let len = payload.len() as u32;
20    writer.write_all(&len.to_be_bytes()).await?;
21    writer.write_all(payload).await?;
22    writer.flush().await?;
23    Ok(())
24}
25
26pub async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<Vec<u8>> {
27    let mut version_buf = [0u8; 1];
28    reader.read_exact(&mut version_buf).await?;
29    let version = version_buf[0];
30    if version != PROTOCOL_VERSION {
31        anyhow::bail!(
32            "protocol version mismatch: remote={} local={}, please update aimux",
33            version, PROTOCOL_VERSION
34        );
35    }
36    let mut len_buf = [0u8; 4];
37    reader.read_exact(&mut len_buf).await?;
38    let len = u32::from_be_bytes(len_buf) as usize;
39    if len > MAX_FRAME_SIZE {
40        anyhow::bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
41    }
42    let mut payload = vec![0u8; len];
43    reader.read_exact(&mut payload).await?;
44    Ok(payload)
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[tokio::test]
52    async fn frame_roundtrip() {
53        let payload = b"hello world";
54        let mut buf = Vec::new();
55        write_frame(&mut buf, payload).await.unwrap();
56        let mut cursor = std::io::Cursor::new(buf);
57        let result = read_frame(&mut cursor).await.unwrap();
58        assert_eq!(result, payload);
59    }
60
61    #[tokio::test]
62    async fn version_byte_present() {
63        let mut buf = Vec::new();
64        write_frame(&mut buf, b"test").await.unwrap();
65        assert_eq!(buf[0], PROTOCOL_VERSION);
66        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
67        assert_eq!(len, 4);
68    }
69
70    #[tokio::test]
71    async fn version_mismatch_error() {
72        let mut buf = Vec::new();
73        buf.push(0x09);
74        buf.extend_from_slice(&4u32.to_be_bytes());
75        buf.extend_from_slice(b"test");
76
77        let mut cursor = std::io::Cursor::new(buf);
78        let err = read_frame(&mut cursor).await.unwrap_err();
79        assert!(err.to_string().contains("version mismatch"));
80        assert!(err.to_string().contains("remote=9"));
81        assert!(err.to_string().contains("local=3"));
82    }
83
84    #[tokio::test]
85    async fn frame_too_large_error() {
86        let mut buf = Vec::new();
87        buf.push(PROTOCOL_VERSION);
88        buf.extend_from_slice(&(MAX_FRAME_SIZE as u32 + 1).to_be_bytes());
89
90        let mut cursor = std::io::Cursor::new(buf);
91        let err = read_frame(&mut cursor).await.unwrap_err();
92        assert!(err.to_string().contains("frame too large"));
93    }
94}