1use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::{
6 net::IpAddr,
7 sync::Arc,
8 time::{Duration, Instant},
9};
10use thiserror::Error;
11
12#[derive(Error, Debug, Clone)]
14pub enum RateLimitError {
15 #[error("Rate limit exceeded: {limit} requests per {window:?}")]
17 LimitExceeded {
18 limit: u32,
20 window: Duration,
22 },
23
24 #[error("Connection limit exceeded: {current}/{max} connections")]
26 ConnectionLimitExceeded {
27 current: usize,
29 max: usize,
31 },
32
33 #[error("Frame size limit exceeded: {size} bytes > {max} bytes")]
35 FrameSizeExceeded {
36 size: usize,
38 max: usize,
40 },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RateLimitConfig {
46 pub max_requests_per_window: u32,
48 pub window_duration: Duration,
50 pub max_connections_per_ip: usize,
52 pub max_frame_size: usize,
54 pub max_messages_per_second: u32,
56 pub burst_allowance: u32,
58}
59
60impl Default for RateLimitConfig {
61 fn default() -> Self {
62 Self {
63 max_requests_per_window: 100,
64 window_duration: Duration::from_secs(60),
65 max_connections_per_ip: 10,
66 max_frame_size: 1024 * 1024, max_messages_per_second: 30,
68 burst_allowance: 5,
69 }
70 }
71}
72
73impl RateLimitConfig {
74 pub fn high_traffic() -> Self {
76 Self {
77 max_requests_per_window: 1000,
78 max_connections_per_ip: 50,
79 max_messages_per_second: 100,
80 burst_allowance: 20,
81 ..Default::default()
82 }
83 }
84
85 pub fn low_resource() -> Self {
87 Self {
88 max_requests_per_window: 20,
89 max_connections_per_ip: 2,
90 max_frame_size: 256 * 1024, max_messages_per_second: 5,
92 burst_allowance: 2,
93 ..Default::default()
94 }
95 }
96}
97
98#[derive(Debug)]
100struct ClientRateLimit {
101 requests: Vec<Instant>,
103 connection_count: usize,
105 tokens: f64,
107 last_refill: Instant,
109}
110
111impl ClientRateLimit {
112 fn new(burst_allowance: u32) -> Self {
113 let now = Instant::now();
114 Self {
115 requests: Vec::new(),
116 connection_count: 0,
117 tokens: burst_allowance as f64, last_refill: now,
119 }
120 }
121
122 fn refill_tokens(&mut self, config: &RateLimitConfig) {
124 let now = Instant::now();
125 let time_passed = now.duration_since(self.last_refill).as_secs_f64();
126
127 let tokens_to_add = time_passed * config.max_messages_per_second as f64;
129 let max_tokens = (config.max_messages_per_second + config.burst_allowance) as f64;
130
131 self.tokens = (self.tokens + tokens_to_add).min(max_tokens);
132 self.last_refill = now;
133 }
134
135 fn check_message_rate(&mut self, config: &RateLimitConfig) -> Result<(), RateLimitError> {
137 self.refill_tokens(config);
138
139 if self.tokens >= 1.0 {
140 self.tokens -= 1.0;
141 Ok(())
142 } else {
143 Err(RateLimitError::LimitExceeded {
144 limit: config.max_messages_per_second,
145 window: Duration::from_secs(1),
146 })
147 }
148 }
149}
150
151#[derive(Debug)]
153pub struct WebSocketRateLimiter {
154 config: RateLimitConfig,
155 clients: Arc<DashMap<IpAddr, ClientRateLimit>>,
156}
157
158impl Default for WebSocketRateLimiter {
159 fn default() -> Self {
160 Self::new(RateLimitConfig::default())
161 }
162}
163
164impl WebSocketRateLimiter {
165 pub fn new(config: RateLimitConfig) -> Self {
167 Self {
168 config,
169 clients: Arc::new(DashMap::new()),
170 }
171 }
172
173 pub fn check_request(&self, ip: IpAddr) -> Result<(), RateLimitError> {
175 let now = Instant::now();
176 let burst = self.config.burst_allowance;
177 let mut client = self
178 .clients
179 .entry(ip)
180 .or_insert_with(|| ClientRateLimit::new(burst));
181
182 let window_start = now - self.config.window_duration;
184 client.requests.retain(|&time| time > window_start);
185
186 if client.requests.len() >= self.config.max_requests_per_window as usize {
188 return Err(RateLimitError::LimitExceeded {
189 limit: self.config.max_requests_per_window,
190 window: self.config.window_duration,
191 });
192 }
193
194 client.requests.push(now);
196 Ok(())
197 }
198
199 pub fn check_connection(&self, ip: IpAddr) -> Result<(), RateLimitError> {
201 let burst = self.config.burst_allowance;
202 let mut client = self
203 .clients
204 .entry(ip)
205 .or_insert_with(|| ClientRateLimit::new(burst));
206
207 if client.connection_count >= self.config.max_connections_per_ip {
208 return Err(RateLimitError::ConnectionLimitExceeded {
209 current: client.connection_count,
210 max: self.config.max_connections_per_ip,
211 });
212 }
213
214 client.connection_count += 1;
215 Ok(())
216 }
217
218 pub fn close_connection(&self, ip: IpAddr) {
220 if let Some(mut client) = self.clients.get_mut(&ip) {
221 client.connection_count = client.connection_count.saturating_sub(1);
222 }
223 }
224
225 pub fn check_message(&self, ip: IpAddr, frame_size: usize) -> Result<(), RateLimitError> {
227 if frame_size > self.config.max_frame_size {
229 return Err(RateLimitError::FrameSizeExceeded {
230 size: frame_size,
231 max: self.config.max_frame_size,
232 });
233 }
234
235 if let Some(mut client) = self.clients.get_mut(&ip) {
237 client.check_message_rate(&self.config)?;
238 }
239
240 Ok(())
241 }
242
243 pub fn get_stats(&self) -> RateLimitStats {
245 let mut stats = RateLimitStats::default();
246
247 for entry in self.clients.iter() {
248 stats.total_clients += 1;
249 stats.total_connections += entry.value().connection_count;
250
251 if entry.value().connection_count > 0 {
252 stats.active_clients += 1;
253 }
254 }
255
256 stats
257 }
258
259 pub fn cleanup_expired(&self) {
261 let now = Instant::now();
262 let cutoff = now - self.config.window_duration * 2; self.clients.retain(|_, client| {
265 !(client.connection_count == 0
267 && client.requests.last().is_none_or(|&time| time < cutoff))
268 });
269 }
270}
271
272#[derive(Debug, Default, Clone)]
274pub struct RateLimitStats {
275 pub total_clients: usize,
277 pub active_clients: usize,
279 pub total_connections: usize,
281}
282
283#[derive(Debug, Clone)]
285pub struct RateLimitGuard {
286 rate_limiter: Arc<WebSocketRateLimiter>,
287 client_ip: IpAddr,
288}
289
290impl RateLimitGuard {
291 pub fn new(
293 rate_limiter: Arc<WebSocketRateLimiter>,
294 client_ip: IpAddr,
295 ) -> Result<Self, RateLimitError> {
296 rate_limiter.check_connection(client_ip)?;
297
298 Ok(Self {
299 rate_limiter,
300 client_ip,
301 })
302 }
303
304 pub fn check_message(&self, frame_size: usize) -> Result<(), RateLimitError> {
306 self.rate_limiter.check_message(self.client_ip, frame_size)
307 }
308}
309
310impl Drop for RateLimitGuard {
311 fn drop(&mut self) {
312 self.rate_limiter.close_connection(self.client_ip);
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use std::net::Ipv4Addr;
320 use std::thread;
321 use std::time::Duration;
322
323 #[test]
324 fn test_rate_limit_requests() {
325 let config = RateLimitConfig {
326 max_requests_per_window: 2,
327 window_duration: Duration::from_millis(100),
328 ..Default::default()
329 };
330
331 let limiter = WebSocketRateLimiter::new(config);
332 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
333
334 assert!(limiter.check_request(ip).is_ok());
336 assert!(limiter.check_request(ip).is_ok());
337
338 assert!(limiter.check_request(ip).is_err());
340
341 thread::sleep(Duration::from_millis(110));
343
344 assert!(limiter.check_request(ip).is_ok());
346 }
347
348 #[test]
349 fn test_connection_limits() {
350 let config = RateLimitConfig {
351 max_connections_per_ip: 2,
352 ..Default::default()
353 };
354
355 let limiter = WebSocketRateLimiter::new(config);
356 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
357
358 assert!(limiter.check_connection(ip).is_ok());
360 assert!(limiter.check_connection(ip).is_ok());
361
362 assert!(limiter.check_connection(ip).is_err());
364
365 limiter.close_connection(ip);
367
368 assert!(limiter.check_connection(ip).is_ok());
370 }
371
372 #[test]
373 fn test_message_rate_limiting() {
374 let config = RateLimitConfig {
375 max_messages_per_second: 2,
376 burst_allowance: 2, ..Default::default()
378 };
379
380 let limiter = WebSocketRateLimiter::new(config.clone());
381 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
382
383 let client = limiter
385 .clients
386 .entry(ip)
387 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
388 drop(client);
390
391 assert!(limiter.check_message(ip, 1024).is_ok());
393 assert!(limiter.check_message(ip, 1024).is_ok());
394
395 assert!(limiter.check_message(ip, 1024).is_err());
397 }
398
399 #[test]
400 fn test_frame_size_limits() {
401 let config = RateLimitConfig {
402 max_frame_size: 1024,
403 ..Default::default()
404 };
405
406 let limiter = WebSocketRateLimiter::new(config);
407 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
408
409 assert!(limiter.check_message(ip, 512).is_ok());
411
412 assert!(limiter.check_message(ip, 2048).is_err());
414 }
415
416 #[test]
417 fn test_rate_limit_guard() {
418 let config = RateLimitConfig {
419 max_connections_per_ip: 1,
420 ..Default::default()
421 };
422
423 let limiter = Arc::new(WebSocketRateLimiter::new(config));
424 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
425
426 let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
428
429 assert!(RateLimitGuard::new(limiter.clone(), ip).is_err());
431
432 drop(guard);
434
435 assert!(RateLimitGuard::new(limiter, ip).is_ok());
437 }
438
439 #[test]
440 fn test_token_refill_over_time() {
441 let config = RateLimitConfig {
442 max_messages_per_second: 1,
443 burst_allowance: 0,
444 window_duration: Duration::from_millis(100),
445 ..Default::default()
446 };
447
448 let limiter = WebSocketRateLimiter::new(config.clone());
449 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
450
451 {
453 let mut client = limiter
454 .clients
455 .entry(ip)
456 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
457 client.tokens = 0.5; }
459
460 assert!(limiter.check_message(ip, 512).is_err());
462
463 thread::sleep(Duration::from_millis(1100));
465
466 let result = limiter.check_message(ip, 512);
468 assert!(result.is_ok(), "Expected refilled tokens to allow message");
470 }
471
472 #[test]
473 fn test_cleanup_expired_entries() {
474 let config = RateLimitConfig {
475 window_duration: Duration::from_millis(100),
476 ..Default::default()
477 };
478
479 let limiter = WebSocketRateLimiter::new(config);
480 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
481 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
482
483 assert!(limiter.check_connection(ip1).is_ok());
485 assert!(limiter.check_connection(ip2).is_ok());
486
487 assert_eq!(limiter.get_stats().total_clients, 2);
489
490 limiter.close_connection(ip1);
492
493 thread::sleep(Duration::from_millis(250));
495
496 limiter.cleanup_expired();
498
499 let stats = limiter.get_stats();
501 assert!(stats.total_clients <= 2);
503 }
504
505 #[test]
506 fn test_multiple_ips_isolation() {
507 let config = RateLimitConfig {
508 max_requests_per_window: 1,
509 window_duration: Duration::from_millis(100),
510 ..Default::default()
511 };
512
513 let limiter = WebSocketRateLimiter::new(config);
514 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
515 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
516
517 assert!(limiter.check_request(ip1).is_ok());
519 assert!(limiter.check_request(ip1).is_err());
520
521 assert!(limiter.check_request(ip2).is_ok());
523 assert!(limiter.check_request(ip2).is_err());
524 }
525
526 #[test]
527 fn test_burst_allowance_boundary() {
528 let config = RateLimitConfig {
529 max_messages_per_second: 1,
530 burst_allowance: 0,
531 ..Default::default()
532 };
533
534 let limiter = WebSocketRateLimiter::new(config.clone());
535 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
536
537 let mut client = limiter
540 .clients
541 .entry(ip)
542 .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
543 client.tokens = 0.0;
544 drop(client);
545
546 assert!(limiter.check_message(ip, 512).is_err());
548 }
549
550 #[test]
551 fn test_rate_limit_config_high_traffic() {
552 let config = RateLimitConfig::high_traffic();
553
554 assert_eq!(config.max_requests_per_window, 1000);
555 assert_eq!(config.max_connections_per_ip, 50);
556 assert_eq!(config.max_messages_per_second, 100);
557 assert_eq!(config.burst_allowance, 20);
558 assert!(config.max_frame_size >= 1024 * 1024);
559 }
560
561 #[test]
562 fn test_rate_limit_config_low_resource() {
563 let config = RateLimitConfig::low_resource();
564
565 assert_eq!(config.max_requests_per_window, 20);
566 assert_eq!(config.max_connections_per_ip, 2);
567 assert_eq!(config.max_messages_per_second, 5);
568 assert_eq!(config.burst_allowance, 2);
569 assert_eq!(config.max_frame_size, 256 * 1024);
570 }
571
572 #[test]
573 fn test_frame_size_boundary_exact() {
574 let config = RateLimitConfig {
575 max_frame_size: 1024,
576 ..Default::default()
577 };
578
579 let limiter = WebSocketRateLimiter::new(config);
580 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
581
582 assert!(limiter.check_message(ip, 1024).is_ok());
584
585 assert!(limiter.check_message(ip, 1025).is_err());
587
588 assert!(limiter.check_message(ip, 0).is_ok());
590 }
591
592 #[test]
593 fn test_get_stats_accuracy() {
594 let config = RateLimitConfig {
595 max_connections_per_ip: 5,
596 ..Default::default()
597 };
598
599 let limiter = WebSocketRateLimiter::new(config);
600 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
601 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
602
603 assert!(limiter.check_connection(ip1).is_ok());
605 assert!(limiter.check_connection(ip1).is_ok());
606 assert!(limiter.check_connection(ip2).is_ok());
607
608 let stats = limiter.get_stats();
609 assert_eq!(stats.total_clients, 2);
610 assert_eq!(stats.total_connections, 3);
611 assert_eq!(stats.active_clients, 2);
612
613 limiter.close_connection(ip1);
615
616 let stats = limiter.get_stats();
617 assert_eq!(stats.total_connections, 2);
618 }
619
620 #[test]
621 fn test_window_duration_respected() {
622 let config = RateLimitConfig {
623 max_requests_per_window: 1,
624 window_duration: Duration::from_millis(50),
625 ..Default::default()
626 };
627
628 let limiter = WebSocketRateLimiter::new(config);
629 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
630
631 assert!(limiter.check_request(ip).is_ok());
633
634 assert!(limiter.check_request(ip).is_err());
636
637 thread::sleep(Duration::from_millis(60));
639
640 assert!(limiter.check_request(ip).is_ok());
642 }
643
644 #[test]
645 fn test_default_limiter() {
646 let limiter = WebSocketRateLimiter::default();
648 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
649
650 assert!(limiter.check_request(ip).is_ok());
652 assert!(limiter.check_connection(ip).is_ok());
653
654 let stats = limiter.get_stats();
656 assert_eq!(stats.total_clients, 1);
657 assert_eq!(stats.total_connections, 1);
658 }
659
660 #[test]
661 fn test_cleanup_expired_removes_inactive_clients() {
662 let config = RateLimitConfig {
663 window_duration: Duration::from_millis(50),
664 ..Default::default()
665 };
666
667 let limiter = WebSocketRateLimiter::new(config);
668 let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
669 let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
670 let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
671
672 assert!(limiter.check_request(ip1).is_ok());
674 assert!(limiter.check_request(ip2).is_ok());
675 assert!(limiter.check_connection(ip3).is_ok());
676
677 let initial_stats = limiter.get_stats();
678 assert_eq!(initial_stats.total_clients, 3);
679
680 thread::sleep(Duration::from_millis(150));
682
683 limiter.cleanup_expired();
685
686 let after_cleanup = limiter.get_stats();
687 assert!(after_cleanup.total_clients <= initial_stats.total_clients);
689 }
690
691 #[test]
692 fn test_client_with_zero_connections_and_no_recent_requests_cleaned() {
693 let config = RateLimitConfig {
694 window_duration: Duration::from_millis(100),
695 ..Default::default()
696 };
697
698 let limiter = WebSocketRateLimiter::new(config);
699 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
700
701 assert!(limiter.check_request(ip).is_ok());
703
704 let initial_stats = limiter.get_stats();
706 assert_eq!(initial_stats.total_clients, 1);
707
708 thread::sleep(Duration::from_millis(250));
710
711 limiter.cleanup_expired();
713
714 let final_stats = limiter.get_stats();
715 assert_eq!(final_stats.total_clients, 0);
717 }
718
719 #[test]
720 fn test_cleanup_preserves_active_clients() {
721 let config = RateLimitConfig {
722 window_duration: Duration::from_millis(100),
723 ..Default::default()
724 };
725
726 let limiter = WebSocketRateLimiter::new(config);
727 let ip1 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
728 let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
729
730 assert!(limiter.check_connection(ip1).is_ok());
732
733 assert!(limiter.check_request(ip2).is_ok());
735
736 let initial_stats = limiter.get_stats();
737 assert_eq!(initial_stats.total_clients, 2);
738
739 thread::sleep(Duration::from_millis(80));
741
742 let _ = limiter.check_request(ip2);
744
745 limiter.cleanup_expired();
747
748 let final_stats = limiter.get_stats();
749 assert!(final_stats.total_clients >= 1);
751 }
752
753 #[test]
754 fn test_close_connection_on_nonexistent_ip() {
755 let limiter = WebSocketRateLimiter::default();
756 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 99));
757
758 limiter.close_connection(ip);
760
761 let stats = limiter.get_stats();
763 assert_eq!(stats.total_clients, 0);
764 }
765
766 #[test]
767 fn test_check_message_on_nonexistent_client() {
768 let limiter = WebSocketRateLimiter::default();
769 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 88));
770
771 assert!(limiter.check_message(ip, 512).is_ok());
774 }
775
776 #[test]
777 fn test_rate_limit_guard_check_message() {
778 let config = RateLimitConfig {
779 max_connections_per_ip: 5,
780 max_frame_size: 1024,
781 max_messages_per_second: 10,
782 burst_allowance: 5,
783 ..Default::default()
784 };
785
786 let limiter = Arc::new(WebSocketRateLimiter::new(config));
787 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
788
789 let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
790
791 assert!(guard.check_message(512).is_ok());
792 assert!(guard.check_message(512).is_ok());
793 assert!(guard.check_message(2048).is_err());
794 }
795
796 #[test]
797 fn test_rate_limit_guard_check_message_rate_limit() {
798 let config = RateLimitConfig {
799 max_connections_per_ip: 5,
800 max_frame_size: 10_000,
801 max_messages_per_second: 2,
802 burst_allowance: 2,
803 ..Default::default()
804 };
805
806 let limiter = Arc::new(WebSocketRateLimiter::new(config));
807 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
808
809 let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
810
811 assert!(guard.check_message(512).is_ok());
812 assert!(guard.check_message(512).is_ok());
813 assert!(guard.check_message(512).is_err());
814 }
815}