zerodds_websocket_bridge/
message.rs1use alloc::vec::Vec;
11
12use crate::frame::{Frame, Opcode};
13use crate::utf8::{StreamingValidator, Utf8Error};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SendError {
22 InvalidFrameLimit,
24 InvalidUtf8,
26}
27
28impl core::fmt::Display for SendError {
29 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30 match self {
31 Self::InvalidFrameLimit => write!(f, "InvalidFrameLimit"),
32 Self::InvalidUtf8 => write!(f, "InvalidUtf8"),
33 }
34 }
35}
36
37#[cfg(feature = "std")]
38impl std::error::Error for SendError {}
39
40pub fn fragment_message(
61 is_text: bool,
62 payload: &[u8],
63 max_frame_payload: usize,
64 mask: [u8; 4],
65) -> Result<Vec<Frame>, SendError> {
66 if max_frame_payload == 0 {
67 return Err(SendError::InvalidFrameLimit);
68 }
69 if is_text {
70 crate::utf8::validate(payload).map_err(|_| SendError::InvalidUtf8)?;
71 }
72
73 if payload.is_empty() {
74 return Ok(alloc::vec![Frame {
77 fin: true,
78 rsv1: false,
79 rsv2: false,
80 rsv3: false,
81 opcode: if is_text {
82 Opcode::Text
83 } else {
84 Opcode::Binary
85 },
86 masking_key: if mask == [0; 4] { None } else { Some(mask) },
87 payload: alloc::vec![],
88 }]);
89 }
90
91 let mut frames = Vec::new();
92 let mut offset = 0;
93 let mut first = true;
94
95 while offset < payload.len() {
96 let chunk_end = (offset + max_frame_payload).min(payload.len());
97 let chunk = &payload[offset..chunk_end];
98 let is_last = chunk_end == payload.len();
99
100 let opcode = if first {
101 if is_text {
102 Opcode::Text
103 } else {
104 Opcode::Binary
105 }
106 } else {
107 Opcode::Continuation
108 };
109
110 frames.push(Frame {
111 fin: is_last,
112 rsv1: false,
113 rsv2: false,
114 rsv3: false,
115 opcode,
116 masking_key: if mask == [0; 4] { None } else { Some(mask) },
117 payload: chunk.to_vec(),
118 });
119
120 offset = chunk_end;
121 first = false;
122 }
123
124 Ok(frames)
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
133pub enum ReceiveError {
134 UnexpectedContinuation,
136 InterleavedDataFrame,
138 InvalidUtf8(Utf8Error),
140 MessageTooLarge,
142}
143
144impl core::fmt::Display for ReceiveError {
145 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
146 match self {
147 Self::UnexpectedContinuation => write!(f, "UnexpectedContinuation"),
148 Self::InterleavedDataFrame => write!(f, "InterleavedDataFrame"),
149 Self::InvalidUtf8(e) => write!(f, "InvalidUtf8({e})"),
150 Self::MessageTooLarge => write!(f, "MessageTooLarge"),
151 }
152 }
153}
154
155#[cfg(feature = "std")]
156impl std::error::Error for ReceiveError {}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct Message {
161 pub is_text: bool,
163 pub payload: Vec<u8>,
165}
166
167pub struct Reassembler {
169 pending: Option<PendingMessage>,
171 pub max_message_size: usize,
174}
175
176struct PendingMessage {
177 is_text: bool,
178 buffer: Vec<u8>,
179 utf8: StreamingValidator,
180}
181
182impl core::fmt::Debug for Reassembler {
183 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
184 f.debug_struct("Reassembler")
185 .field("pending", &self.pending.is_some())
186 .field("max_message_size", &self.max_message_size)
187 .finish()
188 }
189}
190
191impl Default for Reassembler {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl Reassembler {
198 #[must_use]
201 pub fn new() -> Self {
202 Self {
203 pending: None,
204 max_message_size: usize::MAX,
205 }
206 }
207
208 #[must_use]
210 pub fn with_limit(max_message_size: usize) -> Self {
211 Self {
212 pending: None,
213 max_message_size,
214 }
215 }
216
217 pub fn feed(&mut self, frame: &Frame) -> Result<Option<Message>, ReceiveError> {
228 match frame.opcode {
229 Opcode::Text | Opcode::Binary => {
230 if self.pending.is_some() {
231 return Err(ReceiveError::InterleavedDataFrame);
232 }
233 let is_text = frame.opcode == Opcode::Text;
234 if frame.fin {
235 if is_text {
237 crate::utf8::validate(&frame.payload).map_err(ReceiveError::InvalidUtf8)?;
238 }
239 if frame.payload.len() > self.max_message_size {
240 return Err(ReceiveError::MessageTooLarge);
241 }
242 Ok(Some(Message {
243 is_text,
244 payload: frame.payload.clone(),
245 }))
246 } else {
247 let mut utf8 = StreamingValidator::new();
248 if is_text {
249 utf8.feed(&frame.payload)
250 .map_err(ReceiveError::InvalidUtf8)?;
251 }
252 if frame.payload.len() > self.max_message_size {
253 return Err(ReceiveError::MessageTooLarge);
254 }
255 self.pending = Some(PendingMessage {
256 is_text,
257 buffer: frame.payload.clone(),
258 utf8,
259 });
260 Ok(None)
261 }
262 }
263 Opcode::Continuation => {
264 let mut p = self
265 .pending
266 .take()
267 .ok_or(ReceiveError::UnexpectedContinuation)?;
268 if p.is_text {
269 p.utf8
270 .feed(&frame.payload)
271 .map_err(ReceiveError::InvalidUtf8)?;
272 }
273 if p.buffer.len().saturating_add(frame.payload.len()) > self.max_message_size {
274 return Err(ReceiveError::MessageTooLarge);
275 }
276 p.buffer.extend_from_slice(&frame.payload);
277 if frame.fin {
278 if p.is_text {
279 p.utf8.finalize().map_err(ReceiveError::InvalidUtf8)?;
280 }
281 Ok(Some(Message {
282 is_text: p.is_text,
283 payload: p.buffer,
284 }))
285 } else {
286 self.pending = Some(p);
287 Ok(None)
288 }
289 }
290 Opcode::Close | Opcode::Ping | Opcode::Pong => Ok(Some(Message {
292 is_text: false,
293 payload: frame.payload.clone(),
294 })),
295 Opcode::Reserved(_) => Err(ReceiveError::UnexpectedContinuation),
299 }
300 }
301
302 #[must_use]
305 pub fn has_pending(&self) -> bool {
306 self.pending.is_some()
307 }
308}
309
310#[cfg(test)]
315#[allow(clippy::expect_used)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn fragment_empty_message_yields_one_frame() {
321 let f = fragment_message(true, b"", 100, [0; 4]).expect("ok");
322 assert_eq!(f.len(), 1);
323 assert!(f[0].fin);
324 assert_eq!(f[0].opcode, Opcode::Text);
325 }
326
327 #[test]
328 fn fragment_message_within_limit_single_frame() {
329 let f = fragment_message(false, b"hello", 100, [0; 4]).expect("ok");
330 assert_eq!(f.len(), 1);
331 assert!(f[0].fin);
332 assert_eq!(f[0].opcode, Opcode::Binary);
333 }
334
335 #[test]
336 fn fragment_message_splits_into_text_plus_continuations() {
337 let f = fragment_message(true, b"abcdefghij", 3, [0; 4]).expect("ok");
338 assert_eq!(f.len(), 4);
339 assert_eq!(f[0].opcode, Opcode::Text);
340 assert!(!f[0].fin);
341 assert_eq!(f[1].opcode, Opcode::Continuation);
342 assert_eq!(f[2].opcode, Opcode::Continuation);
343 assert_eq!(f[3].opcode, Opcode::Continuation);
344 assert!(f[3].fin);
345 }
346
347 #[test]
348 fn fragment_text_rejects_invalid_utf8() {
349 let bad = [0xff, 0xfe];
350 assert_eq!(
351 fragment_message(true, &bad, 100, [0; 4]),
352 Err(SendError::InvalidUtf8)
353 );
354 }
355
356 #[test]
357 fn fragment_zero_limit_rejected() {
358 assert_eq!(
359 fragment_message(false, b"x", 0, [0; 4]),
360 Err(SendError::InvalidFrameLimit)
361 );
362 }
363
364 #[test]
365 fn fragment_with_mask_sets_mask_field() {
366 let f = fragment_message(false, b"x", 100, [1, 2, 3, 4]).expect("ok");
367 assert_eq!(f[0].masking_key, Some([1, 2, 3, 4]));
368 }
369
370 fn binary_frame(fin: bool, opcode: Opcode, payload: Vec<u8>) -> Frame {
371 Frame {
372 fin,
373 rsv1: false,
374 rsv2: false,
375 rsv3: false,
376 opcode,
377 masking_key: None,
378 payload,
379 }
380 }
381
382 #[test]
383 fn reassembler_single_frame_message_complete() {
384 let mut r = Reassembler::new();
385 let msg = r
386 .feed(&binary_frame(true, Opcode::Text, b"hello".to_vec()))
387 .expect("ok")
388 .expect("complete");
389 assert!(msg.is_text);
390 assert_eq!(msg.payload, b"hello");
391 }
392
393 #[test]
394 fn reassembler_continuation_sequence_reassembles() {
395 let mut r = Reassembler::new();
396 let p1 = r
397 .feed(&binary_frame(false, Opcode::Text, b"hel".to_vec()))
398 .expect("ok");
399 assert!(p1.is_none());
400 let p2 = r
401 .feed(&binary_frame(false, Opcode::Continuation, b"lo ".to_vec()))
402 .expect("ok");
403 assert!(p2.is_none());
404 let msg = r
405 .feed(&binary_frame(true, Opcode::Continuation, b"world".to_vec()))
406 .expect("ok")
407 .expect("complete");
408 assert_eq!(msg.payload, b"hello world");
409 }
410
411 #[test]
412 fn reassembler_continuation_without_preceding_text_rejected() {
413 let mut r = Reassembler::new();
414 assert_eq!(
415 r.feed(&binary_frame(true, Opcode::Continuation, b"x".to_vec())),
416 Err(ReceiveError::UnexpectedContinuation)
417 );
418 }
419
420 #[test]
421 fn reassembler_interleaved_text_during_pending_rejected() {
422 let mut r = Reassembler::new();
423 let _ = r
424 .feed(&binary_frame(false, Opcode::Text, b"hel".to_vec()))
425 .expect("ok");
426 assert_eq!(
427 r.feed(&binary_frame(false, Opcode::Text, b"new".to_vec())),
428 Err(ReceiveError::InterleavedDataFrame)
429 );
430 }
431
432 #[test]
433 fn reassembler_rejects_invalid_utf8_in_text() {
434 let mut r = Reassembler::new();
435 let result = r.feed(&binary_frame(true, Opcode::Text, alloc::vec![0xff]));
436 assert!(matches!(result, Err(ReceiveError::InvalidUtf8(_))));
437 }
438
439 #[test]
440 fn reassembler_rejects_message_above_limit() {
441 let mut r = Reassembler::with_limit(5);
442 let result = r.feed(&binary_frame(true, Opcode::Binary, alloc::vec![0; 10]));
443 assert_eq!(result, Err(ReceiveError::MessageTooLarge));
444 }
445
446 #[test]
447 fn reassembler_passes_through_control_frames() {
448 let mut r = Reassembler::new();
449 let msg = r
450 .feed(&binary_frame(true, Opcode::Ping, b"abc".to_vec()))
451 .expect("ok")
452 .expect("ping");
453 assert!(!msg.is_text);
454 assert_eq!(msg.payload, b"abc");
455 }
456
457 #[test]
458 fn reassembler_has_pending_during_continuation() {
459 let mut r = Reassembler::new();
460 let _ = r
461 .feed(&binary_frame(false, Opcode::Binary, b"x".to_vec()))
462 .expect("ok");
463 assert!(r.has_pending());
464 }
465
466 #[test]
467 fn fragment_send_then_reassemble_round_trip() {
468 let original = b"the quick brown fox jumps";
469 let frames = fragment_message(true, original, 4, [0; 4]).expect("ok");
470 let mut r = Reassembler::new();
471 let mut completed: Option<Message> = None;
472 for f in &frames {
473 if let Some(m) = r.feed(f).expect("ok") {
474 completed = Some(m);
475 }
476 }
477 let msg = completed.expect("completed");
478 assert_eq!(msg.payload, original);
479 }
480}