use crate::{Key, KeyPurpose, core::key::GenericKey};
use serde_derive::{Deserialize, Serialize};
#[cfg(not(target_family = "wasm"))]
use async_net::{TcpListener, TcpStream};
#[allow(unused_imports)]
use futures::{
Sink, SinkExt, Stream, StreamExt, TryStreamExt,
future::FutureExt,
future::TryFutureExt,
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
};
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Instant,
};
mod crypto;
mod transport;
use crypto::TransitHandshakeError;
use transport::{TransitTransport, TransitTransportRx, TransitTransportTx};
pub const DEFAULT_RELAY_SERVER: &str = "tcp://transit.magic-wormhole.io:4001";
#[cfg(not(target_family = "wasm"))]
const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478";
#[derive(Debug)]
pub struct TransitKey;
impl KeyPurpose for TransitKey {}
#[derive(Debug)]
pub(crate) struct TransitRxKey;
impl KeyPurpose for TransitRxKey {}
#[derive(Debug)]
pub(crate) struct TransitTxKey;
impl KeyPurpose for TransitTxKey {}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TransitConnectError {
#[error("{}", _0)]
Protocol(Box<str>),
#[error(
"All (relay) handshakes failed or timed out; could not establish a connection with the peer"
)]
Handshake,
#[error("I/O error")]
IO(
#[from]
#[source]
std::io::Error,
),
#[cfg(target_family = "wasm")]
#[error("WASM error")]
WASM(
#[from]
#[source]
ws_stream_wasm::WsErr,
),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TransitError {
#[error(
"Cryptography error. This is probably an implementation bug, but may also be caused by an attack."
)]
Crypto,
#[error(
"Wrong nonce received, got {:x?} but expected {:x?}. This is probably an implementation bug, but may also be caused by an attack.",
_0,
_1
)]
Nonce(Box<[u8]>, Box<[u8]>),
#[error("I/O error")]
IO(
#[from]
#[source]
std::io::Error,
),
#[cfg(target_family = "wasm")]
#[error("WASM error")]
WASM(
#[from]
#[source]
ws_stream_wasm::WsErr,
),
}
impl From<()> for TransitError {
fn from(_: ()) -> Self {
Self::Crypto
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Abilities {
pub direct_tcp_v1: bool,
pub relay_v1: bool,
#[cfg(any())]
pub noise_v1: bool,
}
impl Abilities {
pub const ALL: Self = Self {
direct_tcp_v1: true,
relay_v1: true,
#[cfg(any())]
noise_v1: false,
};
pub const FORCE_DIRECT: Self = Self {
direct_tcp_v1: true,
relay_v1: false,
#[cfg(any())]
noise_v1: false,
};
pub const FORCE_RELAY: Self = Self {
direct_tcp_v1: false,
relay_v1: true,
#[cfg(any())]
noise_v1: false,
};
pub fn can_direct(&self) -> bool {
self.direct_tcp_v1
}
pub fn can_relay(&self) -> bool {
self.relay_v1
}
#[cfg(any())]
pub(crate) fn can_noise_crypto(&self) -> bool {
self.noise_v1
}
pub(crate) fn can_noise_crypto(&self) -> bool {
false
}
pub fn intersect(mut self, other: &Self) -> Self {
self.direct_tcp_v1 &= other.direct_tcp_v1;
self.relay_v1 &= other.relay_v1;
#[cfg(any())]
{
self.noise_v1 &= other.noise_v1;
}
self
}
}
impl serde::Serialize for Abilities {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut hints = Vec::new();
if self.direct_tcp_v1 {
hints.push(serde_json::json!({
"type": "direct-tcp-v1",
}));
}
if self.relay_v1 {
hints.push(serde_json::json!({
"type": "relay-v1",
}));
}
#[cfg(any())]
if self.noise_v1 {
hints.push(serde_json::json!({
"type": "noise-crypto-v1",
}));
}
serde_json::Value::Array(hints).serialize(ser)
}
}
impl<'de> serde::Deserialize<'de> for Abilities {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case", tag = "type")]
enum Ability {
DirectTcpV1,
RelayV1,
RelayV2,
#[cfg(any())]
NoiseCryptoV1,
#[serde(other)]
Other,
}
let mut abilities = Self::default();
for ability in <Vec<Ability> as serde::Deserialize>::deserialize(de)? {
match ability {
Ability::DirectTcpV1 => {
abilities.direct_tcp_v1 = true;
},
Ability::RelayV1 => {
abilities.relay_v1 = true;
},
#[cfg(any())]
Ability::NoiseCryptoV1 => {
abilities.noise_v1 = true;
},
_ => (),
}
}
Ok(abilities)
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
enum HintSerde {
DirectTcpV1(DirectHint),
RelayV1(RelayHint),
#[serde(other)]
Unknown,
}
#[derive(Clone, Debug, Default)]
pub struct Hints {
pub direct_tcp: HashSet<DirectHint>,
pub relay: Vec<RelayHint>,
}
impl Hints {
pub fn new(
direct_tcp: impl IntoIterator<Item = DirectHint>,
relay: impl IntoIterator<Item = RelayHint>,
) -> Self {
Self {
direct_tcp: direct_tcp.into_iter().collect(),
relay: relay.into_iter().collect(),
}
}
}
impl<'de> serde::Deserialize<'de> for Hints {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let hints: Vec<HintSerde> = serde::Deserialize::deserialize(de)?;
let mut direct_tcp = HashSet::new();
let mut relay = Vec::<RelayHint>::new();
let mut relay_v2 = Vec::<RelayHint>::new();
for hint in hints {
match hint {
HintSerde::DirectTcpV1(hint) => {
direct_tcp.insert(hint);
},
HintSerde::RelayV1(hint) => {
relay_v2.push(hint);
},
_ => {},
}
}
if !relay_v2.is_empty() {
relay.clear();
}
relay.extend(relay_v2);
Ok(Hints { direct_tcp, relay })
}
}
impl serde::Serialize for Hints {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let direct = self.direct_tcp.iter().cloned().map(HintSerde::DirectTcpV1);
let relay = self.relay.iter().cloned().map(HintSerde::RelayV1);
ser.collect_seq(direct.chain(relay))
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)]
#[display("tcp://{}:{}", hostname, port)]
pub struct DirectHint {
pub hostname: String,
pub port: u16,
}
impl DirectHint {
pub fn new(hostname: impl Into<String>, port: u16) -> Self {
Self {
hostname: hostname.into(),
port,
}
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
struct RelayHintSerde {
name: Option<String>,
#[serde(rename = "hints")]
endpoints: Vec<RelayHintSerdeInner>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
enum RelayHintSerdeInner {
#[serde(rename = "direct-tcp-v1")]
Tcp(DirectHint),
Websocket {
url: url::Url,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RelayHintParseError {
#[error(
"Invalid TCP hint endpoint: '{}' (Does it have hostname and port?)",
_0
)]
InvalidTcp(url::Url),
#[error(
"Unknown schema: '{}'. Currently known values are 'tcp', 'ws' and 'wss'.",
_0
)]
UnknownSchema(Box<str>),
#[error("'{}' is not an absolute URL (must start with a '/')", _0)]
UrlNotAbsolute(url::Url),
}
#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct RelayHint {
pub name: Option<String>,
pub tcp: HashSet<DirectHint>,
pub ws: HashSet<url::Url>,
}
impl RelayHint {
pub fn new(
name: Option<String>,
tcp: impl IntoIterator<Item = DirectHint>,
ws: impl IntoIterator<Item = url::Url>,
) -> Self {
Self {
name,
tcp: tcp.into_iter().collect(),
ws: ws.into_iter().collect(),
}
}
pub fn from_urls(
name: Option<String>,
urls: impl IntoIterator<Item = url::Url>,
) -> Result<Self, RelayHintParseError> {
let mut this = Self {
name,
..Self::default()
};
for url in urls.into_iter() {
ensure!(
!url.cannot_be_a_base(),
RelayHintParseError::UrlNotAbsolute(url)
);
match url.scheme() {
"tcp" => {
let (hostname, port) = match (url.host_str(), url.port()) {
(Some(hostname), Some(port)) => (hostname.into(), port),
_ => bail!(RelayHintParseError::InvalidTcp(url)),
};
this.tcp.insert(DirectHint { hostname, port });
},
"ws" | "wss" => {
this.ws.insert(url);
},
other => bail!(RelayHintParseError::UnknownSchema(other.into())),
}
}
assert!(
!this.tcp.is_empty() || !this.ws.is_empty(),
"No URLs provided"
);
Ok(this)
}
pub(crate) fn can_merge(&self, other: &Self) -> bool {
!self.tcp.is_disjoint(&other.tcp) || !self.ws.is_disjoint(&other.ws)
}
pub(crate) fn merge_mut(&mut self, other: Self) {
self.tcp.extend(other.tcp);
self.ws.extend(other.ws);
}
pub(crate) fn merge_into(self, collection: &mut Vec<RelayHint>) {
for item in collection.iter_mut() {
if item.can_merge(&self) {
item.merge_mut(self);
return;
}
}
collection.push(self);
}
}
impl serde::Serialize for RelayHint {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut hints = Vec::new();
hints.extend(self.tcp.iter().cloned().map(RelayHintSerdeInner::Tcp));
hints.extend(
self.ws
.iter()
.cloned()
.map(|h| RelayHintSerdeInner::Websocket { url: h }),
);
serde_json::json!({
"name": self.name,
"hints": hints,
})
.serialize(ser)
}
}
impl<'de> serde::Deserialize<'de> for RelayHint {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RelayHintSerde::deserialize(de)?;
let mut hint = RelayHint {
name: raw.name,
tcp: HashSet::new(),
ws: HashSet::new(),
};
for e in raw.endpoints {
match e {
RelayHintSerdeInner::Tcp(tcp) => {
hint.tcp.insert(tcp);
},
RelayHintSerdeInner::Websocket { url } => {
hint.ws.insert(url);
},
_ => {},
}
}
Ok(hint)
}
}
impl TryFrom<&DirectHint> for IpAddr {
type Error = std::net::AddrParseError;
fn try_from(hint: &DirectHint) -> Result<IpAddr, std::net::AddrParseError> {
hint.hostname.parse()
}
}
impl TryFrom<&DirectHint> for SocketAddr {
type Error = std::net::AddrParseError;
fn try_from(hint: &DirectHint) -> Result<SocketAddr, std::net::AddrParseError> {
let addr = hint.try_into()?;
let addr = match addr {
IpAddr::V4(v4) => IpAddr::V6(v4.to_ipv6_mapped()),
IpAddr::V6(_) => addr,
};
Ok(SocketAddr::new(addr, hint.port))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum ConnectionType {
Direct,
Relay {
name: Option<String>,
},
}
impl std::fmt::Display for ConnectionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionType::Direct => write!(f, "directly"),
ConnectionType::Relay { name: Some(name) } => write!(f, "via relay ({name})"),
ConnectionType::Relay { name: None } => write!(f, "via relay"),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct TransitInfo {
pub conn_type: ConnectionType,
#[cfg(not(target_family = "wasm"))]
pub peer_addr: SocketAddr,
}
type TransitConnection = (Box<dyn TransitTransport>, TransitInfo);
#[cfg(not(target_family = "wasm"))]
#[derive(Debug, thiserror::Error)]
enum StunError {
#[error("No IPv4 addresses were found for the selected STUN server")]
ServerIsV6Only,
#[error("Server did not tell us our IP address")]
ServerNoResponse,
#[error("Connection timed out")]
Timeout,
#[error("IO error")]
IO(
#[from]
#[source]
std::io::Error,
),
#[error("Malformed STUN packet")]
Codec(
#[from]
#[source]
bytecodec::Error,
),
}
#[cfg(not(target_family = "wasm"))]
impl std::fmt::Display for TransitInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.conn_type {
ConnectionType::Direct => {
write!(
f,
"Established direct transit connection to '{}'",
self.peer_addr,
)
},
ConnectionType::Relay { name: Some(name) } => {
write!(
f,
"Established transit connection via relay '{}' ({})",
name, self.peer_addr,
)
},
ConnectionType::Relay { name: None } => {
write!(
f,
"Established transit connection via relay ({})",
self.peer_addr,
)
},
}
}
}
#[cfg(target_family = "wasm")]
impl std::fmt::Display for TransitInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.conn_type {
ConnectionType::Direct => {
write!(f, "Established direct transit connection",)
},
ConnectionType::Relay { name: Some(name) } => {
write!(f, "Established transit connection via relay '{}'", name)
},
ConnectionType::Relay { name: None } => {
write!(f, "Established transit connection via relay",)
},
}
}
}
pub async fn init(
mut abilities: Abilities,
peer_abilities: Option<Abilities>,
relay_hints: Vec<RelayHint>,
) -> Result<TransitConnector, std::io::Error> {
let mut our_hints = Hints::default();
#[cfg(not(target_family = "wasm"))]
let mut sockets = None;
if let Some(peer_abilities) = peer_abilities {
abilities = abilities.intersect(&peer_abilities);
}
#[cfg(not(target_family = "wasm"))]
if abilities.can_direct() {
let create_sockets = async {
let socket: MaybeConnectedSocket = match crate::util::timeout(
std::time::Duration::from_secs(4),
transport::tcp_get_external_ip(),
)
.await
.map_err(|_| StunError::Timeout)
{
Ok(Ok((external_ip, stream))) => {
tracing::debug!("Our external IP address is {}", external_ip);
our_hints.direct_tcp.insert(DirectHint {
hostname: external_ip.ip().to_string(),
port: external_ip.port(),
});
tracing::debug!(
"Our socket for connecting is bound to {} and connected to {}",
stream.local_addr()?,
stream.peer_addr()?,
);
stream.into()
},
Err(err) | Ok(Err(err)) => {
tracing::warn!("Failed to get external address via STUN, {}", err);
let socket =
socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
transport::set_socket_opts(&socket)?;
socket.bind(&"[::]:0".parse::<SocketAddr>().unwrap().into())?;
tracing::debug!(
"Our socket for connecting is bound to {}",
socket.local_addr()?.as_socket().unwrap(),
);
socket.into()
},
};
let listener = TcpListener::bind("[::]:0").await?;
let port = socket.local_addr()?.as_socket().unwrap().port();
let port2 = listener.local_addr()?.port();
our_hints.direct_tcp.extend(
if_addrs::get_if_addrs()?
.iter()
.filter(|iface| !iface.is_loopback())
.flat_map(|ip| {
[
DirectHint {
hostname: ip.ip().to_string(),
port,
},
DirectHint {
hostname: ip.ip().to_string(),
port: port2,
},
]
.into_iter()
}),
);
tracing::debug!("Our socket for listening is {}", listener.local_addr()?);
Ok::<_, std::io::Error>((socket, listener))
};
sockets = create_sockets
.await
.map_err(|err| {
tracing::error!("Failed to create direct hints for our side: {}", err);
err
})
.ok();
}
if abilities.can_relay() {
our_hints.relay.extend(relay_hints);
}
Ok(TransitConnector {
#[cfg(not(target_family = "wasm"))]
sockets,
our_abilities: abilities,
our_hints: Arc::new(our_hints),
})
}
#[cfg(not(target_family = "wasm"))]
#[derive(derive_more::From)]
enum MaybeConnectedSocket {
#[from]
Socket(socket2::Socket),
#[from]
Stream(TcpStream),
}
#[cfg(not(target_family = "wasm"))]
impl MaybeConnectedSocket {
fn local_addr(&self) -> std::io::Result<socket2::SockAddr> {
match &self {
Self::Socket(socket) => socket.local_addr(),
Self::Stream(stream) => Ok(stream.local_addr()?.into()),
}
}
}
#[derive(Clone, Debug)]
pub enum TransitRole {
Leader,
Follower,
}
pub struct TransitConnector {
#[cfg(not(target_family = "wasm"))]
sockets: Option<(MaybeConnectedSocket, TcpListener)>,
our_abilities: Abilities,
our_hints: Arc<Hints>,
}
impl TransitConnector {
pub fn our_abilities(&self) -> &Abilities {
&self.our_abilities
}
pub fn our_hints(&self) -> &Arc<Hints> {
&self.our_hints
}
pub async fn connect(
self,
role: TransitRole,
transit_key: Key<TransitKey>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
) -> Result<(Transit, TransitInfo), TransitConnectError> {
match role {
TransitRole::Leader => {
self.leader_connect(transit_key, their_abilities, their_hints)
.await
},
TransitRole::Follower => {
self.follower_connect(transit_key, their_abilities, their_hints)
.await
},
}
}
async fn leader_connect(
self,
transit_key: Key<TransitKey>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
) -> Result<(Transit, TransitInfo), TransitConnectError> {
let Self {
#[cfg(not(target_family = "wasm"))]
sockets,
our_abilities,
our_hints,
} = self;
let transit_key = Arc::new(transit_key);
let start = Instant::now();
let mut connection_stream = Box::pin(
Self::connect_inner(
true,
transit_key,
our_abilities,
our_hints,
their_abilities,
their_hints,
#[cfg(not(target_family = "wasm"))]
sockets,
)
.filter_map(|result| async {
match result {
Ok(val) => Some(val),
Err(err) => {
tracing::debug!("Some leader handshake failed: {:?}", err);
None
},
}
}),
);
let (mut transit, mut finalizer, mut conn_info) =
crate::util::timeout(std::time::Duration::from_secs(60), connection_stream.next())
.await
.map_err(|_| {
tracing::debug!("`leader_connect` timed out");
TransitConnectError::Handshake
})?
.ok_or(TransitConnectError::Handshake)?;
if conn_info.conn_type != ConnectionType::Direct && our_abilities.can_direct() {
tracing::debug!(
"Established transit connection over relay. Trying to find a direct connection …"
);
let elapsed = start.elapsed();
let to_wait = if elapsed.as_secs() > 5 {
std::time::Duration::from_secs(1)
} else {
elapsed.mul_f32(0.3)
};
let _ = crate::util::timeout(to_wait, async {
while let Some((new_transit, new_finalizer, new_conn_info)) =
connection_stream.next().await
{
if new_conn_info.conn_type == ConnectionType::Direct {
transit = new_transit;
finalizer = new_finalizer;
conn_info = new_conn_info;
tracing::debug!("Found direct connection; using that instead.");
break;
}
}
})
.await;
tracing::debug!("Did not manage to establish a better connection in time.");
} else {
tracing::debug!("Established direct transit connection");
}
std::mem::drop(connection_stream);
let (tx, rx) = finalizer
.handshake_finalize(&mut transit)
.await
.map_err(|e| {
tracing::debug!("`handshake_finalize` failed: {e}");
TransitConnectError::Handshake
})?;
Ok((
Transit {
socket: transit,
tx,
rx,
},
conn_info,
))
}
async fn follower_connect(
self,
transit_key: Key<TransitKey>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
) -> Result<(Transit, TransitInfo), TransitConnectError> {
let Self {
#[cfg(not(target_family = "wasm"))]
sockets,
our_abilities,
our_hints,
} = self;
let transit_key = Arc::new(transit_key);
let mut connection_stream = Box::pin(
Self::connect_inner(
false,
transit_key,
our_abilities,
our_hints,
their_abilities,
their_hints,
#[cfg(not(target_family = "wasm"))]
sockets,
)
.filter_map(|result| async {
match result {
Ok(val) => Some(val),
Err(err) => {
tracing::debug!("Some follower handshake failed: {:?}", err);
None
},
}
}),
);
let transit = match crate::util::timeout(
std::time::Duration::from_secs(60),
&mut connection_stream.next(),
)
.await
{
Ok(Some((mut socket, finalizer, conn_info))) => {
let (tx, rx) = finalizer
.handshake_finalize(&mut socket)
.await
.map_err(|e| {
tracing::debug!("`handshake_finalize` failed: {e}");
TransitConnectError::Handshake
})?;
Ok((Transit { socket, tx, rx }, conn_info))
},
Ok(None) | Err(_) => {
tracing::debug!("`follower_connect` timed out");
Err(TransitConnectError::Handshake)
},
};
std::mem::drop(connection_stream);
transit
}
fn connect_inner(
is_leader: bool,
transit_key: Arc<Key<TransitKey>>,
our_abilities: Abilities,
our_hints: Arc<Hints>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
#[cfg(not(target_family = "wasm"))] sockets: Option<(MaybeConnectedSocket, TcpListener)>,
) -> impl Stream<Item = Result<HandshakeResult, TransitHandshakeError>> + 'static {
#[cfg(not(target_family = "wasm"))]
assert!(sockets.is_none() || our_abilities.can_direct());
let cryptor = if our_abilities.can_noise_crypto() && their_abilities.can_noise_crypto() {
tracing::debug!("Using noise protocol for encryption");
Arc::new(crypto::NoiseInit {
key: transit_key.clone(),
}) as Arc<dyn crypto::TransitCryptoInit>
} else {
tracing::debug!("Using secretbox for encryption");
Arc::new(crypto::SecretboxInit {
key: transit_key.clone(),
}) as Arc<dyn crypto::TransitCryptoInit>
};
let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>()));
#[cfg(not(target_family = "wasm"))]
use futures::future::BoxFuture;
#[cfg(target_family = "wasm")]
use futures::future::LocalBoxFuture as BoxFuture;
type BoxIterator<T> = Box<dyn Iterator<Item = T>>;
type ConnectorFuture = BoxFuture<'static, Result<TransitConnection, TransitHandshakeError>>;
let mut connectors: BoxIterator<ConnectorFuture> = Box::new(std::iter::empty());
#[cfg(not(target_family = "wasm"))]
let (socket, listener) = sockets.unzip();
#[cfg(not(target_family = "wasm"))]
if our_abilities.can_direct() && their_abilities.can_direct() {
let local_addr = socket.map(|socket| {
Arc::new(
socket
.local_addr()
.expect("This is guaranteed to be an IP socket"),
)
});
connectors = Box::new(
connectors.chain(
their_hints
.direct_tcp
.clone()
.into_iter()
.take(50)
.map(move |hint| transport::connect_tcp_direct(local_addr.clone(), hint))
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
}
if our_abilities.can_relay() && their_abilities.can_relay() {
let mut relay_hints = Vec::<RelayHint>::new();
relay_hints.extend(our_hints.relay.iter().take(2).cloned());
for hint in their_hints.relay.iter().take(2).cloned() {
hint.merge_into(&mut relay_hints);
}
#[cfg(not(target_family = "wasm"))]
{
connectors = Box::new(
connectors.chain(
relay_hints
.into_iter()
.flat_map(|hint| {
let name = hint.name
.or_else(|| {
hint.tcp.iter()
.filter_map(|hint| match url::Host::parse(&hint.hostname) {
Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
_ => None,
})
.next()
});
hint.tcp
.into_iter()
.take(3)
.enumerate()
.map(move |(i, h)| (i, h, name.clone()))
})
.map(|(index, host, name)| async move {
async_io::Timer::after(std::time::Duration::from_secs(
index as u64 * 5,
))
.await;
transport::connect_tcp_relay(host, name).await
})
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
}
#[cfg(target_family = "wasm")]
{
connectors = Box::new(
connectors.chain(
relay_hints
.into_iter()
.flat_map(|hint| {
let name = hint.name
.or_else(|| {
hint.tcp.iter()
.filter_map(|hint| match url::Host::parse(&hint.hostname) {
Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
_ => None,
})
.next()
});
hint.ws
.into_iter()
.take(3)
.enumerate()
.map(move |(i, u)| (i, u, name.clone()))
})
.map(|(index, url, name)| async move {
crate::util::sleep(std::time::Duration::from_secs(
index as u64 * 5,
))
.await;
transport::connect_ws_relay(url, name).await
})
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
}
}
let transit_key2 = transit_key.clone();
let tside2 = tside.clone();
let cryptor2 = cryptor.clone();
#[allow(unused_mut)] let mut connectors = Box::new(
connectors
.map(move |fut| {
let transit_key = transit_key2.clone();
let tside = tside2.clone();
let cryptor = cryptor2.clone();
async move {
let (socket, conn_info) = fut.await?;
let (transit, finalizer) = handshake_exchange(
is_leader,
tside,
socket,
&conn_info.conn_type,
&*cryptor,
transit_key,
)
.await?;
Ok((transit, finalizer, conn_info))
}
})
.map(|fut| {
Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
}),
)
as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
#[cfg(not(target_family = "wasm"))]
if let Some(listener) = listener {
connectors = Box::new(
connectors.chain(
std::iter::once(async move {
let transit_key = transit_key.clone();
let tside = tside.clone();
let cryptor = cryptor.clone();
let connect = || async {
let (socket, peer) = listener.accept().await?;
let (socket, info) =
transport::wrap_tcp_connection(socket, ConnectionType::Direct)?;
tracing::debug!("Got connection from {}!", peer);
let (transit, finalizer) = handshake_exchange(
is_leader,
tside.clone(),
socket,
&ConnectionType::Direct,
&*cryptor,
transit_key.clone(),
)
.await?;
Result::<_, TransitHandshakeError>::Ok((transit, finalizer, info))
};
loop {
match connect().await {
Ok(success) => break Ok(success),
Err(err) => {
tracing::debug!(
"Some handshake failed on the listening port: {:?}",
err
);
continue;
},
}
}
})
.map(|fut| {
Box::pin(fut) as BoxFuture<Result<HandshakeResult, TransitHandshakeError>>
}),
),
)
as BoxIterator<BoxFuture<Result<HandshakeResult, TransitHandshakeError>>>;
}
connectors.collect::<futures::stream::futures_unordered::FuturesUnordered<_>>()
}
}
pub struct Transit {
socket: Box<dyn TransitTransport>,
tx: Box<dyn crypto::TransitCryptoEncrypt>,
rx: Box<dyn crypto::TransitCryptoDecrypt>,
}
impl Transit {
pub async fn receive_record(&mut self) -> Result<Box<[u8]>, TransitError> {
self.rx.decrypt(&mut self.socket).await
}
pub async fn send_record(&mut self, plaintext: &[u8]) -> Result<(), TransitError> {
assert!(!plaintext.is_empty());
self.tx.encrypt(&mut self.socket, plaintext).await
}
pub async fn flush(&mut self) -> Result<(), TransitError> {
tracing::debug!("Flush");
self.socket.flush().await.map_err(Into::into)
}
#[cfg(not(target_family = "wasm"))]
#[expect(clippy::type_complexity)]
pub fn split(
self,
) -> (
impl futures::sink::Sink<Box<[u8]>, Error = TransitError>,
impl futures_lite::stream::Stream<Item = Result<Box<[u8]>, TransitError>>,
) {
let (reader, writer) = self.socket.split();
(
futures::sink::unfold(
(writer, self.tx),
|(mut writer, mut tx), plaintext: Box<[u8]>| async move {
tx.encrypt(&mut writer, &plaintext)
.await
.map(|()| (writer, tx))
},
),
futures::stream::try_unfold((reader, self.rx), |(mut reader, mut rx)| async move {
rx.decrypt(&mut reader)
.await
.map(|record| Some((record, (reader, rx))))
}),
)
}
}
type HandshakeResult = (
Box<dyn TransitTransport>,
Box<dyn crypto::TransitCryptoInitFinalizer>,
TransitInfo,
);
async fn handshake_exchange(
is_leader: bool,
tside: Arc<String>,
mut socket: Box<dyn TransitTransport>,
host_type: &ConnectionType,
cryptor: &dyn crypto::TransitCryptoInit,
key: Arc<Key<TransitKey>>,
) -> Result<
(
Box<dyn TransitTransport>,
Box<dyn crypto::TransitCryptoInitFinalizer>,
),
TransitHandshakeError,
> {
if host_type != &ConnectionType::Direct {
tracing::trace!("initiating relay handshake");
let sub_key = key.derive_subkey_from_purpose::<GenericKey>("transit_relay_token");
socket
.write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes())
.await?;
let mut rx = [0u8; 3];
socket.read_exact(&mut rx).await?;
let ok_msg: [u8; 3] = *b"ok\n";
ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed);
}
let finalizer = if is_leader {
cryptor.handshake_leader(&mut socket).await?
} else {
cryptor.handshake_follower(&mut socket).await?
};
Ok((socket, finalizer))
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::json;
#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
pub fn test_abilities_encoding() {
assert_eq!(
serde_json::to_value(Abilities::ALL).unwrap(),
json!([{"type": "direct-tcp-v1"}, {"type": "relay-v1"}])
);
assert_eq!(
serde_json::to_value(Abilities::FORCE_DIRECT).unwrap(),
json!([{"type": "direct-tcp-v1"}])
);
}
#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
pub fn test_hints_encoding() {
assert_eq!(
serde_json::to_value(Hints::new(
[DirectHint {
hostname: "localhost".into(),
port: 1234
}],
[RelayHint::new(
Some("default".into()),
[DirectHint::new("transit.magic-wormhole.io", 4001)],
["ws://transit.magic-wormhole.io/relay".parse().unwrap(),],
)]
))
.unwrap(),
json!([
{
"type": "direct-tcp-v1",
"hostname": "localhost",
"port": 1234
},
{
"type": "relay-v1",
"name": "default",
"hints": [
{
"type": "direct-tcp-v1",
"hostname": "transit.magic-wormhole.io",
"port": 4001,
},
{
"type": "websocket",
"url": "ws://transit.magic-wormhole.io/relay",
},
]
}
])
)
}
}