mwc_libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use futures_rustls::{webpki, client, server};
22use crate::{error::Error, tls};
23use either::Either;
24use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
25use mwc_libp2p_core::{
26    Transport,
27    either::EitherOutput,
28    multiaddr::{Protocol, Multiaddr},
29    transport::{ListenerEvent, TransportError}
30};
31use log::{debug, trace};
32use soketto::{connection, extension::deflate::Deflate, handshake};
33use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll};
34use url::Url;
35
36/// Max. number of payload bytes of a single frame.
37const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
38
39/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
40/// frame payloads which does not implement [`AsyncRead`] or
41/// [`AsyncWrite`]. See [`crate::WsConfig`] if you require the latter.
42#[derive(Debug, Clone)]
43pub struct WsConfig<T> {
44    transport: T,
45    max_data_size: usize,
46    tls_config: tls::Config,
47    max_redirects: u8,
48    use_deflate: bool
49}
50
51impl<T> WsConfig<T> {
52    /// Create a new websocket transport based on another transport.
53    pub fn new(transport: T) -> Self {
54        WsConfig {
55            transport,
56            max_data_size: MAX_DATA_SIZE,
57            tls_config: tls::Config::client(),
58            max_redirects: 0,
59            use_deflate: false
60        }
61    }
62
63    /// Return the configured maximum number of redirects.
64    pub fn max_redirects(&self) -> u8 {
65        self.max_redirects
66    }
67
68    /// Set max. number of redirects to follow.
69    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
70        self.max_redirects = max;
71        self
72    }
73
74    /// Get the max. frame data size we support.
75    pub fn max_data_size(&self) -> usize {
76        self.max_data_size
77    }
78
79    /// Set the max. frame data size we support.
80    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
81        self.max_data_size = size;
82        self
83    }
84
85    /// Set the TLS configuration if TLS support is desired.
86    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
87        self.tls_config = c;
88        self
89    }
90
91    /// Should the deflate extension (RFC 7692) be used if supported?
92    pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
93        self.use_deflate = flag;
94        self
95    }
96}
97
98type TlsOrPlain<T> = EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>;
99
100impl<T> Transport for WsConfig<T>
101where
102    T: Transport + Send + Clone + 'static,
103    T::Error: Send + 'static,
104    T::Dial: Send + 'static,
105    T::Listener: Send + 'static,
106    T::ListenerUpgrade: Send + 'static,
107    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
108{
109    type Output = Connection<T::Output>;
110    type Error = Error<T::Error>;
111    type Listener = BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
112    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
113    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
114
115    fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
116        let mut inner_addr = addr.clone();
117
118        let (use_tls, proto) = match inner_addr.pop() {
119            Some(p@Protocol::Wss(_)) =>
120                if self.tls_config.server.is_some() {
121                    (true, p)
122                } else {
123                    debug!("/wss address but TLS server support is not configured");
124                    return Err(TransportError::MultiaddrNotSupported(addr))
125                }
126            Some(p@Protocol::Ws(_)) => (false, p),
127            _ => {
128                debug!("{} is not a websocket multiaddr", addr);
129                return Err(TransportError::MultiaddrNotSupported(addr))
130            }
131        };
132
133        let tls_config = self.tls_config;
134        let max_size = self.max_data_size;
135        let use_deflate = self.use_deflate;
136        let transport = self.transport.listen_on(inner_addr).map_err(|e| e.map(Error::Transport))?;
137        let listen = transport
138            .map_err(Error::Transport)
139            .map_ok(move |event| match event {
140                ListenerEvent::NewAddress(mut a) => {
141                    a = a.with(proto.clone());
142                    debug!("Listening on {}", a);
143                    ListenerEvent::NewAddress(a)
144                }
145                ListenerEvent::AddressExpired(mut a) => {
146                    a = a.with(proto.clone());
147                    ListenerEvent::AddressExpired(a)
148                }
149                ListenerEvent::Error(err) => {
150                    ListenerEvent::Error(Error::Transport(err))
151                }
152                ListenerEvent::Upgrade { upgrade, mut local_addr, mut remote_addr } => {
153                    local_addr = local_addr.with(proto.clone());
154                    remote_addr = remote_addr.with(proto.clone());
155                    let remote1 = remote_addr.clone(); // used for logging
156                    let remote2 = remote_addr.clone(); // used for logging
157                    let tls_config = tls_config.clone();
158
159                    let upgrade = async move {
160                        let stream = upgrade.map_err(Error::Transport).await?;
161                        trace!("incoming connection from {}", remote1);
162
163                        let stream =
164                            if use_tls { // begin TLS session
165                                let server = tls_config
166                                    .server
167                                    .expect("for use_tls we checked server is not none");
168
169                                trace!("awaiting TLS handshake with {}", remote1);
170
171                                let stream = server.accept(stream)
172                                    .map_err(move |e| {
173                                        debug!("TLS handshake with {} failed: {}", remote1, e);
174                                        Error::Tls(tls::Error::from(e))
175                                    })
176                                    .await?;
177
178                                let stream: TlsOrPlain<_> =
179                                    EitherOutput::First(EitherOutput::Second(stream));
180
181                                stream
182                            } else { // continue with plain stream
183                                EitherOutput::Second(stream)
184                            };
185
186                        trace!("receiving websocket handshake request from {}", remote2);
187
188                        let mut server = handshake::Server::new(stream);
189
190                        if use_deflate {
191                            server.add_extension(Box::new(Deflate::new(connection::Mode::Server)));
192                        }
193
194                        let ws_key = {
195                            let request = server.receive_request()
196                                .map_err(|e| Error::Handshake(Box::new(e)))
197                                .await?;
198                            request.into_key()
199                        };
200
201                        trace!("accepting websocket handshake request from {}", remote2);
202
203                        let response =
204                            handshake::server::Response::Accept {
205                                key: &ws_key,
206                                protocol: None
207                            };
208
209                        server.send_response(&response)
210                            .map_err(|e| Error::Handshake(Box::new(e)))
211                            .await?;
212
213                        let conn = {
214                            let mut builder = server.into_builder();
215                            builder.set_max_message_size(max_size);
216                            builder.set_max_frame_size(max_size);
217                            Connection::new(builder)
218                        };
219
220                        Ok(conn)
221                    };
222
223                    ListenerEvent::Upgrade {
224                        upgrade: Box::pin(upgrade) as BoxFuture<'static, _>,
225                        local_addr,
226                        remote_addr
227                    }
228                }
229            });
230        Ok(Box::pin(listen))
231    }
232
233    fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
234        // Quick sanity check of the provided Multiaddr.
235        if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
236            // ok
237        } else {
238            debug!("{} is not a websocket multiaddr", addr);
239            return Err(TransportError::MultiaddrNotSupported(addr))
240        }
241
242        // We are looping here in order to follow redirects (if any):
243        let mut remaining_redirects = self.max_redirects;
244        let mut addr = addr;
245        let future = async move {
246            loop {
247                let this = self.clone();
248                match this.dial_once(addr).await {
249                    Ok(Either::Left(redirect)) => {
250                        if remaining_redirects == 0 {
251                            debug!("too many redirects");
252                            return Err(Error::TooManyRedirects)
253                        }
254                        remaining_redirects -= 1;
255                        addr = location_to_multiaddr(&redirect)?
256                    }
257                    Ok(Either::Right(conn)) => return Ok(conn),
258                    Err(e) => return Err(e)
259                }
260            }
261        };
262
263        Ok(Box::pin(future))
264    }
265
266    fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
267        self.transport.address_translation(server, observed)
268    }
269}
270
271impl<T> WsConfig<T>
272where
273    T: Transport,
274    T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static
275{
276    /// Attempty to dial the given address and perform a websocket handshake.
277    async fn dial_once(self, address: Multiaddr) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
278        trace!("dial address: {}", address);
279
280        let (host_port, dns_name) = host_and_dnsname(&address)?;
281
282        let mut inner_addr = address.clone();
283
284        let (use_tls, path) =
285            match inner_addr.pop() {
286                Some(Protocol::Ws(path)) => (false, path),
287                Some(Protocol::Wss(path)) => {
288                    if dns_name.is_none() {
289                        debug!("no DNS name in {}", address);
290                        return Err(Error::InvalidMultiaddr(address))
291                    }
292                    (true, path)
293                }
294                _ => {
295                    debug!("{} is not a websocket multiaddr", address);
296                    return Err(Error::InvalidMultiaddr(address))
297                }
298            };
299
300        let dial = self.transport.dial(inner_addr)
301            .map_err(|e| match e {
302                TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
303                TransportError::Other(e) => Error::Transport(e)
304            })?;
305
306        let stream = dial.map_err(Error::Transport).await?;
307        trace!("connected to {}", address);
308
309        let stream =
310            if use_tls { // begin TLS session
311                let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
312                trace!("starting TLS handshake with {}", address);
313                let stream = self.tls_config.client.connect(dns_name.as_ref(), stream)
314                    .map_err(|e| {
315                        debug!("TLS handshake with {} failed: {}", address, e);
316                        Error::Tls(tls::Error::from(e))
317                    })
318                    .await?;
319
320                let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream));
321                stream
322            } else { // continue with plain stream
323                EitherOutput::Second(stream)
324            };
325
326        trace!("sending websocket handshake request to {}", address);
327
328        let mut client = handshake::Client::new(stream, &host_port, path.as_ref());
329
330        if self.use_deflate {
331            client.add_extension(Box::new(Deflate::new(connection::Mode::Client)));
332        }
333
334        match client.handshake().map_err(|e| Error::Handshake(Box::new(e))).await? {
335            handshake::ServerResponse::Redirect { status_code, location } => {
336                debug!("received redirect ({}); location: {}", status_code, location);
337                Ok(Either::Left(location))
338            }
339            handshake::ServerResponse::Rejected { status_code } => {
340                let msg = format!("server rejected handshake; status code = {}", status_code);
341                Err(Error::Handshake(msg.into()))
342            }
343            handshake::ServerResponse::Accepted { .. } => {
344                trace!("websocket handshake with {} successful", address);
345                Ok(Either::Right(Connection::new(client.into_builder())))
346            }
347        }
348    }
349}
350
351// Extract host, port and optionally the DNS name from the given [`Multiaddr`].
352fn host_and_dnsname<T>(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), Error<T>> {
353    let mut iter = addr.iter();
354    match (iter.next(), iter.next()) {
355        (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) =>
356            Ok((format!("{}:{}", ip, port), None)),
357        (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) =>
358            Ok((format!("{}:{}", ip, port), None)),
359        (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) =>
360            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
361        (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) =>
362            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
363        (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) =>
364            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
365        _ => {
366            debug!("multi-address format not supported: {}", addr);
367            Err(Error::InvalidMultiaddr(addr.clone()))
368        }
369    }
370}
371
372// Given a location URL, build a new websocket [`Multiaddr`].
373fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
374    match Url::parse(location) {
375        Ok(url) => {
376            let mut a = Multiaddr::empty();
377            match url.host() {
378                Some(url::Host::Domain(h)) => {
379                    a.push(Protocol::Dns(h.into()))
380                }
381                Some(url::Host::Ipv4(ip)) => {
382                    a.push(Protocol::Ip4(ip))
383                }
384                Some(url::Host::Ipv6(ip)) => {
385                    a.push(Protocol::Ip6(ip))
386                }
387                None => return Err(Error::InvalidRedirectLocation)
388            }
389            if let Some(p) = url.port() {
390                a.push(Protocol::Tcp(p))
391            }
392            let s = url.scheme();
393            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
394                a.push(Protocol::Wss(url.path().into()))
395            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
396                a.push(Protocol::Ws(url.path().into()))
397            } else {
398                debug!("unsupported scheme: {}", s);
399                return Err(Error::InvalidRedirectLocation)
400            }
401            Ok(a)
402        }
403        Err(e) => {
404            debug!("failed to parse url as multi-address: {:?}", e);
405            Err(Error::InvalidRedirectLocation)
406        }
407    }
408}
409
410/// The websocket connection.
411pub struct Connection<T> {
412    receiver: BoxStream<'static, Result<IncomingData, connection::Error>>,
413    sender: Pin<Box<dyn Sink<OutgoingData, Error = connection::Error> + Send>>,
414    _marker: std::marker::PhantomData<T>
415}
416
417/// Data received over the websocket connection.
418#[derive(Debug, Clone)]
419pub enum IncomingData {
420    /// Binary application data.
421    Binary(Vec<u8>),
422    /// UTF-8 encoded application data.
423    Text(Vec<u8>),
424    /// PONG control frame data.
425    Pong(Vec<u8>)
426}
427
428impl IncomingData {
429    pub fn is_data(&self) -> bool {
430        self.is_binary() || self.is_text()
431    }
432
433    pub fn is_binary(&self) -> bool {
434        if let IncomingData::Binary(_) = self { true } else { false }
435    }
436
437    pub fn is_text(&self) -> bool {
438        if let IncomingData::Text(_) = self { true } else { false }
439    }
440
441    pub fn is_pong(&self) -> bool {
442        if let IncomingData::Pong(_) = self { true } else { false }
443    }
444
445    pub fn into_bytes(self) -> Vec<u8> {
446        match self {
447            IncomingData::Binary(d) => d,
448            IncomingData::Text(d) => d,
449            IncomingData::Pong(d) => d
450        }
451    }
452}
453
454impl AsRef<[u8]> for IncomingData {
455    fn as_ref(&self) -> &[u8] {
456        match self {
457            IncomingData::Binary(d) => d,
458            IncomingData::Text(d) => d,
459            IncomingData::Pong(d) => d
460        }
461    }
462}
463
464/// Data sent over the websocket connection.
465#[derive(Debug, Clone)]
466pub enum OutgoingData {
467    /// Send some bytes.
468    Binary(Vec<u8>),
469    /// Send a PING message.
470    Ping(Vec<u8>),
471    /// Send an unsolicited PONG message.
472    /// (Incoming PINGs are answered automatically.)
473    Pong(Vec<u8>)
474}
475
476impl<T> fmt::Debug for Connection<T> {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        f.write_str("Connection")
479    }
480}
481
482impl<T> Connection<T>
483where
484    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
485{
486    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
487        let (sender, receiver) = builder.finish();
488        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
489            match action {
490                quicksink::Action::Send(OutgoingData::Binary(x)) => {
491                    sender.send_binary_mut(x).await?
492                }
493                quicksink::Action::Send(OutgoingData::Ping(x)) => {
494                    let data = x[..].try_into().map_err(|_| {
495                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
496                    })?;
497                    sender.send_ping(data).await?
498                }
499                quicksink::Action::Send(OutgoingData::Pong(x)) => {
500                    let data = x[..].try_into().map_err(|_| {
501                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
502                    })?;
503                    sender.send_pong(data).await?
504                }
505                quicksink::Action::Flush => sender.flush().await?,
506                quicksink::Action::Close => sender.close().await?
507            }
508            Ok(sender)
509        });
510        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
511            match receiver.receive(&mut data).await {
512                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => {
513                    Some((Ok(IncomingData::Text(mem::take(&mut data))), (data, receiver)))
514                }
515                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => {
516                    Some((Ok(IncomingData::Binary(mem::take(&mut data))), (data, receiver)))
517                }
518                Ok(soketto::Incoming::Pong(pong)) => {
519                    Some((Ok(IncomingData::Pong(Vec::from(pong))), (data, receiver)))
520                }
521                Err(connection::Error::Closed) => None,
522                Err(e) => Some((Err(e), (data, receiver)))
523            }
524        });
525        Connection {
526            receiver: stream.boxed(),
527            sender: Box::pin(sink),
528            _marker: std::marker::PhantomData
529        }
530    }
531
532    /// Send binary application data to the remote.
533    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
534        self.send(OutgoingData::Binary(data))
535    }
536
537    /// Send a PING to the remote.
538    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
539        self.send(OutgoingData::Ping(data))
540    }
541
542    /// Send an unsolicited PONG to the remote.
543    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
544        self.send(OutgoingData::Pong(data))
545    }
546}
547
548impl<T> Stream for Connection<T>
549where
550    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
551{
552    type Item = io::Result<IncomingData>;
553
554    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
555        let item = ready!(self.receiver.poll_next_unpin(cx));
556        let item = item.map(|result| {
557            result.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
558        });
559        Poll::Ready(item)
560    }
561}
562
563impl<T> Sink<OutgoingData> for Connection<T>
564where
565    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
566{
567    type Error = io::Error;
568
569    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
570        Pin::new(&mut self.sender)
571            .poll_ready(cx)
572            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
573    }
574
575    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
576        Pin::new(&mut self.sender)
577            .start_send(item)
578            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
579    }
580
581    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
582        Pin::new(&mut self.sender)
583            .poll_flush(cx)
584            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
585    }
586
587    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
588        Pin::new(&mut self.sender)
589            .poll_close(cx)
590            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
591    }
592}