use crate::{AuthError, AuthFramework, AuthToken};
use axum::{
Json, Router,
extract::{FromRequestParts, Request, State},
http::{StatusCode, header::AUTHORIZATION, request::Parts},
middleware::Next,
response::{IntoResponse, Response},
routing::post,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone)]
pub struct RequireAuth {
pub required_permissions: Vec<String>,
pub required_roles: Vec<String>,
}
#[derive(Clone)]
pub struct RequirePermission {
pub permission: String,
pub resource: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AuthenticatedUser {
pub user_id: String,
pub permissions: Vec<String>,
pub roles: Vec<String>,
pub token: AuthToken,
}
pub struct AuthRouter {
login_path: String,
logout_path: String,
refresh_path: String,
profile_path: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: u64,
pub user: UserInfo,
}
#[derive(Debug, Serialize)]
pub struct UserInfo {
pub id: String,
pub username: Option<String>,
pub email: Option<String>,
pub roles: Vec<String>,
}
pub fn protected<F, T>(handler: F) -> ProtectedHandler<F>
where
F: Clone,
{
ProtectedHandler::new(handler)
}
#[derive(Clone)]
pub struct ProtectedHandler<F> {
handler: F,
required_permissions: Vec<String>,
required_roles: Vec<String>,
}
impl RequireAuth {
pub fn new() -> Self {
Self {
required_permissions: Vec::new(),
required_roles: Vec::new(),
}
}
pub fn with_permissions(mut self, permissions: &[&str]) -> Self {
self.required_permissions = permissions.iter().map(|p| p.to_string()).collect();
self
}
pub fn with_roles(mut self, roles: &[&str]) -> Self {
self.required_roles = roles.iter().map(|r| r.to_string()).collect();
self
}
}
impl RequirePermission {
pub fn new(permission: impl Into<String>) -> Self {
Self {
permission: permission.into(),
resource: None,
}
}
pub fn for_resource(mut self, resource: impl Into<String>) -> Self {
self.resource = Some(resource.into());
self
}
}
impl<F> ProtectedHandler<F> {
fn new(handler: F) -> Self {
Self {
handler,
required_permissions: Vec::new(),
required_roles: Vec::new(),
}
}
pub fn require_permissions(mut self, permissions: &[&str]) -> Self {
self.required_permissions = permissions.iter().map(|p| p.to_string()).collect();
self
}
pub fn require_roles(mut self, roles: &[&str]) -> Self {
self.required_roles = roles.iter().map(|r| r.to_string()).collect();
self
}
pub fn require_permission(mut self, permission: &str) -> Self {
self.required_permissions = vec![permission.to_string()];
self
}
pub fn require_role(mut self, role: &str) -> Self {
self.required_roles = vec![role.to_string()];
self
}
}
impl AuthRouter {
pub fn new() -> Self {
Self {
login_path: "/auth/login".to_string(),
logout_path: "/auth/logout".to_string(),
refresh_path: "/auth/refresh".to_string(),
profile_path: "/auth/profile".to_string(),
}
}
pub fn login_route(mut self, path: impl Into<String>) -> Self {
self.login_path = path.into();
self
}
pub fn logout_route(mut self, path: impl Into<String>) -> Self {
self.logout_path = path.into();
self
}
pub fn refresh_route(mut self, path: impl Into<String>) -> Self {
self.refresh_path = path.into();
self
}
pub fn profile_route(mut self, path: impl Into<String>) -> Self {
self.profile_path = path.into();
self
}
pub fn build(self) -> Router<Arc<AuthFramework>> {
Router::new()
.route(&self.login_path, post(login_handler))
.route(&self.refresh_path, post(refresh_handler))
}
}
async fn login_handler(
State(auth): State<Arc<AuthFramework>>,
Json(request): Json<LoginRequest>,
) -> Result<impl IntoResponse, AuthError> {
let token = auth
.create_auth_token(
&request.username,
vec!["read".to_string(), "write".to_string()],
"jwt",
None,
)
.await?;
let response = LoginResponse {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_in: 3600, user: UserInfo {
id: token.user_id.clone(),
username: Some(request.username),
email: None,
roles: token.roles.clone(),
},
};
Ok(Json(response))
}
async fn logout_handler(
State(_auth): State<Arc<AuthFramework>>,
user: AuthenticatedUser,
) -> Result<impl IntoResponse, AuthError> {
tracing::info!("User {} logged out", user.user_id);
Ok(Json(
serde_json::json!({"message": "Successfully logged out"}),
))
}
async fn refresh_handler(
State(_auth): State<Arc<AuthFramework>>,
) -> Result<impl IntoResponse, AuthError> {
Ok(Json(
serde_json::json!({"message": "Token refresh not implemented"}),
))
}
async fn profile_handler(user: AuthenticatedUser) -> Result<impl IntoResponse, AuthError> {
Ok(Json(UserInfo {
id: user.user_id,
username: None, email: None, roles: user.roles,
}))
}
pub async fn auth_middleware(
State(auth): State<Arc<AuthFramework>>,
mut request: Request,
next: Next,
) -> Result<Response, AuthError> {
let token_str = extract_bearer_token(&request)?;
match auth.token_manager().validate_jwt_token(&token_str) {
Ok(_claims) => {
request.extensions_mut().insert(token_str);
Ok(next.run(request).await)
}
Err(e) => Err(e),
}
}
fn extract_bearer_token(request: &Request) -> Result<String, AuthError> {
let auth_header = request
.headers()
.get(AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| AuthError::Token(crate::errors::TokenError::Missing))?;
if auth_header.starts_with("Bearer ") {
Ok(auth_header[7..].to_string())
} else {
Err(AuthError::Token(crate::errors::TokenError::Invalid {
message: "Authorization header must use Bearer scheme".to_string(),
}))
}
}
impl<S> FromRequestParts<S> for AuthenticatedUser
where
S: Send + Sync,
Arc<AuthFramework>: FromRequestParts<S>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let _auth = Arc::<AuthFramework>::from_request_parts(parts, state)
.await
.map_err(|_| AuthError::internal("Failed to extract auth framework from state"))?;
let token_str = extract_bearer_token_from_parts(parts)?;
let user_id = "demo_user".to_string(); let permissions = vec!["read".to_string(), "write".to_string()]; let roles = vec!["user".to_string()];
let token = AuthToken {
token_id: "demo_token_id".to_string(),
user_id: user_id.clone(),
access_token: token_str,
token_type: Some("Bearer".to_string()),
subject: Some(user_id.clone()),
issuer: Some("auth-framework".to_string()),
refresh_token: None,
issued_at: chrono::Utc::now(),
expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
scopes: vec!["read".to_string(), "write".to_string()],
auth_method: "jwt".to_string(),
client_id: None,
user_profile: None,
permissions: permissions.clone(),
roles: roles.clone(),
metadata: crate::tokens::TokenMetadata::default(),
};
Ok(AuthenticatedUser {
user_id,
permissions,
roles,
token,
})
}
}
fn extract_bearer_token_from_parts(parts: &Parts) -> Result<String, AuthError> {
let auth_header = parts
.headers
.get(AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| AuthError::Token(crate::errors::TokenError::Missing))?;
if auth_header.starts_with("Bearer ") {
Ok(auth_header[7..].to_string())
} else {
Err(AuthError::Token(crate::errors::TokenError::Invalid {
message: "Authorization header must use Bearer scheme".to_string(),
}))
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, message) = match &self {
AuthError::Token(_) => (StatusCode::UNAUTHORIZED, "Authentication required"),
AuthError::Permission(_) => (StatusCode::FORBIDDEN, "Insufficient permissions"),
AuthError::RateLimit { .. } => (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded"),
AuthError::Configuration { .. } | AuthError::Storage(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error")
}
_ => (StatusCode::BAD_REQUEST, "Bad request"),
};
let body = Json(serde_json::json!({
"error": message,
"details": self.to_string()
}));
(status, body).into_response()
}
}
pub trait AuthRouterExt<S> {
fn require_auth(self) -> Self;
fn require_permission(self, permission: &str) -> Self;
fn require_role(self, role: &str) -> Self;
}
impl<S> AuthRouterExt<S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn require_auth(self) -> Self {
self.layer(axum::middleware::from_fn_with_state(
(), |_state: (), request: axum::extract::Request, next: axum::middleware::Next| async move {
next.run(request).await
},
))
}
fn require_permission(self, _permission: &str) -> Self {
self
}
fn require_role(self, _role: &str) -> Self {
self
}
}
pub use RequireAuth as AuthMiddleware;