use std::sync::Arc;
use axum::{body::HttpBody, extract::Request};
use futures_core::future::BoxFuture;
use http::{Response, StatusCode, header::AUTHORIZATION};
use pib_service_api_auth::ApiAuth;
use tower_http::auth::AsyncAuthorizeRequest;
#[derive(Debug, Clone)]
pub struct Auth(Arc<dyn ApiAuth>);
impl Auth {
pub fn new(auth: Arc<dyn ApiAuth>) -> Self {
Self(auth)
}
}
impl<B: HttpBody + Send + 'static> AsyncAuthorizeRequest<B> for Auth {
type RequestBody = B;
type ResponseBody = axum::body::Body;
type Future =
BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
let Some(header) = request.headers().get(AUTHORIZATION).cloned() else {
return Box::pin(async move { Err(unauthorized_response) });
};
let auth = self.0.clone();
Box::pin(async move {
match auth.authorize(header).await {
Ok(Some(user_info)) => {
request.extensions_mut().insert(user_info);
Ok(request)
}
Ok(None) => Err(unauthorized_response),
Err(e) => Err(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(format!("Error while checking authentication: {e}").into())
.unwrap()),
}
})
}
}