1pub mod messages;
2pub mod types;
3
4pub use messages::{ClientInput, Request, Response, ServerPush};
5pub use types::{
6 CaptureFilter, CaptureResult, CursorPos, LayoutInfo, LayoutListEntry, LayoutSource, MarkInfo,
7 PaneEvent, PaneId, PaneInfo, PaneLayout, ProcessInfo, ScreenCell, ScreenCellAttrs, ScreenColor,
8 ScrollbackRange, ServerStatusInfo, SessionId, SessionInfo, WindowBarInfo, WindowId, WindowInfo,
9};
10
11use anyhow::Result;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13
14pub const PROTOCOL_VERSION: u8 = 3;
15const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; pub async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, payload: &[u8]) -> Result<()> {
18 writer.write_all(&[PROTOCOL_VERSION]).await?;
19 let len = payload.len() as u32;
20 writer.write_all(&len.to_be_bytes()).await?;
21 writer.write_all(payload).await?;
22 writer.flush().await?;
23 Ok(())
24}
25
26pub async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<Vec<u8>> {
27 let mut version_buf = [0u8; 1];
28 reader.read_exact(&mut version_buf).await?;
29 let version = version_buf[0];
30 if version != PROTOCOL_VERSION {
31 anyhow::bail!(
32 "protocol version mismatch: remote={} local={}, please update aimux",
33 version, PROTOCOL_VERSION
34 );
35 }
36 let mut len_buf = [0u8; 4];
37 reader.read_exact(&mut len_buf).await?;
38 let len = u32::from_be_bytes(len_buf) as usize;
39 if len > MAX_FRAME_SIZE {
40 anyhow::bail!("frame too large: {} bytes (max {})", len, MAX_FRAME_SIZE);
41 }
42 let mut payload = vec![0u8; len];
43 reader.read_exact(&mut payload).await?;
44 Ok(payload)
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[tokio::test]
52 async fn frame_roundtrip() {
53 let payload = b"hello world";
54 let mut buf = Vec::new();
55 write_frame(&mut buf, payload).await.unwrap();
56 let mut cursor = std::io::Cursor::new(buf);
57 let result = read_frame(&mut cursor).await.unwrap();
58 assert_eq!(result, payload);
59 }
60
61 #[tokio::test]
62 async fn version_byte_present() {
63 let mut buf = Vec::new();
64 write_frame(&mut buf, b"test").await.unwrap();
65 assert_eq!(buf[0], PROTOCOL_VERSION);
66 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
67 assert_eq!(len, 4);
68 }
69
70 #[tokio::test]
71 async fn version_mismatch_error() {
72 let mut buf = Vec::new();
73 buf.push(0x09);
74 buf.extend_from_slice(&4u32.to_be_bytes());
75 buf.extend_from_slice(b"test");
76
77 let mut cursor = std::io::Cursor::new(buf);
78 let err = read_frame(&mut cursor).await.unwrap_err();
79 assert!(err.to_string().contains("version mismatch"));
80 assert!(err.to_string().contains("remote=9"));
81 assert!(err.to_string().contains("local=3"));
82 }
83
84 #[tokio::test]
85 async fn frame_too_large_error() {
86 let mut buf = Vec::new();
87 buf.push(PROTOCOL_VERSION);
88 buf.extend_from_slice(&(MAX_FRAME_SIZE as u32 + 1).to_be_bytes());
89
90 let mut cursor = std::io::Cursor::new(buf);
91 let err = read_frame(&mut cursor).await.unwrap_err();
92 assert!(err.to_string().contains("frame too large"));
93 }
94}