#![cfg(feature = "axum")]
use axum::{
extract::FromRequestParts,
http::{StatusCode, request::Parts},
response::{IntoResponse, Response},
};
use std::marker::PhantomData;
use super::auth::{AuthContext, PermissionScope};
#[derive(Debug, Clone)]
pub struct RequireAuth(pub AuthContext);
impl std::ops::Deref for RequireAuth {
type Target = AuthContext;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> FromRequestParts<S> for RequireAuth
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<AuthContext>()
.cloned()
.map(RequireAuth)
.ok_or((StatusCode::UNAUTHORIZED, "Authentication required"))
}
}
pub trait Permission: Send + Sync + 'static {
const PERMISSION: &'static str;
const SCOPE: PermissionScope;
}
#[derive(Debug)]
pub struct RequirePermission<P: Permission> {
pub auth: AuthContext,
_phantom: PhantomData<P>,
}
impl<P: Permission> std::ops::Deref for RequirePermission<P> {
type Target = AuthContext;
fn deref(&self) -> &Self::Target {
&self.auth
}
}
impl<S, P> FromRequestParts<S> for RequirePermission<P>
where
S: Send + Sync,
P: Permission,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let auth = parts
.extensions
.get::<AuthContext>()
.cloned()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "Authentication required").into_response())?;
let has_permission = match P::SCOPE {
PermissionScope::Organization => auth
.permissions
.as_ref()
.map(|perms| {
perms
.organization
.as_ref()
.map(|perms| perms.contains(&P::PERMISSION.to_string()))
.unwrap_or(false)
})
.unwrap_or(false),
PermissionScope::Workspace => auth
.permissions
.as_ref()
.map(|perms| {
perms
.workspace
.as_ref()
.map(|perms| perms.contains(&P::PERMISSION.to_string()))
.unwrap_or(false)
})
.unwrap_or(false),
};
if has_permission {
Ok(RequirePermission {
auth,
_phantom: PhantomData,
})
} else {
Err((
StatusCode::FORBIDDEN,
format!("Missing required permission: {}", P::PERMISSION),
)
.into_response())
}
}
}
#[macro_export]
macro_rules! require_permission {
($name:ident, $permission:expr, $scope:ident) => {
pub struct $name;
impl $crate::middleware::extractors::Permission for $name {
const PERMISSION: &'static str = $permission;
const SCOPE: $crate::middleware::PermissionScope =
$crate::middleware::PermissionScope::$scope;
}
};
}
#[derive(Debug, Clone)]
pub struct OptionalAuth(pub Option<AuthContext>);
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(OptionalAuth(parts.extensions.get::<AuthContext>().cloned()))
}
}