Skip to main content

pylon_plugin/builtin/
cors.rs

1use crate::{Plugin, PluginError};
2use pylon_auth::AuthContext;
3
4/// CORS plugin. Validates request origins against an allowlist.
5pub struct CorsPlugin {
6    /// Allowed origins. Empty = allow all ("*").
7    pub allowed_origins: Vec<String>,
8    /// Whether to allow credentials.
9    pub allow_credentials: bool,
10}
11
12impl CorsPlugin {
13    /// Allow all origins.
14    pub fn allow_all() -> Self {
15        Self {
16            allowed_origins: vec![],
17            allow_credentials: false,
18        }
19    }
20
21    /// Allow specific origins.
22    pub fn new(origins: Vec<String>) -> Self {
23        Self {
24            allowed_origins: origins,
25            allow_credentials: true,
26        }
27    }
28
29    /// Check if an origin is allowed.
30    pub fn is_allowed(&self, origin: &str) -> bool {
31        if self.allowed_origins.is_empty() {
32            return true; // wildcard
33        }
34        self.allowed_origins.iter().any(|o| o == origin || o == "*")
35    }
36
37    /// Get the Access-Control-Allow-Origin header value.
38    pub fn allow_origin_header(&self, request_origin: Option<&str>) -> String {
39        if self.allowed_origins.is_empty() {
40            return "*".to_string();
41        }
42        match request_origin {
43            Some(origin) if self.is_allowed(origin) => origin.to_string(),
44            _ => String::new(),
45        }
46    }
47}
48
49impl Plugin for CorsPlugin {
50    fn name(&self) -> &str {
51        "cors"
52    }
53
54    fn on_request(
55        &self,
56        _method: &str,
57        _path: &str,
58        _auth: &AuthContext,
59    ) -> Result<(), PluginError> {
60        // CORS is handled at the HTTP layer (headers), not here.
61        // This plugin provides the configuration; the server reads it.
62        Ok(())
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn allow_all() {
72        let cors = CorsPlugin::allow_all();
73        assert!(cors.is_allowed("http://localhost:3000"));
74        assert!(cors.is_allowed("https://example.com"));
75        assert_eq!(cors.allow_origin_header(Some("http://localhost:3000")), "*");
76    }
77
78    #[test]
79    fn specific_origins() {
80        let cors = CorsPlugin::new(vec![
81            "http://localhost:3000".into(),
82            "https://myapp.com".into(),
83        ]);
84
85        assert!(cors.is_allowed("http://localhost:3000"));
86        assert!(cors.is_allowed("https://myapp.com"));
87        assert!(!cors.is_allowed("https://evil.com"));
88    }
89
90    #[test]
91    fn allow_origin_header_matches() {
92        let cors = CorsPlugin::new(vec!["https://myapp.com".into()]);
93
94        assert_eq!(
95            cors.allow_origin_header(Some("https://myapp.com")),
96            "https://myapp.com"
97        );
98        assert_eq!(cors.allow_origin_header(Some("https://evil.com")), "");
99        assert_eq!(cors.allow_origin_header(None), "");
100    }
101}