Skip to main content

nexus_net/ws/
frame_writer.rs

1use rand_chacha::ChaCha8Rng;
2use rand_core::{RngCore, SeedableRng};
3
4use super::frame::Role;
5
6/// Error from WebSocket frame encoding.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum EncodeError {
9    /// Control frame payload exceeds 125 bytes (RFC 6455 §5.5).
10    ControlPayloadTooLarge(usize),
11}
12
13impl std::fmt::Display for EncodeError {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        match self {
16            Self::ControlPayloadTooLarge(n) => {
17                write!(f, "control frame payload too large: {n} bytes (max 125)")
18            }
19        }
20    }
21}
22
23impl std::error::Error for EncodeError {}
24
25/// Frame header bytes (stack-allocated, max 14 bytes).
26pub struct FrameHeader {
27    bytes: [u8; 14],
28    len: u8,
29}
30
31impl FrameHeader {
32    /// The header bytes.
33    pub fn as_bytes(&self) -> &[u8] {
34        &self.bytes[..self.len as usize]
35    }
36
37    /// Header length in bytes.
38    pub fn len(&self) -> usize {
39        self.len as usize
40    }
41
42    /// Whether the header is empty (shouldn't happen in practice).
43    pub fn is_empty(&self) -> bool {
44        self.len == 0
45    }
46}
47
48/// WebSocket frame encoder.
49///
50/// Encodes messages into RFC 6455 wire format. If the role is Client,
51/// frames are masked with a random 4-byte key. If Server, no masking.
52///
53/// # Usage
54///
55/// ```
56/// use nexus_net::ws::{FrameWriter, Role};
57///
58/// let mut writer = FrameWriter::new(Role::Server);
59/// let mut dst = vec![0u8; writer.max_encoded_len(5)];
60/// let n = writer.encode_text(b"Hello", &mut dst);
61/// assert_eq!(&dst[..n], &[0x81, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F]);
62/// ```
63pub struct FrameWriter {
64    role: Role,
65    /// PRNG for mask key generation (client only). Seeded lazily from
66    /// OS randomness on first use, then produces mask keys at ~1 cycle
67    /// instead of ~50-200 cycles per getrandom syscall.
68    mask_rng: Option<ChaCha8Rng>,
69}
70
71impl FrameWriter {
72    /// Create a writer for the given role.
73    #[must_use]
74    pub fn new(role: Role) -> Self {
75        Self {
76            role,
77            mask_rng: None,
78        }
79    }
80
81    /// Encode a text message frame. Returns bytes written.
82    ///
83    /// # Panics
84    /// Panics if `dst` is too small. Use [`max_encoded_len`](Self::max_encoded_len).
85    pub fn encode_text(&mut self, payload: &[u8], dst: &mut [u8]) -> usize {
86        self.encode(0x81, payload, dst) // FIN + Text
87    }
88
89    /// Encode a binary message frame. Returns bytes written.
90    pub fn encode_binary(&mut self, payload: &[u8], dst: &mut [u8]) -> usize {
91        self.encode(0x82, payload, dst) // FIN + Binary
92    }
93
94    /// Encode a ping control frame. Returns bytes written.
95    ///
96    /// Returns `Err` if payload exceeds 125 bytes (RFC 6455 §5.5).
97    pub fn encode_ping(&mut self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
98        if payload.len() > 125 {
99            return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
100        }
101        Ok(self.encode(0x89, payload, dst)) // FIN + Ping
102    }
103
104    /// Encode a pong control frame. Returns bytes written.
105    ///
106    /// Returns `Err` if payload exceeds 125 bytes (RFC 6455 §5.5).
107    pub fn encode_pong(&mut self, payload: &[u8], dst: &mut [u8]) -> Result<usize, EncodeError> {
108        if payload.len() > 125 {
109            return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
110        }
111        Ok(self.encode(0x8A, payload, dst)) // FIN + Pong
112    }
113
114    /// Encode a close frame. Returns bytes written.
115    ///
116    /// Returns `Err` if code + reason exceeds 125 bytes.
117    pub fn encode_close(
118        &mut self,
119        code: u16,
120        reason: &[u8],
121        dst: &mut [u8],
122    ) -> Result<usize, EncodeError> {
123        let payload_len = 2 + reason.len();
124        if payload_len > 125 {
125            return Err(EncodeError::ControlPayloadTooLarge(payload_len));
126        }
127
128        let mut close_payload = [0u8; 125];
129        close_payload[..2].copy_from_slice(&code.to_be_bytes());
130        close_payload[2..payload_len].copy_from_slice(reason);
131
132        Ok(self.encode(0x88, &close_payload[..payload_len], dst))
133    }
134
135    /// Maximum encoded size for a given payload length.
136    /// Accounts for header (2-10 bytes) + optional mask (4 bytes).
137    #[must_use]
138    pub fn max_encoded_len(&self, payload_len: usize) -> usize {
139        let header = if payload_len <= 125 {
140            2
141        } else if payload_len <= 65535 {
142            4
143        } else {
144            10
145        };
146        let mask = if self.role == Role::Client { 4 } else { 0 };
147        header + mask + payload_len
148    }
149
150    /// Encode an empty close frame (no status code on the wire).
151    ///
152    /// Used when `CloseCode::NoStatus` is intended — RFC 6455 §7.4.1
153    /// reserves code 1005 from appearing in close frame payloads.
154    pub fn encode_empty_close(&mut self, dst: &mut [u8]) -> usize {
155        self.encode(0x88, &[], dst) // FIN + Close, zero payload
156    }
157
158    /// Encode a close frame with structured [`CloseCode`](super::CloseCode) and UTF-8 reason.
159    ///
160    /// # Panics
161    /// Panics if `code` is `CloseCode::NoStatus` (RFC 6455 reserves 1005
162    /// from appearing on the wire — use [`encode_empty_close`](Self::encode_empty_close)).
163    /// Panics if 2 + reason.len() exceeds 125 bytes.
164    pub fn encode_close_code(
165        &mut self,
166        code: super::message::CloseCode,
167        reason: &str,
168        dst: &mut [u8],
169    ) -> Result<usize, EncodeError> {
170        assert!(
171            code != super::message::CloseCode::NoStatus,
172            "CloseCode::NoStatus cannot be sent on the wire — use encode_empty_close()"
173        );
174        self.encode_close(code.as_u16(), reason.as_bytes(), dst)
175    }
176
177    /// Build just the frame header. Returns (header_bytes, length, optional mask_key).
178    ///
179    /// For use with WriteBuf: append payload, apply mask if Some, prepend header.
180    pub fn build_header(
181        &mut self,
182        byte0: u8,
183        payload_len: usize,
184    ) -> (FrameHeader, Option<[u8; 4]>) {
185        let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
186        let mut hdr = FrameHeader {
187            bytes: [0; 14],
188            len: 0,
189        };
190
191        hdr.bytes[0] = byte0;
192        hdr.len = 1;
193
194        if payload_len <= 125 {
195            hdr.bytes[1] = mask_bit | (payload_len as u8);
196            hdr.len = 2;
197        } else if payload_len <= 65535 {
198            hdr.bytes[1] = mask_bit | 0x7E;
199            hdr.bytes[2..4].copy_from_slice(&(payload_len as u16).to_be_bytes());
200            hdr.len = 4;
201        } else {
202            hdr.bytes[1] = mask_bit | 0x7F;
203            hdr.bytes[2..10].copy_from_slice(&(payload_len as u64).to_be_bytes());
204            hdr.len = 10;
205        }
206
207        let mask_key = if self.role == Role::Client {
208            let mask = self.generate_mask();
209            hdr.bytes[hdr.len as usize..hdr.len as usize + 4].copy_from_slice(&mask);
210            hdr.len += 4;
211            Some(mask)
212        } else {
213            None
214        };
215
216        (hdr, mask_key)
217    }
218
219    /// Encode a complete frame into a WriteBuf.
220    ///
221    /// Clears the WriteBuf, appends payload, applies mask if client,
222    /// prepends header. Result: contiguous `[header | masked_payload]`.
223    pub fn encode_text_into(&mut self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
224        self.encode_into(0x81, payload, dst);
225    }
226
227    /// Encode a binary frame into a WriteBuf.
228    pub fn encode_binary_into(&mut self, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
229        self.encode_into(0x82, payload, dst);
230    }
231
232    /// Encode a ping frame into a WriteBuf.
233    pub fn encode_ping_into(
234        &mut self,
235        payload: &[u8],
236        dst: &mut crate::buf::WriteBuf,
237    ) -> Result<(), EncodeError> {
238        if payload.len() > 125 {
239            return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
240        }
241        self.encode_into(0x89, payload, dst);
242        Ok(())
243    }
244
245    /// Encode a pong frame into a WriteBuf.
246    pub fn encode_pong_into(
247        &mut self,
248        payload: &[u8],
249        dst: &mut crate::buf::WriteBuf,
250    ) -> Result<(), EncodeError> {
251        if payload.len() > 125 {
252            return Err(EncodeError::ControlPayloadTooLarge(payload.len()));
253        }
254        self.encode_into(0x8A, payload, dst);
255        Ok(())
256    }
257
258    /// Encode a close frame into a WriteBuf.
259    pub fn encode_close_into(
260        &mut self,
261        code: u16,
262        reason: &[u8],
263        dst: &mut crate::buf::WriteBuf,
264    ) -> Result<(), EncodeError> {
265        let payload_len = 2 + reason.len();
266        if payload_len > 125 {
267            return Err(EncodeError::ControlPayloadTooLarge(payload_len));
268        }
269        dst.clear();
270        dst.append(&code.to_be_bytes());
271        dst.append(reason);
272        let (hdr, mask_key) = self.build_header(0x88, payload_len);
273        if let Some(mask) = mask_key {
274            super::mask::apply_mask(dst.data_mut(), mask);
275        }
276        dst.prepend(hdr.as_bytes());
277        Ok(())
278    }
279
280    /// Encode a text frame, writing the payload via a closure.
281    ///
282    /// The closure writes directly into the WriteBuf — no intermediate
283    /// allocation. The WS frame header (including payload length) is
284    /// prepended after the closure returns.
285    ///
286    /// ```ignore
287    /// writer.encode_text_writer(&mut wbuf, |w| {
288    ///     use std::io::Write;
289    ///     serde_json::to_writer(w, &msg)
290    /// })?;
291    /// ```
292    pub fn encode_text_writer<F, E>(
293        &mut self,
294        dst: &mut crate::buf::WriteBuf,
295        f: F,
296    ) -> Result<(), E>
297    where
298        F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
299    {
300        self.encode_writer_into(0x81, dst, f)
301    }
302
303    /// Encode a binary frame, writing the payload via a closure.
304    pub fn encode_binary_writer<F, E>(
305        &mut self,
306        dst: &mut crate::buf::WriteBuf,
307        f: F,
308    ) -> Result<(), E>
309    where
310        F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
311    {
312        self.encode_writer_into(0x82, dst, f)
313    }
314
315    /// Encode a text frame with a fixed-size payload via closure.
316    ///
317    /// The closure receives `&mut [u8]` of exactly `len` bytes.
318    pub fn encode_text_fixed(
319        &mut self,
320        dst: &mut crate::buf::WriteBuf,
321        len: usize,
322        f: impl FnOnce(&mut [u8]),
323    ) {
324        self.encode_fixed_into(0x81, dst, len, f);
325    }
326
327    /// Encode a binary frame with a fixed-size payload via closure.
328    pub fn encode_binary_fixed(
329        &mut self,
330        dst: &mut crate::buf::WriteBuf,
331        len: usize,
332        f: impl FnOnce(&mut [u8]),
333    ) {
334        self.encode_fixed_into(0x82, dst, len, f);
335    }
336
337    fn encode_into(&mut self, byte0: u8, payload: &[u8], dst: &mut crate::buf::WriteBuf) {
338        dst.clear();
339        dst.append(payload);
340        let (hdr, mask_key) = self.build_header(byte0, payload.len());
341        if let Some(mask) = mask_key {
342            super::mask::apply_mask(dst.data_mut(), mask);
343        }
344        dst.prepend(hdr.as_bytes());
345    }
346
347    fn encode_writer_into<F, E>(
348        &mut self,
349        byte0: u8,
350        dst: &mut crate::buf::WriteBuf,
351        f: F,
352    ) -> Result<(), E>
353    where
354        F: FnOnce(&mut crate::buf::WriteBufWriter<'_>) -> Result<(), E>,
355    {
356        dst.clear();
357        let payload_len = {
358            let mut bw = crate::buf::WriteBufWriter::new(dst);
359            f(&mut bw)?;
360            bw.written()
361        };
362        let (hdr, mask_key) = self.build_header(byte0, payload_len);
363        if let Some(mask) = mask_key {
364            super::mask::apply_mask(dst.data_mut(), mask);
365        }
366        dst.prepend(hdr.as_bytes());
367        Ok(())
368    }
369
370    fn encode_fixed_into(
371        &mut self,
372        byte0: u8,
373        dst: &mut crate::buf::WriteBuf,
374        len: usize,
375        f: impl FnOnce(&mut [u8]),
376    ) {
377        dst.clear();
378        dst.extend_zeroed(len);
379        f(dst.data_mut());
380        let (hdr, mask_key) = self.build_header(byte0, len);
381        if let Some(mask) = mask_key {
382            super::mask::apply_mask(dst.data_mut(), mask);
383        }
384        dst.prepend(hdr.as_bytes());
385    }
386
387    // =========================================================================
388    // Internal
389    // =========================================================================
390
391    /// Generate a 4-byte mask key from the internal PRNG.
392    ///
393    /// The PRNG is seeded from OS randomness on first use, then produces
394    /// mask keys without syscalls. RFC 6455 §10.3 requires unpredictable
395    /// masking keys — ChaCha8 satisfies this.
396    fn generate_mask(&mut self) -> [u8; 4] {
397        let rng = self.mask_rng.get_or_insert_with(|| {
398            let mut seed = [0u8; 32];
399            getrandom::fill(&mut seed).expect("OS randomness unavailable");
400            ChaCha8Rng::from_seed(seed)
401        });
402        let mut mask = [0u8; 4];
403        rng.fill_bytes(&mut mask);
404        mask
405    }
406
407    fn encode(&mut self, byte0: u8, payload: &[u8], dst: &mut [u8]) -> usize {
408        let mask_bit: u8 = if self.role == Role::Client { 0x80 } else { 0 };
409        let payload_len = payload.len();
410
411        let mut offset = 0;
412
413        // Byte 0: FIN + opcode
414        dst[offset] = byte0;
415        offset += 1;
416
417        // Byte 1: MASK bit + payload length
418        if payload_len <= 125 {
419            dst[offset] = mask_bit | (payload_len as u8);
420            offset += 1;
421        } else if payload_len <= 65535 {
422            dst[offset] = mask_bit | 0x7E;
423            offset += 1;
424            dst[offset..offset + 2].copy_from_slice(&(payload_len as u16).to_be_bytes());
425            offset += 2;
426        } else {
427            dst[offset] = mask_bit | 0x7F;
428            offset += 1;
429            dst[offset..offset + 8].copy_from_slice(&(payload_len as u64).to_be_bytes());
430            offset += 8;
431        }
432
433        // Mask key (client only)
434        if self.role == Role::Client {
435            let mask = self.generate_mask();
436            dst[offset..offset + 4].copy_from_slice(&mask);
437            offset += 4;
438
439            // Copy and mask payload
440            dst[offset..offset + payload_len].copy_from_slice(payload);
441            super::mask::apply_mask(&mut dst[offset..offset + payload_len], mask);
442        } else {
443            dst[offset..offset + payload_len].copy_from_slice(payload);
444        }
445
446        offset + payload_len
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn encode_text_server() {
456        let mut writer = FrameWriter::new(Role::Server);
457        let mut dst = vec![0u8; writer.max_encoded_len(5)];
458        let n = writer.encode_text(b"Hello", &mut dst);
459        assert_eq!(n, 7);
460        assert_eq!(dst[0], 0x81); // FIN + Text
461        assert_eq!(dst[1], 0x05); // no mask, len=5
462        assert_eq!(&dst[2..7], b"Hello");
463    }
464
465    #[test]
466    fn encode_binary_server() {
467        let mut writer = FrameWriter::new(Role::Server);
468        let mut dst = vec![0u8; writer.max_encoded_len(4)];
469        let n = writer.encode_binary(&[0xDE, 0xAD, 0xBE, 0xEF], &mut dst);
470        assert_eq!(n, 6);
471        assert_eq!(dst[0], 0x82); // FIN + Binary
472        assert_eq!(&dst[2..6], &[0xDE, 0xAD, 0xBE, 0xEF]);
473    }
474
475    #[test]
476    fn encode_close_server() {
477        let mut writer = FrameWriter::new(Role::Server);
478        let mut dst = vec![0u8; writer.max_encoded_len(9)];
479        let n = writer.encode_close(1000, b"goodbye", &mut dst).unwrap();
480        assert_eq!(dst[0], 0x88); // FIN + Close
481        assert_eq!(&dst[2..4], &1000u16.to_be_bytes());
482        assert_eq!(&dst[4..n], b"goodbye");
483    }
484
485    #[test]
486    fn encode_ping_server() {
487        let mut writer = FrameWriter::new(Role::Server);
488        let mut dst = vec![0u8; writer.max_encoded_len(4)];
489        let n = writer.encode_ping(b"ping", &mut dst).unwrap();
490        assert_eq!(dst[0], 0x89); // FIN + Ping
491        assert_eq!(&dst[2..n], b"ping");
492    }
493
494    #[test]
495    fn encode_pong_server() {
496        let mut writer = FrameWriter::new(Role::Server);
497        let mut dst = vec![0u8; writer.max_encoded_len(4)];
498        let n = writer.encode_pong(b"pong", &mut dst).unwrap();
499        assert_eq!(dst[0], 0x8A); // FIN + Pong
500        assert_eq!(&dst[2..n], b"pong");
501    }
502
503    #[test]
504    fn encode_client_is_masked() {
505        let mut writer = FrameWriter::new(Role::Client);
506        let mut dst = vec![0u8; writer.max_encoded_len(5)];
507        let n = writer.encode_text(b"Hello", &mut dst);
508        assert_eq!(n, 11); // 2 header + 4 mask + 5 payload
509        assert_eq!(dst[0], 0x81); // FIN + Text
510        assert_eq!(dst[1] & 0x80, 0x80); // mask bit set
511        assert_eq!(dst[1] & 0x7F, 5); // len=5
512        // Payload is masked — shouldn't equal plaintext
513        assert_ne!(&dst[6..11], b"Hello");
514    }
515
516    #[test]
517    fn encode_16bit_length() {
518        let mut writer = FrameWriter::new(Role::Server);
519        let payload = vec![0x42; 256];
520        let mut dst = vec![0u8; writer.max_encoded_len(256)];
521        let n = writer.encode_binary(&payload, &mut dst);
522        assert_eq!(n, 4 + 256); // 2 + 2 (16-bit len) + 256
523        assert_eq!(dst[1] & 0x7F, 126); // extended 16-bit
524        let len = u16::from_be_bytes([dst[2], dst[3]]);
525        assert_eq!(len, 256);
526    }
527
528    #[test]
529    fn max_encoded_len_small() {
530        let server = FrameWriter::new(Role::Server);
531        assert_eq!(server.max_encoded_len(0), 2);
532        assert_eq!(server.max_encoded_len(125), 2 + 125);
533        assert_eq!(server.max_encoded_len(126), 4 + 126);
534
535        let client = FrameWriter::new(Role::Client);
536        assert_eq!(client.max_encoded_len(0), 2 + 4);
537        assert_eq!(client.max_encoded_len(125), 2 + 4 + 125);
538    }
539
540    #[test]
541    fn round_trip_server() {
542        use crate::ws::{FrameReader, Message};
543        let mut writer = FrameWriter::new(Role::Server);
544        let mut dst = vec![0u8; writer.max_encoded_len(5)];
545        let n = writer.encode_text(b"Hello", &mut dst);
546
547        let mut reader = FrameReader::builder().role(Role::Client).build();
548        reader.read(&dst[..n]).unwrap();
549        assert!(matches!(
550            reader.next().unwrap().unwrap(),
551            Message::Text("Hello")
552        ));
553    }
554
555    #[test]
556    fn round_trip_client() {
557        use crate::ws::{FrameReader, Message};
558        let mut writer = FrameWriter::new(Role::Client);
559        let mut dst = vec![0u8; writer.max_encoded_len(5)];
560        let n = writer.encode_text(b"Hello", &mut dst);
561
562        let mut reader = FrameReader::builder().role(Role::Server).build();
563        reader.read(&dst[..n]).unwrap();
564        assert!(matches!(
565            reader.next().unwrap().unwrap(),
566            Message::Text("Hello")
567        ));
568    }
569
570    #[test]
571    fn encode_close_code_round_trip() {
572        use crate::ws::{CloseCode, FrameReader, Message};
573        let mut writer = FrameWriter::new(Role::Server);
574        let mut dst = vec![0u8; 64];
575        let n = writer
576            .encode_close_code(CloseCode::Normal, "goodbye", &mut dst)
577            .unwrap();
578
579        let mut reader = FrameReader::builder().role(Role::Client).build();
580        reader.read(&dst[..n]).unwrap();
581        match reader.next().unwrap().unwrap() {
582            Message::Close(cf) => {
583                assert_eq!(cf.code, CloseCode::Normal);
584                assert_eq!(cf.reason, "goodbye");
585            }
586            other => panic!("expected Close, got {other:?}"),
587        }
588    }
589
590    #[test]
591    fn ping_too_large_returns_err() {
592        let mut writer = FrameWriter::new(Role::Server);
593        let mut dst = vec![0u8; 256];
594        assert!(matches!(
595            writer.encode_ping(&[0; 126], &mut dst),
596            Err(super::EncodeError::ControlPayloadTooLarge(126))
597        ));
598    }
599
600    #[test]
601    fn encode_text_writer_matches_into() {
602        use crate::buf::WriteBuf;
603        let mut writer = FrameWriter::new(Role::Server);
604        let payload = b"Hello, world!";
605
606        let mut wbuf1 = WriteBuf::new(128, 14);
607        writer.encode_text_into(payload, &mut wbuf1);
608
609        let mut wbuf2 = WriteBuf::new(128, 14);
610        writer
611            .encode_text_writer(&mut wbuf2, |w| {
612                use std::io::Write;
613                w.write_all(payload)
614            })
615            .unwrap();
616
617        assert_eq!(wbuf1.data(), wbuf2.data());
618    }
619
620    #[test]
621    fn encode_binary_fixed_matches_into() {
622        use crate::buf::WriteBuf;
623        let mut writer = FrameWriter::new(Role::Server);
624        let payload = [0xDE, 0xAD, 0xBE, 0xEF];
625
626        let mut wbuf1 = WriteBuf::new(128, 14);
627        writer.encode_binary_into(&payload, &mut wbuf1);
628
629        let mut wbuf2 = WriteBuf::new(128, 14);
630        writer.encode_binary_fixed(&mut wbuf2, payload.len(), |buf| {
631            buf.copy_from_slice(&payload);
632        });
633
634        assert_eq!(wbuf1.data(), wbuf2.data());
635    }
636
637    #[test]
638    fn encode_text_writer_round_trip() {
639        use crate::buf::WriteBuf;
640        use crate::ws::{FrameReader, Message};
641
642        let mut writer = FrameWriter::new(Role::Server);
643        let mut wbuf = WriteBuf::new(128, 14);
644        writer
645            .encode_text_writer(&mut wbuf, |w| {
646                use std::io::Write;
647                w.write_all(b"test message")
648            })
649            .unwrap();
650
651        let mut reader = FrameReader::builder().role(Role::Client).build();
652        reader.read(wbuf.data()).unwrap();
653        assert!(matches!(
654            reader.next().unwrap().unwrap(),
655            Message::Text("test message")
656        ));
657    }
658}