use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use sea_orm::EntityTrait;
use sea_orm::sea_query::{Condition, Expr};
use crate::action::Action;
use crate::predicate::Predicate;
#[derive(Default)]
pub enum FieldSet {
#[default]
All,
Only(HashSet<&'static str>),
}
pub(crate) struct Rule {
pub(crate) inverted: bool,
pub(crate) condition: Condition,
pub(crate) predicate: Box<dyn Any + Send + Sync>,
pub(crate) fields: FieldSet,
}
#[derive(Default)]
pub struct Ability {
rules: HashMap<(Action, TypeId), Vec<Rule>>,
}
impl Ability {
pub(crate) fn add_rule(&mut self, action: Action, subject: TypeId, rule: Rule) {
self.rules.entry((action, subject)).or_default().push(rule);
}
fn rules_for(&self, action: Action, subject: TypeId) -> impl Iterator<Item = &Rule> {
let specific = self.rules.get(&(action, subject)).into_iter().flatten();
let wildcard = if action == Action::Manage {
None
} else {
self.rules.get(&(Action::Manage, subject))
};
specific.chain(wildcard.into_iter().flatten())
}
pub fn can_class(&self, action: Action, subject: TypeId) -> bool {
self.rules_for(action, subject).any(|rule| !rule.inverted)
}
pub fn condition_for<E: EntityTrait>(&self, action: Action) -> Condition {
let mut grant = Condition::any();
let mut deny = Condition::any();
for rule in self.rules_for(action, TypeId::of::<E>()) {
if rule.inverted {
deny = deny.add(rule.condition.clone());
} else {
grant = grant.add(rule.condition.clone());
}
}
if grant.is_empty() {
return Condition::all().add(Expr::cust("1 = 0"));
}
let mut out = Condition::all().add(grant);
if !deny.is_empty() {
out = out.add(deny.not());
}
out
}
pub fn can<E: EntityTrait>(&self, action: Action, model: &E::Model) -> bool {
let mut allowed = false;
for rule in self.rules_for(action, TypeId::of::<E>()) {
if predicate_of::<E>(rule).matches(model) {
if rule.inverted {
return false;
}
allowed = true;
}
}
allowed
}
pub fn mask<E>(&self, action: Action, model: &E::Model) -> serde_json::Value
where
E: EntityTrait,
E::Model: serde::Serialize,
{
let mut json = serde_json::to_value(model).unwrap_or(serde_json::Value::Null);
if let FieldSet::Only(allowed) = self.permitted_fields::<E>(action, model)
&& let serde_json::Value::Object(map) = &mut json
{
map.retain(|key, _| allowed.contains(key.as_str()));
}
json
}
pub fn mask_many<'m, E>(
&self,
action: Action,
models: impl IntoIterator<Item = &'m E::Model>,
) -> Vec<serde_json::Value>
where
E: EntityTrait,
E::Model: serde::Serialize + 'm,
{
models
.into_iter()
.filter(|model| self.can::<E>(action, model))
.map(|model| self.mask::<E>(action, model))
.collect()
}
pub fn permitted_fields<E: EntityTrait>(&self, action: Action, model: &E::Model) -> FieldSet {
let mut acc: HashSet<&'static str> = HashSet::new();
for rule in self
.rules_for(action, TypeId::of::<E>())
.filter(|rule| !rule.inverted)
{
if !predicate_of::<E>(rule).matches(model) {
continue;
}
match &rule.fields {
FieldSet::All => return FieldSet::All,
FieldSet::Only(cols) => acc.extend(cols.iter().copied()),
}
}
FieldSet::Only(acc)
}
}
fn predicate_of<E: EntityTrait>(rule: &Rule) -> &Predicate<E> {
rule.predicate
.downcast_ref::<Predicate<E>>()
.expect("rule predicate type matches the subject it is keyed under")
}