magic_wormhole/
forwarding.rs

1//! Client-to-Client protocol to forward TCP connections
2//!
3//! This is a new (and still slightly experimental feature) that allows you to forward TCP connections over a wormhole
4//! `transit` connection.
5//!
6//! It is bound to an [`APPID`], which is distinct to the one used for file transfer. Therefore, the codes used
7//! for port forwarding are in an independent namespace than those for sending files.
8//!
9//! At its core, "peer messages" are exchanged over an established wormhole connection with the other side.
10//! They are used to set up a [`transit`] portal that will be used instead of the wormhole connection, which will be closed.
11//! Connections are tracked via an identifier, and multiplexed over the transit channel. The forwarding is
12//! "logical" and not "raw"; because "TCP in TCP" tunneling is known to be problematic. Packages are sent
13//! and received as they come in, no additional buffering is applied. (Under the assumption that those applications
14//! that need buffering already do it on their side, and those who don't, don't.)
15
16#![allow(deprecated)]
17
18use super::*;
19use async_std::net::{TcpListener, TcpStream};
20use futures::{AsyncReadExt, AsyncWriteExt, Future, SinkExt, StreamExt, TryStreamExt};
21use serde::{Deserialize, Serialize};
22use std::{
23    borrow::Cow,
24    collections::{HashMap, HashSet},
25    rc::Rc,
26    sync::Arc,
27};
28use transit::{TransitConnectError, TransitError};
29
30const APPID_RAW: &str = "piegames.de/wormhole/port-forwarding";
31
32/// The App ID associated with this protocol.
33pub const APPID: AppID = AppID(Cow::Borrowed(APPID_RAW));
34
35/// An [`crate::AppConfig`] with sane defaults for this protocol.
36///
37/// You **must not** change `id` and `rendezvous_url` to be interoperable.
38/// The `app_version` can be adjusted if you want to disable some features.
39pub const APP_CONFIG: crate::AppConfig<AppVersion> = crate::AppConfig::<AppVersion> {
40    id: AppID(Cow::Borrowed(APPID_RAW)),
41    rendezvous_url: Cow::Borrowed(crate::rendezvous::DEFAULT_RENDEZVOUS_SERVER),
42    app_version: AppVersion {
43        transit_abilities: transit::Abilities::ALL_ABILITIES,
44        other: serde_json::Value::Null,
45    },
46};
47
48/**
49 * The application specific version information for this protocol.
50 */
51#[derive(Clone, Debug, Default, Serialize, Deserialize)]
52pub struct AppVersion {
53    /// Our transit abilities
54    pub transit_abilities: transit::Abilities,
55    #[serde(flatten)]
56    other: serde_json::Value,
57}
58
59#[derive(Debug, thiserror::Error)]
60#[non_exhaustive]
61/// An error occurred when establishing a port forwarding session
62pub enum ForwardingError {
63    /// Transfer was not acknowledged by peer
64    #[error("Transfer was not acknowledged by peer")]
65    AckError,
66    /// Something went wrong on the other side
67    #[error("Something went wrong on the other side: {}", _0)]
68    PeerError(String),
69    /// Some deserialization went wrong, we probably got some garbage
70    #[error("Corrupt JSON message received")]
71    ProtocolJson(
72        #[from]
73        #[source]
74        serde_json::Error,
75    ),
76    /// Some deserialization went wrong, we probably got some garbage
77    #[error("Corrupt Msgpack message received")]
78    ProtocolMsgpack(
79        #[from]
80        #[source]
81        rmp_serde::decode::Error,
82    ),
83    /// A generic string message for "something went wrong", i.e.
84    /// the server sent some bullshit message order
85    #[error("Protocol error: {}", _0)]
86    Protocol(Box<str>),
87    /// Unexpected message (protocol error)
88    #[error(
89        "Unexpected message (protocol error): Expected '{}', but got: {:?}",
90        _0,
91        _1
92    )]
93    ProtocolUnexpectedMessage(Box<str>, Box<dyn std::fmt::Debug + Send + Sync>),
94    /// Wormhole connection error
95    #[error("Wormhole connection error")]
96    Wormhole(
97        #[from]
98        #[source]
99        WormholeError,
100    ),
101    /// Error while establishing transit connection
102    #[error("Error while establishing transit connection")]
103    TransitConnect(
104        #[from]
105        #[source]
106        TransitConnectError,
107    ),
108    /// Transit error
109    #[error("Transit error")]
110    Transit(
111        #[from]
112        #[source]
113        TransitError,
114    ),
115    /// I/O error
116    #[error("I/O error")]
117    IO(
118        #[from]
119        #[source]
120        std::io::Error,
121    ),
122}
123
124impl ForwardingError {
125    fn protocol(message: impl Into<Box<str>>) -> Self {
126        Self::Protocol(message.into())
127    }
128
129    pub(self) fn unexpected_message(
130        expected: impl Into<Box<str>>,
131        got: impl std::fmt::Debug + Send + Sync + 'static,
132    ) -> Self {
133        Self::ProtocolUnexpectedMessage(expected.into(), Box::new(got))
134    }
135}
136
137/// Offer to forward some ports
138///
139/// `targets` is a mapping of (host, port) pairs. If no target host is provided, then
140/// a local port will be forwarded (`localhost`). Forwarding remote ports only works well
141/// when the protocol being forwarded is not host-aware. HTTP, for example, is host aware.
142///
143/// The port forwarding will run until an error occurs, the peer terminates the connection
144/// or `cancel` resolves. The last one can be used to provide timeouts or to inject CTRL-C
145/// handling. If you want the forward to never (successfully) stop, pass [`futures::future::pending()`]
146/// as the value.
147pub async fn serve(
148    mut wormhole: Wormhole,
149    transit_handler: impl FnOnce(transit::TransitInfo),
150    relay_hints: Vec<transit::RelayHint>,
151    targets: Vec<(Option<url::Host>, u16)>,
152    cancel: impl Future<Output = ()>,
153) -> Result<(), ForwardingError> {
154    assert!(
155        !targets.is_empty(),
156        "The list of target ports must not be empty"
157    );
158
159    let our_version: &AppVersion = wormhole
160        .our_version()
161        .downcast_ref()
162        .expect("You may only use a Wormhole instance with the correct AppVersion type!");
163    let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
164    let connector = transit::init(
165        our_version.transit_abilities,
166        Some(peer_version.transit_abilities),
167        relay_hints,
168    )
169    .await?;
170
171    /* Send our transit hints */
172    wormhole
173        .send_json(&PeerMessage::Transit {
174            hints: (**connector.our_hints()).clone(),
175        })
176        .await?;
177
178    let targets: HashMap<String, (Option<url::Host>, u16)> = targets
179        .into_iter()
180        .map(|(host, port)| match host {
181            Some(host) => {
182                if port == 80 || port == 443 || port == 8000 || port == 8080 {
183                    tracing::warn!("It seems like you are trying to forward a remote HTTP target ('{}'). Due to HTTP being host-aware this will very likely fail!", host);
184                }
185                (format!("{}:{}", host, port), (Some(host), port))
186            },
187            None => (port.to_string(), (host, port)),
188        })
189        .collect();
190
191    /* Receive their transit hints */
192    let their_hints: transit::Hints = match wormhole.receive_json().await?? {
193        PeerMessage::Transit { hints } => {
194            tracing::debug!("Received transit message: {:?}", hints);
195            hints
196        },
197        PeerMessage::Error(err) => {
198            bail!(ForwardingError::PeerError(err));
199        },
200        other => {
201            let error = ForwardingError::unexpected_message("transit", other);
202            let _ = wormhole
203                .send_json(&PeerMessage::Error(format!("{}", error)))
204                .await;
205            bail!(error)
206        },
207    };
208
209    let (mut transit, info) = match connector
210        .leader_connect(
211            wormhole.key().derive_transit_key(wormhole.appid()),
212            peer_version.transit_abilities,
213            Arc::new(their_hints),
214        )
215        .await
216    {
217        Ok(transit) => transit,
218        Err(error) => {
219            let error = ForwardingError::TransitConnect(error);
220            let _ = wormhole
221                .send_json(&PeerMessage::Error(format!("{}", error)))
222                .await;
223            return Err(error);
224        },
225    };
226    transit_handler(info);
227
228    /* We got a transit, now close the Wormhole */
229    wormhole.close().await?;
230
231    transit
232        .send_record(
233            &PeerMessage::Offer {
234                addresses: targets.keys().cloned().collect(),
235            }
236            .ser_msgpack(),
237        )
238        .await?;
239
240    let (backchannel_tx, backchannel_rx) =
241        futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
242
243    let (transit_tx, transit_rx) = transit.split();
244    let transit_rx = transit_rx.fuse();
245    use futures::future::FutureExt;
246    let cancel = cancel.fuse();
247    futures::pin_mut!(transit_tx);
248    futures::pin_mut!(transit_rx);
249    futures::pin_mut!(cancel);
250
251    /* Main processing loop. Catch errors */
252    let result = ForwardingServe {
253        targets,
254        connections: HashMap::new(),
255        historic_connections: HashSet::new(),
256        backchannel_tx,
257        backchannel_rx,
258    }
259    .run(&mut transit_tx, &mut transit_rx, &mut cancel)
260    .await;
261    /* If the error is not a PeerError (i.e. coming from the other side), try notifying the other side before quitting. */
262    match result {
263        Ok(()) => Ok(()),
264        Err(error @ ForwardingError::PeerError(_)) => Err(error),
265        Err(error) => {
266            let _ = transit_tx
267                .send(
268                    PeerMessage::Error(format!("{}", error))
269                        .ser_msgpack()
270                        .into_boxed_slice(),
271                )
272                .await;
273            Err(error)
274        },
275    }
276}
277
278struct ForwardingServe {
279    targets: HashMap<String, (Option<url::Host>, u16)>,
280    /* self => remote */
281    connections: HashMap<
282        u64,
283        (
284            async_std::task::JoinHandle<()>,
285            futures::io::WriteHalf<TcpStream>,
286        ),
287    >,
288    /* Track old connection IDs that won't be reused again. This is to distinguish race hazards where
289     * one side closes a connection while the other one accesses it simultaneously. Despite the name, the
290     * set also includes connections that are currently live.
291     */
292    historic_connections: HashSet<u64>,
293    /* remote => self. (connection_id, Some=payload or None=close) */
294    backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
295    backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
296}
297
298//futures::pin_mut!(backchannel_rx);
299impl ForwardingServe {
300    async fn forward(
301        &mut self,
302        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
303        connection_id: u64,
304        payload: &[u8],
305    ) -> Result<(), ForwardingError> {
306        tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
307        match self.connections.get_mut(&connection_id) {
308            Some((_worker, connection)) => {
309                /* On an error, log for the user and then terminate that connection */
310                if let Err(e) = connection.write_all(payload).await {
311                    tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
312                    self.remove_connection(transit_tx, connection_id, true)
313                        .await?;
314                }
315            },
316            None if !self.historic_connections.contains(&connection_id) => {
317                bail!(ForwardingError::protocol(format!(
318                    "Connection '{}' not found",
319                    connection_id
320                )));
321            },
322            None => { /* Race hazard. Do nothing. */ },
323        }
324        Ok(())
325    }
326
327    async fn remove_connection(
328        &mut self,
329        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
330        connection_id: u64,
331        tell_peer: bool,
332    ) -> Result<(), ForwardingError> {
333        tracing::debug!("Removing connection: #{}", connection_id);
334        if tell_peer {
335            transit_tx
336                .send(
337                    PeerMessage::Disconnect { connection_id }
338                        .ser_msgpack()
339                        .into_boxed_slice(),
340                )
341                .await?;
342        }
343        match self.connections.remove(&connection_id) {
344            Some((worker, _connection)) => {
345                worker.cancel().await;
346            },
347            None if !self.historic_connections.contains(&connection_id) => {
348                bail!(ForwardingError::protocol(format!(
349                    "Connection '{}' not found",
350                    connection_id
351                )));
352            },
353            None => { /* Race hazard. Do nothing. */ },
354        }
355        Ok(())
356    }
357
358    async fn spawn_connection(
359        &mut self,
360        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
361        mut target: String,
362        connection_id: u64,
363    ) -> Result<(), ForwardingError> {
364        tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
365
366        use std::collections::hash_map::Entry;
367        let entry = match self.connections.entry(connection_id) {
368            Entry::Vacant(entry) => entry,
369            Entry::Occupied(_) => {
370                bail!(ForwardingError::protocol(format!(
371                    "Connection '{}' already exists",
372                    connection_id
373                )));
374            },
375        };
376
377        let (host, port) = self.targets.get(&target).unwrap();
378        if host.is_none() {
379            target = format!("[::1]:{}", port);
380        }
381        let stream = match TcpStream::connect(&target).await {
382            Ok(stream) => stream,
383            Err(err) => {
384                tracing::warn!(
385                    "Cannot open connection to {}: {}. The forwarded service might be down.",
386                    target,
387                    err
388                );
389                transit_tx
390                    .send(
391                        PeerMessage::Disconnect { connection_id }
392                            .ser_msgpack()
393                            .into_boxed_slice(),
394                    )
395                    .await?;
396                return Ok(());
397            },
398        };
399        let (mut connection_rd, connection_wr) = stream.split();
400        let mut backchannel_tx = self.backchannel_tx.clone();
401        let worker = async_std::task::spawn_local(async move {
402            let mut buffer = vec![0; 4096];
403            /* Ignore errors */
404            macro_rules! break_on_err {
405                ($expr:expr) => {
406                    match $expr {
407                        Ok(val) => val,
408                        Err(_) => break,
409                    }
410                };
411            }
412            #[allow(clippy::while_let_loop)]
413            loop {
414                let read = break_on_err!(connection_rd.read(&mut buffer).await);
415                if read == 0 {
416                    break;
417                }
418                let buffer = &buffer[..read];
419                break_on_err!(
420                    backchannel_tx
421                        .send((connection_id, Some(buffer.to_vec())))
422                        .await
423                );
424            }
425            /* Close connection (maybe or not because of error) */
426            let _ = backchannel_tx.send((connection_id, None)).await;
427            backchannel_tx.disconnect();
428        });
429        entry.insert((worker, connection_wr));
430        Ok(())
431    }
432
433    async fn shutdown(self) {
434        tracing::debug!("Shutting down everything");
435        for (worker, _connection) in self.connections.into_values() {
436            worker.cancel().await;
437        }
438    }
439
440    async fn run(
441        mut self,
442        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
443        transit_rx: &mut (impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>>
444                  + Unpin),
445        cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
446    ) -> Result<(), ForwardingError> {
447        /* Event processing loop */
448        tracing::debug!("Entered processing loop");
449        let ret = loop {
450            futures::select! {
451                message = transit_rx.next() => {
452                    match PeerMessage::de_msgpack(&message.unwrap()?)? {
453                        PeerMessage::Forward { connection_id, payload } => {
454                            self.forward(transit_tx, connection_id, &payload).await?
455                        },
456                        PeerMessage::Connect { target, connection_id } => {
457                            /* No matter what happens, as soon as we receive the "connect" command that ID is burned. */
458                            self.historic_connections.insert(connection_id);
459                            ensure!(
460                                self.targets.contains_key(&target),
461                                ForwardingError::protocol(format!("We don't know forwarding target '{}'", target)),
462                            );
463
464                            self.spawn_connection(transit_tx, target, connection_id).await?;
465                        },
466                        PeerMessage::Disconnect { connection_id } => {
467                            self.remove_connection(transit_tx, connection_id, false).await?;
468                        },
469                        PeerMessage::Close => {
470                            tracing::info!("Peer gracefully closed connection");
471                            self.shutdown().await;
472                            break Ok(());
473                        },
474                        PeerMessage::Error(err) => {
475                            self.shutdown().await;
476                            bail!(ForwardingError::PeerError(err));
477                        },
478                        other => {
479                            self.shutdown().await;
480                            bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
481                        },
482                    }
483                },
484                message = self.backchannel_rx.next() => {
485                    /* This channel will never run dry, since we always have at least one sender active */
486                    match message.unwrap() {
487                        (connection_id, Some(payload)) => {
488                            transit_tx.send(
489                                PeerMessage::Forward {
490                                    connection_id,
491                                    payload
492                                }
493                                .ser_msgpack()
494                                .into_boxed_slice()
495                            ).await?;
496                        },
497                        (connection_id, None) => {
498                            self.remove_connection(transit_tx, connection_id, true).await?;
499                        },
500                    }
501                },
502                /* We are done */
503                () = &mut *cancel => {
504                    tracing::info!("Closing connection");
505                    transit_tx.send(
506                        PeerMessage::Close.ser_msgpack()
507                        .into_boxed_slice()
508                    )
509                    .await?;
510                    transit_tx.close().await?;
511                    self.shutdown().await;
512                    break Ok(());
513                },
514            }
515        };
516        tracing::debug!("Exited processing loop");
517        ret
518    }
519}
520
521/// Request a port forwarding offer from the other side
522///
523/// You can optionally specify a `bind_address` where the port forwarding
524/// will be made available. You can also specify a list of `custom_ports` that
525/// will be used for the forwarding. The mapping between custom ports and forwarded
526/// targets is 1:1 and order preserving. If more ports are forwarded than custom
527/// ports were specified, then the remaining ports will be arbitrary.
528///
529/// The method returns a [`ConnectOffer`] from which the resulting port mapping can
530/// be queried. That struct also has an `accept` and `reject` method, of which one
531/// must be used.
532///
533/// This method already binds to all the necessary ports up-front. To limit abuse potential
534/// no more than 1024 ports may be forwarded at once.
535pub async fn connect(
536    mut wormhole: Wormhole,
537    transit_handler: impl FnOnce(transit::TransitInfo),
538    relay_hints: Vec<transit::RelayHint>,
539    bind_address: Option<std::net::IpAddr>,
540    custom_ports: &[u16],
541) -> Result<ConnectOffer, ForwardingError> {
542    let our_version: &AppVersion = wormhole
543        .our_version()
544        .downcast_ref()
545        .expect("You may only use a Wormhole instance with the correct AppVersion type!");
546    let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
547    let connector = transit::init(
548        our_version.transit_abilities,
549        Some(peer_version.transit_abilities),
550        relay_hints,
551    )
552    .await?;
553    let bind_address = bind_address.unwrap_or_else(|| std::net::IpAddr::V6("::".parse().unwrap()));
554
555    /* Send our transit hints */
556    wormhole
557        .send_json(&PeerMessage::Transit {
558            hints: (**connector.our_hints()).clone(),
559        })
560        .await?;
561
562    /* Receive their transit hints */
563    let their_hints: transit::Hints = match wormhole.receive_json().await?? {
564        PeerMessage::Transit { hints } => {
565            tracing::debug!("Received transit message: {:?}", hints);
566            hints
567        },
568        PeerMessage::Error(err) => {
569            bail!(ForwardingError::PeerError(err));
570        },
571        other => {
572            let error = ForwardingError::unexpected_message("transit", other);
573            let _ = wormhole
574                .send_json(&PeerMessage::Error(format!("{}", error)))
575                .await;
576            bail!(error)
577        },
578    };
579
580    let (mut transit, info) = match connector
581        .follower_connect(
582            wormhole.key().derive_transit_key(wormhole.appid()),
583            peer_version.transit_abilities,
584            Arc::new(their_hints),
585        )
586        .await
587    {
588        Ok(transit) => transit,
589        Err(error) => {
590            let error = ForwardingError::TransitConnect(error);
591            let _ = wormhole
592                .send_json(&PeerMessage::Error(format!("{}", error)))
593                .await;
594            return Err(error);
595        },
596    };
597    transit_handler(info);
598
599    /* We got a transit, now close the Wormhole */
600    wormhole.close().await?;
601
602    let run = async {
603        /* Receive offer and ask user */
604
605        let addresses = match PeerMessage::de_msgpack(&transit.receive_record().await?)? {
606            PeerMessage::Offer { addresses } => addresses,
607            PeerMessage::Error(err) => {
608                bail!(ForwardingError::PeerError(err));
609            },
610            other => {
611                bail!(ForwardingError::unexpected_message("offer", other))
612            },
613        };
614
615        /* Sanity check on untrusted input */
616        if addresses.len() > 1024 {
617            return Err(ForwardingError::protocol("Too many forwarded ports"));
618        }
619
620        /* self => remote
621         *                  (address, connection)
622         * Vec<Stream<Item = (String, TcpStream)>>
623         */
624        let listeners: Vec<(
625            async_std::net::TcpListener,
626            u16,
627            std::rc::Rc<std::string::String>,
628        )> = futures::stream::iter(
629            addresses
630                .into_iter()
631                .map(Rc::new)
632                .zip(custom_ports.iter().copied().chain(std::iter::repeat(0))),
633        )
634        .then(|(address, port)| async move {
635            let connection = TcpListener::bind((bind_address, port)).await?;
636            let port = connection.local_addr()?.port();
637            Result::<_, std::io::Error>::Ok((connection, port, address))
638        })
639        .try_collect()
640        .await?;
641        Ok(listeners)
642    };
643
644    match run.await {
645        Ok(listeners) => Ok(ConnectOffer {
646            transit,
647            mapping: listeners.iter().map(|(_, b, c)| (*b, c.clone())).collect(),
648            listeners,
649        }),
650        Err(error @ ForwardingError::PeerError(_)) => Err(error),
651        Err(error) => {
652            let _ = transit
653                .send_record(&PeerMessage::Error(format!("{}", error)).ser_msgpack())
654                .await;
655            Err(error)
656        },
657    }
658}
659
660/// A pending forwarding offer from the other side
661///
662/// You *should* consume this object, either by calling [`accept`](ConnectOffer::accept) or [`reject`](ConnectOffer::reject).
663#[must_use]
664pub struct ConnectOffer {
665    /// The offered port mapping
666    pub mapping: Vec<(u16, Rc<String>)>,
667    transit: transit::Transit,
668    listeners: Vec<(
669        async_std::net::TcpListener,
670        u16,
671        std::rc::Rc<std::string::String>,
672    )>,
673}
674
675impl ConnectOffer {
676    /// Accept the offer and start the forwarding
677    ///
678    /// The method will run until an error occurs, the peer terminates the connection
679    /// or `cancel` resolves. The last one can be used to provide timeouts or to inject CTRL-C
680    /// handling. If you want the forward to never (successfully) stop, pass [`futures::future::pending()`]
681    /// as the value.
682    pub async fn accept(self, cancel: impl Future<Output = ()>) -> Result<(), ForwardingError> {
683        let (transit_tx, transit_rx) = self.transit.split();
684        let transit_rx = transit_rx.fuse();
685        use futures::FutureExt;
686        let cancel = cancel.fuse();
687        futures::pin_mut!(transit_tx);
688        futures::pin_mut!(transit_rx);
689        futures::pin_mut!(cancel);
690
691        /* Error handling catcher (see below) */
692        let run = async {
693            let (backchannel_tx, backchannel_rx) =
694                futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
695
696            ForwardConnect {
697                incoming: futures::stream::select_all(self.listeners.into_iter().map(
698                    |(connection, _, address)| {
699                        connection
700                            .into_incoming()
701                            .map_ok(move |stream| (address.clone(), stream))
702                            .boxed_local()
703                    },
704                )),
705                connection_counter: 0,
706                connections: HashMap::new(),
707                backchannel_tx,
708                backchannel_rx,
709            }
710            .run(&mut transit_tx, &mut transit_rx, &mut cancel)
711            .await
712        };
713
714        match run.await {
715            Ok(()) => Ok(()),
716            Err(error @ ForwardingError::PeerError(_)) => Err(error),
717            Err(error) => {
718                let _ = transit_tx
719                    .send(
720                        PeerMessage::Error(format!("{}", error))
721                            .ser_msgpack()
722                            .into_boxed_slice(),
723                    )
724                    .await;
725                Err(error)
726            },
727        }
728    }
729
730    /// Reject the offer
731    ///
732    /// This will send an error message to the other side so that it knows the transfer failed.
733    pub async fn reject(mut self) -> Result<(), ForwardingError> {
734        self.transit
735            .send_record(&PeerMessage::Error("transfer rejected".into()).ser_msgpack())
736            .await?;
737
738        Ok(())
739    }
740}
741
742#[allow(clippy::type_complexity)]
743struct ForwardConnect {
744    //transit: &'a mut transit::Transit,
745    /* when can I finally store an `impl Trait` in a struct? */
746    incoming: futures::stream::SelectAll<
747        futures::stream::LocalBoxStream<
748            'static,
749            Result<(Rc<String>, async_std::net::TcpStream), std::io::Error>,
750        >,
751    >,
752    /* Our next unique connection_id */
753    connection_counter: u64,
754    connections: HashMap<
755        u64,
756        (
757            async_std::task::JoinHandle<()>,
758            futures::io::WriteHalf<TcpStream>,
759        ),
760    >,
761    /* application => self. (connection_id, Some=payload or None=close) */
762    backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
763    backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
764}
765
766impl ForwardConnect {
767    async fn forward(
768        &mut self,
769        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
770        connection_id: u64,
771        payload: &[u8],
772    ) -> Result<(), ForwardingError> {
773        tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
774        match self.connections.get_mut(&connection_id) {
775            Some((_worker, connection)) => {
776                /* On an error, log for the user and then terminate that connection */
777                if let Err(e) = connection.write_all(payload).await {
778                    tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
779                    self.remove_connection(transit_tx, connection_id, true)
780                        .await?;
781                }
782            },
783            None if self.connection_counter <= connection_id => {
784                bail!(ForwardingError::protocol(format!(
785                    "Connection '{}' not found",
786                    connection_id
787                )));
788            },
789            None => { /* Race hazard. Do nothing. */ },
790        }
791        Ok(())
792    }
793
794    async fn remove_connection(
795        &mut self,
796        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
797        connection_id: u64,
798        tell_peer: bool,
799    ) -> Result<(), ForwardingError> {
800        tracing::debug!("Removing connection: #{}", connection_id);
801        if tell_peer {
802            transit_tx
803                .send(
804                    PeerMessage::Disconnect { connection_id }
805                        .ser_msgpack()
806                        .into_boxed_slice(),
807                )
808                .await?;
809        }
810        match self.connections.remove(&connection_id) {
811            Some((worker, _connection)) => {
812                worker.cancel().await;
813            },
814            None if connection_id >= self.connection_counter => {
815                bail!(ForwardingError::protocol(format!(
816                    "Connection '{}' not found",
817                    connection_id
818                )));
819            },
820            None => { /* Race hazard. Do nothing. */ },
821        }
822        Ok(())
823    }
824
825    async fn spawn_connection(
826        &mut self,
827        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
828        target: Rc<String>,
829        connection: TcpStream,
830    ) -> Result<(), ForwardingError> {
831        let connection_id = self.connection_counter;
832        self.connection_counter += 1;
833        let (mut connection_rd, connection_wr) = connection.split();
834        let mut backchannel_tx = self.backchannel_tx.clone();
835        tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
836
837        transit_tx
838            .send(
839                PeerMessage::Connect {
840                    target: (*target).clone(),
841                    connection_id,
842                }
843                .ser_msgpack()
844                .into_boxed_slice(),
845            )
846            .await?;
847
848        let worker = async_std::task::spawn_local(async move {
849            let mut buffer = vec![0; 4096];
850            /* Ignore errors */
851            macro_rules! break_on_err {
852                ($expr:expr) => {
853                    match $expr {
854                        Ok(val) => val,
855                        Err(_) => break,
856                    }
857                };
858            }
859            #[allow(clippy::while_let_loop)]
860            loop {
861                let read = break_on_err!(connection_rd.read(&mut buffer).await);
862                if read == 0 {
863                    break;
864                }
865                let buffer = &buffer[..read];
866                break_on_err!(
867                    backchannel_tx
868                        .send((connection_id, Some(buffer.to_vec())))
869                        .await
870                );
871            }
872            /* Close connection (maybe or not because of error) */
873            let _ = backchannel_tx.send((connection_id, None)).await;
874            backchannel_tx.disconnect();
875        });
876
877        self.connections
878            .insert(connection_id, (worker, connection_wr));
879        Ok(())
880    }
881
882    async fn shutdown(self) {
883        tracing::debug!("Shutting down everything");
884        for (worker, _connection) in self.connections.into_values() {
885            worker.cancel().await;
886        }
887    }
888
889    async fn run(
890        mut self,
891        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
892        transit_rx: &mut (impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>>
893                  + Unpin),
894        cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
895    ) -> Result<(), ForwardingError> {
896        /* Event processing loop */
897        tracing::debug!("Entered processing loop");
898        let ret = loop {
899            futures::select! {
900                message = transit_rx.next() => {
901                    match PeerMessage::de_msgpack(&message.unwrap()?)? {
902                        PeerMessage::Forward { connection_id, payload } => {
903                            self.forward(transit_tx, connection_id, &payload).await?;
904                        },
905                        PeerMessage::Disconnect { connection_id } => {
906                            self.remove_connection(transit_tx, connection_id, false).await?;
907                        },
908                        PeerMessage::Close => {
909                            tracing::info!("Peer gracefully closed connection");
910                            self.shutdown().await;
911                            break Ok(())
912                        },
913                        PeerMessage::Error(err) => {
914                            for (worker, _connection) in self.connections.into_values() {
915                                worker.cancel().await;
916                            }
917                            bail!(ForwardingError::PeerError(err));
918                        },
919                        other => {
920                            self.shutdown().await;
921                            bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
922                        },
923                    }
924                },
925                message = self.backchannel_rx.next() => {
926                    /* This channel will never run dry, since we always have at least one sender active */
927                    match message.unwrap() {
928                        (connection_id, Some(payload)) => {
929                            transit_tx.send(
930                                PeerMessage::Forward {
931                                    connection_id,
932                                    payload
933                                }.ser_msgpack()
934                                .into_boxed_slice()
935                            )
936                            .await?;
937                        },
938                        (connection_id, None) => {
939                            self.remove_connection(transit_tx, connection_id, true).await?;
940                        },
941                    }
942                },
943                connection = self.incoming.next() => {
944                    let (target, connection): (Rc<String>, TcpStream) = connection.unwrap()?;
945                    self.spawn_connection(transit_tx, target, connection).await?;
946                },
947                /* We are done */
948                () = &mut *cancel => {
949                    tracing::info!("Closing connection");
950                    transit_tx.send(
951                        PeerMessage::Close.ser_msgpack()
952                        .into_boxed_slice()
953                    )
954                    .await?;
955                    transit_tx.close().await?;
956                    self.shutdown().await;
957                    break Ok(());
958                },
959            }
960        };
961        tracing::debug!("Exited processing loop");
962        ret
963    }
964}
965
966/** Serialization struct for this protocol */
967#[derive(Deserialize, Serialize, Debug)]
968#[serde(rename_all = "kebab-case")]
969#[non_exhaustive]
970enum PeerMessage {
971    /** Offer some destinations to be forwarded to.
972     * forwarder -> forwardee only
973     */
974    Offer { addresses: Vec<String> },
975    /** Forward a new connection.
976     * forwardee -> forwarder only
977     */
978    Connect { target: String, connection_id: u64 },
979    /** End a forwarded connection.
980     * Any direction. Errors or the reason why the connection is closed
981     * are not forwarded.
982     */
983    Disconnect { connection_id: u64 },
984    /** Forward some bytes for a connection. */
985    Forward {
986        connection_id: u64,
987        payload: Vec<u8>,
988    },
989    /** Close the whole session */
990    Close,
991    /** Tell the other side you got an error */
992    Error(String),
993    /** Used to set up a transit channel */
994    Transit { hints: transit::Hints },
995    #[serde(other)]
996    Unknown,
997}
998
999impl PeerMessage {
1000    #[allow(dead_code)]
1001    pub fn ser_msgpack(&self) -> Vec<u8> {
1002        let mut writer = Vec::with_capacity(128);
1003        let mut ser = rmp_serde::encode::Serializer::new(&mut writer)
1004            .with_struct_map()
1005            .with_human_readable();
1006        serde::Serialize::serialize(self, &mut ser).unwrap();
1007        writer
1008    }
1009
1010    #[allow(dead_code)]
1011    pub fn de_msgpack(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
1012        rmp_serde::from_read(&mut &*data)
1013    }
1014}