cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::{ast::*, extensions::Extensions, parser::SourceInfo};
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20use smol_str::SmolStr;
21use std::{
22    collections::HashMap,
23    collections::HashSet,
24    hash::{Hash, Hasher},
25    mem,
26    sync::Arc,
27};
28use thiserror::Error;
29
30/// Internal AST for expressions used by the policy evaluator.
31/// This structure is a wrapper around an `ExprKind`, which is the expression
32/// variant this object contains. It also contains source information about
33/// where the expression was written in policy source code, and some generic
34/// data which is stored on each node of the AST.
35/// Cloning is O(1).
36#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
37pub struct Expr<T = ()> {
38    expr_kind: ExprKind<T>,
39    source_info: Option<SourceInfo>,
40    data: T,
41}
42
43/// The possible expression variants. This enum should be matched on by code
44/// recursively traversing the AST.
45#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
46pub enum ExprKind<T = ()> {
47    /// Literal value
48    Lit(Literal),
49    /// Variable
50    Var(Var),
51    /// Template Slots
52    Slot(SlotId),
53    /// Symbolic Unknown for partial-eval
54    Unknown {
55        /// The name of the unknown
56        name: SmolStr,
57        /// The type of the values that can be substituted in for the unknown
58        /// If `None`, we have no type annotation, and thus a value of any type can be substituted.
59        type_annotation: Option<Type>,
60    },
61    /// Ternary expression
62    If {
63        /// Condition for the ternary expression. Must evaluate to Bool type
64        test_expr: Arc<Expr<T>>,
65        /// Value if true
66        then_expr: Arc<Expr<T>>,
67        /// Value if false
68        else_expr: Arc<Expr<T>>,
69    },
70    /// Boolean AND
71    And {
72        /// Left operand, which will be eagerly evaluated
73        left: Arc<Expr<T>>,
74        /// Right operand, which may not be evaluated due to short-circuiting
75        right: Arc<Expr<T>>,
76    },
77    /// Boolean OR
78    Or {
79        /// Left operand, which will be eagerly evaluated
80        left: Arc<Expr<T>>,
81        /// Right operand, which may not be evaluated due to short-circuiting
82        right: Arc<Expr<T>>,
83    },
84    /// Application of a built-in unary operator (single parameter)
85    UnaryApp {
86        /// Unary operator to apply
87        op: UnaryOp,
88        /// Argument to apply operator to
89        arg: Arc<Expr<T>>,
90    },
91    /// Application of a built-in binary operator (two parameters)
92    BinaryApp {
93        /// Binary operator to apply
94        op: BinaryOp,
95        /// First arg
96        arg1: Arc<Expr<T>>,
97        /// Second arg
98        arg2: Arc<Expr<T>>,
99    },
100    /// Multiplication by constant
101    ///
102    /// This isn't just a BinaryOp because its arguments aren't both expressions.
103    /// (Similar to how `like` isn't a BinaryOp and has its own AST node as well.)
104    MulByConst {
105        /// first argument, which may be an arbitrary expression, but must
106        /// evaluate to Long type
107        arg: Arc<Expr<T>>,
108        /// second argument, which must be an integer constant
109        constant: i64,
110    },
111    /// Application of an extension function to n arguments
112    /// INVARIANT (MethodStyleArgs):
113    ///   if op.style is MethodStyle then args _cannot_ be empty.
114    ///     The first element of args refers to the subject of the method call
115    /// Ideally, we find some way to make this non-representable.
116    ExtensionFunctionApp {
117        /// Extension function to apply
118        op: ExtensionFunctionOp,
119        /// Args to apply the function to
120        args: Arc<Vec<Expr<T>>>,
121    },
122    /// Get an attribute of an entity, or a field of a record
123    GetAttr {
124        /// Expression to get an attribute/field of. Must evaluate to either
125        /// Entity or Record type
126        expr: Arc<Expr<T>>,
127        /// Attribute or field to get
128        attr: SmolStr,
129    },
130    /// Does the given `expr` have the given `attr`?
131    HasAttr {
132        /// Expression to test. Must evaluate to either Entity or Record type
133        expr: Arc<Expr<T>>,
134        /// Attribute or field to check for
135        attr: SmolStr,
136    },
137    /// Regex-like string matching similar to IAM's `StringLike` operator.
138    Like {
139        /// Expression to test. Must evaluate to String type
140        expr: Arc<Expr<T>>,
141        /// Pattern to match on; can include the wildcard *, which matches any string.
142        /// To match a literal `*` in the test expression, users can use `\*`.
143        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
144        pattern: Pattern,
145    },
146    /// Set (whose elements may be arbitrary expressions)
147    //
148    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
149    // that are syntactically unequal, may actually be semantically equal --
150    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
151    // evaluated into `Value`s
152    Set(Arc<Vec<Expr<T>>>),
153    /// Anonymous record (whose elements may be arbitrary expressions)
154    /// This is a `Vec` for the same reason as above.
155    Record {
156        /// key/value pairs
157        pairs: Arc<Vec<(SmolStr, Expr<T>)>>,
158    },
159}
160
161impl From<Value> for Expr {
162    fn from(v: Value) -> Self {
163        match v {
164            Value::Lit(l) => Expr::val(l),
165            Value::Set(s) => Expr::set(s.iter().map(|v| Expr::from(v.clone()))),
166            Value::Record(fields) => Expr::record(
167                fields
168                    .as_ref()
169                    .clone()
170                    .into_iter()
171                    .map(|(k, v)| (k, Expr::from(v))),
172            ),
173            Value::ExtensionValue(ev) => ev.as_ref().clone().into(),
174        }
175    }
176}
177
178impl<T> Expr<T> {
179    fn new(expr_kind: ExprKind<T>, source_info: Option<SourceInfo>, data: T) -> Self {
180        Self {
181            expr_kind,
182            source_info,
183            data,
184        }
185    }
186
187    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
188    /// `enum` which specifies the expression variant, so it must be accessed by
189    /// any code matching and recursing on an expression.
190    pub fn expr_kind(&self) -> &ExprKind<T> {
191        &self.expr_kind
192    }
193
194    /// Access the inner `ExprKind`, taking ownership.
195    pub fn into_expr_kind(self) -> ExprKind<T> {
196        self.expr_kind
197    }
198
199    /// Access the data stored on the `Expr`.
200    pub fn data(&self) -> &T {
201        &self.data
202    }
203
204    /// Access the data stored on the `Expr`, taking ownership.
205    pub fn into_data(self) -> T {
206        self.data
207    }
208
209    /// Access the data stored on the `Expr`.
210    pub fn source_info(&self) -> &Option<SourceInfo> {
211        &self.source_info
212    }
213
214    /// Access the data stored on the `Expr`, taking ownership.
215    pub fn into_source_info(self) -> Option<SourceInfo> {
216        self.source_info
217    }
218
219    /// Update the data for this `Expr`. A convenient function used by the
220    /// Validator in one place.
221    pub fn set_data(&mut self, data: T) {
222        self.data = data;
223    }
224
225    /// Check whether this expression is an entity reference
226    ///
227    /// This is used for policy headers, where some syntax is
228    /// required to be an entity reference.
229    pub fn is_ref(&self) -> bool {
230        match &self.expr_kind {
231            ExprKind::Lit(lit) => lit.is_ref(),
232            _ => false,
233        }
234    }
235
236    /// Check whether this expression is a slot.
237    pub fn is_slot(&self) -> bool {
238        matches!(&self.expr_kind, ExprKind::Slot(_))
239    }
240
241    /// Check whether this expression is a set of entity references
242    ///
243    /// This is used for policy headers, where some syntax is
244    /// required to be an entity reference set.
245    pub fn is_ref_set(&self) -> bool {
246        match &self.expr_kind {
247            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
248            _ => false,
249        }
250    }
251
252    /// Iterate over all sub-expressions in this expression
253    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
254        expr_iterator::ExprIterator::new(self)
255    }
256
257    /// Iterate over all of the slots in this policy AST
258    pub fn slots(&self) -> impl Iterator<Item = &SlotId> {
259        self.subexpressions()
260            .filter_map(|exp| match &exp.expr_kind {
261                ExprKind::Slot(slotid) => Some(slotid),
262                _ => None,
263            })
264    }
265
266    /// Determine if the expression is projectable under partial evaluation
267    /// An expression is projectable if it's guaranteed to never error on evaluation
268    /// This is true if the expression is entirely composed of values or unknowns
269    pub fn is_projectable(&self) -> bool {
270        self.subexpressions().all(|e| match e.expr_kind() {
271            ExprKind::Lit(_) => true,
272            ExprKind::Unknown { .. } => true,
273            ExprKind::Set(_) => true,
274            ExprKind::Var(_) => true,
275            ExprKind::Record { pairs } => {
276                // We need to ensure there are no duplicate keys in the expression
277                let uniq_keys = pairs
278                    .as_ref()
279                    .iter()
280                    .map(|(key, _)| key)
281                    .collect::<HashSet<_>>();
282                pairs.len() == uniq_keys.len()
283            }
284            _ => false,
285        })
286    }
287}
288
289#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
290#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
291impl Expr {
292    /// Create an `Expr` that's just a single `Literal`.
293    ///
294    /// Note that you can pass this a `Literal`, an `i64`, a `String`, etc.
295    pub fn val(v: impl Into<Literal>) -> Self {
296        ExprBuilder::new().val(v)
297    }
298
299    /// Create an unknown value
300    pub fn unknown(name: impl Into<SmolStr>) -> Self {
301        Self::unknown_with_type(name, None)
302    }
303
304    /// Create an unknown value, with an optional type annotation
305    pub fn unknown_with_type(name: impl Into<SmolStr>, t: Option<Type>) -> Self {
306        ExprBuilder::new().unknown(name.into(), t)
307    }
308
309    /// Create an `Expr` that's just this literal `Var`
310    pub fn var(v: Var) -> Self {
311        ExprBuilder::new().var(v)
312    }
313
314    /// Create an `Expr` that's just this `SlotId`
315    pub fn slot(s: SlotId) -> Self {
316        ExprBuilder::new().slot(s)
317    }
318
319    /// Create a ternary (if-then-else) `Expr`.
320    ///
321    /// `test_expr` must evaluate to a Bool type
322    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
323        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
324    }
325
326    /// Create a 'not' expression. `e` must evaluate to Bool type
327    pub fn not(e: Expr) -> Self {
328        ExprBuilder::new().not(e)
329    }
330
331    /// Create a '==' expression
332    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
333        ExprBuilder::new().is_eq(e1, e2)
334    }
335
336    /// Create a '!=' expression
337    pub fn noteq(e1: Expr, e2: Expr) -> Self {
338        ExprBuilder::new().noteq(e1, e2)
339    }
340
341    /// Create an 'and' expression. Arguments must evaluate to Bool type
342    pub fn and(e1: Expr, e2: Expr) -> Self {
343        ExprBuilder::new().and(e1, e2)
344    }
345
346    /// Create an 'or' expression. Arguments must evaluate to Bool type
347    pub fn or(e1: Expr, e2: Expr) -> Self {
348        ExprBuilder::new().or(e1, e2)
349    }
350
351    /// Create a '<' expression. Arguments must evaluate to Long type
352    pub fn less(e1: Expr, e2: Expr) -> Self {
353        ExprBuilder::new().less(e1, e2)
354    }
355
356    /// Create a '<=' expression. Arguments must evaluate to Long type
357    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
358        ExprBuilder::new().lesseq(e1, e2)
359    }
360
361    /// Create a '>' expression. Arguments must evaluate to Long type
362    pub fn greater(e1: Expr, e2: Expr) -> Self {
363        ExprBuilder::new().greater(e1, e2)
364    }
365
366    /// Create a '>=' expression. Arguments must evaluate to Long type
367    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
368        ExprBuilder::new().greatereq(e1, e2)
369    }
370
371    /// Create an 'add' expression. Arguments must evaluate to Long type
372    pub fn add(e1: Expr, e2: Expr) -> Self {
373        ExprBuilder::new().add(e1, e2)
374    }
375
376    /// Create a 'sub' expression. Arguments must evaluate to Long type
377    pub fn sub(e1: Expr, e2: Expr) -> Self {
378        ExprBuilder::new().sub(e1, e2)
379    }
380
381    /// Create a 'mul' expression. First argument must evaluate to Long type.
382    pub fn mul(e: Expr, c: i64) -> Self {
383        ExprBuilder::new().mul(e, c)
384    }
385
386    /// Create a 'neg' expression. `e` must evaluate to Long type.
387    pub fn neg(e: Expr) -> Self {
388        ExprBuilder::new().neg(e)
389    }
390
391    /// Create an 'in' expression. First argument must evaluate to Entity type.
392    /// Second argument must evaluate to either Entity type or Set type where
393    /// all set elements have Entity type.
394    pub fn is_in(e1: Expr, e2: Expr) -> Self {
395        ExprBuilder::new().is_in(e1, e2)
396    }
397
398    /// Create a 'contains' expression.
399    /// First argument must have Set type.
400    pub fn contains(e1: Expr, e2: Expr) -> Self {
401        ExprBuilder::new().contains(e1, e2)
402    }
403
404    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
405    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
406        ExprBuilder::new().contains_all(e1, e2)
407    }
408
409    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
410    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
411        ExprBuilder::new().contains_any(e1, e2)
412    }
413
414    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
415    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
416        ExprBuilder::new().set(exprs)
417    }
418
419    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
420    pub fn record(pairs: impl IntoIterator<Item = (SmolStr, Expr)>) -> Self {
421        ExprBuilder::new().record(pairs)
422    }
423
424    /// Create an `Expr` which calls the extension function with the given
425    /// `Name` on `args`
426    pub fn call_extension_fn(function_name: Name, args: Vec<Expr>) -> Self {
427        ExprBuilder::new().call_extension_fn(function_name, args)
428    }
429
430    /// Create an application `Expr` which applies the given built-in unary
431    /// operator to the given `arg`
432    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
433        ExprBuilder::new().unary_app(op, arg)
434    }
435
436    /// Create an application `Expr` which applies the given built-in binary
437    /// operator to `arg1` and `arg2`
438    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
439        ExprBuilder::new().binary_app(op, arg1, arg2)
440    }
441
442    /// Create an `Expr` which gets the attribute of some `Entity` or the field
443    /// of some record.
444    ///
445    /// `expr` must evaluate to either Entity or Record type
446    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
447        ExprBuilder::new().get_attr(expr, attr)
448    }
449
450    /// Create an `Expr` which tests for the existence of a given
451    /// attribute on a given `Entity`, or field on a given record.
452    ///
453    /// `expr` must evaluate to either Entity or Record type
454    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
455        ExprBuilder::new().has_attr(expr, attr)
456    }
457
458    /// Create a 'like' expression.
459    ///
460    /// `expr` must evaluate to a String type
461    pub fn like(expr: Expr, pattern: impl IntoIterator<Item = PatternElem>) -> Self {
462        ExprBuilder::new().like(expr, pattern)
463    }
464
465    /// Check if an expression contains any symbolic unknowns
466    pub fn is_unknown(&self) -> bool {
467        self.subexpressions()
468            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown { .. }))
469    }
470
471    /// Get all unknowns in an expression
472    pub fn unknowns(&self) -> impl Iterator<Item = &str> {
473        self.subexpressions()
474            .filter_map(|subexpr| match subexpr.expr_kind() {
475                ExprKind::Unknown { name, .. } => Some(name.as_str()),
476                _ => None,
477            })
478    }
479
480    /// Substitute unknowns with values
481    /// If a definition is missing, it will be left as an unknown,
482    /// and can be filled in later.
483    pub fn substitute(
484        &self,
485        definitions: &HashMap<SmolStr, Value>,
486    ) -> Result<Expr, SubstitutionError> {
487        match self.expr_kind() {
488            ExprKind::Lit(_) => Ok(self.clone()),
489            ExprKind::Unknown {
490                name,
491                type_annotation,
492            } => match (definitions.get(name), type_annotation) {
493                (None, _) => Ok(self.clone()),
494                (Some(value), None) => Ok(value.clone().into()),
495                (Some(value), Some(t)) => {
496                    if &value.type_of() == t {
497                        Ok(value.clone().into())
498                    } else {
499                        Err(SubstitutionError::TypeError {
500                            expected: t.clone(),
501                            actual: value.type_of(),
502                        })
503                    }
504                }
505            },
506            ExprKind::Var(_) => Ok(self.clone()),
507            ExprKind::Slot(_) => Ok(self.clone()),
508            ExprKind::If {
509                test_expr,
510                then_expr,
511                else_expr,
512            } => Ok(Expr::ite(
513                test_expr.substitute(definitions)?,
514                then_expr.substitute(definitions)?,
515                else_expr.substitute(definitions)?,
516            )),
517            ExprKind::And { left, right } => Ok(Expr::and(
518                left.substitute(definitions)?,
519                right.substitute(definitions)?,
520            )),
521            ExprKind::Or { left, right } => Ok(Expr::or(
522                left.substitute(definitions)?,
523                right.substitute(definitions)?,
524            )),
525            ExprKind::UnaryApp { op, arg } => {
526                Ok(Expr::unary_app(*op, arg.substitute(definitions)?))
527            }
528            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
529                *op,
530                arg1.substitute(definitions)?,
531                arg2.substitute(definitions)?,
532            )),
533            ExprKind::ExtensionFunctionApp { op, args } => {
534                let args = args
535                    .iter()
536                    .map(|e| e.substitute(definitions))
537                    .collect::<Result<Vec<Expr>, _>>()?;
538
539                Ok(Expr::call_extension_fn(op.function_name.clone(), args))
540            }
541            ExprKind::GetAttr { expr, attr } => {
542                Ok(Expr::get_attr(expr.substitute(definitions)?, attr.clone()))
543            }
544            ExprKind::HasAttr { expr, attr } => {
545                Ok(Expr::has_attr(expr.substitute(definitions)?, attr.clone()))
546            }
547            ExprKind::Like { expr, pattern } => Ok(Expr::like(
548                expr.substitute(definitions)?,
549                pattern.iter().cloned(),
550            )),
551            ExprKind::Set(members) => {
552                let members = members
553                    .iter()
554                    .map(|e| e.substitute(definitions))
555                    .collect::<Result<Vec<_>, _>>()?;
556                Ok(Expr::set(members))
557            }
558            ExprKind::Record { pairs } => {
559                let pairs = pairs
560                    .iter()
561                    .map(|(name, e)| Ok((name.clone(), e.substitute(definitions)?)))
562                    .collect::<Result<Vec<_>, _>>()?;
563                Ok(Expr::record(pairs))
564            }
565            ExprKind::MulByConst { arg, constant } => {
566                Ok(Expr::mul(arg.substitute(definitions)?, *constant))
567            }
568        }
569    }
570}
571
572impl std::fmt::Display for Expr {
573    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574        match &self.expr_kind {
575            // Add parenthesis around negative numeric literals otherwise
576            // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
577            ExprKind::Lit(Literal::Long(n)) if *n < 0 => write!(f, "({})", n),
578            ExprKind::Lit(l) => write!(f, "{}", l),
579            ExprKind::Var(v) => write!(f, "{}", v),
580            ExprKind::Unknown {
581                name,
582                type_annotation,
583            } => match type_annotation.as_ref() {
584                Some(type_annotation) => write!(f, "unknown({name:?}:{type_annotation})"),
585                None => write!(f, "unknown({name})"),
586            },
587            ExprKind::Slot(id) => write!(f, "{id}"),
588            ExprKind::If {
589                test_expr,
590                then_expr,
591                else_expr,
592            } => write!(
593                f,
594                "if {} then {} else {}",
595                maybe_with_parens(test_expr),
596                maybe_with_parens(then_expr),
597                maybe_with_parens(else_expr)
598            ),
599            ExprKind::And { left, right } => write!(
600                f,
601                "{} && {}",
602                maybe_with_parens(left),
603                maybe_with_parens(right)
604            ),
605            ExprKind::Or { left, right } => write!(
606                f,
607                "{} || {}",
608                maybe_with_parens(left),
609                maybe_with_parens(right)
610            ),
611            ExprKind::UnaryApp { op, arg } => match op {
612                UnaryOp::Not => write!(f, "!{}", maybe_with_parens(arg)),
613                // Always add parentheses instead of calling
614                // `maybe_with_parens`.
615                // This makes sure that we always get a negation operation back
616                // (as opposed to e.g., a negative number) when parsing the
617                // printed form, thus preserving the round-tripping property.
618                UnaryOp::Neg => write!(f, "-({})", arg),
619            },
620            ExprKind::BinaryApp { op, arg1, arg2 } => match op {
621                BinaryOp::Eq => write!(
622                    f,
623                    "{} == {}",
624                    maybe_with_parens(arg1),
625                    maybe_with_parens(arg2),
626                ),
627                BinaryOp::Less => write!(
628                    f,
629                    "{} < {}",
630                    maybe_with_parens(arg1),
631                    maybe_with_parens(arg2),
632                ),
633                BinaryOp::LessEq => write!(
634                    f,
635                    "{} <= {}",
636                    maybe_with_parens(arg1),
637                    maybe_with_parens(arg2),
638                ),
639                BinaryOp::Add => write!(
640                    f,
641                    "{} + {}",
642                    maybe_with_parens(arg1),
643                    maybe_with_parens(arg2),
644                ),
645                BinaryOp::Sub => write!(
646                    f,
647                    "{} - {}",
648                    maybe_with_parens(arg1),
649                    maybe_with_parens(arg2),
650                ),
651                BinaryOp::In => write!(
652                    f,
653                    "{} in {}",
654                    maybe_with_parens(arg1),
655                    maybe_with_parens(arg2),
656                ),
657                BinaryOp::Contains => {
658                    write!(f, "{}.contains({})", maybe_with_parens(arg1), &arg2)
659                }
660                BinaryOp::ContainsAll => {
661                    write!(f, "{}.containsAll({})", maybe_with_parens(arg1), &arg2)
662                }
663                BinaryOp::ContainsAny => {
664                    write!(f, "{}.containsAny({})", maybe_with_parens(arg1), &arg2)
665                }
666            },
667            ExprKind::MulByConst { arg, constant } => {
668                write!(f, "{} * {}", maybe_with_parens(arg), constant)
669            }
670            ExprKind::ExtensionFunctionApp { op, args } => {
671                // search for the name and callstyle
672                let style = Extensions::all_available().all_funcs().find_map(|f| {
673                    if f.name() == &op.function_name {
674                        Some(f.style())
675                    } else {
676                        None
677                    }
678                });
679                if matches!(style, Some(CallStyle::MethodStyle)) && !args.is_empty() {
680                    // This indexing operation is safe, as MethodStyle calls MUST have a non-empty args list.
681                    // See INVARIANT (MethodStyleArgs)
682                    write!(
683                        f,
684                        "{}.{}({})",
685                        maybe_with_parens(&args[0]),
686                        op.function_name,
687                        args[1..].iter().join(", ")
688                    )
689                } else {
690                    // a function without a name or a method-style function
691                    // call with with zero arguments can only occur when
692                    // manually constructing an AST,
693                    // and even then, this form will be representable because the entire name and
694                    // all the args can be displayed with proper syntax. However, it will not parse
695                    // because the parser requires the extention name.
696                    write!(f, "{}({})", op.function_name, args.iter().join(", "))
697                }
698            }
699            ExprKind::GetAttr { expr, attr } => write!(
700                f,
701                "{}[\"{}\"]",
702                maybe_with_parens(expr),
703                attr.escape_debug()
704            ),
705            ExprKind::HasAttr { expr, attr } => {
706                write!(
707                    f,
708                    "{} has \"{}\"",
709                    maybe_with_parens(expr),
710                    attr.escape_debug()
711                )
712            }
713            ExprKind::Like { expr, pattern } => {
714                // during parsing we convert \* in the pattern into \u{0000},
715                // so when printing we need to convert back
716                write!(f, "{} like \"{}\"", maybe_with_parens(expr), pattern,)
717            }
718            ExprKind::Set(v) => write!(f, "[{}]", v.iter().join(", ")),
719            ExprKind::Record { pairs } => write!(
720                f,
721                "{{{}}}",
722                pairs
723                    .iter()
724                    .map(|(k, v)| format!("\"{}\": {}", k.escape_debug(), v))
725                    .join(", ")
726            ),
727        }
728    }
729}
730
731/// returns the `Display` representation of the Expr, adding parens around the
732/// entire Expr if necessary.
733/// E.g., won't add parens for constants or `principal` etc, but will for things
734/// like `(2 < 5)`.
735/// When in doubt, add the parens.
736fn maybe_with_parens(expr: &Expr) -> String {
737    match expr.expr_kind {
738        ExprKind::Lit(_) => expr.to_string(),
739        ExprKind::Var(_) => expr.to_string(),
740        ExprKind::Unknown { .. } => expr.to_string(),
741        ExprKind::Slot(_) => expr.to_string(),
742        ExprKind::If { .. } => format!("({})", expr),
743        ExprKind::And { .. } => format!("({})", expr),
744        ExprKind::Or { .. } => format!("({})", expr),
745        ExprKind::UnaryApp {
746            op: UnaryOp::Not | UnaryOp::Neg,
747            ..
748        } => {
749            // we want parens here because things like parse((!x).y)
750            // would be printed into !x.y which has a different meaning,
751            // albeit being semantically incorrect.
752            format!("({})", expr)
753        }
754        ExprKind::BinaryApp { .. } => format!("({})", expr),
755        ExprKind::MulByConst { .. } => format!("({})", expr),
756        ExprKind::ExtensionFunctionApp { .. } => format!("({})", expr),
757        ExprKind::GetAttr { .. } => format!("({})", expr),
758        ExprKind::HasAttr { .. } => format!("({})", expr),
759        ExprKind::Like { .. } => format!("({})", expr),
760        ExprKind::Set { .. } => expr.to_string(),
761        ExprKind::Record { .. } => expr.to_string(),
762    }
763}
764
765/// Enum for errors encountered during substitution
766#[derive(Debug, Clone, Error)]
767pub enum SubstitutionError {
768    /// The supplied value did not match the type annotation on the unknown.
769    #[error("Expected a value of type {expected}, got a value of type {actual}")]
770    TypeError {
771        /// The expected type, ie: the type the unknown was annotated with
772        expected: Type,
773        /// The type of the provided value
774        actual: Type,
775    },
776}
777
778/// Builder for constructing `Expr` objects annotated with some `data`
779/// (possibly taking default value) and optional some `source_info`.
780#[derive(Debug)]
781pub struct ExprBuilder<T> {
782    source_info: Option<SourceInfo>,
783    data: T,
784}
785
786impl<T> ExprBuilder<T>
787where
788    T: Default,
789{
790    /// Construct a new `ExprBuilder` where the data used for an expression
791    /// takes a default value.
792    pub fn new() -> Self {
793        Self {
794            source_info: None,
795            data: T::default(),
796        }
797    }
798
799    /// Create a '!=' expression.
800    /// Defined only for `T: Default` because the caller would otherwise need to
801    /// provide a `data` for the intermediate `not` Expr node.
802    pub fn noteq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
803        match &self.source_info {
804            Some(source_info) => ExprBuilder::new().with_source_info(source_info.clone()),
805            None => ExprBuilder::new(),
806        }
807        .not(self.with_expr_kind(ExprKind::BinaryApp {
808            op: BinaryOp::Eq,
809            arg1: Arc::new(e1),
810            arg2: Arc::new(e2),
811        }))
812    }
813}
814
815impl<T: Default> Default for ExprBuilder<T> {
816    fn default() -> Self {
817        Self::new()
818    }
819}
820
821impl<T> ExprBuilder<T> {
822    /// Construct a new `ExprBuild` where the specified data will be stored on
823    /// the `Expr`. This constructor does not populate the `source_info` field,
824    /// so `with_source_info` should be called if constructing an `Expr` where
825    /// the source location is known.
826    pub fn with_data(data: T) -> Self {
827        Self {
828            source_info: None,
829            data,
830        }
831    }
832
833    /// Update the `ExprBuilder` to build an expression with some known location
834    /// in policy source code.
835    pub fn with_source_info(self, source_info: SourceInfo) -> Self {
836        self.with_maybe_source_info(Some(source_info))
837    }
838
839    /// Utility used the validator to get an expression with the same source
840    /// location as an existing expression. This is done when reconstructing the
841    /// `Expr` with type information.
842    pub fn with_same_source_info<U>(self, expr: &Expr<U>) -> Self {
843        self.with_maybe_source_info(expr.source_info.clone())
844    }
845
846    /// internally used to update SourceInfo to the given `Some` or `None`
847    fn with_maybe_source_info(mut self, maybe_source_info: Option<SourceInfo>) -> Self {
848        self.source_info = maybe_source_info;
849        self
850    }
851
852    /// Internally used by the following methods to construct an `Expr`
853    /// containing the `data` and `source_info` in this `ExprBuilder` with some
854    /// inner `ExprKind`.
855    fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
856        Expr::new(expr_kind, self.source_info, self.data)
857    }
858
859    /// Create an `Expr` that's just a single `Literal`.
860    ///
861    /// Note that you can pass this a `Literal`, an `i64`, a `String`, etc.
862    pub fn val(self, v: impl Into<Literal>) -> Expr<T> {
863        self.with_expr_kind(ExprKind::Lit(v.into()))
864    }
865
866    /// Create an `Unknown` `Expr`
867    pub fn unknown(self, name: impl Into<SmolStr>, type_annotation: Option<Type>) -> Expr<T> {
868        self.with_expr_kind(ExprKind::Unknown {
869            name: name.into(),
870            type_annotation,
871        })
872    }
873
874    /// Create an `Expr` that's just this literal `Var`
875    pub fn var(self, v: Var) -> Expr<T> {
876        self.with_expr_kind(ExprKind::Var(v))
877    }
878
879    /// Create an `Expr` that's just this `SlotId`
880    pub fn slot(self, s: SlotId) -> Expr<T> {
881        self.with_expr_kind(ExprKind::Slot(s))
882    }
883
884    /// Create a ternary (if-then-else) `Expr`.
885    ///
886    /// `test_expr` must evaluate to a Bool type
887    pub fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
888        self.with_expr_kind(ExprKind::If {
889            test_expr: Arc::new(test_expr),
890            then_expr: Arc::new(then_expr),
891            else_expr: Arc::new(else_expr),
892        })
893    }
894
895    /// Create a 'not' expression. `e` must evaluate to Bool type
896    pub fn not(self, e: Expr<T>) -> Expr<T> {
897        self.with_expr_kind(ExprKind::UnaryApp {
898            op: UnaryOp::Not,
899            arg: Arc::new(e),
900        })
901    }
902
903    /// Create a '==' expression
904    pub fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
905        self.with_expr_kind(ExprKind::BinaryApp {
906            op: BinaryOp::Eq,
907            arg1: Arc::new(e1),
908            arg2: Arc::new(e2),
909        })
910    }
911
912    /// Create an 'and' expression. Arguments must evaluate to Bool type
913    pub fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
914        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
915            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
916                ExprKind::Lit(Literal::Bool(*b1 && *b2))
917            }
918            _ => ExprKind::And {
919                left: Arc::new(e1),
920                right: Arc::new(e2),
921            },
922        })
923    }
924
925    /// Create an 'or' expression. Arguments must evaluate to Bool type
926    pub fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
927        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
928            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
929                ExprKind::Lit(Literal::Bool(*b1 || *b2))
930            }
931
932            _ => ExprKind::Or {
933                left: Arc::new(e1),
934                right: Arc::new(e2),
935            },
936        })
937    }
938
939    /// Create a '<' expression. Arguments must evaluate to Long type
940    pub fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
941        self.with_expr_kind(ExprKind::BinaryApp {
942            op: BinaryOp::Less,
943            arg1: Arc::new(e1),
944            arg2: Arc::new(e2),
945        })
946    }
947
948    /// Create a '<=' expression. Arguments must evaluate to Long type
949    pub fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
950        self.with_expr_kind(ExprKind::BinaryApp {
951            op: BinaryOp::LessEq,
952            arg1: Arc::new(e1),
953            arg2: Arc::new(e2),
954        })
955    }
956
957    /// Create a '>' expression. Arguments must evaluate to Long type
958    pub fn greater(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
959        self.less(e2, e1)
960    }
961
962    /// Create a '>=' expression. Arguments must evaluate to Long type
963    pub fn greatereq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
964        self.lesseq(e2, e1)
965    }
966
967    /// Create an 'add' expression. Arguments must evaluate to Long type
968    pub fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
969        self.with_expr_kind(ExprKind::BinaryApp {
970            op: BinaryOp::Add,
971            arg1: Arc::new(e1),
972            arg2: Arc::new(e2),
973        })
974    }
975
976    /// Create a 'sub' expression. Arguments must evaluate to Long type
977    pub fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
978        self.with_expr_kind(ExprKind::BinaryApp {
979            op: BinaryOp::Sub,
980            arg1: Arc::new(e1),
981            arg2: Arc::new(e2),
982        })
983    }
984
985    /// Create a 'mul' expression. First argument must evaluate to Long type.
986    pub fn mul(self, e: Expr<T>, c: i64) -> Expr<T> {
987        self.with_expr_kind(ExprKind::MulByConst {
988            arg: Arc::new(e),
989            constant: c,
990        })
991    }
992
993    /// Create a 'neg' expression. `e` must evaluate to Long type.
994    pub fn neg(self, e: Expr<T>) -> Expr<T> {
995        self.with_expr_kind(ExprKind::UnaryApp {
996            op: UnaryOp::Neg,
997            arg: Arc::new(e),
998        })
999    }
1000
1001    /// Create an 'in' expression. First argument must evaluate to Entity type.
1002    /// Second argument must evaluate to either Entity type or Set type where
1003    /// all set elements have Entity type.
1004    pub fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1005        self.with_expr_kind(ExprKind::BinaryApp {
1006            op: BinaryOp::In,
1007            arg1: Arc::new(e1),
1008            arg2: Arc::new(e2),
1009        })
1010    }
1011
1012    /// Create a 'contains' expression.
1013    /// First argument must have Set type.
1014    pub fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1015        self.with_expr_kind(ExprKind::BinaryApp {
1016            op: BinaryOp::Contains,
1017            arg1: Arc::new(e1),
1018            arg2: Arc::new(e2),
1019        })
1020    }
1021
1022    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1023    pub fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1024        self.with_expr_kind(ExprKind::BinaryApp {
1025            op: BinaryOp::ContainsAll,
1026            arg1: Arc::new(e1),
1027            arg2: Arc::new(e2),
1028        })
1029    }
1030
1031    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1032    pub fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1033        self.with_expr_kind(ExprKind::BinaryApp {
1034            op: BinaryOp::ContainsAny,
1035            arg1: Arc::new(e1),
1036            arg2: Arc::new(e2),
1037        })
1038    }
1039
1040    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1041    pub fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1042        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1043    }
1044
1045    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1046    pub fn record(self, pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>) -> Expr<T> {
1047        self.with_expr_kind(ExprKind::Record {
1048            pairs: Arc::new(pairs.into_iter().collect()),
1049        })
1050    }
1051
1052    /// Create an `Expr` which calls the extension function with the given
1053    /// `Name` on `args`
1054    pub fn call_extension_fn(self, function_name: Name, args: Vec<Expr<T>>) -> Expr<T> {
1055        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1056            op: ExtensionFunctionOp { function_name },
1057            args: Arc::new(args),
1058        })
1059    }
1060
1061    /// Create an application `Expr` which applies the given built-in unary
1062    /// operator to the given `arg`
1063    pub fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1064        self.with_expr_kind(ExprKind::UnaryApp {
1065            op: op.into(),
1066            arg: Arc::new(arg),
1067        })
1068    }
1069
1070    /// Create an application `Expr` which applies the given built-in binary
1071    /// operator to `arg1` and `arg2`
1072    pub fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1073        self.with_expr_kind(ExprKind::BinaryApp {
1074            op: op.into(),
1075            arg1: Arc::new(arg1),
1076            arg2: Arc::new(arg2),
1077        })
1078    }
1079
1080    /// Create an `Expr` which gets the attribute of some `Entity` or the field
1081    /// of some record.
1082    ///
1083    /// `expr` must evaluate to either Entity or Record type
1084    pub fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1085        self.with_expr_kind(ExprKind::GetAttr {
1086            expr: Arc::new(expr),
1087            attr,
1088        })
1089    }
1090
1091    /// Create an `Expr` which tests for the existence of a given
1092    /// attribute on a given `Entity`, or field on a given record.
1093    ///
1094    /// `expr` must evaluate to either Entity or Record type
1095    pub fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1096        self.with_expr_kind(ExprKind::HasAttr {
1097            expr: Arc::new(expr),
1098            attr,
1099        })
1100    }
1101
1102    /// Create a 'like' expression.
1103    ///
1104    /// `expr` must evaluate to a String type
1105    pub fn like(self, expr: Expr<T>, pattern: impl IntoIterator<Item = PatternElem>) -> Expr<T> {
1106        self.with_expr_kind(ExprKind::Like {
1107            expr: Arc::new(expr),
1108            pattern: Pattern::new(pattern),
1109        })
1110    }
1111}
1112
1113impl<T: Clone> ExprBuilder<T> {
1114    /// Create an `and` expression that may have more than two subexpressions (A && B && C)
1115    /// or may have only one subexpression, in which case no `&&` is performed at all.
1116    /// Arguments must evaluate to Bool type.
1117    ///
1118    /// This may create multiple AST `&&` nodes. If it does, all the nodes will have the same
1119    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1120    /// with another call to `with_source_info()`.
1121    pub fn and_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1122        others.into_iter().fold(first, |acc, next| {
1123            Self::with_data(self.data.clone())
1124                .with_maybe_source_info(self.source_info.clone())
1125                .and(acc, next)
1126        })
1127    }
1128
1129    /// Create an `or` expression that may have more than two subexpressions (A || B || C)
1130    /// or may have only one subexpression, in which case no `||` is performed at all.
1131    /// Arguments must evaluate to Bool type.
1132    ///
1133    /// This may create multiple AST `||` nodes. If it does, all the nodes will have the same
1134    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1135    /// with another call to `with_source_info()`.
1136    pub fn or_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1137        others.into_iter().fold(first, |acc, next| {
1138            Self::with_data(self.data.clone())
1139                .with_maybe_source_info(self.source_info.clone())
1140                .or(acc, next)
1141        })
1142    }
1143}
1144
1145/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1146/// implementations that ignore any source information or other generic data
1147/// used to annotate the `Expr`.
1148#[derive(Eq, Debug, Clone)]
1149pub struct ExprShapeOnly<'a, T = ()>(&'a Expr<T>);
1150
1151impl<'a, T> ExprShapeOnly<'a, T> {
1152    /// Construct an `ExprShapeOnly` from an `Expr`. The `Expr` is not modified,
1153    /// but any comparisons on the resulting `ExprShapeOnly` will ignore source
1154    /// information and generic data.
1155    pub fn new(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1156        ExprShapeOnly(e)
1157    }
1158}
1159
1160impl<'a, T> PartialEq for ExprShapeOnly<'a, T> {
1161    fn eq(&self, other: &Self) -> bool {
1162        self.0.eq_shape(other.0)
1163    }
1164}
1165
1166impl<'a, T> Hash for ExprShapeOnly<'a, T> {
1167    fn hash<H: Hasher>(&self, state: &mut H) {
1168        self.0.hash_shape(state);
1169    }
1170}
1171
1172impl<T> Expr<T> {
1173    /// Return true if this expression (recursively) has the same expression
1174    /// kind as the argument expression. This accounts for the full recursive
1175    /// shape of the expression, but does not consider source information or any
1176    /// generic data annotated on expression. This should behave the same as the
1177    /// default implementation of `Eq` before source information and generic
1178    /// data were added.
1179    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1180        use ExprKind::*;
1181        match (self.expr_kind(), other.expr_kind()) {
1182            (Lit(l), Lit(l1)) => l == l1,
1183            (Var(v), Var(v1)) => v == v1,
1184            (Slot(s), Slot(s1)) => s == s1,
1185            (
1186                Unknown {
1187                    name: name1,
1188                    type_annotation: ta_1,
1189                },
1190                Unknown {
1191                    name: name2,
1192                    type_annotation: ta_2,
1193                },
1194            ) => (name1 == name2) && (ta_1 == ta_2),
1195            (
1196                If {
1197                    test_expr,
1198                    then_expr,
1199                    else_expr,
1200                },
1201                If {
1202                    test_expr: test_expr1,
1203                    then_expr: then_expr1,
1204                    else_expr: else_expr1,
1205                },
1206            ) => {
1207                test_expr.eq_shape(test_expr1)
1208                    && then_expr.eq_shape(then_expr1)
1209                    && else_expr.eq_shape(else_expr1)
1210            }
1211            (
1212                And { left, right },
1213                And {
1214                    left: left1,
1215                    right: right1,
1216                },
1217            )
1218            | (
1219                Or { left, right },
1220                Or {
1221                    left: left1,
1222                    right: right1,
1223                },
1224            ) => left.eq_shape(left1) && right.eq_shape(right1),
1225            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1226                op == op1 && arg.eq_shape(arg1)
1227            }
1228            (
1229                BinaryApp { op, arg1, arg2 },
1230                BinaryApp {
1231                    op: op1,
1232                    arg1: arg11,
1233                    arg2: arg21,
1234                },
1235            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1236            (
1237                MulByConst { arg, constant },
1238                MulByConst {
1239                    arg: arg1,
1240                    constant: constant1,
1241                },
1242            ) => constant == constant1 && arg.eq_shape(arg1),
1243            (
1244                ExtensionFunctionApp { op, args },
1245                ExtensionFunctionApp {
1246                    op: op1,
1247                    args: args1,
1248                },
1249            ) => op == op1 && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1)),
1250            (
1251                GetAttr { expr, attr },
1252                GetAttr {
1253                    expr: expr1,
1254                    attr: attr1,
1255                },
1256            )
1257            | (
1258                HasAttr { expr, attr },
1259                HasAttr {
1260                    expr: expr1,
1261                    attr: attr1,
1262                },
1263            ) => attr == attr1 && expr.eq_shape(expr1),
1264            (
1265                Like { expr, pattern },
1266                Like {
1267                    expr: expr1,
1268                    pattern: pattern1,
1269                },
1270            ) => pattern == pattern1 && expr.eq_shape(expr1),
1271            (Set(elems), Set(elems1)) => elems
1272                .iter()
1273                .zip(elems1.iter())
1274                .all(|(e, e1)| e.eq_shape(e1)),
1275            (Record { pairs }, Record { pairs: pairs1 }) => pairs
1276                .iter()
1277                .zip(pairs1.iter())
1278                .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1)),
1279            _ => false,
1280        }
1281    }
1282
1283    /// Implementation of hashing corresponding to equality as implemented by
1284    /// `eq_shape`. Must satisfy the usual relationship between equality and
1285    /// hashing.
1286    pub fn hash_shape<H>(&self, state: &mut H)
1287    where
1288        H: Hasher,
1289    {
1290        mem::discriminant(self).hash(state);
1291        match self.expr_kind() {
1292            ExprKind::Lit(l) => l.hash(state),
1293            ExprKind::Var(v) => v.hash(state),
1294            ExprKind::Slot(s) => s.hash(state),
1295            ExprKind::Unknown {
1296                name,
1297                type_annotation,
1298            } => {
1299                name.hash(state);
1300                type_annotation.hash(state);
1301            }
1302            ExprKind::If {
1303                test_expr,
1304                then_expr,
1305                else_expr,
1306            } => {
1307                test_expr.hash_shape(state);
1308                then_expr.hash_shape(state);
1309                else_expr.hash_shape(state);
1310            }
1311            ExprKind::And { left, right } => {
1312                left.hash_shape(state);
1313                right.hash_shape(state);
1314            }
1315            ExprKind::Or { left, right } => {
1316                left.hash_shape(state);
1317                right.hash_shape(state);
1318            }
1319            ExprKind::UnaryApp { op, arg } => {
1320                op.hash(state);
1321                arg.hash_shape(state);
1322            }
1323            ExprKind::BinaryApp { op, arg1, arg2 } => {
1324                op.hash(state);
1325                arg1.hash_shape(state);
1326                arg2.hash_shape(state);
1327            }
1328            ExprKind::MulByConst { arg, constant } => {
1329                arg.hash_shape(state);
1330                constant.hash(state);
1331            }
1332            ExprKind::ExtensionFunctionApp { op, args } => {
1333                op.hash(state);
1334                state.write_usize(args.len());
1335                args.iter().for_each(|a| {
1336                    a.hash_shape(state);
1337                });
1338            }
1339            ExprKind::GetAttr { expr, attr } => {
1340                expr.hash_shape(state);
1341                attr.hash(state);
1342            }
1343            ExprKind::HasAttr { expr, attr } => {
1344                expr.hash_shape(state);
1345                attr.hash(state);
1346            }
1347            ExprKind::Like { expr, pattern } => {
1348                expr.hash_shape(state);
1349                pattern.hash(state);
1350            }
1351            ExprKind::Set(elems) => {
1352                state.write_usize(elems.len());
1353                elems.iter().for_each(|e| {
1354                    e.hash_shape(state);
1355                })
1356            }
1357            ExprKind::Record { pairs } => {
1358                state.write_usize(pairs.len());
1359                pairs.iter().for_each(|(s, a)| {
1360                    s.hash(state);
1361                    a.hash_shape(state);
1362                });
1363            }
1364        }
1365    }
1366}
1367
1368/// AST variables
1369#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1370#[cfg_attr(fuzzing, derive(arbitrary::Arbitrary))]
1371pub enum Var {
1372    /// the Principal of the given request
1373    #[serde(rename = "principal")]
1374    Principal,
1375    /// the Action of the given request
1376    #[serde(rename = "action")]
1377    Action,
1378    /// the Resource of the given request
1379    #[serde(rename = "resource")]
1380    Resource,
1381    /// the Context of the given request
1382    #[serde(rename = "context")]
1383    Context,
1384}
1385
1386#[cfg(test)]
1387pub mod var_generator {
1388    use super::Var;
1389    #[cfg(test)]
1390    pub fn all_vars() -> impl Iterator<Item = Var> {
1391        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1392    }
1393}
1394// by default, Coverlay does not track coverage for lines after a line
1395// containing #[cfg(test)].
1396// we use the following sentinel to "turn back on" coverage tracking for
1397// remaining lines of this file, until the next #[cfg(test)]
1398// GRCOV_BEGIN_COVERAGE
1399
1400impl From<PrincipalOrResource> for Var {
1401    fn from(v: PrincipalOrResource) -> Self {
1402        match v {
1403            PrincipalOrResource::Principal => Var::Principal,
1404            PrincipalOrResource::Resource => Var::Resource,
1405        }
1406    }
1407}
1408
1409impl From<Var> for Id {
1410    fn from(var: Var) -> Self {
1411        format!("{var}").parse().expect("Invalid Identifier")
1412    }
1413}
1414
1415impl std::fmt::Display for Var {
1416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1417        match self {
1418            Self::Principal => write!(f, "principal"),
1419            Self::Action => write!(f, "action"),
1420            Self::Resource => write!(f, "resource"),
1421            Self::Context => write!(f, "context"),
1422        }
1423    }
1424}
1425
1426#[cfg(test)]
1427mod test {
1428    use std::{
1429        collections::{hash_map::DefaultHasher, HashSet},
1430        sync::Arc,
1431    };
1432
1433    use super::{var_generator::all_vars, *};
1434
1435    // Tests that Var::Into never panics
1436    #[test]
1437    fn all_vars_are_ids() {
1438        for var in all_vars() {
1439            let _id: Id = var.into();
1440        }
1441    }
1442
1443    #[test]
1444    fn exprs() {
1445        assert_eq!(
1446            Expr::val(33),
1447            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1448        );
1449        assert_eq!(
1450            Expr::val("hello"),
1451            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1452        );
1453        assert_eq!(
1454            Expr::val(EntityUID::with_eid("foo")),
1455            Expr::new(
1456                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1457                None,
1458                ()
1459            )
1460        );
1461        assert_eq!(
1462            Expr::var(Var::Principal),
1463            Expr::new(ExprKind::Var(Var::Principal), None, ())
1464        );
1465        assert_eq!(
1466            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1467            Expr::new(
1468                ExprKind::If {
1469                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1470                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1471                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1472                },
1473                None,
1474                ()
1475            )
1476        );
1477        assert_eq!(
1478            Expr::not(Expr::val(false)),
1479            Expr::new(
1480                ExprKind::UnaryApp {
1481                    op: UnaryOp::Not,
1482                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1483                },
1484                None,
1485                ()
1486            )
1487        );
1488        assert_eq!(
1489            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1490            Expr::new(
1491                ExprKind::GetAttr {
1492                    expr: Arc::new(Expr::new(
1493                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1494                        None,
1495                        ()
1496                    )),
1497                    attr: "some_attr".into()
1498                },
1499                None,
1500                ()
1501            )
1502        );
1503        assert_eq!(
1504            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1505            Expr::new(
1506                ExprKind::HasAttr {
1507                    expr: Arc::new(Expr::new(
1508                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1509                        None,
1510                        ()
1511                    )),
1512                    attr: "some_attr".into()
1513                },
1514                None,
1515                ()
1516            )
1517        );
1518    }
1519
1520    #[test]
1521    fn like_display() {
1522        // `\0` escaped form is `\0`.
1523        let e = Expr::like(Expr::val("a"), vec![PatternElem::Char('\0')]);
1524        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1525        // `\`'s escaped form is `\\`
1526        let e = Expr::like(
1527            Expr::val("a"),
1528            vec![PatternElem::Char('\\'), PatternElem::Char('0')],
1529        );
1530        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1531        // `\`'s escaped form is `\\`
1532        let e = Expr::like(
1533            Expr::val("a"),
1534            vec![PatternElem::Char('\\'), PatternElem::Wildcard],
1535        );
1536        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1537        // literal star's escaped from is `\*`
1538        let e = Expr::like(
1539            Expr::val("a"),
1540            vec![PatternElem::Char('\\'), PatternElem::Char('*')],
1541        );
1542        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1543    }
1544
1545    #[test]
1546    fn slot_display() {
1547        let e = Expr::slot(SlotId::principal());
1548        assert_eq!(format!("{e}"), "?principal");
1549        let e = Expr::slot(SlotId::resource());
1550        assert_eq!(format!("{e}"), "?resource");
1551        let e = Expr::val(EntityUID::with_eid("eid"));
1552        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1553    }
1554
1555    #[test]
1556    fn simple_slots() {
1557        let e = Expr::slot(SlotId::principal());
1558        let p = SlotId::principal();
1559        let r = SlotId::resource();
1560        let set: HashSet<&SlotId> = [&p].into_iter().collect();
1561        assert_eq!(set, e.slots().collect::<HashSet<_>>());
1562        let e = Expr::or(
1563            Expr::slot(SlotId::principal()),
1564            Expr::ite(
1565                Expr::val(true),
1566                Expr::slot(SlotId::resource()),
1567                Expr::val(false),
1568            ),
1569        );
1570        let set: HashSet<&SlotId> = [&p, &r].into_iter().collect();
1571        assert_eq!(set, e.slots().collect::<HashSet<_>>());
1572    }
1573
1574    #[test]
1575    fn unknowns() {
1576        let e = Expr::ite(
1577            Expr::not(Expr::unknown("a".to_string())),
1578            Expr::and(Expr::unknown("b".to_string()), Expr::val(3)),
1579            Expr::unknown("c".to_string()),
1580        );
1581        let unknowns = e.unknowns().collect_vec();
1582        assert_eq!(unknowns.len(), 3);
1583        assert!(unknowns.contains(&"a"));
1584        assert!(unknowns.contains(&"b"));
1585        assert!(unknowns.contains(&"c"));
1586    }
1587
1588    #[test]
1589    fn is_unknown() {
1590        let e = Expr::ite(
1591            Expr::not(Expr::unknown("a".to_string())),
1592            Expr::and(Expr::unknown("b".to_string()), Expr::val(3)),
1593            Expr::unknown("c".to_string()),
1594        );
1595        assert!(e.is_unknown());
1596        let e = Expr::ite(
1597            Expr::not(Expr::val(true)),
1598            Expr::and(Expr::val(1), Expr::val(3)),
1599            Expr::val(1),
1600        );
1601        assert!(!e.is_unknown());
1602    }
1603
1604    #[test]
1605    fn expr_with_data() {
1606        let e = ExprBuilder::with_data("data").val(1);
1607        assert_eq!(e.into_data(), "data");
1608    }
1609
1610    #[test]
1611    fn expr_shape_only_eq() {
1612        let temp = ExprBuilder::with_data(1).val(1);
1613        let exprs = &[
1614            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1615            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1616            (
1617                ExprBuilder::with_data(1).var(Var::Principal),
1618                Expr::var(Var::Principal),
1619            ),
1620            (
1621                ExprBuilder::with_data(1).slot(SlotId::principal()),
1622                Expr::slot(SlotId::principal()),
1623            ),
1624            (
1625                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1626                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1627            ),
1628            (
1629                ExprBuilder::with_data(1).not(temp.clone()),
1630                Expr::not(Expr::val(1)),
1631            ),
1632            (
1633                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1634                Expr::is_eq(Expr::val(1), Expr::val(1)),
1635            ),
1636            (
1637                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1638                Expr::and(Expr::val(1), Expr::val(1)),
1639            ),
1640            (
1641                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1642                Expr::or(Expr::val(1), Expr::val(1)),
1643            ),
1644            (
1645                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1646                Expr::less(Expr::val(1), Expr::val(1)),
1647            ),
1648            (
1649                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1650                Expr::lesseq(Expr::val(1), Expr::val(1)),
1651            ),
1652            (
1653                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1654                Expr::greater(Expr::val(1), Expr::val(1)),
1655            ),
1656            (
1657                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1658                Expr::greatereq(Expr::val(1), Expr::val(1)),
1659            ),
1660            (
1661                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1662                Expr::add(Expr::val(1), Expr::val(1)),
1663            ),
1664            (
1665                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1666                Expr::sub(Expr::val(1), Expr::val(1)),
1667            ),
1668            (
1669                ExprBuilder::with_data(1).mul(temp.clone(), 1),
1670                Expr::mul(Expr::val(1), 1),
1671            ),
1672            (
1673                ExprBuilder::with_data(1).neg(temp.clone()),
1674                Expr::neg(Expr::val(1)),
1675            ),
1676            (
1677                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1678                Expr::is_in(Expr::val(1), Expr::val(1)),
1679            ),
1680            (
1681                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1682                Expr::contains(Expr::val(1), Expr::val(1)),
1683            ),
1684            (
1685                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1686                Expr::contains_all(Expr::val(1), Expr::val(1)),
1687            ),
1688            (
1689                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1690                Expr::contains_any(Expr::val(1), Expr::val(1)),
1691            ),
1692            (
1693                ExprBuilder::with_data(1).set([temp.clone()]),
1694                Expr::set([Expr::val(1)]),
1695            ),
1696            (
1697                ExprBuilder::with_data(1).record([("foo".into(), temp.clone())]),
1698                Expr::record([("foo".into(), Expr::val(1))]),
1699            ),
1700            (
1701                ExprBuilder::with_data(1)
1702                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1703                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1704            ),
1705            (
1706                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1707                Expr::get_attr(Expr::val(1), "foo".into()),
1708            ),
1709            (
1710                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
1711                Expr::has_attr(Expr::val(1), "foo".into()),
1712            ),
1713            (
1714                ExprBuilder::with_data(1).like(temp, vec![PatternElem::Wildcard]),
1715                Expr::like(Expr::val(1), vec![PatternElem::Wildcard]),
1716            ),
1717        ];
1718
1719        for (e0, e1) in exprs {
1720            assert!(e0.eq_shape(e0));
1721            assert!(e1.eq_shape(e1));
1722            assert!(e0.eq_shape(e1));
1723            assert!(e1.eq_shape(e0));
1724
1725            let mut hasher0 = DefaultHasher::new();
1726            e0.hash_shape(&mut hasher0);
1727            let hash0 = hasher0.finish();
1728
1729            let mut hasher1 = DefaultHasher::new();
1730            e1.hash_shape(&mut hasher1);
1731            let hash1 = hasher1.finish();
1732
1733            assert_eq!(hash0, hash1);
1734        }
1735    }
1736
1737    #[test]
1738    fn expr_shape_only_not_eq() {
1739        let expr1 = ExprBuilder::with_data(1).val(1);
1740        let expr2 = ExprBuilder::with_data(1).val(2);
1741        assert_ne!(ExprShapeOnly::new(&expr1), ExprShapeOnly::new(&expr2));
1742    }
1743}