ant_libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use ant_libp2p_core as libp2p_core;
22
23use std::{
24    borrow::Cow,
25    collections::HashMap,
26    fmt, io, mem,
27    net::IpAddr,
28    ops::DerefMut,
29    pin::Pin,
30    sync::Arc,
31    task::{Context, Poll},
32};
33
34use either::Either;
35use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
36use futures_rustls::{client, rustls::pki_types::ServerName, server};
37use libp2p_core::{
38    multiaddr::{Multiaddr, Protocol},
39    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
40    Transport,
41};
42use parking_lot::Mutex;
43use soketto::{
44    connection::{self, CloseReason},
45    handshake,
46};
47use url::Url;
48
49use crate::{error::Error, quicksink, tls};
50
51/// Max. number of payload bytes of a single frame.
52const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
53
54/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
55/// frame payloads which does not implement [`AsyncRead`] or
56/// [`AsyncWrite`]. See [`crate::WsConfig`] if you require the latter.
57#[derive(Debug)]
58pub struct WsConfig<T> {
59    transport: Arc<Mutex<T>>,
60    max_data_size: usize,
61    tls_config: tls::Config,
62    max_redirects: u8,
63    /// Websocket protocol of the inner listener.
64    listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
65}
66
67impl<T> WsConfig<T>
68where
69    T: Send,
70{
71    /// Create a new websocket transport based on another transport.
72    pub fn new(transport: T) -> Self {
73        WsConfig {
74            transport: Arc::new(Mutex::new(transport)),
75            max_data_size: MAX_DATA_SIZE,
76            tls_config: tls::Config::client(),
77            max_redirects: 0,
78            listener_protos: HashMap::new(),
79        }
80    }
81
82    /// Return the configured maximum number of redirects.
83    pub fn max_redirects(&self) -> u8 {
84        self.max_redirects
85    }
86
87    /// Set max. number of redirects to follow.
88    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
89        self.max_redirects = max;
90        self
91    }
92
93    /// Get the max. frame data size we support.
94    pub fn max_data_size(&self) -> usize {
95        self.max_data_size
96    }
97
98    /// Set the max. frame data size we support.
99    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
100        self.max_data_size = size;
101        self
102    }
103
104    /// Set the TLS configuration if TLS support is desired.
105    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
106        self.tls_config = c;
107        self
108    }
109}
110
111type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
112
113impl<T> Transport for WsConfig<T>
114where
115    T: Transport + Send + Unpin + 'static,
116    T::Error: Send + 'static,
117    T::Dial: Send + 'static,
118    T::ListenerUpgrade: Send + 'static,
119    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
120{
121    type Output = Connection<T::Output>;
122    type Error = Error<T::Error>;
123    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
124    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
125
126    fn listen_on(
127        &mut self,
128        id: ListenerId,
129        addr: Multiaddr,
130    ) -> Result<(), TransportError<Self::Error>> {
131        let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
132            tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
133            TransportError::MultiaddrNotSupported(addr.clone())
134        })?;
135
136        if proto.use_tls() && self.tls_config.server.is_none() {
137            tracing::debug!(
138                "{} address but TLS server support is not configured",
139                proto.prefix()
140            );
141            return Err(TransportError::MultiaddrNotSupported(addr));
142        }
143
144        match self.transport.lock().listen_on(id, inner_addr) {
145            Ok(()) => {
146                self.listener_protos.insert(id, proto);
147                Ok(())
148            }
149            Err(e) => Err(e.map(Error::Transport)),
150        }
151    }
152
153    fn remove_listener(&mut self, id: ListenerId) -> bool {
154        self.transport.lock().remove_listener(id)
155    }
156
157    fn dial(
158        &mut self,
159        addr: Multiaddr,
160        dial_opts: DialOpts,
161    ) -> Result<Self::Dial, TransportError<Self::Error>> {
162        self.do_dial(addr, dial_opts)
163    }
164
165    fn poll(
166        mut self: Pin<&mut Self>,
167        cx: &mut Context<'_>,
168    ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
169        let inner_event = {
170            let mut transport = self.transport.lock();
171            match Transport::poll(Pin::new(transport.deref_mut()), cx) {
172                Poll::Ready(ev) => ev,
173                Poll::Pending => return Poll::Pending,
174            }
175        };
176        let event = match inner_event {
177            TransportEvent::NewAddress {
178                listener_id,
179                mut listen_addr,
180            } => {
181                // Append the ws / wss protocol back to the inner address.
182                self.listener_protos
183                    .get(&listener_id)
184                    .expect("Protocol was inserted in Transport::listen_on.")
185                    .append_on_addr(&mut listen_addr);
186                tracing::debug!(address=%listen_addr, "Listening on address");
187                TransportEvent::NewAddress {
188                    listener_id,
189                    listen_addr,
190                }
191            }
192            TransportEvent::AddressExpired {
193                listener_id,
194                mut listen_addr,
195            } => {
196                self.listener_protos
197                    .get(&listener_id)
198                    .expect("Protocol was inserted in Transport::listen_on.")
199                    .append_on_addr(&mut listen_addr);
200                TransportEvent::AddressExpired {
201                    listener_id,
202                    listen_addr,
203                }
204            }
205            TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
206                listener_id,
207                error: Error::Transport(error),
208            },
209            TransportEvent::ListenerClosed {
210                listener_id,
211                reason,
212            } => {
213                self.listener_protos
214                    .remove(&listener_id)
215                    .expect("Protocol was inserted in Transport::listen_on.");
216                TransportEvent::ListenerClosed {
217                    listener_id,
218                    reason: reason.map_err(Error::Transport),
219                }
220            }
221            TransportEvent::Incoming {
222                listener_id,
223                upgrade,
224                mut local_addr,
225                mut send_back_addr,
226            } => {
227                let proto = self
228                    .listener_protos
229                    .get(&listener_id)
230                    .expect("Protocol was inserted in Transport::listen_on.");
231                let use_tls = proto.use_tls();
232                proto.append_on_addr(&mut local_addr);
233                proto.append_on_addr(&mut send_back_addr);
234                let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
235                TransportEvent::Incoming {
236                    listener_id,
237                    upgrade,
238                    local_addr,
239                    send_back_addr,
240                }
241            }
242        };
243        Poll::Ready(event)
244    }
245}
246
247impl<T> WsConfig<T>
248where
249    T: Transport + Send + Unpin + 'static,
250    T::Error: Send + 'static,
251    T::Dial: Send + 'static,
252    T::ListenerUpgrade: Send + 'static,
253    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
254{
255    fn do_dial(
256        &mut self,
257        addr: Multiaddr,
258        dial_opts: DialOpts,
259    ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
260        let mut addr = match parse_ws_dial_addr(addr) {
261            Ok(addr) => addr,
262            Err(Error::InvalidMultiaddr(a)) => {
263                return Err(TransportError::MultiaddrNotSupported(a))
264            }
265            Err(e) => return Err(TransportError::Other(e)),
266        };
267
268        // We are looping here in order to follow redirects (if any):
269        let mut remaining_redirects = self.max_redirects;
270
271        let transport = self.transport.clone();
272        let tls_config = self.tls_config.clone();
273        let max_redirects = self.max_redirects;
274
275        let future = async move {
276            loop {
277                match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
278                {
279                    Ok(Either::Left(redirect)) => {
280                        if remaining_redirects == 0 {
281                            tracing::debug!(%max_redirects, "Too many redirects");
282                            return Err(Error::TooManyRedirects);
283                        }
284                        remaining_redirects -= 1;
285                        addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
286                    }
287                    Ok(Either::Right(conn)) => return Ok(conn),
288                    Err(e) => return Err(e),
289                }
290            }
291        };
292
293        Ok(Box::pin(future))
294    }
295
296    /// Attempts to dial the given address and perform a websocket handshake.
297    async fn dial_once(
298        transport: Arc<Mutex<T>>,
299        addr: WsAddress,
300        tls_config: tls::Config,
301        dial_opts: DialOpts,
302    ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
303        tracing::trace!(address=?addr, "Dialing websocket address");
304
305        let dial = transport
306            .lock()
307            .dial(addr.tcp_addr, dial_opts)
308            .map_err(|e| match e {
309                TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
310                TransportError::Other(e) => Error::Transport(e),
311            })?;
312
313        let stream = dial.map_err(Error::Transport).await?;
314        tracing::trace!(port=%addr.host_port, "TCP connection established");
315
316        let stream = if addr.use_tls {
317            // begin TLS session
318            tracing::trace!(?addr.server_name, "Starting TLS handshake");
319            let stream = tls_config
320                .client
321                .connect(addr.server_name.clone(), stream)
322                .map_err(|e| {
323                    tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
324                    Error::Tls(tls::Error::from(e))
325                })
326                .await?;
327
328            let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
329            stream
330        } else {
331            // continue with plain stream
332            future::Either::Right(stream)
333        };
334
335        tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
336
337        let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
338
339        match client
340            .handshake()
341            .map_err(|e| Error::Handshake(Box::new(e)))
342            .await?
343        {
344            handshake::ServerResponse::Redirect {
345                status_code,
346                location,
347            } => {
348                tracing::debug!(
349                    %status_code,
350                    %location,
351                    "received redirect"
352                );
353                Ok(Either::Left(location))
354            }
355            handshake::ServerResponse::Rejected { status_code } => {
356                let msg = format!("server rejected handshake; status code = {status_code}");
357                Err(Error::Handshake(msg.into()))
358            }
359            handshake::ServerResponse::Accepted { .. } => {
360                tracing::trace!(port=%addr.host_port, "websocket handshake successful");
361                Ok(Either::Right(Connection::new(client.into_builder())))
362            }
363        }
364    }
365
366    fn map_upgrade(
367        &self,
368        upgrade: T::ListenerUpgrade,
369        remote_addr: Multiaddr,
370        use_tls: bool,
371    ) -> <Self as Transport>::ListenerUpgrade {
372        let remote_addr2 = remote_addr.clone(); // used for logging
373        let tls_config = self.tls_config.clone();
374        let max_size = self.max_data_size;
375
376        async move {
377            let stream = upgrade.map_err(Error::Transport).await?;
378            tracing::trace!(address=%remote_addr, "incoming connection from address");
379
380            let stream = if use_tls {
381                // begin TLS session
382                let server = tls_config
383                    .server
384                    .expect("for use_tls we checked server is not none");
385
386                tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
387
388                let stream = server
389                    .accept(stream)
390                    .map_err(move |e| {
391                        tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
392                        Error::Tls(tls::Error::from(e))
393                    })
394                    .await?;
395
396                let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
397
398                stream
399            } else {
400                // continue with plain stream
401                future::Either::Right(stream)
402            };
403
404            tracing::trace!(
405                address=%remote_addr2,
406                "receiving websocket handshake request from address"
407            );
408
409            let mut server = handshake::Server::new(stream);
410
411            let ws_key = {
412                let request = server
413                    .receive_request()
414                    .map_err(|e| Error::Handshake(Box::new(e)))
415                    .await?;
416                request.key()
417            };
418
419            tracing::trace!(
420                address=%remote_addr2,
421                "accepting websocket handshake request from address"
422            );
423
424            let response = handshake::server::Response::Accept {
425                key: ws_key,
426                protocol: None,
427            };
428
429            server
430                .send_response(&response)
431                .map_err(|e| Error::Handshake(Box::new(e)))
432                .await?;
433
434            let conn = {
435                let mut builder = server.into_builder();
436                builder.set_max_message_size(max_size);
437                builder.set_max_frame_size(max_size);
438                Connection::new(builder)
439            };
440
441            Ok(conn)
442        }
443        .boxed()
444    }
445}
446
447#[derive(Debug, PartialEq)]
448pub(crate) enum WsListenProto<'a> {
449    Ws(Cow<'a, str>),
450    Wss(Cow<'a, str>),
451    TlsWs(Cow<'a, str>),
452}
453
454impl WsListenProto<'_> {
455    pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
456        match self {
457            WsListenProto::Ws(path) => {
458                addr.push(Protocol::Ws(path.clone()));
459            }
460            // `/tls/ws` and `/wss` are equivalend, however we regenerate
461            // the one that user passed at `listen_on` for backward compatibility.
462            WsListenProto::Wss(path) => {
463                addr.push(Protocol::Wss(path.clone()));
464            }
465            WsListenProto::TlsWs(path) => {
466                addr.push(Protocol::Tls);
467                addr.push(Protocol::Ws(path.clone()));
468            }
469        }
470    }
471
472    pub(crate) fn use_tls(&self) -> bool {
473        match self {
474            WsListenProto::Ws(_) => false,
475            WsListenProto::Wss(_) => true,
476            WsListenProto::TlsWs(_) => true,
477        }
478    }
479
480    pub(crate) fn prefix(&self) -> &'static str {
481        match self {
482            WsListenProto::Ws(_) => "/ws",
483            WsListenProto::Wss(_) => "/wss",
484            WsListenProto::TlsWs(_) => "/tls/ws",
485        }
486    }
487}
488
489#[derive(Debug)]
490struct WsAddress {
491    host_port: String,
492    path: String,
493    server_name: ServerName<'static>,
494    use_tls: bool,
495    tcp_addr: Multiaddr,
496}
497
498/// Tries to parse the given `Multiaddr` into a `WsAddress` used
499/// for dialing.
500///
501/// Fails if the given `Multiaddr` does not represent a TCP/IP-based
502/// websocket protocol stack.
503fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
504    // The encapsulating protocol must be based on TCP/IP, possibly via DNS.
505    // We peek at it in order to learn the hostname and port to use for
506    // the websocket handshake.
507    let mut protocols = addr.iter();
508    let mut ip = protocols.next();
509    let mut tcp = protocols.next();
510    let (host_port, server_name) = loop {
511        match (ip, tcp) {
512            (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
513                let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
514                break (format!("{ip}:{port}"), server_name);
515            }
516            (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
517                let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
518                break (format!("[{ip}]:{port}"), server_name);
519            }
520            (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
521            | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
522            | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
523                break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
524            }
525            (Some(_), Some(p)) => {
526                ip = Some(p);
527                tcp = protocols.next();
528            }
529            _ => return Err(Error::InvalidMultiaddr(addr)),
530        }
531    };
532
533    // Now consume the `Ws` / `Wss` protocol from the end of the address,
534    // preserving the trailing `P2p` protocol that identifies the remote,
535    // if any.
536    let mut protocols = addr.clone();
537    let mut p2p = None;
538    let (use_tls, path) = loop {
539        match protocols.pop() {
540            p @ Some(Protocol::P2p(_)) => p2p = p,
541            Some(Protocol::Ws(path)) => match protocols.pop() {
542                Some(Protocol::Tls) => break (true, path.into_owned()),
543                Some(p) => {
544                    protocols.push(p);
545                    break (false, path.into_owned());
546                }
547                None => return Err(Error::InvalidMultiaddr(addr)),
548            },
549            Some(Protocol::Wss(path)) => break (true, path.into_owned()),
550            _ => return Err(Error::InvalidMultiaddr(addr)),
551        }
552    };
553
554    // The original address, stripped of the `/ws` and `/wss` protocols,
555    // makes up the address for the inner TCP-based transport.
556    let tcp_addr = match p2p {
557        Some(p) => protocols.with(p),
558        None => protocols,
559    };
560
561    Ok(WsAddress {
562        host_port,
563        server_name,
564        path,
565        use_tls,
566        tcp_addr,
567    })
568}
569
570fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
571    let mut inner_addr = addr.clone();
572
573    match inner_addr.pop()? {
574        Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
575        Protocol::Ws(path) => match inner_addr.pop()? {
576            Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
577            p => {
578                inner_addr.push(p);
579                Some((inner_addr, WsListenProto::Ws(path)))
580            }
581        },
582        _ => None,
583    }
584}
585
586// Given a location URL, build a new websocket [`Multiaddr`].
587fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
588    match Url::parse(location) {
589        Ok(url) => {
590            let mut a = Multiaddr::empty();
591            match url.host() {
592                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
593                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
594                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
595                None => return Err(Error::InvalidRedirectLocation),
596            }
597            if let Some(p) = url.port() {
598                a.push(Protocol::Tcp(p))
599            }
600            let s = url.scheme();
601            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
602                a.push(Protocol::Tls);
603                a.push(Protocol::Ws(url.path().into()));
604            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
605                a.push(Protocol::Ws(url.path().into()))
606            } else {
607                tracing::debug!(scheme=%s, "unsupported scheme");
608                return Err(Error::InvalidRedirectLocation);
609            }
610            Ok(a)
611        }
612        Err(e) => {
613            tracing::debug!("failed to parse url as multi-address: {:?}", e);
614            Err(Error::InvalidRedirectLocation)
615        }
616    }
617}
618
619/// The websocket connection.
620pub struct Connection<T> {
621    receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
622    sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
623    _marker: std::marker::PhantomData<T>,
624}
625
626/// Data or control information received over the websocket connection.
627#[derive(Debug, Clone)]
628pub enum Incoming {
629    /// Application data.
630    Data(Data),
631    /// PONG control frame data.
632    Pong(Vec<u8>),
633    /// Close reason.
634    Closed(CloseReason),
635}
636
637/// Application data received over the websocket connection
638#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
639pub enum Data {
640    /// UTF-8 encoded textual data.
641    Text(Vec<u8>),
642    /// Binary data.
643    Binary(Vec<u8>),
644}
645
646impl Data {
647    pub fn into_bytes(self) -> Vec<u8> {
648        match self {
649            Data::Text(d) => d,
650            Data::Binary(d) => d,
651        }
652    }
653}
654
655impl AsRef<[u8]> for Data {
656    fn as_ref(&self) -> &[u8] {
657        match self {
658            Data::Text(d) => d,
659            Data::Binary(d) => d,
660        }
661    }
662}
663
664impl Incoming {
665    pub fn is_data(&self) -> bool {
666        self.is_binary() || self.is_text()
667    }
668
669    pub fn is_binary(&self) -> bool {
670        matches!(self, Incoming::Data(Data::Binary(_)))
671    }
672
673    pub fn is_text(&self) -> bool {
674        matches!(self, Incoming::Data(Data::Text(_)))
675    }
676
677    pub fn is_pong(&self) -> bool {
678        matches!(self, Incoming::Pong(_))
679    }
680
681    pub fn is_close(&self) -> bool {
682        matches!(self, Incoming::Closed(_))
683    }
684}
685
686/// Data sent over the websocket connection.
687#[derive(Debug, Clone)]
688pub enum OutgoingData {
689    /// Send some bytes.
690    Binary(Vec<u8>),
691    /// Send a PING message.
692    Ping(Vec<u8>),
693    /// Send an unsolicited PONG message.
694    /// (Incoming PINGs are answered automatically.)
695    Pong(Vec<u8>),
696}
697
698impl<T> fmt::Debug for Connection<T> {
699    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
700        f.write_str("Connection")
701    }
702}
703
704impl<T> Connection<T>
705where
706    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
707{
708    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
709        let (sender, receiver) = builder.finish();
710        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
711            match action {
712                quicksink::Action::Send(OutgoingData::Binary(x)) => {
713                    sender.send_binary_mut(x).await?
714                }
715                quicksink::Action::Send(OutgoingData::Ping(x)) => {
716                    let data = x[..].try_into().map_err(|_| {
717                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
718                    })?;
719                    sender.send_ping(data).await?
720                }
721                quicksink::Action::Send(OutgoingData::Pong(x)) => {
722                    let data = x[..].try_into().map_err(|_| {
723                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
724                    })?;
725                    sender.send_pong(data).await?
726                }
727                quicksink::Action::Flush => sender.flush().await?,
728                quicksink::Action::Close => sender.close().await?,
729            }
730            Ok(sender)
731        });
732        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
733            match receiver.receive(&mut data).await {
734                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
735                    Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
736                    (data, receiver),
737                )),
738                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
739                    Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
740                    (data, receiver),
741                )),
742                Ok(soketto::Incoming::Pong(pong)) => {
743                    Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
744                }
745                Ok(soketto::Incoming::Closed(reason)) => {
746                    Some((Ok(Incoming::Closed(reason)), (data, receiver)))
747                }
748                Err(connection::Error::Closed) => None,
749                Err(e) => Some((Err(e), (data, receiver))),
750            }
751        });
752        Connection {
753            receiver: stream.boxed(),
754            sender: Box::pin(sink),
755            _marker: std::marker::PhantomData,
756        }
757    }
758
759    /// Send binary application data to the remote.
760    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
761        self.send(OutgoingData::Binary(data))
762    }
763
764    /// Send a PING to the remote.
765    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
766        self.send(OutgoingData::Ping(data))
767    }
768
769    /// Send an unsolicited PONG to the remote.
770    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
771        self.send(OutgoingData::Pong(data))
772    }
773}
774
775impl<T> Stream for Connection<T>
776where
777    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
778{
779    type Item = io::Result<Incoming>;
780
781    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
782        let item = ready!(self.receiver.poll_next_unpin(cx));
783        let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
784        Poll::Ready(item)
785    }
786}
787
788impl<T> Sink<OutgoingData> for Connection<T>
789where
790    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
791{
792    type Error = io::Error;
793
794    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
795        Pin::new(&mut self.sender)
796            .poll_ready(cx)
797            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
798    }
799
800    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
801        Pin::new(&mut self.sender)
802            .start_send(item)
803            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
804    }
805
806    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
807        Pin::new(&mut self.sender)
808            .poll_flush(cx)
809            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
810    }
811
812    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
813        Pin::new(&mut self.sender)
814            .poll_close(cx)
815            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use std::io;
822
823    use libp2p_identity::PeerId;
824
825    use super::*;
826
827    #[test]
828    fn listen_addr() {
829        let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
830
831        // Check `/tls/ws`
832        let addr = tcp_addr
833            .clone()
834            .with(Protocol::Tls)
835            .with(Protocol::Ws("/".into()));
836        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
837        assert_eq!(&inner_addr, &tcp_addr);
838        assert_eq!(proto, WsListenProto::TlsWs("/".into()));
839
840        let mut listen_addr = tcp_addr.clone();
841        proto.append_on_addr(&mut listen_addr);
842        assert_eq!(listen_addr, addr);
843
844        // Check `/wss`
845        let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
846        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
847        assert_eq!(&inner_addr, &tcp_addr);
848        assert_eq!(proto, WsListenProto::Wss("/".into()));
849
850        let mut listen_addr = tcp_addr.clone();
851        proto.append_on_addr(&mut listen_addr);
852        assert_eq!(listen_addr, addr);
853
854        // Check `/ws`
855        let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
856        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
857        assert_eq!(&inner_addr, &tcp_addr);
858        assert_eq!(proto, WsListenProto::Ws("/".into()));
859
860        let mut listen_addr = tcp_addr.clone();
861        proto.append_on_addr(&mut listen_addr);
862        assert_eq!(listen_addr, addr);
863    }
864
865    #[test]
866    fn dial_addr() {
867        let peer_id = PeerId::random();
868
869        // Check `/tls/ws`
870        let addr = "/dns4/example.com/tcp/2222/tls/ws"
871            .parse::<Multiaddr>()
872            .unwrap();
873        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
874        assert_eq!(info.host_port, "example.com:2222");
875        assert_eq!(info.path, "/");
876        assert!(info.use_tls);
877        assert_eq!(info.server_name, "example.com".try_into().unwrap());
878        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
879
880        // Check `/tls/ws` with `/p2p`
881        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
882            .parse()
883            .unwrap();
884        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
885        assert_eq!(info.host_port, "example.com:2222");
886        assert_eq!(info.path, "/");
887        assert!(info.use_tls);
888        assert_eq!(info.server_name, "example.com".try_into().unwrap());
889        assert_eq!(
890            info.tcp_addr,
891            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
892                .parse()
893                .unwrap()
894        );
895
896        // Check `/tls/ws` with `/ip4`
897        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
898            .parse::<Multiaddr>()
899            .unwrap();
900        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
901        assert_eq!(info.host_port, "127.0.0.1:2222");
902        assert_eq!(info.path, "/");
903        assert!(info.use_tls);
904        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
905        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
906
907        // Check `/tls/ws` with `/ip6`
908        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
909        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
910        assert_eq!(info.host_port, "[::1]:2222");
911        assert_eq!(info.path, "/");
912        assert!(info.use_tls);
913        assert_eq!(info.server_name, "::1".try_into().unwrap());
914        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
915
916        // Check `/wss`
917        let addr = "/dns4/example.com/tcp/2222/wss"
918            .parse::<Multiaddr>()
919            .unwrap();
920        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
921        assert_eq!(info.host_port, "example.com:2222");
922        assert_eq!(info.path, "/");
923        assert!(info.use_tls);
924        assert_eq!(info.server_name, "example.com".try_into().unwrap());
925        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
926
927        // Check `/wss` with `/p2p`
928        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
929            .parse()
930            .unwrap();
931        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
932        assert_eq!(info.host_port, "example.com:2222");
933        assert_eq!(info.path, "/");
934        assert!(info.use_tls);
935        assert_eq!(info.server_name, "example.com".try_into().unwrap());
936        assert_eq!(
937            info.tcp_addr,
938            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
939                .parse()
940                .unwrap()
941        );
942
943        // Check `/wss` with `/ip4`
944        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
945        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
946        assert_eq!(info.host_port, "127.0.0.1:2222");
947        assert_eq!(info.path, "/");
948        assert!(info.use_tls);
949        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
950        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
951
952        // Check `/wss` with `/ip6`
953        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
954        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
955        assert_eq!(info.host_port, "[::1]:2222");
956        assert_eq!(info.path, "/");
957        assert!(info.use_tls);
958        assert_eq!(info.server_name, "::1".try_into().unwrap());
959        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
960
961        // Check `/ws`
962        let addr = "/dns4/example.com/tcp/2222/ws"
963            .parse::<Multiaddr>()
964            .unwrap();
965        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
966        assert_eq!(info.host_port, "example.com:2222");
967        assert_eq!(info.path, "/");
968        assert!(!info.use_tls);
969        assert_eq!(info.server_name, "example.com".try_into().unwrap());
970        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
971
972        // Check `/ws` with `/p2p`
973        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
974            .parse()
975            .unwrap();
976        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
977        assert_eq!(info.host_port, "example.com:2222");
978        assert_eq!(info.path, "/");
979        assert!(!info.use_tls);
980        assert_eq!(info.server_name, "example.com".try_into().unwrap());
981        assert_eq!(
982            info.tcp_addr,
983            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
984                .parse()
985                .unwrap()
986        );
987
988        // Check `/ws` with `/ip4`
989        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
990        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
991        assert_eq!(info.host_port, "127.0.0.1:2222");
992        assert_eq!(info.path, "/");
993        assert!(!info.use_tls);
994        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
995        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
996
997        // Check `/ws` with `/ip6`
998        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
999        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1000        assert_eq!(info.host_port, "[::1]:2222");
1001        assert_eq!(info.path, "/");
1002        assert!(!info.use_tls);
1003        assert_eq!(info.server_name, "::1".try_into().unwrap());
1004        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1005
1006        // Check `/dnsaddr`
1007        let addr = "/dnsaddr/example.com/tcp/2222/ws"
1008            .parse::<Multiaddr>()
1009            .unwrap();
1010        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1011
1012        // Check non-ws address
1013        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1014        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1015    }
1016}