use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use cedar_policy_core::ast::{
BinaryOp, EntityUID, Expr, ExprKind, Literal, PolicyID, PolicySet, RequestType, UnaryOp, Var,
};
use cedar_policy_core::entities::err::EntitiesError;
use cedar_policy_core::impl_diagnostic_from_source_loc_opt_field;
use cedar_policy_core::parser::Loc;
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use smol_str::SmolStr;
use thiserror::Error;
use crate::{
typecheck::{PolicyCheck, Typechecker},
types::{EntityRecordKind, Type},
ValidationMode, ValidatorSchema,
};
use crate::{ValidationResult, Validator};
#[doc = include_str!("../experimental_warning.md")]
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EntityManifest<T = ()>
where
T: Clone,
{
#[serde_as(as = "Vec<(_, _)>")]
#[serde(bound(deserialize = "T: Default"))]
per_action: HashMap<RequestType, RootAccessTrie<T>>,
}
#[doc = include_str!("../experimental_warning.md")]
pub type Fields<T> = HashMap<SmolStr, Box<AccessTrie<T>>>;
#[doc = include_str!("../experimental_warning.md")]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[serde(rename_all = "camelCase")]
pub enum EntityRoot {
Literal(EntityUID),
Var(Var),
}
impl Display for EntityRoot {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
EntityRoot::Literal(l) => write!(f, "{l}"),
EntityRoot::Var(v) => write!(f, "{v}"),
}
}
}
#[doc = include_str!("../experimental_warning.md")]
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootAccessTrie<T = ()>
where
T: Clone,
{
#[serde_as(as = "Vec<(_, _)>")]
#[serde(bound(deserialize = "T: Default"))]
trie: HashMap<EntityRoot, AccessTrie<T>>,
}
#[doc = include_str!("../experimental_warning.md")]
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AccessTrie<T = ()> {
#[serde_as(as = "Vec<(_, _)>")]
children: Fields<T>,
ancestors_required: bool,
#[serde(skip_serializing, skip_deserializing)]
#[serde(bound(deserialize = "T: Default"))]
data: T,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct AccessPath {
pub root: EntityRoot,
pub path: Vec<SmolStr>,
pub ancestors_required: bool,
}
#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)]
#[error("for policy `{policy_id}`, failed to analyze expression while computing entity manifest`")]
pub struct FailedAnalysisError {
source_loc: Option<Loc>,
policy_id: PolicyID,
expr_kind: ExprKind<Option<Type>>,
}
impl Diagnostic for FailedAnalysisError {
impl_diagnostic_from_source_loc_opt_field!(source_loc);
fn help<'a>(&'a self) -> Option<Box<dyn Display + 'a>> {
Some(Box::new(format!(
"failed to compute entity manifest: {} operators are not allowed before accessing record or entity attributes",
self.expr_kind.operator_description()
)))
}
}
#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)]
#[error("entity slicing requires fully concrete policies. Got a policy with an unknown expression")]
pub struct PartialExpressionError {}
impl Diagnostic for PartialExpressionError {}
#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)]
#[error("entity slicing requires a fully concrete request. Got a partial request")]
pub struct PartialRequestError {}
impl Diagnostic for PartialRequestError {}
#[derive(Debug, Error)]
pub enum EntityManifestError {
#[error("a validation error occurred")]
Validation(ValidationResult),
#[error(transparent)]
Entities(#[from] EntitiesError),
#[error(transparent)]
PartialRequest(#[from] PartialRequestError),
#[error(transparent)]
PartialExpression(#[from] PartialExpressionError),
#[error(transparent)]
FailedAnalysis(#[from] FailedAnalysisError),
}
impl<T: Clone> EntityManifest<T> {
pub fn per_action(&self) -> &HashMap<RequestType, RootAccessTrie<T>> {
&self.per_action
}
}
fn union_fields<T: Clone>(first: &Fields<T>, second: &Fields<T>) -> Fields<T> {
let mut res = first.clone();
for (key, value) in second {
res.entry(key.clone())
.and_modify(|existing| *existing = Box::new((*existing).union(value)))
.or_insert(value.clone());
}
res
}
impl AccessPath {
fn to_root_access_trie(&self) -> RootAccessTrie {
self.to_root_access_trie_with_leaf(AccessTrie {
ancestors_required: true,
children: Default::default(),
data: (),
})
}
fn to_root_access_trie_with_leaf(&self, leaf_trie: AccessTrie) -> RootAccessTrie {
let mut current = leaf_trie;
for field in self.path.iter().rev() {
let mut fields = HashMap::new();
fields.insert(field.clone(), Box::new(current));
current = AccessTrie {
ancestors_required: false,
children: fields,
data: (),
};
}
let mut primary_map = HashMap::new();
primary_map.insert(self.root.clone(), current);
RootAccessTrie { trie: primary_map }
}
}
impl<T: Clone> RootAccessTrie<T> {
pub fn trie(&self) -> &HashMap<EntityRoot, AccessTrie<T>> {
&self.trie
}
}
impl RootAccessTrie {
pub fn new() -> Self {
Self {
trie: Default::default(),
}
}
}
impl<T: Clone> RootAccessTrie<T> {
fn union(&self, other: &Self) -> Self {
let mut res = self.clone();
for (key, value) in &other.trie {
res.trie
.entry(key.clone())
.and_modify(|existing| *existing = (*existing).union(value))
.or_insert(value.clone());
}
res
}
}
impl Default for RootAccessTrie {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone> AccessTrie<T> {
fn union(&self, other: &Self) -> Self {
Self {
children: union_fields(&self.children, &other.children),
ancestors_required: self.ancestors_required || other.ancestors_required,
data: self.data.clone(),
}
}
pub fn children(&self) -> &Fields<T> {
&self.children
}
pub fn ancestors_required(&self) -> bool {
self.ancestors_required
}
pub fn data(&self) -> &T {
&self.data
}
}
impl AccessTrie {
fn new() -> Self {
Self {
children: Default::default(),
ancestors_required: false,
data: (),
}
}
}
pub fn compute_entity_manifest(
schema: &ValidatorSchema,
policies: &PolicySet,
) -> Result<EntityManifest, EntityManifestError> {
let validator = Validator::new(schema.clone());
let validation_res = validator.validate(policies, ValidationMode::Strict);
if !validation_res.validation_passed() {
return Err(EntityManifestError::Validation(validation_res));
}
let mut manifest: HashMap<RequestType, RootAccessTrie> = HashMap::new();
for policy in policies.policies() {
let typechecker = Typechecker::new(schema, ValidationMode::Strict, policy.id().clone());
let request_envs = typechecker.typecheck_by_request_env(policy.template());
for (request_env, policy_check) in request_envs {
let new_primary_slice = match policy_check {
PolicyCheck::Success(typechecked_expr) => {
compute_root_trie(&typechecked_expr, policy.id())
}
PolicyCheck::Irrelevant(_, _) => {
Ok(RootAccessTrie::new())
}
#[allow(clippy::panic)]
PolicyCheck::Fail(_errors) => {
panic!("Policy check failed after validation succeeded")
}
}?;
let request_type = request_env
.to_request_type()
.ok_or(PartialRequestError {})?;
manifest
.entry(request_type)
.and_modify(|existing| {
*existing = existing.union(&new_primary_slice);
})
.or_insert(new_primary_slice);
}
}
Ok(EntityManifest {
per_action: manifest,
})
}
fn compute_root_trie(
expr: &Expr<Option<Type>>,
policy_id: &PolicyID,
) -> Result<RootAccessTrie, EntityManifestError> {
let mut primary_slice = RootAccessTrie::new();
add_to_root_trie(&mut primary_slice, expr, policy_id, false)?;
Ok(primary_slice)
}
fn add_to_root_trie(
root_trie: &mut RootAccessTrie,
expr: &Expr<Option<Type>>,
policy_id: &PolicyID,
should_load_all: bool,
) -> Result<(), EntityManifestError> {
match expr.expr_kind() {
ExprKind::Lit(_) => Ok(()),
ExprKind::Var(_) => Ok(()),
ExprKind::Slot(_) => Ok(()),
ExprKind::Unknown(_) => Err(PartialExpressionError {})?,
ExprKind::If {
test_expr,
then_expr,
else_expr,
} => {
add_to_root_trie(root_trie, test_expr, policy_id, should_load_all)?;
add_to_root_trie(root_trie, then_expr, policy_id, should_load_all)?;
add_to_root_trie(root_trie, else_expr, policy_id, should_load_all)?;
Ok(())
}
ExprKind::And { left, right } => {
add_to_root_trie(root_trie, left, policy_id, should_load_all)?;
add_to_root_trie(root_trie, right, policy_id, should_load_all)?;
Ok(())
}
ExprKind::Or { left, right } => {
add_to_root_trie(root_trie, left, policy_id, should_load_all)?;
add_to_root_trie(root_trie, right, policy_id, should_load_all)?;
Ok(())
}
ExprKind::UnaryApp { op, arg } => {
match op {
UnaryOp::Not => add_to_root_trie(root_trie, arg, policy_id, should_load_all)?,
UnaryOp::Neg => add_to_root_trie(root_trie, arg, policy_id, should_load_all)?,
};
Ok(())
}
ExprKind::BinaryApp { op, arg1, arg2 } => match op {
BinaryOp::Eq => {
add_to_root_trie(root_trie, arg1, policy_id, true)?;
add_to_root_trie(root_trie, arg2, policy_id, true)?;
Ok(())
}
BinaryOp::In => {
add_to_root_trie(root_trie, arg2, policy_id, should_load_all)?;
let mut flat_slice = get_expr_path(arg1, policy_id)?;
flat_slice.ancestors_required = true;
*root_trie = root_trie.union(&flat_slice.to_root_access_trie());
Ok(())
}
BinaryOp::Contains | BinaryOp::ContainsAll | BinaryOp::ContainsAny => {
add_to_root_trie(root_trie, arg1, policy_id, true)?;
add_to_root_trie(root_trie, arg2, policy_id, true)?;
Ok(())
}
BinaryOp::Less | BinaryOp::LessEq | BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => {
add_to_root_trie(root_trie, arg1, policy_id, should_load_all)?;
add_to_root_trie(root_trie, arg2, policy_id, should_load_all)?;
Ok(())
}
BinaryOp::GetTag | BinaryOp::HasTag => {
unimplemented!("interaction between RFCs 74 and 82")
}
},
ExprKind::ExtensionFunctionApp { fn_name: _, args } => {
for arg in args.iter() {
add_to_root_trie(root_trie, arg, policy_id, should_load_all)?;
}
Ok(())
}
ExprKind::Like { expr, pattern: _ } => {
add_to_root_trie(root_trie, expr, policy_id, should_load_all)?;
Ok(())
}
ExprKind::Is {
expr,
entity_type: _,
} => {
add_to_root_trie(root_trie, expr, policy_id, should_load_all)?;
Ok(())
}
ExprKind::Set(contents) => {
for expr in &**contents {
add_to_root_trie(root_trie, expr, policy_id, should_load_all)?;
}
Ok(())
}
ExprKind::Record(content) => {
for expr in content.values() {
add_to_root_trie(root_trie, expr, policy_id, should_load_all)?;
}
Ok(())
}
ExprKind::HasAttr { expr, attr } => {
let mut flat_slice = get_expr_path(expr, policy_id)?;
flat_slice.path.push(attr.clone());
*root_trie = root_trie.union(&flat_slice.to_root_access_trie());
Ok(())
}
ExprKind::GetAttr { .. } => {
let flat_slice = get_expr_path(expr, policy_id)?;
#[allow(clippy::expect_used)]
let leaf_field = if should_load_all {
type_to_access_trie(
expr.data()
.as_ref()
.expect("Typechecked expression missing type"),
)
} else {
AccessTrie::new()
};
*root_trie = root_trie.union(&flat_slice.to_root_access_trie_with_leaf(leaf_field));
Ok(())
}
}
}
fn type_to_access_trie(ty: &Type) -> AccessTrie {
match ty {
Type::ExtensionType { .. }
| Type::Never
| Type::True
| Type::False
| Type::Primitive { .. }
| Type::Set { .. } => AccessTrie::new(),
Type::EntityOrRecord(record_type) => entity_or_record_to_access_trie(record_type),
}
}
fn entity_or_record_to_access_trie(ty: &EntityRecordKind) -> AccessTrie {
match ty {
EntityRecordKind::ActionEntity { attrs, .. } | EntityRecordKind::Record { attrs, .. } => {
let mut fields = HashMap::new();
for (attr_name, attr_type) in attrs.iter() {
fields.insert(
attr_name.clone(),
Box::new(type_to_access_trie(&attr_type.attr_type)),
);
}
AccessTrie {
children: fields,
ancestors_required: false,
data: (),
}
}
EntityRecordKind::Entity(_) | EntityRecordKind::AnyEntity => {
AccessTrie::new()
}
}
}
fn get_expr_path(
expr: &Expr<Option<Type>>,
policy_id: &PolicyID,
) -> Result<AccessPath, EntityManifestError> {
Ok(match expr.expr_kind() {
ExprKind::Slot(slot_id) => {
if slot_id.is_principal() {
AccessPath {
root: EntityRoot::Var(Var::Principal),
path: vec![],
ancestors_required: false,
}
} else {
assert!(slot_id.is_resource());
AccessPath {
root: EntityRoot::Var(Var::Resource),
path: vec![],
ancestors_required: false,
}
}
}
ExprKind::Var(var) => AccessPath {
root: EntityRoot::Var(*var),
path: vec![],
ancestors_required: false,
},
ExprKind::GetAttr { expr, attr } => {
let mut slice = get_expr_path(expr, policy_id)?;
slice.path.push(attr.clone());
slice
}
ExprKind::Lit(Literal::EntityUID(literal)) => AccessPath {
root: EntityRoot::Literal((**literal).clone()),
path: vec![],
ancestors_required: false,
},
ExprKind::Unknown(_) => Err(PartialExpressionError {})?,
_ => Err(EntityManifestError::FailedAnalysis(FailedAnalysisError {
source_loc: expr.source_loc().cloned(),
policy_id: policy_id.clone(),
expr_kind: expr.expr_kind().clone(),
}))?,
})
}
#[cfg(test)]
mod entity_slice_tests {
use cedar_policy_core::{ast::PolicyID, extensions::Extensions, parser::parse_policy};
use super::*;
fn schema() -> ValidatorSchema {
ValidatorSchema::from_cedarschema_str(
"
entity User = {
name: String,
};
entity Document;
action Read appliesTo {
principal: [User],
resource: [Document]
};
",
Extensions::all_available(),
)
.unwrap()
.0
}
#[test]
fn test_simple_entity_manifest() {
let mut pset = PolicySet::new();
let policy = parse_policy(
None,
"permit(principal, action, resource)
when {
principal.name == \"John\"
};",
)
.expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = schema();
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json! ({
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "principal"
},
{
"children": [
[
"name",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_empty_entity_manifest() {
let mut pset = PolicySet::new();
let policy =
parse_policy(None, "permit(principal, action, resource);").expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = schema();
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_entity_manifest_ancestors_required() {
let mut pset = PolicySet::new();
let policy = parse_policy(
None,
"permit(principal, action, resource)
when {
principal in resource || principal.manager in resource
};",
)
.expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = ValidatorSchema::from_cedarschema_str(
"
entity User in [Document] = {
name: String,
manager: User
};
entity Document;
action Read appliesTo {
principal: [User],
resource: [Document]
};
",
Extensions::all_available(),
)
.unwrap()
.0;
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "principal"
},
{
"children": [
[
"manager",
{
"children": [],
"ancestorsRequired": true
}
]
],
"ancestorsRequired": true
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_entity_manifest_multiple_types() {
let mut pset = PolicySet::new();
let policy = parse_policy(
None,
"permit(principal, action, resource)
when {
principal.name == \"John\"
};",
)
.expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = ValidatorSchema::from_cedarschema_str(
"
entity User = {
name: String,
};
entity OtherUserType = {
name: String,
irrelevant: String,
};
entity Document;
action Read appliesTo {
principal: [User, OtherUserType],
resource: [Document]
};
",
Extensions::all_available(),
)
.unwrap()
.0;
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "principal"
},
{
"children": [
[
"name",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
],
[
{
"principal": "OtherUserType",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "principal"
},
{
"children": [
[
"name",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_entity_manifest_multiple_branches() {
let mut pset = PolicySet::new();
let policy1 = parse_policy(
None,
r#"
permit(
principal,
action == Action::"Read",
resource
)
when
{
resource.readers.contains(principal)
};"#,
)
.unwrap();
let policy2 = parse_policy(
Some(PolicyID::from_string("Policy2")),
r#"permit(
principal,
action == Action::"Read",
resource
)
when
{
resource.metadata.owner == principal
};"#,
)
.unwrap();
pset.add(policy1.into()).expect("should succeed");
pset.add(policy2.into()).expect("should succeed");
let schema = ValidatorSchema::from_cedarschema_str(
"
entity User;
entity Metadata = {
owner: User,
time: String,
};
entity Document = {
metadata: Metadata,
readers: Set<User>,
};
action Read appliesTo {
principal: [User],
resource: [Document]
};
",
Extensions::all_available(),
)
.unwrap()
.0;
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Read"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "resource"
},
{
"children": [
[
"metadata",
{
"children": [
[
"owner",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
],
[
"readers",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_entity_manifest_struct_equality() {
let mut pset = PolicySet::new();
let policy = parse_policy(
None,
r#"permit(principal, action, resource)
when {
principal.metadata.nickname == "timmy" && principal.metadata == {
"friends": [ "oliver" ],
"nickname": "timmy"
}
};"#,
)
.expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = ValidatorSchema::from_cedarschema_str(
"
entity User = {
name: String,
metadata: {
friends: Set<String>,
nickname: String,
},
};
entity Document;
action BeSad appliesTo {
principal: [User],
resource: [Document]
};
",
Extensions::all_available(),
)
.unwrap()
.0;
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "BeSad"
},
"resource": "Document"
},
{
"trie": [
[
{
"var": "principal"
},
{
"children": [
[
"metadata",
{
"children": [
[
"nickname",
{
"children": [],
"ancestorsRequired": false
}
],
[
"friends",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
#[test]
fn test_entity_manifest_struct_equality_left_right_different() {
let mut pset = PolicySet::new();
let policy = parse_policy(
None,
r#"permit(principal, action, resource)
when {
principal.metadata == resource.metadata
};"#,
)
.expect("should succeed");
pset.add(policy.into()).expect("should succeed");
let schema = ValidatorSchema::from_cedarschema_str(
"
entity User = {
name: String,
metadata: {
friends: Set<String>,
nickname: String,
},
};
entity Document;
action Hello appliesTo {
principal: [User],
resource: [User]
};
",
Extensions::all_available(),
)
.unwrap()
.0;
let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed");
let expected = serde_json::json!(
{
"perAction": [
[
{
"principal": "User",
"action": {
"ty": "Action",
"eid": "Hello"
},
"resource": "User"
},
{
"trie": [
[
{
"var": "resource"
},
{
"children": [
[
"metadata",
{
"children": [
[
"friends",
{
"children": [],
"ancestorsRequired": false
}
],
[
"nickname",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
],
[
{
"var": "principal"
},
{
"children": [
[
"metadata",
{
"children": [
[
"nickname",
{
"children": [],
"ancestorsRequired": false
}
],
[
"friends",
{
"children": [],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
],
"ancestorsRequired": false
}
]
]
}
]
]
});
let expected_manifest = serde_json::from_value(expected).unwrap();
assert_eq!(entity_manifest, expected_manifest);
}
}