use crate::error::{Error, Result};
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_operations: u64,
pub window_duration: Duration,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_operations: 1000,
window_duration: Duration::from_secs(60),
enabled: false,
}
}
}
#[derive(Debug)]
struct RateLimitWindow {
operations: AtomicU64,
window_start: Instant,
config: RateLimitConfig,
}
impl RateLimitWindow {
fn new(config: RateLimitConfig) -> Self {
Self {
operations: AtomicU64::new(0),
window_start: Instant::now(),
config,
}
}
fn check_and_increment(&mut self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let now = Instant::now();
if now.duration_since(self.window_start) >= self.config.window_duration {
self.operations.store(0, Ordering::Relaxed);
self.window_start = now;
}
let current_ops = self.operations.load(Ordering::Relaxed);
if current_ops >= self.config.max_operations {
return Err(Error::RateLimitExceeded {
subject: "system".to_string(), limit: self.config.max_operations,
window: format!("{:?}", self.config.window_duration),
});
}
self.operations.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn current_usage(&self) -> (u64, u64) {
let current = self.operations.load(Ordering::Relaxed);
(current, self.config.max_operations)
}
}
#[derive(Debug)]
pub struct RateLimiter {
subject_windows: DashMap<String, RateLimitWindow>,
global_window: RateLimitWindow,
subject_config: RateLimitConfig,
}
impl RateLimiter {
pub fn new(global_config: RateLimitConfig, subject_config: RateLimitConfig) -> Self {
Self {
subject_windows: DashMap::new(),
global_window: RateLimitWindow::new(global_config),
subject_config,
}
}
pub fn check_permission_rate_limit(&mut self, subject_id: &str) -> Result<()> {
self.global_window
.check_and_increment()
.map_err(|_| Error::RateLimitExceeded {
subject: "global".to_string(),
limit: self.global_window.config.max_operations,
window: format!("{:?}", self.global_window.config.window_duration),
})?;
if self.subject_config.enabled {
let mut window = self
.subject_windows
.entry(subject_id.to_string())
.or_insert_with(|| RateLimitWindow::new(self.subject_config.clone()));
window
.check_and_increment()
.map_err(|_| Error::RateLimitExceeded {
subject: subject_id.to_string(),
limit: self.subject_config.max_operations,
window: format!("{:?}", self.subject_config.window_duration),
})?;
}
Ok(())
}
pub fn check_role_assignment_rate_limit(&mut self, subject_id: &str) -> Result<()> {
let role_config = RateLimitConfig {
max_operations: self.subject_config.max_operations / 10,
window_duration: self.subject_config.window_duration,
enabled: self.subject_config.enabled,
};
if role_config.enabled {
let mut window = self
.subject_windows
.entry(format!("role_assignment:{}", subject_id))
.or_insert_with(|| RateLimitWindow::new(role_config.clone()));
window
.check_and_increment()
.map_err(|_| Error::RateLimitExceeded {
subject: subject_id.to_string(),
limit: role_config.max_operations,
window: format!("{:?}", role_config.window_duration),
})?;
}
Ok(())
}
pub fn usage_stats(&self) -> RateLimitStats {
let global_usage = self.global_window.current_usage();
let mut subject_usage = Vec::new();
for entry in self.subject_windows.iter() {
let (subject, window) = (entry.key(), entry.value());
let usage = window.current_usage();
subject_usage.push((subject.clone(), usage.0, usage.1));
}
RateLimitStats {
global_usage: global_usage.0,
global_limit: global_usage.1,
subject_usage,
}
}
pub fn reset_subject(&self, subject_id: &str) {
self.subject_windows.remove(subject_id);
self.subject_windows
.remove(&format!("role_assignment:{}", subject_id));
}
pub fn cleanup_expired(&self) {
let now = Instant::now();
let mut expired_keys = Vec::new();
for entry in self.subject_windows.iter() {
let (key, window) = (entry.key(), entry.value());
if now.duration_since(window.window_start) >= window.config.window_duration * 2 {
expired_keys.push(key.clone());
}
}
for key in expired_keys {
self.subject_windows.remove(&key);
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub global_usage: u64,
pub global_limit: u64,
pub subject_usage: Vec<(String, u64, u64)>,
}
impl RateLimitStats {
pub fn global_usage_percentage(&self) -> f64 {
if self.global_limit == 0 {
0.0
} else {
(self.global_usage as f64 / self.global_limit as f64) * 100.0
}
}
pub fn subjects_approaching_limit(&self, threshold_percentage: f64) -> Vec<String> {
self.subject_usage
.iter()
.filter(|(_, current, limit)| {
if *limit == 0 {
false
} else {
let percentage = (*current as f64 / *limit as f64) * 100.0;
percentage >= threshold_percentage
}
})
.map(|(subject, _, _)| subject.clone())
.collect()
}
}
pub trait RateLimited {
fn is_rate_limited(&mut self, subject_id: &str, operation: &str) -> Result<()>;
fn rate_limit_stats(&self) -> RateLimitStats;
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration as StdDuration;
#[test]
fn test_rate_limit_basic() {
let global_config = RateLimitConfig {
max_operations: 5,
window_duration: Duration::from_secs(1),
enabled: true,
};
let subject_config = RateLimitConfig {
max_operations: 3,
window_duration: Duration::from_secs(1),
enabled: true,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
for _ in 0..3 {
limiter.check_permission_rate_limit("user1").unwrap();
}
assert!(limiter.check_permission_rate_limit("user1").is_err());
}
#[test]
fn test_rate_limit_window_reset() {
let global_config = RateLimitConfig {
max_operations: 100,
window_duration: Duration::from_millis(100),
enabled: true,
};
let subject_config = RateLimitConfig {
max_operations: 2,
window_duration: Duration::from_millis(100),
enabled: true,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
limiter.check_permission_rate_limit("user1").unwrap();
limiter.check_permission_rate_limit("user1").unwrap();
assert!(limiter.check_permission_rate_limit("user1").is_err());
thread::sleep(StdDuration::from_millis(150));
limiter.check_permission_rate_limit("user1").unwrap();
}
#[test]
fn test_rate_limit_disabled() {
let global_config = RateLimitConfig {
max_operations: 1,
window_duration: Duration::from_secs(1),
enabled: false,
};
let subject_config = RateLimitConfig {
max_operations: 1,
window_duration: Duration::from_secs(1),
enabled: false,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
for _ in 0..100 {
limiter.check_permission_rate_limit("user1").unwrap();
}
}
#[test]
fn test_role_assignment_rate_limit() {
let global_config = RateLimitConfig::default();
let subject_config = RateLimitConfig {
max_operations: 100,
window_duration: Duration::from_secs(1),
enabled: true,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
for _ in 0..10 {
limiter.check_role_assignment_rate_limit("user1").unwrap();
}
assert!(limiter.check_role_assignment_rate_limit("user1").is_err());
}
#[test]
fn test_usage_stats() {
let global_config = RateLimitConfig {
max_operations: 10,
window_duration: Duration::from_secs(1),
enabled: true,
};
let subject_config = RateLimitConfig {
max_operations: 5,
window_duration: Duration::from_secs(1),
enabled: true,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
limiter.check_permission_rate_limit("user1").unwrap();
limiter.check_permission_rate_limit("user1").unwrap();
limiter.check_permission_rate_limit("user2").unwrap();
let stats = limiter.usage_stats();
assert_eq!(stats.global_usage, 3);
assert_eq!(stats.global_limit, 10);
assert_eq!(stats.global_usage_percentage(), 30.0);
assert!(stats.subject_usage.iter().any(|(id, _, _)| id == "user1"));
assert!(stats.subject_usage.iter().any(|(id, _, _)| id == "user2"));
}
#[test]
fn test_subjects_approaching_limit() {
let global_config = RateLimitConfig::default();
let subject_config = RateLimitConfig {
max_operations: 10,
window_duration: Duration::from_secs(1),
enabled: true,
};
let mut limiter = RateLimiter::new(global_config, subject_config);
for _ in 0..9 {
limiter.check_permission_rate_limit("user1").unwrap();
}
for _ in 0..5 {
limiter.check_permission_rate_limit("user2").unwrap();
}
let stats = limiter.usage_stats();
let approaching = stats.subjects_approaching_limit(80.0);
assert!(approaching.contains(&"user1".to_string()));
assert!(!approaching.contains(&"user2".to_string()));
}
}