1use async_trait::async_trait;
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15#[derive(Debug, Clone)]
17pub struct CspNonce(pub String);
18
19fn is_valid_nonce(nonce: &str) -> bool {
25 !nonce.is_empty()
26 && nonce
27 .bytes()
28 .all(|b| b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=')
29}
30
31#[non_exhaustive]
33#[derive(Debug, Clone)]
34pub struct CspConfig {
35 pub directives: HashMap<String, Vec<String>>,
37 pub report_only: bool,
39 pub include_nonce: bool,
41 pub exempt_paths: HashSet<String>,
50}
51
52impl Default for CspConfig {
53 fn default() -> Self {
54 let mut directives = HashMap::new();
55 directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
56
57 Self {
58 directives,
59 report_only: false,
60 include_nonce: false,
61 exempt_paths: HashSet::new(),
62 }
63 }
64}
65
66impl CspConfig {
67 pub fn strict() -> Self {
82 let mut directives = HashMap::new();
83 directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
84 directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
85 directives.insert("style-src".to_string(), vec!["'self'".to_string()]);
86 directives.insert(
87 "img-src".to_string(),
88 vec!["'self'".to_string(), "data:".to_string()],
89 );
90 directives.insert("font-src".to_string(), vec!["'self'".to_string()]);
91 directives.insert("connect-src".to_string(), vec!["'self'".to_string()]);
92 directives.insert("frame-ancestors".to_string(), vec!["'none'".to_string()]);
93 directives.insert("base-uri".to_string(), vec!["'self'".to_string()]);
94 directives.insert("form-action".to_string(), vec!["'self'".to_string()]);
95
96 Self {
97 directives,
98 report_only: false,
99 include_nonce: false,
100 exempt_paths: HashSet::new(),
101 }
102 }
103
104 pub fn add_exempt_path(mut self, path: String) -> Self {
127 self.exempt_paths.insert(path);
128 self
129 }
130}
131
132pub struct CspMiddleware {
134 config: CspConfig,
135}
136
137impl CspMiddleware {
138 pub fn new() -> Self {
179 Self {
180 config: CspConfig::default(),
181 }
182 }
183 pub fn with_config(config: CspConfig) -> Self {
236 Self { config }
237 }
238 pub fn strict() -> Self {
282 Self {
283 config: CspConfig::strict(),
284 }
285 }
286
287 fn generate_nonce(&self) -> String {
289 use base64::Engine;
290 use rand::RngCore;
291
292 let mut bytes = [0u8; 16];
293 rand::rng().fill_bytes(&mut bytes);
294 base64::engine::general_purpose::STANDARD.encode(bytes)
295 }
296
297 fn build_csp_header(&self, nonce: Option<&str>) -> String {
302 let mut parts = Vec::new();
303
304 let validated_nonce = nonce.filter(|n| is_valid_nonce(n));
306
307 for (directive, values) in &self.config.directives {
308 let mut directive_values = values.clone();
309
310 if self.config.include_nonce
312 && (directive == "script-src" || directive == "style-src")
313 && let Some(n) = validated_nonce
314 {
315 directive_values.push(format!("'nonce-{}'", n));
316 }
317
318 parts.push(format!("{} {}", directive, directive_values.join(" ")));
319 }
320
321 parts.join("; ")
322 }
323
324 fn get_header_name(&self) -> &'static str {
326 if self.config.report_only {
327 "Content-Security-Policy-Report-Only"
328 } else {
329 "Content-Security-Policy"
330 }
331 }
332}
333
334impl Default for CspMiddleware {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340#[async_trait]
341impl Middleware for CspMiddleware {
342 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
343 let path = request.uri.path();
347 if self
348 .config
349 .exempt_paths
350 .iter()
351 .any(|exempt| path == exempt.as_str() || path.starts_with(&format!("{}/", exempt)))
352 {
353 debug!(
354 path = path,
355 "Path is CSP-exempt, skipping CSP header insertion"
356 );
357 return match handler.handle(request).await {
358 Ok(resp) => Ok(resp),
359 Err(e) => Ok(Response::from(e)),
360 };
361 }
362
363 let nonce = if self.config.include_nonce {
365 let generated_nonce = self.generate_nonce();
366 request.extensions.insert(CspNonce(generated_nonce.clone()));
368 Some(generated_nonce)
369 } else {
370 None
371 };
372
373 let mut response = match handler.handle(request).await {
377 Ok(resp) => resp,
378 Err(e) => Response::from(e),
379 };
380
381 let header_name = self.get_header_name();
383 if response.headers.contains_key(header_name) {
384 debug!(
385 header = header_name,
386 "CSP header already present in response, skipping middleware insertion"
387 );
388 } else {
389 let csp_value = self.build_csp_header(nonce.as_deref());
390 match csp_value.parse() {
391 Ok(value) => {
392 response.headers.insert(header_name, value);
393 }
394 Err(e) => {
395 warn!(
396 error = %e,
397 "Failed to parse CSP header value, skipping header insertion"
398 );
399 }
400 }
401 }
402
403 Ok(response)
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use bytes::Bytes;
411 use hyper::{HeaderMap, Method, StatusCode, Version};
412 use rstest::rstest;
413
414 struct TestHandler;
415
416 #[async_trait]
417 impl Handler for TestHandler {
418 async fn handle(&self, _request: Request) -> Result<Response> {
419 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("content")))
420 }
421 }
422
423 #[tokio::test]
424 async fn test_default_csp_header() {
425 let middleware = CspMiddleware::new();
426 let handler = Arc::new(TestHandler);
427
428 let request = Request::builder()
429 .method(Method::GET)
430 .uri("/test")
431 .version(Version::HTTP_11)
432 .headers(HeaderMap::new())
433 .body(Bytes::new())
434 .build()
435 .unwrap();
436
437 let response = middleware.process(request, handler).await.unwrap();
438
439 assert_eq!(response.status, StatusCode::OK);
440 let csp_header = response.headers.get("Content-Security-Policy").unwrap();
441 assert!(csp_header.to_str().unwrap().contains("default-src 'self'"));
442 }
443
444 #[tokio::test]
445 async fn test_custom_csp_directives() {
446 let mut directives = HashMap::new();
447 directives.insert("default-src".to_string(), vec!["'self'".to_string()]);
448 directives.insert(
449 "script-src".to_string(),
450 vec!["'self'".to_string(), "https://cdn.example.com".to_string()],
451 );
452
453 let config = CspConfig {
454 directives,
455 report_only: false,
456 include_nonce: false,
457 exempt_paths: HashSet::new(),
458 };
459 let middleware = CspMiddleware::with_config(config);
460 let handler = Arc::new(TestHandler);
461
462 let request = Request::builder()
463 .method(Method::GET)
464 .uri("/test")
465 .version(Version::HTTP_11)
466 .headers(HeaderMap::new())
467 .body(Bytes::new())
468 .build()
469 .unwrap();
470
471 let response = middleware.process(request, handler).await.unwrap();
472
473 let csp_header = response
474 .headers
475 .get("Content-Security-Policy")
476 .unwrap()
477 .to_str()
478 .unwrap();
479 assert!(csp_header.contains("default-src 'self'"));
480 assert!(csp_header.contains("script-src 'self' https://cdn.example.com"));
481 }
482
483 #[tokio::test]
484 async fn test_report_only_mode() {
485 let config = CspConfig {
486 directives: {
487 let mut d = HashMap::new();
488 d.insert("default-src".to_string(), vec!["'self'".to_string()]);
489 d
490 },
491 report_only: true,
492 include_nonce: false,
493 exempt_paths: HashSet::new(),
494 };
495 let middleware = CspMiddleware::with_config(config);
496 let handler = Arc::new(TestHandler);
497
498 let request = Request::builder()
499 .method(Method::GET)
500 .uri("/test")
501 .version(Version::HTTP_11)
502 .headers(HeaderMap::new())
503 .body(Bytes::new())
504 .build()
505 .unwrap();
506
507 let response = middleware.process(request, handler).await.unwrap();
508
509 assert!(
510 response
511 .headers
512 .contains_key("Content-Security-Policy-Report-Only")
513 );
514 assert!(!response.headers.contains_key("Content-Security-Policy"));
515 }
516
517 #[tokio::test]
518 async fn test_nonce_generation() {
519 let config = CspConfig {
520 directives: {
521 let mut d = HashMap::new();
522 d.insert("script-src".to_string(), vec!["'self'".to_string()]);
523 d
524 },
525 report_only: false,
526 include_nonce: true,
527 exempt_paths: HashSet::new(),
528 };
529 let middleware = CspMiddleware::with_config(config);
530 let handler = Arc::new(TestHandler);
531
532 let request = Request::builder()
533 .method(Method::GET)
534 .uri("/test")
535 .version(Version::HTTP_11)
536 .headers(HeaderMap::new())
537 .body(Bytes::new())
538 .build()
539 .unwrap();
540
541 let response = middleware.process(request, handler).await.unwrap();
542
543 let csp_header = response
544 .headers
545 .get("Content-Security-Policy")
546 .unwrap()
547 .to_str()
548 .unwrap();
549 assert!(csp_header.contains("'nonce-"));
550 }
551
552 #[tokio::test]
553 async fn test_strict_csp() {
554 let middleware = CspMiddleware::strict();
555 let handler = Arc::new(TestHandler);
556
557 let request = Request::builder()
558 .method(Method::GET)
559 .uri("/test")
560 .version(Version::HTTP_11)
561 .headers(HeaderMap::new())
562 .body(Bytes::new())
563 .build()
564 .unwrap();
565
566 let response = middleware.process(request, handler).await.unwrap();
567
568 let csp_header = response
569 .headers
570 .get("Content-Security-Policy")
571 .unwrap()
572 .to_str()
573 .unwrap();
574 assert!(csp_header.contains("default-src 'self'"));
575 assert!(csp_header.contains("script-src 'self'"));
576 assert!(csp_header.contains("style-src 'self'"));
577 assert!(csp_header.contains("frame-ancestors 'none'"));
578 assert!(csp_header.contains("base-uri 'self'"));
579 }
580
581 #[tokio::test]
582 async fn test_multiple_directive_values() {
583 let mut directives = HashMap::new();
584 directives.insert(
585 "img-src".to_string(),
586 vec![
587 "'self'".to_string(),
588 "data:".to_string(),
589 "https:".to_string(),
590 ],
591 );
592
593 let config = CspConfig {
594 directives,
595 report_only: false,
596 include_nonce: false,
597 exempt_paths: HashSet::new(),
598 };
599 let middleware = CspMiddleware::with_config(config);
600 let handler = Arc::new(TestHandler);
601
602 let request = Request::builder()
603 .method(Method::GET)
604 .uri("/test")
605 .version(Version::HTTP_11)
606 .headers(HeaderMap::new())
607 .body(Bytes::new())
608 .build()
609 .unwrap();
610
611 let response = middleware.process(request, handler).await.unwrap();
612
613 let csp_header = response
614 .headers
615 .get("Content-Security-Policy")
616 .unwrap()
617 .to_str()
618 .unwrap();
619 assert!(csp_header.contains("img-src 'self' data: https:"));
620 }
621
622 #[tokio::test]
623 async fn test_nonce_only_added_to_script_and_style() {
624 let mut directives = HashMap::new();
625 directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
626 directives.insert("style-src".to_string(), vec!["'self'".to_string()]);
627 directives.insert("img-src".to_string(), vec!["'self'".to_string()]);
628
629 let config = CspConfig {
630 directives,
631 report_only: false,
632 include_nonce: true,
633 exempt_paths: HashSet::new(),
634 };
635 let middleware = CspMiddleware::with_config(config);
636 let handler = Arc::new(TestHandler);
637
638 let request = Request::builder()
639 .method(Method::GET)
640 .uri("/test")
641 .version(Version::HTTP_11)
642 .headers(HeaderMap::new())
643 .body(Bytes::new())
644 .build()
645 .unwrap();
646
647 let response = middleware.process(request, handler).await.unwrap();
648
649 let csp_header = response
650 .headers
651 .get("Content-Security-Policy")
652 .unwrap()
653 .to_str()
654 .unwrap();
655
656 let nonce_count = csp_header.matches("'nonce-").count();
658 assert_eq!(nonce_count, 2);
659 }
660
661 #[tokio::test]
662 async fn test_empty_directives() {
663 let config = CspConfig {
664 directives: HashMap::new(),
665 report_only: false,
666 include_nonce: false,
667 exempt_paths: HashSet::new(),
668 };
669 let middleware = CspMiddleware::with_config(config);
670 let handler = Arc::new(TestHandler);
671
672 let request = Request::builder()
673 .method(Method::GET)
674 .uri("/test")
675 .version(Version::HTTP_11)
676 .headers(HeaderMap::new())
677 .body(Bytes::new())
678 .build()
679 .unwrap();
680
681 let response = middleware.process(request, handler).await.unwrap();
682
683 assert!(response.headers.contains_key("Content-Security-Policy"));
685 }
686
687 #[tokio::test]
688 async fn test_frame_ancestors_directive() {
689 let mut directives = HashMap::new();
690 directives.insert(
691 "frame-ancestors".to_string(),
692 vec!["'self'".to_string(), "https://trusted.com".to_string()],
693 );
694
695 let config = CspConfig {
696 directives,
697 report_only: false,
698 include_nonce: false,
699 exempt_paths: HashSet::new(),
700 };
701 let middleware = CspMiddleware::with_config(config);
702 let handler = Arc::new(TestHandler);
703
704 let request = Request::builder()
705 .method(Method::GET)
706 .uri("/test")
707 .version(Version::HTTP_11)
708 .headers(HeaderMap::new())
709 .body(Bytes::new())
710 .build()
711 .unwrap();
712
713 let response = middleware.process(request, handler).await.unwrap();
714
715 let csp_header = response
716 .headers
717 .get("Content-Security-Policy")
718 .unwrap()
719 .to_str()
720 .unwrap();
721 assert!(csp_header.contains("frame-ancestors 'self' https://trusted.com"));
722 }
723
724 #[tokio::test]
725 async fn test_nonce_uniqueness_across_requests() {
726 let config = CspConfig {
727 directives: {
728 let mut d = HashMap::new();
729 d.insert("script-src".to_string(), vec!["'self'".to_string()]);
730 d
731 },
732 report_only: false,
733 include_nonce: true,
734 exempt_paths: HashSet::new(),
735 };
736 let middleware = CspMiddleware::with_config(config);
737 let handler = Arc::new(TestHandler);
738
739 let request1 = Request::builder()
741 .method(Method::GET)
742 .uri("/page1")
743 .version(Version::HTTP_11)
744 .headers(HeaderMap::new())
745 .body(Bytes::new())
746 .build()
747 .unwrap();
748 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
749 let csp1 = response1
750 .headers
751 .get("Content-Security-Policy")
752 .unwrap()
753 .to_str()
754 .unwrap()
755 .to_string();
756
757 let request2 = Request::builder()
759 .method(Method::GET)
760 .uri("/page2")
761 .version(Version::HTTP_11)
762 .headers(HeaderMap::new())
763 .body(Bytes::new())
764 .build()
765 .unwrap();
766 let response2 = middleware.process(request2, handler).await.unwrap();
767 let csp2 = response2
768 .headers
769 .get("Content-Security-Policy")
770 .unwrap()
771 .to_str()
772 .unwrap()
773 .to_string();
774
775 let extract_nonce = |csp: &str| -> Option<String> {
777 csp.split("'nonce-")
778 .nth(1)
779 .and_then(|s| s.split('\'').next())
780 .map(|s| s.to_string())
781 };
782
783 let nonce1 = extract_nonce(&csp1);
784 let nonce2 = extract_nonce(&csp2);
785
786 assert!(nonce1.is_some(), "First CSP should contain nonce");
787 assert!(nonce2.is_some(), "Second CSP should contain nonce");
788
789 assert_ne!(nonce1, nonce2, "Nonces should be unique across requests");
791 }
792
793 #[tokio::test]
794 async fn test_response_body_preserved() {
795 struct TestHandlerWithBody;
796
797 #[async_trait]
798 impl Handler for TestHandlerWithBody {
799 async fn handle(&self, _request: Request) -> Result<Response> {
800 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("custom response content")))
801 }
802 }
803
804 let middleware = CspMiddleware::new();
805 let handler = Arc::new(TestHandlerWithBody);
806
807 let request = Request::builder()
808 .method(Method::GET)
809 .uri("/page")
810 .version(Version::HTTP_11)
811 .headers(HeaderMap::new())
812 .body(Bytes::new())
813 .build()
814 .unwrap();
815
816 let response = middleware.process(request, handler).await.unwrap();
817
818 assert!(response.headers.contains_key("Content-Security-Policy"));
820
821 assert_eq!(response.body, Bytes::from("custom response content"));
823 }
824
825 #[rstest]
826 fn test_nonce_is_valid_base64() {
827 use base64::Engine;
829 let middleware = CspMiddleware::new();
830
831 let nonce = middleware.generate_nonce();
833
834 let decoded = base64::engine::general_purpose::STANDARD.decode(&nonce);
836 assert!(
837 decoded.is_ok(),
838 "Nonce should be valid base64, got: {}",
839 nonce
840 );
841 }
842
843 #[rstest]
844 fn test_nonce_length() {
845 use base64::Engine;
847 let middleware = CspMiddleware::new();
848
849 let nonce = middleware.generate_nonce();
851 let decoded = base64::engine::general_purpose::STANDARD
852 .decode(&nonce)
853 .unwrap();
854
855 assert_eq!(
857 decoded.len(),
858 16,
859 "Nonce should be exactly 16 bytes (128 bits)"
860 );
861 }
862
863 #[rstest]
864 fn test_is_valid_nonce_accepts_base64() {
865 assert!(is_valid_nonce("YWJjZGVmZw=="));
867 assert!(is_valid_nonce("abc123+/="));
868 assert!(is_valid_nonce("ABCDEFGHIJKLMNOP"));
869 }
870
871 #[rstest]
872 fn test_is_valid_nonce_rejects_invalid_chars() {
873 assert!(!is_valid_nonce(""));
875 assert!(!is_valid_nonce("abc\ndef"));
876 assert!(!is_valid_nonce("abc;def"));
877 assert!(!is_valid_nonce("abc def"));
878 assert!(!is_valid_nonce("abc'def"));
879 assert!(!is_valid_nonce("abc\rdef"));
880 }
881
882 #[rstest]
883 fn test_build_csp_header_rejects_invalid_nonce() {
884 let mut directives = HashMap::new();
886 directives.insert("script-src".to_string(), vec!["'self'".to_string()]);
887 let config = CspConfig {
888 directives,
889 report_only: false,
890 include_nonce: true,
891 exempt_paths: HashSet::new(),
892 };
893 let middleware = CspMiddleware::with_config(config);
894
895 let csp = middleware.build_csp_header(Some("abc\r\ndef;injected"));
897
898 assert!(
900 !csp.contains("nonce-"),
901 "Invalid nonce should not be embedded in header"
902 );
903 assert!(csp.contains("script-src 'self'"));
904 }
905
906 #[rstest]
907 fn test_nonce_entropy() {
908 let middleware = CspMiddleware::new();
910 let mut nonces = std::collections::HashSet::new();
911
912 for _ in 0..100 {
914 nonces.insert(middleware.generate_nonce());
915 }
916
917 assert_eq!(
919 nonces.len(),
920 100,
921 "All 100 nonces should be unique (statistical randomness)"
922 );
923 }
924
925 #[tokio::test]
926 async fn test_does_not_override_existing_csp_header() {
927 struct HandlerWithCsp;
929
930 #[async_trait]
931 impl Handler for HandlerWithCsp {
932 async fn handle(&self, _request: Request) -> Result<Response> {
933 Ok(Response::new(StatusCode::OK).with_header(
934 "Content-Security-Policy",
935 "default-src 'self'; style-src 'self' 'unsafe-inline'",
936 ))
937 }
938 }
939
940 let middleware = CspMiddleware::strict();
941 let handler = Arc::new(HandlerWithCsp);
942
943 let request = Request::builder()
944 .method(Method::GET)
945 .uri("/admin/")
946 .version(Version::HTTP_11)
947 .headers(HeaderMap::new())
948 .body(Bytes::new())
949 .build()
950 .unwrap();
951
952 let response = middleware.process(request, handler).await.unwrap();
954
955 let csp = response
957 .headers
958 .get("Content-Security-Policy")
959 .unwrap()
960 .to_str()
961 .unwrap();
962 assert!(
963 csp.contains("'unsafe-inline'"),
964 "Handler-set CSP should be preserved, got: {}",
965 csp
966 );
967 }
968
969 #[tokio::test]
970 async fn test_does_not_override_existing_csp_report_only_header() {
971 struct HandlerWithReportOnlyCsp;
973
974 #[async_trait]
975 impl Handler for HandlerWithReportOnlyCsp {
976 async fn handle(&self, _request: Request) -> Result<Response> {
977 Ok(Response::new(StatusCode::OK)
978 .with_header("Content-Security-Policy-Report-Only", "default-src 'none'"))
979 }
980 }
981
982 let config = CspConfig {
983 directives: {
984 let mut d = HashMap::new();
985 d.insert("default-src".to_string(), vec!["'self'".to_string()]);
986 d
987 },
988 report_only: true,
989 include_nonce: false,
990 exempt_paths: HashSet::new(),
991 };
992 let middleware = CspMiddleware::with_config(config);
993 let handler = Arc::new(HandlerWithReportOnlyCsp);
994
995 let request = Request::builder()
996 .method(Method::GET)
997 .uri("/test")
998 .version(Version::HTTP_11)
999 .headers(HeaderMap::new())
1000 .body(Bytes::new())
1001 .build()
1002 .unwrap();
1003
1004 let response = middleware.process(request, handler).await.unwrap();
1006
1007 let csp = response
1009 .headers
1010 .get("Content-Security-Policy-Report-Only")
1011 .unwrap()
1012 .to_str()
1013 .unwrap();
1014 assert_eq!(
1015 csp, "default-src 'none'",
1016 "Handler-set report-only CSP should be preserved"
1017 );
1018 }
1019
1020 #[rstest]
1021 #[tokio::test]
1022 async fn test_exempt_path_skips_csp() {
1023 let config = CspConfig::strict().add_exempt_path("/admin".to_string());
1025 let middleware = CspMiddleware::with_config(config);
1026 let handler = Arc::new(TestHandler);
1027
1028 let request = Request::builder()
1029 .method(Method::GET)
1030 .uri("/admin/dashboard")
1031 .version(Version::HTTP_11)
1032 .headers(HeaderMap::new())
1033 .body(Bytes::new())
1034 .build()
1035 .unwrap();
1036
1037 let response = middleware.process(request, handler).await.unwrap();
1039
1040 assert!(
1042 !response.headers.contains_key("Content-Security-Policy"),
1043 "CSP should not be set for exempt path"
1044 );
1045 }
1046
1047 #[rstest]
1048 #[tokio::test]
1049 async fn test_exempt_path_exact_match() {
1050 let config = CspConfig::strict().add_exempt_path("/admin".to_string());
1052 let middleware = CspMiddleware::with_config(config);
1053 let handler = Arc::new(TestHandler);
1054
1055 let request = Request::builder()
1056 .method(Method::GET)
1057 .uri("/admin")
1058 .version(Version::HTTP_11)
1059 .headers(HeaderMap::new())
1060 .body(Bytes::new())
1061 .build()
1062 .unwrap();
1063
1064 let response = middleware.process(request, handler).await.unwrap();
1066
1067 assert!(
1069 !response.headers.contains_key("Content-Security-Policy"),
1070 "CSP should not be set for exact exempt path match"
1071 );
1072 }
1073
1074 #[rstest]
1075 #[tokio::test]
1076 async fn test_non_exempt_path_gets_csp() {
1077 let config = CspConfig::strict().add_exempt_path("/admin".to_string());
1079 let middleware = CspMiddleware::with_config(config);
1080 let handler = Arc::new(TestHandler);
1081
1082 let request = Request::builder()
1083 .method(Method::GET)
1084 .uri("/api/data")
1085 .version(Version::HTTP_11)
1086 .headers(HeaderMap::new())
1087 .body(Bytes::new())
1088 .build()
1089 .unwrap();
1090
1091 let response = middleware.process(request, handler).await.unwrap();
1093
1094 assert!(
1096 response.headers.contains_key("Content-Security-Policy"),
1097 "CSP should be set for non-exempt path"
1098 );
1099 }
1100
1101 #[rstest]
1102 #[tokio::test]
1103 async fn test_exempt_path_boundary_prevents_false_match() {
1104 let config = CspConfig::strict().add_exempt_path("/admin".to_string());
1106 let middleware = CspMiddleware::with_config(config);
1107 let handler = Arc::new(TestHandler);
1108
1109 let request = Request::builder()
1110 .method(Method::GET)
1111 .uri("/administrator/panel")
1112 .version(Version::HTTP_11)
1113 .headers(HeaderMap::new())
1114 .body(Bytes::new())
1115 .build()
1116 .unwrap();
1117
1118 let response = middleware.process(request, handler).await.unwrap();
1120
1121 assert!(
1123 response.headers.contains_key("Content-Security-Policy"),
1124 "/administrator should NOT be exempt when only /admin is in exempt_paths"
1125 );
1126 }
1127
1128 #[rstest]
1129 fn test_csp_config_add_exempt_path() {
1130 let config = CspConfig::default()
1132 .add_exempt_path("/admin".to_string())
1133 .add_exempt_path("/static/admin".to_string());
1134
1135 assert!(config.exempt_paths.contains("/admin"));
1137 assert!(config.exempt_paths.contains("/static/admin"));
1138 assert_eq!(config.exempt_paths.len(), 2);
1139 }
1140
1141 struct ErrorHandler;
1143
1144 #[async_trait]
1145 impl Handler for ErrorHandler {
1146 async fn handle(&self, _request: Request) -> Result<Response> {
1147 Err(reinhardt_http::Error::Http("handler error".to_string()))
1148 }
1149 }
1150
1151 #[rstest]
1152 #[tokio::test]
1153 async fn test_csp_header_applied_on_handler_error() {
1154 let config = CspConfig {
1156 directives: {
1157 let mut d = HashMap::new();
1158 d.insert("default-src".to_string(), vec!["'none'".to_string()]);
1159 d
1160 },
1161 report_only: false,
1162 include_nonce: false,
1163 exempt_paths: HashSet::new(),
1164 };
1165 let middleware = CspMiddleware::with_config(config);
1166 let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
1167
1168 let request = Request::builder()
1169 .method(Method::GET)
1170 .uri("/test")
1171 .version(Version::HTTP_11)
1172 .headers(HeaderMap::new())
1173 .body(Bytes::new())
1174 .build()
1175 .unwrap();
1176
1177 let response = middleware.process(request, handler).await.unwrap();
1179
1180 assert!(response.status.is_client_error() || response.status.is_server_error());
1182 assert!(
1183 response.headers.contains_key("Content-Security-Policy"),
1184 "CSP header should be applied even when handler returns an error"
1185 );
1186 }
1187
1188 #[rstest]
1189 #[tokio::test]
1190 async fn test_csp_exempt_path_error_converted_to_response() {
1191 let config = CspConfig::strict().add_exempt_path("/exempt".to_string());
1193 let middleware = CspMiddleware::with_config(config);
1194 let handler: Arc<dyn Handler> = Arc::new(ErrorHandler);
1195
1196 let request = Request::builder()
1197 .method(Method::GET)
1198 .uri("/exempt/resource")
1199 .version(Version::HTTP_11)
1200 .headers(HeaderMap::new())
1201 .body(Bytes::new())
1202 .build()
1203 .unwrap();
1204
1205 let result = middleware.process(request, handler).await;
1208
1209 assert!(
1211 result.is_ok(),
1212 "Handler error should be converted to response for exempt path"
1213 );
1214 let response = result.unwrap();
1215 assert!(response.status.is_client_error() || response.status.is_server_error());
1216 }
1217
1218 #[rstest]
1219 #[tokio::test]
1220 async fn test_multiple_exempt_paths() {
1221 let config = CspConfig::strict()
1223 .add_exempt_path("/admin".to_string())
1224 .add_exempt_path("/static/admin".to_string());
1225 let middleware = CspMiddleware::with_config(config);
1226 let handler = Arc::new(TestHandler);
1227
1228 for uri in ["/admin/dashboard", "/static/admin/style.css"] {
1230 let request = Request::builder()
1231 .method(Method::GET)
1232 .uri(uri)
1233 .version(Version::HTTP_11)
1234 .headers(HeaderMap::new())
1235 .body(Bytes::new())
1236 .build()
1237 .unwrap();
1238
1239 let response = middleware.process(request, handler.clone()).await.unwrap();
1240 assert!(
1241 !response.headers.contains_key("Content-Security-Policy"),
1242 "Path {} should be exempt from CSP",
1243 uri
1244 );
1245 }
1246 }
1247}