1use crate::{
4 config::SecurityConfig,
5 error::{Error, Result},
6};
7
8pub mod compression_bomb;
9pub mod rate_limit;
10
11pub use compression_bomb::{
12 CompressionBombConfig, CompressionBombDetector, CompressionBombError, CompressionBombProtector,
13 CompressionStats,
14};
15pub use rate_limit::{
16 RateLimitConfig, RateLimitError, RateLimitGuard, RateLimitStats, WebSocketRateLimiter,
17};
18
19#[derive(Debug, Clone)]
21pub struct SecurityValidator {
22 config: SecurityConfig,
23}
24
25impl SecurityValidator {
26 pub fn new(config: SecurityConfig) -> Self {
28 Self { config }
29 }
30
31 pub fn validate_input_size(&self, size: usize) -> Result<()> {
33 if size > self.config.json.max_input_size {
34 return Err(Error::Other(format!(
35 "Input size {} exceeds maximum allowed {} bytes",
36 size, self.config.json.max_input_size
37 )));
38 }
39 Ok(())
40 }
41
42 pub fn validate_json_depth(&self, depth: usize) -> Result<()> {
44 if depth > self.config.json.max_depth {
45 return Err(Error::Other(format!(
46 "JSON nesting depth {} exceeds maximum allowed {}",
47 depth, self.config.json.max_depth
48 )));
49 }
50 Ok(())
51 }
52
53 pub fn validate_array_length(&self, length: usize) -> Result<()> {
55 if length > self.config.json.max_array_length {
56 return Err(Error::Other(format!(
57 "Array length {} exceeds maximum allowed {}",
58 length, self.config.json.max_array_length
59 )));
60 }
61 Ok(())
62 }
63
64 pub fn validate_object_keys(&self, key_count: usize) -> Result<()> {
66 if key_count > self.config.json.max_object_keys {
67 return Err(Error::Other(format!(
68 "Object key count {} exceeds maximum allowed {}",
69 key_count, self.config.json.max_object_keys
70 )));
71 }
72 Ok(())
73 }
74
75 pub fn validate_string_length(&self, length: usize) -> Result<()> {
77 if length > self.config.json.max_string_length {
78 return Err(Error::Other(format!(
79 "String length {} exceeds maximum allowed {}",
80 length, self.config.json.max_string_length
81 )));
82 }
83 Ok(())
84 }
85
86 pub fn validate_session_id(&self, session_id: &str) -> Result<()> {
88 let len = session_id.len();
89
90 if len < self.config.sessions.min_session_id_length {
91 return Err(Error::Other(format!(
92 "Session ID too short: {} characters (minimum {})",
93 len, self.config.sessions.min_session_id_length
94 )));
95 }
96
97 if len > self.config.sessions.max_session_id_length {
98 return Err(Error::Other(format!(
99 "Session ID too long: {} characters (maximum {})",
100 len, self.config.sessions.max_session_id_length
101 )));
102 }
103
104 if !session_id
106 .chars()
107 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
108 {
109 return Err(Error::Other(
110 "Session ID contains invalid characters (only alphanumeric, hyphens and underscores allowed)".to_string()
111 ));
112 }
113
114 Ok(())
115 }
116
117 pub fn validate_websocket_frame_size(&self, size: usize) -> Result<()> {
119 if size > self.config.network.max_websocket_frame_size {
120 return Err(Error::Other(format!(
121 "WebSocket frame size {} exceeds maximum allowed {}",
122 size, self.config.network.max_websocket_frame_size
123 )));
124 }
125 Ok(())
126 }
127
128 pub fn validate_buffer_size(&self, size: usize) -> Result<()> {
130 if size > self.config.buffers.max_buffer_size {
131 return Err(Error::Other(format!(
132 "Buffer size {} exceeds maximum allowed {}",
133 size, self.config.buffers.max_buffer_size
134 )));
135 }
136 Ok(())
137 }
138}
139
140impl Default for SecurityValidator {
141 fn default() -> Self {
142 Self::new(SecurityConfig::default())
143 }
144}
145
146pub struct DepthTracker {
148 current_depth: usize,
149 max_depth: usize,
150}
151
152impl DepthTracker {
153 pub fn from_config(config: &SecurityConfig) -> Self {
155 Self {
156 current_depth: 0,
157 max_depth: config.json.max_depth,
158 }
159 }
160
161 pub fn with_max_depth(max_depth: usize) -> Self {
163 Self {
164 current_depth: 0,
165 max_depth,
166 }
167 }
168
169 pub fn enter(&mut self) -> Result<()> {
171 if self.current_depth >= self.max_depth {
172 return Err(Error::Other(format!(
173 "JSON nesting depth {} would exceed maximum allowed {}",
174 self.current_depth + 1,
175 self.max_depth
176 )));
177 }
178 self.current_depth += 1;
179 Ok(())
180 }
181
182 pub fn exit(&mut self) {
184 if self.current_depth > 0 {
185 self.current_depth -= 1;
186 }
187 }
188
189 pub fn current_depth(&self) -> usize {
191 self.current_depth
192 }
193}
194
195impl Default for DepthTracker {
196 fn default() -> Self {
197 Self::with_max_depth(64) }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_security_validator_default() {
207 let validator = SecurityValidator::default();
208
209 assert!(validator.validate_input_size(1024).is_ok());
210 assert!(validator.validate_json_depth(10).is_ok());
211 assert!(validator.validate_session_id("valid-session-123").is_ok());
212 }
213
214 #[test]
215 fn test_security_validator_with_config() {
216 let config = SecurityConfig::low_memory();
217 let validator = SecurityValidator::new(config.clone());
218
219 assert!(
221 validator
222 .validate_input_size(config.json.max_input_size)
223 .is_ok()
224 );
225 assert!(
226 validator
227 .validate_input_size(config.json.max_input_size + 1)
228 .is_err()
229 );
230
231 assert!(validator.validate_json_depth(config.json.max_depth).is_ok());
233 assert!(
234 validator
235 .validate_json_depth(config.json.max_depth + 1)
236 .is_err()
237 );
238 }
239
240 #[test]
241 fn test_validate_session_id() {
242 let validator = SecurityValidator::default();
243
244 assert!(validator.validate_session_id("session-123").is_ok());
246 assert!(validator.validate_session_id("abcd1234-5678-90ef").is_ok());
247 assert!(validator.validate_session_id("test_session_id").is_ok());
248
249 assert!(validator.validate_session_id("ab").is_err()); assert!(validator.validate_session_id(&"a".repeat(200)).is_err()); assert!(validator.validate_session_id("session@123").is_err()); assert!(validator.validate_session_id("session 123").is_err()); }
255
256 #[test]
257 fn test_depth_tracker() {
258 let mut tracker = DepthTracker::with_max_depth(64);
259
260 assert_eq!(tracker.current_depth(), 0);
261
262 assert!(tracker.enter().is_ok());
263 assert_eq!(tracker.current_depth(), 1);
264
265 assert!(tracker.enter().is_ok());
266 assert_eq!(tracker.current_depth(), 2);
267
268 tracker.exit();
269 assert_eq!(tracker.current_depth(), 1);
270
271 tracker.exit();
272 assert_eq!(tracker.current_depth(), 0);
273 }
274
275 #[test]
276 fn test_depth_tracker_limit() {
277 let mut tracker = DepthTracker::with_max_depth(2);
278
279 assert!(tracker.enter().is_ok());
280 assert!(tracker.enter().is_ok());
281 assert!(tracker.enter().is_err()); }
283
284 #[test]
285 fn test_depth_tracker_from_config() {
286 let config = SecurityConfig::low_memory();
287 let mut tracker = DepthTracker::from_config(&config);
288
289 for _ in 0..config.json.max_depth {
291 assert!(tracker.enter().is_ok());
292 }
293 assert!(tracker.enter().is_err()); }
295
296 #[test]
297 fn test_high_throughput_config() {
298 let config = SecurityConfig::high_throughput();
299 let _validator = SecurityValidator::new(config.clone());
300
301 let default_config = SecurityConfig::default();
303 assert!(config.json.max_input_size >= default_config.json.max_input_size);
304 assert!(config.buffers.max_total_memory >= default_config.buffers.max_total_memory);
305 }
306
307 #[test]
308 fn test_validate_array_length() {
309 let config = SecurityConfig::low_memory();
310 let max_len = config.json.max_array_length;
311 let validator = SecurityValidator::new(config);
312
313 assert!(validator.validate_array_length(100).is_ok());
315
316 let result = validator.validate_array_length(max_len + 1);
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_validate_object_keys() {
323 let validator = SecurityValidator::default();
324
325 assert!(validator.validate_object_keys(10).is_ok());
327
328 let config = SecurityConfig::low_memory();
330 let max_keys = config.json.max_object_keys;
331 let validator = SecurityValidator::new(config);
332 let result = validator.validate_object_keys(max_keys + 1);
333 assert!(result.is_err());
334 }
335
336 #[test]
337 fn test_validate_string_length() {
338 let validator = SecurityValidator::default();
339
340 assert!(validator.validate_string_length(100).is_ok());
342
343 let config = SecurityConfig::low_memory();
345 let max_str_len = config.json.max_string_length;
346 let validator = SecurityValidator::new(config);
347 let result = validator.validate_string_length(max_str_len + 1);
348 assert!(result.is_err());
349 }
350
351 #[test]
352 fn test_validate_websocket_frame_size() {
353 let validator = SecurityValidator::default();
354
355 assert!(validator.validate_websocket_frame_size(1024).is_ok());
357
358 let config = SecurityConfig::low_memory();
360 let max_frame = config.network.max_websocket_frame_size;
361 let validator = SecurityValidator::new(config);
362 let result = validator.validate_websocket_frame_size(max_frame + 1);
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn test_validate_buffer_size() {
368 let validator = SecurityValidator::default();
369
370 assert!(validator.validate_buffer_size(4096).is_ok());
372
373 let config = SecurityConfig::low_memory();
375 let max_buf = config.buffers.max_buffer_size;
376 let validator = SecurityValidator::new(config);
377 let result = validator.validate_buffer_size(max_buf + 1);
378 assert!(result.is_err());
379 }
380
381 #[test]
382 fn test_validate_input_size_boundary() {
383 let config = SecurityConfig::low_memory();
384 let max_input = config.json.max_input_size;
385 let validator = SecurityValidator::new(config);
386
387 assert!(validator.validate_input_size(max_input).is_ok());
389
390 let result = validator.validate_input_size(max_input + 1);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_validate_json_depth_boundary() {
397 let config = SecurityConfig::low_memory();
398 let max_depth = config.json.max_depth;
399 let validator = SecurityValidator::new(config);
400
401 assert!(validator.validate_json_depth(max_depth).is_ok());
403
404 let result = validator.validate_json_depth(max_depth + 1);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_session_id_length_boundaries() {
411 let validator = SecurityValidator::default();
412
413 let result = validator.validate_session_id("a");
415 assert!(result.is_err());
416
417 let long_id = "a".repeat(200);
419 let result = validator.validate_session_id(&long_id);
420 assert!(result.is_err());
421
422 assert!(
424 validator
425 .validate_session_id("valid-session-id-123")
426 .is_ok()
427 );
428
429 assert!(
431 validator
432 .validate_session_id("valid_session_id_123")
433 .is_ok()
434 );
435 }
436
437 #[test]
438 fn test_session_id_invalid_characters() {
439 let validator = SecurityValidator::default();
440
441 let result = validator.validate_session_id("session@123");
443 assert!(result.is_err());
444
445 let result = validator.validate_session_id("session 123");
447 assert!(result.is_err());
448
449 let result = validator.validate_session_id("session.123");
451 assert!(result.is_err());
452
453 assert!(validator.validate_session_id("session123").is_ok());
455 }
456
457 #[test]
458 fn test_depth_tracker_boundary_cases() {
459 let mut tracker = DepthTracker::with_max_depth(1);
460
461 assert!(tracker.enter().is_ok());
463 assert_eq!(tracker.current_depth(), 1);
464
465 assert!(tracker.enter().is_err());
467 assert_eq!(tracker.current_depth(), 1); tracker.exit();
471 assert_eq!(tracker.current_depth(), 0);
472
473 assert!(tracker.enter().is_ok());
475 }
476
477 #[test]
478 fn test_depth_tracker_exit_at_zero() {
479 let mut tracker = DepthTracker::with_max_depth(64);
480
481 assert_eq!(tracker.current_depth(), 0);
483
484 tracker.exit();
486 assert_eq!(tracker.current_depth(), 0);
487
488 assert!(tracker.enter().is_ok());
490 }
491
492 #[test]
493 fn test_depth_tracker_multiple_cycles() {
494 let mut tracker = DepthTracker::with_max_depth(3);
495
496 assert!(tracker.enter().is_ok());
498 tracker.exit();
499 assert_eq!(tracker.current_depth(), 0);
500
501 assert!(tracker.enter().is_ok());
503 assert!(tracker.enter().is_ok());
504 tracker.exit();
505 assert_eq!(tracker.current_depth(), 1);
506 tracker.exit();
507 assert_eq!(tracker.current_depth(), 0);
508
509 for _ in 0..3 {
511 assert!(tracker.enter().is_ok());
512 }
513 assert_eq!(tracker.current_depth(), 3);
514 }
515}