1#![allow(clippy::type_complexity)]
2use std::{fmt, marker::PhantomData, sync::Arc};
3
4use futures::{
5 channel::mpsc,
6 future::{self, BoxFuture},
7 stream::{self, BoxStream},
8 FutureExt, StreamExt, TryFutureExt, TryStreamExt,
9};
10use libp2p::{
11 core::{either::EitherOutput, transport::ListenerEvent},
12 Multiaddr, Transport, TransportError,
13};
14use parking_lot::Mutex;
15
16pub struct CombinedTransport<TBase, TOuter>
29where
30 TBase: Transport + Clone,
31 TBase::Error: Send + 'static,
32 TBase::Output: 'static,
33{
34 base: TBase,
36 outer: TOuter,
38 construct_outer: fn(ProxyTransport<TBase>) -> TOuter,
40 proxy: ProxyTransport<TBase>,
41 try_upgrade: MaybeUpgrade<TBase>,
43 map_base_addr_to_outer: fn(Multiaddr) -> Multiaddr,
44}
45
46impl<TBase, TOuter> CombinedTransport<TBase, TOuter>
47where
48 TBase: Transport + Clone,
49 TBase::Error: Send + 'static,
50 TBase::Output: 'static,
51{
52 pub fn new(
57 base: TBase,
58 construct_outer: fn(ProxyTransport<TBase>) -> TOuter,
59 try_upgrade: MaybeUpgrade<TBase>,
60 map_base_addr_to_outer: fn(Multiaddr) -> Multiaddr,
61 ) -> Self {
62 let proxy = ProxyTransport::<TBase>::new(base.clone());
63 let mut proxy_clone = proxy.clone();
64 proxy_clone.pending = proxy.pending.clone();
65 let outer = construct_outer(proxy_clone);
66 Self {
67 base,
68 proxy,
69 outer,
70 construct_outer,
71 try_upgrade,
72 map_base_addr_to_outer,
73 }
74 }
75}
76impl<TBase, TOuter> Clone for CombinedTransport<TBase, TOuter>
77where
78 TBase: Transport + Clone,
79 TBase::Error: Send + 'static,
80 TBase::Output: 'static,
81{
82 fn clone(&self) -> Self {
83 Self::new(
84 self.base.clone(),
85 self.construct_outer,
86 self.try_upgrade,
87 self.map_base_addr_to_outer,
88 )
89 }
90}
91
92type MaybeUpgrade<TBase> =
93 fn(
94 <TBase as Transport>::Output,
95 )
96 -> BoxFuture<'static, Result<<TBase as Transport>::Output, <TBase as Transport>::Output>>;
97
98#[derive(Debug, Copy, Clone)]
99pub enum CombinedError<Base, Outer> {
100 UpgradedToOuterTransport,
101 Base(Base),
102 Outer(Outer),
103}
104impl<A, B> fmt::Display for CombinedError<A, B>
105where
106 A: fmt::Display,
107 B: fmt::Display,
108{
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self {
111 CombinedError::Base(a) => a.fmt(f),
112 CombinedError::Outer(b) => b.fmt(f),
113 CombinedError::UpgradedToOuterTransport => write!(f, "Upgraded to outer transport"),
114 }
115 }
116}
117
118impl<A, B> std::error::Error for CombinedError<A, B>
119where
120 A: std::error::Error,
121 B: std::error::Error,
122{
123 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
124 match self {
125 CombinedError::Base(a) => a.source(),
126 CombinedError::Outer(b) => b.source(),
127 CombinedError::UpgradedToOuterTransport => None,
128 }
129 }
130}
131
132impl<TBase, TOuter> Transport for CombinedTransport<TBase, TOuter>
133where
134 TBase: Transport + Clone,
135 TBase::Listener: Send + 'static,
136 TBase::ListenerUpgrade: Send + 'static,
137 TBase::Error: Send + 'static,
138 TBase::Output: Send + 'static,
139 TBase::Dial: Send + 'static,
140 TOuter: Transport,
141 TOuter::Listener: Send + 'static,
142 TOuter::ListenerUpgrade: Send + 'static,
143 TOuter::Error: 'static,
144 TOuter::Output: 'static,
145 TOuter::Dial: Send + 'static,
146{
147 type Output = EitherOutput<TBase::Output, TOuter::Output>;
148
149 type Error = CombinedError<TBase::Error, TOuter::Error>;
150
151 type Listener =
152 BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
153 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
154 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
155
156 fn listen_on(
157 self,
158 addr: libp2p::Multiaddr,
159 ) -> Result<Self::Listener, libp2p::TransportError<Self::Error>>
160 where
161 Self: Sized,
162 {
163 let base_listener = self
166 .base
167 .listen_on(addr.clone())
168 .map_err(|e| e.map(CombinedError::Base))?;
169 let (mut tx, rx) = mpsc::channel(256);
172 let x = self.proxy.pending.lock().replace(rx);
174 debug_assert!(x.is_none());
175 let outer_listener = self
178 .outer
179 .listen_on((self.map_base_addr_to_outer)(addr))
180 .map_err(|e| e.map(CombinedError::Outer))?;
181 debug_assert!(self.proxy.pending.lock().is_none());
182 let upgrader = self.try_upgrade;
185 let combined_listener = stream::select(
186 base_listener
187 .map_ok(move |ev| {
188 let cloned = match &ev {
189 ListenerEvent::NewAddress(a) => Some(ListenerEvent::NewAddress(a.clone())),
190 ListenerEvent::AddressExpired(a) => {
191 Some(ListenerEvent::AddressExpired(a.clone()))
192 }
193 ListenerEvent::Error(_) => None, ListenerEvent::Upgrade { .. } => None,
195 };
196 if let Some(ev) = cloned {
197 tx.start_send(ev).unwrap();
198 }
199 let ev = match ev {
200 ListenerEvent::Upgrade {
201 upgrade,
202 local_addr,
203 remote_addr,
204 } => {
205 let local_addr_c = local_addr.clone();
206 let remote_addr_c = remote_addr.clone();
207 let mut tx_c = tx.clone();
208 let upgrade = async move {
209 match upgrade.await {
210 Ok(u) => {
211 match upgrader(u).await {
216 Ok(u) => {
217 tx_c.start_send(ListenerEvent::Upgrade {
219 upgrade: future::ok(u).boxed(),
225 local_addr: local_addr_c,
226 remote_addr: remote_addr_c,
227 })
228 .expect("Out of sync with proxy");
229 Err(CombinedError::UpgradedToOuterTransport)
230 }
231 Err(u) => {
232 Ok(EitherOutput::First(u))
234 }
235 }
236 }
237 Err(e) => Err(CombinedError::Base(e)),
238 }
239 }
240 .boxed();
241
242 ListenerEvent::Upgrade {
243 local_addr,
244 remote_addr,
245 upgrade,
246 }
247 }
248 ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a),
249 ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a),
250 ListenerEvent::Error(e) => ListenerEvent::Error(e),
251 };
252
253 ev.map_err(CombinedError::Base)
254 })
255 .map_err(CombinedError::Base)
256 .boxed(),
257 outer_listener
258 .map_ok(|ev| {
259 ev.map(|upgrade_fut| {
260 upgrade_fut
261 .map_ok(EitherOutput::Second)
262 .map_err(CombinedError::Outer)
263 .boxed()
264 })
265 .map_err(CombinedError::Outer)
266 })
267 .map_err(CombinedError::Outer)
268 .boxed(),
269 )
270 .boxed();
271 Ok(combined_listener)
273 }
274
275 fn dial(
276 self,
277 addr: libp2p::Multiaddr,
278 ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>>
279 where
280 Self: Sized,
281 {
282 let addr = match self.outer.dial(addr) {
283 Ok(connec) => {
284 return Ok(connec
285 .map_ok(EitherOutput::Second)
286 .map_err(CombinedError::Outer)
287 .boxed())
288 }
289 Err(TransportError::MultiaddrNotSupported(addr)) => addr,
290 Err(TransportError::Other(err)) => {
291 return Err(TransportError::Other(CombinedError::Outer(err)))
292 }
293 };
294
295 let addr = match self.base.dial(addr) {
296 Ok(connec) => {
297 return Ok(connec
298 .map_ok(EitherOutput::First)
299 .map_err(CombinedError::Base)
300 .boxed())
301 }
302 Err(TransportError::MultiaddrNotSupported(addr)) => addr,
303 Err(TransportError::Other(err)) => {
304 return Err(TransportError::Other(CombinedError::Base(err)))
305 }
306 };
307
308 Err(TransportError::MultiaddrNotSupported(addr))
309 }
310
311 fn address_translation(
312 &self,
313 listen: &libp2p::Multiaddr,
314 observed: &libp2p::Multiaddr,
315 ) -> Option<libp2p::Multiaddr> {
316 self.outer
318 .address_translation(listen, observed)
319 .or_else(|| self.base.address_translation(listen, observed))
320 }
321}
322
323pub struct ProxyTransport<TBase>
324where
325 Self: Transport,
326{
327 _marker: PhantomData<TBase>,
328 pub(crate) pending: Arc<
330 Mutex<
331 Option<
332 mpsc::Receiver<
333 ListenerEvent<<Self as Transport>::ListenerUpgrade, <Self as Transport>::Error>,
334 >,
335 >,
336 >,
337 >,
338 base: TBase,
340}
341
342impl<TBase> Clone for ProxyTransport<TBase>
344where
345 TBase: Transport + Clone,
346 TBase::Output: 'static,
347 TBase::Error: Send + 'static,
348{
349 fn clone(&self) -> Self {
350 Self {
351 _marker: Default::default(),
352 pending: Default::default(),
353 base: self.base.clone(),
354 }
355 }
356}
357
358impl<TBase> ProxyTransport<TBase>
359where
360 TBase: Transport + Clone,
361 TBase::Output: 'static,
362 TBase::Error: Send + 'static,
363{
364 fn new(base: TBase) -> Self {
365 Self {
366 pending: Default::default(),
367 _marker: Default::default(),
368 base,
369 }
370 }
371}
372
373impl<TBase> Transport for ProxyTransport<TBase>
374where
375 TBase: Transport + Clone,
376 TBase::Output: 'static,
377 TBase::Error: Send + 'static,
378{
379 type Output = TBase::Output;
380
381 type Error = TBase::Error;
382
383 type Listener =
384 BoxStream<'static, Result<ListenerEvent<Self::ListenerUpgrade, Self::Error>, Self::Error>>;
385
386 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
387
388 type Dial = TBase::Dial;
389
390 fn listen_on(
391 self,
392 _addr: libp2p::Multiaddr,
393 ) -> Result<Self::Listener, libp2p::TransportError<Self::Error>>
394 where
395 Self: Sized,
396 {
397 let listener = self
398 .pending
399 .lock()
400 .take()
401 .expect("Only called after successful base listen");
402 Ok(listener.map(Ok).boxed())
403 }
404
405 fn dial(
406 self,
407 addr: libp2p::Multiaddr,
408 ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>>
409 where
410 Self: Sized,
411 {
412 self.base.dial(addr)
413 }
414
415 fn address_translation(
416 &self,
417 listen: &libp2p::Multiaddr,
418 observed: &libp2p::Multiaddr,
419 ) -> Option<libp2p::Multiaddr> {
420 self.base.address_translation(listen, observed)
421 }
422}