use crate::{Key, KeyPurpose};
use serde_derive::{Deserialize, Serialize};
use async_std::{
io::{prelude::WriteExt, ReadExt},
net::{TcpListener, TcpStream},
};
#[allow(unused_imports)]
use futures::{future::TryFutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use log::*;
use std::{
collections::HashSet,
net::{IpAddr, SocketAddr, ToSocketAddrs},
sync::Arc,
};
mod crypto;
pub const DEFAULT_RELAY_SERVER: &str = "tcp://transit.magic-wormhole.io:4001";
const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478";
#[derive(Debug)]
pub struct TransitKey;
impl KeyPurpose for TransitKey {}
#[derive(Debug)]
pub struct TransitRxKey;
impl KeyPurpose for TransitRxKey {}
#[derive(Debug)]
pub struct TransitTxKey;
impl KeyPurpose for TransitTxKey {}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TransitConnectError {
#[error("{}", _0)]
Protocol(Box<str>),
#[error("All (relay) handshakes failed or timed out; could not establish a connection with the peer")]
Handshake,
#[error("IO error")]
IO(
#[from]
#[source]
std::io::Error,
),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TransitError {
#[error("Cryptography error. This is probably an implementation bug, but may also be caused by an attack.")]
Crypto,
#[error("Wrong nonce received, got {:x?} but expected {:x?}. This is probably an implementation bug, but may also be caused by an attack.", _0, _1)]
Nonce(Box<[u8]>, Box<[u8]>),
#[error("IO error")]
IO(
#[from]
#[source]
std::io::Error,
),
}
impl From<()> for TransitError {
fn from(_: ()) -> Self {
Self::Crypto
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct Abilities {
pub direct_tcp_v1: bool,
pub relay_v1: bool,
#[cfg(all())]
pub noise_v1: bool,
}
impl Abilities {
pub const ALL_ABILITIES: Self = Self {
direct_tcp_v1: true,
relay_v1: true,
#[cfg(all())]
noise_v1: false,
};
pub const FORCE_DIRECT: Self = Self {
direct_tcp_v1: true,
relay_v1: false,
#[cfg(all())]
noise_v1: false,
};
pub const FORCE_RELAY: Self = Self {
direct_tcp_v1: false,
relay_v1: true,
#[cfg(all())]
noise_v1: false,
};
pub fn can_direct(&self) -> bool {
self.direct_tcp_v1
}
pub fn can_relay(&self) -> bool {
self.relay_v1
}
#[cfg(all())]
pub fn can_noise_crypto(&self) -> bool {
self.noise_v1
}
pub fn intersect(mut self, other: &Self) -> Self {
self.direct_tcp_v1 &= other.direct_tcp_v1;
self.relay_v1 &= other.relay_v1;
#[cfg(all())]
{
self.noise_v1 &= other.noise_v1;
}
self
}
}
impl serde::Serialize for Abilities {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut hints = Vec::new();
if self.direct_tcp_v1 {
hints.push(serde_json::json!({
"type": "direct-tcp-v1",
}));
}
if self.relay_v1 {
hints.push(serde_json::json!({
"type": "relay-v1",
}));
}
#[cfg(all())]
if self.noise_v1 {
hints.push(serde_json::json!({
"type": "noise-crypto-v1",
}));
}
serde_json::Value::Array(hints).serialize(ser)
}
}
impl<'de> serde::Deserialize<'de> for Abilities {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case", tag = "type")]
enum Ability {
DirectTcpV1,
RelayV1,
RelayV2,
#[cfg(all())]
NoiseCryptoV1,
#[serde(other)]
Other,
}
let mut abilities = Self::default();
for ability in <Vec<Ability> as serde::Deserialize>::deserialize(de)? {
match ability {
Ability::DirectTcpV1 => {
abilities.direct_tcp_v1 = true;
},
Ability::RelayV1 => {
abilities.relay_v1 = true;
},
#[cfg(all())]
Ability::NoiseCryptoV1 => {
abilities.noise_v1 = true;
},
_ => (),
}
}
Ok(abilities)
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
enum HintSerde {
DirectTcpV1(DirectHint),
RelayV1(RelayHint),
#[serde(other)]
Unknown,
}
#[derive(Clone, Debug, Default)]
pub struct Hints {
pub direct_tcp: HashSet<DirectHint>,
pub relay: Vec<RelayHint>,
}
impl Hints {
pub fn new(
direct_tcp: impl IntoIterator<Item = DirectHint>,
relay: impl IntoIterator<Item = RelayHint>,
) -> Self {
Self {
direct_tcp: direct_tcp.into_iter().collect(),
relay: relay.into_iter().collect(),
}
}
}
impl<'de> serde::Deserialize<'de> for Hints {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let hints: Vec<HintSerde> = serde::Deserialize::deserialize(de)?;
let mut direct_tcp = HashSet::new();
let mut relay = Vec::<RelayHint>::new();
let mut relay_v2 = Vec::<RelayHint>::new();
for hint in hints {
match hint {
HintSerde::DirectTcpV1(hint) => {
direct_tcp.insert(hint);
},
HintSerde::RelayV1(hint) => {
relay_v2.push(hint);
},
_ => {},
}
}
if !relay_v2.is_empty() {
relay.clear();
}
relay.extend(relay_v2.into_iter().map(Into::into));
Ok(Hints { direct_tcp, relay })
}
}
impl serde::Serialize for Hints {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let direct = self.direct_tcp.iter().cloned().map(HintSerde::DirectTcpV1);
let relay = self.relay.iter().cloned().map(HintSerde::RelayV1);
ser.collect_seq(direct.chain(relay))
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)]
#[display(fmt = "tcp://{}:{}", hostname, port)]
pub struct DirectHint {
pub hostname: String,
pub port: u16,
}
impl DirectHint {
pub fn new(hostname: impl Into<String>, port: u16) -> Self {
Self {
hostname: hostname.into(),
port,
}
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
struct RelayHintSerde {
name: Option<String>,
#[serde(rename = "hints")]
endpoints: Vec<RelayHintSerdeInner>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "kebab-case", tag = "type")]
#[non_exhaustive]
enum RelayHintSerdeInner {
#[serde(rename = "direct-tcp-v1")]
Tcp(DirectHint),
Websocket {
url: url::Url,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RelayHintParseError {
#[error(
"Invalid TCP hint endpoint: '{}' (Does it have hostname and port?)",
_0
)]
InvalidTcp(url::Url),
#[error(
"Unknown schema: '{}'. Currently known values are 'tcp', 'ws' and 'wss'.",
_0
)]
UnknownSchema(Box<str>),
#[error("'{}' is not an absolute URL (must start with a '/')", _0)]
UrlNotAbsolute(url::Url),
}
#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct RelayHint {
pub name: Option<String>,
pub tcp: HashSet<DirectHint>,
pub ws: HashSet<url::Url>,
}
impl RelayHint {
pub fn new(
name: Option<String>,
tcp: impl IntoIterator<Item = DirectHint>,
ws: impl IntoIterator<Item = url::Url>,
) -> Self {
Self {
name,
tcp: tcp.into_iter().collect(),
ws: ws.into_iter().collect(),
}
}
pub fn from_urls(
name: Option<String>,
urls: impl IntoIterator<Item = url::Url>,
) -> Result<Self, RelayHintParseError> {
let mut this = Self {
name,
..Self::default()
};
for url in urls.into_iter() {
ensure!(
!url.cannot_be_a_base(),
RelayHintParseError::UrlNotAbsolute(url)
);
match url.scheme() {
"tcp" => {
let (hostname, port) = match (url.host_str(), url.port()) {
(Some(hostname), Some(port)) => (hostname.into(), port),
_ => bail!(RelayHintParseError::InvalidTcp(url)),
};
this.tcp.insert(DirectHint { hostname, port });
},
"ws" | "wss" => {
this.ws.insert(url);
},
other => bail!(RelayHintParseError::UnknownSchema(other.into())),
}
}
assert!(
!this.tcp.is_empty() || !this.ws.is_empty(),
"No URLs provided"
);
Ok(this)
}
pub fn can_merge(&self, other: &Self) -> bool {
!self.tcp.is_disjoint(&other.tcp) || !self.ws.is_disjoint(&other.ws)
}
pub fn merge(mut self, other: Self) -> Self {
self.merge_mut(other);
self
}
pub fn merge_mut(&mut self, other: Self) {
self.tcp.extend(other.tcp);
self.ws.extend(other.ws);
}
pub fn merge_into(self, collection: &mut Vec<RelayHint>) {
for item in collection.iter_mut() {
if item.can_merge(&self) {
item.merge_mut(self);
return;
}
}
collection.push(self);
}
}
impl serde::Serialize for RelayHint {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut hints = Vec::new();
hints.extend(self.tcp.iter().cloned().map(RelayHintSerdeInner::Tcp));
hints.extend(
self.ws
.iter()
.cloned()
.map(|h| RelayHintSerdeInner::Websocket { url: h }),
);
serde_json::json!({
"name": self.name,
"hints": hints,
})
.serialize(ser)
}
}
impl<'de> serde::Deserialize<'de> for RelayHint {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RelayHintSerde::deserialize(de)?;
let mut hint = RelayHint {
name: raw.name,
tcp: HashSet::new(),
ws: HashSet::new(),
};
for e in raw.endpoints {
match e {
RelayHintSerdeInner::Tcp(tcp) => {
hint.tcp.insert(tcp);
},
RelayHintSerdeInner::Websocket { url } => {
hint.ws.insert(url);
},
_ => {},
}
}
Ok(hint)
}
}
use std::convert::{TryFrom, TryInto};
impl TryFrom<&DirectHint> for IpAddr {
type Error = std::net::AddrParseError;
fn try_from(hint: &DirectHint) -> Result<IpAddr, std::net::AddrParseError> {
hint.hostname.parse()
}
}
impl TryFrom<&DirectHint> for SocketAddr {
type Error = std::net::AddrParseError;
fn try_from(hint: &DirectHint) -> Result<SocketAddr, std::net::AddrParseError> {
let addr = hint.try_into()?;
let addr = match addr {
IpAddr::V4(v4) => IpAddr::V6(v4.to_ipv6_mapped()),
IpAddr::V6(_) => addr,
};
Ok(SocketAddr::new(addr, hint.port))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum TransitInfo {
Direct,
Relay { name: Option<String> },
}
type TransitConnection = (TcpStream, TransitInfo);
fn set_socket_opts(socket: &socket2::Socket) -> std::io::Result<()> {
socket.set_nonblocking(true)?;
socket.set_reuse_address(true)?;
#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))]
{
socket.set_reuse_port(true)?;
}
#[cfg(not(any(
all(unix, not(any(target_os = "solaris", target_os = "illumos"))),
target_os = "windows"
)))]
{
compile_error!("Your system is not supported yet, please raise an error");
}
Ok(())
}
async fn connect_custom(
local_addr: &socket2::SockAddr,
dest_addr: &socket2::SockAddr,
) -> std::io::Result<async_std::net::TcpStream> {
log::debug!("Binding to {}", local_addr.as_socket().unwrap());
let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
set_socket_opts(&socket)?;
socket.bind(local_addr)?;
match socket.connect(dest_addr) {
Ok(_) => {},
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {},
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(err),
}
let stream = async_io::Async::new(std::net::TcpStream::from(socket))?;
stream.writable().await?;
stream
.get_ref()
.take_error()
.and_then(|maybe_err| maybe_err.map_or(Ok(()), Result::Err))?;
Ok(stream.into_inner()?.into())
}
#[derive(Debug, thiserror::Error)]
enum StunError {
#[error("No IPv4 addresses were found for the selected STUN server")]
ServerIsV6Only,
#[error("Server did not tell us our IP address")]
ServerNoResponse,
#[error("Connection timed out")]
Timeout,
#[error("IO error")]
IO(
#[from]
#[source]
std::io::Error,
),
#[error("Malformed STUN packet")]
Codec(
#[from]
#[source]
bytecodec::Error,
),
}
async fn get_external_ip() -> Result<(SocketAddr, TcpStream), StunError> {
let mut socket = connect_custom(
&"[::]:0".parse::<SocketAddr>().unwrap().into(),
&PUBLIC_STUN_SERVER
.to_socket_addrs()?
.find(|x| x.is_ipv4())
.map(|addr| match addr {
SocketAddr::V4(v4) => {
SocketAddr::new(IpAddr::V6(v4.ip().to_ipv6_mapped()), v4.port())
},
SocketAddr::V6(_) => unreachable!(),
})
.ok_or(StunError::ServerIsV6Only)?
.into(),
)
.await?;
use bytecodec::{DecodeExt, EncodeExt};
use stun_codec::{
rfc5389::{
self,
attributes::{MappedAddress, Software, XorMappedAddress},
Attribute,
},
Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId,
};
fn get_binding_request() -> Result<Vec<u8>, bytecodec::Error> {
use rand::Rng;
let random_bytes = rand::thread_rng().gen::<[u8; 12]>();
let mut message = Message::<Attribute>::new(
MessageClass::Request,
rfc5389::methods::BINDING,
TransactionId::new(random_bytes),
);
message.add_attribute(Attribute::Software(Software::new(
"magic-wormhole-rust".to_owned(),
)?));
let mut encoder = MessageEncoder::new();
let bytes = encoder.encode_into_bytes(message.clone())?;
Ok(bytes)
}
fn decode_address(buf: &[u8]) -> Result<Option<SocketAddr>, bytecodec::Error> {
let mut decoder = MessageDecoder::<Attribute>::new();
let decoded = decoder.decode_from_bytes(buf)??;
let external_addr1 = decoded
.get_attribute::<XorMappedAddress>()
.map(|x| x.address());
let external_addr3 = decoded
.get_attribute::<MappedAddress>()
.map(|x| x.address());
let external_addr = external_addr1
.or(external_addr3);
Ok(external_addr)
}
socket.write_all(get_binding_request()?.as_ref()).await?;
let mut buf = [0u8; 256];
socket.read_exact(&mut buf[..20]).await?;
let len: u16 = u16::from_be_bytes([buf[2], buf[3]]);
socket.read_exact(&mut buf[20..][..len as usize]).await?;
let external_addr =
decode_address(&buf[..20 + len as usize])?.ok_or(StunError::ServerNoResponse)?;
Ok((external_addr, socket))
}
pub fn log_transit_connection(info: TransitInfo, peer_addr: SocketAddr) {
match info {
TransitInfo::Direct => {
log::info!("Established direct transit connection to '{}'", peer_addr,);
},
TransitInfo::Relay { name: Some(name) } => {
log::info!(
"Established transit connection via relay '{}' ({})",
name,
peer_addr,
);
},
TransitInfo::Relay { name: None } => {
log::info!("Established transit connection via relay ({})", peer_addr,);
},
}
}
pub async fn init(
mut abilities: Abilities,
peer_abilities: Option<Abilities>,
relay_hints: Vec<RelayHint>,
) -> Result<TransitConnector, std::io::Error> {
let mut our_hints = Hints::default();
let mut listener = None;
if let Some(peer_abilities) = peer_abilities {
abilities = abilities.intersect(&peer_abilities);
}
if abilities.can_direct() {
let create_sockets = async {
let socket: MaybeConnectedSocket = match async_std::future::timeout(
std::time::Duration::from_secs(4),
get_external_ip(),
)
.await
.map_err(|_| StunError::Timeout)
{
Ok(Ok((external_ip, stream))) => {
log::debug!("Our external IP address is {}", external_ip);
our_hints.direct_tcp.insert(DirectHint {
hostname: external_ip.ip().to_string(),
port: external_ip.port(),
});
log::debug!(
"Our socket for connecting is bound to {} and connected to {}",
stream.local_addr()?,
stream.peer_addr()?,
);
stream.into()
},
Err(err) | Ok(Err(err)) => {
log::warn!("Failed to get external address via STUN, {}", err);
let socket =
socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
set_socket_opts(&socket)?;
socket.bind(&"[::]:0".parse::<SocketAddr>().unwrap().into())?;
log::debug!(
"Our socket for connecting is bound to {}",
socket.local_addr()?.as_socket().unwrap(),
);
socket.into()
},
};
let socket2 = TcpListener::bind("[::]:0").await?;
let port = socket.local_addr()?.as_socket().unwrap().port();
let port2 = socket2.local_addr()?.port();
our_hints.direct_tcp.extend(
if_addrs::get_if_addrs()?
.iter()
.filter(|iface| !iface.is_loopback())
.flat_map(|ip| {
[
DirectHint {
hostname: ip.ip().to_string(),
port,
},
DirectHint {
hostname: ip.ip().to_string(),
port: port2,
},
]
.into_iter()
}),
);
log::debug!("Our socket for listening is {}", socket2.local_addr()?);
Ok::<_, std::io::Error>((socket, socket2))
};
listener = create_sockets
.await
.map_err(|err| {
log::error!("Failed to create direct hints for our side: {}", err);
err
})
.ok();
}
if abilities.can_relay() {
our_hints.relay.extend(relay_hints);
}
Ok(TransitConnector {
sockets: listener,
our_abilities: abilities,
our_hints: Arc::new(our_hints),
})
}
#[derive(derive_more::From)]
enum MaybeConnectedSocket {
#[from]
Socket(socket2::Socket),
#[from]
Stream(TcpStream),
}
impl MaybeConnectedSocket {
fn local_addr(&self) -> std::io::Result<socket2::SockAddr> {
match &self {
Self::Socket(socket) => socket.local_addr(),
Self::Stream(stream) => Ok(stream.local_addr()?.into()),
}
}
}
pub struct TransitConnector {
sockets: Option<(MaybeConnectedSocket, TcpListener)>,
our_abilities: Abilities,
our_hints: Arc<Hints>,
}
impl TransitConnector {
pub fn our_abilities(&self) -> &Abilities {
&self.our_abilities
}
pub fn our_hints(&self) -> &Arc<Hints> {
&self.our_hints
}
pub async fn leader_connect(
self,
transit_key: Key<TransitKey>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
) -> Result<(Transit, TransitInfo, SocketAddr), TransitConnectError> {
let Self {
sockets,
our_abilities,
our_hints,
} = self;
let transit_key = Arc::new(transit_key);
let start = std::time::Instant::now();
let mut connection_stream = Box::pin(
Self::connect(
true,
transit_key,
our_abilities,
our_hints,
their_abilities,
their_hints,
sockets,
)
.filter_map(|result| async {
match result {
Ok(val) => Some(val),
Err(err) => {
log::debug!("Some leader handshake failed: {:?}", err);
None
},
}
}),
);
let (mut transit, mut host_type) = async_std::future::timeout(
std::time::Duration::from_secs(60),
connection_stream.next(),
)
.await
.map_err(|_| {
log::debug!("`leader_connect` timed out");
TransitConnectError::Handshake
})?
.ok_or(TransitConnectError::Handshake)?;
if host_type != TransitInfo::Direct && our_abilities.can_direct() {
log::debug!(
"Established transit connection over relay. Trying to find a direct connection …"
);
let elapsed = start.elapsed();
let to_wait = if elapsed.as_secs() > 5 {
std::time::Duration::from_secs(1)
} else {
elapsed.mul_f32(0.3)
};
let _ = async_std::future::timeout(to_wait, async {
while let Some((new_transit, new_host_type)) = connection_stream.next().await {
if new_host_type == TransitInfo::Direct {
transit = new_transit;
host_type = new_host_type;
log::debug!("Found direct connection; using that instead.");
break;
}
}
})
.await;
log::debug!("Did not manage to establish a better connection in time.");
} else {
log::debug!("Established direct transit connection");
}
std::mem::drop(connection_stream);
let (mut socket, finalizer) = transit;
let (tx, rx) = finalizer
.handshake_finalize(&mut socket)
.await
.map_err(|e| {
log::debug!("`handshake_finalize` failed: {e}");
TransitConnectError::Handshake
})?;
let addr = socket.peer_addr().unwrap();
Ok((Transit { socket, tx, rx }, host_type, addr))
}
pub async fn follower_connect(
self,
transit_key: Key<TransitKey>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
) -> Result<(Transit, TransitInfo, SocketAddr), TransitConnectError> {
let Self {
sockets,
our_abilities,
our_hints,
} = self;
let transit_key = Arc::new(transit_key);
let mut connection_stream = Box::pin(
Self::connect(
false,
transit_key,
our_abilities,
our_hints,
their_abilities,
their_hints,
sockets,
)
.filter_map(|result| async {
match result {
Ok(val) => Some(val),
Err(err) => {
log::debug!("Some follower handshake failed: {:?}", err);
None
},
}
}),
);
let transit = match async_std::future::timeout(
std::time::Duration::from_secs(60),
&mut connection_stream.next(),
)
.await
{
Ok(Some(((mut socket, finalizer), host_type))) => {
let addr = socket.peer_addr().unwrap();
let (tx, rx) = finalizer
.handshake_finalize(&mut socket)
.await
.map_err(|e| {
log::debug!("`handshake_finalize` failed: {e}");
TransitConnectError::Handshake
})?;
Ok((Transit { socket, tx, rx }, host_type, addr))
},
Ok(None) | Err(_) => {
log::debug!("`follower_connect` timed out");
Err(TransitConnectError::Handshake)
},
};
std::mem::drop(connection_stream);
transit
}
fn connect(
is_leader: bool,
transit_key: Arc<Key<TransitKey>>,
our_abilities: Abilities,
our_hints: Arc<Hints>,
their_abilities: Abilities,
their_hints: Arc<Hints>,
socket: Option<(MaybeConnectedSocket, TcpListener)>,
) -> impl Stream<Item = Result<(HandshakeResult, TransitInfo), crypto::TransitHandshakeError>>
+ 'static {
assert!(socket.is_none() || our_abilities.can_direct());
let cryptor = if our_abilities.can_noise_crypto() && their_abilities.can_noise_crypto() {
log::debug!("Using noise protocol for encryption");
Arc::new(crypto::NoiseInit {
key: transit_key.clone(),
}) as Arc<dyn crypto::TransitCryptoInit>
} else {
log::debug!("Using secretbox for encryption");
Arc::new(crypto::SecretboxInit {
key: transit_key.clone(),
}) as Arc<dyn crypto::TransitCryptoInit>
};
let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>()));
use futures::future::BoxFuture;
type BoxIterator<T> = Box<dyn Iterator<Item = T>>;
type ConnectorFuture =
BoxFuture<'static, Result<TransitConnection, crypto::TransitHandshakeError>>;
let mut connectors: BoxIterator<ConnectorFuture> = Box::new(std::iter::empty());
let socket2 = if let Some((socket, socket2)) = socket {
let local_addr = Arc::new(socket.local_addr().unwrap());
connectors = Box::new(
connectors.chain(
their_hints
.direct_tcp
.clone()
.into_iter()
.take(50)
.map(move |hint| {
let local_addr = local_addr.clone();
async move {
let dest_addr = SocketAddr::try_from(&hint)?;
log::debug!("Connecting directly to {}", dest_addr);
let socket = connect_custom(&local_addr, &dest_addr.into()).await?;
log::debug!("Connected to {}!", dest_addr);
Ok((socket, TransitInfo::Direct))
}
})
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
Some(socket2)
} else if our_abilities.can_direct() {
connectors = Box::new(
connectors.chain(
their_hints
.direct_tcp
.clone()
.into_iter()
.take(50)
.map(move |hint| async move {
let dest_addr = SocketAddr::try_from(&hint)?;
log::debug!("Connecting directly to {}", dest_addr);
let socket = async_std::net::TcpStream::connect(&dest_addr).await?;
log::debug!("Connected to {}!", dest_addr);
Ok((socket, TransitInfo::Direct))
})
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
None
} else {
None
};
if our_abilities.can_relay() && their_abilities.can_relay() {
let mut relay_hints = Vec::<RelayHint>::new();
relay_hints.extend(our_hints.relay.iter().take(2).cloned());
for hint in their_hints.relay.iter().take(2).cloned() {
hint.merge_into(&mut relay_hints);
}
async fn hint_connector(
host: DirectHint,
name: Option<String>,
) -> Result<TransitConnection, crypto::TransitHandshakeError> {
log::debug!("Connecting to relay {}", host);
let transit = TcpStream::connect((host.hostname.as_str(), host.port))
.err_into::<crypto::TransitHandshakeError>()
.await?;
log::debug!("Connected to {}!", host);
Ok((transit, TransitInfo::Relay { name }))
}
connectors = Box::new(
connectors.chain(
relay_hints
.into_iter()
.flat_map(|hint| {
let name = hint.name
.or_else(|| {
hint.tcp.iter()
.filter_map(|hint| match url::Host::parse(&hint.hostname) {
Ok(url::Host::Domain(_)) => Some(hint.hostname.clone()),
_ => None,
})
.next()
});
hint.tcp
.into_iter()
.take(3)
.enumerate()
.map(move |(i, h)| (i, h, name.clone()))
})
.map(|(index, host, name)| async move {
async_std::task::sleep(std::time::Duration::from_secs(
index as u64 * 5,
))
.await;
hint_connector(host, name).await
})
.map(|fut| Box::pin(fut) as ConnectorFuture),
),
) as BoxIterator<ConnectorFuture>;
}
let transit_key2 = transit_key.clone();
let tside2 = tside.clone();
let cryptor2 = cryptor.clone();
let mut connectors = Box::new(
connectors
.map(move |fut| {
let transit_key = transit_key2.clone();
let tside = tside2.clone();
let cryptor = cryptor2.clone();
async move {
let (socket, host_type) = fut.await?;
let transit = handshake_exchange(
is_leader,
tside,
socket,
&host_type,
&*cryptor,
transit_key,
)
.await?;
Ok((transit, host_type))
}
})
.map(|fut| {
Box::pin(fut)
as BoxFuture<
Result<(HandshakeResult, TransitInfo), crypto::TransitHandshakeError>,
>
}),
)
as BoxIterator<
BoxFuture<Result<(HandshakeResult, TransitInfo), crypto::TransitHandshakeError>>,
>;
if let Some(socket2) = socket2 {
connectors = Box::new(
connectors.chain(
std::iter::once(async move {
let transit_key = transit_key.clone();
let tside = tside.clone();
let cryptor = cryptor.clone();
let connect = || async {
let (stream, peer) = socket2.accept().await?;
log::debug!("Got connection from {}!", peer);
let transit = handshake_exchange(
is_leader,
tside.clone(),
stream,
&TransitInfo::Direct,
&*cryptor,
transit_key.clone(),
)
.await?;
Result::<_, crypto::TransitHandshakeError>::Ok((
transit,
TransitInfo::Direct,
))
};
loop {
match connect().await {
Ok(success) => break Ok(success),
Err(err) => {
log::debug!(
"Some handshake failed on the listening port: {:?}",
err
);
continue;
},
}
}
})
.map(|fut| {
Box::pin(fut)
as BoxFuture<
Result<
(HandshakeResult, TransitInfo),
crypto::TransitHandshakeError,
>,
>
}),
),
)
as BoxIterator<
BoxFuture<
Result<(HandshakeResult, TransitInfo), crypto::TransitHandshakeError>,
>,
>;
}
connectors.collect::<futures::stream::futures_unordered::FuturesUnordered<_>>()
}
}
pub struct Transit {
socket: TcpStream,
tx: Box<dyn crypto::TransitCryptoEncrypt>,
rx: Box<dyn crypto::TransitCryptoDecrypt>,
}
impl Transit {
pub async fn receive_record(&mut self) -> Result<Box<[u8]>, TransitError> {
self.rx.decrypt(&mut self.socket).await
}
pub async fn send_record(&mut self, plaintext: &[u8]) -> Result<(), TransitError> {
assert!(!plaintext.is_empty());
self.tx.encrypt(&mut self.socket, plaintext).await
}
pub async fn flush(&mut self) -> Result<(), TransitError> {
log::debug!("Flush");
self.socket.flush().await.map_err(Into::into)
}
pub fn split(
self,
) -> (
impl futures::sink::Sink<Box<[u8]>, Error = TransitError>,
impl futures::stream::Stream<Item = Result<Box<[u8]>, TransitError>>,
) {
use futures::io::AsyncReadExt;
let (reader, writer) = self.socket.split();
(
futures::sink::unfold(
(writer, self.tx),
|(mut writer, mut tx), plaintext: Box<[u8]>| async move {
tx.encrypt(&mut writer, &plaintext)
.await
.map(|()| (writer, tx))
},
),
futures::stream::try_unfold((reader, self.rx), |(mut reader, mut rx)| async move {
rx.decrypt(&mut reader)
.await
.map(|record| Some((record, (reader, rx))))
}),
)
}
}
type HandshakeResult = (TcpStream, Box<dyn crypto::TransitCryptoInitFinalizer>);
async fn handshake_exchange(
is_leader: bool,
tside: Arc<String>,
socket: TcpStream,
host_type: &TransitInfo,
cryptor: &dyn crypto::TransitCryptoInit,
key: Arc<Key<TransitKey>>,
) -> Result<HandshakeResult, crypto::TransitHandshakeError> {
let socket = std::net::TcpStream::try_from(socket)
.expect("Internal error: this should not fail because we never cloned the socket");
socket.set_write_timeout(Some(std::time::Duration::from_secs(120)))?;
socket.set_read_timeout(Some(std::time::Duration::from_secs(120)))?;
let mut socket: TcpStream = socket.into();
if host_type != &TransitInfo::Direct {
log::trace!("initiating relay handshake");
let sub_key = key.derive_subkey_from_purpose::<crate::GenericKey>("transit_relay_token");
socket
.write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes())
.await?;
let mut rx = [0u8; 3];
socket.read_exact(&mut rx).await?;
let ok_msg: [u8; 3] = *b"ok\n";
ensure!(
ok_msg == rx,
crypto::TransitHandshakeError::RelayHandshakeFailed
);
}
let finalizer = if is_leader {
cryptor.handshake_leader(&mut socket).await?
} else {
cryptor.handshake_follower(&mut socket).await?
};
Ok((socket, finalizer))
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::json;
#[test]
pub fn test_abilities_encoding() {
assert_eq!(
serde_json::to_value(Abilities::ALL_ABILITIES).unwrap(),
json!([{"type": "direct-tcp-v1"}, {"type": "relay-v1"}])
);
assert_eq!(
serde_json::to_value(Abilities::FORCE_DIRECT).unwrap(),
json!([{"type": "direct-tcp-v1"}])
);
}
#[test]
pub fn test_hints_encoding() {
assert_eq!(
serde_json::to_value(Hints::new(
[DirectHint {
hostname: "localhost".into(),
port: 1234
}],
[RelayHint::new(
Some("default".into()),
[DirectHint::new("transit.magic-wormhole.io", 4001)],
["ws://transit.magic-wormhole.io/relay".parse().unwrap(),],
)]
))
.unwrap(),
json!([
{
"type": "direct-tcp-v1",
"hostname": "localhost",
"port": 1234
},
{
"type": "relay-v1",
"name": "default",
"hints": [
{
"type": "direct-tcp-v1",
"hostname": "transit.magic-wormhole.io",
"port": 4001,
},
{
"type": "websocket",
"url": "ws://transit.magic-wormhole.io/relay",
},
]
}
])
)
}
}