Skip to main content

steamroom_cli/daemon/
framing.rs

1//! Async length-prefixed rkyv framing for daemon IPC.
2//!
3//! Wire format per frame:
4//! ```text
5//!   u16 LE   proto_version
6//!   u32 LE   payload_length (<= MAX_FRAME_BYTES)
7//!   [u8]     rkyv-archived Frame
8//! ```
9//! The version is checked before deserialization so mismatched daemons
10//! and clients fail with a clear error rather than rkyv validation noise.
11
12use rkyv::rancor;
13use rkyv::util::AlignedVec;
14use tokio::io::AsyncReadExt;
15use tokio::io::AsyncWriteExt;
16
17use crate::daemon::proto::Frame;
18use crate::daemon::proto::PROTO_VERSION;
19use crate::errors::CliError;
20
21pub const MAX_FRAME_BYTES: u32 = 16 * 1024 * 1024;
22
23/// Map `UnexpectedEof` to `CliError::SocketClosed`; all other I/O errors
24/// become `CliError::Io`.  Used by `read_frame` for all three reads so that
25/// a peer disconnecting mid-frame always yields `SocketClosed` rather than a
26/// raw I/O error.
27async fn read_exact_or_closed<R>(r: &mut R, buf: &mut [u8]) -> Result<(), CliError>
28where
29    R: AsyncReadExt + Unpin,
30{
31    match r.read_exact(buf).await {
32        Ok(_) => Ok(()),
33        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Err(CliError::SocketClosed),
34        Err(e) => Err(CliError::Io(e)),
35    }
36}
37
38pub async fn write_frame<W>(w: &mut W, frame: &Frame) -> Result<(), CliError>
39where
40    W: AsyncWriteExt + Unpin,
41{
42    let bytes = rkyv::to_bytes::<rancor::Error>(frame)
43        .map_err(|e| CliError::MalformedFrame(e.to_string()))?;
44    let len_usize = bytes.len();
45    if len_usize > MAX_FRAME_BYTES as usize {
46        return Err(CliError::FrameTooLarge {
47            len_bytes: len_usize as u64,
48            limit_bytes: MAX_FRAME_BYTES as u64,
49        });
50    }
51    // len_usize <= MAX_FRAME_BYTES (u32-sized), so the cast is lossless.
52    let len: u32 = len_usize as u32;
53    w.write_all(&PROTO_VERSION.to_le_bytes())
54        .await
55        .map_err(CliError::Io)?;
56    w.write_all(&len.to_le_bytes())
57        .await
58        .map_err(CliError::Io)?;
59    w.write_all(&bytes).await.map_err(CliError::Io)?;
60    w.flush().await.map_err(CliError::Io)?;
61    Ok(())
62}
63
64/// Read one length-prefixed rkyv-archived `Frame` from `r`.
65///
66/// # Cancel safety
67///
68/// This function is NOT cancel-safe. `AsyncReadExt::read_exact` may
69/// partially consume bytes from the stream before the future is dropped,
70/// leaving the connection in an undefined framing state. Callers that
71/// race this against another future (e.g. via `tokio::select!` on a
72/// shutdown signal) MUST abort the connection on cancellation rather
73/// than re-entering `read_frame` on the same stream.
74pub async fn read_frame<R>(r: &mut R) -> Result<Frame, CliError>
75where
76    R: AsyncReadExt + Unpin,
77{
78    let mut ver_buf = [0u8; 2];
79    read_exact_or_closed(r, &mut ver_buf).await?;
80    let peer = u16::from_le_bytes(ver_buf);
81    if peer != PROTO_VERSION {
82        return Err(CliError::ProtocolVersionMismatch {
83            peer,
84            ours: PROTO_VERSION,
85        });
86    }
87
88    let mut len_buf = [0u8; 4];
89    read_exact_or_closed(r, &mut len_buf).await?;
90    let len = u32::from_le_bytes(len_buf);
91    if len > MAX_FRAME_BYTES {
92        return Err(CliError::FrameTooLarge {
93            len_bytes: len as u64,
94            limit_bytes: MAX_FRAME_BYTES as u64,
95        });
96    }
97
98    // rkyv's checked access demands 16-byte alignment.
99    let mut buf = AlignedVec::<16>::with_capacity(len as usize);
100    buf.resize(len as usize, 0);
101    read_exact_or_closed(r, &mut buf).await?;
102    rkyv::from_bytes::<Frame, rancor::Error>(&buf)
103        .map_err(|e| CliError::MalformedFrame(e.to_string()))
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::daemon::proto::Event;
110    use crate::daemon::proto::JobId;
111    use crate::daemon::proto::LogLevel;
112    use crate::daemon::proto::Response;
113    use tokio::io::duplex;
114
115    #[tokio::test]
116    async fn round_trip_through_duplex() {
117        let (mut a, mut b) = duplex(64 * 1024);
118        let frame = Frame::Response(Response::Stopping);
119
120        write_frame(&mut a, &frame).await.unwrap();
121        let back = read_frame(&mut b).await.unwrap();
122        assert!(matches!(back, Frame::Response(Response::Stopping)));
123    }
124
125    #[tokio::test]
126    async fn rejects_mismatched_version() {
127        let (mut a, mut b) = duplex(64);
128        // Hand-craft a frame with version 999.
129        a.write_all(&999u16.to_le_bytes()).await.unwrap();
130        a.write_all(&0u32.to_le_bytes()).await.unwrap();
131        a.flush().await.unwrap();
132        let err = read_frame(&mut b).await.unwrap_err();
133        match err {
134            CliError::ProtocolVersionMismatch { peer, ours } => {
135                assert_eq!(peer, 999);
136                assert_eq!(ours, PROTO_VERSION);
137            }
138            other => panic!("wrong error: {other:?}"),
139        }
140    }
141
142    #[tokio::test]
143    async fn rejects_oversized_length() {
144        let (mut a, mut b) = duplex(64);
145        a.write_all(&PROTO_VERSION.to_le_bytes()).await.unwrap();
146        a.write_all(&(MAX_FRAME_BYTES + 1).to_le_bytes())
147            .await
148            .unwrap();
149        a.flush().await.unwrap();
150        let err = read_frame(&mut b).await.unwrap_err();
151        assert!(matches!(err, CliError::FrameTooLarge { .. }));
152    }
153
154    #[tokio::test]
155    async fn event_with_log_round_trips() {
156        let (mut a, mut b) = duplex(64 * 1024);
157        let frame = Frame::Event(Event::Log {
158            job_id: Some(JobId(1)),
159            level: LogLevel::Info,
160            target: "t".into(),
161            message: "hello".into(),
162        });
163        write_frame(&mut a, &frame).await.unwrap();
164        let back = read_frame(&mut b).await.unwrap();
165        match back {
166            Frame::Event(Event::Log { message, .. }) => assert_eq!(message, "hello"),
167            other => panic!("wrong: {other:?}"),
168        }
169    }
170}