Skip to main content

pjson_rs/security/
rate_limit.rs

1//! Rate limiting system for WebSocket connections to prevent DoS attacks
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::{
6    net::IpAddr,
7    sync::Arc,
8    time::{Duration, Instant},
9};
10use thiserror::Error;
11
12/// Rate limiting errors
13#[derive(Error, Debug, Clone)]
14pub enum RateLimitError {
15    /// Request count exceeded the per-window limit.
16    #[error("Rate limit exceeded: {limit} requests per {window:?}")]
17    LimitExceeded {
18        /// Configured per-window request limit.
19        limit: u32,
20        /// Configured window duration.
21        window: Duration,
22    },
23
24    /// Per-IP concurrent connection cap was reached.
25    #[error("Connection limit exceeded: {current}/{max} connections")]
26    ConnectionLimitExceeded {
27        /// Current connection count for the IP.
28        current: usize,
29        /// Configured maximum number of connections per IP.
30        max: usize,
31    },
32
33    /// Frame larger than the configured maximum was rejected.
34    #[error("Frame size limit exceeded: {size} bytes > {max} bytes")]
35    FrameSizeExceeded {
36        /// Observed frame size in bytes.
37        size: usize,
38        /// Configured maximum frame size in bytes.
39        max: usize,
40    },
41}
42
43/// Rate limiting configuration
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RateLimitConfig {
46    /// Maximum requests per time window
47    pub max_requests_per_window: u32,
48    /// Time window for rate limiting
49    pub window_duration: Duration,
50    /// Maximum concurrent connections per IP
51    pub max_connections_per_ip: usize,
52    /// Maximum WebSocket frame size
53    pub max_frame_size: usize,
54    /// Maximum message rate (messages per second)
55    pub max_messages_per_second: u32,
56    /// Burst allowance (extra messages above rate)
57    pub burst_allowance: u32,
58}
59
60impl Default for RateLimitConfig {
61    fn default() -> Self {
62        Self {
63            max_requests_per_window: 100,
64            window_duration: Duration::from_secs(60),
65            max_connections_per_ip: 10,
66            max_frame_size: 1024 * 1024, // 1MB
67            max_messages_per_second: 30,
68            burst_allowance: 5,
69        }
70    }
71}
72
73impl RateLimitConfig {
74    /// Configuration for high-traffic scenarios
75    pub fn high_traffic() -> Self {
76        Self {
77            max_requests_per_window: 1000,
78            max_connections_per_ip: 50,
79            max_messages_per_second: 100,
80            burst_allowance: 20,
81            ..Default::default()
82        }
83    }
84
85    /// Configuration for low-resource environments
86    pub fn low_resource() -> Self {
87        Self {
88            max_requests_per_window: 20,
89            max_connections_per_ip: 2,
90            max_frame_size: 256 * 1024, // 256KB
91            max_messages_per_second: 5,
92            burst_allowance: 2,
93            ..Default::default()
94        }
95    }
96}
97
98/// Rate limit tracking for a specific client
99#[derive(Debug)]
100struct ClientRateLimit {
101    /// Request timestamps within current window
102    requests: Vec<Instant>,
103    /// Current connection count
104    connection_count: usize,
105    /// Token bucket for message rate limiting
106    tokens: f64,
107    /// Last token refill time
108    last_refill: Instant,
109}
110
111impl ClientRateLimit {
112    fn new(burst_allowance: u32) -> Self {
113        let now = Instant::now();
114        Self {
115            requests: Vec::new(),
116            connection_count: 0,
117            tokens: burst_allowance as f64, // Start with burst allowance tokens
118            last_refill: now,
119        }
120    }
121
122    /// Refill tokens based on time passed
123    fn refill_tokens(&mut self, config: &RateLimitConfig) {
124        let now = Instant::now();
125        let time_passed = now.duration_since(self.last_refill).as_secs_f64();
126
127        // Add tokens at configured rate
128        let tokens_to_add = time_passed * config.max_messages_per_second as f64;
129        let max_tokens = (config.max_messages_per_second + config.burst_allowance) as f64;
130
131        self.tokens = (self.tokens + tokens_to_add).min(max_tokens);
132        self.last_refill = now;
133    }
134
135    /// Check if message rate is within limits
136    fn check_message_rate(&mut self, config: &RateLimitConfig) -> Result<(), RateLimitError> {
137        self.refill_tokens(config);
138
139        if self.tokens >= 1.0 {
140            self.tokens -= 1.0;
141            Ok(())
142        } else {
143            Err(RateLimitError::LimitExceeded {
144                limit: config.max_messages_per_second,
145                window: Duration::from_secs(1),
146            })
147        }
148    }
149}
150
151/// Rate limiter for WebSocket connections
152#[derive(Debug)]
153pub struct WebSocketRateLimiter {
154    config: RateLimitConfig,
155    clients: Arc<DashMap<IpAddr, ClientRateLimit>>,
156}
157
158impl Default for WebSocketRateLimiter {
159    fn default() -> Self {
160        Self::new(RateLimitConfig::default())
161    }
162}
163
164impl WebSocketRateLimiter {
165    /// Create new rate limiter with configuration
166    pub fn new(config: RateLimitConfig) -> Self {
167        Self {
168            config,
169            clients: Arc::new(DashMap::new()),
170        }
171    }
172
173    /// Check if request is allowed (HTTP upgrade to WebSocket)
174    pub fn check_request(&self, ip: IpAddr) -> Result<(), RateLimitError> {
175        let now = Instant::now();
176        let burst = self.config.burst_allowance;
177        let mut client = self
178            .clients
179            .entry(ip)
180            .or_insert_with(|| ClientRateLimit::new(burst));
181
182        // Clean old requests outside window
183        let window_start = now - self.config.window_duration;
184        client.requests.retain(|&time| time > window_start);
185
186        // Check request rate limit
187        if client.requests.len() >= self.config.max_requests_per_window as usize {
188            return Err(RateLimitError::LimitExceeded {
189                limit: self.config.max_requests_per_window,
190                window: self.config.window_duration,
191            });
192        }
193
194        // Add current request
195        client.requests.push(now);
196        Ok(())
197    }
198
199    /// Check if new connection is allowed
200    pub fn check_connection(&self, ip: IpAddr) -> Result<(), RateLimitError> {
201        let burst = self.config.burst_allowance;
202        let mut client = self
203            .clients
204            .entry(ip)
205            .or_insert_with(|| ClientRateLimit::new(burst));
206
207        if client.connection_count >= self.config.max_connections_per_ip {
208            return Err(RateLimitError::ConnectionLimitExceeded {
209                current: client.connection_count,
210                max: self.config.max_connections_per_ip,
211            });
212        }
213
214        client.connection_count += 1;
215        Ok(())
216    }
217
218    /// Register connection close
219    pub fn close_connection(&self, ip: IpAddr) {
220        if let Some(mut client) = self.clients.get_mut(&ip) {
221            client.connection_count = client.connection_count.saturating_sub(1);
222        }
223    }
224
225    /// Check if WebSocket message is allowed
226    pub fn check_message(&self, ip: IpAddr, frame_size: usize) -> Result<(), RateLimitError> {
227        // Check frame size
228        if frame_size > self.config.max_frame_size {
229            return Err(RateLimitError::FrameSizeExceeded {
230                size: frame_size,
231                max: self.config.max_frame_size,
232            });
233        }
234
235        // Check message rate
236        if let Some(mut client) = self.clients.get_mut(&ip) {
237            client.check_message_rate(&self.config)?;
238        }
239
240        Ok(())
241    }
242
243    /// Get current statistics for monitoring
244    pub fn get_stats(&self) -> RateLimitStats {
245        let mut stats = RateLimitStats::default();
246
247        for entry in self.clients.iter() {
248            stats.total_clients += 1;
249            stats.total_connections += entry.value().connection_count;
250
251            if entry.value().connection_count > 0 {
252                stats.active_clients += 1;
253            }
254        }
255
256        stats
257    }
258
259    /// Clean up expired entries (call periodically)
260    pub fn cleanup_expired(&self) {
261        let now = Instant::now();
262        let cutoff = now - self.config.window_duration * 2; // Keep some history
263
264        self.clients.retain(|_, client| {
265            // Remove clients with no recent activity and no connections
266            !(client.connection_count == 0
267                && client.requests.last().is_none_or(|&time| time < cutoff))
268        });
269    }
270}
271
272/// Rate limiting statistics
273#[derive(Debug, Default, Clone)]
274pub struct RateLimitStats {
275    /// Total distinct clients tracked.
276    pub total_clients: usize,
277    /// Clients that have shown activity within the recent window.
278    pub active_clients: usize,
279    /// Sum of currently held connections across all clients.
280    pub total_connections: usize,
281}
282
283/// Rate limiting middleware for tracking client IPs
284#[derive(Debug, Clone)]
285pub struct RateLimitGuard {
286    rate_limiter: Arc<WebSocketRateLimiter>,
287    client_ip: IpAddr,
288}
289
290impl RateLimitGuard {
291    /// Create new guard for a client connection
292    pub fn new(
293        rate_limiter: Arc<WebSocketRateLimiter>,
294        client_ip: IpAddr,
295    ) -> Result<Self, RateLimitError> {
296        rate_limiter.check_connection(client_ip)?;
297
298        Ok(Self {
299            rate_limiter,
300            client_ip,
301        })
302    }
303
304    /// Check if message is allowed
305    pub fn check_message(&self, frame_size: usize) -> Result<(), RateLimitError> {
306        self.rate_limiter.check_message(self.client_ip, frame_size)
307    }
308}
309
310impl Drop for RateLimitGuard {
311    fn drop(&mut self) {
312        self.rate_limiter.close_connection(self.client_ip);
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use std::net::Ipv4Addr;
320    use std::thread;
321    use std::time::Duration;
322
323    #[test]
324    fn test_rate_limit_requests() {
325        let config = RateLimitConfig {
326            max_requests_per_window: 2,
327            window_duration: Duration::from_millis(100),
328            ..Default::default()
329        };
330
331        let limiter = WebSocketRateLimiter::new(config);
332        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
333
334        // First two requests should succeed
335        assert!(limiter.check_request(ip).is_ok());
336        assert!(limiter.check_request(ip).is_ok());
337
338        // Third request should be rate limited
339        assert!(limiter.check_request(ip).is_err());
340
341        // Wait for window to reset
342        thread::sleep(Duration::from_millis(110));
343
344        // Should work again
345        assert!(limiter.check_request(ip).is_ok());
346    }
347
348    #[test]
349    fn test_connection_limits() {
350        let config = RateLimitConfig {
351            max_connections_per_ip: 2,
352            ..Default::default()
353        };
354
355        let limiter = WebSocketRateLimiter::new(config);
356        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
357
358        // Two connections should succeed
359        assert!(limiter.check_connection(ip).is_ok());
360        assert!(limiter.check_connection(ip).is_ok());
361
362        // Third connection should fail
363        assert!(limiter.check_connection(ip).is_err());
364
365        // Close one connection
366        limiter.close_connection(ip);
367
368        // Should work again
369        assert!(limiter.check_connection(ip).is_ok());
370    }
371
372    #[test]
373    fn test_message_rate_limiting() {
374        let config = RateLimitConfig {
375            max_messages_per_second: 2,
376            burst_allowance: 2, // Allow 2 burst messages
377            ..Default::default()
378        };
379
380        let limiter = WebSocketRateLimiter::new(config.clone());
381        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
382
383        // First connection should create the client entry
384        let client = limiter
385            .clients
386            .entry(ip)
387            .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
388        // Tokens are already initialized with burst_allowance
389        drop(client);
390
391        // Should allow burst messages
392        assert!(limiter.check_message(ip, 1024).is_ok());
393        assert!(limiter.check_message(ip, 1024).is_ok());
394
395        // Should be rate limited now (no more tokens)
396        assert!(limiter.check_message(ip, 1024).is_err());
397    }
398
399    #[test]
400    fn test_frame_size_limits() {
401        let config = RateLimitConfig {
402            max_frame_size: 1024,
403            ..Default::default()
404        };
405
406        let limiter = WebSocketRateLimiter::new(config);
407        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
408
409        // Small frame should succeed
410        assert!(limiter.check_message(ip, 512).is_ok());
411
412        // Large frame should fail
413        assert!(limiter.check_message(ip, 2048).is_err());
414    }
415
416    #[test]
417    fn test_rate_limit_guard() {
418        let config = RateLimitConfig {
419            max_connections_per_ip: 1,
420            ..Default::default()
421        };
422
423        let limiter = Arc::new(WebSocketRateLimiter::new(config));
424        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
425
426        // Create guard
427        let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
428
429        // Second connection should fail
430        assert!(RateLimitGuard::new(limiter.clone(), ip).is_err());
431
432        // Drop guard
433        drop(guard);
434
435        // Should work again
436        assert!(RateLimitGuard::new(limiter, ip).is_ok());
437    }
438
439    #[test]
440    fn test_token_refill_over_time() {
441        let config = RateLimitConfig {
442            max_messages_per_second: 1,
443            burst_allowance: 0,
444            window_duration: Duration::from_millis(100),
445            ..Default::default()
446        };
447
448        let limiter = WebSocketRateLimiter::new(config.clone());
449        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
450
451        // Pre-fill tokens to test refill
452        {
453            let mut client = limiter
454                .clients
455                .entry(ip)
456                .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
457            client.tokens = 0.5; // Start with partial token
458        }
459
460        // Should fail with insufficient tokens
461        assert!(limiter.check_message(ip, 512).is_err());
462
463        // Wait for token refill (1 second = max_messages_per_second tokens)
464        thread::sleep(Duration::from_millis(1100));
465
466        // Should work again after tokens refill (refilled tokens + remaining time)
467        let result = limiter.check_message(ip, 512);
468        // After 1.1 seconds, should have refilled enough tokens to pass
469        assert!(result.is_ok(), "Expected refilled tokens to allow message");
470    }
471
472    #[test]
473    fn test_cleanup_expired_entries() {
474        let config = RateLimitConfig {
475            window_duration: Duration::from_millis(100),
476            ..Default::default()
477        };
478
479        let limiter = WebSocketRateLimiter::new(config);
480        let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
481        let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
482
483        // Add some client entries
484        assert!(limiter.check_connection(ip1).is_ok());
485        assert!(limiter.check_connection(ip2).is_ok());
486
487        // Should have 2 clients
488        assert_eq!(limiter.get_stats().total_clients, 2);
489
490        // Close connection for ip1
491        limiter.close_connection(ip1);
492
493        // Wait beyond the cleanup window
494        thread::sleep(Duration::from_millis(250));
495
496        // Cleanup should remove idle clients
497        limiter.cleanup_expired();
498
499        // After cleanup, ip1 should be removed but ip2 might remain if it has recent activity
500        let stats = limiter.get_stats();
501        // At minimum, ip1 should be cleaned up if no connections
502        assert!(stats.total_clients <= 2);
503    }
504
505    #[test]
506    fn test_multiple_ips_isolation() {
507        let config = RateLimitConfig {
508            max_requests_per_window: 1,
509            window_duration: Duration::from_millis(100),
510            ..Default::default()
511        };
512
513        let limiter = WebSocketRateLimiter::new(config);
514        let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
515        let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
516
517        // ip1 should be rate limited after 1 request
518        assert!(limiter.check_request(ip1).is_ok());
519        assert!(limiter.check_request(ip1).is_err());
520
521        // ip2 should NOT be affected by ip1's limit
522        assert!(limiter.check_request(ip2).is_ok());
523        assert!(limiter.check_request(ip2).is_err());
524    }
525
526    #[test]
527    fn test_burst_allowance_boundary() {
528        let config = RateLimitConfig {
529            max_messages_per_second: 1,
530            burst_allowance: 0,
531            ..Default::default()
532        };
533
534        let limiter = WebSocketRateLimiter::new(config.clone());
535        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
536
537        // With 0 burst, even the first message might be throttled
538        // depending on token distribution
539        let mut client = limiter
540            .clients
541            .entry(ip)
542            .or_insert_with(|| ClientRateLimit::new(config.burst_allowance));
543        client.tokens = 0.0;
544        drop(client);
545
546        // Should fail with no tokens
547        assert!(limiter.check_message(ip, 512).is_err());
548    }
549
550    #[test]
551    fn test_rate_limit_config_high_traffic() {
552        let config = RateLimitConfig::high_traffic();
553
554        assert_eq!(config.max_requests_per_window, 1000);
555        assert_eq!(config.max_connections_per_ip, 50);
556        assert_eq!(config.max_messages_per_second, 100);
557        assert_eq!(config.burst_allowance, 20);
558        assert!(config.max_frame_size >= 1024 * 1024);
559    }
560
561    #[test]
562    fn test_rate_limit_config_low_resource() {
563        let config = RateLimitConfig::low_resource();
564
565        assert_eq!(config.max_requests_per_window, 20);
566        assert_eq!(config.max_connections_per_ip, 2);
567        assert_eq!(config.max_messages_per_second, 5);
568        assert_eq!(config.burst_allowance, 2);
569        assert_eq!(config.max_frame_size, 256 * 1024);
570    }
571
572    #[test]
573    fn test_frame_size_boundary_exact() {
574        let config = RateLimitConfig {
575            max_frame_size: 1024,
576            ..Default::default()
577        };
578
579        let limiter = WebSocketRateLimiter::new(config);
580        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
581
582        // Exactly at limit should succeed
583        assert!(limiter.check_message(ip, 1024).is_ok());
584
585        // Just over limit should fail
586        assert!(limiter.check_message(ip, 1025).is_err());
587
588        // Zero-size frame should succeed (though uncommon)
589        assert!(limiter.check_message(ip, 0).is_ok());
590    }
591
592    #[test]
593    fn test_get_stats_accuracy() {
594        let config = RateLimitConfig {
595            max_connections_per_ip: 5,
596            ..Default::default()
597        };
598
599        let limiter = WebSocketRateLimiter::new(config);
600        let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
601        let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
602
603        // Add connections
604        assert!(limiter.check_connection(ip1).is_ok());
605        assert!(limiter.check_connection(ip1).is_ok());
606        assert!(limiter.check_connection(ip2).is_ok());
607
608        let stats = limiter.get_stats();
609        assert_eq!(stats.total_clients, 2);
610        assert_eq!(stats.total_connections, 3);
611        assert_eq!(stats.active_clients, 2);
612
613        // Close a connection
614        limiter.close_connection(ip1);
615
616        let stats = limiter.get_stats();
617        assert_eq!(stats.total_connections, 2);
618    }
619
620    #[test]
621    fn test_window_duration_respected() {
622        let config = RateLimitConfig {
623            max_requests_per_window: 1,
624            window_duration: Duration::from_millis(50),
625            ..Default::default()
626        };
627
628        let limiter = WebSocketRateLimiter::new(config);
629        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
630
631        // First request succeeds
632        assert!(limiter.check_request(ip).is_ok());
633
634        // Second request within window fails
635        assert!(limiter.check_request(ip).is_err());
636
637        // Wait for window to pass
638        thread::sleep(Duration::from_millis(60));
639
640        // Request after window passes succeeds
641        assert!(limiter.check_request(ip).is_ok());
642    }
643
644    #[test]
645    fn test_default_limiter() {
646        // Test Default implementation for WebSocketRateLimiter
647        let limiter = WebSocketRateLimiter::default();
648        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
649
650        // Default limiter should allow requests
651        assert!(limiter.check_request(ip).is_ok());
652        assert!(limiter.check_connection(ip).is_ok());
653
654        // Verify default config values are applied
655        let stats = limiter.get_stats();
656        assert_eq!(stats.total_clients, 1);
657        assert_eq!(stats.total_connections, 1);
658    }
659
660    #[test]
661    fn test_cleanup_expired_removes_inactive_clients() {
662        let config = RateLimitConfig {
663            window_duration: Duration::from_millis(50),
664            ..Default::default()
665        };
666
667        let limiter = WebSocketRateLimiter::new(config);
668        let ip1 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
669        let ip2 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
670        let ip3 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 3));
671
672        // Add requests for multiple IPs
673        assert!(limiter.check_request(ip1).is_ok());
674        assert!(limiter.check_request(ip2).is_ok());
675        assert!(limiter.check_connection(ip3).is_ok());
676
677        let initial_stats = limiter.get_stats();
678        assert_eq!(initial_stats.total_clients, 3);
679
680        // Wait for cleanup window
681        thread::sleep(Duration::from_millis(150));
682
683        // ip3 has no requests, so it should be removed
684        limiter.cleanup_expired();
685
686        let after_cleanup = limiter.get_stats();
687        // ip3 should be removed (no requests, no connections after cleanup)
688        assert!(after_cleanup.total_clients <= initial_stats.total_clients);
689    }
690
691    #[test]
692    fn test_client_with_zero_connections_and_no_recent_requests_cleaned() {
693        let config = RateLimitConfig {
694            window_duration: Duration::from_millis(100),
695            ..Default::default()
696        };
697
698        let limiter = WebSocketRateLimiter::new(config);
699        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
700
701        // Make a request
702        assert!(limiter.check_request(ip).is_ok());
703
704        // Verify client exists
705        let initial_stats = limiter.get_stats();
706        assert_eq!(initial_stats.total_clients, 1);
707
708        // Wait beyond cleanup window (2x window_duration)
709        thread::sleep(Duration::from_millis(250));
710
711        // Cleanup should remove the client (no connections and stale requests)
712        limiter.cleanup_expired();
713
714        let final_stats = limiter.get_stats();
715        // The client should be removed if no active connections
716        assert_eq!(final_stats.total_clients, 0);
717    }
718
719    #[test]
720    fn test_cleanup_preserves_active_clients() {
721        let config = RateLimitConfig {
722            window_duration: Duration::from_millis(100),
723            ..Default::default()
724        };
725
726        let limiter = WebSocketRateLimiter::new(config);
727        let ip1 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
728        let ip2 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
729
730        // ip1: has active connection
731        assert!(limiter.check_connection(ip1).is_ok());
732
733        // ip2: has recent request but no connection
734        assert!(limiter.check_request(ip2).is_ok());
735
736        let initial_stats = limiter.get_stats();
737        assert_eq!(initial_stats.total_clients, 2);
738
739        // Wait some time (but not beyond full cleanup window)
740        thread::sleep(Duration::from_millis(80));
741
742        // Make another request to ip2 to keep it fresh
743        let _ = limiter.check_request(ip2);
744
745        // Cleanup should preserve both clients
746        limiter.cleanup_expired();
747
748        let final_stats = limiter.get_stats();
749        // ip1 should be preserved (active connection)
750        assert!(final_stats.total_clients >= 1);
751    }
752
753    #[test]
754    fn test_close_connection_on_nonexistent_ip() {
755        let limiter = WebSocketRateLimiter::default();
756        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 99));
757
758        // Closing connection on non-existent IP should not panic
759        limiter.close_connection(ip);
760
761        // Stats should be empty
762        let stats = limiter.get_stats();
763        assert_eq!(stats.total_clients, 0);
764    }
765
766    #[test]
767    fn test_check_message_on_nonexistent_client() {
768        let limiter = WebSocketRateLimiter::default();
769        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 88));
770
771        // Checking message on non-existent IP should be OK for frame size
772        // but not create the client entry if it doesn't exist in clients map
773        assert!(limiter.check_message(ip, 512).is_ok());
774    }
775
776    #[test]
777    fn test_rate_limit_guard_check_message() {
778        let config = RateLimitConfig {
779            max_connections_per_ip: 5,
780            max_frame_size: 1024,
781            max_messages_per_second: 10,
782            burst_allowance: 5,
783            ..Default::default()
784        };
785
786        let limiter = Arc::new(WebSocketRateLimiter::new(config));
787        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
788
789        let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
790
791        assert!(guard.check_message(512).is_ok());
792        assert!(guard.check_message(512).is_ok());
793        assert!(guard.check_message(2048).is_err());
794    }
795
796    #[test]
797    fn test_rate_limit_guard_check_message_rate_limit() {
798        let config = RateLimitConfig {
799            max_connections_per_ip: 5,
800            max_frame_size: 10_000,
801            max_messages_per_second: 2,
802            burst_allowance: 2,
803            ..Default::default()
804        };
805
806        let limiter = Arc::new(WebSocketRateLimiter::new(config));
807        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2));
808
809        let guard = RateLimitGuard::new(limiter.clone(), ip).unwrap();
810
811        assert!(guard.check_message(512).is_ok());
812        assert!(guard.check_message(512).is_ok());
813        assert!(guard.check_message(512).is_err());
814    }
815}