1pub mod error;
24pub mod framed;
25pub mod tls;
26
27use error::Error;
28use framed::Connection;
29use futures::{future::BoxFuture, prelude::*, stream::BoxStream, ready};
30use libp2p_core::{
31 ConnectedPoint,
32 Transport,
33 multiaddr::Multiaddr,
34 transport::{map::{MapFuture, MapStream}, ListenerEvent, TransportError}
35};
36use rw_stream_sink::RwStreamSink;
37use std::{io, pin::Pin, task::{Context, Poll}};
38
39#[derive(Debug, Clone)]
41pub struct WsConfig<T> {
42 transport: framed::WsConfig<T>
43}
44
45impl<T> WsConfig<T> {
46 pub fn new(transport: T) -> Self {
48 framed::WsConfig::new(transport).into()
49 }
50
51 pub fn max_redirects(&self) -> u8 {
53 self.transport.max_redirects()
54 }
55
56 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
58 self.transport.set_max_redirects(max);
59 self
60 }
61
62 pub fn max_data_size(&self) -> usize {
64 self.transport.max_data_size()
65 }
66
67 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
69 self.transport.set_max_data_size(size);
70 self
71 }
72
73 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
75 self.transport.set_tls_config(c);
76 self
77 }
78
79 pub fn use_deflate(&mut self, flag: bool) -> &mut Self {
81 self.transport.use_deflate(flag);
82 self
83 }
84}
85
86impl<T> From<framed::WsConfig<T>> for WsConfig<T> {
87 fn from(framed: framed::WsConfig<T>) -> Self {
88 WsConfig {
89 transport: framed
90 }
91 }
92}
93
94impl<T> Transport for WsConfig<T>
95where
96 T: Transport + Send + Clone + 'static,
97 T::Error: Send + 'static,
98 T::Dial: Send + 'static,
99 T::Listener: Send + 'static,
100 T::ListenerUpgrade: Send + 'static,
101 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
102{
103 type Output = RwStreamSink<BytesConnection<T::Output>>;
104 type Error = Error<T::Error>;
105 type Listener = MapStream<InnerStream<T::Output, T::Error>, WrapperFn<T::Output>>;
106 type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
107 type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
108
109 fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
110 self.transport.map(wrap_connection as WrapperFn<T::Output>).listen_on(addr)
111 }
112
113 fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
114 self.transport.map(wrap_connection as WrapperFn<T::Output>).dial(addr)
115 }
116
117 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
118 self.transport.address_translation(server, observed)
119 }
120}
121
122pub type InnerStream<T, E> = BoxStream<'static, Result<ListenerEvent<InnerFuture<T, E>, Error<E>>, Error<E>>>;
124
125pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
127
128pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
130
131fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
134where
135 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
136{
137 RwStreamSink::new(BytesConnection(c))
138}
139
140#[derive(Debug)]
142pub struct BytesConnection<T>(Connection<T>);
143
144impl<T> Stream for BytesConnection<T>
145where
146 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
147{
148 type Item = io::Result<Vec<u8>>;
149
150 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151 loop {
152 if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
153 if item.is_data() {
154 return Poll::Ready(Some(Ok(item.into_bytes())))
155 }
156 } else {
157 return Poll::Ready(None)
158 }
159 }
160 }
161}
162
163impl<T> Sink<Vec<u8>> for BytesConnection<T>
164where
165 T: AsyncRead + AsyncWrite + Send + Unpin + 'static
166{
167 type Error = io::Error;
168
169 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
170 Pin::new(&mut self.0).poll_ready(cx)
171 }
172
173 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
174 Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
175 }
176
177 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 Pin::new(&mut self.0).poll_flush(cx)
179 }
180
181 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182 Pin::new(&mut self.0).poll_close(cx)
183 }
184}
185
186#[cfg(test)]
189mod tests {
190 use libp2p_core::Multiaddr;
191 use libp2p_tcp as tcp;
192 use futures::prelude::*;
193 use libp2p_core::{Transport, multiaddr::Protocol};
194 use super::WsConfig;
195
196 #[test]
197 fn dialer_connects_to_listener_ipv4() {
198 let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
199 futures::executor::block_on(connect(a))
200 }
201
202 #[test]
203 fn dialer_connects_to_listener_ipv6() {
204 let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
205 futures::executor::block_on(connect(a))
206 }
207
208 async fn connect(listen_addr: Multiaddr) {
209 let ws_config = WsConfig::new(tcp::TcpConfig::new());
210
211 let mut listener = ws_config.clone()
212 .listen_on(listen_addr)
213 .expect("listener");
214
215 let addr = listener.try_next().await
216 .expect("some event")
217 .expect("no error")
218 .into_new_address()
219 .expect("listen address");
220
221 assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
222 assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
223
224 let inbound = async move {
225 let (conn, _addr) = listener.try_filter_map(|e| future::ready(Ok(e.into_upgrade())))
226 .try_next()
227 .await
228 .unwrap()
229 .unwrap();
230 conn.await
231 };
232
233 let outbound = ws_config.dial(addr).unwrap();
234
235 let (a, b) = futures::join!(inbound, outbound);
236 a.and(b).unwrap();
237 }
238}