1use crate::error::{Error, Result};
4use dashmap::DashMap;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone)]
10pub struct RateLimitConfig {
11 pub max_operations: u64,
13 pub window_duration: Duration,
15 pub enabled: bool,
17}
18
19impl Default for RateLimitConfig {
20 fn default() -> Self {
21 Self {
22 max_operations: 1000,
23 window_duration: Duration::from_secs(60),
24 enabled: false,
25 }
26 }
27}
28
29#[derive(Debug)]
31struct RateLimitWindow {
32 operations: AtomicU64,
34 window_start: Instant,
36 config: RateLimitConfig,
38}
39
40impl RateLimitWindow {
41 fn new(config: RateLimitConfig) -> Self {
42 Self {
43 operations: AtomicU64::new(0),
44 window_start: Instant::now(),
45 config,
46 }
47 }
48
49 fn check_and_increment(&mut self) -> Result<()> {
50 if !self.config.enabled {
51 return Ok(());
52 }
53
54 let now = Instant::now();
55
56 if now.duration_since(self.window_start) >= self.config.window_duration {
58 self.operations.store(0, Ordering::Relaxed);
59 self.window_start = now;
60 }
61
62 let current_ops = self.operations.load(Ordering::Relaxed);
63
64 if current_ops >= self.config.max_operations {
65 return Err(Error::RateLimitExceeded {
66 subject: "system".to_string(), limit: self.config.max_operations,
68 window: format!("{:?}", self.config.window_duration),
69 });
70 }
71
72 self.operations.fetch_add(1, Ordering::Relaxed);
73 Ok(())
74 }
75
76 fn current_usage(&self) -> (u64, u64) {
77 let current = self.operations.load(Ordering::Relaxed);
78 (current, self.config.max_operations)
79 }
80}
81
82#[derive(Debug)]
84pub struct RateLimiter {
85 subject_windows: DashMap<String, RateLimitWindow>,
87 global_window: RateLimitWindow,
89 subject_config: RateLimitConfig,
91}
92
93impl RateLimiter {
94 pub fn new(global_config: RateLimitConfig, subject_config: RateLimitConfig) -> Self {
96 Self {
97 subject_windows: DashMap::new(),
98 global_window: RateLimitWindow::new(global_config),
99 subject_config,
100 }
101 }
102
103 pub fn check_permission_rate_limit(&mut self, subject_id: &str) -> Result<()> {
105 self.global_window
107 .check_and_increment()
108 .map_err(|_| Error::RateLimitExceeded {
109 subject: "global".to_string(),
110 limit: self.global_window.config.max_operations,
111 window: format!("{:?}", self.global_window.config.window_duration),
112 })?;
113
114 if self.subject_config.enabled {
116 let mut window = self
117 .subject_windows
118 .entry(subject_id.to_string())
119 .or_insert_with(|| RateLimitWindow::new(self.subject_config.clone()));
120
121 window
122 .check_and_increment()
123 .map_err(|_| Error::RateLimitExceeded {
124 subject: subject_id.to_string(),
125 limit: self.subject_config.max_operations,
126 window: format!("{:?}", self.subject_config.window_duration),
127 })?;
128 }
129
130 Ok(())
131 }
132
133 pub fn check_role_assignment_rate_limit(&mut self, subject_id: &str) -> Result<()> {
135 let role_config = RateLimitConfig {
137 max_operations: self.subject_config.max_operations / 10,
138 window_duration: self.subject_config.window_duration,
139 enabled: self.subject_config.enabled,
140 };
141
142 if role_config.enabled {
143 let mut window = self
144 .subject_windows
145 .entry(format!("role_assignment:{}", subject_id))
146 .or_insert_with(|| RateLimitWindow::new(role_config.clone()));
147
148 window
149 .check_and_increment()
150 .map_err(|_| Error::RateLimitExceeded {
151 subject: subject_id.to_string(),
152 limit: role_config.max_operations,
153 window: format!("{:?}", role_config.window_duration),
154 })?;
155 }
156
157 Ok(())
158 }
159
160 pub fn usage_stats(&self) -> RateLimitStats {
162 let global_usage = self.global_window.current_usage();
163
164 let mut subject_usage = Vec::new();
165 for entry in self.subject_windows.iter() {
166 let (subject, window) = (entry.key(), entry.value());
167 let usage = window.current_usage();
168 subject_usage.push((subject.clone(), usage.0, usage.1));
169 }
170
171 RateLimitStats {
172 global_usage: global_usage.0,
173 global_limit: global_usage.1,
174 subject_usage,
175 }
176 }
177
178 pub fn reset_subject(&self, subject_id: &str) {
180 self.subject_windows.remove(subject_id);
181 self.subject_windows
182 .remove(&format!("role_assignment:{}", subject_id));
183 }
184
185 pub fn cleanup_expired(&self) {
187 let now = Instant::now();
188 let mut expired_keys = Vec::new();
189
190 for entry in self.subject_windows.iter() {
191 let (key, window) = (entry.key(), entry.value());
192 if now.duration_since(window.window_start) >= window.config.window_duration * 2 {
193 expired_keys.push(key.clone());
194 }
195 }
196
197 for key in expired_keys {
198 self.subject_windows.remove(&key);
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct RateLimitStats {
206 pub global_usage: u64,
208 pub global_limit: u64,
210 pub subject_usage: Vec<(String, u64, u64)>,
212}
213
214impl RateLimitStats {
215 pub fn global_usage_percentage(&self) -> f64 {
217 if self.global_limit == 0 {
218 0.0
219 } else {
220 (self.global_usage as f64 / self.global_limit as f64) * 100.0
221 }
222 }
223
224 pub fn subjects_approaching_limit(&self, threshold_percentage: f64) -> Vec<String> {
226 self.subject_usage
227 .iter()
228 .filter(|(_, current, limit)| {
229 if *limit == 0 {
230 false
231 } else {
232 let percentage = (*current as f64 / *limit as f64) * 100.0;
233 percentage >= threshold_percentage
234 }
235 })
236 .map(|(subject, _, _)| subject.clone())
237 .collect()
238 }
239}
240
241pub trait RateLimited {
243 fn is_rate_limited(&mut self, subject_id: &str, operation: &str) -> Result<()>;
245
246 fn rate_limit_stats(&self) -> RateLimitStats;
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use std::thread;
254 use std::time::Duration as StdDuration;
255
256 #[test]
257 fn test_rate_limit_basic() {
258 let global_config = RateLimitConfig {
259 max_operations: 5,
260 window_duration: Duration::from_secs(1),
261 enabled: true,
262 };
263
264 let subject_config = RateLimitConfig {
265 max_operations: 3,
266 window_duration: Duration::from_secs(1),
267 enabled: true,
268 };
269
270 let mut limiter = RateLimiter::new(global_config, subject_config);
271
272 for _ in 0..3 {
274 limiter.check_permission_rate_limit("user1").unwrap();
275 }
276
277 assert!(limiter.check_permission_rate_limit("user1").is_err());
279 }
280
281 #[test]
282 fn test_rate_limit_window_reset() {
283 let global_config = RateLimitConfig {
284 max_operations: 100,
285 window_duration: Duration::from_millis(100),
286 enabled: true,
287 };
288
289 let subject_config = RateLimitConfig {
290 max_operations: 2,
291 window_duration: Duration::from_millis(100),
292 enabled: true,
293 };
294
295 let mut limiter = RateLimiter::new(global_config, subject_config);
296
297 limiter.check_permission_rate_limit("user1").unwrap();
299 limiter.check_permission_rate_limit("user1").unwrap();
300 assert!(limiter.check_permission_rate_limit("user1").is_err());
301
302 thread::sleep(StdDuration::from_millis(150));
304
305 limiter.check_permission_rate_limit("user1").unwrap();
307 }
308
309 #[test]
310 fn test_rate_limit_disabled() {
311 let global_config = RateLimitConfig {
312 max_operations: 1,
313 window_duration: Duration::from_secs(1),
314 enabled: false,
315 };
316
317 let subject_config = RateLimitConfig {
318 max_operations: 1,
319 window_duration: Duration::from_secs(1),
320 enabled: false,
321 };
322
323 let mut limiter = RateLimiter::new(global_config, subject_config);
324
325 for _ in 0..100 {
327 limiter.check_permission_rate_limit("user1").unwrap();
328 }
329 }
330
331 #[test]
332 fn test_role_assignment_rate_limit() {
333 let global_config = RateLimitConfig::default();
334
335 let subject_config = RateLimitConfig {
336 max_operations: 100,
337 window_duration: Duration::from_secs(1),
338 enabled: true,
339 };
340
341 let mut limiter = RateLimiter::new(global_config, subject_config);
342
343 for _ in 0..10 {
345 limiter.check_role_assignment_rate_limit("user1").unwrap();
346 }
347
348 assert!(limiter.check_role_assignment_rate_limit("user1").is_err());
350 }
351
352 #[test]
353 fn test_usage_stats() {
354 let global_config = RateLimitConfig {
355 max_operations: 10,
356 window_duration: Duration::from_secs(1),
357 enabled: true,
358 };
359
360 let subject_config = RateLimitConfig {
361 max_operations: 5,
362 window_duration: Duration::from_secs(1),
363 enabled: true,
364 };
365
366 let mut limiter = RateLimiter::new(global_config, subject_config);
367
368 limiter.check_permission_rate_limit("user1").unwrap();
370 limiter.check_permission_rate_limit("user1").unwrap();
371 limiter.check_permission_rate_limit("user2").unwrap();
372
373 let stats = limiter.usage_stats();
374 assert_eq!(stats.global_usage, 3);
375 assert_eq!(stats.global_limit, 10);
376 assert_eq!(stats.global_usage_percentage(), 30.0);
377
378 assert!(stats.subject_usage.iter().any(|(id, _, _)| id == "user1"));
380 assert!(stats.subject_usage.iter().any(|(id, _, _)| id == "user2"));
381 }
382
383 #[test]
384 fn test_subjects_approaching_limit() {
385 let global_config = RateLimitConfig::default();
386
387 let subject_config = RateLimitConfig {
388 max_operations: 10,
389 window_duration: Duration::from_secs(1),
390 enabled: true,
391 };
392
393 let mut limiter = RateLimiter::new(global_config, subject_config);
394
395 for _ in 0..9 {
397 limiter.check_permission_rate_limit("user1").unwrap();
398 }
399
400 for _ in 0..5 {
402 limiter.check_permission_rate_limit("user2").unwrap();
403 }
404
405 let stats = limiter.usage_stats();
406 let approaching = stats.subjects_approaching_limit(80.0);
407
408 assert!(approaching.contains(&"user1".to_string()));
409 assert!(!approaching.contains(&"user2".to_string()));
410 }
411}