clawbox_proxy/
rate_limiter.rs1use std::collections::HashMap;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7#[non_exhaustive]
9pub struct RateLimiter {
10 buckets: Mutex<HashMap<String, TokenBucket>>,
11 default_capacity: u32,
12 default_refill_rate: f64, }
14
15struct TokenBucket {
16 tokens: f64,
17 capacity: u32,
18 refill_rate: f64,
19 last_refill: Instant,
20}
21
22impl TokenBucket {
23 fn new(capacity: u32, refill_rate: f64) -> Self {
24 Self {
25 tokens: capacity as f64,
26 capacity,
27 refill_rate,
28 last_refill: Instant::now(),
29 }
30 }
31
32 fn try_consume(&mut self) -> bool {
33 let now = Instant::now();
34 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
35 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity as f64);
36 self.last_refill = now;
37
38 if self.tokens >= 1.0 {
39 self.tokens -= 1.0;
40 true
41 } else {
42 false
43 }
44 }
45}
46
47impl RateLimiter {
48 pub fn new(capacity: u32, refill_rate: f64) -> Self {
51 Self {
52 buckets: Mutex::new(HashMap::new()),
53 default_capacity: capacity,
54 default_refill_rate: refill_rate,
55 }
56 }
57
58 pub fn check(&self, key: &str) -> bool {
61 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
66 let bucket = buckets
67 .entry(key.to_string())
68 .or_insert_with(|| TokenBucket::new(self.default_capacity, self.default_refill_rate));
69 bucket.try_consume()
70 }
71
72 pub fn reset(&self, key: &str) {
74 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
75 buckets.remove(key);
76 }
77
78 pub fn remove(&self, key: &str) -> bool {
80 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
81 buckets.remove(key).is_some()
82 }
83
84 pub fn cleanup_stale(&self, max_age: Duration) {
89 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
90 let cutoff = Instant::now() - max_age;
91 buckets.retain(|_, bucket| bucket.last_refill > cutoff);
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use std::thread;
99 use std::time::Duration;
100
101 #[test]
102 fn test_burst_capacity() {
103 let limiter = RateLimiter::new(3, 1.0);
104 assert!(limiter.check("tool-a"));
105 assert!(limiter.check("tool-a"));
106 assert!(limiter.check("tool-a"));
107 assert!(!limiter.check("tool-a"));
109 }
110
111 #[test]
112 fn test_refill_after_wait() {
113 let limiter = RateLimiter::new(1, 10.0); assert!(limiter.check("tool-b"));
115 assert!(!limiter.check("tool-b"));
116 thread::sleep(Duration::from_millis(150));
118 assert!(limiter.check("tool-b"));
119 }
120
121 #[test]
122 fn test_independent_keys() {
123 let limiter = RateLimiter::new(1, 1.0);
124 assert!(limiter.check("a"));
125 assert!(!limiter.check("a"));
126 assert!(limiter.check("b"));
128 }
129
130 #[test]
131 fn test_reset() {
132 let limiter = RateLimiter::new(1, 0.0); assert!(limiter.check("x"));
134 assert!(!limiter.check("x"));
135 limiter.reset("x");
136 assert!(limiter.check("x"));
137 }
138
139 #[test]
140 fn test_remove() {
141 let limiter = RateLimiter::new(1, 0.0);
142 assert!(limiter.check("y"));
143 assert!(limiter.remove("y"));
144 assert!(!limiter.remove("nonexistent"));
145 assert!(limiter.check("y"));
147 }
148
149 #[test]
150 fn test_cleanup_stale() {
151 let limiter = RateLimiter::new(1, 0.0);
152 assert!(limiter.check("old"));
153 thread::sleep(Duration::from_millis(50));
154 assert!(limiter.check("new"));
155 limiter.cleanup_stale(Duration::from_millis(30));
157 assert!(limiter.check("old"));
159 }
160}