use axum::{
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::Response,
Json,
};
use serde::Serialize;
use crate::portal::auth::{AuthService, Claims};
#[derive(Debug, Serialize)]
pub struct AuthErrorResponse {
pub error: String,
pub code: String,
}
fn extract_token(req: &Request) -> Option<&str> {
req.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.strip_prefix("Bearer "))
}
pub async fn require_auth(
req: Request,
next: Next,
) -> Result<Response, (StatusCode, Json<AuthErrorResponse>)> {
let token = extract_token(&req).ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(AuthErrorResponse {
error: "Missing or invalid Authorization header".to_string(),
code: "MISSING_TOKEN".to_string(),
}),
)
})?;
let auth_service = AuthService::new(Default::default());
let claims = auth_service.validate_token(token).map_err(|e| {
(
StatusCode::UNAUTHORIZED,
Json(AuthErrorResponse {
error: e.to_string(),
code: "INVALID_TOKEN".to_string(),
}),
)
})?;
if claims.token_type != crate::portal::auth::TokenType::Access {
return Err((
StatusCode::UNAUTHORIZED,
Json(AuthErrorResponse {
error: "Invalid token type".to_string(),
code: "WRONG_TOKEN_TYPE".to_string(),
}),
));
}
let mut req = req;
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
}
pub async fn optional_auth(req: Request, next: Next) -> Response {
if let Some(token) = extract_token(&req) {
let auth_service = AuthService::new(Default::default());
if let Ok(claims) = auth_service.validate_token(token) {
if claims.token_type == crate::portal::auth::TokenType::Access {
let mut req = req;
req.extensions_mut().insert(claims);
return next.run(req).await;
}
}
}
next.run(req).await
}
#[allow(clippy::type_complexity)]
pub fn require_scope(
required_scope: &'static str,
) -> impl Fn(
Request,
Next,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<Response, (StatusCode, Json<AuthErrorResponse>)>>
+ Send,
>,
> + Clone {
move |req: Request, next: Next| {
Box::pin(async move {
let claims = req.extensions().get::<Claims>().ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(AuthErrorResponse {
error: "Not authenticated".to_string(),
code: "NOT_AUTHENTICATED".to_string(),
}),
)
})?;
if !claims
.scopes
.iter()
.any(|s| s == required_scope || s == "admin")
{
return Err((
StatusCode::FORBIDDEN,
Json(AuthErrorResponse {
error: format!("Missing required scope: {}", required_scope),
code: "INSUFFICIENT_SCOPE".to_string(),
}),
));
}
Ok(next.run(req).await)
})
}
}
pub fn get_claims(req: &Request) -> Option<&Claims> {
req.extensions().get::<Claims>()
}
#[derive(Debug, Clone)]
pub struct AuthClaims(pub Claims);
#[axum::async_trait]
impl<S> axum::extract::FromRequestParts<S> for AuthClaims
where
S: Send + Sync,
{
type Rejection = (StatusCode, Json<AuthErrorResponse>);
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Claims>()
.cloned()
.map(AuthClaims)
.ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(AuthErrorResponse {
error: "Not authenticated".to_string(),
code: "NOT_AUTHENTICATED".to_string(),
}),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
#[test]
fn test_extract_token_valid() {
let req = Request::builder()
.header("Authorization", "Bearer test_token_123")
.body(Body::empty())
.unwrap();
assert_eq!(extract_token(&req), Some("test_token_123"));
}
#[test]
fn test_extract_token_missing() {
let req = Request::builder().body(Body::empty()).unwrap();
assert_eq!(extract_token(&req), None);
}
#[test]
fn test_extract_token_invalid_format() {
let req = Request::builder()
.header("Authorization", "Basic user:pass")
.body(Body::empty())
.unwrap();
assert_eq!(extract_token(&req), None);
}
}