1use chrono::{DateTime, Utc};
2use dashmap::DashMap;
3use serde::{Deserialize, Serialize};
4use std::net::IpAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct RateLimitConfig {
12 pub agent_limits: AgentLimits,
13 pub ip_limits: IpLimits,
14 pub global_limits: GlobalLimits,
15 pub cleanup_interval_seconds: u64,
16}
17
18impl Default for RateLimitConfig {
19 fn default() -> Self {
20 Self {
21 agent_limits: AgentLimits {
22 requests_per_minute: 60,
23 requests_per_hour: 1000,
24 requests_per_day: 10000,
25 concurrent_sessions: 5,
26 bandwidth_mb_per_hour: 1000,
27 },
28 ip_limits: IpLimits {
29 requests_per_minute: 100,
30 requests_per_hour: 2000,
31 requests_per_day: 20000,
32 max_agents_per_ip: 10,
33 },
34 global_limits: GlobalLimits {
35 total_requests_per_minute: 10000,
36 total_requests_per_hour: 100000,
37 total_concurrent_sessions: 1000,
38 },
39 cleanup_interval_seconds: 60,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AgentLimits {
46 pub requests_per_minute: u32,
47 pub requests_per_hour: u32,
48 pub requests_per_day: u32,
49 pub concurrent_sessions: u32,
50 pub bandwidth_mb_per_hour: u32,
51}
52
53impl Default for AgentLimits {
54 fn default() -> Self {
55 Self {
56 requests_per_minute: 100,
57 requests_per_hour: 1000,
58 requests_per_day: 10000,
59 concurrent_sessions: 10,
60 bandwidth_mb_per_hour: 1000,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct IpLimits {
67 pub requests_per_minute: u32,
68 pub requests_per_hour: u32,
69 pub requests_per_day: u32,
70 pub max_agents_per_ip: u32,
71}
72
73impl Default for IpLimits {
74 fn default() -> Self {
75 Self {
76 requests_per_minute: 1000,
77 requests_per_hour: 10000,
78 requests_per_day: 100000,
79 max_agents_per_ip: 100,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GlobalLimits {
86 pub total_requests_per_minute: u32,
87 pub total_requests_per_hour: u32,
88 pub total_concurrent_sessions: u32,
89}
90
91#[derive(Debug, Clone)]
92struct RequestTracker {
93 minute_requests: Vec<Instant>,
94 hour_requests: Vec<Instant>,
95 day_requests: Vec<Instant>,
96 last_cleanup: Instant,
97}
98
99impl RequestTracker {
100 fn new() -> Self {
101 Self {
102 minute_requests: Vec::new(),
103 hour_requests: Vec::new(),
104 day_requests: Vec::new(),
105 last_cleanup: Instant::now(),
106 }
107 }
108
109 fn add_request(&mut self) {
110 let now = Instant::now();
111 self.minute_requests.push(now);
112 self.hour_requests.push(now);
113 self.day_requests.push(now);
114
115 if now.duration_since(self.last_cleanup) > Duration::from_secs(30) {
117 self.cleanup_old_requests(now);
118 self.last_cleanup = now;
119 }
120 }
121
122 fn cleanup_old_requests(&mut self, now: Instant) {
123 let one_minute_ago = now - Duration::from_secs(60);
124 let one_hour_ago = now - Duration::from_secs(3600);
125 let one_day_ago = now - Duration::from_secs(86400);
126
127 self.minute_requests.retain(|&time| time > one_minute_ago);
128 self.hour_requests.retain(|&time| time > one_hour_ago);
129 self.day_requests.retain(|&time| time > one_day_ago);
130 }
131
132 fn get_counts(&self) -> (usize, usize, usize) {
133 (
134 self.minute_requests.len(),
135 self.hour_requests.len(),
136 self.day_requests.len(),
137 )
138 }
139}
140
141#[derive(Debug, Clone)]
142struct IpTracker {
143 request_tracker: RequestTracker,
144 connected_agents: DashMap<String, DateTime<Utc>>,
145 last_seen: Instant,
146}
147
148impl IpTracker {
149 fn new() -> Self {
150 Self {
151 request_tracker: RequestTracker::new(),
152 connected_agents: DashMap::new(),
153 last_seen: Instant::now(),
154 }
155 }
156
157 fn add_agent(&self, agent_id: String) {
158 self.connected_agents.insert(agent_id.clone(), Utc::now());
159 self.cleanup_old_agents();
160 }
161
162 fn cleanup_old_agents(&self) {
163 let cutoff = Utc::now() - chrono::Duration::hours(24);
164 self.connected_agents
165 .retain(|_, &mut timestamp| timestamp > cutoff);
166 }
167
168 fn get_agent_count(&self) -> usize {
169 self.connected_agents.len()
170 }
171}
172
173pub struct RateLimiter {
174 config: RateLimitConfig,
175 agent_trackers: DashMap<String, RequestTracker>,
176 ip_trackers: DashMap<IpAddr, IpTracker>,
177 global_tracker: Arc<RwLock<RequestTracker>>,
178 active_sessions: Arc<RwLock<DashMap<String, Instant>>>,
179}
180
181impl RateLimiter {
182 pub fn new(config: RateLimitConfig) -> Self {
183 Self {
184 config,
185 agent_trackers: DashMap::new(),
186 ip_trackers: DashMap::new(),
187 global_tracker: Arc::new(RwLock::new(RequestTracker::new())),
188 active_sessions: Arc::new(RwLock::new(DashMap::new())),
189 }
190 }
191
192 pub async fn check_agent_request(
193 &self,
194 agent_id: &str,
195 ip: IpAddr,
196 ) -> Result<(), RateLimitError> {
197 self.check_global_limits().await?;
199
200 self.check_ip_limits(ip).await?;
202
203 self.check_agent_limits(agent_id).await?;
205
206 self.record_request(agent_id, ip).await;
208
209 Ok(())
210 }
211
212 pub async fn check_session_creation(
213 &self,
214 agent_id: &str,
215 _ip: IpAddr,
216 ) -> Result<(), RateLimitError> {
217 let sessions = self.active_sessions.read().await;
219 let agent_sessions = sessions
220 .iter()
221 .filter(|entry| entry.key().starts_with(agent_id))
222 .count();
223
224 if agent_sessions >= self.config.agent_limits.concurrent_sessions as usize {
225 return Err(RateLimitError::AgentSessionLimitExceeded {
226 agent_id: agent_id.to_string(),
227 current: agent_sessions,
228 limit: self.config.agent_limits.concurrent_sessions,
229 });
230 }
231
232 let global_sessions = sessions.len();
233 if global_sessions >= self.config.global_limits.total_concurrent_sessions as usize {
234 return Err(RateLimitError::GlobalSessionLimitExceeded {
235 current: global_sessions,
236 limit: self.config.global_limits.total_concurrent_sessions,
237 });
238 }
239
240 Ok(())
241 }
242
243 pub async fn add_session(&self, session_id: String) {
244 let sessions = self.active_sessions.write().await;
245 sessions.insert(session_id, Instant::now());
246 }
247
248 pub async fn remove_session(&self, session_id: &str) {
249 let sessions = self.active_sessions.write().await;
250 sessions.remove(session_id);
251 }
252
253 async fn check_global_limits(&self) -> Result<(), RateLimitError> {
254 let tracker = self.global_tracker.read().await;
255 let (minute_count, hour_count, _day_count) = tracker.get_counts();
256
257 if minute_count >= self.config.global_limits.total_requests_per_minute as usize {
258 return Err(RateLimitError::GlobalMinuteLimitExceeded {
259 current: minute_count,
260 limit: self.config.global_limits.total_requests_per_minute,
261 });
262 }
263
264 if hour_count >= self.config.global_limits.total_requests_per_hour as usize {
265 return Err(RateLimitError::GlobalHourLimitExceeded {
266 current: hour_count,
267 limit: self.config.global_limits.total_requests_per_hour,
268 });
269 }
270
271 Ok(())
272 }
273
274 async fn check_ip_limits(&self, ip: IpAddr) -> Result<(), RateLimitError> {
275 let ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
276 let (minute_count, hour_count, day_count) = ip_tracker.request_tracker.get_counts();
277
278 if minute_count >= self.config.ip_limits.requests_per_minute as usize {
279 return Err(RateLimitError::IpMinuteLimitExceeded {
280 ip,
281 current: minute_count,
282 limit: self.config.ip_limits.requests_per_minute,
283 });
284 }
285
286 if hour_count >= self.config.ip_limits.requests_per_hour as usize {
287 return Err(RateLimitError::IpHourLimitExceeded {
288 ip,
289 current: hour_count,
290 limit: self.config.ip_limits.requests_per_hour,
291 });
292 }
293
294 if day_count >= self.config.ip_limits.requests_per_day as usize {
295 return Err(RateLimitError::IpDayLimitExceeded {
296 ip,
297 current: day_count,
298 limit: self.config.ip_limits.requests_per_day,
299 });
300 }
301
302 let agent_count = ip_tracker.get_agent_count();
303 if agent_count >= self.config.ip_limits.max_agents_per_ip as usize {
304 return Err(RateLimitError::IpAgentLimitExceeded {
305 ip,
306 current: agent_count,
307 limit: self.config.ip_limits.max_agents_per_ip,
308 });
309 }
310
311 Ok(())
312 }
313
314 async fn check_agent_limits(&self, agent_id: &str) -> Result<(), RateLimitError> {
315 let tracker = self
316 .agent_trackers
317 .entry(agent_id.to_string())
318 .or_insert_with(RequestTracker::new);
319 let (minute_count, hour_count, day_count) = tracker.get_counts();
320
321 if minute_count >= self.config.agent_limits.requests_per_minute as usize {
322 return Err(RateLimitError::AgentMinuteLimitExceeded {
323 agent_id: agent_id.to_string(),
324 current: minute_count,
325 limit: self.config.agent_limits.requests_per_minute,
326 });
327 }
328
329 if hour_count >= self.config.agent_limits.requests_per_hour as usize {
330 return Err(RateLimitError::AgentHourLimitExceeded {
331 agent_id: agent_id.to_string(),
332 current: hour_count,
333 limit: self.config.agent_limits.requests_per_hour,
334 });
335 }
336
337 if day_count >= self.config.agent_limits.requests_per_day as usize {
338 return Err(RateLimitError::AgentDayLimitExceeded {
339 agent_id: agent_id.to_string(),
340 current: day_count,
341 limit: self.config.agent_limits.requests_per_day,
342 });
343 }
344
345 Ok(())
346 }
347
348 async fn record_request(&self, agent_id: &str, ip: IpAddr) {
349 {
351 let mut tracker = self.global_tracker.write().await;
352 tracker.add_request();
353 }
354
355 {
357 let mut ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
358 ip_tracker.request_tracker.add_request();
359 ip_tracker.add_agent(agent_id.to_string());
360 ip_tracker.last_seen = Instant::now();
361 }
362
363 {
365 let mut agent_tracker = self
366 .agent_trackers
367 .entry(agent_id.to_string())
368 .or_insert_with(RequestTracker::new);
369 agent_tracker.add_request();
370 }
371
372 debug!("Recorded request for agent {} from IP {}", agent_id, ip);
373 }
374
375 pub async fn get_rate_limit_stats(&self) -> RateLimitStats {
376 let global_tracker = self.global_tracker.read().await;
377 let (global_minute, global_hour, global_day) = global_tracker.get_counts();
378
379 let active_sessions = self.active_sessions.read().await;
380 let session_count = active_sessions.len();
381
382 let ip_count = self.ip_trackers.len();
383 let agent_count = self.agent_trackers.len();
384
385 RateLimitStats {
386 global_requests_per_minute: global_minute,
387 global_requests_per_hour: global_hour,
388 global_requests_per_day: global_day,
389 active_sessions: session_count,
390 unique_ips: ip_count,
391 unique_agents: agent_count,
392 }
393 }
394
395 pub async fn cleanup_expired_data(&self) {
396 let now = Instant::now();
397
398 self.ip_trackers.retain(|_, ip_tracker| {
400 now.duration_since(ip_tracker.last_seen) < Duration::from_secs(86400)
401 });
403
404 let sessions = self.active_sessions.write().await;
406 sessions.retain(|_, created_at| {
407 now.duration_since(*created_at) < Duration::from_secs(7200) });
409
410 debug!("Cleaned up expired rate limiting data");
411 }
412
413 pub async fn start_cleanup_task(&self) {
414 let config = self.config.clone();
415 let rate_limiter = self.clone(); tokio::spawn(async move {
418 let mut interval =
419 tokio::time::interval(Duration::from_secs(config.cleanup_interval_seconds));
420
421 loop {
422 interval.tick().await;
423 rate_limiter.cleanup_expired_data().await;
424 }
425 });
426 }
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct RateLimitStats {
431 pub global_requests_per_minute: usize,
432 pub global_requests_per_hour: usize,
433 pub global_requests_per_day: usize,
434 pub active_sessions: usize,
435 pub unique_ips: usize,
436 pub unique_agents: usize,
437}
438
439#[derive(Debug, thiserror::Error, Clone, Serialize, Deserialize)]
440pub enum RateLimitError {
441 #[error("Agent {agent_id} exceeded minute limit: {current}/{limit}")]
442 AgentMinuteLimitExceeded {
443 agent_id: String,
444 current: usize,
445 limit: u32,
446 },
447
448 #[error("Agent {agent_id} exceeded hour limit: {current}/{limit}")]
449 AgentHourLimitExceeded {
450 agent_id: String,
451 current: usize,
452 limit: u32,
453 },
454
455 #[error("Agent {agent_id} exceeded day limit: {current}/{limit}")]
456 AgentDayLimitExceeded {
457 agent_id: String,
458 current: usize,
459 limit: u32,
460 },
461
462 #[error("Agent {agent_id} exceeded session limit: {current}/{limit}")]
463 AgentSessionLimitExceeded {
464 agent_id: String,
465 current: usize,
466 limit: u32,
467 },
468
469 #[error("IP {ip} exceeded minute limit: {current}/{limit}")]
470 IpMinuteLimitExceeded {
471 ip: std::net::IpAddr,
472 current: usize,
473 limit: u32,
474 },
475
476 #[error("IP {ip} exceeded hour limit: {current}/{limit}")]
477 IpHourLimitExceeded {
478 ip: std::net::IpAddr,
479 current: usize,
480 limit: u32,
481 },
482
483 #[error("IP {ip} exceeded day limit: {current}/{limit}")]
484 IpDayLimitExceeded {
485 ip: std::net::IpAddr,
486 current: usize,
487 limit: u32,
488 },
489
490 #[error("IP {ip} exceeded agent limit: {current}/{limit}")]
491 IpAgentLimitExceeded {
492 ip: std::net::IpAddr,
493 current: usize,
494 limit: u32,
495 },
496
497 #[error("Global minute limit exceeded: {current}/{limit}")]
498 GlobalMinuteLimitExceeded { current: usize, limit: u32 },
499
500 #[error("Global hour limit exceeded: {current}/{limit}")]
501 GlobalHourLimitExceeded { current: usize, limit: u32 },
502
503 #[error("Global session limit exceeded: {current}/{limit}")]
504 GlobalSessionLimitExceeded { current: usize, limit: u32 },
505}
506
507impl Clone for RateLimiter {
509 fn clone(&self) -> Self {
510 Self {
511 config: self.config.clone(),
512 agent_trackers: DashMap::new(),
513 ip_trackers: DashMap::new(),
514 global_tracker: Arc::clone(&self.global_tracker),
515 active_sessions: Arc::clone(&self.active_sessions),
516 }
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use std::net::Ipv4Addr;
524
525 #[tokio::test]
526 async fn test_agent_rate_limiting() {
527 let config = RateLimitConfig {
528 agent_limits: AgentLimits {
529 requests_per_minute: 2,
530 ..Default::default()
531 },
532 ..Default::default()
533 };
534
535 let rate_limiter = RateLimiter::new(config);
536 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
537 let agent_id = "test_agent";
538
539 assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
541 assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
542
543 assert!(rate_limiter
545 .check_agent_request(agent_id, ip)
546 .await
547 .is_err());
548 }
549
550 #[tokio::test]
551 async fn test_ip_rate_limiting() {
552 let config = RateLimitConfig {
553 ip_limits: IpLimits {
554 requests_per_minute: 2,
555 ..Default::default()
556 },
557 ..Default::default()
558 };
559
560 let rate_limiter = RateLimiter::new(config);
561 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
562
563 assert!(rate_limiter.check_agent_request("agent1", ip).await.is_ok());
565 assert!(rate_limiter.check_agent_request("agent2", ip).await.is_ok());
566
567 assert!(rate_limiter
569 .check_agent_request("agent3", ip)
570 .await
571 .is_err());
572 }
573
574 #[tokio::test]
575 async fn test_session_limits() {
576 let config = RateLimitConfig {
577 agent_limits: AgentLimits {
578 concurrent_sessions: 1,
579 ..Default::default()
580 },
581 ..Default::default()
582 };
583
584 let rate_limiter = RateLimiter::new(config);
585 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
586 let agent_id = "test_agent";
587
588 assert!(rate_limiter
590 .check_session_creation(agent_id, ip)
591 .await
592 .is_ok());
593 rate_limiter
594 .add_session(format!("{}_session1", agent_id))
595 .await;
596
597 assert!(rate_limiter
599 .check_session_creation(agent_id, ip)
600 .await
601 .is_err());
602 }
603}