Skip to main content

po_wire/
header.rs

1//! PO Frame Header — the compact binary header that precedes every frame.
2//!
3//! ## Wire Layout
4//!
5//! ```text
6//! ┌─────────┬──────────────┬──────────────┬──────────────┐
7//! │ Byte 0  │ VarInt       │ VarInt       │ VarInt       │
8//! │ Type +  │ Channel ID   │ Stream ID    │ Payload Len  │
9//! │ Flags   │ (1–8 bytes)  │ (1–8 bytes)  │ (1–8 bytes)  │
10//! └─────────┴──────────────┴──────────────┴──────────────┘
11//!
12//! Byte 0 bit layout:
13//! ┌───┬───┬───┬───┬───┬───┬───┬───┐
14//! │ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │
15//! │CTL│PRI│ENC│RSV│    FrameType   │
16//! └───┴───┴───┴───┴───┴───┴───┴───┘
17//! ```
18//!
19//! **Minimum header size: 4 bytes** (1 type/flags + 3×1-byte VarInts).
20//! **Maximum header size: 25 bytes** (1 type/flags + 3×8-byte VarInts).
21
22use crate::error::WireError;
23use crate::frame_type::FrameType;
24use crate::varint;
25
26// --- Flag bit positions in byte 0 ---
27
28/// Bit 7: This is a control frame (not application data).
29const FLAG_CONTROL: u8 = 0b1000_0000;
30/// Bit 6: High-priority — process immediately, bypass queues.
31const FLAG_PRIORITY: u8 = 0b0100_0000;
32/// Bit 5: Payload is encrypted with the session cipher.
33const FLAG_ENCRYPTED: u8 = 0b0010_0000;
34/// Bit 4: Reserved for future use.
35const _FLAG_RESERVED: u8 = 0b0001_0000;
36/// Mask for the frame type (lower 4 bits).
37const TYPE_MASK: u8 = 0x0F;
38
39/// Flags that modify how a frame is processed.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub struct FrameFlags {
42    /// This is a control frame (ping, pong, close).
43    pub control: bool,
44    /// High priority — should bypass normal processing queues.
45    pub priority: bool,
46    /// Payload is encrypted with the session's ChaCha20-Poly1305 cipher.
47    pub encrypted: bool,
48}
49
50impl FrameFlags {
51    /// Encode flags into the upper 4 bits of byte 0.
52    #[inline]
53    const fn to_bits(self) -> u8 {
54        let mut bits = 0u8;
55        if self.control { bits |= FLAG_CONTROL; }
56        if self.priority { bits |= FLAG_PRIORITY; }
57        if self.encrypted { bits |= FLAG_ENCRYPTED; }
58        bits
59    }
60
61    /// Decode flags from byte 0.
62    #[inline]
63    const fn from_bits(byte: u8) -> Self {
64        Self {
65            control: byte & FLAG_CONTROL != 0,
66            priority: byte & FLAG_PRIORITY != 0,
67            encrypted: byte & FLAG_ENCRYPTED != 0,
68        }
69    }
70}
71
72/// A decoded PO frame header.
73///
74/// This is a pure value type — it does not own or reference the payload.
75/// After decoding the header, read `payload_len` bytes from the transport
76/// to get the payload.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct FrameHeader {
79    /// The type of this frame (Data, Handshake, Ping, etc.).
80    pub frame_type: FrameType,
81    /// Processing flags (control, priority, encrypted).
82    pub flags: FrameFlags,
83    /// Logical channel for application-level multiplexing.
84    pub channel_id: u32,
85    /// Stream identifier for QUIC-like concurrent streams within a channel.
86    pub stream_id: u64,
87    /// Length of the payload that follows this header.
88    pub payload_len: u64,
89}
90
91impl FrameHeader {
92    /// Create a new header for a data frame with default flags.
93    #[inline]
94    pub const fn data(channel_id: u32, payload_len: u64) -> Self {
95        Self {
96            frame_type: FrameType::Data,
97            flags: FrameFlags { control: false, priority: false, encrypted: false },
98            channel_id,
99            stream_id: 0,
100            payload_len,
101        }
102    }
103
104    /// Create a new control frame header (e.g., Ping, Pong, Close).
105    #[inline]
106    pub const fn control(frame_type: FrameType) -> Self {
107        Self {
108            frame_type,
109            flags: FrameFlags { control: true, priority: false, encrypted: false },
110            channel_id: 0,
111            stream_id: 0,
112            payload_len: 0,
113        }
114    }
115
116    /// Set the encrypted flag on this header.
117    #[inline]
118    pub const fn with_encrypted(mut self) -> Self {
119        self.flags.encrypted = true;
120        self
121    }
122
123    /// Set the priority flag on this header.
124    #[inline]
125    pub const fn with_priority(mut self) -> Self {
126        self.flags.priority = true;
127        self
128    }
129
130    /// Set the stream ID.
131    #[inline]
132    pub const fn with_stream(mut self, stream_id: u64) -> Self {
133        self.stream_id = stream_id;
134        self
135    }
136
137    /// Calculate the exact number of bytes this header will occupy when encoded.
138    #[inline]
139    pub const fn encoded_len(&self) -> usize {
140        1 // byte 0 (type + flags)
141        + varint::encoded_len(self.channel_id as u64)
142        + varint::encoded_len(self.stream_id)
143        + varint::encoded_len(self.payload_len)
144    }
145
146    /// Encode this header into the provided buffer.
147    ///
148    /// Returns the number of bytes written.
149    pub fn encode(&self, buf: &mut [u8]) -> Result<usize, WireError> {
150        let needed = self.encoded_len();
151        if buf.len() < needed {
152            return Err(WireError::BufferTooSmall { needed, available: buf.len() });
153        }
154
155        let mut offset = 0;
156
157        // Byte 0: flags (upper 4 bits) + frame type (lower 4 bits)
158        buf[0] = self.flags.to_bits() | (self.frame_type as u8 & TYPE_MASK);
159        offset += 1;
160
161        // VarInt: channel_id
162        offset += varint::encode(self.channel_id as u64, &mut buf[offset..])?;
163
164        // VarInt: stream_id
165        offset += varint::encode(self.stream_id, &mut buf[offset..])?;
166
167        // VarInt: payload_len
168        offset += varint::encode(self.payload_len, &mut buf[offset..])?;
169
170        debug_assert_eq!(offset, needed);
171        Ok(offset)
172    }
173
174    /// Decode a header from the provided buffer.
175    ///
176    /// Returns `(header, bytes_consumed)`.
177    ///
178    /// # Errors
179    /// - `WireError::Incomplete` if the buffer doesn't contain a complete header.
180    /// - `WireError::UnknownFrameType` if the type nibble is invalid.
181    pub fn decode(buf: &[u8]) -> Result<(Self, usize), WireError> {
182        if buf.is_empty() {
183            return Err(WireError::Incomplete { needed_min: 4, available: 0 });
184        }
185
186        let byte0 = buf[0];
187        let mut offset = 1;
188
189        // Decode flags
190        let flags = FrameFlags::from_bits(byte0);
191
192        // Decode frame type from lower 4 bits
193        let frame_type = FrameType::from_u8(byte0 & TYPE_MASK)?;
194
195        // Decode channel_id
196        let (channel_raw, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
197            WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
198                needed_min: offset + needed_min,
199                available: buf.len(),
200            },
201            other => other,
202        })?;
203        offset += n;
204
205        // Decode stream_id
206        let (stream_id, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
207            WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
208                needed_min: offset + needed_min,
209                available: buf.len(),
210            },
211            other => other,
212        })?;
213        offset += n;
214
215        // Decode payload_len
216        let (payload_len, n) = varint::decode(&buf[offset..]).map_err(|e| match e {
217            WireError::Incomplete { needed_min, .. } => WireError::Incomplete {
218                needed_min: offset + needed_min,
219                available: buf.len(),
220            },
221            other => other,
222        })?;
223        offset += n;
224
225        Ok((
226            Self {
227                frame_type,
228                flags,
229                channel_id: channel_raw as u32,
230                stream_id,
231                payload_len,
232            },
233            offset,
234        ))
235    }
236}
237
238impl core::fmt::Display for FrameHeader {
239    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
240        write!(
241            f,
242            "[{}] ch={} stream={} len={} flags=[{}{}{}]",
243            self.frame_type,
244            self.channel_id,
245            self.stream_id,
246            self.payload_len,
247            if self.flags.control { "C" } else { "" },
248            if self.flags.priority { "P" } else { "" },
249            if self.flags.encrypted { "E" } else { "" },
250        )
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    extern crate alloc;
257    use alloc::format;
258    use super::*;
259
260    #[test]
261    fn minimal_data_header() {
262        // Smallest possible: Data type, channel 0, stream 0, payload 0
263        let h = FrameHeader::data(0, 0);
264        let mut buf = [0u8; 32];
265        let n = h.encode(&mut buf).unwrap();
266        assert_eq!(n, 4, "minimum header should be 4 bytes");
267        assert_eq!(buf[0] & TYPE_MASK, 0x00); // Data type
268        assert_eq!(buf[0] & 0xF0, 0x00); // No flags
269
270        let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
271        assert_eq!(decoded, h);
272        assert_eq!(consumed, n);
273    }
274
275    #[test]
276    fn data_with_flags() {
277        let h = FrameHeader::data(1, 100)
278            .with_encrypted()
279            .with_priority()
280            .with_stream(42);
281
282        let mut buf = [0u8; 32];
283        let n = h.encode(&mut buf).unwrap();
284
285        // Byte 0: PRI(0x40) + ENC(0x20) + Data(0x00) = 0x60
286        assert_eq!(buf[0], 0x60);
287
288        let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
289        assert_eq!(decoded, h);
290        assert_eq!(consumed, n);
291    }
292
293    #[test]
294    fn control_ping() {
295        let h = FrameHeader::control(FrameType::Ping);
296        let mut buf = [0u8; 32];
297        let n = h.encode(&mut buf).unwrap();
298
299        // Byte 0: CONTROL(0x80) + Ping(0x04) = 0x84
300        assert_eq!(buf[0], 0x84);
301        assert_eq!(n, 4); // 1 + three 1-byte varints (all 0)
302
303        let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
304        assert!(decoded.flags.control);
305        assert_eq!(decoded.frame_type, FrameType::Ping);
306    }
307
308    #[test]
309    fn handshake_type() {
310        for ft in [FrameType::HandshakeInit, FrameType::HandshakeReply, FrameType::HandshakeComplete] {
311            let h = FrameHeader {
312                frame_type: ft,
313                flags: FrameFlags::default(),
314                channel_id: 0,
315                stream_id: 0,
316                payload_len: 128,
317            };
318            let mut buf = [0u8; 32];
319            let n = h.encode(&mut buf).unwrap();
320            let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
321            assert_eq!(decoded.frame_type, ft);
322            assert!(decoded.frame_type.is_handshake());
323        }
324    }
325
326    #[test]
327    fn large_values() {
328        let h = FrameHeader {
329            frame_type: FrameType::FileChunk,
330            flags: FrameFlags { control: false, priority: true, encrypted: true },
331            channel_id: 1_000_000,
332            stream_id: 9_999_999_999,
333            payload_len: 4_294_967_296, // 4GB
334        };
335        let mut buf = [0u8; 32];
336        let n = h.encode(&mut buf).unwrap();
337        let (decoded, consumed) = FrameHeader::decode(&buf[..n]).unwrap();
338        assert_eq!(decoded, h);
339        assert_eq!(consumed, n);
340    }
341
342    #[test]
343    fn encoded_len_accurate() {
344        let h = FrameHeader::data(42, 100).with_stream(12345);
345        assert_eq!(h.encoded_len(), h.encode(&mut [0u8; 32]).unwrap());
346    }
347
348    #[test]
349    fn incomplete_decode() {
350        // Only 2 bytes — not enough for a full header
351        let buf = [0x00, 0x00];
352        match FrameHeader::decode(&buf) {
353            Err(WireError::Incomplete { .. }) => {} // Expected
354            other => panic!("expected Incomplete, got {other:?}"),
355        }
356    }
357
358    #[test]
359    fn unknown_type_rejected() {
360        // Byte 0 with type nibble = 0x0F (reserved)
361        let buf = [0x0F, 0x00, 0x00, 0x00];
362        assert!(matches!(
363            FrameHeader::decode(&buf),
364            Err(WireError::UnknownFrameType(0x0F))
365        ));
366    }
367
368    #[test]
369    fn display_format() {
370        let h = FrameHeader::data(5, 42).with_encrypted();
371        let s = format!("{h}");
372        assert!(s.contains("DATA"));
373        assert!(s.contains("ch=5"));
374        assert!(s.contains("len=42"));
375        assert!(s.contains("E")); // Encrypted flag
376    }
377
378    #[test]
379    fn all_frame_types_encode_decode() {
380        let types = [
381            FrameType::Data,
382            FrameType::HandshakeInit,
383            FrameType::HandshakeReply,
384            FrameType::HandshakeComplete,
385            FrameType::Ping,
386            FrameType::Pong,
387            FrameType::Close,
388            FrameType::FileHeader,
389            FrameType::FileChunk,
390            FrameType::Ack,
391        ];
392        for ft in types {
393            let h = FrameHeader {
394                frame_type: ft,
395                flags: FrameFlags::default(),
396                channel_id: 0,
397                stream_id: 0,
398                payload_len: 0,
399            };
400            let mut buf = [0u8; 32];
401            let n = h.encode(&mut buf).unwrap();
402            let (decoded, _) = FrameHeader::decode(&buf[..n]).unwrap();
403            assert_eq!(decoded.frame_type, ft, "type {ft} failed roundtrip");
404        }
405    }
406}