Skip to main content

atd_protocol/
wire.rs

1use serde::{Serialize, de::DeserializeOwned};
2use std::time::Duration;
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4
5const MAX_FRAME_BYTES: usize = 10 * 1024 * 1024;
6
7/// Errors from the wire codec.
8///
9/// SP-concurrency-baseline §5.2: callers need to distinguish "peer sent garbage"
10/// (fatal; close connection) from "timed out before peer wrote" (retryable;
11/// reissue). The legacy `std::io::Error` collapsed both into `InvalidData`,
12/// leaving call sites to string-match the message — fragile and lossy.
13#[derive(thiserror::Error, Debug)]
14pub enum WireError {
15    #[error("io: {0}")]
16    Io(#[from] std::io::Error),
17    #[error("decode: {0}")]
18    Decode(#[from] serde_json::Error),
19    #[error("frame length overflow: {0} bytes (max {MAX_FRAME_BYTES})")]
20    LengthOverflow(u32),
21    #[error("timeout after {0:?}")]
22    Timeout(Duration),
23}
24
25impl From<WireError> for std::io::Error {
26    /// Compat shim for callers still on `std::io::Result`. Preserves the
27    /// underlying io error when present; wraps other variants as `InvalidData`.
28    fn from(e: WireError) -> Self {
29        match e {
30            WireError::Io(io) => io,
31            other => std::io::Error::new(std::io::ErrorKind::InvalidData, other.to_string()),
32        }
33    }
34}
35
36/// Write one length-prefixed JSON frame, with an optional total-operation deadline.
37///
38/// Deadline covers `write_all(len) + write_all(body) + flush` as one unit;
39/// a partial write that times out leaves the writer in an unspecified state
40/// (caller should close the connection).
41pub async fn write_frame_with_deadline<W, T>(
42    writer: &mut W,
43    msg: &T,
44    deadline: Option<Duration>,
45) -> Result<(), WireError>
46where
47    W: AsyncWriteExt + Unpin,
48    T: Serialize,
49{
50    let fut = async {
51        let body = serde_json::to_vec(msg)?;
52        let len = u32::try_from(body.len()).map_err(|_| WireError::LengthOverflow(u32::MAX))?;
53        if body.len() > MAX_FRAME_BYTES {
54            return Err(WireError::LengthOverflow(len));
55        }
56        writer.write_all(&len.to_be_bytes()).await?;
57        writer.write_all(&body).await?;
58        writer.flush().await?;
59        Ok::<(), WireError>(())
60    };
61    match deadline {
62        None => fut.await,
63        Some(d) => tokio::time::timeout(d, fut)
64            .await
65            .map_err(|_| WireError::Timeout(d))?,
66    }
67}
68
69/// Read one length-prefixed JSON frame, with an optional total-operation deadline.
70pub async fn read_frame_with_deadline<R, T>(
71    reader: &mut R,
72    deadline: Option<Duration>,
73) -> Result<T, WireError>
74where
75    R: AsyncReadExt + Unpin,
76    T: DeserializeOwned,
77{
78    let fut = async {
79        let mut len_buf = [0u8; 4];
80        reader.read_exact(&mut len_buf).await?;
81        let len = u32::from_be_bytes(len_buf);
82        if len as usize > MAX_FRAME_BYTES {
83            return Err(WireError::LengthOverflow(len));
84        }
85        let mut body = vec![0u8; len as usize];
86        reader.read_exact(&mut body).await?;
87        Ok::<T, WireError>(serde_json::from_slice(&body)?)
88    };
89    match deadline {
90        None => fut.await,
91        Some(d) => tokio::time::timeout(d, fut)
92            .await
93            .map_err(|_| WireError::Timeout(d))?,
94    }
95}
96
97/// Back-compat wrapper for callers still on `std::io::Result`.
98pub async fn write_frame<W, T>(writer: &mut W, msg: &T) -> std::io::Result<()>
99where
100    W: AsyncWriteExt + Unpin,
101    T: Serialize,
102{
103    write_frame_with_deadline(writer, msg, None)
104        .await
105        .map_err(Into::into)
106}
107
108/// Back-compat wrapper for callers still on `std::io::Result`.
109pub async fn read_frame<R, T>(reader: &mut R) -> std::io::Result<T>
110where
111    R: AsyncReadExt + Unpin,
112    T: DeserializeOwned,
113{
114    read_frame_with_deadline(reader, None)
115        .await
116        .map_err(Into::into)
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use serde::Deserialize;
123
124    #[derive(Serialize, Deserialize, PartialEq, Debug)]
125    struct M {
126        kind: String,
127        n: u32,
128    }
129
130    #[tokio::test]
131    async fn write_then_read_roundtrip() {
132        let msg = M {
133            kind: "ping".into(),
134            n: 7,
135        };
136        let mut buf: Vec<u8> = Vec::new();
137        write_frame(&mut buf, &msg).await.unwrap();
138
139        let mut cursor = std::io::Cursor::new(&buf);
140        let back: M = read_frame(&mut cursor).await.unwrap();
141        assert_eq!(back, msg);
142    }
143
144    #[tokio::test]
145    async fn frame_uses_big_endian_u32_prefix() {
146        let msg = M {
147            kind: "x".into(),
148            n: 1,
149        };
150        let mut buf: Vec<u8> = Vec::new();
151        write_frame(&mut buf, &msg).await.unwrap();
152        let body_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
153        assert_eq!(body_len, buf.len() - 4);
154    }
155
156    #[tokio::test]
157    async fn oversized_frame_errors() {
158        // Craft a header claiming 20 MiB; reader should refuse before allocating.
159        let mut header = Vec::new();
160        let bogus_len: u32 = 20 * 1024 * 1024;
161        header.extend_from_slice(&bogus_len.to_be_bytes());
162        let mut cursor = std::io::Cursor::new(header);
163        let err = read_frame::<_, M>(&mut cursor).await.unwrap_err();
164        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
165    }
166
167    // ---- SP-concurrency-baseline Phase B Task 1 ----
168
169    #[tokio::test]
170    async fn read_frame_with_deadline_returns_timeout_on_no_data() {
171        // Keep both halves alive but never write — read_exact will block
172        // indefinitely, giving the deadline a real chance to fire.
173        let (_producer, mut consumer) = tokio::io::duplex(64);
174        let err = read_frame_with_deadline::<_, M>(&mut consumer, Some(Duration::from_millis(50)))
175            .await
176            .unwrap_err();
177        assert!(matches!(err, WireError::Timeout(_)), "got {err:?}");
178    }
179
180    #[tokio::test]
181    async fn read_frame_with_deadline_succeeds_within_deadline() {
182        let msg = M {
183            kind: "ok".into(),
184            n: 42,
185        };
186        let mut buf: Vec<u8> = Vec::new();
187        write_frame(&mut buf, &msg).await.unwrap();
188        let mut cursor = std::io::Cursor::new(&buf);
189        let back: M = read_frame_with_deadline(&mut cursor, Some(Duration::from_millis(500)))
190            .await
191            .unwrap();
192        assert_eq!(back, msg);
193    }
194
195    #[tokio::test]
196    async fn read_frame_with_deadline_none_means_unbounded() {
197        // Regression: deadline=None must behave like the unbounded helper.
198        let msg = M {
199            kind: "k".into(),
200            n: 1,
201        };
202        let mut buf: Vec<u8> = Vec::new();
203        write_frame(&mut buf, &msg).await.unwrap();
204        let mut cursor = std::io::Cursor::new(&buf);
205        let back: M = read_frame_with_deadline(&mut cursor, None).await.unwrap();
206        assert_eq!(back, msg);
207    }
208
209    #[tokio::test]
210    async fn write_frame_with_deadline_succeeds_to_in_memory_buf() {
211        let msg = M {
212            kind: "k".into(),
213            n: 1,
214        };
215        let mut buf: Vec<u8> = Vec::new();
216        write_frame_with_deadline(&mut buf, &msg, Some(Duration::from_millis(500)))
217            .await
218            .unwrap();
219        assert!(buf.len() > 4);
220    }
221
222    #[tokio::test]
223    async fn write_frame_with_deadline_returns_timeout_on_blocked_writer() {
224        // A DuplexStream with a 4-byte buffer; we write a frame whose body
225        // is ~80 bytes so the second write_all blocks waiting for the reader
226        // to drain. With no reader, the write times out.
227        let (mut producer, _consumer) = tokio::io::duplex(4);
228        let big_msg = M {
229            kind: "x".repeat(50),
230            n: 1,
231        };
232        let err =
233            write_frame_with_deadline(&mut producer, &big_msg, Some(Duration::from_millis(50)))
234                .await
235                .unwrap_err();
236        assert!(matches!(err, WireError::Timeout(_)), "got {err:?}");
237    }
238
239    #[tokio::test]
240    async fn wire_error_into_io_error_preserves_io_underlying() {
241        let io_err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "nope");
242        let wire = WireError::Io(io_err);
243        let back: std::io::Error = wire.into();
244        assert_eq!(back.kind(), std::io::ErrorKind::PermissionDenied);
245    }
246
247    #[tokio::test]
248    async fn wire_error_into_io_error_wraps_non_io_as_invalid_data() {
249        let wire = WireError::Timeout(Duration::from_millis(100));
250        let back: std::io::Error = wire.into();
251        assert_eq!(back.kind(), std::io::ErrorKind::InvalidData);
252        assert!(back.to_string().contains("timeout"));
253    }
254
255    #[tokio::test]
256    async fn length_overflow_distinct_from_io_error() {
257        let mut header = Vec::new();
258        let bogus_len: u32 = 20 * 1024 * 1024;
259        header.extend_from_slice(&bogus_len.to_be_bytes());
260        let mut cursor = std::io::Cursor::new(header);
261        let err = read_frame_with_deadline::<_, M>(&mut cursor, None)
262            .await
263            .unwrap_err();
264        assert!(
265            matches!(err, WireError::LengthOverflow(n) if n == bogus_len),
266            "got {err:?}"
267        );
268    }
269}