fastwebsockets_monoio/
lib.rs

1// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! _fastwebsockets_ is a minimal, fast WebSocket server implementation.
16//!
17//! [https://github.com/denoland/fastwebsockets](https://github.com/denoland/fastwebsockets)
18//! [https://github.com/denoland/fastwebsockets](https://github.com/denoland/fastwebsockets)
19//!
20//! Passes the _Autobahn|TestSuite_ and fuzzed with LLVM's _libfuzzer_.
21//!
22//! You can use it as a raw websocket frame parser and deal with spec compliance yourself, or you can use it as a full-fledged websocket server.
23//!
24//! # Example
25//!
26//! ```
27//! use tokio::net::TcpStream;
28//! use fastwebsockets_monoio::{WebSocket, OpCode, Role};
29//! use anyhow::Result;
30//!
31//! async fn handle(
32//!   socket: TcpStream,
33//! ) -> Result<()> {
34//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
35//!   ws.set_writev(false);
36//!   ws.set_auto_close(true);
37//!   ws.set_auto_pong(true);
38//!
39//!   loop {
40//!     let frame = ws.read_frame().await?;
41//!     match frame.opcode {
42//!       OpCode::Close => break,
43//!       OpCode::Text | OpCode::Binary => {
44//!         ws.write_frame(frame).await?;
45//!       }
46//!       _ => {}
47//!     }
48//!   }
49//!   Ok(())
50//! }
51//! ```
52//!
53//! ## Fragmentation
54//!
55//! By default, fastwebsockets will give the application raw frames with FIN set. Other
56//! crates like tungstenite which will give you a single message with all the frames
57//! concatenated.
58//!
59//! For concanated frames, use `FragmentCollector`:
60//! ```
61//! use fastwebsockets_monoio::{FragmentCollector, WebSocket, Role};
62//! use tokio::net::TcpStream;
63//! use anyhow::Result;
64//!
65//! async fn handle(
66//!   socket: TcpStream,
67//! ) -> Result<()> {
68//!   let mut ws = WebSocket::after_handshake(socket, Role::Server);
69//!   let mut ws = FragmentCollector::new(ws);
70//!   let incoming = ws.read_frame().await?;
71//!   // Always returns full messages
72//!   assert!(incoming.fin);
73//!   Ok(())
74//! }
75//! ```
76//!
77//! _permessage-deflate is not supported yet._
78//!
79//! ## HTTP Upgrades
80//!
81//! Enable the `upgrade` feature to do server-side upgrades and client-side
82//! handshakes.
83//!
84//! This feature is powered by [hyper](https://docs.rs/hyper).
85//!
86//! ```
87//! use fastwebsockets_monoio::upgrade::upgrade;
88//! use http_body_util::Empty;
89//! use hyper::{Request, body::{Incoming, Bytes}, Response};
90//! use anyhow::Result;
91//!
92//! async fn server_upgrade(
93//!   mut req: Request<Incoming>,
94//! ) -> Result<Response<Empty<Bytes>>> {
95//!   mut req: Request<Incoming>,
96//! ) -> Result<Response<Empty<Bytes>>> {
97//!   let (response, fut) = upgrade(&mut req)?;
98//!
99//!   tokio::spawn(async move {
100//!     let ws = fut.await;
101//!     // Do something with the websocket
102//!   });
103//!
104//!   Ok(response)
105//! }
106//! ```
107//!
108//! Use the `handshake` module for client-side handshakes.
109//!
110//! ```
111//! use fastwebsockets_monoio::handshake;
112//! use fastwebsockets_monoio::FragmentCollector;
113//! use hyper::{Request, body::Bytes, upgrade::Upgraded, header::{UPGRADE, CONNECTION}};
114//! use http_body_util::Empty;
115//! use hyper_util::rt::TokioIo;
116//! use tokio::net::TcpStream;
117//! use std::future::Future;
118//! use anyhow::Result;
119//!
120//! async fn connect() -> Result<FragmentCollector<TokioIo<Upgraded>>> {
121//! async fn connect() -> Result<FragmentCollector<TokioIo<Upgraded>>> {
122//!   let stream = TcpStream::connect("localhost:9001").await?;
123//!
124//!   let req = Request::builder()
125//!     .method("GET")
126//!     .uri("http://localhost:9001/")
127//!     .header("Host", "localhost:9001")
128//!     .header(UPGRADE, "websocket")
129//!     .header(CONNECTION, "upgrade")
130//!     .header(
131//!       "Sec-WebSocket-Key",
132//!       fastwebsockets_monoio::handshake::generate_key(),
133//!     )
134//!     .header("Sec-WebSocket-Version", "13")
135//!     .body(Empty::<Bytes>::new())?;
136//!     .body(Empty::<Bytes>::new())?;
137//!
138//!   let (ws, _) = handshake::client(&SpawnExecutor, req, stream).await?;
139//!   Ok(FragmentCollector::new(ws))
140//! }
141//!
142//! // Tie hyper's executor to tokio runtime
143//! struct SpawnExecutor;
144//!
145//! impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
146//! where
147//!   Fut: Future + Send + 'static,
148//!   Fut::Output: Send + 'static,
149//! {
150//!   fn execute(&self, fut: Fut) {
151//!     tokio::task::spawn(fut);
152//!   }
153//! }
154//! ```
155
156#![cfg_attr(docsrs, feature(doc_cfg))]
157
158mod close;
159mod error;
160mod fragment;
161mod frame;
162/// Client handshake.
163#[cfg(feature = "upgrade")]
164#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
165pub mod handshake;
166mod mask;
167/// HTTP upgrades.
168#[cfg(feature = "upgrade")]
169#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
170pub mod upgrade;
171
172use bytes::Buf;
173
174use bytes::BytesMut;
175#[cfg(feature = "unstable-split")]
176use std::future::Future;
177
178use tokio::io::AsyncRead;
179#[cfg(feature = "unstable-split")]
180use std::future::Future;
181
182use tokio::io::AsyncReadExt;
183use tokio::io::AsyncWrite;
184use tokio::io::AsyncWriteExt;
185pub use crate::close::CloseCode;
186pub use crate::error::WebSocketError;
187pub use crate::fragment::FragmentCollector;
188#[cfg(feature = "unstable-split")]
189pub use crate::fragment::FragmentCollectorRead;
190pub use crate::frame::Frame;
191pub use crate::frame::OpCode;
192pub use crate::frame::Payload;
193pub use crate::mask::unmask;
194
195#[derive(Copy, Clone, PartialEq)]
196pub enum Role {
197  Server,
198  Client,
199}
200
201pub(crate) struct WriteHalf {
202  role: Role,
203  closed: bool,
204  vectored: bool,
205  auto_apply_mask: bool,
206  writev_threshold: usize,
207  write_buffer: Vec<u8>,
208}
209
210pub(crate) struct ReadHalf {
211  role: Role,
212  auto_apply_mask: bool,
213  auto_close: bool,
214  auto_pong: bool,
215  writev_threshold: usize,
216  max_message_size: usize,
217  buffer: BytesMut,
218}
219
220#[cfg(feature = "unstable-split")]
221pub struct WebSocketRead<S> {
222  stream: S,
223  read_half: ReadHalf,
224}
225
226#[cfg(feature = "unstable-split")]
227pub struct WebSocketWrite<S> {
228  stream: S,
229  write_half: WriteHalf,
230}
231
232#[cfg(feature = "unstable-split")]
233/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake.
234pub fn after_handshake_split<R, W>(
235  read: R,
236  write: W,
237  role: Role,
238) -> (WebSocketRead<R>, WebSocketWrite<W>)
239where
240  R: AsyncRead + Unpin,
241  W: AsyncWrite + Unpin,
242{
243  (
244    WebSocketRead {
245      stream: read,
246      read_half: ReadHalf::after_handshake(role),
247    },
248    WebSocketWrite {
249      stream: write,
250      write_half: WriteHalf::after_handshake(role),
251    },
252  )
253}
254
255#[cfg(feature = "unstable-split")]
256impl<'f, S> WebSocketRead<S> {
257  /// Consumes the `WebSocketRead` and returns the underlying stream.
258  #[inline]
259  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
260    (self.stream, self.read_half)
261  }
262
263  pub fn set_writev_threshold(&mut self, threshold: usize) {
264    self.read_half.writev_threshold = threshold;
265  }
266
267  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
268  ///
269  /// Default: `true`
270  pub fn set_auto_close(&mut self, auto_close: bool) {
271    self.read_half.auto_close = auto_close;
272  }
273
274  /// Sets whether to automatically send a pong frame when a ping frame is received.
275  ///
276  /// Default: `true`
277  pub fn set_auto_pong(&mut self, auto_pong: bool) {
278    self.read_half.auto_pong = auto_pong;
279  }
280
281  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
282  ///
283  /// Default: 64 MiB
284  pub fn set_max_message_size(&mut self, max_message_size: usize) {
285    self.read_half.max_message_size = max_message_size;
286  }
287
288  /// Sets whether to automatically apply the mask to the frame payload.
289  ///
290  /// Default: `true`
291  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
292    self.read_half.auto_apply_mask = auto_apply_mask;
293  }
294
295  /// Reads a frame from the stream.
296  pub async fn read_frame<R, E>(
297    &mut self,
298    send_fn: &mut impl FnMut(Frame<'f>) -> R,
299  ) -> Result<Frame, WebSocketError>
300  where
301    S: AsyncRead + Unpin,
302    E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
303    R: Future<Output = Result<(), E>>,
304  {
305    loop {
306      let (res, obligated_send) =
307        self.read_half.read_frame_inner(&mut self.stream).await;
308      if let Some(frame) = obligated_send {
309        let res = send_fn(frame).await;
310        res.map_err(|e| WebSocketError::SendError(e.into()))?;
311      }
312      if let Some(frame) = res? {
313        break Ok(frame);
314      }
315    }
316  }
317}
318
319#[cfg(feature = "unstable-split")]
320impl<'f, S> WebSocketWrite<S> {
321  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
322  ///
323  /// Default: `true`
324  pub fn set_writev(&mut self, vectored: bool) {
325    self.write_half.vectored = vectored;
326  }
327
328  pub fn set_writev_threshold(&mut self, threshold: usize) {
329    self.write_half.writev_threshold = threshold;
330  }
331
332  /// Sets whether to automatically apply the mask to the frame payload.
333  ///
334  /// Default: `true`
335  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
336    self.write_half.auto_apply_mask = auto_apply_mask;
337  }
338
339  pub fn is_closed(&self) -> bool {
340    self.write_half.closed
341  }
342
343  pub async fn write_frame(
344    &mut self,
345    frame: Frame<'f>,
346  ) -> Result<(), WebSocketError>
347  where
348    S: AsyncWrite + Unpin,
349  {
350    self.write_half.write_frame(&mut self.stream, frame).await
351  }
352
353  pub async fn flush(&mut self) -> Result<(), WebSocketError>
354  where
355    S: AsyncWrite + Unpin,
356  {
357    flush(&mut self.stream).await
358  }
359}
360
361#[inline]
362async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
363where
364  S: AsyncWrite + Unpin,
365{
366  stream.flush().await.map_err(WebSocketError::IoError)
367}
368
369#[cfg(feature = "unstable-split")]
370pub struct WebSocketRead<S> {
371  stream: S,
372  read_half: ReadHalf,
373}
374
375#[cfg(feature = "unstable-split")]
376pub struct WebSocketWrite<S> {
377  stream: S,
378  write_half: WriteHalf,
379}
380
381#[cfg(feature = "unstable-split")]
382/// Create a split `WebSocketRead`/`WebSocketWrite` pair from a stream that has already completed the WebSocket handshake.
383pub fn after_handshake_split<R, W>(
384  read: R,
385  write: W,
386  role: Role,
387) -> (WebSocketRead<R>, WebSocketWrite<W>)
388where
389  R: AsyncRead + Unpin,
390  W: AsyncWrite + Unpin,
391{
392  (
393    WebSocketRead {
394      stream: read,
395      read_half: ReadHalf::after_handshake(role),
396    },
397    WebSocketWrite {
398      stream: write,
399      write_half: WriteHalf::after_handshake(role),
400    },
401  )
402}
403
404#[cfg(feature = "unstable-split")]
405impl<'f, S> WebSocketRead<S> {
406  /// Consumes the `WebSocketRead` and returns the underlying stream.
407  #[inline]
408  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
409    (self.stream, self.read_half)
410  }
411
412  pub fn set_writev_threshold(&mut self, threshold: usize) {
413    self.read_half.writev_threshold = threshold;
414  }
415
416  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
417  ///
418  /// Default: `true`
419  pub fn set_auto_close(&mut self, auto_close: bool) {
420    self.read_half.auto_close = auto_close;
421  }
422
423  /// Sets whether to automatically send a pong frame when a ping frame is received.
424  ///
425  /// Default: `true`
426  pub fn set_auto_pong(&mut self, auto_pong: bool) {
427    self.read_half.auto_pong = auto_pong;
428  }
429
430  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
431  ///
432  /// Default: 64 MiB
433  pub fn set_max_message_size(&mut self, max_message_size: usize) {
434    self.read_half.max_message_size = max_message_size;
435  }
436
437  /// Sets whether to automatically apply the mask to the frame payload.
438  ///
439  /// Default: `true`
440  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
441    self.read_half.auto_apply_mask = auto_apply_mask;
442  }
443
444  /// Reads a frame from the stream.
445  pub async fn read_frame<R, E>(
446    &mut self,
447    send_fn: &mut impl FnMut(Frame<'f>) -> R,
448  ) -> Result<Frame, WebSocketError>
449  where
450    S: AsyncRead + Unpin,
451    E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
452    R: Future<Output = Result<(), E>>,
453  {
454    loop {
455      let (res, obligated_send) =
456        self.read_half.read_frame_inner(&mut self.stream).await;
457      if let Some(frame) = obligated_send {
458        let res = send_fn(frame).await;
459        res.map_err(|e| WebSocketError::SendError(e.into()))?;
460      }
461      if let Some(frame) = res? {
462        break Ok(frame);
463      }
464    }
465  }
466}
467
468#[cfg(feature = "unstable-split")]
469impl<'f, S> WebSocketWrite<S> {
470  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
471  ///
472  /// Default: `true`
473  pub fn set_writev(&mut self, vectored: bool) {
474    self.write_half.vectored = vectored;
475  }
476
477  pub fn set_writev_threshold(&mut self, threshold: usize) {
478    self.write_half.writev_threshold = threshold;
479  }
480
481  /// Sets whether to automatically apply the mask to the frame payload.
482  ///
483  /// Default: `true`
484  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
485    self.write_half.auto_apply_mask = auto_apply_mask;
486  }
487
488  pub fn is_closed(&self) -> bool {
489    self.write_half.closed
490  }
491
492  pub async fn write_frame(
493    &mut self,
494    frame: Frame<'f>,
495  ) -> Result<(), WebSocketError>
496  where
497    S: AsyncWrite + Unpin,
498  {
499    self.write_half.write_frame(&mut self.stream, frame).await
500  }
501
502  pub async fn flush(&mut self) -> Result<(), WebSocketError>
503  where
504    S: AsyncWrite + Unpin,
505  {
506    flush(&mut self.stream).await
507  }
508}
509
510/// WebSocket protocol implementation over an async stream.
511pub struct WebSocket<S> {
512  stream: S,
513  write_half: WriteHalf,
514  read_half: ReadHalf,
515}
516
517impl<'f, S> WebSocket<S> {
518  /// Creates a new `WebSocket` from a stream that has already completed the WebSocket handshake.
519  ///
520  /// Use the `upgrade` feature to handle server upgrades and client handshakes.
521  ///
522  /// # Example
523  ///
524  /// ```
525  /// use tokio::net::TcpStream;
526  /// use fastwebsockets_monoio::{WebSocket, OpCode, Role};
527  /// use anyhow::Result;
528  ///
529  /// async fn handle_client(
530  ///   socket: TcpStream,
531  /// ) -> Result<()> {
532  ///   let mut ws = WebSocket::after_handshake(socket, Role::Server);
533  ///   // ...
534  ///   Ok(())
535  /// }
536  /// ```
537  pub fn after_handshake(stream: S, role: Role) -> Self
538  where
539    S: AsyncRead + AsyncWrite + Unpin,
540    S: AsyncRead + AsyncWrite + Unpin,
541  {
542    Self {
543      stream,
544      write_half: WriteHalf::after_handshake(role),
545      read_half: ReadHalf::after_handshake(role),
546    }
547  }
548
549  /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not
550  /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that
551  /// is returned.
552  #[cfg(feature = "unstable-split")]
553  pub fn split<R, W>(
554    self,
555    split_fn: impl Fn(S) -> (R, W),
556  ) -> (WebSocketRead<R>, WebSocketWrite<W>)
557  where
558    S: AsyncRead + AsyncWrite + Unpin,
559    R: AsyncRead + Unpin,
560    W: AsyncWrite + Unpin,
561  {
562    let (stream, read, write) = self.into_parts_internal();
563    let (r, w) = split_fn(stream);
564    (
565      WebSocketRead {
566        stream: r,
567        read_half: read,
568      write_half: WriteHalf::after_handshake(role),
569      read_half: ReadHalf::after_handshake(role),
570    }
571    )
572  }
573
574  /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not
575  /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that
576  /// is returned.
577  #[cfg(feature = "unstable-split")]
578  pub fn split<R, W>(
579    self,
580    split_fn: impl Fn(S) -> (R, W),
581  ) -> (WebSocketRead<R>, WebSocketWrite<W>)
582  where
583    S: AsyncRead + AsyncWrite + Unpin,
584    R: AsyncRead + Unpin,
585    W: AsyncWrite + Unpin,
586  {
587    let (stream, read, write) = self.into_parts_internal();
588    let (r, w) = split_fn(stream);
589    (
590      WebSocketRead {
591        stream: r,
592        read_half: read,
593      },
594      WebSocketWrite {
595        stream: w,
596        write_half: write,
597      },
598    )
599  }
600
601  /// Consumes the `WebSocket` and returns the underlying stream.
602  #[inline]
603  pub fn into_inner(self) -> S {
604    // self.write_half.into_inner().stream
605    self.stream
606  }
607
608  /// Consumes the `WebSocket` and returns the underlying stream.
609  #[inline]
610  pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
611    (self.stream, self.read_half, self.write_half)
612  }
613
614  /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used.
615  ///
616  /// Default: `true`
617  pub fn set_writev(&mut self, vectored: bool) {
618    self.write_half.vectored = vectored;
619  }
620
621  pub fn set_writev_threshold(&mut self, threshold: usize) {
622    self.read_half.writev_threshold = threshold;
623    self.write_half.writev_threshold = threshold;
624  }
625
626  /// Sets whether to automatically close the connection when a close frame is received. When set to `false`, the application will have to manually send close frames.
627  ///
628  /// Default: `true`
629  pub fn set_auto_close(&mut self, auto_close: bool) {
630    self.read_half.auto_close = auto_close;
631  }
632
633  /// Sets whether to automatically send a pong frame when a ping frame is received.
634  ///
635  /// Default: `true`
636  pub fn set_auto_pong(&mut self, auto_pong: bool) {
637    self.read_half.auto_pong = auto_pong;
638  }
639
640  /// Sets the maximum message size in bytes. If a message is received that is larger than this, the connection will be closed.
641  ///
642  /// Default: 64 MiB
643  pub fn set_max_message_size(&mut self, max_message_size: usize) {
644    self.read_half.max_message_size = max_message_size;
645  }
646
647  /// Sets whether to automatically apply the mask to the frame payload.
648  ///
649  /// Default: `true`
650  pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
651    self.read_half.auto_apply_mask = auto_apply_mask;
652    self.write_half.auto_apply_mask = auto_apply_mask;
653  }
654
655  pub fn is_closed(&self) -> bool {
656    self.write_half.closed
657  }
658
659
660  /// Writes a frame to the stream.
661  ///
662  /// # Example
663  ///
664  /// ```
665  /// use fastwebsockets_monoio::{WebSocket, Frame, OpCode};
666  /// use tokio::net::TcpStream;
667  /// use anyhow::Result;
668  ///
669  /// async fn send(
670  ///   ws: &mut WebSocket<TcpStream>
671  /// ) -> Result<()> {
672  ///   let mut frame = Frame::binary(vec![0x01, 0x02, 0x03].into());
673  ///   ws.write_frame(frame).await?;
674  ///   Ok(())
675  /// }
676  /// ```
677  pub async fn write_frame(
678    &mut self,
679    frame: Frame<'f>,
680  ) -> Result<(), WebSocketError>
681  where
682    S: AsyncRead + AsyncWrite + Unpin,
683    S: AsyncRead + AsyncWrite + Unpin,
684  {
685    self.write_half.write_frame(&mut self.stream, frame).await?;
686    Ok(())
687  }
688
689  /// Flushes the data from the underlying stream.
690  ///
691  /// if the underlying stream is buffered (i.e: TlsStream<TcpStream>), it is needed to call flush
692  /// to be sure that the written frame are correctly pushed down to the bottom stream/channel.
693  ///
694  pub async fn flush(&mut self) -> Result<(), WebSocketError>
695  where
696    S: AsyncWrite + Unpin,
697  {
698    flush(&mut self.stream).await
699  }
700
701  /// Flushes the data from the underlying stream.
702  ///
703  /// if the underlying stream is buffered (i.e: TlsStream<TcpStream>), it is needed to call flush
704  /// to be sure that the written frame are correctly pushed down to the bottom stream/channel.
705  ///
706  /// Reads a frame from the stream.
707  ///
708  /// This method will unmask the frame payload. For fragmented frames, use `FragmentCollector::read_frame`.
709  ///
710  /// Text frames payload is guaranteed to be valid UTF-8.
711  ///
712  /// # Example
713  ///
714  /// ```
715  /// use fastwebsockets_monoio::{OpCode, WebSocket, Frame};
716  /// use tokio::net::TcpStream;
717  /// use anyhow::Result;
718  ///
719  /// async fn echo(
720  ///   ws: &mut WebSocket<TcpStream>
721  /// ) -> Result<()> {
722  ///   let frame = ws.read_frame().await?;
723  ///   match frame.opcode {
724  ///     OpCode::Text | OpCode::Binary => {
725  ///       ws.write_frame(frame).await?;
726  ///     }
727  ///     _ => {}
728  ///   }
729  ///   Ok(())
730  /// }
731  /// ```
732  pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
733  where
734    S: AsyncRead + AsyncWrite + Unpin,
735    S: AsyncRead + AsyncWrite + Unpin,
736  {
737    loop {
738      let (res, obligated_send) =
739        self.read_half.read_frame_inner(&mut self.stream).await;
740      let is_closed = self.write_half.closed;
741      if let Some(frame) = obligated_send {
742        if !is_closed {
743          self.write_half.write_frame(&mut self.stream, frame).await?;
744        }
745      }
746      if let Some(frame) = res? {
747        if is_closed && frame.opcode != OpCode::Close {
748          return Err(WebSocketError::ConnectionClosed);
749        }
750        break Ok(frame);
751      }
752    }
753  }
754}
755
756const MAX_HEADER_SIZE: usize = 14;
757
758impl ReadHalf {
759  pub fn after_handshake(role: Role) -> Self {
760    let buffer = BytesMut::with_capacity(8192);
761
762    Self {
763      role,
764      auto_apply_mask: true,
765      auto_close: true,
766      auto_pong: true,
767      writev_threshold: 1024,
768      max_message_size: 64 << 20,
769      buffer,
770    }
771  }
772
773  /// Attempt to read a single frame from from the incoming stream, returning any send obligations if
774  /// `auto_close` or `auto_pong` are enabled. Callers to this function are obligated to send the
775  /// frame in the latter half of the tuple if one is specified, unless the write half of this socket
776  /// has been closed.
777  ///
778  /// XXX: Do not expose this method to the public API.
779  pub(crate) async fn read_frame_inner<'f, S>(
780    &mut self,
781    stream: &mut S,
782  ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
783  where
784    S: AsyncRead + Unpin,
785  {
786    let mut frame = match self.parse_frame_header(stream).await {
787      Ok(frame) => frame,
788      Err(e) => return (Err(e), None),
789    };
790
791    if self.role == Role::Server && self.auto_apply_mask {
792      frame.unmask()
793    };
794
795    match frame.opcode {
796      OpCode::Close if self.auto_close => {
797        match frame.payload.len() {
798          0 => {}
799          1 => return (Err(WebSocketError::InvalidCloseFrame), None),
800          _ => {
801            let code = close::CloseCode::from(u16::from_be_bytes(
802              frame.payload[0..2].try_into().unwrap(),
803            ));
804
805            #[cfg(feature = "simd")]
806            if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
807              return (Err(WebSocketError::InvalidUTF8), None);
808            };
809
810            #[cfg(not(feature = "simd"))]
811            if std::str::from_utf8(&frame.payload[2..]).is_err() {
812              return (Err(WebSocketError::InvalidUTF8), None);
813            };
814
815            if !code.is_allowed() {
816              return (
817                Err(WebSocketError::InvalidCloseCode),
818                Some(Frame::close(1002, &frame.payload[2..])),
819              );
820            }
821          }
822        };
823
824        let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
825        (Ok(Some(frame)), Some(obligated_send))
826      }
827      OpCode::Ping if self.auto_pong => {
828        (Ok(None), Some(Frame::pong(frame.payload)))
829      }
830      OpCode::Text => {
831        if frame.fin && !frame.is_utf8() {
832          (Err(WebSocketError::InvalidUTF8), None)
833        } else {
834          (Ok(Some(frame)), None)
835        }
836      }
837      _ => (Ok(Some(frame)), None),
838    }
839  }
840
841  async fn parse_frame_header<'a, S>(
842    &mut self,
843    stream: &mut S,
844  ) -> Result<Frame<'a>, WebSocketError>
845  where
846    S: AsyncRead + Unpin,
847  {
848    macro_rules! eof {
849      ($n:expr) => {{
850        if $n == 0 {
851          return Err(WebSocketError::UnexpectedEOF);
852        }
853      }};
854    }
855
856    // Read the first two bytes
857    while self.buffer.remaining() < 2 {
858      eof!(stream.read_buf(&mut self.buffer).await?);
859    }
860
861    let fin = self.buffer[0] & 0b10000000 != 0;
862    let rsv1 = self.buffer[0] & 0b01000000 != 0;
863    let rsv2 = self.buffer[0] & 0b00100000 != 0;
864    let rsv3 = self.buffer[0] & 0b00010000 != 0;
865
866    if rsv1 || rsv2 || rsv3 {
867      return Err(WebSocketError::ReservedBitsNotZero);
868    }
869
870    let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
871    let masked = self.buffer[1] & 0b10000000 != 0;
872
873    let length_code = self.buffer[1] & 0x7F;
874    let extra = match length_code {
875      126 => 2,
876      127 => 8,
877      _ => 0,
878    };
879
880    self.buffer.advance(2);
881    while self.buffer.remaining() < extra + masked as usize * 4 {
882      eof!(stream.read_buf(&mut self.buffer).await?);
883    }
884
885    #[allow(unexpected_cfgs)]
886    let payload_len: usize = match extra {
887      0 => usize::from(length_code),
888      2 => self.buffer.get_u16() as usize,
889      #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
890      8 => self.buffer.get_u64() as usize,
891      // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing
892      #[cfg(any(
893        target_pointer_width = "8",
894        target_pointer_width = "16",
895        target_pointer_width = "32"
896      ))]
897      8 => match usize::try_from(self.buffer.get_u64()) {
898        Ok(length) => length,
899        Err(_) => return Err(WebSocketError::FrameTooLarge),
900      },
901      _ => unreachable!(),
902    };
903
904    let mask = if masked {
905      Some(self.buffer.get_u32().to_be_bytes())
906    } else {
907      None
908    };
909
910    if frame::is_control(opcode) && !fin {
911      return Err(WebSocketError::ControlFrameFragmented);
912    }
913
914    if opcode == OpCode::Ping && payload_len > 125 {
915      return Err(WebSocketError::PingFrameTooLarge);
916    }
917
918    if payload_len >= self.max_message_size {
919      return Err(WebSocketError::FrameTooLarge);
920    }
921
922    // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time
923    // if we read too much it will stay in the buffer, for the next call to this method
924    // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time
925    self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
926    while payload_len > self.buffer.remaining() {
927      eof!(stream.read_buf(&mut self.buffer).await?);
928    }
929
930    // if we read too much it will stay in the buffer, for the next call to this method
931    let payload = self.buffer.split_to(payload_len);
932    let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
933    Ok(frame)
934  }
935}
936
937impl WriteHalf {
938  pub fn after_handshake(role: Role) -> Self {
939    Self {
940      role,
941      closed: false,
942      auto_apply_mask: true,
943      vectored: true,
944      writev_threshold: 1024,
945      write_buffer: Vec::with_capacity(2),
946    }
947  }
948
949  /// Writes a frame to the provided stream.
950  pub async fn write_frame<'a, S>(
951    &'a mut self,
952    stream: &mut S,
953    mut frame: Frame<'a>,
954  ) -> Result<(), WebSocketError>
955  where
956    S: AsyncWrite + Unpin,
957  {
958    if self.role == Role::Client && self.auto_apply_mask {
959      frame.mask();
960    }
961
962    if frame.opcode == OpCode::Close {
963      self.closed = true;
964    } else if self.closed {
965      return Err(WebSocketError::ConnectionClosed);
966    } 
967
968    if self.vectored && frame.payload.len() > self.writev_threshold {
969      frame.writev(stream).await?;
970    } else {
971      let text = frame.write(&mut self.write_buffer);
972      stream.write_all(text).await?;
973    }
974
975    Ok(())
976  }
977}
978
979#[cfg(test)]
980mod tests {
981  use super::*;
982
983  const _: () = {
984    const fn assert_unsync<S>() {
985      // Generic trait with a blanket impl over `()` for all types.
986      trait AmbiguousIfImpl<A> {
987        // Required for actually being able to reference the trait.
988        fn some_item() {}
989      }
990
991      impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
992
993      // Used for the specialized impl when *all* traits in
994      // `$($t)+` are implemented.
995      #[allow(dead_code)]
996      struct Invalid;
997
998      impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
999
1000      // If there is only one specialized trait impl, type inference with
1001      // `_` can be resolved and this can compile. Fails to compile if
1002      // `$x` implements `AmbiguousIfImpl<Invalid>`.
1003      let _ = <S as AmbiguousIfImpl<_>>::some_item;
1004    }
1005    assert_unsync::<WebSocket<tokio::net::TcpStream>>();
1006  };
1007}