use crate::{
BatchEvalCtx, EvalCtx, FactKey, FactLoadResult, FactOutcome, FactProvenance, Policy,
PolicyEvalResult, RelationshipQuery,
};
use async_trait::async_trait;
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
pub struct RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation> {
subject_id: Arc<dyn Fn(&S) -> SubjectId + Send + Sync>,
resource_id: Arc<dyn Fn(&R) -> ResourceId + Send + Sync>,
relation: Relation,
_marker: std::marker::PhantomData<(A, C)>,
}
impl<S, R, A, C, SubjectId, ResourceId, Relation>
RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
{
pub fn new<SubjectIdFn, ResourceIdFn>(
subject_id: SubjectIdFn,
resource_id: ResourceIdFn,
relation: Relation,
) -> Self
where
SubjectIdFn: Fn(&S) -> SubjectId + Send + Sync + 'static,
ResourceIdFn: Fn(&R) -> ResourceId + Send + Sync + 'static,
{
Self {
subject_id: Arc::new(subject_id),
resource_id: Arc::new(resource_id),
relation,
_marker: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<S, R, A, C, SubjectId, ResourceId, Relation> Policy<S, R, A, C>
for RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
where
S: Sync + Send,
R: Sync + Send,
A: Sync + Send,
C: Sync + Send,
SubjectId: Eq + Hash + Clone + Send + Sync + fmt::Debug + 'static,
ResourceId: Eq + Hash + Clone + Send + Sync + fmt::Debug + 'static,
Relation: Eq + Hash + Clone + Send + Sync + fmt::Display + 'static,
{
async fn evaluate(&self, ctx: &EvalCtx<'_, S, R, A, C>) -> PolicyEvalResult {
let fact_name = <RelationshipQuery<SubjectId, ResourceId, Relation> as FactKey>::NAME;
let key = RelationshipQuery {
subject_id: (self.subject_id)(ctx.subject),
resource_id: (self.resource_id)(ctx.resource),
relation: self.relation.clone(),
};
let key_repr = Self::render_key(&key);
self.result_from_fact(fact_name, &key_repr, ctx.session.get(key).await)
}
async fn evaluate_batch<'item>(
&self,
ctx: &BatchEvalCtx<'item, S, R, A, C>,
) -> Vec<PolicyEvalResult> {
let fact_name = <RelationshipQuery<SubjectId, ResourceId, Relation> as FactKey>::NAME;
let subject_id = (self.subject_id)(ctx.subject);
let keys = ctx
.items
.iter()
.map(|item| RelationshipQuery {
subject_id: subject_id.clone(),
resource_id: (self.resource_id)(item.resource),
relation: self.relation.clone(),
})
.collect::<Vec<_>>();
let facts = ctx.session.get_many(&keys).await;
if facts.len() != ctx.items.len() {
return ctx
.items
.iter()
.map(|_| {
PolicyEvalResult::denied(
self.policy_type(),
"Relationship fact source returned the wrong number of results",
)
})
.collect();
}
keys.iter()
.zip(facts)
.map(|(key, fact)| self.result_from_fact(fact_name, &Self::render_key(key), fact))
.collect()
}
fn policy_type(&self) -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed("RebacPolicy")
}
}
impl<S, R, A, C, SubjectId, ResourceId, Relation>
RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
where
SubjectId: fmt::Debug,
ResourceId: fmt::Debug,
Relation: fmt::Display,
{
fn render_key(key: &RelationshipQuery<SubjectId, ResourceId, Relation>) -> String {
format!(
"{:?} -[{}]-> {:?}",
key.subject_id, key.relation, key.resource_id
)
}
fn result_from_fact(
&self,
fact_name: &'static str,
key_repr: &str,
fact: FactLoadResult<bool>,
) -> PolicyEvalResult {
let outcome = FactOutcome::from_load_result(&fact);
let detail = match &fact {
FactLoadResult::Error(error) => Some(error.to_string()),
_ => None,
};
let provenance = vec![FactProvenance::new(fact_name, key_repr, outcome, detail)];
match fact {
FactLoadResult::Found(true) => PolicyEvalResult::granted_with_facts(
"RebacPolicy",
Some(format!(
"Subject has '{}' relationship with resource",
self.relation
)),
provenance,
),
FactLoadResult::Found(false) => PolicyEvalResult::denied_with_facts(
"RebacPolicy",
format!(
"Subject does not have '{}' relationship with resource",
self.relation
),
provenance,
),
FactLoadResult::Missing => PolicyEvalResult::denied_with_facts(
"RebacPolicy",
format!("Relationship '{}' fact is missing", self.relation),
provenance,
),
FactLoadResult::Error(error) => PolicyEvalResult::denied_with_facts(
"RebacPolicy",
format!("Relationship '{}' fact load failed: {error}", self.relation),
provenance,
),
}
}
}