Skip to main content

hpx_fastwebsockets/
lib.rs

1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! _hpx-fastwebsockets_ is a minimal, fast WebSocket server implementation.
16//!
17//! Fork of [fastwebsockets](https://github.com/denoland/fastwebsockets) by Divy Srivastava.
18//!
19//! Passes the _Autobahn|TestSuite_ and fuzzed with LLVM's _libfuzzer_.
20//!
21//! You can use it as a raw websocket frame parser and deal with spec compliance yourself, or you can use it as a full-fledged websocket server.
22//!
23//! # Example
24//!
25//! ```
26//! use tokio::net::TcpStream;
27//! use fastwebsockets::{WebSocket, OpCode, Role};
28//! use eyre::Result;
29//!
30//! async fn handle(
31//!   socket: TcpStream,
32//! ) -> Result<()> {
33//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
34//!   ws.set_writev(false);
35//!   ws.set_auto_close(true);
36//!   ws.set_auto_pong(true);
37//!
38//!   loop {
39//!     let frame = ws.read_frame().await?;
40//!     match frame.opcode {
41//!       OpCode::Close => break,
42//!       OpCode::Text | OpCode::Binary => {
43//!         ws.write_frame(frame).await?;
44//!       }
45//!       _ => {}
46//!     }
47//!   }
48//!   Ok(())
49//! }
50//! ```
51//!
52//! ## Fragmentation
53//!
54//! By default, fastwebsockets will give the application raw frames with FIN set.
55//! To get a single message with all frames concatenated, use `FragmentCollector`.
56//!
57//! For concanated frames, use `FragmentCollector`:
58//! ```
59//! use fastwebsockets::{FragmentCollector, WebSocket, Role};
60//! use tokio::net::TcpStream;
61//! use eyre::Result;
62//!
63//! async fn handle(
64//!   socket: TcpStream,
65//! ) -> Result<()> {
66//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
67//!   let mut ws = FragmentCollector::new(ws);
68//!   let incoming = ws.read_frame().await?;
69//!   // Always returns full messages
70//!   assert!(incoming.fin);
71//!   Ok(())
72//! }
73//! ```
74//!
75//! _permessage-deflate is not supported yet._
76//!
77//! ## HTTP Upgrades
78//!
79//! Enable the `upgrade` feature to do server-side upgrades and client-side
80//! handshakes.
81//!
82//! This feature is powered by [hyper](https://docs.rs/hyper).
83//!
84//! ```
85//! use fastwebsockets::upgrade::upgrade;
86//! use http_body_util::Empty;
87//! use hyper::{Request, body::{Incoming, Bytes}, Response};
88//! use eyre::Result;
89//!
90//! async fn server_upgrade(
91//!   mut req: Request<Incoming>,
92//! ) -> Result<Response<Empty<Bytes>>> {
93//!   let (response, fut) = upgrade(&mut req)?;
94//!
95//!   tokio::spawn(async move {
96//!     let ws = fut.await;
97//!     // Do something with the websocket
98//!   });
99//!
100//!   Ok(response)
101//! }
102//! ```
103//!
104//! Use the `handshake` module for client-side handshakes.
105//!
106//! ```
107//! use fastwebsockets::handshake;
108//! use fastwebsockets::FragmentCollector;
109//! use hyper::{Request, body::Bytes, upgrade::Upgraded, header::{UPGRADE, CONNECTION}};
110//! use http_body_util::Empty;
111//! use hyper_util::rt::TokioIo;
112//! use tokio::net::TcpStream;
113//! use std::future::Future;
114//! use eyre::Result;
115//!
116//! async fn connect() -> Result<FragmentCollector<TokioIo<Upgraded>>> {
117//!   let stream = TcpStream::connect("localhost:9001").await?;
118//!
119//!   let req = Request::builder()
120//!     .method("GET")
121//!     .uri("http://localhost:9001/")
122//!     .header("Host", "localhost:9001")
123//!     .header(UPGRADE, "websocket")
124//!     .header(CONNECTION, "upgrade")
125//!     .header(
126//!       "Sec-WebSocket-Key",
127//!       fastwebsockets::handshake::generate_key(),
128//!     )
129//!     .header("Sec-WebSocket-Version", "13")
130//!     .body(Empty::<Bytes>::new())?;
131//!
132//!   let (ws, _) = handshake::client(&SpawnExecutor, req, stream).await?;
133//!   Ok(FragmentCollector::new(ws))
134//! }
135//!
136//! // Tie hyper's executor to tokio runtime
137//! struct SpawnExecutor;
138//!
139//! impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
140//! where
141//!   Fut: Future + Send + 'static,
142//!   Fut::Output: Send + 'static,
143//! {
144//!   fn execute(&self, fut: Fut) {
145//!     tokio::task::spawn(fut);
146//!   }
147//! }
148//! ```
149
150#![cfg_attr(docsrs, feature(doc_cfg))]
151
152mod close;
153mod error;
154mod fragment;
155mod frame;
156/// Client handshake.
157#[cfg(feature = "upgrade")]
158#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
159pub mod handshake;
160mod mask;
161/// HTTP upgrades.
162#[cfg(feature = "upgrade")]
163#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
164pub mod upgrade;
165
166use bytes::Buf;
167
168use bytes::BytesMut;
169#[cfg(feature = "unstable-split")]
170use std::future::Future;
171
172use tokio::io::AsyncRead;
173use tokio::io::AsyncReadExt;
174use tokio::io::AsyncWrite;
175use tokio::io::AsyncWriteExt;
176
177pub use crate::close::CloseCode;
178pub use crate::error::WebSocketError;
179pub use crate::fragment::FragmentCollector;
180#[cfg(feature = "unstable-split")]
181pub use crate::fragment::FragmentCollectorRead;
182pub use crate::frame::Frame;
183pub use crate::frame::OpCode;
184pub use crate::frame::Payload;
185pub use crate::mask::unmask;
186
187#[derive(Copy, Clone, PartialEq)]
188pub enum Role {
189  Server,
190  Client,
191}
192
193pub(crate) struct WriteHalf {
194  role: Role,
195  closed: bool,
196  vectored: bool,
197  auto_apply_mask: bool,
198  writev_threshold: usize,
199  write_buffer: Vec<u8>,
200}
201
202pub(crate) struct ReadHalf {
203  role: Role,
204  auto_apply_mask: bool,
205  auto_close: bool,
206  auto_pong: bool,
207  writev_threshold: usize,
208  max_message_size: usize,
209  buffer: BytesMut,
210}
211
212#[cfg(feature = "unstable-split")]
213pub struct WebSocketRead<S> {
214  stream: S,
215  read_half: ReadHalf,
216}
217
218#[cfg(feature = "unstable-split")]
219pub struct WebSocketWrite<S> {
220  stream: S,
221  write_half: WriteHalf,
222}
223
224#[cfg(feature = "unstable-split")]
225/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake.
226pub fn after_handshake_split<R, W>(
227  read: R,
228  write: W,
229  role: Role,
230) -> (WebSocketRead<R>, WebSocketWrite<W>)
231where
232  R: AsyncRead + Unpin,
233  W: AsyncWrite + Unpin,
234{
235  (
236    WebSocketRead {
237      stream: read,
238      read_half: ReadHalf::after_handshake(role),
239    },
240    WebSocketWrite {
241      stream: write,
242      write_half: WriteHalf::after_handshake(role),
243    },
244  )
245}
246
247#[cfg(feature = "unstable-split")]
248impl<'f, S> WebSocketRead<S> {
249  /// Consumes the `WebSocketRead` and returns the underlying stream.
250  #[inline]
251  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
252    (self.stream, self.read_half)
253  }
254
255  pub fn set_writev_threshold(&mut self, threshold: usize) {
256    self.read_half.writev_threshold = threshold;
257  }
258
259  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
260  ///
261  /// Default: `true`
262  pub fn set_auto_close(&mut self, auto_close: bool) {
263    self.read_half.auto_close = auto_close;
264  }
265
266  /// Sets whether to automatically send a pong frame when a ping frame is received.
267  ///
268  /// Default: `true`
269  pub fn set_auto_pong(&mut self, auto_pong: bool) {
270    self.read_half.auto_pong = auto_pong;
271  }
272
273  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
274  ///
275  /// Default: 64 MiB
276  pub fn set_max_message_size(&mut self, max_message_size: usize) {
277    self.read_half.max_message_size = max_message_size;
278  }
279
280  /// Sets whether to automatically apply the mask to the frame payload.
281  ///
282  /// Default: `true`
283  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
284    self.read_half.auto_apply_mask = auto_apply_mask;
285  }
286
287  /// Reads a frame from the stream.
288  pub async fn read_frame<R, E>(
289    &mut self,
290    send_fn: &mut impl FnMut(Frame<'f>) -> R,
291  ) -> Result<Frame<'_>, WebSocketError>
292  where
293    S: AsyncRead + Unpin,
294    E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
295    R: Future<Output = Result<(), E>>,
296  {
297    loop {
298      let (res, obligated_send) =
299        self.read_half.read_frame_inner(&mut self.stream).await;
300      if let Some(frame) = obligated_send {
301        let res = send_fn(frame).await;
302        res.map_err(|e| WebSocketError::SendError(e.into()))?;
303      }
304      if let Some(frame) = res? {
305        break Ok(frame);
306      }
307    }
308  }
309}
310
311#[cfg(feature = "unstable-split")]
312impl<'f, S> WebSocketWrite<S> {
313  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
314  ///
315  /// Default: `true`
316  pub fn set_writev(&mut self, vectored: bool) {
317    self.write_half.vectored = vectored;
318  }
319
320  pub fn set_writev_threshold(&mut self, threshold: usize) {
321    self.write_half.writev_threshold = threshold;
322  }
323
324  /// Sets whether to automatically apply the mask to the frame payload.
325  ///
326  /// Default: `true`
327  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
328    self.write_half.auto_apply_mask = auto_apply_mask;
329  }
330
331  pub fn is_closed(&self) -> bool {
332    self.write_half.closed
333  }
334
335  pub async fn write_frame(
336    &mut self,
337    frame: Frame<'f>,
338  ) -> Result<(), WebSocketError>
339  where
340    S: AsyncWrite + Unpin,
341  {
342    self.write_half.write_frame(&mut self.stream, frame).await
343  }
344
345  pub async fn flush(&mut self) -> Result<(), WebSocketError>
346  where
347    S: AsyncWrite + Unpin,
348  {
349    flush(&mut self.stream).await
350  }
351}
352
353#[inline]
354async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
355where
356  S: AsyncWrite + Unpin,
357{
358  stream.flush().await.map_err(WebSocketError::IoError)
359}
360
361/// WebSocket protocol implementation over an async stream.
362pub struct WebSocket<S> {
363  stream: S,
364  write_half: WriteHalf,
365  read_half: ReadHalf,
366}
367
368impl<'f, S> WebSocket<S> {
369  /// Creates a new `WebSocket` from a stream that has already completed the WebSocket handshake.
370  ///
371  /// Use the `upgrade` feature to handle server upgrades and client handshakes.
372  ///
373  /// # Example
374  ///
375  /// ```
376  /// use tokio::net::TcpStream;
377  /// use fastwebsockets::{WebSocket, OpCode, Role};
378  /// use eyre::Result;
379  ///
380  /// async fn handle_client(
381  ///   socket: TcpStream,
382  /// ) -> Result<()> {
383  ///   let mut ws = WebSocket::after_handshake(socket, Role::Server);
384  ///   // ...
385  ///   Ok(())
386  /// }
387  /// ```
388  pub fn after_handshake(stream: S, role: Role) -> Self
389  where
390    S: AsyncRead + AsyncWrite + Unpin,
391  {
392    Self {
393      stream,
394      write_half: WriteHalf::after_handshake(role),
395      read_half: ReadHalf::after_handshake(role),
396    }
397  }
398
399  /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not
400  /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that
401  /// is returned.
402  #[cfg(feature = "unstable-split")]
403  pub fn split<R, W>(
404    self,
405    split_fn: impl Fn(S) -> (R, W),
406  ) -> (WebSocketRead<R>, WebSocketWrite<W>)
407  where
408    S: AsyncRead + AsyncWrite + Unpin,
409    R: AsyncRead + Unpin,
410    W: AsyncWrite + Unpin,
411  {
412    let (stream, read, write) = self.into_parts_internal();
413    let (r, w) = split_fn(stream);
414    (
415      WebSocketRead {
416        stream: r,
417        read_half: read,
418      },
419      WebSocketWrite {
420        stream: w,
421        write_half: write,
422      },
423    )
424  }
425
426  /// Consumes the `WebSocket` and returns the underlying stream.
427  #[inline]
428  pub fn into_inner(self) -> S {
429    // self.write_half.into_inner().stream
430    self.stream
431  }
432
433  /// Consumes the `WebSocket` and returns the underlying stream.
434  #[inline]
435  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
436    (self.stream, self.read_half, self.write_half)
437  }
438
439  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
440  ///
441  /// Default: `true`
442  pub fn set_writev(&mut self, vectored: bool) {
443    self.write_half.vectored = vectored;
444  }
445
446  pub fn set_writev_threshold(&mut self, threshold: usize) {
447    self.read_half.writev_threshold = threshold;
448    self.write_half.writev_threshold = threshold;
449  }
450
451  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
452  ///
453  /// Default: `true`
454  pub fn set_auto_close(&mut self, auto_close: bool) {
455    self.read_half.auto_close = auto_close;
456  }
457
458  /// Sets whether to automatically send a pong frame when a ping frame is received.
459  ///
460  /// Default: `true`
461  pub fn set_auto_pong(&mut self, auto_pong: bool) {
462    self.read_half.auto_pong = auto_pong;
463  }
464
465  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
466  ///
467  /// Default: 64 MiB
468  pub fn set_max_message_size(&mut self, max_message_size: usize) {
469    self.read_half.max_message_size = max_message_size;
470  }
471
472  /// Sets whether to automatically apply the mask to the frame payload.
473  ///
474  /// Default: `true`
475  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
476    self.read_half.auto_apply_mask = auto_apply_mask;
477    self.write_half.auto_apply_mask = auto_apply_mask;
478  }
479
480  pub fn is_closed(&self) -> bool {
481    self.write_half.closed
482  }
483
484  /// Writes a frame to the stream.
485  ///
486  /// # Example
487  ///
488  /// ```
489  /// use fastwebsockets::{WebSocket, Frame, OpCode};
490  /// use tokio::net::TcpStream;
491  /// use eyre::Result;
492  ///
493  /// async fn send(
494  ///   ws: &mut WebSocket<TcpStream>
495  /// ) -> Result<()> {
496  ///   let mut frame = Frame::binary(vec![0x01, 0x02, 0x03].into());
497  ///   ws.write_frame(frame).await?;
498  ///   Ok(())
499  /// }
500  /// ```
501  pub async fn write_frame(
502    &mut self,
503    frame: Frame<'f>,
504  ) -> Result<(), WebSocketError>
505  where
506    S: AsyncRead + AsyncWrite + Unpin,
507  {
508    self.write_half.write_frame(&mut self.stream, frame).await?;
509    Ok(())
510  }
511
512  /// Flushes the data from the underlying stream.
513  ///
514  /// if the underlying stream is buffered (i.e: `TlsStream<TcpStream>`), it is needed to call flush
515  /// to be sure that the written frame are correctly pushed down to the bottom stream/channel.
516  ///
517  pub async fn flush(&mut self) -> Result<(), WebSocketError>
518  where
519    S: AsyncWrite + Unpin,
520  {
521    flush(&mut self.stream).await
522  }
523
524  /// Reads a frame from the stream.
525  ///
526  /// This method will unmask the frame payload. For fragmented frames, use `FragmentCollector::read_frame`.
527  ///
528  /// Text frames payload is guaranteed to be valid UTF-8.
529  ///
530  /// # Example
531  ///
532  /// ```
533  /// use fastwebsockets::{OpCode, WebSocket, Frame};
534  /// use tokio::net::TcpStream;
535  /// use eyre::Result;
536  ///
537  /// async fn echo(
538  ///   ws: &mut WebSocket<TcpStream>
539  /// ) -> Result<()> {
540  ///   let frame = ws.read_frame().await?;
541  ///   match frame.opcode {
542  ///     OpCode::Text | OpCode::Binary => {
543  ///       ws.write_frame(frame).await?;
544  ///     }
545  ///     _ => {}
546  ///   }
547  ///   Ok(())
548  /// }
549  /// ```
550  pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
551  where
552    S: AsyncRead + AsyncWrite + Unpin,
553  {
554    loop {
555      let (res, obligated_send) =
556        self.read_half.read_frame_inner(&mut self.stream).await;
557      let is_closed = self.write_half.closed;
558      if let Some(frame) = obligated_send
559        && !is_closed
560      {
561        self.write_half.write_frame(&mut self.stream, frame).await?;
562      }
563      if let Some(frame) = res? {
564        if is_closed && frame.opcode != OpCode::Close {
565          return Err(WebSocketError::ConnectionClosed);
566        }
567        break Ok(frame);
568      }
569    }
570  }
571}
572
573const MAX_HEADER_SIZE: usize = 14;
574
575impl ReadHalf {
576  pub fn after_handshake(role: Role) -> Self {
577    let buffer = BytesMut::with_capacity(8192);
578
579    Self {
580      role,
581      auto_apply_mask: true,
582      auto_close: true,
583      auto_pong: true,
584      writev_threshold: 1024,
585      max_message_size: 64 << 20,
586      buffer,
587    }
588  }
589
590  /// Attempt to read a single frame from the incoming stream, returning any send obligations if
591  /// `auto_close` or `auto_pong` are enabled. Callers to this function are obligated to send the
592  /// frame in the latter half of the tuple if one is specified, unless the write half of this socket
593  /// has been closed.
594  ///
595  /// XXX: Do not expose this method to the public API.
596  pub(crate) async fn read_frame_inner<'f, S>(
597    &mut self,
598    stream: &mut S,
599  ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
600  where
601    S: AsyncRead + Unpin,
602  {
603    let mut frame = match self.parse_frame_header(stream).await {
604      Ok(frame) => frame,
605      Err(e) => return (Err(e), None),
606    };
607
608    if self.role == Role::Server && self.auto_apply_mask {
609      frame.unmask()
610    };
611
612    match frame.opcode {
613      OpCode::Close if self.auto_close => {
614        match frame.payload.len() {
615          0 => {}
616          1 => return (Err(WebSocketError::InvalidCloseFrame), None),
617          _ => {
618            let code = close::CloseCode::from(u16::from_be_bytes(
619              frame.payload[0..2].try_into().unwrap(),
620            ));
621
622            #[cfg(feature = "simd")]
623            if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
624              return (Err(WebSocketError::InvalidUTF8), None);
625            };
626
627            #[cfg(not(feature = "simd"))]
628            if std::str::from_utf8(&frame.payload[2..]).is_err() {
629              return (Err(WebSocketError::InvalidUTF8), None);
630            };
631
632            if !code.is_allowed() {
633              return (
634                Err(WebSocketError::InvalidCloseCode),
635                Some(Frame::close(1002, &frame.payload[2..])),
636              );
637            }
638          }
639        };
640
641        let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
642        (Ok(Some(frame)), Some(obligated_send))
643      }
644      OpCode::Ping if self.auto_pong => {
645        (Ok(None), Some(Frame::pong(frame.payload)))
646      }
647      OpCode::Text => {
648        if frame.fin && !frame.is_utf8() {
649          (Err(WebSocketError::InvalidUTF8), None)
650        } else {
651          (Ok(Some(frame)), None)
652        }
653      }
654      _ => (Ok(Some(frame)), None),
655    }
656  }
657
658  async fn parse_frame_header<'a, S>(
659    &mut self,
660    stream: &mut S,
661  ) -> Result<Frame<'a>, WebSocketError>
662  where
663    S: AsyncRead + Unpin,
664  {
665    macro_rules! eof {
666      ($n:expr) => {{
667        if $n == 0 {
668          return Err(WebSocketError::UnexpectedEOF);
669        }
670      }};
671    }
672
673    // Read the first two bytes
674    while self.buffer.remaining() < 2 {
675      eof!(stream.read_buf(&mut self.buffer).await?);
676    }
677
678    let fin = self.buffer[0] & 0b10000000 != 0;
679    let rsv1 = self.buffer[0] & 0b01000000 != 0;
680    let rsv2 = self.buffer[0] & 0b00100000 != 0;
681    let rsv3 = self.buffer[0] & 0b00010000 != 0;
682
683    if rsv1 || rsv2 || rsv3 {
684      return Err(WebSocketError::ReservedBitsNotZero);
685    }
686
687    let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
688    let masked = self.buffer[1] & 0b10000000 != 0;
689
690    let length_code = self.buffer[1] & 0x7F;
691    let extra = match length_code {
692      126 => 2,
693      127 => 8,
694      _ => 0,
695    };
696
697    self.buffer.advance(2);
698    while self.buffer.remaining() < extra + masked as usize * 4 {
699      eof!(stream.read_buf(&mut self.buffer).await?);
700    }
701
702    let payload_len: usize = match extra {
703      0 => usize::from(length_code),
704      2 => self.buffer.get_u16() as usize,
705      #[cfg(target_pointer_width = "64")]
706      8 => self.buffer.get_u64() as usize,
707      // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing
708      #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
709      8 => match usize::try_from(self.buffer.get_u64()) {
710        Ok(length) => length,
711        Err(_) => return Err(WebSocketError::FrameTooLarge),
712      },
713      _ => unreachable!(),
714    };
715
716    let mask = if masked {
717      Some(self.buffer.get_u32().to_be_bytes())
718    } else {
719      None
720    };
721
722    if frame::is_control(opcode) && !fin {
723      return Err(WebSocketError::ControlFrameFragmented);
724    }
725
726    if opcode == OpCode::Ping && payload_len > 125 {
727      return Err(WebSocketError::PingFrameTooLarge);
728    }
729
730    if payload_len >= self.max_message_size {
731      return Err(WebSocketError::FrameTooLarge);
732    }
733
734    // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time
735    self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
736    while payload_len > self.buffer.remaining() {
737      eof!(stream.read_buf(&mut self.buffer).await?);
738    }
739
740    // if we read too much it will stay in the buffer, for the next call to this method
741    let payload = self.buffer.split_to(payload_len);
742    let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
743    Ok(frame)
744  }
745}
746
747impl WriteHalf {
748  pub fn after_handshake(role: Role) -> Self {
749    Self {
750      role,
751      closed: false,
752      auto_apply_mask: true,
753      vectored: true,
754      writev_threshold: 1024,
755      write_buffer: Vec::with_capacity(2),
756    }
757  }
758
759  /// Writes a frame to the provided stream.
760  pub async fn write_frame<'a, S>(
761    &'a mut self,
762    stream: &mut S,
763    mut frame: Frame<'a>,
764  ) -> Result<(), WebSocketError>
765  where
766    S: AsyncWrite + Unpin,
767  {
768    if self.role == Role::Client && self.auto_apply_mask {
769      frame.mask();
770    }
771
772    if frame.opcode == OpCode::Close {
773      self.closed = true;
774    } else if self.closed {
775      return Err(WebSocketError::ConnectionClosed);
776    }
777
778    if self.vectored && frame.payload.len() > self.writev_threshold {
779      frame.writev(stream).await?;
780    } else {
781      let text = frame.write(&mut self.write_buffer);
782      stream.write_all(text).await?;
783    }
784
785    Ok(())
786  }
787}
788
789#[cfg(test)]
790mod tests {
791  use super::*;
792
793  const _: () = {
794    const fn assert_unsync<S>() {
795      // Generic trait with a blanket impl over `()` for all types.
796      trait AmbiguousIfImpl<A> {
797        // Required for actually being able to reference the trait.
798        fn some_item() {}
799      }
800
801      impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
802
803      // Used for the specialized impl when *all* traits in
804      // `$($t)+` are implemented.
805      #[allow(dead_code)]
806      struct Invalid;
807
808      impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
809
810      // If there is only one specialized trait impl, type inference with
811      // `_` can be resolved and this can compile. Fails to compile if
812      // `$x` implements `AmbiguousIfImpl<Invalid>`.
813      let _ = <S as AmbiguousIfImpl<_>>::some_item;
814    }
815    assert_unsync::<WebSocket<tokio::net::TcpStream>>();
816  };
817}