pylon_runtime/
rate_limit.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5pub struct RateLimiter {
12 window: Duration,
13 max_requests: u32,
14 buckets: Mutex<HashMap<String, Vec<Instant>>>,
15}
16
17impl RateLimiter {
18 pub fn new(max_requests: u32, window_secs: u64) -> Self {
23 Self {
24 window: Duration::from_secs(window_secs),
25 max_requests,
26 buckets: Mutex::new(HashMap::new()),
27 }
28 }
29
30 pub fn check(&self, ip: &str) -> Result<(), u64> {
36 let now = Instant::now();
37 let mut buckets = self.buckets.lock().unwrap();
38 let timestamps = buckets.entry(ip.to_string()).or_default();
39
40 timestamps.retain(|t| now.duration_since(*t) < self.window);
42
43 if timestamps.len() as u32 >= self.max_requests {
44 let oldest = timestamps.first().unwrap();
45 let elapsed = now.duration_since(*oldest).as_secs();
46 let retry_after = self.window.as_secs().saturating_sub(elapsed);
47 return Err(retry_after.max(1));
49 }
50
51 timestamps.push(now);
52 Ok(())
53 }
54
55 pub fn cleanup(&self) {
60 let now = Instant::now();
61 let mut buckets = self.buckets.lock().unwrap();
62
63 buckets.retain(|_ip, timestamps| {
65 timestamps.retain(|t| now.duration_since(*t) < self.window);
66 !timestamps.is_empty()
67 });
68 }
69
70 pub fn current_count(&self, ip: &str) -> u32 {
72 let now = Instant::now();
73 let buckets = self.buckets.lock().unwrap();
74 match buckets.get(ip) {
75 Some(timestamps) => timestamps
76 .iter()
77 .filter(|t| now.duration_since(**t) < self.window)
78 .count() as u32,
79 None => 0,
80 }
81 }
82}
83
84#[cfg(test)]
89mod tests {
90 use super::*;
91 use std::thread;
92 use std::time::Duration;
93
94 #[test]
95 fn under_limit_passes() {
96 let rl = RateLimiter::new(5, 60);
97 for _ in 0..5 {
98 assert!(rl.check("10.0.0.1").is_ok());
99 }
100 }
101
102 #[test]
103 fn over_limit_rejected() {
104 let rl = RateLimiter::new(3, 60);
105 for _ in 0..3 {
106 assert!(rl.check("10.0.0.1").is_ok());
107 }
108 let err = rl.check("10.0.0.1").unwrap_err();
109 assert!(err >= 1, "retry_after should be at least 1 second");
110 }
111
112 #[test]
113 fn window_expiry_allows_new_requests() {
114 let rl = RateLimiter::new(2, 1);
116 assert!(rl.check("10.0.0.1").is_ok());
117 assert!(rl.check("10.0.0.1").is_ok());
118 assert!(rl.check("10.0.0.1").is_err());
119
120 thread::sleep(Duration::from_millis(1100));
122
123 assert!(rl.check("10.0.0.1").is_ok());
125 }
126
127 #[test]
128 fn different_ips_are_independent() {
129 let rl = RateLimiter::new(2, 60);
130 assert!(rl.check("10.0.0.1").is_ok());
131 assert!(rl.check("10.0.0.1").is_ok());
132 assert!(rl.check("10.0.0.1").is_err());
133
134 assert!(rl.check("10.0.0.2").is_ok());
136 assert!(rl.check("10.0.0.2").is_ok());
137 }
138
139 #[test]
140 fn cleanup_removes_expired_buckets() {
141 let rl = RateLimiter::new(10, 1);
142 assert!(rl.check("10.0.0.1").is_ok());
143 assert!(rl.check("10.0.0.2").is_ok());
144
145 thread::sleep(Duration::from_millis(1100));
147
148 rl.cleanup();
149
150 assert_eq!(rl.current_count("10.0.0.1"), 0);
152 assert_eq!(rl.current_count("10.0.0.2"), 0);
153 }
154
155 #[test]
156 fn current_count_reflects_active_requests() {
157 let rl = RateLimiter::new(10, 60);
158 assert_eq!(rl.current_count("10.0.0.1"), 0);
159
160 rl.check("10.0.0.1").unwrap();
161 assert_eq!(rl.current_count("10.0.0.1"), 1);
162
163 rl.check("10.0.0.1").unwrap();
164 rl.check("10.0.0.1").unwrap();
165 assert_eq!(rl.current_count("10.0.0.1"), 3);
166 }
167}