use axum::{
extract::{Request, State},
http::{Method, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum DashboardAuth {
Basic { username: String, password: String },
Bearer { token: String },
}
pub async fn require_auth(
State(auth): State<Arc<DashboardAuth>>,
req: Request,
next: Next,
) -> Response {
let provided = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let ok = match (auth.as_ref(), provided) {
(DashboardAuth::Basic { username, password }, Some(h)) => {
check_basic(h, username, password)
}
(DashboardAuth::Bearer { token }, Some(h)) => check_bearer(h, token),
_ => false,
};
if ok {
next.run(req).await
} else {
unauthorized(&auth)
}
}
pub async fn csrf_guard(req: Request, next: Next) -> Response {
let is_mutation = matches!(
*req.method(),
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
);
if !is_mutation {
return next.run(req).await;
}
let headers = req.headers();
let host = headers.get(header::HOST).and_then(|v| v.to_str().ok());
let source = headers
.get(header::ORIGIN)
.or_else(|| headers.get(header::REFERER))
.and_then(|v| v.to_str().ok())
.and_then(authority_of);
match (host, source) {
(Some(host), Some(source)) if host.eq_ignore_ascii_case(&source) => next.run(req).await,
_ => (StatusCode::FORBIDDEN, "CSRF check failed").into_response(),
}
}
pub(crate) fn is_loopback_host(host: &str) -> bool {
if host.eq_ignore_ascii_case("localhost") {
return true;
}
let trimmed = host.trim_start_matches('[').trim_end_matches(']');
trimmed
.parse::<std::net::IpAddr>()
.map(|ip| ip.is_loopback())
.unwrap_or(false)
}
fn unauthorized(auth: &DashboardAuth) -> Response {
let mut resp = (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
if matches!(auth, DashboardAuth::Basic { .. })
&& let Ok(challenge) = header::HeaderValue::from_str("Basic realm=\"qml-dashboard\"")
{
resp.headers_mut()
.insert(header::WWW_AUTHENTICATE, challenge);
}
resp
}
fn check_basic(header_value: &str, expected_user: &str, expected_pass: &str) -> bool {
let Some(encoded) = header_value.strip_prefix("Basic ") else {
return false;
};
let Some(bytes) = base64_decode(encoded.trim()) else {
return false;
};
let Ok(decoded) = std::str::from_utf8(&bytes) else {
return false;
};
let Some((user, pass)) = decoded.split_once(':') else {
return false;
};
let user_ok = constant_time_eq(user.as_bytes(), expected_user.as_bytes());
let pass_ok = constant_time_eq(pass.as_bytes(), expected_pass.as_bytes());
user_ok & pass_ok
}
fn check_bearer(header_value: &str, expected: &str) -> bool {
let Some(token) = header_value.strip_prefix("Bearer ") else {
return false;
};
constant_time_eq(token.trim().as_bytes(), expected.as_bytes())
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
fn authority_of(url: &str) -> Option<String> {
let (_, rest) = url.split_once("://")?;
let authority = rest.split(['/', '?', '#']).next()?;
if authority.is_empty() {
None
} else {
Some(authority.to_ascii_lowercase())
}
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
use base64::Engine as _;
base64::engine::general_purpose::STANDARD.decode(input).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
body::Body,
http::{Method, Request, StatusCode, header},
middleware,
routing::{delete, get, post},
};
use tower::ServiceExt;
#[test]
fn base64_decodes_standard_input() {
assert_eq!(
base64_decode("YWRtaW46c2VjcmV0"),
Some(b"admin:secret".to_vec())
);
assert_eq!(base64_decode("Zm9vOmJhcg=="), Some(b"foo:bar".to_vec()));
assert_eq!(base64_decode(""), Some(Vec::new()));
assert!(base64_decode("not*base64!").is_none());
}
#[test]
fn check_basic_accepts_matching_creds() {
let header = "Basic YWRtaW46c2VjcmV0"; assert!(check_basic(header, "admin", "secret"));
assert!(!check_basic(header, "admin", "wrong"));
assert!(!check_basic(header, "root", "secret"));
assert!(!check_basic("Bearer xyz", "admin", "secret"));
assert!(!check_basic("Basic !!!", "admin", "secret"));
}
#[test]
fn check_bearer_accepts_matching_token() {
assert!(check_bearer("Bearer abc123", "abc123"));
assert!(!check_bearer("Bearer abc123", "xyz"));
assert!(!check_bearer("Basic abc123", "abc123"));
}
#[test]
fn constant_time_eq_handles_length_mismatch() {
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
}
#[test]
fn authority_of_strips_path_and_scheme() {
assert_eq!(
authority_of("http://example.com:8080/foo?bar=1"),
Some("example.com:8080".to_string())
);
assert_eq!(
authority_of("https://DASH.local"),
Some("dash.local".to_string())
);
assert_eq!(authority_of("not-a-url"), None);
}
#[test]
fn is_loopback_host_recognizes_expected_values() {
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("::1"));
assert!(is_loopback_host("[::1]"));
assert!(!is_loopback_host("10.0.0.1"));
assert!(!is_loopback_host("example.com"));
}
fn test_app(auth: Option<DashboardAuth>) -> Router {
let mut app = Router::new()
.route("/api/health", get(|| async { "ok" }))
.route(
"/api/jobs/{id}/retry",
post(|| async { StatusCode::NO_CONTENT }),
)
.route(
"/api/jobs/{id}",
delete(|| async { StatusCode::NO_CONTENT }),
);
app = app.layer(middleware::from_fn(csrf_guard));
if let Some(auth) = auth {
app = app.layer(middleware::from_fn_with_state(Arc::new(auth), require_auth));
}
app
}
async fn send(app: Router, req: Request<Body>) -> StatusCode {
app.oneshot(req).await.unwrap().status()
}
#[tokio::test]
async fn require_auth_rejects_missing_credentials() {
let app = test_app(Some(DashboardAuth::Bearer {
token: "secret".into(),
}));
let req = Request::builder()
.uri("/api/health")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn require_auth_accepts_matching_bearer_token() {
let app = test_app(Some(DashboardAuth::Bearer {
token: "secret".into(),
}));
let req = Request::builder()
.uri("/api/health")
.header(header::AUTHORIZATION, "Bearer secret")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::OK);
}
#[tokio::test]
async fn require_auth_rejects_wrong_bearer_token() {
let app = test_app(Some(DashboardAuth::Bearer {
token: "secret".into(),
}));
let req = Request::builder()
.uri("/api/health")
.header(header::AUTHORIZATION, "Bearer wrong")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn require_auth_accepts_matching_basic_credentials() {
let app = test_app(Some(DashboardAuth::Basic {
username: "admin".into(),
password: "secret".into(),
}));
let req = Request::builder()
.uri("/api/health")
.header(header::AUTHORIZATION, "Basic YWRtaW46c2VjcmV0")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::OK);
}
#[tokio::test]
async fn csrf_guard_blocks_mutation_without_origin() {
let app = test_app(None);
let req = Request::builder()
.method(Method::POST)
.uri("/api/jobs/abc/retry")
.header(header::HOST, "localhost:8080")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn csrf_guard_allows_same_origin_mutation() {
let app = test_app(None);
let req = Request::builder()
.method(Method::POST)
.uri("/api/jobs/abc/retry")
.header(header::HOST, "localhost:8080")
.header(header::ORIGIN, "http://localhost:8080")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::NO_CONTENT);
}
#[tokio::test]
async fn csrf_guard_rejects_cross_origin_mutation() {
let app = test_app(None);
let req = Request::builder()
.method(Method::DELETE)
.uri("/api/jobs/abc")
.header(header::HOST, "localhost:8080")
.header(header::ORIGIN, "https://evil.example.com")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn csrf_guard_ignores_safe_methods() {
let app = test_app(None);
let req = Request::builder()
.method(Method::GET)
.uri("/api/health")
.header(header::HOST, "localhost:8080")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::OK);
}
#[tokio::test]
async fn csrf_guard_falls_back_to_referer() {
let app = test_app(None);
let req = Request::builder()
.method(Method::POST)
.uri("/api/jobs/abc/retry")
.header(header::HOST, "localhost:8080")
.header(header::REFERER, "http://localhost:8080/jobs")
.body(Body::empty())
.unwrap();
assert_eq!(send(app, req).await, StatusCode::NO_CONTENT);
}
}