use std::sync::OnceLock;
use crate::jwt::{AuthError, HasJti, JwtManager};
use chopin_core::extract::FromRequest;
use chopin_core::http::{Context, Response};
use serde::Deserialize;
pub trait ErrorHandler: Send + Sync {
fn handle(&self, err: AuthError) -> Response;
}
struct DefaultErrorHandler;
impl ErrorHandler for DefaultErrorHandler {
fn handle(&self, err: AuthError) -> Response {
match err {
AuthError::Expired | AuthError::Revoked | AuthError::InvalidToken(_) => {
Response::new(401)
}
_ => Response::server_error(),
}
}
}
static GLOBAL_ERROR_HANDLER: OnceLock<Box<dyn ErrorHandler>> = OnceLock::new();
pub fn set_error_handler(handler: impl ErrorHandler + 'static) {
if GLOBAL_ERROR_HANDLER.set(Box::new(handler)).is_err() {
panic!("ErrorHandler already set — call set_error_handler only once");
}
}
#[inline]
fn dispatch_error(err: AuthError) -> Response {
match GLOBAL_ERROR_HANDLER.get() {
Some(h) => h.handle(err),
None => DefaultErrorHandler.handle(err),
}
}
pub static GLOBAL_JWT_MANAGER: OnceLock<JwtManager> = OnceLock::new();
pub fn init_jwt_manager(manager: JwtManager) {
if GLOBAL_JWT_MANAGER.set(manager).is_err() {
panic!("JwtManager already initialised — call init_jwt_manager only once");
}
}
pub struct Auth<T> {
pub claims: T,
}
impl<'a, T> FromRequest<'a> for Auth<T>
where
T: for<'de> Deserialize<'de> + HasJti + 'static,
{
type Error = Response;
#[allow(clippy::result_large_err)]
fn from_request(ctx: &'a Context<'a>) -> Result<Self, Self::Error> {
let auth_header = (0..ctx.req.header_count as usize).find_map(|i| {
let (k, v) = ctx.req.headers[i];
k.eq_ignore_ascii_case("Authorization").then_some(v)
});
let token = auth_header
.and_then(|v| v.strip_prefix("Bearer "))
.ok_or_else(|| Response::new(401))?;
let manager = GLOBAL_JWT_MANAGER
.get()
.ok_or_else(Response::server_error)?;
let claims = manager.decode::<T>(token).map_err(dispatch_error)?;
Ok(Auth { claims })
}
}