1use crate::transit::TransitRole;
17
18use super::*;
19use async_net::TcpListener;
20use futures::{AsyncReadExt, AsyncWriteExt, Future, SinkExt, StreamExt, TryStreamExt};
21use serde_derive::{Deserialize, Serialize};
22use std::{
23 borrow::Cow,
24 collections::{HashMap, HashSet},
25 rc::Rc,
26 sync::Arc,
27};
28use transit::{TransitConnectError, TransitError};
29
30const APPID_RAW: &str = "piegames.de/wormhole/port-forwarding";
31
32pub const APPID: AppID = AppID(Cow::Borrowed(APPID_RAW));
34
35pub const APP_CONFIG: crate::AppConfig<AppVersion> = crate::AppConfig::<AppVersion> {
40 id: AppID(Cow::Borrowed(APPID_RAW)),
41 rendezvous_url: Cow::Borrowed(crate::rendezvous::DEFAULT_RENDEZVOUS_SERVER),
42 app_version: AppVersion {
43 transit_abilities: transit::Abilities::ALL,
44 other: serde_json::Value::Null,
45 },
46};
47
48#[derive(Clone, Debug, Default, Serialize, Deserialize)]
52pub struct AppVersion {
53 pub transit_abilities: transit::Abilities,
55 #[serde(flatten)]
56 other: serde_json::Value,
57}
58
59#[derive(Debug, thiserror::Error)]
60#[non_exhaustive]
61pub enum ForwardingError {
63 #[error("Transfer was not acknowledged by peer")]
65 AckError,
66 #[error("Something went wrong on the other side: {}", _0)]
68 PeerError(String),
69 #[error("Corrupt JSON message received")]
71 ProtocolJson(
72 #[from]
73 #[source]
74 serde_json::Error,
75 ),
76 #[error("Corrupt Msgpack message received")]
78 ProtocolMsgpack(
79 #[from]
80 #[source]
81 rmp_serde::decode::Error,
82 ),
83 #[error("Protocol error: {}", _0)]
86 Protocol(Box<str>),
87 #[error(
89 "Unexpected message (protocol error): Expected '{}', but got: {:?}",
90 _0,
91 _1
92 )]
93 ProtocolUnexpectedMessage(Box<str>, Box<dyn std::fmt::Debug + Send + Sync>),
94 #[error("Wormhole connection error")]
96 Wormhole(
97 #[from]
98 #[source]
99 WormholeError,
100 ),
101 #[error("Error while establishing transit connection")]
103 TransitConnect(
104 #[from]
105 #[source]
106 TransitConnectError,
107 ),
108 #[error("Transit error")]
110 Transit(
111 #[from]
112 #[source]
113 TransitError,
114 ),
115 #[error("I/O error")]
117 IO(
118 #[from]
119 #[source]
120 std::io::Error,
121 ),
122}
123
124impl ForwardingError {
125 fn protocol(message: impl Into<Box<str>>) -> Self {
126 Self::Protocol(message.into())
127 }
128
129 pub(self) fn unexpected_message(
130 expected: impl Into<Box<str>>,
131 got: impl std::fmt::Debug + Send + Sync + 'static,
132 ) -> Self {
133 Self::ProtocolUnexpectedMessage(expected.into(), Box::new(got))
134 }
135}
136
137pub async fn serve(
148 mut wormhole: Wormhole,
149 transit_handler: impl FnOnce(transit::TransitInfo),
150 relay_hints: Vec<transit::RelayHint>,
151 targets: Vec<(Option<url::Host>, u16)>,
152 cancel: impl Future<Output = ()>,
153) -> Result<(), ForwardingError> {
154 assert!(
155 !targets.is_empty(),
156 "The list of target ports must not be empty"
157 );
158
159 let our_version: &AppVersion = wormhole
160 .our_version()
161 .downcast_ref()
162 .expect("You may only use a Wormhole instance with the correct AppVersion type!");
163 let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
164 let connector = transit::init(
165 our_version.transit_abilities,
166 Some(peer_version.transit_abilities),
167 relay_hints,
168 )
169 .await?;
170
171 wormhole
173 .send_json(&PeerMessage::Transit {
174 hints: (**connector.our_hints()).clone(),
175 })
176 .await?;
177
178 let targets: HashMap<String, (Option<url::Host>, u16)> = targets
179 .into_iter()
180 .map(|(host, port)| match host {
181 Some(host) => {
182 if port == 80 || port == 443 || port == 8000 || port == 8080 {
183 tracing::warn!("It seems like you are trying to forward a remote HTTP target ('{}'). Due to HTTP being host-aware this will very likely fail!", host);
184 }
185 (format!("{host}:{port}"), (Some(host), port))
186 },
187 None => (port.to_string(), (host, port)),
188 })
189 .collect();
190
191 let their_hints: transit::Hints = match wormhole.receive_json().await?? {
193 PeerMessage::Transit { hints } => {
194 tracing::debug!("Received transit message: {:?}", hints);
195 hints
196 },
197 PeerMessage::Error(err) => {
198 bail!(ForwardingError::PeerError(err));
199 },
200 other => {
201 let error = ForwardingError::unexpected_message("transit", other);
202 let _ = wormhole
203 .send_json(&PeerMessage::Error(format!("{error}")))
204 .await;
205 bail!(error)
206 },
207 };
208
209 let (mut transit, info) = match connector
210 .connect(
211 TransitRole::Leader,
212 wormhole.key().derive_transit_key(wormhole.appid()),
213 peer_version.transit_abilities,
214 Arc::new(their_hints),
215 )
216 .await
217 {
218 Ok(transit) => transit,
219 Err(error) => {
220 let error = ForwardingError::TransitConnect(error);
221 let _ = wormhole
222 .send_json(&PeerMessage::Error(format!("{error}")))
223 .await;
224 return Err(error);
225 },
226 };
227 transit_handler(info);
228
229 wormhole.close().await?;
231
232 transit
233 .send_record(
234 &PeerMessage::Offer {
235 addresses: targets.keys().cloned().collect(),
236 }
237 .ser_msgpack(),
238 )
239 .await?;
240
241 let (backchannel_tx, backchannel_rx) =
242 futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
243
244 let (transit_tx, transit_rx) = transit.split();
245 let transit_rx = transit_rx.fuse();
246 use futures::future::FutureExt;
247 let cancel = cancel.fuse();
248 futures::pin_mut!(transit_tx);
249 futures::pin_mut!(transit_rx);
250 futures::pin_mut!(cancel);
251
252 let result = ForwardingServe {
254 targets,
255 connections: HashMap::new(),
256 historic_connections: HashSet::new(),
257 backchannel_tx,
258 backchannel_rx,
259 }
260 .run(&mut transit_tx, &mut transit_rx, &mut cancel)
261 .await;
262 match result {
264 Ok(()) => Ok(()),
265 Err(error @ ForwardingError::PeerError(_)) => Err(error),
266 Err(error) => {
267 let _ = transit_tx
268 .send(
269 PeerMessage::Error(format!("{error}"))
270 .ser_msgpack()
271 .into_boxed_slice(),
272 )
273 .await;
274 Err(error)
275 },
276 }
277}
278
279struct ForwardingServe {
280 targets: HashMap<String, (Option<url::Host>, u16)>,
281 connections: HashMap<
283 u64,
284 (
285 async_task::Task<()>,
286 futures_lite::io::WriteHalf<async_net::TcpStream>,
287 ),
288 >,
289 historic_connections: HashSet<u64>,
294 backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
296 backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
297}
298
299impl ForwardingServe {
301 async fn forward(
302 &mut self,
303 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
304 connection_id: u64,
305 payload: &[u8],
306 ) -> Result<(), ForwardingError> {
307 tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
308 match self.connections.get_mut(&connection_id) {
309 Some((_worker, connection)) => {
310 if let Err(e) = connection.write_all(payload).await {
312 tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
313 self.remove_connection(transit_tx, connection_id, true)
314 .await?;
315 }
316 },
317 None if !self.historic_connections.contains(&connection_id) => {
318 bail!(ForwardingError::protocol(format!(
319 "Connection '{connection_id}' not found"
320 )));
321 },
322 None => { },
323 }
324 Ok(())
325 }
326
327 async fn remove_connection(
328 &mut self,
329 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
330 connection_id: u64,
331 tell_peer: bool,
332 ) -> Result<(), ForwardingError> {
333 tracing::debug!("Removing connection: #{}", connection_id);
334 if tell_peer {
335 transit_tx
336 .send(
337 PeerMessage::Disconnect { connection_id }
338 .ser_msgpack()
339 .into_boxed_slice(),
340 )
341 .await?;
342 }
343 match self.connections.remove(&connection_id) {
344 Some((worker, _connection)) => {
345 worker.cancel().await;
346 },
347 None if !self.historic_connections.contains(&connection_id) => {
348 bail!(ForwardingError::protocol(format!(
349 "Connection '{connection_id}' not found"
350 )));
351 },
352 None => { },
353 }
354 Ok(())
355 }
356
357 async fn spawn_connection(
358 &mut self,
359 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
360 mut target: String,
361 connection_id: u64,
362 ) -> Result<(), ForwardingError> {
363 tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
364
365 use std::collections::hash_map::Entry;
366 let entry = match self.connections.entry(connection_id) {
367 Entry::Vacant(entry) => entry,
368 Entry::Occupied(_) => {
369 bail!(ForwardingError::protocol(format!(
370 "Connection '{connection_id}' already exists"
371 )));
372 },
373 };
374
375 let (host, port) = self.targets.get(&target).unwrap();
376 if host.is_none() {
377 target = format!("[::1]:{port}");
378 }
379 let stream = match async_net::TcpStream::connect(&target).await {
380 Ok(stream) => stream,
381 Err(err) => {
382 tracing::warn!(
383 "Cannot open connection to {}: {}. The forwarded service might be down.",
384 target,
385 err
386 );
387 transit_tx
388 .send(
389 PeerMessage::Disconnect { connection_id }
390 .ser_msgpack()
391 .into_boxed_slice(),
392 )
393 .await?;
394 return Ok(());
395 },
396 };
397 let (mut connection_rd, connection_wr) = futures_lite::io::split(stream);
398 let mut backchannel_tx = self.backchannel_tx.clone();
399 let worker = crate::util::spawn(async move {
400 let mut buffer = vec![0; 4096];
401 macro_rules! break_on_err {
403 ($expr:expr_2021) => {
404 match $expr {
405 Ok(val) => val,
406 Err(_) => break,
407 }
408 };
409 }
410 #[expect(clippy::while_let_loop)]
411 loop {
412 let read = break_on_err!(connection_rd.read(&mut buffer).await);
413 if read == 0 {
414 break;
415 }
416 let buffer = &buffer[..read];
417 break_on_err!(
418 backchannel_tx
419 .send((connection_id, Some(buffer.to_vec())))
420 .await
421 );
422 }
423 let _ = backchannel_tx.send((connection_id, None)).await;
425 backchannel_tx.disconnect();
426 });
427 entry.insert((worker, connection_wr));
428 Ok(())
429 }
430
431 async fn shutdown(self) {
432 tracing::debug!("Shutting down everything");
433 for (worker, _connection) in self.connections.into_values() {
434 worker.cancel().await;
435 }
436 }
437
438 async fn run(
439 mut self,
440 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
441 transit_rx: &mut (
442 impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>> + Unpin
443 ),
444 cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
445 ) -> Result<(), ForwardingError> {
446 tracing::debug!("Entered processing loop");
448 let ret = loop {
449 futures::select! {
450 message = transit_rx.next() => {
451 match PeerMessage::de_msgpack(&message.unwrap()?)? {
452 PeerMessage::Forward { connection_id, payload } => {
453 self.forward(transit_tx, connection_id, &payload).await?
454 },
455 PeerMessage::Connect { target, connection_id } => {
456 self.historic_connections.insert(connection_id);
458 ensure!(
459 self.targets.contains_key(&target),
460 ForwardingError::protocol(format!("We don't know forwarding target '{target}'")),
461 );
462
463 self.spawn_connection(transit_tx, target, connection_id).await?;
464 },
465 PeerMessage::Disconnect { connection_id } => {
466 self.remove_connection(transit_tx, connection_id, false).await?;
467 },
468 PeerMessage::Close => {
469 tracing::info!("Peer gracefully closed connection");
470 self.shutdown().await;
471 break Ok(());
472 },
473 PeerMessage::Error(err) => {
474 self.shutdown().await;
475 bail!(ForwardingError::PeerError(err));
476 },
477 other => {
478 self.shutdown().await;
479 bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
480 },
481 }
482 },
483 message = self.backchannel_rx.next() => {
484 match message.unwrap() {
486 (connection_id, Some(payload)) => {
487 transit_tx.send(
488 PeerMessage::Forward {
489 connection_id,
490 payload
491 }
492 .ser_msgpack()
493 .into_boxed_slice()
494 ).await?;
495 },
496 (connection_id, None) => {
497 self.remove_connection(transit_tx, connection_id, true).await?;
498 },
499 }
500 },
501 () = &mut *cancel => {
503 tracing::info!("Closing connection");
504 transit_tx.send(
505 PeerMessage::Close.ser_msgpack()
506 .into_boxed_slice()
507 )
508 .await?;
509 transit_tx.close().await?;
510 self.shutdown().await;
511 break Ok(());
512 },
513 }
514 };
515 tracing::debug!("Exited processing loop");
516 ret
517 }
518}
519
520pub async fn connect(
535 mut wormhole: Wormhole,
536 transit_handler: impl FnOnce(transit::TransitInfo),
537 relay_hints: Vec<transit::RelayHint>,
538 bind_address: Option<std::net::IpAddr>,
539 custom_ports: &[u16],
540) -> Result<ConnectOffer, ForwardingError> {
541 let our_version: &AppVersion = wormhole
542 .our_version()
543 .downcast_ref()
544 .expect("You may only use a Wormhole instance with the correct AppVersion type!");
545 let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
546 let connector = transit::init(
547 our_version.transit_abilities,
548 Some(peer_version.transit_abilities),
549 relay_hints,
550 )
551 .await?;
552 let bind_address = bind_address.unwrap_or_else(|| std::net::IpAddr::V6("::".parse().unwrap()));
553
554 wormhole
556 .send_json(&PeerMessage::Transit {
557 hints: (**connector.our_hints()).clone(),
558 })
559 .await?;
560
561 let their_hints: transit::Hints = match wormhole.receive_json().await?? {
563 PeerMessage::Transit { hints } => {
564 tracing::debug!("Received transit message: {:?}", hints);
565 hints
566 },
567 PeerMessage::Error(err) => {
568 bail!(ForwardingError::PeerError(err));
569 },
570 other => {
571 let error = ForwardingError::unexpected_message("transit", other);
572 let _ = wormhole
573 .send_json(&PeerMessage::Error(format!("{error}")))
574 .await;
575 bail!(error)
576 },
577 };
578
579 let (mut transit, info) = match connector
580 .connect(
581 TransitRole::Follower,
582 wormhole.key().derive_transit_key(wormhole.appid()),
583 peer_version.transit_abilities,
584 Arc::new(their_hints),
585 )
586 .await
587 {
588 Ok(transit) => transit,
589 Err(error) => {
590 let error = ForwardingError::TransitConnect(error);
591 let _ = wormhole
592 .send_json(&PeerMessage::Error(format!("{error}")))
593 .await;
594 return Err(error);
595 },
596 };
597 transit_handler(info);
598
599 wormhole.close().await?;
601
602 let run = async {
603 let addresses = match PeerMessage::de_msgpack(&transit.receive_record().await?)? {
606 PeerMessage::Offer { addresses } => addresses,
607 PeerMessage::Error(err) => {
608 bail!(ForwardingError::PeerError(err));
609 },
610 other => {
611 bail!(ForwardingError::unexpected_message("offer", other))
612 },
613 };
614
615 if addresses.len() > 1024 {
617 return Err(ForwardingError::protocol("Too many forwarded ports"));
618 }
619
620 let listeners: Vec<(
625 async_net::TcpListener,
626 u16,
627 std::rc::Rc<std::string::String>,
628 )> = futures::stream::iter(
629 addresses
630 .into_iter()
631 .map(Rc::new)
632 .zip(custom_ports.iter().copied().chain(std::iter::repeat(0))),
633 )
634 .then(|(address, port)| async move {
635 let connection = TcpListener::bind((bind_address, port)).await?;
636 let port = connection.local_addr()?.port();
637 Result::<_, std::io::Error>::Ok((connection, port, address))
638 })
639 .try_collect()
640 .await?;
641 Ok(listeners)
642 };
643
644 match run.await {
645 Ok(listeners) => Ok(ConnectOffer {
646 transit,
647 mapping: listeners.iter().map(|(_, b, c)| (*b, c.clone())).collect(),
648 listeners,
649 }),
650 Err(error @ ForwardingError::PeerError(_)) => Err(error),
651 Err(error) => {
652 let _ = transit
653 .send_record(&PeerMessage::Error(format!("{error}")).ser_msgpack())
654 .await;
655 Err(error)
656 },
657 }
658}
659
660#[must_use]
664pub struct ConnectOffer {
665 pub mapping: Vec<(u16, Rc<String>)>,
667 transit: transit::Transit,
668 listeners: Vec<(
669 async_net::TcpListener,
670 u16,
671 std::rc::Rc<std::string::String>,
672 )>,
673}
674
675impl ConnectOffer {
676 pub async fn accept(self, cancel: impl Future<Output = ()>) -> Result<(), ForwardingError> {
683 let (transit_tx, transit_rx) = self.transit.split();
684 let transit_rx = transit_rx.fuse();
685 use futures::FutureExt;
686 let cancel = cancel.fuse();
687 futures::pin_mut!(transit_tx);
688 futures::pin_mut!(transit_rx);
689 futures::pin_mut!(cancel);
690
691 let run = async {
693 let (backchannel_tx, backchannel_rx) =
694 futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
695
696 let incoming_listeners = self.listeners.into_iter().map(|(connection, _, address)| {
697 Box::pin(
698 futures_lite::stream::unfold(connection, |listener| async move {
699 let res = listener.accept().await.map(|(stream, _)| stream);
700 Some((res, listener))
701 })
702 .map_ok(move |stream| (address.clone(), stream)),
703 )
704 });
705
706 ForwardConnect {
707 incoming: futures::stream::select_all(incoming_listeners),
708 connection_counter: 0,
709 connections: HashMap::new(),
710 backchannel_tx,
711 backchannel_rx,
712 }
713 .run(&mut transit_tx, &mut transit_rx, &mut cancel)
714 .await
715 };
716
717 match run.await {
718 Ok(()) => Ok(()),
719 Err(error @ ForwardingError::PeerError(_)) => Err(error),
720 Err(error) => {
721 let _ = transit_tx
722 .send(
723 PeerMessage::Error(format!("{error}"))
724 .ser_msgpack()
725 .into_boxed_slice(),
726 )
727 .await;
728 Err(error)
729 },
730 }
731 }
732
733 pub async fn reject(mut self) -> Result<(), ForwardingError> {
737 self.transit
738 .send_record(&PeerMessage::Error("transfer rejected".into()).ser_msgpack())
739 .await?;
740
741 Ok(())
742 }
743}
744
745struct ForwardConnect<I> {
746 incoming: I,
749 connection_counter: u64,
751 connections: HashMap<
752 u64,
753 (
754 async_task::Task<()>,
755 futures_lite::io::WriteHalf<async_net::TcpStream>,
756 ),
757 >,
758 backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
760 backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
761}
762
763impl<I> ForwardConnect<I>
764where
765 I: Unpin
766 + futures::stream::FusedStream<
767 Item = Result<(Rc<String>, async_net::TcpStream), std::io::Error>,
768 >,
769{
770 async fn forward(
771 &mut self,
772 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
773 connection_id: u64,
774 payload: &[u8],
775 ) -> Result<(), ForwardingError> {
776 tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
777 match self.connections.get_mut(&connection_id) {
778 Some((_worker, connection)) => {
779 if let Err(e) = connection.write_all(payload).await {
781 tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
782 self.remove_connection(transit_tx, connection_id, true)
783 .await?;
784 }
785 },
786 None if self.connection_counter <= connection_id => {
787 bail!(ForwardingError::protocol(format!(
788 "Connection '{connection_id}' not found"
789 )));
790 },
791 None => { },
792 }
793 Ok(())
794 }
795
796 async fn remove_connection(
797 &mut self,
798 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
799 connection_id: u64,
800 tell_peer: bool,
801 ) -> Result<(), ForwardingError> {
802 tracing::debug!("Removing connection: #{}", connection_id);
803 if tell_peer {
804 transit_tx
805 .send(
806 PeerMessage::Disconnect { connection_id }
807 .ser_msgpack()
808 .into_boxed_slice(),
809 )
810 .await?;
811 }
812 match self.connections.remove(&connection_id) {
813 Some((worker, _connection)) => {
814 worker.cancel().await;
815 },
816 None if connection_id >= self.connection_counter => {
817 bail!(ForwardingError::protocol(format!(
818 "Connection '{connection_id}' not found"
819 )));
820 },
821 None => { },
822 }
823 Ok(())
824 }
825
826 async fn spawn_connection(
827 &mut self,
828 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
829 target: Rc<String>,
830 connection: async_net::TcpStream,
831 ) -> Result<(), ForwardingError> {
832 let connection_id = self.connection_counter;
833 self.connection_counter += 1;
834 let (mut connection_rd, connection_wr) = futures_lite::io::split(connection);
835 let mut backchannel_tx = self.backchannel_tx.clone();
836 tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
837
838 transit_tx
839 .send(
840 PeerMessage::Connect {
841 target: (*target).clone(),
842 connection_id,
843 }
844 .ser_msgpack()
845 .into_boxed_slice(),
846 )
847 .await?;
848
849 let worker = crate::util::spawn(async move {
850 let mut buffer = vec![0; 4096];
851 macro_rules! break_on_err {
853 ($expr:expr_2021) => {
854 match $expr {
855 Ok(val) => val,
856 Err(_) => break,
857 }
858 };
859 }
860 #[expect(clippy::while_let_loop)]
861 loop {
862 let read = break_on_err!(connection_rd.read(&mut buffer).await);
863 if read == 0 {
864 break;
865 }
866 let buffer = &buffer[..read];
867 break_on_err!(
868 backchannel_tx
869 .send((connection_id, Some(buffer.to_vec())))
870 .await
871 );
872 }
873 let _ = backchannel_tx.send((connection_id, None)).await;
875 backchannel_tx.disconnect();
876 });
877
878 self.connections
879 .insert(connection_id, (worker, connection_wr));
880 Ok(())
881 }
882
883 async fn shutdown(self) {
884 tracing::debug!("Shutting down everything");
885 for (worker, _connection) in self.connections.into_values() {
886 worker.cancel().await;
887 }
888 }
889
890 async fn run(
891 mut self,
892 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
893 transit_rx: &mut (
894 impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>> + Unpin
895 ),
896 cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
897 ) -> Result<(), ForwardingError> {
898 tracing::debug!("Entered processing loop");
900 let ret = loop {
901 futures::select! {
902 message = transit_rx.next() => {
903 match PeerMessage::de_msgpack(&message.unwrap()?)? {
904 PeerMessage::Forward { connection_id, payload } => {
905 self.forward(transit_tx, connection_id, &payload).await?;
906 },
907 PeerMessage::Disconnect { connection_id } => {
908 self.remove_connection(transit_tx, connection_id, false).await?;
909 },
910 PeerMessage::Close => {
911 tracing::info!("Peer gracefully closed connection");
912 self.shutdown().await;
913 break Ok(())
914 },
915 PeerMessage::Error(err) => {
916 for (worker, _connection) in self.connections.into_values() {
917 worker.cancel().await;
918 }
919 bail!(ForwardingError::PeerError(err));
920 },
921 other => {
922 self.shutdown().await;
923 bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
924 },
925 }
926 },
927 message = self.backchannel_rx.next() => {
928 match message.unwrap() {
930 (connection_id, Some(payload)) => {
931 transit_tx.send(
932 PeerMessage::Forward {
933 connection_id,
934 payload
935 }.ser_msgpack()
936 .into_boxed_slice()
937 )
938 .await?;
939 },
940 (connection_id, None) => {
941 self.remove_connection(transit_tx, connection_id, true).await?;
942 },
943 }
944 },
945 connection = self.incoming.next() => {
946 let (target, connection): (Rc<String>, async_net::TcpStream) = connection.unwrap()?;
947 self.spawn_connection(transit_tx, target, connection).await?;
948 },
949 () = &mut *cancel => {
951 tracing::info!("Closing connection");
952 transit_tx.send(
953 PeerMessage::Close.ser_msgpack()
954 .into_boxed_slice()
955 )
956 .await?;
957 transit_tx.close().await?;
958 self.shutdown().await;
959 break Ok(());
960 },
961 }
962 };
963 tracing::debug!("Exited processing loop");
964 ret
965 }
966}
967
968#[derive(Deserialize, Serialize, Debug)]
970#[serde(rename_all = "kebab-case")]
971#[non_exhaustive]
972enum PeerMessage {
973 Offer { addresses: Vec<String> },
977 Connect { target: String, connection_id: u64 },
981 Disconnect { connection_id: u64 },
986 Forward {
988 connection_id: u64,
989 payload: Vec<u8>,
990 },
991 Close,
993 Error(String),
995 Transit { hints: transit::Hints },
997 #[serde(other)]
998 Unknown,
999}
1000
1001impl PeerMessage {
1002 pub fn ser_msgpack(&self) -> Vec<u8> {
1003 let mut writer = Vec::with_capacity(128);
1004 let mut ser = rmp_serde::encode::Serializer::new(&mut writer)
1005 .with_struct_map()
1006 .with_human_readable();
1007 serde::Serialize::serialize(self, &mut ser).unwrap();
1008 writer
1009 }
1010
1011 pub fn de_msgpack(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
1012 rmp_serde::from_read(&mut &*data)
1013 }
1014}