Skip to main content

reddb_wire/redwire/
codec.rs

1//! Hand-rolled binary codec for v2 frames. No serde — the on-wire
2//! shape is fixed by ADR 0001, kept simple so a hex-dump is
3//! readable.
4
5use super::frame::{Flags, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_FRAME_SIZE};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum FrameError {
9    Truncated,
10    InvalidLength(u32),
11    PayloadTruncated {
12        expected: u32,
13        available: u32,
14    },
15    UnknownKind(u8),
16    UnknownFlags(u8),
17    /// Catalog cross-check failed: the flag bits set on the frame are
18    /// not in `MessageKind::allowed_flags()` for this kind. The wire
19    /// catalog is the single source of truth for which bits a kind
20    /// may carry — see `frame.rs::MessageKind::allowed_flags`.
21    FlagsNotAllowedForKind {
22        kind: u8,
23        flags: u8,
24    },
25}
26
27impl std::fmt::Display for FrameError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Truncated => write!(f, "frame header truncated (< 16 bytes)"),
31            Self::InvalidLength(n) => write!(f, "frame length field invalid: {n}"),
32            Self::PayloadTruncated {
33                expected,
34                available,
35            } => write!(
36                f,
37                "frame payload truncated: expected {expected} bytes, got {available}"
38            ),
39            Self::UnknownKind(byte) => write!(f, "unknown message kind 0x{byte:02x}"),
40            Self::UnknownFlags(byte) => write!(f, "unknown flag bits 0x{byte:02x}"),
41            Self::FlagsNotAllowedForKind { kind, flags } => write!(
42                f,
43                "flag bits 0x{flags:02x} not allowed on kind 0x{kind:02x}"
44            ),
45        }
46    }
47}
48
49impl std::error::Error for FrameError {}
50
51pub fn frame_len_from_header(header: &[u8; FRAME_HEADER_SIZE]) -> Result<usize, FrameError> {
52    let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
53    if length < FRAME_HEADER_SIZE as u32 || length > MAX_FRAME_SIZE {
54        return Err(FrameError::InvalidLength(length));
55    }
56    Ok(length as usize)
57}
58
59pub fn decode_frame_parts(
60    header: &[u8; FRAME_HEADER_SIZE],
61    payload: &[u8],
62) -> Result<Frame, FrameError> {
63    let length = frame_len_from_header(header)?;
64    let expected_payload_len = length - FRAME_HEADER_SIZE;
65    if payload.len() < expected_payload_len {
66        return Err(FrameError::PayloadTruncated {
67            expected: expected_payload_len as u32,
68            available: payload.len() as u32,
69        });
70    }
71
72    let mut bytes = Vec::with_capacity(length);
73    bytes.extend_from_slice(header);
74    bytes.extend_from_slice(&payload[..expected_payload_len]);
75    decode_frame(&bytes).map(|(frame, _)| frame)
76}
77
78pub fn encode_frame(frame: &Frame) -> Vec<u8> {
79    // The frame's `payload` is always the plaintext form. If the
80    // COMPRESSED flag is set we compress on the wire and rewrite
81    // the length header to match the compressed size — the
82    // receiver inflates before delivering to the dispatch loop.
83    if frame.flags.contains(Flags::COMPRESSED) {
84        return encode_compressed(frame);
85    }
86    let total = frame.encoded_len() as usize;
87    let mut buf = Vec::with_capacity(total);
88    buf.extend_from_slice(&frame.encoded_len().to_le_bytes());
89    buf.push(frame.kind as u8);
90    buf.push(frame.flags.bits());
91    buf.extend_from_slice(&frame.stream_id.to_le_bytes());
92    buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
93    buf.extend_from_slice(&frame.payload);
94    buf
95}
96
97fn encode_compressed(frame: &Frame) -> Vec<u8> {
98    // zstd level 1 — keeps CPU low while still cutting JSON +
99    // BulkInsertBinary by 60-80%. Operators that want max ratio
100    // can flip to level 3+ via `RED_REDWIRE_ZSTD_LEVEL` env.
101    let level = std::env::var("RED_REDWIRE_ZSTD_LEVEL")
102        .ok()
103        .and_then(|s| s.parse::<i32>().ok())
104        .unwrap_or(1);
105    let compressed = match zstd::stream::encode_all(frame.payload.as_slice(), level) {
106        Ok(buf) => buf,
107        Err(_) => {
108            // Fallback: drop the COMPRESSED flag and ship plaintext.
109            // Compression failures are rare (level 1 effectively
110            // never fails on bytes), but the fallback is safer
111            // than panicking inside the framing layer.
112            let mut clone = frame.clone();
113            clone.flags = Flags::from_bits(clone.flags.bits() & !Flags::COMPRESSED.bits());
114            return encode_frame(&clone);
115        }
116    };
117    let total = (FRAME_HEADER_SIZE + compressed.len()) as u32;
118    let mut buf = Vec::with_capacity(total as usize);
119    buf.extend_from_slice(&total.to_le_bytes());
120    buf.push(frame.kind as u8);
121    buf.push(frame.flags.bits());
122    buf.extend_from_slice(&frame.stream_id.to_le_bytes());
123    buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
124    buf.extend_from_slice(&compressed);
125    buf
126}
127
128pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), FrameError> {
129    if bytes.len() < FRAME_HEADER_SIZE {
130        return Err(FrameError::Truncated);
131    }
132    let mut header = [0u8; FRAME_HEADER_SIZE];
133    header.copy_from_slice(&bytes[..FRAME_HEADER_SIZE]);
134    let length = frame_len_from_header(&header)? as u32;
135    if (bytes.len() as u32) < length {
136        return Err(FrameError::PayloadTruncated {
137            expected: length,
138            available: bytes.len() as u32,
139        });
140    }
141    let kind = MessageKind::from_u8(bytes[4]).ok_or(FrameError::UnknownKind(bytes[4]))?;
142    let flag_bits = bytes[5];
143    const KNOWN_FLAGS: u8 = 0b0000_0011;
144    if flag_bits & !KNOWN_FLAGS != 0 {
145        return Err(FrameError::UnknownFlags(flag_bits));
146    }
147    let flags = Flags::from_bits(flag_bits);
148    // Catalog cross-check: the kind's `allowed_flags()` is the single
149    // source of truth. Reject combinations the catalog forbids
150    // (e.g. COMPRESSED on tiny handshake payloads) so misframed
151    // frames fail at the boundary instead of reaching dispatch.
152    if !kind.permits_flags(flags) {
153        return Err(FrameError::FlagsNotAllowedForKind {
154            kind: bytes[4],
155            flags: flag_bits,
156        });
157    }
158    let stream_id = u16::from_le_bytes([bytes[6], bytes[7]]);
159    let correlation_id = u64::from_le_bytes([
160        bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
161    ]);
162    let payload_len = (length as usize) - FRAME_HEADER_SIZE;
163    let on_wire = &bytes[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len];
164    let payload = if flags.contains(Flags::COMPRESSED) {
165        // Decompress on read so the rest of the dispatch loop
166        // sees plaintext bytes regardless of how they arrived.
167        match zstd::stream::decode_all(on_wire) {
168            Ok(plain) => plain,
169            Err(e) => {
170                return Err(FrameError::PayloadTruncated {
171                    // Reuse PayloadTruncated for "decompression
172                    // failed" rather than introduce a new variant
173                    // — the wire-layer outcome is the same: the
174                    // body is unparseable, drop the connection.
175                    expected: payload_len as u32,
176                    available: e.to_string().len() as u32,
177                });
178            }
179        }
180    } else {
181        on_wire.to_vec()
182    };
183    Ok((
184        Frame {
185            kind,
186            // The flag stays on the decoded frame so dispatch can
187            // see it was compressed if it cares (audit, metrics).
188            flags,
189            stream_id,
190            correlation_id,
191            payload,
192        },
193        length as usize,
194    ))
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn round_trip(frame: Frame) {
202        let bytes = encode_frame(&frame);
203        let (decoded, consumed) = decode_frame(&bytes).expect("decode");
204        assert_eq!(consumed, bytes.len());
205        assert_eq!(decoded, frame);
206    }
207
208    #[test]
209    fn round_trip_empty_payload() {
210        round_trip(Frame::new(MessageKind::Ping, 1, vec![]));
211    }
212
213    #[test]
214    fn frame_len_from_header_validates_bounds() {
215        let mut header = [0u8; FRAME_HEADER_SIZE];
216        header[..4].copy_from_slice(&(FRAME_HEADER_SIZE as u32).to_le_bytes());
217        assert_eq!(frame_len_from_header(&header).unwrap(), FRAME_HEADER_SIZE);
218
219        header[..4].copy_from_slice(&15u32.to_le_bytes());
220        assert_eq!(
221            frame_len_from_header(&header),
222            Err(FrameError::InvalidLength(15))
223        );
224
225        header[..4].copy_from_slice(&(MAX_FRAME_SIZE + 1).to_le_bytes());
226        assert_eq!(
227            frame_len_from_header(&header),
228            Err(FrameError::InvalidLength(MAX_FRAME_SIZE + 1))
229        );
230    }
231
232    #[test]
233    fn round_trip_with_payload() {
234        round_trip(Frame::new(MessageKind::Query, 42, b"SELECT 1".to_vec()));
235    }
236
237    #[test]
238    fn decode_frame_parts_matches_full_buffer_decode() {
239        let frame = Frame::new(MessageKind::Result, 42, br#"{"ok":true}"#.to_vec());
240        let bytes = encode_frame(&frame);
241        let mut header = [0u8; FRAME_HEADER_SIZE];
242        header.copy_from_slice(&bytes[..FRAME_HEADER_SIZE]);
243        let payload = &bytes[FRAME_HEADER_SIZE..];
244
245        let decoded = decode_frame_parts(&header, payload).expect("decode parts");
246        assert_eq!(decoded, frame);
247    }
248
249    #[test]
250    fn round_trip_with_stream_and_flags() {
251        let frame = Frame::new(MessageKind::Result, 999, vec![0xab; 256])
252            .with_stream(7)
253            .with_flags(Flags::COMPRESSED | Flags::MORE_FRAMES);
254        round_trip(frame);
255    }
256
257    #[test]
258    fn truncated_header_rejected() {
259        assert_eq!(decode_frame(&[]), Err(FrameError::Truncated));
260        assert_eq!(decode_frame(&[0; 15]), Err(FrameError::Truncated));
261    }
262
263    #[test]
264    fn length_below_header_rejected() {
265        let mut bytes = vec![0u8; 16];
266        bytes[..4].copy_from_slice(&15u32.to_le_bytes());
267        assert!(matches!(
268            decode_frame(&bytes),
269            Err(FrameError::InvalidLength(15))
270        ));
271    }
272
273    #[test]
274    fn unknown_kind_rejected() {
275        let mut bytes = vec![0u8; 16];
276        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
277        bytes[4] = 0xff;
278        assert_eq!(decode_frame(&bytes), Err(FrameError::UnknownKind(0xff)));
279    }
280
281    #[test]
282    fn unknown_flag_bits_rejected() {
283        let mut bytes = vec![0u8; 16];
284        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
285        bytes[4] = MessageKind::Ping as u8;
286        bytes[5] = 0b1000_0000;
287        assert!(matches!(
288            decode_frame(&bytes),
289            Err(FrameError::UnknownFlags(_))
290        ));
291    }
292
293    #[test]
294    fn flags_not_allowed_for_kind_rejected() {
295        // Ping is a handshake kind — the catalog forbids COMPRESSED
296        // on tiny handshake payloads. A frame with kind=Ping and the
297        // COMPRESSED bit set must be rejected at the boundary.
298        let mut bytes = vec![0u8; 16];
299        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
300        bytes[4] = MessageKind::Ping as u8;
301        bytes[5] = Flags::COMPRESSED.bits();
302        match decode_frame(&bytes) {
303            Err(FrameError::FlagsNotAllowedForKind { kind, flags }) => {
304                assert_eq!(kind, MessageKind::Ping as u8);
305                assert_eq!(flags, Flags::COMPRESSED.bits());
306            }
307            other => panic!("expected FlagsNotAllowedForKind, got {other:?}"),
308        }
309    }
310
311    #[test]
312    fn streaming_decode_two_frames_back_to_back() {
313        let f1 = Frame::new(MessageKind::Query, 1, b"a".to_vec());
314        let f2 = Frame::new(MessageKind::Query, 2, b"b".to_vec());
315        let mut buf = encode_frame(&f1);
316        buf.extend(encode_frame(&f2));
317        let (got1, n1) = decode_frame(&buf).unwrap();
318        let (got2, _n2) = decode_frame(&buf[n1..]).unwrap();
319        assert_eq!(got1, f1);
320        assert_eq!(got2, f2);
321    }
322
323    #[test]
324    fn compressed_round_trip_recovers_plaintext() {
325        // A compressible payload — a kilobyte of repeating text.
326        let payload = b"abcabcabcabc".repeat(100);
327        let frame =
328            Frame::new(MessageKind::Result, 7, payload.clone()).with_flags(Flags::COMPRESSED);
329        let bytes = encode_frame(&frame);
330        // Wire form should be smaller than the plaintext frame.
331        assert!(
332            bytes.len() < FRAME_HEADER_SIZE + payload.len(),
333            "compressed frame ({}) must be smaller than plaintext payload ({})",
334            bytes.len(),
335            payload.len(),
336        );
337        let (decoded, _) = decode_frame(&bytes).expect("decode compressed");
338        assert_eq!(decoded.payload, payload);
339        assert!(decoded.flags.contains(Flags::COMPRESSED));
340    }
341
342    #[test]
343    fn output_stream_lifecycle_envelopes_round_trip() {
344        // Golden encode/decode for every new variant added in
345        // issue #762 / PRD #759 S3. Pins the exact byte values
346        // and confirms `stream_id` multiplex routing survives the
347        // round trip (it is what `StreamCancel`'s per-stream
348        // targeting relies on).
349        let open = Frame::new(
350            MessageKind::OpenStream,
351            10,
352            br#"{"sql":"SELECT 1","opts":{}}"#.to_vec(),
353        )
354        .with_stream(7);
355        round_trip(open.clone());
356        assert_eq!(encode_frame(&open)[4], 0x29);
357
358        let ack = Frame::new(
359            MessageKind::OpenAck,
360            10,
361            br#"{"lease_handle":"42","resumable":false,"snapshot_lsn":1234}"#.to_vec(),
362        )
363        .with_stream(7);
364        round_trip(ack.clone());
365        assert_eq!(encode_frame(&ack)[4], 0x2A);
366
367        let chunk = Frame::new(
368            MessageKind::StreamChunk,
369            10,
370            br#"{"seq":0,"rows":[{"a":1}],"terminal":false}"#.to_vec(),
371        )
372        .with_stream(7);
373        round_trip(chunk.clone());
374        assert_eq!(encode_frame(&chunk)[4], 0x2B);
375
376        let serr = Frame::new(
377            MessageKind::StreamError,
378            10,
379            br#"{"code":"unknown_stream","message":"x"}"#.to_vec(),
380        )
381        .with_stream(7);
382        round_trip(serr.clone());
383        assert_eq!(encode_frame(&serr)[4], 0x2C);
384
385        let end = Frame::new(
386            MessageKind::StreamEnd,
387            10,
388            br#"{"stats":{"row_count":1}}"#.to_vec(),
389        )
390        .with_stream(7);
391        round_trip(end.clone());
392        assert_eq!(encode_frame(&end)[4], 0x25);
393
394        let cancel = Frame::new(
395            MessageKind::StreamCancel,
396            10,
397            br#"{"reason":"client-abort"}"#.to_vec(),
398        )
399        .with_stream(7);
400        round_trip(cancel.clone());
401        assert_eq!(encode_frame(&cancel)[4], 0x2D);
402    }
403
404    #[test]
405    fn input_stream_envelopes_round_trip() {
406        // Golden encode/decode for the input-direction envelopes
407        // added in issue #764 / PRD #759 S5. The envelope *vocabulary*
408        // is reused from S3 — only the payload shapes and the
409        // direction of `StreamChunk` differ — so the byte values are
410        // pinned to the same kinds (no new MessageKind bytes).
411
412        // OpenStream with `direction:"in"` + target/columns instead of
413        // a `sql` field. Still kind 0x29, still multiplex via stream_id.
414        let open_in = Frame::new(
415            MessageKind::OpenStream,
416            20,
417            br#"{"direction":"in","target":"t","columns":["id","name"]}"#.to_vec(),
418        )
419        .with_stream(5);
420        round_trip(open_in.clone());
421        assert_eq!(encode_frame(&open_in)[4], 0x29);
422
423        // Client-originated chunk of rows on the input stream. Same
424        // 0x2B kind the server uses on output streams; the rows are
425        // JSON objects keyed by column.
426        let chunk_in = Frame::new(
427            MessageKind::StreamChunk,
428            20,
429            br#"{"seq":0,"rows":[{"id":1,"name":"a"}],"terminal":false}"#.to_vec(),
430        )
431        .with_stream(5);
432        round_trip(chunk_in.clone());
433        assert_eq!(encode_frame(&chunk_in)[4], 0x2B);
434
435        // Terminal chunk closes the input stream.
436        let chunk_terminal = Frame::new(
437            MessageKind::StreamChunk,
438            20,
439            br#"{"seq":2,"rows":[],"terminal":true}"#.to_vec(),
440        )
441        .with_stream(5);
442        round_trip(chunk_terminal.clone());
443        assert_eq!(encode_frame(&chunk_terminal)[4], 0x2B);
444
445        // Server StreamEnd carries the committed RID range + stats.
446        let end = Frame::new(
447            MessageKind::StreamEnd,
448            20,
449            br#"{"stats":{"row_count":3,"chunk_count":2,"committed_rid":42,"snapshot_lsn":40,"cancelled":false}}"#.to_vec(),
450        )
451        .with_stream(5);
452        round_trip(end.clone());
453        assert_eq!(encode_frame(&end)[4], 0x25);
454
455        // Server StreamError carries the recoverable_rid prefix.
456        let serr = Frame::new(
457            MessageKind::StreamError,
458            20,
459            br#"{"code":"invalid_row","message":"x","chunk_seq":1,"recoverable_rid":41}"#.to_vec(),
460        )
461        .with_stream(5);
462        round_trip(serr.clone());
463        assert_eq!(encode_frame(&serr)[4], 0x2C);
464    }
465
466    #[test]
467    fn queue_wait_envelopes_round_trip() {
468        // Golden encode/decode for the live queue-wait envelopes added
469        // in issue #917 / PRD #915. Pins the new byte values and
470        // confirms the request/push pair round-trips equal through the
471        // codec, multiplexed over `stream_id` like the other streamed
472        // envelopes.
473        let open = Frame::new(
474            MessageKind::QueueWaitOpen,
475            10,
476            br#"{"queue":"jobs","consumer":"w1","count":1,"wait_ms":5000}"#.to_vec(),
477        )
478        .with_stream(3);
479        round_trip(open.clone());
480        assert_eq!(encode_frame(&open)[4], 0x2E);
481
482        let push = Frame::new(
483            MessageKind::QueueEventPush,
484            10,
485            br#"{"message_id":"42","payload":{"hello":"world"},"consumer":"w1","delivery_count":1}"#
486                .to_vec(),
487        )
488        .with_stream(3);
489        round_trip(push.clone());
490        assert_eq!(encode_frame(&push)[4], 0x2F);
491    }
492
493    #[test]
494    fn uncompressed_frame_decodes_unchanged_when_flag_unset() {
495        let payload = b"hello world".to_vec();
496        let frame = Frame::new(MessageKind::Result, 1, payload.clone());
497        let bytes = encode_frame(&frame);
498        let (decoded, _) = decode_frame(&bytes).unwrap();
499        assert_eq!(decoded.payload, payload);
500        assert!(!decoded.flags.contains(Flags::COMPRESSED));
501    }
502}