1use std::io::{Read, Write};
26
27use crate::{DEFAULT_MAX_FRAME_PAYLOAD, FrameError, PROTOCOL_VERSION};
28
29pub const FRAME_TYPE_WINDOW: u8 = b'W';
31pub const FRAME_TYPE_JSON: u8 = b'J';
33pub const FRAME_TYPE_COMPRESSED: u8 = b'C';
35pub const FRAME_TYPE_ACK: u8 = b'A';
37pub const FRAME_TYPE_DATA_LEGACY: u8 = b'D';
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum FrameType {
47 Window,
50 Json,
52 Compressed,
54 Ack,
57}
58
59impl FrameType {
60 #[must_use]
62 pub const fn wire_byte(self) -> u8 {
63 match self {
64 Self::Window => FRAME_TYPE_WINDOW,
65 Self::Json => FRAME_TYPE_JSON,
66 Self::Compressed => FRAME_TYPE_COMPRESSED,
67 Self::Ack => FRAME_TYPE_ACK,
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74#[non_exhaustive]
75pub enum Frame {
76 Window {
79 count: u32,
81 },
82 Json {
84 seq: u32,
86 payload: Vec<u8>,
89 },
90 Compressed {
110 decompressed: Vec<u8>,
112 },
113 Ack {
115 seq: u32,
117 },
118 Unknown {
123 frame_type: u8,
125 raw: Vec<u8>,
127 },
128}
129
130#[must_use]
136pub fn encode_window(count: u32) -> [u8; 6] {
137 let mut out = [0u8; 6];
138 out[0] = PROTOCOL_VERSION;
139 out[1] = FRAME_TYPE_WINDOW;
140 out[2..6].copy_from_slice(&count.to_be_bytes());
141 out
142}
143
144#[must_use]
146pub fn encode_ack(seq: u32) -> [u8; 6] {
147 let mut out = [0u8; 6];
148 out[0] = PROTOCOL_VERSION;
149 out[1] = FRAME_TYPE_ACK;
150 out[2..6].copy_from_slice(&seq.to_be_bytes());
151 out
152}
153
154#[must_use]
159pub fn encode_json_frame(seq: u32, payload: &[u8]) -> Vec<u8> {
160 let mut out = Vec::with_capacity(10 + payload.len());
161 out.push(PROTOCOL_VERSION);
162 out.push(FRAME_TYPE_JSON);
163 out.extend_from_slice(&seq.to_be_bytes());
164 let len = u32::try_from(payload.len()).unwrap_or(u32::MAX);
168 out.extend_from_slice(&len.to_be_bytes());
169 out.extend_from_slice(payload);
170 out
171}
172
173pub fn encode_compressed(level: u32, inner_frames: &[u8]) -> Result<Vec<u8>, FrameError> {
181 use flate2::Compression;
182 use flate2::write::ZlibEncoder;
183
184 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(level));
185 encoder
186 .write_all(inner_frames)
187 .map_err(|e| FrameError::Compression(e.to_string()))?;
188 let compressed = encoder
189 .finish()
190 .map_err(|e| FrameError::Compression(e.to_string()))?;
191
192 let len = u32::try_from(compressed.len()).map_err(|_| FrameError::PayloadTooLarge {
193 requested: compressed.len(),
194 limit: u32::MAX as usize,
195 })?;
196
197 let mut out = Vec::with_capacity(6 + compressed.len());
198 out.push(PROTOCOL_VERSION);
199 out.push(FRAME_TYPE_COMPRESSED);
200 out.extend_from_slice(&len.to_be_bytes());
201 out.extend_from_slice(&compressed);
202 Ok(out)
203}
204
205#[derive(Debug)]
230pub struct FrameDecoder {
231 buf: Vec<u8>,
232 read_pos: usize,
235 max_frame_payload: usize,
238}
239
240impl Default for FrameDecoder {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl FrameDecoder {
247 #[must_use]
249 pub const fn new() -> Self {
250 Self::with_max_frame_payload(DEFAULT_MAX_FRAME_PAYLOAD)
251 }
252
253 #[must_use]
259 pub const fn with_max_frame_payload(max_frame_payload: usize) -> Self {
260 Self {
261 buf: Vec::new(),
262 read_pos: 0,
263 max_frame_payload,
264 }
265 }
266
267 pub fn feed(&mut self, bytes: &[u8]) {
270 if self.read_pos > 0 && self.read_pos >= self.buf.len() / 2 {
273 self.buf.drain(..self.read_pos);
274 self.read_pos = 0;
275 }
276 self.buf.extend_from_slice(bytes);
277 }
278
279 #[must_use]
281 pub const fn pending(&self) -> usize {
282 self.buf.len() - self.read_pos
283 }
284
285 pub fn next_frame(&mut self) -> Result<Option<Frame>, FrameError> {
295 let avail = self.pending();
296 if avail < 2 {
297 return Ok(None);
298 }
299 let header = &self.buf[self.read_pos..self.read_pos + 2];
300 if header[0] != PROTOCOL_VERSION {
301 return Err(FrameError::UnsupportedVersion(header[0]));
302 }
303 let frame_type = header[1];
304 match frame_type {
305 FRAME_TYPE_WINDOW => Ok(self.try_decode_window()),
306 FRAME_TYPE_ACK => Ok(self.try_decode_ack()),
307 FRAME_TYPE_JSON => self.try_decode_json(),
308 FRAME_TYPE_COMPRESSED => self.try_decode_compressed(),
309 FRAME_TYPE_DATA_LEGACY => self.try_decode_unknown_with_seq_count(b'D'),
310 other => Err(FrameError::UnknownFrameType(other)),
311 }
312 }
313
314 fn read_at<const M: usize>(&self, offset: usize) -> Option<[u8; M]> {
316 let start = self.read_pos + offset;
317 if self.buf.len() < start + M {
318 return None;
319 }
320 let mut out = [0u8; M];
321 out.copy_from_slice(&self.buf[start..start + M]);
322 Some(out)
323 }
324
325 fn try_decode_window(&mut self) -> Option<Frame> {
326 if self.pending() < 6 {
328 return None;
329 }
330 let count = u32::from_be_bytes(
331 self.read_at::<4>(2)
332 .expect("just verified ≥ 6 bytes pending"),
333 );
334 self.read_pos += 6;
335 Some(Frame::Window { count })
336 }
337
338 fn try_decode_ack(&mut self) -> Option<Frame> {
339 if self.pending() < 6 {
341 return None;
342 }
343 let seq = u32::from_be_bytes(
344 self.read_at::<4>(2)
345 .expect("just verified ≥ 6 bytes pending"),
346 );
347 self.read_pos += 6;
348 Some(Frame::Ack { seq })
349 }
350
351 fn try_decode_json(&mut self) -> Result<Option<Frame>, FrameError> {
352 if self.pending() < 10 {
354 return Ok(None);
355 }
356 let seq = u32::from_be_bytes(self.read_at::<4>(2).expect("≥ 10 pending"));
357 let len_raw = u32::from_be_bytes(self.read_at::<4>(6).expect("≥ 10 pending"));
358 let len = len_raw as usize;
359 if len > self.max_frame_payload {
360 return Err(FrameError::PayloadTooLarge {
361 requested: len,
362 limit: self.max_frame_payload,
363 });
364 }
365 if self.pending() < 10 + len {
366 return Ok(None);
367 }
368 let start = self.read_pos + 10;
369 let payload = self.buf[start..start + len].to_vec();
370 self.read_pos += 10 + len;
371 Ok(Some(Frame::Json { seq, payload }))
372 }
373
374 fn try_decode_compressed(&mut self) -> Result<Option<Frame>, FrameError> {
375 if self.pending() < 6 {
377 return Ok(None);
378 }
379 let len_raw = u32::from_be_bytes(self.read_at::<4>(2).expect("≥ 6 pending"));
380 let len = len_raw as usize;
381 if len > self.max_frame_payload {
382 return Err(FrameError::PayloadTooLarge {
383 requested: len,
384 limit: self.max_frame_payload,
385 });
386 }
387 if self.pending() < 6 + len {
388 return Ok(None);
389 }
390 let start = self.read_pos + 6;
391 let compressed = &self.buf[start..start + len];
392 let decompressed = decompress_capped(compressed, self.max_frame_payload)?;
393 self.read_pos += 6 + len;
394 Ok(Some(Frame::Compressed { decompressed }))
395 }
396
397 fn try_decode_unknown_with_seq_count(
404 &mut self,
405 type_byte: u8,
406 ) -> Result<Option<Frame>, FrameError> {
407 if self.pending() < 10 {
408 return Ok(None);
409 }
410 let pair_count = u32::from_be_bytes(self.read_at::<4>(6).expect("≥ 10 pending")) as usize;
411
412 let mut cursor = 10;
415 for _ in 0..pair_count {
416 if self.pending() < cursor + 4 {
418 return Ok(None);
419 }
420 let key_len = u32::from_be_bytes(
421 self.read_at::<4>(cursor)
422 .expect("just bounded by pending check"),
423 ) as usize;
424 if key_len > self.max_frame_payload {
425 return Err(FrameError::PayloadTooLarge {
426 requested: key_len,
427 limit: self.max_frame_payload,
428 });
429 }
430 cursor += 4 + key_len;
431
432 if self.pending() < cursor + 4 {
433 return Ok(None);
434 }
435 let val_len = u32::from_be_bytes(
436 self.read_at::<4>(cursor)
437 .expect("just bounded by pending check"),
438 ) as usize;
439 if val_len > self.max_frame_payload {
440 return Err(FrameError::PayloadTooLarge {
441 requested: val_len,
442 limit: self.max_frame_payload,
443 });
444 }
445 cursor += 4 + val_len;
446 }
447
448 if self.pending() < cursor {
449 return Ok(None);
450 }
451 let raw = self.buf[self.read_pos..self.read_pos + cursor].to_vec();
452 self.read_pos += cursor;
453 Ok(Some(Frame::Unknown {
454 frame_type: type_byte,
455 raw,
456 }))
457 }
458}
459
460fn decompress_capped(compressed: &[u8], limit: usize) -> Result<Vec<u8>, FrameError> {
465 use flate2::read::ZlibDecoder;
466
467 let mut out = Vec::new();
470 let take_limit = u64::try_from(limit).unwrap_or(u64::MAX);
471 let take_plus_one = take_limit.saturating_add(1);
472
473 let decoder = ZlibDecoder::new(compressed);
474 let mut bounded = decoder.take(take_plus_one);
475 bounded
476 .read_to_end(&mut out)
477 .map_err(|e| FrameError::Decompression(e.to_string()))?;
478
479 if out.len() > limit {
480 return Err(FrameError::DecompressedTooLarge { limit });
481 }
482
483 Ok(out)
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493 use proptest::prelude::*;
494
495 #[test]
496 fn encode_window_layout() {
497 let bytes = encode_window(42);
498 assert_eq!(bytes[0], b'2');
499 assert_eq!(bytes[1], b'W');
500 assert_eq!(
501 u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
502 42
503 );
504 }
505
506 #[test]
507 fn encode_ack_layout() {
508 let bytes = encode_ack(7);
509 assert_eq!(bytes[0], b'2');
510 assert_eq!(bytes[1], b'A');
511 assert_eq!(
512 u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
513 7
514 );
515 }
516
517 #[test]
518 fn encode_json_frame_layout() {
519 let bytes = encode_json_frame(13, b"hello");
520 assert_eq!(&bytes[..2], b"2J");
521 assert_eq!(
522 u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]),
523 13
524 );
525 assert_eq!(
526 u32::from_be_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]),
527 5
528 );
529 assert_eq!(&bytes[10..], b"hello");
530 }
531
532 #[test]
533 fn decode_window_round_trip() {
534 let mut d = FrameDecoder::new();
535 d.feed(&encode_window(123));
536 let f = d.next_frame().unwrap().unwrap();
537 assert_eq!(f, Frame::Window { count: 123 });
538 assert!(d.next_frame().unwrap().is_none());
539 }
540
541 #[test]
542 fn decode_ack_round_trip() {
543 let mut d = FrameDecoder::new();
544 d.feed(&encode_ack(987_654));
545 assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 987_654 }));
546 }
547
548 #[test]
549 fn decode_json_round_trip() {
550 let mut d = FrameDecoder::new();
551 d.feed(&encode_json_frame(1, br#"{"k":"v"}"#));
552 let f = d.next_frame().unwrap().unwrap();
553 let Frame::Json { seq, payload } = f else {
554 panic!("expected Json")
555 };
556 assert_eq!(seq, 1);
557 assert_eq!(payload, br#"{"k":"v"}"#);
558 }
559
560 #[test]
561 fn decode_handles_concatenated_frames() {
562 let mut d = FrameDecoder::new();
563 let mut feed = Vec::new();
564 feed.extend_from_slice(&encode_window(2));
565 feed.extend_from_slice(&encode_json_frame(1, b"a"));
566 feed.extend_from_slice(&encode_json_frame(2, b"bb"));
567 feed.extend_from_slice(&encode_ack(2));
568 d.feed(&feed);
569
570 assert_eq!(d.next_frame().unwrap(), Some(Frame::Window { count: 2 }));
571 let Some(Frame::Json { seq: 1, payload }) = d.next_frame().unwrap() else {
572 panic!()
573 };
574 assert_eq!(payload, b"a");
575 let Some(Frame::Json { seq: 2, payload }) = d.next_frame().unwrap() else {
576 panic!()
577 };
578 assert_eq!(payload, b"bb");
579 assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 2 }));
580 assert!(d.next_frame().unwrap().is_none());
581 }
582
583 #[test]
584 fn decode_handles_byte_at_a_time_feeds() {
585 let mut d = FrameDecoder::new();
587 let frame = encode_json_frame(5, b"abcdefgh");
588 for byte in &frame {
589 assert!(d.next_frame().unwrap().is_none());
590 d.feed(std::slice::from_ref(byte));
591 }
592 let Frame::Json { seq, payload } = d.next_frame().unwrap().unwrap() else {
593 panic!()
594 };
595 assert_eq!(seq, 5);
596 assert_eq!(payload, b"abcdefgh");
597 }
598
599 #[test]
600 fn decode_compressed_round_trip() {
601 let inner = [
602 encode_json_frame(1, b"hello").as_slice(),
603 encode_json_frame(2, b"world").as_slice(),
604 ]
605 .concat();
606 let outer = encode_compressed(6, &inner).unwrap();
607
608 let mut d = FrameDecoder::new();
609 d.feed(&outer);
610 let Frame::Compressed { decompressed } = d.next_frame().unwrap().unwrap() else {
611 panic!()
612 };
613 assert_eq!(decompressed, inner);
614 }
615
616 #[test]
617 fn decode_rejects_bad_version() {
618 let mut d = FrameDecoder::new();
619 d.feed(&[b'1', b'W', 0, 0, 0, 1]);
620 assert!(matches!(
621 d.next_frame(),
622 Err(FrameError::UnsupportedVersion(b'1'))
623 ));
624 }
625
626 #[test]
627 fn decode_rejects_unknown_frame_type() {
628 let mut d = FrameDecoder::new();
629 d.feed(&[b'2', b'Z', 0, 0, 0, 1]);
630 assert!(matches!(
631 d.next_frame(),
632 Err(FrameError::UnknownFrameType(b'Z'))
633 ));
634 }
635
636 #[test]
637 fn decode_caps_oversize_json_payload() {
638 let mut d = FrameDecoder::with_max_frame_payload(16);
639 let mut buf = vec![b'2', b'J', 0, 0, 0, 1];
641 buf.extend_from_slice(&100u32.to_be_bytes());
642 d.feed(&buf);
643 assert!(matches!(
644 d.next_frame(),
645 Err(FrameError::PayloadTooLarge { .. })
646 ));
647 }
648
649 #[test]
650 fn decode_caps_zlib_bomb() {
651 let original = vec![0u8; 1024 * 64];
654 let frame = encode_compressed(9, &original).unwrap();
655 let mut d = FrameDecoder::with_max_frame_payload(1024); d.feed(&frame);
657 match d.next_frame() {
658 Err(FrameError::DecompressedTooLarge { .. } | FrameError::PayloadTooLarge { .. }) => {}
662 other => panic!("expected size-related error, got {other:?}"),
663 }
664 }
665
666 #[test]
667 fn legacy_d_frame_is_decoded_as_unknown_and_advances() {
668 let mut frame = Vec::new();
670 frame.push(b'2');
671 frame.push(b'D');
672 frame.extend_from_slice(&5u32.to_be_bytes()); frame.extend_from_slice(&1u32.to_be_bytes()); frame.extend_from_slice(&3u32.to_be_bytes());
676 frame.extend_from_slice(b"foo");
677 frame.extend_from_slice(&3u32.to_be_bytes());
679 frame.extend_from_slice(b"bar");
680
681 frame.extend_from_slice(&encode_ack(5));
683
684 let mut d = FrameDecoder::new();
685 d.feed(&frame);
686 let f = d.next_frame().unwrap().unwrap();
687 let Frame::Unknown { frame_type, raw } = f else {
688 panic!()
689 };
690 assert_eq!(frame_type, b'D');
691 assert_eq!(&raw[..2], b"2D");
692 assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq: 5 }));
694 }
695
696 #[test]
697 fn decoder_compacts_after_consuming_half() {
698 let mut d = FrameDecoder::new();
699 for _ in 0..32 {
700 d.feed(&encode_ack(1));
701 let _ = d.next_frame().unwrap();
702 }
703 assert!(d.buf.capacity() < 1024, "buf cap = {}", d.buf.capacity());
706 }
707
708 proptest! {
709 #[test]
710 fn prop_json_frame_round_trip(seq: u32, payload: Vec<u8>) {
711 let bytes = encode_json_frame(seq, &payload);
712 let mut d = FrameDecoder::new();
713 d.feed(&bytes);
714 let frame = d.next_frame().unwrap().unwrap();
715 let Frame::Json { seq: got_seq, payload: got_payload } = frame else {
716 panic!("expected Json")
717 };
718 prop_assert_eq!(got_seq, seq);
719 prop_assert_eq!(got_payload, payload);
720 prop_assert!(d.next_frame().unwrap().is_none());
721 }
722
723 #[test]
724 fn prop_window_round_trip(count: u32) {
725 let bytes = encode_window(count);
726 let mut d = FrameDecoder::new();
727 d.feed(&bytes);
728 prop_assert_eq!(d.next_frame().unwrap(), Some(Frame::Window { count }));
729 }
730
731 #[test]
732 fn prop_ack_round_trip(seq: u32) {
733 let bytes = encode_ack(seq);
734 let mut d = FrameDecoder::new();
735 d.feed(&bytes);
736 prop_assert_eq!(d.next_frame().unwrap(), Some(Frame::Ack { seq }));
737 }
738
739 #[test]
740 fn prop_compressed_round_trip(payloads in proptest::collection::vec(any::<Vec<u8>>(), 1..16)) {
741 let mut inner = Vec::new();
742 for (i, p) in payloads.iter().enumerate() {
743 let seq = u32::try_from(i + 1).unwrap_or(u32::MAX);
744 inner.extend_from_slice(&encode_json_frame(seq, p));
745 }
746 let outer = encode_compressed(3, &inner).unwrap();
747 let mut d = FrameDecoder::new();
748 d.feed(&outer);
749 let Some(Frame::Compressed { decompressed }) = d.next_frame().unwrap() else {
750 panic!()
751 };
752 prop_assert_eq!(decompressed, inner);
753 }
754
755 #[test]
757 fn prop_decoder_does_not_panic(bytes in proptest::collection::vec(any::<u8>(), 0..4096)) {
758 let mut d = FrameDecoder::with_max_frame_payload(8 * 1024);
759 d.feed(&bytes);
760 for _ in 0..1024 {
762 match d.next_frame() {
763 Ok(Some(_)) => {}
764 Ok(None) | Err(_) => break,
765 }
766 }
767 }
768 }
769}