use std::sync::Arc;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use crate::{
error::{AuthError, Result},
jwt::{Claims, JwtValidator},
session::SessionStore,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticatedUser {
pub user_id: String,
pub claims: Claims,
}
impl AuthenticatedUser {
pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.claims.get_custom(key)
}
pub fn has_role(&self, role: &str) -> bool {
if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
user_role == role
} else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
roles.iter().any(|r| {
if let serde_json::Value::String(r_str) = r {
r_str == role
} else {
false
}
})
} else {
false
}
}
}
pub struct AuthMiddleware {
validator: Arc<JwtValidator>,
_session_store: Arc<dyn SessionStore>,
public_key: Vec<u8>,
_optional: bool,
}
impl AuthMiddleware {
pub fn new(
validator: Arc<JwtValidator>,
session_store: Arc<dyn SessionStore>,
public_key: Vec<u8>,
optional: bool,
) -> Self {
Self {
validator,
_session_store: session_store,
public_key,
_optional: optional,
}
}
pub async fn validate_token(&self, token: &str) -> Result<Claims> {
self.validator.validate(token, &self.public_key)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
use tracing::warn;
let (status, error_code, sanitized_message) = match self {
AuthError::TokenExpired => {
(StatusCode::UNAUTHORIZED, "token_expired", "Authentication failed".to_string())
},
AuthError::InvalidSignature => (
StatusCode::UNAUTHORIZED,
"invalid_signature",
"Authentication failed".to_string(),
),
AuthError::InvalidToken { ref reason } => {
warn!("Invalid token error: {}", reason);
(StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
},
AuthError::MissingClaim { ref claim } => {
warn!("Missing required claim: {}", claim);
(StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
},
AuthError::InvalidClaimValue {
ref claim,
ref reason,
} => {
warn!("Invalid claim value for '{}': {}", claim, reason);
(StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
},
AuthError::TokenNotFound => {
(StatusCode::UNAUTHORIZED, "token_not_found", "Authentication failed".to_string())
},
AuthError::SessionRevoked => {
(StatusCode::UNAUTHORIZED, "session_revoked", "Authentication failed".to_string())
},
AuthError::InvalidState => {
(StatusCode::BAD_REQUEST, "invalid_state", "Authentication failed".to_string())
},
AuthError::Forbidden { ref message } => {
warn!("Authorization denied: {}", message);
(StatusCode::FORBIDDEN, "forbidden", "Permission denied".to_string())
},
AuthError::OAuthError { ref message } => {
warn!("OAuth provider error: {}", message);
(StatusCode::UNAUTHORIZED, "oauth_error", "Authentication failed".to_string())
},
AuthError::SessionError { ref message } => {
warn!("Session error: {}", message);
(StatusCode::UNAUTHORIZED, "session_error", "Authentication failed".to_string())
},
AuthError::DatabaseError { ref message } => {
warn!("Database error (should not reach client): {}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
)
},
AuthError::ConfigError { ref message } => {
warn!("Configuration error (should not reach client): {}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
)
},
AuthError::OidcMetadataError { ref message } => {
warn!("OIDC metadata error: {}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
)
},
AuthError::PkceError { ref message } => {
warn!("PKCE error: {}", message);
(StatusCode::BAD_REQUEST, "pkce_error", "Authentication failed".to_string())
},
AuthError::Internal { ref message } => {
warn!("Internal error (should not reach client): {}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
)
},
AuthError::SystemTimeError { ref message } => {
warn!("System time error (should not reach client): {}", message);
(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
)
},
AuthError::RateLimited { retry_after_secs } => (
StatusCode::TOO_MANY_REQUESTS,
"rate_limited",
format!("Too many requests. Retry after {} seconds", retry_after_secs),
),
};
let body = serde_json::json!({
"errors": [{
"message": sanitized_message,
"extensions": {
"code": error_code
}
}]
});
(status, axum::Json(body)).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_authenticated_user_clone() {
use std::collections::HashMap;
use crate::Claims;
let claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
let _cloned = user.clone();
assert_eq!(user.user_id, "user123");
}
#[test]
fn test_has_role_single_string() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims.extra.insert("role".to_string(), serde_json::json!("admin"));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert!(user.has_role("admin"));
assert!(!user.has_role("user"));
}
#[test]
fn test_has_role_array() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims
.extra
.insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert!(user.has_role("admin"));
assert!(user.has_role("user"));
assert!(user.has_role("editor"));
assert!(!user.has_role("moderator"));
}
#[test]
fn test_get_custom_claim() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
assert_eq!(user.get_custom_claim("nonexistent"), None);
}
#[test]
fn test_invalid_token_sanitized() {
let error = AuthError::InvalidToken {
reason: "RS256 signature mismatch at offset 512 bytes".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_missing_claim_sanitized() {
let error = AuthError::MissingClaim {
claim: "sensitive_user_id".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_claim_value_sanitized() {
let error = AuthError::InvalidClaimValue {
claim: "exp".to_string(),
reason: "Must match pattern: ^[0-9]{10,}$".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_database_error_sanitized() {
let error = AuthError::DatabaseError {
message: "Connection to 192.168.1.100:5432 failed: timeout".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_config_error_sanitized() {
let error = AuthError::ConfigError {
message: "Secret key missing in /etc/fraiseql/config.toml".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_oauth_error_sanitized() {
let error = AuthError::OAuthError {
message: "GitHub API returned 500 from https://api.github.com/user (rate limited)"
.to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_session_error_sanitized() {
let error = AuthError::SessionError {
message: "Redis connection pool exhausted: 0/10 available".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_forbidden_error_sanitized() {
let error = AuthError::Forbidden {
message: "User lacks role=admin AND permission=write:config for operation".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn test_internal_error_sanitized() {
let error = AuthError::Internal {
message: "Panic in JWT validation thread: index out of bounds".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_system_time_error_sanitized() {
let error = AuthError::SystemTimeError {
message: "System clock jumped backward by 3600 seconds".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_rate_limited_error_message() {
let error = AuthError::RateLimited {
retry_after_secs: 60,
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[test]
fn test_token_expired_returns_generic_message() {
let error = AuthError::TokenExpired;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_signature_returns_generic_message() {
let error = AuthError::InvalidSignature;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_state_error() {
let error = AuthError::InvalidState;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_pkce_error_returns_bad_request() {
let error = AuthError::PkceError {
message: "Challenge verification failed".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_oidc_metadata_error_returns_server_error() {
let error = AuthError::OidcMetadataError {
message: "Failed to fetch metadata".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_all_errors_have_status_codes() {
let errors = vec![
AuthError::TokenExpired,
AuthError::InvalidSignature,
AuthError::InvalidState,
AuthError::TokenNotFound,
AuthError::SessionRevoked,
AuthError::InvalidToken {
reason: "test".to_string(),
},
AuthError::MissingClaim {
claim: "test".to_string(),
},
AuthError::InvalidClaimValue {
claim: "test".to_string(),
reason: "test".to_string(),
},
AuthError::OAuthError {
message: "test".to_string(),
},
AuthError::SessionError {
message: "test".to_string(),
},
AuthError::DatabaseError {
message: "test".to_string(),
},
AuthError::ConfigError {
message: "test".to_string(),
},
AuthError::OidcMetadataError {
message: "test".to_string(),
},
AuthError::PkceError {
message: "test".to_string(),
},
AuthError::Forbidden {
message: "test".to_string(),
},
AuthError::Internal {
message: "test".to_string(),
},
AuthError::SystemTimeError {
message: "test".to_string(),
},
AuthError::RateLimited {
retry_after_secs: 60,
},
];
for error in errors {
let response = error.into_response();
let status = response.status();
assert!(
status == StatusCode::UNAUTHORIZED
|| status == StatusCode::FORBIDDEN
|| status == StatusCode::BAD_REQUEST
|| status == StatusCode::INTERNAL_SERVER_ERROR
|| status == StatusCode::TOO_MANY_REQUESTS,
"Unexpected status code: {}",
status
);
}
}
}