pylon_runtime/
ip_limit.rs1use std::collections::HashMap;
11use std::net::IpAddr;
12use std::sync::{Arc, Mutex};
13
14pub const DEFAULT_MAX_CONNECTIONS_PER_IP: u32 = 64;
19
20pub struct IpConnCounter {
22 counts: Mutex<HashMap<IpAddr, u32>>,
23 cap: u32,
24}
25
26impl Default for IpConnCounter {
27 fn default() -> Self {
28 Self::new(DEFAULT_MAX_CONNECTIONS_PER_IP)
29 }
30}
31
32impl IpConnCounter {
33 pub fn new(cap: u32) -> Self {
34 Self {
35 counts: Mutex::new(HashMap::new()),
36 cap,
37 }
38 }
39
40 pub fn acquire(self: &Arc<Self>, ip: IpAddr) -> Option<IpConnGuard> {
43 let mut map = self.counts.lock().unwrap();
44 let slot = map.entry(ip).or_insert(0);
45 if *slot >= self.cap {
46 return None;
47 }
48 *slot += 1;
49 Some(IpConnGuard {
50 counter: Arc::clone(self),
51 ip,
52 })
53 }
54
55 #[cfg(test)]
56 pub(crate) fn get(&self, ip: IpAddr) -> u32 {
57 self.counts.lock().unwrap().get(&ip).copied().unwrap_or(0)
58 }
59}
60
61pub struct IpConnGuard {
65 counter: Arc<IpConnCounter>,
66 ip: IpAddr,
67}
68
69impl Drop for IpConnGuard {
70 fn drop(&mut self) {
71 let mut map = self.counter.counts.lock().unwrap();
72 if let Some(count) = map.get_mut(&self.ip) {
73 *count = count.saturating_sub(1);
74 if *count == 0 {
75 map.remove(&self.ip);
76 }
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn respects_cap() {
87 let counter = Arc::new(IpConnCounter::new(3));
88 let ip: IpAddr = "192.0.2.1".parse().unwrap();
89 let g1 = counter.acquire(ip).unwrap();
90 let _g2 = counter.acquire(ip).unwrap();
91 let _g3 = counter.acquire(ip).unwrap();
92 assert!(counter.acquire(ip).is_none(), "at cap, next acquire fails");
93 drop(g1);
94 assert!(counter.acquire(ip).is_some(), "freed slot is reusable");
95 }
96
97 #[test]
98 fn frees_on_drop() {
99 let counter = Arc::new(IpConnCounter::new(3));
100 let ip: IpAddr = "192.0.2.1".parse().unwrap();
101 {
102 let _g = counter.acquire(ip).unwrap();
103 assert_eq!(counter.get(ip), 1);
104 }
105 assert_eq!(counter.get(ip), 0, "empty entries evicted");
106 }
107
108 #[test]
109 fn isolates_ips() {
110 let counter = Arc::new(IpConnCounter::new(2));
111 let a: IpAddr = "192.0.2.1".parse().unwrap();
112 let b: IpAddr = "192.0.2.2".parse().unwrap();
113 let _a1 = counter.acquire(a).unwrap();
114 let _a2 = counter.acquire(a).unwrap();
115 assert!(counter.acquire(a).is_none());
116 assert!(counter.acquire(b).is_some(), "other IP not starved");
117 }
118}