compio_ws/
lib.rs

1//! Async WebSocket support for compio.
2//!
3//! This library is an implementation of WebSocket handshakes and streams for
4//! compio. It is based on the tungstenite crate which implements all required
5//! WebSocket protocol logic. This crate brings compio support / compio
6//! integration to it.
7//!
8//! Each WebSocket stream implements message reading and writing.
9
10pub mod stream;
11
12#[cfg(feature = "rustls")]
13pub mod rustls;
14
15use std::io::ErrorKind;
16
17use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
18use tungstenite::{
19    Error as WsError, HandshakeError, Message, WebSocket,
20    client::IntoClientRequest,
21    handshake::server::{Callback, NoCallback},
22    protocol::CloseFrame,
23};
24pub use tungstenite::{
25    Message as WebSocketMessage, error::Error as TungsteniteError, handshake::client::Response,
26    protocol::WebSocketConfig,
27};
28
29#[cfg(feature = "rustls")]
30pub use crate::rustls::{
31    AutoStream, ConnectStream, Connector, client_async_tls, client_async_tls_with_config,
32    client_async_tls_with_connector, client_async_tls_with_connector_and_config, connect_async,
33    connect_async_with_config, connect_async_with_tls_connector,
34    connect_async_with_tls_connector_and_config,
35};
36
37pub struct WebSocketStream<S> {
38    inner: WebSocket<SyncStream<S>>,
39}
40
41impl<S> WebSocketStream<S>
42where
43    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
44{
45    pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
46        // Send the message - this buffers it
47        // Since CompioStream::flush() now returns Ok, this should succeed on first try
48        self.inner.send(message)?;
49
50        // flush the buffer to the network
51        self.inner
52            .get_mut()
53            .flush_write_buf()
54            .await
55            .map_err(WsError::Io)?;
56
57        Ok(())
58    }
59
60    pub async fn read(&mut self) -> Result<Message, WsError> {
61        loop {
62            match self.inner.read() {
63                Ok(msg) => {
64                    // Always try to flush after read (close frames need this)
65                    let _ = self.inner.get_mut().flush_write_buf().await;
66                    return Ok(msg);
67                }
68                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
69                    // Need more data - fill the read buffer
70                    self.inner
71                        .get_mut()
72                        .fill_read_buf()
73                        .await
74                        .map_err(WsError::Io)?;
75                    // Retry the read
76                    continue;
77                }
78                Err(e) => {
79                    // Always try to flush on error (close frames)
80                    let _ = self.inner.get_mut().flush_write_buf().await;
81                    return Err(e);
82                }
83            }
84        }
85    }
86
87    pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
88        loop {
89            match self.inner.close(close_frame.clone()) {
90                Ok(()) => return Ok(()),
91                Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
92                    let sync_stream = self.inner.get_mut();
93
94                    let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
95
96                    if flushed == 0 {
97                        sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
98                    }
99                    continue;
100                }
101                Err(e) => return Err(e),
102            }
103        }
104    }
105
106    pub fn get_ref(&self) -> &S {
107        self.inner.get_ref().get_ref()
108    }
109
110    pub fn get_mut(&mut self) -> &mut S {
111        self.inner.get_mut().get_mut()
112    }
113
114    pub fn get_inner(self) -> WebSocket<SyncStream<S>> {
115        self.inner
116    }
117}
118
119/// Accepts a new WebSocket connection with the provided stream.
120///
121/// This function will internally call `server::accept` to create a
122/// handshake representation and returns a future representing the
123/// resolution of the WebSocket handshake. The returned future will resolve
124/// to either `WebSocketStream<S>` or `Error` depending if it's successful
125/// or not.
126///
127/// This is typically used after a socket has been accepted from a
128/// `TcpListener`. That socket is then passed to this function to perform
129/// the server half of accepting a client's websocket connection.
130pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
131where
132    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
133{
134    accept_hdr_async(stream, NoCallback).await
135}
136
137/// The same as `accept_async()` but the one can specify a websocket
138/// configuration. Please refer to `accept_async()` for more details.
139pub async fn accept_async_with_config<S>(
140    stream: S,
141    config: Option<WebSocketConfig>,
142) -> Result<WebSocketStream<S>, WsError>
143where
144    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
145{
146    accept_hdr_with_config_async(stream, NoCallback, config).await
147}
148/// Accepts a new WebSocket connection with the provided stream.
149///
150/// This function does the same as `accept_async()` but accepts an extra
151/// callback for header processing. The callback receives headers of the
152/// incoming requests and is able to add extra headers to the reply.
153pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
156    C: Callback,
157{
158    accept_hdr_with_config_async(stream, callback, None).await
159}
160
161/// The same as `accept_hdr_async()` but the one can specify a websocket
162/// configuration. Please refer to `accept_hdr_async()` for more details.
163pub async fn accept_hdr_with_config_async<S, C>(
164    stream: S,
165    callback: C,
166    config: Option<WebSocketConfig>,
167) -> Result<WebSocketStream<S>, WsError>
168where
169    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
170    C: Callback,
171{
172    let sync_stream = SyncStream::new(stream);
173    let mut handshake_result = tungstenite::accept_hdr_with_config(sync_stream, callback, config);
174
175    loop {
176        match handshake_result {
177            Ok(mut websocket) => {
178                websocket
179                    .get_mut()
180                    .flush_write_buf()
181                    .await
182                    .map_err(WsError::Io)?;
183                return Ok(WebSocketStream { inner: websocket });
184            }
185            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
186                let sync_stream = mid_handshake.get_mut().get_mut();
187
188                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
189
190                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
191
192                handshake_result = mid_handshake.handshake();
193            }
194            Err(HandshakeError::Failure(error)) => {
195                return Err(error);
196            }
197        }
198    }
199}
200
201/// Creates a WebSocket handshake from a request and a stream.
202///
203/// For convenience, the user may call this with a url string, a URL,
204/// or a `Request`. Calling with `Request` allows the user to add
205/// a WebSocket protocol or other custom headers.
206///
207/// Internally, this creates a handshake representation and returns
208/// a future representing the resolution of the WebSocket handshake. The
209/// returned future will resolve to either `WebSocketStream<S>` or `Error`
210/// depending on whether the handshake is successful.
211///
212/// This is typically used for clients who have already established, for
213/// example, a TCP connection to the remote server.
214pub async fn client_async<R, S>(
215    request: R,
216    stream: S,
217) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
218where
219    R: IntoClientRequest,
220    S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
221{
222    client_async_with_config(request, stream, None).await
223}
224
225/// The same as `client_async()` but the one can specify a websocket
226/// configuration. Please refer to `client_async()` for more details.
227pub async fn client_async_with_config<R, S>(
228    request: R,
229    stream: S,
230    config: Option<WebSocketConfig>,
231) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
232where
233    R: IntoClientRequest,
234    S: AsyncRead + AsyncWrite + Unpin,
235{
236    let sync_stream = SyncStream::new(stream);
237    let mut handshake_result =
238        tungstenite::client::client_with_config(request, sync_stream, config);
239
240    loop {
241        match handshake_result {
242            Ok((mut websocket, response)) => {
243                // Ensure any remaining data is flushed
244                websocket
245                    .get_mut()
246                    .flush_write_buf()
247                    .await
248                    .map_err(WsError::Io)?;
249                return Ok((WebSocketStream { inner: websocket }, response));
250            }
251            Err(HandshakeError::Interrupted(mut mid_handshake)) => {
252                let sync_stream = mid_handshake.get_mut().get_mut();
253
254                // For handshake: always try both operations
255                sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
256
257                sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
258
259                handshake_result = mid_handshake.handshake();
260            }
261            Err(HandshakeError::Failure(error)) => {
262                return Err(error);
263            }
264        }
265    }
266}
267
268#[inline]
269#[allow(clippy::result_large_err)]
270#[cfg(feature = "rustls")]
271pub(crate) fn domain(
272    request: &tungstenite::handshake::client::Request,
273) -> Result<String, tungstenite::Error> {
274    request
275        .uri()
276        .host()
277        .map(|host| {
278            // If host is an IPv6 address, it might be surrounded by brackets. These
279            // brackets are *not* part of a valid IP, so they must be stripped
280            // out.
281            //
282            // The URI from the request is guaranteed to be valid, so we don't need a
283            // separate check for the closing bracket.
284            let host = if host.starts_with('[') {
285                &host[1..host.len() - 1]
286            } else {
287                host
288            };
289
290            host.to_owned()
291        })
292        .ok_or(tungstenite::Error::Url(
293            tungstenite::error::UrlError::NoHostName,
294        ))
295}