reasonkit_web/
cors.rs

1//! CORS (Cross-Origin Resource Sharing) Configuration for ReasonKit Web
2//!
3//! This module provides a strict CORS policy for the HTTP server, allowing only
4//! localhost origins for security. This is essential for MCP HTTP transport
5//! and browser-based integrations.
6//!
7//! # Security Policy
8//!
9//! - **Allowed Origins**: Only `localhost` and `127.0.0.1` on any port
10//! - **Allowed Methods**: GET, POST, OPTIONS (preflight)
11//! - **Allowed Headers**: Content-Type, Authorization
12//! - **Max Age**: 3600 seconds (1 hour) for preflight caching
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use reasonkit_web::cors::cors_layer;
18//! use axum::Router;
19//!
20//! let app = Router::new()
21//!     .route("/api/mcp", post(mcp_handler))
22//!     .layer(cors_layer());
23//! ```
24
25use http::{header::HeaderValue, Method};
26use std::time::Duration;
27use tower_http::cors::{AllowOrigin, CorsLayer};
28
29/// Standard allowed headers for MCP HTTP transport
30pub const ALLOWED_HEADERS: [http::header::HeaderName; 2] =
31    [http::header::CONTENT_TYPE, http::header::AUTHORIZATION];
32
33/// Standard allowed methods for MCP HTTP transport
34pub const ALLOWED_METHODS: [Method; 3] = [Method::GET, Method::POST, Method::OPTIONS];
35
36/// Default max age for preflight cache (1 hour)
37pub const DEFAULT_MAX_AGE_SECS: u64 = 3600;
38
39/// Creates a strict CORS layer that only allows localhost origins.
40///
41/// This is the recommended configuration for development and local MCP servers.
42/// For production deployments with specific domain requirements, use
43/// `cors_layer_with_origins` instead.
44///
45/// # Security Properties
46///
47/// - Only allows requests from `http://localhost:*` and `http://127.0.0.1:*`
48/// - Blocks all external origins including other private IP ranges
49/// - Properly handles preflight OPTIONS requests
50/// - Does not expose credentials by default
51///
52/// # Example
53///
54/// ```rust,ignore
55/// use reasonkit_web::cors::cors_layer;
56/// use axum::Router;
57///
58/// let app = Router::new()
59///     .layer(cors_layer());
60/// ```
61pub fn cors_layer() -> CorsLayer {
62    CorsLayer::new()
63        .allow_origin(AllowOrigin::predicate(|origin, _| {
64            is_localhost_origin(origin)
65        }))
66        .allow_methods(ALLOWED_METHODS)
67        .allow_headers(ALLOWED_HEADERS)
68        .max_age(Duration::from_secs(DEFAULT_MAX_AGE_SECS))
69}
70
71/// Creates a CORS layer with custom configuration.
72///
73/// # Arguments
74///
75/// * `config` - CORS configuration options
76///
77/// # Example
78///
79/// ```rust,no_run
80/// use reasonkit_web::cors::{cors_layer_with_config, CorsConfig};
81///
82/// let config = CorsConfig::default()
83///     .with_max_age(7200)
84///     .with_allow_credentials(true);
85///
86/// let layer = cors_layer_with_config(config);
87/// ```
88pub fn cors_layer_with_config(config: CorsConfig) -> CorsLayer {
89    let mut layer = CorsLayer::new()
90        .allow_origin(AllowOrigin::predicate(move |origin, _| {
91            if config.allow_all_localhost {
92                is_localhost_origin(origin)
93            } else {
94                // Strict mode: no origins allowed by default
95                false
96            }
97        }))
98        .allow_methods(config.allowed_methods.clone())
99        .allow_headers(config.allowed_headers.clone())
100        .max_age(Duration::from_secs(config.max_age_secs));
101
102    if config.allow_credentials {
103        layer = layer.allow_credentials(true);
104    }
105
106    if config.expose_headers {
107        layer = layer.expose_headers([http::header::CONTENT_LENGTH, http::header::CONTENT_TYPE]);
108    }
109
110    layer
111}
112
113/// Creates a permissive CORS layer for development/testing.
114///
115/// # Warning
116///
117/// This configuration is NOT secure for production use. It allows all origins.
118/// Use only for local development and testing.
119///
120/// # Example
121///
122/// ```rust,no_run
123/// use reasonkit_web::cors::cors_layer_permissive;
124///
125/// // Only for development!
126/// #[cfg(debug_assertions)]
127/// let layer = cors_layer_permissive();
128/// ```
129pub fn cors_layer_permissive() -> CorsLayer {
130    CorsLayer::new()
131        .allow_origin(tower_http::cors::Any)
132        .allow_methods(tower_http::cors::Any)
133        .allow_headers(tower_http::cors::Any)
134        .max_age(Duration::from_secs(DEFAULT_MAX_AGE_SECS))
135}
136
137/// CORS configuration options.
138#[derive(Debug, Clone)]
139pub struct CorsConfig {
140    /// Whether to allow all localhost origins (default: true)
141    pub allow_all_localhost: bool,
142    /// Whether to allow credentials (cookies, auth headers)
143    pub allow_credentials: bool,
144    /// Whether to expose response headers to the client
145    pub expose_headers: bool,
146    /// Maximum age for preflight cache in seconds
147    pub max_age_secs: u64,
148    /// Allowed HTTP methods
149    pub allowed_methods: Vec<Method>,
150    /// Allowed request headers
151    pub allowed_headers: Vec<http::header::HeaderName>,
152}
153
154impl Default for CorsConfig {
155    fn default() -> Self {
156        Self {
157            allow_all_localhost: true,
158            allow_credentials: false,
159            expose_headers: false,
160            max_age_secs: DEFAULT_MAX_AGE_SECS,
161            allowed_methods: ALLOWED_METHODS.to_vec(),
162            allowed_headers: ALLOWED_HEADERS.to_vec(),
163        }
164    }
165}
166
167impl CorsConfig {
168    /// Create a new CORS configuration with default settings.
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    /// Set the maximum age for preflight cache.
174    pub fn with_max_age(mut self, secs: u64) -> Self {
175        self.max_age_secs = secs;
176        self
177    }
178
179    /// Enable or disable credentials support.
180    pub fn with_allow_credentials(mut self, allow: bool) -> Self {
181        self.allow_credentials = allow;
182        self
183    }
184
185    /// Enable or disable header exposure.
186    pub fn with_expose_headers(mut self, expose: bool) -> Self {
187        self.expose_headers = expose;
188        self
189    }
190
191    /// Set allowed HTTP methods.
192    pub fn with_methods(mut self, methods: Vec<Method>) -> Self {
193        self.allowed_methods = methods;
194        self
195    }
196
197    /// Set allowed request headers.
198    pub fn with_headers(mut self, headers: Vec<http::header::HeaderName>) -> Self {
199        self.allowed_headers = headers;
200        self
201    }
202
203    /// Disable localhost origin allowance (for custom origin handling).
204    pub fn with_strict_origins(mut self) -> Self {
205        self.allow_all_localhost = false;
206        self
207    }
208}
209
210/// Checks if the given origin is a localhost origin.
211///
212/// # Valid Origins
213///
214/// - `http://localhost` (any port)
215/// - `http://127.0.0.1` (any port)
216/// - `https://localhost` (any port, for secure contexts)
217/// - `https://127.0.0.1` (any port)
218///
219/// # Invalid Origins
220///
221/// - External domains (e.g., `http://example.com`)
222/// - Other private IPs (e.g., `http://192.168.1.1`)
223/// - IPv6 localhost (currently not supported)
224///
225/// # Arguments
226///
227/// * `origin` - The Origin header value to check
228///
229/// # Returns
230///
231/// `true` if the origin is a valid localhost origin, `false` otherwise.
232///
233/// # Example
234///
235/// ```rust
236/// use http::header::HeaderValue;
237/// use reasonkit_web::cors::is_localhost_origin;
238///
239/// let origin = HeaderValue::from_static("http://localhost:3000");
240/// assert!(is_localhost_origin(&origin));
241///
242/// let external = HeaderValue::from_static("http://example.com");
243/// assert!(!is_localhost_origin(&external));
244/// ```
245pub fn is_localhost_origin(origin: &HeaderValue) -> bool {
246    let origin_str = match origin.to_str() {
247        Ok(s) => s,
248        Err(_) => return false, // Invalid UTF-8, reject
249    };
250
251    // Parse the origin to extract host
252    // Origin format: scheme://host[:port]
253    let origin_lower = origin_str.to_lowercase();
254
255    // Check for localhost patterns
256    // http://localhost or http://localhost:PORT
257    if origin_lower.starts_with("http://localhost") || origin_lower.starts_with("https://localhost")
258    {
259        return validate_localhost_format(&origin_lower, "localhost");
260    }
261
262    // Check for 127.0.0.1 patterns
263    // http://127.0.0.1 or http://127.0.0.1:PORT
264    if origin_lower.starts_with("http://127.0.0.1") || origin_lower.starts_with("https://127.0.0.1")
265    {
266        return validate_localhost_format(&origin_lower, "127.0.0.1");
267    }
268
269    // Check for IPv6 localhost [::1]
270    if origin_lower.starts_with("http://[::1]") || origin_lower.starts_with("https://[::1]") {
271        return validate_ipv6_localhost_format(&origin_lower);
272    }
273
274    false
275}
276
277/// Validates the format of a localhost origin string.
278fn validate_localhost_format(origin: &str, host: &str) -> bool {
279    // Find the position after the host
280    let scheme_end = if origin.starts_with("https://") {
281        8
282    } else {
283        7 // "http://"
284    };
285
286    let after_host = scheme_end + host.len();
287
288    // Check what follows the host
289    if origin.len() == after_host {
290        // Exact match: http://localhost
291        return true;
292    }
293
294    let remaining = &origin[after_host..];
295
296    // Should be either end of string, port, or path
297    if let Some(port_str) = remaining.strip_prefix(':') {
298        // Port follows - validate it's a number
299        // Port might be followed by path
300        let port_end = port_str.find('/').unwrap_or(port_str.len());
301        let port = &port_str[..port_end];
302
303        // Validate port is numeric and in valid range
304        if let Ok(port_num) = port.parse::<u16>() {
305            return port_num > 0;
306        }
307        return false;
308    }
309
310    if remaining.starts_with('/') {
311        // Path follows directly (no port)
312        return true;
313    }
314
315    // Invalid format (e.g., "localhostevil.com")
316    false
317}
318
319/// Validates the format of an IPv6 localhost origin string.
320fn validate_ipv6_localhost_format(origin: &str) -> bool {
321    // IPv6 localhost format: http://[::1] or http://[::1]:PORT
322    let scheme_end = if origin.starts_with("https://") { 8 } else { 7 };
323
324    let after_bracket = origin[scheme_end..].find(']');
325    if let Some(pos) = after_bracket {
326        let after_host = scheme_end + pos + 1;
327        if origin.len() == after_host {
328            return true;
329        }
330
331        let remaining = &origin[after_host..];
332        if let Some(port_str) = remaining.strip_prefix(':') {
333            let port_end = port_str.find('/').unwrap_or(port_str.len());
334            let port = &port_str[..port_end];
335            if let Ok(port_num) = port.parse::<u16>() {
336                return port_num > 0;
337            }
338            return false;
339        }
340
341        if remaining.starts_with('/') {
342            return true;
343        }
344    }
345
346    false
347}
348
349/// Result of CORS validation containing diagnostic information.
350#[derive(Debug, Clone)]
351pub struct CorsValidationResult {
352    /// Whether the origin is allowed
353    pub allowed: bool,
354    /// The origin that was checked
355    pub origin: String,
356    /// Reason for the decision
357    pub reason: String,
358}
359
360impl CorsValidationResult {
361    /// Create a new validation result.
362    pub fn new(allowed: bool, origin: String, reason: String) -> Self {
363        Self {
364            allowed,
365            origin,
366            reason,
367        }
368    }
369}
370
371/// Validates an origin and returns detailed information.
372///
373/// Useful for debugging and logging CORS decisions.
374///
375/// # Example
376///
377/// ```rust
378/// use reasonkit_web::cors::validate_origin;
379///
380/// let result = validate_origin("http://localhost:3000");
381/// assert!(result.allowed);
382/// println!("Reason: {}", result.reason);
383/// ```
384pub fn validate_origin(origin: &str) -> CorsValidationResult {
385    let header_value = match HeaderValue::from_str(origin) {
386        Ok(v) => v,
387        Err(_) => {
388            return CorsValidationResult::new(
389                false,
390                origin.to_string(),
391                "Invalid header value format".to_string(),
392            );
393        }
394    };
395
396    let allowed = is_localhost_origin(&header_value);
397    let reason = if allowed {
398        "Localhost origin allowed".to_string()
399    } else {
400        determine_rejection_reason(origin)
401    };
402
403    CorsValidationResult::new(allowed, origin.to_string(), reason)
404}
405
406/// Determines the specific reason why an origin was rejected.
407fn determine_rejection_reason(origin: &str) -> String {
408    let origin_lower = origin.to_lowercase();
409
410    if !origin_lower.starts_with("http://") && !origin_lower.starts_with("https://") {
411        return "Invalid scheme: must be http:// or https://".to_string();
412    }
413
414    if origin_lower.contains("localhost") && !is_valid_localhost_pattern(&origin_lower) {
415        return "Invalid localhost format: possible subdomain attack".to_string();
416    }
417
418    if origin_lower.contains("127.0.0.1") && !is_valid_loopback_pattern(&origin_lower) {
419        return "Invalid 127.0.0.1 format".to_string();
420    }
421
422    // Check for other private IPs that we don't allow
423    if is_private_ip_origin(&origin_lower) {
424        return "Private IP origins other than 127.0.0.1 are not allowed".to_string();
425    }
426
427    "External origin not allowed: only localhost origins permitted".to_string()
428}
429
430/// Checks if the origin matches valid localhost patterns.
431fn is_valid_localhost_pattern(origin: &str) -> bool {
432    let patterns = [
433        "http://localhost",
434        "https://localhost",
435        "http://localhost:",
436        "https://localhost:",
437        "http://localhost/",
438        "https://localhost/",
439    ];
440
441    for pattern in patterns {
442        if origin.starts_with(pattern) {
443            return true;
444        }
445    }
446
447    false
448}
449
450/// Checks if the origin matches valid loopback patterns.
451fn is_valid_loopback_pattern(origin: &str) -> bool {
452    let patterns = [
453        "http://127.0.0.1",
454        "https://127.0.0.1",
455        "http://127.0.0.1:",
456        "https://127.0.0.1:",
457        "http://127.0.0.1/",
458        "https://127.0.0.1/",
459    ];
460
461    for pattern in patterns {
462        if origin.starts_with(pattern) {
463            return true;
464        }
465    }
466
467    false
468}
469
470/// Checks if the origin appears to be a private IP (not 127.0.0.1).
471fn is_private_ip_origin(origin: &str) -> bool {
472    // Common private IP ranges we want to block
473    let private_patterns = [
474        "192.168.", "10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.20.", "172.21.",
475        "172.22.", "172.23.", "172.24.", "172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
476        "172.30.", "172.31.",
477    ];
478
479    for pattern in private_patterns {
480        if origin.contains(pattern) {
481            return true;
482        }
483    }
484
485    false
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    // ==================== Localhost Origin Tests ====================
493
494    #[test]
495    fn test_localhost_origin_http() {
496        let origin = HeaderValue::from_static("http://localhost");
497        assert!(
498            is_localhost_origin(&origin),
499            "http://localhost should be allowed"
500        );
501    }
502
503    #[test]
504    fn test_localhost_origin_https() {
505        let origin = HeaderValue::from_static("https://localhost");
506        assert!(
507            is_localhost_origin(&origin),
508            "https://localhost should be allowed"
509        );
510    }
511
512    #[test]
513    fn test_localhost_origin_with_port() {
514        let origin = HeaderValue::from_static("http://localhost:3000");
515        assert!(
516            is_localhost_origin(&origin),
517            "http://localhost:3000 should be allowed"
518        );
519    }
520
521    #[test]
522    fn test_localhost_origin_with_high_port() {
523        let origin = HeaderValue::from_static("http://localhost:65535");
524        assert!(
525            is_localhost_origin(&origin),
526            "http://localhost:65535 should be allowed"
527        );
528    }
529
530    #[test]
531    fn test_localhost_origin_with_path() {
532        let origin = HeaderValue::from_static("http://localhost/api");
533        assert!(
534            is_localhost_origin(&origin),
535            "http://localhost/api should be allowed"
536        );
537    }
538
539    #[test]
540    fn test_localhost_origin_with_port_and_path() {
541        let origin = HeaderValue::from_static("http://localhost:8080/api/v1");
542        assert!(
543            is_localhost_origin(&origin),
544            "http://localhost:8080/api/v1 should be allowed"
545        );
546    }
547
548    // ==================== 127.0.0.1 Origin Tests ====================
549
550    #[test]
551    fn test_loopback_origin_http() {
552        let origin = HeaderValue::from_static("http://127.0.0.1");
553        assert!(
554            is_localhost_origin(&origin),
555            "http://127.0.0.1 should be allowed"
556        );
557    }
558
559    #[test]
560    fn test_loopback_origin_https() {
561        let origin = HeaderValue::from_static("https://127.0.0.1");
562        assert!(
563            is_localhost_origin(&origin),
564            "https://127.0.0.1 should be allowed"
565        );
566    }
567
568    #[test]
569    fn test_loopback_origin_with_port() {
570        let origin = HeaderValue::from_static("http://127.0.0.1:8000");
571        assert!(
572            is_localhost_origin(&origin),
573            "http://127.0.0.1:8000 should be allowed"
574        );
575    }
576
577    #[test]
578    fn test_loopback_origin_with_path() {
579        let origin = HeaderValue::from_static("http://127.0.0.1/mcp");
580        assert!(
581            is_localhost_origin(&origin),
582            "http://127.0.0.1/mcp should be allowed"
583        );
584    }
585
586    // ==================== IPv6 Localhost Tests ====================
587
588    #[test]
589    fn test_ipv6_localhost_origin() {
590        let origin = HeaderValue::from_static("http://[::1]");
591        assert!(
592            is_localhost_origin(&origin),
593            "http://[::1] should be allowed"
594        );
595    }
596
597    #[test]
598    fn test_ipv6_localhost_origin_with_port() {
599        let origin = HeaderValue::from_static("http://[::1]:3000");
600        assert!(
601            is_localhost_origin(&origin),
602            "http://[::1]:3000 should be allowed"
603        );
604    }
605
606    #[test]
607    fn test_ipv6_localhost_origin_https() {
608        let origin = HeaderValue::from_static("https://[::1]:8080");
609        assert!(
610            is_localhost_origin(&origin),
611            "https://[::1]:8080 should be allowed"
612        );
613    }
614
615    // ==================== External Origin Tests (Should Block) ====================
616
617    #[test]
618    fn test_external_origin_blocked() {
619        let origin = HeaderValue::from_static("http://example.com");
620        assert!(
621            !is_localhost_origin(&origin),
622            "http://example.com should be blocked"
623        );
624    }
625
626    #[test]
627    fn test_external_origin_with_port_blocked() {
628        let origin = HeaderValue::from_static("http://evil.com:3000");
629        assert!(
630            !is_localhost_origin(&origin),
631            "http://evil.com:3000 should be blocked"
632        );
633    }
634
635    #[test]
636    fn test_external_https_blocked() {
637        let origin = HeaderValue::from_static("https://malicious.org");
638        assert!(
639            !is_localhost_origin(&origin),
640            "https://malicious.org should be blocked"
641        );
642    }
643
644    // ==================== Subdomain Attack Prevention Tests ====================
645
646    #[test]
647    fn test_localhost_subdomain_attack_blocked() {
648        let origin = HeaderValue::from_static("http://localhost.evil.com");
649        assert!(
650            !is_localhost_origin(&origin),
651            "http://localhost.evil.com should be blocked (subdomain attack)"
652        );
653    }
654
655    #[test]
656    fn test_localhostevil_blocked() {
657        let origin = HeaderValue::from_static("http://localhostevil.com");
658        assert!(
659            !is_localhost_origin(&origin),
660            "http://localhostevil.com should be blocked"
661        );
662    }
663
664    #[test]
665    fn test_subdomain_localhost_blocked() {
666        let origin = HeaderValue::from_static("http://sub.localhost.com");
667        assert!(
668            !is_localhost_origin(&origin),
669            "http://sub.localhost.com should be blocked"
670        );
671    }
672
673    #[test]
674    fn test_fake_localhost_blocked() {
675        let origin = HeaderValue::from_static("http://my-localhost.com");
676        assert!(
677            !is_localhost_origin(&origin),
678            "http://my-localhost.com should be blocked"
679        );
680    }
681
682    // ==================== Private IP Tests (Should Block) ====================
683
684    #[test]
685    fn test_private_ip_192_blocked() {
686        let origin = HeaderValue::from_static("http://192.168.1.1");
687        assert!(
688            !is_localhost_origin(&origin),
689            "http://192.168.1.1 should be blocked"
690        );
691    }
692
693    #[test]
694    fn test_private_ip_10_blocked() {
695        let origin = HeaderValue::from_static("http://10.0.0.1:8080");
696        assert!(
697            !is_localhost_origin(&origin),
698            "http://10.0.0.1:8080 should be blocked"
699        );
700    }
701
702    #[test]
703    fn test_private_ip_172_blocked() {
704        let origin = HeaderValue::from_static("http://172.16.0.1");
705        assert!(
706            !is_localhost_origin(&origin),
707            "http://172.16.0.1 should be blocked"
708        );
709    }
710
711    // ==================== Invalid Format Tests ====================
712
713    #[test]
714    fn test_no_scheme_blocked() {
715        let origin = HeaderValue::from_static("localhost:3000");
716        assert!(
717            !is_localhost_origin(&origin),
718            "localhost:3000 (no scheme) should be blocked"
719        );
720    }
721
722    #[test]
723    fn test_ftp_scheme_blocked() {
724        let origin = HeaderValue::from_static("ftp://localhost");
725        assert!(
726            !is_localhost_origin(&origin),
727            "ftp://localhost should be blocked"
728        );
729    }
730
731    #[test]
732    fn test_file_scheme_blocked() {
733        let origin = HeaderValue::from_static("file://localhost");
734        assert!(
735            !is_localhost_origin(&origin),
736            "file://localhost should be blocked"
737        );
738    }
739
740    #[test]
741    fn test_invalid_port_blocked() {
742        let origin = HeaderValue::from_static("http://localhost:notaport");
743        assert!(
744            !is_localhost_origin(&origin),
745            "http://localhost:notaport should be blocked"
746        );
747    }
748
749    #[test]
750    fn test_port_zero_blocked() {
751        let origin = HeaderValue::from_static("http://localhost:0");
752        assert!(
753            !is_localhost_origin(&origin),
754            "http://localhost:0 should be blocked (invalid port)"
755        );
756    }
757
758    // ==================== CORS Config Tests ====================
759
760    #[test]
761    fn test_cors_config_default() {
762        let config = CorsConfig::default();
763        assert!(config.allow_all_localhost);
764        assert!(!config.allow_credentials);
765        assert!(!config.expose_headers);
766        assert_eq!(config.max_age_secs, DEFAULT_MAX_AGE_SECS);
767    }
768
769    #[test]
770    fn test_cors_config_builder() {
771        let config = CorsConfig::new()
772            .with_max_age(7200)
773            .with_allow_credentials(true)
774            .with_expose_headers(true);
775
776        assert_eq!(config.max_age_secs, 7200);
777        assert!(config.allow_credentials);
778        assert!(config.expose_headers);
779    }
780
781    #[test]
782    fn test_cors_config_strict_origins() {
783        let config = CorsConfig::new().with_strict_origins();
784        assert!(!config.allow_all_localhost);
785    }
786
787    // ==================== Validation Result Tests ====================
788
789    #[test]
790    fn test_validate_origin_allowed() {
791        let result = validate_origin("http://localhost:3000");
792        assert!(result.allowed);
793        assert_eq!(result.origin, "http://localhost:3000");
794        assert!(result.reason.contains("allowed"));
795    }
796
797    #[test]
798    fn test_validate_origin_blocked_external() {
799        let result = validate_origin("http://example.com");
800        assert!(!result.allowed);
801        assert!(result.reason.contains("External") || result.reason.contains("not allowed"));
802    }
803
804    #[test]
805    fn test_validate_origin_blocked_private_ip() {
806        let result = validate_origin("http://192.168.1.100");
807        assert!(!result.allowed);
808        assert!(result.reason.contains("Private IP") || result.reason.contains("not allowed"));
809    }
810
811    #[test]
812    fn test_validate_origin_blocked_subdomain_attack() {
813        let result = validate_origin("http://localhost.evil.com");
814        assert!(!result.allowed);
815    }
816
817    // ==================== Layer Creation Tests ====================
818
819    #[test]
820    fn test_cors_layer_creation() {
821        let layer = cors_layer();
822        // Layer should be created without panicking
823        let _ = format!("{:?}", layer);
824    }
825
826    #[test]
827    fn test_cors_layer_with_config_creation() {
828        let config = CorsConfig::new()
829            .with_max_age(1800)
830            .with_allow_credentials(true);
831        let layer = cors_layer_with_config(config);
832        let _ = format!("{:?}", layer);
833    }
834
835    #[test]
836    fn test_cors_layer_permissive_creation() {
837        let layer = cors_layer_permissive();
838        let _ = format!("{:?}", layer);
839    }
840
841    // ==================== Edge Case Tests ====================
842
843    #[test]
844    fn test_empty_origin_blocked() {
845        let origin = HeaderValue::from_static("");
846        assert!(
847            !is_localhost_origin(&origin),
848            "Empty origin should be blocked"
849        );
850    }
851
852    #[test]
853    fn test_case_insensitive_localhost() {
854        let origin = HeaderValue::from_static("HTTP://LOCALHOST:3000");
855        assert!(
856            is_localhost_origin(&origin),
857            "HTTP://LOCALHOST:3000 should be allowed (case insensitive)"
858        );
859    }
860
861    #[test]
862    fn test_case_insensitive_loopback() {
863        let origin = HeaderValue::from_static("HTTPS://127.0.0.1:8080");
864        assert!(
865            is_localhost_origin(&origin),
866            "HTTPS://127.0.0.1:8080 should be allowed (case insensitive)"
867        );
868    }
869
870    #[test]
871    fn test_localhost_with_trailing_slash() {
872        let origin = HeaderValue::from_static("http://localhost/");
873        assert!(
874            is_localhost_origin(&origin),
875            "http://localhost/ should be allowed"
876        );
877    }
878
879    #[test]
880    fn test_port_boundary_1() {
881        let origin = HeaderValue::from_static("http://localhost:1");
882        assert!(
883            is_localhost_origin(&origin),
884            "http://localhost:1 should be allowed"
885        );
886    }
887
888    #[test]
889    fn test_common_dev_ports() {
890        let ports = ["3000", "5000", "8000", "8080", "9000", "4200", "5173"];
891        for port in ports {
892            let origin_str = format!("http://localhost:{}", port);
893            let origin = HeaderValue::from_str(&origin_str).unwrap();
894            assert!(
895                is_localhost_origin(&origin),
896                "http://localhost:{} should be allowed",
897                port
898            );
899        }
900    }
901}