1use rand_chacha::ChaCha8Rng;
2use rand_core::{RngCore, SeedableRng};
3
4use super::frame::Role;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum EncodeError {
9 ControlPayloadTooLarge(usize),
11}
12
13impl std::fmt::Display for EncodeError {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 match self {
16 Self::ControlPayloadTooLarge(n) => {
17 write!(f, "control frame payload too large: {n} bytes (max 125)")
18 }
19 }
20 }
21}
22
23impl std::error::Error for EncodeError {}
24
25pub struct FrameHeader {
27 bytes: [u8; 14],
28 len: u8,
29}
30
31impl FrameHeader {
32 pub fn as_bytes(&self) -> &[u8] {
34 &self.bytes[..self.len as usize]
35 }
36
37 pub fn len(&self) -> usize {
39 self.len as usize
40 }
41
42 pub fn is_empty(&self) -> bool {
44 self.len == 0
45 }
46}
47
48pub struct FrameWriter {
64 role: Role,
65 mask_rng: Option<ChaCha8Rng>,
69}
70
71impl FrameWriter {
72 #[must_use]
74 pub fn new(role: Role) -> Self {
75 Self {
76 role,
77 mask_rng: None,
78 }
79 }
80
81 pub fn encode_text(&mut self, payload: &[u8], dst: &mut [u8]) -> usize {
86 self.encode(0x81, payload, dst) }
88
89 pub fn encode_binary(&mut self, payload: &[u8], dst: &mut [u8]) -> usize {
91 self.encode(0x82, payload, dst) }
93
94 pub fn encode_ping(&mut self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
98 if payload.len() > 125 {
99 return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
100 }
101 Ok(self.encode(0x89, payload, dst)) }
103
104 pub fn encode_pong(&mut self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
108 if payload.len() > 125 {
109 return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
110 }
111 Ok(self.encode(0x8A, payload, dst)) }
113
114 pub fn encode_close(
118 &mut self,
119 code: u16,
120 reason: &[u8],
121 dst: &mut [u8],
122 ) -> Result<usize, EncodeError> {
123 let payload_len = 2 + reason.len();
124 if payload_len > 125 {
125 return Err(EncodeError::ControlPayloadTooLarge(payload_len));
126 }
127
128 let mut close_payload = [0u8; 125];
129 close_payload[..2].copy_from_slice(&code.to_be_bytes());
130 close_payload[2..payload_len].copy_from_slice(reason);
131
132 Ok(self.encode(0x88, &close_payload[..payload_len], dst))
133 }
134
135 #[must_use]
138 pub fn max_encoded_len(&self, payload_len: usize) -> usize {
139 let header = if payload_len <= 125 {
140 2
141 } else if payload_len <= 65535 {
142 4
143 } else {
144 10
145 };
146 let mask = if self.role == Role::Client { 4 } else { 0 };
147 header + mask + payload_len
148 }
149
150 pub fn encode_empty_close(&mut self, dst: &mut [u8]) -> usize {
155 self.encode(0x88, &[], dst) }
157
158 pub fn encode_close_code(
165 &mut self,
166 code: super::message::CloseCode,
167 reason: &str,
168 dst: &mut [u8],
169 ) -> Result<usize, EncodeError> {
170 assert!(
171 code != super::message::CloseCode::NoStatus,
172 "CloseCode::NoStatus cannot be sent on the wire — use encode_empty_close()"
173 );
174 self.encode_close(code.as_u16(), reason.as_bytes(), dst)
175 }
176
177 pub fn build_header(
181 &mut self,
182 byte0: u8,
183 payload_len: usize,
184 ) -> (FrameHeader, Option<[u8; 4]>) {
185 let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
186 let mut hdr = FrameHeader {
187 bytes: [0; 14],
188 len: 0,
189 };
190
191 hdr.bytes[0] = byte0;
192 hdr.len = 1;
193
194 if payload_len <= 125 {
195 hdr.bytes[1] = mask_bit | (payload_len as u8);
196 hdr.len = 2;
197 } else if payload_len <= 65535 {
198 hdr.bytes[1] = mask_bit | 0x7E;
199 hdr.bytes[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
200 hdr.len = 4;
201 } else {
202 hdr.bytes[1] = mask_bit | 0x7F;
203 hdr.bytes[2..10].copy_from_slice(&(payload_len as u64).to_be_bytes());
204 hdr.len = 10;
205 }
206
207 let mask_key = if self.role == Role::Client {
208 let mask = self.generate_mask();
209 hdr.bytes[hdr.len as usize..hdr.len as usize + 4].copy_from_slice(&mask);
210 hdr.len += 4;
211 Some(mask)
212 } else {
213 None
214 };
215
216 (hdr, mask_key)
217 }
218
219 pub fn encode_text_into(&mut self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
224 self.encode_into(0x81, payload, dst);
225 }
226
227 pub fn encode_binary_into(&mut self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
229 self.encode_into(0x82, payload, dst);
230 }
231
232 pub fn encode_ping_into(
234 &mut self,
235 payload: &[u8],
236 dst: &mut crate::buf::WriteBuf,
237 ) -> Result<(), EncodeError> {
238 if payload.len() > 125 {
239 return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
240 }
241 self.encode_into(0x89, payload, dst);
242 Ok(())
243 }
244
245 pub fn encode_pong_into(
247 &mut self,
248 payload: &[u8],
249 dst: &mut crate::buf::WriteBuf,
250 ) -> Result<(), EncodeError> {
251 if payload.len() > 125 {
252 return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
253 }
254 self.encode_into(0x8A, payload, dst);
255 Ok(())
256 }
257
258 pub fn encode_close_into(
260 &mut self,
261 code: u16,
262 reason: &[u8],
263 dst: &mut crate::buf::WriteBuf,
264 ) -> Result<(), EncodeError> {
265 let payload_len = 2 + reason.len();
266 if payload_len > 125 {
267 return Err(EncodeError::ControlPayloadTooLarge(payload_len));
268 }
269 dst.clear();
270 dst.append(&code.to_be_bytes());
271 dst.append(reason);
272 let (hdr, mask_key) = self.build_header(0x88, payload_len);
273 if let Some(mask) = mask_key {
274 super::mask::apply_mask(dst.data_mut(), mask);
275 }
276 dst.prepend(hdr.as_bytes());
277 Ok(())
278 }
279
280 pub fn encode_text_writer<F, E>(
293 &mut self,
294 dst: &mut crate::buf::WriteBuf,
295 f: F,
296 ) -> Result<(), E>
297 where
298 F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
299 {
300 self.encode_writer_into(0x81, dst, f)
301 }
302
303 pub fn encode_binary_writer<F, E>(
305 &mut self,
306 dst: &mut crate::buf::WriteBuf,
307 f: F,
308 ) -> Result<(), E>
309 where
310 F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
311 {
312 self.encode_writer_into(0x82, dst, f)
313 }
314
315 pub fn encode_text_fixed(
319 &mut self,
320 dst: &mut crate::buf::WriteBuf,
321 len: usize,
322 f: impl FnOnce(&mut [u8]),
323 ) {
324 self.encode_fixed_into(0x81, dst, len, f);
325 }
326
327 pub fn encode_binary_fixed(
329 &mut self,
330 dst: &mut crate::buf::WriteBuf,
331 len: usize,
332 f: impl FnOnce(&mut [u8]),
333 ) {
334 self.encode_fixed_into(0x82, dst, len, f);
335 }
336
337 fn encode_into(&mut self, byte0: u8, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
338 dst.clear();
339 dst.append(payload);
340 let (hdr, mask_key) = self.build_header(byte0, payload.len());
341 if let Some(mask) = mask_key {
342 super::mask::apply_mask(dst.data_mut(), mask);
343 }
344 dst.prepend(hdr.as_bytes());
345 }
346
347 fn encode_writer_into<F, E>(
348 &mut self,
349 byte0: u8,
350 dst: &mut crate::buf::WriteBuf,
351 f: F,
352 ) -> Result<(), E>
353 where
354 F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
355 {
356 dst.clear();
357 let payload_len = {
358 let mut bw = crate::buf::WriteBufWriter::new(dst);
359 f(&mut bw)?;
360 bw.written()
361 };
362 let (hdr, mask_key) = self.build_header(byte0, payload_len);
363 if let Some(mask) = mask_key {
364 super::mask::apply_mask(dst.data_mut(), mask);
365 }
366 dst.prepend(hdr.as_bytes());
367 Ok(())
368 }
369
370 fn encode_fixed_into(
371 &mut self,
372 byte0: u8,
373 dst: &mut crate::buf::WriteBuf,
374 len: usize,
375 f: impl FnOnce(&mut [u8]),
376 ) {
377 dst.clear();
378 dst.extend_zeroed(len);
379 f(dst.data_mut());
380 let (hdr, mask_key) = self.build_header(byte0, len);
381 if let Some(mask) = mask_key {
382 super::mask::apply_mask(dst.data_mut(), mask);
383 }
384 dst.prepend(hdr.as_bytes());
385 }
386
387 fn generate_mask(&mut self) -> [u8; 4] {
397 let rng = self.mask_rng.get_or_insert_with(|| {
398 let mut seed = [0u8; 32];
399 getrandom::fill(&mut seed).expect("OS randomness unavailable");
400 ChaCha8Rng::from_seed(seed)
401 });
402 let mut mask = [0u8; 4];
403 rng.fill_bytes(&mut mask);
404 mask
405 }
406
407 fn encode(&mut self, byte0: u8, payload: &[u8], dst: &mut [u8]) -> usize {
408 let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
409 let payload_len = payload.len();
410
411 let mut offset = 0;
412
413 dst[offset] = byte0;
415 offset += 1;
416
417 if payload_len <= 125 {
419 dst[offset] = mask_bit | (payload_len as u8);
420 offset += 1;
421 } else if payload_len <= 65535 {
422 dst[offset] = mask_bit | 0x7E;
423 offset += 1;
424 dst[offset..offset + 2].copy_from_slice(&(payload_len as u16).to_be_bytes());
425 offset += 2;
426 } else {
427 dst[offset] = mask_bit | 0x7F;
428 offset += 1;
429 dst[offset..offset + 8].copy_from_slice(&(payload_len as u64).to_be_bytes());
430 offset += 8;
431 }
432
433 if self.role == Role::Client {
435 let mask = self.generate_mask();
436 dst[offset..offset + 4].copy_from_slice(&mask);
437 offset += 4;
438
439 dst[offset..offset + payload_len].copy_from_slice(payload);
441 super::mask::apply_mask(&mut dst[offset..offset + payload_len], mask);
442 } else {
443 dst[offset..offset + payload_len].copy_from_slice(payload);
444 }
445
446 offset + payload_len
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn encode_text_server() {
456 let mut writer = FrameWriter::new(Role::Server);
457 let mut dst = vec![0u8; writer.max_encoded_len(5)];
458 let n = writer.encode_text(b"Hello", &mut dst);
459 assert_eq!(n, 7);
460 assert_eq!(dst[0], 0x81); assert_eq!(dst[1], 0x05); assert_eq!(&dst[2..7], b"Hello");
463 }
464
465 #[test]
466 fn encode_binary_server() {
467 let mut writer = FrameWriter::new(Role::Server);
468 let mut dst = vec![0u8; writer.max_encoded_len(4)];
469 let n = writer.encode_binary(&[0xDE, 0xAD, 0xBE, 0xEF], &mut dst);
470 assert_eq!(n, 6);
471 assert_eq!(dst[0], 0x82); assert_eq!(&dst[2..6], &[0xDE, 0xAD, 0xBE, 0xEF]);
473 }
474
475 #[test]
476 fn encode_close_server() {
477 let mut writer = FrameWriter::new(Role::Server);
478 let mut dst = vec![0u8; writer.max_encoded_len(9)];
479 let n = writer.encode_close(1000, b"goodbye", &mut dst).unwrap();
480 assert_eq!(dst[0], 0x88); assert_eq!(&dst[2..4], &1000u16.to_be_bytes());
482 assert_eq!(&dst[4..n], b"goodbye");
483 }
484
485 #[test]
486 fn encode_ping_server() {
487 let mut writer = FrameWriter::new(Role::Server);
488 let mut dst = vec![0u8; writer.max_encoded_len(4)];
489 let n = writer.encode_ping(b"ping", &mut dst).unwrap();
490 assert_eq!(dst[0], 0x89); assert_eq!(&dst[2..n], b"ping");
492 }
493
494 #[test]
495 fn encode_pong_server() {
496 let mut writer = FrameWriter::new(Role::Server);
497 let mut dst = vec![0u8; writer.max_encoded_len(4)];
498 let n = writer.encode_pong(b"pong", &mut dst).unwrap();
499 assert_eq!(dst[0], 0x8A); assert_eq!(&dst[2..n], b"pong");
501 }
502
503 #[test]
504 fn encode_client_is_masked() {
505 let mut writer = FrameWriter::new(Role::Client);
506 let mut dst = vec![0u8; writer.max_encoded_len(5)];
507 let n = writer.encode_text(b"Hello", &mut dst);
508 assert_eq!(n, 11); assert_eq!(dst[0], 0x81); assert_eq!(dst[1] & 0x80, 0x80); assert_eq!(dst[1] & 0x7F, 5); assert_ne!(&dst[6..11], b"Hello");
514 }
515
516 #[test]
517 fn encode_16bit_length() {
518 let mut writer = FrameWriter::new(Role::Server);
519 let payload = vec![0x42; 256];
520 let mut dst = vec![0u8; writer.max_encoded_len(256)];
521 let n = writer.encode_binary(&payload, &mut dst);
522 assert_eq!(n, 4 + 256); assert_eq!(dst[1] & 0x7F, 126); let len = u16::from_be_bytes([dst[2], dst[3]]);
525 assert_eq!(len, 256);
526 }
527
528 #[test]
529 fn max_encoded_len_small() {
530 let server = FrameWriter::new(Role::Server);
531 assert_eq!(server.max_encoded_len(0), 2);
532 assert_eq!(server.max_encoded_len(125), 2 + 125);
533 assert_eq!(server.max_encoded_len(126), 4 + 126);
534
535 let client = FrameWriter::new(Role::Client);
536 assert_eq!(client.max_encoded_len(0), 2 + 4);
537 assert_eq!(client.max_encoded_len(125), 2 + 4 + 125);
538 }
539
540 #[test]
541 fn round_trip_server() {
542 use crate::ws::{FrameReader, Message};
543 let mut writer = FrameWriter::new(Role::Server);
544 let mut dst = vec![0u8; writer.max_encoded_len(5)];
545 let n = writer.encode_text(b"Hello", &mut dst);
546
547 let mut reader = FrameReader::builder().role(Role::Client).build();
548 reader.read(&dst[..n]).unwrap();
549 assert!(matches!(
550 reader.next().unwrap().unwrap(),
551 Message::Text("Hello")
552 ));
553 }
554
555 #[test]
556 fn round_trip_client() {
557 use crate::ws::{FrameReader, Message};
558 let mut writer = FrameWriter::new(Role::Client);
559 let mut dst = vec![0u8; writer.max_encoded_len(5)];
560 let n = writer.encode_text(b"Hello", &mut dst);
561
562 let mut reader = FrameReader::builder().role(Role::Server).build();
563 reader.read(&dst[..n]).unwrap();
564 assert!(matches!(
565 reader.next().unwrap().unwrap(),
566 Message::Text("Hello")
567 ));
568 }
569
570 #[test]
571 fn encode_close_code_round_trip() {
572 use crate::ws::{CloseCode, FrameReader, Message};
573 let mut writer = FrameWriter::new(Role::Server);
574 let mut dst = vec![0u8; 64];
575 let n = writer
576 .encode_close_code(CloseCode::Normal, "goodbye", &mut dst)
577 .unwrap();
578
579 let mut reader = FrameReader::builder().role(Role::Client).build();
580 reader.read(&dst[..n]).unwrap();
581 match reader.next().unwrap().unwrap() {
582 Message::Close(cf) => {
583 assert_eq!(cf.code, CloseCode::Normal);
584 assert_eq!(cf.reason, "goodbye");
585 }
586 other => panic!("expected Close, got {other:?}"),
587 }
588 }
589
590 #[test]
591 fn ping_too_large_returns_err() {
592 let mut writer = FrameWriter::new(Role::Server);
593 let mut dst = vec![0u8; 256];
594 assert!(matches!(
595 writer.encode_ping(&[0; 126], &mut dst),
596 Err(super::EncodeError::ControlPayloadTooLarge(126))
597 ));
598 }
599
600 #[test]
601 fn encode_text_writer_matches_into() {
602 use crate::buf::WriteBuf;
603 let mut writer = FrameWriter::new(Role::Server);
604 let payload = b"Hello, world!";
605
606 let mut wbuf1 = WriteBuf::new(128, 14);
607 writer.encode_text_into(payload, &mut wbuf1);
608
609 let mut wbuf2 = WriteBuf::new(128, 14);
610 writer
611 .encode_text_writer(&mut wbuf2, |w| {
612 use std::io::Write;
613 w.write_all(payload)
614 })
615 .unwrap();
616
617 assert_eq!(wbuf1.data(), wbuf2.data());
618 }
619
620 #[test]
621 fn encode_binary_fixed_matches_into() {
622 use crate::buf::WriteBuf;
623 let mut writer = FrameWriter::new(Role::Server);
624 let payload = [0xDE, 0xAD, 0xBE, 0xEF];
625
626 let mut wbuf1 = WriteBuf::new(128, 14);
627 writer.encode_binary_into(&payload, &mut wbuf1);
628
629 let mut wbuf2 = WriteBuf::new(128, 14);
630 writer.encode_binary_fixed(&mut wbuf2, payload.len(), |buf| {
631 buf.copy_from_slice(&payload);
632 });
633
634 assert_eq!(wbuf1.data(), wbuf2.data());
635 }
636
637 #[test]
638 fn encode_text_writer_round_trip() {
639 use crate::buf::WriteBuf;
640 use crate::ws::{FrameReader, Message};
641
642 let mut writer = FrameWriter::new(Role::Server);
643 let mut wbuf = WriteBuf::new(128, 14);
644 writer
645 .encode_text_writer(&mut wbuf, |w| {
646 use std::io::Write;
647 w.write_all(b"test message")
648 })
649 .unwrap();
650
651 let mut reader = FrameReader::builder().role(Role::Client).build();
652 reader.read(wbuf.data()).unwrap();
653 assert!(matches!(
654 reader.next().unwrap().unwrap(),
655 Message::Text("test message")
656 ));
657 }
658}