1use std::collections::HashMap;
4use std::net::IpAddr;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7
8const MAX_TOKENS: u32 = 100;
10
11#[derive(Clone, Default)]
15pub struct RateLimiter {
16 buckets: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
17}
18
19impl RateLimiter {
20 pub fn new() -> Self {
22 Self {
23 buckets: Arc::new(Mutex::new(HashMap::new())),
24 }
25 }
26
27 pub fn check(&self, ip: IpAddr) -> bool {
32 let mut buckets = self.buckets.lock().expect("rate limiter lock poisoned");
33 let now = Instant::now();
34
35 let entry = buckets.entry(ip).or_insert((MAX_TOKENS, now));
36
37 let elapsed = now.duration_since(entry.1);
39 let refill = (elapsed.as_secs_f64() * MAX_TOKENS as f64) as u32;
40 if refill > 0 {
41 entry.0 = (entry.0 + refill).min(MAX_TOKENS);
42 entry.1 = now;
43 }
44
45 if entry.0 > 0 {
47 entry.0 -= 1;
48 true
49 } else {
50 false
51 }
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use std::net::Ipv4Addr;
59
60 #[test]
61 fn allows_requests_under_limit() {
62 let limiter = RateLimiter::new();
63 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
64 for _ in 0..MAX_TOKENS {
65 assert!(limiter.check(ip));
66 }
67 }
68
69 #[test]
70 fn rejects_requests_over_limit() {
71 let limiter = RateLimiter::new();
72 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
73 for _ in 0..MAX_TOKENS {
74 limiter.check(ip);
75 }
76 assert!(!limiter.check(ip));
77 }
78
79 #[test]
80 fn separate_buckets_per_ip() {
81 let limiter = RateLimiter::new();
82 let ip1 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
83 let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
84 for _ in 0..MAX_TOKENS {
85 limiter.check(ip1);
86 }
87 assert!(!limiter.check(ip1));
88 assert!(limiter.check(ip2));
89 }
90
91 #[test]
92 fn test_rate_limiter_allows_under_limit() {
93 let limiter = RateLimiter::new();
94 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
95 for i in 0..50 {
96 assert!(limiter.check(ip), "request {i} should be allowed");
97 }
98 }
99
100 #[test]
101 fn test_rate_limiter_blocks_over_limit() {
102 let limiter = RateLimiter::new();
103 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
104 for _ in 0..MAX_TOKENS {
106 limiter.check(ip);
107 }
108 assert!(
110 !limiter.check(ip),
111 "request 101 should be blocked after exhausting tokens"
112 );
113 }
114
115 #[test]
116 fn test_rate_limiter_allows_different_ips() {
117 let limiter = RateLimiter::new();
118 let ip1 = IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1));
119 let ip2 = IpAddr::V4(Ipv4Addr::new(10, 1, 0, 2));
120 for i in 0..MAX_TOKENS {
122 assert!(limiter.check(ip1), "ip1 request {i} should be allowed");
123 assert!(limiter.check(ip2), "ip2 request {i} should be allowed");
124 }
125 }
126
127 #[test]
128 fn test_rate_limiter_refills_after_time() {
129 let limiter = RateLimiter::new();
130 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 3));
131 for _ in 0..MAX_TOKENS {
133 limiter.check(ip);
134 }
135 assert!(!limiter.check(ip), "should be blocked after exhaustion");
136
137 std::thread::sleep(std::time::Duration::from_secs(1));
139
140 assert!(
141 limiter.check(ip),
142 "should be allowed again after token refill"
143 );
144 }
145
146 #[test]
147 fn test_rate_limiter_stress() {
148 let limiter = RateLimiter::new();
149 let ip = IpAddr::V4(Ipv4Addr::new(10, 99, 0, 1));
150 let mut allowed = 0u32;
151 let mut blocked = 0u32;
152 for _ in 0..200 {
153 if limiter.check(ip) {
154 allowed += 1;
155 } else {
156 blocked += 1;
157 }
158 }
159 assert!(
162 allowed >= 100,
163 "expected at least 100 allowed, got {allowed}"
164 );
165 assert!(blocked >= 50, "expected at least 50 blocked, got {blocked}");
166 }
167}