auth_client_axum 0.1.6

client to integrate axum servers with our auth service
Documentation
use std::{sync::Arc, time::Duration};

use anyhow::{anyhow, Context};
use axum::{
    body::Body,
    http::{HeaderMap, HeaderValue, Request, StatusCode},
    middleware::Next,
    response::Response,
    Extension, Json,
};
use serde::de::DeserializeOwned;
use serde_json::Value;

pub struct RequestUser {
    pub id: String,
}
pub type RequestUserExtension = Extension<Arc<RequestUser>>;

type ResponseResult<T> = Result<T, (StatusCode, String)>;

pub async fn get_user_route(
    auth: Extension<Arc<AuthClient>>,
    headers: HeaderMap,
) -> ResponseResult<Json<Value>> {
    let token = headers.get("authorization").ok_or((
        StatusCode::UNAUTHORIZED,
        String::from("request does not contain authorization header"),
    ))?;
    let user = auth.get_user(token).await.map_err(handle_anyhow_error)?;
    Ok(Json(user))
}

pub async fn exchange_retrieval_token_route(
    auth: Extension<Arc<AuthClient>>,
    headers: HeaderMap,
) -> ResponseResult<String> {
    let token = headers.get("authorization").ok_or((
        StatusCode::UNAUTHORIZED,
        String::from("request does not contain authorization header"),
    ))?;
    auth.exchange_retrieval_token(&token)
        .await
        .map_err(handle_anyhow_error)
}

// this will attach a RequestUser to the request, it will contain the user_id
pub async fn auth_request(
    mut req: Request<Body>,
    next: Next<Body>,
) -> ResponseResult<Response> {
    let auth = req.extensions().get::<Arc<AuthClient>>();
    if auth.is_none() {
        eprintln!("auth extension not attached correctly, exiting program");
        std::process::exit(1)
    }
    let user_id = auth.unwrap().authenticate_req(&req).await.map_err(handle_anyhow_error)?;
    let req_user = Arc::new(RequestUser {
        id: user_id
    });
    drop(auth);
    req.extensions_mut().insert(req_user);
    Ok(next.run(req).await)
}

pub async fn auth_admin_request(
    req: Request<Body>,
    next: Next<Body>,
) -> Result<Response, (StatusCode, String)> {
    let auth = req.extensions().get::<Arc<AuthClient>>();
    if auth.is_none() {
        eprintln!("auth extension not attached correctly, exiting program");
        std::process::exit(1)
    }
    let token = req.headers().get("authorization").ok_or((
        StatusCode::UNAUTHORIZED,
        String::from("request does not contain authorization header"),
    ))?;
    let user: Value = auth
        .unwrap()
        .get_user(token)
        .await
        .map_err(handle_anyhow_error)?;
    let is_admin = user
        .as_object()
        .ok_or(anyhow!("user is not object"))
        .map_err(handle_anyhow_error)?
        .get("admin")
        .ok_or((
            StatusCode::UNAUTHORIZED,
            String::from("user does not contain 'admin' field"),
        ))?
        .as_bool()
        .unwrap_or(false);
    if is_admin {
        Ok(next.run(req).await)
    } else {
        Err((StatusCode::UNAUTHORIZED, String::from("user is not admin")))
    }
}

pub struct AuthClient {
    client: reqwest::Client,
    auth_host: String,
}

impl AuthClient {
    pub fn new(auth_host: impl Into<String>, timeout_secs: u64) -> AuthClient {
        AuthClient {
            client: reqwest::Client::builder()
                .timeout(Duration::from_secs(timeout_secs))
                .build()
                .expect("failed to create request client for auth client"),
            auth_host: auth_host.into(),
        }
    }

    pub async fn authenticate_req(&self, req: &Request<Body>) -> anyhow::Result<String> {
        let token = req
            .headers()
            .get("authorization")
            .ok_or(anyhow!("request does not contain authorization header"))?;
        let res = self
            .client
            .get(&format!("{}/api/auth_jwt", self.auth_host))
            .header("authorization", token)
            .send()
            .await
            .context("failed at request to auth jwt")?;
        let status = res.status();
        if status == StatusCode::OK {
            let user_id = res.text().await.context(format!("failed to extract user id from auth jwt response body"))?;
            Ok(user_id)
        } else {
            let text = res.text().await.context(format!("status: {status} | failed to authenticate request"))?;
            Err(anyhow!("status: {status} | failed to authenticate request | {text}"))
        }
    }

    pub async fn get_user<User: DeserializeOwned>(
        &self,
        token: &HeaderValue,
    ) -> anyhow::Result<User> {
        let res = self
            .client
            .get(&format!("{}/api/user", self.auth_host))
            .header("authorization", token)
            .send()
            .await
            .context("failed to get user from auth server")?;
        let status = res.status();
        if status == StatusCode::OK {
            let user: User = res.json().await.context("failed to parse user type")?;
            Ok(user)
        } else {
            let text = res.text().await.context(format!(
                "status: {status} | failed to get user from auth server"
            ))?;
            Err(anyhow!(
                "status: {status} | failed to get user from auth server | {text}"
            ))
        }
    }

    pub async fn exchange_retrieval_token(&self, token: &HeaderValue) -> anyhow::Result<String> {
        let res = self
            .client
            .get(&format!("{}/api/exchange", self.auth_host))
            .header("authorization", token)
            .send()
            .await
            .context("failed to exchange retrieval token from auth server")?;
        let status = res.status();
        if status == StatusCode::OK {
            let jwt = res
                .text()
                .await
                .context("failed to extract body of exchange token response")?;
            Ok(jwt)
        } else {
            let text = res.text().await.context(format!(
                "status: {status} | failed to exchange retrieval token for jwt"
            ))?;
            Err(anyhow!(
                "status: {status} | failed to exchange retrieval token for jwt | {text}"
            ))
        }
    }
}

fn handle_anyhow_error(e: anyhow::Error) -> (StatusCode, String) {
    (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}"))
}