Skip to main content

a3s_common/transport/
codec.rs

1//! Async frame reader/writer for any `AsyncRead`/`AsyncWrite` stream.
2//!
3//! Wraps the [`Frame`] wire format with buffered async I/O.
4
5use bytes::{Buf, BytesMut};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use super::frame::Frame;
9use super::TransportError;
10
11const INITIAL_BUF_CAPACITY: usize = 8 * 1024;
12
13/// Async frame reader over any `AsyncRead` stream.
14///
15/// Buffers incoming bytes and yields complete [`Frame`]s.
16#[derive(Debug)]
17pub struct FrameReader<R> {
18    inner: R,
19    buf: BytesMut,
20}
21
22impl<R: AsyncRead + Unpin> FrameReader<R> {
23    /// Wrap a reader.
24    pub fn new(inner: R) -> Self {
25        Self {
26            inner,
27            buf: BytesMut::with_capacity(INITIAL_BUF_CAPACITY),
28        }
29    }
30
31    /// Read the next frame. Returns `None` on clean EOF.
32    pub async fn read_frame(&mut self) -> Result<Option<Frame>, TransportError> {
33        loop {
34            // Try to decode a frame from the buffer
35            if let Some((frame, consumed)) = Frame::decode(&self.buf)? {
36                self.buf.advance(consumed);
37                return Ok(Some(frame));
38            }
39
40            // Need more data — read from the stream
41            let n = self
42                .inner
43                .read_buf(&mut self.buf)
44                .await
45                .map_err(|e| TransportError::RecvFailed(e.to_string()))?;
46
47            if n == 0 {
48                // EOF
49                if self.buf.is_empty() {
50                    return Ok(None);
51                }
52                return Err(TransportError::RecvFailed(
53                    "Connection closed with incomplete frame".to_string(),
54                ));
55            }
56        }
57    }
58
59    /// Get a reference to the inner reader.
60    pub fn inner(&self) -> &R {
61        &self.inner
62    }
63
64    /// Consume the reader and return the inner stream.
65    pub fn into_inner(self) -> R {
66        self.inner
67    }
68}
69
70/// Async frame writer over any `AsyncWrite` stream.
71#[derive(Debug)]
72pub struct FrameWriter<W> {
73    inner: W,
74}
75
76impl<W: AsyncWrite + Unpin> FrameWriter<W> {
77    /// Wrap a writer.
78    pub fn new(inner: W) -> Self {
79        Self { inner }
80    }
81
82    /// Write a frame to the stream.
83    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), TransportError> {
84        let encoded = frame.encode()?;
85        self.inner
86            .write_all(&encoded)
87            .await
88            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
89        self.inner
90            .flush()
91            .await
92            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
93        Ok(())
94    }
95
96    /// Write a data frame with the given payload.
97    pub async fn write_data(&mut self, payload: &[u8]) -> Result<(), TransportError> {
98        self.write_frame(&Frame::data(payload.to_vec())).await
99    }
100
101    /// Write a control frame with the given payload.
102    pub async fn write_control(&mut self, payload: &[u8]) -> Result<(), TransportError> {
103        self.write_frame(&Frame::control(payload.to_vec())).await
104    }
105
106    /// Write a JSON-serializable value as a data frame.
107    pub async fn write_json<T: serde::Serialize>(
108        &mut self,
109        value: &T,
110    ) -> Result<(), TransportError> {
111        let payload = serde_json::to_vec(value)
112            .map_err(|e| TransportError::SendFailed(format!("JSON serialize: {}", e)))?;
113        self.write_data(&payload).await
114    }
115
116    /// Get a reference to the inner writer.
117    pub fn inner(&self) -> &W {
118        &self.inner
119    }
120
121    /// Consume the writer and return the inner stream.
122    pub fn into_inner(self) -> W {
123        self.inner
124    }
125}
126
127/// Combined async frame reader + writer over a split stream.
128#[derive(Debug)]
129pub struct FrameCodec<R, W> {
130    pub reader: FrameReader<R>,
131    pub writer: FrameWriter<W>,
132}
133
134impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> FrameCodec<R, W> {
135    /// Create from separate read and write halves.
136    pub fn new(reader: R, writer: W) -> Self {
137        Self {
138            reader: FrameReader::new(reader),
139            writer: FrameWriter::new(writer),
140        }
141    }
142
143    /// Read the next frame.
144    pub async fn read_frame(&mut self) -> Result<Option<Frame>, TransportError> {
145        self.reader.read_frame().await
146    }
147
148    /// Write a frame.
149    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), TransportError> {
150        self.writer.write_frame(frame).await
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Tests
156// ---------------------------------------------------------------------------
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use super::super::frame::FrameType;
162
163    #[tokio::test]
164    async fn test_reader_writer_roundtrip() {
165        let (client, server) = tokio::io::duplex(1024);
166        let (cr, cw) = tokio::io::split(client);
167        let (sr, _sw) = tokio::io::split(server);
168
169        let mut writer = FrameWriter::new(cw);
170        let mut reader = FrameReader::new(sr);
171
172        writer.write_data(b"hello").await.unwrap();
173        writer.write_frame(&Frame::heartbeat()).await.unwrap();
174        // Drop both writer and unused read half so the DuplexStream is fully
175        // released and the server read half sees EOF.
176        drop(writer);
177        drop(cr);
178
179        let f1 = reader.read_frame().await.unwrap().unwrap();
180        assert_eq!(f1.frame_type, FrameType::Data);
181        assert_eq!(f1.payload, b"hello");
182
183        let f2 = reader.read_frame().await.unwrap().unwrap();
184        assert_eq!(f2.frame_type, FrameType::Heartbeat);
185        assert!(f2.payload.is_empty());
186
187        let f3 = reader.read_frame().await.unwrap();
188        assert!(f3.is_none()); // EOF
189    }
190
191    #[tokio::test]
192    async fn test_codec_bidirectional() {
193        let (a, b) = tokio::io::duplex(1024);
194        let (ar, aw) = tokio::io::split(a);
195        let (br, bw) = tokio::io::split(b);
196
197        let mut codec_a = FrameCodec::new(ar, aw);
198        let mut codec_b = FrameCodec::new(br, bw);
199
200        codec_a.writer.write_data(b"ping").await.unwrap();
201        let frame = codec_b.reader.read_frame().await.unwrap().unwrap();
202        assert_eq!(frame.payload, b"ping");
203
204        codec_b.writer.write_data(b"pong").await.unwrap();
205        let frame = codec_a.reader.read_frame().await.unwrap().unwrap();
206        assert_eq!(frame.payload, b"pong");
207    }
208
209    #[tokio::test]
210    async fn test_write_json() {
211        let (client, server) = tokio::io::duplex(1024);
212        let (_, cw) = tokio::io::split(client);
213        let (sr, _) = tokio::io::split(server);
214
215        let mut writer = FrameWriter::new(cw);
216        let mut reader = FrameReader::new(sr);
217
218        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
219        struct Msg {
220            text: String,
221        }
222
223        writer
224            .write_json(&Msg {
225                text: "hello".into(),
226            })
227            .await
228            .unwrap();
229        drop(writer);
230
231        let frame = reader.read_frame().await.unwrap().unwrap();
232        let msg: Msg = serde_json::from_slice(&frame.payload).unwrap();
233        assert_eq!(msg.text, "hello");
234    }
235
236    #[tokio::test]
237    async fn test_reader_incomplete_frame_on_eof() {
238        // Write only a partial header
239        let (client, server) = tokio::io::duplex(1024);
240        let (_, mut cw) = tokio::io::split(client);
241        let (sr, _) = tokio::io::split(server);
242
243        let mut reader = FrameReader::new(sr);
244
245        // Write 3 bytes (incomplete header) then close
246        cw.write_all(&[0x01, 0x00, 0x00]).await.unwrap();
247        drop(cw);
248
249        let result = reader.read_frame().await;
250        assert!(result.is_err());
251    }
252
253    #[tokio::test]
254    async fn test_multiple_frames_in_one_read() {
255        let (client, server) = tokio::io::duplex(4096);
256        let (_, mut cw) = tokio::io::split(client);
257        let (sr, _) = tokio::io::split(server);
258
259        // Encode two frames and write them in a single write
260        let f1 = Frame::data(b"first".to_vec());
261        let f2 = Frame::data(b"second".to_vec());
262        let mut buf = f1.encode().unwrap();
263        buf.extend_from_slice(&f2.encode().unwrap());
264        cw.write_all(&buf).await.unwrap();
265        drop(cw);
266
267        let mut reader = FrameReader::new(sr);
268        let r1 = reader.read_frame().await.unwrap().unwrap();
269        assert_eq!(r1.payload, b"first");
270        let r2 = reader.read_frame().await.unwrap().unwrap();
271        assert_eq!(r2.payload, b"second");
272        assert!(reader.read_frame().await.unwrap().is_none());
273    }
274}