use cedar_policy_core::{
ast::{EntityType, EntityUID, PartialValueSerializedAsExpr},
transitive_closure::TCNode,
};
use serde::Serialize;
use smol_str::SmolStr;
use std::collections::{BTreeMap, HashSet};
use crate::types::{Attributes, Type};
#[derive(Clone, Debug, Serialize)]
pub struct ValidatorActionId {
pub(crate) name: EntityUID,
#[serde(rename = "appliesTo")]
pub(crate) applies_to: ValidatorApplySpec,
pub(crate) descendants: HashSet<EntityUID>,
pub(crate) context: Type,
pub(crate) attribute_types: Attributes,
pub(crate) attributes: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
}
impl ValidatorActionId {
pub fn principals(&self) -> impl Iterator<Item = &EntityType> {
self.applies_to.principal_apply_spec.iter()
}
pub fn resources(&self) -> impl Iterator<Item = &EntityType> {
self.applies_to.resource_apply_spec.iter()
}
pub fn context_type(&self) -> Type {
self.context.clone()
}
pub fn applies_to_principals(&self) -> impl Iterator<Item = &EntityType> {
self.applies_to.principal_apply_spec.iter()
}
pub fn applies_to_resources(&self) -> impl Iterator<Item = &EntityType> {
self.applies_to.resource_apply_spec.iter()
}
}
impl TCNode<EntityUID> for ValidatorActionId {
fn get_key(&self) -> EntityUID {
self.name.clone()
}
fn add_edge_to(&mut self, k: EntityUID) {
self.descendants.insert(k);
}
fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
Box::new(self.descendants.iter())
}
fn has_edge_to(&self, e: &EntityUID) -> bool {
self.descendants.contains(e)
}
}
#[derive(Clone, Debug, Serialize)]
pub(crate) struct ValidatorApplySpec {
#[serde(rename = "principalApplySpec")]
principal_apply_spec: HashSet<EntityType>,
#[serde(rename = "resourceApplySpec")]
resource_apply_spec: HashSet<EntityType>,
}
impl ValidatorApplySpec {
pub fn new(
principal_apply_spec: HashSet<EntityType>,
resource_apply_spec: HashSet<EntityType>,
) -> Self {
Self {
principal_apply_spec,
resource_apply_spec,
}
}
pub fn is_applicable_principal_type(&self, ty: &EntityType) -> bool {
self.principal_apply_spec.contains(ty)
}
pub fn applicable_principal_types(&self) -> impl Iterator<Item = &EntityType> {
self.principal_apply_spec.iter()
}
pub fn is_applicable_resource_type(&self, ty: &EntityType) -> bool {
self.resource_apply_spec.contains(ty)
}
pub fn applicable_resource_types(&self) -> impl Iterator<Item = &EntityType> {
self.resource_apply_spec.iter()
}
}
#[cfg(test)]
mod test {
use super::*;
fn make_action() -> ValidatorActionId {
ValidatorActionId {
name: r#"Action::"foo""#.parse().unwrap(),
applies_to: ValidatorApplySpec {
principal_apply_spec: HashSet::from([
EntityType::Specified("User".parse().unwrap()),
EntityType::Specified("User".parse().unwrap()),
]),
resource_apply_spec: HashSet::from([
EntityType::Specified("App".parse().unwrap()),
EntityType::Specified("File".parse().unwrap()),
]),
},
descendants: HashSet::new(),
context: Type::any_record(),
attribute_types: Attributes::default(),
attributes: BTreeMap::default(),
}
}
#[test]
fn test_resources() {
let a = make_action();
let got = a.resources().cloned().collect::<HashSet<EntityType>>();
let expected = HashSet::from([
EntityType::Specified("App".parse().unwrap()),
EntityType::Specified("File".parse().unwrap()),
]);
assert_eq!(got, expected);
}
#[test]
fn test_principals() {
let a = make_action();
let got = a.principals().cloned().collect::<Vec<EntityType>>();
let expected: [EntityType; 1] = [EntityType::Specified("User".parse().unwrap())];
assert_eq!(got, &expected);
}
}