volli-manager 0.1.12

Manager for volli
Documentation
use super::{
    ManagerContext, manager::handle_manager_as_server, rate_limiter::RateLimiter,
    worker::handle_worker_join,
};
use crate::config::ServerConfigOpts;
use eyre::Report;

use quinn::Endpoint;
use socket2::{SockRef, TcpKeepalive};
use std::{
    sync::Arc,
    sync::atomic::{AtomicUsize, Ordering},
};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::Duration;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, warn};
use volli_core::env_config;
use volli_transport::{QuicTransport, TcpTransport, Transport};

pub(crate) fn try_increment_connection(active: &Arc<AtomicUsize>, max: usize) -> bool {
    loop {
        let current = active.load(Ordering::SeqCst);
        if current >= max {
            return false;
        }
        if active
            .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::SeqCst)
            .is_ok()
        {
            return true;
        }
    }
}

async fn dispatch_conn(
    ctx: ManagerContext,
    proto: Option<&str>,
    transport: Box<dyn Transport>,
    addr: std::net::SocketAddr,
    worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
    tracing::debug!(
        "Dispatching connection from {} with ALPN protocol: {:?}",
        addr,
        proto
    );

    match proto {
        Some("volli/worker") => {
            tracing::debug!("Handling worker client connection from {}", addr);
            if !worker_limiter.lock().await.check(addr.ip()) {
                warn!(peer=%addr, "worker connection rate limit exceeded");
                return;
            }
            handle_worker_join(&ctx, transport, addr).await.ok();
        }
        Some("volli/manager") => {
            tracing::debug!("Handling manager client connection from {}", addr);
            handle_manager_as_server(&ctx, transport, addr).await.ok();
        }
        _ => {
            tracing::warn!(
                "Unknown or missing ALPN protocol {:?} from {}, dropping connection",
                proto,
                addr
            );
        }
    }
}

async fn process_tcp(
    ctx: ManagerContext,
    stream: TcpStream,
    addr: std::net::SocketAddr,
    acceptor: TlsAcceptor,
    handshake_timeout: Duration,
    worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
    let stream = match stream.into_std() {
        Ok(s) => {
            let keepalive = TcpKeepalive::new()
                .with_time(std::time::Duration::from_secs(
                    env_config().heartbeat_secs(),
                ))
                .with_interval(std::time::Duration::from_secs(
                    env_config().heartbeat_secs(),
                ));
            let sock = SockRef::from(&s);
            sock.set_keepalive(true).ok();
            sock.set_tcp_keepalive(&keepalive).ok();
            s.set_nonblocking(true).ok();
            match TcpStream::from_std(s) {
                Ok(ts) => ts,
                Err(e) => {
                    warn!("failed to construct tokio stream: {}", e);
                    return;
                }
            }
        }
        Err(e) => {
            warn!("failed to convert tcp stream: {}", e);
            return;
        }
    };
    if let Ok(Ok(tls)) = tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
        let proto = tls
            .get_ref()
            .1
            .alpn_protocol()
            .map(|p| String::from_utf8_lossy(p).to_string());
        dispatch_conn(
            ctx,
            proto.as_deref(),
            Box::new(TcpTransport::new(tls)),
            addr,
            worker_limiter,
        )
        .await;
    }
}

async fn process_quic(
    ctx: ManagerContext,
    connecting: quinn::Connecting,
    addr: std::net::SocketAddr,
    handshake_timeout: Duration,
    worker_limiter: Arc<tokio::sync::Mutex<RateLimiter>>,
) {
    if let Ok(Ok(conn)) = tokio::time::timeout(handshake_timeout, connecting).await {
        debug!("Opening QUIC connection");
        let protocol = conn
            .handshake_data()
            .and_then(|d| d.downcast::<quinn::crypto::rustls::HandshakeData>().ok())
            .and_then(|hd| hd.protocol.clone());
        match tokio::time::timeout(handshake_timeout, conn.accept_bi()).await {
            Ok(Ok((send, recv))) => {
                let proto_str = protocol
                    .as_deref()
                    .and_then(|p| std::str::from_utf8(p).ok());
                dispatch_conn(
                    ctx,
                    proto_str,
                    Box::new(QuicTransport::new(send, recv)),
                    addr,
                    worker_limiter,
                )
                .await;
            }
            // If the stream accept times out or fails, proactively close to signal the client.
            _ => {
                tracing::debug!("QUIC handshake ok but stream accept failed/timed out; closing");
                conn.close(0u32.into(), b"accept-timeout");
            }
        }
    }
}

