1#![allow(deprecated)]
17
18use super::*;
19use async_std::net::{TcpListener, TcpStream};
20use futures::{AsyncReadExt, AsyncWriteExt, Future, SinkExt, StreamExt, TryStreamExt};
21use serde::{Deserialize, Serialize};
22use std::{
23 borrow::Cow,
24 collections::{HashMap, HashSet},
25 rc::Rc,
26 sync::Arc,
27};
28use transit::{TransitConnectError, TransitError};
29
30const APPID_RAW: &str = "piegames.de/wormhole/port-forwarding";
31
32pub const APPID: AppID = AppID(Cow::Borrowed(APPID_RAW));
34
35pub const APP_CONFIG: crate::AppConfig<AppVersion> = crate::AppConfig::<AppVersion> {
40 id: AppID(Cow::Borrowed(APPID_RAW)),
41 rendezvous_url: Cow::Borrowed(crate::rendezvous::DEFAULT_RENDEZVOUS_SERVER),
42 app_version: AppVersion {
43 transit_abilities: transit::Abilities::ALL_ABILITIES,
44 other: serde_json::Value::Null,
45 },
46};
47
48#[derive(Clone, Debug, Default, Serialize, Deserialize)]
52pub struct AppVersion {
53 pub transit_abilities: transit::Abilities,
55 #[serde(flatten)]
56 other: serde_json::Value,
57}
58
59#[derive(Debug, thiserror::Error)]
60#[non_exhaustive]
61pub enum ForwardingError {
63 #[error("Transfer was not acknowledged by peer")]
65 AckError,
66 #[error("Something went wrong on the other side: {}", _0)]
68 PeerError(String),
69 #[error("Corrupt JSON message received")]
71 ProtocolJson(
72 #[from]
73 #[source]
74 serde_json::Error,
75 ),
76 #[error("Corrupt Msgpack message received")]
78 ProtocolMsgpack(
79 #[from]
80 #[source]
81 rmp_serde::decode::Error,
82 ),
83 #[error("Protocol error: {}", _0)]
86 Protocol(Box<str>),
87 #[error(
89 "Unexpected message (protocol error): Expected '{}', but got: {:?}",
90 _0,
91 _1
92 )]
93 ProtocolUnexpectedMessage(Box<str>, Box<dyn std::fmt::Debug + Send + Sync>),
94 #[error("Wormhole connection error")]
96 Wormhole(
97 #[from]
98 #[source]
99 WormholeError,
100 ),
101 #[error("Error while establishing transit connection")]
103 TransitConnect(
104 #[from]
105 #[source]
106 TransitConnectError,
107 ),
108 #[error("Transit error")]
110 Transit(
111 #[from]
112 #[source]
113 TransitError,
114 ),
115 #[error("I/O error")]
117 IO(
118 #[from]
119 #[source]
120 std::io::Error,
121 ),
122}
123
124impl ForwardingError {
125 fn protocol(message: impl Into<Box<str>>) -> Self {
126 Self::Protocol(message.into())
127 }
128
129 pub(self) fn unexpected_message(
130 expected: impl Into<Box<str>>,
131 got: impl std::fmt::Debug + Send + Sync + 'static,
132 ) -> Self {
133 Self::ProtocolUnexpectedMessage(expected.into(), Box::new(got))
134 }
135}
136
137pub async fn serve(
148 mut wormhole: Wormhole,
149 transit_handler: impl FnOnce(transit::TransitInfo),
150 relay_hints: Vec<transit::RelayHint>,
151 targets: Vec<(Option<url::Host>, u16)>,
152 cancel: impl Future<Output = ()>,
153) -> Result<(), ForwardingError> {
154 assert!(
155 !targets.is_empty(),
156 "The list of target ports must not be empty"
157 );
158
159 let our_version: &AppVersion = wormhole
160 .our_version()
161 .downcast_ref()
162 .expect("You may only use a Wormhole instance with the correct AppVersion type!");
163 let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
164 let connector = transit::init(
165 our_version.transit_abilities,
166 Some(peer_version.transit_abilities),
167 relay_hints,
168 )
169 .await?;
170
171 wormhole
173 .send_json(&PeerMessage::Transit {
174 hints: (**connector.our_hints()).clone(),
175 })
176 .await?;
177
178 let targets: HashMap<String, (Option<url::Host>, u16)> = targets
179 .into_iter()
180 .map(|(host, port)| match host {
181 Some(host) => {
182 if port == 80 || port == 443 || port == 8000 || port == 8080 {
183 tracing::warn!("It seems like you are trying to forward a remote HTTP target ('{}'). Due to HTTP being host-aware this will very likely fail!", host);
184 }
185 (format!("{}:{}", host, port), (Some(host), port))
186 },
187 None => (port.to_string(), (host, port)),
188 })
189 .collect();
190
191 let their_hints: transit::Hints = match wormhole.receive_json().await?? {
193 PeerMessage::Transit { hints } => {
194 tracing::debug!("Received transit message: {:?}", hints);
195 hints
196 },
197 PeerMessage::Error(err) => {
198 bail!(ForwardingError::PeerError(err));
199 },
200 other => {
201 let error = ForwardingError::unexpected_message("transit", other);
202 let _ = wormhole
203 .send_json(&PeerMessage::Error(format!("{}", error)))
204 .await;
205 bail!(error)
206 },
207 };
208
209 let (mut transit, info) = match connector
210 .leader_connect(
211 wormhole.key().derive_transit_key(wormhole.appid()),
212 peer_version.transit_abilities,
213 Arc::new(their_hints),
214 )
215 .await
216 {
217 Ok(transit) => transit,
218 Err(error) => {
219 let error = ForwardingError::TransitConnect(error);
220 let _ = wormhole
221 .send_json(&PeerMessage::Error(format!("{}", error)))
222 .await;
223 return Err(error);
224 },
225 };
226 transit_handler(info);
227
228 wormhole.close().await?;
230
231 transit
232 .send_record(
233 &PeerMessage::Offer {
234 addresses: targets.keys().cloned().collect(),
235 }
236 .ser_msgpack(),
237 )
238 .await?;
239
240 let (backchannel_tx, backchannel_rx) =
241 futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
242
243 let (transit_tx, transit_rx) = transit.split();
244 let transit_rx = transit_rx.fuse();
245 use futures::future::FutureExt;
246 let cancel = cancel.fuse();
247 futures::pin_mut!(transit_tx);
248 futures::pin_mut!(transit_rx);
249 futures::pin_mut!(cancel);
250
251 let result = ForwardingServe {
253 targets,
254 connections: HashMap::new(),
255 historic_connections: HashSet::new(),
256 backchannel_tx,
257 backchannel_rx,
258 }
259 .run(&mut transit_tx, &mut transit_rx, &mut cancel)
260 .await;
261 match result {
263 Ok(()) => Ok(()),
264 Err(error @ ForwardingError::PeerError(_)) => Err(error),
265 Err(error) => {
266 let _ = transit_tx
267 .send(
268 PeerMessage::Error(format!("{}", error))
269 .ser_msgpack()
270 .into_boxed_slice(),
271 )
272 .await;
273 Err(error)
274 },
275 }
276}
277
278struct ForwardingServe {
279 targets: HashMap<String, (Option<url::Host>, u16)>,
280 connections: HashMap<
282 u64,
283 (
284 async_std::task::JoinHandle<()>,
285 futures::io::WriteHalf<TcpStream>,
286 ),
287 >,
288 historic_connections: HashSet<u64>,
293 backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
295 backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
296}
297
298impl ForwardingServe {
300 async fn forward(
301 &mut self,
302 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
303 connection_id: u64,
304 payload: &[u8],
305 ) -> Result<(), ForwardingError> {
306 tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
307 match self.connections.get_mut(&connection_id) {
308 Some((_worker, connection)) => {
309 if let Err(e) = connection.write_all(payload).await {
311 tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
312 self.remove_connection(transit_tx, connection_id, true)
313 .await?;
314 }
315 },
316 None if !self.historic_connections.contains(&connection_id) => {
317 bail!(ForwardingError::protocol(format!(
318 "Connection '{}' not found",
319 connection_id
320 )));
321 },
322 None => { },
323 }
324 Ok(())
325 }
326
327 async fn remove_connection(
328 &mut self,
329 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
330 connection_id: u64,
331 tell_peer: bool,
332 ) -> Result<(), ForwardingError> {
333 tracing::debug!("Removing connection: #{}", connection_id);
334 if tell_peer {
335 transit_tx
336 .send(
337 PeerMessage::Disconnect { connection_id }
338 .ser_msgpack()
339 .into_boxed_slice(),
340 )
341 .await?;
342 }
343 match self.connections.remove(&connection_id) {
344 Some((worker, _connection)) => {
345 worker.cancel().await;
346 },
347 None if !self.historic_connections.contains(&connection_id) => {
348 bail!(ForwardingError::protocol(format!(
349 "Connection '{}' not found",
350 connection_id
351 )));
352 },
353 None => { },
354 }
355 Ok(())
356 }
357
358 async fn spawn_connection(
359 &mut self,
360 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
361 mut target: String,
362 connection_id: u64,
363 ) -> Result<(), ForwardingError> {
364 tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
365
366 use std::collections::hash_map::Entry;
367 let entry = match self.connections.entry(connection_id) {
368 Entry::Vacant(entry) => entry,
369 Entry::Occupied(_) => {
370 bail!(ForwardingError::protocol(format!(
371 "Connection '{}' already exists",
372 connection_id
373 )));
374 },
375 };
376
377 let (host, port) = self.targets.get(&target).unwrap();
378 if host.is_none() {
379 target = format!("[::1]:{}", port);
380 }
381 let stream = match TcpStream::connect(&target).await {
382 Ok(stream) => stream,
383 Err(err) => {
384 tracing::warn!(
385 "Cannot open connection to {}: {}. The forwarded service might be down.",
386 target,
387 err
388 );
389 transit_tx
390 .send(
391 PeerMessage::Disconnect { connection_id }
392 .ser_msgpack()
393 .into_boxed_slice(),
394 )
395 .await?;
396 return Ok(());
397 },
398 };
399 let (mut connection_rd, connection_wr) = stream.split();
400 let mut backchannel_tx = self.backchannel_tx.clone();
401 let worker = async_std::task::spawn_local(async move {
402 let mut buffer = vec![0; 4096];
403 macro_rules! break_on_err {
405 ($expr:expr) => {
406 match $expr {
407 Ok(val) => val,
408 Err(_) => break,
409 }
410 };
411 }
412 #[allow(clippy::while_let_loop)]
413 loop {
414 let read = break_on_err!(connection_rd.read(&mut buffer).await);
415 if read == 0 {
416 break;
417 }
418 let buffer = &buffer[..read];
419 break_on_err!(
420 backchannel_tx
421 .send((connection_id, Some(buffer.to_vec())))
422 .await
423 );
424 }
425 let _ = backchannel_tx.send((connection_id, None)).await;
427 backchannel_tx.disconnect();
428 });
429 entry.insert((worker, connection_wr));
430 Ok(())
431 }
432
433 async fn shutdown(self) {
434 tracing::debug!("Shutting down everything");
435 for (worker, _connection) in self.connections.into_values() {
436 worker.cancel().await;
437 }
438 }
439
440 async fn run(
441 mut self,
442 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
443 transit_rx: &mut (impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>>
444 + Unpin),
445 cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
446 ) -> Result<(), ForwardingError> {
447 tracing::debug!("Entered processing loop");
449 let ret = loop {
450 futures::select! {
451 message = transit_rx.next() => {
452 match PeerMessage::de_msgpack(&message.unwrap()?)? {
453 PeerMessage::Forward { connection_id, payload } => {
454 self.forward(transit_tx, connection_id, &payload).await?
455 },
456 PeerMessage::Connect { target, connection_id } => {
457 self.historic_connections.insert(connection_id);
459 ensure!(
460 self.targets.contains_key(&target),
461 ForwardingError::protocol(format!("We don't know forwarding target '{}'", target)),
462 );
463
464 self.spawn_connection(transit_tx, target, connection_id).await?;
465 },
466 PeerMessage::Disconnect { connection_id } => {
467 self.remove_connection(transit_tx, connection_id, false).await?;
468 },
469 PeerMessage::Close => {
470 tracing::info!("Peer gracefully closed connection");
471 self.shutdown().await;
472 break Ok(());
473 },
474 PeerMessage::Error(err) => {
475 self.shutdown().await;
476 bail!(ForwardingError::PeerError(err));
477 },
478 other => {
479 self.shutdown().await;
480 bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
481 },
482 }
483 },
484 message = self.backchannel_rx.next() => {
485 match message.unwrap() {
487 (connection_id, Some(payload)) => {
488 transit_tx.send(
489 PeerMessage::Forward {
490 connection_id,
491 payload
492 }
493 .ser_msgpack()
494 .into_boxed_slice()
495 ).await?;
496 },
497 (connection_id, None) => {
498 self.remove_connection(transit_tx, connection_id, true).await?;
499 },
500 }
501 },
502 () = &mut *cancel => {
504 tracing::info!("Closing connection");
505 transit_tx.send(
506 PeerMessage::Close.ser_msgpack()
507 .into_boxed_slice()
508 )
509 .await?;
510 transit_tx.close().await?;
511 self.shutdown().await;
512 break Ok(());
513 },
514 }
515 };
516 tracing::debug!("Exited processing loop");
517 ret
518 }
519}
520
521pub async fn connect(
536 mut wormhole: Wormhole,
537 transit_handler: impl FnOnce(transit::TransitInfo),
538 relay_hints: Vec<transit::RelayHint>,
539 bind_address: Option<std::net::IpAddr>,
540 custom_ports: &[u16],
541) -> Result<ConnectOffer, ForwardingError> {
542 let our_version: &AppVersion = wormhole
543 .our_version()
544 .downcast_ref()
545 .expect("You may only use a Wormhole instance with the correct AppVersion type!");
546 let peer_version: AppVersion = serde_json::from_value(wormhole.peer_version().clone())?;
547 let connector = transit::init(
548 our_version.transit_abilities,
549 Some(peer_version.transit_abilities),
550 relay_hints,
551 )
552 .await?;
553 let bind_address = bind_address.unwrap_or_else(|| std::net::IpAddr::V6("::".parse().unwrap()));
554
555 wormhole
557 .send_json(&PeerMessage::Transit {
558 hints: (**connector.our_hints()).clone(),
559 })
560 .await?;
561
562 let their_hints: transit::Hints = match wormhole.receive_json().await?? {
564 PeerMessage::Transit { hints } => {
565 tracing::debug!("Received transit message: {:?}", hints);
566 hints
567 },
568 PeerMessage::Error(err) => {
569 bail!(ForwardingError::PeerError(err));
570 },
571 other => {
572 let error = ForwardingError::unexpected_message("transit", other);
573 let _ = wormhole
574 .send_json(&PeerMessage::Error(format!("{}", error)))
575 .await;
576 bail!(error)
577 },
578 };
579
580 let (mut transit, info) = match connector
581 .follower_connect(
582 wormhole.key().derive_transit_key(wormhole.appid()),
583 peer_version.transit_abilities,
584 Arc::new(their_hints),
585 )
586 .await
587 {
588 Ok(transit) => transit,
589 Err(error) => {
590 let error = ForwardingError::TransitConnect(error);
591 let _ = wormhole
592 .send_json(&PeerMessage::Error(format!("{}", error)))
593 .await;
594 return Err(error);
595 },
596 };
597 transit_handler(info);
598
599 wormhole.close().await?;
601
602 let run = async {
603 let addresses = match PeerMessage::de_msgpack(&transit.receive_record().await?)? {
606 PeerMessage::Offer { addresses } => addresses,
607 PeerMessage::Error(err) => {
608 bail!(ForwardingError::PeerError(err));
609 },
610 other => {
611 bail!(ForwardingError::unexpected_message("offer", other))
612 },
613 };
614
615 if addresses.len() > 1024 {
617 return Err(ForwardingError::protocol("Too many forwarded ports"));
618 }
619
620 let listeners: Vec<(
625 async_std::net::TcpListener,
626 u16,
627 std::rc::Rc<std::string::String>,
628 )> = futures::stream::iter(
629 addresses
630 .into_iter()
631 .map(Rc::new)
632 .zip(custom_ports.iter().copied().chain(std::iter::repeat(0))),
633 )
634 .then(|(address, port)| async move {
635 let connection = TcpListener::bind((bind_address, port)).await?;
636 let port = connection.local_addr()?.port();
637 Result::<_, std::io::Error>::Ok((connection, port, address))
638 })
639 .try_collect()
640 .await?;
641 Ok(listeners)
642 };
643
644 match run.await {
645 Ok(listeners) => Ok(ConnectOffer {
646 transit,
647 mapping: listeners.iter().map(|(_, b, c)| (*b, c.clone())).collect(),
648 listeners,
649 }),
650 Err(error @ ForwardingError::PeerError(_)) => Err(error),
651 Err(error) => {
652 let _ = transit
653 .send_record(&PeerMessage::Error(format!("{}", error)).ser_msgpack())
654 .await;
655 Err(error)
656 },
657 }
658}
659
660#[must_use]
664pub struct ConnectOffer {
665 pub mapping: Vec<(u16, Rc<String>)>,
667 transit: transit::Transit,
668 listeners: Vec<(
669 async_std::net::TcpListener,
670 u16,
671 std::rc::Rc<std::string::String>,
672 )>,
673}
674
675impl ConnectOffer {
676 pub async fn accept(self, cancel: impl Future<Output = ()>) -> Result<(), ForwardingError> {
683 let (transit_tx, transit_rx) = self.transit.split();
684 let transit_rx = transit_rx.fuse();
685 use futures::FutureExt;
686 let cancel = cancel.fuse();
687 futures::pin_mut!(transit_tx);
688 futures::pin_mut!(transit_rx);
689 futures::pin_mut!(cancel);
690
691 let run = async {
693 let (backchannel_tx, backchannel_rx) =
694 futures::channel::mpsc::channel::<(u64, Option<Vec<u8>>)>(20);
695
696 ForwardConnect {
697 incoming: futures::stream::select_all(self.listeners.into_iter().map(
698 |(connection, _, address)| {
699 connection
700 .into_incoming()
701 .map_ok(move |stream| (address.clone(), stream))
702 .boxed_local()
703 },
704 )),
705 connection_counter: 0,
706 connections: HashMap::new(),
707 backchannel_tx,
708 backchannel_rx,
709 }
710 .run(&mut transit_tx, &mut transit_rx, &mut cancel)
711 .await
712 };
713
714 match run.await {
715 Ok(()) => Ok(()),
716 Err(error @ ForwardingError::PeerError(_)) => Err(error),
717 Err(error) => {
718 let _ = transit_tx
719 .send(
720 PeerMessage::Error(format!("{}", error))
721 .ser_msgpack()
722 .into_boxed_slice(),
723 )
724 .await;
725 Err(error)
726 },
727 }
728 }
729
730 pub async fn reject(mut self) -> Result<(), ForwardingError> {
734 self.transit
735 .send_record(&PeerMessage::Error("transfer rejected".into()).ser_msgpack())
736 .await?;
737
738 Ok(())
739 }
740}
741
742#[allow(clippy::type_complexity)]
743struct ForwardConnect {
744 incoming: futures::stream::SelectAll<
747 futures::stream::LocalBoxStream<
748 'static,
749 Result<(Rc<String>, async_std::net::TcpStream), std::io::Error>,
750 >,
751 >,
752 connection_counter: u64,
754 connections: HashMap<
755 u64,
756 (
757 async_std::task::JoinHandle<()>,
758 futures::io::WriteHalf<TcpStream>,
759 ),
760 >,
761 backchannel_tx: futures::channel::mpsc::Sender<(u64, Option<Vec<u8>>)>,
763 backchannel_rx: futures::channel::mpsc::Receiver<(u64, Option<Vec<u8>>)>,
764}
765
766impl ForwardConnect {
767 async fn forward(
768 &mut self,
769 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
770 connection_id: u64,
771 payload: &[u8],
772 ) -> Result<(), ForwardingError> {
773 tracing::debug!("Forwarding {} bytes from #{}", payload.len(), connection_id);
774 match self.connections.get_mut(&connection_id) {
775 Some((_worker, connection)) => {
776 if let Err(e) = connection.write_all(payload).await {
778 tracing::warn!("Forwarding to #{} failed: {}", connection_id, e);
779 self.remove_connection(transit_tx, connection_id, true)
780 .await?;
781 }
782 },
783 None if self.connection_counter <= connection_id => {
784 bail!(ForwardingError::protocol(format!(
785 "Connection '{}' not found",
786 connection_id
787 )));
788 },
789 None => { },
790 }
791 Ok(())
792 }
793
794 async fn remove_connection(
795 &mut self,
796 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
797 connection_id: u64,
798 tell_peer: bool,
799 ) -> Result<(), ForwardingError> {
800 tracing::debug!("Removing connection: #{}", connection_id);
801 if tell_peer {
802 transit_tx
803 .send(
804 PeerMessage::Disconnect { connection_id }
805 .ser_msgpack()
806 .into_boxed_slice(),
807 )
808 .await?;
809 }
810 match self.connections.remove(&connection_id) {
811 Some((worker, _connection)) => {
812 worker.cancel().await;
813 },
814 None if connection_id >= self.connection_counter => {
815 bail!(ForwardingError::protocol(format!(
816 "Connection '{}' not found",
817 connection_id
818 )));
819 },
820 None => { },
821 }
822 Ok(())
823 }
824
825 async fn spawn_connection(
826 &mut self,
827 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
828 target: Rc<String>,
829 connection: TcpStream,
830 ) -> Result<(), ForwardingError> {
831 let connection_id = self.connection_counter;
832 self.connection_counter += 1;
833 let (mut connection_rd, connection_wr) = connection.split();
834 let mut backchannel_tx = self.backchannel_tx.clone();
835 tracing::debug!("Creating new connection: #{} -> {}", connection_id, target);
836
837 transit_tx
838 .send(
839 PeerMessage::Connect {
840 target: (*target).clone(),
841 connection_id,
842 }
843 .ser_msgpack()
844 .into_boxed_slice(),
845 )
846 .await?;
847
848 let worker = async_std::task::spawn_local(async move {
849 let mut buffer = vec![0; 4096];
850 macro_rules! break_on_err {
852 ($expr:expr) => {
853 match $expr {
854 Ok(val) => val,
855 Err(_) => break,
856 }
857 };
858 }
859 #[allow(clippy::while_let_loop)]
860 loop {
861 let read = break_on_err!(connection_rd.read(&mut buffer).await);
862 if read == 0 {
863 break;
864 }
865 let buffer = &buffer[..read];
866 break_on_err!(
867 backchannel_tx
868 .send((connection_id, Some(buffer.to_vec())))
869 .await
870 );
871 }
872 let _ = backchannel_tx.send((connection_id, None)).await;
874 backchannel_tx.disconnect();
875 });
876
877 self.connections
878 .insert(connection_id, (worker, connection_wr));
879 Ok(())
880 }
881
882 async fn shutdown(self) {
883 tracing::debug!("Shutting down everything");
884 for (worker, _connection) in self.connections.into_values() {
885 worker.cancel().await;
886 }
887 }
888
889 async fn run(
890 mut self,
891 transit_tx: &mut (impl futures::sink::Sink<Box<[u8]>, Error = TransitError> + Unpin),
892 transit_rx: &mut (impl futures::stream::FusedStream<Item = Result<Box<[u8]>, TransitError>>
893 + Unpin),
894 cancel: &mut (impl futures::future::FusedFuture<Output = ()> + Unpin),
895 ) -> Result<(), ForwardingError> {
896 tracing::debug!("Entered processing loop");
898 let ret = loop {
899 futures::select! {
900 message = transit_rx.next() => {
901 match PeerMessage::de_msgpack(&message.unwrap()?)? {
902 PeerMessage::Forward { connection_id, payload } => {
903 self.forward(transit_tx, connection_id, &payload).await?;
904 },
905 PeerMessage::Disconnect { connection_id } => {
906 self.remove_connection(transit_tx, connection_id, false).await?;
907 },
908 PeerMessage::Close => {
909 tracing::info!("Peer gracefully closed connection");
910 self.shutdown().await;
911 break Ok(())
912 },
913 PeerMessage::Error(err) => {
914 for (worker, _connection) in self.connections.into_values() {
915 worker.cancel().await;
916 }
917 bail!(ForwardingError::PeerError(err));
918 },
919 other => {
920 self.shutdown().await;
921 bail!(ForwardingError::unexpected_message("connect' or 'disconnect' or 'forward' or 'close", other));
922 },
923 }
924 },
925 message = self.backchannel_rx.next() => {
926 match message.unwrap() {
928 (connection_id, Some(payload)) => {
929 transit_tx.send(
930 PeerMessage::Forward {
931 connection_id,
932 payload
933 }.ser_msgpack()
934 .into_boxed_slice()
935 )
936 .await?;
937 },
938 (connection_id, None) => {
939 self.remove_connection(transit_tx, connection_id, true).await?;
940 },
941 }
942 },
943 connection = self.incoming.next() => {
944 let (target, connection): (Rc<String>, TcpStream) = connection.unwrap()?;
945 self.spawn_connection(transit_tx, target, connection).await?;
946 },
947 () = &mut *cancel => {
949 tracing::info!("Closing connection");
950 transit_tx.send(
951 PeerMessage::Close.ser_msgpack()
952 .into_boxed_slice()
953 )
954 .await?;
955 transit_tx.close().await?;
956 self.shutdown().await;
957 break Ok(());
958 },
959 }
960 };
961 tracing::debug!("Exited processing loop");
962 ret
963 }
964}
965
966#[derive(Deserialize, Serialize, Debug)]
968#[serde(rename_all = "kebab-case")]
969#[non_exhaustive]
970enum PeerMessage {
971 Offer { addresses: Vec<String> },
975 Connect { target: String, connection_id: u64 },
979 Disconnect { connection_id: u64 },
984 Forward {
986 connection_id: u64,
987 payload: Vec<u8>,
988 },
989 Close,
991 Error(String),
993 Transit { hints: transit::Hints },
995 #[serde(other)]
996 Unknown,
997}
998
999impl PeerMessage {
1000 #[allow(dead_code)]
1001 pub fn ser_msgpack(&self) -> Vec<u8> {
1002 let mut writer = Vec::with_capacity(128);
1003 let mut ser = rmp_serde::encode::Serializer::new(&mut writer)
1004 .with_struct_map()
1005 .with_human_readable();
1006 serde::Serialize::serialize(self, &mut ser).unwrap();
1007 writer
1008 }
1009
1010 #[allow(dead_code)]
1011 pub fn de_msgpack(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
1012 rmp_serde::from_read(&mut &*data)
1013 }
1014}