use crate::auth::{provider::AuthProvider, token::TokenExtractor};
use crate::error::TidewayError;
use axum::{extract::Request, middleware::Next, response::Response};
use std::marker::PhantomData;
pub struct RequireAuth<P: AuthProvider> {
_provider: PhantomData<P>,
}
impl<P: AuthProvider> RequireAuth<P> {
pub async fn middleware(request: Request, next: Next) -> Result<Response, TidewayError> {
let provider = request
.extensions()
.get::<P>()
.ok_or_else(|| TidewayError::internal("Auth provider not found in request extensions"))?
.clone();
let (parts, body) = request.into_parts();
let token = TokenExtractor::from_header(&parts)?;
let claims = provider.verify_token(&token).await?;
let user = provider.load_user(&claims).await?;
provider.validate_user(&user).await?;
let mut request = Request::from_parts(parts, body);
request.extensions_mut().insert(user);
Ok(next.run(request).await)
}
}
pub struct AuthLayer<P: AuthProvider> {
provider: P,
}
impl<P: AuthProvider> AuthLayer<P> {
pub fn new(provider: P) -> Self {
Self { provider }
}
pub async fn middleware(&self, mut request: Request, next: Next) -> Response {
request.extensions_mut().insert(self.provider.clone());
next.run(request).await
}
}
impl<P: AuthProvider> Clone for AuthLayer<P> {
fn clone(&self) -> Self {
Self {
provider: self.provider.clone(),
}
}
}