libp2prs_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2// Copyright 2020 Netwarps Ltd.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22// use crate::connection::{Connection, TlsOrPlain};
23use crate::connection::{Connection, TlsClientStream, TlsOrPlain, TlsServerStream};
24use crate::{error::WsError, tls};
25use async_trait::async_trait;
26use either::Either;
27use futures::prelude::*;
28use libp2prs_core::transport::{ConnectionInfo, ListenerEvent};
29use libp2prs_core::transport::{IListener, ITransport};
30use libp2prs_core::{
31    either::EitherOutput,
32    multiaddr::{protocol, protocol::Protocol, Multiaddr},
33    transport::{TransportError, TransportListener},
34    Transport,
35};
36use libp2prs_tcp::TcpTransStream;
37use log::{debug, info, trace};
38use soketto::{connection, extension::deflate::Deflate, handshake};
39use std::fmt;
40use url::Url;
41
42/// Max. number of payload bytes of a single frame.
43const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
44
45/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
46/// frame payloads which does not implement [`AsyncRead`] or
47/// [`AsyncWrite`]. See [`crate::WsConfig`] if you require the latter.
48
49#[derive(Clone)]
50pub struct WsConfig {
51    transport: ITransport<TcpTransStream>,
52    pub(crate) inner_config: InnerConfig,
53}
54
55impl fmt::Debug for WsConfig {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("WsConfig").field("Config", &self.inner_config).finish()
58    }
59}
60
61impl WsConfig {
62    /// Create a new websocket transport based on another transport.
63    pub fn new(transport: ITransport<TcpTransStream>) -> Self {
64        WsConfig {
65            transport,
66            inner_config: InnerConfig::new(),
67        }
68    }
69}
70
71#[derive(Debug, Clone)]
72pub(crate) struct InnerConfig {
73    max_data_size: usize,
74    tls_config: tls::Config,
75    max_redirects: u8,
76    use_deflate: bool,
77}
78
79impl InnerConfig {
80    /// Create a new websocket transport based on another transport.
81    pub fn new() -> Self {
82        InnerConfig {
83            max_data_size: MAX_DATA_SIZE,
84            tls_config: tls::Config::client(),
85            max_redirects: 0,
86            use_deflate: false,
87        }
88    }
89
90    /// Return the configured maximum number of redirects.
91    pub fn max_redirects(&self) -> u8 {
92        self.max_redirects
93    }
94
95    /// Set max. number of redirects to follow.
96    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
97        self.max_redirects = max;
98        self
99    }
100
101    /// Get the max. frame data size we support.
102    pub fn max_data_size(&self) -> usize {
103        self.max_data_size
104    }
105
106    /// Set the max. frame data size we support.
107    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
108        self.max_data_size = size;
109        self
110    }
111
112    /// Set the TLS configuration if TLS support is desired.
113    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
114        self.tls_config = c;
115        self
116    }
117
118    /// Should the deflate extension (RFC 7692) be used if supported?
119    pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
120        self.use_deflate = flag;
121        self
122    }
123}
124
125pub struct WsTransListener {
126    inner: IListener<TcpTransStream>,
127    inner_config: InnerConfig,
128    use_tls: bool,
129}
130
131impl fmt::Debug for WsTransListener {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_struct("WsTransListener")
134            .field("Config", &self.inner_config)
135            .field("tls", &self.use_tls)
136            .finish()
137    }
138}
139
140impl WsTransListener {
141    pub(crate) fn new(inner: IListener<TcpTransStream>, inner_config: InnerConfig, use_tls: bool) -> Self {
142        Self {
143            inner,
144            inner_config,
145            use_tls,
146        }
147    }
148}
149
150#[async_trait]
151impl TransportListener for WsTransListener {
152    type Output = Connection<TlsOrPlain<TcpTransStream>>;
153    async fn accept(&mut self) -> Result<ListenerEvent<Self::Output>, TransportError> {
154        let raw_stream = match self.inner.accept().await? {
155            ListenerEvent::Accepted(stream) => stream,
156            ListenerEvent::AddressAdded(a) => return Ok(ListenerEvent::AddressAdded(a)),
157            ListenerEvent::AddressDeleted(a) => return Ok(ListenerEvent::AddressDeleted(a)),
158        };
159        let local_addr = raw_stream.local_multiaddr();
160        let remote_addr = raw_stream.remote_multiaddr();
161        let remote1 = remote_addr.clone(); // used for logging
162        let remote2 = remote_addr.clone(); // used for logging
163        let tls_config = self.inner_config.tls_config.clone();
164        trace!("[Server] incoming connection from {}", remote1);
165        let stream = if self.use_tls {
166            // begin TLS session
167            let server = tls_config.server.expect("for use_tls we checked server is not none");
168            trace!("[Server] awaiting TLS handshake with {}", remote1);
169            let stream = server.accept(raw_stream).await.map_err(move |e| {
170                debug!("[Server] TLS handshake with {} failed: {}", remote1, e);
171                WsError::Tls(tls::Error::from(e))
172            })?;
173
174            let stream: TlsOrPlain<_> = EitherOutput::A(EitherOutput::B(TlsServerStream(stream)));
175            stream
176        } else {
177            // continue with plain stream
178            EitherOutput::B(raw_stream)
179        };
180
181        trace!("[Server] receiving websocket handshake request from {}", remote2);
182        let mut server = handshake::Server::new(stream);
183
184        if self.inner_config.use_deflate {
185            server.add_extension(Box::new(Deflate::new(connection::Mode::Server)));
186        }
187
188        let ws_key = {
189            let request = server.receive_request().await.map_err(|e| WsError::Handshake(Box::new(e)))?;
190            request.into_key()
191        };
192
193        debug!("[Server] accepting websocket handshake request from {}", remote2);
194
195        let response = handshake::server::Response::Accept {
196            key: &ws_key,
197            protocol: None,
198        };
199
200        server.send_response(&response).await.map_err(|e| WsError::Handshake(Box::new(e)))?;
201
202        let conn = {
203            let mut builder = server.into_builder();
204            builder.set_max_message_size(self.inner_config.max_data_size);
205            builder.set_max_frame_size(self.inner_config.max_data_size);
206            Connection::new(builder, local_addr, remote_addr)
207        };
208        Ok(ListenerEvent::Accepted(conn))
209    }
210
211    fn multi_addr(&self) -> Option<&Multiaddr> {
212        self.inner.multi_addr()
213    }
214}
215
216#[async_trait]
217impl Transport for WsConfig {
218    type Output = Connection<TlsOrPlain<TcpTransStream>>;
219    fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
220        log::debug!("WebSocket listen on addr: {}", addr);
221        let mut inner_addr = addr.clone();
222
223        let (use_tls, _proto) = match inner_addr.pop() {
224            Some(p @ Protocol::Wss(_)) => {
225                if self.inner_config.tls_config.server.is_some() {
226                    (true, p)
227                } else {
228                    debug!("/wss address but TLS server support is not configured");
229                    return Err(TransportError::MultiaddrNotSupported(addr));
230                }
231            }
232            Some(p @ Protocol::Ws(_)) => (false, p),
233            _ => {
234                debug!("{} is not a websocket multiaddr", addr);
235                return Err(TransportError::MultiaddrNotSupported(addr));
236            }
237        };
238        let inner_listener = self.transport.listen_on(addr)?;
239        let listener = WsTransListener::new(inner_listener, self.inner_config.clone(), use_tls);
240        Ok(Box::new(listener))
241    }
242
243    async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
244        // Quick sanity check of the provided Multiaddr.
245        if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
246            // ok
247        } else {
248            debug!("{} is not a websocket multiaddr", addr);
249            return Err(TransportError::MultiaddrNotSupported(addr));
250        }
251
252        // We are looping here in order to follow redirects (if any):
253        let mut remaining_redirects = self.inner_config.max_redirects;
254        let mut addr = addr;
255        loop {
256            match self.dial_once(addr).await {
257                Ok(Either::Left(redirect)) => {
258                    if remaining_redirects == 0 {
259                        debug!("too many redirects");
260                        return Err(WsError::TooManyRedirects.into());
261                    }
262                    remaining_redirects -= 1;
263                    addr = location_to_multiaddr(&redirect)?;
264                }
265                Ok(Either::Right(conn)) => return Ok(conn),
266                Err(e) => {
267                    debug!("websocket transport dial error:{}", e);
268                    return Err(e.into());
269                }
270            }
271        }
272    }
273
274    fn box_clone(&self) -> ITransport<Self::Output> {
275        Box::new(self.clone())
276    }
277
278    fn protocols(&self) -> Vec<u32> {
279        vec![protocol::WS, protocol::WSS]
280    }
281}
282
283impl WsConfig {
284    /// Attempty to dial the given address and perform a websocket handshake.
285    async fn dial_once(&mut self, address: Multiaddr) -> Result<Either<String, Connection<TlsOrPlain<TcpTransStream>>>, WsError> {
286        // trace!("[Client] dial address: {}", address);
287        debug!("[Client] dial address: {}", address);
288        let (host_port, dns_name) = host_and_dnsname(&address)?;
289        if dns_name.is_some() {
290            trace!("[Client] host_port: {:?}  dns_name:{:?}", host_port, dns_name.clone().unwrap());
291        }
292        let mut inner_addr = address.clone();
293
294        let (use_tls, path) = match inner_addr.pop() {
295            Some(Protocol::Ws(path)) => (false, path),
296            Some(Protocol::Wss(path)) => {
297                if dns_name.is_none() {
298                    debug!("[Client] no DNS name in {}", address);
299                    return Err(WsError::InvalidMultiaddr(address));
300                };
301                (true, path)
302            }
303            _ => {
304                debug!("[Client] {} is not a websocket multiaddr", address);
305                return Err(WsError::InvalidMultiaddr(address));
306            }
307        };
308
309        let raw_stream = self.transport.dial(inner_addr).await.map_err(WsError::Transport)?;
310        // trace!("[Client] connected to {}", address);
311        debug!("[Client] connected to {}", address);
312        let local_addr = raw_stream.local_multiaddr();
313        let remote_addr = raw_stream.remote_multiaddr();
314        let stream = if use_tls {
315            // begin TLS session
316            let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
317            trace!("[Client] starting TLS handshake with {}", address);
318            let stream = self
319                .inner_config
320                .tls_config
321                .client
322                .connect(&dns_name, raw_stream)
323                .await
324                .map_err(|e| {
325                    debug!("[Client] TLS handshake with {} failed: {}", address, e);
326                    WsError::Tls(tls::Error::from(e))
327                })?;
328
329            let stream = TlsClientStream(stream);
330
331            let stream: TlsOrPlain<_> = EitherOutput::A(EitherOutput::A(stream));
332            stream
333        } else {
334            // continue with plain stream
335            EitherOutput::B(raw_stream)
336        };
337
338        // trace!("[Client] sending websocket handshake request to {}", address);
339        debug!("[Client] sending websocket handshake request to {}", address);
340
341        let mut client = handshake::Client::new(stream, &host_port, path.as_ref());
342
343        if self.inner_config.use_deflate {
344            client.add_extension(Box::new(Deflate::new(connection::Mode::Client)));
345        }
346
347        match client
348            .handshake()
349            .map_err(|e| {
350                info!("[Client] {:?}", e);
351                WsError::Handshake(Box::new(e))
352            })
353            .await?
354        {
355            handshake::ServerResponse::Redirect { status_code, location } => {
356                debug!("[Client] received redirect ({}); location: {}", status_code, location);
357                Ok(Either::Left(location))
358            }
359            handshake::ServerResponse::Rejected { status_code } => {
360                let msg = format!("[Client] server rejected handshake; status code = {}", status_code);
361                Err(WsError::Handshake(msg.into()))
362            }
363            handshake::ServerResponse::Accepted { .. } => {
364                debug!("[Client] websocket handshake with {} successful", address);
365                Ok(Either::Right(Connection::new(client.into_builder(), local_addr, remote_addr)))
366            }
367        }
368    }
369}
370
371impl From<WsError> for TransportError {
372    fn from(e: WsError) -> Self {
373        match e {
374            WsError::InvalidMultiaddr(a) => TransportError::MultiaddrNotSupported(a),
375            _ => TransportError::WsError(Box::new(e)),
376        }
377    }
378}
379
380// Extract host, port and optionally the DNS name from the given [`Multiaddr`].
381fn host_and_dnsname(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), WsError> {
382    let mut iter = addr.iter();
383    match (iter.next(), iter.next()) {
384        (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => Ok((format!("{}:{}", ip, port), None)),
385        (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => Ok((format!("{}:{}", ip, port), None)),
386        (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) => {
387            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
388        }
389        (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) => {
390            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
391        }
392        (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
393            Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
394        }
395        _ => {
396            debug!("multi-address format not supported: {}", addr);
397            Err(WsError::InvalidMultiaddr(addr.clone()))
398        }
399    }
400}
401
402// Given a location URL, build a new websocket [`Multiaddr`].
403fn location_to_multiaddr(location: &str) -> Result<Multiaddr, WsError> {
404    match Url::parse(location) {
405        Ok(url) => {
406            let mut a = Multiaddr::empty();
407            match url.host() {
408                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
409                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
410                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
411                None => return Err(WsError::InvalidRedirectLocation),
412            }
413            if let Some(p) = url.port() {
414                a.push(Protocol::Tcp(p))
415            }
416            let s = url.scheme();
417            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
418                a.push(Protocol::Wss(url.path().into()))
419            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
420                a.push(Protocol::Ws(url.path().into()))
421            } else {
422                debug!("unsupported scheme: {}", s);
423                return Err(WsError::InvalidRedirectLocation);
424            }
425            Ok(a)
426        }
427        Err(e) => {
428            debug!("failed to parse url as multi-address: {:?}", e);
429            Err(WsError::InvalidRedirectLocation)
430        }
431    }
432}