use super::{
ManagerContext, manager::handle_manager_as_server, rate_limiter::RateLimiter,
worker::handle_worker_join,
};
use crate::config::ServerConfigOpts;
use eyre::Report;
use quinn::Endpoint;
use socket2::{SockRef, TcpKeepalive};
use std::{
sync::Arc,
sync::atomic::{AtomicUsize, Ordering},
};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::Duration;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, warn};
use volli_core::env_config;
use volli_transport::{QuicTransport, TcpTransport, Transport};
pub(crate) fn try_increment_connection(active: &Arc<AtomicUsize>, max: usize) -> bool {
loop {
let current = active.load(Ordering::SeqCst);
if current >= max {
return false;
}
if active
.compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
return true;
}
}
}
async fn dispatch_conn(
ctx: ManagerContext,
proto: Option<&str>,
transport: Box<dyn Transport>,
addr: std::net::SocketAddr,
worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
tracing::debug!(
"Dispatching connection from {} with ALPN protocol: {:?}",
addr,
proto
);
match proto {
Some("volli/worker") => {
tracing::debug!("Handling worker client connection from {}", addr);
if !worker_limiter.lock().await.check(addr.ip()) {
warn!(peer=%addr, "worker connection rate limit exceeded");
return;
}
handle_worker_join(&ctx, transport, addr).await.ok();
}
Some("volli/manager") => {
tracing::debug!("Handling manager client connection from {}", addr);
handle_manager_as_server(&ctx, transport, addr).await.ok();
}
_ => {
tracing::warn!(
"Unknown or missing ALPN protocol {:?} from {}, dropping connection",
proto,
addr
);
}
}
}
async fn process_tcp(
ctx: ManagerContext,
stream: TcpStream,
addr: std::net::SocketAddr,
acceptor: TlsAcceptor,
handshake_timeout: Duration,
worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
let stream = match stream.into_std() {
Ok(s) => {
let keepalive = TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(
env_config().heartbeat_secs(),
))
.with_interval(std::time::Duration::from_secs(
env_config().heartbeat_secs(),
));
let sock = SockRef::from(&s);
sock.set_keepalive(true).ok();
sock.set_tcp_keepalive(&keepalive).ok();
s.set_nonblocking(true).ok();
match TcpStream::from_std(s) {
Ok(ts) => ts,
Err(e) => {
warn!("failed to construct tokio stream: {}", e);
return;
}
}
}
Err(e) => {
warn!("failed to convert tcp stream: {}", e);
return;
}
};
if let Ok(Ok(tls)) = tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
let proto = tls
.get_ref()
.1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
dispatch_conn(
ctx,
proto.as_deref(),
Box::new(TcpTransport::new(tls)),
addr,
worker_limiter,
)
.await;
}
}
async fn process_quic(
ctx: ManagerContext,
connecting: quinn::Connecting,
addr: std::net::SocketAddr,
handshake_timeout: Duration,
worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
if let Ok(Ok(conn)) = tokio::time::timeout(handshake_timeout, connecting).await {
debug!("Opening QUIC connection");
let protocol = conn
.handshake_data()
.and_then(|d| d.downcast::<quinn::crypto::rustls::HandshakeData>().ok())
.and_then(|hd| hd.protocol.clone());
match tokio::time::timeout(handshake_timeout, conn.accept_bi()).await {
Ok(Ok((send, recv))) => {
let proto_str = protocol
.as_deref()
.and_then(|p| std::str::from_utf8(p).ok());
dispatch_conn(
ctx,
proto_str,
Box::new(QuicTransport::new(send, recv)),
addr,
worker_limiter,
)
.await;
}
_ => {
tracing::debug!("QUIC handshake ok but stream accept failed/timed out; closing");
conn.close(0u32.into(), b"accept-timeout");
}
}
}
}
pub(crate) async fn accept_loop(
ctx: ManagerContext,
cfg: ServerConfigOpts,
tcp_listener: TcpListener,
tls_acceptor: TlsAcceptor,
quic_endpoint: Endpoint,
active_connections: Arc<AtomicUsize>,
) -> Result<(), Report> {
let base_ctx = ctx;
let mut limiter = RateLimiter::new(cfg.max_conn_rate, cfg.conn_rate_interval);
let worker_limit = if cfg.max_worker_conn_rate > 0 {
cfg.max_worker_conn_rate
} else {
cfg.max_conn_rate
};
let worker_interval = if cfg.max_worker_conn_rate > 0 {
cfg.worker_conn_rate_interval
} else {
cfg.conn_rate_interval
};
let worker_limiter = Arc::new(tokio::sync::Mutex::new(RateLimiter::new(
worker_limit,
worker_interval,
)));
let handshake_timeout = Duration::from_millis(cfg.handshake_timeout_ms);
loop {
tokio::select! {
Ok((stream, addr)) = tcp_listener.accept() => {
if !limiter.check(addr.ip()) {
warn!(peer=%addr, "connection rate limit exceeded");
drop(stream);
continue;
}
if !try_increment_connection(&active_connections, cfg.max_connections) {
warn!(peer=%addr, "connection limit reached");
drop(stream);
continue;
}
tracing::trace!(peer=%addr, "Accepted TCP connection");
let counter = active_connections.clone();
let ctx = base_ctx.clone();
let acceptor = tls_acceptor.clone();
let timeout = handshake_timeout;
let worker_limiter = worker_limiter.clone();
tokio::spawn(async move {
process_tcp(ctx, stream, addr, acceptor, timeout, worker_limiter).await;
counter.fetch_sub(1, Ordering::SeqCst);
});
}
Some(incoming) = quic_endpoint.accept() => {
let addr = incoming.remote_address();
if !limiter.check(addr.ip()) {
warn!(peer=%addr, "connection rate limit exceeded");
incoming.refuse();
continue;
}
if !try_increment_connection(&active_connections, cfg.max_connections) {
warn!(peer=%addr, "connection limit reached");
incoming.refuse();
continue;
}
let connecting = match incoming.accept() {
Ok(connecting) => connecting,
Err(err) => {
warn!(peer=%addr, "failed to accept incoming QUIC connection: {}", err);
active_connections.fetch_sub(1, Ordering::SeqCst);
continue;
}
};
tracing::trace!(peer=%addr, "Accepted QUIC connection");
let counter = active_connections.clone();
let ctx = base_ctx.clone();
let timeout = handshake_timeout;
let worker_limiter = worker_limiter.clone();
tokio::spawn(async move {
process_quic(ctx, connecting, addr, timeout, worker_limiter).await;
counter.fetch_sub(1, Ordering::SeqCst);
});
}
}
}
}
#[cfg(test)]
mod tests {
use super::try_increment_connection;
use std::sync::{Arc, atomic::AtomicUsize};
#[test]
fn refuses_when_limit_reached() {
let active = Arc::new(AtomicUsize::new(0));
assert!(try_increment_connection(&active, 1));
assert!(!try_increment_connection(&active, 1));
}
#[test]
fn allows_after_decrement() {
let active = Arc::new(AtomicUsize::new(0));
assert!(try_increment_connection(&active, 1));
active.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
assert!(try_increment_connection(&active, 1));
}
}