1use futures_rustls::{webpki, client, server};
22use crate::{error::Error, tls};
23use either::Either;
24use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
25use libp2p_core::{
26 Transport,
27 either::EitherOutput,
28 multiaddr::{Protocol, Multiaddr},
29 transport::{ListenerEvent, TransportError}
30};
31use log::{debug, trace};
32use soketto::{connection, extension::deflate::Deflate, handshake};
33use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll};
34use url::Url;
35
36const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
38
39#[derive(Debug, Clone)]
43pub struct WsConfig<T> {
44 transport: T,
45 max_data_size: usize,
46 tls_config: tls::Config,
47 max_redirects: u8,
48 use_deflate: bool
49}
50
51impl<T> WsConfig<T> {
52 pub fn new(transport: T) -> Self {
54 WsConfig {
55 transport,
56 max_data_size: MAX_DATA_SIZE,
57 tls_config: tls::Config::client(),
58 max_redirects: 0,
59 use_deflate: false
60 }
61 }
62
63 pub fn max_redirects(&self) -> u8 {
65 self.max_redirects
66 }
67
68 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
70 self.max_redirects = max;
71 self
72 }
73
74 pub fn max_data_size(&self) -> usize {
76 self.max_data_size
77 }
78
79 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
81 self.max_data_size = size;
82 self
83 }
84
85 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
87 self.tls_config = c;
88 self
89 }
90
91 pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
93 self.use_deflate = flag;
94 self
95 }
96}
97
98type TlsOrPlain<T> = EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>;
99
100impl<T> Transport for WsConfig<T>
101where
102 T: Transport + Send + Clone + 'static,
103 T::Error: Send + 'static,
104 T::Dial: Send + 'static,
105 T::Listener: Send + 'static,
106 T::ListenerUpgrade: Send + 'static,
107 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
108{
109 type Output = Connection<T::Output>;
110 type Error = Error<T::Error>;
111 type Listener = BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
112 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
113 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
114
115 fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
116 let mut inner_addr = addr.clone();
117
118 let (use_tls, proto) = match inner_addr.pop() {
119 Some(p@Protocol::Wss(_)) =>
120 if self.tls_config.server.is_some() {
121 (true, p)
122 } else {
123 debug!("/wss address but TLS server support is not configured");
124 return Err(TransportError::MultiaddrNotSupported(addr))
125 }
126 Some(p@Protocol::Ws(_)) => (false, p),
127 _ => {
128 debug!("{} is not a websocket multiaddr", addr);
129 return Err(TransportError::MultiaddrNotSupported(addr))
130 }
131 };
132
133 let tls_config = self.tls_config;
134 let max_size = self.max_data_size;
135 let use_deflate = self.use_deflate;
136 let transport = self.transport.listen_on(inner_addr).map_err(|e| e.map(Error::Transport))?;
137 let listen = transport
138 .map_err(Error::Transport)
139 .map_ok(move |event| match event {
140 ListenerEvent::NewAddress(mut a) => {
141 a = a.with(proto.clone());
142 debug!("Listening on {}", a);
143 ListenerEvent::NewAddress(a)
144 }
145 ListenerEvent::AddressExpired(mut a) => {
146 a = a.with(proto.clone());
147 ListenerEvent::AddressExpired(a)
148 }
149 ListenerEvent::Error(err) => {
150 ListenerEvent::Error(Error::Transport(err))
151 }
152 ListenerEvent::Upgrade { upgrade, mut local_addr, mut remote_addr } => {
153 local_addr = local_addr.with(proto.clone());
154 remote_addr = remote_addr.with(proto.clone());
155 let remote1 = remote_addr.clone(); let remote2 = remote_addr.clone(); let tls_config = tls_config.clone();
158
159 let upgrade = async move {
160 let stream = upgrade.map_err(Error::Transport).await?;
161 trace!("incoming connection from {}", remote1);
162
163 let stream =
164 if use_tls { let server = tls_config
166 .server
167 .expect("for use_tls we checked server is not none");
168
169 trace!("awaiting TLS handshake with {}", remote1);
170
171 let stream = server.accept(stream)
172 .map_err(move |e| {
173 debug!("TLS handshake with {} failed: {}", remote1, e);
174 Error::Tls(tls::Error::from(e))
175 })
176 .await?;
177
178 let stream: TlsOrPlain<_> =
179 EitherOutput::First(EitherOutput::Second(stream));
180
181 stream
182 } else { EitherOutput::Second(stream)
184 };
185
186 trace!("receiving websocket handshake request from {}", remote2);
187
188 let mut server = handshake::Server::new(stream);
189
190 if use_deflate {
191 server.add_extension(Box::new(Deflate::new(connection::Mode::Server)));
192 }
193
194 let ws_key = {
195 let request = server.receive_request()
196 .map_err(|e| Error::Handshake(Box::new(e)))
197 .await?;
198 request.into_key()
199 };
200
201 trace!("accepting websocket handshake request from {}", remote2);
202
203 let response =
204 handshake::server::Response::Accept {
205 key: &ws_key,
206 protocol: None
207 };
208
209 server.send_response(&response)
210 .map_err(|e| Error::Handshake(Box::new(e)))
211 .await?;
212
213 let conn = {
214 let mut builder = server.into_builder();
215 builder.set_max_message_size(max_size);
216 builder.set_max_frame_size(max_size);
217 Connection::new(builder)
218 };
219
220 Ok(conn)
221 };
222
223 ListenerEvent::Upgrade {
224 upgrade: Box::pin(upgrade) as BoxFuture<'static, _>,
225 local_addr,
226 remote_addr
227 }
228 }
229 });
230 Ok(Box::pin(listen))
231 }
232
233 fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
234 if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
236 } else {
238 debug!("{} is not a websocket multiaddr", addr);
239 return Err(TransportError::MultiaddrNotSupported(addr))
240 }
241
242 let mut remaining_redirects = self.max_redirects;
244 let mut addr = addr;
245 let future = async move {
246 loop {
247 let this = self.clone();
248 match this.dial_once(addr).await {
249 Ok(Either::Left(redirect)) => {
250 if remaining_redirects == 0 {
251 debug!("too many redirects");
252 return Err(Error::TooManyRedirects)
253 }
254 remaining_redirects -= 1;
255 addr = location_to_multiaddr(&redirect)?
256 }
257 Ok(Either::Right(conn)) => return Ok(conn),
258 Err(e) => return Err(e)
259 }
260 }
261 };
262
263 Ok(Box::pin(future))
264 }
265
266 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
267 self.transport.address_translation(server, observed)
268 }
269}
270
271impl<T> WsConfig<T>
272where
273 T: Transport,
274 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static
275{
276 async fn dial_once(self, address: Multiaddr) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
278 trace!("dial address: {}", address);
279
280 let (host_port, dns_name) = host_and_dnsname(&address)?;
281
282 let mut inner_addr = address.clone();
283
284 let (use_tls, path) =
285 match inner_addr.pop() {
286 Some(Protocol::Ws(path)) => (false, path),
287 Some(Protocol::Wss(path)) => {
288 if dns_name.is_none() {
289 debug!("no DNS name in {}", address);
290 return Err(Error::InvalidMultiaddr(address))
291 }
292 (true, path)
293 }
294 _ => {
295 debug!("{} is not a websocket multiaddr", address);
296 return Err(Error::InvalidMultiaddr(address))
297 }
298 };
299
300 let dial = self.transport.dial(inner_addr)
301 .map_err(|e| match e {
302 TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
303 TransportError::Other(e) => Error::Transport(e)
304 })?;
305
306 let stream = dial.map_err(Error::Transport).await?;
307 trace!("connected to {}", address);
308
309 let stream =
310 if use_tls { let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
312 trace!("starting TLS handshake with {}", address);
313 let stream = self.tls_config.client.connect(dns_name.as_ref(), stream)
314 .map_err(|e| {
315 debug!("TLS handshake with {} failed: {}", address, e);
316 Error::Tls(tls::Error::from(e))
317 })
318 .await?;
319
320 let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream));
321 stream
322 } else { EitherOutput::Second(stream)
324 };
325
326 trace!("sending websocket handshake request to {}", address);
327
328 let mut client = handshake::Client::new(stream, &host_port, path.as_ref());
329
330 if self.use_deflate {
331 client.add_extension(Box::new(Deflate::new(connection::Mode::Client)));
332 }
333
334 match client.handshake().map_err(|e| Error::Handshake(Box::new(e))).await? {
335 handshake::ServerResponse::Redirect { status_code, location } => {
336 debug!("received redirect ({}); location: {}", status_code, location);
337 Ok(Either::Left(location))
338 }
339 handshake::ServerResponse::Rejected { status_code } => {
340 let msg = format!("server rejected handshake; status code = {}", status_code);
341 Err(Error::Handshake(msg.into()))
342 }
343 handshake::ServerResponse::Accepted { .. } => {
344 trace!("websocket handshake with {} successful", address);
345 Ok(Either::Right(Connection::new(client.into_builder())))
346 }
347 }
348 }
349}
350
351fn host_and_dnsname<T>(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), Error<T>> {
353 let mut iter = addr.iter();
354 match (iter.next(), iter.next()) {
355 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) =>
356 Ok((format!("{}:{}", ip, port), None)),
357 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) =>
358 Ok((format!("{}:{}", ip, port), None)),
359 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) =>
360 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
361 (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) =>
362 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
363 (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) =>
364 Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
365 _ => {
366 debug!("multi-address format not supported: {}", addr);
367 Err(Error::InvalidMultiaddr(addr.clone()))
368 }
369 }
370}
371
372fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
374 match Url::parse(location) {
375 Ok(url) => {
376 let mut a = Multiaddr::empty();
377 match url.host() {
378 Some(url::Host::Domain(h)) => {
379 a.push(Protocol::Dns(h.into()))
380 }
381 Some(url::Host::Ipv4(ip)) => {
382 a.push(Protocol::Ip4(ip))
383 }
384 Some(url::Host::Ipv6(ip)) => {
385 a.push(Protocol::Ip6(ip))
386 }
387 None => return Err(Error::InvalidRedirectLocation)
388 }
389 if let Some(p) = url.port() {
390 a.push(Protocol::Tcp(p))
391 }
392 let s = url.scheme();
393 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
394 a.push(Protocol::Wss(url.path().into()))
395 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
396 a.push(Protocol::Ws(url.path().into()))
397 } else {
398 debug!("unsupported scheme: {}", s);
399 return Err(Error::InvalidRedirectLocation)
400 }
401 Ok(a)
402 }
403 Err(e) => {
404 debug!("failed to parse url as multi-address: {:?}", e);
405 Err(Error::InvalidRedirectLocation)
406 }
407 }
408}
409
410pub struct Connection<T> {
412 receiver: BoxStream<'static, Result<IncomingData, connection::Error>>,
413 sender: Pin<Box<dyn Sink<OutgoingData, Error = connection::Error> + Send>>,
414 _marker: std::marker::PhantomData<T>
415}
416
417#[derive(Debug, Clone)]
419pub enum IncomingData {
420 Binary(Vec<u8>),
422 Text(Vec<u8>),
424 Pong(Vec<u8>)
426}
427
428impl IncomingData {
429 pub fn is_data(&self) -> bool {
430 self.is_binary() || self.is_text()
431 }
432
433 pub fn is_binary(&self) -> bool {
434 if let IncomingData::Binary(_) = self { true } else { false }
435 }
436
437 pub fn is_text(&self) -> bool {
438 if let IncomingData::Text(_) = self { true } else { false }
439 }
440
441 pub fn is_pong(&self) -> bool {
442 if let IncomingData::Pong(_) = self { true } else { false }
443 }
444
445 pub fn into_bytes(self) -> Vec<u8> {
446 match self {
447 IncomingData::Binary(d) => d,
448 IncomingData::Text(d) => d,
449 IncomingData::Pong(d) => d
450 }
451 }
452}
453
454impl AsRef<[u8]> for IncomingData {
455 fn as_ref(&self) -> &[u8] {
456 match self {
457 IncomingData::Binary(d) => d,
458 IncomingData::Text(d) => d,
459 IncomingData::Pong(d) => d
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
466pub enum OutgoingData {
467 Binary(Vec<u8>),
469 Ping(Vec<u8>),
471 Pong(Vec<u8>)
474}
475
476impl<T> fmt::Debug for Connection<T> {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 f.write_str("Connection")
479 }
480}
481
482impl<T> Connection<T>
483where
484 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
485{
486 fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
487 let (sender, receiver) = builder.finish();
488 let sink = quicksink::make_sink(sender, |mut sender, action| async move {
489 match action {
490 quicksink::Action::Send(OutgoingData::Binary(x)) => {
491 sender.send_binary_mut(x).await?
492 }
493 quicksink::Action::Send(OutgoingData::Ping(x)) => {
494 let data = x[..].try_into().map_err(|_| {
495 io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
496 })?;
497 sender.send_ping(data).await?
498 }
499 quicksink::Action::Send(OutgoingData::Pong(x)) => {
500 let data = x[..].try_into().map_err(|_| {
501 io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
502 })?;
503 sender.send_pong(data).await?
504 }
505 quicksink::Action::Flush => sender.flush().await?,
506 quicksink::Action::Close => sender.close().await?
507 }
508 Ok(sender)
509 });
510 let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
511 match receiver.receive(&mut data).await {
512 Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => {
513 Some((Ok(IncomingData::Text(mem::take(&mut data))), (data, receiver)))
514 }
515 Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => {
516 Some((Ok(IncomingData::Binary(mem::take(&mut data))), (data, receiver)))
517 }
518 Ok(soketto::Incoming::Pong(pong)) => {
519 Some((Ok(IncomingData::Pong(Vec::from(pong))), (data, receiver)))
520 }
521 Err(connection::Error::Closed) => None,
522 Err(e) => Some((Err(e), (data, receiver)))
523 }
524 });
525 Connection {
526 receiver: stream.boxed(),
527 sender: Box::pin(sink),
528 _marker: std::marker::PhantomData
529 }
530 }
531
532 pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
534 self.send(OutgoingData::Binary(data))
535 }
536
537 pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
539 self.send(OutgoingData::Ping(data))
540 }
541
542 pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
544 self.send(OutgoingData::Pong(data))
545 }
546}
547
548impl<T> Stream for Connection<T>
549where
550 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
551{
552 type Item = io::Result<IncomingData>;
553
554 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
555 let item = ready!(self.receiver.poll_next_unpin(cx));
556 let item = item.map(|result| {
557 result.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
558 });
559 Poll::Ready(item)
560 }
561}
562
563impl<T> Sink<OutgoingData> for Connection<T>
564where
565 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
566{
567 type Error = io::Error;
568
569 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
570 Pin::new(&mut self.sender)
571 .poll_ready(cx)
572 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
573 }
574
575 fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
576 Pin::new(&mut self.sender)
577 .start_send(item)
578 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
579 }
580
581 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
582 Pin::new(&mut self.sender)
583 .poll_flush(cx)
584 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
585 }
586
587 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
588 Pin::new(&mut self.sender)
589 .poll_close(cx)
590 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
591 }
592}