Skip to main content

ditto_os/security/
rate_limit.rs

1use chrono::{DateTime, Utc};
2use dashmap::DashMap;
3use serde::{Deserialize, Serialize};
4use std::net::IpAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use tracing::debug;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct RateLimitConfig {
12    pub agent_limits: AgentLimits,
13    pub ip_limits: IpLimits,
14    pub global_limits: GlobalLimits,
15    pub cleanup_interval_seconds: u64,
16}
17
18impl Default for RateLimitConfig {
19    fn default() -> Self {
20        Self {
21            agent_limits: AgentLimits {
22                requests_per_minute: 60,
23                requests_per_hour: 1000,
24                requests_per_day: 10000,
25                concurrent_sessions: 5,
26                bandwidth_mb_per_hour: 1000,
27            },
28            ip_limits: IpLimits {
29                requests_per_minute: 100,
30                requests_per_hour: 2000,
31                requests_per_day: 20000,
32                max_agents_per_ip: 10,
33            },
34            global_limits: GlobalLimits {
35                total_requests_per_minute: 10000,
36                total_requests_per_hour: 100000,
37                total_concurrent_sessions: 1000,
38            },
39            cleanup_interval_seconds: 60,
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AgentLimits {
46    pub requests_per_minute: u32,
47    pub requests_per_hour: u32,
48    pub requests_per_day: u32,
49    pub concurrent_sessions: u32,
50    pub bandwidth_mb_per_hour: u32,
51}
52
53impl Default for AgentLimits {
54    fn default() -> Self {
55        Self {
56            requests_per_minute: 100,
57            requests_per_hour: 1000,
58            requests_per_day: 10000,
59            concurrent_sessions: 10,
60            bandwidth_mb_per_hour: 1000,
61        }
62    }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct IpLimits {
67    pub requests_per_minute: u32,
68    pub requests_per_hour: u32,
69    pub requests_per_day: u32,
70    pub max_agents_per_ip: u32,
71}
72
73impl Default for IpLimits {
74    fn default() -> Self {
75        Self {
76            requests_per_minute: 1000,
77            requests_per_hour: 10000,
78            requests_per_day: 100000,
79            max_agents_per_ip: 100,
80        }
81    }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GlobalLimits {
86    pub total_requests_per_minute: u32,
87    pub total_requests_per_hour: u32,
88    pub total_concurrent_sessions: u32,
89}
90
91#[derive(Debug, Clone)]
92struct RequestTracker {
93    minute_requests: Vec<Instant>,
94    hour_requests: Vec<Instant>,
95    day_requests: Vec<Instant>,
96    last_cleanup: Instant,
97}
98
99impl RequestTracker {
100    fn new() -> Self {
101        Self {
102            minute_requests: Vec::new(),
103            hour_requests: Vec::new(),
104            day_requests: Vec::new(),
105            last_cleanup: Instant::now(),
106        }
107    }
108
109    fn add_request(&mut self) {
110        let now = Instant::now();
111        self.minute_requests.push(now);
112        self.hour_requests.push(now);
113        self.day_requests.push(now);
114
115        // Cleanup old requests periodically
116        if now.duration_since(self.last_cleanup) > Duration::from_secs(30) {
117            self.cleanup_old_requests(now);
118            self.last_cleanup = now;
119        }
120    }
121
122    fn cleanup_old_requests(&mut self, now: Instant) {
123        let one_minute_ago = now - Duration::from_secs(60);
124        let one_hour_ago = now - Duration::from_secs(3600);
125        let one_day_ago = now - Duration::from_secs(86400);
126
127        self.minute_requests.retain(|&time| time > one_minute_ago);
128        self.hour_requests.retain(|&time| time > one_hour_ago);
129        self.day_requests.retain(|&time| time > one_day_ago);
130    }
131
132    fn get_counts(&self) -> (usize, usize, usize) {
133        (
134            self.minute_requests.len(),
135            self.hour_requests.len(),
136            self.day_requests.len(),
137        )
138    }
139}
140
141#[derive(Debug, Clone)]
142struct IpTracker {
143    request_tracker: RequestTracker,
144    connected_agents: DashMap<String, DateTime<Utc>>,
145    last_seen: Instant,
146}
147
148impl IpTracker {
149    fn new() -> Self {
150        Self {
151            request_tracker: RequestTracker::new(),
152            connected_agents: DashMap::new(),
153            last_seen: Instant::now(),
154        }
155    }
156
157    fn add_agent(&self, agent_id: String) {
158        self.connected_agents.insert(agent_id.clone(), Utc::now());
159        self.cleanup_old_agents();
160    }
161
162    fn cleanup_old_agents(&self) {
163        let cutoff = Utc::now() - chrono::Duration::hours(24);
164        self.connected_agents
165            .retain(|_, &mut timestamp| timestamp > cutoff);
166    }
167
168    fn get_agent_count(&self) -> usize {
169        self.connected_agents.len()
170    }
171}
172
173pub struct RateLimiter {
174    config: RateLimitConfig,
175    agent_trackers: DashMap<String, RequestTracker>,
176    ip_trackers: DashMap<IpAddr, IpTracker>,
177    global_tracker: Arc<RwLock<RequestTracker>>,
178    active_sessions: Arc<RwLock<DashMap<String, Instant>>>,
179}
180
181impl RateLimiter {
182    pub fn new(config: RateLimitConfig) -> Self {
183        Self {
184            config,
185            agent_trackers: DashMap::new(),
186            ip_trackers: DashMap::new(),
187            global_tracker: Arc::new(RwLock::new(RequestTracker::new())),
188            active_sessions: Arc::new(RwLock::new(DashMap::new())),
189        }
190    }
191
192    pub async fn check_agent_request(
193        &self,
194        agent_id: &str,
195        ip: IpAddr,
196    ) -> Result<(), RateLimitError> {
197        // Check global limits first
198        self.check_global_limits().await?;
199
200        // Check IP limits
201        self.check_ip_limits(ip).await?;
202
203        // Check agent-specific limits
204        self.check_agent_limits(agent_id).await?;
205
206        // Record the request
207        self.record_request(agent_id, ip).await;
208
209        Ok(())
210    }
211
212    pub async fn check_session_creation(
213        &self,
214        agent_id: &str,
215        _ip: IpAddr,
216    ) -> Result<(), RateLimitError> {
217        // Check concurrent session limits
218        let sessions = self.active_sessions.read().await;
219        let agent_sessions = sessions
220            .iter()
221            .filter(|entry| entry.key().starts_with(agent_id))
222            .count();
223
224        if agent_sessions >= self.config.agent_limits.concurrent_sessions as usize {
225            return Err(RateLimitError::AgentSessionLimitExceeded {
226                agent_id: agent_id.to_string(),
227                current: agent_sessions,
228                limit: self.config.agent_limits.concurrent_sessions,
229            });
230        }
231
232        let global_sessions = sessions.len();
233        if global_sessions >= self.config.global_limits.total_concurrent_sessions as usize {
234            return Err(RateLimitError::GlobalSessionLimitExceeded {
235                current: global_sessions,
236                limit: self.config.global_limits.total_concurrent_sessions,
237            });
238        }
239
240        Ok(())
241    }
242
243    pub async fn add_session(&self, session_id: String) {
244        let sessions = self.active_sessions.write().await;
245        sessions.insert(session_id, Instant::now());
246    }
247
248    pub async fn remove_session(&self, session_id: &str) {
249        let sessions = self.active_sessions.write().await;
250        sessions.remove(session_id);
251    }
252
253    async fn check_global_limits(&self) -> Result<(), RateLimitError> {
254        let tracker = self.global_tracker.read().await;
255        let (minute_count, hour_count, _day_count) = tracker.get_counts();
256
257        if minute_count >= self.config.global_limits.total_requests_per_minute as usize {
258            return Err(RateLimitError::GlobalMinuteLimitExceeded {
259                current: minute_count,
260                limit: self.config.global_limits.total_requests_per_minute,
261            });
262        }
263
264        if hour_count >= self.config.global_limits.total_requests_per_hour as usize {
265            return Err(RateLimitError::GlobalHourLimitExceeded {
266                current: hour_count,
267                limit: self.config.global_limits.total_requests_per_hour,
268            });
269        }
270
271        Ok(())
272    }
273
274    async fn check_ip_limits(&self, ip: IpAddr) -> Result<(), RateLimitError> {
275        let ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
276        let (minute_count, hour_count, day_count) = ip_tracker.request_tracker.get_counts();
277
278        if minute_count >= self.config.ip_limits.requests_per_minute as usize {
279            return Err(RateLimitError::IpMinuteLimitExceeded {
280                ip,
281                current: minute_count,
282                limit: self.config.ip_limits.requests_per_minute,
283            });
284        }
285
286        if hour_count >= self.config.ip_limits.requests_per_hour as usize {
287            return Err(RateLimitError::IpHourLimitExceeded {
288                ip,
289                current: hour_count,
290                limit: self.config.ip_limits.requests_per_hour,
291            });
292        }
293
294        if day_count >= self.config.ip_limits.requests_per_day as usize {
295            return Err(RateLimitError::IpDayLimitExceeded {
296                ip,
297                current: day_count,
298                limit: self.config.ip_limits.requests_per_day,
299            });
300        }
301
302        let agent_count = ip_tracker.get_agent_count();
303        if agent_count >= self.config.ip_limits.max_agents_per_ip as usize {
304            return Err(RateLimitError::IpAgentLimitExceeded {
305                ip,
306                current: agent_count,
307                limit: self.config.ip_limits.max_agents_per_ip,
308            });
309        }
310
311        Ok(())
312    }
313
314    async fn check_agent_limits(&self, agent_id: &str) -> Result<(), RateLimitError> {
315        let tracker = self
316            .agent_trackers
317            .entry(agent_id.to_string())
318            .or_insert_with(RequestTracker::new);
319        let (minute_count, hour_count, day_count) = tracker.get_counts();
320
321        if minute_count >= self.config.agent_limits.requests_per_minute as usize {
322            return Err(RateLimitError::AgentMinuteLimitExceeded {
323                agent_id: agent_id.to_string(),
324                current: minute_count,
325                limit: self.config.agent_limits.requests_per_minute,
326            });
327        }
328
329        if hour_count >= self.config.agent_limits.requests_per_hour as usize {
330            return Err(RateLimitError::AgentHourLimitExceeded {
331                agent_id: agent_id.to_string(),
332                current: hour_count,
333                limit: self.config.agent_limits.requests_per_hour,
334            });
335        }
336
337        if day_count >= self.config.agent_limits.requests_per_day as usize {
338            return Err(RateLimitError::AgentDayLimitExceeded {
339                agent_id: agent_id.to_string(),
340                current: day_count,
341                limit: self.config.agent_limits.requests_per_day,
342            });
343        }
344
345        Ok(())
346    }
347
348    async fn record_request(&self, agent_id: &str, ip: IpAddr) {
349        // Record in global tracker
350        {
351            let mut tracker = self.global_tracker.write().await;
352            tracker.add_request();
353        }
354
355        // Record in IP tracker
356        {
357            let mut ip_tracker = self.ip_trackers.entry(ip).or_insert_with(IpTracker::new);
358            ip_tracker.request_tracker.add_request();
359            ip_tracker.add_agent(agent_id.to_string());
360            ip_tracker.last_seen = Instant::now();
361        }
362
363        // Record in agent tracker
364        {
365            let mut agent_tracker = self
366                .agent_trackers
367                .entry(agent_id.to_string())
368                .or_insert_with(RequestTracker::new);
369            agent_tracker.add_request();
370        }
371
372        debug!("Recorded request for agent {} from IP {}", agent_id, ip);
373    }
374
375    pub async fn get_rate_limit_stats(&self) -> RateLimitStats {
376        let global_tracker = self.global_tracker.read().await;
377        let (global_minute, global_hour, global_day) = global_tracker.get_counts();
378
379        let active_sessions = self.active_sessions.read().await;
380        let session_count = active_sessions.len();
381
382        let ip_count = self.ip_trackers.len();
383        let agent_count = self.agent_trackers.len();
384
385        RateLimitStats {
386            global_requests_per_minute: global_minute,
387            global_requests_per_hour: global_hour,
388            global_requests_per_day: global_day,
389            active_sessions: session_count,
390            unique_ips: ip_count,
391            unique_agents: agent_count,
392        }
393    }
394
395    pub async fn cleanup_expired_data(&self) {
396        let now = Instant::now();
397
398        // Cleanup old IP trackers
399        self.ip_trackers.retain(|_, ip_tracker| {
400            now.duration_since(ip_tracker.last_seen) < Duration::from_secs(86400)
401            // 24 hours
402        });
403
404        // Cleanup old sessions
405        let sessions = self.active_sessions.write().await;
406        sessions.retain(|_, created_at| {
407            now.duration_since(*created_at) < Duration::from_secs(7200) // 2 hours
408        });
409
410        debug!("Cleaned up expired rate limiting data");
411    }
412
413    pub async fn start_cleanup_task(&self) {
414        let config = self.config.clone();
415        let rate_limiter = self.clone(); // Note: This would require implementing Clone
416
417        tokio::spawn(async move {
418            let mut interval =
419                tokio::time::interval(Duration::from_secs(config.cleanup_interval_seconds));
420
421            loop {
422                interval.tick().await;
423                rate_limiter.cleanup_expired_data().await;
424            }
425        });
426    }
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct RateLimitStats {
431    pub global_requests_per_minute: usize,
432    pub global_requests_per_hour: usize,
433    pub global_requests_per_day: usize,
434    pub active_sessions: usize,
435    pub unique_ips: usize,
436    pub unique_agents: usize,
437}
438
439#[derive(Debug, thiserror::Error, Clone, Serialize, Deserialize)]
440pub enum RateLimitError {
441    #[error("Agent {agent_id} exceeded minute limit: {current}/{limit}")]
442    AgentMinuteLimitExceeded {
443        agent_id: String,
444        current: usize,
445        limit: u32,
446    },
447
448    #[error("Agent {agent_id} exceeded hour limit: {current}/{limit}")]
449    AgentHourLimitExceeded {
450        agent_id: String,
451        current: usize,
452        limit: u32,
453    },
454
455    #[error("Agent {agent_id} exceeded day limit: {current}/{limit}")]
456    AgentDayLimitExceeded {
457        agent_id: String,
458        current: usize,
459        limit: u32,
460    },
461
462    #[error("Agent {agent_id} exceeded session limit: {current}/{limit}")]
463    AgentSessionLimitExceeded {
464        agent_id: String,
465        current: usize,
466        limit: u32,
467    },
468
469    #[error("IP {ip} exceeded minute limit: {current}/{limit}")]
470    IpMinuteLimitExceeded {
471        ip: std::net::IpAddr,
472        current: usize,
473        limit: u32,
474    },
475
476    #[error("IP {ip} exceeded hour limit: {current}/{limit}")]
477    IpHourLimitExceeded {
478        ip: std::net::IpAddr,
479        current: usize,
480        limit: u32,
481    },
482
483    #[error("IP {ip} exceeded day limit: {current}/{limit}")]
484    IpDayLimitExceeded {
485        ip: std::net::IpAddr,
486        current: usize,
487        limit: u32,
488    },
489
490    #[error("IP {ip} exceeded agent limit: {current}/{limit}")]
491    IpAgentLimitExceeded {
492        ip: std::net::IpAddr,
493        current: usize,
494        limit: u32,
495    },
496
497    #[error("Global minute limit exceeded: {current}/{limit}")]
498    GlobalMinuteLimitExceeded { current: usize, limit: u32 },
499
500    #[error("Global hour limit exceeded: {current}/{limit}")]
501    GlobalHourLimitExceeded { current: usize, limit: u32 },
502
503    #[error("Global session limit exceeded: {current}/{limit}")]
504    GlobalSessionLimitExceeded { current: usize, limit: u32 },
505}
506
507// Implement Clone for RateLimiter (needed for cleanup task)
508impl Clone for RateLimiter {
509    fn clone(&self) -> Self {
510        Self {
511            config: self.config.clone(),
512            agent_trackers: DashMap::new(),
513            ip_trackers: DashMap::new(),
514            global_tracker: Arc::clone(&self.global_tracker),
515            active_sessions: Arc::clone(&self.active_sessions),
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use std::net::Ipv4Addr;
524
525    #[tokio::test]
526    async fn test_agent_rate_limiting() {
527        let config = RateLimitConfig {
528            agent_limits: AgentLimits {
529                requests_per_minute: 2,
530                ..Default::default()
531            },
532            ..Default::default()
533        };
534
535        let rate_limiter = RateLimiter::new(config);
536        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
537        let agent_id = "test_agent";
538
539        // First two requests should succeed
540        assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
541        assert!(rate_limiter.check_agent_request(agent_id, ip).await.is_ok());
542
543        // Third request should fail
544        assert!(rate_limiter
545            .check_agent_request(agent_id, ip)
546            .await
547            .is_err());
548    }
549
550    #[tokio::test]
551    async fn test_ip_rate_limiting() {
552        let config = RateLimitConfig {
553            ip_limits: IpLimits {
554                requests_per_minute: 2,
555                ..Default::default()
556            },
557            ..Default::default()
558        };
559
560        let rate_limiter = RateLimiter::new(config);
561        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
562
563        // First two requests should succeed
564        assert!(rate_limiter.check_agent_request("agent1", ip).await.is_ok());
565        assert!(rate_limiter.check_agent_request("agent2", ip).await.is_ok());
566
567        // Third request should fail
568        assert!(rate_limiter
569            .check_agent_request("agent3", ip)
570            .await
571            .is_err());
572    }
573
574    #[tokio::test]
575    async fn test_session_limits() {
576        let config = RateLimitConfig {
577            agent_limits: AgentLimits {
578                concurrent_sessions: 1,
579                ..Default::default()
580            },
581            ..Default::default()
582        };
583
584        let rate_limiter = RateLimiter::new(config);
585        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
586        let agent_id = "test_agent";
587
588        // First session should succeed
589        assert!(rate_limiter
590            .check_session_creation(agent_id, ip)
591            .await
592            .is_ok());
593        rate_limiter
594            .add_session(format!("{}_session1", agent_id))
595            .await;
596
597        // Second session should fail
598        assert!(rate_limiter
599            .check_session_creation(agent_id, ip)
600            .await
601            .is_err());
602    }
603}