use super::{ConnectionType, TransitConnection, TransitHandshakeError, TransitInfo};
#[cfg(not(target_family = "wasm"))]
use super::{DirectHint, StunError};
#[cfg(not(target_family = "wasm"))]
use async_net::TcpStream;
use async_trait::async_trait;
use futures::{
future::TryFutureExt,
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
};
#[cfg(not(target_family = "wasm"))]
use std::{
net::{IpAddr, SocketAddr, ToSocketAddrs},
sync::Arc,
};
#[async_trait]
pub(super) trait TransitTransportRx: AsyncRead + std::any::Any + Unpin + Send {
async fn read_expect(&mut self, expected: &[u8]) -> Result<(), TransitHandshakeError> {
let mut buffer = vec![0u8; expected.len()];
self.read_exact(&mut buffer).await?;
ensure!(buffer == expected, TransitHandshakeError::HandshakeFailed);
Ok(())
}
async fn read_transit_message(&mut self) -> Result<Vec<u8>, std::io::Error> {
let length = {
let mut length_arr: [u8; 4] = [0; 4];
self.read_exact(&mut length_arr[..]).await?;
u32::from_be_bytes(length_arr) as usize
};
let mut buffer = Vec::with_capacity(length);
let len = self.take(length as u64).read_to_end(&mut buffer).await?;
use std::io::{Error, ErrorKind};
ensure!(
len == length,
Error::new(ErrorKind::UnexpectedEof, "failed to read whole message")
);
Ok(buffer)
}
}
#[async_trait]
pub(super) trait TransitTransportTx: AsyncWrite + std::any::Any + Unpin + Send {
async fn write_transit_message(&mut self, message: &[u8]) -> Result<(), std::io::Error> {
self.write_all(&(message.len() as u32).to_be_bytes())
.await?;
self.write_all(message).await
}
}
pub(super) trait TransitTransport: TransitTransportRx + TransitTransportTx {}
impl<T> TransitTransportRx for T where T: AsyncRead + std::any::Any + Unpin + Send {}
impl<T> TransitTransportTx for T where T: AsyncWrite + std::any::Any + Unpin + Send {}
impl<T> TransitTransport for T where T: AsyncRead + AsyncWrite + std::any::Any + Unpin + Send {}
#[cfg(not(target_family = "wasm"))]
pub(super) fn set_socket_opts(socket: &socket2::Socket) -> std::io::Result<()> {
socket.set_nonblocking(true)?;
socket.set_only_v6(false)?;
socket.set_reuse_address(true)?;
#[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))]
{
socket.set_reuse_port(true)?;
}
#[cfg(not(any(
all(unix, not(any(target_os = "solaris", target_os = "illumos"))),
target_os = "windows"
)))]
{
compile_error!("Your system is not supported yet, please raise an error");
}
Ok(())
}
#[cfg(not(target_family = "wasm"))]
pub(super) async fn tcp_get_external_ip() -> Result<(SocketAddr, TcpStream), StunError> {
let mut socket = tcp_connect_custom(
&"[::]:0".parse::<SocketAddr>().unwrap().into(),
&super::PUBLIC_STUN_SERVER
.to_socket_addrs()?
.find(|x| x.is_ipv4())
.map(|addr| match addr {
SocketAddr::V4(v4) => {
SocketAddr::new(IpAddr::V6(v4.ip().to_ipv6_mapped()), v4.port())
},
SocketAddr::V6(_) => unreachable!(),
})
.ok_or(StunError::ServerIsV6Only)?
.into(),
)
.await?;
use bytecodec::{DecodeExt, EncodeExt};
use stun_codec::{
Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId,
rfc5389::{
self, Attribute,
attributes::{MappedAddress, Software, XorMappedAddress},
},
};
fn get_binding_request() -> Result<Vec<u8>, bytecodec::Error> {
use rand::Rng;
let random_bytes = rand::thread_rng().r#gen::<[u8; 12]>();
let mut message: Message<Attribute> = Message::new(
MessageClass::Request,
rfc5389::methods::BINDING,
TransactionId::new(random_bytes),
);
message.add_attribute(Attribute::Software(Software::new(
"magic-wormhole-rust".to_owned(),
)?));
let mut encoder = MessageEncoder::new();
let bytes = encoder.encode_into_bytes(message.clone())?;
Ok(bytes)
}
fn decode_address(buf: &[u8]) -> Result<Option<SocketAddr>, bytecodec::Error> {
let mut decoder = MessageDecoder::<Attribute>::new();
let decoded = decoder.decode_from_bytes(buf)??;
let external_addr1 = decoded
.get_attribute::<XorMappedAddress>()
.map(|x| x.address());
let external_addr3 = decoded
.get_attribute::<MappedAddress>()
.map(|x| x.address());
let external_addr = external_addr1
.or(external_addr3);
Ok(external_addr)
}
socket.write_all(get_binding_request()?.as_ref()).await?;
let mut buf = [0u8; 256];
socket.read_exact(&mut buf[..20]).await?;
let len: u16 = u16::from_be_bytes([buf[2], buf[3]]);
socket.read_exact(&mut buf[20..][..len as usize]).await?;
let external_addr =
decode_address(&buf[..20 + len as usize])?.ok_or(StunError::ServerNoResponse)?;
Ok((external_addr, socket))
}
#[cfg(not(target_family = "wasm"))]
async fn tcp_connect_custom(
local_addr: &socket2::SockAddr,
dest_addr: &socket2::SockAddr,
) -> std::io::Result<async_net::TcpStream> {
tracing::debug!("Binding to {}", local_addr.as_socket().unwrap());
let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?;
set_socket_opts(&socket)?;
socket.bind(local_addr)?;
match socket.connect(dest_addr) {
Ok(_) => {},
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {},
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(err),
}
let stream = async_io::Async::new(std::net::TcpStream::from(socket))?;
stream.writable().await?;
stream
.get_ref()
.take_error()
.and_then(|maybe_err| maybe_err.map_or(Ok(()), Result::Err))?;
stream.into_inner()?.try_into()
}
#[cfg(not(target_family = "wasm"))]
pub(super) async fn connect_tcp_direct(
local_addr: Option<Arc<socket2::SockAddr>>,
hint: DirectHint,
) -> Result<TransitConnection, TransitHandshakeError> {
let dest_addr = SocketAddr::try_from(&hint)?;
tracing::debug!("Connecting directly to {}", dest_addr);
let socket;
if let Some(local_addr) = local_addr {
socket = tcp_connect_custom(&local_addr, &dest_addr.into()).await?;
tracing::debug!("Connected to {}!", dest_addr);
} else {
socket = async_net::TcpStream::connect(&dest_addr).await?;
tracing::debug!("Connected to {}!", dest_addr);
}
wrap_tcp_connection(socket, ConnectionType::Direct)
}
#[cfg(not(target_family = "wasm"))]
pub(super) async fn connect_tcp_relay(
host: DirectHint,
name: Option<String>,
) -> Result<TransitConnection, TransitHandshakeError> {
tracing::debug!("Connecting to relay {}", host);
let socket = TcpStream::connect((host.hostname.as_str(), host.port))
.err_into::<TransitHandshakeError>()
.await?;
tracing::debug!("Connected to {}!", host);
wrap_tcp_connection(socket, ConnectionType::Relay { name })
}
#[cfg(target_family = "wasm")]
pub(super) async fn connect_ws_relay(
url: url::Url,
name: Option<String>,
) -> Result<TransitConnection, TransitHandshakeError> {
tracing::debug!("Connecting to relay {}", url);
let (_meta, transit) = ws_stream_wasm::WsMeta::connect(&url, None)
.err_into::<TransitHandshakeError>()
.await?;
tracing::debug!("Connected to {}!", url);
let transit = Box::new(transit.into_io()) as Box<dyn TransitTransport>;
Ok((
transit,
TransitInfo {
conn_type: ConnectionType::Relay { name },
},
))
}
#[cfg(not(target_family = "wasm"))]
pub(super) fn wrap_tcp_connection(
socket: TcpStream,
conn_type: ConnectionType,
) -> Result<TransitConnection, TransitHandshakeError> {
let async_stream: Arc<async_io::Async<std::net::TcpStream>> = socket.into();
let stream = Arc::into_inner(async_stream)
.expect("Internal error: this should not fail because we never cloned the socket")
.into_inner()?;
stream.set_write_timeout(Some(std::time::Duration::from_secs(120)))?;
stream.set_read_timeout(Some(std::time::Duration::from_secs(120)))?;
let socket: TcpStream = stream.try_into()?;
let info = TransitInfo {
conn_type,
peer_addr: socket
.peer_addr()
.expect("Internal error: socket must be IP"),
};
Ok((Box::new(socket), info))
}