blitz_ws/protocol/frame/
frame.rs

1//! WebSocket Frame module
2
3use std::{
4    fmt::Display,
5    io::{Cursor, ErrorKind, Read, Write},
6    mem,
7    result::Result as StdResult,
8    str::Utf8Error,
9};
10
11use bytes::{Bytes, BytesMut};
12
13use super::{
14    codec::{CloseCode, Control, Data, OpCode},
15    mask::{apply_mask, generate},
16};
17use crate::{
18    error::{Error, ProtocolError, Result},
19    protocol::frame::Utf8Bytes,
20};
21
22/// A struct representing the close command.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CloseFrame {
25    /// The reason as a code.
26    pub code: CloseCode,
27    /// The reason as text string.
28    pub reason: Utf8Bytes,
29}
30
31impl Display for CloseFrame {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "{} ({})", self.reason, self.code)
34    }
35}
36
37/// A struct representing a WebSocket frame header.
38#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct FrameHeader {
41    /// Indicates is the frame is the last one of a possibly fragmented message
42    pub fin: bool,
43    /// Reserved for protocol extensions.
44    pub rsv1: bool,
45    /// Reserved for protocol extensions.
46    pub rsv2: bool,
47    /// Reserved for protocol extensions.
48    pub rsv3: bool,
49    /// WebSocket protocol opcode.
50    pub opcode: OpCode,
51    /// A frame mask (if any)
52    pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56    fn default() -> Self {
57        FrameHeader {
58            fin: false,
59            rsv1: false,
60            rsv2: false,
61            rsv3: false,
62            opcode: OpCode::Control(Control::Close),
63            mask: None,
64        }
65    }
66}
67
68impl FrameHeader {
69    /// > The longest possible header is 14 bytes, which would represent a message sent from
70    /// > the client to the server with a payload greater than 64KB.
71    pub(crate) const MAX_HEADER_SIZE: usize = 14;
72
73    /// Parse a header from an input stream.
74    /// Returns `None` if insufficient data and does not consume anything in this case.
75    /// Payload size is returned along with the header.
76    pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77        let init = cursor.position();
78
79        match Self::parse_internal(cursor) {
80            i @ Ok(None) => {
81                cursor.set_position(init);
82                i
83            }
84            other => other,
85        }
86    }
87
88    /// Get the size of the header formatted with given payload length.
89    #[allow(clippy::len_without_is_empty)]
90    pub fn len(&self, length: u64) -> usize {
91        2 + Length::for_len(length).additional() + (if self.mask.is_some() { 4 } else { 0 })
92    }
93
94    /// Format a header for given payload size.
95    pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
96        let code: u8 = self.opcode.into();
97
98        let first_byte = {
99            code | if self.fin { 0x80 } else { 0 }
100                | if self.rsv1 { 0x40 } else { 0 }
101                | if self.rsv2 { 0x20 } else { 0 }
102                | if self.rsv3 { 0x10 } else { 0 }
103        };
104
105        let len = Length::for_len(length);
106
107        let second_byte = len.len_byte() | if self.mask.is_some() { 0x80 } else { 0 };
108
109        output.write_all(&[first_byte, second_byte])?;
110
111        match len {
112            Length::U8(_) => (),
113            Length::U16 => {
114                output.write_all(&(length as u16).to_be_bytes())?;
115            }
116            Length::U64 => {
117                output.write_all(&length.to_be_bytes())?;
118            }
119        }
120
121        if let Some(ref mask) = self.mask {
122            output.write_all(mask)?;
123        }
124
125        Ok(())
126    }
127
128    /// Generate a random frame mask and store this in the header.
129    ///
130    /// Of course this does not change frame contents. It just generates a mask.
131    pub(crate) fn set_random_mask(&mut self) {
132        self.mask = Some(generate());
133    }
134
135    /// Internal parse engine.
136    /// Returns `None` if insufficient data.
137    /// Payload size is returned along with the header.
138    fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139        let (a, b) = {
140            let mut head = [0u8; 2];
141            if cursor.read(&mut head)? != 2 {
142                return Ok(None);
143            }
144
145            (head[0], head[1])
146        };
147
148        let fin = a & 0x80 != 0;
149        let rsv1 = a & 0x40 != 0;
150        let rsv2 = a & 0x20 != 0;
151        let rsv3 = a & 0x10 != 0;
152
153        let opcode = OpCode::from(a & 0x0F);
154
155        let masked = b & 0x80 != 0;
156
157        let len = {
158            let len_byte = b & 0x7F;
159            let particular_len = Length::for_byte(len_byte).additional();
160
161            if particular_len > 0 {
162                const SIZE: usize = mem::size_of::<u64>();
163                assert!(
164                    particular_len < SIZE,
165                    "Length exceeded max size of unsigned 64-bit integer"
166                );
167
168                let start = SIZE - particular_len;
169                let mut buf = [0u8; SIZE];
170
171                match cursor.read_exact(&mut buf[start..]) {
172                    Err(ref e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
173                    Err(e) => return Err(e.into()),
174                    Ok(()) => u64::from_be_bytes(buf),
175                }
176            } else {
177                u64::from(len_byte)
178            }
179        };
180
181        let mask = if masked {
182            let mut mask_bytes = [0u8; 4];
183            if cursor.read(&mut mask_bytes)? != 4 {
184                return Ok(None);
185            } else {
186                Some(mask_bytes)
187            }
188        } else {
189            None
190        };
191
192        match opcode {
193            OpCode::Control(Control::Reserved(_)) => {
194                return Err(Error::Protocol(ProtocolError::UnknownControlOpCode(a & 0x0F)));
195            }
196            OpCode::Data(Data::Reserved(_)) => {
197                return Err(Error::Protocol(ProtocolError::UnknownDataOpCode(a & 0x0F)));
198            }
199            _ => (),
200        };
201
202        let header = FrameHeader { fin, rsv1, rsv2, rsv3, opcode, mask };
203
204        Ok(Some((header, len)))
205    }
206}
207
208impl Frame {}
209
210/// The WebSocket Frame
211#[derive(Debug, Clone, PartialEq, Eq)]
212pub struct Frame {
213    header: FrameHeader,
214    payload: Bytes,
215}
216
217impl Frame {
218    /// Get the length of the frame.
219    /// This is the length of the header + the length of the payload.
220    #[inline]
221    pub fn len(&self) -> usize {
222        let length = self.payload.len();
223        self.header.len(length as u64) + length
224    }
225
226    /// Check if the frame is empty.
227    #[inline]
228    pub fn is_empty(&self) -> bool {
229        self.len() == 0
230    }
231
232    /// Get a reference to the frame's header.
233    #[inline]
234    pub fn header(&self) -> &FrameHeader {
235        &self.header
236    }
237
238    /// Get a mutable reference to the frame's header.
239    #[inline]
240    pub fn header_mut(&mut self) -> &mut FrameHeader {
241        &mut self.header
242    }
243
244    /// Get a reference to the frame's payload.
245    #[inline]
246    pub fn payload(&self) -> &[u8] {
247        &self.payload
248    }
249
250    /// Test whether the frame is masked.
251    #[inline]
252    pub(crate) fn is_masked(&self) -> bool {
253        self.header.mask.is_some()
254    }
255
256    /// Generate a random mask for the frame.
257    ///
258    /// This just generates a mask, payload is not changed. The actual masking is performed
259    /// either on `format()` or on `apply_mask()` call.
260    #[inline]
261    pub(crate) fn set_random_mask(&mut self) {
262        self.header.set_random_mask();
263    }
264
265    /// Consume the frame into its payload as string.
266    #[inline]
267    pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
268        self.payload.try_into()
269    }
270
271    /// Consume the frame into its payload.
272    #[inline]
273    pub fn into_payload(self) -> Bytes {
274        self.payload
275    }
276
277    /// Get frame payload as `&str`.
278    #[inline]
279    pub fn to_text(&self) -> Result<&str, Utf8Error> {
280        std::str::from_utf8(&self.payload)
281    }
282
283    /// Consume the frame into a closing frame.
284    #[inline]
285    pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
286        match self.payload.len() {
287            0 => Ok(None),
288            1 => Err(Error::Protocol(ProtocolError::InvalidCloseFrame)),
289            _ => {
290                let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
291                let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
292
293                Ok(Some(CloseFrame { code, reason }))
294            }
295        }
296    }
297
298    /// Create a new data frame.
299    #[inline]
300    pub fn new_data(data: impl Into<Bytes>, opcode: OpCode, fin: bool) -> Frame {
301        debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame");
302
303        Frame { header: FrameHeader { fin, opcode, ..Default::default() }, payload: data.into() }
304    }
305
306    /// Create a new Ping control frame.
307    #[inline]
308    pub fn new_ping(data: impl Into<Bytes>) -> Frame {
309        Frame {
310            header: FrameHeader { opcode: OpCode::Control(Control::Ping), ..<_>::default() },
311            payload: data.into(),
312        }
313    }
314
315    /// Create a new Pong control frame.
316    #[inline]
317    pub fn new_pong(data: impl Into<Bytes>) -> Frame {
318        Frame {
319            header: FrameHeader { opcode: OpCode::Control(Control::Pong), ..<_>::default() },
320            payload: data.into(),
321        }
322    }
323
324    /// Create a new Close control frame.
325    #[inline]
326    pub fn new_close(msg: Option<CloseFrame>) -> Frame {
327        let payload = if let Some(CloseFrame { code, reason }) = msg {
328            let mut p = BytesMut::with_capacity(reason.len() + 2);
329            p.extend(u16::from(code).to_be_bytes());
330            p.extend_from_slice(reason.as_bytes());
331            p
332        } else {
333            <_>::default()
334        };
335
336        Frame { header: <_>::default(), payload: payload.into() }
337    }
338
339    /// Initializes a new frame
340    pub fn new(header: FrameHeader, payload: Bytes) -> Self {
341        Frame { header, payload }
342    }
343
344    /// Write a frame out to a buffer
345    pub fn format_to_buf(mut self, output: &mut impl Write) -> Result<()> {
346        self.header.format(self.payload.len() as u64, output)?;
347
348        if let Some(mask) = self.header.mask.take() {
349            let mut data = Vec::from(mem::take(&mut self.payload));
350            apply_mask(&mut data, mask);
351
352            output.write_all(&data)?;
353        } else {
354            output.write_all(&self.payload)?;
355        }
356
357        Ok(())
358    }
359
360    pub(crate) fn into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
361        self.header.format(self.payload.len() as u64, buf)?;
362
363        let len = buf.len();
364        buf.extend_from_slice(&self.payload);
365
366        if let Some(mask) = self.header.mask.take() {
367            apply_mask(&mut buf[len..], mask);
368        }
369
370        Ok(())
371    }
372}
373
374impl Display for Frame {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        use std::fmt::Write;
377
378        write!(
379            f,
380            "/
381            [FRAME]
382            final: {},
383            reserved: {} {} {},
384            opcode: {},
385            length: {},
386            payload-length: {},
387            payload: 0x{}
388            ",
389            self.header.fin,
390            self.header.rsv1,
391            self.header.rsv2,
392            self.header.rsv3,
393            self.header.opcode,
394            self.len(),
395            self.payload.len(),
396            self.payload.iter().fold(String::new(), |mut out, byte| {
397                _ = write!(out, "{byte:02x}");
398                out
399            })
400        )
401    }
402}
403
404enum Length {
405    U8(u8),
406    U16,
407    U64,
408}
409
410impl Length {
411    #[inline]
412    fn for_len(len: u64) -> Self {
413        if len < 126 {
414            Length::U8(len as u8)
415        } else if len < 65536 {
416            Length::U16
417        } else {
418            Length::U64
419        }
420    }
421
422    #[inline]
423    fn additional(&self) -> usize {
424        match *self {
425            Self::U8(_) => 0,
426            Self::U16 => 2,
427            Self::U64 => 8,
428        }
429    }
430
431    #[inline]
432    fn len_byte(&self) -> u8 {
433        match *self {
434            Self::U8(b) => b,
435            Self::U16 => 126,
436            Self::U64 => 127,
437        }
438    }
439
440    #[inline]
441    fn for_byte(byte: u8) -> Self {
442        match byte & 0x7F {
443            126 => Length::U16,
444            127 => Length::U64,
445            b => Length::U8(b),
446        }
447    }
448}