reinhardt_middleware/
origin_guard.rs1use async_trait::async_trait;
8use std::sync::Arc;
9
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11
12pub struct OriginGuardMiddleware {
53 allowed_origins: Vec<String>,
54}
55
56impl OriginGuardMiddleware {
57 pub fn new(allowed_origins: Vec<String>) -> Self {
66 Self { allowed_origins }
67 }
68
69 fn origin_from_referer(referer: &str) -> Option<String> {
73 let url = url::Url::parse(referer).ok()?;
74 let scheme = url.scheme();
75 let host = url.host_str()?;
76 let port = url.port();
77
78 let origin = if let Some(p) = port {
79 format!("{}://{}:{}", scheme, host, p)
80 } else {
81 format!("{}://{}", scheme, host)
82 };
83
84 Some(origin)
85 }
86
87 fn is_allowed(&self, origin: &str) -> bool {
89 self.allowed_origins.iter().any(|o| o == origin)
90 }
91}
92
93#[async_trait]
94impl Middleware for OriginGuardMiddleware {
95 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
96 let method = request.method.clone();
97
98 let is_safe = matches!(method.as_str(), "GET" | "HEAD" | "OPTIONS");
100
101 if is_safe {
102 return next.handle(request).await;
103 }
104
105 let origin = request
107 .headers
108 .get("Origin")
109 .and_then(|v| v.to_str().ok())
110 .map(|s| s.to_string())
111 .or_else(|| {
112 request
113 .headers
114 .get("Referer")
115 .and_then(|v| v.to_str().ok())
116 .and_then(Self::origin_from_referer)
117 });
118
119 match origin {
120 Some(ref o) if self.is_allowed(o) => next.handle(request).await,
121 _ => Ok(Response::forbidden().with_body("Origin validation failed")),
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use bytes::Bytes;
130 use hyper::{HeaderMap, Method, Version};
131 use reinhardt_http::{Handler, Middleware, Request, Response, Result};
132
133 struct PassThroughHandler;
134
135 #[async_trait::async_trait]
136 impl Handler for PassThroughHandler {
137 async fn handle(&self, _request: Request) -> Result<Response> {
138 Ok(Response::ok().with_body("ok"))
139 }
140 }
141
142 fn make_request(method: Method, origin: Option<&str>, referer: Option<&str>) -> Request {
143 let mut headers = HeaderMap::new();
144 if let Some(o) = origin {
145 headers.insert("Origin", o.parse().unwrap());
146 }
147 if let Some(r) = referer {
148 headers.insert("Referer", r.parse().unwrap());
149 }
150 Request::builder()
151 .method(method)
152 .uri("/submit")
153 .version(Version::HTTP_11)
154 .headers(headers)
155 .body(Bytes::new())
156 .build()
157 .unwrap()
158 }
159
160 fn middleware() -> OriginGuardMiddleware {
161 OriginGuardMiddleware::new(vec![
162 "https://example.com".to_string(),
163 "https://app.example.com".to_string(),
164 ])
165 }
166
167 fn handler() -> Arc<dyn Handler> {
168 Arc::new(PassThroughHandler)
169 }
170
171 #[tokio::test]
174 async fn test_get_always_passes_no_origin() {
175 let mw = middleware();
176 let req = make_request(Method::GET, None, None);
177 let resp = mw.process(req, handler()).await.unwrap();
178 assert_eq!(resp.status.as_u16(), 200);
179 }
180
181 #[tokio::test]
182 async fn test_head_always_passes() {
183 let mw = middleware();
184 let req = make_request(Method::HEAD, None, None);
185 let resp = mw.process(req, handler()).await.unwrap();
186 assert_eq!(resp.status.as_u16(), 200);
187 }
188
189 #[tokio::test]
190 async fn test_options_always_passes() {
191 let mw = middleware();
192 let req = make_request(Method::OPTIONS, None, None);
193 let resp = mw.process(req, handler()).await.unwrap();
194 assert_eq!(resp.status.as_u16(), 200);
195 }
196
197 #[tokio::test]
200 async fn test_post_with_valid_origin_passes() {
201 let mw = middleware();
202 let req = make_request(Method::POST, Some("https://example.com"), None);
203 let resp = mw.process(req, handler()).await.unwrap();
204 assert_eq!(resp.status.as_u16(), 200);
205 }
206
207 #[tokio::test]
210 async fn test_post_with_invalid_origin_returns_403() {
211 let mw = middleware();
212 let req = make_request(Method::POST, Some("https://evil.com"), None);
213 let resp = mw.process(req, handler()).await.unwrap();
214 assert_eq!(resp.status.as_u16(), 403);
215 let body = String::from_utf8(resp.body.to_vec()).unwrap();
216 assert_eq!(body, "Origin validation failed");
217 }
218
219 #[tokio::test]
222 async fn test_post_no_origin_valid_referer_passes() {
223 let mw = middleware();
224 let req = make_request(
225 Method::POST,
226 None,
227 Some("https://example.com/some/path?foo=bar"),
228 );
229 let resp = mw.process(req, handler()).await.unwrap();
230 assert_eq!(resp.status.as_u16(), 200);
231 }
232
233 #[tokio::test]
236 async fn test_post_no_origin_no_referer_returns_403() {
237 let mw = middleware();
238 let req = make_request(Method::POST, None, None);
239 let resp = mw.process(req, handler()).await.unwrap();
240 assert_eq!(resp.status.as_u16(), 403);
241 let body = String::from_utf8(resp.body.to_vec()).unwrap();
242 assert_eq!(body, "Origin validation failed");
243 }
244
245 #[tokio::test]
248 async fn test_delete_with_valid_origin_passes() {
249 let mw = middleware();
250 let req = make_request(Method::DELETE, Some("https://app.example.com"), None);
251 let resp = mw.process(req, handler()).await.unwrap();
252 assert_eq!(resp.status.as_u16(), 200);
253 }
254
255 #[tokio::test]
258 async fn test_put_with_invalid_origin_returns_403() {
259 let mw = middleware();
260 let req = make_request(Method::PUT, Some("https://attacker.example.com"), None);
261 let resp = mw.process(req, handler()).await.unwrap();
262 assert_eq!(resp.status.as_u16(), 403);
263 let body = String::from_utf8(resp.body.to_vec()).unwrap();
264 assert_eq!(body, "Origin validation failed");
265 }
266}