aerosocket_server/
rate_limit.rs1use aerosocket_core::Result;
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::time::{Duration, Instant};
9use tokio::sync::Mutex;
10
11#[derive(Debug, Clone)]
13pub struct RateLimitConfig {
14 pub max_requests: usize,
16 pub window: Duration,
18 pub max_connections: usize,
20 pub connection_timeout: Duration,
22}
23
24impl Default for RateLimitConfig {
25 fn default() -> Self {
26 Self {
27 max_requests: 100,
28 window: Duration::from_secs(60),
29 max_connections: 10,
30 connection_timeout: Duration::from_secs(300),
31 }
32 }
33}
34
35pub struct RateLimiter {
37 config: RateLimitConfig,
38 request_counters: Mutex<HashMap<IpAddr, RequestCounter>>,
40 connection_counters: Mutex<HashMap<IpAddr, usize>>,
42}
43
44#[derive(Debug, Clone)]
46struct RequestCounter {
47 count: usize,
48 window_start: Instant,
49}
50
51impl RateLimiter {
52 pub fn new(config: RateLimitConfig) -> Self {
54 Self {
55 config,
56 request_counters: Mutex::new(HashMap::new()),
57 connection_counters: Mutex::new(HashMap::new()),
58 }
59 }
60
61 pub async fn check_request_rate(&self, ip: IpAddr) -> Result<bool> {
63 let mut counters = self.request_counters.lock().await;
64 let now = Instant::now();
65
66 let counter = counters.entry(ip).or_insert_with(|| RequestCounter {
67 count: 0,
68 window_start: now,
69 });
70
71 if now.duration_since(counter.window_start) >= self.config.window {
73 counter.count = 0;
74 counter.window_start = now;
75 }
76
77 if counter.count >= self.config.max_requests {
79 return Ok(false);
80 }
81
82 counter.count += 1;
83 Ok(true)
84 }
85
86 pub async fn can_connect(&self, ip: IpAddr) -> Result<bool> {
88 let mut conn_counters = self.connection_counters.lock().await;
89 let current_count = conn_counters.entry(ip).or_insert(0);
90
91 if *current_count >= self.config.max_connections {
92 return Ok(false);
93 }
94
95 *current_count += 1;
96 Ok(true)
97 }
98
99 pub async fn remove_connection(&self, ip: IpAddr) {
101 let mut conn_counters = self.connection_counters.lock().await;
102 if let Some(count) = conn_counters.get_mut(&ip) {
103 if *count > 0 {
104 *count -= 1;
105 }
106 if *count == 0 {
107 conn_counters.remove(&ip);
108 }
109 }
110 }
111
112 pub async fn cleanup(&self) {
114 let now = Instant::now();
115
116 {
118 let mut counters = self.request_counters.lock().await;
119 counters.retain(|_, counter| {
120 now.duration_since(counter.window_start) < self.config.window * 2
121 });
122 }
123
124 {
127 let mut conn_counters = self.connection_counters.lock().await;
128 conn_counters.retain(|_, &mut count| count > 0);
129 }
130 }
131
132 pub async fn get_stats(&self) -> RateLimitStats {
134 let request_count = self.request_counters.lock().await.len();
135 let connection_count = self.connection_counters.lock().await.len();
136
137 RateLimitStats {
138 tracked_ips: request_count,
139 active_connections: connection_count,
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct RateLimitStats {
147 pub tracked_ips: usize,
149 pub active_connections: usize,
151}
152
153pub struct RateLimitMiddleware {
155 limiter: RateLimiter,
156}
157
158impl RateLimitMiddleware {
159 pub fn new(config: RateLimitConfig) -> Self {
161 Self {
162 limiter: RateLimiter::new(config),
163 }
164 }
165
166 pub async fn check_connection(&self, ip: IpAddr) -> Result<bool> {
168 let request_ok = self.limiter.check_request_rate(ip).await?;
170 let connection_ok = self.limiter.can_connect(ip).await?;
171
172 Ok(request_ok && connection_ok)
173 }
174
175 pub async fn connection_closed(&self, ip: IpAddr) {
177 self.limiter.remove_connection(ip).await;
178 }
179
180 pub async fn stats(&self) -> RateLimitStats {
182 self.limiter.get_stats().await
183 }
184
185 pub async fn cleanup(&self) {
187 self.limiter.cleanup().await;
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use std::net::{Ipv4Addr, Ipv6Addr};
195
196 #[tokio::test]
197 async fn test_rate_limiting() {
198 let config = RateLimitConfig {
199 max_requests: 2,
200 window: Duration::from_secs(1),
201 max_connections: 1,
202 connection_timeout: Duration::from_secs(60),
203 };
204
205 let limiter = RateLimiter::new(config);
206 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
207
208 assert!(limiter.check_request_rate(ip).await.unwrap());
210
211 assert!(limiter.check_request_rate(ip).await.unwrap());
213
214 assert!(!limiter.check_request_rate(ip).await.unwrap());
216
217 tokio::time::sleep(Duration::from_secs(2)).await;
219
220 assert!(limiter.check_request_rate(ip).await.unwrap());
222 }
223
224 #[tokio::test]
225 async fn test_connection_limiting() {
226 let config = RateLimitConfig {
227 max_requests: 100,
228 window: Duration::from_secs(60),
229 max_connections: 2,
230 connection_timeout: Duration::from_secs(60),
231 };
232
233 let limiter = RateLimiter::new(config);
234 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
235
236 assert!(limiter.can_connect(ip).await.unwrap());
238
239 assert!(limiter.can_connect(ip).await.unwrap());
241
242 assert!(!limiter.can_connect(ip).await.unwrap());
244
245 limiter.remove_connection(ip).await;
247
248 assert!(limiter.can_connect(ip).await.unwrap());
250 }
251
252 #[tokio::test]
253 async fn test_middleware() {
254 let config = RateLimitConfig::default();
255 let middleware = RateLimitMiddleware::new(config);
256 let ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
257
258 assert!(middleware.check_connection(ip).await.unwrap());
260
261 let stats = middleware.stats().await;
263 assert_eq!(stats.tracked_ips, 1);
264 assert_eq!(stats.active_connections, 1);
265
266 middleware.connection_closed(ip).await;
268
269 let stats = middleware.stats().await;
270 assert_eq!(stats.active_connections, 0);
271 }
272}