1use std::collections::HashMap;
7
8pub struct RateLimiter {
10 buckets: HashMap<[u8; 32], TokenBucket>,
11 max_tokens: f64,
12 refill_rate: f64,
13}
14
15struct TokenBucket {
16 tokens: f64,
17 last_refill: i64,
18}
19
20impl RateLimiter {
21 pub fn new() -> Self {
24 Self {
25 buckets: HashMap::new(),
26 max_tokens: 10.0,
27 refill_rate: 10.0 / 60.0,
28 }
29 }
30
31 pub fn with_params(max_tokens: f64, refill_rate: f64) -> Self {
33 Self {
34 buckets: HashMap::new(),
35 max_tokens,
36 refill_rate,
37 }
38 }
39
40 pub fn check(&mut self, public_key: &[u8; 32], now_ms: i64) -> bool {
43 let bucket = self.buckets.entry(*public_key).or_insert(TokenBucket {
44 tokens: self.max_tokens,
45 last_refill: now_ms,
46 });
47
48 let elapsed_secs = (now_ms - bucket.last_refill) as f64 / 1000.0;
50 if elapsed_secs > 0.0 {
51 bucket.tokens = (bucket.tokens + elapsed_secs * self.refill_rate).min(self.max_tokens);
52 bucket.last_refill = now_ms;
53 }
54
55 if bucket.tokens >= 1.0 {
57 bucket.tokens -= 1.0;
58 true
59 } else {
60 false
61 }
62 }
63
64 pub fn cleanup(&mut self, now_ms: i64, stale_threshold_ms: i64) {
67 self.buckets
68 .retain(|_, bucket| (now_ms - bucket.last_refill) < stale_threshold_ms);
69 }
70}
71
72impl Default for RateLimiter {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_rate_limiter_allows_up_to_max() {
84 let mut limiter = RateLimiter::new();
85 let key = [1u8; 32];
86 let now = 1000000i64;
87
88 for _ in 0..10 {
90 assert!(limiter.check(&key, now));
91 }
92
93 assert!(!limiter.check(&key, now));
95 }
96
97 #[test]
98 fn test_rate_limiter_refills_over_time() {
99 let mut limiter = RateLimiter::new();
100 let key = [1u8; 32];
101 let start = 1000000i64;
102
103 for _ in 0..10 {
105 assert!(limiter.check(&key, start));
106 }
107 assert!(!limiter.check(&key, start));
108
109 let after_6s = start + 6000;
111 assert!(limiter.check(&key, after_6s));
112 assert!(!limiter.check(&key, after_6s));
113 }
114
115 #[test]
116 fn test_rate_limiter_full_refill() {
117 let mut limiter = RateLimiter::new();
118 let key = [1u8; 32];
119 let start = 1000000i64;
120
121 for _ in 0..10 {
123 limiter.check(&key, start);
124 }
125
126 let after_60s = start + 60_000;
128 for _ in 0..10 {
129 assert!(limiter.check(&key, after_60s));
130 }
131 assert!(!limiter.check(&key, after_60s));
132 }
133
134 #[test]
135 fn test_rate_limiter_independent_keys() {
136 let mut limiter = RateLimiter::new();
137 let key_a = [1u8; 32];
138 let key_b = [2u8; 32];
139 let now = 1000000i64;
140
141 for _ in 0..10 {
143 limiter.check(&key_a, now);
144 }
145 assert!(!limiter.check(&key_a, now));
146
147 assert!(limiter.check(&key_b, now));
149 }
150
151 #[test]
152 fn test_rate_limiter_cleanup() {
153 let mut limiter = RateLimiter::new();
154 let key = [1u8; 32];
155 let start = 1_000_000i64;
156
157 limiter.check(&key, start);
158 assert_eq!(limiter.buckets.len(), 1);
159
160 let stale_threshold = 3_600_000;
162 limiter.cleanup(start + stale_threshold + 1, stale_threshold);
163 assert_eq!(limiter.buckets.len(), 0);
164 }
165
166 #[test]
167 fn test_rate_limiter_no_negative_tokens() {
168 let mut limiter = RateLimiter::new();
169 let key = [1u8; 32];
170 let now = 1000000i64;
171
172 for _ in 0..20 {
174 limiter.check(&key, now);
175 }
176
177 let after_6s = now + 6000;
179 assert!(limiter.check(&key, after_6s));
180 assert!(!limiter.check(&key, after_6s));
181 }
182
183 #[test]
184 fn test_rate_limiter_tokens_cap_at_max() {
185 let mut limiter = RateLimiter::new();
186 let key = [1u8; 32];
187 let start = 1000000i64;
188
189 limiter.check(&key, start);
191
192 let much_later = start + 1_000_000;
194 for _ in 0..10 {
195 assert!(limiter.check(&key, much_later));
196 }
197 assert!(!limiter.check(&key, much_later));
198 }
199
200 #[test]
201 fn test_rate_limiter_custom_params() {
202 let mut limiter = RateLimiter::with_params(3.0, 1.0); let key = [1u8; 32];
204 let now = 1000000i64;
205
206 assert!(limiter.check(&key, now));
207 assert!(limiter.check(&key, now));
208 assert!(limiter.check(&key, now));
209 assert!(!limiter.check(&key, now)); assert!(limiter.check(&key, now + 1000));
213 assert!(!limiter.check(&key, now + 1000));
214 }
215}