libp2p_combined_transport/
lib.rs

1#![allow(clippy::type_complexity)]
2use std::{fmt, marker::PhantomData, sync::Arc};
3
4use futures::{
5    channel::mpsc,
6    future::{self, BoxFuture},
7    stream::{self, BoxStream},
8    FutureExt, StreamExt, TryFutureExt, TryStreamExt,
9};
10use libp2p::{
11    core::{either::EitherOutput, transport::ListenerEvent},
12    Multiaddr, Transport, TransportError,
13};
14use parking_lot::Mutex;
15
16/// Transport combining two transports. One of which is the base transport (like TCP), and another
17/// one is a higher-level transport (like WebSocket). Similar to [`OrTransport`], this tries to
18/// dial first with the outer connection, and if that fails, with the base one. The main difference
19/// is that incoming connections can be accepted on either one of them. For this to work, a
20/// switch must be provided when handling incoming connections. For TCP, this can be achieved with
21/// the [`peek`] method on the underlying [`TcpStream`].
22/// [`ListenerEvent`]s from the base transport are cloned and routed to the outer transport via the
23/// [`ProxyTransport`], with the exception of upgrades.
24///
25/// [`peek`]: https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.peek
26///
27/// For a usage example, have a look at the [TCP-Websocket example](https://github.com/wngr/libp2p-combined-transport/tree/master/examples/tcp-websocket.rs).
28pub struct CombinedTransport<TBase, TOuter>
29where
30    TBase: Transport + Clone,
31    TBase::Error: Send + 'static,
32    TBase::Output: 'static,
33{
34    /// The base transport
35    base: TBase,
36    /// The outer transport, wrapping the base transport
37    outer: TOuter,
38    /// Function pointer to construct the outer transport, given the base transport
39    construct_outer: fn(ProxyTransport<TBase>) -> TOuter,
40    proxy: ProxyTransport<TBase>,
41    /// Function pointer to try upgrading the base transport to the outer transport
42    try_upgrade: MaybeUpgrade<TBase>,
43    map_base_addr_to_outer: fn(Multiaddr) -> Multiaddr,
44}
45
46impl<TBase, TOuter> CombinedTransport<TBase, TOuter>
47where
48    TBase: Transport + Clone,
49    TBase::Error: Send + 'static,
50    TBase::Output: 'static,
51{
52    /// Construct a new combined transport, given a base transport, a function to construct the
53    /// outer transport given the base transport, a function to try the upgrade to the outer
54    /// transport given incoming base connections, and a function to map base addresses to outer
55    /// addresses (if necessary).
56    pub fn new(
57        base: TBase,
58        construct_outer: fn(ProxyTransport<TBase>) -> TOuter,
59        try_upgrade: MaybeUpgrade<TBase>,
60        map_base_addr_to_outer: fn(Multiaddr) -> Multiaddr,
61    ) -> Self {
62        let proxy = ProxyTransport::<TBase>::new(base.clone());
63        let mut proxy_clone = proxy.clone();
64        proxy_clone.pending = proxy.pending.clone();
65        let outer = construct_outer(proxy_clone);
66        Self {
67            base,
68            proxy,
69            outer,
70            construct_outer,
71            try_upgrade,
72            map_base_addr_to_outer,
73        }
74    }
75}
76impl<TBase, TOuter> Clone for CombinedTransport<TBase, TOuter>
77where
78    TBase: Transport + Clone,
79    TBase::Error: Send + 'static,
80    TBase::Output: 'static,
81{
82    fn clone(&self) -> Self {
83        Self::new(
84            self.base.clone(),
85            self.construct_outer,
86            self.try_upgrade,
87            self.map_base_addr_to_outer,
88        )
89    }
90}
91
92type MaybeUpgrade<TBase> =
93    fn(
94        <TBase as Transport>::Output,
95    )
96        -> BoxFuture<'static, Result<<TBase as Transport>::Output, <TBase as Transport>::Output>>;
97
98#[derive(Debug, Copy, Clone)]
99pub enum CombinedError<Base, Outer> {
100    UpgradedToOuterTransport,
101    Base(Base),
102    Outer(Outer),
103}
104impl<A, B> fmt::Display for CombinedError<A, B>
105where
106    A: fmt::Display,
107    B: fmt::Display,
108{
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        match self {
111            CombinedError::Base(a) => a.fmt(f),
112            CombinedError::Outer(b) => b.fmt(f),
113            CombinedError::UpgradedToOuterTransport => write!(f, "Upgraded to outer transport"),
114        }
115    }
116}
117
118impl<A, B> std::error::Error for CombinedError<A, B>
119where
120    A: std::error::Error,
121    B: std::error::Error,
122{
123    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
124        match self {
125            CombinedError::Base(a) => a.source(),
126            CombinedError::Outer(b) => b.source(),
127            CombinedError::UpgradedToOuterTransport => None,
128        }
129    }
130}
131
132impl<TBase, TOuter> Transport for CombinedTransport<TBase, TOuter>
133where
134    TBase: Transport + Clone,
135    TBase::Listener: Send + 'static,
136    TBase::ListenerUpgrade: Send + 'static,
137    TBase::Error: Send + 'static,
138    TBase::Output: Send + 'static,
139    TBase::Dial: Send + 'static,
140    TOuter: Transport,
141    TOuter::Listener: Send + 'static,
142    TOuter::ListenerUpgrade: Send + 'static,
143    TOuter::Error: 'static,
144    TOuter::Output: 'static,
145    TOuter::Dial: Send + 'static,
146{
147    type Output = EitherOutput<TBase::Output, TOuter::Output>;
148
149    type Error = CombinedError<TBase::Error, TOuter::Error>;
150
151    type Listener =
152        BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
153    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
154    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
155
156    fn listen_on(
157        self,
158        addr: libp2p::Multiaddr,
159    ) -> Result<Self::Listener, libp2p::TransportError<Self::Error>>
160    where
161        Self: Sized,
162    {
163        // 1. User calls `listen_on`
164        // 2. Base transport `listen_on` -> returns `TBase::Listener`
165        let base_listener = self
166            .base
167            .listen_on(addr.clone())
168            .map_err(|e| e.map(CombinedError::Base))?;
169        // 3. Create new mpsc::channel, all events emitted by (2) will be cloned and piped into
170        //    this tx, with the exception of the Upgrade event
171        let (mut tx, rx) = mpsc::channel(256);
172        // 4. Move rx into the proxy
173        let x = self.proxy.pending.lock().replace(rx);
174        debug_assert!(x.is_none());
175        // 5. Call listen_on on `TOuter`, which will call listen_on on proxy. Proxy returns tx from
176        //    (4)
177        let outer_listener = self
178            .outer
179            .listen_on((self.map_base_addr_to_outer)(addr))
180            .map_err(|e| e.map(CombinedError::Outer))?;
181        debug_assert!(self.proxy.pending.lock().is_none());
182        // 6. Stream returned by (5) will be joined with the one from (2) and returned from the
183        //    function
184        let upgrader = self.try_upgrade;
185        let combined_listener = stream::select(
186            base_listener
187                .map_ok(move |ev| {
188                    let cloned = match &ev {
189                        ListenerEvent::NewAddress(a) => Some(ListenerEvent::NewAddress(a.clone())),
190                        ListenerEvent::AddressExpired(a) => {
191                            Some(ListenerEvent::AddressExpired(a.clone()))
192                        }
193                        ListenerEvent::Error(_) => None, // Error is only propagated once, namely for the base transport
194                        ListenerEvent::Upgrade { .. } => None,
195                    };
196                    if let Some(ev) = cloned {
197                        tx.start_send(ev).unwrap();
198                    }
199                    let ev = match ev {
200                        ListenerEvent::Upgrade {
201                            upgrade,
202                            local_addr,
203                            remote_addr,
204                        } => {
205                            let local_addr_c = local_addr.clone();
206                            let remote_addr_c = remote_addr.clone();
207                            let mut tx_c = tx.clone();
208                            let upgrade = async move {
209                                match upgrade.await {
210                                    Ok(u) => {
211                                        // We could try to upgrade here; if it works, we emit an
212                                        // error for the base transport, and send the whole event over to
213                                        // the outer transport via `tx`. If the upgrade fails, we just
214                                        // continue.
215                                        match upgrader(u).await {
216                                            Ok(u) => {
217                                                // yay to outer
218                                                tx_c.start_send(ListenerEvent::Upgrade {
219                                                    // FUCK!
220                                                    // TBase::ListenerUpgrade is generic
221                                                    // Maybe this can be TransportProxy::Output?
222                                                    // ok, so the type of `tx` needs to be modified to
223                                                    // accomodate the types of ProxyTransport
224                                                    upgrade: future::ok(u).boxed(),
225                                                    local_addr: local_addr_c,
226                                                    remote_addr: remote_addr_c,
227                                                })
228                                                .expect("Out of sync with proxy");
229                                                Err(CombinedError::UpgradedToOuterTransport)
230                                            }
231                                            Err(u) => {
232                                                // continue
233                                                Ok(EitherOutput::First(u))
234                                            }
235                                        }
236                                    }
237                                    Err(e) => Err(CombinedError::Base(e)),
238                                }
239                            }
240                            .boxed();
241
242                            ListenerEvent::Upgrade {
243                                local_addr,
244                                remote_addr,
245                                upgrade,
246                            }
247                        }
248                        ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a),
249                        ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a),
250                        ListenerEvent::Error(e) => ListenerEvent::Error(e),
251                    };
252
253                    ev.map_err(CombinedError::Base)
254                })
255                .map_err(CombinedError::Base)
256                .boxed(),
257            outer_listener
258                .map_ok(|ev| {
259                    ev.map(|upgrade_fut| {
260                        upgrade_fut
261                            .map_ok(EitherOutput::Second)
262                            .map_err(CombinedError::Outer)
263                            .boxed()
264                    })
265                    .map_err(CombinedError::Outer)
266                })
267                .map_err(CombinedError::Outer)
268                .boxed(),
269        )
270        .boxed();
271        // 7. On an upgrade, check the switch, and route it either via outer or directly out
272        Ok(combined_listener)
273    }
274
275    fn dial(
276        self,
277        addr: libp2p::Multiaddr,
278    ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>>
279    where
280        Self: Sized,
281    {
282        let addr = match self.outer.dial(addr) {
283            Ok(connec) => {
284                return Ok(connec
285                    .map_ok(EitherOutput::Second)
286                    .map_err(CombinedError::Outer)
287                    .boxed())
288            }
289            Err(TransportError::MultiaddrNotSupported(addr)) => addr,
290            Err(TransportError::Other(err)) => {
291                return Err(TransportError::Other(CombinedError::Outer(err)))
292            }
293        };
294
295        let addr = match self.base.dial(addr) {
296            Ok(connec) => {
297                return Ok(connec
298                    .map_ok(EitherOutput::First)
299                    .map_err(CombinedError::Base)
300                    .boxed())
301            }
302            Err(TransportError::MultiaddrNotSupported(addr)) => addr,
303            Err(TransportError::Other(err)) => {
304                return Err(TransportError::Other(CombinedError::Base(err)))
305            }
306        };
307
308        Err(TransportError::MultiaddrNotSupported(addr))
309    }
310
311    fn address_translation(
312        &self,
313        listen: &libp2p::Multiaddr,
314        observed: &libp2p::Multiaddr,
315    ) -> Option<libp2p::Multiaddr> {
316        // Outer probably will call proxy, which will proxy to base
317        self.outer
318            .address_translation(listen, observed)
319            .or_else(|| self.base.address_translation(listen, observed))
320    }
321}
322
323pub struct ProxyTransport<TBase>
324where
325    Self: Transport,
326{
327    _marker: PhantomData<TBase>,
328    // 1-1 relation between [`CombinedTransport`] and [`ProxyTransport`]
329    pub(crate) pending: Arc<
330        Mutex<
331            Option<
332                mpsc::Receiver<
333                    ListenerEvent<<Self as Transport>::ListenerUpgrade, <Self as Transport>::Error>,
334                >,
335            >,
336        >,
337    >,
338    // Clone of TBase for dialing
339    base: TBase,
340}
341
342// TODO: simplify all those trait bounds
343impl<TBase> Clone for ProxyTransport<TBase>
344where
345    TBase: Transport + Clone,
346    TBase::Output: 'static,
347    TBase::Error: Send + 'static,
348{
349    fn clone(&self) -> Self {
350        Self {
351            _marker: Default::default(),
352            pending: Default::default(),
353            base: self.base.clone(),
354        }
355    }
356}
357
358impl<TBase> ProxyTransport<TBase>
359where
360    TBase: Transport + Clone,
361    TBase::Output: 'static,
362    TBase::Error: Send + 'static,
363{
364    fn new(base: TBase) -> Self {
365        Self {
366            pending: Default::default(),
367            _marker: Default::default(),
368            base,
369        }
370    }
371}
372
373impl<TBase> Transport for ProxyTransport<TBase>
374where
375    TBase: Transport + Clone,
376    TBase::Output: 'static,
377    TBase::Error: Send + 'static,
378{
379    type Output = TBase::Output;
380
381    type Error = TBase::Error;
382
383    type Listener =
384        BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
385
386    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
387
388    type Dial = TBase::Dial;
389
390    fn listen_on(
391        self,
392        _addr: libp2p::Multiaddr,
393    ) -> Result<Self::Listener, libp2p::TransportError<Self::Error>>
394    where
395        Self: Sized,
396    {
397        let listener = self
398            .pending
399            .lock()
400            .take()
401            .expect("Only called after successful base listen");
402        Ok(listener.map(Ok).boxed())
403    }
404
405    fn dial(
406        self,
407        addr: libp2p::Multiaddr,
408    ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>>
409    where
410        Self: Sized,
411    {
412        self.base.dial(addr)
413    }
414
415    fn address_translation(
416        &self,
417        listen: &libp2p::Multiaddr,
418        observed: &libp2p::Multiaddr,
419    ) -> Option<libp2p::Multiaddr> {
420        self.base.address_translation(listen, observed)
421    }
422}