pub(crate) async fn accept_loop(
    ctx: ManagerContext,
    cfg: ServerConfigOpts,
    tcp_listener: TcpListener,
    tls_acceptor: TlsAcceptor,
    quic_endpoint: Endpoint,
    active_connections: Arc<AtomicUsize>,
) -> Result<(), Report> {
    let base_ctx = ctx;
    let mut limiter = RateLimiter::new(cfg.max_conn_rate, cfg.conn_rate_interval);
    let worker_limit = if cfg.max_worker_conn_rate > 0 {
        cfg.max_worker_conn_rate
    } else {
        cfg.max_conn_rate
    };
    let worker_interval = if cfg.max_worker_conn_rate > 0 {
        cfg.worker_conn_rate_interval
    } else {
        cfg.conn_rate_interval
    };
    let worker_limiter = Arc::new(tokio::sync::Mutex::new(RateLimiter::new(
        worker_limit,
        worker_interval,
    )));
    let handshake_timeout = Duration::from_millis(cfg.handshake_timeout_ms);
    loop {
        tokio::select! {
            Ok((stream, addr)) = tcp_listener.accept() => {
                if !limiter.check(addr.ip()) {
                    warn!(peer=%addr, "connection rate limit exceeded");
                    drop(stream);
                    continue;
                }
                if !try_increment_connection(&active_connections, cfg.max_connections) {
                    warn!(peer=%addr, "connection limit reached");
                    drop(stream);
                    continue;
                }
                tracing::trace!(peer=%addr, "Accepted TCP connection");
                let counter = active_connections.clone();
                let ctx = base_ctx.clone();
                let acceptor = tls_acceptor.clone();
                let timeout = handshake_timeout;
                let worker_limiter = worker_limiter.clone();
                tokio::spawn(async move {
                    process_tcp(ctx, stream, addr, acceptor, timeout, worker_limiter).await;
                    counter.fetch_sub(1, Ordering::SeqCst);
                });
            }
            Some(incoming) = quic_endpoint.accept() => {
                let addr = incoming.remote_address();
                if !limiter.check(addr.ip()) {
                    warn!(peer=%addr, "connection rate limit exceeded");
                    incoming.refuse();
                    continue;
                }
                if !try_increment_connection(&active_connections, cfg.max_connections) {
                    warn!(peer=%addr, "connection limit reached");
                    incoming.refuse();
                    continue;
                }
                let connecting = match incoming.accept() {
                    Ok(connecting) => connecting,
                    Err(err) => {
                        warn!(peer=%addr, "failed to accept incoming QUIC connection: {}", err);
                        active_connections.fetch_sub(1, Ordering::SeqCst);
                        continue;
                    }
                };
                tracing::trace!(peer=%addr, "Accepted QUIC connection");
                let counter = active_connections.clone();
                let ctx = base_ctx.clone();
                let timeout = handshake_timeout;
                let worker_limiter = worker_limiter.clone();
                tokio::spawn(async move {
                    process_quic(ctx, connecting, addr, timeout, worker_limiter).await;
                    counter.fetch_sub(1, Ordering::SeqCst);
                });
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::try_increment_connection;
    use std::sync::{Arc, atomic::AtomicUsize};

    #[test]
    fn refuses_when_limit_reached() {
        let active = Arc::new(AtomicUsize::new(0));
        assert!(try_increment_connection(&active, 1));
        assert!(!try_increment_connection(&active, 1));
    }

    #[test]
    fn allows_after_decrement() {
        let active = Arc::new(AtomicUsize::new(0));
        assert!(try_increment_connection(&active, 1));
        active.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
        assert!(try_increment_connection(&active, 1));
    }
}