use axum::{
body::Body,
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[derive(Debug, Clone)]
pub struct AdminAuth {
pub token: Option<String>,
}
pub async fn admin_auth_middleware(
axum::extract::Extension(auth): axum::extract::Extension<AdminAuth>,
req: Request<Body>,
next: Next,
) -> Response {
if let Some(expected) = &auth.token {
if bearer_token_matches(req.headers(), expected) {
next.run(req).await
} else {
unauthorized_response("Missing or invalid admin bearer token")
}
} else {
if is_loopback(&req) {
next.run(req).await
} else {
unauthorized_response(
"Admin API requires a bearer token when not accessed from loopback",
)
}
}
}
fn bearer_token_matches(headers: &header::HeaderMap, expected: &str) -> bool {
let Some(value) = headers.get(header::AUTHORIZATION) else {
return false;
};
let Ok(str_val) = value.to_str() else {
return false;
};
let Some(token) = str_val.strip_prefix("Bearer ") else {
return false;
};
token == expected
}
fn is_loopback(req: &Request<Body>) -> bool {
if let Some(xff) = req.headers().get("x-forwarded-for") {
if let Ok(val) = xff.to_str() {
let first = val.split(',').next().unwrap_or("").trim();
if let Ok(ip) = first.parse::<std::net::IpAddr>() {
return ip.is_loopback();
}
}
}
if let Some(xri) = req.headers().get("x-real-ip") {
if let Ok(val) = xri.to_str() {
if let Ok(ip) = val.trim().parse::<std::net::IpAddr>() {
return ip.is_loopback();
}
}
}
true
}
fn unauthorized_response(message: &str) -> Response {
let body = serde_json::json!({
"error": {
"message": message,
"type": "authentication_error",
}
});
(StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
}
pub fn ensure_admin_security(admin_host: &str, token: &Option<String>) {
if token.is_some() {
return; }
let host = admin_host.split(':').next().unwrap_or(admin_host).trim();
let is_loopback_addr = matches!(
host.parse::<std::net::IpAddr>(),
Ok(ip) if ip.is_loopback()
) || matches!(host, "localhost");
if !is_loopback_addr {
eprintln!(
"FATAL: admin listen address '{admin_host}' is non-loopback \
but no admin bearer_token is configured.\n\
Set [admin] bearer_token = \"...\" in server.toml or \
bind the admin interface to 127.0.0.1."
);
std::process::exit(1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
fn make_req_with_auth(token: &str) -> Request<Body> {
Request::builder()
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.expect("build request")
}
fn make_req_without_auth() -> Request<Body> {
Request::builder()
.body(Body::empty())
.expect("build request")
}
#[test]
fn bearer_matches_correct_token() {
let req = make_req_with_auth("secret");
assert!(bearer_token_matches(req.headers(), "secret"));
}
#[test]
fn bearer_rejects_wrong_token() {
let req = make_req_with_auth("wrong");
assert!(!bearer_token_matches(req.headers(), "secret"));
}
#[test]
fn bearer_rejects_missing_header() {
let req = make_req_without_auth();
assert!(!bearer_token_matches(req.headers(), "secret"));
}
#[test]
fn bearer_rejects_basic_scheme() {
let req = Request::builder()
.header("authorization", "Basic dXNlcjpwYXNz")
.body(Body::empty())
.expect("build");
assert!(!bearer_token_matches(req.headers(), "secret"));
}
#[test]
fn ensure_admin_security_passes_with_token() {
ensure_admin_security("0.0.0.0:8888", &Some("tok".to_string()));
}
#[test]
fn ensure_admin_security_passes_loopback_no_token() {
ensure_admin_security("127.0.0.1:8888", &None);
ensure_admin_security("localhost:8888", &None);
}
}