use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrincipalCapExceeded {
pub principal: String,
pub limit: usize,
pub current: usize,
}
pub const PRINCIPAL_INFLIGHT_CODE: &str = "principal_inflight_exhausted";
#[derive(Debug)]
struct Inner {
cap: usize,
in_use: Mutex<HashMap<String, usize>>,
rejected: AtomicU64,
}
#[derive(Debug)]
pub struct PrincipalInflightPermit {
inner: Arc<Inner>,
principal: String,
}
impl Drop for PrincipalInflightPermit {
fn drop(&mut self) {
let mut map = self.inner.in_use.lock().expect("principal limiter mutex");
if let Some(count) = map.get_mut(&self.principal) {
*count -= 1;
if *count == 0 {
map.remove(&self.principal);
}
}
}
}
#[derive(Debug, Clone)]
pub struct PrincipalConnectionLimiter {
inner: Arc<Inner>,
}
impl PrincipalConnectionLimiter {
pub fn new(cap: usize) -> Self {
Self {
inner: Arc::new(Inner {
cap,
in_use: Mutex::new(HashMap::new()),
rejected: AtomicU64::new(0),
}),
}
}
pub fn cap(&self) -> usize {
self.inner.cap
}
pub fn is_enforced(&self) -> bool {
self.inner.cap > 0
}
pub fn current_for(&self, principal: &str) -> usize {
let map = self.inner.in_use.lock().expect("principal limiter mutex");
map.get(principal).copied().unwrap_or(0)
}
pub fn tracked_principals(&self) -> usize {
let map = self.inner.in_use.lock().expect("principal limiter mutex");
map.len()
}
pub fn rejected_total(&self) -> u64 {
self.inner.rejected.load(Ordering::Relaxed)
}
pub fn try_acquire(
&self,
principal: &str,
) -> Result<PrincipalInflightPermit, PrincipalCapExceeded> {
if self.inner.cap == 0 {
return Ok(PrincipalInflightPermit {
inner: Arc::clone(&self.inner),
principal: principal.to_string(),
});
}
let mut map = self.inner.in_use.lock().expect("principal limiter mutex");
let count = map.entry(principal.to_string()).or_insert(0);
if *count >= self.inner.cap {
let current = *count;
drop(map);
self.inner.rejected.fetch_add(1, Ordering::Relaxed);
return Err(PrincipalCapExceeded {
principal: principal.to_string(),
limit: self.inner.cap,
current,
});
}
*count += 1;
Ok(PrincipalInflightPermit {
inner: Arc::clone(&self.inner),
principal: principal.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::thread;
#[test]
fn disabled_cap_admits_everything_without_tracking() {
let limiter = PrincipalConnectionLimiter::new(0);
assert!(!limiter.is_enforced());
let mut permits = Vec::new();
for _ in 0..1_000 {
permits.push(limiter.try_acquire("alice").expect("disabled admits"));
}
assert_eq!(limiter.tracked_principals(), 0);
assert_eq!(limiter.rejected_total(), 0);
}
#[test]
fn admits_up_to_cap_then_refuses_with_structured_detail() {
let limiter = PrincipalConnectionLimiter::new(3);
let p1 = limiter.try_acquire("alice").expect("slot 1");
let p2 = limiter.try_acquire("alice").expect("slot 2");
let p3 = limiter.try_acquire("alice").expect("slot 3");
assert_eq!(limiter.current_for("alice"), 3);
let err = limiter.try_acquire("alice").expect_err("over cap");
assert_eq!(
err,
PrincipalCapExceeded {
principal: "alice".to_string(),
limit: 3,
current: 3,
}
);
assert_eq!(limiter.rejected_total(), 1);
drop((p1, p2, p3));
}
#[test]
fn dropping_a_permit_frees_a_slot() {
let limiter = PrincipalConnectionLimiter::new(1);
let p = limiter.try_acquire("bob").expect("first slot");
assert!(limiter.try_acquire("bob").is_err());
drop(p);
assert_eq!(limiter.current_for("bob"), 0);
assert_eq!(limiter.tracked_principals(), 0);
let _p = limiter.try_acquire("bob").expect("reacquire after drop");
assert_eq!(limiter.current_for("bob"), 1);
}
#[test]
fn principals_are_isolated() {
let limiter = PrincipalConnectionLimiter::new(1);
let _alice = limiter.try_acquire("alice").expect("alice slot");
assert!(limiter.try_acquire("alice").is_err());
let _bob = limiter.try_acquire("bob").expect("bob unaffected");
assert_eq!(limiter.tracked_principals(), 2);
}
#[test]
fn entry_evicted_when_last_permit_drops() {
let limiter = PrincipalConnectionLimiter::new(4);
let a = limiter.try_acquire("carol").expect("1");
let b = limiter.try_acquire("carol").expect("2");
assert_eq!(limiter.tracked_principals(), 1);
drop(a);
assert_eq!(limiter.current_for("carol"), 1);
assert_eq!(limiter.tracked_principals(), 1);
drop(b);
assert_eq!(limiter.tracked_principals(), 0);
}
#[test]
fn concurrent_acquire_never_over_issues_per_principal() {
let cap = 8;
let limiter = PrincipalConnectionLimiter::new(cap);
let success = Arc::new(AtomicUsize::new(0));
let denied = Arc::new(AtomicUsize::new(0));
let held: Arc<Mutex<Vec<PrincipalInflightPermit>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = Vec::new();
for _ in 0..64 {
let l = limiter.clone();
let s = Arc::clone(&success);
let d = Arc::clone(&denied);
let h = Arc::clone(&held);
handles.push(thread::spawn(move || match l.try_acquire("storm") {
Ok(p) => {
s.fetch_add(1, Ordering::Relaxed);
h.lock().unwrap().push(p);
}
Err(_) => {
d.fetch_add(1, Ordering::Relaxed);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(success.load(Ordering::Relaxed), cap);
assert_eq!(denied.load(Ordering::Relaxed), 64 - cap);
assert_eq!(limiter.current_for("storm"), cap);
assert_eq!(limiter.rejected_total() as usize, 64 - cap);
held.lock().unwrap().clear();
assert_eq!(limiter.current_for("storm"), 0);
}
#[test]
fn clone_shares_state() {
let a = PrincipalConnectionLimiter::new(2);
let b = a.clone();
let _p = a.try_acquire("dave").unwrap();
assert_eq!(b.current_for("dave"), 1);
assert_eq!(b.cap(), 2);
}
}