use std::sync::Arc;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use crate::{
error::{AuthError, Result},
jwt::{Claims, JwtValidator},
session::SessionStore,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticatedUser {
pub user_id: String,
pub claims: Claims,
}
impl AuthenticatedUser {
pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.claims.get_custom(key)
}
pub fn has_role(&self, role: &str) -> bool {
if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
user_role == role
} else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
roles.iter().any(|r| {
if let serde_json::Value::String(r_str) = r {
r_str == role
} else {
false
}
})
} else {
false
}
}
}
pub struct AuthMiddleware {
validator: Arc<JwtValidator>,
_session_store: Arc<dyn SessionStore>,
public_key: Vec<u8>,
_optional: bool,
}
impl AuthMiddleware {
pub fn new(
validator: Arc<JwtValidator>,
session_store: Arc<dyn SessionStore>,
public_key: Vec<u8>,
optional: bool,
) -> Self {
Self {
validator,
_session_store: session_store,
public_key,
_optional: optional,
}
}
pub async fn validate_token(&self, token: &str) -> Result<Claims> {
self.validator.validate(token, &self.public_key)
}
}
impl AuthError {
#[allow(clippy::cognitive_complexity)] fn response_parts(&self) -> (StatusCode, &'static str, String) {
match self {
Self::TokenExpired => {
(StatusCode::UNAUTHORIZED, "token_expired", "Authentication failed".to_string())
},
Self::InvalidSignature => (
StatusCode::UNAUTHORIZED,
"invalid_signature",
"Authentication failed".to_string(),
),
Self::InvalidToken { .. }
| Self::MissingClaim { .. }
| Self::InvalidClaimValue { .. }
| Self::MissingNonce
| Self::NonceMismatch
| Self::MissingAuthTime
| Self::SessionTooOld { .. } => {
(StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
},
Self::TokenNotFound => {
(StatusCode::UNAUTHORIZED, "token_not_found", "Authentication failed".to_string())
},
Self::SessionRevoked => {
(StatusCode::UNAUTHORIZED, "session_revoked", "Authentication failed".to_string())
},
Self::InvalidState => {
(StatusCode::BAD_REQUEST, "invalid_state", "Authentication failed".to_string())
},
Self::Forbidden { .. } => {
(StatusCode::FORBIDDEN, "forbidden", "Permission denied".to_string())
},
Self::OAuthError { .. } => {
(StatusCode::UNAUTHORIZED, "oauth_error", "Authentication failed".to_string())
},
Self::SessionError { .. } => {
(StatusCode::UNAUTHORIZED, "session_error", "Authentication failed".to_string())
},
Self::DatabaseError { .. }
| Self::ConfigError { .. }
| Self::OidcMetadataError { .. }
| Self::Internal { .. }
| Self::SystemTimeError { .. } => (
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"Service temporarily unavailable".to_string(),
),
Self::PkceError { .. } => {
(StatusCode::BAD_REQUEST, "pkce_error", "Authentication failed".to_string())
},
Self::RateLimited { retry_after_secs } => (
StatusCode::TOO_MANY_REQUESTS,
"rate_limited",
format!("Too many requests. Retry after {retry_after_secs} seconds"),
),
}
}
#[allow(clippy::cognitive_complexity)] fn log_security_details(&self) {
use tracing::warn;
match self {
Self::InvalidToken { reason } => warn!("Invalid token error: {reason}"),
Self::MissingClaim { claim } => warn!("Missing required claim: {claim}"),
Self::InvalidClaimValue { claim, reason } => {
warn!("Invalid claim value for '{claim}': {reason}");
},
Self::Forbidden { message } => warn!("Authorization denied: {message}"),
Self::OAuthError { message } => warn!("OAuth provider error: {message}"),
Self::SessionError { message } => warn!("Session error: {message}"),
Self::DatabaseError { message } => {
warn!("Database error (should not reach client): {message}");
},
Self::ConfigError { message } => {
warn!("Configuration error (should not reach client): {message}");
},
Self::OidcMetadataError { message } => warn!("OIDC metadata error: {message}"),
Self::PkceError { message } => warn!("PKCE error: {message}"),
Self::Internal { message } => {
warn!("Internal error (should not reach client): {message}");
},
Self::SystemTimeError { message } => {
warn!("System time error (should not reach client): {message}");
},
Self::MissingNonce | Self::NonceMismatch => {
warn!("OIDC nonce validation failed: {self}");
},
Self::MissingAuthTime | Self::SessionTooOld { .. } => {
warn!("OIDC auth_time validation failed: {self}");
},
Self::TokenExpired
| Self::InvalidSignature
| Self::TokenNotFound
| Self::SessionRevoked
| Self::InvalidState
| Self::RateLimited { .. } => {},
}
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
self.log_security_details();
let (status, error_code, sanitized_message) = self.response_parts();
let body = serde_json::json!({
"errors": [{
"message": sanitized_message,
"extensions": {
"code": error_code
}
}]
});
(status, axum::Json(body)).into_response()
}
}
#[cfg(test)]
mod tests {
#[allow(clippy::wildcard_imports)]
use super::*;
#[test]
fn test_authenticated_user_clone() {
use std::collections::HashMap;
use crate::Claims;
let claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
let _cloned = user.clone();
assert_eq!(user.user_id, "user123");
}
#[test]
fn test_has_role_single_string() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims.extra.insert("role".to_string(), serde_json::json!("admin"));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert!(user.has_role("admin"));
assert!(!user.has_role("user"));
}
#[test]
fn test_has_role_array() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims
.extra
.insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert!(user.has_role("admin"));
assert!(user.has_role("user"));
assert!(user.has_role("editor"));
assert!(!user.has_role("moderator"));
}
#[test]
fn test_get_custom_claim() {
use std::collections::HashMap;
use crate::Claims;
let mut claims = Claims {
sub: "user123".to_string(),
iat: 1000,
exp: 2000,
iss: "https://example.com".to_string(),
aud: vec!["api".to_string()],
extra: HashMap::new(),
};
claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
let user = AuthenticatedUser {
user_id: "user123".to_string(),
claims,
};
assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
assert_eq!(user.get_custom_claim("nonexistent"), None);
}
#[test]
fn test_invalid_token_sanitized() {
let error = AuthError::InvalidToken {
reason: "RS256 signature mismatch at offset 512 bytes".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_missing_claim_sanitized() {
let error = AuthError::MissingClaim {
claim: "sensitive_user_id".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_claim_value_sanitized() {
let error = AuthError::InvalidClaimValue {
claim: "exp".to_string(),
reason: "Must match pattern: ^[0-9]{10,}$".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_database_error_sanitized() {
let error = AuthError::DatabaseError {
message: "Connection to 192.168.1.100:5432 failed: timeout".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_config_error_sanitized() {
let error = AuthError::ConfigError {
message: "Secret key missing in /etc/fraiseql/config.toml".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_oauth_error_sanitized() {
let error = AuthError::OAuthError {
message: "GitHub API returned 500 from https://api.github.com/user (rate limited)"
.to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_session_error_sanitized() {
let error = AuthError::SessionError {
message: "Redis connection pool exhausted: 0/10 available".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_forbidden_error_sanitized() {
let error = AuthError::Forbidden {
message: "User lacks role=admin AND permission=write:config for operation".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn test_internal_error_sanitized() {
let error = AuthError::Internal {
message: "Panic in JWT validation thread: index out of bounds".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_system_time_error_sanitized() {
let error = AuthError::SystemTimeError {
message: "System clock jumped backward by 3600 seconds".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_rate_limited_error_message() {
let error = AuthError::RateLimited {
retry_after_secs: 60,
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[test]
fn test_token_expired_returns_generic_message() {
let error = AuthError::TokenExpired;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_signature_returns_generic_message() {
let error = AuthError::InvalidSignature;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_invalid_state_error() {
let error = AuthError::InvalidState;
let response = error.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_pkce_error_returns_bad_request() {
let error = AuthError::PkceError {
message: "Challenge verification failed".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_oidc_metadata_error_returns_server_error() {
let error = AuthError::OidcMetadataError {
message: "Failed to fetch metadata".to_string(),
};
let response = error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_all_errors_have_status_codes() {
let errors = vec![
AuthError::TokenExpired,
AuthError::InvalidSignature,
AuthError::InvalidState,
AuthError::TokenNotFound,
AuthError::SessionRevoked,
AuthError::InvalidToken {
reason: "test".to_string(),
},
AuthError::MissingClaim {
claim: "test".to_string(),
},
AuthError::InvalidClaimValue {
claim: "test".to_string(),
reason: "test".to_string(),
},
AuthError::OAuthError {
message: "test".to_string(),
},
AuthError::SessionError {
message: "test".to_string(),
},
AuthError::DatabaseError {
message: "test".to_string(),
},
AuthError::ConfigError {
message: "test".to_string(),
},
AuthError::OidcMetadataError {
message: "test".to_string(),
},
AuthError::PkceError {
message: "test".to_string(),
},
AuthError::Forbidden {
message: "test".to_string(),
},
AuthError::Internal {
message: "test".to_string(),
},
AuthError::SystemTimeError {
message: "test".to_string(),
},
AuthError::RateLimited {
retry_after_secs: 60,
},
];
for error in errors {
let response = error.into_response();
let status = response.status();
assert!(
status == StatusCode::UNAUTHORIZED
|| status == StatusCode::FORBIDDEN
|| status == StatusCode::BAD_REQUEST
|| status == StatusCode::INTERNAL_SERVER_ERROR
|| status == StatusCode::TOO_MANY_REQUESTS,
"Unexpected status code: {}",
status
);
}
}
}