libp2prs_websocket/
lib.rs

1// Copyright 2017-2019 Parity Technologies (UK) Ltd.
2// Copyright 2020 Netwarps Ltd.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Implementation of the libp2p `Transport` trait for Websockets.
23
24// pub mod connection;
25pub mod connection;
26pub mod error;
27pub mod framed;
28pub mod tls;
29
30use async_trait::async_trait;
31
32use libp2prs_core::transport::{IListener, ITransport};
33use libp2prs_core::{multiaddr::Multiaddr, transport::TransportError, Transport};
34use libp2prs_dns::DnsConfig;
35use libp2prs_tcp::{TcpConfig, TcpTransStream};
36
37use crate::connection::TlsOrPlain;
38
39/// A Websocket transport.
40#[derive(Clone)]
41pub struct WsConfig {
42    inner: framed::WsConfig,
43}
44
45impl WsConfig {
46    /// Create a new websocket transport based on the tcp transport.
47    pub fn new() -> Self {
48        framed::WsConfig::new(TcpConfig::default().box_clone()).into()
49    }
50
51    /// Create a new websocket transport based on the dns transport.
52    pub fn new_with_dns() -> Self {
53        framed::WsConfig::new(DnsConfig::new(TcpConfig::default()).box_clone()).into()
54    }
55
56    /// Return the configured maximum number of redirects.
57    pub fn max_redirects(&self) -> u8 {
58        self.inner.inner_config.max_redirects()
59    }
60
61    /// Set max. number of redirects to follow.
62    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
63        self.inner.inner_config.set_max_redirects(max);
64        self
65    }
66
67    /// Get the max. frame data size we support.
68    pub fn max_data_size(&self) -> usize {
69        self.inner.inner_config.max_data_size()
70    }
71
72    /// Set the max. frame data size we support.
73    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
74        self.inner.inner_config.set_max_data_size(size);
75        self
76    }
77
78    /// Set the TLS configuration if TLS support is desired.
79    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
80        self.inner.inner_config.set_tls_config(c);
81        self
82    }
83
84    /// Should the deflate extension (RFC 7692) be used if supported?
85    pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
86        self.inner.inner_config.use_deflate(flag);
87        self
88    }
89}
90
91impl Default for WsConfig {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl From<framed::WsConfig> for WsConfig {
98    fn from(framed: framed::WsConfig) -> Self {
99        WsConfig { inner: framed }
100    }
101}
102
103#[async_trait]
104impl Transport for WsConfig {
105    type Output = connection::Connection<TlsOrPlain<TcpTransStream>>;
106    fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
107        self.inner.listen_on(addr)
108    }
109
110    async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
111        self.inner.dial(addr).await
112    }
113
114    fn box_clone(&self) -> ITransport<Self::Output> {
115        Box::new(self.clone())
116    }
117
118    fn protocols(&self) -> Vec<u32> {
119        self.inner.protocols()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::WsConfig;
126    use futures::{AsyncReadExt, AsyncWriteExt};
127    use libp2prs_core::transport::ListenerEvent;
128    use libp2prs_core::Multiaddr;
129    use libp2prs_core::Transport;
130    use libp2prs_runtime::task;
131
132    #[test]
133    fn dialer_connects_to_listener_ipv4() {
134        env_logger::builder().is_test(true).filter_level(log::LevelFilter::Debug).init();
135        let listen_addr = "/ip4/127.0.0.1/tcp/38099/ws".parse().unwrap();
136        let dial_addr = "/ip4/127.0.0.1/tcp/38099/ws".parse().unwrap();
137        let s = task::spawn(async { server(listen_addr).await });
138        let c = task::spawn(async { client(dial_addr, false).await });
139        task::block_on(async {
140            assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
141        });
142    }
143
144    #[test]
145    fn dialer_connects_to_listener_dns() {
146        let listen_addr = "/ip4/127.0.0.1/tcp/38100/ws".parse().unwrap();
147        let dial_addr = "/dns4/localhost/tcp/38100/ws".parse().unwrap();
148        let s = task::spawn(async { server(listen_addr).await });
149        let c = task::spawn(async { client(dial_addr, true).await });
150        task::block_on(async {
151            assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
152        });
153    }
154
155    #[test]
156    fn dialer_connects_to_listener_ipv6() {
157        let listen_addr = "/ip6/::1/tcp/38101/ws".parse().unwrap();
158        let dial_addr = "/ip6/::1/tcp/38101/ws".parse().unwrap();
159        let s = task::spawn(async { server(listen_addr).await });
160        let c = task::spawn(async { client(dial_addr, false).await });
161        task::block_on(async {
162            assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
163        });
164    }
165
166    async fn server(listen_addr: Multiaddr) -> bool {
167        let ws_config: WsConfig = WsConfig::new();
168        let mut listener = ws_config
169            .clone()
170            .timeout(std::time::Duration::from_secs(5))
171            .listen_on(listen_addr.clone())
172            .expect("listener");
173
174        let mut stream = match listener.accept().await.expect("no error") {
175            ListenerEvent::Accepted(s) => s,
176            _ => panic!("unreachable"),
177        };
178        let mut buf = vec![0_u8; 3];
179
180        stream.read_exact(&mut buf).await.expect("read_exact");
181        log::info!("{:?}", buf);
182        vec![1, 23, 5] == buf
183    }
184
185    async fn client(dial_addr: Multiaddr, dns: bool) -> bool {
186        let ws_config: WsConfig;
187        if dns {
188            ws_config = WsConfig::new_with_dns();
189        } else {
190            ws_config = WsConfig::new();
191        }
192        task::sleep(std::time::Duration::from_millis(200)).await;
193        let conn = ws_config.timeout(std::time::Duration::from_secs(5)).dial(dial_addr.clone()).await;
194        let mut conn = conn.expect("");
195        let data = vec![1_u8, 23, 5];
196        log::debug!("[Client] write data {:?}", data);
197        conn.write_all(&data).await.expect("write all");
198        let r = conn.close().await;
199        r.is_ok()
200    }
201}