use std::{sync::Arc, time::Duration};
use anyhow::{anyhow, Context};
use axum::{
body::Body,
http::{HeaderMap, HeaderValue, Request, StatusCode},
middleware::Next,
response::Response,
Extension, Json,
};
use serde::de::DeserializeOwned;
use serde_json::Value;
pub struct RequestUser {
pub id: String,
}
pub type RequestUserExtension = Extension<Arc<RequestUser>>;
type ResponseResult<T> = Result<T, (StatusCode, String)>;
pub async fn get_user_route(
auth: Extension<Arc<AuthClient>>,
headers: HeaderMap,
) -> ResponseResult<Json<Value>> {
let token = headers.get("authorization").ok_or((
StatusCode::UNAUTHORIZED,
String::from("request does not contain authorization header"),
))?;
let user = auth.get_user(token).await.map_err(handle_anyhow_error)?;
Ok(Json(user))
}
pub async fn exchange_retrieval_token_route(
auth: Extension<Arc<AuthClient>>,
headers: HeaderMap,
) -> ResponseResult<String> {
let token = headers.get("authorization").ok_or((
StatusCode::UNAUTHORIZED,
String::from("request does not contain authorization header"),
))?;
auth.exchange_retrieval_token(&token)
.await
.map_err(handle_anyhow_error)
}
pub async fn auth_request(
mut req: Request<Body>,
next: Next<Body>,
) -> ResponseResult<Response> {
let auth = req.extensions().get::<Arc<AuthClient>>();
if auth.is_none() {
eprintln!("auth extension not attached correctly, exiting program");
std::process::exit(1)
}
let user_id = auth.unwrap().authenticate_req(&req).await.map_err(handle_anyhow_error)?;
let req_user = Arc::new(RequestUser {
id: user_id
});
drop(auth);
req.extensions_mut().insert(req_user);
Ok(next.run(req).await)
}
pub async fn auth_admin_request(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response, (StatusCode, String)> {
let auth = req.extensions().get::<Arc<AuthClient>>();
if auth.is_none() {
eprintln!("auth extension not attached correctly, exiting program");
std::process::exit(1)
}
let token = req.headers().get("authorization").ok_or((
StatusCode::UNAUTHORIZED,
String::from("request does not contain authorization header"),
))?;
let user: Value = auth
.unwrap()
.get_user(token)
.await
.map_err(handle_anyhow_error)?;
let is_admin = user
.as_object()
.ok_or(anyhow!("user is not object"))
.map_err(handle_anyhow_error)?
.get("admin")
.ok_or((
StatusCode::UNAUTHORIZED,
String::from("user does not contain 'admin' field"),
))?
.as_bool()
.unwrap_or(false);
if is_admin {
Ok(next.run(req).await)
} else {
Err((StatusCode::UNAUTHORIZED, String::from("user is not admin")))
}
}
pub struct AuthClient {
client: reqwest::Client,
auth_host: String,
}
impl AuthClient {
pub fn new(auth_host: impl Into<String>, timeout_secs: u64) -> AuthClient {
AuthClient {
client: reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.expect("failed to create request client for auth client"),
auth_host: auth_host.into(),
}
}
pub async fn authenticate_req(&self, req: &Request<Body>) -> anyhow::Result<String> {
let token = req
.headers()
.get("authorization")
.ok_or(anyhow!("request does not contain authorization header"))?;
let res = self
.client
.get(&format!("{}/api/auth_jwt", self.auth_host))
.header("authorization", token)
.send()
.await
.context("failed at request to auth jwt")?;
let status = res.status();
if status == StatusCode::OK {
let user_id = res.text().await.context(format!("failed to extract user id from auth jwt response body"))?;
Ok(user_id)
} else {
let text = res.text().await.context(format!("status: {status} | failed to authenticate request"))?;
Err(anyhow!("status: {status} | failed to authenticate request | {text}"))
}
}
pub async fn get_user<User: DeserializeOwned>(
&self,
token: &HeaderValue,
) -> anyhow::Result<User> {
let res = self
.client
.get(&format!("{}/api/user", self.auth_host))
.header("authorization", token)
.send()
.await
.context("failed to get user from auth server")?;
let status = res.status();
if status == StatusCode::OK {
let user: User = res.json().await.context("failed to parse user type")?;
Ok(user)
} else {
let text = res.text().await.context(format!(
"status: {status} | failed to get user from auth server"
))?;
Err(anyhow!(
"status: {status} | failed to get user from auth server | {text}"
))
}
}
pub async fn exchange_retrieval_token(&self, token: &HeaderValue) -> anyhow::Result<String> {
let res = self
.client
.get(&format!("{}/api/exchange", self.auth_host))
.header("authorization", token)
.send()
.await
.context("failed to exchange retrieval token from auth server")?;
let status = res.status();
if status == StatusCode::OK {
let jwt = res
.text()
.await
.context("failed to extract body of exchange token response")?;
Ok(jwt)
} else {
let text = res.text().await.context(format!(
"status: {status} | failed to exchange retrieval token for jwt"
))?;
Err(anyhow!(
"status: {status} | failed to exchange retrieval token for jwt | {text}"
))
}
}
}
fn handle_anyhow_error(e: anyhow::Error) -> (StatusCode, String) {
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}"))
}