use crate::{BatchEvalCtx, EvalCtx, FactLoadResult, Policy, PolicyEvalResult, RelationshipQuery};
use async_trait::async_trait;
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
pub struct RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation> {
subject_id: Arc<dyn Fn(&S) -> SubjectId + Send + Sync>,
resource_id: Arc<dyn Fn(&R) -> ResourceId + Send + Sync>,
relation: Relation,
_marker: std::marker::PhantomData<(A, C)>,
}
impl<S, R, A, C, SubjectId, ResourceId, Relation>
RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
{
pub fn new<SubjectIdFn, ResourceIdFn>(
subject_id: SubjectIdFn,
resource_id: ResourceIdFn,
relation: Relation,
) -> Self
where
SubjectIdFn: Fn(&S) -> SubjectId + Send + Sync + 'static,
ResourceIdFn: Fn(&R) -> ResourceId + Send + Sync + 'static,
{
Self {
subject_id: Arc::new(subject_id),
resource_id: Arc::new(resource_id),
relation,
_marker: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<S, R, A, C, SubjectId, ResourceId, Relation> Policy<S, R, A, C>
for RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
where
S: Sync + Send,
R: Sync + Send,
A: Sync + Send,
C: Sync + Send,
SubjectId: Eq + Hash + Clone + Send + Sync + 'static,
ResourceId: Eq + Hash + Clone + Send + Sync + 'static,
Relation: Eq + Hash + Clone + Send + Sync + fmt::Display + 'static,
{
async fn evaluate(&self, ctx: &EvalCtx<'_, S, R, A, C>) -> PolicyEvalResult {
let key = RelationshipQuery {
subject_id: (self.subject_id)(ctx.subject),
resource_id: (self.resource_id)(ctx.resource),
relation: self.relation.clone(),
};
self.result_from_fact(ctx.session.get(key).await)
}
async fn evaluate_batch<'item>(
&self,
ctx: &BatchEvalCtx<'item, S, R, A, C>,
) -> Vec<PolicyEvalResult> {
let subject_id = (self.subject_id)(ctx.subject);
let keys = ctx
.items
.iter()
.map(|item| RelationshipQuery {
subject_id: subject_id.clone(),
resource_id: (self.resource_id)(item.resource),
relation: self.relation.clone(),
})
.collect::<Vec<_>>();
let facts = ctx.session.get_many(&keys).await;
if facts.len() != ctx.items.len() {
return ctx
.items
.iter()
.map(|_| PolicyEvalResult::Denied {
policy_type: self.policy_type().to_string(),
reason: "Relationship fact source returned the wrong number of results"
.to_string(),
})
.collect();
}
facts
.into_iter()
.map(|fact| self.result_from_fact(fact))
.collect()
}
fn policy_type(&self) -> &str {
"RebacPolicy"
}
}
impl<S, R, A, C, SubjectId, ResourceId, Relation>
RebacPolicy<S, R, A, C, SubjectId, ResourceId, Relation>
where
Relation: fmt::Display,
{
fn result_from_fact(&self, fact: FactLoadResult<bool>) -> PolicyEvalResult {
match fact {
FactLoadResult::Found(true) => PolicyEvalResult::Granted {
policy_type: "RebacPolicy".to_string(),
reason: Some(format!(
"Subject has '{}' relationship with resource",
self.relation
)),
},
FactLoadResult::Found(false) => PolicyEvalResult::Denied {
policy_type: "RebacPolicy".to_string(),
reason: format!(
"Subject does not have '{}' relationship with resource",
self.relation
),
},
FactLoadResult::Missing => PolicyEvalResult::Denied {
policy_type: "RebacPolicy".to_string(),
reason: format!("Relationship '{}' fact is missing", self.relation),
},
FactLoadResult::Error(error) => PolicyEvalResult::Denied {
policy_type: "RebacPolicy".to_string(),
reason: format!("Relationship '{}' fact load failed: {error}", self.relation),
},
}
}
}