async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28//!    provides.
29//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
32//!    the [gio](https://www.gtk-rs.org) runtime.
33//!
34//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35//! making the socket a stream of WebSocket messages coming in and going out.
36
37#![deny(
38    missing_docs,
39    unused_must_use,
40    unused_mut,
41    unused_imports,
42    unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51    feature = "async-tls",
52    feature = "async-native-tls",
53    feature = "tokio-native-tls",
54    feature = "tokio-rustls-manual-roots",
55    feature = "tokio-rustls-native-certs",
56    feature = "tokio-rustls-webpki-roots",
57    feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62    io::{Read, Write},
63    pin::Pin,
64    sync::{Arc, Mutex, MutexGuard},
65    task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75    client::IntoClientRequest,
76    handshake::{
77        client::{ClientHandshake, Response},
78        server::{Callback, NoCallback},
79        HandshakeError,
80    },
81};
82use tungstenite::{
83    error::Error as WsError,
84    protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102/// Creates a WebSocket handshake from a request and a stream.
103/// For convenience, the user may call this with a url string, a URL,
104/// or a `Request`. Calling with `Request` allows the user to add
105/// a WebSocket protocol or other custom headers.
106///
107/// Internally, this custom creates a handshake representation and returns
108/// a future representing the resolution of the WebSocket handshake. The
109/// returned future will resolve to either `WebSocketStream<S>` or `Error`
110/// depending on whether the handshake is successful.
111///
112/// This is typically used for clients who have already established, for
113/// example, a TCP connection to the remote server.
114#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116    request: R,
117    stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120    R: IntoClientRequest + Unpin,
121    S: AsyncRead + AsyncWrite + Unpin,
122{
123    client_async_with_config(request, stream, None).await
124}
125
126/// The same as `client_async()` but the one can specify a websocket configuration.
127/// Please refer to `client_async()` for more details.
128#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130    request: R,
131    stream: S,
132    config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135    R: IntoClientRequest + Unpin,
136    S: AsyncRead + AsyncWrite + Unpin,
137{
138    let f = handshake::client_handshake(stream, move |allow_std| {
139        let request = request.into_client_request()?;
140        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141        cli_handshake.handshake()
142    });
143    f.await.map_err(|e| match e {
144        HandshakeError::Failure(e) => e,
145        e => WsError::Io(std::io::Error::new(
146            std::io::ErrorKind::Other,
147            e.to_string(),
148        )),
149    })
150}
151
152/// Accepts a new WebSocket connection with the provided stream.
153///
154/// This function will internally call `server::accept` to create a
155/// handshake representation and returns a future representing the
156/// resolution of the WebSocket handshake. The returned future will resolve
157/// to either `WebSocketStream<S>` or `Error` depending if it's successful
158/// or not.
159///
160/// This is typically used after a socket has been accepted from a
161/// `TcpListener`. That socket is then passed to this function to perform
162/// the server half of the accepting a client's websocket connection.
163#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166    S: AsyncRead + AsyncWrite + Unpin,
167{
168    accept_hdr_async(stream, NoCallback).await
169}
170
171/// The same as `accept_async()` but the one can specify a websocket configuration.
172/// Please refer to `accept_async()` for more details.
173#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175    stream: S,
176    config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179    S: AsyncRead + AsyncWrite + Unpin,
180{
181    accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184/// Accepts a new WebSocket connection with the provided stream.
185///
186/// This function does the same as `accept_async()` but accepts an extra callback
187/// for header processing. The callback receives headers of the incoming
188/// requests and is able to add extra headers to the reply.
189#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192    S: AsyncRead + AsyncWrite + Unpin,
193    C: Callback + Unpin,
194{
195    accept_hdr_async_with_config(stream, callback, None).await
196}
197
198/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
199/// Please refer to `accept_hdr_async()` for more details.
200#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202    stream: S,
203    callback: C,
204    config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207    S: AsyncRead + AsyncWrite + Unpin,
208    C: Callback + Unpin,
209{
210    let f = handshake::server_handshake(stream, move |allow_std| {
211        tungstenite::accept_hdr_with_config(allow_std, callback, config)
212    });
213    f.await.map_err(|e| match e {
214        HandshakeError::Failure(e) => e,
215        e => WsError::Io(std::io::Error::new(
216            std::io::ErrorKind::Other,
217            e.to_string(),
218        )),
219    })
220}
221
222/// A wrapper around an underlying raw stream which implements the WebSocket
223/// protocol.
224///
225/// A `WebSocketStream<S>` represents a handshake that has been completed
226/// successfully and both the server and the client are ready for receiving
227/// and sending data. Message from a `WebSocketStream<S>` are accessible
228/// through the respective `Stream` and `Sink`. Check more information about
229/// them in `futures-rs` crate documentation or have a look on the examples
230/// and unit tests for this crate.
231#[derive(Debug)]
232pub struct WebSocketStream<S> {
233    inner: WebSocket<AllowStd<S>>,
234    #[cfg(feature = "futures-03-sink")]
235    closing: bool,
236    ended: bool,
237    /// Tungstenite is probably ready to receive more data.
238    ///
239    /// `false` once start_send hits `WouldBlock` errors.
240    /// `true` initially and after `flush`ing.
241    ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245    /// Convert a raw socket into a WebSocketStream without performing a
246    /// handshake.
247    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248    where
249        S: AsyncRead + AsyncWrite + Unpin,
250    {
251        handshake::without_handshake(stream, move |allow_std| {
252            WebSocket::from_raw_socket(allow_std, role, config)
253        })
254        .await
255    }
256
257    /// Convert a raw socket into a WebSocketStream without performing a
258    /// handshake.
259    pub async fn from_partially_read(
260        stream: S,
261        part: Vec<u8>,
262        role: Role,
263        config: Option<WebSocketConfig>,
264    ) -> Self
265    where
266        S: AsyncRead + AsyncWrite + Unpin,
267    {
268        handshake::without_handshake(stream, move |allow_std| {
269            WebSocket::from_partially_read(allow_std, part, role, config)
270        })
271        .await
272    }
273
274    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275        Self {
276            inner: ws,
277            #[cfg(feature = "futures-03-sink")]
278            closing: false,
279            ended: false,
280            ready: true,
281        }
282    }
283
284    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285    where
286        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
287        AllowStd<S>: Read + Write,
288    {
289        #[cfg(feature = "verbose-logging")]
290        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
291        if let Some((kind, ctx)) = ctx {
292            self.inner.get_mut().set_waker(kind, ctx.waker());
293        }
294        f(&mut self.inner)
295    }
296
297    /// Returns a shared reference to the inner stream.
298    pub fn get_ref(&self) -> &S
299    where
300        S: AsyncRead + AsyncWrite + Unpin,
301    {
302        self.inner.get_ref().get_ref()
303    }
304
305    /// Returns a mutable reference to the inner stream.
306    pub fn get_mut(&mut self) -> &mut S
307    where
308        S: AsyncRead + AsyncWrite + Unpin,
309    {
310        self.inner.get_mut().get_mut()
311    }
312
313    /// Returns a reference to the configuration of the tungstenite stream.
314    pub fn get_config(&self) -> &WebSocketConfig {
315        self.inner.get_config()
316    }
317
318    /// Close the underlying web socket
319    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
320    where
321        S: AsyncRead + AsyncWrite + Unpin,
322    {
323        self.send(Message::Close(msg)).await
324    }
325
326    /// Splits the websocket stream into separate
327    /// [sender](WebSocketSender) and [receiver](WebSocketReceiver) parts.
328    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
329        let shared = Arc::new(Shared(Mutex::new(self)));
330        let sender = WebSocketSender {
331            shared: shared.clone(),
332        };
333
334        let receiver = WebSocketReceiver { shared };
335        (sender, receiver)
336    }
337
338    /// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
339    /// parts back into a single stream. If both parts originate from the same
340    /// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
341    /// Otherwise, returns `Err` containing the provided parts.
342    pub fn reunite(
343        sender: WebSocketSender<S>,
344        receiver: WebSocketReceiver<S>,
345    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
346        if sender.is_pair_of(&receiver) {
347            drop(receiver);
348            let stream = Arc::try_unwrap(sender.shared)
349                .ok()
350                .expect("reunite the stream")
351                .into_inner();
352
353            Ok(stream)
354        } else {
355            Err((sender, receiver))
356        }
357    }
358}
359
360impl<S> WebSocketStream<S>
361where
362    S: AsyncRead + AsyncWrite + Unpin,
363{
364    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
365        #[cfg(feature = "verbose-logging")]
366        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
367
368        // The connection has been closed or a critical error has occurred.
369        // We have already returned the error to the user, the `Stream` is unusable,
370        // so we assume that the stream has been "fused".
371        if self.ended {
372            return Poll::Ready(None);
373        }
374
375        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
376            #[cfg(feature = "verbose-logging")]
377            trace!(
378                "{}:{} WebSocketStream.with_context poll_next -> read()",
379                file!(),
380                line!()
381            );
382            cvt(s.read())
383        })) {
384            Ok(v) => Poll::Ready(Some(Ok(v))),
385            Err(e) => {
386                self.ended = true;
387                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
388                    Poll::Ready(None)
389                } else {
390                    Poll::Ready(Some(Err(e)))
391                }
392            }
393        }
394    }
395
396    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
397        if self.ready {
398            return Poll::Ready(Ok(()));
399        }
400
401        // Currently blocked so try to flush the blockage away
402        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
403            .map(|r| {
404                self.ready = true;
405                r
406            })
407    }
408
409    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
410        match self.with_context(None, |s| s.write(item)) {
411            Ok(()) => {
412                self.ready = true;
413                Ok(())
414            }
415            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
416                // the message was accepted and queued so not an error
417                // but `poll_ready` will now start trying to flush the block
418                self.ready = false;
419                Ok(())
420            }
421            Err(e) => {
422                self.ready = true;
423                debug!("websocket start_send error: {}", e);
424                Err(e)
425            }
426        }
427    }
428
429    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
430        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
431            .map(|r| {
432                self.ready = true;
433                match r {
434                    // WebSocket connection has just been closed. Flushing completed, not an error.
435                    Err(WsError::ConnectionClosed) => Ok(()),
436                    other => other,
437                }
438            })
439    }
440
441    #[cfg(feature = "futures-03-sink")]
442    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
443        self.ready = true;
444        let res = if self.closing {
445            // After queueing it, we call `flush` to drive the close handshake to completion.
446            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
447        } else {
448            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
449        };
450
451        match res {
452            Ok(()) => Poll::Ready(Ok(())),
453            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
454            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
455                trace!("WouldBlock");
456                self.closing = true;
457                Poll::Pending
458            }
459            Err(err) => {
460                debug!("websocket close error: {}", err);
461                Poll::Ready(Err(err))
462            }
463        }
464    }
465}
466
467impl<S> Stream for WebSocketStream<S>
468where
469    S: AsyncRead + AsyncWrite + Unpin,
470{
471    type Item = Result<Message, WsError>;
472
473    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
474        self.get_mut().poll_next(cx)
475    }
476}
477
478impl<S> FusedStream for WebSocketStream<S>
479where
480    S: AsyncRead + AsyncWrite + Unpin,
481{
482    fn is_terminated(&self) -> bool {
483        self.ended
484    }
485}
486
487#[cfg(feature = "futures-03-sink")]
488impl<S> futures_util::Sink<Message> for WebSocketStream<S>
489where
490    S: AsyncRead + AsyncWrite + Unpin,
491{
492    type Error = WsError;
493
494    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
495        self.get_mut().poll_ready(cx)
496    }
497
498    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
499        self.get_mut().start_send(item)
500    }
501
502    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
503        self.get_mut().poll_flush(cx)
504    }
505
506    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
507        self.get_mut().poll_close(cx)
508    }
509}
510
511#[cfg(not(feature = "futures-03-sink"))]
512impl<S> bytes::private::SealedSender for WebSocketStream<S>
513where
514    S: AsyncRead + AsyncWrite + Unpin,
515{
516    fn poll_write(
517        self: Pin<&mut Self>,
518        cx: &mut Context<'_>,
519        buf: &[u8],
520    ) -> Poll<Result<usize, WsError>> {
521        let me = self.get_mut();
522        ready!(me.poll_ready(cx))?;
523        let len = buf.len();
524        me.start_send(Message::binary(buf.to_owned()))?;
525        Poll::Ready(Ok(len))
526    }
527
528    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
529        self.get_mut().poll_flush(cx)
530    }
531
532    fn poll_close(
533        self: Pin<&mut Self>,
534        cx: &mut Context<'_>,
535        msg: &mut Option<Message>,
536    ) -> Poll<Result<(), WsError>> {
537        let me = self.get_mut();
538        send_helper(me, msg, cx)
539    }
540}
541
542impl<S> WebSocketStream<S> {
543    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
544    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
545    where
546        S: AsyncRead + AsyncWrite + Unpin,
547    {
548        Send {
549            ws: self,
550            msg: Some(msg),
551        }
552        .await
553    }
554}
555
556struct Send<W> {
557    ws: W,
558    msg: Option<Message>,
559}
560
561/// Performs an asynchronous message send to the websocket.
562fn send_helper<S>(
563    ws: &mut WebSocketStream<S>,
564    msg: &mut Option<Message>,
565    cx: &mut Context<'_>,
566) -> Poll<Result<(), WsError>>
567where
568    S: AsyncRead + AsyncWrite + Unpin,
569{
570    if msg.is_some() {
571        ready!(ws.poll_ready(cx))?;
572        let msg = msg.take().expect("unreachable");
573        ws.start_send(msg)?;
574    }
575
576    ws.poll_flush(cx)
577}
578
579impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
580where
581    S: AsyncRead + AsyncWrite + Unpin,
582{
583    type Output = Result<(), WsError>;
584
585    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
586        let me = self.get_mut();
587        send_helper(me.ws, &mut me.msg, cx)
588    }
589}
590
591impl<S> std::future::Future for Send<&Shared<S>>
592where
593    S: AsyncRead + AsyncWrite + Unpin,
594{
595    type Output = Result<(), WsError>;
596
597    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598        let me = self.get_mut();
599        let mut ws = me.ws.lock();
600        send_helper(&mut ws, &mut me.msg, cx)
601    }
602}
603
604/// The sender part of a [websocket](WebSocketStream) stream.
605#[derive(Debug)]
606pub struct WebSocketSender<S> {
607    shared: Arc<Shared<S>>,
608}
609
610impl<S> WebSocketSender<S> {
611    /// Send a message via [websocket](WebSocketStream).
612    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
613    where
614        S: AsyncRead + AsyncWrite + Unpin,
615    {
616        Send {
617            ws: &*self.shared,
618            msg: Some(msg),
619        }
620        .await
621    }
622
623    /// Close the underlying [websocket](WebSocketStream).
624    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
625    where
626        S: AsyncRead + AsyncWrite + Unpin,
627    {
628        self.send(Message::Close(msg)).await
629    }
630
631    /// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
632    /// were split from the same [websocket](WebSocketStream) stream.
633    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
634        Arc::ptr_eq(&self.shared, &other.shared)
635    }
636}
637
638#[cfg(feature = "futures-03-sink")]
639impl<T> futures_util::Sink<Message> for WebSocketSender<T>
640where
641    T: AsyncRead + AsyncWrite + Unpin,
642{
643    type Error = WsError;
644
645    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
646        self.shared.lock().poll_ready(cx)
647    }
648
649    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
650        self.shared.lock().start_send(item)
651    }
652
653    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
654        self.shared.lock().poll_flush(cx)
655    }
656
657    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658        self.shared.lock().poll_close(cx)
659    }
660}
661
662#[cfg(not(feature = "futures-03-sink"))]
663impl<S> bytes::private::SealedSender for WebSocketSender<S>
664where
665    S: AsyncRead + AsyncWrite + Unpin,
666{
667    fn poll_write(
668        self: Pin<&mut Self>,
669        cx: &mut Context<'_>,
670        buf: &[u8],
671    ) -> Poll<Result<usize, WsError>> {
672        let me = self.get_mut();
673        let mut ws = me.shared.lock();
674        ready!(ws.poll_ready(cx))?;
675        let len = buf.len();
676        ws.start_send(Message::binary(buf.to_owned()))?;
677        Poll::Ready(Ok(len))
678    }
679
680    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
681        self.shared.lock().poll_flush(cx)
682    }
683
684    fn poll_close(
685        self: Pin<&mut Self>,
686        cx: &mut Context<'_>,
687        msg: &mut Option<Message>,
688    ) -> Poll<Result<(), WsError>> {
689        let me = self.get_mut();
690        let mut ws = me.shared.lock();
691        send_helper(&mut ws, msg, cx)
692    }
693}
694
695/// The receiver part of a [websocket](WebSocketStream) stream.
696#[derive(Debug)]
697pub struct WebSocketReceiver<S> {
698    shared: Arc<Shared<S>>,
699}
700
701impl<S> WebSocketReceiver<S> {
702    /// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
703    /// were split from the same [websocket](WebSocketStream) stream.
704    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
705        Arc::ptr_eq(&self.shared, &other.shared)
706    }
707}
708
709impl<S> Stream for WebSocketReceiver<S>
710where
711    S: AsyncRead + AsyncWrite + Unpin,
712{
713    type Item = Result<Message, WsError>;
714
715    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
716        self.shared.lock().poll_next(cx)
717    }
718}
719
720impl<S> FusedStream for WebSocketReceiver<S>
721where
722    S: AsyncRead + AsyncWrite + Unpin,
723{
724    fn is_terminated(&self) -> bool {
725        self.shared.lock().ended
726    }
727}
728
729#[derive(Debug)]
730struct Shared<S>(Mutex<WebSocketStream<S>>);
731
732impl<S> Shared<S> {
733    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
734        self.0.lock().expect("lock shared stream")
735    }
736
737    fn into_inner(self) -> WebSocketStream<S> {
738        self.0.into_inner().expect("get shared stream")
739    }
740}
741
742#[cfg(any(
743    feature = "async-tls",
744    feature = "async-std-runtime",
745    feature = "tokio-runtime",
746    feature = "gio-runtime"
747))]
748/// Get a domain from an URL.
749#[inline]
750pub(crate) fn domain(
751    request: &tungstenite::handshake::client::Request,
752) -> Result<String, tungstenite::Error> {
753    request
754        .uri()
755        .host()
756        .map(|host| {
757            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
758            // *not* part of a valid IP, so they must be stripped out.
759            //
760            // The URI from the request is guaranteed to be valid, so we don't need a separate
761            // check for the closing bracket.
762            let host = if host.starts_with('[') {
763                &host[1..host.len() - 1]
764            } else {
765                host
766            };
767
768            host.to_owned()
769        })
770        .ok_or(tungstenite::Error::Url(
771            tungstenite::error::UrlError::NoHostName,
772        ))
773}
774
775#[cfg(any(
776    feature = "async-std-runtime",
777    feature = "tokio-runtime",
778    feature = "gio-runtime"
779))]
780/// Get the port from an URL.
781#[inline]
782pub(crate) fn port(
783    request: &tungstenite::handshake::client::Request,
784) -> Result<u16, tungstenite::Error> {
785    request
786        .uri()
787        .port_u16()
788        .or_else(|| match request.uri().scheme_str() {
789            Some("wss") => Some(443),
790            Some("ws") => Some(80),
791            _ => None,
792        })
793        .ok_or(tungstenite::Error::Url(
794            tungstenite::error::UrlError::UnsupportedUrlScheme,
795        ))
796}
797
798#[cfg(test)]
799mod tests {
800    #[cfg(any(
801        feature = "async-tls",
802        feature = "async-std-runtime",
803        feature = "tokio-runtime",
804        feature = "gio-runtime"
805    ))]
806    #[test]
807    fn domain_strips_ipv6_brackets() {
808        use tungstenite::client::IntoClientRequest;
809
810        let request = "ws://[::1]:80".into_client_request().unwrap();
811        assert_eq!(crate::domain(&request).unwrap(), "::1");
812    }
813
814    #[cfg(feature = "handshake")]
815    #[test]
816    fn requests_cannot_contain_invalid_uris() {
817        use tungstenite::client::IntoClientRequest;
818
819        assert!("ws://[".into_client_request().is_err());
820        assert!("ws://[blabla/bla".into_client_request().is_err());
821        assert!("ws://[::1/bla".into_client_request().is_err());
822    }
823}