1use ant_libp2p_core as libp2p_core;
22
23use std::{
24 borrow::Cow,
25 collections::HashMap,
26 fmt, io, mem,
27 net::IpAddr,
28 ops::DerefMut,
29 pin::Pin,
30 sync::Arc,
31 task::{Context, Poll},
32};
33
34use either::Either;
35use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
36use futures_rustls::{client, rustls::pki_types::ServerName, server};
37use libp2p_core::{
38 multiaddr::{Multiaddr, Protocol},
39 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
40 Transport,
41};
42use parking_lot::Mutex;
43use soketto::{
44 connection::{self, CloseReason},
45 handshake,
46};
47use url::Url;
48
49use crate::{error::Error, quicksink, tls};
50
51const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
53
54#[derive(Debug)]
58pub struct WsConfig<T> {
59 transport: Arc<Mutex<T>>,
60 max_data_size: usize,
61 tls_config: tls::Config,
62 max_redirects: u8,
63 listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
65}
66
67impl<T> WsConfig<T>
68where
69 T: Send,
70{
71 pub fn new(transport: T) -> Self {
73 WsConfig {
74 transport: Arc::new(Mutex::new(transport)),
75 max_data_size: MAX_DATA_SIZE,
76 tls_config: tls::Config::client(),
77 max_redirects: 0,
78 listener_protos: HashMap::new(),
79 }
80 }
81
82 pub fn max_redirects(&self) -> u8 {
84 self.max_redirects
85 }
86
87 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
89 self.max_redirects = max;
90 self
91 }
92
93 pub fn max_data_size(&self) -> usize {
95 self.max_data_size
96 }
97
98 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
100 self.max_data_size = size;
101 self
102 }
103
104 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
106 self.tls_config = c;
107 self
108 }
109}
110
111type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
112
113impl<T> Transport for WsConfig<T>
114where
115 T: Transport + Send + Unpin + 'static,
116 T::Error: Send + 'static,
117 T::Dial: Send + 'static,
118 T::ListenerUpgrade: Send + 'static,
119 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
120{
121 type Output = Connection<T::Output>;
122 type Error = Error<T::Error>;
123 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
124 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
125
126 fn listen_on(
127 &mut self,
128 id: ListenerId,
129 addr: Multiaddr,
130 ) -> Result<(), TransportError<Self::Error>> {
131 let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
132 tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
133 TransportError::MultiaddrNotSupported(addr.clone())
134 })?;
135
136 if proto.use_tls() && self.tls_config.server.is_none() {
137 tracing::debug!(
138 "{} address but TLS server support is not configured",
139 proto.prefix()
140 );
141 return Err(TransportError::MultiaddrNotSupported(addr));
142 }
143
144 match self.transport.lock().listen_on(id, inner_addr) {
145 Ok(()) => {
146 self.listener_protos.insert(id, proto);
147 Ok(())
148 }
149 Err(e) => Err(e.map(Error::Transport)),
150 }
151 }
152
153 fn remove_listener(&mut self, id: ListenerId) -> bool {
154 self.transport.lock().remove_listener(id)
155 }
156
157 fn dial(
158 &mut self,
159 addr: Multiaddr,
160 dial_opts: DialOpts,
161 ) -> Result<Self::Dial, TransportError<Self::Error>> {
162 self.do_dial(addr, dial_opts)
163 }
164
165 fn poll(
166 mut self: Pin<&mut Self>,
167 cx: &mut Context<'_>,
168 ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
169 let inner_event = {
170 let mut transport = self.transport.lock();
171 match Transport::poll(Pin::new(transport.deref_mut()), cx) {
172 Poll::Ready(ev) => ev,
173 Poll::Pending => return Poll::Pending,
174 }
175 };
176 let event = match inner_event {
177 TransportEvent::NewAddress {
178 listener_id,
179 mut listen_addr,
180 } => {
181 self.listener_protos
183 .get(&listener_id)
184 .expect("Protocol was inserted in Transport::listen_on.")
185 .append_on_addr(&mut listen_addr);
186 tracing::debug!(address=%listen_addr, "Listening on address");
187 TransportEvent::NewAddress {
188 listener_id,
189 listen_addr,
190 }
191 }
192 TransportEvent::AddressExpired {
193 listener_id,
194 mut listen_addr,
195 } => {
196 self.listener_protos
197 .get(&listener_id)
198 .expect("Protocol was inserted in Transport::listen_on.")
199 .append_on_addr(&mut listen_addr);
200 TransportEvent::AddressExpired {
201 listener_id,
202 listen_addr,
203 }
204 }
205 TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
206 listener_id,
207 error: Error::Transport(error),
208 },
209 TransportEvent::ListenerClosed {
210 listener_id,
211 reason,
212 } => {
213 self.listener_protos
214 .remove(&listener_id)
215 .expect("Protocol was inserted in Transport::listen_on.");
216 TransportEvent::ListenerClosed {
217 listener_id,
218 reason: reason.map_err(Error::Transport),
219 }
220 }
221 TransportEvent::Incoming {
222 listener_id,
223 upgrade,
224 mut local_addr,
225 mut send_back_addr,
226 } => {
227 let proto = self
228 .listener_protos
229 .get(&listener_id)
230 .expect("Protocol was inserted in Transport::listen_on.");
231 let use_tls = proto.use_tls();
232 proto.append_on_addr(&mut local_addr);
233 proto.append_on_addr(&mut send_back_addr);
234 let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
235 TransportEvent::Incoming {
236 listener_id,
237 upgrade,
238 local_addr,
239 send_back_addr,
240 }
241 }
242 };
243 Poll::Ready(event)
244 }
245}
246
247impl<T> WsConfig<T>
248where
249 T: Transport + Send + Unpin + 'static,
250 T::Error: Send + 'static,
251 T::Dial: Send + 'static,
252 T::ListenerUpgrade: Send + 'static,
253 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
254{
255 fn do_dial(
256 &mut self,
257 addr: Multiaddr,
258 dial_opts: DialOpts,
259 ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
260 let mut addr = match parse_ws_dial_addr(addr) {
261 Ok(addr) => addr,
262 Err(Error::InvalidMultiaddr(a)) => {
263 return Err(TransportError::MultiaddrNotSupported(a))
264 }
265 Err(e) => return Err(TransportError::Other(e)),
266 };
267
268 let mut remaining_redirects = self.max_redirects;
270
271 let transport = self.transport.clone();
272 let tls_config = self.tls_config.clone();
273 let max_redirects = self.max_redirects;
274
275 let future = async move {
276 loop {
277 match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
278 {
279 Ok(Either::Left(redirect)) => {
280 if remaining_redirects == 0 {
281 tracing::debug!(%max_redirects, "Too many redirects");
282 return Err(Error::TooManyRedirects);
283 }
284 remaining_redirects -= 1;
285 addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
286 }
287 Ok(Either::Right(conn)) => return Ok(conn),
288 Err(e) => return Err(e),
289 }
290 }
291 };
292
293 Ok(Box::pin(future))
294 }
295
296 async fn dial_once(
298 transport: Arc<Mutex<T>>,
299 addr: WsAddress,
300 tls_config: tls::Config,
301 dial_opts: DialOpts,
302 ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
303 tracing::trace!(address=?addr, "Dialing websocket address");
304
305 let dial = transport
306 .lock()
307 .dial(addr.tcp_addr, dial_opts)
308 .map_err(|e| match e {
309 TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
310 TransportError::Other(e) => Error::Transport(e),
311 })?;
312
313 let stream = dial.map_err(Error::Transport).await?;
314 tracing::trace!(port=%addr.host_port, "TCP connection established");
315
316 let stream = if addr.use_tls {
317 tracing::trace!(?addr.server_name, "Starting TLS handshake");
319 let stream = tls_config
320 .client
321 .connect(addr.server_name.clone(), stream)
322 .map_err(|e| {
323 tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
324 Error::Tls(tls::Error::from(e))
325 })
326 .await?;
327
328 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
329 stream
330 } else {
331 future::Either::Right(stream)
333 };
334
335 tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
336
337 let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
338
339 match client
340 .handshake()
341 .map_err(|e| Error::Handshake(Box::new(e)))
342 .await?
343 {
344 handshake::ServerResponse::Redirect {
345 status_code,
346 location,
347 } => {
348 tracing::debug!(
349 %status_code,
350 %location,
351 "received redirect"
352 );
353 Ok(Either::Left(location))
354 }
355 handshake::ServerResponse::Rejected { status_code } => {
356 let msg = format!("server rejected handshake; status code = {status_code}");
357 Err(Error::Handshake(msg.into()))
358 }
359 handshake::ServerResponse::Accepted { .. } => {
360 tracing::trace!(port=%addr.host_port, "websocket handshake successful");
361 Ok(Either::Right(Connection::new(client.into_builder())))
362 }
363 }
364 }
365
366 fn map_upgrade(
367 &self,
368 upgrade: T::ListenerUpgrade,
369 remote_addr: Multiaddr,
370 use_tls: bool,
371 ) -> <Self as Transport>::ListenerUpgrade {
372 let remote_addr2 = remote_addr.clone(); let tls_config = self.tls_config.clone();
374 let max_size = self.max_data_size;
375
376 async move {
377 let stream = upgrade.map_err(Error::Transport).await?;
378 tracing::trace!(address=%remote_addr, "incoming connection from address");
379
380 let stream = if use_tls {
381 let server = tls_config
383 .server
384 .expect("for use_tls we checked server is not none");
385
386 tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
387
388 let stream = server
389 .accept(stream)
390 .map_err(move |e| {
391 tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
392 Error::Tls(tls::Error::from(e))
393 })
394 .await?;
395
396 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
397
398 stream
399 } else {
400 future::Either::Right(stream)
402 };
403
404 tracing::trace!(
405 address=%remote_addr2,
406 "receiving websocket handshake request from address"
407 );
408
409 let mut server = handshake::Server::new(stream);
410
411 let ws_key = {
412 let request = server
413 .receive_request()
414 .map_err(|e| Error::Handshake(Box::new(e)))
415 .await?;
416 request.key()
417 };
418
419 tracing::trace!(
420 address=%remote_addr2,
421 "accepting websocket handshake request from address"
422 );
423
424 let response = handshake::server::Response::Accept {
425 key: ws_key,
426 protocol: None,
427 };
428
429 server
430 .send_response(&response)
431 .map_err(|e| Error::Handshake(Box::new(e)))
432 .await?;
433
434 let conn = {
435 let mut builder = server.into_builder();
436 builder.set_max_message_size(max_size);
437 builder.set_max_frame_size(max_size);
438 Connection::new(builder)
439 };
440
441 Ok(conn)
442 }
443 .boxed()
444 }
445}
446
447#[derive(Debug, PartialEq)]
448pub(crate) enum WsListenProto<'a> {
449 Ws(Cow<'a, str>),
450 Wss(Cow<'a, str>),
451 TlsWs(Cow<'a, str>),
452}
453
454impl WsListenProto<'_> {
455 pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
456 match self {
457 WsListenProto::Ws(path) => {
458 addr.push(Protocol::Ws(path.clone()));
459 }
460 WsListenProto::Wss(path) => {
463 addr.push(Protocol::Wss(path.clone()));
464 }
465 WsListenProto::TlsWs(path) => {
466 addr.push(Protocol::Tls);
467 addr.push(Protocol::Ws(path.clone()));
468 }
469 }
470 }
471
472 pub(crate) fn use_tls(&self) -> bool {
473 match self {
474 WsListenProto::Ws(_) => false,
475 WsListenProto::Wss(_) => true,
476 WsListenProto::TlsWs(_) => true,
477 }
478 }
479
480 pub(crate) fn prefix(&self) -> &'static str {
481 match self {
482 WsListenProto::Ws(_) => "/ws",
483 WsListenProto::Wss(_) => "/wss",
484 WsListenProto::TlsWs(_) => "/tls/ws",
485 }
486 }
487}
488
489#[derive(Debug)]
490struct WsAddress {
491 host_port: String,
492 path: String,
493 server_name: ServerName<'static>,
494 use_tls: bool,
495 tcp_addr: Multiaddr,
496}
497
498fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
504 let mut protocols = addr.iter();
508 let mut ip = protocols.next();
509 let mut tcp = protocols.next();
510 let (host_port, server_name) = loop {
511 match (ip, tcp) {
512 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
513 let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
514 break (format!("{ip}:{port}"), server_name);
515 }
516 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
517 let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
518 break (format!("[{ip}]:{port}"), server_name);
519 }
520 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
521 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
522 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
523 break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
524 }
525 (Some(_), Some(p)) => {
526 ip = Some(p);
527 tcp = protocols.next();
528 }
529 _ => return Err(Error::InvalidMultiaddr(addr)),
530 }
531 };
532
533 let mut protocols = addr.clone();
537 let mut p2p = None;
538 let (use_tls, path) = loop {
539 match protocols.pop() {
540 p @ Some(Protocol::P2p(_)) => p2p = p,
541 Some(Protocol::Ws(path)) => match protocols.pop() {
542 Some(Protocol::Tls) => break (true, path.into_owned()),
543 Some(p) => {
544 protocols.push(p);
545 break (false, path.into_owned());
546 }
547 None => return Err(Error::InvalidMultiaddr(addr)),
548 },
549 Some(Protocol::Wss(path)) => break (true, path.into_owned()),
550 _ => return Err(Error::InvalidMultiaddr(addr)),
551 }
552 };
553
554 let tcp_addr = match p2p {
557 Some(p) => protocols.with(p),
558 None => protocols,
559 };
560
561 Ok(WsAddress {
562 host_port,
563 server_name,
564 path,
565 use_tls,
566 tcp_addr,
567 })
568}
569
570fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
571 let mut inner_addr = addr.clone();
572
573 match inner_addr.pop()? {
574 Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
575 Protocol::Ws(path) => match inner_addr.pop()? {
576 Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
577 p => {
578 inner_addr.push(p);
579 Some((inner_addr, WsListenProto::Ws(path)))
580 }
581 },
582 _ => None,
583 }
584}
585
586fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
588 match Url::parse(location) {
589 Ok(url) => {
590 let mut a = Multiaddr::empty();
591 match url.host() {
592 Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
593 Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
594 Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
595 None => return Err(Error::InvalidRedirectLocation),
596 }
597 if let Some(p) = url.port() {
598 a.push(Protocol::Tcp(p))
599 }
600 let s = url.scheme();
601 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
602 a.push(Protocol::Tls);
603 a.push(Protocol::Ws(url.path().into()));
604 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
605 a.push(Protocol::Ws(url.path().into()))
606 } else {
607 tracing::debug!(scheme=%s, "unsupported scheme");
608 return Err(Error::InvalidRedirectLocation);
609 }
610 Ok(a)
611 }
612 Err(e) => {
613 tracing::debug!("failed to parse url as multi-address: {:?}", e);
614 Err(Error::InvalidRedirectLocation)
615 }
616 }
617}
618
619pub struct Connection<T> {
621 receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
622 sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
623 _marker: std::marker::PhantomData<T>,
624}
625
626#[derive(Debug, Clone)]
628pub enum Incoming {
629 Data(Data),
631 Pong(Vec<u8>),
633 Closed(CloseReason),
635}
636
637#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
639pub enum Data {
640 Text(Vec<u8>),
642 Binary(Vec<u8>),
644}
645
646impl Data {
647 pub fn into_bytes(self) -> Vec<u8> {
648 match self {
649 Data::Text(d) => d,
650 Data::Binary(d) => d,
651 }
652 }
653}
654
655impl AsRef<[u8]> for Data {
656 fn as_ref(&self) -> &[u8] {
657 match self {
658 Data::Text(d) => d,
659 Data::Binary(d) => d,
660 }
661 }
662}
663
664impl Incoming {
665 pub fn is_data(&self) -> bool {
666 self.is_binary() || self.is_text()
667 }
668
669 pub fn is_binary(&self) -> bool {
670 matches!(self, Incoming::Data(Data::Binary(_)))
671 }
672
673 pub fn is_text(&self) -> bool {
674 matches!(self, Incoming::Data(Data::Text(_)))
675 }
676
677 pub fn is_pong(&self) -> bool {
678 matches!(self, Incoming::Pong(_))
679 }
680
681 pub fn is_close(&self) -> bool {
682 matches!(self, Incoming::Closed(_))
683 }
684}
685
686#[derive(Debug, Clone)]
688pub enum OutgoingData {
689 Binary(Vec<u8>),
691 Ping(Vec<u8>),
693 Pong(Vec<u8>),
696}
697
698impl<T> fmt::Debug for Connection<T> {
699 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
700 f.write_str("Connection")
701 }
702}
703
704impl<T> Connection<T>
705where
706 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
707{
708 fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
709 let (sender, receiver) = builder.finish();
710 let sink = quicksink::make_sink(sender, |mut sender, action| async move {
711 match action {
712 quicksink::Action::Send(OutgoingData::Binary(x)) => {
713 sender.send_binary_mut(x).await?
714 }
715 quicksink::Action::Send(OutgoingData::Ping(x)) => {
716 let data = x[..].try_into().map_err(|_| {
717 io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
718 })?;
719 sender.send_ping(data).await?
720 }
721 quicksink::Action::Send(OutgoingData::Pong(x)) => {
722 let data = x[..].try_into().map_err(|_| {
723 io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
724 })?;
725 sender.send_pong(data).await?
726 }
727 quicksink::Action::Flush => sender.flush().await?,
728 quicksink::Action::Close => sender.close().await?,
729 }
730 Ok(sender)
731 });
732 let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
733 match receiver.receive(&mut data).await {
734 Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
735 Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
736 (data, receiver),
737 )),
738 Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
739 Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
740 (data, receiver),
741 )),
742 Ok(soketto::Incoming::Pong(pong)) => {
743 Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
744 }
745 Ok(soketto::Incoming::Closed(reason)) => {
746 Some((Ok(Incoming::Closed(reason)), (data, receiver)))
747 }
748 Err(connection::Error::Closed) => None,
749 Err(e) => Some((Err(e), (data, receiver))),
750 }
751 });
752 Connection {
753 receiver: stream.boxed(),
754 sender: Box::pin(sink),
755 _marker: std::marker::PhantomData,
756 }
757 }
758
759 pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
761 self.send(OutgoingData::Binary(data))
762 }
763
764 pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
766 self.send(OutgoingData::Ping(data))
767 }
768
769 pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
771 self.send(OutgoingData::Pong(data))
772 }
773}
774
775impl<T> Stream for Connection<T>
776where
777 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
778{
779 type Item = io::Result<Incoming>;
780
781 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
782 let item = ready!(self.receiver.poll_next_unpin(cx));
783 let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
784 Poll::Ready(item)
785 }
786}
787
788impl<T> Sink<OutgoingData> for Connection<T>
789where
790 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
791{
792 type Error = io::Error;
793
794 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
795 Pin::new(&mut self.sender)
796 .poll_ready(cx)
797 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
798 }
799
800 fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
801 Pin::new(&mut self.sender)
802 .start_send(item)
803 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
804 }
805
806 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
807 Pin::new(&mut self.sender)
808 .poll_flush(cx)
809 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
810 }
811
812 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
813 Pin::new(&mut self.sender)
814 .poll_close(cx)
815 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
816 }
817}
818
819#[cfg(test)]
820mod tests {
821 use std::io;
822
823 use libp2p_identity::PeerId;
824
825 use super::*;
826
827 #[test]
828 fn listen_addr() {
829 let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
830
831 let addr = tcp_addr
833 .clone()
834 .with(Protocol::Tls)
835 .with(Protocol::Ws("/".into()));
836 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
837 assert_eq!(&inner_addr, &tcp_addr);
838 assert_eq!(proto, WsListenProto::TlsWs("/".into()));
839
840 let mut listen_addr = tcp_addr.clone();
841 proto.append_on_addr(&mut listen_addr);
842 assert_eq!(listen_addr, addr);
843
844 let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
846 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
847 assert_eq!(&inner_addr, &tcp_addr);
848 assert_eq!(proto, WsListenProto::Wss("/".into()));
849
850 let mut listen_addr = tcp_addr.clone();
851 proto.append_on_addr(&mut listen_addr);
852 assert_eq!(listen_addr, addr);
853
854 let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
856 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
857 assert_eq!(&inner_addr, &tcp_addr);
858 assert_eq!(proto, WsListenProto::Ws("/".into()));
859
860 let mut listen_addr = tcp_addr.clone();
861 proto.append_on_addr(&mut listen_addr);
862 assert_eq!(listen_addr, addr);
863 }
864
865 #[test]
866 fn dial_addr() {
867 let peer_id = PeerId::random();
868
869 let addr = "/dns4/example.com/tcp/2222/tls/ws"
871 .parse::<Multiaddr>()
872 .unwrap();
873 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
874 assert_eq!(info.host_port, "example.com:2222");
875 assert_eq!(info.path, "/");
876 assert!(info.use_tls);
877 assert_eq!(info.server_name, "example.com".try_into().unwrap());
878 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
879
880 let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
882 .parse()
883 .unwrap();
884 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
885 assert_eq!(info.host_port, "example.com:2222");
886 assert_eq!(info.path, "/");
887 assert!(info.use_tls);
888 assert_eq!(info.server_name, "example.com".try_into().unwrap());
889 assert_eq!(
890 info.tcp_addr,
891 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
892 .parse()
893 .unwrap()
894 );
895
896 let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
898 .parse::<Multiaddr>()
899 .unwrap();
900 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
901 assert_eq!(info.host_port, "127.0.0.1:2222");
902 assert_eq!(info.path, "/");
903 assert!(info.use_tls);
904 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
905 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
906
907 let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
909 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
910 assert_eq!(info.host_port, "[::1]:2222");
911 assert_eq!(info.path, "/");
912 assert!(info.use_tls);
913 assert_eq!(info.server_name, "::1".try_into().unwrap());
914 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
915
916 let addr = "/dns4/example.com/tcp/2222/wss"
918 .parse::<Multiaddr>()
919 .unwrap();
920 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
921 assert_eq!(info.host_port, "example.com:2222");
922 assert_eq!(info.path, "/");
923 assert!(info.use_tls);
924 assert_eq!(info.server_name, "example.com".try_into().unwrap());
925 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
926
927 let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
929 .parse()
930 .unwrap();
931 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
932 assert_eq!(info.host_port, "example.com:2222");
933 assert_eq!(info.path, "/");
934 assert!(info.use_tls);
935 assert_eq!(info.server_name, "example.com".try_into().unwrap());
936 assert_eq!(
937 info.tcp_addr,
938 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
939 .parse()
940 .unwrap()
941 );
942
943 let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
945 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
946 assert_eq!(info.host_port, "127.0.0.1:2222");
947 assert_eq!(info.path, "/");
948 assert!(info.use_tls);
949 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
950 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
951
952 let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
954 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
955 assert_eq!(info.host_port, "[::1]:2222");
956 assert_eq!(info.path, "/");
957 assert!(info.use_tls);
958 assert_eq!(info.server_name, "::1".try_into().unwrap());
959 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
960
961 let addr = "/dns4/example.com/tcp/2222/ws"
963 .parse::<Multiaddr>()
964 .unwrap();
965 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
966 assert_eq!(info.host_port, "example.com:2222");
967 assert_eq!(info.path, "/");
968 assert!(!info.use_tls);
969 assert_eq!(info.server_name, "example.com".try_into().unwrap());
970 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
971
972 let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
974 .parse()
975 .unwrap();
976 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
977 assert_eq!(info.host_port, "example.com:2222");
978 assert_eq!(info.path, "/");
979 assert!(!info.use_tls);
980 assert_eq!(info.server_name, "example.com".try_into().unwrap());
981 assert_eq!(
982 info.tcp_addr,
983 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
984 .parse()
985 .unwrap()
986 );
987
988 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
990 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
991 assert_eq!(info.host_port, "127.0.0.1:2222");
992 assert_eq!(info.path, "/");
993 assert!(!info.use_tls);
994 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
995 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
996
997 let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
999 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1000 assert_eq!(info.host_port, "[::1]:2222");
1001 assert_eq!(info.path, "/");
1002 assert!(!info.use_tls);
1003 assert_eq!(info.server_name, "::1".try_into().unwrap());
1004 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1005
1006 let addr = "/dnsaddr/example.com/tcp/2222/ws"
1008 .parse::<Multiaddr>()
1009 .unwrap();
1010 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1011
1012 let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1014 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1015 }
1016}