use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::Arc,
};
use async_trait::async_trait;
use bytecodec::DecodeExt;
use bytes::BytesMut;
use futures::StreamExt;
use stun_codec::MessageDecoder;
use tokio::{
io::AsyncWriteExt as _,
net::{TcpListener, TcpStream},
sync::{mpsc, mpsc::error::TrySendError, oneshot, Mutex},
};
use tokio_util::codec::{Decoder, FramedRead};
use crate::{
attr::{Attribute, PROTO_TCP},
chandata::{nearest_padded_value_length, ChannelData},
};
use super::{Error, Request, Transport};
type TcpWritersMap = Arc<
Mutex<
HashMap<
SocketAddr,
mpsc::Sender<(Vec<u8>, oneshot::Sender<Result<(), Error>>)>,
>,
>,
>;
#[derive(Debug)]
pub struct Server {
ingress_rx: Mutex<mpsc::Receiver<(Request, SocketAddr)>>,
local_addr: SocketAddr,
writers: TcpWritersMap,
}
#[async_trait]
impl Transport for Server {
async fn recv_from(&self) -> Result<(Request, SocketAddr), Error> {
let req_and_addr = self.ingress_rx.lock().await.recv().await;
if let Some((data, addr)) = req_and_addr {
Ok((data, addr))
} else {
Err(Error::TransportIsDead)
}
}
async fn send_to(
&self,
data: Vec<u8>,
target: SocketAddr,
) -> Result<(), Error> {
let mut writers = self.writers.lock().await;
match writers.entry(target) {
Entry::Occupied(mut e) => {
let (res_tx, res_rx) = oneshot::channel();
if e.get_mut().send((data, res_tx)).await.is_err() {
drop(e.remove_entry());
Err(Error::TransportIsDead)
} else {
#[expect( // intentional
clippy::map_err_ignore,
reason = "only errors on channel closing",
)]
res_rx.await.map_err(|_| Error::TransportIsDead)?
}
}
Entry::Vacant(_) => Err(Error::TransportIsDead),
}
}
fn local_addr(&self) -> SocketAddr {
self.local_addr
}
fn proto(&self) -> u8 {
PROTO_TCP
}
}
impl Server {
pub fn new(listener: TcpListener) -> Result<Self, Error> {
let local_addr = listener.local_addr()?;
let (ingress_tx, ingress_rx) = mpsc::channel(256);
let writers = Arc::new(Mutex::new(HashMap::new()));
drop(tokio::spawn({
let writers = Arc::clone(&writers);
async move {
loop {
tokio::select! {
stream = listener.accept() => {
match stream {
Ok((stream, remote)) => {
Self::spawn_stream_handler(
stream,
local_addr,
remote,
ingress_tx.clone(),
Arc::clone(&writers),
);
},
Err(_) => {
break;
}
}
}
() = ingress_tx.closed() => {
break;
}
}
}
log::debug!("Closing `TcpListener` at {local_addr}");
}
}));
Ok(Self { ingress_rx: Mutex::new(ingress_rx), local_addr, writers })
}
fn spawn_stream_handler(
mut stream: TcpStream,
local: SocketAddr,
remote: SocketAddr,
ingress_tx: mpsc::Sender<(Request, SocketAddr)>,
writers: TcpWritersMap,
) {
drop(tokio::spawn(async move {
let (egress_tx, mut egress_rx) = mpsc::channel::<(
Vec<u8>,
oneshot::Sender<Result<(), Error>>,
)>(256);
drop(writers.lock().await.insert(remote, egress_tx));
let (reader, mut writer) = stream.split();
let mut reader = FramedRead::new(reader, Codec::default());
loop {
tokio::select! {
msg = egress_rx.recv() => {
if let Some((msg, tx)) = msg {
let res =
writer.write_all(msg.as_slice()).await
.map_err(Error::from);
drop(tx.send(res));
} else {
log::debug!("Closing TCP {local} <=> {remote}");
break;
}
},
msg = reader.next() => {
match msg {
Some(Ok(msg)) => {
match ingress_tx.try_send((msg, remote)) {
Ok(()) => {},
Err(TrySendError::Full(_)) => {
log::debug!(
"Dropped ingress message from TCP \
{local} <=> {remote}",
);
}
Err(TrySendError::Closed(_)) =>
{
log::debug!(
"Closing TCP {local} <=> {remote}",
);
break;
}
}
}
Some(Err(_)) => {},
None => {
log::debug!("Closing TCP {local} <=> {remote}");
break;
}
}
},
}
}
}));
}
}
#[derive(Clone, Copy, Debug)]
enum RequestKind {
Message(usize),
ChannelData(usize),
}
impl RequestKind {
fn detect_kind(first_4_bytes: [u8; 4]) -> Self {
let size = usize::from(u16::from_be_bytes([
first_4_bytes[2],
first_4_bytes[3],
]));
if first_4_bytes[0] & 0b1100_0000 == 0 {
Self::Message(nearest_padded_value_length(size + 20))
} else {
Self::ChannelData(nearest_padded_value_length(size + 4))
}
}
const fn length(&self) -> usize {
*match self {
Self::Message(l) | Self::ChannelData(l) => l,
}
}
}
#[derive(Default)]
struct Codec {
current: Option<RequestKind>,
msg_decoder: MessageDecoder<Attribute>,
}
impl Decoder for Codec {
type Item = Request;
type Error = Error;
#[expect( // false positive
clippy::missing_asserts_for_indexing,
reason = "indexing is guarded with `if` condition"
)]
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if self.current.is_none() && src.len() >= 4 {
self.current = Some(RequestKind::detect_kind([
src[0], src[1], src[2], src[3],
]));
}
if let Some(current) = self.current {
if src.len() >= current.length() {
_ = self.current.take();
let raw = src.split_to(current.length());
let msg = match current {
RequestKind::Message(_) => {
let msg = self
.msg_decoder
.decode_from_bytes(&raw)
.map_err(|e| Error::Decode(*e.kind()))?
.map_err(|e| Error::Decode(*e.error().kind()))?;
Request::Message(msg)
}
RequestKind::ChannelData(_) => {
Request::ChannelData(ChannelData::decode(raw.to_vec())?)
}
};
return Ok(Some(msg));
}
}
Ok(None)
}
}