async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28//!    provides.
29//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
32//!    the [gio](https://www.gtk-rs.org) runtime.
33//!
34//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35//! making the socket a stream of WebSocket messages coming in and going out.
36
37#![deny(
38    missing_docs,
39    unused_must_use,
40    unused_mut,
41    unused_imports,
42    unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51    feature = "async-tls",
52    feature = "async-native-tls",
53    feature = "tokio-native-tls",
54    feature = "tokio-rustls-manual-roots",
55    feature = "tokio-rustls-native-certs",
56    feature = "tokio-rustls-webpki-roots",
57    feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62    io::{Read, Write},
63    pin::Pin,
64    sync::{Arc, Mutex, MutexGuard},
65    task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75    client::IntoClientRequest,
76    handshake::{
77        client::{ClientHandshake, Response},
78        server::{Callback, NoCallback},
79        HandshakeError,
80    },
81};
82use tungstenite::{
83    error::Error as WsError,
84    protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98#[cfg(feature = "futures-03-sink")]
99pub use bytes::ByteWriter;
100
101use tungstenite::protocol::CloseFrame;
102
103/// Creates a WebSocket handshake from a request and a stream.
104/// For convenience, the user may call this with a url string, a URL,
105/// or a `Request`. Calling with `Request` allows the user to add
106/// a WebSocket protocol or other custom headers.
107///
108/// Internally, this custom creates a handshake representation and returns
109/// a future representing the resolution of the WebSocket handshake. The
110/// returned future will resolve to either `WebSocketStream<S>` or `Error`
111/// depending on whether the handshake is successful.
112///
113/// This is typically used for clients who have already established, for
114/// example, a TCP connection to the remote server.
115#[cfg(feature = "handshake")]
116pub async fn client_async<'a, R, S>(
117    request: R,
118    stream: S,
119) -> Result<(WebSocketStream<S>, Response), WsError>
120where
121    R: IntoClientRequest + Unpin,
122    S: AsyncRead + AsyncWrite + Unpin,
123{
124    client_async_with_config(request, stream, None).await
125}
126
127/// The same as `client_async()` but the one can specify a websocket configuration.
128/// Please refer to `client_async()` for more details.
129#[cfg(feature = "handshake")]
130pub async fn client_async_with_config<'a, R, S>(
131    request: R,
132    stream: S,
133    config: Option<WebSocketConfig>,
134) -> Result<(WebSocketStream<S>, Response), WsError>
135where
136    R: IntoClientRequest + Unpin,
137    S: AsyncRead + AsyncWrite + Unpin,
138{
139    let f = handshake::client_handshake(stream, move |allow_std| {
140        let request = request.into_client_request()?;
141        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
142        cli_handshake.handshake()
143    });
144    f.await.map_err(|e| match e {
145        HandshakeError::Failure(e) => e,
146        e => WsError::Io(std::io::Error::new(
147            std::io::ErrorKind::Other,
148            e.to_string(),
149        )),
150    })
151}
152
153/// Accepts a new WebSocket connection with the provided stream.
154///
155/// This function will internally call `server::accept` to create a
156/// handshake representation and returns a future representing the
157/// resolution of the WebSocket handshake. The returned future will resolve
158/// to either `WebSocketStream<S>` or `Error` depending if it's successful
159/// or not.
160///
161/// This is typically used after a socket has been accepted from a
162/// `TcpListener`. That socket is then passed to this function to perform
163/// the server half of the accepting a client's websocket connection.
164#[cfg(feature = "handshake")]
165pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
166where
167    S: AsyncRead + AsyncWrite + Unpin,
168{
169    accept_hdr_async(stream, NoCallback).await
170}
171
172/// The same as `accept_async()` but the one can specify a websocket configuration.
173/// Please refer to `accept_async()` for more details.
174#[cfg(feature = "handshake")]
175pub async fn accept_async_with_config<S>(
176    stream: S,
177    config: Option<WebSocketConfig>,
178) -> Result<WebSocketStream<S>, WsError>
179where
180    S: AsyncRead + AsyncWrite + Unpin,
181{
182    accept_hdr_async_with_config(stream, NoCallback, config).await
183}
184
185/// Accepts a new WebSocket connection with the provided stream.
186///
187/// This function does the same as `accept_async()` but accepts an extra callback
188/// for header processing. The callback receives headers of the incoming
189/// requests and is able to add extra headers to the reply.
190#[cfg(feature = "handshake")]
191pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
192where
193    S: AsyncRead + AsyncWrite + Unpin,
194    C: Callback + Unpin,
195{
196    accept_hdr_async_with_config(stream, callback, None).await
197}
198
199/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
200/// Please refer to `accept_hdr_async()` for more details.
201#[cfg(feature = "handshake")]
202pub async fn accept_hdr_async_with_config<S, C>(
203    stream: S,
204    callback: C,
205    config: Option<WebSocketConfig>,
206) -> Result<WebSocketStream<S>, WsError>
207where
208    S: AsyncRead + AsyncWrite + Unpin,
209    C: Callback + Unpin,
210{
211    let f = handshake::server_handshake(stream, move |allow_std| {
212        tungstenite::accept_hdr_with_config(allow_std, callback, config)
213    });
214    f.await.map_err(|e| match e {
215        HandshakeError::Failure(e) => e,
216        e => WsError::Io(std::io::Error::new(
217            std::io::ErrorKind::Other,
218            e.to_string(),
219        )),
220    })
221}
222
223/// A wrapper around an underlying raw stream which implements the WebSocket
224/// protocol.
225///
226/// A `WebSocketStream<S>` represents a handshake that has been completed
227/// successfully and both the server and the client are ready for receiving
228/// and sending data. Message from a `WebSocketStream<S>` are accessible
229/// through the respective `Stream` and `Sink`. Check more information about
230/// them in `futures-rs` crate documentation or have a look on the examples
231/// and unit tests for this crate.
232#[derive(Debug)]
233pub struct WebSocketStream<S> {
234    inner: WebSocket<AllowStd<S>>,
235    #[cfg(feature = "futures-03-sink")]
236    closing: bool,
237    ended: bool,
238    /// Tungstenite is probably ready to receive more data.
239    ///
240    /// `false` once start_send hits `WouldBlock` errors.
241    /// `true` initially and after `flush`ing.
242    ready: bool,
243}
244
245impl<S> WebSocketStream<S> {
246    /// Convert a raw socket into a WebSocketStream without performing a
247    /// handshake.
248    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
249    where
250        S: AsyncRead + AsyncWrite + Unpin,
251    {
252        handshake::without_handshake(stream, move |allow_std| {
253            WebSocket::from_raw_socket(allow_std, role, config)
254        })
255        .await
256    }
257
258    /// Convert a raw socket into a WebSocketStream without performing a
259    /// handshake.
260    pub async fn from_partially_read(
261        stream: S,
262        part: Vec<u8>,
263        role: Role,
264        config: Option<WebSocketConfig>,
265    ) -> Self
266    where
267        S: AsyncRead + AsyncWrite + Unpin,
268    {
269        handshake::without_handshake(stream, move |allow_std| {
270            WebSocket::from_partially_read(allow_std, part, role, config)
271        })
272        .await
273    }
274
275    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
276        Self {
277            inner: ws,
278            #[cfg(feature = "futures-03-sink")]
279            closing: false,
280            ended: false,
281            ready: true,
282        }
283    }
284
285    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
286    where
287        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
288        AllowStd<S>: Read + Write,
289    {
290        #[cfg(feature = "verbose-logging")]
291        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
292        if let Some((kind, ctx)) = ctx {
293            self.inner.get_mut().set_waker(kind, ctx.waker());
294        }
295        f(&mut self.inner)
296    }
297
298    /// Returns a shared reference to the inner stream.
299    pub fn get_ref(&self) -> &S
300    where
301        S: AsyncRead + AsyncWrite + Unpin,
302    {
303        self.inner.get_ref().get_ref()
304    }
305
306    /// Returns a mutable reference to the inner stream.
307    pub fn get_mut(&mut self) -> &mut S
308    where
309        S: AsyncRead + AsyncWrite + Unpin,
310    {
311        self.inner.get_mut().get_mut()
312    }
313
314    /// Returns a reference to the configuration of the tungstenite stream.
315    pub fn get_config(&self) -> &WebSocketConfig {
316        self.inner.get_config()
317    }
318
319    /// Close the underlying web socket
320    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
321    where
322        S: AsyncRead + AsyncWrite + Unpin,
323    {
324        self.send(Message::Close(msg)).await
325    }
326
327    /// Splits the websocket stream into separate
328    /// [sender](WebSocketSender) and [receiver](WebSocketReceiver) parts.
329    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
330        let shared = Arc::new(Shared(Mutex::new(self)));
331        let sender = WebSocketSender {
332            shared: shared.clone(),
333        };
334
335        let receiver = WebSocketReceiver { shared };
336        (sender, receiver)
337    }
338
339    /// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
340    /// parts back into a single stream. If both parts originate from the same
341    /// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
342    /// Otherwise, returns `Err` containing the provided parts.
343    pub fn reunite(
344        sender: WebSocketSender<S>,
345        receiver: WebSocketReceiver<S>,
346    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
347        if sender.is_pair_of(&receiver) {
348            drop(receiver);
349            let stream = Arc::try_unwrap(sender.shared)
350                .ok()
351                .expect("reunite the stream")
352                .into_inner();
353
354            Ok(stream)
355        } else {
356            Err((sender, receiver))
357        }
358    }
359}
360
361impl<T> WebSocketStream<T>
362where
363    T: AsyncRead + AsyncWrite + Unpin,
364{
365    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
366        #[cfg(feature = "verbose-logging")]
367        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
368
369        // The connection has been closed or a critical error has occurred.
370        // We have already returned the error to the user, the `Stream` is unusable,
371        // so we assume that the stream has been "fused".
372        if self.ended {
373            return Poll::Ready(None);
374        }
375
376        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
377            #[cfg(feature = "verbose-logging")]
378            trace!(
379                "{}:{} WebSocketStream.with_context poll_next -> read()",
380                file!(),
381                line!()
382            );
383            cvt(s.read())
384        })) {
385            Ok(v) => Poll::Ready(Some(Ok(v))),
386            Err(e) => {
387                self.ended = true;
388                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
389                    Poll::Ready(None)
390                } else {
391                    Poll::Ready(Some(Err(e)))
392                }
393            }
394        }
395    }
396
397    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
398        if self.ready {
399            return Poll::Ready(Ok(()));
400        }
401
402        // Currently blocked so try to flush the blockage away
403        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
404            .map(|r| {
405                self.ready = true;
406                r
407            })
408    }
409
410    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
411        match self.with_context(None, |s| s.write(item)) {
412            Ok(()) => {
413                self.ready = true;
414                Ok(())
415            }
416            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
417                // the message was accepted and queued so not an error
418                // but `poll_ready` will now start trying to flush the block
419                self.ready = false;
420                Ok(())
421            }
422            Err(e) => {
423                self.ready = true;
424                debug!("websocket start_send error: {}", e);
425                Err(e)
426            }
427        }
428    }
429
430    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
431        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
432            .map(|r| {
433                self.ready = true;
434                match r {
435                    // WebSocket connection has just been closed. Flushing completed, not an error.
436                    Err(WsError::ConnectionClosed) => Ok(()),
437                    other => other,
438                }
439            })
440    }
441
442    #[cfg(feature = "futures-03-sink")]
443    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
444        self.ready = true;
445        let res = if self.closing {
446            // After queueing it, we call `flush` to drive the close handshake to completion.
447            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
448        } else {
449            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
450        };
451
452        match res {
453            Ok(()) => Poll::Ready(Ok(())),
454            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
455            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
456                trace!("WouldBlock");
457                self.closing = true;
458                Poll::Pending
459            }
460            Err(err) => {
461                debug!("websocket close error: {}", err);
462                Poll::Ready(Err(err))
463            }
464        }
465    }
466}
467
468impl<T> Stream for WebSocketStream<T>
469where
470    T: AsyncRead + AsyncWrite + Unpin,
471{
472    type Item = Result<Message, WsError>;
473
474    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
475        self.get_mut().poll_next(cx)
476    }
477}
478
479impl<T> FusedStream for WebSocketStream<T>
480where
481    T: AsyncRead + AsyncWrite + Unpin,
482{
483    fn is_terminated(&self) -> bool {
484        self.ended
485    }
486}
487
488#[cfg(feature = "futures-03-sink")]
489impl<T> futures_util::Sink<Message> for WebSocketStream<T>
490where
491    T: AsyncRead + AsyncWrite + Unpin,
492{
493    type Error = WsError;
494
495    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
496        self.get_mut().poll_ready(cx)
497    }
498
499    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
500        self.get_mut().start_send(item)
501    }
502
503    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
504        self.get_mut().poll_flush(cx)
505    }
506
507    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
508        self.get_mut().poll_close(cx)
509    }
510}
511
512impl<S> WebSocketStream<S> {
513    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
514    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
515    where
516        S: AsyncRead + AsyncWrite + Unpin,
517    {
518        Send {
519            ws: self,
520            msg: Some(msg),
521        }
522        .await
523    }
524}
525
526struct Send<W> {
527    ws: W,
528    msg: Option<Message>,
529}
530
531/// Performs an asynchronous message send to the websocket.
532fn send_helper<S>(
533    ws: &mut WebSocketStream<S>,
534    msg: &mut Option<Message>,
535    cx: &mut Context<'_>,
536) -> Poll<Result<(), WsError>>
537where
538    S: AsyncRead + AsyncWrite + Unpin,
539{
540    if msg.is_some() {
541        ready!(ws.poll_ready(cx))?;
542        let msg = msg.take().expect("unreachable");
543        ws.start_send(msg)?;
544    }
545
546    ws.poll_flush(cx)
547}
548
549impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
550where
551    S: AsyncRead + AsyncWrite + Unpin,
552{
553    type Output = Result<(), WsError>;
554
555    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
556        let me = self.get_mut();
557        send_helper(me.ws, &mut me.msg, cx)
558    }
559}
560
561impl<S> std::future::Future for Send<&Shared<S>>
562where
563    S: AsyncRead + AsyncWrite + Unpin,
564{
565    type Output = Result<(), WsError>;
566
567    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
568        let me = self.get_mut();
569        let mut ws = me.ws.lock();
570        send_helper(&mut ws, &mut me.msg, cx)
571    }
572}
573
574/// The sender part of a [websocket](WebSocketStream) stream.
575#[derive(Debug)]
576pub struct WebSocketSender<S> {
577    shared: Arc<Shared<S>>,
578}
579
580impl<S> WebSocketSender<S> {
581    /// Send a message via [websocket](WebSocketStream).
582    pub async fn send(&self, msg: Message) -> Result<(), WsError>
583    where
584        S: AsyncRead + AsyncWrite + Unpin,
585    {
586        Send {
587            ws: &*self.shared,
588            msg: Some(msg),
589        }
590        .await
591    }
592
593    /// Close the underlying [websocket](WebSocketStream).
594    pub async fn close(&self, msg: Option<CloseFrame>) -> Result<(), WsError>
595    where
596        S: AsyncRead + AsyncWrite + Unpin,
597    {
598        self.send(Message::Close(msg)).await
599    }
600
601    /// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
602    /// were split from the same [websocket](WebSocketStream) stream.
603    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
604        Arc::ptr_eq(&self.shared, &other.shared)
605    }
606}
607
608#[cfg(feature = "futures-03-sink")]
609impl<T> futures_util::Sink<Message> for WebSocketSender<T>
610where
611    T: AsyncRead + AsyncWrite + Unpin,
612{
613    type Error = WsError;
614
615    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616        self.shared.lock().poll_ready(cx)
617    }
618
619    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
620        self.shared.lock().start_send(item)
621    }
622
623    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
624        self.shared.lock().poll_flush(cx)
625    }
626
627    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
628        self.shared.lock().poll_close(cx)
629    }
630}
631
632/// The receiver part of a [websocket](WebSocketStream) stream.
633#[derive(Debug)]
634pub struct WebSocketReceiver<S> {
635    shared: Arc<Shared<S>>,
636}
637
638impl<S> WebSocketReceiver<S> {
639    /// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
640    /// were split from the same [websocket](WebSocketStream) stream.
641    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
642        Arc::ptr_eq(&self.shared, &other.shared)
643    }
644}
645
646impl<S> Stream for WebSocketReceiver<S>
647where
648    S: AsyncRead + AsyncWrite + Unpin,
649{
650    type Item = Result<Message, WsError>;
651
652    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
653        self.shared.lock().poll_next(cx)
654    }
655}
656
657impl<S> FusedStream for WebSocketReceiver<S>
658where
659    S: AsyncRead + AsyncWrite + Unpin,
660{
661    fn is_terminated(&self) -> bool {
662        self.shared.lock().ended
663    }
664}
665
666#[derive(Debug)]
667struct Shared<S>(Mutex<WebSocketStream<S>>);
668
669impl<S> Shared<S> {
670    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
671        self.0.lock().expect("lock shared stream")
672    }
673
674    fn into_inner(self) -> WebSocketStream<S> {
675        self.0.into_inner().expect("get shared stream")
676    }
677}
678
679#[cfg(any(
680    feature = "async-tls",
681    feature = "async-std-runtime",
682    feature = "tokio-runtime",
683    feature = "gio-runtime"
684))]
685/// Get a domain from an URL.
686#[inline]
687pub(crate) fn domain(
688    request: &tungstenite::handshake::client::Request,
689) -> Result<String, tungstenite::Error> {
690    request
691        .uri()
692        .host()
693        .map(|host| {
694            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
695            // *not* part of a valid IP, so they must be stripped out.
696            //
697            // The URI from the request is guaranteed to be valid, so we don't need a separate
698            // check for the closing bracket.
699            let host = if host.starts_with('[') {
700                &host[1..host.len() - 1]
701            } else {
702                host
703            };
704
705            host.to_owned()
706        })
707        .ok_or(tungstenite::Error::Url(
708            tungstenite::error::UrlError::NoHostName,
709        ))
710}
711
712#[cfg(any(
713    feature = "async-std-runtime",
714    feature = "tokio-runtime",
715    feature = "gio-runtime"
716))]
717/// Get the port from an URL.
718#[inline]
719pub(crate) fn port(
720    request: &tungstenite::handshake::client::Request,
721) -> Result<u16, tungstenite::Error> {
722    request
723        .uri()
724        .port_u16()
725        .or_else(|| match request.uri().scheme_str() {
726            Some("wss") => Some(443),
727            Some("ws") => Some(80),
728            _ => None,
729        })
730        .ok_or(tungstenite::Error::Url(
731            tungstenite::error::UrlError::UnsupportedUrlScheme,
732        ))
733}
734
735#[cfg(test)]
736mod tests {
737    #[cfg(any(
738        feature = "async-tls",
739        feature = "async-std-runtime",
740        feature = "tokio-runtime",
741        feature = "gio-runtime"
742    ))]
743    #[test]
744    fn domain_strips_ipv6_brackets() {
745        use tungstenite::client::IntoClientRequest;
746
747        let request = "ws://[::1]:80".into_client_request().unwrap();
748        assert_eq!(crate::domain(&request).unwrap(), "::1");
749    }
750
751    #[cfg(feature = "handshake")]
752    #[test]
753    fn requests_cannot_contain_invalid_uris() {
754        use tungstenite::client::IntoClientRequest;
755
756        assert!("ws://[".into_client_request().is_err());
757        assert!("ws://[blabla/bla".into_client_request().is_err());
758        assert!("ws://[::1/bla".into_client_request().is_err());
759    }
760}