Skip to main content

pjson_rs/infrastructure/websocket/
security.rs

1//! WebSocket security integration with rate limiting
2
3use crate::security::{RateLimitConfig, RateLimitError, RateLimitGuard, WebSocketRateLimiter};
4use std::net::IpAddr;
5use std::sync::Arc;
6use tracing::{error, info, warn};
7
8/// Security-enhanced WebSocket handler with rate limiting
9#[derive(Debug)]
10pub struct SecureWebSocketHandler {
11    rate_limiter: Arc<WebSocketRateLimiter>,
12}
13
14impl SecureWebSocketHandler {
15    /// Create new secure handler with rate limiting configuration
16    pub fn new(rate_limit_config: RateLimitConfig) -> Self {
17        let rate_limiter = Arc::new(WebSocketRateLimiter::new(rate_limit_config));
18
19        // Start background cleanup task
20        let limiter_cleanup = rate_limiter.clone();
21        tokio::spawn(async move {
22            let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); // 5 minutes
23            loop {
24                interval.tick().await;
25                limiter_cleanup.cleanup_expired();
26                info!("Rate limiter cleanup completed");
27            }
28        });
29
30        Self { rate_limiter }
31    }
32
33    /// Create with default security configuration
34    pub fn with_default_security() -> Self {
35        Self::new(RateLimitConfig::default())
36    }
37
38    /// Create with high-traffic configuration
39    pub fn with_high_traffic_security() -> Self {
40        Self::new(RateLimitConfig::high_traffic())
41    }
42
43    /// Create with low-resource configuration
44    pub fn with_low_resource_security() -> Self {
45        Self::new(RateLimitConfig::low_resource())
46    }
47
48    /// Check if HTTP upgrade request is allowed
49    pub fn check_upgrade_request(&self, client_ip: IpAddr) -> Result<(), RateLimitError> {
50        match self.rate_limiter.check_request(client_ip) {
51            Ok(()) => {
52                info!("WebSocket upgrade request allowed for IP: {}", client_ip);
53                Ok(())
54            }
55            Err(e) => {
56                warn!(
57                    "WebSocket upgrade request denied for IP {}: {}",
58                    client_ip, e
59                );
60                Err(e)
61            }
62        }
63    }
64
65    /// Create connection guard for a new WebSocket connection
66    pub fn create_connection_guard(
67        &self,
68        client_ip: IpAddr,
69    ) -> Result<RateLimitGuard, RateLimitError> {
70        match RateLimitGuard::new(self.rate_limiter.clone(), client_ip) {
71            Ok(guard) => {
72                info!("WebSocket connection established for IP: {}", client_ip);
73                Ok(guard)
74            }
75            Err(e) => {
76                warn!("WebSocket connection denied for IP {}: {}", client_ip, e);
77                Err(e)
78            }
79        }
80    }
81
82    /// Validate WebSocket message before processing
83    pub fn validate_message(
84        &self,
85        guard: &RateLimitGuard,
86        frame_size: usize,
87    ) -> Result<(), RateLimitError> {
88        match guard.check_message(frame_size) {
89            Ok(()) => Ok(()),
90            Err(e) => {
91                error!("WebSocket message validation failed: {}", e);
92                Err(e)
93            }
94        }
95    }
96
97    /// Get current rate limiting statistics
98    pub fn get_security_stats(&self) -> crate::security::RateLimitStats {
99        self.rate_limiter.get_stats()
100    }
101
102    /// Force cleanup of expired rate limit entries
103    pub fn force_cleanup(&self) {
104        self.rate_limiter.cleanup_expired();
105    }
106}
107
108impl Clone for SecureWebSocketHandler {
109    fn clone(&self) -> Self {
110        Self {
111            rate_limiter: self.rate_limiter.clone(),
112        }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use std::net::Ipv4Addr;
120
121    #[tokio::test]
122    async fn test_secure_websocket_handler() {
123        let handler = SecureWebSocketHandler::with_default_security();
124        let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
125
126        // Should allow upgrade request
127        assert!(handler.check_upgrade_request(ip).is_ok());
128
129        // Should allow connection
130        let guard = handler.create_connection_guard(ip).unwrap();
131
132        // Should allow message
133        let result = handler.validate_message(&guard, 1024);
134        assert!(
135            result.is_ok(),
136            "validate_message failed: {:?}",
137            result.err()
138        );
139
140        // Get stats
141        let stats = handler.get_security_stats();
142        assert!(stats.total_clients > 0);
143    }
144
145    #[tokio::test]
146    async fn test_rate_limiting_enforcement() {
147        let config = RateLimitConfig {
148            max_requests_per_window: 1,
149            max_connections_per_ip: 1,
150            max_messages_per_second: 1,
151            burst_allowance: 0,
152            ..Default::default()
153        };
154
155        let handler = SecureWebSocketHandler::new(config);
156        let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
157
158        // First request should succeed
159        assert!(handler.check_upgrade_request(ip).is_ok());
160
161        // Second request should be rate limited
162        assert!(handler.check_upgrade_request(ip).is_err());
163
164        // Connection should succeed after first request
165        let _guard = handler.create_connection_guard(ip).unwrap();
166
167        // Second connection should fail
168        assert!(handler.create_connection_guard(ip).is_err());
169    }
170
171    #[tokio::test]
172    async fn test_different_security_levels() {
173        let default_handler = SecureWebSocketHandler::with_default_security();
174        let high_traffic_handler = SecureWebSocketHandler::with_high_traffic_security();
175        let low_resource_handler = SecureWebSocketHandler::with_low_resource_security();
176
177        let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
178
179        // All should allow initial requests
180        assert!(default_handler.check_upgrade_request(ip).is_ok());
181        assert!(high_traffic_handler.check_upgrade_request(ip).is_ok());
182        assert!(low_resource_handler.check_upgrade_request(ip).is_ok());
183
184        // Create connections with different limits
185        let _guard1 = default_handler.create_connection_guard(ip).unwrap();
186        let _guard2 = high_traffic_handler.create_connection_guard(ip).unwrap();
187        let _guard3 = low_resource_handler.create_connection_guard(ip).unwrap();
188    }
189}