use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use subtle::ConstantTimeEq as _;
#[derive(Clone)]
pub struct BearerAuthState {
pub token: Arc<String>,
}
impl BearerAuthState {
#[must_use]
pub fn new(token: String) -> Self {
Self {
token: Arc::new(token),
}
}
}
pub async fn bearer_auth_middleware(
State(auth_state): State<BearerAuthState>,
request: Request<Body>,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
None => {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Missing Authorization header",
)
.into_response();
},
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Invalid Authorization header format. Expected: Bearer <token>",
)
.into_response();
}
let token = &header_value[7..];
if !constant_time_compare(token, &auth_state.token) {
return (StatusCode::FORBIDDEN, "Invalid token").into_response();
}
},
}
next.run(request).await
}
pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
header_value.strip_prefix("Bearer ")
}
fn constant_time_compare(a: &str, b: &str) -> bool {
a.as_bytes().ct_eq(b.as_bytes()).into()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt;
use super::*;
async fn protected_handler() -> &'static str {
"secret data"
}
fn create_test_app(token: &str) -> Router {
let auth_state = BearerAuthState::new(token.to_string());
Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
}
#[tokio::test]
async fn test_valid_token_allows_access() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer secret-token-12345")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_missing_auth_header_returns_401() {
let app = create_test_app("secret-token-12345");
let request = Request::builder().uri("/protected").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert!(response.headers().contains_key("www-authenticate"));
}
#[tokio::test]
async fn test_invalid_auth_format_returns_401() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Basic dXNlcjpwYXNz") .body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_wrong_token_returns_403() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer wrong-token")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_empty_bearer_token_returns_403() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer ")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn test_constant_time_compare_equal() {
assert!(constant_time_compare("hello", "hello"));
assert!(constant_time_compare("", ""));
assert!(constant_time_compare("a-long-token-123", "a-long-token-123"));
}
#[test]
fn test_constant_time_compare_not_equal() {
assert!(!constant_time_compare("hello", "world"));
assert!(!constant_time_compare("hello", "hello!"));
assert!(!constant_time_compare("hello", "hell"));
assert!(!constant_time_compare("abc", "abd"));
}
#[test]
fn test_constant_time_compare_different_lengths() {
assert!(!constant_time_compare("short", "longer-string"));
assert!(!constant_time_compare("", "notempty"));
}
#[test]
fn test_subtle_compare_identical_tokens() {
assert!(constant_time_compare("x", "x"));
assert!(constant_time_compare(
"super-secret-32-char-admin-token",
"super-secret-32-char-admin-token"
));
}
#[test]
fn test_subtle_compare_off_by_one_byte() {
assert!(!constant_time_compare("token-abc", "token-abd")); assert!(!constant_time_compare("Aoken-abc", "token-abc")); }
#[test]
fn test_subtle_compare_empty_strings() {
assert!(constant_time_compare("", ""));
assert!(!constant_time_compare("", "a"));
assert!(!constant_time_compare("a", ""));
}
}