compio_ws/
lib.rs

1//! Async WebSocket support for compio.
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9
10#![cfg_attr(docsrs, feature(doc_cfg))]
11#![warn(missing_docs)]
12
13use std::io::ErrorKind;
14
15use compio_buf::IntoInner;
16use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
17use tungstenite::{
18    Error as WsError, HandshakeError, Message, WebSocket,
19    client::IntoClientRequest,
20    handshake::server::{Callback, NoCallback},
21    protocol::{CloseFrame, WebSocketConfig},
22};
23
24mod tls;
25pub use tls::*;
26pub use tungstenite;
27
28/// Configuration for compio-ws.
29///
30/// ## API Interface
31///
32/// `_with_config` functions in this crate accept `impl Into<Config>`, so
33/// following are all valid:
34/// - [`Config`]
35/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining
36///   settings)
37/// - [`None`] (use default value)
38pub struct Config {
39    /// WebSocket configuration from tungstenite.
40    websocket: Option<WebSocketConfig>,
41
42    /// Base buffer size
43    buffer_size_base: usize,
44
45    /// Maximum buffer size
46    buffer_size_limit: usize,
47
48    /// Disable Nagle's algorithm. This only affects
49    /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`].
50    disable_nagle: bool,
51}
52
53impl Config {
54    // 128 KiB, see <https://github.com/compio-rs/compio/pull/532>.
55    const DEFAULT_BUF_SIZE: usize = 128 * 1024;
56    // 64 MiB, the same as [`SyncStream`].
57    const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
58
59    /// Creates a new `Config` with default settings.
60    pub fn new() -> Self {
61        Self {
62            websocket: None,
63            buffer_size_base: Self::DEFAULT_BUF_SIZE,
64            buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
65            disable_nagle: false,
66        }
67    }
68
69    /// Get the WebSocket configuration.
70    pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
71        self.websocket.as_ref()
72    }
73
74    /// Get the base buffer size.
75    pub fn buffer_size_base(&self) -> usize {
76        self.buffer_size_base
77    }
78
79    /// Get the maximum buffer size.
80    pub fn buffer_size_limit(&self) -> usize {
81        self.buffer_size_limit
82    }
83
84    /// Set custom base buffer size.
85    ///
86    /// Default to 128 KiB.
87    pub fn with_buffer_size_base(mut self, size: usize) -> Self {
88        self.buffer_size_base = size;
89        self
90    }
91
92    /// Set custom maximum buffer size.
93    ///
94    /// Default to 64 MiB.
95    pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
96        self.buffer_size_limit = size;
97        self
98    }
99
100    /// Set custom buffer sizes.
101    ///
102    /// Default to 128 KiB for base and 64 MiB for limit.
103    pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
104        self.buffer_size_base = base;
105        self.buffer_size_limit = limit;
106        self
107    }
108
109    /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`.
110    ///
111    /// Default to `false`. If you don't know what the Nagle's algorithm is,
112    /// better leave it to `false`.
113    pub fn disable_nagle(mut self, disable: bool) -> Self {
114        self.disable_nagle = disable;
115        self
116    }
117}
118
119impl Default for Config {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl From<WebSocketConfig> for Config {
126    fn from(config: WebSocketConfig) -> Self {
127        Self {
128            websocket: Some(config),
129            ..Default::default()
130        }
131    }
132}
133
134impl From<Option<WebSocketConfig>> for Config {
135    fn from(config: Option<WebSocketConfig>) -> Self {
136        Self {
137            websocket: config,
138            ..Default::default()
139        }
140    }
141}
142
143/// A WebSocket stream that works with compio.
144#[derive(Debug)]
145pub struct WebSocketStream<S> {
146    inner: WebSocket<SyncStream<S>>,
147}
148
149impl<S> WebSocketStream<S>
150where
151    S: AsyncRead + AsyncWrite,
152{
153    /// Send a message on the WebSocket stream.
154    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
155        // Send the message - this buffers it
156        // Since CompioStream::flush() now returns Ok, this should succeed on first try
157        self.inner.send(message)?;
158
159        // flush the buffer to the network
160        self.flush().await
161    }
162
163    /// Read a message from the WebSocket stream.
164    pub async fn read(&mut self) -> Result<Message, WsError> {
165        loop {
166            match self.inner.read() {
167                Ok(msg) => {
168                    self.flush().await?;
169                    return Ok(msg);
170                }
171                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
172                    // Need more data - fill the read buffer
173                    self.inner
174                        .get_mut()
175                        .fill_read_buf()
176                        .await
177                        .map_err(WsError::Io)?;
178                }
179                Err(e) => {
180                    let _ = self.flush().await;
181                    return Err(e);
182                }
183            }
184        }
185    }
186
187    /// Flush the WebSocket stream.
188    pub async fn flush(&mut self) -> Result<(), WsError> {
189        loop {
190            match self.inner.flush() {
191                Ok(()) => break,
192                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
193                    self.inner
194                        .get_mut()
195                        .flush_write_buf()
196                        .await
197                        .map_err(WsError::Io)?;
198                }
199                Err(WsError::ConnectionClosed) => break,
200                Err(e) => return Err(e),
201            }
202        }
203        self.inner
204            .get_mut()
205            .flush_write_buf()
206            .await
207            .map_err(WsError::Io)?;
208        Ok(())
209    }
210
211    /// Close the WebSocket connection.
212    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
213        loop {
214            match self.inner.close(close_frame.clone()) {
215                Ok(()) => break,
216                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
217                    let sync_stream = self.inner.get_mut();
218
219                    let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
220
221                    if flushed == 0 {
222                        sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
223                    }
224                }
225                Err(WsError::ConnectionClosed) => break,
226                Err(e) => return Err(e),
227            }
228        }
229        self.flush().await
230    }
231
232    /// Get a reference to the underlying stream.
233    pub fn get_ref(&self) -> &S {
234        self.inner.get_ref().get_ref()
235    }
236
237    /// Get a mutable reference to the underlying stream.
238    pub fn get_mut(&mut self) -> &mut S {
239        self.inner.get_mut().get_mut()
240    }
241}
242
243impl<S> IntoInner for WebSocketStream<S> {
244    type Inner = WebSocket<SyncStream<S>>;
245
246    fn into_inner(self) -> Self::Inner {
247        self.inner
248    }
249}
250
251/// Accepts a new WebSocket connection with the provided stream.
252///
253/// This function will internally create a handshake representation and returns
254/// a future representing the resolution of the WebSocket handshake. The
255/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
256/// depending on if it's successful or not.
257///
258/// This is typically used after a socket has been accepted from a
259/// `TcpListener`. That socket is then passed to this function to perform
260/// the server half of accepting a client's websocket connection.
261pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
262where
263    S: AsyncRead + AsyncWrite,
264{
265    accept_hdr_async(stream, NoCallback).await
266}
267
268/// Similar to [`accept_async()`] but user can specify a [`Config`].
269pub async fn accept_async_with_config<S>(
270    stream: S,
271    config: impl Into<Config>,
272) -> Result<WebSocketStream<S>, WsError>
273where
274    S: AsyncRead + AsyncWrite,
275{
276    accept_hdr_with_config_async(stream, NoCallback, config).await
277}
278/// Accepts a new WebSocket connection with the provided stream.
279///
280/// This function does the same as [`accept_async()`] but accepts an extra
281/// callback for header processing. The callback receives headers of the
282/// incoming requests and is able to add extra headers to the reply.
283pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
284where
285    S: AsyncRead + AsyncWrite,
286    C: Callback,
287{
288    accept_hdr_with_config_async(stream, callback, None).await
289}
290
291/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`].
292pub async fn accept_hdr_with_config_async<S, C>(
293    stream: S,
294    callback: C,
295    config: impl Into<Config>,
296) -> Result<WebSocketStream<S>, WsError>
297where
298    S: AsyncRead + AsyncWrite,
299    C: Callback,
300{
301    let config = config.into();
302    let sync_stream =
303        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
304    let mut handshake_result =
305        tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
306
307    loop {
308        match handshake_result {
309            Ok(mut websocket) => {
310                websocket
311                    .get_mut()
312                    .flush_write_buf()
313                    .await
314                    .map_err(WsError::Io)?;
315                return Ok(WebSocketStream { inner: websocket });
316            }
317            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
318                let sync_stream = mid_handshake.get_mut().get_mut();
319
320                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
321
322                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
323
324                handshake_result = mid_handshake.handshake();
325            }
326            Err(HandshakeError::Failure(error)) => {
327                return Err(error);
328            }
329        }
330    }
331}
332
333/// Creates a WebSocket handshake from a request and a stream.
334///
335/// For convenience, the user may call this with a url string, a URL,
336/// or a `Request`. Calling with `Request` allows the user to add
337/// a WebSocket protocol or other custom headers.
338///
339/// Internally, this creates a handshake representation and returns
340/// a future representing the resolution of the WebSocket handshake. The
341/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
342/// depending on whether the handshake is successful.
343///
344/// This is typically used for clients who have already established, for
345/// example, a TCP connection to the remote server.
346pub async fn client_async<R, S>(
347    request: R,
348    stream: S,
349) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
350where
351    R: IntoClientRequest,
352    S: AsyncRead + AsyncWrite,
353{
354    client_async_with_config(request, stream, None).await
355}
356
357/// Similar to [`client_async()`] but user can specify a [`Config`].
358pub async fn client_async_with_config<R, S>(
359    request: R,
360    stream: S,
361    config: impl Into<Config>,
362) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
363where
364    R: IntoClientRequest,
365    S: AsyncRead + AsyncWrite,
366{
367    let config = config.into();
368    let sync_stream =
369        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
370    let mut handshake_result =
371        tungstenite::client::client_with_config(request, sync_stream, config.websocket);
372
373    loop {
374        match handshake_result {
375            Ok((mut websocket, response)) => {
376                // Ensure any remaining data is flushed
377                websocket
378                    .get_mut()
379                    .flush_write_buf()
380                    .await
381                    .map_err(WsError::Io)?;
382                return Ok((WebSocketStream { inner: websocket }, response));
383            }
384            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
385                let sync_stream = mid_handshake.get_mut().get_mut();
386
387                // For handshake: always try both operations
388                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
389
390                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
391
392                handshake_result = mid_handshake.handshake();
393            }
394            Err(HandshakeError::Failure(error)) => {
395                return Err(error);
396            }
397        }
398    }
399}