Skip to main content

veyron_wire/
framing.rs

1use crate::error::WireError;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5
6pub const MAX_PAYLOAD_SIZE: usize = 1_048_576;
7const MAGIC: u16 = 0x5652;
8const HEADER_SIZE: usize = 44;
9
10/// `flags` bit indicating a 32-byte HMAC tag is appended after the payload.
11pub const FLAG_MAC_PRESENT: u16 = 0x0001;
12
13/// Payload is zstd-compressed. Decompressed by framing layer before delivery.
14/// CRC32 is computed over the compressed bytes (what is on the wire).
15pub const FLAG_COMPRESSED: u16 = 0x0002;
16
17/// Payload is raw binary (PCM/Opus audio). Router skips Protobuf decode.
18pub const FLAG_RAW_BINARY: u16 = 0x0010;
19
20/// Frame is one fragment of a larger message. The first [`FRAG_HEADER_SIZE`]
21/// bytes of the payload contain fragment metadata; the remainder is the chunk.
22pub const FLAG_FRAGMENTED: u16 = 0x0004;
23
24/// Payloads at or above this size are candidates for zstd compression.
25pub const COMPRESS_THRESHOLD: usize = 65_536;
26
27/// Byte length of the fragment metadata header embedded at the start of a
28/// fragmented frame's payload when [`FLAG_FRAGMENTED`] is set.
29///
30/// Layout (all big-endian):
31///   [fragment_id: u16][sequence: u16][total: u16][stream_id: u32]
32pub const FRAG_HEADER_SIZE: usize = 10;
33
34/// Parsed representation of the 10-byte fragment metadata header.
35#[derive(Debug, Clone, Copy)]
36pub struct FragmentHeader {
37    /// Opaque identifier for the fragmented message within a stream.
38    /// Parsed from the wire; available for callers but not used by the kernel.
39    pub fragment_id: u16,
40    /// Zero-based position of this fragment in the sequence.
41    pub sequence: u16,
42    /// Total number of fragments that make up the original message.
43    pub total: u16,
44    /// Stream identifier used as the reassembly buffer key.
45    pub stream_id: u32,
46}
47
48/// Parses the [`FragmentHeader`] from the start of a frame payload.
49/// Returns `None` if the payload is shorter than [`FRAG_HEADER_SIZE`].
50pub fn parse_frag_header(payload: &[u8]) -> Option<FragmentHeader> {
51    if payload.len() < FRAG_HEADER_SIZE {
52        return None;
53    }
54    Some(FragmentHeader {
55        fragment_id: u16::from_be_bytes([payload[0], payload[1]]),
56        sequence: u16::from_be_bytes([payload[2], payload[3]]),
57        total: u16::from_be_bytes([payload[4], payload[5]]),
58        stream_id: u32::from_be_bytes([payload[6], payload[7], payload[8], payload[9]]),
59    })
60}
61
62/// Once a frame has started arriving, the rest of the header + payload must
63/// complete within this window. Bounds slow-loris stalls (a peer that sends a
64/// header declaring a large payload then dribbles or stops). Idle connections
65/// waiting for the next frame are NOT subject to it.
66const FRAME_READ_TIMEOUT: Duration = Duration::from_secs(10);
67
68#[derive(Debug, Clone)]
69pub struct Frame {
70    pub magic: u16,
71    pub flags: u16,
72    pub length: u32,
73    pub target: [u8; 32],
74    pub crc32: u32,
75    /// Shared, immutable payload bytes. `Arc<[u8]>` so fan-out to N subscribers
76    /// (broadcast, event bus) and per-write framing clone the reference, not
77    /// the bytes.
78    pub payload: Arc<[u8]>,
79    /// 32-byte HMAC tag, present iff `flags & FLAG_MAC_PRESENT != 0`.
80    pub mac: Option<[u8; 32]>,
81}
82
83/// Serialize the 44-byte frame header exactly as it goes on the wire. Used by
84/// both `write_frame_raw` and MAC computation so the tag covers the real bytes.
85pub fn serialize_header(frame: &Frame) -> [u8; HEADER_SIZE] {
86    let mut header = [0u8; HEADER_SIZE];
87    header[0..2].copy_from_slice(&frame.magic.to_be_bytes());
88    header[2..4].copy_from_slice(&frame.flags.to_be_bytes());
89    header[4..8].copy_from_slice(&frame.length.to_be_bytes());
90    header[8..40].copy_from_slice(&frame.target);
91    header[40..44].copy_from_slice(&frame.crc32.to_be_bytes());
92    header
93}
94
95pub async fn write_frame<W>(
96    stream: &mut W,
97    target: &str,
98    flags: u16,
99    payload: &[u8],
100) -> Result<(), WireError>
101where
102    W: AsyncWrite + Unpin,
103{
104    if payload.len() > MAX_PAYLOAD_SIZE {
105        return Err(WireError::PayloadTooLarge(payload.len()));
106    }
107
108    let mut header = [0u8; HEADER_SIZE];
109    header[0..2].copy_from_slice(&MAGIC.to_be_bytes());
110    header[2..4].copy_from_slice(&flags.to_be_bytes());
111    header[4..8].copy_from_slice(&(payload.len() as u32).to_be_bytes());
112
113    let target_bytes = target.as_bytes();
114    let copy_len = target_bytes.len().min(32);
115    header[8..8 + copy_len].copy_from_slice(&target_bytes[..copy_len]);
116
117    let checksum = crc32fast::hash(payload);
118    header[40..44].copy_from_slice(&checksum.to_be_bytes());
119
120    stream.write_all(&header).await?;
121    stream.write_all(payload).await?;
122    Ok(())
123}
124
125pub async fn write_frame_raw<W>(stream: &mut W, frame: &Frame) -> Result<(), WireError>
126where
127    W: AsyncWrite + Unpin,
128{
129    // Reject oversized payloads before compression: we don't accept inputs that
130    // exceed the protocol limit regardless of how well they might compress.
131    if frame.payload.len() > MAX_PAYLOAD_SIZE {
132        return Err(WireError::PayloadTooLarge(frame.payload.len()));
133    }
134
135    // Compress payloads at or above the threshold when FLAG_COMPRESSED is not
136    // already set and the payload is not raw binary (audio bypasses compression).
137    let (wire_payload, wire_flags): (Arc<[u8]>, u16) = if frame.payload.len() >= COMPRESS_THRESHOLD
138        && frame.flags & FLAG_COMPRESSED == 0
139        && frame.flags & FLAG_RAW_BINARY == 0
140    {
141        match zstd::bulk::compress(&frame.payload, 3) {
142            Ok(c) if c.len() < frame.payload.len() => (Arc::from(c), frame.flags | FLAG_COMPRESSED),
143            // Common path: no (re)compression needed, so no byte copy either —
144            // just bump the refcount on the shared payload.
145            _ => (frame.payload.clone(), frame.flags),
146        }
147    } else {
148        (frame.payload.clone(), frame.flags)
149    };
150
151    // CRC32 is over the compressed bytes — the bytes actually on the wire.
152    let wire_crc = crc32fast::hash(&wire_payload);
153    let wire_frame = Frame {
154        magic: frame.magic,
155        flags: wire_flags,
156        length: wire_payload.len() as u32,
157        target: frame.target,
158        crc32: wire_crc,
159        payload: wire_payload,
160        mac: frame.mac,
161    };
162    let header = serialize_header(&wire_frame);
163
164    stream.write_all(&header).await?;
165    stream.write_all(&wire_frame.payload).await?;
166    if let Some(tag) = &wire_frame.mac {
167        stream.write_all(tag).await?;
168    }
169    Ok(())
170}
171
172pub async fn read_frame<R>(stream: &mut R) -> Result<Frame, WireError>
173where
174    R: AsyncRead + Unpin,
175{
176    read_frame_with_timeout(stream, FRAME_READ_TIMEOUT).await
177}
178
179pub async fn read_frame_with_timeout<R>(
180    stream: &mut R,
181    frame_timeout: Duration,
182) -> Result<Frame, WireError>
183where
184    R: AsyncRead + Unpin,
185{
186    // Block indefinitely for the first byte — an idle connection between frames
187    // must not be torn down. Once a byte arrives, a frame is in progress and the
188    // remainder is bounded by frame_timeout.
189    let mut first = [0u8; 1];
190    stream.read_exact(&mut first).await?;
191
192    match tokio::time::timeout(frame_timeout, read_frame_body(stream, first[0])).await {
193        Ok(result) => result,
194        Err(_) => Err(WireError::FrameReadTimeout),
195    }
196}
197
198async fn read_frame_body<R>(stream: &mut R, first_byte: u8) -> Result<Frame, WireError>
199where
200    R: AsyncRead + Unpin,
201{
202    let mut header = [0u8; HEADER_SIZE];
203    header[0] = first_byte;
204    stream.read_exact(&mut header[1..]).await?;
205
206    let magic = u16::from_be_bytes([header[0], header[1]]);
207    if magic != MAGIC {
208        return Err(WireError::FrameMagicMismatch);
209    }
210
211    let flags = u16::from_be_bytes([header[2], header[3]]);
212    let length = u32::from_be_bytes([header[4], header[5], header[6], header[7]]);
213
214    if length as usize > MAX_PAYLOAD_SIZE {
215        return Err(WireError::PayloadTooLarge(length as usize));
216    }
217
218    let mut target = [0u8; 32];
219    target.copy_from_slice(&header[8..40]);
220
221    let crc32 = u32::from_be_bytes([header[40], header[41], header[42], header[43]]);
222
223    let mut payload = vec![0u8; length as usize];
224    if length > 0 {
225        stream.read_exact(&mut payload).await?;
226    }
227
228    // CRC is over the wire bytes (possibly compressed); verify before decompressing.
229    let computed = crc32fast::hash(&payload);
230    if computed != crc32 {
231        return Err(WireError::FrameCrcMismatch);
232    }
233
234    // Normalize the in-memory invariant: payload is always plaintext, and
235    // flags/length/crc32 describe the plaintext, regardless of what was on the
236    // wire. The MAC (if any) was computed by the sender over the pre-compression
237    // header+payload, so crc32 must be recomputed over the decompressed bytes —
238    // the wire crc32 (over compressed bytes) would fail verification.
239    let (payload, flags, length, crc32) = if flags & FLAG_COMPRESSED != 0 {
240        let decompressed = zstd::bulk::decompress(&payload, MAX_PAYLOAD_SIZE)
241            .map_err(|e| WireError::Internal(format!("decompress frame: {e}")))?;
242        let plain_len = decompressed.len() as u32;
243        let plain_crc = crc32fast::hash(&decompressed);
244        (decompressed, flags & !FLAG_COMPRESSED, plain_len, plain_crc)
245    } else {
246        (payload, flags, length, crc32)
247    };
248
249    let mac = if flags & FLAG_MAC_PRESENT != 0 {
250        let mut tag = [0u8; 32];
251        stream.read_exact(&mut tag).await?;
252        Some(tag)
253    } else {
254        None
255    };
256
257    Ok(Frame {
258        magic,
259        flags,
260        length,
261        target,
262        crc32,
263        payload: payload.into(),
264        mac,
265    })
266}
267
268/// Returns `None` if the target bytes are not valid UTF-8. Callers must log
269/// the raw hex and return an error frame in that case (VULN-022).
270pub fn target_as_str(frame: &Frame) -> Option<&str> {
271    let end = frame.target.iter().position(|&b| b == 0).unwrap_or(32);
272    std::str::from_utf8(&frame.target[..end]).ok()
273}