1use super::frame::{Flags, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_FRAME_SIZE};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum FrameError {
9 Truncated,
10 InvalidLength(u32),
11 PayloadTruncated {
12 expected: u32,
13 available: u32,
14 },
15 UnknownKind(u8),
16 UnknownFlags(u8),
17 FlagsNotAllowedForKind {
22 kind: u8,
23 flags: u8,
24 },
25}
26
27impl std::fmt::Display for FrameError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 Self::Truncated => write!(f, "frame header truncated (< 16 bytes)"),
31 Self::InvalidLength(n) => write!(f, "frame length field invalid: {n}"),
32 Self::PayloadTruncated {
33 expected,
34 available,
35 } => write!(
36 f,
37 "frame payload truncated: expected {expected} bytes, got {available}"
38 ),
39 Self::UnknownKind(byte) => write!(f, "unknown message kind 0x{byte:02x}"),
40 Self::UnknownFlags(byte) => write!(f, "unknown flag bits 0x{byte:02x}"),
41 Self::FlagsNotAllowedForKind { kind, flags } => write!(
42 f,
43 "flag bits 0x{flags:02x} not allowed on kind 0x{kind:02x}"
44 ),
45 }
46 }
47}
48
49impl std::error::Error for FrameError {}
50
51pub fn encode_frame(frame: &Frame) -> Vec<u8> {
52 if frame.flags.contains(Flags::COMPRESSED) {
57 return encode_compressed(frame);
58 }
59 let total = frame.encoded_len() as usize;
60 let mut buf = Vec::with_capacity(total);
61 buf.extend_from_slice(&frame.encoded_len().to_le_bytes());
62 buf.push(frame.kind as u8);
63 buf.push(frame.flags.bits());
64 buf.extend_from_slice(&frame.stream_id.to_le_bytes());
65 buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
66 buf.extend_from_slice(&frame.payload);
67 buf
68}
69
70fn encode_compressed(frame: &Frame) -> Vec<u8> {
71 let level = std::env::var("RED_REDWIRE_ZSTD_LEVEL")
75 .ok()
76 .and_then(|s| s.parse::<i32>().ok())
77 .unwrap_or(1);
78 let compressed = match zstd::stream::encode_all(frame.payload.as_slice(), level) {
79 Ok(buf) => buf,
80 Err(_) => {
81 let mut clone = frame.clone();
86 clone.flags = Flags::from_bits(clone.flags.bits() & !Flags::COMPRESSED.bits());
87 return encode_frame(&clone);
88 }
89 };
90 let total = (FRAME_HEADER_SIZE + compressed.len()) as u32;
91 let mut buf = Vec::with_capacity(total as usize);
92 buf.extend_from_slice(&total.to_le_bytes());
93 buf.push(frame.kind as u8);
94 buf.push(frame.flags.bits());
95 buf.extend_from_slice(&frame.stream_id.to_le_bytes());
96 buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
97 buf.extend_from_slice(&compressed);
98 buf
99}
100
101pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), FrameError> {
102 if bytes.len() < FRAME_HEADER_SIZE {
103 return Err(FrameError::Truncated);
104 }
105 let length = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
106 if length < FRAME_HEADER_SIZE as u32 || length > MAX_FRAME_SIZE {
107 return Err(FrameError::InvalidLength(length));
108 }
109 if (bytes.len() as u32) < length {
110 return Err(FrameError::PayloadTruncated {
111 expected: length,
112 available: bytes.len() as u32,
113 });
114 }
115 let kind = MessageKind::from_u8(bytes[4]).ok_or(FrameError::UnknownKind(bytes[4]))?;
116 let flag_bits = bytes[5];
117 const KNOWN_FLAGS: u8 = 0b0000_0011;
118 if flag_bits & !KNOWN_FLAGS != 0 {
119 return Err(FrameError::UnknownFlags(flag_bits));
120 }
121 let flags = Flags::from_bits(flag_bits);
122 if !kind.permits_flags(flags) {
127 return Err(FrameError::FlagsNotAllowedForKind {
128 kind: bytes[4],
129 flags: flag_bits,
130 });
131 }
132 let stream_id = u16::from_le_bytes([bytes[6], bytes[7]]);
133 let correlation_id = u64::from_le_bytes([
134 bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
135 ]);
136 let payload_len = (length as usize) - FRAME_HEADER_SIZE;
137 let on_wire = &bytes[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len];
138 let payload = if flags.contains(Flags::COMPRESSED) {
139 match zstd::stream::decode_all(on_wire) {
142 Ok(plain) => plain,
143 Err(e) => {
144 return Err(FrameError::PayloadTruncated {
145 expected: payload_len as u32,
150 available: e.to_string().len() as u32,
151 });
152 }
153 }
154 } else {
155 on_wire.to_vec()
156 };
157 Ok((
158 Frame {
159 kind,
160 flags,
163 stream_id,
164 correlation_id,
165 payload,
166 },
167 length as usize,
168 ))
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 fn round_trip(frame: Frame) {
176 let bytes = encode_frame(&frame);
177 let (decoded, consumed) = decode_frame(&bytes).expect("decode");
178 assert_eq!(consumed, bytes.len());
179 assert_eq!(decoded, frame);
180 }
181
182 #[test]
183 fn round_trip_empty_payload() {
184 round_trip(Frame::new(MessageKind::Ping, 1, vec![]));
185 }
186
187 #[test]
188 fn round_trip_with_payload() {
189 round_trip(Frame::new(MessageKind::Query, 42, b"SELECT 1".to_vec()));
190 }
191
192 #[test]
193 fn round_trip_with_stream_and_flags() {
194 let frame = Frame::new(MessageKind::Result, 999, vec![0xab; 256])
195 .with_stream(7)
196 .with_flags(Flags::COMPRESSED | Flags::MORE_FRAMES);
197 round_trip(frame);
198 }
199
200 #[test]
201 fn truncated_header_rejected() {
202 assert_eq!(decode_frame(&[]), Err(FrameError::Truncated));
203 assert_eq!(decode_frame(&[0; 15]), Err(FrameError::Truncated));
204 }
205
206 #[test]
207 fn length_below_header_rejected() {
208 let mut bytes = vec![0u8; 16];
209 bytes[..4].copy_from_slice(&15u32.to_le_bytes());
210 assert!(matches!(
211 decode_frame(&bytes),
212 Err(FrameError::InvalidLength(15))
213 ));
214 }
215
216 #[test]
217 fn unknown_kind_rejected() {
218 let mut bytes = vec![0u8; 16];
219 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
220 bytes[4] = 0xff;
221 assert_eq!(decode_frame(&bytes), Err(FrameError::UnknownKind(0xff)));
222 }
223
224 #[test]
225 fn unknown_flag_bits_rejected() {
226 let mut bytes = vec![0u8; 16];
227 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
228 bytes[4] = MessageKind::Ping as u8;
229 bytes[5] = 0b1000_0000;
230 assert!(matches!(
231 decode_frame(&bytes),
232 Err(FrameError::UnknownFlags(_))
233 ));
234 }
235
236 #[test]
237 fn flags_not_allowed_for_kind_rejected() {
238 let mut bytes = vec![0u8; 16];
242 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
243 bytes[4] = MessageKind::Ping as u8;
244 bytes[5] = Flags::COMPRESSED.bits();
245 match decode_frame(&bytes) {
246 Err(FrameError::FlagsNotAllowedForKind { kind, flags }) => {
247 assert_eq!(kind, MessageKind::Ping as u8);
248 assert_eq!(flags, Flags::COMPRESSED.bits());
249 }
250 other => panic!("expected FlagsNotAllowedForKind, got {other:?}"),
251 }
252 }
253
254 #[test]
255 fn streaming_decode_two_frames_back_to_back() {
256 let f1 = Frame::new(MessageKind::Query, 1, b"a".to_vec());
257 let f2 = Frame::new(MessageKind::Query, 2, b"b".to_vec());
258 let mut buf = encode_frame(&f1);
259 buf.extend(encode_frame(&f2));
260 let (got1, n1) = decode_frame(&buf).unwrap();
261 let (got2, _n2) = decode_frame(&buf[n1..]).unwrap();
262 assert_eq!(got1, f1);
263 assert_eq!(got2, f2);
264 }
265
266 #[test]
267 fn compressed_round_trip_recovers_plaintext() {
268 let payload = b"abcabcabcabc".repeat(100);
270 let frame =
271 Frame::new(MessageKind::Result, 7, payload.clone()).with_flags(Flags::COMPRESSED);
272 let bytes = encode_frame(&frame);
273 assert!(
275 bytes.len() < FRAME_HEADER_SIZE + payload.len(),
276 "compressed frame ({}) must be smaller than plaintext payload ({})",
277 bytes.len(),
278 payload.len(),
279 );
280 let (decoded, _) = decode_frame(&bytes).expect("decode compressed");
281 assert_eq!(decoded.payload, payload);
282 assert!(decoded.flags.contains(Flags::COMPRESSED));
283 }
284
285 #[test]
286 fn uncompressed_frame_decodes_unchanged_when_flag_unset() {
287 let payload = b"hello world".to_vec();
288 let frame = Frame::new(MessageKind::Result, 1, payload.clone());
289 let bytes = encode_frame(&frame);
290 let (decoded, _) = decode_frame(&bytes).unwrap();
291 assert_eq!(decoded.payload, payload);
292 assert!(!decoded.flags.contains(Flags::COMPRESSED));
293 }
294}