use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::response::Response;
use std::future::Future;
use crate::error::AuthentikError;
use crate::user::AuthentikUser;
#[derive(Debug, Clone)]
pub struct RequireGroup(pub AuthentikUser);
impl std::ops::Deref for RequireGroup {
type Target = AuthentikUser;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
struct RequiredGroups(Vec<String>);
#[derive(Debug, Clone, Copy)]
enum GroupMatchMode {
All,
Any,
}
impl<S> FromRequestParts<S> for RequireGroup
where
S: Send + Sync,
{
type Rejection = AuthentikError;
fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
let required = parts.extensions.get::<RequiredGroups>().cloned();
let mode = parts.extensions.get::<GroupMatchMode>().copied();
let user_fut = AuthentikUser::from_request_parts(parts, state);
async move {
let user = user_fut.await?;
match required {
Some(groups) if !groups.0.is_empty() => {
let passes = match mode.unwrap_or(GroupMatchMode::Any) {
GroupMatchMode::All => groups.0.iter().all(|g| user.has_group(g)),
GroupMatchMode::Any => groups.0.iter().any(|g| user.has_group(g)),
};
if passes {
Ok(RequireGroup(user))
} else {
let desc = match mode.unwrap_or(GroupMatchMode::Any) {
GroupMatchMode::All => groups.0.join(", "),
GroupMatchMode::Any => groups.0.join(" or "),
};
Err(AuthentikError::Forbidden {
required_group: desc,
})
}
}
_ => Ok(RequireGroup(user)),
}
}
}
}
pub fn require_group<H, T>(group: &'static str, handler: H) -> GroupGuard<H>
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
GroupGuard {
handler,
groups: vec![group.to_string()],
mode: GroupMatchMode::Any,
}
}
pub fn require_all_groups<H, T>(groups: &[&'static str], handler: H) -> GroupGuard<H>
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
GroupGuard {
handler,
groups: groups.iter().map(|g| g.to_string()).collect(),
mode: GroupMatchMode::All,
}
}
pub fn require_any_group<H, T>(groups: &[&'static str], handler: H) -> GroupGuard<H>
where
H: axum::handler::Handler<T, ()>,
T: 'static,
{
GroupGuard {
handler,
groups: groups.iter().map(|g| g.to_string()).collect(),
mode: GroupMatchMode::Any,
}
}
pub struct GroupGuard<H> {
handler: H,
groups: Vec<String>,
mode: GroupMatchMode,
}
impl<H: Clone> Clone for GroupGuard<H> {
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
groups: self.groups.clone(),
mode: self.mode,
}
}
}
impl<H, T, S> axum::handler::Handler<T, S> for GroupGuard<H>
where
H: axum::handler::Handler<T, S> + Clone + Send + 'static,
T: 'static,
S: Send + Sync + 'static,
{
type Future = std::pin::Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
fn call(self, mut req: axum::extract::Request, state: S) -> Self::Future {
req.extensions_mut()
.insert(RequiredGroups(self.groups.clone()));
req.extensions_mut().insert(self.mode);
let handler = self.handler;
Box::pin(async move { handler.call(req, state).await })
}
}