libp2p_websocket/
lib.rs

1// Copyright 2017-2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Implementation of the libp2p `Transport` trait for Websockets.
22
23pub mod error;
24pub mod framed;
25pub mod tls;
26
27use error::Error;
28use framed::Connection;
29use futures::{future::BoxFuture, prelude::*, stream::BoxStream, ready};
30use libp2p_core::{
31    ConnectedPoint,
32    Transport,
33    multiaddr::Multiaddr,
34    transport::{map::{MapFuture, MapStream}, ListenerEvent, TransportError}
35};
36use rw_stream_sink::RwStreamSink;
37use std::{io, pin::Pin, task::{Context, Poll}};
38
39/// A Websocket transport.
40#[derive(Debug, Clone)]
41pub struct WsConfig<T> {
42    transport: framed::WsConfig<T>
43}
44
45impl<T> WsConfig<T> {
46    /// Create a new websocket transport based on the given transport.
47    pub fn new(transport: T) -> Self {
48        framed::WsConfig::new(transport).into()
49    }
50
51    /// Return the configured maximum number of redirects.
52    pub fn max_redirects(&self) -> u8 {
53        self.transport.max_redirects()
54    }
55
56    /// Set max. number of redirects to follow.
57    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
58        self.transport.set_max_redirects(max);
59        self
60    }
61
62    /// Get the max. frame data size we support.
63    pub fn max_data_size(&self) -> usize {
64        self.transport.max_data_size()
65    }
66
67    /// Set the max. frame data size we support.
68    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
69        self.transport.set_max_data_size(size);
70        self
71    }
72
73    /// Set the TLS configuration if TLS support is desired.
74    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
75        self.transport.set_tls_config(c);
76        self
77    }
78
79    /// Should the deflate extension (RFC 7692) be used if supported?
80    pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
81        self.transport.use_deflate(flag);
82        self
83    }
84}
85
86impl<T> From<framed::WsConfig<T>> for WsConfig<T> {
87    fn from(framed: framed::WsConfig<T>) -> Self {
88        WsConfig {
89            transport: framed
90        }
91    }
92}
93
94impl<T> Transport for WsConfig<T>
95where
96    T: Transport + Send + Clone + 'static,
97    T::Error: Send + 'static,
98    T::Dial: Send + 'static,
99    T::Listener: Send + 'static,
100    T::ListenerUpgrade: Send + 'static,
101    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
102{
103    type Output = RwStreamSink<BytesConnection<T::Output>>;
104    type Error = Error<T::Error>;
105    type Listener = MapStream<InnerStream<T::Output, T::Error>, WrapperFn<T::Output>>;
106    type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
107    type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
108
109    fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
110        self.transport.map(wrap_connection as WrapperFn<T::Output>).listen_on(addr)
111    }
112
113    fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
114        self.transport.map(wrap_connection as WrapperFn<T::Output>).dial(addr)
115    }
116
117    fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
118        self.transport.address_translation(server, observed)
119    }
120}
121
122/// Type alias corresponding to `framed::WsConfig::Listener`.
123pub type InnerStream<T, E> = BoxStream<'static, Result<ListenerEvent<InnerFuture<T, E>, Error<E>>, Error<E>>>;
124
125/// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`.
126pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
127
128/// Function type that wraps a websocket connection (see. `wrap_connection`).
129pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
130
131/// Wrap a websocket connection producing data frames into a `RwStreamSink`
132/// implementing `AsyncRead` + `AsyncWrite`.
133fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
134where
135    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
136{
137    RwStreamSink::new(BytesConnection(c))
138}
139
140/// The websocket connection.
141#[derive(Debug)]
142pub struct BytesConnection<T>(Connection<T>);
143
144impl<T> Stream for BytesConnection<T>
145where
146    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
147{
148    type Item = io::Result<Vec<u8>>;
149
150    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151        loop {
152            if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
153                if item.is_data() {
154                    return Poll::Ready(Some(Ok(item.into_bytes())))
155                }
156            } else {
157                return Poll::Ready(None)
158            }
159        }
160    }
161}
162
163impl<T> Sink<Vec<u8>> for BytesConnection<T>
164where
165    T: AsyncRead + AsyncWrite + Send + Unpin + 'static
166{
167    type Error = io::Error;
168
169    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
170        Pin::new(&mut self.0).poll_ready(cx)
171    }
172
173    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
174        Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
175    }
176
177    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        Pin::new(&mut self.0).poll_flush(cx)
179    }
180
181    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182        Pin::new(&mut self.0).poll_close(cx)
183    }
184}
185
186// Tests //////////////////////////////////////////////////////////////////////////////////////////
187
188#[cfg(test)]
189mod tests {
190    use libp2p_core::Multiaddr;
191    use libp2p_tcp as tcp;
192    use futures::prelude::*;
193    use libp2p_core::{Transport, multiaddr::Protocol};
194    use super::WsConfig;
195
196    #[test]
197    fn dialer_connects_to_listener_ipv4() {
198        let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
199        futures::executor::block_on(connect(a))
200    }
201
202    #[test]
203    fn dialer_connects_to_listener_ipv6() {
204        let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
205        futures::executor::block_on(connect(a))
206    }
207
208    async fn connect(listen_addr: Multiaddr) {
209        let ws_config = WsConfig::new(tcp::TcpConfig::new());
210
211        let mut listener = ws_config.clone()
212            .listen_on(listen_addr)
213            .expect("listener");
214
215        let addr = listener.try_next().await
216            .expect("some event")
217            .expect("no error")
218            .into_new_address()
219            .expect("listen address");
220
221        assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
222        assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
223
224        let inbound = async move {
225            let (conn, _addr) = listener.try_filter_map(|e| future::ready(Ok(e.into_upgrade())))
226                .try_next()
227                .await
228                .unwrap()
229                .unwrap();
230            conn.await
231        };
232
233        let outbound = ws_config.dial(addr).unwrap();
234
235        let (a, b) = futures::join!(inbound, outbound);
236        a.and(b).unwrap();
237    }
238}