Skip to main content

cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::{ast::*, parser::err::ParseErrors, parser::Loc};
18use miette::Diagnostic;
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use std::{
22    collections::{btree_map, BTreeMap, HashMap},
23    hash::{Hash, Hasher},
24    mem,
25    sync::Arc,
26};
27use thiserror::Error;
28
29#[cfg(feature = "wasm")]
30extern crate tsify;
31
32/// Internal AST for expressions used by the policy evaluator.
33/// This structure is a wrapper around an `ExprKind`, which is the expression
34/// variant this object contains. It also contains source information about
35/// where the expression was written in policy source code, and some generic
36/// data which is stored on each node of the AST.
37/// Cloning is O(1).
38#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
39pub struct Expr<T = ()> {
40    expr_kind: ExprKind<T>,
41    source_loc: Option<Loc>,
42    data: T,
43}
44
45/// The possible expression variants. This enum should be matched on by code
46/// recursively traversing the AST.
47#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
48pub enum ExprKind<T = ()> {
49    /// Literal value
50    Lit(Literal),
51    /// Variable
52    Var(Var),
53    /// Template Slots
54    Slot(SlotId),
55    /// Symbolic Unknown for partial-eval
56    Unknown(Unknown),
57    /// Ternary expression
58    If {
59        /// Condition for the ternary expression. Must evaluate to Bool type
60        test_expr: Arc<Expr<T>>,
61        /// Value if true
62        then_expr: Arc<Expr<T>>,
63        /// Value if false
64        else_expr: Arc<Expr<T>>,
65    },
66    /// Boolean AND
67    And {
68        /// Left operand, which will be eagerly evaluated
69        left: Arc<Expr<T>>,
70        /// Right operand, which may not be evaluated due to short-circuiting
71        right: Arc<Expr<T>>,
72    },
73    /// Boolean OR
74    Or {
75        /// Left operand, which will be eagerly evaluated
76        left: Arc<Expr<T>>,
77        /// Right operand, which may not be evaluated due to short-circuiting
78        right: Arc<Expr<T>>,
79    },
80    /// Application of a built-in unary operator (single parameter)
81    UnaryApp {
82        /// Unary operator to apply
83        op: UnaryOp,
84        /// Argument to apply operator to
85        arg: Arc<Expr<T>>,
86    },
87    /// Application of a built-in binary operator (two parameters)
88    BinaryApp {
89        /// Binary operator to apply
90        op: BinaryOp,
91        /// First arg
92        arg1: Arc<Expr<T>>,
93        /// Second arg
94        arg2: Arc<Expr<T>>,
95    },
96    /// Application of an extension function to n arguments
97    /// INVARIANT (MethodStyleArgs):
98    ///   if op.style is MethodStyle then args _cannot_ be empty.
99    ///     The first element of args refers to the subject of the method call
100    /// Ideally, we find some way to make this non-representable.
101    ExtensionFunctionApp {
102        /// Extension function to apply
103        fn_name: Name,
104        /// Args to apply the function to
105        args: Arc<Vec<Expr<T>>>,
106    },
107    /// Get an attribute of an entity, or a field of a record
108    GetAttr {
109        /// Expression to get an attribute/field of. Must evaluate to either
110        /// Entity or Record type
111        expr: Arc<Expr<T>>,
112        /// Attribute or field to get
113        attr: SmolStr,
114    },
115    /// Does the given `expr` have the given `attr`?
116    HasAttr {
117        /// Expression to test. Must evaluate to either Entity or Record type
118        expr: Arc<Expr<T>>,
119        /// Attribute or field to check for
120        attr: SmolStr,
121    },
122    /// Regex-like string matching similar to IAM's `StringLike` operator.
123    Like {
124        /// Expression to test. Must evaluate to String type
125        expr: Arc<Expr<T>>,
126        /// Pattern to match on; can include the wildcard *, which matches any string.
127        /// To match a literal `*` in the test expression, users can use `\*`.
128        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
129        pattern: Pattern,
130    },
131    /// Entity type test. Does the first argument have the entity type
132    /// specified by the second argument.
133    Is {
134        /// Expression to test. Must evaluate to an Entity.
135        expr: Arc<Expr<T>>,
136        /// The entity type `Name` used for the type membership test.
137        entity_type: Name,
138    },
139    /// Set (whose elements may be arbitrary expressions)
140    //
141    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
142    // that are syntactically unequal, may actually be semantically equal --
143    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
144    // evaluated into `Value`s
145    Set(Arc<Vec<Expr<T>>>),
146    /// Anonymous record (whose elements may be arbitrary expressions)
147    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
148}
149
150impl From<Value> for Expr {
151    fn from(v: Value) -> Self {
152        Expr::from(v.value).with_maybe_source_loc(v.loc)
153    }
154}
155
156impl From<ValueKind> for Expr {
157    fn from(v: ValueKind) -> Self {
158        match v {
159            ValueKind::Lit(lit) => Expr::val(lit),
160            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
161            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
162            #[allow(clippy::expect_used)]
163            ValueKind::Record(record) => Expr::record(
164                Arc::unwrap_or_clone(record)
165                    .into_iter()
166                    .map(|(k, v)| (k, Expr::from(v))),
167            )
168            .expect("cannot have duplicate key because the input was already a BTreeMap"),
169            ValueKind::ExtensionValue(ev) => Expr::from(ev.as_ref().clone()),
170        }
171    }
172}
173
174impl From<PartialValue> for Expr {
175    fn from(pv: PartialValue) -> Self {
176        match pv {
177            PartialValue::Value(v) => Expr::from(v),
178            PartialValue::Residual(expr) => expr,
179        }
180    }
181}
182
183impl<T> Expr<T> {
184    fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
185        Self {
186            expr_kind,
187            source_loc,
188            data,
189        }
190    }
191
192    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
193    /// `enum` which specifies the expression variant, so it must be accessed by
194    /// any code matching and recursing on an expression.
195    pub fn expr_kind(&self) -> &ExprKind<T> {
196        &self.expr_kind
197    }
198
199    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
200    pub fn into_expr_kind(self) -> ExprKind<T> {
201        self.expr_kind
202    }
203
204    /// Access the data stored on the `Expr`.
205    pub fn data(&self) -> &T {
206        &self.data
207    }
208
209    /// Access the data stored on the `Expr`, taking ownership and consuming the
210    /// `Expr`.
211    pub fn into_data(self) -> T {
212        self.data
213    }
214
215    /// Access the `Loc` stored on the `Expr`.
216    pub fn source_loc(&self) -> Option<&Loc> {
217        self.source_loc.as_ref()
218    }
219
220    /// Return the `Expr`, but with the new `source_loc` (or `None`).
221    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
222        Self { source_loc, ..self }
223    }
224
225    /// Update the data for this `Expr`. A convenient function used by the
226    /// Validator in one place.
227    pub fn set_data(&mut self, data: T) {
228        self.data = data;
229    }
230
231    /// Check whether this expression is an entity reference
232    ///
233    /// This is used for policy scopes, where some syntax is
234    /// required to be an entity reference.
235    pub fn is_ref(&self) -> bool {
236        match &self.expr_kind {
237            ExprKind::Lit(lit) => lit.is_ref(),
238            _ => false,
239        }
240    }
241
242    /// Check whether this expression is a slot.
243    pub fn is_slot(&self) -> bool {
244        matches!(&self.expr_kind, ExprKind::Slot(_))
245    }
246
247    /// Check whether this expression is a set of entity references
248    ///
249    /// This is used for policy scopes, where some syntax is
250    /// required to be an entity reference set.
251    pub fn is_ref_set(&self) -> bool {
252        match &self.expr_kind {
253            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
254            _ => false,
255        }
256    }
257
258    /// Iterate over all sub-expressions in this expression
259    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
260        expr_iterator::ExprIterator::new(self)
261    }
262
263    /// Iterate over all of the slots in this policy AST
264    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
265        self.subexpressions()
266            .filter_map(|exp| match &exp.expr_kind {
267                ExprKind::Slot(slotid) => Some(Slot {
268                    id: *slotid,
269                    loc: exp.source_loc().cloned(),
270                }),
271                _ => None,
272            })
273    }
274
275    /// Determine if the expression is projectable under partial evaluation
276    /// An expression is projectable if it's guaranteed to never error on evaluation
277    /// This is true if the expression is entirely composed of values or unknowns
278    pub fn is_projectable(&self) -> bool {
279        self.subexpressions().all(|e| {
280            matches!(
281                e.expr_kind(),
282                ExprKind::Lit(_)
283                    | ExprKind::Unknown(_)
284                    | ExprKind::Set(_)
285                    | ExprKind::Var(_)
286                    | ExprKind::Record(_)
287            )
288        })
289    }
290}
291
292#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
293#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
294impl Expr {
295    /// Create an `Expr` that's just a single `Literal`.
296    ///
297    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
298    pub fn val(v: impl Into<Literal>) -> Self {
299        ExprBuilder::new().val(v)
300    }
301
302    /// Create an `Expr` that's just a single `Unknown`.
303    pub fn unknown(u: Unknown) -> Self {
304        ExprBuilder::new().unknown(u)
305    }
306
307    /// Create an `Expr` that's just this literal `Var`
308    pub fn var(v: Var) -> Self {
309        ExprBuilder::new().var(v)
310    }
311
312    /// Create an `Expr` that's just this `SlotId`
313    pub fn slot(s: SlotId) -> Self {
314        ExprBuilder::new().slot(s)
315    }
316
317    /// Create a ternary (if-then-else) `Expr`.
318    ///
319    /// `test_expr` must evaluate to a Bool type
320    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
321        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
322    }
323
324    /// Create a ternary (if-then-else) `Expr`.
325    /// Takes `Arc`s instead of owned `Expr`s.
326    /// `test_expr` must evaluate to a Bool type
327    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
328        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
329    }
330
331    /// Create a 'not' expression. `e` must evaluate to Bool type
332    pub fn not(e: Expr) -> Self {
333        ExprBuilder::new().not(e)
334    }
335
336    /// Create a '==' expression
337    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
338        ExprBuilder::new().is_eq(e1, e2)
339    }
340
341    /// Create a '!=' expression
342    pub fn noteq(e1: Expr, e2: Expr) -> Self {
343        ExprBuilder::new().noteq(e1, e2)
344    }
345
346    /// Create an 'and' expression. Arguments must evaluate to Bool type
347    pub fn and(e1: Expr, e2: Expr) -> Self {
348        ExprBuilder::new().and(e1, e2)
349    }
350
351    /// Create an 'or' expression. Arguments must evaluate to Bool type
352    pub fn or(e1: Expr, e2: Expr) -> Self {
353        ExprBuilder::new().or(e1, e2)
354    }
355
356    /// Create a '<' expression. Arguments must evaluate to Long type
357    pub fn less(e1: Expr, e2: Expr) -> Self {
358        ExprBuilder::new().less(e1, e2)
359    }
360
361    /// Create a '<=' expression. Arguments must evaluate to Long type
362    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
363        ExprBuilder::new().lesseq(e1, e2)
364    }
365
366    /// Create a '>' expression. Arguments must evaluate to Long type
367    pub fn greater(e1: Expr, e2: Expr) -> Self {
368        ExprBuilder::new().greater(e1, e2)
369    }
370
371    /// Create a '>=' expression. Arguments must evaluate to Long type
372    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
373        ExprBuilder::new().greatereq(e1, e2)
374    }
375
376    /// Create an 'add' expression. Arguments must evaluate to Long type
377    pub fn add(e1: Expr, e2: Expr) -> Self {
378        ExprBuilder::new().add(e1, e2)
379    }
380
381    /// Create a 'sub' expression. Arguments must evaluate to Long type
382    pub fn sub(e1: Expr, e2: Expr) -> Self {
383        ExprBuilder::new().sub(e1, e2)
384    }
385
386    /// Create a 'mul' expression. Arguments must evaluate to Long type
387    pub fn mul(e1: Expr, e2: Expr) -> Self {
388        ExprBuilder::new().mul(e1, e2)
389    }
390
391    /// Create a 'neg' expression. `e` must evaluate to Long type.
392    pub fn neg(e: Expr) -> Self {
393        ExprBuilder::new().neg(e)
394    }
395
396    /// Create an 'in' expression. First argument must evaluate to Entity type.
397    /// Second argument must evaluate to either Entity type or Set type where
398    /// all set elements have Entity type.
399    pub fn is_in(e1: Expr, e2: Expr) -> Self {
400        ExprBuilder::new().is_in(e1, e2)
401    }
402
403    /// Create a 'contains' expression.
404    /// First argument must have Set type.
405    pub fn contains(e1: Expr, e2: Expr) -> Self {
406        ExprBuilder::new().contains(e1, e2)
407    }
408
409    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
410    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
411        ExprBuilder::new().contains_all(e1, e2)
412    }
413
414    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
415    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
416        ExprBuilder::new().contains_any(e1, e2)
417    }
418
419    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
420    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
421        ExprBuilder::new().set(exprs)
422    }
423
424    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
425    pub fn record(
426        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
427    ) -> Result<Self, ExprConstructionError> {
428        ExprBuilder::new().record(pairs)
429    }
430
431    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
432    ///
433    /// If you have an iterator of pairs, generally prefer calling
434    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
435    /// potentially for efficiency reasons but also because `Expr::record()`
436    /// will properly handle duplicate keys but your own `.collect()` will not
437    /// (by default).
438    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
439        ExprBuilder::new().record_arc(map)
440    }
441
442    /// Create an `Expr` which calls the extension function with the given
443    /// `Name` on `args`
444    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
445        ExprBuilder::new().call_extension_fn(fn_name, args)
446    }
447
448    /// Create an application `Expr` which applies the given built-in unary
449    /// operator to the given `arg`
450    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
451        ExprBuilder::new().unary_app(op, arg)
452    }
453
454    /// Create an application `Expr` which applies the given built-in binary
455    /// operator to `arg1` and `arg2`
456    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
457        ExprBuilder::new().binary_app(op, arg1, arg2)
458    }
459
460    /// Create an `Expr` which gets the attribute of some `Entity` or the field
461    /// of some record.
462    ///
463    /// `expr` must evaluate to either Entity or Record type
464    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
465        ExprBuilder::new().get_attr(expr, attr)
466    }
467
468    /// Create an `Expr` which tests for the existence of a given
469    /// attribute on a given `Entity`, or field on a given record.
470    ///
471    /// `expr` must evaluate to either Entity or Record type
472    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
473        ExprBuilder::new().has_attr(expr, attr)
474    }
475
476    /// Create a 'like' expression.
477    ///
478    /// `expr` must evaluate to a String type
479    pub fn like(expr: Expr, pattern: impl IntoIterator<Item = PatternElem>) -> Self {
480        ExprBuilder::new().like(expr, pattern)
481    }
482
483    /// Create an `is` expression.
484    pub fn is_entity_type(expr: Expr, entity_type: Name) -> Self {
485        ExprBuilder::new().is_entity_type(expr, entity_type)
486    }
487
488    /// Check if an expression contains any symbolic unknowns
489    pub fn is_unknown(&self) -> bool {
490        self.subexpressions()
491            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
492    }
493
494    /// Get all unknowns in an expression
495    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
496        self.subexpressions()
497            .filter_map(|subexpr| match subexpr.expr_kind() {
498                ExprKind::Unknown(u) => Some(u),
499                _ => None,
500            })
501    }
502
503    /// Substitute unknowns with concrete values.
504    ///
505    /// Ignores unmapped unknowns.
506    /// Ignores type annotations on unknowns.
507    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
508        match self.substitute_general::<UntypedSubstitution>(definitions) {
509            Ok(e) => e,
510            Err(empty) => match empty {},
511        }
512    }
513
514    /// Substitute unknowns with concrete values.
515    ///
516    /// Ignores unmapped unknowns.
517    /// Errors if the substituted value does not match the type annotation on the unknown.
518    pub fn substitute_typed(
519        &self,
520        definitions: &HashMap<SmolStr, Value>,
521    ) -> Result<Expr, SubstitutionError> {
522        self.substitute_general::<TypedSubstitution>(definitions)
523    }
524
525    /// Substitute unknowns with values
526    ///
527    /// Generic over the function implementing the substitution to allow for multiple error behaviors
528    fn substitute_general<T: SubstitutionFunction>(
529        &self,
530        definitions: &HashMap<SmolStr, Value>,
531    ) -> Result<Expr, T::Err> {
532        match self.expr_kind() {
533            ExprKind::Lit(_) => Ok(self.clone()),
534            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
535            ExprKind::Var(_) => Ok(self.clone()),
536            ExprKind::Slot(_) => Ok(self.clone()),
537            ExprKind::If {
538                test_expr,
539                then_expr,
540                else_expr,
541            } => Ok(Expr::ite(
542                test_expr.substitute_general::<T>(definitions)?,
543                then_expr.substitute_general::<T>(definitions)?,
544                else_expr.substitute_general::<T>(definitions)?,
545            )),
546            ExprKind::And { left, right } => Ok(Expr::and(
547                left.substitute_general::<T>(definitions)?,
548                right.substitute_general::<T>(definitions)?,
549            )),
550            ExprKind::Or { left, right } => Ok(Expr::or(
551                left.substitute_general::<T>(definitions)?,
552                right.substitute_general::<T>(definitions)?,
553            )),
554            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
555                *op,
556                arg.substitute_general::<T>(definitions)?,
557            )),
558            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
559                *op,
560                arg1.substitute_general::<T>(definitions)?,
561                arg2.substitute_general::<T>(definitions)?,
562            )),
563            ExprKind::ExtensionFunctionApp { fn_name, args } => {
564                let args = args
565                    .iter()
566                    .map(|e| e.substitute_general::<T>(definitions))
567                    .collect::<Result<Vec<Expr>, _>>()?;
568
569                Ok(Expr::call_extension_fn(fn_name.clone(), args))
570            }
571            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
572                expr.substitute_general::<T>(definitions)?,
573                attr.clone(),
574            )),
575            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
576                expr.substitute_general::<T>(definitions)?,
577                attr.clone(),
578            )),
579            ExprKind::Like { expr, pattern } => Ok(Expr::like(
580                expr.substitute_general::<T>(definitions)?,
581                pattern.iter().cloned(),
582            )),
583            ExprKind::Set(members) => {
584                let members = members
585                    .iter()
586                    .map(|e| e.substitute_general::<T>(definitions))
587                    .collect::<Result<Vec<_>, _>>()?;
588                Ok(Expr::set(members))
589            }
590            ExprKind::Record(map) => {
591                let map = map
592                    .iter()
593                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
594                    .collect::<Result<BTreeMap<_, _>, _>>()?;
595                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
596                #[allow(clippy::expect_used)]
597                Ok(Expr::record(map)
598                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
599            }
600            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
601                expr.substitute_general::<T>(definitions)?,
602                entity_type.clone(),
603            )),
604        }
605    }
606}
607
608/// A trait for customizing the error behavior of substitution
609trait SubstitutionFunction {
610    /// The potential errors this substitution function can return
611    type Err;
612    /// The function for implementing the substitution.
613    ///
614    /// Takes the expression being substituted,
615    /// The substitution from the map (if present)
616    /// and the type annotation from the unknown (if present)
617    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
618}
619
620struct TypedSubstitution {}
621
622impl SubstitutionFunction for TypedSubstitution {
623    type Err = SubstitutionError;
624
625    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
626        match (substitute, &value.type_annotation) {
627            (None, _) => Ok(Expr::unknown(value.clone())),
628            (Some(v), None) => Ok(v.clone().into()),
629            (Some(v), Some(t)) => {
630                if v.type_of() == *t {
631                    Ok(v.clone().into())
632                } else {
633                    Err(SubstitutionError::TypeError {
634                        expected: t.clone(),
635                        actual: v.type_of(),
636                    })
637                }
638            }
639        }
640    }
641}
642
643struct UntypedSubstitution {}
644
645impl SubstitutionFunction for UntypedSubstitution {
646    type Err = std::convert::Infallible;
647
648    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
649        Ok(substitute
650            .map(|v| v.clone().into())
651            .unwrap_or_else(|| Expr::unknown(value.clone())))
652    }
653}
654
655impl std::fmt::Display for Expr {
656    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
657        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
658        // we just convert to EST and use the EST pretty-printer.
659        // Note that converting AST->EST is lossless and infallible.
660        write!(f, "{}", crate::est::Expr::from(self.clone()))
661    }
662}
663
664impl std::str::FromStr for Expr {
665    type Err = ParseErrors;
666
667    fn from_str(s: &str) -> Result<Expr, Self::Err> {
668        crate::parser::parse_expr(s)
669    }
670}
671
672/// Enum for errors encountered during substitution
673#[derive(Debug, Clone, Diagnostic, Error)]
674pub enum SubstitutionError {
675    /// The supplied value did not match the type annotation on the unknown.
676    #[error("expected a value of type {expected}, got a value of type {actual}")]
677    TypeError {
678        /// The expected type, ie: the type the unknown was annotated with
679        expected: Type,
680        /// The type of the provided value
681        actual: Type,
682    },
683}
684
685/// Representation of a partial-evaluation Unknown at the AST level
686#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
687pub struct Unknown {
688    /// The name of the unknown
689    pub name: SmolStr,
690    /// The type of the values that can be substituted in for the unknown.
691    /// If `None`, we have no type annotation, and thus a value of any type can
692    /// be substituted.
693    pub type_annotation: Option<Type>,
694}
695
696impl Unknown {
697    /// Create a new untyped `Unknown`
698    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
699        Self {
700            name: name.into(),
701            type_annotation: None,
702        }
703    }
704
705    /// Create a new `Unknown` with type annotation. (Only values of the given
706    /// type can be substituted.)
707    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
708        Self {
709            name: name.into(),
710            type_annotation: Some(ty),
711        }
712    }
713}
714
715impl std::fmt::Display for Unknown {
716    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
717        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
718        // to avoid code duplication
719        write!(f, "{}", crate::est::Expr::from(Expr::unknown(self.clone())))
720    }
721}
722
723/// Builder for constructing `Expr` objects annotated with some `data`
724/// (possibly taking default value) and optionally a `source_loc`.
725#[derive(Debug)]
726pub struct ExprBuilder<T> {
727    source_loc: Option<Loc>,
728    data: T,
729}
730
731impl<T> ExprBuilder<T>
732where
733    T: Default,
734{
735    /// Construct a new `ExprBuilder` where the data used for an expression
736    /// takes a default value.
737    pub fn new() -> Self {
738        Self {
739            source_loc: None,
740            data: T::default(),
741        }
742    }
743
744    /// Create a '!=' expression.
745    /// Defined only for `T: Default` because the caller would otherwise need to
746    /// provide a `data` for the intermediate `not` Expr node.
747    pub fn noteq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
748        match &self.source_loc {
749            Some(source_loc) => ExprBuilder::new().with_source_loc(source_loc.clone()),
750            None => ExprBuilder::new(),
751        }
752        .not(self.with_expr_kind(ExprKind::BinaryApp {
753            op: BinaryOp::Eq,
754            arg1: Arc::new(e1),
755            arg2: Arc::new(e2),
756        }))
757    }
758}
759
760impl<T: Default> Default for ExprBuilder<T> {
761    fn default() -> Self {
762        Self::new()
763    }
764}
765
766impl<T> ExprBuilder<T> {
767    /// Construct a new `ExprBuild` where the specified data will be stored on
768    /// the `Expr`. This constructor does not populate the `source_loc` field,
769    /// so `with_source_loc` should be called if constructing an `Expr` where
770    /// the source location is known.
771    pub fn with_data(data: T) -> Self {
772        Self {
773            source_loc: None,
774            data,
775        }
776    }
777
778    /// Update the `ExprBuilder` to build an expression with some known location
779    /// in policy source code.
780    pub fn with_source_loc(self, source_loc: Loc) -> Self {
781        self.with_maybe_source_loc(Some(source_loc))
782    }
783
784    /// Utility used the validator to get an expression with the same source
785    /// location as an existing expression. This is done when reconstructing the
786    /// `Expr` with type information.
787    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
788        self.with_maybe_source_loc(expr.source_loc.clone())
789    }
790
791    /// internally used to update `.source_loc` to the given `Some` or `None`
792    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<Loc>) -> Self {
793        self.source_loc = maybe_source_loc;
794        self
795    }
796
797    /// Internally used by the following methods to construct an `Expr`
798    /// containing the `data` and `source_loc` in this `ExprBuilder` with some
799    /// inner `ExprKind`.
800    fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
801        Expr::new(expr_kind, self.source_loc, self.data)
802    }
803
804    /// Create an `Expr` that's just a single `Literal`.
805    ///
806    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
807    pub fn val(self, v: impl Into<Literal>) -> Expr<T> {
808        self.with_expr_kind(ExprKind::Lit(v.into()))
809    }
810
811    /// Create an `Unknown` `Expr`
812    pub fn unknown(self, u: Unknown) -> Expr<T> {
813        self.with_expr_kind(ExprKind::Unknown(u))
814    }
815
816    /// Create an `Expr` that's just this literal `Var`
817    pub fn var(self, v: Var) -> Expr<T> {
818        self.with_expr_kind(ExprKind::Var(v))
819    }
820
821    /// Create an `Expr` that's just this `SlotId`
822    pub fn slot(self, s: SlotId) -> Expr<T> {
823        self.with_expr_kind(ExprKind::Slot(s))
824    }
825
826    /// Create a ternary (if-then-else) `Expr`.
827    ///
828    /// `test_expr` must evaluate to a Bool type
829    pub fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
830        self.with_expr_kind(ExprKind::If {
831            test_expr: Arc::new(test_expr),
832            then_expr: Arc::new(then_expr),
833            else_expr: Arc::new(else_expr),
834        })
835    }
836
837    /// Create a ternary (if-then-else) `Expr`.
838    /// Takes `Arc`s instead of owned `Expr`s.
839    /// `test_expr` must evaluate to a Bool type
840    pub fn ite_arc(
841        self,
842        test_expr: Arc<Expr<T>>,
843        then_expr: Arc<Expr<T>>,
844        else_expr: Arc<Expr<T>>,
845    ) -> Expr<T> {
846        self.with_expr_kind(ExprKind::If {
847            test_expr,
848            then_expr,
849            else_expr,
850        })
851    }
852
853    /// Create a 'not' expression. `e` must evaluate to Bool type
854    pub fn not(self, e: Expr<T>) -> Expr<T> {
855        self.with_expr_kind(ExprKind::UnaryApp {
856            op: UnaryOp::Not,
857            arg: Arc::new(e),
858        })
859    }
860
861    /// Create a '==' expression
862    pub fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
863        self.with_expr_kind(ExprKind::BinaryApp {
864            op: BinaryOp::Eq,
865            arg1: Arc::new(e1),
866            arg2: Arc::new(e2),
867        })
868    }
869
870    /// Create an 'and' expression. Arguments must evaluate to Bool type
871    pub fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
872        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
873            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
874                ExprKind::Lit(Literal::Bool(*b1 && *b2))
875            }
876            _ => ExprKind::And {
877                left: Arc::new(e1),
878                right: Arc::new(e2),
879            },
880        })
881    }
882
883    /// Create an 'or' expression. Arguments must evaluate to Bool type
884    pub fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
885        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
886            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
887                ExprKind::Lit(Literal::Bool(*b1 || *b2))
888            }
889
890            _ => ExprKind::Or {
891                left: Arc::new(e1),
892                right: Arc::new(e2),
893            },
894        })
895    }
896
897    /// Create a '<' expression. Arguments must evaluate to Long type
898    pub fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
899        self.with_expr_kind(ExprKind::BinaryApp {
900            op: BinaryOp::Less,
901            arg1: Arc::new(e1),
902            arg2: Arc::new(e2),
903        })
904    }
905
906    /// Create a '<=' expression. Arguments must evaluate to Long type
907    pub fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
908        self.with_expr_kind(ExprKind::BinaryApp {
909            op: BinaryOp::LessEq,
910            arg1: Arc::new(e1),
911            arg2: Arc::new(e2),
912        })
913    }
914
915    /// Create an 'add' expression. Arguments must evaluate to Long type
916    pub fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
917        self.with_expr_kind(ExprKind::BinaryApp {
918            op: BinaryOp::Add,
919            arg1: Arc::new(e1),
920            arg2: Arc::new(e2),
921        })
922    }
923
924    /// Create a 'sub' expression. Arguments must evaluate to Long type
925    pub fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
926        self.with_expr_kind(ExprKind::BinaryApp {
927            op: BinaryOp::Sub,
928            arg1: Arc::new(e1),
929            arg2: Arc::new(e2),
930        })
931    }
932
933    /// Create a 'mul' expression. Arguments must evaluate to Long type
934    pub fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
935        self.with_expr_kind(ExprKind::BinaryApp {
936            op: BinaryOp::Mul,
937            arg1: Arc::new(e1),
938            arg2: Arc::new(e2),
939        })
940    }
941
942    /// Create a 'neg' expression. `e` must evaluate to Long type.
943    pub fn neg(self, e: Expr<T>) -> Expr<T> {
944        self.with_expr_kind(ExprKind::UnaryApp {
945            op: UnaryOp::Neg,
946            arg: Arc::new(e),
947        })
948    }
949
950    /// Create an 'in' expression. First argument must evaluate to Entity type.
951    /// Second argument must evaluate to either Entity type or Set type where
952    /// all set elements have Entity type.
953    pub fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
954        self.with_expr_kind(ExprKind::BinaryApp {
955            op: BinaryOp::In,
956            arg1: Arc::new(e1),
957            arg2: Arc::new(e2),
958        })
959    }
960
961    /// Create a 'contains' expression.
962    /// First argument must have Set type.
963    pub fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
964        self.with_expr_kind(ExprKind::BinaryApp {
965            op: BinaryOp::Contains,
966            arg1: Arc::new(e1),
967            arg2: Arc::new(e2),
968        })
969    }
970
971    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
972    pub fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
973        self.with_expr_kind(ExprKind::BinaryApp {
974            op: BinaryOp::ContainsAll,
975            arg1: Arc::new(e1),
976            arg2: Arc::new(e2),
977        })
978    }
979
980    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
981    pub fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
982        self.with_expr_kind(ExprKind::BinaryApp {
983            op: BinaryOp::ContainsAny,
984            arg1: Arc::new(e1),
985            arg2: Arc::new(e2),
986        })
987    }
988
989    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
990    pub fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
991        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
992    }
993
994    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
995    pub fn record(
996        self,
997        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
998    ) -> Result<Expr<T>, ExprConstructionError> {
999        let mut map = BTreeMap::new();
1000        for (k, v) in pairs {
1001            match map.entry(k) {
1002                btree_map::Entry::Occupied(oentry) => {
1003                    return Err(expression_construction_errors::DuplicateKeyError {
1004                        key: oentry.key().clone(),
1005                        context: "in record literal",
1006                    }
1007                    .into());
1008                }
1009                btree_map::Entry::Vacant(ventry) => {
1010                    ventry.insert(v);
1011                }
1012            }
1013        }
1014        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1015    }
1016
1017    /// Create an `Expr` which evalutes to a Record with the given key-value mapping.
1018    ///
1019    /// If you have an iterator of pairs, generally prefer calling `.record()`
1020    /// instead of `.collect()`-ing yourself and calling this, potentially for
1021    /// efficiency reasons but also because `.record()` will properly handle
1022    /// duplicate keys but your own `.collect()` will not (by default).
1023    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1024        self.with_expr_kind(ExprKind::Record(map))
1025    }
1026
1027    /// Create an `Expr` which calls the extension function with the given
1028    /// `Name` on `args`
1029    pub fn call_extension_fn(
1030        self,
1031        fn_name: Name,
1032        args: impl IntoIterator<Item = Expr<T>>,
1033    ) -> Expr<T> {
1034        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1035            fn_name,
1036            args: Arc::new(args.into_iter().collect()),
1037        })
1038    }
1039
1040    /// Create an application `Expr` which applies the given built-in unary
1041    /// operator to the given `arg`
1042    pub fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1043        self.with_expr_kind(ExprKind::UnaryApp {
1044            op: op.into(),
1045            arg: Arc::new(arg),
1046        })
1047    }
1048
1049    /// Create an application `Expr` which applies the given built-in binary
1050    /// operator to `arg1` and `arg2`
1051    pub fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1052        self.with_expr_kind(ExprKind::BinaryApp {
1053            op: op.into(),
1054            arg1: Arc::new(arg1),
1055            arg2: Arc::new(arg2),
1056        })
1057    }
1058
1059    /// Create an `Expr` which gets the attribute of some `Entity` or the field
1060    /// of some record.
1061    ///
1062    /// `expr` must evaluate to either Entity or Record type
1063    pub fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1064        self.with_expr_kind(ExprKind::GetAttr {
1065            expr: Arc::new(expr),
1066            attr,
1067        })
1068    }
1069
1070    /// Create an `Expr` which tests for the existence of a given
1071    /// attribute on a given `Entity`, or field on a given record.
1072    ///
1073    /// `expr` must evaluate to either Entity or Record type
1074    pub fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1075        self.with_expr_kind(ExprKind::HasAttr {
1076            expr: Arc::new(expr),
1077            attr,
1078        })
1079    }
1080
1081    /// Create a 'like' expression.
1082    ///
1083    /// `expr` must evaluate to a String type
1084    pub fn like(self, expr: Expr<T>, pattern: impl IntoIterator<Item = PatternElem>) -> Expr<T> {
1085        self.with_expr_kind(ExprKind::Like {
1086            expr: Arc::new(expr),
1087            pattern: Pattern::new(pattern),
1088        })
1089    }
1090
1091    /// Create an 'is' expression.
1092    pub fn is_entity_type(self, expr: Expr<T>, entity_type: Name) -> Expr<T> {
1093        self.with_expr_kind(ExprKind::Is {
1094            expr: Arc::new(expr),
1095            entity_type,
1096        })
1097    }
1098}
1099
1100impl<T: Clone> ExprBuilder<T> {
1101    /// Create an `and` expression that may have more than two subexpressions (A && B && C)
1102    /// or may have only one subexpression, in which case no `&&` is performed at all.
1103    /// Arguments must evaluate to Bool type.
1104    ///
1105    /// This may create multiple AST `&&` nodes. If it does, all the nodes will have the same
1106    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1107    /// with another call to `with_source_loc()`.
1108    pub fn and_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1109        others.into_iter().fold(first, |acc, next| {
1110            Self::with_data(self.data.clone())
1111                .with_maybe_source_loc(self.source_loc.clone())
1112                .and(acc, next)
1113        })
1114    }
1115
1116    /// Create an `or` expression that may have more than two subexpressions (A || B || C)
1117    /// or may have only one subexpression, in which case no `||` is performed at all.
1118    /// Arguments must evaluate to Bool type.
1119    ///
1120    /// This may create multiple AST `||` nodes. If it does, all the nodes will have the same
1121    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1122    /// with another call to `with_source_loc()`.
1123    pub fn or_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1124        others.into_iter().fold(first, |acc, next| {
1125            Self::with_data(self.data.clone())
1126                .with_maybe_source_loc(self.source_loc.clone())
1127                .or(acc, next)
1128        })
1129    }
1130
1131    /// Create a '>' expression. Arguments must evaluate to Long type
1132    pub fn greater(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1133        // e1 > e2 is defined as !(e1 <= e2)
1134        let leq = Self::with_data(self.data.clone())
1135            .with_maybe_source_loc(self.source_loc.clone())
1136            .lesseq(e1, e2);
1137        self.not(leq)
1138    }
1139
1140    /// Create a '>=' expression. Arguments must evaluate to Long type
1141    pub fn greatereq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1142        // e1 >= e2 is defined as !(e1 < e2)
1143        let leq = Self::with_data(self.data.clone())
1144            .with_maybe_source_loc(self.source_loc.clone())
1145            .less(e1, e2);
1146        self.not(leq)
1147    }
1148}
1149
1150/// Errors when constructing an `Expr`
1151#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1152pub enum ExprConstructionError {
1153    /// The same key occurred two or more times
1154    #[error(transparent)]
1155    #[diagnostic(transparent)]
1156    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1157}
1158
1159/// Error subtypes for [`ExprConstructionError`]
1160pub mod expression_construction_errors {
1161    use miette::Diagnostic;
1162    use smol_str::SmolStr;
1163    use thiserror::Error;
1164
1165    /// The same key occurred two or more times
1166    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1167    #[error("duplicate key `{key}` {context}")]
1168    pub struct DuplicateKeyError {
1169        /// The key which occurred two or more times
1170        pub(crate) key: SmolStr,
1171        /// Information about where the duplicate key occurred (e.g., "in record literal")
1172        pub(crate) context: &'static str,
1173    }
1174
1175    impl DuplicateKeyError {
1176        /// Get the key which occurred two or more times
1177        pub fn key(&self) -> &str {
1178            &self.key
1179        }
1180
1181        /// Make a new error with an updated `context` field
1182        pub(crate) fn with_context(self, context: &'static str) -> Self {
1183            Self { context, ..self }
1184        }
1185    }
1186}
1187
1188/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1189/// implementations that ignore any source information or other generic data
1190/// used to annotate the `Expr`.
1191#[derive(Eq, Debug, Clone)]
1192pub struct ExprShapeOnly<'a, T = ()>(&'a Expr<T>);
1193
1194impl<'a, T> ExprShapeOnly<'a, T> {
1195    /// Construct an `ExprShapeOnly` from an `Expr`. The `Expr` is not modified,
1196    /// but any comparisons on the resulting `ExprShapeOnly` will ignore source
1197    /// information and generic data.
1198    pub fn new(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1199        ExprShapeOnly(e)
1200    }
1201}
1202
1203impl<'a, T> PartialEq for ExprShapeOnly<'a, T> {
1204    fn eq(&self, other: &Self) -> bool {
1205        self.0.eq_shape(other.0)
1206    }
1207}
1208
1209impl<'a, T> Hash for ExprShapeOnly<'a, T> {
1210    fn hash<H: Hasher>(&self, state: &mut H) {
1211        self.0.hash_shape(state);
1212    }
1213}
1214
1215impl<T> Expr<T> {
1216    /// Return true if this expression (recursively) has the same expression
1217    /// kind as the argument expression. This accounts for the full recursive
1218    /// shape of the expression, but does not consider source information or any
1219    /// generic data annotated on expression. This should behave the same as the
1220    /// default implementation of `Eq` before source information and generic
1221    /// data were added.
1222    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1223        use ExprKind::*;
1224        match (self.expr_kind(), other.expr_kind()) {
1225            (Lit(lit), Lit(lit1)) => lit == lit1,
1226            (Var(v), Var(v1)) => v == v1,
1227            (Slot(s), Slot(s1)) => s == s1,
1228            (
1229                Unknown(self::Unknown {
1230                    name: name1,
1231                    type_annotation: ta_1,
1232                }),
1233                Unknown(self::Unknown {
1234                    name: name2,
1235                    type_annotation: ta_2,
1236                }),
1237            ) => (name1 == name2) && (ta_1 == ta_2),
1238            (
1239                If {
1240                    test_expr,
1241                    then_expr,
1242                    else_expr,
1243                },
1244                If {
1245                    test_expr: test_expr1,
1246                    then_expr: then_expr1,
1247                    else_expr: else_expr1,
1248                },
1249            ) => {
1250                test_expr.eq_shape(test_expr1)
1251                    && then_expr.eq_shape(then_expr1)
1252                    && else_expr.eq_shape(else_expr1)
1253            }
1254            (
1255                And { left, right },
1256                And {
1257                    left: left1,
1258                    right: right1,
1259                },
1260            )
1261            | (
1262                Or { left, right },
1263                Or {
1264                    left: left1,
1265                    right: right1,
1266                },
1267            ) => left.eq_shape(left1) && right.eq_shape(right1),
1268            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1269                op == op1 && arg.eq_shape(arg1)
1270            }
1271            (
1272                BinaryApp { op, arg1, arg2 },
1273                BinaryApp {
1274                    op: op1,
1275                    arg1: arg11,
1276                    arg2: arg21,
1277                },
1278            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1279            (
1280                ExtensionFunctionApp { fn_name, args },
1281                ExtensionFunctionApp {
1282                    fn_name: fn_name1,
1283                    args: args1,
1284                },
1285            ) => fn_name == fn_name1 && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1)),
1286            (
1287                GetAttr { expr, attr },
1288                GetAttr {
1289                    expr: expr1,
1290                    attr: attr1,
1291                },
1292            )
1293            | (
1294                HasAttr { expr, attr },
1295                HasAttr {
1296                    expr: expr1,
1297                    attr: attr1,
1298                },
1299            ) => attr == attr1 && expr.eq_shape(expr1),
1300            (
1301                Like { expr, pattern },
1302                Like {
1303                    expr: expr1,
1304                    pattern: pattern1,
1305                },
1306            ) => pattern == pattern1 && expr.eq_shape(expr1),
1307            (Set(elems), Set(elems1)) => elems
1308                .iter()
1309                .zip(elems1.iter())
1310                .all(|(e, e1)| e.eq_shape(e1)),
1311            (Record(map), Record(map1)) => {
1312                map.len() == map1.len()
1313                    && map
1314                        .iter()
1315                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1316                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1317            }
1318            (
1319                Is { expr, entity_type },
1320                Is {
1321                    expr: expr1,
1322                    entity_type: entity_type1,
1323                },
1324            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1325            _ => false,
1326        }
1327    }
1328
1329    /// Implementation of hashing corresponding to equality as implemented by
1330    /// `eq_shape`. Must satisfy the usual relationship between equality and
1331    /// hashing.
1332    pub fn hash_shape<H>(&self, state: &mut H)
1333    where
1334        H: Hasher,
1335    {
1336        mem::discriminant(self).hash(state);
1337        match self.expr_kind() {
1338            ExprKind::Lit(lit) => lit.hash(state),
1339            ExprKind::Var(v) => v.hash(state),
1340            ExprKind::Slot(s) => s.hash(state),
1341            ExprKind::Unknown(u) => u.hash(state),
1342            ExprKind::If {
1343                test_expr,
1344                then_expr,
1345                else_expr,
1346            } => {
1347                test_expr.hash_shape(state);
1348                then_expr.hash_shape(state);
1349                else_expr.hash_shape(state);
1350            }
1351            ExprKind::And { left, right } => {
1352                left.hash_shape(state);
1353                right.hash_shape(state);
1354            }
1355            ExprKind::Or { left, right } => {
1356                left.hash_shape(state);
1357                right.hash_shape(state);
1358            }
1359            ExprKind::UnaryApp { op, arg } => {
1360                op.hash(state);
1361                arg.hash_shape(state);
1362            }
1363            ExprKind::BinaryApp { op, arg1, arg2 } => {
1364                op.hash(state);
1365                arg1.hash_shape(state);
1366                arg2.hash_shape(state);
1367            }
1368            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1369                fn_name.hash(state);
1370                state.write_usize(args.len());
1371                args.iter().for_each(|a| {
1372                    a.hash_shape(state);
1373                });
1374            }
1375            ExprKind::GetAttr { expr, attr } => {
1376                expr.hash_shape(state);
1377                attr.hash(state);
1378            }
1379            ExprKind::HasAttr { expr, attr } => {
1380                expr.hash_shape(state);
1381                attr.hash(state);
1382            }
1383            ExprKind::Like { expr, pattern } => {
1384                expr.hash_shape(state);
1385                pattern.hash(state);
1386            }
1387            ExprKind::Set(elems) => {
1388                state.write_usize(elems.len());
1389                elems.iter().for_each(|e| {
1390                    e.hash_shape(state);
1391                })
1392            }
1393            ExprKind::Record(map) => {
1394                state.write_usize(map.len());
1395                map.iter().for_each(|(s, a)| {
1396                    s.hash(state);
1397                    a.hash_shape(state);
1398                });
1399            }
1400            ExprKind::Is { expr, entity_type } => {
1401                expr.hash_shape(state);
1402                entity_type.hash(state);
1403            }
1404        }
1405    }
1406}
1407
1408/// AST variables
1409#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1410#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1411#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1412#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1413pub enum Var {
1414    /// the Principal of the given request
1415    #[serde(rename = "principal")]
1416    Principal,
1417    /// the Action of the given request
1418    #[serde(rename = "action")]
1419    Action,
1420    /// the Resource of the given request
1421    #[serde(rename = "resource")]
1422    Resource,
1423    /// the Context of the given request
1424    #[serde(rename = "context")]
1425    Context,
1426}
1427
1428#[cfg(test)]
1429pub(crate) mod var_generator {
1430    use super::Var;
1431    #[cfg(test)]
1432    pub fn all_vars() -> impl Iterator<Item = Var> {
1433        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1434    }
1435}
1436// by default, Coverlay does not track coverage for lines after a line
1437// containing #[cfg(test)].
1438// we use the following sentinel to "turn back on" coverage tracking for
1439// remaining lines of this file, until the next #[cfg(test)]
1440// GRCOV_BEGIN_COVERAGE
1441
1442impl From<PrincipalOrResource> for Var {
1443    fn from(v: PrincipalOrResource) -> Self {
1444        match v {
1445            PrincipalOrResource::Principal => Var::Principal,
1446            PrincipalOrResource::Resource => Var::Resource,
1447        }
1448    }
1449}
1450
1451// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1452#[allow(clippy::fallible_impl_from)]
1453impl From<Var> for Id {
1454    fn from(var: Var) -> Self {
1455        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1456        #[allow(clippy::unwrap_used)]
1457        format!("{var}").parse().unwrap()
1458    }
1459}
1460
1461impl std::fmt::Display for Var {
1462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1463        match self {
1464            Self::Principal => write!(f, "principal"),
1465            Self::Action => write!(f, "action"),
1466            Self::Resource => write!(f, "resource"),
1467            Self::Context => write!(f, "context"),
1468        }
1469    }
1470}
1471
1472#[cfg(test)]
1473mod test {
1474    use cool_asserts::assert_matches;
1475    use itertools::Itertools;
1476    use std::collections::{hash_map::DefaultHasher, HashSet};
1477
1478    use super::{var_generator::all_vars, *};
1479
1480    // Tests that Var::Into never panics
1481    #[test]
1482    fn all_vars_are_ids() {
1483        for var in all_vars() {
1484            let _id: Id = var.into();
1485        }
1486    }
1487
1488    #[test]
1489    fn exprs() {
1490        assert_eq!(
1491            Expr::val(33),
1492            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1493        );
1494        assert_eq!(
1495            Expr::val("hello"),
1496            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1497        );
1498        assert_eq!(
1499            Expr::val(EntityUID::with_eid("foo")),
1500            Expr::new(
1501                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1502                None,
1503                ()
1504            )
1505        );
1506        assert_eq!(
1507            Expr::var(Var::Principal),
1508            Expr::new(ExprKind::Var(Var::Principal), None, ())
1509        );
1510        assert_eq!(
1511            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1512            Expr::new(
1513                ExprKind::If {
1514                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1515                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1516                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1517                },
1518                None,
1519                ()
1520            )
1521        );
1522        assert_eq!(
1523            Expr::not(Expr::val(false)),
1524            Expr::new(
1525                ExprKind::UnaryApp {
1526                    op: UnaryOp::Not,
1527                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1528                },
1529                None,
1530                ()
1531            )
1532        );
1533        assert_eq!(
1534            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1535            Expr::new(
1536                ExprKind::GetAttr {
1537                    expr: Arc::new(Expr::new(
1538                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1539                        None,
1540                        ()
1541                    )),
1542                    attr: "some_attr".into()
1543                },
1544                None,
1545                ()
1546            )
1547        );
1548        assert_eq!(
1549            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1550            Expr::new(
1551                ExprKind::HasAttr {
1552                    expr: Arc::new(Expr::new(
1553                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1554                        None,
1555                        ()
1556                    )),
1557                    attr: "some_attr".into()
1558                },
1559                None,
1560                ()
1561            )
1562        );
1563        assert_eq!(
1564            Expr::is_entity_type(
1565                Expr::val(EntityUID::with_eid("foo")),
1566                "Type".parse().unwrap()
1567            ),
1568            Expr::new(
1569                ExprKind::Is {
1570                    expr: Arc::new(Expr::new(
1571                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1572                        None,
1573                        ()
1574                    )),
1575                    entity_type: "Type".parse().unwrap()
1576                },
1577                None,
1578                ()
1579            ),
1580        );
1581    }
1582
1583    #[test]
1584    fn like_display() {
1585        // `\0` escaped form is `\0`.
1586        let e = Expr::like(Expr::val("a"), vec![PatternElem::Char('\0')]);
1587        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1588        // `\`'s escaped form is `\\`
1589        let e = Expr::like(
1590            Expr::val("a"),
1591            vec![PatternElem::Char('\\'), PatternElem::Char('0')],
1592        );
1593        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1594        // `\`'s escaped form is `\\`
1595        let e = Expr::like(
1596            Expr::val("a"),
1597            vec![PatternElem::Char('\\'), PatternElem::Wildcard],
1598        );
1599        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1600        // literal star's escaped from is `\*`
1601        let e = Expr::like(
1602            Expr::val("a"),
1603            vec![PatternElem::Char('\\'), PatternElem::Char('*')],
1604        );
1605        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1606    }
1607
1608    #[test]
1609    fn has_display() {
1610        // `\0` escaped form is `\0`.
1611        let e = Expr::has_attr(Expr::val("a"), "\0".into());
1612        assert_eq!(format!("{e}"), r#""a" has "\0""#);
1613        // `\`'s escaped form is `\\`
1614        let e = Expr::has_attr(Expr::val("a"), r#"\"#.into());
1615        assert_eq!(format!("{e}"), r#""a" has "\\""#);
1616    }
1617
1618    #[test]
1619    fn slot_display() {
1620        let e = Expr::slot(SlotId::principal());
1621        assert_eq!(format!("{e}"), "?principal");
1622        let e = Expr::slot(SlotId::resource());
1623        assert_eq!(format!("{e}"), "?resource");
1624        let e = Expr::val(EntityUID::with_eid("eid"));
1625        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1626    }
1627
1628    #[test]
1629    fn simple_slots() {
1630        let e = Expr::slot(SlotId::principal());
1631        let p = SlotId::principal();
1632        let r = SlotId::resource();
1633        let set: HashSet<SlotId> = HashSet::from_iter([p]);
1634        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1635        let e = Expr::or(
1636            Expr::slot(SlotId::principal()),
1637            Expr::ite(
1638                Expr::val(true),
1639                Expr::slot(SlotId::resource()),
1640                Expr::val(false),
1641            ),
1642        );
1643        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
1644        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1645    }
1646
1647    #[test]
1648    fn unknowns() {
1649        let e = Expr::ite(
1650            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1651            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1652            Expr::unknown(Unknown::new_untyped("c")),
1653        );
1654        let unknowns = e.unknowns().collect_vec();
1655        assert_eq!(unknowns.len(), 3);
1656        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
1657        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
1658        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
1659    }
1660
1661    #[test]
1662    fn is_unknown() {
1663        let e = Expr::ite(
1664            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1665            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1666            Expr::unknown(Unknown::new_untyped("c")),
1667        );
1668        assert!(e.is_unknown());
1669        let e = Expr::ite(
1670            Expr::not(Expr::val(true)),
1671            Expr::and(Expr::val(1), Expr::val(3)),
1672            Expr::val(1),
1673        );
1674        assert!(!e.is_unknown());
1675    }
1676
1677    #[test]
1678    fn expr_with_data() {
1679        let e = ExprBuilder::with_data("data").val(1);
1680        assert_eq!(e.into_data(), "data");
1681    }
1682
1683    #[test]
1684    fn expr_shape_only_eq() {
1685        let temp = ExprBuilder::with_data(1).val(1);
1686        let exprs = &[
1687            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1688            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1689            (
1690                ExprBuilder::with_data(1).var(Var::Principal),
1691                Expr::var(Var::Principal),
1692            ),
1693            (
1694                ExprBuilder::with_data(1).slot(SlotId::principal()),
1695                Expr::slot(SlotId::principal()),
1696            ),
1697            (
1698                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1699                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1700            ),
1701            (
1702                ExprBuilder::with_data(1).not(temp.clone()),
1703                Expr::not(Expr::val(1)),
1704            ),
1705            (
1706                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1707                Expr::is_eq(Expr::val(1), Expr::val(1)),
1708            ),
1709            (
1710                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1711                Expr::and(Expr::val(1), Expr::val(1)),
1712            ),
1713            (
1714                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1715                Expr::or(Expr::val(1), Expr::val(1)),
1716            ),
1717            (
1718                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1719                Expr::less(Expr::val(1), Expr::val(1)),
1720            ),
1721            (
1722                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1723                Expr::lesseq(Expr::val(1), Expr::val(1)),
1724            ),
1725            (
1726                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1727                Expr::greater(Expr::val(1), Expr::val(1)),
1728            ),
1729            (
1730                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1731                Expr::greatereq(Expr::val(1), Expr::val(1)),
1732            ),
1733            (
1734                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1735                Expr::add(Expr::val(1), Expr::val(1)),
1736            ),
1737            (
1738                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1739                Expr::sub(Expr::val(1), Expr::val(1)),
1740            ),
1741            (
1742                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
1743                Expr::mul(Expr::val(1), Expr::val(1)),
1744            ),
1745            (
1746                ExprBuilder::with_data(1).neg(temp.clone()),
1747                Expr::neg(Expr::val(1)),
1748            ),
1749            (
1750                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1751                Expr::is_in(Expr::val(1), Expr::val(1)),
1752            ),
1753            (
1754                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1755                Expr::contains(Expr::val(1), Expr::val(1)),
1756            ),
1757            (
1758                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1759                Expr::contains_all(Expr::val(1), Expr::val(1)),
1760            ),
1761            (
1762                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1763                Expr::contains_any(Expr::val(1), Expr::val(1)),
1764            ),
1765            (
1766                ExprBuilder::with_data(1).set([temp.clone()]),
1767                Expr::set([Expr::val(1)]),
1768            ),
1769            (
1770                ExprBuilder::with_data(1)
1771                    .record([("foo".into(), temp.clone())])
1772                    .unwrap(),
1773                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
1774            ),
1775            (
1776                ExprBuilder::with_data(1)
1777                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1778                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1779            ),
1780            (
1781                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1782                Expr::get_attr(Expr::val(1), "foo".into()),
1783            ),
1784            (
1785                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
1786                Expr::has_attr(Expr::val(1), "foo".into()),
1787            ),
1788            (
1789                ExprBuilder::with_data(1).like(temp.clone(), vec![PatternElem::Wildcard]),
1790                Expr::like(Expr::val(1), vec![PatternElem::Wildcard]),
1791            ),
1792            (
1793                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
1794                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
1795            ),
1796        ];
1797
1798        for (e0, e1) in exprs {
1799            assert!(e0.eq_shape(e0));
1800            assert!(e1.eq_shape(e1));
1801            assert!(e0.eq_shape(e1));
1802            assert!(e1.eq_shape(e0));
1803
1804            let mut hasher0 = DefaultHasher::new();
1805            e0.hash_shape(&mut hasher0);
1806            let hash0 = hasher0.finish();
1807
1808            let mut hasher1 = DefaultHasher::new();
1809            e1.hash_shape(&mut hasher1);
1810            let hash1 = hasher1.finish();
1811
1812            assert_eq!(hash0, hash1);
1813        }
1814    }
1815
1816    #[test]
1817    fn expr_shape_only_not_eq() {
1818        let expr1 = ExprBuilder::with_data(1).val(1);
1819        let expr2 = ExprBuilder::with_data(1).val(2);
1820        assert_ne!(ExprShapeOnly::new(&expr1), ExprShapeOnly::new(&expr2));
1821    }
1822
1823    #[test]
1824    fn untyped_subst_present() {
1825        let u = Unknown {
1826            name: "foo".into(),
1827            type_annotation: None,
1828        };
1829        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1830        match r {
1831            Ok(e) => assert_eq!(e, Expr::val(1)),
1832            Err(empty) => match empty {},
1833        }
1834    }
1835
1836    #[test]
1837    fn untyped_subst_present_correct_type() {
1838        let u = Unknown {
1839            name: "foo".into(),
1840            type_annotation: Some(Type::Long),
1841        };
1842        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1843        match r {
1844            Ok(e) => assert_eq!(e, Expr::val(1)),
1845            Err(empty) => match empty {},
1846        }
1847    }
1848
1849    #[test]
1850    fn untyped_subst_present_wrong_type() {
1851        let u = Unknown {
1852            name: "foo".into(),
1853            type_annotation: Some(Type::Bool),
1854        };
1855        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1856        match r {
1857            Ok(e) => assert_eq!(e, Expr::val(1)),
1858            Err(empty) => match empty {},
1859        }
1860    }
1861
1862    #[test]
1863    fn untyped_subst_not_present() {
1864        let u = Unknown {
1865            name: "foo".into(),
1866            type_annotation: Some(Type::Bool),
1867        };
1868        let r = UntypedSubstitution::substitute(&u, None);
1869        match r {
1870            Ok(n) => assert_eq!(n, Expr::unknown(u)),
1871            Err(empty) => match empty {},
1872        }
1873    }
1874
1875    #[test]
1876    fn typed_subst_present() {
1877        let u = Unknown {
1878            name: "foo".into(),
1879            type_annotation: None,
1880        };
1881        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
1882        assert_eq!(e, Expr::val(1));
1883    }
1884
1885    #[test]
1886    fn typed_subst_present_correct_type() {
1887        let u = Unknown {
1888            name: "foo".into(),
1889            type_annotation: Some(Type::Long),
1890        };
1891        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
1892        assert_eq!(e, Expr::val(1));
1893    }
1894
1895    #[test]
1896    fn typed_subst_present_wrong_type() {
1897        let u = Unknown {
1898            name: "foo".into(),
1899            type_annotation: Some(Type::Bool),
1900        };
1901        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
1902        assert_matches!(
1903            r,
1904            SubstitutionError::TypeError {
1905                expected: Type::Bool,
1906                actual: Type::Long,
1907            }
1908        );
1909    }
1910
1911    #[test]
1912    fn typed_subst_not_present() {
1913        let u = Unknown {
1914            name: "foo".into(),
1915            type_annotation: None,
1916        };
1917        let r = TypedSubstitution::substitute(&u, None).unwrap();
1918        assert_eq!(r, Expr::unknown(u));
1919    }
1920}