pylon_plugin/builtin/csrf.rs
1use crate::PluginError;
2
3/// CSRF protection plugin.
4///
5/// Validates the `Origin` or `Referer` header on state-changing requests
6/// (POST, PATCH, DELETE, PUT) against a list of allowed origins. This is
7/// complementary to CORS: CORS controls which origins can *read* responses,
8/// while CSRF protection ensures that state-changing requests originate from
9/// trusted sources.
10pub struct CsrfPlugin {
11 allowed_origins: Vec<String>,
12}
13
14impl CsrfPlugin {
15 /// Create a CSRF plugin with explicit allowed origins.
16 pub fn new(allowed_origins: Vec<String>) -> Self {
17 Self { allowed_origins }
18 }
19
20 /// Convenience constructor for local development. Allows both `localhost`
21 /// and `127.0.0.1` on the given port.
22 pub fn with_localhost(port: u16) -> Self {
23 Self::new(vec![
24 format!("http://localhost:{port}"),
25 format!("http://127.0.0.1:{port}"),
26 ])
27 }
28
29 /// Safe (read-only) methods that do not require origin validation.
30 fn is_safe_method(method: &str) -> bool {
31 matches!(method, "GET" | "HEAD" | "OPTIONS")
32 }
33
34 /// Check whether `origin` is in the allowlist. A wildcard entry (`"*"`)
35 /// matches every origin.
36 fn is_allowed_origin(&self, origin: &str) -> bool {
37 self.allowed_origins.iter().any(|o| o == origin || o == "*")
38 }
39
40 /// Extract the origin portion (`scheme://host[:port]`) from a full URL
41 /// such as a `Referer` header value.
42 ///
43 /// ```text
44 /// "http://example.com/path?q=1" -> Some("http://example.com")
45 /// "https://a.b:8080/x" -> Some("https://a.b:8080")
46 /// "garbage" -> None
47 /// ```
48 fn origin_from_referer(referer: &str) -> Option<String> {
49 // Split on '/' keeping at most 4 parts:
50 // "http:" "" "example.com" "path..."
51 let parts: Vec<&str> = referer.splitn(4, '/').collect();
52 if parts.len() >= 3 && !parts[2].is_empty() {
53 Some(format!("{}//{}", parts[0], parts[2]))
54 } else {
55 None
56 }
57 }
58
59 /// Validate an incoming request.
60 ///
61 /// For safe methods this always succeeds. For state-changing
62 /// methods, the `Origin` header is checked first; if absent the
63 /// origin is derived from the `Referer` header.
64 ///
65 /// **CSRF defense model.** Modern browsers always send `Origin`
66 /// on cross-origin state-changing requests — a malicious page
67 /// can't suppress it. Browsers also send `Origin` on same-site
68 /// POSTs in current spec. So a request with NEITHER `Origin` nor
69 /// `Referer` is by definition not a browser request — it's a
70 /// server-to-server caller (Next.js SSR forwarding a session
71 /// cookie, a curl script with `--cookie`, an internal admin
72 /// tool, etc.). Those callers attach the cookie explicitly via
73 /// the `Cookie:` header rather than relying on browser
74 /// auto-attachment, so the cross-site forgery attack surface
75 /// the CSRF gate exists to protect against doesn't apply.
76 ///
77 /// Without this allowance every Next.js dashboard route that
78 /// calls a Pylon mutation server-side (`pylon.json("/api/fn/X",
79 /// {method: "POST"})`) would 403 — Next.js SSR has no Origin to
80 /// send. We learned this the hard way via the dashboard
81 /// "Members" page returning empty after release 0.3.11.
82 ///
83 /// When a header IS present it must match the allowlist; an
84 /// attacker can never inject one, so its presence is always
85 /// trustworthy.
86 pub fn check(
87 &self,
88 method: &str,
89 origin: Option<&str>,
90 referer: Option<&str>,
91 ) -> Result<(), PluginError> {
92 if Self::is_safe_method(method) {
93 return Ok(());
94 }
95
96 let effective_origin = origin
97 .map(String::from)
98 .or_else(|| referer.and_then(Self::origin_from_referer));
99
100 match effective_origin {
101 Some(ref o) if self.is_allowed_origin(o) => Ok(()),
102 Some(ref o) => Err(PluginError {
103 code: "CSRF_REJECTED".into(),
104 message: format!("Origin '{}' not allowed", o),
105 status: 403,
106 }),
107 // Server-to-server caller — see contract above.
108 None => Ok(()),
109 }
110 }
111}
112
113impl crate::Plugin for CsrfPlugin {
114 fn name(&self) -> &str {
115 "csrf"
116 }
117
118 fn on_request(
119 &self,
120 _method: &str,
121 _path: &str,
122 _auth: &pylon_auth::AuthContext,
123 ) -> Result<(), PluginError> {
124 // The Plugin trait's on_request does not receive HTTP headers, so CSRF
125 // validation cannot happen here automatically. Use `check()` at the
126 // HTTP layer where headers are available.
127 Ok(())
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 fn localhost_plugin() -> CsrfPlugin {
136 CsrfPlugin::with_localhost(3000)
137 }
138
139 // -- Safe methods always pass --
140
141 #[test]
142 fn safe_methods_pass_without_origin() {
143 let csrf = localhost_plugin();
144 for method in &["GET", "HEAD", "OPTIONS"] {
145 assert!(csrf.check(method, None, None).is_ok());
146 }
147 }
148
149 #[test]
150 fn safe_methods_pass_with_bad_origin() {
151 let csrf = localhost_plugin();
152 assert!(csrf.check("GET", Some("https://evil.com"), None).is_ok());
153 }
154
155 // -- Matching origin passes --
156
157 #[test]
158 fn matching_origin_passes() {
159 let csrf = localhost_plugin();
160 assert!(csrf
161 .check("POST", Some("http://localhost:3000"), None)
162 .is_ok());
163 assert!(csrf
164 .check("DELETE", Some("http://127.0.0.1:3000"), None)
165 .is_ok());
166 }
167
168 // -- Wrong origin rejected --
169
170 #[test]
171 fn wrong_origin_rejected() {
172 let csrf = localhost_plugin();
173 let err = csrf
174 .check("POST", Some("https://evil.com"), None)
175 .unwrap_err();
176 assert_eq!(err.code, "CSRF_REJECTED");
177 assert_eq!(err.status, 403);
178 }
179
180 // -- Server-to-server callers (no Origin/Referer) pass --
181
182 #[test]
183 fn server_to_server_no_origin_passes() {
184 // Modern browsers always send Origin on state-changing
185 // requests, so absent Origin = not-a-browser = no CSRF
186 // attack surface. Legitimate server-to-server callers
187 // (Next.js SSR, curl --cookie, internal admin tools)
188 // attach the cookie explicitly via Cookie header. Pre-fix
189 // this returned CSRF_NO_ORIGIN and broke server-side POSTs
190 // from the dashboard.
191 let csrf = localhost_plugin();
192 for method in &["POST", "PUT", "PATCH", "DELETE"] {
193 assert!(
194 csrf.check(method, None, None).is_ok(),
195 "{method} with no Origin/Referer should be allowed (server-to-server)"
196 );
197 }
198 }
199
200 // -- Wildcard allows all --
201
202 #[test]
203 fn wildcard_allows_all() {
204 let csrf = CsrfPlugin::new(vec!["*".into()]);
205 assert!(csrf
206 .check("POST", Some("https://anything.example.com"), None)
207 .is_ok());
208 assert!(csrf.check("DELETE", Some("http://evil.com"), None).is_ok());
209 }
210
211 // -- Referer extraction --
212
213 #[test]
214 fn origin_from_referer_extraction() {
215 assert_eq!(
216 CsrfPlugin::origin_from_referer("http://example.com/path?q=1"),
217 Some("http://example.com".into())
218 );
219 assert_eq!(
220 CsrfPlugin::origin_from_referer("https://a.b:8080/x/y"),
221 Some("https://a.b:8080".into())
222 );
223 assert_eq!(CsrfPlugin::origin_from_referer("garbage"), None);
224 assert_eq!(CsrfPlugin::origin_from_referer(""), None);
225 }
226
227 // -- Referer fallback when Origin is missing --
228
229 #[test]
230 fn referer_fallback_when_origin_missing() {
231 let csrf = localhost_plugin();
232 assert!(csrf
233 .check("POST", None, Some("http://localhost:3000/some/path"))
234 .is_ok());
235 }
236
237 #[test]
238 fn referer_fallback_wrong_origin() {
239 let csrf = localhost_plugin();
240 let err = csrf
241 .check("POST", None, Some("https://evil.com/attack"))
242 .unwrap_err();
243 assert_eq!(err.code, "CSRF_REJECTED");
244 }
245
246 // -- All state-changing methods validate present-but-wrong Origin --
247
248 #[test]
249 fn all_state_changing_methods_reject_wrong_origin() {
250 let csrf = localhost_plugin();
251 for method in &["POST", "PUT", "PATCH", "DELETE"] {
252 let err = csrf
253 .check(method, Some("https://evil.com"), None)
254 .unwrap_err();
255 assert_eq!(err.code, "CSRF_REJECTED", "{method} with bad Origin");
256 }
257 }
258
259 // -- Plugin trait --
260
261 #[test]
262 fn plugin_name() {
263 let csrf = localhost_plugin();
264 assert_eq!(crate::Plugin::name(&csrf), "csrf");
265 }
266}