use axum::http::Request;
use std::fmt;
use std::task::{Context, Poll};
use tower::Service;
#[derive(Debug, Clone)]
pub struct HeaderPrefix(pub String);
#[derive(Debug, Clone)]
pub struct AuthentikConfig {
pub header_prefix: String,
pub require_auth: bool,
}
impl Default for AuthentikConfig {
fn default() -> Self {
Self {
header_prefix: "x-authentik".to_string(),
require_auth: true,
}
}
}
impl AuthentikConfig {
pub fn with_prefix(header_prefix: impl Into<String>) -> Self {
Self {
header_prefix: header_prefix.into(),
require_auth: true,
}
}
}
pub struct AuthentikLayer {
config: AuthentikConfig,
}
impl AuthentikLayer {
pub fn new() -> Self {
Self {
config: AuthentikConfig::default(),
}
}
pub fn with_config(config: AuthentikConfig) -> Self {
Self { config }
}
}
impl Default for AuthentikLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for AuthentikLayer {
type Service = AuthentikMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthentikMiddleware {
inner,
config: self.config.clone(),
}
}
}
impl fmt::Debug for AuthentikLayer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthentikLayer")
.field("header_prefix", &self.config.header_prefix)
.field("require_auth", &self.config.require_auth)
.finish()
}
}
#[derive(Clone)]
pub struct AuthentikMiddleware<S> {
inner: S,
config: AuthentikConfig,
}
impl<S, ReqBody> Service<Request<ReqBody>> for AuthentikMiddleware<S>
where
S: Service<Request<ReqBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = futures_util::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
req.extensions_mut()
.insert(HeaderPrefix(self.config.header_prefix.clone()));
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move { inner.call(req).await })
}
}
impl<S: fmt::Debug> fmt::Debug for AuthentikMiddleware<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthentikMiddleware")
.field("inner", &self.inner)
.field("header_prefix", &self.config.header_prefix)
.finish()
}
}