use async_trait::async_trait;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;
pub struct AdminOriginGuardMiddleware;
#[async_trait]
impl Middleware for AdminOriginGuardMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
if is_safe_method(&request.method) {
return next.handle(request).await;
}
if !is_same_origin(&request.headers) {
tracing::warn!(
method = %request.method,
uri = %request.uri,
"Admin origin guard: cross-origin or missing origin"
);
return Ok(Response::new(hyper::StatusCode::FORBIDDEN)
.with_header("Content-Type", "application/json")
.with_body(
r#"{"error":"Forbidden: cross-origin admin requests are not allowed"}"#,
));
}
next.handle(request).await
}
}
fn is_safe_method(method: &hyper::Method) -> bool {
matches!(
*method,
hyper::Method::GET | hyper::Method::HEAD | hyper::Method::OPTIONS
)
}
fn is_same_origin(headers: &hyper::HeaderMap) -> bool {
let host = match headers
.get(hyper::header::HOST)
.and_then(|v| v.to_str().ok())
{
Some(h) => h,
None => return false,
};
if let Some(origin) = headers
.get(hyper::header::ORIGIN)
.and_then(|v| v.to_str().ok())
{
return origin_matches_host(origin, host);
}
if let Some(referer) = headers
.get(hyper::header::REFERER)
.and_then(|v| v.to_str().ok())
{
return referer_matches_host(referer, host);
}
false
}
fn origin_matches_host(origin: &str, host: &str) -> bool {
let origin_host = origin.split("://").nth(1).unwrap_or(origin);
let origin_host = origin_host.trim_end_matches('/');
origin_host == host
}
fn referer_matches_host(referer: &str, host: &str) -> bool {
let after_scheme = referer.split("://").nth(1).unwrap_or(referer);
let referer_host = after_scheme.split('/').next().unwrap_or(after_scheme);
referer_host == host
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct PassthroughHandler;
#[async_trait]
impl Handler for PassthroughHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body("ok"))
}
}
fn make_request(method: Method, headers: HeaderMap) -> Request {
Request::builder()
.method(method)
.uri("/api/server_fn/get_list")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap()
}
#[tokio::test]
async fn test_get_request_passes_through() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let req = make_request(Method::GET, HeaderMap::new());
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[tokio::test]
async fn test_post_without_origin_returns_403() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "localhost:8000".parse().unwrap());
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_post_without_host_returns_403() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(
hyper::header::ORIGIN,
"http://localhost:8000".parse().unwrap(),
);
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_post_same_origin_passes() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "localhost:8000".parse().unwrap());
headers.insert(
hyper::header::ORIGIN,
"http://localhost:8000".parse().unwrap(),
);
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[tokio::test]
async fn test_post_different_origin_returns_403() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "localhost:8000".parse().unwrap());
headers.insert(hyper::header::ORIGIN, "http://evil.com".parse().unwrap());
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_post_referer_same_origin_passes() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
headers.insert(
hyper::header::REFERER,
"https://example.com/admin/".parse().unwrap(),
);
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[tokio::test]
async fn test_post_referer_different_origin_returns_403() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
headers.insert(
hyper::header::REFERER,
"https://evil.com/admin/".parse().unwrap(),
);
let req = make_request(Method::POST, headers);
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_options_request_passes_through() {
let mw = AdminOriginGuardMiddleware;
let next = Arc::new(PassthroughHandler);
let req = make_request(Method::OPTIONS, HeaderMap::new());
let resp = mw.process(req, next).await.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[test]
fn test_origin_matches_host() {
assert!(origin_matches_host(
"http://localhost:8000",
"localhost:8000"
));
assert!(origin_matches_host("https://example.com", "example.com"));
assert!(!origin_matches_host("http://evil.com", "example.com"));
assert!(!origin_matches_host(
"http://localhost:9000",
"localhost:8000"
));
}
#[test]
fn test_referer_matches_host() {
assert!(referer_matches_host(
"http://localhost:8000/admin/",
"localhost:8000"
));
assert!(referer_matches_host(
"https://example.com/admin/model/",
"example.com"
));
assert!(!referer_matches_host(
"http://evil.com/admin/",
"example.com"
));
}
}