1use crate::buffer::{BufferKind, BufferMut, ZeroCopyBuffer};
20
21#[derive(Debug)]
22pub enum WsBinaryError {
23 OrphanContinuation,
26 MessageTooLarge {
28 limit: usize,
29 },
30}
31
32impl std::fmt::Display for WsBinaryError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::OrphanContinuation => {
36 write!(f, "continuation frame without opener")
37 }
38 Self::MessageTooLarge { limit } => {
39 write!(f, "binary message exceeded {limit} bytes")
40 }
41 }
42 }
43}
44
45impl std::error::Error for WsBinaryError {}
46
47#[derive(Debug, Clone)]
49pub struct WsBinaryLimits {
50 pub max_message_bytes: usize,
51}
52
53impl Default for WsBinaryLimits {
54 fn default() -> Self {
55 WsBinaryLimits {
59 max_message_bytes: 128 * 1024 * 1024,
60 }
61 }
62}
63
64pub struct WsBinaryAccumulator {
68 buffer: Option<BufferMut>,
69 kind: BufferKind,
70 limits: WsBinaryLimits,
71 tenant_id: Option<String>,
72}
73
74impl WsBinaryAccumulator {
75 pub fn new(kind: BufferKind, limits: WsBinaryLimits) -> Self {
76 WsBinaryAccumulator {
77 buffer: None,
78 kind,
79 limits,
80 tenant_id: None,
81 }
82 }
83
84 pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
85 self.tenant_id = Some(tenant_id.into());
86 self
87 }
88
89 pub fn feed(
96 &mut self,
97 opcode: u8,
98 is_final: bool,
99 payload: &[u8],
100 ) -> Result<Option<ZeroCopyBuffer>, WsBinaryError> {
101 match opcode {
102 0x2 => {
103 let mut body = BufferMut::with_capacity(
108 payload.len().max(4 * 1024),
109 self.kind.clone(),
110 );
111 if let Some(tenant) = &self.tenant_id {
112 body = body.with_tenant(tenant.as_str());
113 }
114 if payload.len() > self.limits.max_message_bytes {
115 return Err(WsBinaryError::MessageTooLarge {
116 limit: self.limits.max_message_bytes,
117 });
118 }
119 body.extend_from_slice(payload);
120 self.buffer = Some(body);
121 }
122 0x0 => {
123 let Some(buf) = self.buffer.as_mut() else {
125 return Err(WsBinaryError::OrphanContinuation);
126 };
127 if buf.len() + payload.len() > self.limits.max_message_bytes {
128 return Err(WsBinaryError::MessageTooLarge {
129 limit: self.limits.max_message_bytes,
130 });
131 }
132 buf.extend_from_slice(payload);
133 }
134 _ => {
135 return Ok(None);
138 }
139 }
140
141 if is_final {
142 let body = self.buffer.take().expect("buffer present in final frame");
143 Ok(Some(body.freeze()))
144 } else {
145 Ok(None)
146 }
147 }
148
149 pub fn reset(&mut self) {
150 self.buffer = None;
151 }
152
153 pub fn pending_bytes(&self) -> usize {
155 self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn single_frame_message() {
165 let mut acc = WsBinaryAccumulator::new(
166 BufferKind::pcm16(),
167 WsBinaryLimits::default(),
168 );
169 let result = acc
170 .feed(0x2, true, b"hello")
171 .unwrap()
172 .expect("buffer");
173 assert_eq!(result.as_slice(), b"hello");
174 assert_eq!(result.kind().slug(), "pcm16");
175 }
176
177 #[test]
178 fn fragmented_message_is_stitched() {
179 let mut acc = WsBinaryAccumulator::new(
180 BufferKind::raw(),
181 WsBinaryLimits::default(),
182 );
183 assert!(acc.feed(0x2, false, b"he").unwrap().is_none());
184 assert!(acc.feed(0x0, false, b"ll").unwrap().is_none());
185 let end = acc
186 .feed(0x0, true, b"o")
187 .unwrap()
188 .expect("buffer on FIN");
189 assert_eq!(end.as_slice(), b"hello");
190 }
191
192 #[test]
193 fn orphan_continuation_errors() {
194 let mut acc = WsBinaryAccumulator::new(
195 BufferKind::raw(),
196 WsBinaryLimits::default(),
197 );
198 let err = acc.feed(0x0, true, b"x").unwrap_err();
199 matches!(err, WsBinaryError::OrphanContinuation);
200 }
201
202 #[test]
203 fn message_too_large_errors() {
204 let mut acc = WsBinaryAccumulator::new(
205 BufferKind::raw(),
206 WsBinaryLimits {
207 max_message_bytes: 4,
208 },
209 );
210 let err = acc.feed(0x2, true, b"too-big").unwrap_err();
211 matches!(err, WsBinaryError::MessageTooLarge { .. });
212 }
213
214 #[test]
215 fn partial_then_oversize_errors_on_second_frame() {
216 let mut acc = WsBinaryAccumulator::new(
217 BufferKind::raw(),
218 WsBinaryLimits {
219 max_message_bytes: 4,
220 },
221 );
222 assert!(acc.feed(0x2, false, b"ok").unwrap().is_none());
223 let err = acc.feed(0x0, true, b"xxxx").unwrap_err();
224 matches!(err, WsBinaryError::MessageTooLarge { .. });
225 }
226
227 #[test]
228 fn tenant_tag_propagates_into_buffer() {
229 let mut acc = WsBinaryAccumulator::new(
230 BufferKind::raw(),
231 WsBinaryLimits::default(),
232 )
233 .with_tenant("alpha");
234 let out = acc
235 .feed(0x2, true, b"payload")
236 .unwrap()
237 .expect("buffer");
238 assert_eq!(out.tenant_id(), Some("alpha"));
239 }
240
241 #[test]
242 fn control_opcode_is_ignored() {
243 let mut acc = WsBinaryAccumulator::new(
244 BufferKind::raw(),
245 WsBinaryLimits::default(),
246 );
247 assert!(acc.feed(0x9, true, b"ping").unwrap().is_none());
249 }
250}