use axum::{
extract::{Request, State},
http::header,
middleware::Next,
response::Response,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
use crate::api::jwt::{Claims, JwtManager};
use crate::api::models::ApiError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserContext {
pub user_id: String,
pub tenant_id: String,
pub client_id: Uuid,
pub auth_method: AuthMethod,
pub scopes: Vec<String>,
}
impl UserContext {
pub fn from_claims(claims: Claims, auth_method: AuthMethod) -> Self {
Self {
user_id: claims.sub,
tenant_id: claims.tenant_id,
client_id: claims.client_id,
auth_method,
scopes: claims.scopes,
}
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope)
}
pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
scopes.iter().any(|scope| self.has_scope(scope))
}
pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
scopes.iter().all(|scope| self.has_scope(scope))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum AuthMethod {
JwtBearer,
ApiKey,
}
#[derive(Clone)]
pub struct AuthMiddleware {
jwt_manager: Arc<JwtManager>,
api_keys: Arc<std::collections::HashMap<String, ApiKeyInfo>>,
}
#[derive(Debug, Clone)]
struct ApiKeyInfo {
user_id: String,
tenant_id: String,
client_id: Uuid,
scopes: Vec<String>,
name: String,
}
impl AuthMiddleware {
pub fn new(jwt_secret: &[u8]) -> Self {
Self {
jwt_manager: Arc::new(JwtManager::new(jwt_secret)),
api_keys: Arc::new(std::collections::HashMap::new()),
}
}
pub fn from_env_or_default() -> Self {
let secret = std::env::var("HELIOSDB_JWT_SECRET")
.unwrap_or_else(|_| "default-secret-change-in-production".to_string());
Self::new(secret.as_bytes())
}
pub fn with_api_key(
mut self,
key: String,
user_id: String,
tenant_id: String,
client_id: Uuid,
scopes: Vec<String>,
name: String,
) -> Self {
let mut keys = (*self.api_keys).clone();
keys.insert(key, ApiKeyInfo {
user_id,
tenant_id,
client_id,
scopes,
name,
});
self.api_keys = Arc::new(keys);
self
}
pub async fn authenticate(&self, req: &Request) -> Result<UserContext, ApiError> {
if let Some(auth_header) = req.headers().get(header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
return self.authenticate_jwt(token).await;
}
}
}
if let Some(api_key_header) = req.headers().get("x-api-key") {
if let Ok(api_key) = api_key_header.to_str() {
return self.authenticate_api_key(api_key).await;
}
}
Err(ApiError::unauthorized("Missing or invalid authentication credentials"))
}
pub async fn authenticate_jwt(&self, token: &str) -> Result<UserContext, ApiError> {
let claims = self.jwt_manager
.validate_token(token)
.map_err(|_| ApiError::unauthorized("Invalid or expired JWT token"))?;
Ok(UserContext::from_claims(claims, AuthMethod::JwtBearer))
}
pub async fn authenticate_api_key(&self, api_key: &str) -> Result<UserContext, ApiError> {
let key_info = self.api_keys
.get(api_key)
.ok_or_else(|| ApiError::unauthorized("Invalid API key"))?;
Ok(UserContext {
user_id: key_info.user_id.clone(),
tenant_id: key_info.tenant_id.clone(),
client_id: key_info.client_id,
auth_method: AuthMethod::ApiKey,
scopes: key_info.scopes.clone(),
})
}
}
pub async fn auth_middleware(
State(auth): State<Arc<AuthMiddleware>>,
mut req: Request,
next: Next,
) -> Result<Response, ApiError> {
let user_ctx = auth.authenticate(&req).await?;
req.extensions_mut().insert(user_ctx);
Ok(next.run(req).await)
}
#[derive(Debug, Clone)]
pub struct AuthUser(pub UserContext);
#[axum::async_trait]
impl<S> axum::extract::FromRequestParts<S> for AuthUser
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<UserContext>()
.cloned()
.map(AuthUser)
.ok_or_else(|| ApiError::unauthorized("User context not found. Authentication middleware may not be configured."))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_user_context_from_claims() {
let claims = Claims::new(
"user123".to_string(),
"tenant456".to_string(),
Uuid::new_v4(),
chrono::Duration::hours(1),
);
let ctx = UserContext::from_claims(claims.clone(), AuthMethod::JwtBearer);
assert_eq!(ctx.user_id, "user123");
assert_eq!(ctx.tenant_id, "tenant456");
assert_eq!(ctx.auth_method, AuthMethod::JwtBearer);
assert!(ctx.has_scope("api:read"));
assert!(ctx.has_scope("api:write"));
}
#[test]
fn test_user_context_scope_checks() {
let mut claims = Claims::new(
"user123".to_string(),
"tenant456".to_string(),
Uuid::new_v4(),
chrono::Duration::hours(1),
);
claims.scopes = vec!["read".to_string(), "write".to_string()];
let ctx = UserContext::from_claims(claims, AuthMethod::JwtBearer);
assert!(ctx.has_scope("read"));
assert!(ctx.has_scope("write"));
assert!(!ctx.has_scope("admin"));
assert!(ctx.has_any_scope(&["read", "admin"]));
assert!(!ctx.has_any_scope(&["admin", "delete"]));
assert!(ctx.has_all_scopes(&["read", "write"]));
assert!(!ctx.has_all_scopes(&["read", "write", "admin"]));
}
#[tokio::test]
async fn test_auth_middleware_jwt() {
let jwt_manager = JwtManager::new(b"test-secret");
let client_id = Uuid::new_v4();
let token = jwt_manager
.generate_token(
"user123".to_string(),
"tenant456".to_string(),
client_id,
)
.unwrap();
let auth = AuthMiddleware::new(b"test-secret");
let req = Request::builder()
.header(header::AUTHORIZATION, format!("Bearer {}", token))
.body(axum::body::Body::empty())
.unwrap();
let result = auth.authenticate(&req).await;
assert!(result.is_ok());
let ctx = result.unwrap();
assert_eq!(ctx.user_id, "user123");
assert_eq!(ctx.tenant_id, "tenant456");
assert_eq!(ctx.auth_method, AuthMethod::JwtBearer);
}
#[tokio::test]
async fn test_auth_middleware_api_key() {
let client_id = Uuid::new_v4();
let auth = AuthMiddleware::new(b"test-secret")
.with_api_key(
"test-key-123".to_string(),
"user123".to_string(),
"tenant456".to_string(),
client_id,
vec!["read".to_string(), "write".to_string()],
"Test Key".to_string(),
);
let req = Request::builder()
.header("x-api-key", "test-key-123")
.body(axum::body::Body::empty())
.unwrap();
let result = auth.authenticate(&req).await;
assert!(result.is_ok());
let ctx = result.unwrap();
assert_eq!(ctx.user_id, "user123");
assert_eq!(ctx.tenant_id, "tenant456");
assert_eq!(ctx.auth_method, AuthMethod::ApiKey);
assert!(ctx.has_scope("read"));
}
#[tokio::test]
async fn test_auth_middleware_missing_credentials() {
let auth = AuthMiddleware::new(b"test-secret");
let req = Request::builder()
.body(axum::body::Body::empty())
.unwrap();
let result = auth.authenticate(&req).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_middleware_invalid_jwt() {
let auth = AuthMiddleware::new(b"test-secret");
let req = Request::builder()
.header(header::AUTHORIZATION, "Bearer invalid.token.here")
.body(axum::body::Body::empty())
.unwrap();
let result = auth.authenticate(&req).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_middleware_invalid_api_key() {
let auth = AuthMiddleware::new(b"test-secret");
let req = Request::builder()
.header("x-api-key", "invalid-key")
.body(axum::body::Body::empty())
.unwrap();
let result = auth.authenticate(&req).await;
assert!(result.is_err());
}
}