use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub agent_limits: AgentLimits,
pub ip_limits: IpLimits,
pub global_limits: GlobalLimits,
pub cleanup_interval_seconds: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
agent_limits: AgentLimits {
requests_per_minute: 60,
requests_per_hour: 1000,
requests_per_day: 10000,
concurrent_sessions: 5,
bandwidth_mb_per_hour: 1000,
},
ip_limits: IpLimits {
requests_per_minute: 100,
requests_per_hour: 2000,
requests_per_day: 20000,
max_agents_per_ip: 10,
},
global_limits: GlobalLimits {
total_requests_per_minute: 10000,
total_requests_per_hour: 100000,
total_concurrent_sessions: 1000,
},
cleanup_interval_seconds: 60,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentLimits {
pub requests_per_minute: u32,
pub requests_per_hour: u32,
pub requests_per_day: u32,
pub concurrent_sessions: u32,
pub bandwidth_mb_per_hour: u32,
}
impl Default for AgentLimits {
fn default() -> Self {
Self {
requests_per_minute: 100,
requests_per_hour: 1000,
requests_per_day: 10000,
concurrent_sessions: 10,
bandwidth_mb_per_hour: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IpLimits {
pub requests_per_minute: u32,
pub requests_per_hour: u32,
pub requests_per_day: u32,
pub max_agents_per_ip: u32,
}
impl Default for IpLimits {
fn default() -> Self {
Self {
requests_per_minute: 1000,
requests_per_hour: 10000,
requests_per_day: 100000,
max_agents_per_ip: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalLimits {
pub total_requests_per_minute: u32,
pub total_requests_per_hour: u32,
pub total_concurrent_sessions: u32,
}
#[derive(Debug, Clone)]
struct RequestTracker {
minute_requests: Vec<Instant>,
hour_requests: Vec<Instant>,
day_requests: Vec<Instant>,
last_cleanup: Instant,
}
impl RequestTracker {
fn new() -> Self {
Self {
minute_requests: Vec::new(),
hour_requests: Vec::new(),
day_requests: Vec::new(),
last_cleanup: Instant::now(),
}
}
fn add_request(&mut self) {
let now = Instant::now();
self.minute_requests.push(now);
self.hour_requests.push(now);
self.day_requests.push(now);
if now.duration_since(self.last_cleanup) > Duration::from_secs(30) {
self.cleanup_old_requests(now);
self.last_cleanup = now;
}
}
fn cleanup_old_requests(&mut self, now: Instant) {
let one_minute_ago = now - Duration::from_secs(60);
let one_hour_ago = now - Duration::from_secs(3600);
let one_day_ago = now - Duration::from_secs(86400);
self.minute_requests.retain(|&time| time > one_minute_ago);
self.hour_requests.retain(|&time| time > one_hour_ago);
self.day_requests.retain(|&time| time > one_day_ago);
}
fn get_counts(&self) -> (usize, usize, usize) {
(
self.minute_requests.len(),
self.hour_requests.len(),
self.day_requests.len(),
)
}
}
#[derive(Debug, Clone)]
struct IpTracker {
request_tracker: RequestTracker,
connected_agents: DashMap<String, DateTime<Utc>>,
last_seen: Instant,
}
impl IpTracker {
fn new() -> Self {
Self {
request_tracker: RequestTracker::new(),
connected_agents: DashMap::new(),
last_seen: Instant::now(),
}
}
fn add_agent(&self, agent_id: String) {
self.connected_agents.insert(agent_id.clone(), Utc::now());
self.cleanup_old_agents();
}
fn cleanup_old_agents(&self) {
let cutoff = Utc::now() - chrono::Duration::hours(24);
self.connected_agents
.retain(|_, &mut timestamp| timestamp > cutoff);
}
fn get_agent_count(&self) -> usize {
self.connected_agents.len()
}
}
pub struct RateLimiter {
config: RateLimitConfig,
agent_trackers: DashMap<String, RequestTracker>,
ip_trackers: DashMap<IpAddr, IpTracker>,
global_tracker: Arc<RwLock<RequestTracker>>,
active_sessions: Arc<RwLock<DashMap<String, Instant>>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
agent_trackers: DashMap::new(),
ip_trackers: DashMap::new(),
global_tracker: Arc::new(RwLock::new(RequestTracker::new())),
active_sessions: Arc::new(RwLock::new(DashMap::new())),
}
}
pub async fn check_agent_request(
&self,
agent_id: &str,
ip: IpAddr,
) -> Result<(), RateLimitError> {
self.check_global_limits().await?;
self.check_ip_limits(ip).await?;
self.check_agent_limits(agent_id).await?;
self.record_request(agent_id, ip).await;
Ok(())
}
pub async fn check_session_creation(
&self,
agent_id: &str,
_ip: IpAddr,
) -> Result<(), RateLimitError> {
let sessions = self.active_sessions.read().await;
let agent_sessions = sessions
.iter()
.filter(|entry| entry.key().starts_with(agent_id))
.count();
if agent_sessions >= self.config.agent_limits.concurrent_sessions as usize {
return Err(RateLimitError::AgentSessionLimitExceeded {
agent_id: agent_id.to_string(),
current: agent_sessions,
limit: self.config.agent_limits.concurrent_sessions,
});
}
let global_sessions = sessions.len();
if global_sessions >= self.config.global_limits.total_concurrent_sessions as usize {
return Err(RateLimitError::GlobalSessionLimitExceeded {
current: global_sessions,
limit: self.config.global_limits.total_concurrent_sessions,
});
}
Ok(())
}
pub async fn add_session(&self, session_id: String) {
let sessions = self.active_sessions.write().await;
sessions.insert(session_id, Instant::now());
}
pub async fn remove_session(&self, session_id: &str) {
let sessions = self.active_sessions.write().await;
sessions.remove(session_id);
}
async fn check_global_limits(&self) -> Result<(), RateLimitError> {
let tracker = self.global_tracker.read().await;
let (minute_count, hour_count, _day_count) = tracker.get_counts();
if minute_count >= self.config.global_limits.total_requests_per_minute as usize {
return Err(RateLimitError::GlobalMinuteLimitExceeded {
current: minute_count,
limit: self.config.global_limits.total_requests_per_minute,
});
}
if hour_count >= self.config.global_limits.total_requests_per_hour as usize {
return Err(RateLimitError::GlobalHourLimitExceeded {
current: hour_count,
limit: self.config.global_limits.total_requests_per_hour,
});
}
Ok(())
}
async fn check_ip_limits(&self, ip: IpAddr) -> Result<(), RateLimitError> {
let ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
let (minute_count, hour_count, day_count) = ip_tracker.request_tracker.get_counts();
if minute_count >= self.config.ip_limits.requests_per_minute as usize {
return Err(RateLimitError::IpMinuteLimitExceeded {
ip,
current: minute_count,
limit: self.config.ip_limits.requests_per_minute,
});
}
if hour_count >= self.config.ip_limits.requests_per_hour as usize {
return Err(RateLimitError::IpHourLimitExceeded {
ip,
current: hour_count,
limit: self.config.ip_limits.requests_per_hour,
});
}
if day_count >= self.config.ip_limits.requests_per_day as usize {
return Err(RateLimitError::IpDayLimitExceeded {
ip,
current: day_count,
limit: self.config.ip_limits.requests_per_day,
});
}
let agent_count = ip_tracker.get_agent_count();
if agent_count >= self.config.ip_limits.max_agents_per_ip as usize {
return Err(RateLimitError::IpAgentLimitExceeded {
ip,
current: agent_count,
limit: self.config.ip_limits.max_agents_per_ip,
});
}
Ok(())
}
async fn check_agent_limits(&self, agent_id: &str) -> Result<(), RateLimitError> {
let tracker = self
.agent_trackers
.entry(agent_id.to_string())
.or_insert_with(RequestTracker::new);
let (minute_count, hour_count, day_count) = tracker.get_counts();
if minute_count >= self.config.agent_limits.requests_per_minute as usize {
return Err(RateLimitError::AgentMinuteLimitExceeded {
agent_id: agent_id.to_string(),
current: minute_count,
limit: self.config.agent_limits.requests_per_minute,
});
}
if hour_count >= self.config.agent_limits.requests_per_hour as usize {
return Err(RateLimitError::AgentHourLimitExceeded {
agent_id: agent_id.to_string(),
current: hour_count,
limit: self.config.agent_limits.requests_per_hour,
});
}
if day_count >= self.config.agent_limits.requests_per_day as usize {
return Err(RateLimitError::AgentDayLimitExceeded {
agent_id: agent_id.to_string(),
current: day_count,
limit: self.config.agent_limits.requests_per_day,
});
}
Ok(())
}
async fn record_request(&self, agent_id: &str, ip: IpAddr) {
{
let mut tracker = self.global_tracker.write().await;
tracker.add_request();
}
{
let mut ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
ip_tracker.request_tracker.add_request();
ip_tracker.add_agent(agent_id.to_string());
ip_tracker.last_seen = Instant::now();
}
{
let mut agent_tracker = self
.agent_trackers
.entry(agent_id.to_string())
.or_insert_with(RequestTracker::new);
agent_tracker.add_request();
}
debug!("Recorded request for agent {} from IP {}", agent_id, ip);
}
pub async fn get_rate_limit_stats(&self) -> RateLimitStats {
let global_tracker = self.global_tracker.read().await;
let (global_minute, global_hour, global_day) = global_tracker.get_counts();
let active_sessions = self.active_sessions.read().await;
let session_count = active_sessions.len();
let ip_count = self.ip_trackers.len();
let agent_count = self.agent_trackers.len();
RateLimitStats {
global_requests_per_minute: global_minute,
global_requests_per_hour: global_hour,
global_requests_per_day: global_day,
active_sessions: session_count,
unique_ips: ip_count,
unique_agents: agent_count,
}
}
pub async fn cleanup_expired_data(&self) {
let now = Instant::now();
self.ip_trackers.retain(|_, ip_tracker| {
now.duration_since(ip_tracker.last_seen) < Duration::from_secs(86400)
});
let sessions = self.active_sessions.write().await;
sessions.retain(|_, created_at| {
now.duration_since(*created_at) < Duration::from_secs(7200) });
debug!("Cleaned up expired rate limiting data");
}
pub async fn start_cleanup_task(&self) {
let config = self.config.clone();
let rate_limiter = self.clone();
tokio::spawn(async move {
let mut interval =
tokio::time::interval(Duration::from_secs(config.cleanup_interval_seconds));
loop {
interval.tick().await;
rate_limiter.cleanup_expired_data().await;
}
});
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitStats {
pub global_requests_per_minute: usize,
pub global_requests_per_hour: usize,
pub global_requests_per_day: usize,
pub active_sessions: usize,
pub unique_ips: usize,
pub unique_agents: usize,
}
#[derive(Debug, thiserror::Error, Clone, Serialize, Deserialize)]
pub enum RateLimitError {
#[error("Agent {agent_id} exceeded minute limit: {current}/{limit}")]
AgentMinuteLimitExceeded {
agent_id: String,
current: usize,
limit: u32,
},
#[error("Agent {agent_id} exceeded hour limit: {current}/{limit}")]
AgentHourLimitExceeded {
agent_id: String,
current: usize,
limit: u32,
},
#[error("Agent {agent_id} exceeded day limit: {current}/{limit}")]
AgentDayLimitExceeded {
agent_id: String,
current: usize,
limit: u32,
},
#[error("Agent {agent_id} exceeded session limit: {current}/{limit}")]
AgentSessionLimitExceeded {
agent_id: String,
current: usize,
limit: u32,
},
#[error("IP {ip} exceeded minute limit: {current}/{limit}")]
IpMinuteLimitExceeded {
ip: std::net::IpAddr,
current: usize,
limit: u32,
},
#[error("IP {ip} exceeded hour limit: {current}/{limit}")]
IpHourLimitExceeded {
ip: std::net::IpAddr,
current: usize,
limit: u32,
},
#[error("IP {ip} exceeded day limit: {current}/{limit}")]
IpDayLimitExceeded {
ip: std::net::IpAddr,
current: usize,
limit: u32,
},
#[error("IP {ip} exceeded agent limit: {current}/{limit}")]
IpAgentLimitExceeded {
ip: std::net::IpAddr,
current: usize,
limit: u32,
},
#[error("Global minute limit exceeded: {current}/{limit}")]
GlobalMinuteLimitExceeded { current: usize, limit: u32 },
#[error("Global hour limit exceeded: {current}/{limit}")]
GlobalHourLimitExceeded { current: usize, limit: u32 },
#[error("Global session limit exceeded: {current}/{limit}")]
GlobalSessionLimitExceeded { current: usize, limit: u32 },
}
impl Clone for RateLimiter {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
agent_trackers: DashMap::new(),
ip_trackers: DashMap::new(),
global_tracker: Arc::clone(&self.global_tracker),
active_sessions: Arc::clone(&self.active_sessions),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[tokio::test]
async fn test_agent_rate_limiting() {
let config = RateLimitConfig {
agent_limits: AgentLimits {
requests_per_minute: 2,
..Default::default()
},
..Default::default()
};
let rate_limiter = RateLimiter::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let agent_id = "test_agent";
assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
assert!(rate_limiter
.check_agent_request(agent_id, ip)
.await
.is_err());
}
#[tokio::test]
async fn test_ip_rate_limiting() {
let config = RateLimitConfig {
ip_limits: IpLimits {
requests_per_minute: 2,
..Default::default()
},
..Default::default()
};
let rate_limiter = RateLimiter::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
assert!(rate_limiter.check_agent_request("agent1", ip).await.is_ok());
assert!(rate_limiter.check_agent_request("agent2", ip).await.is_ok());
assert!(rate_limiter
.check_agent_request("agent3", ip)
.await
.is_err());
}
#[tokio::test]
async fn test_session_limits() {
let config = RateLimitConfig {
agent_limits: AgentLimits {
concurrent_sessions: 1,
..Default::default()
},
..Default::default()
};
let rate_limiter = RateLimiter::new(config);
let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let agent_id = "test_agent";
assert!(rate_limiter
.check_session_creation(agent_id, ip)
.await
.is_ok());
rate_limiter
.add_session(format!("{}_session1", agent_id))
.await;
assert!(rate_limiter
.check_session_creation(agent_id, ip)
.await
.is_err());
}
}