1use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::{
6 net::IpAddr,
7 sync::Arc,
8 time::{Duration, Instant},
9};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
14pub enum RateLimitError {
15 #[error("Rate limit exceeded: {limit} requests per {window:?}")]
16 LimitExceeded { limit: u32, window: Duration },
17
18 #[error("Connection limit exceeded: {current}/{max} connections")]
19 ConnectionLimitExceeded { current: usize, max: usize },
20
21 #[error("Frame size limit exceeded: {size} bytes > {max} bytes")]
22 FrameSizeExceeded { size: usize, max: usize },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RateLimitConfig {
28 pub max_requests_per_window: u32,
30 pub window_duration: Duration,
32 pub max_connections_per_ip: usize,
34 pub max_frame_size: usize,
36 pub max_messages_per_second: u32,
38 pub burst_allowance: u32,
40}
41
42impl Default for RateLimitConfig {
43 fn default() -> Self {
44 Self {
45 max_requests_per_window: 100,
46 window_duration: Duration::from_secs(60),
47 max_connections_per_ip: 10,
48 max_frame_size: 1024 * 1024, max_messages_per_second: 30,
50 burst_allowance: 5,
51 }
52 }
53}
54
55impl RateLimitConfig {
56 pub fn high_traffic() -> Self {
58 Self {
59 max_requests_per_window: 1000,
60 max_connections_per_ip: 50,
61 max_messages_per_second: 100,
62 burst_allowance: 20,
63 ..Default::default()
64 }
65 }
66
67 pub fn low_resource() -> Self {
69 Self {
70 max_requests_per_window: 20,
71 max_connections_per_ip: 2,
72 max_frame_size: 256 * 1024, max_messages_per_second: 5,
74 burst_allowance: 2,
75 ..Default::default()
76 }
77 }
78}
79
80#[derive(Debug)]
82struct ClientRateLimit {
83 requests: Vec<Instant>,
85 connection_count: usize,
87 tokens: f64,
89 last_refill: Instant,
91}
92
93impl ClientRateLimit {
94 fn new(burst_allowance: u32) -> Self {
95 let now = Instant::now();
96 Self {
97 requests: Vec::new(),
98 connection_count: 0,
99 tokens: burst_allowance as f64, last_refill: now,
101 }
102 }
103
104 fn refill_tokens(&mut self, config: &RateLimitConfig) {
106 let now = Instant::now();
107 let time_passed = now.duration_since(self.last_refill).as_secs_f64();
108
109 let tokens_to_add = time_passed * config.max_messages_per_second as f64;
111 let max_tokens = (config.max_messages_per_second + config.burst_allowance) as f64;
112
113 self.tokens = (self.tokens + tokens_to_add).min(max_tokens);
114 self.last_refill = now;
115 }
116
117 fn check_message_rate(&mut self, config: &RateLimitConfig) -> Result<(), RateLimitError> {
119 self.refill_tokens(config);
120
121 if self.tokens >= 1.0 {
122 self.tokens -= 1.0;
123 Ok(())
124 } else {
125 Err(RateLimitError::LimitExceeded {
126 limit: config.max_messages_per_second,
127 window: Duration::from_secs(1),
128 })
129 }
130 }
131}
132
133#[derive(Debug)]
135pub struct WebSocketRateLimiter {
136 config: RateLimitConfig,
137 clients: Arc<DashMap<IpAddr, ClientRateLimit>>,
138}
139
140impl Default for WebSocketRateLimiter {
141 fn default() -> Self {
142 Self::new(RateLimitConfig::default())
143 }
144}
145
146impl WebSocketRateLimiter {
147 pub fn new(config: RateLimitConfig) -> Self {
149 Self {
150 config,
151 clients: Arc::new(DashMap::new()),
152 }
153 }
154
155 pub fn check_request(&self, ip: IpAddr) -> Result<(), RateLimitError> {
157 let now = Instant::now();
158 let burst = self.config.burst_allowance;
159 let mut client = self
160 .clients
161 .entry(ip)
162 .or_insert_with(|| ClientRateLimit::new(burst));
163
164 let window_start = now - self.config.window_duration;
166 client.requests.retain(|&time| time > window_start);
167
168 if client.requests.len() >= self.config.max_requests_per_window as usize {
170 return Err(RateLimitError::LimitExceeded {
171 limit: self.config.max_requests_per_window,
172 window: self.config.window_duration,
173 });
174 }
175
176 client.requests.push(now);
178 Ok(())
179 }
180
181 pub fn check_connection(&self, ip: IpAddr) -> Result<(), RateLimitError> {
183 let burst = self.config.burst_allowance;
184 let mut client = self
185 .clients
186 .entry(ip)
187 .or_insert_with(|| ClientRateLimit::new(burst));
188
189 if client.connection_count >= self.config.max_connections_per_ip {
190 return Err(RateLimitError::ConnectionLimitExceeded {
191 current: client.connection_count,
192 max: self.config.max_connections_per_ip,
193 });
194 }
195
196 client.connection_count += 1;
197 Ok(())
198 }
199
200 pub fn close_connection(&self, ip: IpAddr) {
202 if let Some(mut client) = self.clients.get_mut(&ip) {
203 client.connection_count = client.connection_count.saturating_sub(1);
204 }
205 }
206
207 pub fn check_message(&self, ip: IpAddr, frame_size: usize) -> Result<(), RateLimitError> {
209 if frame_size > self.config.max_frame_size {
211 return Err(RateLimitError::FrameSizeExceeded {
212 size: frame_size,
213 max: self.config.max_frame_size,
214 });
215 }
216
217 if let Some(mut client) = self.clients.get_mut(&ip) {
219 client.check_message_rate(&self.config)?;
220 }
221
222 Ok(())
223 }
224
225 pub fn get_stats(&self) -> RateLimitStats {
227 let mut stats = RateLimitStats::default();
228
229 for entry in self.clients.iter() {
230 stats.total_clients += 1;
231 stats.total_connections += entry.value().connection_count;
232
233 if entry.value().connection_count > 0 {
234 stats.active_clients += 1;
235 }
236 }
237
238 stats
239 }
240
241 pub fn cleanup_expired(&self) {
243 let now = Instant::now();
244 let cutoff = now - self.config.window_duration * 2; self.clients.retain(|_, client| {
247 !(client.connection_count == 0
249 && client.requests.last().is_none_or(|&time| time < cutoff))
250 });
251 }
252}
253
254#[derive(Debug, Default, Clone)]
256pub struct RateLimitStats {
257 pub total_clients: usize,
258 pub active_clients: usize,
259 pub total_connections: usize,
260}
261
262#[derive(Debug, Clone)]
264pub struct RateLimitGuard {
265 rate_limiter: Arc<WebSocketRateLimiter>,
266 client_ip: IpAddr,
267}
268
269impl RateLimitGuard {
270 pub fn new(
272 rate_limiter: Arc<WebSocketRateLimiter>,
273 client_ip: IpAddr,
274 ) -> Result<Self, RateLimitError> {
275 rate_limiter.check_connection(client_ip)?;
276
277 Ok(Self {
278 rate_limiter,
279 client_ip,
280 })
281 }
282
283 pub fn check_message(&self, frame_size: usize) -> Result<(), RateLimitError> {
285 self.rate_limiter.check_message(self.client_ip, frame_size)
286 }
287}
288
289impl Drop for RateLimitGuard {
290 fn drop(&mut self) {
291 self.rate_limiter.close_connection(self.client_ip);
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use std::net::Ipv4Addr;
299 use std::thread;
300 use std::time::Duration;
301
302 #[test]
303 fn test_rate_limit_requests() {
304 let config = RateLimitConfig {
305 max_requests_per_window: 2,
306 window_duration: Duration::from_millis(100),
307 ..Default::default()
308 };
309
310 let limiter = WebSocketRateLimiter::new(config);
311 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
312
313 assert!(limiter.check_request(ip).is_ok());
315 assert!(limiter.check_request(ip).is_ok());
316
317 assert!(limiter.check_request(ip).is_err());
319
320 thread::sleep(Duration::from_millis(110));
322
323 assert!(limiter.check_request(ip).is_ok());
325 }
326
327 #[test]
328 fn test_connection_limits() {
329 let config = RateLimitConfig {
330 max_connections_per_ip: 2,
331 ..Default::default()
332 };
333
334 let limiter = WebSocketRateLimiter::new(config);
335 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
336
337 assert!(limiter.check_connection(ip).is_ok());
339 assert!(limiter.check_connection(ip).is_ok());
340
341 assert!(limiter.check_connection(ip).is_err());
343
344 limiter.close_connection(ip);
346
347 assert!(limiter.check_connection(ip).is_ok());
349 }
350
351 #[test]
352 fn test_message_rate_limiting() {
353 let config = RateLimitConfig {
354 max_messages_per_second: 2,
355 burst_allowance: 2, ..Default::default()
357 };
358
359 let limiter = WebSocketRateLimiter::new(config.clone());
360 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
361
362 let client = limiter
364 .clients
365 .entry(ip)
366 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
367 drop(client);
369
370 assert!(limiter.check_message(ip, 1024).is_ok());
372 assert!(limiter.check_message(ip, 1024).is_ok());
373
374 assert!(limiter.check_message(ip, 1024).is_err());
376 }
377
378 #[test]
379 fn test_frame_size_limits() {
380 let config = RateLimitConfig {
381 max_frame_size: 1024,
382 ..Default::default()
383 };
384
385 let limiter = WebSocketRateLimiter::new(config);
386 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
387
388 assert!(limiter.check_message(ip, 512).is_ok());
390
391 assert!(limiter.check_message(ip, 2048).is_err());
393 }
394
395 #[test]
396 fn test_rate_limit_guard() {
397 let config = RateLimitConfig {
398 max_connections_per_ip: 1,
399 ..Default::default()
400 };
401
402 let limiter = Arc::new(WebSocketRateLimiter::new(config));
403 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
404
405 let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
407
408 assert!(RateLimitGuard::new(limiter.clone(), ip).is_err());
410
411 drop(guard);
413
414 assert!(RateLimitGuard::new(limiter, ip).is_ok());
416 }
417
418 #[test]
419 fn test_token_refill_over_time() {
420 let config = RateLimitConfig {
421 max_messages_per_second: 1,
422 burst_allowance: 0,
423 window_duration: Duration::from_millis(100),
424 ..Default::default()
425 };
426
427 let limiter = WebSocketRateLimiter::new(config.clone());
428 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
429
430 {
432 let mut client = limiter
433 .clients
434 .entry(ip)
435 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
436 client.tokens = 0.5; }
438
439 assert!(limiter.check_message(ip, 512).is_err());
441
442 thread::sleep(Duration::from_millis(1100));
444
445 let result = limiter.check_message(ip, 512);
447 assert!(result.is_ok(), "Expected refilled tokens to allow message");
449 }
450
451 #[test]
452 fn test_cleanup_expired_entries() {
453 let config = RateLimitConfig {
454 window_duration: Duration::from_millis(100),
455 ..Default::default()
456 };
457
458 let limiter = WebSocketRateLimiter::new(config);
459 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
460 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
461
462 assert!(limiter.check_connection(ip1).is_ok());
464 assert!(limiter.check_connection(ip2).is_ok());
465
466 assert_eq!(limiter.get_stats().total_clients, 2);
468
469 limiter.close_connection(ip1);
471
472 thread::sleep(Duration::from_millis(250));
474
475 limiter.cleanup_expired();
477
478 let stats = limiter.get_stats();
480 assert!(stats.total_clients <= 2);
482 }
483
484 #[test]
485 fn test_multiple_ips_isolation() {
486 let config = RateLimitConfig {
487 max_requests_per_window: 1,
488 window_duration: Duration::from_millis(100),
489 ..Default::default()
490 };
491
492 let limiter = WebSocketRateLimiter::new(config);
493 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
494 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
495
496 assert!(limiter.check_request(ip1).is_ok());
498 assert!(limiter.check_request(ip1).is_err());
499
500 assert!(limiter.check_request(ip2).is_ok());
502 assert!(limiter.check_request(ip2).is_err());
503 }
504
505 #[test]
506 fn test_burst_allowance_boundary() {
507 let config = RateLimitConfig {
508 max_messages_per_second: 1,
509 burst_allowance: 0,
510 ..Default::default()
511 };
512
513 let limiter = WebSocketRateLimiter::new(config.clone());
514 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
515
516 let mut client = limiter
519 .clients
520 .entry(ip)
521 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
522 client.tokens = 0.0;
523 drop(client);
524
525 assert!(limiter.check_message(ip, 512).is_err());
527 }
528
529 #[test]
530 fn test_rate_limit_config_high_traffic() {
531 let config = RateLimitConfig::high_traffic();
532
533 assert_eq!(config.max_requests_per_window, 1000);
534 assert_eq!(config.max_connections_per_ip, 50);
535 assert_eq!(config.max_messages_per_second, 100);
536 assert_eq!(config.burst_allowance, 20);
537 assert!(config.max_frame_size >= 1024 * 1024);
538 }
539
540 #[test]
541 fn test_rate_limit_config_low_resource() {
542 let config = RateLimitConfig::low_resource();
543
544 assert_eq!(config.max_requests_per_window, 20);
545 assert_eq!(config.max_connections_per_ip, 2);
546 assert_eq!(config.max_messages_per_second, 5);
547 assert_eq!(config.burst_allowance, 2);
548 assert_eq!(config.max_frame_size, 256 * 1024);
549 }
550
551 #[test]
552 fn test_frame_size_boundary_exact() {
553 let config = RateLimitConfig {
554 max_frame_size: 1024,
555 ..Default::default()
556 };
557
558 let limiter = WebSocketRateLimiter::new(config);
559 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
560
561 assert!(limiter.check_message(ip, 1024).is_ok());
563
564 assert!(limiter.check_message(ip, 1025).is_err());
566
567 assert!(limiter.check_message(ip, 0).is_ok());
569 }
570
571 #[test]
572 fn test_get_stats_accuracy() {
573 let config = RateLimitConfig {
574 max_connections_per_ip: 5,
575 ..Default::default()
576 };
577
578 let limiter = WebSocketRateLimiter::new(config);
579 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
580 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
581
582 assert!(limiter.check_connection(ip1).is_ok());
584 assert!(limiter.check_connection(ip1).is_ok());
585 assert!(limiter.check_connection(ip2).is_ok());
586
587 let stats = limiter.get_stats();
588 assert_eq!(stats.total_clients, 2);
589 assert_eq!(stats.total_connections, 3);
590 assert_eq!(stats.active_clients, 2);
591
592 limiter.close_connection(ip1);
594
595 let stats = limiter.get_stats();
596 assert_eq!(stats.total_connections, 2);
597 }
598
599 #[test]
600 fn test_window_duration_respected() {
601 let config = RateLimitConfig {
602 max_requests_per_window: 1,
603 window_duration: Duration::from_millis(50),
604 ..Default::default()
605 };
606
607 let limiter = WebSocketRateLimiter::new(config);
608 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
609
610 assert!(limiter.check_request(ip).is_ok());
612
613 assert!(limiter.check_request(ip).is_err());
615
616 thread::sleep(Duration::from_millis(60));
618
619 assert!(limiter.check_request(ip).is_ok());
621 }
622
623 #[test]
624 fn test_default_limiter() {
625 let limiter = WebSocketRateLimiter::default();
627 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
628
629 assert!(limiter.check_request(ip).is_ok());
631 assert!(limiter.check_connection(ip).is_ok());
632
633 let stats = limiter.get_stats();
635 assert_eq!(stats.total_clients, 1);
636 assert_eq!(stats.total_connections, 1);
637 }
638
639 #[test]
640 fn test_cleanup_expired_removes_inactive_clients() {
641 let config = RateLimitConfig {
642 window_duration: Duration::from_millis(50),
643 ..Default::default()
644 };
645
646 let limiter = WebSocketRateLimiter::new(config);
647 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
648 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
649 let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
650
651 assert!(limiter.check_request(ip1).is_ok());
653 assert!(limiter.check_request(ip2).is_ok());
654 assert!(limiter.check_connection(ip3).is_ok());
655
656 let initial_stats = limiter.get_stats();
657 assert_eq!(initial_stats.total_clients, 3);
658
659 thread::sleep(Duration::from_millis(150));
661
662 limiter.cleanup_expired();
664
665 let after_cleanup = limiter.get_stats();
666 assert!(after_cleanup.total_clients <= initial_stats.total_clients);
668 }
669
670 #[test]
671 fn test_client_with_zero_connections_and_no_recent_requests_cleaned() {
672 let config = RateLimitConfig {
673 window_duration: Duration::from_millis(100),
674 ..Default::default()
675 };
676
677 let limiter = WebSocketRateLimiter::new(config);
678 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
679
680 assert!(limiter.check_request(ip).is_ok());
682
683 let initial_stats = limiter.get_stats();
685 assert_eq!(initial_stats.total_clients, 1);
686
687 thread::sleep(Duration::from_millis(250));
689
690 limiter.cleanup_expired();
692
693 let final_stats = limiter.get_stats();
694 assert_eq!(final_stats.total_clients, 0);
696 }
697
698 #[test]
699 fn test_cleanup_preserves_active_clients() {
700 let config = RateLimitConfig {
701 window_duration: Duration::from_millis(100),
702 ..Default::default()
703 };
704
705 let limiter = WebSocketRateLimiter::new(config);
706 let ip1 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
707 let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
708
709 assert!(limiter.check_connection(ip1).is_ok());
711
712 assert!(limiter.check_request(ip2).is_ok());
714
715 let initial_stats = limiter.get_stats();
716 assert_eq!(initial_stats.total_clients, 2);
717
718 thread::sleep(Duration::from_millis(80));
720
721 let _ = limiter.check_request(ip2);
723
724 limiter.cleanup_expired();
726
727 let final_stats = limiter.get_stats();
728 assert!(final_stats.total_clients >= 1);
730 }
731
732 #[test]
733 fn test_close_connection_on_nonexistent_ip() {
734 let limiter = WebSocketRateLimiter::default();
735 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 99));
736
737 limiter.close_connection(ip);
739
740 let stats = limiter.get_stats();
742 assert_eq!(stats.total_clients, 0);
743 }
744
745 #[test]
746 fn test_check_message_on_nonexistent_client() {
747 let limiter = WebSocketRateLimiter::default();
748 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 88));
749
750 assert!(limiter.check_message(ip, 512).is_ok());
753 }
754}