deepslate 0.3.1

A high-performance Minecraft server proxy written in Rust.
Documentation
//! Connection rate limiting and concurrency guards.
//!
//! Provides a [`ConnectionGuard`] that enforces a global connection cap
//! (via a [`tokio::sync::Semaphore`]) and per-IP concurrency limits (via a
//! simple `HashMap` behind a `Mutex`). Acquiring a [`ConnectionPermit`]
//! reserves capacity on both axes; dropping it releases both automatically.

use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::{Arc, Mutex};

use tokio::sync::{OwnedSemaphorePermit, Semaphore};

/// Reason a connection was rejected by [`ConnectionGuard::try_acquire`].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RejectReason {
    /// The global concurrent connection limit has been reached.
    GlobalLimit,
    /// The per-IP concurrent connection limit has been reached.
    PerIpLimit,
}

/// Shared state tracking per-IP connection counts.
///
/// Wrapped in an `Arc` so that [`ConnectionPermit`] can decrement on drop
/// without borrowing the guard itself.
type PerIpState = Arc<Mutex<HashMap<IpAddr, u32>>>;

/// Enforces global and per-IP connection concurrency limits.
///
/// Create one instance per proxy run and call [`try_acquire`](Self::try_acquire)
/// for every accepted TCP connection. If the returned permit is `None`, the
/// connection should be dropped immediately.
pub struct ConnectionGuard {
    /// Global connection semaphore.
    global: Arc<Semaphore>,
    /// Current connection count per IP address.
    per_ip: PerIpState,
    /// Maximum concurrent connections from a single IP.
    max_per_ip: u32,
}

impl ConnectionGuard {
    /// Create a new guard with the given limits.
    ///
    /// # Panics
    ///
    /// Panics if `max_connections` or `max_per_ip` is zero. These invariants
    /// are enforced by [`Config::validate`](crate::config::Config::validate)
    /// before the guard is constructed.
    #[must_use]
    pub fn new(max_connections: u32, max_per_ip: u32) -> Self {
        assert!(max_connections > 0, "max_connections must be > 0");
        assert!(max_per_ip > 0, "max_per_ip must be > 0");
        Self {
            global: Arc::new(Semaphore::new(max_connections as usize)),
            per_ip: Arc::new(Mutex::new(HashMap::new())),
            max_per_ip,
        }
    }

    /// Attempt to acquire a connection permit for `ip`.
    ///
    /// Returns `Ok(permit)` if both the global and per-IP limits allow the
    /// connection. The permit is an RAII guard — dropping it releases both
    /// the global semaphore slot and the per-IP counter.
    ///
    /// # Errors
    ///
    /// Returns `Err(reason)` if the connection should be rejected because the
    /// global or per-IP limit has been reached.
    pub fn try_acquire(&self, ip: IpAddr) -> Result<ConnectionPermit, RejectReason> {
        // 1. Check global limit via the semaphore.
        let global_permit = Arc::clone(&self.global)
            .try_acquire_owned()
            .map_err(|_| RejectReason::GlobalLimit)?;

        // 2. Check per-IP limit under the mutex.
        {
            let mut map = self.per_ip.lock().unwrap_or_else(|e| {
                tracing::warn!("per-IP lock was poisoned, recovering");
                e.into_inner()
            });
            let count = map.entry(ip).or_insert(0);
            if *count >= self.max_per_ip {
                // Drop the global permit before returning — we can't use the slot.
                drop(global_permit);
                return Err(RejectReason::PerIpLimit);
            }
            *count += 1;
        }

        Ok(ConnectionPermit {
            _global: global_permit,
            per_ip: Arc::clone(&self.per_ip),
            ip,
        })
    }
}

/// RAII permit returned by [`ConnectionGuard::try_acquire`].
///
/// Holds both a global semaphore slot and a per-IP counter increment.
/// Dropping this permit releases both, allowing new connections.
pub struct ConnectionPermit {
    /// Held for the connection's lifetime; released automatically on drop.
    _global: OwnedSemaphorePermit,
    /// Shared per-IP state so we can decrement on drop.
    per_ip: PerIpState,
    /// The IP address this permit was issued for.
    ip: IpAddr,
}

