Skip to main content

hpx_fastwebsockets/
frame.rs

1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use tokio::io::AsyncWriteExt;
16
17use bytes::BytesMut;
18use core::ops::Deref;
19
20use crate::WebSocketError;
21
22macro_rules! repr_u8 {
23    ($(#[$meta:meta])* $vis:vis enum $name:ident {
24      $($(#[$vmeta:meta])* $vname:ident $(= $val:expr)?,)*
25    }) => {
26      $(#[$meta])*
27      $vis enum $name {
28        $($(#[$vmeta])* $vname $(= $val)?,)*
29      }
30
31      impl core::convert::TryFrom<u8> for $name {
32        type Error = WebSocketError;
33
34        fn try_from(v: u8) -> Result<Self, Self::Error> {
35          match v {
36            $(x if x == $name::$vname as u8 => Ok($name::$vname),)*
37            _ => Err(WebSocketError::InvalidValue),
38          }
39        }
40      }
41    }
42}
43
44pub enum Payload<'a> {
45  BorrowedMut(&'a mut [u8]),
46  Borrowed(&'a [u8]),
47  Owned(Vec<u8>),
48  Bytes(BytesMut),
49}
50
51impl<'a> core::fmt::Debug for Payload<'a> {
52  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53    f.debug_struct("Payload").field("len", &self.len()).finish()
54  }
55}
56
57impl Deref for Payload<'_> {
58  type Target = [u8];
59
60  fn deref(&self) -> &Self::Target {
61    match self {
62      Payload::Borrowed(borrowed) => borrowed,
63      Payload::BorrowedMut(borrowed_mut) => borrowed_mut,
64      Payload::Owned(owned) => owned.as_ref(),
65      Payload::Bytes(b) => b.as_ref(),
66    }
67  }
68}
69
70impl<'a> From<&'a mut [u8]> for Payload<'a> {
71  fn from(borrowed: &'a mut [u8]) -> Payload<'a> {
72    Payload::BorrowedMut(borrowed)
73  }
74}
75
76impl<'a> From<&'a [u8]> for Payload<'a> {
77  fn from(borrowed: &'a [u8]) -> Payload<'a> {
78    Payload::Borrowed(borrowed)
79  }
80}
81
82impl From<Vec<u8>> for Payload<'_> {
83  fn from(owned: Vec<u8>) -> Self {
84    Payload::Owned(owned)
85  }
86}
87
88impl From<Payload<'_>> for Vec<u8> {
89  fn from(cow: Payload<'_>) -> Self {
90    match cow {
91      Payload::Borrowed(borrowed) => borrowed.to_vec(),
92      Payload::BorrowedMut(borrowed_mut) => borrowed_mut.to_vec(),
93      Payload::Owned(owned) => owned,
94      Payload::Bytes(b) => Vec::from(b),
95    }
96  }
97}
98
99impl Payload<'_> {
100  #[inline(always)]
101  pub fn to_mut(&mut self) -> &mut [u8] {
102    match self {
103      Payload::Borrowed(borrowed) => {
104        *self = Payload::Owned(borrowed.to_owned());
105        match self {
106          Payload::Owned(owned) => owned,
107          _ => unreachable!(),
108        }
109      }
110      Payload::BorrowedMut(borrowed) => borrowed,
111      Payload::Owned(owned) => owned,
112      Payload::Bytes(b) => b.as_mut(),
113    }
114  }
115}
116
117impl<'a> PartialEq<&'_ [u8]> for Payload<'a> {
118  fn eq(&self, other: &&'_ [u8]) -> bool {
119    self.deref() == *other
120  }
121}
122
123impl<'a, const N: usize> PartialEq<&'_ [u8; N]> for Payload<'a> {
124  fn eq(&self, other: &&'_ [u8; N]) -> bool {
125    self.deref() == *other
126  }
127}
128
129/// Represents a WebSocket frame.
130pub struct Frame<'f> {
131  /// Indicates if this is the final frame in a message.
132  pub fin: bool,
133  /// The opcode of the frame.
134  pub opcode: OpCode,
135  /// The masking key of the frame, if any.
136  mask: Option<[u8; 4]>,
137  /// The payload of the frame.
138  pub payload: Payload<'f>,
139}
140
141const MAX_HEAD_SIZE: usize = 16;
142
143impl<'f> Frame<'f> {
144  /// Creates a new WebSocket `Frame`.
145  pub fn new(
146    fin: bool,
147    opcode: OpCode,
148    mask: Option<[u8; 4]>,
149    payload: Payload<'f>,
150  ) -> Self {
151    Self {
152      fin,
153      opcode,
154      mask,
155      payload,
156    }
157  }
158
159  /// Create a new WebSocket text `Frame`.
160  ///
161  /// This is a convenience method for `Frame::new(true, OpCode::Text, None, payload)`.
162  ///
163  /// This method does not check if the payload is valid UTF-8.
164  pub fn text(payload: Payload<'f>) -> Self {
165    Self {
166      fin: true,
167      opcode: OpCode::Text,
168      mask: None,
169      payload,
170    }
171  }
172
173  /// Create a new WebSocket binary `Frame`.
174  ///
175  /// This is a convenience method for `Frame::new(true, OpCode::Binary, None, payload)`.
176  pub fn binary(payload: Payload<'f>) -> Self {
177    Self {
178      fin: true,
179      opcode: OpCode::Binary,
180      mask: None,
181      payload,
182    }
183  }
184
185  /// Create a new WebSocket close `Frame`.
186  ///
187  /// This is a convenience method for `Frame::new(true, OpCode::Close, None, payload)`.
188  ///
189  /// This method does not check if `code` is a valid close code and `reason` is valid UTF-8.
190  pub fn close(code: u16, reason: &[u8]) -> Self {
191    let mut payload = Vec::with_capacity(2 + reason.len());
192    payload.extend_from_slice(&code.to_be_bytes());
193    payload.extend_from_slice(reason);
194
195    Self {
196      fin: true,
197      opcode: OpCode::Close,
198      mask: None,
199      payload: payload.into(),
200    }
201  }
202
203  /// Create a new WebSocket close `Frame` with a raw payload.
204  ///
205  /// This is a convenience method for `Frame::new(true, OpCode::Close, None, payload)`.
206  ///
207  /// This method does not check if `payload` is valid Close frame payload.
208  pub fn close_raw(payload: Payload<'f>) -> Self {
209    Self {
210      fin: true,
211      opcode: OpCode::Close,
212      mask: None,
213      payload,
214    }
215  }
216
217  /// Create a new WebSocket pong `Frame`.
218  ///
219  /// This is a convenience method for `Frame::new(true, OpCode::Pong, None, payload)`.
220  pub fn pong(payload: Payload<'f>) -> Self {
221    Self {
222      fin: true,
223      opcode: OpCode::Pong,
224      mask: None,
225      payload,
226    }
227  }
228
229  /// Checks if the frame payload is valid UTF-8.
230  pub fn is_utf8(&self) -> bool {
231    #[cfg(feature = "simd")]
232    return simdutf8::basic::from_utf8(&self.payload).is_ok();
233
234    #[cfg(not(feature = "simd"))]
235    return std::str::from_utf8(&self.payload).is_ok();
236  }
237
238  pub fn mask(&mut self) {
239    if let Some(mask) = self.mask {
240      crate::mask::unmask(self.payload.to_mut(), mask);
241    } else {
242      let mask: [u8; 4] = rand::random();
243      crate::mask::unmask(self.payload.to_mut(), mask);
244      self.mask = Some(mask);
245    }
246  }
247
248  /// Unmasks the frame payload in-place. This method does nothing if the frame is not masked.
249  ///
250  /// Note: By default, the frame payload is unmasked by `WebSocket::read_frame`.
251  pub fn unmask(&mut self) {
252    if let Some(mask) = self.mask {
253      crate::mask::unmask(self.payload.to_mut(), mask);
254    }
255  }
256
257  /// Formats the frame header into the head buffer. Returns the size of the length field.
258  ///
259  /// # Panics
260  ///
261  /// This method panics if the head buffer is not at least n-bytes long, where n is the size of the length field (0, 2, 4, or 10)
262  pub fn fmt_head(&mut self, head: &mut [u8]) -> usize {
263    head[0] = (self.fin as u8) << 7 | (self.opcode as u8);
264
265    let len = self.payload.len();
266    let size = if len < 126 {
267      head[1] = len as u8;
268      2
269    } else if len < 65536 {
270      head[1] = 126;
271      head[2..4].copy_from_slice(&(len as u16).to_be_bytes());
272      4
273    } else {
274      head[1] = 127;
275      head[2..10].copy_from_slice(&(len as u64).to_be_bytes());
276      10
277    };
278
279    if let Some(mask) = self.mask {
280      head[1] |= 0x80;
281      head[size..size + 4].copy_from_slice(&mask);
282      size + 4
283    } else {
284      size
285    }
286  }
287
288  pub async fn writev<S>(
289    &mut self,
290    stream: &mut S,
291  ) -> Result<(), std::io::Error>
292  where
293    S: AsyncWriteExt + Unpin,
294  {
295    use std::io::IoSlice;
296
297    let mut head = [0; MAX_HEAD_SIZE];
298    let size = self.fmt_head(&mut head);
299
300    let total = size + self.payload.len();
301
302    let mut b = [IoSlice::new(&head[..size]), IoSlice::new(&self.payload)];
303
304    let mut n = stream.write_vectored(&b).await?;
305    if n == total {
306      return Ok(());
307    }
308
309    // Slightly more optimized than (unstable) write_all_vectored for 2 iovecs.
310    while n <= size {
311      b[0] = IoSlice::new(&head[n..size]);
312      n += stream.write_vectored(&b).await?;
313    }
314
315    // Header out of the way.
316    if n < total && n > size {
317      stream.write_all(&self.payload[n - size..]).await?;
318    }
319
320    Ok(())
321  }
322
323  /// Writes the frame to the buffer and returns a slice of the buffer containing the frame.
324  pub fn write<'a>(&mut self, buf: &'a mut Vec<u8>) -> &'a [u8] {
325    fn reserve_enough(buf: &mut Vec<u8>, len: usize) {
326      if buf.len() < len {
327        buf.resize(len, 0);
328      }
329    }
330    let len = self.payload.len();
331    reserve_enough(buf, len + MAX_HEAD_SIZE);
332
333    let size = self.fmt_head(buf);
334    buf[size..size + len].copy_from_slice(&self.payload);
335    &buf[..size + len]
336  }
337}
338
339repr_u8! {
340    #[repr(u8)]
341    #[derive(Debug, Copy, Clone, PartialEq, Eq)]
342    pub enum OpCode {
343        Continuation = 0x0,
344        Text = 0x1,
345        Binary = 0x2,
346        Close = 0x8,
347        Ping = 0x9,
348        Pong = 0xA,
349    }
350}
351
352#[inline]
353pub fn is_control(opcode: OpCode) -> bool {
354  matches!(opcode, OpCode::Close | OpCode::Ping | OpCode::Pong)
355}