use crate::buffer::{BufferKind, BufferMut, ZeroCopyBuffer};
#[derive(Debug)]
pub enum WsBinaryError {
OrphanContinuation,
MessageTooLarge {
limit: usize,
},
}
impl std::fmt::Display for WsBinaryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::OrphanContinuation => {
write!(f, "continuation frame without opener")
}
Self::MessageTooLarge { limit } => {
write!(f, "binary message exceeded {limit} bytes")
}
}
}
}
impl std::error::Error for WsBinaryError {}
#[derive(Debug, Clone)]
pub struct WsBinaryLimits {
pub max_message_bytes: usize,
}
impl Default for WsBinaryLimits {
fn default() -> Self {
WsBinaryLimits {
max_message_bytes: 128 * 1024 * 1024,
}
}
}
pub struct WsBinaryAccumulator {
buffer: Option<BufferMut>,
kind: BufferKind,
limits: WsBinaryLimits,
tenant_id: Option<String>,
}
impl WsBinaryAccumulator {
pub fn new(kind: BufferKind, limits: WsBinaryLimits) -> Self {
WsBinaryAccumulator {
buffer: None,
kind,
limits,
tenant_id: None,
}
}
pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn feed(
&mut self,
opcode: u8,
is_final: bool,
payload: &[u8],
) -> Result<Option<ZeroCopyBuffer>, WsBinaryError> {
match opcode {
0x2 => {
let mut body = BufferMut::with_capacity(
payload.len().max(4 * 1024),
self.kind.clone(),
);
if let Some(tenant) = &self.tenant_id {
body = body.with_tenant(tenant.as_str());
}
if payload.len() > self.limits.max_message_bytes {
return Err(WsBinaryError::MessageTooLarge {
limit: self.limits.max_message_bytes,
});
}
body.extend_from_slice(payload);
self.buffer = Some(body);
}
0x0 => {
let Some(buf) = self.buffer.as_mut() else {
return Err(WsBinaryError::OrphanContinuation);
};
if buf.len() + payload.len() > self.limits.max_message_bytes {
return Err(WsBinaryError::MessageTooLarge {
limit: self.limits.max_message_bytes,
});
}
buf.extend_from_slice(payload);
}
_ => {
return Ok(None);
}
}
if is_final {
let body = self.buffer.take().expect("buffer present in final frame");
Ok(Some(body.freeze()))
} else {
Ok(None)
}
}
pub fn reset(&mut self) {
self.buffer = None;
}
pub fn pending_bytes(&self) -> usize {
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_frame_message() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::pcm16(),
WsBinaryLimits::default(),
);
let result = acc
.feed(0x2, true, b"hello")
.unwrap()
.expect("buffer");
assert_eq!(result.as_slice(), b"hello");
assert_eq!(result.kind().slug(), "pcm16");
}
#[test]
fn fragmented_message_is_stitched() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits::default(),
);
assert!(acc.feed(0x2, false, b"he").unwrap().is_none());
assert!(acc.feed(0x0, false, b"ll").unwrap().is_none());
let end = acc
.feed(0x0, true, b"o")
.unwrap()
.expect("buffer on FIN");
assert_eq!(end.as_slice(), b"hello");
}
#[test]
fn orphan_continuation_errors() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits::default(),
);
let err = acc.feed(0x0, true, b"x").unwrap_err();
matches!(err, WsBinaryError::OrphanContinuation);
}
#[test]
fn message_too_large_errors() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits {
max_message_bytes: 4,
},
);
let err = acc.feed(0x2, true, b"too-big").unwrap_err();
matches!(err, WsBinaryError::MessageTooLarge { .. });
}
#[test]
fn partial_then_oversize_errors_on_second_frame() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits {
max_message_bytes: 4,
},
);
assert!(acc.feed(0x2, false, b"ok").unwrap().is_none());
let err = acc.feed(0x0, true, b"xxxx").unwrap_err();
matches!(err, WsBinaryError::MessageTooLarge { .. });
}
#[test]
fn tenant_tag_propagates_into_buffer() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits::default(),
)
.with_tenant("alpha");
let out = acc
.feed(0x2, true, b"payload")
.unwrap()
.expect("buffer");
assert_eq!(out.tenant_id(), Some("alpha"));
}
#[test]
fn control_opcode_is_ignored() {
let mut acc = WsBinaryAccumulator::new(
BufferKind::raw(),
WsBinaryLimits::default(),
);
assert!(acc.feed(0x9, true, b"ping").unwrap().is_none());
}
}