impl std::fmt::Debug for ConnectionPermit {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ConnectionPermit")
            .field("ip", &self.ip)
            .finish_non_exhaustive()
    }
}

impl Drop for ConnectionPermit {
    fn drop(&mut self) {
        let mut map = self.per_ip.lock().unwrap_or_else(|e| {
            tracing::warn!("per-IP lock was poisoned, recovering");
            e.into_inner()
        });
        if let Some(count) = map.get_mut(&self.ip) {
            *count -= 1;
            if *count == 0 {
                map.remove(&self.ip);
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::net::{IpAddr, Ipv4Addr};

    use super::*;

    const IP_A: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
    const IP_B: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));

    #[test]
    fn global_limit_enforced() {
        let guard = ConnectionGuard::new(2, 10);

        let p1 = guard.try_acquire(IP_A).unwrap();
        let p2 = guard.try_acquire(IP_B).unwrap();
        let err = guard.try_acquire(IP_A).unwrap_err();
        assert_eq!(err, RejectReason::GlobalLimit);

        // Dropping a permit frees the slot.
        drop(p1);
        let p3 = guard.try_acquire(IP_A).unwrap();

        drop(p2);
        drop(p3);
    }

    #[test]
    fn per_ip_limit_enforced() {
        let guard = ConnectionGuard::new(100, 2);

        let p1 = guard.try_acquire(IP_A).unwrap();
        let p2 = guard.try_acquire(IP_A).unwrap();
        let err = guard.try_acquire(IP_A).unwrap_err();
        assert_eq!(err, RejectReason::PerIpLimit);

        // Different IP is still allowed.
        let p3 = guard.try_acquire(IP_B).unwrap();

        drop(p1);
        drop(p2);
        drop(p3);
    }

    #[test]
    fn different_ips_are_independent() {
        let guard = ConnectionGuard::new(100, 1);

        let _p1 = guard.try_acquire(IP_A).unwrap();
        let _p2 = guard.try_acquire(IP_B).unwrap();

        // Both at their per-IP limit — each blocks only their own IP.
        assert_eq!(
            guard.try_acquire(IP_A).unwrap_err(),
            RejectReason::PerIpLimit
        );
        assert_eq!(
            guard.try_acquire(IP_B).unwrap_err(),
            RejectReason::PerIpLimit
        );
    }

    #[test]
    fn permits_release_on_drop() {
        let guard = ConnectionGuard::new(1, 1);

        let permit = guard.try_acquire(IP_A).unwrap();
        assert_eq!(
            guard.try_acquire(IP_A).unwrap_err(),
            RejectReason::GlobalLimit
        );

        drop(permit);

        // Both global and per-IP should be available again.
        let _permit = guard.try_acquire(IP_A).unwrap();
    }

    #[test]
    fn zero_count_entries_cleaned_up() {
        let guard = ConnectionGuard::new(10, 2);

        let p1 = guard.try_acquire(IP_A).unwrap();
        let p2 = guard.try_acquire(IP_A).unwrap();

        drop(p1);
        drop(p2);

        // The map should have no entry for IP_A after all permits are dropped.
        assert!(!guard.per_ip.lock().unwrap().contains_key(&IP_A));
    }

    #[test]
    fn per_ip_freed_but_global_exhausted() {
        // Global allows 2, per-IP allows 1.
        let guard = ConnectionGuard::new(2, 1);

        let _p1 = guard.try_acquire(IP_A).unwrap();
        let _p2 = guard.try_acquire(IP_B).unwrap();

        // IP_A has room per-IP (only 1 of 1 used) but global is full.
        // A new IP would hit global limit.
        let ip_c = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3));
        assert_eq!(
            guard.try_acquire(ip_c).unwrap_err(),
            RejectReason::GlobalLimit
        );
    }

    #[test]
    #[should_panic(expected = "max_connections must be > 0")]
    fn zero_max_connections_panics() {
        let _ = ConnectionGuard::new(0, 1);
    }

    #[test]
    #[should_panic(expected = "max_per_ip must be > 0")]
    fn zero_max_per_ip_panics() {
        let _ = ConnectionGuard::new(1, 0);
    }
}