use bytes::{Buf, BytesMut};
#[derive(Debug, Clone, PartialEq)]
pub(super) struct Frame {
pub event_type: String,
pub exception_type: Option<String>,
pub payload: Vec<u8>,
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub(super) enum FrameError {
#[error("frame too large: {0} bytes")]
TooLarge(u32),
#[error("malformed frame: total={total} headers={headers}")]
Malformed {
total: u32,
headers: u32,
},
}
const MAX_FRAME_BYTES: u32 = 1024 * 1024;
#[derive(Debug, Default)]
pub(super) struct EventStreamDecoder {
buf: BytesMut,
}
impl EventStreamDecoder {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) fn feed(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
}
pub(super) fn next_frame(&mut self) -> Result<Option<Frame>, FrameError> {
if self.buf.len() < 12 {
return Ok(None);
}
let total_len = u32::from_be_bytes([self.buf[0], self.buf[1], self.buf[2], self.buf[3]]);
let headers_len = u32::from_be_bytes([self.buf[4], self.buf[5], self.buf[6], self.buf[7]]);
if total_len > MAX_FRAME_BYTES {
return Err(FrameError::TooLarge(total_len));
}
if total_len < 16 || headers_len > total_len.saturating_sub(16) {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
}
if (self.buf.len() as u32) < total_len {
return Ok(None);
}
self.buf.advance(12);
let mut headers_remaining = headers_len as usize;
let payload_len = total_len as usize - 16 - headers_len as usize;
let mut event_type = String::new();
let mut exception_type: Option<String> = None;
let mut message_type: Option<String> = None;
while headers_remaining > 0 {
if self.buf.is_empty() {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
}
let name_len = self.buf[0] as usize;
if 1 + name_len + 1 > headers_remaining || 1 + name_len + 1 > self.buf.len() {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
}
let name_bytes = self.buf[1..1 + name_len].to_vec();
let value_type = self.buf[1 + name_len];
let header_fixed = 1 + name_len + 1;
self.buf.advance(header_fixed);
headers_remaining -= header_fixed;
let value_str = if value_type == 7 {
if self.buf.len() < 2 || headers_remaining < 2 {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
}
let vlen = u16::from_be_bytes([self.buf[0], self.buf[1]]) as usize;
if 2 + vlen > headers_remaining || 2 + vlen > self.buf.len() {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
}
let v = self.buf[2..2 + vlen].to_vec();
self.buf.advance(2 + vlen);
headers_remaining -= 2 + vlen;
Some(String::from_utf8_lossy(&v).into_owned())
} else {
return Err(FrameError::Malformed {
total: total_len,
headers: headers_len,
});
};
let name = String::from_utf8_lossy(&name_bytes);
match (name.as_ref(), value_str.as_deref()) {
(":event-type", Some(v)) => event_type = v.to_string(),
(":exception-type", Some(v)) => exception_type = Some(v.to_string()),
(":message-type", Some(v)) => message_type = Some(v.to_string()),
_ => {}
}
}
let mut payload = vec![0u8; payload_len];
self.buf.copy_to_slice(&mut payload);
self.buf.advance(4);
if message_type.as_deref() != Some("exception") {
exception_type = None;
}
Ok(Some(Frame {
event_type,
exception_type,
payload,
}))
}
}
#[derive(Debug, serde::Deserialize)]
pub(super) struct ChunkPayload {
pub bytes: String,
}
#[cfg(test)]
mod tests {
use super::*;
fn build_frame(event_type: &str, payload: &[u8]) -> Vec<u8> {
let mut header = Vec::new();
header.push(":event-type".len() as u8);
header.extend_from_slice(b":event-type");
header.push(7u8); header.extend_from_slice(&(event_type.len() as u16).to_be_bytes());
header.extend_from_slice(event_type.as_bytes());
let total = 12 + header.len() + payload.len() + 4;
let mut frame = Vec::with_capacity(total);
frame.extend_from_slice(&(total as u32).to_be_bytes());
frame.extend_from_slice(&(header.len() as u32).to_be_bytes());
frame.extend_from_slice(&[0u8; 4]); frame.extend_from_slice(&header);
frame.extend_from_slice(payload);
frame.extend_from_slice(&[0u8; 4]); frame
}
fn build_exception_frame(exception_type: &str, payload: &[u8]) -> Vec<u8> {
let mut header = Vec::new();
header.push(":event-type".len() as u8);
header.extend_from_slice(b":event-type");
header.push(7u8);
header.extend_from_slice(&("error".len() as u16).to_be_bytes());
header.extend_from_slice(b"error");
header.push(":message-type".len() as u8);
header.extend_from_slice(b":message-type");
header.push(7u8);
header.extend_from_slice(&("exception".len() as u16).to_be_bytes());
header.extend_from_slice(b"exception");
header.push(":exception-type".len() as u8);
header.extend_from_slice(b":exception-type");
header.push(7u8);
header.extend_from_slice(&(exception_type.len() as u16).to_be_bytes());
header.extend_from_slice(exception_type.as_bytes());
let total = 12 + header.len() + payload.len() + 4;
let mut frame = Vec::with_capacity(total);
frame.extend_from_slice(&(total as u32).to_be_bytes());
frame.extend_from_slice(&(header.len() as u32).to_be_bytes());
frame.extend_from_slice(&[0u8; 4]);
frame.extend_from_slice(&header);
frame.extend_from_slice(payload);
frame.extend_from_slice(&[0u8; 4]);
frame
}
#[test]
fn parses_one_chunk_frame() {
let frame_bytes = build_frame("chunk", br#"{"bytes":"aGVsbG8="}"#);
let mut dec = EventStreamDecoder::new();
dec.feed(&frame_bytes);
let f = dec.next_frame().unwrap().unwrap();
assert_eq!(f.event_type, "chunk");
assert_eq!(f.payload, br#"{"bytes":"aGVsbG8="}"#);
assert!(dec.next_frame().unwrap().is_none());
}
#[test]
fn parses_two_back_to_back_frames() {
let mut buf = build_frame("chunk", br#"{"bytes":"YQ=="}"#);
buf.extend_from_slice(&build_frame("chunk", br#"{"bytes":"Yg=="}"#));
let mut dec = EventStreamDecoder::new();
dec.feed(&buf);
let a = dec.next_frame().unwrap().unwrap();
let b = dec.next_frame().unwrap().unwrap();
assert_eq!(a.event_type, "chunk");
assert_eq!(b.event_type, "chunk");
assert!(dec.next_frame().unwrap().is_none());
}
#[test]
fn handles_partial_feed() {
let frame = build_frame("chunk", br#"{"bytes":"YQ=="}"#);
let mut dec = EventStreamDecoder::new();
for chunk in frame.chunks(3) {
assert!(dec.next_frame().unwrap().is_none() || dec.next_frame().unwrap().is_some());
dec.feed(chunk);
}
let f = dec.next_frame().unwrap().unwrap();
assert_eq!(f.event_type, "chunk");
}
#[test]
fn surfaces_exception_type() {
let frame = build_exception_frame("ThrottlingException", br#"{"message":"slow down"}"#);
let mut dec = EventStreamDecoder::new();
dec.feed(&frame);
let f = dec.next_frame().unwrap().unwrap();
assert_eq!(f.event_type, "error");
assert_eq!(f.exception_type.as_deref(), Some("ThrottlingException"));
}
#[test]
fn rejects_overlarge_frame() {
let mut dec = EventStreamDecoder::new();
let mut prelude = Vec::new();
prelude.extend_from_slice(&10_000_000u32.to_be_bytes());
prelude.extend_from_slice(&0u32.to_be_bytes());
prelude.extend_from_slice(&[0u8; 4]);
dec.feed(&prelude);
let err = dec.next_frame().unwrap_err();
assert!(matches!(err, FrameError::TooLarge(10_000_000)));
}
#[test]
fn parses_chunk_payload_json() {
let f = build_frame("chunk", br#"{"bytes":"aGVsbG8="}"#);
let mut dec = EventStreamDecoder::new();
dec.feed(&f);
let frame = dec.next_frame().unwrap().unwrap();
let chunk: ChunkPayload = serde_json::from_slice(&frame.payload).unwrap();
assert_eq!(chunk.bytes, "aGVsbG8=");
}
}