use crate::auth::AuthService;
use crate::middleware::rate_limit::ClientTier;
use axum::{
body::Body,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub key_id: String,
pub name: String,
pub tier: crate::config::rate_limit::RateLimitTier,
}
pub async fn auth_middleware(
auth_service: Arc<AuthService>,
mut request: Request,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok());
let api_key = match auth_header {
Some(header) => {
if let Some(key) = header.strip_prefix("Bearer ") {
key
} else {
return create_unauthorized_response("Invalid authorization header format");
}
}
None => {
return create_unauthorized_response("Missing authorization header");
}
};
let key = match auth_service.validate_key(api_key).await {
Ok(key) => key,
Err(e) => {
tracing::warn!("API key validation failed: {}", e);
return create_unauthorized_response("Invalid or expired API key");
}
};
let user = AuthenticatedUser {
key_id: key.id.clone(),
name: key.name.clone(),
tier: key.tier,
};
request.extensions_mut().insert(ClientTier(key.tier));
request.extensions_mut().insert(user);
tracing::debug!(
"Authenticated request from key: {} ({})",
key.id,
key.name
);
next.run(request).await
}
fn create_unauthorized_response(message: &str) -> Response {
let body = serde_json::json!({
"error": "Unauthorized",
"message": message,
});
(
StatusCode::UNAUTHORIZED,
serde_json::to_string(&body).unwrap(),
)
.into_response()
}
pub async fn optional_auth_middleware(
auth_service: Arc<AuthService>,
mut request: Request,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get("authorization")
.and_then(|h| h.to_str().ok());
if let Some(header) = auth_header {
if let Some(api_key) = header.strip_prefix("Bearer ") {
if let Ok(key) = auth_service.validate_key(api_key).await {
let user = AuthenticatedUser {
key_id: key.id.clone(),
name: key.name.clone(),
tier: key.tier,
};
request.extensions_mut().insert(ClientTier(key.tier));
request.extensions_mut().insert(user);
tracing::debug!(
"Authenticated optional request from key: {} ({})",
key.id,
key.name
);
}
}
}
next.run(request).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{storage::MemoryKeyStorage, AuthService};
use crate::config::rate_limit::RateLimitTier;
use axum::http::{Request, StatusCode};
async fn create_test_service_and_key() -> (Arc<AuthService>, String) {
let storage = Arc::new(MemoryKeyStorage::new());
let service = Arc::new(AuthService::new(storage));
let response = service
.create_key("test-key".to_string(), RateLimitTier::Pro, None)
.await
.unwrap();
(service, response.key)
}
#[test]
fn test_create_unauthorized_response() {
let response = create_unauthorized_response("Test message");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_authenticated_user_creation() {
let user = AuthenticatedUser {
key_id: "test-id".to_string(),
name: "test-name".to_string(),
tier: RateLimitTier::Pro,
};
assert_eq!(user.key_id, "test-id");
assert_eq!(user.name, "test-name");
assert_eq!(user.tier, RateLimitTier::Pro);
}
}