1use std::{marker::PhantomData, sync::Arc};
2
3use axum::extract::FromRequestParts;
4use hmac::{Hmac, Mac};
5use http::request::Parts;
6use jwt::{FromBase64, SignWithKey, VerifyWithKey};
7use serde::{de::DeserializeOwned, Serialize};
8use sha2::Sha256;
9
10use crate::errors::{ApiError, ApiResult};
11
12pub struct AuthConfig<T: Serialize + DeserializeOwned + FromBase64> {
13 key: Hmac<Sha256>,
14 prefix: String,
15 _t: PhantomData<T>,
16}
17
18impl<T: Serialize + DeserializeOwned + FromBase64> AuthConfig<T> {
19 pub fn new(key: &[u8]) -> Self {
20 AuthConfig {
21 key: Hmac::new_from_slice(key).unwrap(),
22 prefix: "Token ".to_string(),
23 _t: PhantomData,
24 }
25 }
26
27 pub fn with_prefix(mut self, mut prefix: String) -> Self {
28 if !prefix.is_empty() {
29 prefix.push(' ');
30 }
31 self.prefix = prefix;
32 self
33 }
34
35 pub fn sign(&self, value: &T) -> ApiResult<String> {
36 Ok(value.sign_with_key(&self.key)?)
37 }
38
39 pub fn validate(&self, value: &str) -> ApiResult<T> {
40 let out = value
41 .verify_with_key(&self.key)
42 .map_err(|_| ApiError::Unauthorized("malformed auth token".to_string()))?;
43
44 Ok(out)
45 }
46}
47
48#[async_trait::async_trait]
49pub trait AuthParam<T: Serialize + DeserializeOwned + FromBase64> {
50 fn config() -> Arc<AuthConfig<T>>;
51
52 async fn authenticated(req: &mut Parts, arg: &T) -> ApiResult<()>;
53}
54
55pub struct Auth<T: Serialize + DeserializeOwned + FromBase64, P: AuthParam<T>>(
56 pub T,
57 pub PhantomData<P>,
58);
59
60#[async_trait::async_trait]
61impl<
62 T: Serialize + DeserializeOwned + FromBase64 + Send + Sync,
63 P: AuthParam<T>,
64 S: Send + Sync,
65 > FromRequestParts<S> for Auth<T, P>
66{
67 type Rejection = ApiError;
68
69 async fn from_request_parts(req: &mut Parts, _state: &S) -> ApiResult<Self> {
70 let Some(auth) = req.headers.get("Authorization") else {
71 return Err(ApiError::Unauthorized("missing auth token".to_string()));
72 };
73 let config = P::config();
74 let auth = auth.to_str()?;
75 let Some(auth) = auth.strip_prefix(&config.prefix).map(|x| x.trim()) else {
76 return Err(ApiError::Unauthorized("malformed auth token".to_string()));
77 };
78
79 let out = P::config().validate(auth)?;
80 P::authenticated(req, &out).await?;
81 Ok(Self(out, PhantomData))
82 }
83}