use std::collections::HashSet;
use std::{collections::BTreeMap, sync::Arc};
use crate::ast::{
Annotations, Effect, EntityUID, Literal, Policy, PolicyID, UnwrapInfallible, ValueKind,
};
#[cfg(feature = "tolerant-ast")]
use crate::tpe::err::ErrorNotSupportedError;
use crate::tpe::err::{
ExprToResidualError, MissingTypeAnnotationError, SlotNotSupportedError,
UnknownNotSupportedError,
};
use crate::validator::types::Type;
use crate::{
ast::{self, BinaryOp, EntityType, Expr, Name, Pattern, UnaryOp, Value, Var},
expr_builder::ExprBuilder,
};
use smol_str::SmolStr;
#[derive(Debug, Clone)]
pub enum Residual {
Partial {
kind: ResidualKind,
ty: Type,
},
Concrete {
value: Value,
ty: Type,
},
Error(Type),
}
impl Residual {
pub fn to_policy(self, id: PolicyID, effect: Effect, annotations: Annotations) -> Policy {
Policy::from_when_clause_annos(
effect,
Arc::new(self.into()),
id,
None,
Arc::new(annotations),
)
}
pub fn all_literal_uids(&self) -> HashSet<EntityUID> {
match self {
Residual::Partial { kind, .. } => kind.all_literal_uids(),
Residual::Concrete { value, .. } => value.all_literal_uids(),
Residual::Error(_) => HashSet::new(),
}
}
pub fn ty(&self) -> &Type {
match self {
Residual::Partial { ty, .. } => ty,
Residual::Concrete { ty, .. } => ty,
Residual::Error(ty) => ty,
}
}
pub fn can_error_assuming_well_formed(&self) -> bool {
match self {
Residual::Concrete { .. } => false,
Residual::Error(_) => true,
Residual::Partial { kind, .. } => match kind {
ResidualKind::Var(_) => false,
ResidualKind::And { left, right } => {
left.can_error_assuming_well_formed() || right.can_error_assuming_well_formed()
}
ResidualKind::Or { left, right } => {
left.can_error_assuming_well_formed() || right.can_error_assuming_well_formed()
}
ResidualKind::If {
test_expr,
then_expr,
else_expr,
} => {
test_expr.can_error_assuming_well_formed()
|| then_expr.can_error_assuming_well_formed()
|| else_expr.can_error_assuming_well_formed()
}
ResidualKind::Is { expr, .. } => expr.can_error_assuming_well_formed(),
ResidualKind::Like { expr, .. } => expr.can_error_assuming_well_formed(),
ResidualKind::BinaryApp { op, arg1, arg2 } => match op {
ast::BinaryOp::Add => true,
ast::BinaryOp::Mul => true,
ast::BinaryOp::Sub => true,
ast::BinaryOp::GetTag => true,
ast::BinaryOp::Contains
| ast::BinaryOp::ContainsAll
| ast::BinaryOp::ContainsAny
| ast::BinaryOp::Eq
| ast::BinaryOp::HasTag
| ast::BinaryOp::In
| ast::BinaryOp::Less
| ast::BinaryOp::LessEq => {
arg1.can_error_assuming_well_formed()
|| arg2.can_error_assuming_well_formed()
}
},
ResidualKind::ExtensionFunctionApp { .. } => true,
ResidualKind::GetAttr { .. } => true,
ResidualKind::HasAttr { expr, .. } => expr.can_error_assuming_well_formed(),
ResidualKind::UnaryApp { op, arg } => match op {
ast::UnaryOp::Neg => true,
ast::UnaryOp::IsEmpty | ast::UnaryOp::Not => {
arg.can_error_assuming_well_formed()
}
},
ResidualKind::Set(items) => items.iter().any(Self::can_error_assuming_well_formed),
ResidualKind::Record(attrs) => attrs
.iter()
.any(|(_, e)| e.can_error_assuming_well_formed()),
},
}
}
}
impl TryFrom<Residual> for Value {
type Error = ();
fn try_from(value: Residual) -> std::result::Result<Self, Self::Error> {
match value {
Residual::Concrete { value, .. } => Ok(value),
_ => Err(()),
}
}
}
impl TryFrom<&Expr<Option<Type>>> for Residual {
type Error = ExprToResidualError;
fn try_from(expr: &Expr<Option<Type>>) -> std::result::Result<Self, ExprToResidualError> {
let ty = expr.data().clone().ok_or(MissingTypeAnnotationError)?;
let kind = match expr.expr_kind() {
ast::ExprKind::Var(var) => ResidualKind::Var(*var),
ast::ExprKind::If {
test_expr,
then_expr,
else_expr,
} => ResidualKind::If {
test_expr: Arc::new(Self::try_from(test_expr.as_ref())?),
then_expr: Arc::new(Self::try_from(then_expr.as_ref())?),
else_expr: Arc::new(Self::try_from(else_expr.as_ref())?),
},
ast::ExprKind::And { left, right } => ResidualKind::And {
left: Arc::new(Self::try_from(left.as_ref())?),
right: Arc::new(Self::try_from(right.as_ref())?),
},
ast::ExprKind::Or { left, right } => ResidualKind::Or {
left: Arc::new(Self::try_from(left.as_ref())?),
right: Arc::new(Self::try_from(right.as_ref())?),
},
ast::ExprKind::UnaryApp { op, arg } => ResidualKind::UnaryApp {
op: *op,
arg: Arc::new(Self::try_from(arg.as_ref())?),
},
ast::ExprKind::BinaryApp { op, arg1, arg2 } => ResidualKind::BinaryApp {
op: *op,
arg1: Arc::new(Self::try_from(arg1.as_ref())?),
arg2: Arc::new(Self::try_from(arg2.as_ref())?),
},
ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
let residual_args: Result<Vec<_>, _> = args.iter().map(Self::try_from).collect();
ResidualKind::ExtensionFunctionApp {
fn_name: fn_name.clone(),
args: Arc::new(residual_args?),
}
}
ast::ExprKind::GetAttr { expr, attr } => ResidualKind::GetAttr {
expr: Arc::new(Self::try_from(expr.as_ref())?),
attr: attr.clone(),
},
ast::ExprKind::HasAttr { expr, attr } => ResidualKind::HasAttr {
expr: Arc::new(Self::try_from(expr.as_ref())?),
attr: attr.clone(),
},
ast::ExprKind::Like { expr, pattern } => ResidualKind::Like {
expr: Arc::new(Self::try_from(expr.as_ref())?),
pattern: pattern.clone(),
},
ast::ExprKind::Is { expr, entity_type } => ResidualKind::Is {
expr: Arc::new(Self::try_from(expr.as_ref())?),
entity_type: entity_type.clone(),
},
ast::ExprKind::Set(elements) => {
let residual_elements: Result<Vec<_>, _> =
elements.iter().map(Self::try_from).collect();
ResidualKind::Set(Arc::new(residual_elements?))
}
ast::ExprKind::Record(map) => {
let residual_map: Result<BTreeMap<_, _>, ExprToResidualError> = map
.iter()
.map(|(k, v)| Ok((k.clone(), Self::try_from(v)?)))
.collect();
ResidualKind::Record(Arc::new(residual_map?))
}
ast::ExprKind::Lit(lit) => {
let value = Value::new(lit.clone(), None);
return Ok(Residual::Concrete { value, ty });
}
ast::ExprKind::Slot(_) => return Err(SlotNotSupportedError.into()),
ast::ExprKind::Unknown(_) => return Err(UnknownNotSupportedError.into()),
#[cfg(feature = "tolerant-ast")]
ast::ExprKind::Error { .. } => {
return Err(ErrorNotSupportedError.into());
}
};
Ok(Residual::Partial { kind, ty })
}
}
#[derive(Debug, Clone)]
pub enum ResidualKind {
Var(Var),
If {
test_expr: Arc<Residual>,
then_expr: Arc<Residual>,
else_expr: Arc<Residual>,
},
And {
left: Arc<Residual>,
right: Arc<Residual>,
},
Or {
left: Arc<Residual>,
right: Arc<Residual>,
},
UnaryApp {
op: UnaryOp,
arg: Arc<Residual>,
},
BinaryApp {
op: BinaryOp,
arg1: Arc<Residual>,
arg2: Arc<Residual>,
},
ExtensionFunctionApp {
fn_name: Name,
args: Arc<Vec<Residual>>,
},
GetAttr {
expr: Arc<Residual>,
attr: SmolStr,
},
HasAttr {
expr: Arc<Residual>,
attr: SmolStr,
},
Like {
expr: Arc<Residual>,
pattern: Pattern,
},
Is {
expr: Arc<Residual>,
entity_type: EntityType,
},
Set(Arc<Vec<Residual>>),
Record(Arc<BTreeMap<SmolStr, Residual>>),
}
impl ResidualKind {
pub fn all_literal_uids(&self) -> HashSet<EntityUID> {
match self {
ResidualKind::Var(_) => HashSet::new(),
ResidualKind::If {
test_expr,
then_expr,
else_expr,
} => {
let mut uids = test_expr.all_literal_uids();
uids.extend(then_expr.all_literal_uids());
uids.extend(else_expr.all_literal_uids());
uids
}
ResidualKind::And { left, right } | ResidualKind::Or { left, right } => {
let mut uids = left.all_literal_uids();
uids.extend(right.all_literal_uids());
uids
}
ResidualKind::UnaryApp { arg, .. } => arg.all_literal_uids(),
ResidualKind::BinaryApp { arg1, arg2, .. } => {
let mut uids = arg1.all_literal_uids();
uids.extend(arg2.all_literal_uids());
uids
}
ResidualKind::ExtensionFunctionApp { args, .. } => {
let mut uids = HashSet::new();
for arg in args.as_ref() {
uids.extend(arg.all_literal_uids());
}
uids
}
ResidualKind::GetAttr { expr, .. }
| ResidualKind::HasAttr { expr, .. }
| ResidualKind::Like { expr, .. }
| ResidualKind::Is { expr, .. } => expr.all_literal_uids(),
ResidualKind::Set(elements) => {
let mut uids = HashSet::new();
for element in elements.as_ref() {
uids.extend(element.all_literal_uids());
}
uids
}
ResidualKind::Record(map) => {
let mut uids = HashSet::new();
for value in map.values() {
uids.extend(value.all_literal_uids());
}
uids
}
}
}
}
impl Residual {
pub fn is_true(&self) -> bool {
matches!(
self,
Residual::Concrete {
value: Value {
value: ValueKind::Lit(Literal::Bool(true)),
..
},
..
}
)
}
pub fn is_false(&self) -> bool {
matches!(
self,
Residual::Concrete {
value: Value {
value: ValueKind::Lit(Literal::Bool(false)),
..
},
..
}
)
}
pub fn is_error(&self) -> bool {
matches!(self, Residual::Error { .. })
}
}
#[expect(
clippy::fallible_impl_from,
reason = "Residual to Expr conversion should always succeed"
)]
impl From<Residual> for Expr {
fn from(value: Residual) -> Expr {
match value {
Residual::Partial { kind, .. } => {
let builder: ast::ExprBuilder<()> = ExprBuilder::with_data(());
match kind {
ResidualKind::And { left, right } => {
builder.and(left.as_ref().clone().into(), right.as_ref().clone().into())
}
ResidualKind::BinaryApp { op, arg1, arg2 } => builder.binary_app(
op,
arg1.as_ref().clone().into(),
arg2.as_ref().clone().into(),
),
ResidualKind::ExtensionFunctionApp { fn_name, args } => builder
.call_extension_fn(
fn_name,
args.as_ref()
.clone()
.into_iter()
.map(|arg| arg.into())
.collect::<Vec<_>>(),
)
.unwrap_infallible(),
ResidualKind::GetAttr { expr, attr } => {
builder.get_attr(expr.as_ref().clone().into(), attr)
}
ResidualKind::HasAttr { expr, attr } => {
builder.has_attr(expr.as_ref().clone().into(), attr)
}
ResidualKind::If {
test_expr,
then_expr,
else_expr,
} => builder.ite(
test_expr.as_ref().clone().into(),
then_expr.as_ref().clone().into(),
else_expr.as_ref().clone().into(),
),
ResidualKind::Is { expr, entity_type } => {
builder.is_entity_type(expr.as_ref().clone().into(), entity_type)
}
ResidualKind::Like { expr, pattern } => {
builder.like(expr.as_ref().clone().into(), pattern)
}
ResidualKind::Or { left, right } => {
builder.or(left.as_ref().clone().into(), right.as_ref().clone().into())
}
#[expect(clippy::expect_used, reason = "record construction should succeed")]
ResidualKind::Record(map) => builder
.record(map.as_ref().clone().into_iter().map(|(k, v)| (k, v.into())))
.expect("should succeed"),
ResidualKind::Set(set) => builder.set(
set.as_ref()
.clone()
.into_iter()
.map(|v| v.into())
.collect::<Vec<_>>(),
),
ResidualKind::UnaryApp { op, arg } => {
builder.unary_app(op, arg.as_ref().clone().into())
}
ResidualKind::Var(v) => builder.var(v),
}
}
Residual::Concrete { value, .. } => value.into(),
Residual::Error(_) => {
let builder: ast::ExprBuilder<()> = ExprBuilder::with_data(());
#[expect(clippy::unwrap_used, reason = "`error` is a valid `Name`")]
builder
.call_extension_fn("error".parse().unwrap(), std::iter::empty())
.unwrap_infallible() }
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::extensions::Extensions;
use crate::parser::parse_expr;
use crate::tpe::request::{PartialEntityUID, PartialRequest};
use crate::validator::typecheck::{PolicyCheck, Typechecker};
use crate::validator::types::BoolType;
use crate::validator::{ValidationMode, Validator, ValidatorSchema};
use similar_asserts::assert_eq;
#[track_caller]
fn parse_residual(expr_str: &str) -> Residual {
let expr = parse_expr(expr_str).unwrap();
let policy_id = crate::ast::PolicyID::from_string("test");
let policy = Policy::from_when_clause(Effect::Permit, expr, policy_id, None);
let t = policy.template();
let schema = ValidatorSchema::from_cedarschema_str(r#"
entity User in Organization { foo: Bool, str: String, num: Long, period: __cedar::duration, set: Set<String> } tags String;
entity Organization;
entity Document in Organization;
action get appliesTo { principal: [User], resource: [Document] };"#,
&Extensions::all_available(),
)
.unwrap()
.0;
let typechecker = Typechecker::new(&schema, ValidationMode::Strict);
let request = PartialRequest::new(
PartialEntityUID {
ty: "User".parse().unwrap(),
eid: None,
},
r#"Action::"get""#.parse().unwrap(),
PartialEntityUID {
ty: "Document".parse().unwrap(),
eid: None,
},
None,
&schema,
)
.unwrap();
let env = request.find_request_env(&schema).unwrap();
let errs: Vec<_> = Validator::validate_entity_types_and_literals(&schema, t).collect();
if !errs.is_empty() {
panic!("unexpected type error in expression");
}
match typechecker.typecheck_by_single_request_env(t, &env) {
PolicyCheck::Success(expr) => Residual::try_from(&expr).unwrap(),
PolicyCheck::Fail(errs) => {
println!("got {} type errors", errs.len());
for e in errs {
println!("{:?}", miette::Report::new(e));
}
panic!("unexpected type error in expression")
}
PolicyCheck::Irrelevant(errs, expr) => {
if errs.is_empty() {
Residual::try_from(&expr).unwrap()
} else {
println!("got {} type errors", errs.len());
for e in errs {
println!("{:?}", miette::Report::new(e));
}
panic!("unexpected type error in expression")
}
}
}
}
#[test]
fn test_can_error_assuming_well_formed() {
assert_eq!(
parse_residual(
r#"
principal is User &&
principal in Organization::"foo" &&
action == Action::"get" &&
resource is Document &&
resource in Organization::"foo"
"#
)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(r#"User::"jane" in [User::"foo", User::"jane"]"#)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(r#"principal has foo || principal.hasTag("foo")"#)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(r#"principal == resource && !(principal in Organization::"foo")"#)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(
r#"
if principal.hasTag("foo") then
principal in Organization::"foo"
else principal in Organization::"bar"
"#
)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(
r#"
1 == 2 ||
!("a" == "b") &&
["a", "b"].contains("a") &&
!["a", "b"].containsAll(["a"]) &&
["a", "b"].containsAny(["a"])
"#
)
.can_error_assuming_well_formed(),
false
);
assert_eq!(
parse_residual(r#"{a: true, b: false}["a"] && false"#).can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"User::"jane".str like "jane-*""#).can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(
r#"if principal.num > 0 then User::"jane".num >= 100 else User::"foo".num == 1"#
)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.hasTag("foo") && principal.getTag("foo") == "bar""#)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(
r#"
!principal.set.isEmpty() && (
principal.set.contains("foo") ||
principal.set.containsAll(["foo", "bar"]) ||
principal.set.containsAny(["foo", "bar"])
)"#
)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.num + 1 == 100 || true"#).can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"if principal.foo then principal.num - 1 == 100 else true"#)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.foo && principal.num * 2 == 100"#)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.foo || -principal.num == 100"#)
.can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.num == 1 && principal.period < (if principal.foo then duration("1d") else duration("2d"))"#).can_error_assuming_well_formed(),
true
);
assert_eq!(
parse_residual(r#"principal.period.toDays() == 365"#).can_error_assuming_well_formed(),
true
);
assert_eq!(
Residual::Error(Type::Bool(BoolType::AnyBool)).can_error_assuming_well_formed(),
true
);
}
mod literal_uids {
use similar_asserts::assert_eq;
use std::collections::HashSet;
use super::parse_residual;
#[test]
fn var() {
assert_eq!(
parse_residual("principal.foo").all_literal_uids(),
HashSet::new()
);
}
#[test]
fn r#if() {
assert_eq!(
parse_residual(
r#"if User::"alice".foo then User::"bob".foo else User::"jane".foo"#
)
.all_literal_uids(),
HashSet::from([
r#"User::"alice""#.parse().unwrap(),
r#"User::"bob""#.parse().unwrap(),
r#"User::"jane""#.parse().unwrap(),
])
);
}
#[test]
fn and() {
assert_eq!(
parse_residual(r#"User::"alice".foo && User::"jane".foo"#).all_literal_uids(),
HashSet::from([
r#"User::"alice""#.parse().unwrap(),
r#"User::"jane""#.parse().unwrap(),
])
);
}
#[test]
fn set() {
assert_eq!(
parse_residual(r#"principal in [User::"alice", User::"jane"]"#).all_literal_uids(),
HashSet::from([
r#"User::"alice""#.parse().unwrap(),
r#"User::"jane""#.parse().unwrap(),
])
);
}
#[test]
fn record() {
assert_eq!(
parse_residual(r#"(if principal.foo then {a: User::"alice", b: true} else {a: User::"jane", b: false}).a.foo"#).all_literal_uids(),
HashSet::from([
r#"User::"alice""#.parse().unwrap(),
r#"User::"jane""#.parse().unwrap(),
])
);
}
}
fn assert_eq_expr(expr_str: &str) {
let e: Expr = format!("true && (true && (true && ({})))", expr_str)
.parse()
.unwrap();
let residual = parse_residual(expr_str);
let e2 = Expr::from(residual);
println!("e: {}", e);
println!("e2: {}", e2);
assert_eq!(e, e2);
}
#[test]
fn to_expr() {
assert_eq_expr(r#"User::"alice".foo && User::"jane".foo"#);
assert_eq_expr(r#"User::"alice".foo || User::"jane".foo"#);
assert_eq_expr(r#"[User::"jane".foo].contains(User::"jane".foo)"#);
assert_eq_expr(r#"User::"alice" has foo"#);
assert_eq_expr(r#"(if User::"alice".foo then User::"bob" else User::"jane").foo"#);
assert_eq_expr(r#""foo" like "bar""#);
assert_eq_expr(r#"principal in [User::"alice", User::"jane"]"#);
assert_eq_expr(
r#"(if principal.foo then {a: User::"alice", b: true} else {a: User::"jane", b: false}).a.foo"#,
);
}
}