1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25pub mod error;
26pub mod framed;
27mod quicksink;
28pub mod tls;
29
30use std::{
31 io,
32 pin::Pin,
33 task::{Context, Poll},
34};
35
36use error::Error;
37use framed::{Connection, Incoming};
38use futures::{future::BoxFuture, prelude::*, ready};
39use libp2p_core::{
40 connection::ConnectedPoint,
41 multiaddr::Multiaddr,
42 transport::{map::MapFuture, DialOpts, ListenerId, TransportError, TransportEvent},
43 Transport,
44};
45use rw_stream_sink::RwStreamSink;
46
47#[deprecated = "Use `Config` instead"]
137pub type WsConfig<Transport> = Config<Transport>;
138
139#[derive(Debug)]
140pub struct Config<T: Transport>
141where
142 T: Transport,
143 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
144{
145 transport: libp2p_core::transport::map::Map<framed::Config<T>, WrapperFn<T::Output>>,
146}
147
148impl<T: Transport> Config<T>
149where
150 T: Transport + Send + Unpin + 'static,
151 T::Error: Send + 'static,
152 T::Dial: Send + 'static,
153 T::ListenerUpgrade: Send + 'static,
154 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
155{
156 pub fn new(transport: T) -> Self {
165 Self {
166 transport: framed::Config::new(transport).map(wrap_connection as WrapperFn<T::Output>),
167 }
168 }
169
170 pub fn max_redirects(&self) -> u8 {
172 self.transport.inner().max_redirects()
173 }
174
175 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
177 self.transport.inner_mut().set_max_redirects(max);
178 self
179 }
180
181 pub fn max_data_size(&self) -> usize {
183 self.transport.inner().max_data_size()
184 }
185
186 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
188 self.transport.inner_mut().set_max_data_size(size);
189 self
190 }
191
192 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
194 self.transport.inner_mut().set_tls_config(c);
195 self
196 }
197}
198
199impl<T> Transport for Config<T>
200where
201 T: Transport + Send + Unpin + 'static,
202 T::Error: Send + 'static,
203 T::Dial: Send + 'static,
204 T::ListenerUpgrade: Send + 'static,
205 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
206{
207 type Output = RwStreamSink<BytesConnection<T::Output>>;
208 type Error = Error<T::Error>;
209 type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
210 type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
211
212 fn listen_on(
213 &mut self,
214 id: ListenerId,
215 addr: Multiaddr,
216 ) -> Result<(), TransportError<Self::Error>> {
217 self.transport.listen_on(id, addr)
218 }
219
220 fn remove_listener(&mut self, id: ListenerId) -> bool {
221 self.transport.remove_listener(id)
222 }
223
224 fn dial(
225 &mut self,
226 addr: Multiaddr,
227 opts: DialOpts,
228 ) -> Result<Self::Dial, TransportError<Self::Error>> {
229 self.transport.dial(addr, opts)
230 }
231
232 fn poll(
233 mut self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
236 Pin::new(&mut self.transport).poll(cx)
237 }
238}
239
240pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
242
243pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
245
246fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
249where
250 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
251{
252 RwStreamSink::new(BytesConnection(c))
253}
254
255#[derive(Debug)]
257pub struct BytesConnection<T>(Connection<T>);
258
259impl<T> Stream for BytesConnection<T>
260where
261 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
262{
263 type Item = io::Result<Vec<u8>>;
264
265 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266 loop {
267 if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
268 if let Incoming::Data(payload) = item {
269 return Poll::Ready(Some(Ok(payload.into_bytes())));
270 }
271 } else {
272 return Poll::Ready(None);
273 }
274 }
275 }
276}
277
278impl<T> Sink<Vec<u8>> for BytesConnection<T>
279where
280 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
281{
282 type Error = io::Error;
283
284 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285 Pin::new(&mut self.0).poll_ready(cx)
286 }
287
288 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
289 Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
290 }
291
292 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293 Pin::new(&mut self.0).poll_flush(cx)
294 }
295
296 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
297 Pin::new(&mut self.0).poll_close(cx)
298 }
299}
300
301#[cfg(test)]
304mod tests {
305 use futures::prelude::*;
306 use libp2p_core::{
307 multiaddr::Protocol,
308 transport::{DialOpts, ListenerId, PortUse},
309 Endpoint, Multiaddr, Transport,
310 };
311 use libp2p_identity::PeerId;
312 use libp2p_tcp as tcp;
313
314 use super::Config;
315
316 #[tokio::test]
317 async fn dialer_connects_to_listener_ipv4() {
318 let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
319 connect(a).await
320 }
321
322 #[tokio::test]
323 async fn dialer_connects_to_listener_ipv6() {
324 let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
325 connect(a).await
326 }
327
328 fn new_ws_config() -> Config<tcp::tokio::Transport> {
329 Config::new(tcp::tokio::Transport::new(tcp::Config::default()))
330 }
331
332 async fn connect(listen_addr: Multiaddr) {
333 let mut ws_config = new_ws_config().boxed();
334 ws_config
335 .listen_on(ListenerId::next(), listen_addr)
336 .expect("listener");
337
338 let addr = ws_config
339 .next()
340 .await
341 .expect("no error")
342 .into_new_address()
343 .expect("listen address");
344
345 assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
346 assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
347
348 let inbound = async move {
349 let (conn, _addr) = ws_config
350 .select_next_some()
351 .map(|ev| ev.into_incoming())
352 .await
353 .unwrap();
354 conn.await
355 };
356
357 let outbound = new_ws_config()
358 .boxed()
359 .dial(
360 addr.with(Protocol::P2p(PeerId::random())),
361 DialOpts {
362 role: Endpoint::Dialer,
363 port_use: PortUse::New,
364 },
365 )
366 .unwrap();
367
368 let (a, b) = futures::join!(inbound, outbound);
369 a.and(b).unwrap();
370 }
371}