Skip to main content

reinhardt_middleware/
origin_guard.rs

1//! Origin validation middleware for CSRF protection.
2//!
3//! Validates the `Origin` (or `Referer`) header on state-changing HTTP
4//! requests, providing defense-in-depth CSRF protection alongside
5//! `SameSite=Lax` cookies.
6
7use async_trait::async_trait;
8use std::sync::Arc;
9
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11
12/// Middleware that validates the `Origin` or `Referer` header on
13/// state-changing requests as a CSRF protection layer.
14///
15/// Safe methods (`GET`, `HEAD`, `OPTIONS`) are always passed through.
16/// For state-changing methods (`POST`, `PUT`, `DELETE`, `PATCH`) the
17/// middleware checks whether the request origin appears in the
18/// `allowed_origins` list:
19///
20/// 1. Reads the `Origin` header directly.
21/// 2. If absent, falls back to the `Referer` header and extracts the
22///    `scheme://authority` portion.
23/// 3. If the origin matches → request proceeds.
24/// 4. If neither header is present, or the origin does not match →
25///    **403 Forbidden** with body `"Origin validation failed"`.
26///
27/// # Examples
28///
29/// ```rust,no_run
30/// use std::sync::Arc;
31/// use reinhardt_middleware::OriginGuardMiddleware;
32/// use reinhardt_http::MiddlewareChain;
33/// # use reinhardt_http::{Handler, Request, Response, Result};
34/// # use async_trait::async_trait;
35/// # struct MyHandler;
36/// # #[async_trait]
37/// # impl Handler for MyHandler {
38/// #     async fn handle(&self, _request: Request) -> Result<Response> {
39/// #         Ok(Response::ok())
40/// #     }
41/// # }
42/// # let handler = Arc::new(MyHandler);
43///
44/// let middleware = OriginGuardMiddleware::new(vec![
45///     "https://example.com".to_string(),
46///     "https://app.example.com".to_string(),
47/// ]);
48///
49/// let app = MiddlewareChain::new(handler)
50///     .with_middleware(Arc::new(middleware));
51/// ```
52pub struct OriginGuardMiddleware {
53	allowed_origins: Vec<String>,
54}
55
56impl OriginGuardMiddleware {
57	/// Creates a new `OriginGuardMiddleware` with the given list of allowed origins.
58	///
59	/// Each entry should be a `scheme://authority` string such as
60	/// `"https://example.com"` (no trailing slash, no path).
61	///
62	/// # Arguments
63	///
64	/// * `allowed_origins` - Origins that are permitted to make state-changing requests.
65	pub fn new(allowed_origins: Vec<String>) -> Self {
66		Self { allowed_origins }
67	}
68
69	/// Extracts the `scheme://authority` origin from a `Referer` URL string.
70	///
71	/// Returns `None` if the URL cannot be parsed or has no host.
72	fn origin_from_referer(referer: &str) -> Option<String> {
73		let url = url::Url::parse(referer).ok()?;
74		let scheme = url.scheme();
75		let host = url.host_str()?;
76		let port = url.port();
77
78		let origin = if let Some(p) = port {
79			format!("{}://{}:{}", scheme, host, p)
80		} else {
81			format!("{}://{}", scheme, host)
82		};
83
84		Some(origin)
85	}
86
87	/// Returns true if the given origin string appears in `allowed_origins`.
88	fn is_allowed(&self, origin: &str) -> bool {
89		self.allowed_origins.iter().any(|o| o == origin)
90	}
91}
92
93#[async_trait]
94impl Middleware for OriginGuardMiddleware {
95	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
96		let method = request.method.clone();
97
98		// Safe methods always pass through.
99		let is_safe = matches!(method.as_str(), "GET" | "HEAD" | "OPTIONS");
100
101		if is_safe {
102			return next.handle(request).await;
103		}
104
105		// State-changing method: validate Origin / Referer.
106		let origin = request
107			.headers
108			.get("Origin")
109			.and_then(|v| v.to_str().ok())
110			.map(|s| s.to_string())
111			.or_else(|| {
112				request
113					.headers
114					.get("Referer")
115					.and_then(|v| v.to_str().ok())
116					.and_then(Self::origin_from_referer)
117			});
118
119		match origin {
120			Some(ref o) if self.is_allowed(o) => next.handle(request).await,
121			_ => Ok(Response::forbidden().with_body("Origin validation failed")),
122		}
123	}
124}
125
126#[cfg(test)]
127mod tests {
128	use super::*;
129	use bytes::Bytes;
130	use hyper::{HeaderMap, Method, Version};
131	use reinhardt_http::{Handler, Middleware, Request, Response, Result};
132
133	struct PassThroughHandler;
134
135	#[async_trait::async_trait]
136	impl Handler for PassThroughHandler {
137		async fn handle(&self, _request: Request) -> Result<Response> {
138			Ok(Response::ok().with_body("ok"))
139		}
140	}
141
142	fn make_request(method: Method, origin: Option<&str>, referer: Option<&str>) -> Request {
143		let mut headers = HeaderMap::new();
144		if let Some(o) = origin {
145			headers.insert("Origin", o.parse().unwrap());
146		}
147		if let Some(r) = referer {
148			headers.insert("Referer", r.parse().unwrap());
149		}
150		Request::builder()
151			.method(method)
152			.uri("/submit")
153			.version(Version::HTTP_11)
154			.headers(headers)
155			.body(Bytes::new())
156			.build()
157			.unwrap()
158	}
159
160	fn middleware() -> OriginGuardMiddleware {
161		OriginGuardMiddleware::new(vec![
162			"https://example.com".to_string(),
163			"https://app.example.com".to_string(),
164		])
165	}
166
167	fn handler() -> Arc<dyn Handler> {
168		Arc::new(PassThroughHandler)
169	}
170
171	// Safe methods
172
173	#[tokio::test]
174	async fn test_get_always_passes_no_origin() {
175		let mw = middleware();
176		let req = make_request(Method::GET, None, None);
177		let resp = mw.process(req, handler()).await.unwrap();
178		assert_eq!(resp.status.as_u16(), 200);
179	}
180
181	#[tokio::test]
182	async fn test_head_always_passes() {
183		let mw = middleware();
184		let req = make_request(Method::HEAD, None, None);
185		let resp = mw.process(req, handler()).await.unwrap();
186		assert_eq!(resp.status.as_u16(), 200);
187	}
188
189	#[tokio::test]
190	async fn test_options_always_passes() {
191		let mw = middleware();
192		let req = make_request(Method::OPTIONS, None, None);
193		let resp = mw.process(req, handler()).await.unwrap();
194		assert_eq!(resp.status.as_u16(), 200);
195	}
196
197	// POST with valid origin
198
199	#[tokio::test]
200	async fn test_post_with_valid_origin_passes() {
201		let mw = middleware();
202		let req = make_request(Method::POST, Some("https://example.com"), None);
203		let resp = mw.process(req, handler()).await.unwrap();
204		assert_eq!(resp.status.as_u16(), 200);
205	}
206
207	// POST with invalid origin
208
209	#[tokio::test]
210	async fn test_post_with_invalid_origin_returns_403() {
211		let mw = middleware();
212		let req = make_request(Method::POST, Some("https://evil.com"), None);
213		let resp = mw.process(req, handler()).await.unwrap();
214		assert_eq!(resp.status.as_u16(), 403);
215		let body = String::from_utf8(resp.body.to_vec()).unwrap();
216		assert_eq!(body, "Origin validation failed");
217	}
218
219	// POST with no origin but valid referer
220
221	#[tokio::test]
222	async fn test_post_no_origin_valid_referer_passes() {
223		let mw = middleware();
224		let req = make_request(
225			Method::POST,
226			None,
227			Some("https://example.com/some/path?foo=bar"),
228		);
229		let resp = mw.process(req, handler()).await.unwrap();
230		assert_eq!(resp.status.as_u16(), 200);
231	}
232
233	// POST with no origin and no referer
234
235	#[tokio::test]
236	async fn test_post_no_origin_no_referer_returns_403() {
237		let mw = middleware();
238		let req = make_request(Method::POST, None, None);
239		let resp = mw.process(req, handler()).await.unwrap();
240		assert_eq!(resp.status.as_u16(), 403);
241		let body = String::from_utf8(resp.body.to_vec()).unwrap();
242		assert_eq!(body, "Origin validation failed");
243	}
244
245	// DELETE with valid origin
246
247	#[tokio::test]
248	async fn test_delete_with_valid_origin_passes() {
249		let mw = middleware();
250		let req = make_request(Method::DELETE, Some("https://app.example.com"), None);
251		let resp = mw.process(req, handler()).await.unwrap();
252		assert_eq!(resp.status.as_u16(), 200);
253	}
254
255	// PUT with invalid origin
256
257	#[tokio::test]
258	async fn test_put_with_invalid_origin_returns_403() {
259		let mw = middleware();
260		let req = make_request(Method::PUT, Some("https://attacker.example.com"), None);
261		let resp = mw.process(req, handler()).await.unwrap();
262		assert_eq!(resp.status.as_u16(), 403);
263		let body = String::from_utf8(resp.body.to_vec()).unwrap();
264		assert_eq!(body, "Origin validation failed");
265	}
266}