cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use super::expr_allows_errors::AstExprErrorKind;
19use crate::{
20    ast::*,
21    expr_builder::{self, ExprBuilder as _},
22    extensions::Extensions,
23    parser::{err::ParseErrors, Loc},
24};
25use educe::Educe;
26use miette::Diagnostic;
27use serde::{Deserialize, Serialize};
28use smol_str::SmolStr;
29use std::{
30    borrow::Cow,
31    collections::{btree_map, BTreeMap, HashMap},
32    hash::{Hash, Hasher},
33    mem,
34    sync::Arc,
35};
36use thiserror::Error;
37
38#[cfg(feature = "wasm")]
39extern crate tsify;
40
41/// Internal AST for expressions used by the policy evaluator.
42/// This structure is a wrapper around an `ExprKind`, which is the expression
43/// variant this object contains. It also contains source information about
44/// where the expression was written in policy source code, and some generic
45/// data which is stored on each node of the AST.
46/// Cloning is O(1).
47#[derive(Educe, Debug, Clone)]
48#[educe(PartialEq, Eq, Hash)]
49pub struct Expr<T = ()> {
50    expr_kind: ExprKind<T>,
51    #[educe(PartialEq(ignore))]
52    #[educe(Hash(ignore))]
53    source_loc: Option<Loc>,
54    data: T,
55}
56
57/// The possible expression variants. This enum should be matched on by code
58/// recursively traversing the AST.
59#[derive(Hash, Debug, Clone, PartialEq, Eq)]
60pub enum ExprKind<T = ()> {
61    /// Literal value
62    Lit(Literal),
63    /// Variable
64    Var(Var),
65    /// Template Slots
66    Slot(SlotId),
67    /// Symbolic Unknown for partial-eval
68    Unknown(Unknown),
69    /// Ternary expression
70    If {
71        /// Condition for the ternary expression. Must evaluate to Bool type
72        test_expr: Arc<Expr<T>>,
73        /// Value if true
74        then_expr: Arc<Expr<T>>,
75        /// Value if false
76        else_expr: Arc<Expr<T>>,
77    },
78    /// Boolean AND
79    And {
80        /// Left operand, which will be eagerly evaluated
81        left: Arc<Expr<T>>,
82        /// Right operand, which may not be evaluated due to short-circuiting
83        right: Arc<Expr<T>>,
84    },
85    /// Boolean OR
86    Or {
87        /// Left operand, which will be eagerly evaluated
88        left: Arc<Expr<T>>,
89        /// Right operand, which may not be evaluated due to short-circuiting
90        right: Arc<Expr<T>>,
91    },
92    /// Application of a built-in unary operator (single parameter)
93    UnaryApp {
94        /// Unary operator to apply
95        op: UnaryOp,
96        /// Argument to apply operator to
97        arg: Arc<Expr<T>>,
98    },
99    /// Application of a built-in binary operator (two parameters)
100    BinaryApp {
101        /// Binary operator to apply
102        op: BinaryOp,
103        /// First arg
104        arg1: Arc<Expr<T>>,
105        /// Second arg
106        arg2: Arc<Expr<T>>,
107    },
108    /// Application of an extension function to n arguments
109    /// INVARIANT (MethodStyleArgs):
110    ///   if op.style is MethodStyle then args _cannot_ be empty.
111    ///     The first element of args refers to the subject of the method call
112    /// Ideally, we find some way to make this non-representable.
113    ExtensionFunctionApp {
114        /// Extension function to apply
115        fn_name: Name,
116        /// Args to apply the function to
117        args: Arc<Vec<Expr<T>>>,
118    },
119    /// Get an attribute of an entity, or a field of a record
120    GetAttr {
121        /// Expression to get an attribute/field of. Must evaluate to either
122        /// Entity or Record type
123        expr: Arc<Expr<T>>,
124        /// Attribute or field to get
125        attr: SmolStr,
126    },
127    /// Does the given `expr` have the given `attr`?
128    HasAttr {
129        /// Expression to test. Must evaluate to either Entity or Record type
130        expr: Arc<Expr<T>>,
131        /// Attribute or field to check for
132        attr: SmolStr,
133    },
134    /// Regex-like string matching similar to IAM's `StringLike` operator.
135    Like {
136        /// Expression to test. Must evaluate to String type
137        expr: Arc<Expr<T>>,
138        /// Pattern to match on; can include the wildcard *, which matches any string.
139        /// To match a literal `*` in the test expression, users can use `\*`.
140        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
141        pattern: Pattern,
142    },
143    /// Entity type test. Does the first argument have the entity type
144    /// specified by the second argument.
145    Is {
146        /// Expression to test. Must evaluate to an Entity.
147        expr: Arc<Expr<T>>,
148        /// The [`EntityType`] used for the type membership test.
149        entity_type: EntityType,
150    },
151    /// Set (whose elements may be arbitrary expressions)
152    //
153    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
154    // that are syntactically unequal, may actually be semantically equal --
155    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
156    // evaluated into `Value`s
157    Set(Arc<Vec<Expr<T>>>),
158    /// Anonymous record (whose elements may be arbitrary expressions)
159    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
160    #[cfg(feature = "tolerant-ast")]
161    /// Error expression - allows us to continue parsing even when we have errors
162    Error {
163        /// Type of error that led to the failure
164        error_kind: AstExprErrorKind,
165    },
166}
167
168impl From<Value> for Expr {
169    fn from(v: Value) -> Self {
170        Expr::from(v.value).with_maybe_source_loc(v.loc)
171    }
172}
173
174impl From<ValueKind> for Expr {
175    fn from(v: ValueKind) -> Self {
176        match v {
177            ValueKind::Lit(lit) => Expr::val(lit),
178            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
179            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
180            #[allow(clippy::expect_used)]
181            ValueKind::Record(record) => Expr::record(
182                Arc::unwrap_or_clone(record)
183                    .into_iter()
184                    .map(|(k, v)| (k, Expr::from(v))),
185            )
186            .expect("cannot have duplicate key because the input was already a BTreeMap"),
187            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
188        }
189    }
190}
191
192impl From<PartialValue> for Expr {
193    fn from(pv: PartialValue) -> Self {
194        match pv {
195            PartialValue::Value(v) => Expr::from(v),
196            PartialValue::Residual(expr) => expr,
197        }
198    }
199}
200
201impl<T> Expr<T> {
202    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
203        Self {
204            expr_kind,
205            source_loc,
206            data,
207        }
208    }
209
210    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
211    /// `enum` which specifies the expression variant, so it must be accessed by
212    /// any code matching and recursing on an expression.
213    pub fn expr_kind(&self) -> &ExprKind<T> {
214        &self.expr_kind
215    }
216
217    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
218    pub fn into_expr_kind(self) -> ExprKind<T> {
219        self.expr_kind
220    }
221
222    /// Access the data stored on the `Expr`.
223    pub fn data(&self) -> &T {
224        &self.data
225    }
226
227    /// Access the data stored on the `Expr`, taking ownership and consuming the
228    /// `Expr`.
229    pub fn into_data(self) -> T {
230        self.data
231    }
232
233    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
234    /// data.
235    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
236        (self.expr_kind, self.source_loc, self.data)
237    }
238
239    /// Access the `Loc` stored on the `Expr`.
240    pub fn source_loc(&self) -> Option<&Loc> {
241        self.source_loc.as_ref()
242    }
243
244    /// Return the `Expr`, but with the new `source_loc` (or `None`).
245    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
246        Self { source_loc, ..self }
247    }
248
249    /// Update the data for this `Expr`. A convenient function used by the
250    /// Validator in one place.
251    pub fn set_data(&mut self, data: T) {
252        self.data = data;
253    }
254
255    /// Check whether this expression is an entity reference
256    ///
257    /// This is used for policy scopes, where some syntax is
258    /// required to be an entity reference.
259    pub fn is_ref(&self) -> bool {
260        match &self.expr_kind {
261            ExprKind::Lit(lit) => lit.is_ref(),
262            _ => false,
263        }
264    }
265
266    /// Check whether this expression is a slot.
267    pub fn is_slot(&self) -> bool {
268        matches!(&self.expr_kind, ExprKind::Slot(_))
269    }
270
271    /// Check whether this expression is a set of entity references
272    ///
273    /// This is used for policy scopes, where some syntax is
274    /// required to be an entity reference set.
275    pub fn is_ref_set(&self) -> bool {
276        match &self.expr_kind {
277            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
278            _ => false,
279        }
280    }
281
282    /// Iterate over all sub-expressions in this expression
283    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
284        expr_iterator::ExprIterator::new(self)
285    }
286
287    /// Iterate over all of the slots in this policy AST
288    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
289        self.subexpressions()
290            .filter_map(|exp| match &exp.expr_kind {
291                ExprKind::Slot(slotid) => Some(Slot {
292                    id: *slotid,
293                    loc: exp.source_loc().cloned(),
294                }),
295                _ => None,
296            })
297    }
298
299    /// Determine if the expression is projectable under partial evaluation
300    /// An expression is projectable if it's guaranteed to never error on evaluation
301    /// This is true if the expression is entirely composed of values or unknowns
302    pub fn is_projectable(&self) -> bool {
303        self.subexpressions().all(|e| {
304            matches!(
305                e.expr_kind(),
306                ExprKind::Lit(_)
307                    | ExprKind::Unknown(_)
308                    | ExprKind::Set(_)
309                    | ExprKind::Var(_)
310                    | ExprKind::Record(_)
311            )
312        })
313    }
314
315    /// Try to compute the runtime type of this expression. This operation may
316    /// fail (returning `None`), for example, when asked to get the type of any
317    /// variables, any attributes of entities or records, or an `unknown`
318    /// without an explicitly annotated type.
319    ///
320    /// Also note that this is _not_ typechecking the expression. It does not
321    /// check that the expression actually evaluates to a value (as opposed to
322    /// erroring).
323    ///
324    /// Because of these limitations, this function should only be used to
325    /// obtain a type for use in diagnostics such as error strings.
326    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
327        match &self.expr_kind {
328            ExprKind::Lit(l) => Some(l.type_of()),
329            ExprKind::Var(_) => None,
330            ExprKind::Slot(_) => None,
331            ExprKind::Unknown(u) => u.type_annotation.clone(),
332            ExprKind::If {
333                then_expr,
334                else_expr,
335                ..
336            } => {
337                let type_of_then = then_expr.try_type_of(extensions);
338                let type_of_else = else_expr.try_type_of(extensions);
339                if type_of_then == type_of_else {
340                    type_of_then
341                } else {
342                    None
343                }
344            }
345            ExprKind::And { .. } => Some(Type::Bool),
346            ExprKind::Or { .. } => Some(Type::Bool),
347            ExprKind::UnaryApp {
348                op: UnaryOp::Neg, ..
349            } => Some(Type::Long),
350            ExprKind::UnaryApp {
351                op: UnaryOp::Not, ..
352            } => Some(Type::Bool),
353            ExprKind::UnaryApp {
354                op: UnaryOp::IsEmpty,
355                ..
356            } => Some(Type::Bool),
357            ExprKind::BinaryApp {
358                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
359                ..
360            } => Some(Type::Long),
361            ExprKind::BinaryApp {
362                op:
363                    BinaryOp::Contains
364                    | BinaryOp::ContainsAll
365                    | BinaryOp::ContainsAny
366                    | BinaryOp::Eq
367                    | BinaryOp::In
368                    | BinaryOp::Less
369                    | BinaryOp::LessEq,
370                ..
371            } => Some(Type::Bool),
372            ExprKind::BinaryApp {
373                op: BinaryOp::HasTag,
374                ..
375            } => Some(Type::Bool),
376            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
377                .func(fn_name)
378                .ok()?
379                .return_type()
380                .map(|rty| rty.clone().into()),
381            // We could try to be more complete here, but we can't do all that
382            // much better without evaluating the argument. Even if we know it's
383            // a record `Type::Record` tells us nothing about the type of the
384            // attribute.
385            ExprKind::GetAttr { .. } => None,
386            // similarly to `GetAttr`
387            ExprKind::BinaryApp {
388                op: BinaryOp::GetTag,
389                ..
390            } => None,
391            ExprKind::HasAttr { .. } => Some(Type::Bool),
392            ExprKind::Like { .. } => Some(Type::Bool),
393            ExprKind::Is { .. } => Some(Type::Bool),
394            ExprKind::Set(_) => Some(Type::Set),
395            ExprKind::Record(_) => Some(Type::Record),
396            #[cfg(feature = "tolerant-ast")]
397            ExprKind::Error { .. } => None,
398        }
399    }
400}
401
402#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
403#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
404impl Expr {
405    /// Create an `Expr` that's just a single `Literal`.
406    ///
407    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
408    pub fn val(v: impl Into<Literal>) -> Self {
409        ExprBuilder::new().val(v)
410    }
411
412    /// Create an `Expr` that's just a single `Unknown`.
413    pub fn unknown(u: Unknown) -> Self {
414        ExprBuilder::new().unknown(u)
415    }
416
417    /// Create an `Expr` that's just this literal `Var`
418    pub fn var(v: Var) -> Self {
419        ExprBuilder::new().var(v)
420    }
421
422    /// Create an `Expr` that's just this `SlotId`
423    pub fn slot(s: SlotId) -> Self {
424        ExprBuilder::new().slot(s)
425    }
426
427    /// Create a ternary (if-then-else) `Expr`.
428    ///
429    /// `test_expr` must evaluate to a Bool type
430    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
431        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
432    }
433
434    /// Create a ternary (if-then-else) `Expr`.
435    /// Takes `Arc`s instead of owned `Expr`s.
436    /// `test_expr` must evaluate to a Bool type
437    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
438        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
439    }
440
441    /// Create a 'not' expression. `e` must evaluate to Bool type
442    pub fn not(e: Expr) -> Self {
443        ExprBuilder::new().not(e)
444    }
445
446    /// Create a '==' expression
447    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
448        ExprBuilder::new().is_eq(e1, e2)
449    }
450
451    /// Create a '!=' expression
452    pub fn noteq(e1: Expr, e2: Expr) -> Self {
453        ExprBuilder::new().noteq(e1, e2)
454    }
455
456    /// Create an 'and' expression. Arguments must evaluate to Bool type
457    pub fn and(e1: Expr, e2: Expr) -> Self {
458        ExprBuilder::new().and(e1, e2)
459    }
460
461    /// Create an 'or' expression. Arguments must evaluate to Bool type
462    pub fn or(e1: Expr, e2: Expr) -> Self {
463        ExprBuilder::new().or(e1, e2)
464    }
465
466    /// Create a '<' expression. Arguments must evaluate to Long type
467    pub fn less(e1: Expr, e2: Expr) -> Self {
468        ExprBuilder::new().less(e1, e2)
469    }
470
471    /// Create a '<=' expression. Arguments must evaluate to Long type
472    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
473        ExprBuilder::new().lesseq(e1, e2)
474    }
475
476    /// Create a '>' expression. Arguments must evaluate to Long type
477    pub fn greater(e1: Expr, e2: Expr) -> Self {
478        ExprBuilder::new().greater(e1, e2)
479    }
480
481    /// Create a '>=' expression. Arguments must evaluate to Long type
482    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
483        ExprBuilder::new().greatereq(e1, e2)
484    }
485
486    /// Create an 'add' expression. Arguments must evaluate to Long type
487    pub fn add(e1: Expr, e2: Expr) -> Self {
488        ExprBuilder::new().add(e1, e2)
489    }
490
491    /// Create a 'sub' expression. Arguments must evaluate to Long type
492    pub fn sub(e1: Expr, e2: Expr) -> Self {
493        ExprBuilder::new().sub(e1, e2)
494    }
495
496    /// Create a 'mul' expression. Arguments must evaluate to Long type
497    pub fn mul(e1: Expr, e2: Expr) -> Self {
498        ExprBuilder::new().mul(e1, e2)
499    }
500
501    /// Create a 'neg' expression. `e` must evaluate to Long type.
502    pub fn neg(e: Expr) -> Self {
503        ExprBuilder::new().neg(e)
504    }
505
506    /// Create an 'in' expression. First argument must evaluate to Entity type.
507    /// Second argument must evaluate to either Entity type or Set type where
508    /// all set elements have Entity type.
509    pub fn is_in(e1: Expr, e2: Expr) -> Self {
510        ExprBuilder::new().is_in(e1, e2)
511    }
512
513    /// Create a `contains` expression.
514    /// First argument must have Set type.
515    pub fn contains(e1: Expr, e2: Expr) -> Self {
516        ExprBuilder::new().contains(e1, e2)
517    }
518
519    /// Create a `containsAll` expression. Arguments must evaluate to Set type
520    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
521        ExprBuilder::new().contains_all(e1, e2)
522    }
523
524    /// Create a `containsAny` expression. Arguments must evaluate to Set type
525    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
526        ExprBuilder::new().contains_any(e1, e2)
527    }
528
529    /// Create a `isEmpty` expression. Argument must evaluate to Set type
530    pub fn is_empty(e: Expr) -> Self {
531        ExprBuilder::new().is_empty(e)
532    }
533
534    /// Create a `getTag` expression.
535    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
536    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
537        ExprBuilder::new().get_tag(expr, tag)
538    }
539
540    /// Create a `hasTag` expression.
541    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
542    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
543        ExprBuilder::new().has_tag(expr, tag)
544    }
545
546    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
547    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
548        ExprBuilder::new().set(exprs)
549    }
550
551    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
552    pub fn record(
553        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
554    ) -> Result<Self, ExpressionConstructionError> {
555        ExprBuilder::new().record(pairs)
556    }
557
558    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
559    ///
560    /// If you have an iterator of pairs, generally prefer calling
561    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
562    /// potentially for efficiency reasons but also because `Expr::record()`
563    /// will properly handle duplicate keys but your own `.collect()` will not
564    /// (by default).
565    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
566        ExprBuilder::new().record_arc(map)
567    }
568
569    /// Create an `Expr` which calls the extension function with the given
570    /// `Name` on `args`
571    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
572        ExprBuilder::new().call_extension_fn(fn_name, args)
573    }
574
575    /// Create an application `Expr` which applies the given built-in unary
576    /// operator to the given `arg`
577    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
578        ExprBuilder::new().unary_app(op, arg)
579    }
580
581    /// Create an application `Expr` which applies the given built-in binary
582    /// operator to `arg1` and `arg2`
583    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
584        ExprBuilder::new().binary_app(op, arg1, arg2)
585    }
586
587    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
588    ///
589    /// `expr` must evaluate to either Entity or Record type
590    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
591        ExprBuilder::new().get_attr(expr, attr)
592    }
593
594    /// Create an `Expr` which tests for the existence of a given
595    /// attribute on a given `Entity` or record.
596    ///
597    /// `expr` must evaluate to either Entity or Record type
598    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
599        ExprBuilder::new().has_attr(expr, attr)
600    }
601
602    /// Create a 'like' expression.
603    ///
604    /// `expr` must evaluate to a String type
605    pub fn like(expr: Expr, pattern: Pattern) -> Self {
606        ExprBuilder::new().like(expr, pattern)
607    }
608
609    /// Create an `is` expression.
610    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
611        ExprBuilder::new().is_entity_type(expr, entity_type)
612    }
613
614    /// Check if an expression contains any symbolic unknowns
615    pub fn contains_unknown(&self) -> bool {
616        self.subexpressions()
617            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
618    }
619
620    /// Get all unknowns in an expression
621    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
622        self.subexpressions()
623            .filter_map(|subexpr| match subexpr.expr_kind() {
624                ExprKind::Unknown(u) => Some(u),
625                _ => None,
626            })
627    }
628
629    /// Substitute unknowns with concrete values.
630    ///
631    /// Ignores unmapped unknowns.
632    /// Ignores type annotations on unknowns.
633    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
634        match self.substitute_general::<UntypedSubstitution>(definitions) {
635            Ok(e) => e,
636            Err(empty) => match empty {},
637        }
638    }
639
640    /// Substitute unknowns with concrete values.
641    ///
642    /// Ignores unmapped unknowns.
643    /// Errors if the substituted value does not match the type annotation on the unknown.
644    pub fn substitute_typed(
645        &self,
646        definitions: &HashMap<SmolStr, Value>,
647    ) -> Result<Expr, SubstitutionError> {
648        self.substitute_general::<TypedSubstitution>(definitions)
649    }
650
651    /// Substitute unknowns with values
652    ///
653    /// Generic over the function implementing the substitution to allow for multiple error behaviors
654    fn substitute_general<T: SubstitutionFunction>(
655        &self,
656        definitions: &HashMap<SmolStr, Value>,
657    ) -> Result<Expr, T::Err> {
658        match self.expr_kind() {
659            ExprKind::Lit(_) => Ok(self.clone()),
660            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
661            ExprKind::Var(_) => Ok(self.clone()),
662            ExprKind::Slot(_) => Ok(self.clone()),
663            ExprKind::If {
664                test_expr,
665                then_expr,
666                else_expr,
667            } => Ok(Expr::ite(
668                test_expr.substitute_general::<T>(definitions)?,
669                then_expr.substitute_general::<T>(definitions)?,
670                else_expr.substitute_general::<T>(definitions)?,
671            )),
672            ExprKind::And { left, right } => Ok(Expr::and(
673                left.substitute_general::<T>(definitions)?,
674                right.substitute_general::<T>(definitions)?,
675            )),
676            ExprKind::Or { left, right } => Ok(Expr::or(
677                left.substitute_general::<T>(definitions)?,
678                right.substitute_general::<T>(definitions)?,
679            )),
680            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
681                *op,
682                arg.substitute_general::<T>(definitions)?,
683            )),
684            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
685                *op,
686                arg1.substitute_general::<T>(definitions)?,
687                arg2.substitute_general::<T>(definitions)?,
688            )),
689            ExprKind::ExtensionFunctionApp { fn_name, args } => {
690                let args = args
691                    .iter()
692                    .map(|e| e.substitute_general::<T>(definitions))
693                    .collect::<Result<Vec<Expr>, _>>()?;
694
695                Ok(Expr::call_extension_fn(fn_name.clone(), args))
696            }
697            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
698                expr.substitute_general::<T>(definitions)?,
699                attr.clone(),
700            )),
701            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
702                expr.substitute_general::<T>(definitions)?,
703                attr.clone(),
704            )),
705            ExprKind::Like { expr, pattern } => Ok(Expr::like(
706                expr.substitute_general::<T>(definitions)?,
707                pattern.clone(),
708            )),
709            ExprKind::Set(members) => {
710                let members = members
711                    .iter()
712                    .map(|e| e.substitute_general::<T>(definitions))
713                    .collect::<Result<Vec<_>, _>>()?;
714                Ok(Expr::set(members))
715            }
716            ExprKind::Record(map) => {
717                let map = map
718                    .iter()
719                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
720                    .collect::<Result<BTreeMap<_, _>, _>>()?;
721                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
722                #[allow(clippy::expect_used)]
723                Ok(Expr::record(map)
724                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
725            }
726            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
727                expr.substitute_general::<T>(definitions)?,
728                entity_type.clone(),
729            )),
730            #[cfg(feature = "tolerant-ast")]
731            ExprKind::Error { .. } => Ok(self.clone()),
732        }
733    }
734}
735
736/// A trait for customizing the error behavior of substitution
737trait SubstitutionFunction {
738    /// The potential errors this substitution function can return
739    type Err;
740    /// The function for implementing the substitution.
741    ///
742    /// Takes the expression being substituted,
743    /// The substitution from the map (if present)
744    /// and the type annotation from the unknown (if present)
745    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
746}
747
748struct TypedSubstitution {}
749
750impl SubstitutionFunction for TypedSubstitution {
751    type Err = SubstitutionError;
752
753    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
754        match (substitute, &value.type_annotation) {
755            (None, _) => Ok(Expr::unknown(value.clone())),
756            (Some(v), None) => Ok(v.clone().into()),
757            (Some(v), Some(t)) => {
758                if v.type_of() == *t {
759                    Ok(v.clone().into())
760                } else {
761                    Err(SubstitutionError::TypeError {
762                        expected: t.clone(),
763                        actual: v.type_of(),
764                    })
765                }
766            }
767        }
768    }
769}
770
771struct UntypedSubstitution {}
772
773impl SubstitutionFunction for UntypedSubstitution {
774    type Err = std::convert::Infallible;
775
776    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
777        Ok(substitute
778            .map(|v| v.clone().into())
779            .unwrap_or_else(|| Expr::unknown(value.clone())))
780    }
781}
782
783impl<T: Clone> std::fmt::Display for Expr<T> {
784    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
785        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
786        // we just convert to EST and use the EST pretty-printer.
787        // Note that converting AST->EST is lossless and infallible.
788        write!(f, "{}", crate::est::Expr::from(self.clone()))
789    }
790}
791
792impl<T: Clone> BoundedDisplay for Expr<T> {
793    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
794        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
795        // pretty-printer. Note that converting AST->EST is lossless and infallible.
796        BoundedDisplay::fmt(&crate::est::Expr::from(self.clone()), f, n)
797    }
798}
799
800impl std::str::FromStr for Expr {
801    type Err = ParseErrors;
802
803    fn from_str(s: &str) -> Result<Expr, Self::Err> {
804        crate::parser::parse_expr(s)
805    }
806}
807
808/// Enum for errors encountered during substitution
809#[derive(Debug, Clone, Diagnostic, Error)]
810pub enum SubstitutionError {
811    /// The supplied value did not match the type annotation on the unknown.
812    #[error("expected a value of type {expected}, got a value of type {actual}")]
813    TypeError {
814        /// The expected type, ie: the type the unknown was annotated with
815        expected: Type,
816        /// The type of the provided value
817        actual: Type,
818    },
819}
820
821/// Representation of a partial-evaluation Unknown at the AST level
822#[derive(Hash, Debug, Clone, PartialEq, Eq)]
823pub struct Unknown {
824    /// The name of the unknown
825    pub name: SmolStr,
826    /// The type of the values that can be substituted in for the unknown.
827    /// If `None`, we have no type annotation, and thus a value of any type can
828    /// be substituted.
829    pub type_annotation: Option<Type>,
830}
831
832impl Unknown {
833    /// Create a new untyped `Unknown`
834    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
835        Self {
836            name: name.into(),
837            type_annotation: None,
838        }
839    }
840
841    /// Create a new `Unknown` with type annotation. (Only values of the given
842    /// type can be substituted.)
843    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
844        Self {
845            name: name.into(),
846            type_annotation: Some(ty),
847        }
848    }
849}
850
851impl std::fmt::Display for Unknown {
852    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
853        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
854        // to avoid code duplication
855        write!(f, "{}", crate::est::Expr::from(Expr::unknown(self.clone())))
856    }
857}
858
859/// Builder for constructing `Expr` objects annotated with some `data`
860/// (possibly taking default value) and optionally a `source_loc`.
861#[derive(Clone, Debug)]
862pub struct ExprBuilder<T> {
863    source_loc: Option<Loc>,
864    data: T,
865}
866
867impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
868    type Expr = Expr<T>;
869
870    type Data = T;
871
872    #[cfg(feature = "tolerant-ast")]
873    type ErrorType = ParseErrors;
874
875    fn loc(&self) -> Option<&Loc> {
876        self.source_loc.as_ref()
877    }
878
879    fn data(&self) -> &Self::Data {
880        &self.data
881    }
882
883    fn with_data(data: T) -> Self {
884        Self {
885            source_loc: None,
886            data,
887        }
888    }
889
890    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
891        self.source_loc = maybe_source_loc.cloned();
892        self
893    }
894
895    /// Create an `Expr` that's just a single `Literal`.
896    ///
897    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
898    fn val(self, v: impl Into<Literal>) -> Expr<T> {
899        self.with_expr_kind(ExprKind::Lit(v.into()))
900    }
901
902    /// Create an `Unknown` `Expr`
903    fn unknown(self, u: Unknown) -> Expr<T> {
904        self.with_expr_kind(ExprKind::Unknown(u))
905    }
906
907    /// Create an `Expr` that's just this literal `Var`
908    fn var(self, v: Var) -> Expr<T> {
909        self.with_expr_kind(ExprKind::Var(v))
910    }
911
912    /// Create an `Expr` that's just this `SlotId`
913    fn slot(self, s: SlotId) -> Expr<T> {
914        self.with_expr_kind(ExprKind::Slot(s))
915    }
916
917    /// Create a ternary (if-then-else) `Expr`.
918    ///
919    /// `test_expr` must evaluate to a Bool type
920    fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
921        self.with_expr_kind(ExprKind::If {
922            test_expr: Arc::new(test_expr),
923            then_expr: Arc::new(then_expr),
924            else_expr: Arc::new(else_expr),
925        })
926    }
927
928    /// Create a 'not' expression. `e` must evaluate to Bool type
929    fn not(self, e: Expr<T>) -> Expr<T> {
930        self.with_expr_kind(ExprKind::UnaryApp {
931            op: UnaryOp::Not,
932            arg: Arc::new(e),
933        })
934    }
935
936    /// Create a '==' expression
937    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
938        self.with_expr_kind(ExprKind::BinaryApp {
939            op: BinaryOp::Eq,
940            arg1: Arc::new(e1),
941            arg2: Arc::new(e2),
942        })
943    }
944
945    /// Create an 'and' expression. Arguments must evaluate to Bool type
946    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
947        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
948            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
949                ExprKind::Lit(Literal::Bool(*b1 && *b2))
950            }
951            _ => ExprKind::And {
952                left: Arc::new(e1),
953                right: Arc::new(e2),
954            },
955        })
956    }
957
958    /// Create an 'or' expression. Arguments must evaluate to Bool type
959    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
960        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
961            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
962                ExprKind::Lit(Literal::Bool(*b1 || *b2))
963            }
964
965            _ => ExprKind::Or {
966                left: Arc::new(e1),
967                right: Arc::new(e2),
968            },
969        })
970    }
971
972    /// Create a '<' expression. Arguments must evaluate to Long type
973    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
974        self.with_expr_kind(ExprKind::BinaryApp {
975            op: BinaryOp::Less,
976            arg1: Arc::new(e1),
977            arg2: Arc::new(e2),
978        })
979    }
980
981    /// Create a '<=' expression. Arguments must evaluate to Long type
982    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
983        self.with_expr_kind(ExprKind::BinaryApp {
984            op: BinaryOp::LessEq,
985            arg1: Arc::new(e1),
986            arg2: Arc::new(e2),
987        })
988    }
989
990    /// Create an 'add' expression. Arguments must evaluate to Long type
991    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
992        self.with_expr_kind(ExprKind::BinaryApp {
993            op: BinaryOp::Add,
994            arg1: Arc::new(e1),
995            arg2: Arc::new(e2),
996        })
997    }
998
999    /// Create a 'sub' expression. Arguments must evaluate to Long type
1000    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1001        self.with_expr_kind(ExprKind::BinaryApp {
1002            op: BinaryOp::Sub,
1003            arg1: Arc::new(e1),
1004            arg2: Arc::new(e2),
1005        })
1006    }
1007
1008    /// Create a 'mul' expression. Arguments must evaluate to Long type
1009    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1010        self.with_expr_kind(ExprKind::BinaryApp {
1011            op: BinaryOp::Mul,
1012            arg1: Arc::new(e1),
1013            arg2: Arc::new(e2),
1014        })
1015    }
1016
1017    /// Create a 'neg' expression. `e` must evaluate to Long type.
1018    fn neg(self, e: Expr<T>) -> Expr<T> {
1019        self.with_expr_kind(ExprKind::UnaryApp {
1020            op: UnaryOp::Neg,
1021            arg: Arc::new(e),
1022        })
1023    }
1024
1025    /// Create an 'in' expression. First argument must evaluate to Entity type.
1026    /// Second argument must evaluate to either Entity type or Set type where
1027    /// all set elements have Entity type.
1028    fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1029        self.with_expr_kind(ExprKind::BinaryApp {
1030            op: BinaryOp::In,
1031            arg1: Arc::new(e1),
1032            arg2: Arc::new(e2),
1033        })
1034    }
1035
1036    /// Create a 'contains' expression.
1037    /// First argument must have Set type.
1038    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1039        self.with_expr_kind(ExprKind::BinaryApp {
1040            op: BinaryOp::Contains,
1041            arg1: Arc::new(e1),
1042            arg2: Arc::new(e2),
1043        })
1044    }
1045
1046    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1047    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1048        self.with_expr_kind(ExprKind::BinaryApp {
1049            op: BinaryOp::ContainsAll,
1050            arg1: Arc::new(e1),
1051            arg2: Arc::new(e2),
1052        })
1053    }
1054
1055    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1056    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1057        self.with_expr_kind(ExprKind::BinaryApp {
1058            op: BinaryOp::ContainsAny,
1059            arg1: Arc::new(e1),
1060            arg2: Arc::new(e2),
1061        })
1062    }
1063
1064    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1065    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1066        self.with_expr_kind(ExprKind::UnaryApp {
1067            op: UnaryOp::IsEmpty,
1068            arg: Arc::new(expr),
1069        })
1070    }
1071
1072    /// Create a 'getTag' expression.
1073    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1074    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1075        self.with_expr_kind(ExprKind::BinaryApp {
1076            op: BinaryOp::GetTag,
1077            arg1: Arc::new(expr),
1078            arg2: Arc::new(tag),
1079        })
1080    }
1081
1082    /// Create a 'hasTag' expression.
1083    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1084    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1085        self.with_expr_kind(ExprKind::BinaryApp {
1086            op: BinaryOp::HasTag,
1087            arg1: Arc::new(expr),
1088            arg2: Arc::new(tag),
1089        })
1090    }
1091
1092    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1093    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1094        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1095    }
1096
1097    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1098    fn record(
1099        self,
1100        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1101    ) -> Result<Expr<T>, ExpressionConstructionError> {
1102        let mut map = BTreeMap::new();
1103        for (k, v) in pairs {
1104            match map.entry(k) {
1105                btree_map::Entry::Occupied(oentry) => {
1106                    return Err(expression_construction_errors::DuplicateKeyError {
1107                        key: oentry.key().clone(),
1108                        context: "in record literal",
1109                    }
1110                    .into());
1111                }
1112                btree_map::Entry::Vacant(ventry) => {
1113                    ventry.insert(v);
1114                }
1115            }
1116        }
1117        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1118    }
1119
1120    /// Create an `Expr` which calls the extension function with the given
1121    /// `Name` on `args`
1122    fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1123        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1124            fn_name,
1125            args: Arc::new(args.into_iter().collect()),
1126        })
1127    }
1128
1129    /// Create an application `Expr` which applies the given built-in unary
1130    /// operator to the given `arg`
1131    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1132        self.with_expr_kind(ExprKind::UnaryApp {
1133            op: op.into(),
1134            arg: Arc::new(arg),
1135        })
1136    }
1137
1138    /// Create an application `Expr` which applies the given built-in binary
1139    /// operator to `arg1` and `arg2`
1140    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1141        self.with_expr_kind(ExprKind::BinaryApp {
1142            op: op.into(),
1143            arg1: Arc::new(arg1),
1144            arg2: Arc::new(arg2),
1145        })
1146    }
1147
1148    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1149    ///
1150    /// `expr` must evaluate to either Entity or Record type
1151    fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1152        self.with_expr_kind(ExprKind::GetAttr {
1153            expr: Arc::new(expr),
1154            attr,
1155        })
1156    }
1157
1158    /// Create an `Expr` which tests for the existence of a given
1159    /// attribute on a given `Entity` or record.
1160    ///
1161    /// `expr` must evaluate to either Entity or Record type
1162    fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1163        self.with_expr_kind(ExprKind::HasAttr {
1164            expr: Arc::new(expr),
1165            attr,
1166        })
1167    }
1168
1169    /// Create a 'like' expression.
1170    ///
1171    /// `expr` must evaluate to a String type
1172    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1173        self.with_expr_kind(ExprKind::Like {
1174            expr: Arc::new(expr),
1175            pattern,
1176        })
1177    }
1178
1179    /// Create an 'is' expression.
1180    fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1181        self.with_expr_kind(ExprKind::Is {
1182            expr: Arc::new(expr),
1183            entity_type,
1184        })
1185    }
1186
1187    /// Don't support AST Error nodes - return the error right back
1188    #[cfg(feature = "tolerant-ast")]
1189    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1190        Err(parse_errors)
1191    }
1192}
1193
1194impl<T> ExprBuilder<T> {
1195    /// Construct an `Expr` containing the `data` and `source_loc` in this
1196    /// `ExprBuilder` and the given `ExprKind`.
1197    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1198        Expr::new(expr_kind, self.source_loc, self.data)
1199    }
1200
1201    /// Create a ternary (if-then-else) `Expr`.
1202    /// Takes `Arc`s instead of owned `Expr`s.
1203    /// `test_expr` must evaluate to a Bool type
1204    pub fn ite_arc(
1205        self,
1206        test_expr: Arc<Expr<T>>,
1207        then_expr: Arc<Expr<T>>,
1208        else_expr: Arc<Expr<T>>,
1209    ) -> Expr<T> {
1210        self.with_expr_kind(ExprKind::If {
1211            test_expr,
1212            then_expr,
1213            else_expr,
1214        })
1215    }
1216
1217    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1218    ///
1219    /// If you have an iterator of pairs, generally prefer calling `.record()`
1220    /// instead of `.collect()`-ing yourself and calling this, potentially for
1221    /// efficiency reasons but also because `.record()` will properly handle
1222    /// duplicate keys but your own `.collect()` will not (by default).
1223    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1224        self.with_expr_kind(ExprKind::Record(map))
1225    }
1226}
1227
1228impl<T: Clone + Default> ExprBuilder<T> {
1229    /// Utility used the validator to get an expression with the same source
1230    /// location as an existing expression. This is done when reconstructing the
1231    /// `Expr` with type information.
1232    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1233        self.with_maybe_source_loc(expr.source_loc.as_ref())
1234    }
1235}
1236
1237/// Errors when constructing an expression
1238//
1239// CAUTION: this type is publicly exported in `cedar-policy`.
1240// Don't make fields `pub`, don't make breaking changes, and use caution
1241// when adding public methods.
1242#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1243pub enum ExpressionConstructionError {
1244    /// The same key occurred two or more times
1245    #[error(transparent)]
1246    #[diagnostic(transparent)]
1247    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1248}
1249
1250/// Error subtypes for [`ExpressionConstructionError`]
1251pub mod expression_construction_errors {
1252    use miette::Diagnostic;
1253    use smol_str::SmolStr;
1254    use thiserror::Error;
1255
1256    /// The same key occurred two or more times
1257    //
1258    // CAUTION: this type is publicly exported in `cedar-policy`.
1259    // Don't make fields `pub`, don't make breaking changes, and use caution
1260    // when adding public methods.
1261    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1262    #[error("duplicate key `{key}` {context}")]
1263    pub struct DuplicateKeyError {
1264        /// The key which occurred two or more times
1265        pub(crate) key: SmolStr,
1266        /// Information about where the duplicate key occurred (e.g., "in record literal")
1267        pub(crate) context: &'static str,
1268    }
1269
1270    impl DuplicateKeyError {
1271        /// Get the key which occurred two or more times
1272        pub fn key(&self) -> &str {
1273            &self.key
1274        }
1275
1276        /// Make a new error with an updated `context` field
1277        pub(crate) fn with_context(self, context: &'static str) -> Self {
1278            Self { context, ..self }
1279        }
1280    }
1281}
1282
1283/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1284/// implementations that ignore any source information or other generic data
1285/// used to annotate the `Expr`.
1286#[derive(Eq, Debug, Clone)]
1287pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1288
1289impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1290    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1291    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1292    /// ignore source information and generic data.
1293    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1294        ExprShapeOnly(Cow::Borrowed(e))
1295    }
1296
1297    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1298    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1299    /// ignore source information and generic data.
1300    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1301        ExprShapeOnly(Cow::Owned(e))
1302    }
1303}
1304
1305impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1306    fn eq(&self, other: &Self) -> bool {
1307        self.0.eq_shape(&other.0)
1308    }
1309}
1310
1311impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1312    fn hash<H: Hasher>(&self, state: &mut H) {
1313        self.0.hash_shape(state);
1314    }
1315}
1316
1317impl<T> Expr<T> {
1318    /// Return true if this expression (recursively) has the same expression
1319    /// kind as the argument expression. This accounts for the full recursive
1320    /// shape of the expression, but does not consider source information or any
1321    /// generic data annotated on expression. This should behave the same as the
1322    /// default implementation of `Eq` before source information and generic
1323    /// data were added.
1324    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1325        use ExprKind::*;
1326        match (self.expr_kind(), other.expr_kind()) {
1327            (Lit(lit), Lit(lit1)) => lit == lit1,
1328            (Var(v), Var(v1)) => v == v1,
1329            (Slot(s), Slot(s1)) => s == s1,
1330            (
1331                Unknown(self::Unknown {
1332                    name: name1,
1333                    type_annotation: ta_1,
1334                }),
1335                Unknown(self::Unknown {
1336                    name: name2,
1337                    type_annotation: ta_2,
1338                }),
1339            ) => (name1 == name2) && (ta_1 == ta_2),
1340            (
1341                If {
1342                    test_expr,
1343                    then_expr,
1344                    else_expr,
1345                },
1346                If {
1347                    test_expr: test_expr1,
1348                    then_expr: then_expr1,
1349                    else_expr: else_expr1,
1350                },
1351            ) => {
1352                test_expr.eq_shape(test_expr1)
1353                    && then_expr.eq_shape(then_expr1)
1354                    && else_expr.eq_shape(else_expr1)
1355            }
1356            (
1357                And { left, right },
1358                And {
1359                    left: left1,
1360                    right: right1,
1361                },
1362            )
1363            | (
1364                Or { left, right },
1365                Or {
1366                    left: left1,
1367                    right: right1,
1368                },
1369            ) => left.eq_shape(left1) && right.eq_shape(right1),
1370            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1371                op == op1 && arg.eq_shape(arg1)
1372            }
1373            (
1374                BinaryApp { op, arg1, arg2 },
1375                BinaryApp {
1376                    op: op1,
1377                    arg1: arg11,
1378                    arg2: arg21,
1379                },
1380            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1381            (
1382                ExtensionFunctionApp { fn_name, args },
1383                ExtensionFunctionApp {
1384                    fn_name: fn_name1,
1385                    args: args1,
1386                },
1387            ) => {
1388                fn_name == fn_name1
1389                    && args.len() == args1.len()
1390                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1391            }
1392            (
1393                GetAttr { expr, attr },
1394                GetAttr {
1395                    expr: expr1,
1396                    attr: attr1,
1397                },
1398            )
1399            | (
1400                HasAttr { expr, attr },
1401                HasAttr {
1402                    expr: expr1,
1403                    attr: attr1,
1404                },
1405            ) => attr == attr1 && expr.eq_shape(expr1),
1406            (
1407                Like { expr, pattern },
1408                Like {
1409                    expr: expr1,
1410                    pattern: pattern1,
1411                },
1412            ) => pattern == pattern1 && expr.eq_shape(expr1),
1413            (Set(elems), Set(elems1)) => {
1414                elems.len() == elems1.len()
1415                    && elems
1416                        .iter()
1417                        .zip(elems1.iter())
1418                        .all(|(e, e1)| e.eq_shape(e1))
1419            }
1420            (Record(map), Record(map1)) => {
1421                map.len() == map1.len()
1422                    && map
1423                        .iter()
1424                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1425                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1426            }
1427            (
1428                Is { expr, entity_type },
1429                Is {
1430                    expr: expr1,
1431                    entity_type: entity_type1,
1432                },
1433            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1434            _ => false,
1435        }
1436    }
1437
1438    /// Implementation of hashing corresponding to equality as implemented by
1439    /// `eq_shape`. Must satisfy the usual relationship between equality and
1440    /// hashing.
1441    pub fn hash_shape<H>(&self, state: &mut H)
1442    where
1443        H: Hasher,
1444    {
1445        mem::discriminant(self).hash(state);
1446        match self.expr_kind() {
1447            ExprKind::Lit(lit) => lit.hash(state),
1448            ExprKind::Var(v) => v.hash(state),
1449            ExprKind::Slot(s) => s.hash(state),
1450            ExprKind::Unknown(u) => u.hash(state),
1451            ExprKind::If {
1452                test_expr,
1453                then_expr,
1454                else_expr,
1455            } => {
1456                test_expr.hash_shape(state);
1457                then_expr.hash_shape(state);
1458                else_expr.hash_shape(state);
1459            }
1460            ExprKind::And { left, right } => {
1461                left.hash_shape(state);
1462                right.hash_shape(state);
1463            }
1464            ExprKind::Or { left, right } => {
1465                left.hash_shape(state);
1466                right.hash_shape(state);
1467            }
1468            ExprKind::UnaryApp { op, arg } => {
1469                op.hash(state);
1470                arg.hash_shape(state);
1471            }
1472            ExprKind::BinaryApp { op, arg1, arg2 } => {
1473                op.hash(state);
1474                arg1.hash_shape(state);
1475                arg2.hash_shape(state);
1476            }
1477            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1478                fn_name.hash(state);
1479                state.write_usize(args.len());
1480                args.iter().for_each(|a| {
1481                    a.hash_shape(state);
1482                });
1483            }
1484            ExprKind::GetAttr { expr, attr } => {
1485                expr.hash_shape(state);
1486                attr.hash(state);
1487            }
1488            ExprKind::HasAttr { expr, attr } => {
1489                expr.hash_shape(state);
1490                attr.hash(state);
1491            }
1492            ExprKind::Like { expr, pattern } => {
1493                expr.hash_shape(state);
1494                pattern.hash(state);
1495            }
1496            ExprKind::Set(elems) => {
1497                state.write_usize(elems.len());
1498                elems.iter().for_each(|e| {
1499                    e.hash_shape(state);
1500                })
1501            }
1502            ExprKind::Record(map) => {
1503                state.write_usize(map.len());
1504                map.iter().for_each(|(s, a)| {
1505                    s.hash(state);
1506                    a.hash_shape(state);
1507                });
1508            }
1509            ExprKind::Is { expr, entity_type } => {
1510                expr.hash_shape(state);
1511                entity_type.hash(state);
1512            }
1513            #[cfg(feature = "tolerant-ast")]
1514            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1515        }
1516    }
1517}
1518
1519/// AST variables
1520#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1521#[serde(rename_all = "camelCase")]
1522#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1523#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1524#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1525pub enum Var {
1526    /// the Principal of the given request
1527    Principal,
1528    /// the Action of the given request
1529    Action,
1530    /// the Resource of the given request
1531    Resource,
1532    /// the Context of the given request
1533    Context,
1534}
1535
1536impl From<PrincipalOrResource> for Var {
1537    fn from(v: PrincipalOrResource) -> Self {
1538        match v {
1539            PrincipalOrResource::Principal => Var::Principal,
1540            PrincipalOrResource::Resource => Var::Resource,
1541        }
1542    }
1543}
1544
1545// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1546#[allow(clippy::fallible_impl_from)]
1547impl From<Var> for Id {
1548    fn from(var: Var) -> Self {
1549        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1550        #[allow(clippy::unwrap_used)]
1551        format!("{var}").parse().unwrap()
1552    }
1553}
1554
1555// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1556#[allow(clippy::fallible_impl_from)]
1557impl From<Var> for UnreservedId {
1558    fn from(var: Var) -> Self {
1559        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`
1560        #[allow(clippy::unwrap_used)]
1561        Id::from(var).try_into().unwrap()
1562    }
1563}
1564
1565impl std::fmt::Display for Var {
1566    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1567        match self {
1568            Self::Principal => write!(f, "principal"),
1569            Self::Action => write!(f, "action"),
1570            Self::Resource => write!(f, "resource"),
1571            Self::Context => write!(f, "context"),
1572        }
1573    }
1574}
1575
1576#[cfg(test)]
1577mod test {
1578    use cool_asserts::assert_matches;
1579    use itertools::Itertools;
1580    use smol_str::ToSmolStr;
1581    use std::collections::{hash_map::DefaultHasher, HashSet};
1582
1583    use crate::expr_builder::ExprBuilder as _;
1584
1585    use super::*;
1586
1587    pub fn all_vars() -> impl Iterator<Item = Var> {
1588        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1589    }
1590
1591    // Tests that Var::Into never panics
1592    #[test]
1593    fn all_vars_are_ids() {
1594        for var in all_vars() {
1595            let _id: Id = var.into();
1596            let _id: UnreservedId = var.into();
1597        }
1598    }
1599
1600    #[test]
1601    fn exprs() {
1602        assert_eq!(
1603            Expr::val(33),
1604            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1605        );
1606        assert_eq!(
1607            Expr::val("hello"),
1608            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1609        );
1610        assert_eq!(
1611            Expr::val(EntityUID::with_eid("foo")),
1612            Expr::new(
1613                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1614                None,
1615                ()
1616            )
1617        );
1618        assert_eq!(
1619            Expr::var(Var::Principal),
1620            Expr::new(ExprKind::Var(Var::Principal), None, ())
1621        );
1622        assert_eq!(
1623            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1624            Expr::new(
1625                ExprKind::If {
1626                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1627                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1628                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1629                },
1630                None,
1631                ()
1632            )
1633        );
1634        assert_eq!(
1635            Expr::not(Expr::val(false)),
1636            Expr::new(
1637                ExprKind::UnaryApp {
1638                    op: UnaryOp::Not,
1639                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1640                },
1641                None,
1642                ()
1643            )
1644        );
1645        assert_eq!(
1646            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1647            Expr::new(
1648                ExprKind::GetAttr {
1649                    expr: Arc::new(Expr::new(
1650                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1651                        None,
1652                        ()
1653                    )),
1654                    attr: "some_attr".into()
1655                },
1656                None,
1657                ()
1658            )
1659        );
1660        assert_eq!(
1661            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1662            Expr::new(
1663                ExprKind::HasAttr {
1664                    expr: Arc::new(Expr::new(
1665                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1666                        None,
1667                        ()
1668                    )),
1669                    attr: "some_attr".into()
1670                },
1671                None,
1672                ()
1673            )
1674        );
1675        assert_eq!(
1676            Expr::is_entity_type(
1677                Expr::val(EntityUID::with_eid("foo")),
1678                "Type".parse().unwrap()
1679            ),
1680            Expr::new(
1681                ExprKind::Is {
1682                    expr: Arc::new(Expr::new(
1683                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1684                        None,
1685                        ()
1686                    )),
1687                    entity_type: "Type".parse().unwrap()
1688                },
1689                None,
1690                ()
1691            ),
1692        );
1693    }
1694
1695    #[test]
1696    fn like_display() {
1697        // `\0` escaped form is `\0`.
1698        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
1699        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1700        // `\`'s escaped form is `\\`
1701        let e = Expr::like(
1702            Expr::val("a"),
1703            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
1704        );
1705        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1706        // `\`'s escaped form is `\\`
1707        let e = Expr::like(
1708            Expr::val("a"),
1709            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
1710        );
1711        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1712        // literal star's escaped from is `\*`
1713        let e = Expr::like(
1714            Expr::val("a"),
1715            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
1716        );
1717        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1718    }
1719
1720    #[test]
1721    fn has_display() {
1722        // `\0` escaped form is `\0`.
1723        let e = Expr::has_attr(Expr::val("a"), "\0".into());
1724        assert_eq!(format!("{e}"), r#""a" has "\0""#);
1725        // `\`'s escaped form is `\\`
1726        let e = Expr::has_attr(Expr::val("a"), r"\".into());
1727        assert_eq!(format!("{e}"), r#""a" has "\\""#);
1728    }
1729
1730    #[test]
1731    fn slot_display() {
1732        let e = Expr::slot(SlotId::principal());
1733        assert_eq!(format!("{e}"), "?principal");
1734        let e = Expr::slot(SlotId::resource());
1735        assert_eq!(format!("{e}"), "?resource");
1736        let e = Expr::val(EntityUID::with_eid("eid"));
1737        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1738    }
1739
1740    #[test]
1741    fn simple_slots() {
1742        let e = Expr::slot(SlotId::principal());
1743        let p = SlotId::principal();
1744        let r = SlotId::resource();
1745        let set: HashSet<SlotId> = HashSet::from_iter([p]);
1746        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1747        let e = Expr::or(
1748            Expr::slot(SlotId::principal()),
1749            Expr::ite(
1750                Expr::val(true),
1751                Expr::slot(SlotId::resource()),
1752                Expr::val(false),
1753            ),
1754        );
1755        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
1756        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1757    }
1758
1759    #[test]
1760    fn unknowns() {
1761        let e = Expr::ite(
1762            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1763            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1764            Expr::unknown(Unknown::new_untyped("c")),
1765        );
1766        let unknowns = e.unknowns().collect_vec();
1767        assert_eq!(unknowns.len(), 3);
1768        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
1769        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
1770        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
1771    }
1772
1773    #[test]
1774    fn is_unknown() {
1775        let e = Expr::ite(
1776            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1777            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1778            Expr::unknown(Unknown::new_untyped("c")),
1779        );
1780        assert!(e.contains_unknown());
1781        let e = Expr::ite(
1782            Expr::not(Expr::val(true)),
1783            Expr::and(Expr::val(1), Expr::val(3)),
1784            Expr::val(1),
1785        );
1786        assert!(!e.contains_unknown());
1787    }
1788
1789    #[test]
1790    fn expr_with_data() {
1791        let e = ExprBuilder::with_data("data").val(1);
1792        assert_eq!(e.into_data(), "data");
1793    }
1794
1795    #[test]
1796    fn expr_shape_only_eq() {
1797        let temp = ExprBuilder::with_data(1).val(1);
1798        let exprs = &[
1799            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1800            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1801            (
1802                ExprBuilder::with_data(1).var(Var::Principal),
1803                Expr::var(Var::Principal),
1804            ),
1805            (
1806                ExprBuilder::with_data(1).slot(SlotId::principal()),
1807                Expr::slot(SlotId::principal()),
1808            ),
1809            (
1810                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1811                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1812            ),
1813            (
1814                ExprBuilder::with_data(1).not(temp.clone()),
1815                Expr::not(Expr::val(1)),
1816            ),
1817            (
1818                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1819                Expr::is_eq(Expr::val(1), Expr::val(1)),
1820            ),
1821            (
1822                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1823                Expr::and(Expr::val(1), Expr::val(1)),
1824            ),
1825            (
1826                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1827                Expr::or(Expr::val(1), Expr::val(1)),
1828            ),
1829            (
1830                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1831                Expr::less(Expr::val(1), Expr::val(1)),
1832            ),
1833            (
1834                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1835                Expr::lesseq(Expr::val(1), Expr::val(1)),
1836            ),
1837            (
1838                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1839                Expr::greater(Expr::val(1), Expr::val(1)),
1840            ),
1841            (
1842                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1843                Expr::greatereq(Expr::val(1), Expr::val(1)),
1844            ),
1845            (
1846                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1847                Expr::add(Expr::val(1), Expr::val(1)),
1848            ),
1849            (
1850                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1851                Expr::sub(Expr::val(1), Expr::val(1)),
1852            ),
1853            (
1854                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
1855                Expr::mul(Expr::val(1), Expr::val(1)),
1856            ),
1857            (
1858                ExprBuilder::with_data(1).neg(temp.clone()),
1859                Expr::neg(Expr::val(1)),
1860            ),
1861            (
1862                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1863                Expr::is_in(Expr::val(1), Expr::val(1)),
1864            ),
1865            (
1866                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1867                Expr::contains(Expr::val(1), Expr::val(1)),
1868            ),
1869            (
1870                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1871                Expr::contains_all(Expr::val(1), Expr::val(1)),
1872            ),
1873            (
1874                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1875                Expr::contains_any(Expr::val(1), Expr::val(1)),
1876            ),
1877            (
1878                ExprBuilder::with_data(1).is_empty(temp.clone()),
1879                Expr::is_empty(Expr::val(1)),
1880            ),
1881            (
1882                ExprBuilder::with_data(1).set([temp.clone()]),
1883                Expr::set([Expr::val(1)]),
1884            ),
1885            (
1886                ExprBuilder::with_data(1)
1887                    .record([("foo".into(), temp.clone())])
1888                    .unwrap(),
1889                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
1890            ),
1891            (
1892                ExprBuilder::with_data(1)
1893                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1894                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1895            ),
1896            (
1897                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1898                Expr::get_attr(Expr::val(1), "foo".into()),
1899            ),
1900            (
1901                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
1902                Expr::has_attr(Expr::val(1), "foo".into()),
1903            ),
1904            (
1905                ExprBuilder::with_data(1)
1906                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
1907                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
1908            ),
1909            (
1910                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
1911                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
1912            ),
1913        ];
1914
1915        for (e0, e1) in exprs {
1916            assert!(e0.eq_shape(e0));
1917            assert!(e1.eq_shape(e1));
1918            assert!(e0.eq_shape(e1));
1919            assert!(e1.eq_shape(e0));
1920
1921            let mut hasher0 = DefaultHasher::new();
1922            e0.hash_shape(&mut hasher0);
1923            let hash0 = hasher0.finish();
1924
1925            let mut hasher1 = DefaultHasher::new();
1926            e1.hash_shape(&mut hasher1);
1927            let hash1 = hasher1.finish();
1928
1929            assert_eq!(hash0, hash1);
1930        }
1931    }
1932
1933    #[test]
1934    fn expr_shape_only_not_eq() {
1935        let expr1 = ExprBuilder::with_data(1).val(1);
1936        let expr2 = ExprBuilder::with_data(1).val(2);
1937        assert_ne!(
1938            ExprShapeOnly::new_from_borrowed(&expr1),
1939            ExprShapeOnly::new_from_borrowed(&expr2)
1940        );
1941    }
1942
1943    #[test]
1944    fn expr_shape_only_set_prefix_ne() {
1945        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
1946        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
1947        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
1948
1949        assert_ne!(e1, e2);
1950        assert_ne!(e1, e3);
1951        assert_ne!(e2, e1);
1952        assert_ne!(e2, e3);
1953        assert_ne!(e3, e1);
1954        assert_ne!(e2, e1);
1955    }
1956
1957    #[test]
1958    fn expr_shape_only_ext_fn_arg_prefix_ne() {
1959        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
1960            "decimal".parse().unwrap(),
1961            vec![],
1962        ));
1963        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
1964            "decimal".parse().unwrap(),
1965            vec![Expr::val("0.0")],
1966        ));
1967        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
1968            "decimal".parse().unwrap(),
1969            vec![Expr::val("0.0"), Expr::val("0.0")],
1970        ));
1971
1972        assert_ne!(e1, e2);
1973        assert_ne!(e1, e3);
1974        assert_ne!(e2, e1);
1975        assert_ne!(e2, e3);
1976        assert_ne!(e3, e1);
1977        assert_ne!(e2, e1);
1978    }
1979
1980    #[test]
1981    fn expr_shape_only_record_attr_prefix_ne() {
1982        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
1983        let e2 = ExprShapeOnly::new_from_owned(
1984            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
1985        );
1986        let e3 = ExprShapeOnly::new_from_owned(
1987            Expr::record([
1988                ("a".to_smolstr(), Expr::val(1)),
1989                ("b".to_smolstr(), Expr::val(2)),
1990            ])
1991            .unwrap(),
1992        );
1993
1994        assert_ne!(e1, e2);
1995        assert_ne!(e1, e3);
1996        assert_ne!(e2, e1);
1997        assert_ne!(e2, e3);
1998        assert_ne!(e3, e1);
1999        assert_ne!(e2, e1);
2000    }
2001
2002    #[test]
2003    fn untyped_subst_present() {
2004        let u = Unknown {
2005            name: "foo".into(),
2006            type_annotation: None,
2007        };
2008        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2009        match r {
2010            Ok(e) => assert_eq!(e, Expr::val(1)),
2011            Err(empty) => match empty {},
2012        }
2013    }
2014
2015    #[test]
2016    fn untyped_subst_present_correct_type() {
2017        let u = Unknown {
2018            name: "foo".into(),
2019            type_annotation: Some(Type::Long),
2020        };
2021        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2022        match r {
2023            Ok(e) => assert_eq!(e, Expr::val(1)),
2024            Err(empty) => match empty {},
2025        }
2026    }
2027
2028    #[test]
2029    fn untyped_subst_present_wrong_type() {
2030        let u = Unknown {
2031            name: "foo".into(),
2032            type_annotation: Some(Type::Bool),
2033        };
2034        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2035        match r {
2036            Ok(e) => assert_eq!(e, Expr::val(1)),
2037            Err(empty) => match empty {},
2038        }
2039    }
2040
2041    #[test]
2042    fn untyped_subst_not_present() {
2043        let u = Unknown {
2044            name: "foo".into(),
2045            type_annotation: Some(Type::Bool),
2046        };
2047        let r = UntypedSubstitution::substitute(&u, None);
2048        match r {
2049            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2050            Err(empty) => match empty {},
2051        }
2052    }
2053
2054    #[test]
2055    fn typed_subst_present() {
2056        let u = Unknown {
2057            name: "foo".into(),
2058            type_annotation: None,
2059        };
2060        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2061        assert_eq!(e, Expr::val(1));
2062    }
2063
2064    #[test]
2065    fn typed_subst_present_correct_type() {
2066        let u = Unknown {
2067            name: "foo".into(),
2068            type_annotation: Some(Type::Long),
2069        };
2070        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2071        assert_eq!(e, Expr::val(1));
2072    }
2073
2074    #[test]
2075    fn typed_subst_present_wrong_type() {
2076        let u = Unknown {
2077            name: "foo".into(),
2078            type_annotation: Some(Type::Bool),
2079        };
2080        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2081        assert_matches!(
2082            r,
2083            SubstitutionError::TypeError {
2084                expected: Type::Bool,
2085                actual: Type::Long,
2086            }
2087        );
2088    }
2089
2090    #[test]
2091    fn typed_subst_not_present() {
2092        let u = Unknown {
2093            name: "foo".into(),
2094            type_annotation: None,
2095        };
2096        let r = TypedSubstitution::substitute(&u, None).unwrap();
2097        assert_eq!(r, Expr::unknown(u));
2098    }
2099}