Skip to main content

auth_client_axum/
lib.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Context};
4use axum::{
5    body::Body,
6    http::{HeaderMap, HeaderValue, Request, StatusCode},
7    middleware::Next,
8    response::Response,
9    Extension, Json,
10};
11use serde::de::DeserializeOwned;
12use serde_json::Value;
13
14pub struct RequestUser {
15    pub id: String,
16}
17pub type RequestUserExtension = Extension<Arc<RequestUser>>;
18
19type ResponseResult<T> = Result<T, (StatusCode, String)>;
20
21pub async fn get_user_route(
22    auth: Extension<Arc<AuthClient>>,
23    headers: HeaderMap,
24) -> ResponseResult<Json<Value>> {
25    let token = headers.get("authorization").ok_or((
26        StatusCode::UNAUTHORIZED,
27        String::from("request does not contain authorization header"),
28    ))?;
29    let user = auth.get_user(token).await.map_err(handle_anyhow_error)?;
30    Ok(Json(user))
31}
32
33pub async fn exchange_retrieval_token_route(
34    auth: Extension<Arc<AuthClient>>,
35    headers: HeaderMap,
36) -> ResponseResult<String> {
37    let token = headers.get("authorization").ok_or((
38        StatusCode::UNAUTHORIZED,
39        String::from("request does not contain authorization header"),
40    ))?;
41    auth.exchange_retrieval_token(&token)
42        .await
43        .map_err(handle_anyhow_error)
44}
45
46// this will attach a RequestUser to the request, it will contain the user_id
47pub async fn auth_request(
48    mut req: Request<Body>,
49    next: Next<Body>,
50) -> ResponseResult<Response> {
51    let auth = req.extensions().get::<Arc<AuthClient>>();
52    if auth.is_none() {
53        eprintln!("auth extension not attached correctly, exiting program");
54        std::process::exit(1)
55    }
56    let user_id = auth.unwrap().authenticate_req(&req).await.map_err(handle_anyhow_error)?;
57    let req_user = Arc::new(RequestUser {
58        id: user_id
59    });
60    drop(auth);
61    req.extensions_mut().insert(req_user);
62    Ok(next.run(req).await)
63}
64
65pub async fn auth_admin_request(
66    req: Request<Body>,
67    next: Next<Body>,
68) -> Result<Response, (StatusCode, String)> {
69    let auth = req.extensions().get::<Arc<AuthClient>>();
70    if auth.is_none() {
71        eprintln!("auth extension not attached correctly, exiting program");
72        std::process::exit(1)
73    }
74    let token = req.headers().get("authorization").ok_or((
75        StatusCode::UNAUTHORIZED,
76        String::from("request does not contain authorization header"),
77    ))?;
78    let user: Value = auth
79        .unwrap()
80        .get_user(token)
81        .await
82        .map_err(handle_anyhow_error)?;
83    let is_admin = user
84        .as_object()
85        .ok_or(anyhow!("user is not object"))
86        .map_err(handle_anyhow_error)?
87        .get("admin")
88        .ok_or((
89            StatusCode::UNAUTHORIZED,
90            String::from("user does not contain 'admin' field"),
91        ))?
92        .as_bool()
93        .unwrap_or(false);
94    if is_admin {
95        Ok(next.run(req).await)
96    } else {
97        Err((StatusCode::UNAUTHORIZED, String::from("user is not admin")))
98    }
99}
100
101pub struct AuthClient {
102    client: reqwest::Client,
103    auth_host: String,
104}
105
106impl AuthClient {
107    pub fn new(auth_host: impl Into<String>, timeout_secs: u64) -> AuthClient {
108        AuthClient {
109            client: reqwest::Client::builder()
110                .timeout(Duration::from_secs(timeout_secs))
111                .build()
112                .expect("failed to create request client for auth client"),
113            auth_host: auth_host.into(),
114        }
115    }
116
117    pub async fn authenticate_req(&self, req: &Request<Body>) -> anyhow::Result<String> {
118        let token = req
119            .headers()
120            .get("authorization")
121            .ok_or(anyhow!("request does not contain authorization header"))?;
122        let res = self
123            .client
124            .get(&format!("{}/api/auth_jwt", self.auth_host))
125            .header("authorization", token)
126            .send()
127            .await
128            .context("failed at request to auth jwt")?;
129        let status = res.status();
130        if status == StatusCode::OK {
131            let user_id = res.text().await.context(format!("failed to extract user id from auth jwt response body"))?;
132            Ok(user_id)
133        } else {
134            let text = res.text().await.context(format!("status: {status} | failed to authenticate request"))?;
135            Err(anyhow!("status: {status} | failed to authenticate request | {text}"))
136        }
137    }
138
139    pub async fn get_user<User: DeserializeOwned>(
140        &self,
141        token: &HeaderValue,
142    ) -> anyhow::Result<User> {
143        let res = self
144            .client
145            .get(&format!("{}/api/user", self.auth_host))
146            .header("authorization", token)
147            .send()
148            .await
149            .context("failed to get user from auth server")?;
150        let status = res.status();
151        if status == StatusCode::OK {
152            let user: User = res.json().await.context("failed to parse user type")?;
153            Ok(user)
154        } else {
155            let text = res.text().await.context(format!(
156                "status: {status} | failed to get user from auth server"
157            ))?;
158            Err(anyhow!(
159                "status: {status} | failed to get user from auth server | {text}"
160            ))
161        }
162    }
163
164    pub async fn exchange_retrieval_token(&self, token: &HeaderValue) -> anyhow::Result<String> {
165        let res = self
166            .client
167            .get(&format!("{}/api/exchange", self.auth_host))
168            .header("authorization", token)
169            .send()
170            .await
171            .context("failed to exchange retrieval token from auth server")?;
172        let status = res.status();
173        if status == StatusCode::OK {
174            let jwt = res
175                .text()
176                .await
177                .context("failed to extract body of exchange token response")?;
178            Ok(jwt)
179        } else {
180            let text = res.text().await.context(format!(
181                "status: {status} | failed to exchange retrieval token for jwt"
182            ))?;
183            Err(anyhow!(
184                "status: {status} | failed to exchange retrieval token for jwt | {text}"
185            ))
186        }
187    }
188}
189
190fn handle_anyhow_error(e: anyhow::Error) -> (StatusCode, String) {
191    (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}"))
192}