Skip to main content

reddb_wire/redwire/
codec.rs

1//! Hand-rolled binary codec for v2 frames. No serde — the on-wire
2//! shape is fixed by ADR 0001, kept simple so a hex-dump is
3//! readable.
4
5use super::frame::{Flags, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_FRAME_SIZE};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum FrameError {
9    Truncated,
10    InvalidLength(u32),
11    PayloadTruncated {
12        expected: u32,
13        available: u32,
14    },
15    UnknownKind(u8),
16    UnknownFlags(u8),
17    /// Catalog cross-check failed: the flag bits set on the frame are
18    /// not in `MessageKind::allowed_flags()` for this kind. The wire
19    /// catalog is the single source of truth for which bits a kind
20    /// may carry — see `frame.rs::MessageKind::allowed_flags`.
21    FlagsNotAllowedForKind {
22        kind: u8,
23        flags: u8,
24    },
25}
26
27impl std::fmt::Display for FrameError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Truncated => write!(f, "frame header truncated (< 16 bytes)"),
31            Self::InvalidLength(n) => write!(f, "frame length field invalid: {n}"),
32            Self::PayloadTruncated {
33                expected,
34                available,
35            } => write!(
36                f,
37                "frame payload truncated: expected {expected} bytes, got {available}"
38            ),
39            Self::UnknownKind(byte) => write!(f, "unknown message kind 0x{byte:02x}"),
40            Self::UnknownFlags(byte) => write!(f, "unknown flag bits 0x{byte:02x}"),
41            Self::FlagsNotAllowedForKind { kind, flags } => write!(
42                f,
43                "flag bits 0x{flags:02x} not allowed on kind 0x{kind:02x}"
44            ),
45        }
46    }
47}
48
49impl std::error::Error for FrameError {}
50
51pub fn encode_frame(frame: &Frame) -> Vec<u8> {
52    // The frame's `payload` is always the plaintext form. If the
53    // COMPRESSED flag is set we compress on the wire and rewrite
54    // the length header to match the compressed size — the
55    // receiver inflates before delivering to the dispatch loop.
56    if frame.flags.contains(Flags::COMPRESSED) {
57        return encode_compressed(frame);
58    }
59    let total = frame.encoded_len() as usize;
60    let mut buf = Vec::with_capacity(total);
61    buf.extend_from_slice(&frame.encoded_len().to_le_bytes());
62    buf.push(frame.kind as u8);
63    buf.push(frame.flags.bits());
64    buf.extend_from_slice(&frame.stream_id.to_le_bytes());
65    buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
66    buf.extend_from_slice(&frame.payload);
67    buf
68}
69
70fn encode_compressed(frame: &Frame) -> Vec<u8> {
71    // zstd level 1 — keeps CPU low while still cutting JSON +
72    // BulkInsertBinary by 60-80%. Operators that want max ratio
73    // can flip to level 3+ via `RED_REDWIRE_ZSTD_LEVEL` env.
74    let level = std::env::var("RED_REDWIRE_ZSTD_LEVEL")
75        .ok()
76        .and_then(|s| s.parse::<i32>().ok())
77        .unwrap_or(1);
78    let compressed = match zstd::stream::encode_all(frame.payload.as_slice(), level) {
79        Ok(buf) => buf,
80        Err(_) => {
81            // Fallback: drop the COMPRESSED flag and ship plaintext.
82            // Compression failures are rare (level 1 effectively
83            // never fails on bytes), but the fallback is safer
84            // than panicking inside the framing layer.
85            let mut clone = frame.clone();
86            clone.flags = Flags::from_bits(clone.flags.bits() & !Flags::COMPRESSED.bits());
87            return encode_frame(&clone);
88        }
89    };
90    let total = (FRAME_HEADER_SIZE + compressed.len()) as u32;
91    let mut buf = Vec::with_capacity(total as usize);
92    buf.extend_from_slice(&total.to_le_bytes());
93    buf.push(frame.kind as u8);
94    buf.push(frame.flags.bits());
95    buf.extend_from_slice(&frame.stream_id.to_le_bytes());
96    buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
97    buf.extend_from_slice(&compressed);
98    buf
99}
100
101pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), FrameError> {
102    if bytes.len() < FRAME_HEADER_SIZE {
103        return Err(FrameError::Truncated);
104    }
105    let length = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
106    if length < FRAME_HEADER_SIZE as u32 || length > MAX_FRAME_SIZE {
107        return Err(FrameError::InvalidLength(length));
108    }
109    if (bytes.len() as u32) < length {
110        return Err(FrameError::PayloadTruncated {
111            expected: length,
112            available: bytes.len() as u32,
113        });
114    }
115    let kind = MessageKind::from_u8(bytes[4]).ok_or(FrameError::UnknownKind(bytes[4]))?;
116    let flag_bits = bytes[5];
117    const KNOWN_FLAGS: u8 = 0b0000_0011;
118    if flag_bits & !KNOWN_FLAGS != 0 {
119        return Err(FrameError::UnknownFlags(flag_bits));
120    }
121    let flags = Flags::from_bits(flag_bits);
122    // Catalog cross-check: the kind's `allowed_flags()` is the single
123    // source of truth. Reject combinations the catalog forbids
124    // (e.g. COMPRESSED on tiny handshake payloads) so misframed
125    // frames fail at the boundary instead of reaching dispatch.
126    if !kind.permits_flags(flags) {
127        return Err(FrameError::FlagsNotAllowedForKind {
128            kind: bytes[4],
129            flags: flag_bits,
130        });
131    }
132    let stream_id = u16::from_le_bytes([bytes[6], bytes[7]]);
133    let correlation_id = u64::from_le_bytes([
134        bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
135    ]);
136    let payload_len = (length as usize) - FRAME_HEADER_SIZE;
137    let on_wire = &bytes[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len];
138    let payload = if flags.contains(Flags::COMPRESSED) {
139        // Decompress on read so the rest of the dispatch loop
140        // sees plaintext bytes regardless of how they arrived.
141        match zstd::stream::decode_all(on_wire) {
142            Ok(plain) => plain,
143            Err(e) => {
144                return Err(FrameError::PayloadTruncated {
145                    // Reuse PayloadTruncated for "decompression
146                    // failed" rather than introduce a new variant
147                    // — the wire-layer outcome is the same: the
148                    // body is unparseable, drop the connection.
149                    expected: payload_len as u32,
150                    available: e.to_string().len() as u32,
151                });
152            }
153        }
154    } else {
155        on_wire.to_vec()
156    };
157    Ok((
158        Frame {
159            kind,
160            // The flag stays on the decoded frame so dispatch can
161            // see it was compressed if it cares (audit, metrics).
162            flags,
163            stream_id,
164            correlation_id,
165            payload,
166        },
167        length as usize,
168    ))
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    fn round_trip(frame: Frame) {
176        let bytes = encode_frame(&frame);
177        let (decoded, consumed) = decode_frame(&bytes).expect("decode");
178        assert_eq!(consumed, bytes.len());
179        assert_eq!(decoded, frame);
180    }
181
182    #[test]
183    fn round_trip_empty_payload() {
184        round_trip(Frame::new(MessageKind::Ping, 1, vec![]));
185    }
186
187    #[test]
188    fn round_trip_with_payload() {
189        round_trip(Frame::new(MessageKind::Query, 42, b"SELECT 1".to_vec()));
190    }
191
192    #[test]
193    fn round_trip_with_stream_and_flags() {
194        let frame = Frame::new(MessageKind::Result, 999, vec![0xab; 256])
195            .with_stream(7)
196            .with_flags(Flags::COMPRESSED | Flags::MORE_FRAMES);
197        round_trip(frame);
198    }
199
200    #[test]
201    fn truncated_header_rejected() {
202        assert_eq!(decode_frame(&[]), Err(FrameError::Truncated));
203        assert_eq!(decode_frame(&[0; 15]), Err(FrameError::Truncated));
204    }
205
206    #[test]
207    fn length_below_header_rejected() {
208        let mut bytes = vec![0u8; 16];
209        bytes[..4].copy_from_slice(&15u32.to_le_bytes());
210        assert!(matches!(
211            decode_frame(&bytes),
212            Err(FrameError::InvalidLength(15))
213        ));
214    }
215
216    #[test]
217    fn unknown_kind_rejected() {
218        let mut bytes = vec![0u8; 16];
219        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
220        bytes[4] = 0xff;
221        assert_eq!(decode_frame(&bytes), Err(FrameError::UnknownKind(0xff)));
222    }
223
224    #[test]
225    fn unknown_flag_bits_rejected() {
226        let mut bytes = vec![0u8; 16];
227        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
228        bytes[4] = MessageKind::Ping as u8;
229        bytes[5] = 0b1000_0000;
230        assert!(matches!(
231            decode_frame(&bytes),
232            Err(FrameError::UnknownFlags(_))
233        ));
234    }
235
236    #[test]
237    fn flags_not_allowed_for_kind_rejected() {
238        // Ping is a handshake kind — the catalog forbids COMPRESSED
239        // on tiny handshake payloads. A frame with kind=Ping and the
240        // COMPRESSED bit set must be rejected at the boundary.
241        let mut bytes = vec![0u8; 16];
242        bytes[..4].copy_from_slice(&16u32.to_le_bytes());
243        bytes[4] = MessageKind::Ping as u8;
244        bytes[5] = Flags::COMPRESSED.bits();
245        match decode_frame(&bytes) {
246            Err(FrameError::FlagsNotAllowedForKind { kind, flags }) => {
247                assert_eq!(kind, MessageKind::Ping as u8);
248                assert_eq!(flags, Flags::COMPRESSED.bits());
249            }
250            other => panic!("expected FlagsNotAllowedForKind, got {other:?}"),
251        }
252    }
253
254    #[test]
255    fn streaming_decode_two_frames_back_to_back() {
256        let f1 = Frame::new(MessageKind::Query, 1, b"a".to_vec());
257        let f2 = Frame::new(MessageKind::Query, 2, b"b".to_vec());
258        let mut buf = encode_frame(&f1);
259        buf.extend(encode_frame(&f2));
260        let (got1, n1) = decode_frame(&buf).unwrap();
261        let (got2, _n2) = decode_frame(&buf[n1..]).unwrap();
262        assert_eq!(got1, f1);
263        assert_eq!(got2, f2);
264    }
265
266    #[test]
267    fn compressed_round_trip_recovers_plaintext() {
268        // A compressible payload — a kilobyte of repeating text.
269        let payload = b"abcabcabcabc".repeat(100);
270        let frame =
271            Frame::new(MessageKind::Result, 7, payload.clone()).with_flags(Flags::COMPRESSED);
272        let bytes = encode_frame(&frame);
273        // Wire form should be smaller than the plaintext frame.
274        assert!(
275            bytes.len() < FRAME_HEADER_SIZE + payload.len(),
276            "compressed frame ({}) must be smaller than plaintext payload ({})",
277            bytes.len(),
278            payload.len(),
279        );
280        let (decoded, _) = decode_frame(&bytes).expect("decode compressed");
281        assert_eq!(decoded.payload, payload);
282        assert!(decoded.flags.contains(Flags::COMPRESSED));
283    }
284
285    #[test]
286    fn uncompressed_frame_decodes_unchanged_when_flag_unset() {
287        let payload = b"hello world".to_vec();
288        let frame = Frame::new(MessageKind::Result, 1, payload.clone());
289        let bytes = encode_frame(&frame);
290        let (decoded, _) = decode_frame(&bytes).unwrap();
291        assert_eq!(decoded.payload, payload);
292        assert!(!decoded.flags.contains(Flags::COMPRESSED));
293    }
294}