pjson_rs/infrastructure/websocket/
security.rs1use crate::security::{RateLimitConfig, RateLimitError, RateLimitGuard, WebSocketRateLimiter};
4use std::net::IpAddr;
5use std::sync::Arc;
6use tracing::{error, info, warn};
7
8#[derive(Debug)]
10pub struct SecureWebSocketHandler {
11 rate_limiter: Arc<WebSocketRateLimiter>,
12}
13
14impl SecureWebSocketHandler {
15 pub fn new(rate_limit_config: RateLimitConfig) -> Self {
17 let rate_limiter = Arc::new(WebSocketRateLimiter::new(rate_limit_config));
18
19 let limiter_cleanup = rate_limiter.clone();
21 tokio::spawn(async move {
22 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); loop {
24 interval.tick().await;
25 limiter_cleanup.cleanup_expired();
26 info!("Rate limiter cleanup completed");
27 }
28 });
29
30 Self { rate_limiter }
31 }
32
33 pub fn with_default_security() -> Self {
35 Self::new(RateLimitConfig::default())
36 }
37
38 pub fn with_high_traffic_security() -> Self {
40 Self::new(RateLimitConfig::high_traffic())
41 }
42
43 pub fn with_low_resource_security() -> Self {
45 Self::new(RateLimitConfig::low_resource())
46 }
47
48 pub fn check_upgrade_request(&self, client_ip: IpAddr) -> Result<(), RateLimitError> {
50 match self.rate_limiter.check_request(client_ip) {
51 Ok(()) => {
52 info!("WebSocket upgrade request allowed for IP: {}", client_ip);
53 Ok(())
54 }
55 Err(e) => {
56 warn!(
57 "WebSocket upgrade request denied for IP {}: {}",
58 client_ip, e
59 );
60 Err(e)
61 }
62 }
63 }
64
65 pub fn create_connection_guard(
67 &self,
68 client_ip: IpAddr,
69 ) -> Result<RateLimitGuard, RateLimitError> {
70 match RateLimitGuard::new(self.rate_limiter.clone(), client_ip) {
71 Ok(guard) => {
72 info!("WebSocket connection established for IP: {}", client_ip);
73 Ok(guard)
74 }
75 Err(e) => {
76 warn!("WebSocket connection denied for IP {}: {}", client_ip, e);
77 Err(e)
78 }
79 }
80 }
81
82 pub fn validate_message(
84 &self,
85 guard: &RateLimitGuard,
86 frame_size: usize,
87 ) -> Result<(), RateLimitError> {
88 match guard.check_message(frame_size) {
89 Ok(()) => Ok(()),
90 Err(e) => {
91 error!("WebSocket message validation failed: {}", e);
92 Err(e)
93 }
94 }
95 }
96
97 pub fn get_security_stats(&self) -> crate::security::RateLimitStats {
99 self.rate_limiter.get_stats()
100 }
101
102 pub fn force_cleanup(&self) {
104 self.rate_limiter.cleanup_expired();
105 }
106}
107
108impl Clone for SecureWebSocketHandler {
109 fn clone(&self) -> Self {
110 Self {
111 rate_limiter: self.rate_limiter.clone(),
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use std::net::Ipv4Addr;
120
121 #[tokio::test]
122 async fn test_secure_websocket_handler() {
123 let handler = SecureWebSocketHandler::with_default_security();
124 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
125
126 assert!(handler.check_upgrade_request(ip).is_ok());
128
129 let guard = handler.create_connection_guard(ip).unwrap();
131
132 let result = handler.validate_message(&guard, 1024);
134 assert!(
135 result.is_ok(),
136 "validate_message failed: {:?}",
137 result.err()
138 );
139
140 let stats = handler.get_security_stats();
142 assert!(stats.total_clients > 0);
143 }
144
145 #[tokio::test]
146 async fn test_rate_limiting_enforcement() {
147 let config = RateLimitConfig {
148 max_requests_per_window: 1,
149 max_connections_per_ip: 1,
150 max_messages_per_second: 1,
151 burst_allowance: 0,
152 ..Default::default()
153 };
154
155 let handler = SecureWebSocketHandler::new(config);
156 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
157
158 assert!(handler.check_upgrade_request(ip).is_ok());
160
161 assert!(handler.check_upgrade_request(ip).is_err());
163
164 let _guard = handler.create_connection_guard(ip).unwrap();
166
167 assert!(handler.create_connection_guard(ip).is_err());
169 }
170
171 #[tokio::test]
172 async fn test_different_security_levels() {
173 let default_handler = SecureWebSocketHandler::with_default_security();
174 let high_traffic_handler = SecureWebSocketHandler::with_high_traffic_security();
175 let low_resource_handler = SecureWebSocketHandler::with_low_resource_security();
176
177 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
178
179 assert!(default_handler.check_upgrade_request(ip).is_ok());
181 assert!(high_traffic_handler.check_upgrade_request(ip).is_ok());
182 assert!(low_resource_handler.check_upgrade_request(ip).is_ok());
183
184 let _guard1 = default_handler.create_connection_guard(ip).unwrap();
186 let _guard2 = high_traffic_handler.create_connection_guard(ip).unwrap();
187 let _guard3 = low_resource_handler.create_connection_guard(ip).unwrap();
188 }
189}