Skip to main content

pylon_runtime/
ip_limit.rs

1//! Per-IP concurrent connection limiter used by every streaming endpoint
2//! (WS notifier, SSE, shard WS). A single misbehaving peer should not be
3//! able to exhaust the server's thread budget or per-client mutex pool by
4//! opening hundreds of long-lived sockets.
5//!
6//! The limiter is cheap: one mutex, one HashMap entry per active IP. An
7//! RAII guard released on disconnect decrements the count — callers cannot
8//! leak a slot by forgetting to release it, even on panic.
9
10use std::collections::HashMap;
11use std::net::IpAddr;
12use std::sync::{Arc, Mutex};
13
14/// Default cap on concurrent streaming connections per client IP. Generous
15/// enough for normal browser tabs, chatty mobile apps, or shared NATs, but
16/// stingy enough that one attacker can't open 10k sockets. Each endpoint
17/// can override by constructing the counter with a different cap.
18pub const DEFAULT_MAX_CONNECTIONS_PER_IP: u32 = 64;
19
20/// Tracks how many concurrent streaming connections each IP currently holds.
21pub struct IpConnCounter {
22    counts: Mutex<HashMap<IpAddr, u32>>,
23    cap: u32,
24}
25
26impl Default for IpConnCounter {
27    fn default() -> Self {
28        Self::new(DEFAULT_MAX_CONNECTIONS_PER_IP)
29    }
30}
31
32impl IpConnCounter {
33    pub fn new(cap: u32) -> Self {
34        Self {
35            counts: Mutex::new(HashMap::new()),
36            cap,
37        }
38    }
39
40    /// Increment the counter for `ip` if it hasn't hit the cap. Returns a
41    /// guard that decrements on drop.
42    pub fn acquire(self: &Arc<Self>, ip: IpAddr) -> Option<IpConnGuard> {
43        let mut map = self.counts.lock().unwrap();
44        let slot = map.entry(ip).or_insert(0);
45        if *slot >= self.cap {
46            return None;
47        }
48        *slot += 1;
49        Some(IpConnGuard {
50            counter: Arc::clone(self),
51            ip,
52        })
53    }
54
55    #[cfg(test)]
56    pub(crate) fn get(&self, ip: IpAddr) -> u32 {
57        self.counts.lock().unwrap().get(&ip).copied().unwrap_or(0)
58    }
59}
60
61/// RAII guard: decrements the IP's connection count when dropped. Hold it
62/// for the full lifetime of the connection (thread, task) so the slot is
63/// only released on actual disconnect.
64pub struct IpConnGuard {
65    counter: Arc<IpConnCounter>,
66    ip: IpAddr,
67}
68
69impl Drop for IpConnGuard {
70    fn drop(&mut self) {
71        let mut map = self.counter.counts.lock().unwrap();
72        if let Some(count) = map.get_mut(&self.ip) {
73            *count = count.saturating_sub(1);
74            if *count == 0 {
75                map.remove(&self.ip);
76            }
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn respects_cap() {
87        let counter = Arc::new(IpConnCounter::new(3));
88        let ip: IpAddr = "192.0.2.1".parse().unwrap();
89        let g1 = counter.acquire(ip).unwrap();
90        let _g2 = counter.acquire(ip).unwrap();
91        let _g3 = counter.acquire(ip).unwrap();
92        assert!(counter.acquire(ip).is_none(), "at cap, next acquire fails");
93        drop(g1);
94        assert!(counter.acquire(ip).is_some(), "freed slot is reusable");
95    }
96
97    #[test]
98    fn frees_on_drop() {
99        let counter = Arc::new(IpConnCounter::new(3));
100        let ip: IpAddr = "192.0.2.1".parse().unwrap();
101        {
102            let _g = counter.acquire(ip).unwrap();
103            assert_eq!(counter.get(ip), 1);
104        }
105        assert_eq!(counter.get(ip), 0, "empty entries evicted");
106    }
107
108    #[test]
109    fn isolates_ips() {
110        let counter = Arc::new(IpConnCounter::new(2));
111        let a: IpAddr = "192.0.2.1".parse().unwrap();
112        let b: IpAddr = "192.0.2.2".parse().unwrap();
113        let _a1 = counter.acquire(a).unwrap();
114        let _a2 = counter.acquire(a).unwrap();
115        assert!(counter.acquire(a).is_none());
116        assert!(counter.acquire(b).is_some(), "other IP not starved");
117    }
118}