use crate::auth::{provider::AuthProvider, token::TokenExtractor};
use crate::error::TidewayError;
use axum::{extract::FromRequestParts, http::request::Parts};
use std::future::Future;
pub struct AuthUser<P: AuthProvider>(pub P::User);
impl<P, S> FromRequestParts<S> for AuthUser<P>
where
P: AuthProvider,
S: Send + Sync,
{
type Rejection = TidewayError;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
Box::pin(async move {
let provider = parts
.extensions
.get::<P>()
.ok_or_else(|| {
TidewayError::internal("Auth provider not found in request extensions")
})?
.clone();
#[cfg(feature = "test-auth-bypass")]
{
if let Some(_test_user_id) = parts.headers.get("X-Test-User-Id") {
return Err(TidewayError::internal(
"Test auth bypass requires Claims to implement Default trait. \
Implement Default for your Claims type or disable test-auth-bypass feature."
));
}
}
let token = TokenExtractor::from_header(parts)?;
let claims = provider.verify_token(&token).await?;
let user = provider.load_user(&claims).await?;
provider.validate_user(&user).await?;
Ok(AuthUser(user))
})
}
}
pub struct OptionalAuth<P: AuthProvider>(pub Option<P::User>);
impl<P, S> FromRequestParts<S> for OptionalAuth<P>
where
P: AuthProvider,
S: Send + Sync,
{
type Rejection = TidewayError;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
Box::pin(async move {
let provider = match parts.extensions.get::<P>() {
Some(p) => p.clone(),
None => return Ok(OptionalAuth(None)),
};
let token = match TokenExtractor::from_header(parts) {
Ok(t) => t,
Err(_) => return Ok(OptionalAuth(None)),
};
match provider.verify_token(&token).await {
Ok(claims) => match provider.load_user(&claims).await {
Ok(user) => {
if provider.validate_user(&user).await.is_ok() {
Ok(OptionalAuth(Some(user)))
} else {
Ok(OptionalAuth(None))
}
}
Err(_) => Ok(OptionalAuth(None)),
},
Err(_) => Ok(OptionalAuth(None)),
}
})
}
}
pub struct Claims<P: AuthProvider>(pub P::Claims);
impl<P, S> FromRequestParts<S> for Claims<P>
where
P: AuthProvider,
S: Send + Sync,
{
type Rejection = TidewayError;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
Box::pin(async move {
let provider = parts
.extensions
.get::<P>()
.ok_or_else(|| {
TidewayError::internal("Auth provider not found in request extensions")
})?
.clone();
let token = TokenExtractor::from_header(parts)?;
let claims = provider.verify_token(&token).await?;
Ok(Claims(claims))
})
}
}