use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::{Arc, Mutex};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RejectReason {
GlobalLimit,
PerIpLimit,
}
type PerIpState = Arc<Mutex<HashMap<IpAddr, u32>>>;
pub struct ConnectionGuard {
global: Arc<Semaphore>,
per_ip: PerIpState,
max_per_ip: u32,
}
impl ConnectionGuard {
#[must_use]
pub fn new(max_connections: u32, max_per_ip: u32) -> Self {
assert!(max_connections > 0, "max_connections must be > 0");
assert!(max_per_ip > 0, "max_per_ip must be > 0");
Self {
global: Arc::new(Semaphore::new(max_connections as usize)),
per_ip: Arc::new(Mutex::new(HashMap::new())),
max_per_ip,
}
}
pub fn try_acquire(&self, ip: IpAddr) -> Result<ConnectionPermit, RejectReason> {
let global_permit = Arc::clone(&self.global)
.try_acquire_owned()
.map_err(|_| RejectReason::GlobalLimit)?;
{
let mut map = self.per_ip.lock().unwrap_or_else(|e| {
tracing::warn!("per-IP lock was poisoned, recovering");
e.into_inner()
});
let count = map.entry(ip).or_insert(0);
if *count >= self.max_per_ip {
drop(global_permit);
return Err(RejectReason::PerIpLimit);
}
*count += 1;
}
Ok(ConnectionPermit {
_global: global_permit,
per_ip: Arc::clone(&self.per_ip),
ip,
})
}
}
pub struct ConnectionPermit {
_global: OwnedSemaphorePermit,
per_ip: PerIpState,
ip: IpAddr,
}
impl std::fmt::Debug for ConnectionPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionPermit")
.field("ip", &self.ip)
.finish_non_exhaustive()
}
}
impl Drop for ConnectionPermit {
fn drop(&mut self) {
let mut map = self.per_ip.lock().unwrap_or_else(|e| {
tracing::warn!("per-IP lock was poisoned, recovering");
e.into_inner()
});
if let Some(count) = map.get_mut(&self.ip) {
*count -= 1;
if *count == 0 {
map.remove(&self.ip);
}
}
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr};
use super::*;
const IP_A: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
const IP_B: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
#[test]
fn global_limit_enforced() {
let guard = ConnectionGuard::new(2, 10);
let p1 = guard.try_acquire(IP_A).unwrap();
let p2 = guard.try_acquire(IP_B).unwrap();
let err = guard.try_acquire(IP_A).unwrap_err();
assert_eq!(err, RejectReason::GlobalLimit);
drop(p1);
let p3 = guard.try_acquire(IP_A).unwrap();
drop(p2);
drop(p3);
}
#[test]
fn per_ip_limit_enforced() {
let guard = ConnectionGuard::new(100, 2);
let p1 = guard.try_acquire(IP_A).unwrap();
let p2 = guard.try_acquire(IP_A).unwrap();
let err = guard.try_acquire(IP_A).unwrap_err();
assert_eq!(err, RejectReason::PerIpLimit);
let p3 = guard.try_acquire(IP_B).unwrap();
drop(p1);
drop(p2);
drop(p3);
}
#[test]
fn different_ips_are_independent() {
let guard = ConnectionGuard::new(100, 1);
let _p1 = guard.try_acquire(IP_A).unwrap();
let _p2 = guard.try_acquire(IP_B).unwrap();
assert_eq!(
guard.try_acquire(IP_A).unwrap_err(),
RejectReason::PerIpLimit
);
assert_eq!(
guard.try_acquire(IP_B).unwrap_err(),
RejectReason::PerIpLimit
);
}
#[test]
fn permits_release_on_drop() {
let guard = ConnectionGuard::new(1, 1);
let permit = guard.try_acquire(IP_A).unwrap();
assert_eq!(
guard.try_acquire(IP_A).unwrap_err(),
RejectReason::GlobalLimit
);
drop(permit);
let _permit = guard.try_acquire(IP_A).unwrap();
}
#[test]
fn zero_count_entries_cleaned_up() {
let guard = ConnectionGuard::new(10, 2);
let p1 = guard.try_acquire(IP_A).unwrap();
let p2 = guard.try_acquire(IP_A).unwrap();
drop(p1);
drop(p2);
assert!(!guard.per_ip.lock().unwrap().contains_key(&IP_A));
}
#[test]
fn per_ip_freed_but_global_exhausted() {
let guard = ConnectionGuard::new(2, 1);
let _p1 = guard.try_acquire(IP_A).unwrap();
let _p2 = guard.try_acquire(IP_B).unwrap();
let ip_c = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3));
assert_eq!(
guard.try_acquire(ip_c).unwrap_err(),
RejectReason::GlobalLimit
);
}
#[test]
#[should_panic(expected = "max_connections must be > 0")]
fn zero_max_connections_panics() {
let _ = ConnectionGuard::new(0, 1);
}
#[test]
#[should_panic(expected = "max_per_ip must be > 0")]
fn zero_max_per_ip_panics() {
let _ = ConnectionGuard::new(1, 0);
}
}