1use std::{sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Context};
4use axum::{
5 body::Body,
6 http::{HeaderMap, HeaderValue, Request, StatusCode},
7 middleware::Next,
8 response::Response,
9 Extension, Json,
10};
11use serde::de::DeserializeOwned;
12use serde_json::Value;
13
14pub struct RequestUser {
15 pub id: String,
16}
17pub type RequestUserExtension = Extension<Arc<RequestUser>>;
18
19type ResponseResult<T> = Result<T, (StatusCode, String)>;
20
21pub async fn get_user_route(
22 auth: Extension<Arc<AuthClient>>,
23 headers: HeaderMap,
24) -> ResponseResult<Json<Value>> {
25 let token = headers.get("authorization").ok_or((
26 StatusCode::UNAUTHORIZED,
27 String::from("request does not contain authorization header"),
28 ))?;
29 let user = auth.get_user(token).await.map_err(handle_anyhow_error)?;
30 Ok(Json(user))
31}
32
33pub async fn exchange_retrieval_token_route(
34 auth: Extension<Arc<AuthClient>>,
35 headers: HeaderMap,
36) -> ResponseResult<String> {
37 let token = headers.get("authorization").ok_or((
38 StatusCode::UNAUTHORIZED,
39 String::from("request does not contain authorization header"),
40 ))?;
41 auth.exchange_retrieval_token(&token)
42 .await
43 .map_err(handle_anyhow_error)
44}
45
46pub async fn auth_request(
48 mut req: Request<Body>,
49 next: Next<Body>,
50) -> ResponseResult<Response> {
51 let auth = req.extensions().get::<Arc<AuthClient>>();
52 if auth.is_none() {
53 eprintln!("auth extension not attached correctly, exiting program");
54 std::process::exit(1)
55 }
56 let user_id = auth.unwrap().authenticate_req(&req).await.map_err(handle_anyhow_error)?;
57 let req_user = Arc::new(RequestUser {
58 id: user_id
59 });
60 drop(auth);
61 req.extensions_mut().insert(req_user);
62 Ok(next.run(req).await)
63}
64
65pub async fn auth_admin_request(
66 req: Request<Body>,
67 next: Next<Body>,
68) -> Result<Response, (StatusCode, String)> {
69 let auth = req.extensions().get::<Arc<AuthClient>>();
70 if auth.is_none() {
71 eprintln!("auth extension not attached correctly, exiting program");
72 std::process::exit(1)
73 }
74 let token = req.headers().get("authorization").ok_or((
75 StatusCode::UNAUTHORIZED,
76 String::from("request does not contain authorization header"),
77 ))?;
78 let user: Value = auth
79 .unwrap()
80 .get_user(token)
81 .await
82 .map_err(handle_anyhow_error)?;
83 let is_admin = user
84 .as_object()
85 .ok_or(anyhow!("user is not object"))
86 .map_err(handle_anyhow_error)?
87 .get("admin")
88 .ok_or((
89 StatusCode::UNAUTHORIZED,
90 String::from("user does not contain 'admin' field"),
91 ))?
92 .as_bool()
93 .unwrap_or(false);
94 if is_admin {
95 Ok(next.run(req).await)
96 } else {
97 Err((StatusCode::UNAUTHORIZED, String::from("user is not admin")))
98 }
99}
100
101pub struct AuthClient {
102 client: reqwest::Client,
103 auth_host: String,
104}
105
106impl AuthClient {
107 pub fn new(auth_host: impl Into<String>, timeout_secs: u64) -> AuthClient {
108 AuthClient {
109 client: reqwest::Client::builder()
110 .timeout(Duration::from_secs(timeout_secs))
111 .build()
112 .expect("failed to create request client for auth client"),
113 auth_host: auth_host.into(),
114 }
115 }
116
117 pub async fn authenticate_req(&self, req: &Request<Body>) -> anyhow::Result<String> {
118 let token = req
119 .headers()
120 .get("authorization")
121 .ok_or(anyhow!("request does not contain authorization header"))?;
122 let res = self
123 .client
124 .get(&format!("{}/api/auth_jwt", self.auth_host))
125 .header("authorization", token)
126 .send()
127 .await
128 .context("failed at request to auth jwt")?;
129 let status = res.status();
130 if status == StatusCode::OK {
131 let user_id = res.text().await.context(format!("failed to extract user id from auth jwt response body"))?;
132 Ok(user_id)
133 } else {
134 let text = res.text().await.context(format!("status: {status} | failed to authenticate request"))?;
135 Err(anyhow!("status: {status} | failed to authenticate request | {text}"))
136 }
137 }
138
139 pub async fn get_user<User: DeserializeOwned>(
140 &self,
141 token: &HeaderValue,
142 ) -> anyhow::Result<User> {
143 let res = self
144 .client
145 .get(&format!("{}/api/user", self.auth_host))
146 .header("authorization", token)
147 .send()
148 .await
149 .context("failed to get user from auth server")?;
150 let status = res.status();
151 if status == StatusCode::OK {
152 let user: User = res.json().await.context("failed to parse user type")?;
153 Ok(user)
154 } else {
155 let text = res.text().await.context(format!(
156 "status: {status} | failed to get user from auth server"
157 ))?;
158 Err(anyhow!(
159 "status: {status} | failed to get user from auth server | {text}"
160 ))
161 }
162 }
163
164 pub async fn exchange_retrieval_token(&self, token: &HeaderValue) -> anyhow::Result<String> {
165 let res = self
166 .client
167 .get(&format!("{}/api/exchange", self.auth_host))
168 .header("authorization", token)
169 .send()
170 .await
171 .context("failed to exchange retrieval token from auth server")?;
172 let status = res.status();
173 if status == StatusCode::OK {
174 let jwt = res
175 .text()
176 .await
177 .context("failed to extract body of exchange token response")?;
178 Ok(jwt)
179 } else {
180 let text = res.text().await.context(format!(
181 "status: {status} | failed to exchange retrieval token for jwt"
182 ))?;
183 Err(anyhow!(
184 "status: {status} | failed to exchange retrieval token for jwt | {text}"
185 ))
186 }
187 }
188}
189
190fn handle_anyhow_error(e: anyhow::Error) -> (StatusCode, String) {
191 (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}"))
192}