use axum::extract::FromRequestParts;
use axum::http::StatusCode;
use axum::http::header::{AUTHORIZATION, HeaderMap, WWW_AUTHENTICATE};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Response};
use crate::state::AppState;
#[derive(Debug)]
pub(crate) struct RequireBearer;
impl FromRequestParts<AppState> for RequireBearer {
type Rejection = AuthRejection;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let Some(expected) = state.push_token.as_deref() else {
return Err(AuthRejection::Unconfigured);
};
let presented = extract_bearer(&parts.headers).ok_or(AuthRejection::Missing)?;
if constant_time_eq(presented.as_bytes(), expected.as_bytes()) {
Ok(Self)
} else {
Err(AuthRejection::Mismatch)
}
}
}
fn extract_bearer(headers: &HeaderMap) -> Option<String> {
let raw = headers.get(AUTHORIZATION)?.to_str().ok()?;
let (scheme, token) = raw.split_once(' ')?;
if !scheme.eq_ignore_ascii_case("Bearer") {
return None;
}
let token = token.trim();
if token.is_empty() {
return None;
}
Some(token.to_string())
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum AuthRejection {
Missing,
Mismatch,
Unconfigured,
}
impl IntoResponse for AuthRejection {
fn into_response(self) -> Response {
match self {
Self::Missing | Self::Mismatch => {
let body = serde_json::json!({
"type": "https://mnem.dev/errors/auth",
"title": "Unauthorized",
"status": 401,
"detail": match self {
Self::Missing => "missing or malformed Authorization: Bearer header",
Self::Mismatch => "bearer token did not match",
Self::Unconfigured => unreachable!(),
},
});
(
StatusCode::UNAUTHORIZED,
[
(WWW_AUTHENTICATE, "Bearer realm=\"mnem\""),
(axum::http::header::CONTENT_TYPE, "application/problem+json"),
],
body.to_string(),
)
.into_response()
}
Self::Unconfigured => {
let body = serde_json::json!({
"type": "https://mnem.dev/errors/auth-unconfigured",
"title": "Service Unavailable",
"status": 503,
"detail": "push authentication not configured on this server",
});
(
StatusCode::SERVICE_UNAVAILABLE,
[(axum::http::header::CONTENT_TYPE, "application/problem+json")],
body.to_string(),
)
.into_response()
}
}
}
}
#[inline]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn state_with(token: Option<&str>) -> AppState {
crate::state::test_support::state_with_token(token.map(str::to_string))
}
#[test]
fn constant_time_eq_matches_only_equal_bytes() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn extract_bearer_happy_path() {
let mut h = HeaderMap::new();
h.insert(AUTHORIZATION, HeaderValue::from_static("Bearer tok123"));
assert_eq!(extract_bearer(&h).as_deref(), Some("tok123"));
}
#[test]
fn extract_bearer_case_insensitive_scheme() {
let mut h = HeaderMap::new();
h.insert(AUTHORIZATION, HeaderValue::from_static("bearer tok123"));
assert_eq!(extract_bearer(&h).as_deref(), Some("tok123"));
}
#[test]
fn extract_bearer_rejects_wrong_scheme() {
let mut h = HeaderMap::new();
h.insert(
AUTHORIZATION,
HeaderValue::from_static("Basic dXNlcjpwYXNz"),
);
assert!(extract_bearer(&h).is_none());
}
#[test]
fn extract_bearer_rejects_empty_token() {
let mut h = HeaderMap::new();
h.insert(AUTHORIZATION, HeaderValue::from_static("Bearer "));
assert!(extract_bearer(&h).is_none());
}
#[test]
fn extract_bearer_missing_header() {
let h = HeaderMap::new();
assert!(extract_bearer(&h).is_none());
}
#[tokio::test]
async fn extractor_accepts_matching_token() {
let state = state_with(Some("secret"));
let req = axum::http::Request::builder()
.uri("/remote/v1/push-blocks")
.header(AUTHORIZATION, "Bearer secret")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let r = RequireBearer::from_request_parts(&mut parts, &state).await;
assert!(r.is_ok(), "expected ok, got {r:?}");
}
#[tokio::test]
async fn extractor_rejects_missing_header() {
let state = state_with(Some("secret"));
let req = axum::http::Request::builder()
.uri("/remote/v1/push-blocks")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let r = RequireBearer::from_request_parts(&mut parts, &state).await;
assert!(matches!(r, Err(AuthRejection::Missing)));
}
#[tokio::test]
async fn extractor_rejects_mismatched_token() {
let state = state_with(Some("secret"));
let req = axum::http::Request::builder()
.uri("/remote/v1/push-blocks")
.header(AUTHORIZATION, "Bearer wrong")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let r = RequireBearer::from_request_parts(&mut parts, &state).await;
assert!(matches!(r, Err(AuthRejection::Mismatch)));
}
#[tokio::test]
async fn extractor_returns_unconfigured_when_token_missing() {
let state = state_with(None);
let req = axum::http::Request::builder()
.uri("/remote/v1/push-blocks")
.header(AUTHORIZATION, "Bearer anything")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let r = RequireBearer::from_request_parts(&mut parts, &state).await;
assert!(matches!(r, Err(AuthRejection::Unconfigured)));
}
#[test]
fn rejection_401_carries_www_authenticate_bearer_realm() {
for rej in [AuthRejection::Missing, AuthRejection::Mismatch] {
let resp = rej.into_response();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let hdr = resp
.headers()
.get(WWW_AUTHENTICATE)
.unwrap_or_else(|| panic!("WWW-Authenticate missing for {rej:?}"));
assert_eq!(
hdr.to_str().unwrap(),
"Bearer realm=\"mnem\"",
"challenge shape drifted for {rej:?}"
);
let ct = resp
.headers()
.get(axum::http::header::CONTENT_TYPE)
.expect("content-type present");
assert_eq!(ct.to_str().unwrap(), "application/problem+json");
}
}
#[test]
fn rejection_503_omits_www_authenticate() {
let resp = AuthRejection::Unconfigured.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
assert!(
resp.headers().get(WWW_AUTHENTICATE).is_none(),
"503 must not emit a bearer challenge"
);
}
}