1use std::fmt;
9
10#[derive(Debug, Clone)]
15pub enum SecurityError {
16 RateLimitExceeded {
23 retry_after: u64,
25 limit: usize,
27 window_secs: u64,
29 },
30
31 QueryTooDeep {
36 depth: usize,
38 max_depth: usize,
40 },
41
42 QueryTooComplex {
47 complexity: usize,
49 max_complexity: usize,
51 },
52
53 QueryTooLarge {
57 size: usize,
59 max_size: usize,
61 },
62
63 OriginNotAllowed(String),
65
66 MethodNotAllowed(String),
68
69 HeaderNotAllowed(String),
71
72 InvalidCSRFToken(String),
74
75 CSRFSessionMismatch,
77
78 AuditLogFailure(String),
83
84 SecurityConfigError(String),
88
89 TlsRequired {
94 detail: String,
96 },
97
98 TlsVersionTooOld {
103 current: crate::security::TlsVersion,
105 required: crate::security::TlsVersion,
107 },
108
109 MtlsRequired {
114 detail: String,
116 },
117
118 InvalidClientCert {
124 detail: String,
126 },
127
128 AuthRequired,
134
135 InvalidToken,
140
141 TokenExpired {
146 expired_at: chrono::DateTime<chrono::Utc>,
148 },
149
150 TokenMissingClaim {
154 claim: String,
156 },
157
158 InvalidTokenAlgorithm {
163 algorithm: String,
165 },
166
167 IntrospectionDisabled {
172 detail: String,
174 },
175}
176
177pub type Result<T> = std::result::Result<T, SecurityError>;
181
182impl fmt::Display for SecurityError {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Self::RateLimitExceeded {
186 retry_after,
187 limit,
188 window_secs,
189 } => {
190 write!(
191 f,
192 "Rate limit exceeded. Limit: {limit} per {window_secs} seconds. Retry after: {retry_after} seconds"
193 )
194 },
195 Self::QueryTooDeep { depth, max_depth } => {
196 write!(f, "Query too deep: {depth} levels (max: {max_depth})")
197 },
198 Self::QueryTooComplex {
199 complexity,
200 max_complexity,
201 } => {
202 write!(f, "Query too complex: {complexity} (max: {max_complexity})")
203 },
204 Self::QueryTooLarge { size, max_size } => {
205 write!(f, "Query too large: {size} bytes (max: {max_size})")
206 },
207 Self::OriginNotAllowed(origin) => {
208 write!(f, "CORS origin not allowed: {origin}")
209 },
210 Self::MethodNotAllowed(method) => {
211 write!(f, "CORS method not allowed: {method}")
212 },
213 Self::HeaderNotAllowed(header) => {
214 write!(f, "CORS header not allowed: {header}")
215 },
216 Self::InvalidCSRFToken(reason) => {
217 write!(f, "Invalid CSRF token: {reason}")
218 },
219 Self::CSRFSessionMismatch => {
220 write!(f, "CSRF token session mismatch")
221 },
222 Self::AuditLogFailure(reason) => {
223 write!(f, "Audit logging failed: {reason}")
224 },
225 Self::SecurityConfigError(reason) => {
226 write!(f, "Security configuration error: {reason}")
227 },
228 Self::TlsRequired { detail } => {
229 write!(f, "TLS/HTTPS required: {detail}")
230 },
231 Self::TlsVersionTooOld { current, required } => {
232 write!(f, "TLS version too old: {current} (required: {required})")
233 },
234 Self::MtlsRequired { detail } => {
235 write!(f, "Mutual TLS required: {detail}")
236 },
237 Self::InvalidClientCert { detail } => {
238 write!(f, "Invalid client certificate: {detail}")
239 },
240 Self::AuthRequired => {
241 write!(f, "Authentication required")
242 },
243 Self::InvalidToken => {
244 write!(f, "Invalid authentication token")
245 },
246 Self::TokenExpired { expired_at } => {
247 write!(f, "Token expired at {expired_at}")
248 },
249 Self::TokenMissingClaim { claim } => {
250 write!(f, "Token missing required claim: {claim}")
251 },
252 Self::InvalidTokenAlgorithm { algorithm } => {
253 write!(f, "Invalid token algorithm: {algorithm}")
254 },
255 Self::IntrospectionDisabled { detail } => {
256 write!(f, "Introspection disabled: {detail}")
257 },
258 }
259 }
260}
261
262impl std::error::Error for SecurityError {}
263
264impl PartialEq for SecurityError {
265 fn eq(&self, other: &Self) -> bool {
266 match (self, other) {
267 (
268 Self::RateLimitExceeded {
269 retry_after: r1,
270 limit: l1,
271 window_secs: w1,
272 },
273 Self::RateLimitExceeded {
274 retry_after: r2,
275 limit: l2,
276 window_secs: w2,
277 },
278 ) => r1 == r2 && l1 == l2 && w1 == w2,
279 (
280 Self::QueryTooDeep {
281 depth: d1,
282 max_depth: m1,
283 },
284 Self::QueryTooDeep {
285 depth: d2,
286 max_depth: m2,
287 },
288 ) => d1 == d2 && m1 == m2,
289 (
290 Self::QueryTooComplex {
291 complexity: c1,
292 max_complexity: m1,
293 },
294 Self::QueryTooComplex {
295 complexity: c2,
296 max_complexity: m2,
297 },
298 ) => c1 == c2 && m1 == m2,
299 (
300 Self::QueryTooLarge {
301 size: s1,
302 max_size: m1,
303 },
304 Self::QueryTooLarge {
305 size: s2,
306 max_size: m2,
307 },
308 ) => s1 == s2 && m1 == m2,
309 (Self::OriginNotAllowed(o1), Self::OriginNotAllowed(o2)) => o1 == o2,
310 (Self::MethodNotAllowed(m1), Self::MethodNotAllowed(m2)) => m1 == m2,
311 (Self::HeaderNotAllowed(h1), Self::HeaderNotAllowed(h2)) => h1 == h2,
312 (Self::InvalidCSRFToken(r1), Self::InvalidCSRFToken(r2)) => r1 == r2,
313 (Self::CSRFSessionMismatch, Self::CSRFSessionMismatch) => true,
314 (Self::AuditLogFailure(r1), Self::AuditLogFailure(r2)) => r1 == r2,
315 (Self::SecurityConfigError(r1), Self::SecurityConfigError(r2)) => r1 == r2,
316 (Self::TlsRequired { detail: d1 }, Self::TlsRequired { detail: d2 }) => d1 == d2,
317 (
318 Self::TlsVersionTooOld {
319 current: c1,
320 required: r1,
321 },
322 Self::TlsVersionTooOld {
323 current: c2,
324 required: r2,
325 },
326 ) => c1 == c2 && r1 == r2,
327 (Self::MtlsRequired { detail: d1 }, Self::MtlsRequired { detail: d2 }) => d1 == d2,
328 (Self::InvalidClientCert { detail: d1 }, Self::InvalidClientCert { detail: d2 }) => {
329 d1 == d2
330 },
331 (Self::AuthRequired, Self::AuthRequired) => true,
332 (Self::InvalidToken, Self::InvalidToken) => true,
333 (Self::TokenExpired { expired_at: e1 }, Self::TokenExpired { expired_at: e2 }) => {
334 e1 == e2
335 },
336 (Self::TokenMissingClaim { claim: c1 }, Self::TokenMissingClaim { claim: c2 }) => {
337 c1 == c2
338 },
339 (
340 Self::InvalidTokenAlgorithm { algorithm: a1 },
341 Self::InvalidTokenAlgorithm { algorithm: a2 },
342 ) => a1 == a2,
343 (
344 Self::IntrospectionDisabled { detail: d1 },
345 Self::IntrospectionDisabled { detail: d2 },
346 ) => d1 == d2,
347 _ => false,
348 }
349 }
350}
351
352impl Eq for SecurityError {}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_rate_limit_error_display() {
360 let err = SecurityError::RateLimitExceeded {
361 retry_after: 60,
362 limit: 100,
363 window_secs: 60,
364 };
365
366 assert!(err.to_string().contains("Rate limit exceeded"));
367 assert!(err.to_string().contains("100"));
368 assert!(err.to_string().contains("60"));
369 }
370
371 #[test]
372 fn test_query_too_deep_display() {
373 let err = SecurityError::QueryTooDeep {
374 depth: 20,
375 max_depth: 10,
376 };
377
378 assert!(err.to_string().contains("Query too deep"));
379 assert!(err.to_string().contains("20"));
380 assert!(err.to_string().contains("10"));
381 }
382
383 #[test]
384 fn test_query_too_complex_display() {
385 let err = SecurityError::QueryTooComplex {
386 complexity: 500,
387 max_complexity: 100,
388 };
389
390 assert!(err.to_string().contains("Query too complex"));
391 assert!(err.to_string().contains("500"));
392 assert!(err.to_string().contains("100"));
393 }
394
395 #[test]
396 fn test_query_too_large_display() {
397 let err = SecurityError::QueryTooLarge {
398 size: 100_000,
399 max_size: 10_000,
400 };
401
402 assert!(err.to_string().contains("Query too large"));
403 assert!(err.to_string().contains("100000"));
404 assert!(err.to_string().contains("10000"));
405 }
406
407 #[test]
408 fn test_cors_errors() {
409 let origin_err = SecurityError::OriginNotAllowed("https://evil.com".to_string());
410 assert!(origin_err.to_string().contains("CORS origin"));
411
412 let method_err = SecurityError::MethodNotAllowed("DELETE".to_string());
413 assert!(method_err.to_string().contains("CORS method"));
414
415 let header_err = SecurityError::HeaderNotAllowed("X-Custom".to_string());
416 assert!(header_err.to_string().contains("CORS header"));
417 }
418
419 #[test]
420 fn test_csrf_errors() {
421 let invalid = SecurityError::InvalidCSRFToken("expired".to_string());
422 assert!(invalid.to_string().contains("Invalid CSRF token"));
423
424 let mismatch = SecurityError::CSRFSessionMismatch;
425 assert!(mismatch.to_string().contains("session mismatch"));
426 }
427
428 #[test]
429 fn test_audit_error() {
430 let err = SecurityError::AuditLogFailure("connection timeout".to_string());
431 assert!(err.to_string().contains("Audit logging failed"));
432 }
433
434 #[test]
435 fn test_config_error() {
436 let err = SecurityError::SecurityConfigError("missing config key".to_string());
437 assert!(err.to_string().contains("Security configuration error"));
438 }
439
440 #[test]
441 fn test_error_equality() {
442 let err1 = SecurityError::QueryTooDeep {
443 depth: 20,
444 max_depth: 10,
445 };
446 let err2 = SecurityError::QueryTooDeep {
447 depth: 20,
448 max_depth: 10,
449 };
450 assert_eq!(err1, err2);
451
452 let err3 = SecurityError::QueryTooDeep {
453 depth: 30,
454 max_depth: 10,
455 };
456 assert_ne!(err1, err3);
457 }
458
459 #[test]
460 fn test_rate_limit_equality() {
461 let err1 = SecurityError::RateLimitExceeded {
462 retry_after: 60,
463 limit: 100,
464 window_secs: 60,
465 };
466 let err2 = SecurityError::RateLimitExceeded {
467 retry_after: 60,
468 limit: 100,
469 window_secs: 60,
470 };
471 assert_eq!(err1, err2);
472 }
473
474 #[test]
479 fn test_tls_required_error_display() {
480 let err = SecurityError::TlsRequired {
481 detail: "HTTPS required".to_string(),
482 };
483
484 assert!(err.to_string().contains("TLS/HTTPS required"));
485 assert!(err.to_string().contains("HTTPS required"));
486 }
487
488 #[test]
489 fn test_tls_version_too_old_error_display() {
490 use crate::security::tls_enforcer::TlsVersion;
491
492 let err = SecurityError::TlsVersionTooOld {
493 current: TlsVersion::V1_2,
494 required: TlsVersion::V1_3,
495 };
496
497 assert!(err.to_string().contains("TLS version too old"));
498 assert!(err.to_string().contains("1.2"));
499 assert!(err.to_string().contains("1.3"));
500 }
501
502 #[test]
503 fn test_mtls_required_error_display() {
504 let err = SecurityError::MtlsRequired {
505 detail: "Client certificate required".to_string(),
506 };
507
508 assert!(err.to_string().contains("Mutual TLS required"));
509 assert!(err.to_string().contains("Client certificate"));
510 }
511
512 #[test]
513 fn test_invalid_client_cert_error_display() {
514 let err = SecurityError::InvalidClientCert {
515 detail: "Certificate validation failed".to_string(),
516 };
517
518 assert!(err.to_string().contains("Invalid client certificate"));
519 assert!(err.to_string().contains("validation failed"));
520 }
521
522 #[test]
523 fn test_auth_required_error_display() {
524 let err = SecurityError::AuthRequired;
525 assert!(err.to_string().contains("Authentication required"));
526 }
527
528 #[test]
529 fn test_invalid_token_error_display() {
530 let err = SecurityError::InvalidToken;
531 assert!(err.to_string().contains("Invalid authentication token"));
532 }
533
534 #[test]
535 fn test_token_expired_error_display() {
536 use chrono::{Duration, Utc};
537
538 let expired_at = Utc::now() - Duration::hours(1);
539 let err = SecurityError::TokenExpired { expired_at };
540
541 assert!(err.to_string().contains("Token expired"));
542 }
543
544 #[test]
545 fn test_token_missing_claim_error_display() {
546 let err = SecurityError::TokenMissingClaim {
547 claim: "sub".to_string(),
548 };
549
550 assert!(err.to_string().contains("Token missing required claim"));
551 assert!(err.to_string().contains("sub"));
552 }
553
554 #[test]
555 fn test_invalid_token_algorithm_error_display() {
556 let err = SecurityError::InvalidTokenAlgorithm {
557 algorithm: "HS256".to_string(),
558 };
559
560 assert!(err.to_string().contains("Invalid token algorithm"));
561 assert!(err.to_string().contains("HS256"));
562 }
563
564 #[test]
565 fn test_introspection_disabled_error_display() {
566 let err = SecurityError::IntrospectionDisabled {
567 detail: "Introspection not allowed in production".to_string(),
568 };
569
570 assert!(err.to_string().contains("Introspection disabled"));
571 assert!(err.to_string().contains("production"));
572 }
573
574 #[test]
579 fn test_tls_required_equality() {
580 let err1 = SecurityError::TlsRequired {
581 detail: "test".to_string(),
582 };
583 let err2 = SecurityError::TlsRequired {
584 detail: "test".to_string(),
585 };
586 assert_eq!(err1, err2);
587
588 let err3 = SecurityError::TlsRequired {
589 detail: "different".to_string(),
590 };
591 assert_ne!(err1, err3);
592 }
593
594 #[test]
595 fn test_tls_version_too_old_equality() {
596 use crate::security::tls_enforcer::TlsVersion;
597
598 let err1 = SecurityError::TlsVersionTooOld {
599 current: TlsVersion::V1_2,
600 required: TlsVersion::V1_3,
601 };
602 let err2 = SecurityError::TlsVersionTooOld {
603 current: TlsVersion::V1_2,
604 required: TlsVersion::V1_3,
605 };
606 assert_eq!(err1, err2);
607
608 let err3 = SecurityError::TlsVersionTooOld {
609 current: TlsVersion::V1_1,
610 required: TlsVersion::V1_3,
611 };
612 assert_ne!(err1, err3);
613 }
614
615 #[test]
616 fn test_mtls_required_equality() {
617 let err1 = SecurityError::MtlsRequired {
618 detail: "test".to_string(),
619 };
620 let err2 = SecurityError::MtlsRequired {
621 detail: "test".to_string(),
622 };
623 assert_eq!(err1, err2);
624 }
625
626 #[test]
627 fn test_invalid_token_equality() {
628 assert_eq!(SecurityError::InvalidToken, SecurityError::InvalidToken);
629 }
630
631 #[test]
632 fn test_auth_required_equality() {
633 assert_eq!(SecurityError::AuthRequired, SecurityError::AuthRequired);
634 }
635}