axum_util/
auth.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use axum::extract::FromRequestParts;
4use hmac::{Hmac, Mac};
5use http::request::Parts;
6use jwt::{FromBase64, SignWithKey, VerifyWithKey};
7use serde::{de::DeserializeOwned, Serialize};
8use sha2::Sha256;
9
10use crate::errors::{ApiError, ApiResult};
11
12pub struct AuthConfig<T: Serialize + DeserializeOwned + FromBase64> {
13    key: Hmac<Sha256>,
14    prefix: String,
15    _t: PhantomData<T>,
16}
17
18impl<T: Serialize + DeserializeOwned + FromBase64> AuthConfig<T> {
19    pub fn new(key: &[u8]) -> Self {
20        AuthConfig {
21            key: Hmac::new_from_slice(key).unwrap(),
22            prefix: "Token ".to_string(),
23            _t: PhantomData,
24        }
25    }
26
27    pub fn with_prefix(mut self, mut prefix: String) -> Self {
28        if !prefix.is_empty() {
29            prefix.push(' ');
30        }
31        self.prefix = prefix;
32        self
33    }
34
35    pub fn sign(&self, value: &T) -> ApiResult<String> {
36        Ok(value.sign_with_key(&self.key)?)
37    }
38
39    pub fn validate(&self, value: &str) -> ApiResult<T> {
40        let out = value
41            .verify_with_key(&self.key)
42            .map_err(|_| ApiError::Unauthorized("malformed auth token".to_string()))?;
43
44        Ok(out)
45    }
46}
47
48#[async_trait::async_trait]
49pub trait AuthParam<T: Serialize + DeserializeOwned + FromBase64> {
50    fn config() -> Arc<AuthConfig<T>>;
51
52    async fn authenticated(req: &mut Parts, arg: &T) -> ApiResult<()>;
53}
54
55pub struct Auth<T: Serialize + DeserializeOwned + FromBase64, P: AuthParam<T>>(
56    pub T,
57    pub PhantomData<P>,
58);
59
60#[async_trait::async_trait]
61impl<
62        T: Serialize + DeserializeOwned + FromBase64 + Send + Sync,
63        P: AuthParam<T>,
64        S: Send + Sync,
65    > FromRequestParts<S> for Auth<T, P>
66{
67    type Rejection = ApiError;
68
69    async fn from_request_parts(req: &mut Parts, _state: &S) -> ApiResult<Self> {
70        let Some(auth) = req.headers.get("Authorization") else {
71            return Err(ApiError::Unauthorized("missing auth token".to_string()));
72        };
73        let config = P::config();
74        let auth = auth.to_str()?;
75        let Some(auth) = auth.strip_prefix(&config.prefix).map(|x| x.trim()) else {
76            return Err(ApiError::Unauthorized("malformed auth token".to_string()));
77        };
78
79        let out = P::config().validate(auth)?;
80        P::authenticated(req, &out).await?;
81        Ok(Self(out, PhantomData))
82    }
83}