Skip to main content

pylon_plugin/builtin/
csrf.rs

1use crate::PluginError;
2
3/// CSRF protection plugin.
4///
5/// Validates the `Origin` or `Referer` header on state-changing requests
6/// (POST, PATCH, DELETE, PUT) against a list of allowed origins. This is
7/// complementary to CORS: CORS controls which origins can *read* responses,
8/// while CSRF protection ensures that state-changing requests originate from
9/// trusted sources.
10pub struct CsrfPlugin {
11    allowed_origins: Vec<String>,
12}
13
14impl CsrfPlugin {
15    /// Create a CSRF plugin with explicit allowed origins.
16    pub fn new(allowed_origins: Vec<String>) -> Self {
17        Self { allowed_origins }
18    }
19
20    /// Convenience constructor for local development. Allows both `localhost`
21    /// and `127.0.0.1` on the given port.
22    pub fn with_localhost(port: u16) -> Self {
23        Self::new(vec![
24            format!("http://localhost:{port}"),
25            format!("http://127.0.0.1:{port}"),
26        ])
27    }
28
29    /// Safe (read-only) methods that do not require origin validation.
30    fn is_safe_method(method: &str) -> bool {
31        matches!(method, "GET" | "HEAD" | "OPTIONS")
32    }
33
34    /// Check whether `origin` is in the allowlist. A wildcard entry (`"*"`)
35    /// matches every origin.
36    fn is_allowed_origin(&self, origin: &str) -> bool {
37        self.allowed_origins.iter().any(|o| o == origin || o == "*")
38    }
39
40    /// Extract the origin portion (`scheme://host[:port]`) from a full URL
41    /// such as a `Referer` header value.
42    ///
43    /// ```text
44    /// "http://example.com/path?q=1" -> Some("http://example.com")
45    /// "https://a.b:8080/x"          -> Some("https://a.b:8080")
46    /// "garbage"                      -> None
47    /// ```
48    fn origin_from_referer(referer: &str) -> Option<String> {
49        // Split on '/' keeping at most 4 parts:
50        //   "http:" "" "example.com" "path..."
51        let parts: Vec<&str> = referer.splitn(4, '/').collect();
52        if parts.len() >= 3 && !parts[2].is_empty() {
53            Some(format!("{}//{}", parts[0], parts[2]))
54        } else {
55            None
56        }
57    }
58
59    /// Validate an incoming request.
60    ///
61    /// For safe methods this always succeeds. For state-changing methods the
62    /// `Origin` header is checked first; if absent the origin is derived from
63    /// the `Referer` header. If neither header provides a trusted origin the
64    /// request is rejected.
65    pub fn check(
66        &self,
67        method: &str,
68        origin: Option<&str>,
69        referer: Option<&str>,
70    ) -> Result<(), PluginError> {
71        if Self::is_safe_method(method) {
72            return Ok(());
73        }
74
75        let effective_origin = origin
76            .map(String::from)
77            .or_else(|| referer.and_then(Self::origin_from_referer));
78
79        match effective_origin {
80            Some(ref o) if self.is_allowed_origin(o) => Ok(()),
81            Some(ref o) => Err(PluginError {
82                code: "CSRF_REJECTED".into(),
83                message: format!("Origin '{}' not allowed", o),
84                status: 403,
85            }),
86            None => Err(PluginError {
87                code: "CSRF_NO_ORIGIN".into(),
88                message: "Missing Origin header on state-changing request".into(),
89                status: 403,
90            }),
91        }
92    }
93}
94
95impl crate::Plugin for CsrfPlugin {
96    fn name(&self) -> &str {
97        "csrf"
98    }
99
100    fn on_request(
101        &self,
102        _method: &str,
103        _path: &str,
104        _auth: &pylon_auth::AuthContext,
105    ) -> Result<(), PluginError> {
106        // The Plugin trait's on_request does not receive HTTP headers, so CSRF
107        // validation cannot happen here automatically. Use `check()` at the
108        // HTTP layer where headers are available.
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    fn localhost_plugin() -> CsrfPlugin {
118        CsrfPlugin::with_localhost(3000)
119    }
120
121    // -- Safe methods always pass --
122
123    #[test]
124    fn safe_methods_pass_without_origin() {
125        let csrf = localhost_plugin();
126        for method in &["GET", "HEAD", "OPTIONS"] {
127            assert!(csrf.check(method, None, None).is_ok());
128        }
129    }
130
131    #[test]
132    fn safe_methods_pass_with_bad_origin() {
133        let csrf = localhost_plugin();
134        assert!(csrf.check("GET", Some("https://evil.com"), None).is_ok());
135    }
136
137    // -- Matching origin passes --
138
139    #[test]
140    fn matching_origin_passes() {
141        let csrf = localhost_plugin();
142        assert!(csrf
143            .check("POST", Some("http://localhost:3000"), None)
144            .is_ok());
145        assert!(csrf
146            .check("DELETE", Some("http://127.0.0.1:3000"), None)
147            .is_ok());
148    }
149
150    // -- Wrong origin rejected --
151
152    #[test]
153    fn wrong_origin_rejected() {
154        let csrf = localhost_plugin();
155        let err = csrf
156            .check("POST", Some("https://evil.com"), None)
157            .unwrap_err();
158        assert_eq!(err.code, "CSRF_REJECTED");
159        assert_eq!(err.status, 403);
160    }
161
162    // -- Missing origin rejected --
163
164    #[test]
165    fn missing_origin_rejected() {
166        let csrf = localhost_plugin();
167        let err = csrf.check("PUT", None, None).unwrap_err();
168        assert_eq!(err.code, "CSRF_NO_ORIGIN");
169        assert_eq!(err.status, 403);
170    }
171
172    // -- Wildcard allows all --
173
174    #[test]
175    fn wildcard_allows_all() {
176        let csrf = CsrfPlugin::new(vec!["*".into()]);
177        assert!(csrf
178            .check("POST", Some("https://anything.example.com"), None)
179            .is_ok());
180        assert!(csrf.check("DELETE", Some("http://evil.com"), None).is_ok());
181    }
182
183    // -- Referer extraction --
184
185    #[test]
186    fn origin_from_referer_extraction() {
187        assert_eq!(
188            CsrfPlugin::origin_from_referer("http://example.com/path?q=1"),
189            Some("http://example.com".into())
190        );
191        assert_eq!(
192            CsrfPlugin::origin_from_referer("https://a.b:8080/x/y"),
193            Some("https://a.b:8080".into())
194        );
195        assert_eq!(CsrfPlugin::origin_from_referer("garbage"), None);
196        assert_eq!(CsrfPlugin::origin_from_referer(""), None);
197    }
198
199    // -- Referer fallback when Origin is missing --
200
201    #[test]
202    fn referer_fallback_when_origin_missing() {
203        let csrf = localhost_plugin();
204        assert!(csrf
205            .check("POST", None, Some("http://localhost:3000/some/path"))
206            .is_ok());
207    }
208
209    #[test]
210    fn referer_fallback_wrong_origin() {
211        let csrf = localhost_plugin();
212        let err = csrf
213            .check("POST", None, Some("https://evil.com/attack"))
214            .unwrap_err();
215        assert_eq!(err.code, "CSRF_REJECTED");
216    }
217
218    // -- All state-changing methods are checked --
219
220    #[test]
221    fn all_state_changing_methods_require_origin() {
222        let csrf = localhost_plugin();
223        for method in &["POST", "PUT", "PATCH", "DELETE"] {
224            assert!(csrf.check(method, None, None).is_err());
225        }
226    }
227
228    // -- Plugin trait --
229
230    #[test]
231    fn plugin_name() {
232        let csrf = localhost_plugin();
233        assert_eq!(crate::Plugin::name(&csrf), "csrf");
234    }
235}