mod request;
use std::{collections::HashMap, sync::Arc};
use derive_more::with_trait::Debug;
use tokio::{
sync::{
broadcast::{
error::RecvError,
{self},
},
mpsc, oneshot,
},
time::Duration,
};
#[cfg(doc)]
use crate::allocation::Allocation;
use crate::{
AuthHandler, Error,
allocation::{FiveTuple, Info, Manager, ManagerConfig},
relay,
transport::{self, Transport},
};
pub(crate) const DEFAULT_LIFETIME: Duration = Duration::from_secs(10 * 60);
pub(crate) const INBOUND_MTU: usize = 1500;
#[derive(Debug)]
pub struct Config<A> {
#[debug("{:?}", connections.iter()
.map(|c| (c.local_addr(), c.proto()))
.collect::<Vec<_>>())]
pub connections: Vec<Arc<dyn Transport + Send + Sync>>,
pub relay_addr_generator: relay::Allocator,
pub realm: String,
pub auth_handler: Arc<A>,
pub channel_bind_lifetime: Duration,
pub alloc_close_notify: Option<mpsc::Sender<Info>>,
}
#[derive(Debug)]
pub struct Server {
command_tx: broadcast::Sender<Command>,
}
impl Server {
#[must_use]
pub fn new<A>(config: Config<A>) -> Self
where
A: AuthHandler + Send + Sync + 'static,
{
let (command_tx, _) = broadcast::channel(16);
let this = Self { command_tx: command_tx.clone() };
let channel_bind_lifetime =
if config.channel_bind_lifetime == Duration::from_secs(0) {
DEFAULT_LIFETIME
} else {
config.channel_bind_lifetime
};
for conn in config.connections {
let auth_handler = Arc::clone(&config.auth_handler);
let realm = config.realm.clone();
let mut nonces = HashMap::new();
let mut handle_rx = command_tx.subscribe();
let mut allocation_manager = Manager::new(ManagerConfig {
relay_addr_generator: config.relay_addr_generator.clone(),
alloc_close_notify: config.alloc_close_notify.clone(),
});
let (mut close_tx, mut close_rx) = oneshot::channel::<()>();
drop(tokio::spawn(async move {
let local_con_addr = conn.local_addr();
let protocol = conn.proto();
loop {
let (msg, src_addr) = tokio::select! {
cmd = handle_rx.recv() => {
match cmd {
Ok(Command::DeleteAllocations(
name,
completion,
)) => {
allocation_manager
.delete_allocations_by_username(
&name,
);
drop(completion);
}
Ok(Command::GetAllocationsInfo(
five_tuples,
tx,
)) => {
let infos = allocation_manager
.get_allocations_info(
five_tuples.as_ref(),
);
drop(tx.send(infos).await);
}
Err(RecvError::Closed) => {
close_rx.close();
break;
}
Err(RecvError::Lagged(n)) => {
log::warn!(
"`Server` has lagged by {n} messages",
);
}
}
continue;
},
v = conn.recv_from() => {
match v {
Ok(v) => v,
Err(e) if e.is_fatal() => {
log::error!(
"Exit `Server` read loop on transport \
recv error: {e}",
);
break;
}
Err(e) => {
log::debug!("`Request` parse error: {e}");
continue;
}
}
},
() = close_tx.closed() => break
};
let handle = request::handle(
msg,
&conn,
FiveTuple {
src_addr,
dst_addr: local_con_addr,
protocol,
},
&realm,
channel_bind_lifetime,
&mut allocation_manager,
&mut nonces,
&auth_handler,
);
if let Err(e) = handle.await {
log::warn!("Error when handling `Request`: {e}");
}
}
}));
}
this
}
pub async fn delete_allocations_by_username(
&self,
username: String,
) -> Result<(), Error> {
let (closed_tx, closed_rx) = mpsc::channel(1);
#[expect(clippy::map_err_ignore, reason = "only errors on closing")]
let _: usize = self
.command_tx
.send(Command::DeleteAllocations(username, Arc::new(closed_rx)))
.map_err(|_| Error::Closed)?;
closed_tx.closed().await;
Ok(())
}
pub async fn get_allocations_info(
&self,
five_tuples: Option<Vec<FiveTuple>>,
) -> Result<HashMap<FiveTuple, Info>, Error> {
if let Some(five_tuples) = &five_tuples {
if five_tuples.is_empty() {
return Ok(HashMap::new());
}
}
let (infos_tx, mut infos_rx) = mpsc::channel(1);
#[expect(clippy::map_err_ignore, reason = "only errors on closing")]
let _: usize = self
.command_tx
.send(Command::GetAllocationsInfo(five_tuples, infos_tx))
.map_err(|_| Error::Closed)?;
let mut info: HashMap<FiveTuple, Info> = HashMap::new();
for _ in 0..self.command_tx.receiver_count() {
info.extend(infos_rx.recv().await.ok_or(Error::Closed)?);
}
Ok(info)
}
}
#[derive(Clone)]
enum Command {
DeleteAllocations(String, Arc<mpsc::Receiver<()>>),
GetAllocationsInfo(
Option<Vec<FiveTuple>>,
mpsc::Sender<HashMap<FiveTuple, Info>>,
),
}
trait FatalError {
fn is_fatal(&self) -> bool;
}
impl FatalError for transport::Error {
fn is_fatal(&self) -> bool {
match self {
Self::Io(_) | Self::TransportIsDead => true,
Self::ChannelData(_) | Self::Decode(_) => false,
}
}
}