Skip to main content

po_session/
framer.rs

1//! Frame reader/writer for PO connections.
2//!
3//! The `Framer` handles buffering, fragmentation, and reassembly of PO frames
4//! over any `AsyncFrameTransport`. It manages both reading and writing —
5//! unlike the original implementation which only handled reads.
6
7use bytes::{Bytes, BytesMut};
8use po_transport::traits::{AsyncFrameTransport, TransportError};
9use po_wire::{FrameHeader, WireError};
10
11/// Default maximum frame payload size (10 MB).
12const DEFAULT_MAX_FRAME_SIZE: u64 = 10 * 1024 * 1024;
13
14/// Frame reader/writer that handles buffering and reassembly.
15pub struct Framer {
16    /// Read buffer for accumulating incoming bytes.
17    read_buf: BytesMut,
18    /// Maximum allowed payload size.
19    max_frame_size: u64,
20}
21
22impl Default for Framer {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl Framer {
29    /// Create a new `Framer` with default settings.
30    pub fn new() -> Self {
31        Self {
32            read_buf: BytesMut::with_capacity(65536),
33            max_frame_size: DEFAULT_MAX_FRAME_SIZE,
34        }
35    }
36
37    /// Set the maximum allowed payload size for incoming frames.
38    pub fn with_max_frame_size(mut self, max: u64) -> Self {
39        self.max_frame_size = max;
40        self
41    }
42
43    // ─── Writing ────────────────────────────────────────────────────────
44
45    /// Write a complete frame (header + payload) to the transport.
46    ///
47    /// Coalesces header + payload into a single `write_all` call to minimize
48    /// syscalls and QUIC packet fragmentation.
49    pub async fn write_frame(
50        &self,
51        transport: &mut dyn AsyncFrameTransport,
52        header: &FrameHeader,
53        payload: &[u8],
54    ) -> Result<(), FramerError> {
55        let header_len = header.encoded_len();
56        let total_len = header_len + payload.len();
57
58        // Coalesce header + payload into a single buffer for one write_all call.
59        // Stack-allocated for small frames (≤ 25-byte header + no payload),
60        // heap-allocated otherwise.
61        let mut combined = Vec::with_capacity(total_len);
62        combined.resize(header_len, 0u8);
63        header
64            .encode(&mut combined[..header_len])
65            .map_err(FramerError::Wire)?;
66        combined.extend_from_slice(payload);
67
68        transport
69            .write_all(&combined)
70            .await
71            .map_err(FramerError::Transport)?;
72
73        Ok(())
74    }
75
76    // ─── Reading ────────────────────────────────────────────────────────
77
78    /// Read the next complete frame from the transport.
79    ///
80    /// Returns `None` if the connection was cleanly closed.
81    pub async fn read_frame(
82        &mut self,
83        transport: &mut dyn AsyncFrameTransport,
84    ) -> Result<Option<(FrameHeader, Bytes)>, FramerError> {
85        loop {
86            // 1. Try to parse a header from the current buffer
87            if let Some((header, header_len)) = self.try_parse_header()? {
88                // Validate payload size
89                if header.payload_len > self.max_frame_size {
90                    return Err(FramerError::Wire(WireError::PayloadTooLarge {
91                        declared: header.payload_len,
92                        max_allowed: self.max_frame_size,
93                    }));
94                }
95
96                let total_needed = header_len + header.payload_len as usize;
97
98                // 2. Do we have the full frame?
99                if self.read_buf.len() >= total_needed {
100                    // Consume header bytes
101                    let _ = self.read_buf.split_to(header_len);
102                    // Consume payload bytes
103                    let payload = self.read_buf.split_to(header.payload_len as usize).freeze();
104                    return Ok(Some((header, payload)));
105                }
106
107                // 3. Need more bytes — read at least enough for the payload
108                let still_needed = total_needed - self.read_buf.len();
109                if !self.fill_buffer(transport, still_needed).await? {
110                    return Ok(None); // Connection closed
111                }
112                continue;
113            }
114
115            // 4. Not enough data for a header — read more
116            if !self.fill_buffer(transport, 1).await? {
117                if self.read_buf.is_empty() {
118                    return Ok(None); // Clean EOF
119                }
120                return Err(FramerError::Wire(WireError::Incomplete {
121                    needed_min: 4,
122                    available: self.read_buf.len(),
123                }));
124            }
125        }
126    }
127
128    /// Try to parse a `FrameHeader` from the current read buffer.
129    fn try_parse_header(&self) -> Result<Option<(FrameHeader, usize)>, FramerError> {
130        if self.read_buf.is_empty() {
131            return Ok(None);
132        }
133        match FrameHeader::decode(&self.read_buf) {
134            Ok((header, len)) => Ok(Some((header, len))),
135            Err(WireError::Incomplete { .. }) => Ok(None), // Need more data
136            Err(e) => Err(FramerError::Wire(e)),
137        }
138    }
139
140    /// Read at least `min_bytes` additional bytes into the buffer.
141    /// Returns `false` if the connection was closed before any bytes were read.
142    async fn fill_buffer(
143        &mut self,
144        transport: &mut dyn AsyncFrameTransport,
145        min_bytes: usize,
146    ) -> Result<bool, FramerError> {
147        let mut total = 0;
148        let mut tmp = [0u8; 65536];
149
150        while total < min_bytes {
151            match transport.read(&mut tmp).await {
152                Ok(n) => {
153                    self.read_buf.extend_from_slice(&tmp[..n]);
154                    total += n;
155                }
156                Err(TransportError::ConnectionClosed) => {
157                    return Ok(false);
158                }
159                Err(e) => return Err(FramerError::Transport(e)),
160            }
161        }
162
163        Ok(true)
164    }
165
166    /// Number of buffered bytes not yet consumed.
167    pub fn buffered(&self) -> usize {
168        self.read_buf.len()
169    }
170}
171
172/// Errors from the framing layer.
173#[derive(Debug)]
174pub enum FramerError {
175    Wire(WireError),
176    Transport(TransportError),
177}
178
179impl std::fmt::Display for FramerError {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        match self {
182            Self::Wire(e) => write!(f, "wire: {e}"),
183            Self::Transport(e) => write!(f, "transport: {e}"),
184        }
185    }
186}
187
188impl std::error::Error for FramerError {}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use po_transport::MemoryTransport;
194    use po_wire::FrameType;
195
196    #[tokio::test]
197    async fn write_and_read_data_frame() {
198        let (mut a, mut b) = MemoryTransport::pair(64);
199        let framer_w = Framer::new();
200        let mut framer_r = Framer::new();
201
202        let payload = b"Hello Protocol Orzatty!";
203        let header = FrameHeader::data(0, payload.len() as u64);
204
205        framer_w
206            .write_frame(&mut a, &header, payload)
207            .await
208            .unwrap();
209
210        let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
211        assert_eq!(recv_header.frame_type, FrameType::Data);
212        assert_eq!(recv_payload.as_ref(), payload);
213    }
214
215    #[tokio::test]
216    async fn write_and_read_control_frame() {
217        let (mut a, mut b) = MemoryTransport::pair(64);
218        let framer_w = Framer::new();
219        let mut framer_r = Framer::new();
220
221        let header = FrameHeader::control(FrameType::Ping);
222        framer_w.write_frame(&mut a, &header, &[]).await.unwrap();
223
224        let (recv_header, recv_payload) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
225        assert_eq!(recv_header.frame_type, FrameType::Ping);
226        assert!(recv_header.flags.control);
227        assert!(recv_payload.is_empty());
228    }
229
230    #[tokio::test]
231    async fn multiple_frames_sequential() {
232        let (mut a, mut b) = MemoryTransport::pair(64);
233        let framer_w = Framer::new();
234        let mut framer_r = Framer::new();
235
236        for i in 0u8..10 {
237            let payload = vec![i; (i as usize + 1) * 10];
238            let header = FrameHeader::data(i as u32, payload.len() as u64);
239            framer_w
240                .write_frame(&mut a, &header, &payload)
241                .await
242                .unwrap();
243        }
244
245        for i in 0u8..10 {
246            let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
247            assert_eq!(h.channel_id, i as u32);
248            assert_eq!(p.len(), (i as usize + 1) * 10);
249            assert!(p.iter().all(|&b| b == i));
250        }
251    }
252
253    #[tokio::test]
254    async fn eof_returns_none() {
255        let (a, mut b) = MemoryTransport::pair(64);
256        let mut framer_r = Framer::new();
257
258        drop(a); // Close the writer
259
260        let result = framer_r.read_frame(&mut b).await.unwrap();
261        assert!(result.is_none());
262    }
263
264    #[tokio::test]
265    async fn large_payload() {
266        let (mut a, mut b) = MemoryTransport::pair(256);
267        let framer_w = Framer::new();
268        let mut framer_r = Framer::new();
269
270        let payload = vec![0xAB; 100_000]; // 100KB
271        let header = FrameHeader::data(1, payload.len() as u64);
272        framer_w
273            .write_frame(&mut a, &header, &payload)
274            .await
275            .unwrap();
276
277        let (h, p) = framer_r.read_frame(&mut b).await.unwrap().unwrap();
278        assert_eq!(h.payload_len, 100_000);
279        assert_eq!(p.as_ref(), payload.as_slice());
280    }
281}