use async_trait::async_trait;
use axum::{
extract::{FromRef, FromRequestParts},
http::{request::Parts, HeaderMap},
};
use serde::{Deserialize, Serialize};
use crate::{app::AppContext, auth, errors::Error};
const TOKEN_PREFIX: &str = "Bearer ";
const AUTH_HEADER: &str = "authorization";
#[derive(Debug, Deserialize, Serialize)]
pub struct JWT {
pub claims: auth::UserClaims,
}
#[async_trait]
impl<S> FromRequestParts<S> for JWT
where
AppContext: FromRef<S>,
S: Send + Sync,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let token = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let state: AppContext = AppContext::from_ref(state);
let jwt_secret = state.config.get_jwt_config()?;
match auth::JWT::new(&jwt_secret.secret).validate(&token) {
Ok(claims) => Ok(Self {
claims: claims.claims,
}),
Err(_err) => {
return Err(Error::Unauthorized(
"[Auth] token is not valid.".to_string(),
));
}
}
}
}
pub fn extract_token_from_header(headers: &HeaderMap) -> eyre::Result<String> {
Ok(headers
.get(AUTH_HEADER)
.ok_or_else(|| eyre::eyre!("header {} token not found", AUTH_HEADER))?
.to_str()?
.strip_prefix(TOKEN_PREFIX)
.ok_or_else(|| eyre::eyre!("error strip {} value", AUTH_HEADER))?
.to_string())
}