#![allow(deprecated)]
use axum_core::{__define_rejection as define_rejection, extract::FromRequestParts};
use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};
const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";
#[deprecated = "will be removed in the next version; see https://github.com/tokio-rs/axum/issues/3442"]
#[derive(Debug, Clone)]
pub struct Scheme(pub String);
define_rejection! {
#[status = BAD_REQUEST]
#[body = "No scheme found in request"]
pub struct SchemeMissing;
}
impl<S> FromRequestParts<S> for Scheme
where
S: Send + Sync,
{
type Rejection = SchemeMissing;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(scheme) = parse_forwarded(&parts.headers) {
return Ok(Scheme(scheme.to_owned()));
}
if let Some(scheme) = parts
.headers
.get(X_FORWARDED_PROTO_HEADER_KEY)
.and_then(|scheme| scheme.to_str().ok())
{
return Ok(Scheme(scheme.to_owned()));
}
if let Some(scheme) = parts.uri.scheme_str() {
return Ok(Scheme(scheme.to_owned()));
}
Err(SchemeMissing)
}
}
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
let first_value = forwarded_values.split(',').next()?;
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("proto")
.then(|| value.trim().trim_matches('"'))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::TestClient;
use axum::{routing::get, Router};
use http::header::HeaderName;
fn test_client() -> TestClient {
async fn scheme_as_body(Scheme(scheme): Scheme) -> String {
scheme
}
TestClient::new(Router::new().route("/", get(scheme_as_body)))
}
#[crate::test]
async fn forwarded_scheme_parsing() {
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "http");
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "https");
let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "ftp");
let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "ftp");
}
#[crate::test]
async fn x_forwarded_scheme_header() {
let original_scheme = "https";
let scheme = test_client()
.get("/")
.header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme)
.await
.text()
.await;
assert_eq!(scheme, original_scheme);
}
#[crate::test]
async fn precedence_forwarded_over_x_forwarded() {
let scheme = test_client()
.get("/")
.header(X_FORWARDED_PROTO_HEADER_KEY, "https")
.header(FORWARDED, "proto=ftp")
.await
.text()
.await;
assert_eq!(scheme, "ftp");
}
fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}