use axum::{
extract::Request,
http::{HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
use uuid::Uuid;
#[allow(unused_imports)]
use crate::models::{Account, AppState};
pub async fn auth_middleware(
axum::extract::State(state): axum::extract::State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if request.uri().path() == "/health" {
return Ok(next.run(request).await);
}
let api_key = extract_api_key(request.headers()).map_err(|e| {
eprintln!("[AUTH] Failed to extract API key: {}", e);
StatusCode::UNAUTHORIZED
})?;
eprintln!("[AUTH] Checking API key: {}...", &api_key[..20.min(api_key.len())]);
let account = {
let mut accounts = state.accounts.write().await;
eprintln!("[AUTH] Total accounts loaded: {}", accounts.len());
if let Some(account) = accounts.get_mut(&api_key) {
account.last_used = chrono::Utc::now();
Some(account.clone())
} else {
eprintln!("[AUTH] API key not found in {} accounts", accounts.len());
None
}
};
if let Some(account) = account {
request.extensions_mut().insert(account);
Ok(next.run(request).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
pub async fn request_id_middleware(
mut request: Request,
next: Next,
) -> Response {
let request_id = Uuid::new_v4().to_string();
request.headers_mut().insert(
"x-request-id",
request_id.parse().unwrap(),
);
next.run(request).await
}
pub fn extract_api_key(headers: &HeaderMap) -> Result<String, String> {
headers
.get("x-api-key")
.or_else(|| headers.get("authorization"))
.and_then(|h| h.to_str().ok())
.map(|h| {
if h.starts_with("Bearer ") {
h.strip_prefix("Bearer ").unwrap().to_string()
} else {
h.to_string()
}
})
.ok_or_else(|| "Missing API key".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_api_key_from_header() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "sk_test_123456".parse().unwrap());
let result = extract_api_key(&headers);
assert_eq!(result.unwrap(), "sk_test_123456");
}
#[test]
fn test_extract_api_key_from_bearer() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer sk_test_123456".parse().unwrap());
let result = extract_api_key(&headers);
assert_eq!(result.unwrap(), "sk_test_123456");
}
#[test]
fn test_extract_api_key_missing() {
let headers = HeaderMap::new();
let result = extract_api_key(&headers);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Missing API key");
}
#[test]
fn test_extract_api_key_from_x_api_key_priority() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "key_from_x_api_key".parse().unwrap());
headers.insert("authorization", "Bearer key_from_auth".parse().unwrap());
let result = extract_api_key(&headers);
assert_eq!(result.unwrap(), "key_from_x_api_key");
}
#[test]
fn test_extract_api_key_case_insensitive_header() {
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", "sk_test_uppercase".parse().unwrap());
let result = extract_api_key(&headers);
assert_eq!(result.unwrap(), "sk_test_uppercase");
}
#[test]
fn test_extract_api_key_bearer_with_extra_spaces() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer sk_test_spaces".parse().unwrap());
let result = extract_api_key(&headers);
assert_eq!(result.unwrap(), " sk_test_spaces");
}
}