#![forbid(unsafe_code)]
#![warn(clippy::all)]
mod connection_handler;
mod state;
use anyhow::Context;
use anyhow::anyhow;
use clap::Parser;
use connection_handler::handle_connection;
use log::{debug, error, info, warn};
use socket2::{Domain, Protocol, TcpKeepalive, Type};
use state::State;
use std::future::Future;
use std::net::SocketAddr;
use std::{
io::BufReader,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use tokio::task::JoinSet;
use tokio_rustls::{
TlsAcceptor,
rustls::{self, pki_types::CertificateDer},
};
#[derive(Parser, Debug)]
#[command(author, version, about)]
pub struct Args {
#[arg(short, long, required_unless_present("unencrypted"))]
pub key: Option<PathBuf>,
#[arg(short, long, required_unless_present("unencrypted"))]
pub certificate: Option<PathBuf>,
#[arg(short, long, conflicts_with_all(["key", "certificate"]))]
pub unencrypted: bool,
#[arg(short, long, default_values = ["0.0.0.0:2311", "[::]:2311"])]
pub addresses: Vec<SocketAddr>,
#[arg(short, long, default_value = "600")]
pub timeout: u64,
#[arg(short, long, default_value = "10")]
pub request_limit: u32,
#[arg(short, long, default_value = "debug")]
pub verbosity: log::LevelFilter,
}
pub fn start_server(args: Args) -> anyhow::Result<(Vec<SocketAddr>, impl Future<Output = ()>)> {
if let Err(err) = env_logger::builder()
.filter_level(args.verbosity)
.try_init()
{
error!("Non-fatal error. Couldn't initialize logger: {err}")
}
let tcp_listeners: anyhow::Result<Vec<tokio::net::TcpListener>> =
args.addresses.into_iter().map(get_tcp_listener).collect();
let tcp_listeners = tcp_listeners?;
let addresses: std::io::Result<Vec<SocketAddr>> =
tcp_listeners.iter().map(|l| l.local_addr()).collect();
let addresses = addresses.context("Couldn't determine local address")?;
let tls_acceptor = if let (Some(key), Some(cert)) = (args.key, args.certificate) {
Some(get_tls_acceptor(&key, &cert)?)
} else {
None
};
let state = State::new(
args.request_limit,
std::time::Duration::from_secs(args.timeout),
);
let mut joinset = JoinSet::new();
for tcp_listener in tcp_listeners {
joinset.spawn(run_single_server(
state.clone(),
tcp_listener,
tls_acceptor.clone(),
));
}
let handle = async {
joinset.join_all().await;
};
info!("Listening on these addresses: {addresses:?}");
info!("Is encrypted?: {}", tls_acceptor.is_some());
info!(
"Critical requests per minute per IP address limit: {}",
args.request_limit
);
info!(
"Number of seconds before a new room is deleted: {}",
args.timeout
);
info!("Server is now running.");
Ok((addresses, handle))
}
async fn run_single_server(
state: State,
tcp_listener: tokio::net::TcpListener,
tls_acceptor: Option<TlsAcceptor>,
) {
loop {
let (stream, origin) = match tcp_listener.accept().await {
Ok(ok) => ok,
Err(err) => {
error!("Error accepting incoming TCP connection: {err}.");
continue;
}
};
debug!("Accepted incoming TCP connection from {origin}.");
tokio::spawn(handle_connection(
stream,
origin,
tls_acceptor.clone(),
state.clone(),
));
}
}
fn get_tcp_listener(addr: SocketAddr) -> anyhow::Result<tokio::net::TcpListener> {
let socket = socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
.context("Couldn't create TCP socket")?;
if addr.is_ipv6() {
socket
.set_only_v6(true)
.with_context(|| format!("Couldn't set IPV6_V6ONLY on {addr}"))?;
}
let tcp_keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
socket
.set_tcp_keepalive(&tcp_keepalive)
.context("Couldn't set TCP keepalive")?;
socket
.bind(&addr.into())
.with_context(|| format!("Couldn't bind socket to address {addr}"))?;
socket
.listen(128)
.with_context(|| format!("Couldn't listen on {addr}"))?;
let listener: std::net::TcpListener = socket.into();
listener
.set_nonblocking(true)
.context("Couldn't set TCP socket to non blocking")?;
let listener = tokio::net::TcpListener::from_std(listener)
.context("Couldn't create async TCP listener")?;
Ok(listener)
}
fn get_tls_acceptor(key_path: &Path, cert_path: &Path) -> anyhow::Result<TlsAcceptor> {
let key = std::fs::File::open(key_path)
.with_context(|| format!("Couldn't open key file {key_path:?}."))?;
let mut key = BufReader::new(key);
let key = rustls_pemfile::private_key(&mut key)
.with_context(|| format!("Couldn't parse key file {key_path:?}."))?
.ok_or(anyhow!("No private keys found in file {key_path:?}."))?;
let cert = std::fs::File::open(cert_path)
.with_context(|| format!("Couldn't open certificate file {cert_path:?}."))?;
let mut cert = BufReader::new(cert);
let cert: Result<Vec<CertificateDer<'static>>, _> = rustls_pemfile::certs(&mut cert).collect();
let cert = cert.with_context(|| format!("Couldn't parse certificate file {cert_path:?}."))?;
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)
.context("Couldn't configure TLS")?;
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config)))
}