1use crate::connection::{Connection, TlsClientStream, TlsOrPlain, TlsServerStream};
24use crate::{error::WsError, tls};
25use async_trait::async_trait;
26use either::Either;
27use futures::prelude::*;
28use libp2prs_core::transport::{ConnectionInfo, ListenerEvent};
29use libp2prs_core::transport::{IListener, ITransport};
30use libp2prs_core::{
31 either::EitherOutput,
32 multiaddr::{protocol, protocol::Protocol, Multiaddr},
33 transport::{TransportError, TransportListener},
34 Transport,
35};
36use libp2prs_tcp::TcpTransStream;
37use log::{debug, info, trace};
38use soketto::{connection, extension::deflate::Deflate, handshake};
39use std::fmt;
40use url::Url;
41
42const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
44
45#[derive(Clone)]
50pub struct WsConfig {
51 transport: ITransport<TcpTransStream>,
52 pub(crate) inner_config: InnerConfig,
53}
54
55impl fmt::Debug for WsConfig {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("WsConfig").field("Config", &self.inner_config).finish()
58 }
59}
60
61impl WsConfig {
62 pub fn new(transport: ITransport<TcpTransStream>) -> Self {
64 WsConfig {
65 transport,
66 inner_config: InnerConfig::new(),
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
72pub(crate) struct InnerConfig {
73 max_data_size: usize,
74 tls_config: tls::Config,
75 max_redirects: u8,
76 use_deflate: bool,
77}
78
79impl InnerConfig {
80 pub fn new() -> Self {
82 InnerConfig {
83 max_data_size: MAX_DATA_SIZE,
84 tls_config: tls::Config::client(),
85 max_redirects: 0,
86 use_deflate: false,
87 }
88 }
89
90 pub fn max_redirects(&self) -> u8 {
92 self.max_redirects
93 }
94
95 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
97 self.max_redirects = max;
98 self
99 }
100
101 pub fn max_data_size(&self) -> usize {
103 self.max_data_size
104 }
105
106 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
108 self.max_data_size = size;
109 self
110 }
111
112 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
114 self.tls_config = c;
115 self
116 }
117
118 pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
120 self.use_deflate = flag;
121 self
122 }
123}
124
125pub struct WsTransListener {
126 inner: IListener<TcpTransStream>,
127 inner_config: InnerConfig,
128 use_tls: bool,
129}
130
131impl fmt::Debug for WsTransListener {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 f.debug_struct("WsTransListener")
134 .field("Config", &self.inner_config)
135 .field("tls", &self.use_tls)
136 .finish()
137 }
138}
139
140impl WsTransListener {
141 pub(crate) fn new(inner: IListener<TcpTransStream>, inner_config: InnerConfig, use_tls: bool) -> Self {
142 Self {
143 inner,
144 inner_config,
145 use_tls,
146 }
147 }
148}
149
150#[async_trait]
151impl TransportListener for WsTransListener {
152 type Output = Connection<TlsOrPlain<TcpTransStream>>;
153 async fn accept(&mut self) -> Result<ListenerEvent<Self::Output>, TransportError> {
154 let raw_stream = match self.inner.accept().await? {
155 ListenerEvent::Accepted(stream) => stream,
156 ListenerEvent::AddressAdded(a) => return Ok(ListenerEvent::AddressAdded(a)),
157 ListenerEvent::AddressDeleted(a) => return Ok(ListenerEvent::AddressDeleted(a)),
158 };
159 let local_addr = raw_stream.local_multiaddr();
160 let remote_addr = raw_stream.remote_multiaddr();
161 let remote1 = remote_addr.clone(); let remote2 = remote_addr.clone(); let tls_config = self.inner_config.tls_config.clone();
164 trace!("[Server] incoming connection from {}", remote1);
165 let stream = if self.use_tls {
166 let server = tls_config.server.expect("for use_tls we checked server is not none");
168 trace!("[Server] awaiting TLS handshake with {}", remote1);
169 let stream = server.accept(raw_stream).await.map_err(move |e| {
170 debug!("[Server] TLS handshake with {} failed: {}", remote1, e);
171 WsError::Tls(tls::Error::from(e))
172 })?;
173
174 let stream: TlsOrPlain<_> = EitherOutput::A(EitherOutput::B(TlsServerStream(stream)));
175 stream
176 } else {
177 EitherOutput::B(raw_stream)
179 };
180
181 trace!("[Server] receiving websocket handshake request from {}", remote2);
182 let mut server = handshake::Server::new(stream);
183
184 if self.inner_config.use_deflate {
185 server.add_extension(Box::new(Deflate::new(connection::Mode::Server)));
186 }
187
188 let ws_key = {
189 let request = server.receive_request().await.map_err(|e| WsError::Handshake(Box::new(e)))?;
190 request.into_key()
191 };
192
193 debug!("[Server] accepting websocket handshake request from {}", remote2);
194
195 let response = handshake::server::Response::Accept {
196 key: &ws_key,
197 protocol: None,
198 };
199
200 server.send_response(&response).await.map_err(|e| WsError::Handshake(Box::new(e)))?;
201
202 let conn = {
203 let mut builder = server.into_builder();
204 builder.set_max_message_size(self.inner_config.max_data_size);
205 builder.set_max_frame_size(self.inner_config.max_data_size);
206 Connection::new(builder, local_addr, remote_addr)
207 };
208 Ok(ListenerEvent::Accepted(conn))
209 }
210
211 fn multi_addr(&self) -> Option<&Multiaddr> {
212 self.inner.multi_addr()
213 }
214}
215
216#[async_trait]
217impl Transport for WsConfig {
218 type Output = Connection<TlsOrPlain<TcpTransStream>>;
219 fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
220 log::debug!("WebSocket listen on addr: {}", addr);
221 let mut inner_addr = addr.clone();
222
223 let (use_tls, _proto) = match inner_addr.pop() {
224 Some(p @ Protocol::Wss(_)) => {
225 if self.inner_config.tls_config.server.is_some() {
226 (true, p)
227 } else {
228 debug!("/wss address but TLS server support is not configured");
229 return Err(TransportError::MultiaddrNotSupported(addr));
230 }
231 }
232 Some(p @ Protocol::Ws(_)) => (false, p),
233 _ => {
234 debug!("{} is not a websocket multiaddr", addr);
235 return Err(TransportError::MultiaddrNotSupported(addr));
236 }
237 };
238 let inner_listener = self.transport.listen_on(addr)?;
239 let listener = WsTransListener::new(inner_listener, self.inner_config.clone(), use_tls);
240 Ok(Box::new(listener))
241 }
242
243 async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
244 if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
246 } else {
248 debug!("{} is not a websocket multiaddr", addr);
249 return Err(TransportError::MultiaddrNotSupported(addr));
250 }
251
252 let mut remaining_redirects = self.inner_config.max_redirects;
254 let mut addr = addr;
255 loop {
256 match self.dial_once(addr).await {
257 Ok(Either::Left(redirect)) => {
258 if remaining_redirects == 0 {
259 debug!("too many redirects");
260 return Err(WsError::TooManyRedirects.into());
261 }
262 remaining_redirects -= 1;
263 addr = location_to_multiaddr(&redirect)?;
264 }
265 Ok(Either::Right(conn)) => return Ok(conn),
266 Err(e) => {
267 debug!("websocket transport dial error:{}", e);
268 return Err(e.into());
269 }
270 }
271 }
272 }
273
274 fn box_clone(&self) -> ITransport<Self::Output> {
275 Box::new(self.clone())
276 }
277
278 fn protocols(&self) -> Vec<u32> {
279 vec![protocol::WS, protocol::WSS]
280 }
281}
282
283impl WsConfig {
284 async fn dial_once(&mut self, address: Multiaddr) -> Result<Either<String, Connection<TlsOrPlain<TcpTransStream>>>, WsError> {
286 debug!("[Client] dial address: {}", address);
288 let (host_port, dns_name) = host_and_dnsname(&address)?;
289 if dns_name.is_some() {
290 trace!("[Client] host_port: {:?} dns_name:{:?}", host_port, dns_name.clone().unwrap());
291 }
292 let mut inner_addr = address.clone();
293
294 let (use_tls, path) = match inner_addr.pop() {
295 Some(Protocol::Ws(path)) => (false, path),
296 Some(Protocol::Wss(path)) => {
297 if dns_name.is_none() {
298 debug!("[Client] no DNS name in {}", address);
299 return Err(WsError::InvalidMultiaddr(address));
300 };
301 (true, path)
302 }
303 _ => {
304 debug!("[Client] {} is not a websocket multiaddr", address);
305 return Err(WsError::InvalidMultiaddr(address));
306 }
307 };
308
309 let raw_stream = self.transport.dial(inner_addr).await.map_err(WsError::Transport)?;
310 debug!("[Client] connected to {}", address);
312 let local_addr = raw_stream.local_multiaddr();
313 let remote_addr = raw_stream.remote_multiaddr();
314 let stream = if use_tls {
315 let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
317 trace!("[Client] starting TLS handshake with {}", address);
318 let stream = self
319 .inner_config
320 .tls_config
321 .client
322 .connect(&dns_name, raw_stream)
323 .await
324 .map_err(|e| {
325 debug!("[Client] TLS handshake with {} failed: {}", address, e);
326 WsError::Tls(tls::Error::from(e))
327 })?;
328
329 let stream = TlsClientStream(stream);
330
331 let stream: TlsOrPlain<_> = EitherOutput::A(EitherOutput::A(stream));
332 stream
333 } else {
334 EitherOutput::B(raw_stream)
336 };
337
338 debug!("[Client] sending websocket handshake request to {}", address);
340
341 let mut client = handshake::Client::new(stream, &host_port, path.as_ref());
342
343 if self.inner_config.use_deflate {
344 client.add_extension(Box::new(Deflate::new(connection::Mode::Client)));
345 }
346
347 match client
348 .handshake()
349 .map_err(|e| {
350 info!("[Client] {:?}", e);
351 WsError::Handshake(Box::new(e))
352 })
353 .await?
354 {
355 handshake::ServerResponse::Redirect { status_code, location } => {
356 debug!("[Client] received redirect ({}); location: {}", status_code, location);
357 Ok(Either::Left(location))
358 }
359 handshake::ServerResponse::Rejected { status_code } => {
360 let msg = format!("[Client] server rejected handshake; status code = {}", status_code);
361 Err(WsError::Handshake(msg.into()))
362 }
363 handshake::ServerResponse::Accepted { .. } => {
364 debug!("[Client] websocket handshake with {} successful", address);
365 Ok(Either::Right(Connection::new(client.into_builder(), local_addr, remote_addr)))
366 }
367 }
368 }
369}
370
371impl From<WsError> for TransportError {
372 fn from(e: WsError) -> Self {
373 match e {
374 WsError::InvalidMultiaddr(a) => TransportError::MultiaddrNotSupported(a),
375 _ => TransportError::WsError(Box::new(e)),
376 }
377 }
378}
379
380fn host_and_dnsname(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), WsError> {
382 let mut iter = addr.iter();
383 match (iter.next(), iter.next()) {
384 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => Ok((format!("{}:{}", ip, port), None)),
385 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => Ok((format!("{}:{}", ip, port), None)),
386 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) => {
387 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
388 }
389 (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) => {
390 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
391 }
392 (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
393 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())))
394 }
395 _ => {
396 debug!("multi-address format not supported: {}", addr);
397 Err(WsError::InvalidMultiaddr(addr.clone()))
398 }
399 }
400}
401
402fn location_to_multiaddr(location: &str) -> Result<Multiaddr, WsError> {
404 match Url::parse(location) {
405 Ok(url) => {
406 let mut a = Multiaddr::empty();
407 match url.host() {
408 Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
409 Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
410 Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
411 None => return Err(WsError::InvalidRedirectLocation),
412 }
413 if let Some(p) = url.port() {
414 a.push(Protocol::Tcp(p))
415 }
416 let s = url.scheme();
417 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
418 a.push(Protocol::Wss(url.path().into()))
419 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
420 a.push(Protocol::Ws(url.path().into()))
421 } else {
422 debug!("unsupported scheme: {}", s);
423 return Err(WsError::InvalidRedirectLocation);
424 }
425 Ok(a)
426 }
427 Err(e) => {
428 debug!("failed to parse url as multi-address: {:?}", e);
429 Err(WsError::InvalidRedirectLocation)
430 }
431 }
432}