use std::sync::Arc;
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use crate::api::errors::{api_error, ApiErrorCode};
use crate::auth::{AuthOutcome, AuthProvider};
use crate::server::AppState;
pub async fn require_authenticated_principal(
State(state): State<Arc<AppState>>,
mut req: Request,
next: Next,
) -> Result<Response, (StatusCode, axum::Json<serde_json::Value>)> {
let token = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(str::to_owned)
.ok_or_else(|| {
api_error(
StatusCode::UNAUTHORIZED,
ApiErrorCode::Unauthenticated,
"missing Bearer token",
)
})?;
let outcome = state
.auth_provider
.authenticate(&token)
.await
.map_err(|e| {
api_error(
StatusCode::INTERNAL_SERVER_ERROR,
ApiErrorCode::InternalError,
format!("auth backend error: {e}"),
)
})?;
let principal = match outcome {
AuthOutcome::Authenticated(p) => p,
AuthOutcome::Unauthenticated => {
return Err(api_error(
StatusCode::UNAUTHORIZED,
ApiErrorCode::Unauthenticated,
"invalid or unknown token",
));
}
AuthOutcome::Revoked { .. } => {
return Err(api_error(
StatusCode::UNAUTHORIZED,
ApiErrorCode::Unauthenticated,
"token revoked",
));
}
AuthOutcome::Expired { .. } => {
return Err(api_error(
StatusCode::UNAUTHORIZED,
ApiErrorCode::Unauthenticated,
"token expired",
));
}
};
req.extensions_mut().insert(principal);
Ok(next.run(req).await)
}