Skip to main content

compio_ws/
lib.rs

1//! WebSocket support based on [`tungstenite`].
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9//!
10//! [`tungstenite`]: https://docs.rs/tungstenite
11
12#![cfg_attr(docsrs, feature(doc_cfg))]
13#![warn(missing_docs)]
14#![deny(rustdoc::broken_intra_doc_links)]
15#![doc(
16    html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
17)]
18#![doc(
19    html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
20)]
21
22use std::io::ErrorKind;
23
24use compio_buf::IntoInner;
25use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
26use tungstenite::{
27    Error as WsError, HandshakeError, Message, WebSocket,
28    client::IntoClientRequest,
29    handshake::server::{Callback, NoCallback},
30    protocol::{CloseFrame, Role, WebSocketConfig},
31};
32
33mod tls;
34pub use tls::*;
35pub use tungstenite;
36
37/// Configuration for compio-ws.
38///
39/// ## API Interface
40///
41/// `_with_config` functions in this crate accept `impl Into<Config>`, so
42/// following are all valid:
43/// - [`Config`]
44/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining
45///   settings)
46/// - [`None`] (use default value)
47pub struct Config {
48    /// WebSocket configuration from tungstenite.
49    websocket: Option<WebSocketConfig>,
50
51    /// Base buffer size
52    buffer_size_base: usize,
53
54    /// Maximum buffer size
55    buffer_size_limit: usize,
56
57    /// Disable Nagle's algorithm. This only affects
58    /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`].
59    disable_nagle: bool,
60}
61
62impl Config {
63    // 128 KiB, see <https://github.com/compio-rs/compio/pull/532>.
64    const DEFAULT_BUF_SIZE: usize = 128 * 1024;
65    // 64 MiB, the same as [`SyncStream`].
66    const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
67
68    /// Creates a new `Config` with default settings.
69    pub fn new() -> Self {
70        Self {
71            websocket: None,
72            buffer_size_base: Self::DEFAULT_BUF_SIZE,
73            buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
74            disable_nagle: false,
75        }
76    }
77
78    /// Get the WebSocket configuration.
79    pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
80        self.websocket.as_ref()
81    }
82
83    /// Get the base buffer size.
84    pub fn buffer_size_base(&self) -> usize {
85        self.buffer_size_base
86    }
87
88    /// Get the maximum buffer size.
89    pub fn buffer_size_limit(&self) -> usize {
90        self.buffer_size_limit
91    }
92
93    /// Set custom base buffer size.
94    ///
95    /// Default to 128 KiB.
96    pub fn with_buffer_size_base(mut self, size: usize) -> Self {
97        self.buffer_size_base = size;
98        self
99    }
100
101    /// Set custom maximum buffer size.
102    ///
103    /// Default to 64 MiB.
104    pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
105        self.buffer_size_limit = size;
106        self
107    }
108
109    /// Set custom buffer sizes.
110    ///
111    /// Default to 128 KiB for base and 64 MiB for limit.
112    pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
113        self.buffer_size_base = base;
114        self.buffer_size_limit = limit;
115        self
116    }
117
118    /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`.
119    ///
120    /// Default to `false`. If you don't know what the Nagle's algorithm is,
121    /// better leave it to `false`.
122    pub fn disable_nagle(mut self, disable: bool) -> Self {
123        self.disable_nagle = disable;
124        self
125    }
126}
127
128impl Default for Config {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl From<WebSocketConfig> for Config {
135    fn from(config: WebSocketConfig) -> Self {
136        Self {
137            websocket: Some(config),
138            ..Default::default()
139        }
140    }
141}
142
143impl From<Option<WebSocketConfig>> for Config {
144    fn from(config: Option<WebSocketConfig>) -> Self {
145        Self {
146            websocket: config,
147            ..Default::default()
148        }
149    }
150}
151
152/// A WebSocket stream that works with compio.
153#[derive(Debug)]
154pub struct WebSocketStream<S> {
155    inner: WebSocket<SyncStream<S>>,
156}
157
158impl<S> WebSocketStream<S>
159where
160    S: AsyncRead + AsyncWrite,
161{
162    /// Convert a raw socket into a [`WebSocketStream`] without performing a
163    /// handshake.
164    ///
165    /// `disable_nagle` will be ignored since the socket is already connected
166    /// and the user can set `nodelay` on the socket directly before calling
167    /// this function if needed.
168    pub async fn from_raw_socket(stream: S, role: Role, config: impl Into<Config>) -> Self {
169        let config = config.into();
170        let sync_stream =
171            SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
172        WebSocketStream {
173            inner: WebSocket::from_raw_socket(sync_stream, role, config.websocket),
174        }
175    }
176
177    /// Convert a raw socket into a [`WebSocketStream`] without performing a
178    /// handshake.
179    ///
180    /// `disable_nagle` will be ignored since the socket is already connected
181    /// and the user can set `nodelay` on the socket directly before calling
182    /// this function if needed.
183    pub async fn from_partially_read(
184        stream: S,
185        part: Vec<u8>,
186        role: Role,
187        config: impl Into<Config>,
188    ) -> Self {
189        let config = config.into();
190        let sync_stream =
191            SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
192        WebSocketStream {
193            inner: WebSocket::from_partially_read(sync_stream, part, role, config.websocket),
194        }
195    }
196
197    /// Send a message on the WebSocket stream.
198    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
199        // Send the message - this buffers it
200        // Since CompioStream::flush() now returns Ok, this should succeed on first try
201        self.inner.send(message)?;
202
203        // flush the buffer to the network
204        self.flush().await
205    }
206
207    /// Read a message from the WebSocket stream.
208    pub async fn read(&mut self) -> Result<Message, WsError> {
209        loop {
210            match self.inner.read() {
211                Ok(msg) => {
212                    self.flush().await?;
213                    return Ok(msg);
214                }
215                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
216                    // Need more data - fill the read buffer
217                    self.inner
218                        .get_mut()
219                        .fill_read_buf()
220                        .await
221                        .map_err(WsError::Io)?;
222                }
223                Err(e) => {
224                    let _ = self.flush().await;
225                    return Err(e);
226                }
227            }
228        }
229    }
230
231    /// Flush the WebSocket stream.
232    pub async fn flush(&mut self) -> Result<(), WsError> {
233        loop {
234            match self.inner.flush() {
235                Ok(()) => break,
236                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
237                    self.inner
238                        .get_mut()
239                        .flush_write_buf()
240                        .await
241                        .map_err(WsError::Io)?;
242                }
243                Err(WsError::ConnectionClosed) => break,
244                Err(e) => return Err(e),
245            }
246        }
247        self.inner
248            .get_mut()
249            .flush_write_buf()
250            .await
251            .map_err(WsError::Io)?;
252        Ok(())
253    }
254
255    /// Close the WebSocket connection.
256    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
257        loop {
258            match self.inner.close(close_frame.clone()) {
259                Ok(()) => break,
260                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
261                    let sync_stream = self.inner.get_mut();
262
263                    let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
264
265                    if flushed == 0 {
266                        sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
267                    }
268                }
269                Err(WsError::ConnectionClosed) => break,
270                Err(e) => return Err(e),
271            }
272        }
273        self.flush().await
274    }
275
276    /// Get a reference to the underlying stream.
277    pub fn get_ref(&self) -> &S {
278        self.inner.get_ref().get_ref()
279    }
280
281    /// Get a mutable reference to the underlying stream.
282    pub fn get_mut(&mut self) -> &mut S {
283        self.inner.get_mut().get_mut()
284    }
285}
286
287impl<S> IntoInner for WebSocketStream<S> {
288    type Inner = WebSocket<SyncStream<S>>;
289
290    fn into_inner(self) -> Self::Inner {
291        self.inner
292    }
293}
294
295/// Accepts a new WebSocket connection with the provided stream.
296///
297/// This function will internally create a handshake representation and returns
298/// a future representing the resolution of the WebSocket handshake. The
299/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
300/// depending on if it's successful or not.
301///
302/// This is typically used after a socket has been accepted from a
303/// `TcpListener`. That socket is then passed to this function to perform
304/// the server half of accepting a client's websocket connection.
305pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
306where
307    S: AsyncRead + AsyncWrite,
308{
309    accept_hdr_async(stream, NoCallback).await
310}
311
312/// Similar to [`accept_async()`] but user can specify a [`Config`].
313pub async fn accept_async_with_config<S>(
314    stream: S,
315    config: impl Into<Config>,
316) -> Result<WebSocketStream<S>, WsError>
317where
318    S: AsyncRead + AsyncWrite,
319{
320    accept_hdr_with_config_async(stream, NoCallback, config).await
321}
322/// Accepts a new WebSocket connection with the provided stream.
323///
324/// This function does the same as [`accept_async()`] but accepts an extra
325/// callback for header processing. The callback receives headers of the
326/// incoming requests and is able to add extra headers to the reply.
327pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
328where
329    S: AsyncRead + AsyncWrite,
330    C: Callback,
331{
332    accept_hdr_with_config_async(stream, callback, None).await
333}
334
335/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`].
336pub async fn accept_hdr_with_config_async<S, C>(
337    stream: S,
338    callback: C,
339    config: impl Into<Config>,
340) -> Result<WebSocketStream<S>, WsError>
341where
342    S: AsyncRead + AsyncWrite,
343    C: Callback,
344{
345    let config = config.into();
346    let sync_stream =
347        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
348    let mut handshake_result =
349        tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
350
351    loop {
352        match handshake_result {
353            Ok(mut websocket) => {
354                websocket
355                    .get_mut()
356                    .flush_write_buf()
357                    .await
358                    .map_err(WsError::Io)?;
359                return Ok(WebSocketStream { inner: websocket });
360            }
361            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
362                let sync_stream = mid_handshake.get_mut().get_mut();
363
364                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
365
366                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
367
368                handshake_result = mid_handshake.handshake();
369            }
370            Err(HandshakeError::Failure(error)) => {
371                return Err(error);
372            }
373        }
374    }
375}
376
377/// Creates a WebSocket handshake from a request and a stream.
378///
379/// For convenience, the user may call this with a url string, a URL,
380/// or a `Request`. Calling with `Request` allows the user to add
381/// a WebSocket protocol or other custom headers.
382///
383/// Internally, this creates a handshake representation and returns
384/// a future representing the resolution of the WebSocket handshake. The
385/// returned future will resolve to either [`WebSocketStream<S>`] or [`WsError`]
386/// depending on whether the handshake is successful.
387///
388/// This is typically used for clients who have already established, for
389/// example, a TCP connection to the remote server.
390pub async fn client_async<R, S>(
391    request: R,
392    stream: S,
393) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
394where
395    R: IntoClientRequest,
396    S: AsyncRead + AsyncWrite,
397{
398    client_async_with_config(request, stream, None).await
399}
400
401/// Similar to [`client_async()`] but user can specify a [`Config`].
402pub async fn client_async_with_config<R, S>(
403    request: R,
404    stream: S,
405    config: impl Into<Config>,
406) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
407where
408    R: IntoClientRequest,
409    S: AsyncRead + AsyncWrite,
410{
411    let config = config.into();
412    let sync_stream =
413        SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
414    let mut handshake_result =
415        tungstenite::client::client_with_config(request, sync_stream, config.websocket);
416
417    loop {
418        match handshake_result {
419            Ok((mut websocket, response)) => {
420                // Ensure any remaining data is flushed
421                websocket
422                    .get_mut()
423                    .flush_write_buf()
424                    .await
425                    .map_err(WsError::Io)?;
426                return Ok((WebSocketStream { inner: websocket }, response));
427            }
428            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
429                let sync_stream = mid_handshake.get_mut().get_mut();
430
431                // For handshake: always try both operations
432                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
433
434                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
435
436                handshake_result = mid_handshake.handshake();
437            }
438            Err(HandshakeError::Failure(error)) => {
439                return Err(error);
440            }
441        }
442    }
443}