use crate::types::{AuthProvider, HttpRequestParts, RequestParts};
use axum::{
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use micromegas_tracing::prelude::*;
use std::sync::Arc;
pub async fn auth_middleware(
auth_provider: Arc<dyn AuthProvider>,
mut req: Request,
next: Next,
) -> Result<Response, AuthError> {
let parts = HttpRequestParts {
headers: req.headers().clone(),
method: req.method().clone(),
uri: req.uri().clone(),
};
let auth_ctx = auth_provider
.validate_request(&parts as &dyn RequestParts)
.await
.map_err(|e| {
warn!("authentication failed: {e}");
AuthError::InvalidToken
})?;
info!(
"authenticated: subject={} email={:?} issuer={} admin={}",
auth_ctx.subject, auth_ctx.email, auth_ctx.issuer, auth_ctx.is_admin
);
req.extensions_mut().insert(auth_ctx);
Ok(next.run(req).await)
}
#[derive(Debug)]
pub enum AuthError {
InvalidToken,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, message) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
};
(status, message).into_response()
}
}