1use std::sync::Arc;
7
8use axum::{body::HttpBody, extract::Request};
9use futures_core::future::BoxFuture;
10use http::{Response, StatusCode, header::AUTHORIZATION};
11use pib_service_api_auth::ApiAuth;
12use tower_http::auth::AsyncAuthorizeRequest;
13
14#[derive(Debug, Clone)]
15pub struct Auth(Arc<dyn ApiAuth>);
16
17impl Auth {
18 pub fn new(auth: Arc<dyn ApiAuth>) -> Self {
19 Self(auth)
20 }
21}
22
23impl<B: HttpBody + Send + 'static> AsyncAuthorizeRequest<B> for Auth {
24 type RequestBody = B;
25 type ResponseBody = axum::body::Body;
26 type Future =
27 BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
28
29 fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
30 let unauthorized_response = Response::builder()
31 .status(StatusCode::UNAUTHORIZED)
32 .body(axum::body::Body::empty())
33 .unwrap();
34
35 let Some(header) = request.headers().get(AUTHORIZATION).cloned() else {
36 return Box::pin(async move { Err(unauthorized_response) });
37 };
38
39 let auth = self.0.clone();
40
41 Box::pin(async move {
42 match auth.authorize(header).await {
43 Ok(Some(user_info)) => {
44 request.extensions_mut().insert(user_info);
45
46 Ok(request)
47 }
48 Ok(None) => Err(unauthorized_response),
49 Err(e) => Err(Response::builder()
50 .status(StatusCode::INTERNAL_SERVER_ERROR)
51 .body(format!("Error while checking authentication: {e}").into())
52 .unwrap()),
53 }
54 })
55 }
56}