use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
#[derive(Clone)]
pub struct BearerAuthState {
pub token: Arc<String>,
}
impl BearerAuthState {
#[must_use]
pub fn new(token: String) -> Self {
Self {
token: Arc::new(token),
}
}
}
pub async fn bearer_auth_middleware(
State(auth_state): State<BearerAuthState>,
request: Request<Body>,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
None => {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Missing Authorization header",
)
.into_response();
},
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer")],
"Invalid Authorization header format. Expected: Bearer <token>",
)
.into_response();
}
let token = &header_value[7..];
if !constant_time_compare(token, &auth_state.token) {
return (StatusCode::FORBIDDEN, "Invalid token").into_response();
}
},
}
next.run(request).await
}
fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.bytes().zip(b.bytes()) {
result |= x ^ y;
}
result == 0
}
#[cfg(test)]
mod tests {
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt;
use super::*;
async fn protected_handler() -> &'static str {
"secret data"
}
fn create_test_app(token: &str) -> Router {
let auth_state = BearerAuthState::new(token.to_string());
Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
}
#[tokio::test]
async fn test_valid_token_allows_access() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer secret-token-12345")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_missing_auth_header_returns_401() {
let app = create_test_app("secret-token-12345");
let request = Request::builder().uri("/protected").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert!(response.headers().contains_key("www-authenticate"));
}
#[tokio::test]
async fn test_invalid_auth_format_returns_401() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Basic dXNlcjpwYXNz") .body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_wrong_token_returns_403() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer wrong-token")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn test_empty_bearer_token_returns_403() {
let app = create_test_app("secret-token-12345");
let request = Request::builder()
.uri("/protected")
.header("Authorization", "Bearer ")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn test_constant_time_compare_equal() {
assert!(constant_time_compare("hello", "hello"));
assert!(constant_time_compare("", ""));
assert!(constant_time_compare("a-long-token-123", "a-long-token-123"));
}
#[test]
fn test_constant_time_compare_not_equal() {
assert!(!constant_time_compare("hello", "world"));
assert!(!constant_time_compare("hello", "hello!"));
assert!(!constant_time_compare("hello", "hell"));
assert!(!constant_time_compare("abc", "abd"));
}
#[test]
fn test_constant_time_compare_different_lengths() {
assert!(!constant_time_compare("short", "longer-string"));
assert!(!constant_time_compare("", "notempty"));
}
}