use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use futures::prelude::*;
use tracing::{debug_span, info, trace, warn};
use tracing_futures::Instrument;
use crate::runtime::{
AsyncTcpStream, AsyncTcpStreamRead, AsyncTcpStreamReadExt, AsyncTcpStreamWriteExt,
AsyncUdpSocket, AsyncUdpSocketExt, Runtime,
};
use crate::utils::DebugWrapper;
use rice_c::candidate::TransportType;
pub(crate) struct Transmit<T: AsRef<[u8]> + std::fmt::Debug> {
pub transport: TransportType,
pub from: SocketAddr,
pub to: SocketAddr,
pub data: T,
}
impl<T: AsRef<[u8]> + std::fmt::Debug> Transmit<T> {
pub fn new(data: T, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self {
Self {
data,
transport,
from,
to,
}
}
}
const MAX_STUN_MESSAGE_SIZE: usize = 1500 * 2;
#[derive(Debug, Clone)]
pub enum StunChannel {
Udp(UdpSocketChannel),
Tcp(TcpChannel),
}
#[derive(Debug, Clone)]
struct DataAddress {
pub data: Vec<u8>,
pub address: SocketAddr,
}
impl DataAddress {
fn new(data: Vec<u8>, address: SocketAddr) -> Self {
Self { data, address }
}
}
impl StunChannel {
pub async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
StunChannel::Udp(c) => c.close(),
StunChannel::Tcp(c) => c.close().await,
}
}
pub fn transport(&self) -> TransportType {
match self {
StunChannel::Udp(_) => TransportType::Udp,
StunChannel::Tcp(_) => TransportType::Tcp,
}
}
pub async fn send_to(&mut self, data: &[u8], to: SocketAddr) -> Result<(), std::io::Error> {
match self {
StunChannel::Udp(udp) => udp.send_to(data, to).await,
StunChannel::Tcp(tcp) => tcp.send_to(data, to).await,
}
}
pub fn recv(&mut self) -> impl Stream<Item = (Vec<u8>, SocketAddr)> + '_ {
match self {
StunChannel::Udp(udp) => udp.recv().left_stream(),
StunChannel::Tcp(tcp) => tcp.recv().right_stream(),
}
}
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
StunChannel::Udp(c) => c.local_addr(),
StunChannel::Tcp(c) => c.local_addr(),
}
}
pub fn remote_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
StunChannel::Udp(_) => Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"connection-less udp doesn't have a remote addr",
)),
StunChannel::Tcp(c) => c.remote_addr(),
}
}
}
#[derive(Debug, Clone)]
pub struct UdpSocketChannel {
socket: Arc<dyn AsyncUdpSocket>,
inner: DebugWrapper<Arc<Mutex<UdpSocketChannelInner>>>,
}
#[derive(Debug)]
struct UdpSocketChannelInner {
closed: bool,
}
impl UdpSocketChannel {
pub fn new(socket: Arc<dyn AsyncUdpSocket>) -> Self {
Self {
socket,
inner: DebugWrapper::wrap(
Arc::new(Mutex::new(UdpSocketChannelInner { closed: false })),
"...",
),
}
}
pub async fn send_to(&self, data: &[u8], to: SocketAddr) -> std::io::Result<()> {
{
let inner = self.inner.lock().unwrap();
if inner.closed {
return Err(std::io::Error::other("Connection closed"));
}
}
trace!(
"udp socket send_to {:?} bytes from {:?} to {:?}",
data.len(),
self.local_addr(),
to
);
self.socket.send_to(data, to).await?;
Ok(())
}
pub fn close(&self) -> Result<(), std::io::Error> {
{
let mut inner = self.inner.lock().unwrap();
inner.closed = true;
};
Ok(())
}
pub fn recv(&self) -> impl Stream<Item = (Vec<u8>, SocketAddr)> + '_ {
stream::unfold(self.clone(), |this| async move {
let mut buf = vec![0; 2048];
let (size, from) = this.socket.recv_from(&mut buf).await.unwrap();
let ret = buf.split_at(size).0.to_vec();
Some(((ret, from), this))
})
}
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.socket.local_addr()
}
}
#[derive(Debug, Clone)]
pub struct TcpChannel {
read_channel: Arc<Mutex<Option<futures::channel::mpsc::Receiver<DataAddress>>>>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
sender_channel: futures::channel::mpsc::Sender<TcpData>,
}
#[derive(Debug)]
enum TcpData {
Data(Vec<u8>),
Shutdown,
}
impl TcpChannel {
pub fn new(runtime: Arc<dyn Runtime>, stream: Box<dyn AsyncTcpStream>) -> Self {
let local_addr = stream.local_addr().unwrap();
let remote_addr = stream.remote_addr().unwrap();
let (send_tx, send_rx) = futures::channel::mpsc::channel::<TcpData>(1);
let (mut read, mut write) = stream.split();
runtime.spawn(Box::pin({
async move {
let mut send_rx = core::pin::pin!(send_rx);
while let Some(data) = send_rx.next().await {
match data {
TcpData::Data(data) => {
if let Err(e) = write.write_all(&data).await {
warn!("tcp write produced error {e:?}");
break;
}
}
TcpData::Shutdown => {
if let Err(e) = write.shutdown(std::net::Shutdown::Both).await {
warn!("tcp shutdown produced error {e:?}");
}
break;
}
}
}
}
}));
let (mut recv_tx, recv_rx) = futures::channel::mpsc::channel::<DataAddress>(1);
runtime.spawn(Box::pin(async move {
while let Ok(data_addr) = Self::inner_recv(&mut read).await {
if recv_tx.send(data_addr).await.is_err() {
break;
}
}
}));
Self {
local_addr,
remote_addr,
read_channel: Arc::new(Mutex::new(Some(recv_rx))),
sender_channel: send_tx,
}
}
#[tracing::instrument(
name = "tcp_single_recv",
skip(stream),
fields(
remote.addr = ?stream.remote_addr()
)
)]
async fn inner_recv(
stream: &mut Box<dyn AsyncTcpStreamRead>,
) -> Result<DataAddress, std::io::Error> {
let from = stream.remote_addr()?;
let mut data = vec![0; MAX_STUN_MESSAGE_SIZE];
match stream.read(&mut data).await {
Ok(size) => {
trace!("recved {} bytes", size);
if size == 0 {
info!("connection closed");
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"TCP connection closed",
));
}
trace!("return {} bytes", size);
return Ok(DataAddress::new(data[..size].to_vec(), from));
}
Err(e) => return Err(e),
}
}
pub async fn close(&mut self) -> Result<(), std::io::Error> {
self.sender_channel
.send(TcpData::Shutdown)
.await
.map_err(|e| {
if e.is_disconnected() {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Disconnected")
} else {
unreachable!();
}
})
}
pub async fn send_to(&mut self, data: &[u8], to: SocketAddr) -> Result<(), std::io::Error> {
if to != self.remote_addr()? {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Address to send to is different from connected address",
));
}
if data.len() > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"data length too large for transport",
));
}
self.sender_channel
.send(TcpData::Data(data.to_vec()))
.await
.map_err(|_| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))
}
pub fn recv(&mut self) -> impl Stream<Item = (Vec<u8>, SocketAddr)> + '_ {
let span = debug_span!("tcp_recv");
let chan = self
.read_channel
.lock()
.unwrap()
.take()
.expect("Receiver already taken!");
chan.map(|v| (v.data, v.address))
.instrument(span.or_current())
}
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
Ok(self.local_addr)
}
pub fn remote_addr(&self) -> Result<SocketAddr, std::io::Error> {
Ok(self.remote_addr)
}
}