Skip to main content

magic_wormhole/
forwarding.rs

1//! Client-to-Client protocol to forward TCP connections
2//!
3//! This is a new (and still slightly experimental feature) that allows you to forward TCP connections over a wormhole
4//! `transit` connection.
5//!
6//! It is bound to an [`APPID`], which is distinct to the one used for file transfer. Therefore, the codes used
7//! for port forwarding are in an independent namespace than those for sending files.
8//!
9//! At its core, "peer messages" are exchanged over an established wormhole connection with the other side.
10//! They are used to set up a [`transit`] portal that will be used instead of the wormhole connection, which will be closed.
11//! Connections are tracked via an identifier, and multiplexed over the transit channel. The forwarding is
12//! "logical" and not "raw"; because "TCP in TCP" tunneling is known to be problematic. Packages are sent
13//! and received as they come in, no additional buffering is applied. (Under the assumption that those applications
14//! that need buffering already do it on their side, and those who don't, don't.)
15
16use crate::transit::TransitRole;
17
18use super::*;
19use async_net::TcpListener;
20use futures::{AsyncReadExt, AsyncWriteExt, Future, SinkExt, StreamExt, TryStreamExt};
21use serde_derive::{Deserialize, Serialize};
22use std::{
23    borrow::Cow,
24    collections::{HashMap, HashSet},
25    rc::Rc,
26    sync::Arc,
27};
28use transit::{TransitConnectError, TransitError};
29
30const APPID_RAW: &str = "piegames.de/wormhole/port-forwarding";
31
32/// The App ID associated with this protocol.
33pub const APPID: AppID = AppID(Cow::Borrowed(APPID_RAW));
34
35/// An [`crate::AppConfig`] with sane defaults for this protocol.
36///
37/// You **must not** change `id` and `rendezvous_url` to be interoperable.
38/// The `app_version` can be adjusted if you want to disable some features.
39pub const APP_CONFIG: crate::AppConfig<AppVersion> = crate::AppConfig::<AppVersion> {
40    id: AppID(Cow::Borrowed(APPID_RAW)),
41    rendezvous_url: Cow::Borrowed(crate::rendezvous::DEFAULT_RENDEZVOUS_SERVER),
42    app_version: AppVersion {
43        transit_abilities: transit::Abilities::ALL,
44        other: serde_json::Value::Null,
45    },
46};
47
48/**
49 * The application specific version information for this protocol.
50 */
51#[derive(Clone, Debug, Default, Serialize, Deserialize)]
52pub struct AppVersion {
53    /// Our transit abilities
54    pub transit_abilities: transit::Abilities,
55    #[serde(flatten)]
56    other: serde_json::Value,
57}
58
59#[derive(Debug, thiserror::Error)]
60#[non_exhaustive]
61/// An error occurred when establishing a port forwarding session
62pub enum ForwardingError {
63    /// Transfer was not acknowledged by peer
64    #[error("Transfer was not acknowledged by peer")]
65    AckError,
66    /// Something went wrong on the other side
67    #[error("Something went wrong on the other side: {}", _0)]
68    PeerError(String),
69    /// Some deserialization went wrong, we probably got some garbage
70    #[error("Corrupt JSON message received")]
71    ProtocolJson(
72        #[from]
73        #[source]
74        serde_json::Error,
75    ),
76    /// Some deserialization went wrong, we probably got some garbage
77    #[error("Corrupt Msgpack message received")]
78    ProtocolMsgpack(
79        #[from]
80        #[source]
81        rmp_serde::decode::Error,
82    ),
83    /// A generic string message for "something went wrong", i.e.
84    /// the server sent some bullshit message order
85    #[error("Protocol error: {}", _0)]
86    Protocol(Box<str>),
87    /// Unexpected message (protocol error)
88    #[error(
89        "Unexpected message (protocol error): Expected '{}', but got: {:?}",
90        _0,
91        _1
92    )]
93    ProtocolUnexpectedMessage(Box<str>, Box<dyn std::fmt::Debug + Send + Sync>),
94    /// Wormhole connection error
95    #[error("Wormhole connection error")]
96    Wormhole(
97        #[from]
98        #[source]
99        WormholeError,
100    ),
101    /// Error while establishing transit connection
102    #[error("Error while establishing transit connection")]
103    TransitConnect(
104        #[from]
105        #[source]
106        TransitConnectError,
107    ),
108    /// Transit error
109    #[error("Transit error")]
110    Transit(
111        #[from]
112        #[source]
113        TransitError,
114    ),
115    /// I/O error
116    #[error("I/O error")]
117    IO(
118        #[from]
119        #[source]
120        std::io::Error,
121    ),
122}
123
124impl ForwardingError {
125    fn protocol(message: impl Into<Box<str>>) -> Self {
126        Self::Protocol(message.into())
127    }
128
129    pub(self) fn unexpected_message(
130        expected: impl Into<Box<str>>,
131        got: impl std::fmt::Debug + Send + Sync + 'static,
132    ) -> Self {
133        Self::ProtocolUnexpectedMessage(expected.into(), Box::new(got))
134    }
135}
136
137/// Offer to forward some ports
138///
139/// `targets` is a mapping of (host, port) pairs. If no target host is provided, then
140/// a local port will be forwarded (`localhost`). Forwarding remote ports only works well
141/// when the protocol being forwarded is not host-aware. HTTP, for example, is host aware.
142///
143/// The port forwarding will run until an error occurs, the peer terminates the connection
144/// or `cancel` resolves. The last one can be used to provide timeouts or to inject CTRL-C
145/// handling. If you want the forward to never (successfully) stop, pass [`futures::future::pending()`]
146/// as the value.
147pub async fn serve(
148    mut wormhole: Wormhole,
149    transit_handler: impl FnOnce(transit::TransitInfo),
150    relay_hints: Vec<transit::RelayHint>,
151    targets: Vec<(Option<url::Host>, u16)>,
152    cancel: impl Future<Output = ()>,
153) -> Result<(), ForwardingError> {
154    assert!(
155        !targets.is_empty(),
156        "The list of target ports must not be empty"
157    );
158
159    let our_version: &AppVersion = wormhole
160        .our_version()
161        .downcast_ref()
162        .expect("You may only use a Wormhole instance with the correct AppVersion type!");
163    let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
164    let connector = transit::init(
165        our_version.transit_abilities,
166        Some(peer_version.transit_abilities),
167        relay_hints,
168    )
169    .await?;
170
171    /* Send our transit hints */
172    wormhole
173        .send_json(&PeerMessage::Transit {
174            hints: (**connector.our_hints()).clone(),
175        })
176        .await?;
177
178    let targets: HashMap<String, (Option<url::Host>, u16)> = targets
179        .into_iter()
180        .map(|(host, port)| match host {
181            Some(host) => {
182                if port == 80 || port == 443 || port == 8000 || port == 8080 {
183                    tracing::warn!("It seems like you are trying to forward a remote HTTP target ('{}'). Due to HTTP being host-aware this will very likely fail!", host);
184                }
185                (format!("{host}:{port}"), (Some(host), port))
186            },
187            None => (port.to_string(), (host, port)),
188        })
189        .collect();
190
191    /* Receive their transit hints */
192    let their_hints: transit::Hints = match wormhole.receive_json().await?? {
193        PeerMessage::Transit { hints } => {
194            tracing::debug!("Received transit message: {:?}", hints);
195            hints
196        },
197        PeerMessage::Error(err) => {
198            bail!(ForwardingError::PeerError(err));
199        },
200        other => {
201            let error = ForwardingError::unexpected_message("transit", other);
202            let _ = wormhole
203                .send_json(&PeerMessage::Error(format!("{error}")))
204                .await;
205            bail!(error)
206        },
207    };
208
209    let (mut transit, info) = match connector
210        .connect(
211            TransitRole::Leader,
212            wormhole.key().derive_transit_key(wormhole.appid()),
213            peer_version.transit_abilities,
214            Arc::new(their_hints),
215        )
216        .await
217    {
218        Ok(transit) => transit,
219        Err(error) => {
220            let error = ForwardingError::TransitConnect(error);
221            let _ = wormhole
222                .send_json(&PeerMessage::Error(format!("{error}")))
223                .await;
224            return Err(error);
225        },
226    };
227    transit_handler(info);
228
229    /* We got a transit, now close the Wormhole */
230    wormhole.close().await?;
231
232    transit
233        .send_record(
234            &PeerMessage::Offer {
235                addresses: targets.keys().cloned().collect(),
236            }
237            .ser_msgpack(),
238        )
239        .await?;
240
241    let (backchannel_tx, backchannel_rx) =
242        futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
243
244    let (transit_tx, transit_rx) = transit.split();
245    let transit_rx = transit_rx.fuse();
246    use futures::future::FutureExt;
247    let cancel = cancel.fuse();
248    futures::pin_mut!(transit_tx);
249    futures::pin_mut!(transit_rx);
250    futures::pin_mut!(cancel);
251
252    /* Main processing loop. Catch errors */
253    let result = ForwardingServe {
254        targets,
255        connections: HashMap::new(),
256        historic_connections: HashSet::new(),
257        backchannel_tx,
258        backchannel_rx,
259    }
260    .run(&mut transit_tx, &mut transit_rx, &mut cancel)
261    .await;
262    /* If the error is not a PeerError (i.e. coming from the other side), try notifying the other side before quitting. */
263    match result {
264        Ok(()) => Ok(()),
265        Err(error @ ForwardingError::PeerError(_)) => Err(error),
266        Err(error) => {
267            let _ = transit_tx
268                .send(
269                    PeerMessage::Error(format!("{error}"))
270                        .ser_msgpack()
271                        .into_boxed_slice(),
272                )
273                .await;
274            Err(error)
275        },
276    }
277}
278
279struct ForwardingServe {
280    targets: HashMap<String, (Option<url::Host>, u16)>,
281    /* self => remote */
282    connections: HashMap<
283        u64,
284        (
285            async_task::Task<()>,
286            futures_lite::io::WriteHalf<async_net::TcpStream>,
287        ),
288    >,
289    /* Track old connection IDs that won't be reused again. This is to distinguish race hazards where
290     * one side closes a connection while the other one accesses it simultaneously. Despite the name, the
291     * set also includes connections that are currently live.
292     */
293    historic_connections: HashSet<u64>,
294    /* remote => self. (connection_id, Some=payload or None=close) */
295    backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
296    backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
297}
298
299//futures::pin_mut!(backchannel_rx);
300impl ForwardingServe {
301    async fn forward(
302        &mut self,
303        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
304        connection_id: u64,
305        payload: &[u8],
306    ) -> Result<(), ForwardingError> {
307        tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
308        match self.connections.get_mut(&connection_id) {
309            Some((_worker, connection)) => {
310                /* On an error, log for the user and then terminate that connection */
311                if let Err(e) = connection.write_all(payload).await {
312                    tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
313                    self.remove_connection(transit_tx, connection_id, true)
314                        .await?;
315                }
316            },
317            None if !self.historic_connections.contains(&connection_id) => {
318                bail!(ForwardingError::protocol(format!(
319                    "Connection '{connection_id}' not found"
320                )));
321            },
322            None => { /* Race hazard. Do nothing. */ },
323        }
324        Ok(())
325    }
326
327    async fn remove_connection(
328        &mut self,
329        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
330        connection_id: u64,
331        tell_peer: bool,
332    ) -> Result<(), ForwardingError> {
333        tracing::debug!("Removing connection: #{}", connection_id);
334        if tell_peer {
335            transit_tx
336                .send(
337                    PeerMessage::Disconnect { connection_id }
338                        .ser_msgpack()
339                        .into_boxed_slice(),
340                )
341                .await?;
342        }
343        match self.connections.remove(&connection_id) {
344            Some((worker, _connection)) => {
345                worker.cancel().await;
346            },
347            None if !self.historic_connections.contains(&connection_id) => {
348                bail!(ForwardingError::protocol(format!(
349                    "Connection '{connection_id}' not found"
350                )));
351            },
352            None => { /* Race hazard. Do nothing. */ },
353        }
354        Ok(())
355    }
356
357    async fn spawn_connection(
358        &mut self,
359        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
360        mut target: String,
361        connection_id: u64,
362    ) -> Result<(), ForwardingError> {
363        tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
364
365        use std::collections::hash_map::Entry;
366        let entry = match self.connections.entry(connection_id) {
367            Entry::Vacant(entry) => entry,
368            Entry::Occupied(_) => {
369                bail!(ForwardingError::protocol(format!(
370                    "Connection '{connection_id}' already exists"
371                )));
372            },
373        };
374
375        let (host, port) = self.targets.get(&target).unwrap();
376        if host.is_none() {
377            target = format!("[::1]:{port}");
378        }
379        let stream = match async_net::TcpStream::connect(&target).await {
380            Ok(stream) => stream,
381            Err(err) => {
382                tracing::warn!(
383                    "Cannot open connection to {}: {}. The forwarded service might be down.",
384                    target,
385                    err
386                );
387                transit_tx
388                    .send(
389                        PeerMessage::Disconnect { connection_id }
390                            .ser_msgpack()
391                            .into_boxed_slice(),
392                    )
393                    .await?;
394                return Ok(());
395            },
396        };
397        let (mut connection_rd, connection_wr) = futures_lite::io::split(stream);
398        let mut backchannel_tx = self.backchannel_tx.clone();
399        let worker = crate::util::spawn(async move {
400            let mut buffer = vec![0; 4096];
401            /* Ignore errors */
402            macro_rules! break_on_err {
403                ($expr:expr_2021) => {
404                    match $expr {
405                        Ok(val) => val,
406                        Err(_) => break,
407                    }
408                };
409            }
410            #[expect(clippy::while_let_loop)]
411            loop {
412                let read = break_on_err!(connection_rd.read(&mut buffer).await);
413                if read == 0 {
414                    break;
415                }
416                let buffer = &buffer[..read];
417                break_on_err!(
418                    backchannel_tx
419                        .send((connection_id, Some(buffer.to_vec())))
420                        .await
421                );
422            }
423            /* Close connection (maybe or not because of error) */
424            let _ = backchannel_tx.send((connection_id, None)).await;
425            backchannel_tx.disconnect();
426        });
427        entry.insert((worker, connection_wr));
428        Ok(())
429    }
430
431    async fn shutdown(self) {
432        tracing::debug!("Shutting down everything");
433        for (worker, _connection) in self.connections.into_values() {
434            worker.cancel().await;
435        }
436    }
437
438    async fn run(
439        mut self,
440        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
441        transit_rx: &mut (
442                 impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>> + Unpin
443             ),
444        cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
445    ) -> Result<(), ForwardingError> {
446        /* Event processing loop */
447        tracing::debug!("Entered processing loop");
448        let ret = loop {
449            futures::select! {
450                message = transit_rx.next() => {
451                    match PeerMessage::de_msgpack(&message.unwrap()?)? {
452                        PeerMessage::Forward { connection_id, payload } => {
453                            self.forward(transit_tx, connection_id, &payload).await?
454                        },
455                        PeerMessage::Connect { target, connection_id } => {
456                            /* No matter what happens, as soon as we receive the "connect" command that ID is burned. */
457                            self.historic_connections.insert(connection_id);
458                            ensure!(
459                                self.targets.contains_key(&target),
460                                ForwardingError::protocol(format!("We don't know forwarding target '{target}'")),
461                            );
462
463                            self.spawn_connection(transit_tx, target, connection_id).await?;
464                        },
465                        PeerMessage::Disconnect { connection_id } => {
466                            self.remove_connection(transit_tx, connection_id, false).await?;
467                        },
468                        PeerMessage::Close => {
469                            tracing::info!("Peer gracefully closed connection");
470                            self.shutdown().await;
471                            break Ok(());
472                        },
473                        PeerMessage::Error(err) => {
474                            self.shutdown().await;
475                            bail!(ForwardingError::PeerError(err));
476                        },
477                        other => {
478                            self.shutdown().await;
479                            bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
480                        },
481                    }
482                },
483                message = self.backchannel_rx.next() => {
484                    /* This channel will never run dry, since we always have at least one sender active */
485                    match message.unwrap() {
486                        (connection_id, Some(payload)) => {
487                            transit_tx.send(
488                                PeerMessage::Forward {
489                                    connection_id,
490                                    payload
491                                }
492                                .ser_msgpack()
493                                .into_boxed_slice()
494                            ).await?;
495                        },
496                        (connection_id, None) => {
497                            self.remove_connection(transit_tx, connection_id, true).await?;
498                        },
499                    }
500                },
501                /* We are done */
502                () = &mut *cancel => {
503                    tracing::info!("Closing connection");
504                    transit_tx.send(
505                        PeerMessage::Close.ser_msgpack()
506                        .into_boxed_slice()
507                    )
508                    .await?;
509                    transit_tx.close().await?;
510                    self.shutdown().await;
511                    break Ok(());
512                },
513            }
514        };
515        tracing::debug!("Exited processing loop");
516        ret
517    }
518}
519
520/// Request a port forwarding offer from the other side
521///
522/// You can optionally specify a `bind_address` where the port forwarding
523/// will be made available. You can also specify a list of `custom_ports` that
524/// will be used for the forwarding. The mapping between custom ports and forwarded
525/// targets is 1:1 and order preserving. If more ports are forwarded than custom
526/// ports were specified, then the remaining ports will be arbitrary.
527///
528/// The method returns a [`ConnectOffer`] from which the resulting port mapping can
529/// be queried. That struct also has an `accept` and `reject` method, of which one
530/// must be used.
531///
532/// This method already binds to all the necessary ports up-front. To limit abuse potential
533/// no more than 1024 ports may be forwarded at once.
534pub async fn connect(
535    mut wormhole: Wormhole,
536    transit_handler: impl FnOnce(transit::TransitInfo),
537    relay_hints: Vec<transit::RelayHint>,
538    bind_address: Option<std::net::IpAddr>,
539    custom_ports: &[u16],
540) -> Result<ConnectOffer, ForwardingError> {
541    let our_version: &AppVersion = wormhole
542        .our_version()
543        .downcast_ref()
544        .expect("You may only use a Wormhole instance with the correct AppVersion type!");
545    let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
546    let connector = transit::init(
547        our_version.transit_abilities,
548        Some(peer_version.transit_abilities),
549        relay_hints,
550    )
551    .await?;
552    let bind_address = bind_address.unwrap_or_else(|| std::net::IpAddr::V6("::".parse().unwrap()));
553
554    /* Send our transit hints */
555    wormhole
556        .send_json(&PeerMessage::Transit {
557            hints: (**connector.our_hints()).clone(),
558        })
559        .await?;
560
561    /* Receive their transit hints */
562    let their_hints: transit::Hints = match wormhole.receive_json().await?? {
563        PeerMessage::Transit { hints } => {
564            tracing::debug!("Received transit message: {:?}", hints);
565            hints
566        },
567        PeerMessage::Error(err) => {
568            bail!(ForwardingError::PeerError(err));
569        },
570        other => {
571            let error = ForwardingError::unexpected_message("transit", other);
572            let _ = wormhole
573                .send_json(&PeerMessage::Error(format!("{error}")))
574                .await;
575            bail!(error)
576        },
577    };
578
579    let (mut transit, info) = match connector
580        .connect(
581            TransitRole::Follower,
582            wormhole.key().derive_transit_key(wormhole.appid()),
583            peer_version.transit_abilities,
584            Arc::new(their_hints),
585        )
586        .await
587    {
588        Ok(transit) => transit,
589        Err(error) => {
590            let error = ForwardingError::TransitConnect(error);
591            let _ = wormhole
592                .send_json(&PeerMessage::Error(format!("{error}")))
593                .await;
594            return Err(error);
595        },
596    };
597    transit_handler(info);
598
599    /* We got a transit, now close the Wormhole */
600    wormhole.close().await?;
601
602    let run = async {
603        /* Receive offer and ask user */
604
605        let addresses = match PeerMessage::de_msgpack(&transit.receive_record().await?)? {
606            PeerMessage::Offer { addresses } => addresses,
607            PeerMessage::Error(err) => {
608                bail!(ForwardingError::PeerError(err));
609            },
610            other => {
611                bail!(ForwardingError::unexpected_message("offer", other))
612            },
613        };
614
615        /* Sanity check on untrusted input */
616        if addresses.len() > 1024 {
617            return Err(ForwardingError::protocol("Too many forwarded ports"));
618        }
619
620        /* self => remote
621         *                  (address, connection)
622         * Vec<Stream<Item = (String, TcpStream)>>
623         */
624        let listeners: Vec<(
625            async_net::TcpListener,
626            u16,
627            std::rc::Rc<std::string::String>,
628        )> = futures::stream::iter(
629            addresses
630                .into_iter()
631                .map(Rc::new)
632                .zip(custom_ports.iter().copied().chain(std::iter::repeat(0))),
633        )
634        .then(|(address, port)| async move {
635            let connection = TcpListener::bind((bind_address, port)).await?;
636            let port = connection.local_addr()?.port();
637            Result::<_, std::io::Error>::Ok((connection, port, address))
638        })
639        .try_collect()
640        .await?;
641        Ok(listeners)
642    };
643
644    match run.await {
645        Ok(listeners) => Ok(ConnectOffer {
646            transit,
647            mapping: listeners.iter().map(|(_, b, c)| (*b, c.clone())).collect(),
648            listeners,
649        }),
650        Err(error @ ForwardingError::PeerError(_)) => Err(error),
651        Err(error) => {
652            let _ = transit
653                .send_record(&PeerMessage::Error(format!("{error}")).ser_msgpack())
654                .await;
655            Err(error)
656        },
657    }
658}
659
660/// A pending forwarding offer from the other side
661///
662/// You *should* consume this object, either by calling [`accept`](ConnectOffer::accept) or [`reject`](ConnectOffer::reject).
663#[must_use]
664pub struct ConnectOffer {
665    /// The offered port mapping
666    pub mapping: Vec<(u16, Rc<String>)>,
667    transit: transit::Transit,
668    listeners: Vec<(
669        async_net::TcpListener,
670        u16,
671        std::rc::Rc<std::string::String>,
672    )>,
673}
674
675impl ConnectOffer {
676    /// Accept the offer and start the forwarding
677    ///
678    /// The method will run until an error occurs, the peer terminates the connection
679    /// or `cancel` resolves. The last one can be used to provide timeouts or to inject CTRL-C
680    /// handling. If you want the forward to never (successfully) stop, pass [`futures::future::pending()`]
681    /// as the value.
682    pub async fn accept(self, cancel: impl Future<Output = ()>) -> Result<(), ForwardingError> {
683        let (transit_tx, transit_rx) = self.transit.split();
684        let transit_rx = transit_rx.fuse();
685        use futures::FutureExt;
686        let cancel = cancel.fuse();
687        futures::pin_mut!(transit_tx);
688        futures::pin_mut!(transit_rx);
689        futures::pin_mut!(cancel);
690
691        /* Error handling catcher (see below) */
692        let run = async {
693            let (backchannel_tx, backchannel_rx) =
694                futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
695
696            let incoming_listeners = self.listeners.into_iter().map(|(connection, _, address)| {
697                Box::pin(
698                    futures_lite::stream::unfold(connection, |listener| async move {
699                        let res = listener.accept().await.map(|(stream, _)| stream);
700                        Some((res, listener))
701                    })
702                    .map_ok(move |stream| (address.clone(), stream)),
703                )
704            });
705
706            ForwardConnect {
707                incoming: futures::stream::select_all(incoming_listeners),
708                connection_counter: 0,
709                connections: HashMap::new(),
710                backchannel_tx,
711                backchannel_rx,
712            }
713            .run(&mut transit_tx, &mut transit_rx, &mut cancel)
714            .await
715        };
716
717        match run.await {
718            Ok(()) => Ok(()),
719            Err(error @ ForwardingError::PeerError(_)) => Err(error),
720            Err(error) => {
721                let _ = transit_tx
722                    .send(
723                        PeerMessage::Error(format!("{error}"))
724                            .ser_msgpack()
725                            .into_boxed_slice(),
726                    )
727                    .await;
728                Err(error)
729            },
730        }
731    }
732
733    /// Reject the offer
734    ///
735    /// This will send an error message to the other side so that it knows the transfer failed.
736    pub async fn reject(mut self) -> Result<(), ForwardingError> {
737        self.transit
738            .send_record(&PeerMessage::Error("transfer rejected".into()).ser_msgpack())
739            .await?;
740
741        Ok(())
742    }
743}
744
745struct ForwardConnect<I> {
746    //transit: &'a mut transit::Transit,
747    /* when can I finally store an `impl Trait` in a struct? */
748    incoming: I,
749    /* Our next unique connection_id */
750    connection_counter: u64,
751    connections: HashMap<
752        u64,
753        (
754            async_task::Task<()>,
755            futures_lite::io::WriteHalf<async_net::TcpStream>,
756        ),
757    >,
758    /* application => self. (connection_id, Some=payload or None=close) */
759    backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
760    backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
761}
762
763impl<I> ForwardConnect<I>
764where
765    I: Unpin
766        + futures::stream::FusedStream<
767            Item = Result<(Rc<String>, async_net::TcpStream), std::io::Error>,
768        >,
769{
770    async fn forward(
771        &mut self,
772        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
773        connection_id: u64,
774        payload: &[u8],
775    ) -> Result<(), ForwardingError> {
776        tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
777        match self.connections.get_mut(&connection_id) {
778            Some((_worker, connection)) => {
779                /* On an error, log for the user and then terminate that connection */
780                if let Err(e) = connection.write_all(payload).await {
781                    tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
782                    self.remove_connection(transit_tx, connection_id, true)
783                        .await?;
784                }
785            },
786            None if self.connection_counter <= connection_id => {
787                bail!(ForwardingError::protocol(format!(
788                    "Connection '{connection_id}' not found"
789                )));
790            },
791            None => { /* Race hazard. Do nothing. */ },
792        }
793        Ok(())
794    }
795
796    async fn remove_connection(
797        &mut self,
798        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
799        connection_id: u64,
800        tell_peer: bool,
801    ) -> Result<(), ForwardingError> {
802        tracing::debug!("Removing connection: #{}", connection_id);
803        if tell_peer {
804            transit_tx
805                .send(
806                    PeerMessage::Disconnect { connection_id }
807                        .ser_msgpack()
808                        .into_boxed_slice(),
809                )
810                .await?;
811        }
812        match self.connections.remove(&connection_id) {
813            Some((worker, _connection)) => {
814                worker.cancel().await;
815            },
816            None if connection_id >= self.connection_counter => {
817                bail!(ForwardingError::protocol(format!(
818                    "Connection '{connection_id}' not found"
819                )));
820            },
821            None => { /* Race hazard. Do nothing. */ },
822        }
823        Ok(())
824    }
825
826    async fn spawn_connection(
827        &mut self,
828        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
829        target: Rc<String>,
830        connection: async_net::TcpStream,
831    ) -> Result<(), ForwardingError> {
832        let connection_id = self.connection_counter;
833        self.connection_counter += 1;
834        let (mut connection_rd, connection_wr) = futures_lite::io::split(connection);
835        let mut backchannel_tx = self.backchannel_tx.clone();
836        tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
837
838        transit_tx
839            .send(
840                PeerMessage::Connect {
841                    target: (*target).clone(),
842                    connection_id,
843                }
844                .ser_msgpack()
845                .into_boxed_slice(),
846            )
847            .await?;
848
849        let worker = crate::util::spawn(async move {
850            let mut buffer = vec![0; 4096];
851            /* Ignore errors */
852            macro_rules! break_on_err {
853                ($expr:expr_2021) => {
854                    match $expr {
855                        Ok(val) => val,
856                        Err(_) => break,
857                    }
858                };
859            }
860            #[expect(clippy::while_let_loop)]
861            loop {
862                let read = break_on_err!(connection_rd.read(&mut buffer).await);
863                if read == 0 {
864                    break;
865                }
866                let buffer = &buffer[..read];
867                break_on_err!(
868                    backchannel_tx
869                        .send((connection_id, Some(buffer.to_vec())))
870                        .await
871                );
872            }
873            /* Close connection (maybe or not because of error) */
874            let _ = backchannel_tx.send((connection_id, None)).await;
875            backchannel_tx.disconnect();
876        });
877
878        self.connections
879            .insert(connection_id, (worker, connection_wr));
880        Ok(())
881    }
882
883    async fn shutdown(self) {
884        tracing::debug!("Shutting down everything");
885        for (worker, _connection) in self.connections.into_values() {
886            worker.cancel().await;
887        }
888    }
889
890    async fn run(
891        mut self,
892        transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
893        transit_rx: &mut (
894                 impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>> + Unpin
895             ),
896        cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
897    ) -> Result<(), ForwardingError> {
898        /* Event processing loop */
899        tracing::debug!("Entered processing loop");
900        let ret = loop {
901            futures::select! {
902                message = transit_rx.next() => {
903                    match PeerMessage::de_msgpack(&message.unwrap()?)? {
904                        PeerMessage::Forward { connection_id, payload } => {
905                            self.forward(transit_tx, connection_id, &payload).await?;
906                        },
907                        PeerMessage::Disconnect { connection_id } => {
908                            self.remove_connection(transit_tx, connection_id, false).await?;
909                        },
910                        PeerMessage::Close => {
911                            tracing::info!("Peer gracefully closed connection");
912                            self.shutdown().await;
913                            break Ok(())
914                        },
915                        PeerMessage::Error(err) => {
916                            for (worker, _connection) in self.connections.into_values() {
917                                worker.cancel().await;
918                            }
919                            bail!(ForwardingError::PeerError(err));
920                        },
921                        other => {
922                            self.shutdown().await;
923                            bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
924                        },
925                    }
926                },
927                message = self.backchannel_rx.next() => {
928                    /* This channel will never run dry, since we always have at least one sender active */
929                    match message.unwrap() {
930                        (connection_id, Some(payload)) => {
931                            transit_tx.send(
932                                PeerMessage::Forward {
933                                    connection_id,
934                                    payload
935                                }.ser_msgpack()
936                                .into_boxed_slice()
937                            )
938                            .await?;
939                        },
940                        (connection_id, None) => {
941                            self.remove_connection(transit_tx, connection_id, true).await?;
942                        },
943                    }
944                },
945                connection = self.incoming.next() => {
946                    let (target, connection): (Rc<String>, async_net::TcpStream) = connection.unwrap()?;
947                    self.spawn_connection(transit_tx, target, connection).await?;
948                },
949                /* We are done */
950                () = &mut *cancel => {
951                    tracing::info!("Closing connection");
952                    transit_tx.send(
953                        PeerMessage::Close.ser_msgpack()
954                        .into_boxed_slice()
955                    )
956                    .await?;
957                    transit_tx.close().await?;
958                    self.shutdown().await;
959                    break Ok(());
960                },
961            }
962        };
963        tracing::debug!("Exited processing loop");
964        ret
965    }
966}
967
968/** Serialization struct for this protocol */
969#[derive(Deserialize, Serialize, Debug)]
970#[serde(rename_all = "kebab-case")]
971#[non_exhaustive]
972enum PeerMessage {
973    /** Offer some destinations to be forwarded to.
974     * forwarder -> forwardee only
975     */
976    Offer { addresses: Vec<String> },
977    /** Forward a new connection.
978     * forwardee -> forwarder only
979     */
980    Connect { target: String, connection_id: u64 },
981    /** End a forwarded connection.
982     * Any direction. Errors or the reason why the connection is closed
983     * are not forwarded.
984     */
985    Disconnect { connection_id: u64 },
986    /** Forward some bytes for a connection. */
987    Forward {
988        connection_id: u64,
989        payload: Vec<u8>,
990    },
991    /** Close the whole session */
992    Close,
993    /** Tell the other side you got an error */
994    Error(String),
995    /** Used to set up a transit channel */
996    Transit { hints: transit::Hints },
997    #[serde(other)]
998    Unknown,
999}
1000
1001impl PeerMessage {
1002    pub fn ser_msgpack(&self) -> Vec<u8> {
1003        let mut writer = Vec::with_capacity(128);
1004        let mut ser = rmp_serde::encode::Serializer::new(&mut writer)
1005            .with_struct_map()
1006            .with_human_readable();
1007        serde::Serialize::serialize(self, &mut ser).unwrap();
1008        writer
1009    }
1010
1011    pub fn de_msgpack(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
1012        rmp_serde::from_read(&mut &*data)
1013    }
1014}