use axum::{
extract::State,
http::{self, Request, StatusCode},
middleware::Next,
response::Response,
};
use std::sync::Arc;
use subtle::ConstantTimeEq;
use crate::daemon::AppState;
pub async fn bearer_auth(
State(state): State<Arc<AppState>>,
request: Request<axum::body::Body>,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_header {
Some(h) if h.starts_with("Bearer ") => {
let token = &h["Bearer ".len()..];
if token.as_bytes().ct_eq(state.token.as_bytes()).into() {
Ok(next.run(request).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_eq_valid_token_matches() {
let token = "abc123def456";
assert!(
bool::from(token.as_bytes().ct_eq(token.as_bytes())),
"identical tokens must match"
);
}
#[test]
fn test_constant_time_eq_invalid_token_does_not_match() {
let expected = "correct-token-abc";
let provided = "wrong-token-xyz";
assert!(
!bool::from(expected.as_bytes().ct_eq(provided.as_bytes())),
"different tokens must not match"
);
}
#[test]
fn test_constant_time_eq_empty_token_does_not_match_nonempty() {
let expected = "correct-token";
let provided = "";
assert!(
!bool::from(expected.as_bytes().ct_eq(provided.as_bytes())),
"empty token must not match non-empty expected"
);
}
}