pylon_plugin/builtin/
csrf.rs1use crate::PluginError;
2
3pub struct CsrfPlugin {
11 allowed_origins: Vec<String>,
12}
13
14impl CsrfPlugin {
15 pub fn new(allowed_origins: Vec<String>) -> Self {
17 Self { allowed_origins }
18 }
19
20 pub fn with_localhost(port: u16) -> Self {
23 Self::new(vec![
24 format!("http://localhost:{port}"),
25 format!("http://127.0.0.1:{port}"),
26 ])
27 }
28
29 fn is_safe_method(method: &str) -> bool {
31 matches!(method, "GET" | "HEAD" | "OPTIONS")
32 }
33
34 fn is_allowed_origin(&self, origin: &str) -> bool {
37 self.allowed_origins.iter().any(|o| o == origin || o == "*")
38 }
39
40 fn origin_from_referer(referer: &str) -> Option<String> {
49 let parts: Vec<&str> = referer.splitn(4, '/').collect();
52 if parts.len() >= 3 && !parts[2].is_empty() {
53 Some(format!("{}//{}", parts[0], parts[2]))
54 } else {
55 None
56 }
57 }
58
59 pub fn check(
66 &self,
67 method: &str,
68 origin: Option<&str>,
69 referer: Option<&str>,
70 ) -> Result<(), PluginError> {
71 if Self::is_safe_method(method) {
72 return Ok(());
73 }
74
75 let effective_origin = origin
76 .map(String::from)
77 .or_else(|| referer.and_then(Self::origin_from_referer));
78
79 match effective_origin {
80 Some(ref o) if self.is_allowed_origin(o) => Ok(()),
81 Some(ref o) => Err(PluginError {
82 code: "CSRF_REJECTED".into(),
83 message: format!("Origin '{}' not allowed", o),
84 status: 403,
85 }),
86 None => Err(PluginError {
87 code: "CSRF_NO_ORIGIN".into(),
88 message: "Missing Origin header on state-changing request".into(),
89 status: 403,
90 }),
91 }
92 }
93}
94
95impl crate::Plugin for CsrfPlugin {
96 fn name(&self) -> &str {
97 "csrf"
98 }
99
100 fn on_request(
101 &self,
102 _method: &str,
103 _path: &str,
104 _auth: &pylon_auth::AuthContext,
105 ) -> Result<(), PluginError> {
106 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn localhost_plugin() -> CsrfPlugin {
118 CsrfPlugin::with_localhost(3000)
119 }
120
121 #[test]
124 fn safe_methods_pass_without_origin() {
125 let csrf = localhost_plugin();
126 for method in &["GET", "HEAD", "OPTIONS"] {
127 assert!(csrf.check(method, None, None).is_ok());
128 }
129 }
130
131 #[test]
132 fn safe_methods_pass_with_bad_origin() {
133 let csrf = localhost_plugin();
134 assert!(csrf.check("GET", Some("https://evil.com"), None).is_ok());
135 }
136
137 #[test]
140 fn matching_origin_passes() {
141 let csrf = localhost_plugin();
142 assert!(csrf
143 .check("POST", Some("http://localhost:3000"), None)
144 .is_ok());
145 assert!(csrf
146 .check("DELETE", Some("http://127.0.0.1:3000"), None)
147 .is_ok());
148 }
149
150 #[test]
153 fn wrong_origin_rejected() {
154 let csrf = localhost_plugin();
155 let err = csrf
156 .check("POST", Some("https://evil.com"), None)
157 .unwrap_err();
158 assert_eq!(err.code, "CSRF_REJECTED");
159 assert_eq!(err.status, 403);
160 }
161
162 #[test]
165 fn missing_origin_rejected() {
166 let csrf = localhost_plugin();
167 let err = csrf.check("PUT", None, None).unwrap_err();
168 assert_eq!(err.code, "CSRF_NO_ORIGIN");
169 assert_eq!(err.status, 403);
170 }
171
172 #[test]
175 fn wildcard_allows_all() {
176 let csrf = CsrfPlugin::new(vec!["*".into()]);
177 assert!(csrf
178 .check("POST", Some("https://anything.example.com"), None)
179 .is_ok());
180 assert!(csrf.check("DELETE", Some("http://evil.com"), None).is_ok());
181 }
182
183 #[test]
186 fn origin_from_referer_extraction() {
187 assert_eq!(
188 CsrfPlugin::origin_from_referer("http://example.com/path?q=1"),
189 Some("http://example.com".into())
190 );
191 assert_eq!(
192 CsrfPlugin::origin_from_referer("https://a.b:8080/x/y"),
193 Some("https://a.b:8080".into())
194 );
195 assert_eq!(CsrfPlugin::origin_from_referer("garbage"), None);
196 assert_eq!(CsrfPlugin::origin_from_referer(""), None);
197 }
198
199 #[test]
202 fn referer_fallback_when_origin_missing() {
203 let csrf = localhost_plugin();
204 assert!(csrf
205 .check("POST", None, Some("http://localhost:3000/some/path"))
206 .is_ok());
207 }
208
209 #[test]
210 fn referer_fallback_wrong_origin() {
211 let csrf = localhost_plugin();
212 let err = csrf
213 .check("POST", None, Some("https://evil.com/attack"))
214 .unwrap_err();
215 assert_eq!(err.code, "CSRF_REJECTED");
216 }
217
218 #[test]
221 fn all_state_changing_methods_require_origin() {
222 let csrf = localhost_plugin();
223 for method in &["POST", "PUT", "PATCH", "DELETE"] {
224 assert!(csrf.check(method, None, None).is_err());
225 }
226 }
227
228 #[test]
231 fn plugin_name() {
232 let csrf = localhost_plugin();
233 assert_eq!(crate::Plugin::name(&csrf), "csrf");
234 }
235}