1use super::frame::{Flags, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_FRAME_SIZE};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum FrameError {
9 Truncated,
10 InvalidLength(u32),
11 PayloadTruncated {
12 expected: u32,
13 available: u32,
14 },
15 UnknownKind(u8),
16 UnknownFlags(u8),
17 FlagsNotAllowedForKind {
22 kind: u8,
23 flags: u8,
24 },
25}
26
27impl std::fmt::Display for FrameError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 Self::Truncated => write!(f, "frame header truncated (< 16 bytes)"),
31 Self::InvalidLength(n) => write!(f, "frame length field invalid: {n}"),
32 Self::PayloadTruncated {
33 expected,
34 available,
35 } => write!(
36 f,
37 "frame payload truncated: expected {expected} bytes, got {available}"
38 ),
39 Self::UnknownKind(byte) => write!(f, "unknown message kind 0x{byte:02x}"),
40 Self::UnknownFlags(byte) => write!(f, "unknown flag bits 0x{byte:02x}"),
41 Self::FlagsNotAllowedForKind { kind, flags } => write!(
42 f,
43 "flag bits 0x{flags:02x} not allowed on kind 0x{kind:02x}"
44 ),
45 }
46 }
47}
48
49impl std::error::Error for FrameError {}
50
51pub fn frame_len_from_header(header: &[u8; FRAME_HEADER_SIZE]) -> Result<usize, FrameError> {
52 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
53 if length < FRAME_HEADER_SIZE as u32 || length > MAX_FRAME_SIZE {
54 return Err(FrameError::InvalidLength(length));
55 }
56 Ok(length as usize)
57}
58
59pub fn decode_frame_parts(
60 header: &[u8; FRAME_HEADER_SIZE],
61 payload: &[u8],
62) -> Result<Frame, FrameError> {
63 let length = frame_len_from_header(header)?;
64 let expected_payload_len = length - FRAME_HEADER_SIZE;
65 if payload.len() < expected_payload_len {
66 return Err(FrameError::PayloadTruncated {
67 expected: expected_payload_len as u32,
68 available: payload.len() as u32,
69 });
70 }
71
72 let mut bytes = Vec::with_capacity(length);
73 bytes.extend_from_slice(header);
74 bytes.extend_from_slice(&payload[..expected_payload_len]);
75 decode_frame(&bytes).map(|(frame, _)| frame)
76}
77
78pub fn encode_frame(frame: &Frame) -> Vec<u8> {
79 if frame.flags.contains(Flags::COMPRESSED) {
84 return encode_compressed(frame);
85 }
86 let total = frame.encoded_len() as usize;
87 let mut buf = Vec::with_capacity(total);
88 buf.extend_from_slice(&frame.encoded_len().to_le_bytes());
89 buf.push(frame.kind as u8);
90 buf.push(frame.flags.bits());
91 buf.extend_from_slice(&frame.stream_id.to_le_bytes());
92 buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
93 buf.extend_from_slice(&frame.payload);
94 buf
95}
96
97fn encode_compressed(frame: &Frame) -> Vec<u8> {
98 let level = std::env::var("RED_REDWIRE_ZSTD_LEVEL")
102 .ok()
103 .and_then(|s| s.parse::<i32>().ok())
104 .unwrap_or(1);
105 let compressed = match zstd::stream::encode_all(frame.payload.as_slice(), level) {
106 Ok(buf) => buf,
107 Err(_) => {
108 let mut clone = frame.clone();
113 clone.flags = Flags::from_bits(clone.flags.bits() & !Flags::COMPRESSED.bits());
114 return encode_frame(&clone);
115 }
116 };
117 let total = (FRAME_HEADER_SIZE + compressed.len()) as u32;
118 let mut buf = Vec::with_capacity(total as usize);
119 buf.extend_from_slice(&total.to_le_bytes());
120 buf.push(frame.kind as u8);
121 buf.push(frame.flags.bits());
122 buf.extend_from_slice(&frame.stream_id.to_le_bytes());
123 buf.extend_from_slice(&frame.correlation_id.to_le_bytes());
124 buf.extend_from_slice(&compressed);
125 buf
126}
127
128pub fn decode_frame(bytes: &[u8]) -> Result<(Frame, usize), FrameError> {
129 if bytes.len() < FRAME_HEADER_SIZE {
130 return Err(FrameError::Truncated);
131 }
132 let mut header = [0u8; FRAME_HEADER_SIZE];
133 header.copy_from_slice(&bytes[..FRAME_HEADER_SIZE]);
134 let length = frame_len_from_header(&header)? as u32;
135 if (bytes.len() as u32) < length {
136 return Err(FrameError::PayloadTruncated {
137 expected: length,
138 available: bytes.len() as u32,
139 });
140 }
141 let kind = MessageKind::from_u8(bytes[4]).ok_or(FrameError::UnknownKind(bytes[4]))?;
142 let flag_bits = bytes[5];
143 const KNOWN_FLAGS: u8 = 0b0000_0011;
144 if flag_bits & !KNOWN_FLAGS != 0 {
145 return Err(FrameError::UnknownFlags(flag_bits));
146 }
147 let flags = Flags::from_bits(flag_bits);
148 if !kind.permits_flags(flags) {
153 return Err(FrameError::FlagsNotAllowedForKind {
154 kind: bytes[4],
155 flags: flag_bits,
156 });
157 }
158 let stream_id = u16::from_le_bytes([bytes[6], bytes[7]]);
159 let correlation_id = u64::from_le_bytes([
160 bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
161 ]);
162 let payload_len = (length as usize) - FRAME_HEADER_SIZE;
163 let on_wire = &bytes[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len];
164 let payload = if flags.contains(Flags::COMPRESSED) {
165 match zstd::stream::decode_all(on_wire) {
168 Ok(plain) => plain,
169 Err(e) => {
170 return Err(FrameError::PayloadTruncated {
171 expected: payload_len as u32,
176 available: e.to_string().len() as u32,
177 });
178 }
179 }
180 } else {
181 on_wire.to_vec()
182 };
183 Ok((
184 Frame {
185 kind,
186 flags,
189 stream_id,
190 correlation_id,
191 payload,
192 },
193 length as usize,
194 ))
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn round_trip(frame: Frame) {
202 let bytes = encode_frame(&frame);
203 let (decoded, consumed) = decode_frame(&bytes).expect("decode");
204 assert_eq!(consumed, bytes.len());
205 assert_eq!(decoded, frame);
206 }
207
208 #[test]
209 fn round_trip_empty_payload() {
210 round_trip(Frame::new(MessageKind::Ping, 1, vec![]));
211 }
212
213 #[test]
214 fn frame_len_from_header_validates_bounds() {
215 let mut header = [0u8; FRAME_HEADER_SIZE];
216 header[..4].copy_from_slice(&(FRAME_HEADER_SIZE as u32).to_le_bytes());
217 assert_eq!(frame_len_from_header(&header).unwrap(), FRAME_HEADER_SIZE);
218
219 header[..4].copy_from_slice(&15u32.to_le_bytes());
220 assert_eq!(
221 frame_len_from_header(&header),
222 Err(FrameError::InvalidLength(15))
223 );
224
225 header[..4].copy_from_slice(&(MAX_FRAME_SIZE + 1).to_le_bytes());
226 assert_eq!(
227 frame_len_from_header(&header),
228 Err(FrameError::InvalidLength(MAX_FRAME_SIZE + 1))
229 );
230 }
231
232 #[test]
233 fn round_trip_with_payload() {
234 round_trip(Frame::new(MessageKind::Query, 42, b"SELECT 1".to_vec()));
235 }
236
237 #[test]
238 fn decode_frame_parts_matches_full_buffer_decode() {
239 let frame = Frame::new(MessageKind::Result, 42, br#"{"ok":true}"#.to_vec());
240 let bytes = encode_frame(&frame);
241 let mut header = [0u8; FRAME_HEADER_SIZE];
242 header.copy_from_slice(&bytes[..FRAME_HEADER_SIZE]);
243 let payload = &bytes[FRAME_HEADER_SIZE..];
244
245 let decoded = decode_frame_parts(&header, payload).expect("decode parts");
246 assert_eq!(decoded, frame);
247 }
248
249 #[test]
250 fn round_trip_with_stream_and_flags() {
251 let frame = Frame::new(MessageKind::Result, 999, vec![0xab; 256])
252 .with_stream(7)
253 .with_flags(Flags::COMPRESSED | Flags::MORE_FRAMES);
254 round_trip(frame);
255 }
256
257 #[test]
258 fn truncated_header_rejected() {
259 assert_eq!(decode_frame(&[]), Err(FrameError::Truncated));
260 assert_eq!(decode_frame(&[0; 15]), Err(FrameError::Truncated));
261 }
262
263 #[test]
264 fn length_below_header_rejected() {
265 let mut bytes = vec![0u8; 16];
266 bytes[..4].copy_from_slice(&15u32.to_le_bytes());
267 assert!(matches!(
268 decode_frame(&bytes),
269 Err(FrameError::InvalidLength(15))
270 ));
271 }
272
273 #[test]
274 fn unknown_kind_rejected() {
275 let mut bytes = vec![0u8; 16];
276 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
277 bytes[4] = 0xff;
278 assert_eq!(decode_frame(&bytes), Err(FrameError::UnknownKind(0xff)));
279 }
280
281 #[test]
282 fn unknown_flag_bits_rejected() {
283 let mut bytes = vec![0u8; 16];
284 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
285 bytes[4] = MessageKind::Ping as u8;
286 bytes[5] = 0b1000_0000;
287 assert!(matches!(
288 decode_frame(&bytes),
289 Err(FrameError::UnknownFlags(_))
290 ));
291 }
292
293 #[test]
294 fn flags_not_allowed_for_kind_rejected() {
295 let mut bytes = vec![0u8; 16];
299 bytes[..4].copy_from_slice(&16u32.to_le_bytes());
300 bytes[4] = MessageKind::Ping as u8;
301 bytes[5] = Flags::COMPRESSED.bits();
302 match decode_frame(&bytes) {
303 Err(FrameError::FlagsNotAllowedForKind { kind, flags }) => {
304 assert_eq!(kind, MessageKind::Ping as u8);
305 assert_eq!(flags, Flags::COMPRESSED.bits());
306 }
307 other => panic!("expected FlagsNotAllowedForKind, got {other:?}"),
308 }
309 }
310
311 #[test]
312 fn streaming_decode_two_frames_back_to_back() {
313 let f1 = Frame::new(MessageKind::Query, 1, b"a".to_vec());
314 let f2 = Frame::new(MessageKind::Query, 2, b"b".to_vec());
315 let mut buf = encode_frame(&f1);
316 buf.extend(encode_frame(&f2));
317 let (got1, n1) = decode_frame(&buf).unwrap();
318 let (got2, _n2) = decode_frame(&buf[n1..]).unwrap();
319 assert_eq!(got1, f1);
320 assert_eq!(got2, f2);
321 }
322
323 #[test]
324 fn compressed_round_trip_recovers_plaintext() {
325 let payload = b"abcabcabcabc".repeat(100);
327 let frame =
328 Frame::new(MessageKind::Result, 7, payload.clone()).with_flags(Flags::COMPRESSED);
329 let bytes = encode_frame(&frame);
330 assert!(
332 bytes.len() < FRAME_HEADER_SIZE + payload.len(),
333 "compressed frame ({}) must be smaller than plaintext payload ({})",
334 bytes.len(),
335 payload.len(),
336 );
337 let (decoded, _) = decode_frame(&bytes).expect("decode compressed");
338 assert_eq!(decoded.payload, payload);
339 assert!(decoded.flags.contains(Flags::COMPRESSED));
340 }
341
342 #[test]
343 fn output_stream_lifecycle_envelopes_round_trip() {
344 let open = Frame::new(
350 MessageKind::OpenStream,
351 10,
352 br#"{"sql":"SELECT 1","opts":{}}"#.to_vec(),
353 )
354 .with_stream(7);
355 round_trip(open.clone());
356 assert_eq!(encode_frame(&open)[4], 0x29);
357
358 let ack = Frame::new(
359 MessageKind::OpenAck,
360 10,
361 br#"{"lease_handle":"42","resumable":false,"snapshot_lsn":1234}"#.to_vec(),
362 )
363 .with_stream(7);
364 round_trip(ack.clone());
365 assert_eq!(encode_frame(&ack)[4], 0x2A);
366
367 let chunk = Frame::new(
368 MessageKind::StreamChunk,
369 10,
370 br#"{"seq":0,"rows":[{"a":1}],"terminal":false}"#.to_vec(),
371 )
372 .with_stream(7);
373 round_trip(chunk.clone());
374 assert_eq!(encode_frame(&chunk)[4], 0x2B);
375
376 let serr = Frame::new(
377 MessageKind::StreamError,
378 10,
379 br#"{"code":"unknown_stream","message":"x"}"#.to_vec(),
380 )
381 .with_stream(7);
382 round_trip(serr.clone());
383 assert_eq!(encode_frame(&serr)[4], 0x2C);
384
385 let end = Frame::new(
386 MessageKind::StreamEnd,
387 10,
388 br#"{"stats":{"row_count":1}}"#.to_vec(),
389 )
390 .with_stream(7);
391 round_trip(end.clone());
392 assert_eq!(encode_frame(&end)[4], 0x25);
393
394 let cancel = Frame::new(
395 MessageKind::StreamCancel,
396 10,
397 br#"{"reason":"client-abort"}"#.to_vec(),
398 )
399 .with_stream(7);
400 round_trip(cancel.clone());
401 assert_eq!(encode_frame(&cancel)[4], 0x2D);
402 }
403
404 #[test]
405 fn input_stream_envelopes_round_trip() {
406 let open_in = Frame::new(
415 MessageKind::OpenStream,
416 20,
417 br#"{"direction":"in","target":"t","columns":["id","name"]}"#.to_vec(),
418 )
419 .with_stream(5);
420 round_trip(open_in.clone());
421 assert_eq!(encode_frame(&open_in)[4], 0x29);
422
423 let chunk_in = Frame::new(
427 MessageKind::StreamChunk,
428 20,
429 br#"{"seq":0,"rows":[{"id":1,"name":"a"}],"terminal":false}"#.to_vec(),
430 )
431 .with_stream(5);
432 round_trip(chunk_in.clone());
433 assert_eq!(encode_frame(&chunk_in)[4], 0x2B);
434
435 let chunk_terminal = Frame::new(
437 MessageKind::StreamChunk,
438 20,
439 br#"{"seq":2,"rows":[],"terminal":true}"#.to_vec(),
440 )
441 .with_stream(5);
442 round_trip(chunk_terminal.clone());
443 assert_eq!(encode_frame(&chunk_terminal)[4], 0x2B);
444
445 let end = Frame::new(
447 MessageKind::StreamEnd,
448 20,
449 br#"{"stats":{"row_count":3,"chunk_count":2,"committed_rid":42,"snapshot_lsn":40,"cancelled":false}}"#.to_vec(),
450 )
451 .with_stream(5);
452 round_trip(end.clone());
453 assert_eq!(encode_frame(&end)[4], 0x25);
454
455 let serr = Frame::new(
457 MessageKind::StreamError,
458 20,
459 br#"{"code":"invalid_row","message":"x","chunk_seq":1,"recoverable_rid":41}"#.to_vec(),
460 )
461 .with_stream(5);
462 round_trip(serr.clone());
463 assert_eq!(encode_frame(&serr)[4], 0x2C);
464 }
465
466 #[test]
467 fn queue_wait_envelopes_round_trip() {
468 let open = Frame::new(
474 MessageKind::QueueWaitOpen,
475 10,
476 br#"{"queue":"jobs","consumer":"w1","count":1,"wait_ms":5000}"#.to_vec(),
477 )
478 .with_stream(3);
479 round_trip(open.clone());
480 assert_eq!(encode_frame(&open)[4], 0x2E);
481
482 let push = Frame::new(
483 MessageKind::QueueEventPush,
484 10,
485 br#"{"message_id":"42","payload":{"hello":"world"},"consumer":"w1","delivery_count":1}"#
486 .to_vec(),
487 )
488 .with_stream(3);
489 round_trip(push.clone());
490 assert_eq!(encode_frame(&push)[4], 0x2F);
491 }
492
493 #[test]
494 fn uncompressed_frame_decodes_unchanged_when_flag_unset() {
495 let payload = b"hello world".to_vec();
496 let frame = Frame::new(MessageKind::Result, 1, payload.clone());
497 let bytes = encode_frame(&frame);
498 let (decoded, _) = decode_frame(&bytes).unwrap();
499 assert_eq!(decoded.payload, payload);
500 assert!(!decoded.flags.contains(Flags::COMPRESSED));
501 }
502}