1use http::{header::HeaderValue, Method};
26use std::time::Duration;
27use tower_http::cors::{AllowOrigin, CorsLayer};
28
29pub const ALLOWED_HEADERS: [http::header::HeaderName; 2] =
31 [http::header::CONTENT_TYPE, http::header::AUTHORIZATION];
32
33pub const ALLOWED_METHODS: [Method; 3] = [Method::GET, Method::POST, Method::OPTIONS];
35
36pub const DEFAULT_MAX_AGE_SECS: u64 = 3600;
38
39pub fn cors_layer() -> CorsLayer {
62 CorsLayer::new()
63 .allow_origin(AllowOrigin::predicate(|origin, _| {
64 is_localhost_origin(origin)
65 }))
66 .allow_methods(ALLOWED_METHODS)
67 .allow_headers(ALLOWED_HEADERS)
68 .max_age(Duration::from_secs(DEFAULT_MAX_AGE_SECS))
69}
70
71pub fn cors_layer_with_config(config: CorsConfig) -> CorsLayer {
89 let mut layer = CorsLayer::new()
90 .allow_origin(AllowOrigin::predicate(move |origin, _| {
91 if config.allow_all_localhost {
92 is_localhost_origin(origin)
93 } else {
94 false
96 }
97 }))
98 .allow_methods(config.allowed_methods.clone())
99 .allow_headers(config.allowed_headers.clone())
100 .max_age(Duration::from_secs(config.max_age_secs));
101
102 if config.allow_credentials {
103 layer = layer.allow_credentials(true);
104 }
105
106 if config.expose_headers {
107 layer = layer.expose_headers([http::header::CONTENT_LENGTH, http::header::CONTENT_TYPE]);
108 }
109
110 layer
111}
112
113pub fn cors_layer_permissive() -> CorsLayer {
130 CorsLayer::new()
131 .allow_origin(tower_http::cors::Any)
132 .allow_methods(tower_http::cors::Any)
133 .allow_headers(tower_http::cors::Any)
134 .max_age(Duration::from_secs(DEFAULT_MAX_AGE_SECS))
135}
136
137#[derive(Debug, Clone)]
139pub struct CorsConfig {
140 pub allow_all_localhost: bool,
142 pub allow_credentials: bool,
144 pub expose_headers: bool,
146 pub max_age_secs: u64,
148 pub allowed_methods: Vec<Method>,
150 pub allowed_headers: Vec<http::header::HeaderName>,
152}
153
154impl Default for CorsConfig {
155 fn default() -> Self {
156 Self {
157 allow_all_localhost: true,
158 allow_credentials: false,
159 expose_headers: false,
160 max_age_secs: DEFAULT_MAX_AGE_SECS,
161 allowed_methods: ALLOWED_METHODS.to_vec(),
162 allowed_headers: ALLOWED_HEADERS.to_vec(),
163 }
164 }
165}
166
167impl CorsConfig {
168 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn with_max_age(mut self, secs: u64) -> Self {
175 self.max_age_secs = secs;
176 self
177 }
178
179 pub fn with_allow_credentials(mut self, allow: bool) -> Self {
181 self.allow_credentials = allow;
182 self
183 }
184
185 pub fn with_expose_headers(mut self, expose: bool) -> Self {
187 self.expose_headers = expose;
188 self
189 }
190
191 pub fn with_methods(mut self, methods: Vec<Method>) -> Self {
193 self.allowed_methods = methods;
194 self
195 }
196
197 pub fn with_headers(mut self, headers: Vec<http::header::HeaderName>) -> Self {
199 self.allowed_headers = headers;
200 self
201 }
202
203 pub fn with_strict_origins(mut self) -> Self {
205 self.allow_all_localhost = false;
206 self
207 }
208}
209
210pub fn is_localhost_origin(origin: &HeaderValue) -> bool {
246 let origin_str = match origin.to_str() {
247 Ok(s) => s,
248 Err(_) => return false, };
250
251 let origin_lower = origin_str.to_lowercase();
254
255 if origin_lower.starts_with("http://localhost") || origin_lower.starts_with("https://localhost")
258 {
259 return validate_localhost_format(&origin_lower, "localhost");
260 }
261
262 if origin_lower.starts_with("http://127.0.0.1") || origin_lower.starts_with("https://127.0.0.1")
265 {
266 return validate_localhost_format(&origin_lower, "127.0.0.1");
267 }
268
269 if origin_lower.starts_with("http://[::1]") || origin_lower.starts_with("https://[::1]") {
271 return validate_ipv6_localhost_format(&origin_lower);
272 }
273
274 false
275}
276
277fn validate_localhost_format(origin: &str, host: &str) -> bool {
279 let scheme_end = if origin.starts_with("https://") {
281 8
282 } else {
283 7 };
285
286 let after_host = scheme_end + host.len();
287
288 if origin.len() == after_host {
290 return true;
292 }
293
294 let remaining = &origin[after_host..];
295
296 if let Some(port_str) = remaining.strip_prefix(':') {
298 let port_end = port_str.find('/').unwrap_or(port_str.len());
301 let port = &port_str[..port_end];
302
303 if let Ok(port_num) = port.parse::<u16>() {
305 return port_num > 0;
306 }
307 return false;
308 }
309
310 if remaining.starts_with('/') {
311 return true;
313 }
314
315 false
317}
318
319fn validate_ipv6_localhost_format(origin: &str) -> bool {
321 let scheme_end = if origin.starts_with("https://") { 8 } else { 7 };
323
324 let after_bracket = origin[scheme_end..].find(']');
325 if let Some(pos) = after_bracket {
326 let after_host = scheme_end + pos + 1;
327 if origin.len() == after_host {
328 return true;
329 }
330
331 let remaining = &origin[after_host..];
332 if let Some(port_str) = remaining.strip_prefix(':') {
333 let port_end = port_str.find('/').unwrap_or(port_str.len());
334 let port = &port_str[..port_end];
335 if let Ok(port_num) = port.parse::<u16>() {
336 return port_num > 0;
337 }
338 return false;
339 }
340
341 if remaining.starts_with('/') {
342 return true;
343 }
344 }
345
346 false
347}
348
349#[derive(Debug, Clone)]
351pub struct CorsValidationResult {
352 pub allowed: bool,
354 pub origin: String,
356 pub reason: String,
358}
359
360impl CorsValidationResult {
361 pub fn new(allowed: bool, origin: String, reason: String) -> Self {
363 Self {
364 allowed,
365 origin,
366 reason,
367 }
368 }
369}
370
371pub fn validate_origin(origin: &str) -> CorsValidationResult {
385 let header_value = match HeaderValue::from_str(origin) {
386 Ok(v) => v,
387 Err(_) => {
388 return CorsValidationResult::new(
389 false,
390 origin.to_string(),
391 "Invalid header value format".to_string(),
392 );
393 }
394 };
395
396 let allowed = is_localhost_origin(&header_value);
397 let reason = if allowed {
398 "Localhost origin allowed".to_string()
399 } else {
400 determine_rejection_reason(origin)
401 };
402
403 CorsValidationResult::new(allowed, origin.to_string(), reason)
404}
405
406fn determine_rejection_reason(origin: &str) -> String {
408 let origin_lower = origin.to_lowercase();
409
410 if !origin_lower.starts_with("http://") && !origin_lower.starts_with("https://") {
411 return "Invalid scheme: must be http:// or https://".to_string();
412 }
413
414 if origin_lower.contains("localhost") && !is_valid_localhost_pattern(&origin_lower) {
415 return "Invalid localhost format: possible subdomain attack".to_string();
416 }
417
418 if origin_lower.contains("127.0.0.1") && !is_valid_loopback_pattern(&origin_lower) {
419 return "Invalid 127.0.0.1 format".to_string();
420 }
421
422 if is_private_ip_origin(&origin_lower) {
424 return "Private IP origins other than 127.0.0.1 are not allowed".to_string();
425 }
426
427 "External origin not allowed: only localhost origins permitted".to_string()
428}
429
430fn is_valid_localhost_pattern(origin: &str) -> bool {
432 let patterns = [
433 "http://localhost",
434 "https://localhost",
435 "http://localhost:",
436 "https://localhost:",
437 "http://localhost/",
438 "https://localhost/",
439 ];
440
441 for pattern in patterns {
442 if origin.starts_with(pattern) {
443 return true;
444 }
445 }
446
447 false
448}
449
450fn is_valid_loopback_pattern(origin: &str) -> bool {
452 let patterns = [
453 "http://127.0.0.1",
454 "https://127.0.0.1",
455 "http://127.0.0.1:",
456 "https://127.0.0.1:",
457 "http://127.0.0.1/",
458 "https://127.0.0.1/",
459 ];
460
461 for pattern in patterns {
462 if origin.starts_with(pattern) {
463 return true;
464 }
465 }
466
467 false
468}
469
470fn is_private_ip_origin(origin: &str) -> bool {
472 let private_patterns = [
474 "192.168.", "10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.20.", "172.21.",
475 "172.22.", "172.23.", "172.24.", "172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
476 "172.30.", "172.31.",
477 ];
478
479 for pattern in private_patterns {
480 if origin.contains(pattern) {
481 return true;
482 }
483 }
484
485 false
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
495 fn test_localhost_origin_http() {
496 let origin = HeaderValue::from_static("http://localhost");
497 assert!(
498 is_localhost_origin(&origin),
499 "http://localhost should be allowed"
500 );
501 }
502
503 #[test]
504 fn test_localhost_origin_https() {
505 let origin = HeaderValue::from_static("https://localhost");
506 assert!(
507 is_localhost_origin(&origin),
508 "https://localhost should be allowed"
509 );
510 }
511
512 #[test]
513 fn test_localhost_origin_with_port() {
514 let origin = HeaderValue::from_static("http://localhost:3000");
515 assert!(
516 is_localhost_origin(&origin),
517 "http://localhost:3000 should be allowed"
518 );
519 }
520
521 #[test]
522 fn test_localhost_origin_with_high_port() {
523 let origin = HeaderValue::from_static("http://localhost:65535");
524 assert!(
525 is_localhost_origin(&origin),
526 "http://localhost:65535 should be allowed"
527 );
528 }
529
530 #[test]
531 fn test_localhost_origin_with_path() {
532 let origin = HeaderValue::from_static("http://localhost/api");
533 assert!(
534 is_localhost_origin(&origin),
535 "http://localhost/api should be allowed"
536 );
537 }
538
539 #[test]
540 fn test_localhost_origin_with_port_and_path() {
541 let origin = HeaderValue::from_static("http://localhost:8080/api/v1");
542 assert!(
543 is_localhost_origin(&origin),
544 "http://localhost:8080/api/v1 should be allowed"
545 );
546 }
547
548 #[test]
551 fn test_loopback_origin_http() {
552 let origin = HeaderValue::from_static("http://127.0.0.1");
553 assert!(
554 is_localhost_origin(&origin),
555 "http://127.0.0.1 should be allowed"
556 );
557 }
558
559 #[test]
560 fn test_loopback_origin_https() {
561 let origin = HeaderValue::from_static("https://127.0.0.1");
562 assert!(
563 is_localhost_origin(&origin),
564 "https://127.0.0.1 should be allowed"
565 );
566 }
567
568 #[test]
569 fn test_loopback_origin_with_port() {
570 let origin = HeaderValue::from_static("http://127.0.0.1:8000");
571 assert!(
572 is_localhost_origin(&origin),
573 "http://127.0.0.1:8000 should be allowed"
574 );
575 }
576
577 #[test]
578 fn test_loopback_origin_with_path() {
579 let origin = HeaderValue::from_static("http://127.0.0.1/mcp");
580 assert!(
581 is_localhost_origin(&origin),
582 "http://127.0.0.1/mcp should be allowed"
583 );
584 }
585
586 #[test]
589 fn test_ipv6_localhost_origin() {
590 let origin = HeaderValue::from_static("http://[::1]");
591 assert!(
592 is_localhost_origin(&origin),
593 "http://[::1] should be allowed"
594 );
595 }
596
597 #[test]
598 fn test_ipv6_localhost_origin_with_port() {
599 let origin = HeaderValue::from_static("http://[::1]:3000");
600 assert!(
601 is_localhost_origin(&origin),
602 "http://[::1]:3000 should be allowed"
603 );
604 }
605
606 #[test]
607 fn test_ipv6_localhost_origin_https() {
608 let origin = HeaderValue::from_static("https://[::1]:8080");
609 assert!(
610 is_localhost_origin(&origin),
611 "https://[::1]:8080 should be allowed"
612 );
613 }
614
615 #[test]
618 fn test_external_origin_blocked() {
619 let origin = HeaderValue::from_static("http://example.com");
620 assert!(
621 !is_localhost_origin(&origin),
622 "http://example.com should be blocked"
623 );
624 }
625
626 #[test]
627 fn test_external_origin_with_port_blocked() {
628 let origin = HeaderValue::from_static("http://evil.com:3000");
629 assert!(
630 !is_localhost_origin(&origin),
631 "http://evil.com:3000 should be blocked"
632 );
633 }
634
635 #[test]
636 fn test_external_https_blocked() {
637 let origin = HeaderValue::from_static("https://malicious.org");
638 assert!(
639 !is_localhost_origin(&origin),
640 "https://malicious.org should be blocked"
641 );
642 }
643
644 #[test]
647 fn test_localhost_subdomain_attack_blocked() {
648 let origin = HeaderValue::from_static("http://localhost.evil.com");
649 assert!(
650 !is_localhost_origin(&origin),
651 "http://localhost.evil.com should be blocked (subdomain attack)"
652 );
653 }
654
655 #[test]
656 fn test_localhostevil_blocked() {
657 let origin = HeaderValue::from_static("http://localhostevil.com");
658 assert!(
659 !is_localhost_origin(&origin),
660 "http://localhostevil.com should be blocked"
661 );
662 }
663
664 #[test]
665 fn test_subdomain_localhost_blocked() {
666 let origin = HeaderValue::from_static("http://sub.localhost.com");
667 assert!(
668 !is_localhost_origin(&origin),
669 "http://sub.localhost.com should be blocked"
670 );
671 }
672
673 #[test]
674 fn test_fake_localhost_blocked() {
675 let origin = HeaderValue::from_static("http://my-localhost.com");
676 assert!(
677 !is_localhost_origin(&origin),
678 "http://my-localhost.com should be blocked"
679 );
680 }
681
682 #[test]
685 fn test_private_ip_192_blocked() {
686 let origin = HeaderValue::from_static("http://192.168.1.1");
687 assert!(
688 !is_localhost_origin(&origin),
689 "http://192.168.1.1 should be blocked"
690 );
691 }
692
693 #[test]
694 fn test_private_ip_10_blocked() {
695 let origin = HeaderValue::from_static("http://10.0.0.1:8080");
696 assert!(
697 !is_localhost_origin(&origin),
698 "http://10.0.0.1:8080 should be blocked"
699 );
700 }
701
702 #[test]
703 fn test_private_ip_172_blocked() {
704 let origin = HeaderValue::from_static("http://172.16.0.1");
705 assert!(
706 !is_localhost_origin(&origin),
707 "http://172.16.0.1 should be blocked"
708 );
709 }
710
711 #[test]
714 fn test_no_scheme_blocked() {
715 let origin = HeaderValue::from_static("localhost:3000");
716 assert!(
717 !is_localhost_origin(&origin),
718 "localhost:3000 (no scheme) should be blocked"
719 );
720 }
721
722 #[test]
723 fn test_ftp_scheme_blocked() {
724 let origin = HeaderValue::from_static("ftp://localhost");
725 assert!(
726 !is_localhost_origin(&origin),
727 "ftp://localhost should be blocked"
728 );
729 }
730
731 #[test]
732 fn test_file_scheme_blocked() {
733 let origin = HeaderValue::from_static("file://localhost");
734 assert!(
735 !is_localhost_origin(&origin),
736 "file://localhost should be blocked"
737 );
738 }
739
740 #[test]
741 fn test_invalid_port_blocked() {
742 let origin = HeaderValue::from_static("http://localhost:notaport");
743 assert!(
744 !is_localhost_origin(&origin),
745 "http://localhost:notaport should be blocked"
746 );
747 }
748
749 #[test]
750 fn test_port_zero_blocked() {
751 let origin = HeaderValue::from_static("http://localhost:0");
752 assert!(
753 !is_localhost_origin(&origin),
754 "http://localhost:0 should be blocked (invalid port)"
755 );
756 }
757
758 #[test]
761 fn test_cors_config_default() {
762 let config = CorsConfig::default();
763 assert!(config.allow_all_localhost);
764 assert!(!config.allow_credentials);
765 assert!(!config.expose_headers);
766 assert_eq!(config.max_age_secs, DEFAULT_MAX_AGE_SECS);
767 }
768
769 #[test]
770 fn test_cors_config_builder() {
771 let config = CorsConfig::new()
772 .with_max_age(7200)
773 .with_allow_credentials(true)
774 .with_expose_headers(true);
775
776 assert_eq!(config.max_age_secs, 7200);
777 assert!(config.allow_credentials);
778 assert!(config.expose_headers);
779 }
780
781 #[test]
782 fn test_cors_config_strict_origins() {
783 let config = CorsConfig::new().with_strict_origins();
784 assert!(!config.allow_all_localhost);
785 }
786
787 #[test]
790 fn test_validate_origin_allowed() {
791 let result = validate_origin("http://localhost:3000");
792 assert!(result.allowed);
793 assert_eq!(result.origin, "http://localhost:3000");
794 assert!(result.reason.contains("allowed"));
795 }
796
797 #[test]
798 fn test_validate_origin_blocked_external() {
799 let result = validate_origin("http://example.com");
800 assert!(!result.allowed);
801 assert!(result.reason.contains("External") || result.reason.contains("not allowed"));
802 }
803
804 #[test]
805 fn test_validate_origin_blocked_private_ip() {
806 let result = validate_origin("http://192.168.1.100");
807 assert!(!result.allowed);
808 assert!(result.reason.contains("Private IP") || result.reason.contains("not allowed"));
809 }
810
811 #[test]
812 fn test_validate_origin_blocked_subdomain_attack() {
813 let result = validate_origin("http://localhost.evil.com");
814 assert!(!result.allowed);
815 }
816
817 #[test]
820 fn test_cors_layer_creation() {
821 let layer = cors_layer();
822 let _ = format!("{:?}", layer);
824 }
825
826 #[test]
827 fn test_cors_layer_with_config_creation() {
828 let config = CorsConfig::new()
829 .with_max_age(1800)
830 .with_allow_credentials(true);
831 let layer = cors_layer_with_config(config);
832 let _ = format!("{:?}", layer);
833 }
834
835 #[test]
836 fn test_cors_layer_permissive_creation() {
837 let layer = cors_layer_permissive();
838 let _ = format!("{:?}", layer);
839 }
840
841 #[test]
844 fn test_empty_origin_blocked() {
845 let origin = HeaderValue::from_static("");
846 assert!(
847 !is_localhost_origin(&origin),
848 "Empty origin should be blocked"
849 );
850 }
851
852 #[test]
853 fn test_case_insensitive_localhost() {
854 let origin = HeaderValue::from_static("HTTP://LOCALHOST:3000");
855 assert!(
856 is_localhost_origin(&origin),
857 "HTTP://LOCALHOST:3000 should be allowed (case insensitive)"
858 );
859 }
860
861 #[test]
862 fn test_case_insensitive_loopback() {
863 let origin = HeaderValue::from_static("HTTPS://127.0.0.1:8080");
864 assert!(
865 is_localhost_origin(&origin),
866 "HTTPS://127.0.0.1:8080 should be allowed (case insensitive)"
867 );
868 }
869
870 #[test]
871 fn test_localhost_with_trailing_slash() {
872 let origin = HeaderValue::from_static("http://localhost/");
873 assert!(
874 is_localhost_origin(&origin),
875 "http://localhost/ should be allowed"
876 );
877 }
878
879 #[test]
880 fn test_port_boundary_1() {
881 let origin = HeaderValue::from_static("http://localhost:1");
882 assert!(
883 is_localhost_origin(&origin),
884 "http://localhost:1 should be allowed"
885 );
886 }
887
888 #[test]
889 fn test_common_dev_ports() {
890 let ports = ["3000", "5000", "8000", "8080", "9000", "4200", "5173"];
891 for port in ports {
892 let origin_str = format!("http://localhost:{}", port);
893 let origin = HeaderValue::from_str(&origin_str).unwrap();
894 assert!(
895 is_localhost_origin(&origin),
896 "http://localhost:{} should be allowed",
897 port
898 );
899 }
900 }
901}