a3s_common/transport/
codec.rs1use bytes::{Buf, BytesMut};
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use super::frame::Frame;
9use super::TransportError;
10
11const INITIAL_BUF_CAPACITY: usize = 8 * 1024;
12
13#[derive(Debug)]
17pub struct FrameReader<R> {
18 inner: R,
19 buf: BytesMut,
20}
21
22impl<R: AsyncRead + Unpin> FrameReader<R> {
23 pub fn new(inner: R) -> Self {
25 Self {
26 inner,
27 buf: BytesMut::with_capacity(INITIAL_BUF_CAPACITY),
28 }
29 }
30
31 pub async fn read_frame(&mut self) -> Result<Option<Frame>, TransportError> {
33 loop {
34 if let Some((frame, consumed)) = Frame::decode(&self.buf)? {
36 self.buf.advance(consumed);
37 return Ok(Some(frame));
38 }
39
40 let n = self
42 .inner
43 .read_buf(&mut self.buf)
44 .await
45 .map_err(|e| TransportError::RecvFailed(e.to_string()))?;
46
47 if n == 0 {
48 if self.buf.is_empty() {
50 return Ok(None);
51 }
52 return Err(TransportError::RecvFailed(
53 "Connection closed with incomplete frame".to_string(),
54 ));
55 }
56 }
57 }
58
59 pub fn inner(&self) -> &R {
61 &self.inner
62 }
63
64 pub fn into_inner(self) -> R {
66 self.inner
67 }
68}
69
70#[derive(Debug)]
72pub struct FrameWriter<W> {
73 inner: W,
74}
75
76impl<W: AsyncWrite + Unpin> FrameWriter<W> {
77 pub fn new(inner: W) -> Self {
79 Self { inner }
80 }
81
82 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), TransportError> {
84 let encoded = frame.encode()?;
85 self.inner
86 .write_all(&encoded)
87 .await
88 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
89 self.inner
90 .flush()
91 .await
92 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
93 Ok(())
94 }
95
96 pub async fn write_data(&mut self, payload: &[u8]) -> Result<(), TransportError> {
98 self.write_frame(&Frame::data(payload.to_vec())).await
99 }
100
101 pub async fn write_control(&mut self, payload: &[u8]) -> Result<(), TransportError> {
103 self.write_frame(&Frame::control(payload.to_vec())).await
104 }
105
106 pub async fn write_json<T: serde::Serialize>(
108 &mut self,
109 value: &T,
110 ) -> Result<(), TransportError> {
111 let payload = serde_json::to_vec(value)
112 .map_err(|e| TransportError::SendFailed(format!("JSON serialize: {}", e)))?;
113 self.write_data(&payload).await
114 }
115
116 pub fn inner(&self) -> &W {
118 &self.inner
119 }
120
121 pub fn into_inner(self) -> W {
123 self.inner
124 }
125}
126
127#[derive(Debug)]
129pub struct FrameCodec<R, W> {
130 pub reader: FrameReader<R>,
131 pub writer: FrameWriter<W>,
132}
133
134impl<R: AsyncRead + Unpin, W: AsyncWrite + Unpin> FrameCodec<R, W> {
135 pub fn new(reader: R, writer: W) -> Self {
137 Self {
138 reader: FrameReader::new(reader),
139 writer: FrameWriter::new(writer),
140 }
141 }
142
143 pub async fn read_frame(&mut self) -> Result<Option<Frame>, TransportError> {
145 self.reader.read_frame().await
146 }
147
148 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), TransportError> {
150 self.writer.write_frame(frame).await
151 }
152}
153
154#[cfg(test)]
159mod tests {
160 use super::*;
161 use super::super::frame::FrameType;
162
163 #[tokio::test]
164 async fn test_reader_writer_roundtrip() {
165 let (client, server) = tokio::io::duplex(1024);
166 let (cr, cw) = tokio::io::split(client);
167 let (sr, _sw) = tokio::io::split(server);
168
169 let mut writer = FrameWriter::new(cw);
170 let mut reader = FrameReader::new(sr);
171
172 writer.write_data(b"hello").await.unwrap();
173 writer.write_frame(&Frame::heartbeat()).await.unwrap();
174 drop(writer);
177 drop(cr);
178
179 let f1 = reader.read_frame().await.unwrap().unwrap();
180 assert_eq!(f1.frame_type, FrameType::Data);
181 assert_eq!(f1.payload, b"hello");
182
183 let f2 = reader.read_frame().await.unwrap().unwrap();
184 assert_eq!(f2.frame_type, FrameType::Heartbeat);
185 assert!(f2.payload.is_empty());
186
187 let f3 = reader.read_frame().await.unwrap();
188 assert!(f3.is_none()); }
190
191 #[tokio::test]
192 async fn test_codec_bidirectional() {
193 let (a, b) = tokio::io::duplex(1024);
194 let (ar, aw) = tokio::io::split(a);
195 let (br, bw) = tokio::io::split(b);
196
197 let mut codec_a = FrameCodec::new(ar, aw);
198 let mut codec_b = FrameCodec::new(br, bw);
199
200 codec_a.writer.write_data(b"ping").await.unwrap();
201 let frame = codec_b.reader.read_frame().await.unwrap().unwrap();
202 assert_eq!(frame.payload, b"ping");
203
204 codec_b.writer.write_data(b"pong").await.unwrap();
205 let frame = codec_a.reader.read_frame().await.unwrap().unwrap();
206 assert_eq!(frame.payload, b"pong");
207 }
208
209 #[tokio::test]
210 async fn test_write_json() {
211 let (client, server) = tokio::io::duplex(1024);
212 let (_, cw) = tokio::io::split(client);
213 let (sr, _) = tokio::io::split(server);
214
215 let mut writer = FrameWriter::new(cw);
216 let mut reader = FrameReader::new(sr);
217
218 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
219 struct Msg {
220 text: String,
221 }
222
223 writer
224 .write_json(&Msg {
225 text: "hello".into(),
226 })
227 .await
228 .unwrap();
229 drop(writer);
230
231 let frame = reader.read_frame().await.unwrap().unwrap();
232 let msg: Msg = serde_json::from_slice(&frame.payload).unwrap();
233 assert_eq!(msg.text, "hello");
234 }
235
236 #[tokio::test]
237 async fn test_reader_incomplete_frame_on_eof() {
238 let (client, server) = tokio::io::duplex(1024);
240 let (_, mut cw) = tokio::io::split(client);
241 let (sr, _) = tokio::io::split(server);
242
243 let mut reader = FrameReader::new(sr);
244
245 cw.write_all(&[0x01, 0x00, 0x00]).await.unwrap();
247 drop(cw);
248
249 let result = reader.read_frame().await;
250 assert!(result.is_err());
251 }
252
253 #[tokio::test]
254 async fn test_multiple_frames_in_one_read() {
255 let (client, server) = tokio::io::duplex(4096);
256 let (_, mut cw) = tokio::io::split(client);
257 let (sr, _) = tokio::io::split(server);
258
259 let f1 = Frame::data(b"first".to_vec());
261 let f2 = Frame::data(b"second".to_vec());
262 let mut buf = f1.encode().unwrap();
263 buf.extend_from_slice(&f2.encode().unwrap());
264 cw.write_all(&buf).await.unwrap();
265 drop(cw);
266
267 let mut reader = FrameReader::new(sr);
268 let r1 = reader.read_frame().await.unwrap().unwrap();
269 assert_eq!(r1.payload, b"first");
270 let r2 = reader.read_frame().await.unwrap().unwrap();
271 assert_eq!(r2.payload, b"second");
272 assert!(reader.read_frame().await.unwrap().is_none());
273 }
274}