use crate::Authentication;
use async_recursion::async_recursion;
use async_trait::async_trait;
use http::Method;
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt, hash::Hash, marker::PhantomData};
#[async_trait]
pub trait HasPermission<Pool>
where
Pool: Clone + Send + Sync + fmt::Debug + 'static,
{
async fn has(&self, perm: &str, pool: &Option<&Pool>) -> bool;
}
#[derive(Clone, Default)]
pub enum Rights {
All(Box<[Rights]>),
Any(Box<[Rights]>),
NoneOf(Box<[Rights]>),
Permission(String),
#[default]
None,
}
impl Rights {
pub fn all(rights: impl IntoIterator<Item = Rights>) -> Rights {
Rights::All(rights.into_iter().collect())
}
pub fn any(rights: impl IntoIterator<Item = Rights>) -> Rights {
Rights::Any(rights.into_iter().collect())
}
pub fn none(rights: impl IntoIterator<Item = Rights>) -> Rights {
Rights::NoneOf(rights.into_iter().collect())
}
pub fn permission(permission: impl Into<String>) -> Rights {
Rights::Permission(permission.into())
}
#[async_recursion()]
pub async fn evaluate<Pool>(
&self,
user: &(dyn HasPermission<Pool> + Sync),
db: &Option<&Pool>,
) -> bool
where
Pool: Clone + Send + Sync + fmt::Debug + 'static,
{
match self {
Self::All(rights) => {
let mut all = true;
for r in rights.iter() {
if !r.evaluate(user, db).await {
all = false;
break;
}
}
all
}
Self::Any(rights) => {
let mut all = false;
for r in rights.iter() {
if r.evaluate(user, db).await {
all = true;
break;
}
}
all
}
Self::NoneOf(rights) => {
let mut all = true;
for r in rights.iter() {
if r.evaluate(user, db).await {
all = false;
break;
}
}
all
}
Self::Permission(perm) => user.has(perm, db).await,
Self::None => false,
}
}
}
pub struct Auth<User, Type, Pool>
where
User: Authentication<User, Type, Pool> + HasPermission<Pool> + Send,
Pool: Clone + Send + Sync + fmt::Debug + 'static,
Type: Eq + Default + Clone + Send + Sync + Hash + Serialize + DeserializeOwned + 'static,
{
pub rights: Rights,
pub auth_required: bool,
pub methods: Vec<Method>,
phantom_user: PhantomData<User>,
phantom_pool: PhantomData<Pool>,
phantom_type: PhantomData<Type>,
}
impl<User, Type, Pool> Auth<User, Type, Pool>
where
User: Authentication<User, Type, Pool> + HasPermission<Pool> + Sync + Send,
Pool: Clone + Send + Sync + fmt::Debug + 'static,
Type: Eq + Default + Clone + Send + Sync + Hash + Serialize + DeserializeOwned + 'static,
{
pub fn build(
methods: impl IntoIterator<Item = Method>,
auth_req: bool,
) -> Auth<User, Type, Pool> {
Auth::<User, Type, Pool> {
rights: Rights::None,
auth_required: auth_req,
methods: methods.into_iter().collect(),
phantom_user: Default::default(),
phantom_pool: Default::default(),
phantom_type: Default::default(),
}
}
pub fn requires(&mut self, rights: Rights) -> &mut Self {
self.rights = rights;
self
}
pub async fn validate(&self, user: &User, method: &Method, db: Option<&Pool>) -> bool
where
User: HasPermission<Pool> + Authentication<User, Type, Pool>,
{
if self.auth_required && !user.is_authenticated() {
return false;
}
if self.methods.iter().any(|r| r == method) {
self.rights.evaluate(user, &db).await
} else {
false
}
}
}