use std::marker::PhantomData;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::http::header::{AUTHORIZATION, COOKIE};
use axum::http::{HeaderMap, Request, Response, StatusCode};
use axum::response::IntoResponse;
use axum_extra::extract::cookie::CookieJar;
use tower::{Layer, Service};
use super::config::BearerAuthConfig;
use super::error::VerifyError;
use super::port::AuthProvider;
pub struct BearerAuthLayer<Sess, P: ?Sized>
where
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
provider: Arc<P>,
config: BearerAuthConfig,
_session: PhantomData<fn() -> Sess>,
}
impl<Sess, P: ?Sized> BearerAuthLayer<Sess, P>
where
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
#[must_use]
pub fn new(provider: Arc<P>, config: BearerAuthConfig) -> Self {
Self {
provider,
config,
_session: PhantomData,
}
}
}
impl<Sess, P: ?Sized> Clone for BearerAuthLayer<Sess, P>
where
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
fn clone(&self) -> Self {
Self {
provider: self.provider.clone(),
config: self.config.clone(),
_session: PhantomData,
}
}
}
impl<Inner, Sess, P: ?Sized> Layer<Inner> for BearerAuthLayer<Sess, P>
where
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
type Service = BearerAuthService<Inner, Sess, P>;
fn layer(&self, inner: Inner) -> Self::Service {
BearerAuthService {
inner,
provider: self.provider.clone(),
config: self.config.clone(),
_session: PhantomData,
}
}
}
pub struct BearerAuthService<Inner, Sess, P: ?Sized>
where
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
inner: Inner,
provider: Arc<P>,
config: BearerAuthConfig,
_session: PhantomData<fn() -> Sess>,
}
impl<Inner, Sess, P: ?Sized> Clone for BearerAuthService<Inner, Sess, P>
where
Inner: Clone,
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
provider: self.provider.clone(),
config: self.config.clone(),
_session: PhantomData,
}
}
}
impl<Inner, Sess, P: ?Sized> Service<Request<Body>> for BearerAuthService<Inner, Sess, P>
where
Inner: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
Inner::Future: Send + 'static,
Sess: Clone + Send + Sync + 'static,
P: AuthProvider<Sess> + 'static,
{
type Response = Response<Body>;
type Error = Inner::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let provider = self.provider.clone();
let cookie_name = self.config.access_cookie_name;
let on_clear = self.config.on_clear.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let token = match extract_bearer(req.headers(), cookie_name) {
Some(t) => t,
None => return Ok(unauthenticated_response(false, &on_clear)),
};
match provider.verify_token(&token).await {
Ok(session) => {
req.extensions_mut().insert(session);
inner.call(req).await
}
Err(VerifyError::SubstrateTransient(reason)) => {
tracing::warn!(reason = %reason, "auth substrate transient — 503");
Ok(transient_response())
}
Err(VerifyError::Rejected(reason)) => {
tracing::debug!(reason = %reason, "token rejected at perimeter — 401 + clear");
Ok(unauthenticated_response(true, &on_clear))
}
}
})
}
}
fn extract_bearer(headers: &HeaderMap, cookie_name: &'static str) -> Option<String> {
if let Some(value) = headers.get(AUTHORIZATION).and_then(|v| v.to_str().ok())
&& let Some(token) = value
.strip_prefix("Bearer ")
.or_else(|| value.strip_prefix("bearer "))
&& !token.is_empty()
{
return Some(token.to_owned());
}
let cookie_header = headers.get(COOKIE).and_then(|v| v.to_str().ok())?;
cookie_header
.split(';')
.filter_map(|s| s.trim().split_once('='))
.find(|(name, _)| *name == cookie_name)
.map(|(_, value)| value.to_owned())
.filter(|v| !v.is_empty())
}
fn unauthenticated_response(
clear_cookies: bool,
on_clear: &Arc<dyn Fn(CookieJar) -> CookieJar + Send + Sync>,
) -> Response<Body> {
if clear_cookies {
let jar = on_clear(CookieJar::new());
(StatusCode::UNAUTHORIZED, jar, "unauthenticated").into_response()
} else {
(StatusCode::UNAUTHORIZED, "unauthenticated").into_response()
}
}
fn transient_response() -> Response<Body> {
(
StatusCode::SERVICE_UNAVAILABLE,
"auth substrate temporarily unavailable",
)
.into_response()
}