1use bytes::{Bytes, BytesMut};
8use po_transport::traits::{AsyncFrameTransport, TransportError};
9use po_wire::{FrameHeader, WireError};
10
11const DEFAULT_MAX_FRAME_SIZE: u64 = 10 * 1024 * 1024;
13
14pub struct Framer {
16 read_buf: BytesMut,
18 max_frame_size: u64,
20}
21
22impl Default for Framer {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl Framer {
29 pub fn new() -> Self {
31 Self {
32 read_buf: BytesMut::with_capacity(65536),
33 max_frame_size: DEFAULT_MAX_FRAME_SIZE,
34 }
35 }
36
37 pub fn with_max_frame_size(mut self, max: u64) -> Self {
39 self.max_frame_size = max;
40 self
41 }
42
43 pub async fn write_frame(
50 &self,
51 transport: &mut dyn AsyncFrameTransport,
52 header: &FrameHeader,
53 payload: &[u8],
54 ) -> Result<(), FramerError> {
55 let header_len = header.encoded_len();
56 let total_len = header_len + payload.len();
57
58 let mut combined = Vec::with_capacity(total_len);
62 combined.resize(header_len, 0u8);
63 header
64 .encode(&mut combined[..header_len])
65 .map_err(FramerError::Wire)?;
66 combined.extend_from_slice(payload);
67
68 transport
69 .write_all(&combined)
70 .await
71 .map_err(FramerError::Transport)?;
72
73 Ok(())
74 }
75
76 pub async fn read_frame(
82 &mut self,
83 transport: &mut dyn AsyncFrameTransport,
84 ) -> Result<Option<(FrameHeader, Bytes)>, FramerError> {
85 loop {
86 if let Some((header, header_len)) = self.try_parse_header()? {
88 if header.payload_len > self.max_frame_size {
90 return Err(FramerError::Wire(WireError::PayloadTooLarge {
91 declared: header.payload_len,
92 max_allowed: self.max_frame_size,
93 }));
94 }
95
96 let total_needed = header_len + header.payload_len as usize;
97
98 if self.read_buf.len() >= total_needed {
100 let _ = self.read_buf.split_to(header_len);
102 let payload = self.read_buf.split_to(header.payload_len as usize).freeze();
104 return Ok(Some((header, payload)));
105 }
106
107 let still_needed = total_needed - self.read_buf.len();
109 if !self.fill_buffer(transport, still_needed).await? {
110 return Ok(None); }
112 continue;
113 }
114
115 if !self.fill_buffer(transport, 1).await? {
117 if self.read_buf.is_empty() {
118 return Ok(None); }
120 return Err(FramerError::Wire(WireError::Incomplete {
121 needed_min: 4,
122 available: self.read_buf.len(),
123 }));
124 }
125 }
126 }
127
128 fn try_parse_header(&self) -> Result<Option<(FrameHeader, usize)>, FramerError> {
130 if self.read_buf.is_empty() {
131 return Ok(None);
132 }
133 match FrameHeader::decode(&self.read_buf) {
134 Ok((header, len)) => Ok(Some((header, len))),
135 Err(WireError::Incomplete { .. }) => Ok(None), Err(e) => Err(FramerError::Wire(e)),
137 }
138 }
139
140 async fn fill_buffer(
143 &mut self,
144 transport: &mut dyn AsyncFrameTransport,
145 min_bytes: usize,
146 ) -> Result<bool, FramerError> {
147 let mut total = 0;
148 let mut tmp = [0u8; 65536];
149
150 while total < min_bytes {
151 match transport.read(&mut tmp).await {
152 Ok(n) => {
153 self.read_buf.extend_from_slice(&tmp[..n]);
154 total += n;
155 }
156 Err(TransportError::ConnectionClosed) => {
157 return Ok(false);
158 }
159 Err(e) => return Err(FramerError::Transport(e)),
160 }
161 }
162
163 Ok(true)
164 }
165
166 pub fn buffered(&self) -> usize {
168 self.read_buf.len()
169 }
170}
171
172#[derive(Debug)]
174pub enum FramerError {
175 Wire(WireError),
176 Transport(TransportError),
177}
178
179impl std::fmt::Display for FramerError {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 match self {
182 Self::Wire(e) => write!(f, "wire: {e}"),
183 Self::Transport(e) => write!(f, "transport: {e}"),
184 }
185 }
186}
187
188impl std::error::Error for FramerError {}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use po_transport::MemoryTransport;
194 use po_wire::FrameType;
195
196 #[tokio::test]
197 async fn write_and_read_data_frame() {
198 let (mut a, mut b) = MemoryTransport::pair(64);
199 let framer_w = Framer::new();
200 let mut framer_r = Framer::new();
201
202 let payload = b"Hello Protocol Orzatty!";
203 let header = FrameHeader::data(0, payload.len() as u64);
204
205 framer_w
206 .write_frame(&mut a, &header, payload)
207 .await
208 .unwrap();
209
210 let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
211 assert_eq!(recv_header.frame_type, FrameType::Data);
212 assert_eq!(recv_payload.as_ref(), payload);
213 }
214
215 #[tokio::test]
216 async fn write_and_read_control_frame() {
217 let (mut a, mut b) = MemoryTransport::pair(64);
218 let framer_w = Framer::new();
219 let mut framer_r = Framer::new();
220
221 let header = FrameHeader::control(FrameType::Ping);
222 framer_w.write_frame(&mut a, &header, &[]).await.unwrap();
223
224 let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
225 assert_eq!(recv_header.frame_type, FrameType::Ping);
226 assert!(recv_header.flags.control);
227 assert!(recv_payload.is_empty());
228 }
229
230 #[tokio::test]
231 async fn multiple_frames_sequential() {
232 let (mut a, mut b) = MemoryTransport::pair(64);
233 let framer_w = Framer::new();
234 let mut framer_r = Framer::new();
235
236 for i in 0u8..10 {
237 let payload = vec![i; (i as usize + 1) * 10];
238 let header = FrameHeader::data(i as u32, payload.len() as u64);
239 framer_w
240 .write_frame(&mut a, &header, &payload)
241 .await
242 .unwrap();
243 }
244
245 for i in 0u8..10 {
246 let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
247 assert_eq!(h.channel_id, i as u32);
248 assert_eq!(p.len(), (i as usize + 1) * 10);
249 assert!(p.iter().all(|&b| b == i));
250 }
251 }
252
253 #[tokio::test]
254 async fn eof_returns_none() {
255 let (a, mut b) = MemoryTransport::pair(64);
256 let mut framer_r = Framer::new();
257
258 drop(a); let result = framer_r.read_frame(&mut b).await.unwrap();
261 assert!(result.is_none());
262 }
263
264 #[tokio::test]
265 async fn large_payload() {
266 let (mut a, mut b) = MemoryTransport::pair(256);
267 let framer_w = Framer::new();
268 let mut framer_r = Framer::new();
269
270 let payload = vec![0xAB; 100_000]; let header = FrameHeader::data(1, payload.len() as u64);
272 framer_w
273 .write_frame(&mut a, &header, &payload)
274 .await
275 .unwrap();
276
277 let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
278 assert_eq!(h.payload_len, 100_000);
279 assert_eq!(p.as_ref(), payload.as_slice());
280 }
281}