use crate::spec_ai_api::api::auth::AuthService;
use axum::{
Json,
extract::{Request, State},
http::{StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct AuthenticatedUser {
pub username: String,
}
pub async fn auth_middleware(
State(auth_service): State<Arc<AuthService>>,
mut request: Request,
next: Next,
) -> Response {
if !auth_service.is_enabled() {
return next.run(request).await;
}
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
let Some(auth_str) = auth_header else {
return unauthorized_response("Missing Authorization header");
};
let Some(token) = auth_str.strip_prefix("Bearer ") else {
return unauthorized_response(
"Invalid Authorization header format. Expected: Bearer <token>",
);
};
let Some(username) = auth_service.validate_token(token) else {
return unauthorized_response("Invalid or expired token");
};
request
.extensions_mut()
.insert(AuthenticatedUser { username });
next.run(request).await
}
fn unauthorized_response(message: &str) -> Response {
let body = serde_json::json!({
"error": message,
"code": "unauthorized"
});
(
StatusCode::UNAUTHORIZED,
[(header::CONTENT_TYPE, "application/json")],
Json(body),
)
.into_response()
}
pub struct ApiKeyAuth {
api_key: Option<String>,
}
impl ApiKeyAuth {
pub fn new(api_key: Option<String>) -> Self {
Self { api_key }
}
pub fn is_enabled(&self) -> bool {
self.api_key.is_some()
}
pub fn validate(&self, key: &str) -> bool {
match &self.api_key {
Some(expected) => expected == key,
None => true, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_key_auth_disabled() {
let auth = ApiKeyAuth::new(None);
assert!(!auth.is_enabled());
assert!(auth.validate("any_key"));
}
#[test]
fn test_api_key_auth_enabled() {
let auth = ApiKeyAuth::new(Some("secret123".to_string()));
assert!(auth.is_enabled());
assert!(auth.validate("secret123"));
assert!(!auth.validate("wrong_key"));
}
#[test]
fn test_api_key_validation() {
let auth = ApiKeyAuth::new(Some("my-secret-key".to_string()));
assert!(auth.validate("my-secret-key"));
assert!(!auth.validate(""));
assert!(!auth.validate("wrong"));
}
}