aerosocket_server/
rate_limit.rs

1//! Rate limiting and DoS protection for WebSocket server
2//!
3//! This module provides rate limiting capabilities to protect against DoS attacks.
4
5use aerosocket_core::Result;
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11/// Rate limiter configuration
12#[derive(Debug, Clone)]
13pub struct RateLimitConfig {
14    /// Maximum requests per window
15    pub max_requests: usize,
16    /// Time window for rate limiting
17    pub window: Duration,
18    /// Maximum concurrent connections per IP
19    pub max_connections: usize,
20    /// Connection timeout duration
21    pub connection_timeout: Duration,
22}
23
24impl Default for RateLimitConfig {
25    fn default() -> Self {
26        Self {
27            max_requests: 100,
28            window: Duration::from_secs(60),
29            max_connections: 10,
30            connection_timeout: Duration::from_secs(300),
31        }
32    }
33}
34
35/// Rate limiter for tracking requests and connections
36pub struct RateLimiter {
37    config: RateLimitConfig,
38    /// Request tracking per IP
39    request_counters: Mutex<HashMap<IpAddr, RequestCounter>>,
40    /// Connection tracking per IP
41    connection_counters: Mutex<HashMap<IpAddr, usize>>,
42}
43
44/// Request counter for a specific IP
45#[derive(Debug, Clone)]
46struct RequestCounter {
47    count: usize,
48    window_start: Instant,
49}
50
51impl RateLimiter {
52    /// Create a new rate limiter
53    pub fn new(config: RateLimitConfig) -> Self {
54        Self {
55            config,
56            request_counters: Mutex::new(HashMap::new()),
57            connection_counters: Mutex::new(HashMap::new()),
58        }
59    }
60
61    /// Check if an IP is allowed to make a request
62    pub async fn check_request_rate(&self, ip: IpAddr) -> Result<bool> {
63        let mut counters = self.request_counters.lock().await;
64        let now = Instant::now();
65
66        let counter = counters.entry(ip).or_insert_with(|| RequestCounter {
67            count: 0,
68            window_start: now,
69        });
70
71        // Reset window if expired
72        if now.duration_since(counter.window_start) >= self.config.window {
73            counter.count = 0;
74            counter.window_start = now;
75        }
76
77        // Check rate limit
78        if counter.count >= self.config.max_requests {
79            return Ok(false);
80        }
81
82        counter.count += 1;
83        Ok(true)
84    }
85
86    /// Check if an IP can establish a new connection
87    pub async fn can_connect(&self, ip: IpAddr) -> Result<bool> {
88        let mut conn_counters = self.connection_counters.lock().await;
89        let current_count = conn_counters.entry(ip).or_insert(0);
90
91        if *current_count >= self.config.max_connections {
92            return Ok(false);
93        }
94
95        *current_count += 1;
96        Ok(true)
97    }
98
99    /// Remove a connection for an IP
100    pub async fn remove_connection(&self, ip: IpAddr) {
101        let mut conn_counters = self.connection_counters.lock().await;
102        if let Some(count) = conn_counters.get_mut(&ip) {
103            if *count > 0 {
104                *count -= 1;
105            }
106            if *count == 0 {
107                conn_counters.remove(&ip);
108            }
109        }
110    }
111
112    /// Cleanup expired entries
113    pub async fn cleanup(&self) {
114        let now = Instant::now();
115
116        // Cleanup expired request counters
117        {
118            let mut counters = self.request_counters.lock().await;
119            counters.retain(|_, counter| {
120                now.duration_since(counter.window_start) < self.config.window * 2
121            });
122        }
123
124        // Cleanup connection counters (they don't expire naturally)
125        // This is mainly for memory management
126        {
127            let mut conn_counters = self.connection_counters.lock().await;
128            conn_counters.retain(|_, &mut count| count > 0);
129        }
130    }
131
132    /// Get statistics about rate limiting
133    pub async fn get_stats(&self) -> RateLimitStats {
134        let request_count = self.request_counters.lock().await.len();
135        let connection_count = self.connection_counters.lock().await.len();
136
137        RateLimitStats {
138            tracked_ips: request_count,
139            active_connections: connection_count,
140        }
141    }
142}
143
144/// Rate limiting statistics
145#[derive(Debug, Clone)]
146pub struct RateLimitStats {
147    /// Number of IPs being tracked for requests
148    pub tracked_ips: usize,
149    /// Number of IPs with active connections
150    pub active_connections: usize,
151}
152
153/// Middleware for rate limiting WebSocket connections
154pub struct RateLimitMiddleware {
155    limiter: RateLimiter,
156}
157
158impl RateLimitMiddleware {
159    /// Create new rate limit middleware
160    pub fn new(config: RateLimitConfig) -> Self {
161        Self {
162            limiter: RateLimiter::new(config),
163        }
164    }
165
166    /// Check if a connection is allowed
167    pub async fn check_connection(&self, ip: IpAddr) -> Result<bool> {
168        // Check both request rate and connection limits
169        let request_ok = self.limiter.check_request_rate(ip).await?;
170        let connection_ok = self.limiter.can_connect(ip).await?;
171
172        Ok(request_ok && connection_ok)
173    }
174
175    /// Remove a connection from tracking
176    pub async fn connection_closed(&self, ip: IpAddr) {
177        self.limiter.remove_connection(ip).await;
178    }
179
180    /// Get rate limiting statistics
181    pub async fn stats(&self) -> RateLimitStats {
182        self.limiter.get_stats().await
183    }
184
185    /// Cleanup expired entries
186    pub async fn cleanup(&self) {
187        self.limiter.cleanup().await;
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use std::net::{Ipv4Addr, Ipv6Addr};
195
196    #[tokio::test]
197    async fn test_rate_limiting() {
198        let config = RateLimitConfig {
199            max_requests: 2,
200            window: Duration::from_secs(1),
201            max_connections: 1,
202            connection_timeout: Duration::from_secs(60),
203        };
204
205        let limiter = RateLimiter::new(config);
206        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
207
208        // First request should be allowed
209        assert!(limiter.check_request_rate(ip).await.unwrap());
210
211        // Second request should be allowed
212        assert!(limiter.check_request_rate(ip).await.unwrap());
213
214        // Third request should be denied
215        assert!(!limiter.check_request_rate(ip).await.unwrap());
216
217        // Wait for window to reset
218        tokio::time::sleep(Duration::from_secs(2)).await;
219
220        // Should be allowed again
221        assert!(limiter.check_request_rate(ip).await.unwrap());
222    }
223
224    #[tokio::test]
225    async fn test_connection_limiting() {
226        let config = RateLimitConfig {
227            max_requests: 100,
228            window: Duration::from_secs(60),
229            max_connections: 2,
230            connection_timeout: Duration::from_secs(60),
231        };
232
233        let limiter = RateLimiter::new(config);
234        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
235
236        // First connection should be allowed
237        assert!(limiter.can_connect(ip).await.unwrap());
238
239        // Second connection should be allowed
240        assert!(limiter.can_connect(ip).await.unwrap());
241
242        // Third connection should be denied
243        assert!(!limiter.can_connect(ip).await.unwrap());
244
245        // Remove one connection
246        limiter.remove_connection(ip).await;
247
248        // Should be allowed again
249        assert!(limiter.can_connect(ip).await.unwrap());
250    }
251
252    #[tokio::test]
253    async fn test_middleware() {
254        let config = RateLimitConfig::default();
255        let middleware = RateLimitMiddleware::new(config);
256        let ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
257
258        // Should allow connection
259        assert!(middleware.check_connection(ip).await.unwrap());
260
261        // Get stats
262        let stats = middleware.stats().await;
263        assert_eq!(stats.tracked_ips, 1);
264        assert_eq!(stats.active_connections, 1);
265
266        // Close connection
267        middleware.connection_closed(ip).await;
268
269        let stats = middleware.stats().await;
270        assert_eq!(stats.active_connections, 0);
271    }
272}