use axum::extract::State;
use axum::http::{HeaderMap, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Json, Response};
use serde_json::json;
use subtle::ConstantTimeEq;
use crate::state::AppState;
pub fn unauthorized(detail: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
Json(json!({
"error": "unauthorized",
"detail": detail,
})),
)
.into_response()
}
pub fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
let value = headers.get("authorization")?.to_str().ok()?;
let mut parts = value.splitn(2, ' ');
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("bearer") {
return None;
}
let token = parts.next()?;
if token.is_empty() {
return None;
}
Some(token)
}
pub async fn require_api_key(
State(state): State<AppState>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let expected = match &state.api_key {
Some(key) => key,
None => return next.run(request).await,
};
let provided = match extract_bearer_token(request.headers()) {
Some(token) => token,
None => return unauthorized("missing or malformed Authorization header"),
};
if provided.as_bytes().ct_eq(expected.as_bytes()).into() {
next.run(request).await
} else {
unauthorized("invalid API key")
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
use http_body_util::BodyExt;
#[tokio::test]
async fn unauthorized_returns_401() {
let resp = unauthorized("test detail");
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn unauthorized_returns_json_body() {
let resp = unauthorized("bad token");
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("body must be valid JSON");
assert_eq!(body["error"], "unauthorized");
assert_eq!(body["detail"], "bad token");
}
#[tokio::test]
async fn unauthorized_preserves_detail() {
let resp = unauthorized("custom message");
let body_bytes = resp.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(
body["detail"], "custom message",
"detail field must match the provided message"
);
}
#[test]
fn extract_valid_bearer_token() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
HeaderValue::from_static("Bearer my-secret"),
);
assert_eq!(extract_bearer_token(&headers), Some("my-secret"));
}
#[test]
fn extract_bearer_case_insensitive() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("bearer my-key"));
assert_eq!(extract_bearer_token(&headers), Some("my-key"));
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("BEARER my-key"));
assert_eq!(extract_bearer_token(&headers), Some("my-key"));
}
#[test]
fn extract_missing_header_returns_none() {
let headers = HeaderMap::new();
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn extract_non_bearer_scheme_returns_none() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
HeaderValue::from_static("Basic dXNlcjpwYXNz"),
);
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn extract_bearer_no_token_returns_none() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("Bearer"));
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn extract_bearer_empty_token_returns_none() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("Bearer "));
assert_eq!(extract_bearer_token(&headers), None);
}
#[test]
fn extract_token_with_spaces_preserved() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
HeaderValue::from_static("Bearer token with spaces"),
);
assert_eq!(
extract_bearer_token(&headers),
Some("token with spaces"),
"everything after 'Bearer ' should be the token"
);
}
#[test]
fn extract_scheme_only_returns_none() {
let mut headers = HeaderMap::new();
headers.insert("authorization", HeaderValue::from_static("Bearer"));
assert_eq!(extract_bearer_token(&headers), None);
}
}