1use crate::{Key, KeyPurpose, core::key::GenericKey};
17use serde_derive::{Deserialize, Serialize};
18
19#[cfg(not(target_family = "wasm"))]
20use async_net::{TcpListener, TcpStream};
21#[allow(unused_imports)] use futures::{
23 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
24 future::FutureExt,
25 future::TryFutureExt,
26 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
27};
28use std::{
29 collections::HashSet,
30 net::{IpAddr, SocketAddr},
31 sync::Arc,
32 time::Instant,
33};
34
35mod crypto;
36mod transport;
37use crypto::TransitHandshakeError;
38use transport::{TransitTransport, TransitTransportRx, TransitTransportTx};
39
40pub const DEFAULT_RELAY_SERVER: &str = "tcp://transit.magic-wormhole.io:4001";
42#[cfg(not(target_family = "wasm"))]
46const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478";
47
48#[derive(Debug)]
50pub struct TransitKey;
51
52impl KeyPurpose for TransitKey {}
53
54#[derive(Debug)]
55pub(crate) struct TransitRxKey;
56
57impl KeyPurpose for TransitRxKey {}
58
59#[derive(Debug)]
60pub(crate) struct TransitTxKey;
61
62impl KeyPurpose for TransitTxKey {}
63
64#[derive(Debug, thiserror::Error)]
66#[non_exhaustive]
67pub enum TransitConnectError {
68 #[error("{}", _0)]
70 Protocol(Box<str>),
71
72 #[error(
74 "All (relay) handshakes failed or timed out; could not establish a connection with the peer"
75 )]
76 Handshake,
77
78 #[error("I/O error")]
80 IO(
81 #[from]
82 #[source]
83 std::io::Error,
84 ),
85
86 #[cfg(target_family = "wasm")]
88 #[error("WASM error")]
89 WASM(
90 #[from]
91 #[source]
92 ws_stream_wasm::WsErr,
93 ),
94}
95
96#[derive(Debug, thiserror::Error)]
98#[non_exhaustive]
99pub enum TransitError {
100 #[error(
102 "Cryptography error. This is probably an implementation bug, but may also be caused by an attack."
103 )]
104 Crypto,
105
106 #[error(
108 "Wrong nonce received, got {:x?} but expected {:x?}. This is probably an implementation bug, but may also be caused by an attack.",
109 _0,
110 _1
111 )]
112 Nonce(Box<[u8]>, Box<[u8]>),
113
114 #[error("I/O error")]
116 IO(
117 #[from]
118 #[source]
119 std::io::Error,
120 ),
121
122 #[cfg(target_family = "wasm")]
124 #[error("WASM error")]
125 WASM(
126 #[from]
127 #[source]
128 ws_stream_wasm::WsErr,
129 ),
130}
131
132impl From<()> for TransitError {
133 fn from(_: ()) -> Self {
134 Self::Crypto
135 }
136}
137
138#[derive(Copy, Clone, Debug, Default)]
144pub struct Abilities {
145 pub direct_tcp_v1: bool,
147 pub relay_v1: bool,
149 #[cfg(any())]
150 pub noise_v1: bool,
152}
153
154impl Abilities {
155 pub const ALL: Self = Self {
157 direct_tcp_v1: true,
158 relay_v1: true,
159 #[cfg(any())]
160 noise_v1: false,
161 };
162
163 pub const FORCE_DIRECT: Self = Self {
170 direct_tcp_v1: true,
171 relay_v1: false,
172 #[cfg(any())]
173 noise_v1: false,
174 };
175
176 pub const FORCE_RELAY: Self = Self {
185 direct_tcp_v1: false,
186 relay_v1: true,
187 #[cfg(any())]
188 noise_v1: false,
189 };
190
191 pub fn can_direct(&self) -> bool {
193 self.direct_tcp_v1
194 }
195
196 pub fn can_relay(&self) -> bool {
198 self.relay_v1
199 }
200
201 #[cfg(any())]
202 pub(crate) fn can_noise_crypto(&self) -> bool {
203 self.noise_v1
204 }
205
206 pub(crate) fn can_noise_crypto(&self) -> bool {
208 false
209 }
210
211 pub fn intersect(mut self, other: &Self) -> Self {
213 self.direct_tcp_v1 &= other.direct_tcp_v1;
214 self.relay_v1 &= other.relay_v1;
215 #[cfg(any())]
216 {
217 self.noise_v1 &= other.noise_v1;
218 }
219 self
220 }
221}
222
223impl serde::Serialize for Abilities {
224 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
225 where
226 S: serde::Serializer,
227 {
228 let mut hints = Vec::new();
229 if self.direct_tcp_v1 {
230 hints.push(serde_json::json!({
231 "type": "direct-tcp-v1",
232 }));
233 }
234 if self.relay_v1 {
235 hints.push(serde_json::json!({
236 "type": "relay-v1",
237 }));
238 }
239 #[cfg(any())]
240 if self.noise_v1 {
241 hints.push(serde_json::json!({
242 "type": "noise-crypto-v1",
243 }));
244 }
245 serde_json::Value::Array(hints).serialize(ser)
246 }
247}
248
249impl<'de> serde::Deserialize<'de> for Abilities {
250 fn deserialize<D>(de: D) -> Result<Self, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 #[derive(Deserialize)]
255 #[serde(rename_all = "kebab-case", tag = "type")]
256 enum Ability {
257 DirectTcpV1,
258 RelayV1,
259 RelayV2,
260 #[cfg(any())]
261 NoiseCryptoV1,
262 #[serde(other)]
263 Other,
264 }
265
266 let mut abilities = Self::default();
267 for ability in <Vec<Ability> as serde::Deserialize>::deserialize(de)? {
269 match ability {
270 Ability::DirectTcpV1 => {
271 abilities.direct_tcp_v1 = true;
272 },
273 Ability::RelayV1 => {
274 abilities.relay_v1 = true;
275 },
276 #[cfg(any())]
277 Ability::NoiseCryptoV1 => {
278 abilities.noise_v1 = true;
279 },
280 _ => (),
281 }
282 }
283 Ok(abilities)
284 }
285}
286
287#[derive(Serialize, Deserialize, Debug, PartialEq)]
289#[serde(rename_all = "kebab-case", tag = "type")]
290#[non_exhaustive]
291enum HintSerde {
292 DirectTcpV1(DirectHint),
293 RelayV1(RelayHint),
294 #[serde(other)]
295 Unknown,
296}
297
298#[derive(Clone, Debug, Default)]
300pub struct Hints {
301 pub direct_tcp: HashSet<DirectHint>,
303 pub relay: Vec<RelayHint>,
305}
306
307impl Hints {
308 pub fn new(
310 direct_tcp: impl IntoIterator<Item = DirectHint>,
311 relay: impl IntoIterator<Item = RelayHint>,
312 ) -> Self {
313 Self {
314 direct_tcp: direct_tcp.into_iter().collect(),
315 relay: relay.into_iter().collect(),
316 }
317 }
318}
319
320impl<'de> serde::Deserialize<'de> for Hints {
321 fn deserialize<D>(de: D) -> Result<Self, D::Error>
322 where
323 D: serde::Deserializer<'de>,
324 {
325 let hints: Vec<HintSerde> = serde::Deserialize::deserialize(de)?;
326 let mut direct_tcp = HashSet::new();
327 let mut relay = Vec::<RelayHint>::new();
328 let mut relay_v2 = Vec::<RelayHint>::new();
329
330 for hint in hints {
331 match hint {
332 HintSerde::DirectTcpV1(hint) => {
333 direct_tcp.insert(hint);
334 },
335 HintSerde::RelayV1(hint) => {
336 relay_v2.push(hint);
337 },
338 _ => {},
340 }
341 }
342
343 if !relay_v2.is_empty() {
345 relay.clear();
346 }
347 relay.extend(relay_v2);
348
349 Ok(Hints { direct_tcp, relay })
350 }
351}
352
353impl serde::Serialize for Hints {
354 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
355 where
356 S: serde::Serializer,
357 {
358 let direct = self.direct_tcp.iter().cloned().map(HintSerde::DirectTcpV1);
359 let relay = self.relay.iter().cloned().map(HintSerde::RelayV1);
360 ser.collect_seq(direct.chain(relay))
361 }
362}
363
364#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)]
366#[display("tcp://{}:{}", hostname, port)]
367pub struct DirectHint {
368 pub hostname: String,
373 pub port: u16,
375}
376
377impl DirectHint {
378 pub fn new(hostname: impl Into<String>, port: u16) -> Self {
380 Self {
381 hostname: hostname.into(),
382 port,
383 }
384 }
385}
386
387#[derive(Serialize, Deserialize, Debug, PartialEq)]
389#[serde(rename_all = "kebab-case", tag = "type")]
390#[non_exhaustive]
391struct RelayHintSerde {
392 name: Option<String>,
393 #[serde(rename = "hints")]
394 endpoints: Vec<RelayHintSerdeInner>,
395}
396
397#[derive(Serialize, Deserialize, Debug, PartialEq)]
399#[serde(rename_all = "kebab-case", tag = "type")]
400#[non_exhaustive]
401enum RelayHintSerdeInner {
402 #[serde(rename = "direct-tcp-v1")]
403 Tcp(DirectHint),
404 Websocket {
405 url: url::Url,
406 },
407 #[serde(other)]
408 Unknown,
409}
410
411#[derive(Debug, thiserror::Error)]
412#[non_exhaustive]
413pub enum RelayHintParseError {
415 #[error(
416 "Invalid TCP hint endpoint: '{}' (Does it have hostname and port?)",
417 _0
418 )]
419 InvalidTcp(url::Url),
421 #[error(
422 "Unknown schema: '{}'. Currently known values are 'tcp', 'ws' and 'wss'.",
423 _0
424 )]
425 UnknownSchema(Box<str>),
427 #[error("'{}' is not an absolute URL (must start with a '/')", _0)]
428 UrlNotAbsolute(url::Url),
430}
431
432#[derive(Clone, Debug, Eq, PartialEq, Default)]
443pub struct RelayHint {
444 pub name: Option<String>,
448 pub tcp: HashSet<DirectHint>,
450 pub ws: HashSet<url::Url>,
452}
453
454impl RelayHint {
455 pub fn new(
457 name: Option<String>,
458 tcp: impl IntoIterator<Item = DirectHint>,
459 ws: impl IntoIterator<Item = url::Url>,
460 ) -> Self {
461 Self {
462 name,
463 tcp: tcp.into_iter().collect(),
464 ws: ws.into_iter().collect(),
465 }
466 }
467
468 pub fn from_urls(
490 name: Option<String>,
491 urls: impl IntoIterator<Item = url::Url>,
492 ) -> Result<Self, RelayHintParseError> {
493 let mut this = Self {
494 name,
495 ..Self::default()
496 };
497 for url in urls.into_iter() {
498 ensure!(
499 !url.cannot_be_a_base(),
500 RelayHintParseError::UrlNotAbsolute(url)
501 );
502 match url.scheme() {
503 "tcp" => {
504 let (hostname, port) = match (url.host_str(), url.port()) {
506 (Some(hostname), Some(port)) => (hostname.into(), port),
507 _ => bail!(RelayHintParseError::InvalidTcp(url)),
508 };
509 this.tcp.insert(DirectHint { hostname, port });
510 },
511 "ws" | "wss" => {
512 this.ws.insert(url);
513 },
514 other => bail!(RelayHintParseError::UnknownSchema(other.into())),
515 }
516 }
517 assert!(
518 !this.tcp.is_empty() || !this.ws.is_empty(),
519 "No URLs provided"
520 );
521 Ok(this)
522 }
523
524 pub(crate) fn can_merge(&self, other: &Self) -> bool {
526 !self.tcp.is_disjoint(&other.tcp) || !self.ws.is_disjoint(&other.ws)
527 }
528
529 pub(crate) fn merge_mut(&mut self, other: Self) {
531 self.tcp.extend(other.tcp);
532 self.ws.extend(other.ws);
533 }
534
535 pub(crate) fn merge_into(self, collection: &mut Vec<RelayHint>) {
537 for item in collection.iter_mut() {
538 if item.can_merge(&self) {
539 item.merge_mut(self);
540 return;
541 }
542 }
543 collection.push(self);
544 }
545}
546
547impl serde::Serialize for RelayHint {
548 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
549 where
550 S: serde::Serializer,
551 {
552 let mut hints = Vec::new();
553 hints.extend(self.tcp.iter().cloned().map(RelayHintSerdeInner::Tcp));
554 hints.extend(
555 self.ws
556 .iter()
557 .cloned()
558 .map(|h| RelayHintSerdeInner::Websocket { url: h }),
559 );
560
561 serde_json::json!({
562 "name": self.name,
563 "hints": hints,
564 })
565 .serialize(ser)
566 }
567}
568
569impl<'de> serde::Deserialize<'de> for RelayHint {
570 fn deserialize<D>(de: D) -> Result<Self, D::Error>
571 where
572 D: serde::Deserializer<'de>,
573 {
574 let raw = RelayHintSerde::deserialize(de)?;
575 let mut hint = RelayHint {
576 name: raw.name,
577 tcp: HashSet::new(),
578 ws: HashSet::new(),
579 };
580
581 for e in raw.endpoints {
582 match e {
583 RelayHintSerdeInner::Tcp(tcp) => {
584 hint.tcp.insert(tcp);
585 },
586 RelayHintSerdeInner::Websocket { url } => {
587 hint.ws.insert(url);
588 },
589 _ => {},
591 }
592 }
593
594 Ok(hint)
595 }
596}
597
598impl TryFrom<&DirectHint> for IpAddr {
599 type Error = std::net::AddrParseError;
600 fn try_from(hint: &DirectHint) -> Result<IpAddr, std::net::AddrParseError> {
601 hint.hostname.parse()
602 }
603}
604
605impl TryFrom<&DirectHint> for SocketAddr {
606 type Error = std::net::AddrParseError;
607 fn try_from(hint: &DirectHint) -> Result<SocketAddr, std::net::AddrParseError> {
609 let addr = hint.try_into()?;
610 let addr = match addr {
611 IpAddr::V4(v4) => IpAddr::V6(v4.to_ipv6_mapped()),
612 IpAddr::V6(_) => addr,
613 };
614 Ok(SocketAddr::new(addr, hint.port))
615 }
616}
617
618#[derive(Clone, Debug, Eq, PartialEq)]
620#[non_exhaustive]
621pub enum ConnectionType {
622 Direct,
624 Relay {
626 name: Option<String>,
628 },
629}
630
631impl std::fmt::Display for ConnectionType {
632 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
633 match self {
634 ConnectionType::Direct => write!(f, "directly"),
635 ConnectionType::Relay { name: Some(name) } => write!(f, "via relay ({name})"),
636 ConnectionType::Relay { name: None } => write!(f, "via relay"),
637 }
638 }
639}
640
641#[derive(Clone, Debug, Eq, PartialEq)]
643#[non_exhaustive]
644pub struct TransitInfo {
645 pub conn_type: ConnectionType,
647 #[cfg(not(target_family = "wasm"))]
650 pub peer_addr: SocketAddr,
651}
652
653type TransitConnection = (Box<dyn TransitTransport>, TransitInfo);
654
655#[cfg(not(target_family = "wasm"))]
656#[derive(Debug, thiserror::Error)]
657enum StunError {
658 #[error("No IPv4 addresses were found for the selected STUN server")]
659 ServerIsV6Only,
660 #[error("Server did not tell us our IP address")]
661 ServerNoResponse,
662 #[error("Connection timed out")]
663 Timeout,
664 #[error("IO error")]
665 IO(
666 #[from]
667 #[source]
668 std::io::Error,
669 ),
670 #[error("Malformed STUN packet")]
671 Codec(
672 #[from]
673 #[source]
674 bytecodec::Error,
675 ),
676}
677
678#[cfg(not(target_family = "wasm"))]
679impl std::fmt::Display for TransitInfo {
680 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681 match &self.conn_type {
682 ConnectionType::Direct => {
683 write!(
684 f,
685 "Established direct transit connection to '{}'",
686 self.peer_addr,
687 )
688 },
689 ConnectionType::Relay { name: Some(name) } => {
690 write!(
691 f,
692 "Established transit connection via relay '{}' ({})",
693 name, self.peer_addr,
694 )
695 },
696 ConnectionType::Relay { name: None } => {
697 write!(
698 f,
699 "Established transit connection via relay ({})",
700 self.peer_addr,
701 )
702 },
703 }
704 }
705}
706
707#[cfg(target_family = "wasm")]
708impl std::fmt::Display for TransitInfo {
709 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
710 match &self.conn_type {
711 ConnectionType::Direct => {
712 write!(f, "Established direct transit connection",)
713 },
714 ConnectionType::Relay { name: Some(name) } => {
715 write!(f, "Established transit connection via relay '{}'", name)
716 },
717 ConnectionType::Relay { name: None } => {
718 write!(f, "Established transit connection via relay",)
719 },
720 }
721 }
722}
723
724pub async fn init(
730 mut abilities: Abilities,
731 peer_abilities: Option<Abilities>,
732 relay_hints: Vec<RelayHint>,
733) -> Result<TransitConnector, std::io::Error> {
734 let mut our_hints = Hints::default();
735 #[cfg(not(target_family = "wasm"))]
736 let mut sockets = None;
737
738 if let Some(peer_abilities) = peer_abilities {
739 abilities = abilities.intersect(&peer_abilities);
740 }
741
742 #[cfg(not(target_family = "wasm"))]
744 if abilities.can_direct() {
745 let create_sockets = async {
746 let socket: MaybeConnectedSocket = match crate::util::timeout(
752 std::time::Duration::from_secs(4),
753 transport::tcp_get_external_ip(),
754 )
755 .await
756 .map_err(|_| StunError::Timeout)
757 {
758 Ok(Ok((external_ip, stream))) => {
759 tracing::debug!("Our external IP address is {}", external_ip);
760 our_hints.direct_tcp.insert(DirectHint {
761 hostname: external_ip.ip().to_string(),
762 port: external_ip.port(),
763 });
764 tracing::debug!(
765 "Our socket for connecting is bound to {} and connected to {}",
766 stream.local_addr()?,
767 stream.peer_addr()?,
768 );
769 stream.into()
770 },
771 Err(err) | Ok(Err(err)) => {
774 tracing::warn!("Failed to get external address via STUN, {}", err);
775 let socket =
776 socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
777 transport::set_socket_opts(&socket)?;
778
779 socket.bind(&"[::]:0".parse::<SocketAddr>().unwrap().into())?;
780 tracing::debug!(
781 "Our socket for connecting is bound to {}",
782 socket.local_addr()?.as_socket().unwrap(),
783 );
784
785 socket.into()
786 },
787 };
788
789 let listener = TcpListener::bind("[::]:0").await?;
796
797 let port = socket.local_addr()?.as_socket().unwrap().port();
799 let port2 = listener.local_addr()?.port();
800 our_hints.direct_tcp.extend(
801 if_addrs::get_if_addrs()?
802 .iter()
803 .filter(|iface| !iface.is_loopback())
804 .flat_map(|ip| {
805 [
806 DirectHint {
807 hostname: ip.ip().to_string(),
808 port,
809 },
810 DirectHint {
811 hostname: ip.ip().to_string(),
812 port: port2,
813 },
814 ]
815 .into_iter()
816 }),
817 );
818 tracing::debug!("Our socket for listening is {}", listener.local_addr()?);
819
820 Ok::<_, std::io::Error>((socket, listener))
821 };
822
823 sockets = create_sockets
824 .await
825 .map_err(|err| {
827 tracing::error!("Failed to create direct hints for our side: {}", err);
828 err
829 })
830 .ok();
831 }
832
833 if abilities.can_relay() {
834 our_hints.relay.extend(relay_hints);
835 }
836
837 Ok(TransitConnector {
838 #[cfg(not(target_family = "wasm"))]
839 sockets,
840 our_abilities: abilities,
841 our_hints: Arc::new(our_hints),
842 })
843}
844
845#[cfg(not(target_family = "wasm"))]
847#[derive(derive_more::From)]
848enum MaybeConnectedSocket {
849 #[from]
850 Socket(socket2::Socket),
851 #[from]
852 Stream(TcpStream),
853}
854
855#[cfg(not(target_family = "wasm"))]
856impl MaybeConnectedSocket {
857 fn local_addr(&self) -> std::io::Result<socket2::SockAddr> {
858 match &self {
859 Self::Socket(socket) => socket.local_addr(),
860 Self::Stream(stream) => Ok(stream.local_addr()?.into()),
861 }
862 }
863}
864
865#[derive(Clone, Debug)]
869pub enum TransitRole {
870 Leader,
872 Follower,
874}
875
876pub struct TransitConnector {
884 #[cfg(not(target_family = "wasm"))]
889 sockets: Option<(MaybeConnectedSocket, TcpListener)>,
890 our_abilities: Abilities,
891 our_hints: Arc<Hints>,
892}
893
894impl TransitConnector {
895 pub fn our_abilities(&self) -> &Abilities {
897 &self.our_abilities
898 }
899
900 pub fn our_hints(&self) -> &Arc<Hints> {
902 &self.our_hints
903 }
904
905 pub async fn connect(
910 self,
911 role: TransitRole,
912 transit_key: Key<TransitKey>,
913 their_abilities: Abilities,
914 their_hints: Arc<Hints>,
915 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
916 match role {
917 TransitRole::Leader => {
918 self.leader_connect(transit_key, their_abilities, their_hints)
919 .await
920 },
921 TransitRole::Follower => {
922 self.follower_connect(transit_key, their_abilities, their_hints)
923 .await
924 },
925 }
926 }
927
928 async fn leader_connect(
932 self,
933 transit_key: Key<TransitKey>,
934 their_abilities: Abilities,
935 their_hints: Arc<Hints>,
936 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
937 let Self {
938 #[cfg(not(target_family = "wasm"))]
939 sockets,
940 our_abilities,
941 our_hints,
942 } = self;
943 let transit_key = Arc::new(transit_key);
944
945 let start = Instant::now();
946 let mut connection_stream = Box::pin(
947 Self::connect_inner(
948 true,
949 transit_key,
950 our_abilities,
951 our_hints,
952 their_abilities,
953 their_hints,
954 #[cfg(not(target_family = "wasm"))]
955 sockets,
956 )
957 .filter_map(|result| async {
958 match result {
959 Ok(val) => Some(val),
960 Err(err) => {
961 tracing::debug!("Some leader handshake failed: {:?}", err);
962 None
963 },
964 }
965 }),
966 );
967
968 let (mut transit, mut finalizer, mut conn_info) =
969 crate::util::timeout(std::time::Duration::from_secs(60), connection_stream.next())
970 .await
971 .map_err(|_| {
972 tracing::debug!("`leader_connect` timed out");
973 TransitConnectError::Handshake
974 })?
975 .ok_or(TransitConnectError::Handshake)?;
976
977 if conn_info.conn_type != ConnectionType::Direct && our_abilities.can_direct() {
978 tracing::debug!(
979 "Established transit connection over relay. Trying to find a direct connection …"
980 );
981 let elapsed = start.elapsed();
985 let to_wait = if elapsed.as_secs() > 5 {
986 std::time::Duration::from_secs(1)
988 } else {
989 elapsed.mul_f32(0.3)
990 };
991 let _ = crate::util::timeout(to_wait, async {
992 while let Some((new_transit, new_finalizer, new_conn_info)) =
993 connection_stream.next().await
994 {
995 if new_conn_info.conn_type == ConnectionType::Direct {
997 transit = new_transit;
998 finalizer = new_finalizer;
999 conn_info = new_conn_info;
1000 tracing::debug!("Found direct connection; using that instead.");
1001 break;
1002 }
1003 }
1004 })
1005 .await;
1006 tracing::debug!("Did not manage to establish a better connection in time.");
1007 } else {
1008 tracing::debug!("Established direct transit connection");
1009 }
1010
1011 std::mem::drop(connection_stream);
1015
1016 let (tx, rx) = finalizer
1017 .handshake_finalize(&mut transit)
1018 .await
1019 .map_err(|e| {
1020 tracing::debug!("`handshake_finalize` failed: {e}");
1021 TransitConnectError::Handshake
1022 })?;
1023
1024 Ok((
1025 Transit {
1026 socket: transit,
1027 tx,
1028 rx,
1029 },
1030 conn_info,
1031 ))
1032 }
1033
1034 async fn follower_connect(
1038 self,
1039 transit_key: Key<TransitKey>,
1040 their_abilities: Abilities,
1041 their_hints: Arc<Hints>,
1042 ) -> Result<(Transit, TransitInfo), TransitConnectError> {
1043 let Self {
1044 #[cfg(not(target_family = "wasm"))]
1045 sockets,
1046 our_abilities,
1047 our_hints,
1048 } = self;
1049 let transit_key = Arc::new(transit_key);
1050
1051 let mut connection_stream = Box::pin(
1052 Self::connect_inner(
1053 false,
1054 transit_key,
1055 our_abilities,
1056 our_hints,
1057 their_abilities,
1058 their_hints,
1059 #[cfg(not(target_family = "wasm"))]
1060 sockets,
1061 )
1062 .filter_map(|result| async {
1063 match result {
1064 Ok(val) => Some(val),
1065 Err(err) => {
1066 tracing::debug!("Some follower handshake failed: {:?}", err);
1067 None
1068 },
1069 }
1070 }),
1071 );
1072
1073 let transit = match crate::util::timeout(
1074 std::time::Duration::from_secs(60),
1075 &mut connection_stream.next(),
1076 )
1077 .await
1078 {
1079 Ok(Some((mut socket, finalizer, conn_info))) => {
1080 let (tx, rx) = finalizer
1081 .handshake_finalize(&mut socket)
1082 .await
1083 .map_err(|e| {
1084 tracing::debug!("`handshake_finalize` failed: {e}");
1085 TransitConnectError::Handshake
1086 })?;
1087
1088 Ok((Transit { socket, tx, rx }, conn_info))
1089 },
1090 Ok(None) | Err(_) => {
1091 tracing::debug!("`follower_connect` timed out");
1092 Err(TransitConnectError::Handshake)
1093 },
1094 };
1095
1096 std::mem::drop(connection_stream);
1100
1101 transit
1102 }
1103
1104 fn connect_inner(
1114 is_leader: bool,
1115 transit_key: Arc<Key<TransitKey>>,
1116 our_abilities: Abilities,
1117 our_hints: Arc<Hints>,
1118 their_abilities: Abilities,
1119 their_hints: Arc<Hints>,
1120 #[cfg(not(target_family = "wasm"))] sockets: Option<(MaybeConnectedSocket, TcpListener)>,
1121 ) -> impl Stream<Item = Result<HandshakeResult, TransitHandshakeError>> + 'static {
1122 #[cfg(not(target_family = "wasm"))]
1124 assert!(sockets.is_none() || our_abilities.can_direct());
1125
1126 let cryptor = if our_abilities.can_noise_crypto() && their_abilities.can_noise_crypto() {
1127 tracing::debug!("Using noise protocol for encryption");
1128 Arc::new(crypto::NoiseInit {
1129 key: transit_key.clone(),
1130 }) as Arc<dyn crypto::TransitCryptoInit>
1131 } else {
1132 tracing::debug!("Using secretbox for encryption");
1133 Arc::new(crypto::SecretboxInit {
1134 key: transit_key.clone(),
1135 }) as Arc<dyn crypto::TransitCryptoInit>
1136 };
1137
1138 let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>()));
1140
1141 #[cfg(not(target_family = "wasm"))]
1145 use futures::future::BoxFuture;
1146 #[cfg(target_family = "wasm")]
1147 use futures::future::LocalBoxFuture as BoxFuture;
1148 type BoxIterator<T> = Box<dyn Iterator<Item = T>>;
1149 type ConnectorFuture = BoxFuture<'static, Result<TransitConnection, TransitHandshakeError>>;
1150 let mut connectors: BoxIterator<ConnectorFuture> = Box::new(std::iter::empty());
1151
1152 #[cfg(not(target_family = "wasm"))]
1153 let (socket, listener) = sockets.unzip();
1154 #[cfg(not(target_family = "wasm"))]
1155 if our_abilities.can_direct() && their_abilities.can_direct() {
1156 let local_addr = socket.map(|socket| {
1157 Arc::new(
1158 socket
1159 .local_addr()
1160 .expect("This is guaranteed to be an IP socket"),
1161 )
1162 });
1163 connectors = Box::new(
1165 connectors.chain(
1166 their_hints
1167 .direct_tcp
1168 .clone()
1169 .into_iter()
1170 .take(50)
1172 .map(move |hint| transport::connect_tcp_direct(local_addr.clone(), hint))
1173 .map(|fut| Box::pin(fut) as ConnectorFuture),
1174 ),
1175 ) as BoxIterator<ConnectorFuture>;
1176 }
1177
1178 if our_abilities.can_relay() && their_abilities.can_relay() {
1180 let mut relay_hints = Vec::<RelayHint>::new();
1182 relay_hints.extend(our_hints.relay.iter().take(2).cloned());
1183 for hint in their_hints.relay.iter().take(2).cloned() {
1184 hint.merge_into(&mut relay_hints);
1185 }
1186
1187 #[cfg(not(target_family = "wasm"))]
1188 {
1189 connectors = Box::new(
1190 connectors.chain(
1191 relay_hints
1192 .into_iter()
1193 .flat_map(|hint| {
1202 let name = hint.name
1204 .or_else(|| {
1205 hint.tcp.iter()
1207 .filter_map(|hint| match url::Host::parse(&hint.hostname) {
1208 Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
1209 _ => None,
1210 })
1211 .next()
1212 });
1213 hint.tcp
1214 .into_iter()
1215 .take(3)
1216 .enumerate()
1217 .map(move |(i, h)| (i, h, name.clone()))
1218 })
1219 .map(|(index, host, name)| async move {
1220 async_io::Timer::after(std::time::Duration::from_secs(
1221 index as u64 * 5,
1222 ))
1223 .await;
1224 transport::connect_tcp_relay(host, name).await
1225 })
1226 .map(|fut| Box::pin(fut) as ConnectorFuture),
1227 ),
1228 ) as BoxIterator<ConnectorFuture>;
1229 }
1230
1231 #[cfg(target_family = "wasm")]
1232 {
1233 connectors = Box::new(
1234 connectors.chain(
1235 relay_hints
1236 .into_iter()
1237 .flat_map(|hint| {
1246 let name = hint.name
1248 .or_else(|| {
1249 hint.tcp.iter()
1251 .filter_map(|hint| match url::Host::parse(&hint.hostname) {
1252 Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
1253 _ => None,
1254 })
1255 .next()
1256 });
1257 hint.ws
1258 .into_iter()
1259 .take(3)
1260 .enumerate()
1261 .map(move |(i, u)| (i, u, name.clone()))
1262 })
1263 .map(|(index, url, name)| async move {
1264 crate::util::sleep(std::time::Duration::from_secs(
1265 index as u64 * 5,
1266 ))
1267 .await;
1268 transport::connect_ws_relay(url, name).await
1269 })
1270 .map(|fut| Box::pin(fut) as ConnectorFuture),
1271 ),
1272 ) as BoxIterator<ConnectorFuture>;
1273 }
1274 }
1275
1276 let transit_key2 = transit_key.clone();
1278 let tside2 = tside.clone();
1279 let cryptor2 = cryptor.clone();
1280 #[allow(unused_mut)] let mut connectors = Box::new(
1282 connectors
1283 .map(move |fut| {
1284 let transit_key = transit_key2.clone();
1285 let tside = tside2.clone();
1286 let cryptor = cryptor2.clone();
1287 async move {
1288 let (socket, conn_info) = fut.await?;
1289 let (transit, finalizer) = handshake_exchange(
1290 is_leader,
1291 tside,
1292 socket,
1293 &conn_info.conn_type,
1294 &*cryptor,
1295 transit_key,
1296 )
1297 .await?;
1298 Ok((transit, finalizer, conn_info))
1299 }
1300 })
1301 .map(|fut| {
1302 Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
1303 }),
1304 )
1305 as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
1306
1307 #[cfg(not(target_family = "wasm"))]
1309 if let Some(listener) = listener {
1310 connectors = Box::new(
1311 connectors.chain(
1312 std::iter::once(async move {
1313 let transit_key = transit_key.clone();
1314 let tside = tside.clone();
1315 let cryptor = cryptor.clone();
1316 let connect = || async {
1317 let (socket, peer) = listener.accept().await?;
1318 let (socket, info) =
1319 transport::wrap_tcp_connection(socket, ConnectionType::Direct)?;
1320 tracing::debug!("Got connection from {}!", peer);
1321 let (transit, finalizer) = handshake_exchange(
1322 is_leader,
1323 tside.clone(),
1324 socket,
1325 &ConnectionType::Direct,
1326 &*cryptor,
1327 transit_key.clone(),
1328 )
1329 .await?;
1330 Result::<_, TransitHandshakeError>::Ok((transit, finalizer, info))
1331 };
1332 loop {
1333 match connect().await {
1334 Ok(success) => break Ok(success),
1335 Err(err) => {
1336 tracing::debug!(
1337 "Some handshake failed on the listening port: {:?}",
1338 err
1339 );
1340 continue;
1341 },
1342 }
1343 }
1344 })
1345 .map(|fut| {
1346 Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
1347 }),
1348 ),
1349 )
1350 as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
1351 }
1352 connectors.collect::<futures::stream::futures_unordered::FuturesUnordered<_>>()
1353 }
1354}
1355
1356pub struct Transit {
1363 socket: Box<dyn TransitTransport>,
1365 tx: Box<dyn crypto::TransitCryptoEncrypt>,
1366 rx: Box<dyn crypto::TransitCryptoDecrypt>,
1367}
1368
1369impl Transit {
1370 pub async fn receive_record(&mut self) -> Result<Box<[u8]>, TransitError> {
1372 self.rx.decrypt(&mut self.socket).await
1373 }
1374
1375 pub async fn send_record(&mut self, plaintext: &[u8]) -> Result<(), TransitError> {
1377 assert!(!plaintext.is_empty());
1378 self.tx.encrypt(&mut self.socket, plaintext).await
1379 }
1380
1381 pub async fn flush(&mut self) -> Result<(), TransitError> {
1383 tracing::debug!("Flush");
1384 self.socket.flush().await.map_err(Into::into)
1385 }
1386
1387 #[cfg(not(target_family = "wasm"))]
1389 #[expect(clippy::type_complexity)]
1390 pub fn split(
1391 self,
1392 ) -> (
1393 impl futures::sink::Sink<Box<[u8]>, Error = TransitError>,
1394 impl futures_lite::stream::Stream<Item = Result<Box<[u8]>, TransitError>>,
1395 ) {
1396 let (reader, writer) = self.socket.split();
1397 (
1398 futures::sink::unfold(
1399 (writer, self.tx),
1400 |(mut writer, mut tx), plaintext: Box<[u8]>| async move {
1401 tx.encrypt(&mut writer, &plaintext)
1402 .await
1403 .map(|()| (writer, tx))
1404 },
1405 ),
1406 futures::stream::try_unfold((reader, self.rx), |(mut reader, mut rx)| async move {
1407 rx.decrypt(&mut reader)
1408 .await
1409 .map(|record| Some((record, (reader, rx))))
1410 }),
1411 )
1412 }
1413}
1414
1415type HandshakeResult = (
1416 Box<dyn TransitTransport>,
1417 Box<dyn crypto::TransitCryptoInitFinalizer>,
1418 TransitInfo,
1419);
1420
1421async fn handshake_exchange(
1431 is_leader: bool,
1432 tside: Arc<String>,
1433 mut socket: Box<dyn TransitTransport>,
1434 host_type: &ConnectionType,
1435 cryptor: &dyn crypto::TransitCryptoInit,
1436 key: Arc<Key<TransitKey>>,
1437) -> Result<
1438 (
1439 Box<dyn TransitTransport>,
1440 Box<dyn crypto::TransitCryptoInitFinalizer>,
1441 ),
1442 TransitHandshakeError,
1443> {
1444 if host_type != &ConnectionType::Direct {
1445 tracing::trace!("initiating relay handshake");
1446
1447 let sub_key = key.derive_subkey_from_purpose::<GenericKey>("transit_relay_token");
1448 socket
1449 .write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes())
1450 .await?;
1451 let mut rx = [0u8; 3];
1452 socket.read_exact(&mut rx).await?;
1453 let ok_msg: [u8; 3] = *b"ok\n";
1454 ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed);
1455 }
1456
1457 let finalizer = if is_leader {
1458 cryptor.handshake_leader(&mut socket).await?
1459 } else {
1460 cryptor.handshake_follower(&mut socket).await?
1461 };
1462
1463 Ok((socket, finalizer))
1464}
1465
1466#[cfg(test)]
1467mod test {
1468 use super::*;
1469 use serde_json::json;
1470
1471 #[test]
1472 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1473 pub fn test_abilities_encoding() {
1474 assert_eq!(
1475 serde_json::to_value(Abilities::ALL).unwrap(),
1476 json!([{"type": "direct-tcp-v1"}, {"type": "relay-v1"}])
1477 );
1478 assert_eq!(
1479 serde_json::to_value(Abilities::FORCE_DIRECT).unwrap(),
1480 json!([{"type": "direct-tcp-v1"}])
1481 );
1482 }
1483
1484 #[test]
1485 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1486 pub fn test_hints_encoding() {
1487 assert_eq!(
1488 serde_json::to_value(Hints::new(
1489 [DirectHint {
1490 hostname: "localhost".into(),
1491 port: 1234
1492 }],
1493 [RelayHint::new(
1494 Some("default".into()),
1495 [DirectHint::new("transit.magic-wormhole.io", 4001)],
1496 ["ws://transit.magic-wormhole.io/relay".parse().unwrap(),],
1497 )]
1498 ))
1499 .unwrap(),
1500 json!([
1501 {
1502 "type": "direct-tcp-v1",
1503 "hostname": "localhost",
1504 "port": 1234
1505 },
1506 {
1507 "type": "relay-v1",
1508 "name": "default",
1509 "hints": [
1510 {
1511 "type": "direct-tcp-v1",
1512 "hostname": "transit.magic-wormhole.io",
1513 "port": 4001,
1514 },
1515 {
1516 "type": "websocket",
1517 "url": "ws://transit.magic-wormhole.io/relay",
1518 },
1519 ]
1520 }
1521 ])
1522 )
1523 }
1524}