use async_trait::async_trait;
use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, HeaderMap},
};
use serde::{Deserialize, Serialize};
use crate::{app::AppContext, auth, errors::Error, model::Authenticable};
const TOKEN_PREFIX: &str = "Bearer ";
const AUTH_HEADER: &str = "authorization";
#[derive(Debug, Deserialize, Serialize)]
pub struct JWTWithUser<T: Authenticable> {
pub claims: auth::jwt::UserClaims,
pub user: T,
}
#[async_trait]
impl<S, T> FromRequestParts<S> for JWTWithUser<T>
where
AppContext: FromRef<S>,
S: Send + Sync,
T: Authenticable,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let token = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let state: AppContext = AppContext::from_ref(state);
let jwt_secret = state.config.get_jwt_config()?;
match auth::jwt::JWT::new(&jwt_secret.secret).validate(&token) {
Ok(claims) => {
let user = T::find_by_claims_key(&state.db, &claims.claims.pid)
.await
.map_err(|_| Error::Unauthorized("token is not valid".to_string()))?;
Ok(Self {
claims: claims.claims,
user,
})
}
Err(_err) => {
return Err(Error::Unauthorized("token is not valid".to_string()));
}
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct JWT {
pub claims: auth::jwt::UserClaims,
}
#[async_trait]
impl<S> FromRequestParts<S> for JWT
where
AppContext: FromRef<S>,
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let token = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let state: AppContext = AppContext::from_ref(state);
let jwt_secret = state.config.get_jwt_config()?;
match auth::jwt::JWT::new(&jwt_secret.secret).validate(&token) {
Ok(claims) => Ok(Self {
claims: claims.claims,
}),
Err(_err) => {
return Err(Error::Unauthorized("token is not valid".to_string()));
}
}
}
}
pub fn extract_token_from_header(headers: &HeaderMap) -> eyre::Result<String> {
Ok(headers
.get(AUTH_HEADER)
.ok_or_else(|| eyre::eyre!("header {} token not found", AUTH_HEADER))?
.to_str()?
.strip_prefix(TOKEN_PREFIX)
.ok_or_else(|| eyre::eyre!("error strip {} value", AUTH_HEADER))?
.to_string())
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ApiToken<T: Authenticable> {
pub user: T,
}
#[async_trait]
impl<S, T> FromRequestParts<S> for ApiToken<T>
where
AppContext: FromRef<S>,
S: Send + Sync,
T: Authenticable,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let api_key = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let state: AppContext = AppContext::from_ref(state);
let user = T::find_by_api_key(&state.db, &api_key)
.await
.map_err(|e| Error::Unauthorized(e.to_string()))?;
Ok(Self { user })
}
}