Skip to main content

pylon_plugin/builtin/
csrf.rs

1use crate::PluginError;
2
3/// CSRF protection plugin.
4///
5/// Validates the `Origin` or `Referer` header on state-changing requests
6/// (POST, PATCH, DELETE, PUT) against a list of allowed origins. This is
7/// complementary to CORS: CORS controls which origins can *read* responses,
8/// while CSRF protection ensures that state-changing requests originate from
9/// trusted sources.
10pub struct CsrfPlugin {
11    allowed_origins: Vec<String>,
12}
13
14impl CsrfPlugin {
15    /// Create a CSRF plugin with explicit allowed origins.
16    pub fn new(allowed_origins: Vec<String>) -> Self {
17        Self { allowed_origins }
18    }
19
20    /// Convenience constructor for local development. Allows both `localhost`
21    /// and `127.0.0.1` on the given port.
22    pub fn with_localhost(port: u16) -> Self {
23        Self::new(vec![
24            format!("http://localhost:{port}"),
25            format!("http://127.0.0.1:{port}"),
26        ])
27    }
28
29    /// Safe (read-only) methods that do not require origin validation.
30    fn is_safe_method(method: &str) -> bool {
31        matches!(method, "GET" | "HEAD" | "OPTIONS")
32    }
33
34    /// Check whether `origin` is in the allowlist. A wildcard entry (`"*"`)
35    /// matches every origin.
36    fn is_allowed_origin(&self, origin: &str) -> bool {
37        self.allowed_origins.iter().any(|o| o == origin || o == "*")
38    }
39
40    /// Extract the origin portion (`scheme://host[:port]`) from a full URL
41    /// such as a `Referer` header value.
42    ///
43    /// ```text
44    /// "http://example.com/path?q=1" -> Some("http://example.com")
45    /// "https://a.b:8080/x"          -> Some("https://a.b:8080")
46    /// "garbage"                      -> None
47    /// ```
48    fn origin_from_referer(referer: &str) -> Option<String> {
49        // Split on '/' keeping at most 4 parts:
50        //   "http:" "" "example.com" "path..."
51        let parts: Vec<&str> = referer.splitn(4, '/').collect();
52        if parts.len() >= 3 && !parts[2].is_empty() {
53            Some(format!("{}//{}", parts[0], parts[2]))
54        } else {
55            None
56        }
57    }
58
59    /// Validate an incoming request.
60    ///
61    /// For safe methods this always succeeds. For state-changing
62    /// methods, the `Origin` header is checked first; if absent the
63    /// origin is derived from the `Referer` header.
64    ///
65    /// **CSRF defense model.** Modern browsers always send `Origin`
66    /// on cross-origin state-changing requests — a malicious page
67    /// can't suppress it. Browsers also send `Origin` on same-site
68    /// POSTs in current spec. So a request with NEITHER `Origin` nor
69    /// `Referer` is by definition not a browser request — it's a
70    /// server-to-server caller (Next.js SSR forwarding a session
71    /// cookie, a curl script with `--cookie`, an internal admin
72    /// tool, etc.). Those callers attach the cookie explicitly via
73    /// the `Cookie:` header rather than relying on browser
74    /// auto-attachment, so the cross-site forgery attack surface
75    /// the CSRF gate exists to protect against doesn't apply.
76    ///
77    /// Without this allowance every Next.js dashboard route that
78    /// calls a Pylon mutation server-side (`pylon.json("/api/fn/X",
79    /// {method: "POST"})`) would 403 — Next.js SSR has no Origin to
80    /// send. We learned this the hard way via the dashboard
81    /// "Members" page returning empty after release 0.3.11.
82    ///
83    /// When a header IS present it must match the allowlist; an
84    /// attacker can never inject one, so its presence is always
85    /// trustworthy.
86    pub fn check(
87        &self,
88        method: &str,
89        origin: Option<&str>,
90        referer: Option<&str>,
91    ) -> Result<(), PluginError> {
92        if Self::is_safe_method(method) {
93            return Ok(());
94        }
95
96        let effective_origin = origin
97            .map(String::from)
98            .or_else(|| referer.and_then(Self::origin_from_referer));
99
100        match effective_origin {
101            Some(ref o) if self.is_allowed_origin(o) => Ok(()),
102            Some(ref o) => Err(PluginError {
103                code: "CSRF_REJECTED".into(),
104                message: format!("Origin '{}' not allowed", o),
105                status: 403,
106            }),
107            // Server-to-server caller — see contract above.
108            None => Ok(()),
109        }
110    }
111}
112
113impl crate::Plugin for CsrfPlugin {
114    fn name(&self) -> &str {
115        "csrf"
116    }
117
118    fn on_request(
119        &self,
120        _method: &str,
121        _path: &str,
122        _auth: &pylon_auth::AuthContext,
123    ) -> Result<(), PluginError> {
124        // The Plugin trait's on_request does not receive HTTP headers, so CSRF
125        // validation cannot happen here automatically. Use `check()` at the
126        // HTTP layer where headers are available.
127        Ok(())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    fn localhost_plugin() -> CsrfPlugin {
136        CsrfPlugin::with_localhost(3000)
137    }
138
139    // -- Safe methods always pass --
140
141    #[test]
142    fn safe_methods_pass_without_origin() {
143        let csrf = localhost_plugin();
144        for method in &["GET", "HEAD", "OPTIONS"] {
145            assert!(csrf.check(method, None, None).is_ok());
146        }
147    }
148
149    #[test]
150    fn safe_methods_pass_with_bad_origin() {
151        let csrf = localhost_plugin();
152        assert!(csrf.check("GET", Some("https://evil.com"), None).is_ok());
153    }
154
155    // -- Matching origin passes --
156
157    #[test]
158    fn matching_origin_passes() {
159        let csrf = localhost_plugin();
160        assert!(csrf
161            .check("POST", Some("http://localhost:3000"), None)
162            .is_ok());
163        assert!(csrf
164            .check("DELETE", Some("http://127.0.0.1:3000"), None)
165            .is_ok());
166    }
167
168    // -- Wrong origin rejected --
169
170    #[test]
171    fn wrong_origin_rejected() {
172        let csrf = localhost_plugin();
173        let err = csrf
174            .check("POST", Some("https://evil.com"), None)
175            .unwrap_err();
176        assert_eq!(err.code, "CSRF_REJECTED");
177        assert_eq!(err.status, 403);
178    }
179
180    // -- Server-to-server callers (no Origin/Referer) pass --
181
182    #[test]
183    fn server_to_server_no_origin_passes() {
184        // Modern browsers always send Origin on state-changing
185        // requests, so absent Origin = not-a-browser = no CSRF
186        // attack surface. Legitimate server-to-server callers
187        // (Next.js SSR, curl --cookie, internal admin tools)
188        // attach the cookie explicitly via Cookie header. Pre-fix
189        // this returned CSRF_NO_ORIGIN and broke server-side POSTs
190        // from the dashboard.
191        let csrf = localhost_plugin();
192        for method in &["POST", "PUT", "PATCH", "DELETE"] {
193            assert!(
194                csrf.check(method, None, None).is_ok(),
195                "{method} with no Origin/Referer should be allowed (server-to-server)"
196            );
197        }
198    }
199
200    // -- Wildcard allows all --
201
202    #[test]
203    fn wildcard_allows_all() {
204        let csrf = CsrfPlugin::new(vec!["*".into()]);
205        assert!(csrf
206            .check("POST", Some("https://anything.example.com"), None)
207            .is_ok());
208        assert!(csrf.check("DELETE", Some("http://evil.com"), None).is_ok());
209    }
210
211    // -- Referer extraction --
212
213    #[test]
214    fn origin_from_referer_extraction() {
215        assert_eq!(
216            CsrfPlugin::origin_from_referer("http://example.com/path?q=1"),
217            Some("http://example.com".into())
218        );
219        assert_eq!(
220            CsrfPlugin::origin_from_referer("https://a.b:8080/x/y"),
221            Some("https://a.b:8080".into())
222        );
223        assert_eq!(CsrfPlugin::origin_from_referer("garbage"), None);
224        assert_eq!(CsrfPlugin::origin_from_referer(""), None);
225    }
226
227    // -- Referer fallback when Origin is missing --
228
229    #[test]
230    fn referer_fallback_when_origin_missing() {
231        let csrf = localhost_plugin();
232        assert!(csrf
233            .check("POST", None, Some("http://localhost:3000/some/path"))
234            .is_ok());
235    }
236
237    #[test]
238    fn referer_fallback_wrong_origin() {
239        let csrf = localhost_plugin();
240        let err = csrf
241            .check("POST", None, Some("https://evil.com/attack"))
242            .unwrap_err();
243        assert_eq!(err.code, "CSRF_REJECTED");
244    }
245
246    // -- All state-changing methods validate present-but-wrong Origin --
247
248    #[test]
249    fn all_state_changing_methods_reject_wrong_origin() {
250        let csrf = localhost_plugin();
251        for method in &["POST", "PUT", "PATCH", "DELETE"] {
252            let err = csrf
253                .check(method, Some("https://evil.com"), None)
254                .unwrap_err();
255            assert_eq!(err.code, "CSRF_REJECTED", "{method} with bad Origin");
256        }
257    }
258
259    // -- Plugin trait --
260
261    #[test]
262    fn plugin_name() {
263        let csrf = localhost_plugin();
264        assert_eq!(crate::Plugin::name(&csrf), "csrf");
265    }
266}