cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use {
19    super::expr_allows_errors::AstExprErrorKind, crate::parser::err::ToASTError,
20    crate::parser::err::ToASTErrorKind,
21};
22
23use crate::{
24    ast::*,
25    expr_builder::{self, ExprBuilder as _},
26    extensions::Extensions,
27    parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34    borrow::Cow,
35    collections::{btree_map, BTreeMap, HashMap},
36    hash::{Hash, Hasher},
37    mem,
38    sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45/// Internal AST for expressions used by the policy evaluator.
46/// This structure is a wrapper around an `ExprKind`, which is the expression
47/// variant this object contains. It also contains source information about
48/// where the expression was written in policy source code, and some generic
49/// data which is stored on each node of the AST.
50/// Cloning is O(1).
51#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54    expr_kind: ExprKind<T>,
55    #[educe(PartialEq(ignore))]
56    #[educe(Hash(ignore))]
57    source_loc: Option<Loc>,
58    data: T,
59}
60
61/// The possible expression variants. This enum should be matched on by code
62/// recursively traversing the AST.
63#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65    /// Literal value
66    Lit(Literal),
67    /// Variable
68    Var(Var),
69    /// Template Slots
70    Slot(SlotId),
71    /// Symbolic Unknown for partial-eval
72    Unknown(Unknown),
73    /// Ternary expression
74    If {
75        /// Condition for the ternary expression. Must evaluate to Bool type
76        test_expr: Arc<Expr<T>>,
77        /// Value if true
78        then_expr: Arc<Expr<T>>,
79        /// Value if false
80        else_expr: Arc<Expr<T>>,
81    },
82    /// Boolean AND
83    And {
84        /// Left operand, which will be eagerly evaluated
85        left: Arc<Expr<T>>,
86        /// Right operand, which may not be evaluated due to short-circuiting
87        right: Arc<Expr<T>>,
88    },
89    /// Boolean OR
90    Or {
91        /// Left operand, which will be eagerly evaluated
92        left: Arc<Expr<T>>,
93        /// Right operand, which may not be evaluated due to short-circuiting
94        right: Arc<Expr<T>>,
95    },
96    /// Application of a built-in unary operator (single parameter)
97    UnaryApp {
98        /// Unary operator to apply
99        op: UnaryOp,
100        /// Argument to apply operator to
101        arg: Arc<Expr<T>>,
102    },
103    /// Application of a built-in binary operator (two parameters)
104    BinaryApp {
105        /// Binary operator to apply
106        op: BinaryOp,
107        /// First arg
108        arg1: Arc<Expr<T>>,
109        /// Second arg
110        arg2: Arc<Expr<T>>,
111    },
112    /// Application of an extension function to n arguments
113    /// INVARIANT (MethodStyleArgs):
114    ///   if op.style is MethodStyle then args _cannot_ be empty.
115    ///     The first element of args refers to the subject of the method call
116    /// Ideally, we find some way to make this non-representable.
117    ExtensionFunctionApp {
118        /// Extension function to apply
119        fn_name: Name,
120        /// Args to apply the function to
121        args: Arc<Vec<Expr<T>>>,
122    },
123    /// Get an attribute of an entity, or a field of a record
124    GetAttr {
125        /// Expression to get an attribute/field of. Must evaluate to either
126        /// Entity or Record type
127        expr: Arc<Expr<T>>,
128        /// Attribute or field to get
129        attr: SmolStr,
130    },
131    /// Does the given `expr` have the given `attr`?
132    HasAttr {
133        /// Expression to test. Must evaluate to either Entity or Record type
134        expr: Arc<Expr<T>>,
135        /// Attribute or field to check for
136        attr: SmolStr,
137    },
138    /// Regex-like string matching similar to IAM's `StringLike` operator.
139    Like {
140        /// Expression to test. Must evaluate to String type
141        expr: Arc<Expr<T>>,
142        /// Pattern to match on; can include the wildcard *, which matches any string.
143        /// To match a literal `*` in the test expression, users can use `\*`.
144        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
145        pattern: Pattern,
146    },
147    /// Entity type test. Does the first argument have the entity type
148    /// specified by the second argument.
149    Is {
150        /// Expression to test. Must evaluate to an Entity.
151        expr: Arc<Expr<T>>,
152        /// The [`EntityType`] used for the type membership test.
153        entity_type: EntityType,
154    },
155    /// Set (whose elements may be arbitrary expressions)
156    //
157    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
158    // that are syntactically unequal, may actually be semantically equal --
159    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
160    // evaluated into `Value`s
161    Set(Arc<Vec<Expr<T>>>),
162    /// Anonymous record (whose elements may be arbitrary expressions)
163    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164    #[cfg(feature = "tolerant-ast")]
165    /// Error expression - allows us to continue parsing even when we have errors
166    Error {
167        /// Type of error that led to the failure
168        error_kind: AstExprErrorKind,
169    },
170}
171
172impl From<Value> for Expr {
173    fn from(v: Value) -> Self {
174        Expr::from(v.value).with_maybe_source_loc(v.loc)
175    }
176}
177
178impl From<ValueKind> for Expr {
179    fn from(v: ValueKind) -> Self {
180        match v {
181            ValueKind::Lit(lit) => Expr::val(lit),
182            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
183            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
184            #[allow(clippy::expect_used)]
185            ValueKind::Record(record) => Expr::record(
186                Arc::unwrap_or_clone(record)
187                    .into_iter()
188                    .map(|(k, v)| (k, Expr::from(v))),
189            )
190            .expect("cannot have duplicate key because the input was already a BTreeMap"),
191            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
192        }
193    }
194}
195
196impl From<PartialValue> for Expr {
197    fn from(pv: PartialValue) -> Self {
198        match pv {
199            PartialValue::Value(v) => Expr::from(v),
200            PartialValue::Residual(expr) => expr,
201        }
202    }
203}
204
205impl<T> Expr<T> {
206    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
207        Self {
208            expr_kind,
209            source_loc,
210            data,
211        }
212    }
213
214    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
215    /// `enum` which specifies the expression variant, so it must be accessed by
216    /// any code matching and recursing on an expression.
217    pub fn expr_kind(&self) -> &ExprKind<T> {
218        &self.expr_kind
219    }
220
221    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
222    pub fn into_expr_kind(self) -> ExprKind<T> {
223        self.expr_kind
224    }
225
226    /// Access the data stored on the `Expr`.
227    pub fn data(&self) -> &T {
228        &self.data
229    }
230
231    /// Access the data stored on the `Expr`, taking ownership and consuming the
232    /// `Expr`.
233    pub fn into_data(self) -> T {
234        self.data
235    }
236
237    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
238    /// data.
239    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
240        (self.expr_kind, self.source_loc, self.data)
241    }
242
243    /// Access the `Loc` stored on the `Expr`.
244    pub fn source_loc(&self) -> Option<&Loc> {
245        self.source_loc.as_ref()
246    }
247
248    /// Return the `Expr`, but with the new `source_loc` (or `None`).
249    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
250        Self { source_loc, ..self }
251    }
252
253    /// Update the data for this `Expr`. A convenient function used by the
254    /// Validator in one place.
255    pub fn set_data(&mut self, data: T) {
256        self.data = data;
257    }
258
259    /// Check whether this expression is an entity reference
260    ///
261    /// This is used for policy scopes, where some syntax is
262    /// required to be an entity reference.
263    pub fn is_ref(&self) -> bool {
264        match &self.expr_kind {
265            ExprKind::Lit(lit) => lit.is_ref(),
266            _ => false,
267        }
268    }
269
270    /// Check whether this expression is a slot.
271    pub fn is_slot(&self) -> bool {
272        matches!(&self.expr_kind, ExprKind::Slot(_))
273    }
274
275    /// Check whether this expression is a set of entity references
276    ///
277    /// This is used for policy scopes, where some syntax is
278    /// required to be an entity reference set.
279    pub fn is_ref_set(&self) -> bool {
280        match &self.expr_kind {
281            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
282            _ => false,
283        }
284    }
285
286    /// Iterate over all sub-expressions in this expression
287    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
288        expr_iterator::ExprIterator::new(self)
289    }
290
291    /// Iterate over all of the slots in this policy AST
292    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
293        self.subexpressions()
294            .filter_map(|exp| match &exp.expr_kind {
295                ExprKind::Slot(slotid) => Some(Slot {
296                    id: *slotid,
297                    loc: exp.source_loc().cloned(),
298                }),
299                _ => None,
300            })
301    }
302
303    /// Determine if the expression is projectable under partial evaluation
304    /// An expression is projectable if it's guaranteed to never error on evaluation
305    /// This is true if the expression is entirely composed of values or unknowns
306    pub fn is_projectable(&self) -> bool {
307        self.subexpressions().all(|e| {
308            matches!(
309                e.expr_kind(),
310                ExprKind::Lit(_)
311                    | ExprKind::Unknown(_)
312                    | ExprKind::Set(_)
313                    | ExprKind::Var(_)
314                    | ExprKind::Record(_)
315            )
316        })
317    }
318
319    /// Try to compute the runtime type of this expression. This operation may
320    /// fail (returning `None`), for example, when asked to get the type of any
321    /// variables, any attributes of entities or records, or an `unknown`
322    /// without an explicitly annotated type.
323    ///
324    /// Also note that this is _not_ typechecking the expression. It does not
325    /// check that the expression actually evaluates to a value (as opposed to
326    /// erroring).
327    ///
328    /// Because of these limitations, this function should only be used to
329    /// obtain a type for use in diagnostics such as error strings.
330    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
331        match &self.expr_kind {
332            ExprKind::Lit(l) => Some(l.type_of()),
333            ExprKind::Var(_) => None,
334            ExprKind::Slot(_) => None,
335            ExprKind::Unknown(u) => u.type_annotation.clone(),
336            ExprKind::If {
337                then_expr,
338                else_expr,
339                ..
340            } => {
341                let type_of_then = then_expr.try_type_of(extensions);
342                let type_of_else = else_expr.try_type_of(extensions);
343                if type_of_then == type_of_else {
344                    type_of_then
345                } else {
346                    None
347                }
348            }
349            ExprKind::And { .. } => Some(Type::Bool),
350            ExprKind::Or { .. } => Some(Type::Bool),
351            ExprKind::UnaryApp {
352                op: UnaryOp::Neg, ..
353            } => Some(Type::Long),
354            ExprKind::UnaryApp {
355                op: UnaryOp::Not, ..
356            } => Some(Type::Bool),
357            ExprKind::UnaryApp {
358                op: UnaryOp::IsEmpty,
359                ..
360            } => Some(Type::Bool),
361            ExprKind::BinaryApp {
362                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
363                ..
364            } => Some(Type::Long),
365            ExprKind::BinaryApp {
366                op:
367                    BinaryOp::Contains
368                    | BinaryOp::ContainsAll
369                    | BinaryOp::ContainsAny
370                    | BinaryOp::Eq
371                    | BinaryOp::In
372                    | BinaryOp::Less
373                    | BinaryOp::LessEq,
374                ..
375            } => Some(Type::Bool),
376            ExprKind::BinaryApp {
377                op: BinaryOp::HasTag,
378                ..
379            } => Some(Type::Bool),
380            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
381                .func(fn_name)
382                .ok()?
383                .return_type()
384                .map(|rty| rty.clone().into()),
385            // We could try to be more complete here, but we can't do all that
386            // much better without evaluating the argument. Even if we know it's
387            // a record `Type::Record` tells us nothing about the type of the
388            // attribute.
389            ExprKind::GetAttr { .. } => None,
390            // similarly to `GetAttr`
391            ExprKind::BinaryApp {
392                op: BinaryOp::GetTag,
393                ..
394            } => None,
395            ExprKind::HasAttr { .. } => Some(Type::Bool),
396            ExprKind::Like { .. } => Some(Type::Bool),
397            ExprKind::Is { .. } => Some(Type::Bool),
398            ExprKind::Set(_) => Some(Type::Set),
399            ExprKind::Record(_) => Some(Type::Record),
400            #[cfg(feature = "tolerant-ast")]
401            ExprKind::Error { .. } => None,
402        }
403    }
404
405    /// Converts an `Expr<V>` to `B::Expr` using the provided builder.
406    ///
407    /// Preserves source location information and recursively transforms each expression node.
408    /// Note: Data may be cloned if the source expression is retained elsewhere.
409    pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
410    where
411        T: Clone,
412    {
413        let builder = B::new().with_maybe_source_loc(self.source_loc().cloned().as_ref());
414        match self.into_expr_kind() {
415            ExprKind::Lit(lit) => builder.val(lit),
416            ExprKind::Var(var) => builder.var(var),
417            ExprKind::Slot(slot) => builder.slot(slot),
418            ExprKind::Unknown(u) => builder.unknown(u),
419            ExprKind::If {
420                test_expr,
421                then_expr,
422                else_expr,
423            } => builder.ite(
424                Arc::unwrap_or_clone(test_expr).into_expr::<B>(),
425                Arc::unwrap_or_clone(then_expr).into_expr::<B>(),
426                Arc::unwrap_or_clone(else_expr).into_expr::<B>(),
427            ),
428            ExprKind::And { left, right } => builder.and(
429                Arc::unwrap_or_clone(left).into_expr::<B>(),
430                Arc::unwrap_or_clone(right).into_expr::<B>(),
431            ),
432            ExprKind::Or { left, right } => builder.or(
433                Arc::unwrap_or_clone(left).into_expr::<B>(),
434                Arc::unwrap_or_clone(right).into_expr::<B>(),
435            ),
436            ExprKind::UnaryApp { op, arg } => {
437                let arg = Arc::unwrap_or_clone(arg).into_expr::<B>();
438                builder.unary_app(op, arg)
439            }
440            ExprKind::BinaryApp { op, arg1, arg2 } => {
441                let arg1 = Arc::unwrap_or_clone(arg1).into_expr::<B>();
442                let arg2 = Arc::unwrap_or_clone(arg2).into_expr::<B>();
443                builder.binary_app(op, arg1, arg2)
444            }
445            ExprKind::ExtensionFunctionApp { fn_name, args } => {
446                let args = Arc::unwrap_or_clone(args)
447                    .into_iter()
448                    .map(|e| e.into_expr::<B>());
449                builder.call_extension_fn(fn_name, args)
450            }
451            ExprKind::GetAttr { expr, attr } => {
452                builder.get_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
453            }
454            ExprKind::HasAttr { expr, attr } => {
455                builder.has_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
456            }
457            ExprKind::Like { expr, pattern } => {
458                builder.like(Arc::unwrap_or_clone(expr).into_expr::<B>(), pattern)
459            }
460            ExprKind::Is { expr, entity_type } => {
461                builder.is_entity_type(Arc::unwrap_or_clone(expr).into_expr::<B>(), entity_type)
462            }
463            ExprKind::Set(set) => builder.set(
464                Arc::unwrap_or_clone(set)
465                    .into_iter()
466                    .map(|e| e.into_expr::<B>()),
467            ),
468            // PANIC SAFETY: `map` is a map, so it will not have duplicates keys, so the `record` constructor cannot error.
469            #[allow(clippy::unwrap_used)]
470            ExprKind::Record(map) => builder
471                .record(
472                    Arc::unwrap_or_clone(map)
473                        .into_iter()
474                        .map(|(k, v)| (k, v.into_expr::<B>())),
475                )
476                .unwrap(),
477            #[cfg(feature = "tolerant-ast")]
478            // PANIC SAFETY: error type is Infallible so can never happen
479            #[allow(clippy::unwrap_used)]
480            ExprKind::Error { .. } => builder
481                .error(ParseErrors::singleton(ToASTError::new(
482                    ToASTErrorKind::ASTErrorNode,
483                    Loc::new(0..1, "AST_ERROR_NODE".into()),
484                )))
485                .unwrap(),
486        }
487    }
488}
489
490#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
491#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
492impl Expr {
493    /// Create an `Expr` that's just a single `Literal`.
494    ///
495    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
496    pub fn val(v: impl Into<Literal>) -> Self {
497        ExprBuilder::new().val(v)
498    }
499
500    /// Create an `Expr` that's just a single `Unknown`.
501    pub fn unknown(u: Unknown) -> Self {
502        ExprBuilder::new().unknown(u)
503    }
504
505    /// Create an `Expr` that's just this literal `Var`
506    pub fn var(v: Var) -> Self {
507        ExprBuilder::new().var(v)
508    }
509
510    /// Create an `Expr` that's just this `SlotId`
511    pub fn slot(s: SlotId) -> Self {
512        ExprBuilder::new().slot(s)
513    }
514
515    /// Create a ternary (if-then-else) `Expr`.
516    ///
517    /// `test_expr` must evaluate to a Bool type
518    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
519        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
520    }
521
522    /// Create a ternary (if-then-else) `Expr`.
523    /// Takes `Arc`s instead of owned `Expr`s.
524    /// `test_expr` must evaluate to a Bool type
525    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
526        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
527    }
528
529    /// Create a 'not' expression. `e` must evaluate to Bool type
530    pub fn not(e: Expr) -> Self {
531        ExprBuilder::new().not(e)
532    }
533
534    /// Create a '==' expression
535    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
536        ExprBuilder::new().is_eq(e1, e2)
537    }
538
539    /// Create a '!=' expression
540    pub fn noteq(e1: Expr, e2: Expr) -> Self {
541        ExprBuilder::new().noteq(e1, e2)
542    }
543
544    /// Create an 'and' expression. Arguments must evaluate to Bool type
545    pub fn and(e1: Expr, e2: Expr) -> Self {
546        ExprBuilder::new().and(e1, e2)
547    }
548
549    /// Create an 'or' expression. Arguments must evaluate to Bool type
550    pub fn or(e1: Expr, e2: Expr) -> Self {
551        ExprBuilder::new().or(e1, e2)
552    }
553
554    /// Create a '<' expression. Arguments must evaluate to Long type
555    pub fn less(e1: Expr, e2: Expr) -> Self {
556        ExprBuilder::new().less(e1, e2)
557    }
558
559    /// Create a '<=' expression. Arguments must evaluate to Long type
560    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
561        ExprBuilder::new().lesseq(e1, e2)
562    }
563
564    /// Create a '>' expression. Arguments must evaluate to Long type
565    pub fn greater(e1: Expr, e2: Expr) -> Self {
566        ExprBuilder::new().greater(e1, e2)
567    }
568
569    /// Create a '>=' expression. Arguments must evaluate to Long type
570    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
571        ExprBuilder::new().greatereq(e1, e2)
572    }
573
574    /// Create an 'add' expression. Arguments must evaluate to Long type
575    pub fn add(e1: Expr, e2: Expr) -> Self {
576        ExprBuilder::new().add(e1, e2)
577    }
578
579    /// Create a 'sub' expression. Arguments must evaluate to Long type
580    pub fn sub(e1: Expr, e2: Expr) -> Self {
581        ExprBuilder::new().sub(e1, e2)
582    }
583
584    /// Create a 'mul' expression. Arguments must evaluate to Long type
585    pub fn mul(e1: Expr, e2: Expr) -> Self {
586        ExprBuilder::new().mul(e1, e2)
587    }
588
589    /// Create a 'neg' expression. `e` must evaluate to Long type.
590    pub fn neg(e: Expr) -> Self {
591        ExprBuilder::new().neg(e)
592    }
593
594    /// Create an 'in' expression. First argument must evaluate to Entity type.
595    /// Second argument must evaluate to either Entity type or Set type where
596    /// all set elements have Entity type.
597    pub fn is_in(e1: Expr, e2: Expr) -> Self {
598        ExprBuilder::new().is_in(e1, e2)
599    }
600
601    /// Create a `contains` expression.
602    /// First argument must have Set type.
603    pub fn contains(e1: Expr, e2: Expr) -> Self {
604        ExprBuilder::new().contains(e1, e2)
605    }
606
607    /// Create a `containsAll` expression. Arguments must evaluate to Set type
608    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
609        ExprBuilder::new().contains_all(e1, e2)
610    }
611
612    /// Create a `containsAny` expression. Arguments must evaluate to Set type
613    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
614        ExprBuilder::new().contains_any(e1, e2)
615    }
616
617    /// Create a `isEmpty` expression. Argument must evaluate to Set type
618    pub fn is_empty(e: Expr) -> Self {
619        ExprBuilder::new().is_empty(e)
620    }
621
622    /// Create a `getTag` expression.
623    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
624    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
625        ExprBuilder::new().get_tag(expr, tag)
626    }
627
628    /// Create a `hasTag` expression.
629    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
630    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
631        ExprBuilder::new().has_tag(expr, tag)
632    }
633
634    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
635    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
636        ExprBuilder::new().set(exprs)
637    }
638
639    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
640    pub fn record(
641        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
642    ) -> Result<Self, ExpressionConstructionError> {
643        ExprBuilder::new().record(pairs)
644    }
645
646    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
647    ///
648    /// If you have an iterator of pairs, generally prefer calling
649    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
650    /// potentially for efficiency reasons but also because `Expr::record()`
651    /// will properly handle duplicate keys but your own `.collect()` will not
652    /// (by default).
653    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
654        ExprBuilder::new().record_arc(map)
655    }
656
657    /// Create an `Expr` which calls the extension function with the given
658    /// `Name` on `args`
659    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
660        ExprBuilder::new().call_extension_fn(fn_name, args)
661    }
662
663    /// Create an application `Expr` which applies the given built-in unary
664    /// operator to the given `arg`
665    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
666        ExprBuilder::new().unary_app(op, arg)
667    }
668
669    /// Create an application `Expr` which applies the given built-in binary
670    /// operator to `arg1` and `arg2`
671    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
672        ExprBuilder::new().binary_app(op, arg1, arg2)
673    }
674
675    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
676    ///
677    /// `expr` must evaluate to either Entity or Record type
678    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
679        ExprBuilder::new().get_attr(expr, attr)
680    }
681
682    /// Create an `Expr` which tests for the existence of a given
683    /// attribute on a given `Entity` or record.
684    ///
685    /// `expr` must evaluate to either Entity or Record type
686    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
687        ExprBuilder::new().has_attr(expr, attr)
688    }
689
690    /// Create a 'like' expression.
691    ///
692    /// `expr` must evaluate to a String type
693    pub fn like(expr: Expr, pattern: Pattern) -> Self {
694        ExprBuilder::new().like(expr, pattern)
695    }
696
697    /// Create an `is` expression.
698    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
699        ExprBuilder::new().is_entity_type(expr, entity_type)
700    }
701
702    /// Check if an expression contains any symbolic unknowns
703    pub fn contains_unknown(&self) -> bool {
704        self.subexpressions()
705            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
706    }
707
708    /// Get all unknowns in an expression
709    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
710        self.subexpressions()
711            .filter_map(|subexpr| match subexpr.expr_kind() {
712                ExprKind::Unknown(u) => Some(u),
713                _ => None,
714            })
715    }
716
717    /// Substitute unknowns with concrete values.
718    ///
719    /// Ignores unmapped unknowns.
720    /// Ignores type annotations on unknowns.
721    /// Note that there might be "undiscovered unknowns" in the Expr, which
722    /// this function does not notice if evaluation of this Expr did not
723    /// traverse all entities and attributes during evaluation, leading to
724    /// this function only substituting one unknown at a time.
725    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
726        match self.substitute_general::<UntypedSubstitution>(definitions) {
727            Ok(e) => e,
728            Err(empty) => match empty {},
729        }
730    }
731
732    /// Substitute unknowns with concrete values.
733    ///
734    /// Ignores unmapped unknowns.
735    /// Errors if the substituted value does not match the type annotation on the unknown.
736    /// Note that there might be "undiscovered unknowns" in the Expr, which
737    /// this function does not notice if evaluation of this Expr did not
738    /// traverse all entities and attributes during evaluation, leading to
739    /// this function only substituting one unknown at a time.
740    pub fn substitute_typed(
741        &self,
742        definitions: &HashMap<SmolStr, Value>,
743    ) -> Result<Expr, SubstitutionError> {
744        self.substitute_general::<TypedSubstitution>(definitions)
745    }
746
747    /// Substitute unknowns with values
748    ///
749    /// Generic over the function implementing the substitution to allow for multiple error behaviors
750    fn substitute_general<T: SubstitutionFunction>(
751        &self,
752        definitions: &HashMap<SmolStr, Value>,
753    ) -> Result<Expr, T::Err> {
754        match self.expr_kind() {
755            ExprKind::Lit(_) => Ok(self.clone()),
756            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
757            ExprKind::Var(_) => Ok(self.clone()),
758            ExprKind::Slot(_) => Ok(self.clone()),
759            ExprKind::If {
760                test_expr,
761                then_expr,
762                else_expr,
763            } => Ok(Expr::ite(
764                test_expr.substitute_general::<T>(definitions)?,
765                then_expr.substitute_general::<T>(definitions)?,
766                else_expr.substitute_general::<T>(definitions)?,
767            )),
768            ExprKind::And { left, right } => Ok(Expr::and(
769                left.substitute_general::<T>(definitions)?,
770                right.substitute_general::<T>(definitions)?,
771            )),
772            ExprKind::Or { left, right } => Ok(Expr::or(
773                left.substitute_general::<T>(definitions)?,
774                right.substitute_general::<T>(definitions)?,
775            )),
776            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
777                *op,
778                arg.substitute_general::<T>(definitions)?,
779            )),
780            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
781                *op,
782                arg1.substitute_general::<T>(definitions)?,
783                arg2.substitute_general::<T>(definitions)?,
784            )),
785            ExprKind::ExtensionFunctionApp { fn_name, args } => {
786                let args = args
787                    .iter()
788                    .map(|e| e.substitute_general::<T>(definitions))
789                    .collect::<Result<Vec<Expr>, _>>()?;
790
791                Ok(Expr::call_extension_fn(fn_name.clone(), args))
792            }
793            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
794                expr.substitute_general::<T>(definitions)?,
795                attr.clone(),
796            )),
797            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
798                expr.substitute_general::<T>(definitions)?,
799                attr.clone(),
800            )),
801            ExprKind::Like { expr, pattern } => Ok(Expr::like(
802                expr.substitute_general::<T>(definitions)?,
803                pattern.clone(),
804            )),
805            ExprKind::Set(members) => {
806                let members = members
807                    .iter()
808                    .map(|e| e.substitute_general::<T>(definitions))
809                    .collect::<Result<Vec<_>, _>>()?;
810                Ok(Expr::set(members))
811            }
812            ExprKind::Record(map) => {
813                let map = map
814                    .iter()
815                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
816                    .collect::<Result<BTreeMap<_, _>, _>>()?;
817                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
818                #[allow(clippy::expect_used)]
819                Ok(Expr::record(map)
820                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
821            }
822            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
823                expr.substitute_general::<T>(definitions)?,
824                entity_type.clone(),
825            )),
826            #[cfg(feature = "tolerant-ast")]
827            ExprKind::Error { .. } => Ok(self.clone()),
828        }
829    }
830}
831
832/// A trait for customizing the error behavior of substitution
833trait SubstitutionFunction {
834    /// The potential errors this substitution function can return
835    type Err;
836    /// The function for implementing the substitution.
837    ///
838    /// Takes the expression being substituted,
839    /// The substitution from the map (if present)
840    /// and the type annotation from the unknown (if present)
841    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
842}
843
844struct TypedSubstitution {}
845
846impl SubstitutionFunction for TypedSubstitution {
847    type Err = SubstitutionError;
848
849    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
850        match (substitute, &value.type_annotation) {
851            (None, _) => Ok(Expr::unknown(value.clone())),
852            (Some(v), None) => Ok(v.clone().into()),
853            (Some(v), Some(t)) => {
854                if v.type_of() == *t {
855                    Ok(v.clone().into())
856                } else {
857                    Err(SubstitutionError::TypeError {
858                        expected: t.clone(),
859                        actual: v.type_of(),
860                    })
861                }
862            }
863        }
864    }
865}
866
867struct UntypedSubstitution {}
868
869impl SubstitutionFunction for UntypedSubstitution {
870    type Err = std::convert::Infallible;
871
872    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
873        Ok(substitute
874            .map(|v| v.clone().into())
875            .unwrap_or_else(|| Expr::unknown(value.clone())))
876    }
877}
878
879impl<T: Clone> std::fmt::Display for Expr<T> {
880    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
882        // we just convert to EST and use the EST pretty-printer.
883        // Note that converting AST->EST is lossless and infallible.
884        write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
885    }
886}
887
888impl<T: Clone> BoundedDisplay for Expr<T> {
889    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
890        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
891        // pretty-printer. Note that converting AST->EST is lossless and infallible.
892        BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
893    }
894}
895
896impl std::str::FromStr for Expr {
897    type Err = ParseErrors;
898
899    fn from_str(s: &str) -> Result<Expr, Self::Err> {
900        crate::parser::parse_expr(s)
901    }
902}
903
904/// Enum for errors encountered during substitution
905#[derive(Debug, Clone, Diagnostic, Error)]
906pub enum SubstitutionError {
907    /// The supplied value did not match the type annotation on the unknown.
908    #[error("expected a value of type {expected}, got a value of type {actual}")]
909    TypeError {
910        /// The expected type, ie: the type the unknown was annotated with
911        expected: Type,
912        /// The type of the provided value
913        actual: Type,
914    },
915}
916
917/// Representation of a partial-evaluation Unknown at the AST level
918#[derive(Hash, Debug, Clone, PartialEq, Eq)]
919pub struct Unknown {
920    /// The name of the unknown
921    pub name: SmolStr,
922    /// The type of the values that can be substituted in for the unknown.
923    /// If `None`, we have no type annotation, and thus a value of any type can
924    /// be substituted.
925    pub type_annotation: Option<Type>,
926}
927
928impl Unknown {
929    /// Create a new untyped `Unknown`
930    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
931        Self {
932            name: name.into(),
933            type_annotation: None,
934        }
935    }
936
937    /// Create a new `Unknown` with type annotation. (Only values of the given
938    /// type can be substituted.)
939    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
940        Self {
941            name: name.into(),
942            type_annotation: Some(ty),
943        }
944    }
945}
946
947impl std::fmt::Display for Unknown {
948    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
949        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
950        // to avoid code duplication
951        write!(
952            f,
953            "{}",
954            Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
955        )
956    }
957}
958
959/// Builder for constructing `Expr` objects annotated with some `data`
960/// (possibly taking default value) and optionally a `source_loc`.
961#[derive(Clone, Debug)]
962pub struct ExprBuilder<T> {
963    source_loc: Option<Loc>,
964    data: T,
965}
966
967impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
968    type Expr = Expr<T>;
969
970    type Data = T;
971
972    #[cfg(feature = "tolerant-ast")]
973    type ErrorType = ParseErrors;
974
975    fn loc(&self) -> Option<&Loc> {
976        self.source_loc.as_ref()
977    }
978
979    fn data(&self) -> &Self::Data {
980        &self.data
981    }
982
983    fn with_data(data: T) -> Self {
984        Self {
985            source_loc: None,
986            data,
987        }
988    }
989
990    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
991        self.source_loc = maybe_source_loc.cloned();
992        self
993    }
994
995    /// Create an `Expr` that's just a single `Literal`.
996    ///
997    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
998    fn val(self, v: impl Into<Literal>) -> Expr<T> {
999        self.with_expr_kind(ExprKind::Lit(v.into()))
1000    }
1001
1002    /// Create an `Unknown` `Expr`
1003    fn unknown(self, u: Unknown) -> Expr<T> {
1004        self.with_expr_kind(ExprKind::Unknown(u))
1005    }
1006
1007    /// Create an `Expr` that's just this literal `Var`
1008    fn var(self, v: Var) -> Expr<T> {
1009        self.with_expr_kind(ExprKind::Var(v))
1010    }
1011
1012    /// Create an `Expr` that's just this `SlotId`
1013    fn slot(self, s: SlotId) -> Expr<T> {
1014        self.with_expr_kind(ExprKind::Slot(s))
1015    }
1016
1017    /// Create a ternary (if-then-else) `Expr`.
1018    ///
1019    /// `test_expr` must evaluate to a Bool type
1020    fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
1021        self.with_expr_kind(ExprKind::If {
1022            test_expr: Arc::new(test_expr),
1023            then_expr: Arc::new(then_expr),
1024            else_expr: Arc::new(else_expr),
1025        })
1026    }
1027
1028    /// Create a 'not' expression. `e` must evaluate to Bool type
1029    fn not(self, e: Expr<T>) -> Expr<T> {
1030        self.with_expr_kind(ExprKind::UnaryApp {
1031            op: UnaryOp::Not,
1032            arg: Arc::new(e),
1033        })
1034    }
1035
1036    /// Create a '==' expression
1037    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1038        self.with_expr_kind(ExprKind::BinaryApp {
1039            op: BinaryOp::Eq,
1040            arg1: Arc::new(e1),
1041            arg2: Arc::new(e2),
1042        })
1043    }
1044
1045    /// Create an 'and' expression. Arguments must evaluate to Bool type
1046    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1047        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1048            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1049                ExprKind::Lit(Literal::Bool(*b1 && *b2))
1050            }
1051            _ => ExprKind::And {
1052                left: Arc::new(e1),
1053                right: Arc::new(e2),
1054            },
1055        })
1056    }
1057
1058    /// Create an 'or' expression. Arguments must evaluate to Bool type
1059    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1060        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1061            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1062                ExprKind::Lit(Literal::Bool(*b1 || *b2))
1063            }
1064
1065            _ => ExprKind::Or {
1066                left: Arc::new(e1),
1067                right: Arc::new(e2),
1068            },
1069        })
1070    }
1071
1072    /// Create a '<' expression. Arguments must evaluate to Long type
1073    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1074        self.with_expr_kind(ExprKind::BinaryApp {
1075            op: BinaryOp::Less,
1076            arg1: Arc::new(e1),
1077            arg2: Arc::new(e2),
1078        })
1079    }
1080
1081    /// Create a '<=' expression. Arguments must evaluate to Long type
1082    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1083        self.with_expr_kind(ExprKind::BinaryApp {
1084            op: BinaryOp::LessEq,
1085            arg1: Arc::new(e1),
1086            arg2: Arc::new(e2),
1087        })
1088    }
1089
1090    /// Create an 'add' expression. Arguments must evaluate to Long type
1091    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1092        self.with_expr_kind(ExprKind::BinaryApp {
1093            op: BinaryOp::Add,
1094            arg1: Arc::new(e1),
1095            arg2: Arc::new(e2),
1096        })
1097    }
1098
1099    /// Create a 'sub' expression. Arguments must evaluate to Long type
1100    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1101        self.with_expr_kind(ExprKind::BinaryApp {
1102            op: BinaryOp::Sub,
1103            arg1: Arc::new(e1),
1104            arg2: Arc::new(e2),
1105        })
1106    }
1107
1108    /// Create a 'mul' expression. Arguments must evaluate to Long type
1109    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1110        self.with_expr_kind(ExprKind::BinaryApp {
1111            op: BinaryOp::Mul,
1112            arg1: Arc::new(e1),
1113            arg2: Arc::new(e2),
1114        })
1115    }
1116
1117    /// Create a 'neg' expression. `e` must evaluate to Long type.
1118    fn neg(self, e: Expr<T>) -> Expr<T> {
1119        self.with_expr_kind(ExprKind::UnaryApp {
1120            op: UnaryOp::Neg,
1121            arg: Arc::new(e),
1122        })
1123    }
1124
1125    /// Create an 'in' expression. First argument must evaluate to Entity type.
1126    /// Second argument must evaluate to either Entity type or Set type where
1127    /// all set elements have Entity type.
1128    fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1129        self.with_expr_kind(ExprKind::BinaryApp {
1130            op: BinaryOp::In,
1131            arg1: Arc::new(e1),
1132            arg2: Arc::new(e2),
1133        })
1134    }
1135
1136    /// Create a 'contains' expression.
1137    /// First argument must have Set type.
1138    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1139        self.with_expr_kind(ExprKind::BinaryApp {
1140            op: BinaryOp::Contains,
1141            arg1: Arc::new(e1),
1142            arg2: Arc::new(e2),
1143        })
1144    }
1145
1146    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1147    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1148        self.with_expr_kind(ExprKind::BinaryApp {
1149            op: BinaryOp::ContainsAll,
1150            arg1: Arc::new(e1),
1151            arg2: Arc::new(e2),
1152        })
1153    }
1154
1155    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1156    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1157        self.with_expr_kind(ExprKind::BinaryApp {
1158            op: BinaryOp::ContainsAny,
1159            arg1: Arc::new(e1),
1160            arg2: Arc::new(e2),
1161        })
1162    }
1163
1164    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1165    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1166        self.with_expr_kind(ExprKind::UnaryApp {
1167            op: UnaryOp::IsEmpty,
1168            arg: Arc::new(expr),
1169        })
1170    }
1171
1172    /// Create a 'getTag' expression.
1173    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1174    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1175        self.with_expr_kind(ExprKind::BinaryApp {
1176            op: BinaryOp::GetTag,
1177            arg1: Arc::new(expr),
1178            arg2: Arc::new(tag),
1179        })
1180    }
1181
1182    /// Create a 'hasTag' expression.
1183    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1184    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1185        self.with_expr_kind(ExprKind::BinaryApp {
1186            op: BinaryOp::HasTag,
1187            arg1: Arc::new(expr),
1188            arg2: Arc::new(tag),
1189        })
1190    }
1191
1192    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1193    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1194        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1195    }
1196
1197    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1198    fn record(
1199        self,
1200        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1201    ) -> Result<Expr<T>, ExpressionConstructionError> {
1202        let mut map = BTreeMap::new();
1203        for (k, v) in pairs {
1204            match map.entry(k) {
1205                btree_map::Entry::Occupied(oentry) => {
1206                    return Err(expression_construction_errors::DuplicateKeyError {
1207                        key: oentry.key().clone(),
1208                        context: "in record literal",
1209                    }
1210                    .into());
1211                }
1212                btree_map::Entry::Vacant(ventry) => {
1213                    ventry.insert(v);
1214                }
1215            }
1216        }
1217        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1218    }
1219
1220    /// Create an `Expr` which calls the extension function with the given
1221    /// `Name` on `args`
1222    fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1223        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1224            fn_name,
1225            args: Arc::new(args.into_iter().collect()),
1226        })
1227    }
1228
1229    /// Create an application `Expr` which applies the given built-in unary
1230    /// operator to the given `arg`
1231    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1232        self.with_expr_kind(ExprKind::UnaryApp {
1233            op: op.into(),
1234            arg: Arc::new(arg),
1235        })
1236    }
1237
1238    /// Create an application `Expr` which applies the given built-in binary
1239    /// operator to `arg1` and `arg2`
1240    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1241        self.with_expr_kind(ExprKind::BinaryApp {
1242            op: op.into(),
1243            arg1: Arc::new(arg1),
1244            arg2: Arc::new(arg2),
1245        })
1246    }
1247
1248    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1249    ///
1250    /// `expr` must evaluate to either Entity or Record type
1251    fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1252        self.with_expr_kind(ExprKind::GetAttr {
1253            expr: Arc::new(expr),
1254            attr,
1255        })
1256    }
1257
1258    /// Create an `Expr` which tests for the existence of a given
1259    /// attribute on a given `Entity` or record.
1260    ///
1261    /// `expr` must evaluate to either Entity or Record type
1262    fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1263        self.with_expr_kind(ExprKind::HasAttr {
1264            expr: Arc::new(expr),
1265            attr,
1266        })
1267    }
1268
1269    /// Create a 'like' expression.
1270    ///
1271    /// `expr` must evaluate to a String type
1272    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1273        self.with_expr_kind(ExprKind::Like {
1274            expr: Arc::new(expr),
1275            pattern,
1276        })
1277    }
1278
1279    /// Create an 'is' expression.
1280    fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1281        self.with_expr_kind(ExprKind::Is {
1282            expr: Arc::new(expr),
1283            entity_type,
1284        })
1285    }
1286
1287    /// Don't support AST Error nodes - return the error right back
1288    #[cfg(feature = "tolerant-ast")]
1289    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1290        Err(parse_errors)
1291    }
1292}
1293
1294impl<T> ExprBuilder<T> {
1295    /// Construct an `Expr` containing the `data` and `source_loc` in this
1296    /// `ExprBuilder` and the given `ExprKind`.
1297    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1298        Expr::new(expr_kind, self.source_loc, self.data)
1299    }
1300
1301    /// Create a ternary (if-then-else) `Expr`.
1302    /// Takes `Arc`s instead of owned `Expr`s.
1303    /// `test_expr` must evaluate to a Bool type
1304    pub fn ite_arc(
1305        self,
1306        test_expr: Arc<Expr<T>>,
1307        then_expr: Arc<Expr<T>>,
1308        else_expr: Arc<Expr<T>>,
1309    ) -> Expr<T> {
1310        self.with_expr_kind(ExprKind::If {
1311            test_expr,
1312            then_expr,
1313            else_expr,
1314        })
1315    }
1316
1317    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1318    ///
1319    /// If you have an iterator of pairs, generally prefer calling `.record()`
1320    /// instead of `.collect()`-ing yourself and calling this, potentially for
1321    /// efficiency reasons but also because `.record()` will properly handle
1322    /// duplicate keys but your own `.collect()` will not (by default).
1323    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1324        self.with_expr_kind(ExprKind::Record(map))
1325    }
1326}
1327
1328impl<T: Clone + Default> ExprBuilder<T> {
1329    /// Utility used the validator to get an expression with the same source
1330    /// location as an existing expression. This is done when reconstructing the
1331    /// `Expr` with type information.
1332    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1333        self.with_maybe_source_loc(expr.source_loc.as_ref())
1334    }
1335}
1336
1337/// Errors when constructing an expression
1338//
1339// CAUTION: this type is publicly exported in `cedar-policy`.
1340// Don't make fields `pub`, don't make breaking changes, and use caution
1341// when adding public methods.
1342#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1343pub enum ExpressionConstructionError {
1344    /// The same key occurred two or more times
1345    #[error(transparent)]
1346    #[diagnostic(transparent)]
1347    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1348}
1349
1350/// Error subtypes for [`ExpressionConstructionError`]
1351pub mod expression_construction_errors {
1352    use miette::Diagnostic;
1353    use smol_str::SmolStr;
1354    use thiserror::Error;
1355
1356    /// The same key occurred two or more times
1357    //
1358    // CAUTION: this type is publicly exported in `cedar-policy`.
1359    // Don't make fields `pub`, don't make breaking changes, and use caution
1360    // when adding public methods.
1361    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1362    #[error("duplicate key `{key}` {context}")]
1363    pub struct DuplicateKeyError {
1364        /// The key which occurred two or more times
1365        pub(crate) key: SmolStr,
1366        /// Information about where the duplicate key occurred (e.g., "in record literal")
1367        pub(crate) context: &'static str,
1368    }
1369
1370    impl DuplicateKeyError {
1371        /// Get the key which occurred two or more times
1372        pub fn key(&self) -> &str {
1373            &self.key
1374        }
1375
1376        /// Make a new error with an updated `context` field
1377        pub(crate) fn with_context(self, context: &'static str) -> Self {
1378            Self { context, ..self }
1379        }
1380    }
1381}
1382
1383/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1384/// implementations that ignore any source information or other generic data
1385/// used to annotate the `Expr`.
1386#[derive(Eq, Debug, Clone)]
1387pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1388
1389impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1390    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1391    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1392    /// ignore source information and generic data.
1393    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1394        ExprShapeOnly(Cow::Borrowed(e))
1395    }
1396
1397    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1398    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1399    /// ignore source information and generic data.
1400    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1401        ExprShapeOnly(Cow::Owned(e))
1402    }
1403}
1404
1405impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1406    fn eq(&self, other: &Self) -> bool {
1407        self.0.eq_shape(&other.0)
1408    }
1409}
1410
1411impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1412    fn hash<H: Hasher>(&self, state: &mut H) {
1413        self.0.hash_shape(state);
1414    }
1415}
1416
1417impl<T> Expr<T> {
1418    /// Return true if this expression (recursively) has the same expression
1419    /// kind as the argument expression. This accounts for the full recursive
1420    /// shape of the expression, but does not consider source information or any
1421    /// generic data annotated on expression. This should behave the same as the
1422    /// default implementation of `Eq` before source information and generic
1423    /// data were added.
1424    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1425        use ExprKind::*;
1426        match (self.expr_kind(), other.expr_kind()) {
1427            (Lit(lit), Lit(lit1)) => lit == lit1,
1428            (Var(v), Var(v1)) => v == v1,
1429            (Slot(s), Slot(s1)) => s == s1,
1430            (
1431                Unknown(self::Unknown {
1432                    name: name1,
1433                    type_annotation: ta_1,
1434                }),
1435                Unknown(self::Unknown {
1436                    name: name2,
1437                    type_annotation: ta_2,
1438                }),
1439            ) => (name1 == name2) && (ta_1 == ta_2),
1440            (
1441                If {
1442                    test_expr,
1443                    then_expr,
1444                    else_expr,
1445                },
1446                If {
1447                    test_expr: test_expr1,
1448                    then_expr: then_expr1,
1449                    else_expr: else_expr1,
1450                },
1451            ) => {
1452                test_expr.eq_shape(test_expr1)
1453                    && then_expr.eq_shape(then_expr1)
1454                    && else_expr.eq_shape(else_expr1)
1455            }
1456            (
1457                And { left, right },
1458                And {
1459                    left: left1,
1460                    right: right1,
1461                },
1462            )
1463            | (
1464                Or { left, right },
1465                Or {
1466                    left: left1,
1467                    right: right1,
1468                },
1469            ) => left.eq_shape(left1) && right.eq_shape(right1),
1470            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1471                op == op1 && arg.eq_shape(arg1)
1472            }
1473            (
1474                BinaryApp { op, arg1, arg2 },
1475                BinaryApp {
1476                    op: op1,
1477                    arg1: arg11,
1478                    arg2: arg21,
1479                },
1480            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1481            (
1482                ExtensionFunctionApp { fn_name, args },
1483                ExtensionFunctionApp {
1484                    fn_name: fn_name1,
1485                    args: args1,
1486                },
1487            ) => {
1488                fn_name == fn_name1
1489                    && args.len() == args1.len()
1490                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1491            }
1492            (
1493                GetAttr { expr, attr },
1494                GetAttr {
1495                    expr: expr1,
1496                    attr: attr1,
1497                },
1498            )
1499            | (
1500                HasAttr { expr, attr },
1501                HasAttr {
1502                    expr: expr1,
1503                    attr: attr1,
1504                },
1505            ) => attr == attr1 && expr.eq_shape(expr1),
1506            (
1507                Like { expr, pattern },
1508                Like {
1509                    expr: expr1,
1510                    pattern: pattern1,
1511                },
1512            ) => pattern == pattern1 && expr.eq_shape(expr1),
1513            (Set(elems), Set(elems1)) => {
1514                elems.len() == elems1.len()
1515                    && elems
1516                        .iter()
1517                        .zip(elems1.iter())
1518                        .all(|(e, e1)| e.eq_shape(e1))
1519            }
1520            (Record(map), Record(map1)) => {
1521                map.len() == map1.len()
1522                    && map
1523                        .iter()
1524                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1525                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1526            }
1527            (
1528                Is { expr, entity_type },
1529                Is {
1530                    expr: expr1,
1531                    entity_type: entity_type1,
1532                },
1533            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1534            _ => false,
1535        }
1536    }
1537
1538    /// Implementation of hashing corresponding to equality as implemented by
1539    /// `eq_shape`. Must satisfy the usual relationship between equality and
1540    /// hashing.
1541    pub fn hash_shape<H>(&self, state: &mut H)
1542    where
1543        H: Hasher,
1544    {
1545        mem::discriminant(self).hash(state);
1546        match self.expr_kind() {
1547            ExprKind::Lit(lit) => lit.hash(state),
1548            ExprKind::Var(v) => v.hash(state),
1549            ExprKind::Slot(s) => s.hash(state),
1550            ExprKind::Unknown(u) => u.hash(state),
1551            ExprKind::If {
1552                test_expr,
1553                then_expr,
1554                else_expr,
1555            } => {
1556                test_expr.hash_shape(state);
1557                then_expr.hash_shape(state);
1558                else_expr.hash_shape(state);
1559            }
1560            ExprKind::And { left, right } => {
1561                left.hash_shape(state);
1562                right.hash_shape(state);
1563            }
1564            ExprKind::Or { left, right } => {
1565                left.hash_shape(state);
1566                right.hash_shape(state);
1567            }
1568            ExprKind::UnaryApp { op, arg } => {
1569                op.hash(state);
1570                arg.hash_shape(state);
1571            }
1572            ExprKind::BinaryApp { op, arg1, arg2 } => {
1573                op.hash(state);
1574                arg1.hash_shape(state);
1575                arg2.hash_shape(state);
1576            }
1577            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1578                fn_name.hash(state);
1579                state.write_usize(args.len());
1580                args.iter().for_each(|a| {
1581                    a.hash_shape(state);
1582                });
1583            }
1584            ExprKind::GetAttr { expr, attr } => {
1585                expr.hash_shape(state);
1586                attr.hash(state);
1587            }
1588            ExprKind::HasAttr { expr, attr } => {
1589                expr.hash_shape(state);
1590                attr.hash(state);
1591            }
1592            ExprKind::Like { expr, pattern } => {
1593                expr.hash_shape(state);
1594                pattern.hash(state);
1595            }
1596            ExprKind::Set(elems) => {
1597                state.write_usize(elems.len());
1598                elems.iter().for_each(|e| {
1599                    e.hash_shape(state);
1600                })
1601            }
1602            ExprKind::Record(map) => {
1603                state.write_usize(map.len());
1604                map.iter().for_each(|(s, a)| {
1605                    s.hash(state);
1606                    a.hash_shape(state);
1607                });
1608            }
1609            ExprKind::Is { expr, entity_type } => {
1610                expr.hash_shape(state);
1611                entity_type.hash(state);
1612            }
1613            #[cfg(feature = "tolerant-ast")]
1614            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1615        }
1616    }
1617}
1618
1619/// AST variables
1620#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1621#[serde(rename_all = "camelCase")]
1622#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1623#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1624#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1625pub enum Var {
1626    /// the Principal of the given request
1627    Principal,
1628    /// the Action of the given request
1629    Action,
1630    /// the Resource of the given request
1631    Resource,
1632    /// the Context of the given request
1633    Context,
1634}
1635
1636impl From<PrincipalOrResource> for Var {
1637    fn from(v: PrincipalOrResource) -> Self {
1638        match v {
1639            PrincipalOrResource::Principal => Var::Principal,
1640            PrincipalOrResource::Resource => Var::Resource,
1641        }
1642    }
1643}
1644
1645// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1646#[allow(clippy::fallible_impl_from)]
1647impl From<Var> for Id {
1648    fn from(var: Var) -> Self {
1649        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1650        #[allow(clippy::unwrap_used)]
1651        format!("{var}").parse().unwrap()
1652    }
1653}
1654
1655// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1656#[allow(clippy::fallible_impl_from)]
1657impl From<Var> for UnreservedId {
1658    fn from(var: Var) -> Self {
1659        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`
1660        #[allow(clippy::unwrap_used)]
1661        Id::from(var).try_into().unwrap()
1662    }
1663}
1664
1665impl std::fmt::Display for Var {
1666    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1667        match self {
1668            Self::Principal => write!(f, "principal"),
1669            Self::Action => write!(f, "action"),
1670            Self::Resource => write!(f, "resource"),
1671            Self::Context => write!(f, "context"),
1672        }
1673    }
1674}
1675
1676#[cfg(test)]
1677mod test {
1678    use cool_asserts::assert_matches;
1679    use itertools::Itertools;
1680    use smol_str::ToSmolStr;
1681    use std::collections::{hash_map::DefaultHasher, HashSet};
1682
1683    use crate::expr_builder::ExprBuilder as _;
1684
1685    use super::*;
1686
1687    pub fn all_vars() -> impl Iterator<Item = Var> {
1688        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1689    }
1690
1691    // Tests that Var::Into never panics
1692    #[test]
1693    fn all_vars_are_ids() {
1694        for var in all_vars() {
1695            let _id: Id = var.into();
1696            let _id: UnreservedId = var.into();
1697        }
1698    }
1699
1700    #[test]
1701    fn exprs() {
1702        assert_eq!(
1703            Expr::val(33),
1704            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1705        );
1706        assert_eq!(
1707            Expr::val("hello"),
1708            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1709        );
1710        assert_eq!(
1711            Expr::val(EntityUID::with_eid("foo")),
1712            Expr::new(
1713                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1714                None,
1715                ()
1716            )
1717        );
1718        assert_eq!(
1719            Expr::var(Var::Principal),
1720            Expr::new(ExprKind::Var(Var::Principal), None, ())
1721        );
1722        assert_eq!(
1723            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1724            Expr::new(
1725                ExprKind::If {
1726                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1727                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1728                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1729                },
1730                None,
1731                ()
1732            )
1733        );
1734        assert_eq!(
1735            Expr::not(Expr::val(false)),
1736            Expr::new(
1737                ExprKind::UnaryApp {
1738                    op: UnaryOp::Not,
1739                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1740                },
1741                None,
1742                ()
1743            )
1744        );
1745        assert_eq!(
1746            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1747            Expr::new(
1748                ExprKind::GetAttr {
1749                    expr: Arc::new(Expr::new(
1750                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1751                        None,
1752                        ()
1753                    )),
1754                    attr: "some_attr".into()
1755                },
1756                None,
1757                ()
1758            )
1759        );
1760        assert_eq!(
1761            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1762            Expr::new(
1763                ExprKind::HasAttr {
1764                    expr: Arc::new(Expr::new(
1765                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1766                        None,
1767                        ()
1768                    )),
1769                    attr: "some_attr".into()
1770                },
1771                None,
1772                ()
1773            )
1774        );
1775        assert_eq!(
1776            Expr::is_entity_type(
1777                Expr::val(EntityUID::with_eid("foo")),
1778                "Type".parse().unwrap()
1779            ),
1780            Expr::new(
1781                ExprKind::Is {
1782                    expr: Arc::new(Expr::new(
1783                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1784                        None,
1785                        ()
1786                    )),
1787                    entity_type: "Type".parse().unwrap()
1788                },
1789                None,
1790                ()
1791            ),
1792        );
1793    }
1794
1795    #[test]
1796    fn like_display() {
1797        // `\0` escaped form is `\0`.
1798        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
1799        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1800        // `\`'s escaped form is `\\`
1801        let e = Expr::like(
1802            Expr::val("a"),
1803            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
1804        );
1805        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1806        // `\`'s escaped form is `\\`
1807        let e = Expr::like(
1808            Expr::val("a"),
1809            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
1810        );
1811        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1812        // literal star's escaped from is `\*`
1813        let e = Expr::like(
1814            Expr::val("a"),
1815            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
1816        );
1817        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1818    }
1819
1820    #[test]
1821    fn has_display() {
1822        // `\0` escaped form is `\0`.
1823        let e = Expr::has_attr(Expr::val("a"), "\0".into());
1824        assert_eq!(format!("{e}"), r#""a" has "\0""#);
1825        // `\`'s escaped form is `\\`
1826        let e = Expr::has_attr(Expr::val("a"), r"\".into());
1827        assert_eq!(format!("{e}"), r#""a" has "\\""#);
1828    }
1829
1830    #[test]
1831    fn slot_display() {
1832        let e = Expr::slot(SlotId::principal());
1833        assert_eq!(format!("{e}"), "?principal");
1834        let e = Expr::slot(SlotId::resource());
1835        assert_eq!(format!("{e}"), "?resource");
1836        let e = Expr::val(EntityUID::with_eid("eid"));
1837        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1838    }
1839
1840    #[test]
1841    fn simple_slots() {
1842        let e = Expr::slot(SlotId::principal());
1843        let p = SlotId::principal();
1844        let r = SlotId::resource();
1845        let set: HashSet<SlotId> = HashSet::from_iter([p]);
1846        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1847        let e = Expr::or(
1848            Expr::slot(SlotId::principal()),
1849            Expr::ite(
1850                Expr::val(true),
1851                Expr::slot(SlotId::resource()),
1852                Expr::val(false),
1853            ),
1854        );
1855        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
1856        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1857    }
1858
1859    #[test]
1860    fn unknowns() {
1861        let e = Expr::ite(
1862            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1863            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1864            Expr::unknown(Unknown::new_untyped("c")),
1865        );
1866        let unknowns = e.unknowns().collect_vec();
1867        assert_eq!(unknowns.len(), 3);
1868        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
1869        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
1870        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
1871    }
1872
1873    #[test]
1874    fn is_unknown() {
1875        let e = Expr::ite(
1876            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1877            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1878            Expr::unknown(Unknown::new_untyped("c")),
1879        );
1880        assert!(e.contains_unknown());
1881        let e = Expr::ite(
1882            Expr::not(Expr::val(true)),
1883            Expr::and(Expr::val(1), Expr::val(3)),
1884            Expr::val(1),
1885        );
1886        assert!(!e.contains_unknown());
1887    }
1888
1889    #[test]
1890    fn expr_with_data() {
1891        let e = ExprBuilder::with_data("data").val(1);
1892        assert_eq!(e.into_data(), "data");
1893    }
1894
1895    #[test]
1896    fn expr_shape_only_eq() {
1897        let temp = ExprBuilder::with_data(1).val(1);
1898        let exprs = &[
1899            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1900            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1901            (
1902                ExprBuilder::with_data(1).var(Var::Principal),
1903                Expr::var(Var::Principal),
1904            ),
1905            (
1906                ExprBuilder::with_data(1).slot(SlotId::principal()),
1907                Expr::slot(SlotId::principal()),
1908            ),
1909            (
1910                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1911                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1912            ),
1913            (
1914                ExprBuilder::with_data(1).not(temp.clone()),
1915                Expr::not(Expr::val(1)),
1916            ),
1917            (
1918                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1919                Expr::is_eq(Expr::val(1), Expr::val(1)),
1920            ),
1921            (
1922                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1923                Expr::and(Expr::val(1), Expr::val(1)),
1924            ),
1925            (
1926                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1927                Expr::or(Expr::val(1), Expr::val(1)),
1928            ),
1929            (
1930                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1931                Expr::less(Expr::val(1), Expr::val(1)),
1932            ),
1933            (
1934                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1935                Expr::lesseq(Expr::val(1), Expr::val(1)),
1936            ),
1937            (
1938                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1939                Expr::greater(Expr::val(1), Expr::val(1)),
1940            ),
1941            (
1942                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1943                Expr::greatereq(Expr::val(1), Expr::val(1)),
1944            ),
1945            (
1946                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1947                Expr::add(Expr::val(1), Expr::val(1)),
1948            ),
1949            (
1950                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1951                Expr::sub(Expr::val(1), Expr::val(1)),
1952            ),
1953            (
1954                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
1955                Expr::mul(Expr::val(1), Expr::val(1)),
1956            ),
1957            (
1958                ExprBuilder::with_data(1).neg(temp.clone()),
1959                Expr::neg(Expr::val(1)),
1960            ),
1961            (
1962                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1963                Expr::is_in(Expr::val(1), Expr::val(1)),
1964            ),
1965            (
1966                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1967                Expr::contains(Expr::val(1), Expr::val(1)),
1968            ),
1969            (
1970                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1971                Expr::contains_all(Expr::val(1), Expr::val(1)),
1972            ),
1973            (
1974                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1975                Expr::contains_any(Expr::val(1), Expr::val(1)),
1976            ),
1977            (
1978                ExprBuilder::with_data(1).is_empty(temp.clone()),
1979                Expr::is_empty(Expr::val(1)),
1980            ),
1981            (
1982                ExprBuilder::with_data(1).set([temp.clone()]),
1983                Expr::set([Expr::val(1)]),
1984            ),
1985            (
1986                ExprBuilder::with_data(1)
1987                    .record([("foo".into(), temp.clone())])
1988                    .unwrap(),
1989                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
1990            ),
1991            (
1992                ExprBuilder::with_data(1)
1993                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1994                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1995            ),
1996            (
1997                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1998                Expr::get_attr(Expr::val(1), "foo".into()),
1999            ),
2000            (
2001                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
2002                Expr::has_attr(Expr::val(1), "foo".into()),
2003            ),
2004            (
2005                ExprBuilder::with_data(1)
2006                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
2007                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2008            ),
2009            (
2010                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2011                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2012            ),
2013        ];
2014
2015        for (e0, e1) in exprs {
2016            assert!(e0.eq_shape(e0));
2017            assert!(e1.eq_shape(e1));
2018            assert!(e0.eq_shape(e1));
2019            assert!(e1.eq_shape(e0));
2020
2021            let mut hasher0 = DefaultHasher::new();
2022            e0.hash_shape(&mut hasher0);
2023            let hash0 = hasher0.finish();
2024
2025            let mut hasher1 = DefaultHasher::new();
2026            e1.hash_shape(&mut hasher1);
2027            let hash1 = hasher1.finish();
2028
2029            assert_eq!(hash0, hash1);
2030        }
2031    }
2032
2033    #[test]
2034    fn expr_shape_only_not_eq() {
2035        let expr1 = ExprBuilder::with_data(1).val(1);
2036        let expr2 = ExprBuilder::with_data(1).val(2);
2037        assert_ne!(
2038            ExprShapeOnly::new_from_borrowed(&expr1),
2039            ExprShapeOnly::new_from_borrowed(&expr2)
2040        );
2041    }
2042
2043    #[test]
2044    fn expr_shape_only_set_prefix_ne() {
2045        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2046        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2047        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2048
2049        assert_ne!(e1, e2);
2050        assert_ne!(e1, e3);
2051        assert_ne!(e2, e1);
2052        assert_ne!(e2, e3);
2053        assert_ne!(e3, e1);
2054        assert_ne!(e2, e1);
2055    }
2056
2057    #[test]
2058    fn expr_shape_only_ext_fn_arg_prefix_ne() {
2059        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2060            "decimal".parse().unwrap(),
2061            vec![],
2062        ));
2063        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2064            "decimal".parse().unwrap(),
2065            vec![Expr::val("0.0")],
2066        ));
2067        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2068            "decimal".parse().unwrap(),
2069            vec![Expr::val("0.0"), Expr::val("0.0")],
2070        ));
2071
2072        assert_ne!(e1, e2);
2073        assert_ne!(e1, e3);
2074        assert_ne!(e2, e1);
2075        assert_ne!(e2, e3);
2076        assert_ne!(e3, e1);
2077        assert_ne!(e2, e1);
2078    }
2079
2080    #[test]
2081    fn expr_shape_only_record_attr_prefix_ne() {
2082        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2083        let e2 = ExprShapeOnly::new_from_owned(
2084            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2085        );
2086        let e3 = ExprShapeOnly::new_from_owned(
2087            Expr::record([
2088                ("a".to_smolstr(), Expr::val(1)),
2089                ("b".to_smolstr(), Expr::val(2)),
2090            ])
2091            .unwrap(),
2092        );
2093
2094        assert_ne!(e1, e2);
2095        assert_ne!(e1, e3);
2096        assert_ne!(e2, e1);
2097        assert_ne!(e2, e3);
2098        assert_ne!(e3, e1);
2099        assert_ne!(e2, e1);
2100    }
2101
2102    #[test]
2103    fn untyped_subst_present() {
2104        let u = Unknown {
2105            name: "foo".into(),
2106            type_annotation: None,
2107        };
2108        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2109        match r {
2110            Ok(e) => assert_eq!(e, Expr::val(1)),
2111            Err(empty) => match empty {},
2112        }
2113    }
2114
2115    #[test]
2116    fn untyped_subst_present_correct_type() {
2117        let u = Unknown {
2118            name: "foo".into(),
2119            type_annotation: Some(Type::Long),
2120        };
2121        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2122        match r {
2123            Ok(e) => assert_eq!(e, Expr::val(1)),
2124            Err(empty) => match empty {},
2125        }
2126    }
2127
2128    #[test]
2129    fn untyped_subst_present_wrong_type() {
2130        let u = Unknown {
2131            name: "foo".into(),
2132            type_annotation: Some(Type::Bool),
2133        };
2134        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2135        match r {
2136            Ok(e) => assert_eq!(e, Expr::val(1)),
2137            Err(empty) => match empty {},
2138        }
2139    }
2140
2141    #[test]
2142    fn untyped_subst_not_present() {
2143        let u = Unknown {
2144            name: "foo".into(),
2145            type_annotation: Some(Type::Bool),
2146        };
2147        let r = UntypedSubstitution::substitute(&u, None);
2148        match r {
2149            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2150            Err(empty) => match empty {},
2151        }
2152    }
2153
2154    #[test]
2155    fn typed_subst_present() {
2156        let u = Unknown {
2157            name: "foo".into(),
2158            type_annotation: None,
2159        };
2160        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2161        assert_eq!(e, Expr::val(1));
2162    }
2163
2164    #[test]
2165    fn typed_subst_present_correct_type() {
2166        let u = Unknown {
2167            name: "foo".into(),
2168            type_annotation: Some(Type::Long),
2169        };
2170        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2171        assert_eq!(e, Expr::val(1));
2172    }
2173
2174    #[test]
2175    fn typed_subst_present_wrong_type() {
2176        let u = Unknown {
2177            name: "foo".into(),
2178            type_annotation: Some(Type::Bool),
2179        };
2180        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2181        assert_matches!(
2182            r,
2183            SubstitutionError::TypeError {
2184                expected: Type::Bool,
2185                actual: Type::Long,
2186            }
2187        );
2188    }
2189
2190    #[test]
2191    fn typed_subst_not_present() {
2192        let u = Unknown {
2193            name: "foo".into(),
2194            type_annotation: None,
2195        };
2196        let r = TypedSubstitution::substitute(&u, None).unwrap();
2197        assert_eq!(r, Expr::unknown(u));
2198    }
2199}