use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::http::{Request, Response, StatusCode};
use tower::{Layer, Service};
use crate::security::trusted_proxies::ResolvedClientIdentity;
pub const DEFAULT_METHOD_OVERRIDE_FIELD: &str = "_method";
const MAX_BODY_SCAN_BYTES: usize = 2 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct MethodOverrideConfig {
pub field_name: String,
}
impl Default for MethodOverrideConfig {
fn default() -> Self {
Self {
field_name: DEFAULT_METHOD_OVERRIDE_FIELD.to_owned(),
}
}
}
#[derive(Clone, Debug)]
pub struct OverriddenMethod {
pub transport: http::Method,
pub effective: http::Method,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MethodOverrideRejection {
InvalidValue,
BodyTooLarge,
}
pub async fn method_override_rejection_filter(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
use crate::middleware::exception_filter::AutumnErrorInfo;
use axum::response::IntoResponse;
if let Some(rejection) = request
.extensions()
.get::<MethodOverrideRejection>()
.copied()
{
let (status, message) = match rejection {
MethodOverrideRejection::InvalidValue => (
StatusCode::BAD_REQUEST,
"Invalid method override value: must be PUT, PATCH, or DELETE.",
),
MethodOverrideRejection::BodyTooLarge => (
StatusCode::PAYLOAD_TOO_LARGE,
"Form body too large for method-override scanning.",
),
};
let mut response = (
status,
[(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain; charset=utf-8"),
)],
message,
)
.into_response();
response.extensions_mut().insert(AutumnErrorInfo {
status,
message: message.to_owned(),
details: None,
problem_type: None,
backtrace_string: None,
});
return response;
}
next.run(request).await
}
#[derive(Clone, Debug)]
pub struct MethodOverrideLayer {
config: Arc<MethodOverrideConfig>,
max_scan_bytes: usize,
}
impl MethodOverrideLayer {
#[must_use]
pub fn new() -> Self {
Self::from_config(MethodOverrideConfig::default())
}
#[must_use]
pub fn from_config(config: MethodOverrideConfig) -> Self {
Self {
config: Arc::new(config),
max_scan_bytes: MAX_BODY_SCAN_BYTES,
}
}
#[must_use]
pub(crate) fn with_max_scan_bytes(mut self, n: usize) -> Self {
self.max_scan_bytes = n.min(MAX_BODY_SCAN_BYTES);
self
}
}
impl Default for MethodOverrideLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for MethodOverrideLayer {
type Service = MethodOverrideService<S>;
fn layer(&self, inner: S) -> Self::Service {
MethodOverrideService {
inner,
config: Arc::clone(&self.config),
max_scan_bytes: self.max_scan_bytes,
}
}
}
#[derive(Clone, Debug)]
pub struct MethodOverrideService<S> {
inner: S,
config: Arc<MethodOverrideConfig>,
max_scan_bytes: usize,
}
fn parse_override_value(value: &str) -> Option<http::Method> {
let trimmed = value.trim();
if trimmed.eq_ignore_ascii_case("PUT") {
Some(http::Method::PUT)
} else if trimmed.eq_ignore_ascii_case("PATCH") {
Some(http::Method::PATCH)
} else if trimmed.eq_ignore_ascii_case("DELETE") {
Some(http::Method::DELETE)
} else {
None
}
}
#[derive(Debug)]
enum OverrideOutcome {
None,
Replace(http::Method),
Invalid,
}
fn scan_form_for_override(bytes: &[u8], field: &str) -> OverrideOutcome {
let mut outcome = OverrideOutcome::None;
for (key, value) in url::form_urlencoded::parse(bytes) {
if key == field {
outcome = parse_override_value(&value)
.map_or(OverrideOutcome::Invalid, OverrideOutcome::Replace);
break;
}
}
outcome
}
fn is_form_urlencoded(headers: &http::HeaderMap) -> bool {
headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.starts_with("application/x-www-form-urlencoded"))
}
fn is_same_origin_form_request(req: &Request<axum::body::Body>) -> bool {
let identity = req.extensions().get::<ResolvedClientIdentity>();
let headers = req.headers();
let Some(origin) = headers
.get(http::header::ORIGIN)
.and_then(|v| v.to_str().ok())
else {
if let Some(site) = headers.get("sec-fetch-site").and_then(|v| v.to_str().ok()) {
return matches!(site, "same-origin" | "none");
}
return false;
};
if let Some(site) = headers.get("sec-fetch-site").and_then(|v| v.to_str().ok()) {
return match site {
"same-origin" | "none" => true,
"same-site" => origin_matches_request_with_identity(origin, headers, identity),
_ => false,
};
}
origin_matches_request_with_identity(origin, headers, identity)
}
fn origin_matches_request_with_identity(
origin: &str,
headers: &http::HeaderMap,
identity: Option<&ResolvedClientIdentity>,
) -> bool {
let Some((origin_scheme, origin_authority)) = parse_origin(origin) else {
return false;
};
let expected_host: Option<std::borrow::Cow<str>> = identity.map_or_else(
|| {
headers
.get("x-forwarded-host")
.and_then(|v| v.to_str().ok())
.or_else(|| {
headers
.get(http::header::HOST)
.and_then(|v| v.to_str().ok())
})
.map(std::borrow::Cow::Borrowed)
},
|id| {
id.host
.as_deref()
.or_else(|| {
headers
.get(http::header::HOST)
.and_then(|v| v.to_str().ok())
})
.map(std::borrow::Cow::Borrowed)
},
);
let Some(expected_host) = expected_host else {
return false;
};
if !origin_authority.eq_ignore_ascii_case(expected_host.as_ref()) {
return false;
}
let resolved_scheme: Option<String> = identity.map_or_else(
|| {
headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(|s| {
s.split(',').next().unwrap_or(s).trim().to_ascii_lowercase()
})
},
|id| id.scheme.clone(),
);
if let Some(scheme) = resolved_scheme {
return scheme.eq_ignore_ascii_case(origin_scheme);
}
true
}
fn parse_origin(origin: &str) -> Option<(&str, &str)> {
let (scheme, rest) = origin.split_once("://")?;
if !scheme.eq_ignore_ascii_case("http") && !scheme.eq_ignore_ascii_case("https") {
return None;
}
let authority = rest.split('/').next()?;
if authority.is_empty() {
return None;
}
Some((scheme, authority))
}
impl<S, ResBody> Service<Request<axum::body::Body>> for MethodOverrideService<S>
where
S: Service<Request<axum::body::Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ResBody: From<&'static str> + Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<axum::body::Body>) -> Self::Future {
if req.method() != http::Method::POST
|| !is_form_urlencoded(req.headers())
|| !is_same_origin_form_request(&req)
{
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
return Box::pin(async move { inner.call(req).await });
}
let config = Arc::clone(&self.config);
let max_scan_bytes = self.max_scan_bytes;
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let body = std::mem::replace(req.body_mut(), axum::body::Body::empty());
let Ok(bytes) = axum::body::to_bytes(body, max_scan_bytes).await else {
req.extensions_mut()
.insert(MethodOverrideRejection::BodyTooLarge);
let res = inner.call(req).await?;
if res.status() == StatusCode::METHOD_NOT_ALLOWED
|| res.status() == StatusCode::NOT_FOUND
{
use crate::middleware::exception_filter::AutumnErrorInfo;
let (parts, _body) = res.into_parts();
let mut res = Response::from_parts(
parts,
ResBody::from("Form body too large for method-override scanning."),
);
*res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain; charset=utf-8"),
);
res.extensions_mut().insert(AutumnErrorInfo {
status: StatusCode::PAYLOAD_TOO_LARGE,
message: "Form body too large for method-override scanning.".to_owned(),
details: None,
problem_type: None,
backtrace_string: None,
});
return Ok(res);
}
return Ok(res);
};
let outcome = scan_form_for_override(&bytes, &config.field_name);
*req.body_mut() = axum::body::Body::from(bytes);
match outcome {
OverrideOutcome::None => inner.call(req).await,
OverrideOutcome::Replace(method) => {
let transport = req.method().clone();
*req.method_mut() = method.clone();
req.extensions_mut().insert(OverriddenMethod {
transport,
effective: method,
});
inner.call(req).await
}
OverrideOutcome::Invalid => {
req.extensions_mut()
.insert(MethodOverrideRejection::InvalidValue);
let res = inner.call(req).await?;
if res.status() == StatusCode::METHOD_NOT_ALLOWED
|| res.status() == StatusCode::NOT_FOUND
{
use crate::middleware::exception_filter::AutumnErrorInfo;
let (parts, _body) = res.into_parts();
let mut res = Response::from_parts(
parts,
ResBody::from(
"Invalid method override value: must be PUT, PATCH, or DELETE.",
),
);
*res.status_mut() = StatusCode::BAD_REQUEST;
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain; charset=utf-8"),
);
res.extensions_mut().insert(AutumnErrorInfo {
status: StatusCode::BAD_REQUEST,
message:
"Invalid method override value: must be PUT, PATCH, or DELETE."
.to_owned(),
details: None,
problem_type: None,
backtrace_string: None,
});
return Ok(res);
}
Ok(res)
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::routing::{delete, get, patch, post, put};
use tower::ServiceExt;
fn layered_router() -> MethodOverrideService<Router> {
let router = Router::new()
.route("/items", post(|| async { "created" }))
.route("/items/{id}", put(|| async { "put-ok" }))
.route("/items/{id}", patch(|| async { "patch-ok" }))
.route("/items/{id}", delete(|| async { "delete-ok" }))
.route("/items/{id}", get(|| async { "show" }))
.layer(axum::middleware::from_fn(method_override_rejection_filter));
MethodOverrideLayer::new().layer(router)
}
#[tokio::test]
async fn post_without_override_field_reaches_post_handler() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("title=hello"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"created");
}
#[tokio::test]
async fn post_with_method_put_reaches_put_handler() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=PUT&title=hi"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"put-ok");
}
#[tokio::test]
async fn post_with_method_patch_reaches_patch_handler() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=patch"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"patch-ok");
}
#[tokio::test]
async fn post_with_method_delete_reaches_delete_handler() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"delete-ok");
}
#[tokio::test]
async fn invalid_override_value_returns_400() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=BREW"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn override_ignored_without_form_content_type() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items")
.header("content-type", "application/json")
.body(Body::from(r#"{"_method":"DELETE","title":"hi"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"created");
}
#[tokio::test]
async fn override_ignored_for_get_requests() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("GET")
.uri("/items/1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"show");
}
#[tokio::test]
async fn override_preserves_body_for_handler() {
async fn echo(body: String) -> String {
body
}
let router = Router::new().route("/echo", put(echo));
let app = MethodOverrideLayer::new().layer(router);
let payload = "_method=PUT&title=hello&count=3";
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/echo")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from(payload))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), payload);
}
#[tokio::test]
async fn override_marks_request_extension() {
async fn marker(axum::Extension(o): axum::Extension<OverriddenMethod>) -> String {
format!("{}->{}", o.transport, o.effective)
}
let router = Router::new().route("/x", delete(marker));
let app = MethodOverrideLayer::new().layer(router);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/x")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"POST->DELETE");
}
#[tokio::test]
async fn custom_field_name_is_respected() {
let router = Router::new().route("/items/{id}", delete(|| async { "gone" }));
let app = MethodOverrideLayer::from_config(MethodOverrideConfig {
field_name: "x-method".into(),
})
.layer(router);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("x-method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"gone");
}
#[tokio::test]
async fn overridden_delete_without_csrf_token_is_rejected() {
let csrf_config = crate::security::CsrfConfig {
enabled: true,
..Default::default()
};
let router = Router::new()
.route("/items/{id}", delete(|| async { "deleted" }))
.layer(crate::security::CsrfLayer::from_config(&csrf_config));
let app = MethodOverrideLayer::new().layer(router);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.header(http::header::ACCEPT, "text/html")
.header("Cookie", "autumn-csrf=valid-cookie-token")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn overridden_delete_with_csrf_token_reaches_handler() {
let csrf_config = crate::security::CsrfConfig {
enabled: true,
..Default::default()
};
let router = Router::new()
.route("/items/{id}", delete(|| async { "deleted" }))
.layer(crate::security::CsrfLayer::from_config(&csrf_config));
let app = MethodOverrideLayer::new().layer(router);
let token = "valid-cookie-token";
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.header("Cookie", format!("autumn-csrf={token}"))
.body(Body::from(format!("_csrf={token}&_method=DELETE")))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024)
.await
.unwrap();
assert_eq!(&body[..], b"deleted");
}
#[tokio::test]
async fn oversized_form_post_without_override_field_is_rejected() {
async fn measure(body: bytes::Bytes) -> String {
format!("{}", body.len())
}
let router = Router::new()
.route("/upload", post(measure))
.layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024))
.layer(axum::middleware::from_fn(method_override_rejection_filter));
let app = MethodOverrideLayer::new().layer(router);
let big = "x".repeat(3 * 1024 * 1024);
let payload = format!("title={big}");
let payload_len = payload.len();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/upload")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.header("content-length", payload_len.to_string())
.body(Body::from(payload))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn unbounded_oversized_form_post_returns_413() {
async fn handler(body: String) -> String {
body
}
let router = Router::new()
.route("/x", post(handler))
.layer(axum::middleware::from_fn(method_override_rejection_filter));
let app = MethodOverrideLayer::new().layer(router);
let payload = "x".repeat(3 * 1024 * 1024);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/x")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from(payload))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
use std::sync::Mutex;
use std::sync::OnceLock;
static OBSERVED_REJECTION: OnceLock<Mutex<Option<MethodOverrideRejection>>> = OnceLock::new();
async fn capture_rejection(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
if let Some(rej) = request
.extensions()
.get::<MethodOverrideRejection>()
.copied()
{
*OBSERVED_REJECTION
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap() = Some(rej);
}
next.run(request).await
}
#[tokio::test]
async fn outer_layer_stamps_extension_without_short_circuiting() {
let cell = OBSERVED_REJECTION.get_or_init(|| Mutex::new(None));
cell.lock().unwrap().take();
let router = Router::new()
.route("/x", post(|| async { "post-ok" }))
.layer(axum::middleware::from_fn(capture_rejection));
let app = MethodOverrideLayer::new().layer(router);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/x")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=BREW"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
*cell.lock().unwrap(),
Some(MethodOverrideRejection::InvalidValue)
);
}
#[tokio::test]
async fn cross_site_form_is_not_honoured_as_override() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "cross-site")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::METHOD_NOT_ALLOWED,
"cross-site override must not be applied; expected the inner \
router to reject the underlying POST"
);
}
#[tokio::test]
async fn origin_host_mismatch_is_not_honoured_as_override() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://evil.example")
.header("host", "app.example")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn origin_host_match_is_honoured_as_override() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://app.example")
.header("host", "app.example")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn same_site_sibling_origin_is_not_honoured_as_override() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-site")
.header("origin", "https://evil.example")
.header("host", "app.example")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::METHOD_NOT_ALLOWED,
"same-site with mismatched Origin/Host must not be honoured as override"
);
}
#[tokio::test]
async fn missing_origin_signals_fail_closed() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn origin_scheme_mismatch_via_forwarded_proto_is_rejected() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "http://app.example")
.header("host", "app.example")
.header("x-forwarded-proto", "https")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::METHOD_NOT_ALLOWED,
"different Origin scheme is not same-origin"
);
}
#[tokio::test]
async fn origin_scheme_match_via_chained_forwarded_proto() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://app.example")
.header("host", "app.example")
.header("x-forwarded-proto", "https, http")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn forwarded_host_takes_precedence_over_host_header() {
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://app.example")
.header("host", "internal.cluster.local")
.header("x-forwarded-host", "app.example")
.header("x-forwarded-proto", "https")
.body(Body::from("_method=DELETE"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn parse_origin_accepts_http_and_https_only() {
assert_eq!(
parse_origin("https://app.example"),
Some(("https", "app.example"))
);
assert_eq!(
parse_origin("http://app.example:8080"),
Some(("http", "app.example:8080"))
);
assert_eq!(parse_origin("null"), None);
assert_eq!(parse_origin("file:///etc/passwd"), None);
assert_eq!(parse_origin("javascript:alert(1)"), None);
assert_eq!(parse_origin("app.example"), None);
assert_eq!(parse_origin("https://"), None);
}
fn req_from_headers(headers: http::HeaderMap) -> axum::extract::Request<Body> {
let mut req = axum::http::Request::builder().body(Body::empty()).unwrap();
*req.headers_mut() = headers;
req
}
#[test]
fn is_same_origin_form_decision_matrix() {
for site in ["same-origin", "none"] {
let mut h = http::HeaderMap::new();
h.insert("sec-fetch-site", http::HeaderValue::from_str(site).unwrap());
assert!(
is_same_origin_form_request(&req_from_headers(h)),
"sec-fetch-site={site} should be allowed"
);
}
let mut h = http::HeaderMap::new();
h.insert(
"sec-fetch-site",
http::HeaderValue::from_static("same-site"),
);
assert!(
!is_same_origin_form_request(&req_from_headers(h)),
"same-site alone must not be accepted"
);
let mut h = http::HeaderMap::new();
h.insert(
"sec-fetch-site",
http::HeaderValue::from_static("same-site"),
);
h.insert(
http::header::ORIGIN,
http::HeaderValue::from_static("https://app.example"),
);
h.insert(
http::header::HOST,
http::HeaderValue::from_static("app.example"),
);
assert!(is_same_origin_form_request(&req_from_headers(h)));
let mut h = http::HeaderMap::new();
h.insert(
"sec-fetch-site",
http::HeaderValue::from_static("same-site"),
);
h.insert(
http::header::ORIGIN,
http::HeaderValue::from_static("https://evil.example"),
);
h.insert(
http::header::HOST,
http::HeaderValue::from_static("app.example"),
);
assert!(
!is_same_origin_form_request(&req_from_headers(h)),
"same-site with mismatched Origin/Host must be rejected"
);
let mut h = http::HeaderMap::new();
h.insert(
"sec-fetch-site",
http::HeaderValue::from_static("cross-site"),
);
assert!(!is_same_origin_form_request(&req_from_headers(h)));
let mut h = http::HeaderMap::new();
h.insert(
"sec-fetch-site",
http::HeaderValue::from_static("undefined"),
);
assert!(!is_same_origin_form_request(&req_from_headers(h)));
let mut h = http::HeaderMap::new();
h.insert(
http::header::ORIGIN,
http::HeaderValue::from_static("https://app.example"),
);
h.insert(
http::header::HOST,
http::HeaderValue::from_static("app.example"),
);
assert!(is_same_origin_form_request(&req_from_headers(h)));
let mut h = http::HeaderMap::new();
h.insert(
http::header::ORIGIN,
http::HeaderValue::from_static("https://evil.example"),
);
h.insert(
http::header::HOST,
http::HeaderValue::from_static("app.example"),
);
assert!(!is_same_origin_form_request(&req_from_headers(h)));
let h = http::HeaderMap::new();
assert!(!is_same_origin_form_request(&req_from_headers(h)));
}
#[tokio::test]
async fn rejection_response_carries_autumn_error_info() {
use crate::middleware::exception_filter::AutumnErrorInfo;
let app = layered_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=BREW"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let info = response
.extensions()
.get::<AutumnErrorInfo>()
.expect("override rejection must carry AutumnErrorInfo");
assert_eq!(info.status, StatusCode::BAD_REQUEST);
assert!(info.message.contains("PUT, PATCH, or DELETE"));
}
#[tokio::test]
async fn with_max_scan_bytes_rejects_body_exceeding_custom_cap() {
let router = Router::new()
.route("/items", post(|| async { "ok" }))
.layer(axum::middleware::from_fn(method_override_rejection_filter));
let service =
tower::Layer::layer(&MethodOverrideLayer::new().with_max_scan_bytes(10), router);
let response = service
.oneshot(
Request::builder()
.method("POST")
.uri("/items")
.header("content-type", "application/x-www-form-urlencoded")
.header("sec-fetch-site", "same-origin")
.body(Body::from("_method=DELETE&x=123456789012"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn pr791_poc_with_resolved_identity_uses_validated_host() {
use crate::security::trusted_proxies::ResolvedClientIdentity;
let app = layered_router();
let mut req = Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://app.example")
.header("host", "internal.cluster.local")
.header("x-forwarded-host", "app.example")
.header("x-forwarded-proto", "https")
.body(Body::from("_method=DELETE"))
.unwrap();
req.extensions_mut().insert(ResolvedClientIdentity {
addr: None,
host: Some("app.example".to_owned()),
scheme: Some("https".to_owned()),
});
let response = app.oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn pr791_poc_resolved_identity_different_host_is_rejected() {
use crate::security::trusted_proxies::ResolvedClientIdentity;
let app = layered_router();
let mut req = Request::builder()
.method("POST")
.uri("/items/1")
.header("content-type", "application/x-www-form-urlencoded")
.header("origin", "https://app.example")
.header("host", "internal.cluster.local")
.header("x-forwarded-host", "app.example")
.header("x-forwarded-proto", "https")
.body(Body::from("_method=DELETE"))
.unwrap();
req.extensions_mut().insert(ResolvedClientIdentity {
addr: None,
host: Some("internal.cluster.local".to_owned()),
scheme: Some("http".to_owned()),
});
let response = app.oneshot(req).await.unwrap();
assert_eq!(
response.status(),
StatusCode::METHOD_NOT_ALLOWED,
"resolver-validated host mismatch must reject the override"
);
}
}