use axum::{
body::Body,
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use std::sync::Arc;
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut acc: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
acc |= x ^ y;
}
acc == 0
}
fn any_key_matches(keys: &[String], token: &str) -> bool {
let token_bytes = token.as_bytes();
let mut matched = false;
for key in keys {
if constant_time_eq(key.as_bytes(), token_bytes) {
matched = true;
}
}
matched
}
#[derive(Debug, Clone)]
pub struct AuthState {
pub api_keys: Arc<Vec<String>>,
}
impl AuthState {
pub fn new(api_keys: Vec<String>) -> Self {
Self {
api_keys: Arc::new(api_keys),
}
}
pub fn auth_enabled(&self) -> bool {
!self.api_keys.is_empty()
}
}
fn is_public_path(path: &str) -> bool {
matches!(path, "/health" | "/ready" | "/v1/health" | "/v1/ready")
}
fn extract_bearer_token(header_value: &str) -> Option<&str> {
let trimmed = header_value.trim();
if trimmed.len() > 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
let token = trimmed[7..].trim();
if token.is_empty() {
None
} else {
Some(token)
}
} else {
None
}
}
pub async fn auth_middleware(
axum::extract::State(state): axum::extract::State<AuthState>,
request: Request<Body>,
next: Next,
) -> Response {
if !state.auth_enabled() {
return next.run(request).await;
}
if is_public_path(request.uri().path()) {
return next.run(request).await;
}
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth_header {
Some(value) => match extract_bearer_token(value) {
Some(token) if any_key_matches(&state.api_keys, token) => next.run(request).await,
Some(_) => unauthorized_response("invalid API key"),
None => {
unauthorized_response("invalid Authorization header format, expected: Bearer <key>")
}
},
None => unauthorized_response("missing Authorization header"),
}
}
fn unauthorized_response(message: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "Unauthorized",
"message": message
})),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_state_disabled_when_empty() {
let state = AuthState::new(vec![]);
assert!(!state.auth_enabled());
}
#[test]
fn test_auth_state_enabled_with_keys() {
let state = AuthState::new(vec!["key1".to_string()]);
assert!(state.auth_enabled());
}
#[test]
fn test_is_public_path_health() {
assert!(is_public_path("/health"));
}
#[test]
fn test_is_public_path_ready() {
assert!(is_public_path("/ready"));
}
#[test]
fn test_is_public_path_metrics_is_protected() {
assert!(!is_public_path("/metrics"));
assert!(!is_public_path("/v1/metrics"));
}
#[test]
fn test_is_public_path_versioned_health() {
assert!(is_public_path("/v1/health"));
}
#[test]
fn test_is_public_path_versioned_ready() {
assert!(is_public_path("/v1/ready"));
}
#[test]
fn test_is_public_path_other() {
assert!(!is_public_path("/collections"));
assert!(!is_public_path("/query"));
assert!(!is_public_path("/health/extra"));
assert!(!is_public_path("/v1/collections"));
}
#[test]
fn test_extract_bearer_token_valid() {
assert_eq!(extract_bearer_token("Bearer my-key"), Some("my-key"));
assert_eq!(extract_bearer_token("bearer my-key"), Some("my-key"));
assert_eq!(extract_bearer_token("BEARER my-key"), Some("my-key"));
assert_eq!(extract_bearer_token(" Bearer my-key "), Some("my-key"));
}
#[test]
fn test_extract_bearer_token_invalid() {
assert_eq!(extract_bearer_token("Basic abc123"), None);
assert_eq!(extract_bearer_token("my-key"), None);
assert_eq!(extract_bearer_token("Bearer"), None);
assert_eq!(extract_bearer_token(""), None);
}
#[test]
fn test_extract_bearer_token_whitespace_only() {
assert_eq!(extract_bearer_token("Bearer "), None);
}
#[test]
fn test_constant_time_eq_identical() {
assert!(constant_time_eq(b"secret-key-42", b"secret-key-42"));
}
#[test]
fn test_constant_time_eq_different_content() {
assert!(!constant_time_eq(b"secret-key-42", b"secret-key-43"));
}
#[test]
fn test_constant_time_eq_different_length() {
assert!(!constant_time_eq(b"short", b"longer-key"));
}
#[test]
fn test_constant_time_eq_empty() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn test_any_key_matches_found() {
let keys = vec!["key-a".to_string(), "key-b".to_string()];
assert!(any_key_matches(&keys, "key-b"));
}
#[test]
fn test_any_key_matches_not_found() {
let keys = vec!["key-a".to_string(), "key-b".to_string()];
assert!(!any_key_matches(&keys, "key-c"));
}
#[test]
fn test_any_key_matches_empty_keys() {
let keys: Vec<String> = vec![];
assert!(!any_key_matches(&keys, "anything"));
}
}