libp2prs_websocket/
lib.rs1pub mod connection;
26pub mod error;
27pub mod framed;
28pub mod tls;
29
30use async_trait::async_trait;
31
32use libp2prs_core::transport::{IListener, ITransport};
33use libp2prs_core::{multiaddr::Multiaddr, transport::TransportError, Transport};
34use libp2prs_dns::DnsConfig;
35use libp2prs_tcp::{TcpConfig, TcpTransStream};
36
37use crate::connection::TlsOrPlain;
38
39#[derive(Clone)]
41pub struct WsConfig {
42 inner: framed::WsConfig,
43}
44
45impl WsConfig {
46 pub fn new() -> Self {
48 framed::WsConfig::new(TcpConfig::default().box_clone()).into()
49 }
50
51 pub fn new_with_dns() -> Self {
53 framed::WsConfig::new(DnsConfig::new(TcpConfig::default()).box_clone()).into()
54 }
55
56 pub fn max_redirects(&self) -> u8 {
58 self.inner.inner_config.max_redirects()
59 }
60
61 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
63 self.inner.inner_config.set_max_redirects(max);
64 self
65 }
66
67 pub fn max_data_size(&self) -> usize {
69 self.inner.inner_config.max_data_size()
70 }
71
72 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
74 self.inner.inner_config.set_max_data_size(size);
75 self
76 }
77
78 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
80 self.inner.inner_config.set_tls_config(c);
81 self
82 }
83
84 pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
86 self.inner.inner_config.use_deflate(flag);
87 self
88 }
89}
90
91impl Default for WsConfig {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl From<framed::WsConfig> for WsConfig {
98 fn from(framed: framed::WsConfig) -> Self {
99 WsConfig { inner: framed }
100 }
101}
102
103#[async_trait]
104impl Transport for WsConfig {
105 type Output = connection::Connection<TlsOrPlain<TcpTransStream>>;
106 fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
107 self.inner.listen_on(addr)
108 }
109
110 async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
111 self.inner.dial(addr).await
112 }
113
114 fn box_clone(&self) -> ITransport<Self::Output> {
115 Box::new(self.clone())
116 }
117
118 fn protocols(&self) -> Vec<u32> {
119 self.inner.protocols()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::WsConfig;
126 use futures::{AsyncReadExt, AsyncWriteExt};
127 use libp2prs_core::transport::ListenerEvent;
128 use libp2prs_core::Multiaddr;
129 use libp2prs_core::Transport;
130 use libp2prs_runtime::task;
131
132 #[test]
133 fn dialer_connects_to_listener_ipv4() {
134 env_logger::builder().is_test(true).filter_level(log::LevelFilter::Debug).init();
135 let listen_addr = "/ip4/127.0.0.1/tcp/38099/ws".parse().unwrap();
136 let dial_addr = "/ip4/127.0.0.1/tcp/38099/ws".parse().unwrap();
137 let s = task::spawn(async { server(listen_addr).await });
138 let c = task::spawn(async { client(dial_addr, false).await });
139 task::block_on(async {
140 assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
141 });
142 }
143
144 #[test]
145 fn dialer_connects_to_listener_dns() {
146 let listen_addr = "/ip4/127.0.0.1/tcp/38100/ws".parse().unwrap();
147 let dial_addr = "/dns4/localhost/tcp/38100/ws".parse().unwrap();
148 let s = task::spawn(async { server(listen_addr).await });
149 let c = task::spawn(async { client(dial_addr, true).await });
150 task::block_on(async {
151 assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
152 });
153 }
154
155 #[test]
156 fn dialer_connects_to_listener_ipv6() {
157 let listen_addr = "/ip6/::1/tcp/38101/ws".parse().unwrap();
158 let dial_addr = "/ip6/::1/tcp/38101/ws".parse().unwrap();
159 let s = task::spawn(async { server(listen_addr).await });
160 let c = task::spawn(async { client(dial_addr, false).await });
161 task::block_on(async {
162 assert_eq!(futures::join!(s, c), (Some(true), Some(true)));
163 });
164 }
165
166 async fn server(listen_addr: Multiaddr) -> bool {
167 let ws_config: WsConfig = WsConfig::new();
168 let mut listener = ws_config
169 .clone()
170 .timeout(std::time::Duration::from_secs(5))
171 .listen_on(listen_addr.clone())
172 .expect("listener");
173
174 let mut stream = match listener.accept().await.expect("no error") {
175 ListenerEvent::Accepted(s) => s,
176 _ => panic!("unreachable"),
177 };
178 let mut buf = vec![0_u8; 3];
179
180 stream.read_exact(&mut buf).await.expect("read_exact");
181 log::info!("{:?}", buf);
182 vec![1, 23, 5] == buf
183 }
184
185 async fn client(dial_addr: Multiaddr, dns: bool) -> bool {
186 let ws_config: WsConfig;
187 if dns {
188 ws_config = WsConfig::new_with_dns();
189 } else {
190 ws_config = WsConfig::new();
191 }
192 task::sleep(std::time::Duration::from_millis(200)).await;
193 let conn = ws_config.timeout(std::time::Duration::from_secs(5)).dial(dial_addr.clone()).await;
194 let mut conn = conn.expect("");
195 let data = vec![1_u8, 23, 5];
196 log::debug!("[Client] write data {:?}", data);
197 conn.write_all(&data).await.expect("write all");
198 let r = conn.close().await;
199 r.is_ok()
200 }
201}