1#[allow(deprecated)]
17use crate::{Key, KeyPurpose};
18use serde_derive::{Deserialize, Serialize};
19
20#[cfg(not(target_family = "wasm"))]
21use async_std::net::{TcpListener, TcpStream};
22#[allow(unused_imports)] use futures::{
24 future::FutureExt,
25 future::TryFutureExt,
26 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
27 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
28};
29use std::{
30 collections::HashSet,
31 net::{IpAddr, SocketAddr},
32 sync::Arc,
33 time::Instant,
34};
35
36mod crypto;
37mod transport;
38use crypto::TransitHandshakeError;
39use transport::{TransitTransport, TransitTransportRx, TransitTransportTx};
40
41pub const DEFAULT_RELAY_SERVER: &str = "tcp://transit.magic-wormhole.io:4001";
43#[cfg(not(target_family = "wasm"))]
47const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478";
48
49#[deprecated(
51 since = "0.7.0",
52 note = "This will be a private type in the future. Open an issue if you require access to protocol intrinsics in the future"
53)]
54#[derive(Debug)]
55pub struct TransitKey;
56
57#[allow(deprecated)]
58impl KeyPurpose for TransitKey {}
59
60#[deprecated(
62 since = "0.7.0",
63 note = "This will be a private type in the future. Open an issue if you require access to protocol intrinsics in the future"
64)]
65#[derive(Debug)]
66pub struct TransitRxKey;
67#[allow(deprecated)]
68impl KeyPurpose for TransitRxKey {}
69
70#[deprecated(
72 since = "0.7.0",
73 note = "This will be a private type in the future. Open an issue if you require access to protocol intrinsics in the future"
74)]
75#[derive(Debug)]
76pub struct TransitTxKey;
77#[allow(deprecated)]
78impl KeyPurpose for TransitTxKey {}
79
80#[derive(Debug, thiserror::Error)]
82#[non_exhaustive]
83pub enum TransitConnectError {
84 #[error("{}", _0)]
86 Protocol(Box<str>),
87
88 #[error("All (relay) handshakes failed or timed out; could not establish a connection with the peer")]
90 Handshake,
91
92 #[error("I/O error")]
94 IO(
95 #[from]
96 #[source]
97 std::io::Error,
98 ),
99
100 #[cfg(target_family = "wasm")]
102 #[error("WASM error")]
103 WASM(
104 #[from]
105 #[source]
106 ws_stream_wasm::WsErr,
107 ),
108}
109
110#[derive(Debug, thiserror::Error)]
112#[non_exhaustive]
113pub enum TransitError {
114 #[error("Cryptography error. This is probably an implementation bug, but may also be caused by an attack.")]
116 Crypto,
117
118 #[error("Wrong nonce received, got {:x?} but expected {:x?}. This is probably an implementation bug, but may also be caused by an attack.", _0, _1)]
120 Nonce(Box<[u8]>, Box<[u8]>),
121
122 #[error("I/O error")]
124 IO(
125 #[from]
126 #[source]
127 std::io::Error,
128 ),
129
130 #[cfg(target_family = "wasm")]
132 #[error("WASM error")]
133 WASM(
134 #[from]
135 #[source]
136 ws_stream_wasm::WsErr,
137 ),
138}
139
140impl From<()> for TransitError {
141 fn from(_: ()) -> Self {
142 Self::Crypto
143 }
144}
145
146#[derive(Copy, Clone, Debug, Default)]
152pub struct Abilities {
153 pub direct_tcp_v1: bool,
155 pub relay_v1: bool,
157 #[cfg(any())]
158 pub noise_v1: bool,
160}
161
162impl Abilities {
163 pub const ALL: Self = Self {
165 direct_tcp_v1: true,
166 relay_v1: true,
167 #[cfg(any())]
168 noise_v1: false,
169 };
170
171 #[deprecated(since = "0.7.0", note = "use Abilities::ALL")]
173 pub const ALL_ABILITIES: Self = Self::ALL;
174
175 pub const FORCE_DIRECT: Self = Self {
182 direct_tcp_v1: true,
183 relay_v1: false,
184 #[cfg(any())]
185 noise_v1: false,
186 };
187
188 pub const FORCE_RELAY: Self = Self {
197 direct_tcp_v1: false,
198 relay_v1: true,
199 #[cfg(any())]
200 noise_v1: false,
201 };
202
203 pub fn can_direct(&self) -> bool {
205 self.direct_tcp_v1
206 }
207
208 pub fn can_relay(&self) -> bool {
210 self.relay_v1
211 }
212
213 #[cfg(any())]
214 pub fn can_noise_crypto(&self) -> bool {
215 self.noise_v1
216 }
217
218 #[deprecated(since = "0.7.0", note = "Noise cryptography is not standardized")]
220 pub fn can_noise_crypto(&self) -> bool {
221 false
222 }
223
224 pub fn intersect(mut self, other: &Self) -> Self {
226 self.direct_tcp_v1 &= other.direct_tcp_v1;
227 self.relay_v1 &= other.relay_v1;
228 #[cfg(any())]
229 {
230 self.noise_v1 &= other.noise_v1;
231 }
232 self
233 }
234}
235
236impl serde::Serialize for Abilities {
237 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
238 where
239 S: serde::Serializer,
240 {
241 let mut hints = Vec::new();
242 if self.direct_tcp_v1 {
243 hints.push(serde_json::json!({
244 "type": "direct-tcp-v1",
245 }));
246 }
247 if self.relay_v1 {
248 hints.push(serde_json::json!({
249 "type": "relay-v1",
250 }));
251 }
252 #[cfg(any())]
253 if self.noise_v1 {
254 hints.push(serde_json::json!({
255 "type": "noise-crypto-v1",
256 }));
257 }
258 serde_json::Value::Array(hints).serialize(ser)
259 }
260}
261
262impl<'de> serde::Deserialize<'de> for Abilities {
263 fn deserialize<D>(de: D) -> Result<Self, D::Error>
264 where
265 D: serde::Deserializer<'de>,
266 {
267 #[derive(Deserialize)]
268 #[serde(rename_all = "kebab-case", tag = "type")]
269 enum Ability {
270 DirectTcpV1,
271 RelayV1,
272 RelayV2,
273 #[cfg(any())]
274 NoiseCryptoV1,
275 #[serde(other)]
276 Other,
277 }
278
279 let mut abilities = Self::default();
280 for ability in <Vec<Ability> as serde::Deserialize>::deserialize(de)? {
282 match ability {
283 Ability::DirectTcpV1 => {
284 abilities.direct_tcp_v1 = true;
285 },
286 Ability::RelayV1 => {
287 abilities.relay_v1 = true;
288 },
289 #[cfg(any())]
290 Ability::NoiseCryptoV1 => {
291 abilities.noise_v1 = true;
292 },
293 _ => (),
294 }
295 }
296 Ok(abilities)
297 }
298}
299
300#[derive(Serialize, Deserialize, Debug, PartialEq)]
302#[serde(rename_all = "kebab-case", tag = "type")]
303#[non_exhaustive]
304enum HintSerde {
305 DirectTcpV1(DirectHint),
306 RelayV1(RelayHint),
307 #[serde(other)]
308 Unknown,
309}
310
311#[derive(Clone, Debug, Default)]
313pub struct Hints {
314 pub direct_tcp: HashSet<DirectHint>,
316 pub relay: Vec<RelayHint>,
318}
319
320impl Hints {
321 pub fn new(
323 direct_tcp: impl IntoIterator<Item = DirectHint>,
324 relay: impl IntoIterator<Item = RelayHint>,
325 ) -> Self {
326 Self {
327 direct_tcp: direct_tcp.into_iter().collect(),
328 relay: relay.into_iter().collect(),
329 }
330 }
331}
332
333impl<'de> serde::Deserialize<'de> for Hints {
334 fn deserialize<D>(de: D) -> Result<Self, D::Error>
335 where
336 D: serde::Deserializer<'de>,
337 {
338 let hints: Vec<HintSerde> = serde::Deserialize::deserialize(de)?;
339 let mut direct_tcp = HashSet::new();
340 let mut relay = Vec::<RelayHint>::new();
341 let mut relay_v2 = Vec::<RelayHint>::new();
342
343 for hint in hints {
344 match hint {
345 HintSerde::DirectTcpV1(hint) => {
346 direct_tcp.insert(hint);
347 },
348 HintSerde::RelayV1(hint) => {
349 relay_v2.push(hint);
350 },
351 _ => {},
353 }
354 }
355
356 if !relay_v2.is_empty() {
358 relay.clear();
359 }
360 relay.extend(relay_v2.into_iter().map(Into::into));
361
362 Ok(Hints { direct_tcp, relay })
363 }
364}
365
366impl serde::Serialize for Hints {
367 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
368 where
369 S: serde::Serializer,
370 {
371 let direct = self.direct_tcp.iter().cloned().map(HintSerde::DirectTcpV1);
372 let relay = self.relay.iter().cloned().map(HintSerde::RelayV1);
373 ser.collect_seq(direct.chain(relay))
374 }
375}
376
377#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)]
379#[display("tcp://{}:{}", hostname, port)]
380pub struct DirectHint {
381 pub hostname: String,
386 pub port: u16,
388}
389
390impl DirectHint {
391 pub fn new(hostname: impl Into<String>, port: u16) -> Self {
393 Self {
394 hostname: hostname.into(),
395 port,
396 }
397 }
398}
399
400#[derive(Serialize, Deserialize, Debug, PartialEq)]
402#[serde(rename_all = "kebab-case", tag = "type")]
403#[non_exhaustive]
404struct RelayHintSerde {
405 name: Option<String>,
406 #[serde(rename = "hints")]
407 endpoints: Vec<RelayHintSerdeInner>,
408}
409
410#[derive(Serialize, Deserialize, Debug, PartialEq)]
412#[serde(rename_all = "kebab-case", tag = "type")]
413#[non_exhaustive]
414enum RelayHintSerdeInner {
415 #[serde(rename = "direct-tcp-v1")]
416 Tcp(DirectHint),
417 Websocket {
418 url: url::Url,
419 },
420 #[serde(other)]
421 Unknown,
422}
423
424#[derive(Debug, thiserror::Error)]
425#[non_exhaustive]
426pub enum RelayHintParseError {
428 #[error(
429 "Invalid TCP hint endpoint: '{}' (Does it have hostname and port?)",
430 _0
431 )]
432 InvalidTcp(url::Url),
434 #[error(
435 "Unknown schema: '{}'. Currently known values are 'tcp', 'ws' and 'wss'.",
436 _0
437 )]
438 UnknownSchema(Box<str>),
440 #[error("'{}' is not an absolute URL (must start with a '/')", _0)]
441 UrlNotAbsolute(url::Url),
443}
444
445#[derive(Clone, Debug, Eq, PartialEq, Default)]
456pub struct RelayHint {
457 pub name: Option<String>,
461 pub tcp: HashSet<DirectHint>,
463 pub ws: HashSet<url::Url>,
465}
466
467impl RelayHint {
468 pub fn new(
470 name: Option<String>,
471 tcp: impl IntoIterator<Item = DirectHint>,
472 ws: impl IntoIterator<Item = url::Url>,
473 ) -> Self {
474 Self {
475 name,
476 tcp: tcp.into_iter().collect(),
477 ws: ws.into_iter().collect(),
478 }
479 }
480
481 pub fn from_urls(
503 name: Option<String>,
504 urls: impl IntoIterator<Item = url::Url>,
505 ) -> Result<Self, RelayHintParseError> {
506 let mut this = Self {
507 name,
508 ..Self::default()
509 };
510 for url in urls.into_iter() {
511 ensure!(
512 !url.cannot_be_a_base(),
513 RelayHintParseError::UrlNotAbsolute(url)
514 );
515 match url.scheme() {
516 "tcp" => {
517 let (hostname, port) = match (url.host_str(), url.port()) {
519 (Some(hostname), Some(port)) => (hostname.into(), port),
520 _ => bail!(RelayHintParseError::InvalidTcp(url)),
521 };
522 this.tcp.insert(DirectHint { hostname, port });
523 },
524 "ws" | "wss" => {
525 this.ws.insert(url);
526 },
527 other => bail!(RelayHintParseError::UnknownSchema(other.into())),
528 }
529 }
530 assert!(
531 !this.tcp.is_empty() || !this.ws.is_empty(),
532 "No URLs provided"
533 );
534 Ok(this)
535 }
536
537 #[deprecated(
538 since = "0.7.0",
539 note = "This will be a private method in the future. Open an issue if you require access to protocol intrinsics in the future"
540 )]
541 pub fn can_merge(&self, other: &Self) -> bool {
543 !self.tcp.is_disjoint(&other.tcp) || !self.ws.is_disjoint(&other.ws)
544 }
545
546 #[deprecated(
547 since = "0.7.0",
548 note = "This will be a private method in the future. Open an issue if you require access to protocol intrinsics in the future"
549 )]
550 pub fn merge(mut self, other: Self) -> Self {
552 #[allow(deprecated)]
553 self.merge_mut(other);
554 self
555 }
556
557 #[deprecated(
558 since = "0.7.0",
559 note = "This will be a private method in the future. Open an issue if you require access to protocol intrinsics in the future"
560 )]
561 pub fn merge_mut(&mut self, other: Self) {
563 self.tcp.extend(other.tcp);
564 self.ws.extend(other.ws);
565 }
566
567 #[deprecated(
568 since = "0.7.0",
569 note = "This will be a private method in the future. Open an issue if you require access to protocol intrinsics in the future"
570 )]
571 #[allow(deprecated)]
572 pub fn merge_into(self, collection: &mut Vec<RelayHint>) {
574 for item in collection.iter_mut() {
575 if item.can_merge(&self) {
576 item.merge_mut(self);
577 return;
578 }
579 }
580 collection.push(self);
581 }
582}
583
584impl serde::Serialize for RelayHint {
585 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
586 where
587 S: serde::Serializer,
588 {
589 let mut hints = Vec::new();
590 hints.extend(self.tcp.iter().cloned().map(RelayHintSerdeInner::Tcp));
591 hints.extend(
592 self.ws
593 .iter()
594 .cloned()
595 .map(|h| RelayHintSerdeInner::Websocket { url: h }),
596 );
597
598 serde_json::json!({
599 "name": self.name,
600 "hints": hints,
601 })
602 .serialize(ser)
603 }
604}
605
606impl<'de> serde::Deserialize<'de> for RelayHint {
607 fn deserialize<D>(de: D) -> Result<Self, D::Error>
608 where
609 D: serde::Deserializer<'de>,
610 {
611 let raw = RelayHintSerde::deserialize(de)?;
612 let mut hint = RelayHint {
613 name: raw.name,
614 tcp: HashSet::new(),
615 ws: HashSet::new(),
616 };
617
618 for e in raw.endpoints {
619 match e {
620 RelayHintSerdeInner::Tcp(tcp) => {
621 hint.tcp.insert(tcp);
622 },
623 RelayHintSerdeInner::Websocket { url } => {
624 hint.ws.insert(url);
625 },
626 _ => {},
628 }
629 }
630
631 Ok(hint)
632 }
633}
634
635impl TryFrom<&DirectHint> for IpAddr {
636 type Error = std::net::AddrParseError;
637 fn try_from(hint: &DirectHint) -> Result<IpAddr, std::net::AddrParseError> {
638 hint.hostname.parse()
639 }
640}
641
642impl TryFrom<&DirectHint> for SocketAddr {
643 type Error = std::net::AddrParseError;
644 fn try_from(hint: &DirectHint) -> Result<SocketAddr, std::net::AddrParseError> {
646 let addr = hint.try_into()?;
647 let addr = match addr {
648 IpAddr::V4(v4) => IpAddr::V6(v4.to_ipv6_mapped()),
649 IpAddr::V6(_) => addr,
650 };
651 Ok(SocketAddr::new(addr, hint.port))
652 }
653}
654
655#[derive(Clone, Debug, Eq, PartialEq)]
657#[non_exhaustive]
658pub enum ConnectionType {
659 Direct,
661 Relay {
663 name: Option<String>,
665 },
666}
667
668#[derive(Clone, Debug, Eq, PartialEq)]
670#[non_exhaustive]
671pub struct TransitInfo {
672 pub conn_type: ConnectionType,
674 #[cfg(not(target_family = "wasm"))]
677 pub peer_addr: SocketAddr,
678}
679
680type TransitConnection = (Box<dyn TransitTransport>, TransitInfo);
681
682#[cfg(not(target_family = "wasm"))]
683#[derive(Debug, thiserror::Error)]
684enum StunError {
685 #[error("No IPv4 addresses were found for the selected STUN server")]
686 ServerIsV6Only,
687 #[error("Server did not tell us our IP address")]
688 ServerNoResponse,
689 #[error("Connection timed out")]
690 Timeout,
691 #[error("IO error")]
692 IO(
693 #[from]
694 #[source]
695 std::io::Error,
696 ),
697 #[error("Malformed STUN packet")]
698 Codec(
699 #[from]
700 #[source]
701 bytecodec::Error,
702 ),
703}
704
705#[cfg(not(target_family = "wasm"))]
706impl std::fmt::Display for TransitInfo {
707 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
708 match &self.conn_type {
709 ConnectionType::Direct => {
710 write!(
711 f,
712 "Established direct transit connection to '{}'",
713 self.peer_addr,
714 )
715 },
716 ConnectionType::Relay { name: Some(name) } => {
717 write!(
718 f,
719 "Established transit connection via relay '{}' ({})",
720 name, self.peer_addr,
721 )
722 },
723 ConnectionType::Relay { name: None } => {
724 write!(
725 f,
726 "Established transit connection via relay ({})",
727 self.peer_addr,
728 )
729 },
730 }
731 }
732}
733
734#[cfg(target_family = "wasm")]
735impl std::fmt::Display for TransitInfo {
736 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737 match &self.conn_type {
738 ConnectionType::Direct => {
739 write!(f, "Established direct transit connection",)
740 },
741 ConnectionType::Relay { name: Some(name) } => {
742 write!(f, "Established transit connection via relay '{}'", name)
743 },
744 ConnectionType::Relay { name: None } => {
745 write!(f, "Established transit connection via relay",)
746 },
747 }
748 }
749}
750
751#[deprecated(
752 since = "0.7.0",
753 note = "use the `Display` implementation of `TransitInfo` instead"
754)]
755pub fn log_transit_connection(
757 conn_type: ConnectionType,
758 #[cfg(not(target_family = "wasm"))] peer_addr: SocketAddr,
759) {
760 let info = TransitInfo {
761 conn_type,
762 #[cfg(not(target_family = "wasm"))]
763 peer_addr,
764 };
765
766 tracing::info!("{info}");
767}
768
769#[deprecated(
775 since = "0.7.0",
776 note = "This will be a private type in the future. Open an issue if you require access to protocol intrinsics in the future"
777)]
778#[allow(deprecated)]
779pub async fn init(
780 mut abilities: Abilities,
781 peer_abilities: Option<Abilities>,
782 relay_hints: Vec<RelayHint>,
783) -> Result<TransitConnector, std::io::Error> {
784 let mut our_hints = Hints::default();
785 #[cfg(not(target_family = "wasm"))]
786 let mut sockets = None;
787
788 if let Some(peer_abilities) = peer_abilities {
789 abilities = abilities.intersect(&peer_abilities);
790 }
791
792 #[cfg(not(target_family = "wasm"))]
794 if abilities.can_direct() {
795 let create_sockets = async {
796 let socket: MaybeConnectedSocket = match async_std::future::timeout(
801 std::time::Duration::from_secs(4),
802 transport::tcp_get_external_ip(),
803 )
804 .await
805 .map_err(|_| StunError::Timeout)
806 {
807 Ok(Ok((external_ip, stream))) => {
808 tracing::debug!("Our external IP address is {}", external_ip);
809 our_hints.direct_tcp.insert(DirectHint {
810 hostname: external_ip.ip().to_string(),
811 port: external_ip.port(),
812 });
813 tracing::debug!(
814 "Our socket for connecting is bound to {} and connected to {}",
815 stream.local_addr()?,
816 stream.peer_addr()?,
817 );
818 stream.into()
819 },
820 Err(err) | Ok(Err(err)) => {
823 tracing::warn!("Failed to get external address via STUN, {}", err);
824 let socket =
825 socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
826 transport::set_socket_opts(&socket)?;
827
828 socket.bind(&"[::]:0".parse::<SocketAddr>().unwrap().into())?;
829 tracing::debug!(
830 "Our socket for connecting is bound to {}",
831 socket.local_addr()?.as_socket().unwrap(),
832 );
833
834 socket.into()
835 },
836 };
837
838 let listener = TcpListener::bind("[::]:0").await?;
845
846 let port = socket.local_addr()?.as_socket().unwrap().port();
848 let port2 = listener.local_addr()?.port();
849 our_hints.direct_tcp.extend(
850 if_addrs::get_if_addrs()?
851 .iter()
852 .filter(|iface| !iface.is_loopback())
853 .flat_map(|ip| {
854 [
855 DirectHint {
856 hostname: ip.ip().to_string(),
857 port,
858 },
859 DirectHint {
860 hostname: ip.ip().to_string(),
861 port: port2,
862 },
863 ]
864 .into_iter()
865 }),
866 );
867 tracing::debug!("Our socket for listening is {}", listener.local_addr()?);
868
869 Ok::<_, std::io::Error>((socket, listener))
870 };
871
872 sockets = create_sockets
873 .await
874 .map_err(|err| {
876 tracing::error!("Failed to create direct hints for our side: {}", err);
877 err
878 })
879 .ok();
880 }
881
882 if abilities.can_relay() {
883 our_hints.relay.extend(relay_hints);
884 }
885
886 Ok(TransitConnector {
887 #[cfg(not(target_family = "wasm"))]
888 sockets,
889 our_abilities: abilities,
890 our_hints: Arc::new(our_hints),
891 })
892}
893
894#[cfg(not(target_family = "wasm"))]
896#[derive(derive_more::From)]
897enum MaybeConnectedSocket {
898 #[from]
899 Socket(socket2::Socket),
900 #[from]
901 Stream(TcpStream),
902}
903
904#[cfg(not(target_family = "wasm"))]
905impl MaybeConnectedSocket {
906 fn local_addr(&self) -> std::io::Result<socket2::SockAddr> {
907 match &self {
908 Self::Socket(socket) => socket.local_addr(),
909 Self::Stream(stream) => Ok(stream.local_addr()?.into()),
910 }
911 }
912}
913
914#[deprecated(
922 since = "0.7.0",
923 note = "This will be a private type in the future. Open an issue if you require access to protocol intrinsics in the future"
924)]
925pub struct TransitConnector {
926 #[cfg(not(target_family = "wasm"))]
931 sockets: Option<(MaybeConnectedSocket, TcpListener)>,
932 our_abilities: Abilities,
933 our_hints: Arc<Hints>,
934}
935
936#[allow(deprecated)]
937impl TransitConnector {
938 pub fn our_abilities(&self) -> &Abilities {
940 &self.our_abilities
941 }
942
943 pub fn our_hints(&self) -> &Arc<Hints> {
945 &self.our_hints
946 }
947
948 pub async fn connect(
956 self,
957 is_leader: bool,
958 transit_key: Key<TransitKey>,
959 their_abilities: Abilities,
960 their_hints: Arc<Hints>,
961 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
962 if is_leader {
963 self.leader_connect(transit_key, their_abilities, their_hints)
964 .await
965 } else {
966 self.follower_connect(transit_key, their_abilities, their_hints)
967 .await
968 }
969 }
970
971 pub async fn leader_connect(
975 self,
976 transit_key: Key<TransitKey>,
977 their_abilities: Abilities,
978 their_hints: Arc<Hints>,
979 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
980 let Self {
981 #[cfg(not(target_family = "wasm"))]
982 sockets,
983 our_abilities,
984 our_hints,
985 } = self;
986 let transit_key = Arc::new(transit_key);
987
988 let start = Instant::now();
989 let mut connection_stream = Box::pin(
990 Self::connect_inner(
991 true,
992 transit_key,
993 our_abilities,
994 our_hints,
995 their_abilities,
996 their_hints,
997 #[cfg(not(target_family = "wasm"))]
998 sockets,
999 )
1000 .filter_map(|result| async {
1001 match result {
1002 Ok(val) => Some(val),
1003 Err(err) => {
1004 tracing::debug!("Some leader handshake failed: {:?}", err);
1005 None
1006 },
1007 }
1008 }),
1009 );
1010
1011 let (mut transit, mut finalizer, mut conn_info) = async_std::future::timeout(
1012 std::time::Duration::from_secs(60),
1013 connection_stream.next(),
1014 )
1015 .await
1016 .map_err(|_| {
1017 tracing::debug!("`leader_connect` timed out");
1018 TransitConnectError::Handshake
1019 })?
1020 .ok_or(TransitConnectError::Handshake)?;
1021
1022 if conn_info.conn_type != ConnectionType::Direct && our_abilities.can_direct() {
1023 tracing::debug!(
1024 "Established transit connection over relay. Trying to find a direct connection …"
1025 );
1026 let elapsed = start.elapsed();
1030 let to_wait = if elapsed.as_secs() > 5 {
1031 std::time::Duration::from_secs(1)
1033 } else {
1034 elapsed.mul_f32(0.3)
1035 };
1036 let _ = async_std::future::timeout(to_wait, async {
1037 while let Some((new_transit, new_finalizer, new_conn_info)) =
1038 connection_stream.next().await
1039 {
1040 if new_conn_info.conn_type == ConnectionType::Direct {
1042 transit = new_transit;
1043 finalizer = new_finalizer;
1044 conn_info = new_conn_info;
1045 tracing::debug!("Found direct connection; using that instead.");
1046 break;
1047 }
1048 }
1049 })
1050 .await;
1051 tracing::debug!("Did not manage to establish a better connection in time.");
1052 } else {
1053 tracing::debug!("Established direct transit connection");
1054 }
1055
1056 std::mem::drop(connection_stream);
1060
1061 let (tx, rx) = finalizer
1062 .handshake_finalize(&mut transit)
1063 .await
1064 .map_err(|e| {
1065 tracing::debug!("`handshake_finalize` failed: {e}");
1066 TransitConnectError::Handshake
1067 })?;
1068
1069 Ok((
1070 Transit {
1071 socket: transit,
1072 tx,
1073 rx,
1074 },
1075 conn_info,
1076 ))
1077 }
1078
1079 pub async fn follower_connect(
1083 self,
1084 transit_key: Key<TransitKey>,
1085 their_abilities: Abilities,
1086 their_hints: Arc<Hints>,
1087 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
1088 let Self {
1089 #[cfg(not(target_family = "wasm"))]
1090 sockets,
1091 our_abilities,
1092 our_hints,
1093 } = self;
1094 let transit_key = Arc::new(transit_key);
1095
1096 let mut connection_stream = Box::pin(
1097 Self::connect_inner(
1098 false,
1099 transit_key,
1100 our_abilities,
1101 our_hints,
1102 their_abilities,
1103 their_hints,
1104 #[cfg(not(target_family = "wasm"))]
1105 sockets,
1106 )
1107 .filter_map(|result| async {
1108 match result {
1109 Ok(val) => Some(val),
1110 Err(err) => {
1111 tracing::debug!("Some follower handshake failed: {:?}", err);
1112 None
1113 },
1114 }
1115 }),
1116 );
1117
1118 let transit = match async_std::future::timeout(
1119 std::time::Duration::from_secs(60),
1120 &mut connection_stream.next(),
1121 )
1122 .await
1123 {
1124 Ok(Some((mut socket, finalizer, conn_info))) => {
1125 let (tx, rx) = finalizer
1126 .handshake_finalize(&mut socket)
1127 .await
1128 .map_err(|e| {
1129 tracing::debug!("`handshake_finalize` failed: {e}");
1130 TransitConnectError::Handshake
1131 })?;
1132
1133 Ok((Transit { socket, tx, rx }, conn_info))
1134 },
1135 Ok(None) | Err(_) => {
1136 tracing::debug!("`follower_connect` timed out");
1137 Err(TransitConnectError::Handshake)
1138 },
1139 };
1140
1141 std::mem::drop(connection_stream);
1145
1146 transit
1147 }
1148
1149 fn connect_inner(
1159 is_leader: bool,
1160 transit_key: Arc<Key<TransitKey>>,
1161 our_abilities: Abilities,
1162 our_hints: Arc<Hints>,
1163 their_abilities: Abilities,
1164 their_hints: Arc<Hints>,
1165 #[cfg(not(target_family = "wasm"))] sockets: Option<(MaybeConnectedSocket, TcpListener)>,
1166 ) -> impl Stream<Item = Result<HandshakeResult, TransitHandshakeError>> + 'static {
1167 #[cfg(not(target_family = "wasm"))]
1169 assert!(sockets.is_none() || our_abilities.can_direct());
1170
1171 let cryptor = if our_abilities.can_noise_crypto() && their_abilities.can_noise_crypto() {
1172 tracing::debug!("Using noise protocol for encryption");
1173 Arc::new(crypto::NoiseInit {
1174 key: transit_key.clone(),
1175 }) as Arc<dyn crypto::TransitCryptoInit>
1176 } else {
1177 tracing::debug!("Using secretbox for encryption");
1178 Arc::new(crypto::SecretboxInit {
1179 key: transit_key.clone(),
1180 }) as Arc<dyn crypto::TransitCryptoInit>
1181 };
1182
1183 let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>()));
1185
1186 #[cfg(not(target_family = "wasm"))]
1190 use futures::future::BoxFuture;
1191 #[cfg(target_family = "wasm")]
1192 use futures::future::LocalBoxFuture as BoxFuture;
1193 type BoxIterator<T> = Box<dyn Iterator<Item = T>>;
1194 type ConnectorFuture = BoxFuture<'static, Result<TransitConnection, TransitHandshakeError>>;
1195 let mut connectors: BoxIterator<ConnectorFuture> = Box::new(std::iter::empty());
1196
1197 #[cfg(not(target_family = "wasm"))]
1198 let (socket, listener) = sockets.unzip();
1199 #[cfg(not(target_family = "wasm"))]
1200 if our_abilities.can_direct() && their_abilities.can_direct() {
1201 let local_addr = socket.map(|socket| {
1202 Arc::new(
1203 socket
1204 .local_addr()
1205 .expect("This is guaranteed to be an IP socket"),
1206 )
1207 });
1208 connectors = Box::new(
1210 connectors.chain(
1211 their_hints
1212 .direct_tcp
1213 .clone()
1214 .into_iter()
1215 .take(50)
1217 .map(move |hint| transport::connect_tcp_direct(local_addr.clone(), hint))
1218 .map(|fut| Box::pin(fut) as ConnectorFuture),
1219 ),
1220 ) as BoxIterator<ConnectorFuture>;
1221 }
1222
1223 if our_abilities.can_relay() && their_abilities.can_relay() {
1225 let mut relay_hints = Vec::<RelayHint>::new();
1227 relay_hints.extend(our_hints.relay.iter().take(2).cloned());
1228 for hint in their_hints.relay.iter().take(2).cloned() {
1229 hint.merge_into(&mut relay_hints);
1230 }
1231
1232 #[cfg(not(target_family = "wasm"))]
1233 {
1234 connectors = Box::new(
1235 connectors.chain(
1236 relay_hints
1237 .into_iter()
1238 .flat_map(|hint| {
1247 let name = hint.name
1249 .or_else(|| {
1250 hint.tcp.iter()
1252 .filter_map(|hint| match url::Host::parse(&hint.hostname) {
1253 Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
1254 _ => None,
1255 })
1256 .next()
1257 });
1258 hint.tcp
1259 .into_iter()
1260 .take(3)
1261 .enumerate()
1262 .map(move |(i, h)| (i, h, name.clone()))
1263 })
1264 .map(|(index, host, name)| async move {
1265 async_std::task::sleep(std::time::Duration::from_secs(
1266 index as u64 * 5,
1267 ))
1268 .await;
1269 transport::connect_tcp_relay(host, name).await
1270 })
1271 .map(|fut| Box::pin(fut) as ConnectorFuture),
1272 ),
1273 ) as BoxIterator<ConnectorFuture>;
1274 }
1275
1276 #[cfg(target_family = "wasm")]
1277 {
1278 connectors = Box::new(
1279 connectors.chain(
1280 relay_hints
1281 .into_iter()
1282 .flat_map(|hint| {
1291 let name = hint.name
1293 .or_else(|| {
1294 hint.tcp.iter()
1296 .filter_map(|hint| match url::Host::parse(&hint.hostname) {
1297 Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
1298 _ => None,
1299 })
1300 .next()
1301 });
1302 hint.ws
1303 .into_iter()
1304 .take(3)
1305 .enumerate()
1306 .map(move |(i, u)| (i, u, name.clone()))
1307 })
1308 .map(|(index, url, name)| async move {
1309 async_std::task::sleep(std::time::Duration::from_secs(
1310 index as u64 * 5,
1311 ))
1312 .await;
1313 transport::connect_ws_relay(url, name).await
1314 })
1315 .map(|fut| Box::pin(fut) as ConnectorFuture),
1316 ),
1317 ) as BoxIterator<ConnectorFuture>;
1318 }
1319 }
1320
1321 let transit_key2 = transit_key.clone();
1323 let tside2 = tside.clone();
1324 let cryptor2 = cryptor.clone();
1325 #[allow(unused_mut)] let mut connectors = Box::new(
1327 connectors
1328 .map(move |fut| {
1329 let transit_key = transit_key2.clone();
1330 let tside = tside2.clone();
1331 let cryptor = cryptor2.clone();
1332 async move {
1333 let (socket, conn_info) = fut.await?;
1334 let (transit, finalizer) = handshake_exchange(
1335 is_leader,
1336 tside,
1337 socket,
1338 &conn_info.conn_type,
1339 &*cryptor,
1340 transit_key,
1341 )
1342 .await?;
1343 Ok((transit, finalizer, conn_info))
1344 }
1345 })
1346 .map(|fut| {
1347 Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
1348 }),
1349 )
1350 as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
1351
1352 #[cfg(not(target_family = "wasm"))]
1354 if let Some(listener) = listener {
1355 connectors = Box::new(
1356 connectors.chain(
1357 std::iter::once(async move {
1358 let transit_key = transit_key.clone();
1359 let tside = tside.clone();
1360 let cryptor = cryptor.clone();
1361 let connect = || async {
1362 let (socket, peer) = listener.accept().await?;
1363 let (socket, info) =
1364 transport::wrap_tcp_connection(socket, ConnectionType::Direct)?;
1365 tracing::debug!("Got connection from {}!", peer);
1366 let (transit, finalizer) = handshake_exchange(
1367 is_leader,
1368 tside.clone(),
1369 socket,
1370 &ConnectionType::Direct,
1371 &*cryptor,
1372 transit_key.clone(),
1373 )
1374 .await?;
1375 Result::<_, TransitHandshakeError>::Ok((transit, finalizer, info))
1376 };
1377 loop {
1378 match connect().await {
1379 Ok(success) => break Ok(success),
1380 Err(err) => {
1381 tracing::debug!(
1382 "Some handshake failed on the listening port: {:?}",
1383 err
1384 );
1385 continue;
1386 },
1387 }
1388 }
1389 })
1390 .map(|fut| {
1391 Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
1392 }),
1393 ),
1394 )
1395 as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
1396 }
1397 connectors.collect::<futures::stream::futures_unordered::FuturesUnordered<_>>()
1398 }
1399}
1400
1401pub struct Transit {
1408 socket: Box<dyn TransitTransport>,
1410 tx: Box<dyn crypto::TransitCryptoEncrypt>,
1411 rx: Box<dyn crypto::TransitCryptoDecrypt>,
1412}
1413
1414impl Transit {
1415 pub async fn receive_record(&mut self) -> Result<Box<[u8]>, TransitError> {
1417 self.rx.decrypt(&mut self.socket).await
1418 }
1419
1420 pub async fn send_record(&mut self, plaintext: &[u8]) -> Result<(), TransitError> {
1422 assert!(!plaintext.is_empty());
1423 self.tx.encrypt(&mut self.socket, plaintext).await
1424 }
1425
1426 pub async fn flush(&mut self) -> Result<(), TransitError> {
1428 tracing::debug!("Flush");
1429 self.socket.flush().await.map_err(Into::into)
1430 }
1431
1432 #[cfg(not(target_family = "wasm"))]
1434 pub fn split(
1435 self,
1436 ) -> (
1437 impl futures::sink::Sink<Box<[u8]>, Error = TransitError>,
1438 impl futures::stream::Stream<Item = Result<Box<[u8]>, TransitError>>,
1439 ) {
1440 let (reader, writer) = self.socket.split();
1441 (
1442 futures::sink::unfold(
1443 (writer, self.tx),
1444 |(mut writer, mut tx), plaintext: Box<[u8]>| async move {
1445 tx.encrypt(&mut writer, &plaintext)
1446 .await
1447 .map(|()| (writer, tx))
1448 },
1449 ),
1450 futures::stream::try_unfold((reader, self.rx), |(mut reader, mut rx)| async move {
1451 rx.decrypt(&mut reader)
1452 .await
1453 .map(|record| Some((record, (reader, rx))))
1454 }),
1455 )
1456 }
1457}
1458
1459type HandshakeResult = (
1460 Box<dyn TransitTransport>,
1461 Box<dyn crypto::TransitCryptoInitFinalizer>,
1462 TransitInfo,
1463);
1464
1465#[allow(deprecated)]
1475async fn handshake_exchange(
1476 is_leader: bool,
1477 tside: Arc<String>,
1478 mut socket: Box<dyn TransitTransport>,
1479 host_type: &ConnectionType,
1480 cryptor: &dyn crypto::TransitCryptoInit,
1481 key: Arc<Key<TransitKey>>,
1482) -> Result<
1483 (
1484 Box<dyn TransitTransport>,
1485 Box<dyn crypto::TransitCryptoInitFinalizer>,
1486 ),
1487 TransitHandshakeError,
1488> {
1489 if host_type != &ConnectionType::Direct {
1490 tracing::trace!("initiating relay handshake");
1491
1492 let sub_key = key.derive_subkey_from_purpose::<crate::GenericKey>("transit_relay_token");
1493 socket
1494 .write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes())
1495 .await?;
1496 let mut rx = [0u8; 3];
1497 socket.read_exact(&mut rx).await?;
1498 let ok_msg: [u8; 3] = *b"ok\n";
1499 ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed);
1500 }
1501
1502 let finalizer = if is_leader {
1503 cryptor.handshake_leader(&mut socket).await?
1504 } else {
1505 cryptor.handshake_follower(&mut socket).await?
1506 };
1507
1508 Ok((socket, finalizer))
1509}
1510
1511#[cfg(test)]
1512mod test {
1513 use super::*;
1514 use serde_json::json;
1515
1516 #[test]
1517 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1518 pub fn test_abilities_encoding() {
1519 assert_eq!(
1520 serde_json::to_value(Abilities::ALL).unwrap(),
1521 json!([{"type": "direct-tcp-v1"}, {"type": "relay-v1"}])
1522 );
1523 assert_eq!(
1524 serde_json::to_value(Abilities::FORCE_DIRECT).unwrap(),
1525 json!([{"type": "direct-tcp-v1"}])
1526 );
1527 }
1528
1529 #[test]
1530 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1531 pub fn test_hints_encoding() {
1532 assert_eq!(
1533 serde_json::to_value(Hints::new(
1534 [DirectHint {
1535 hostname: "localhost".into(),
1536 port: 1234
1537 }],
1538 [RelayHint::new(
1539 Some("default".into()),
1540 [DirectHint::new("transit.magic-wormhole.io", 4001)],
1541 ["ws://transit.magic-wormhole.io/relay".parse().unwrap(),],
1542 )]
1543 ))
1544 .unwrap(),
1545 json!([
1546 {
1547 "type": "direct-tcp-v1",
1548 "hostname": "localhost",
1549 "port": 1234
1550 },
1551 {
1552 "type": "relay-v1",
1553 "name": "default",
1554 "hints": [
1555 {
1556 "type": "direct-tcp-v1",
1557 "hostname": "transit.magic-wormhole.io",
1558 "port": 4001,
1559 },
1560 {
1561 "type": "websocket",
1562 "url": "ws://transit.magic-wormhole.io/relay",
1563 },
1564 ]
1565 }
1566 ])
1567 )
1568 }
1569}