Skip to main content

cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use {
19    super::expr_allows_errors::AstExprErrorKind, crate::parser::err::ToASTError,
20    crate::parser::err::ToASTErrorKind,
21};
22
23use crate::{
24    ast::*,
25    expr_builder::{self, ExprBuilder as _},
26    extensions::Extensions,
27    parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34    borrow::Cow,
35    collections::{btree_map, BTreeMap, HashMap},
36    hash::{Hash, Hasher},
37    mem,
38    sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45/// Internal AST for expressions used by the policy evaluator.
46/// This structure is a wrapper around an `ExprKind`, which is the expression
47/// variant this object contains. It also contains source information about
48/// where the expression was written in policy source code, and some generic
49/// data which is stored on each node of the AST.
50/// Cloning is O(1).
51#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54    expr_kind: ExprKind<T>,
55    #[educe(PartialEq(ignore))]
56    #[educe(Hash(ignore))]
57    source_loc: Option<Loc>,
58    data: T,
59}
60
61/// The possible expression variants. This enum should be matched on by code
62/// recursively traversing the AST.
63#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65    /// Literal value
66    Lit(Literal),
67    /// Variable
68    Var(Var),
69    /// Template Slots
70    Slot(SlotId),
71    /// Symbolic Unknown for partial-eval
72    Unknown(Unknown),
73    /// Ternary expression
74    If {
75        /// Condition for the ternary expression. Must evaluate to Bool type
76        test_expr: Arc<Expr<T>>,
77        /// Value if true
78        then_expr: Arc<Expr<T>>,
79        /// Value if false
80        else_expr: Arc<Expr<T>>,
81    },
82    /// Boolean AND
83    And {
84        /// Left operand, which will be eagerly evaluated
85        left: Arc<Expr<T>>,
86        /// Right operand, which may not be evaluated due to short-circuiting
87        right: Arc<Expr<T>>,
88    },
89    /// Boolean OR
90    Or {
91        /// Left operand, which will be eagerly evaluated
92        left: Arc<Expr<T>>,
93        /// Right operand, which may not be evaluated due to short-circuiting
94        right: Arc<Expr<T>>,
95    },
96    /// Application of a built-in unary operator (single parameter)
97    UnaryApp {
98        /// Unary operator to apply
99        op: UnaryOp,
100        /// Argument to apply operator to
101        arg: Arc<Expr<T>>,
102    },
103    /// Application of a built-in binary operator (two parameters)
104    BinaryApp {
105        /// Binary operator to apply
106        op: BinaryOp,
107        /// First arg
108        arg1: Arc<Expr<T>>,
109        /// Second arg
110        arg2: Arc<Expr<T>>,
111    },
112    /// Application of an extension function to n arguments
113    /// INVARIANT (MethodStyleArgs):
114    ///   if op.style is MethodStyle then args _cannot_ be empty.
115    ///     The first element of args refers to the subject of the method call
116    /// Ideally, we find some way to make this non-representable.
117    ExtensionFunctionApp {
118        /// Extension function to apply
119        fn_name: Name,
120        /// Args to apply the function to
121        args: Arc<Vec<Expr<T>>>,
122    },
123    /// Get an attribute of an entity, or a field of a record
124    GetAttr {
125        /// Expression to get an attribute/field of. Must evaluate to either
126        /// Entity or Record type
127        expr: Arc<Expr<T>>,
128        /// Attribute or field to get
129        attr: SmolStr,
130    },
131    /// Does the given `expr` have the given `attr`?
132    HasAttr {
133        /// Expression to test. Must evaluate to either Entity or Record type
134        expr: Arc<Expr<T>>,
135        /// Attribute or field to check for
136        attr: SmolStr,
137    },
138    /// Regex-like string matching similar to IAM's `StringLike` operator.
139    Like {
140        /// Expression to test. Must evaluate to String type
141        expr: Arc<Expr<T>>,
142        /// Pattern to match on; can include the wildcard *, which matches any string.
143        /// To match a literal `*` in the test expression, users can use `\*`.
144        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
145        pattern: Pattern,
146    },
147    /// Entity type test. Does the first argument have the entity type
148    /// specified by the second argument.
149    Is {
150        /// Expression to test. Must evaluate to an Entity.
151        expr: Arc<Expr<T>>,
152        /// The [`EntityType`] used for the type membership test.
153        entity_type: EntityType,
154    },
155    /// Set (whose elements may be arbitrary expressions)
156    //
157    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
158    // that are syntactically unequal, may actually be semantically equal --
159    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
160    // evaluated into `Value`s
161    Set(Arc<Vec<Expr<T>>>),
162    /// Anonymous record (whose elements may be arbitrary expressions)
163    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164    #[cfg(feature = "tolerant-ast")]
165    /// Error expression - allows us to continue parsing even when we have errors
166    Error {
167        /// Type of error that led to the failure
168        error_kind: AstExprErrorKind,
169    },
170}
171
172impl<T> ExprKind<T> {
173    /// Get the variant order (same as derive(Ord) for enums)
174    fn variant_order(&self) -> u8 {
175        match self {
176            ExprKind::Lit(_) => 0,
177            ExprKind::Var(_) => 1,
178            ExprKind::Slot(_) => 2,
179            ExprKind::Unknown(_) => 3,
180            ExprKind::If { .. } => 4,
181            ExprKind::And { .. } => 5,
182            ExprKind::Or { .. } => 6,
183            ExprKind::UnaryApp { .. } => 7,
184            ExprKind::BinaryApp { .. } => 8,
185            ExprKind::ExtensionFunctionApp { .. } => 9,
186            ExprKind::GetAttr { .. } => 10,
187            ExprKind::HasAttr { .. } => 11,
188            ExprKind::Like { .. } => 12,
189            ExprKind::Set(_) => 13,
190            ExprKind::Record(_) => 14,
191            ExprKind::Is { .. } => 15,
192            #[cfg(feature = "tolerant-ast")]
193            ExprKind::Error { .. } => 16,
194        }
195    }
196}
197
198impl From<Value> for Expr {
199    fn from(v: Value) -> Self {
200        Expr::from(v.value).with_maybe_source_loc(v.loc)
201    }
202}
203
204impl From<ValueKind> for Expr {
205    fn from(v: ValueKind) -> Self {
206        match v {
207            ValueKind::Lit(lit) => Expr::val(lit),
208            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
209            #[expect(
210                clippy::expect_used,
211                reason = "cannot have duplicate key because the input was already a BTreeMap"
212            )]
213            ValueKind::Record(record) => Expr::record(
214                Arc::unwrap_or_clone(record)
215                    .into_iter()
216                    .map(|(k, v)| (k, Expr::from(v))),
217            )
218            .expect("cannot have duplicate key because the input was already a BTreeMap"),
219            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
220        }
221    }
222}
223
224impl From<PartialValue> for Expr {
225    fn from(pv: PartialValue) -> Self {
226        match pv {
227            PartialValue::Value(v) => Expr::from(v),
228            PartialValue::Residual(expr) => expr,
229        }
230    }
231}
232
233impl<T> Expr<T> {
234    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
235        Self {
236            expr_kind,
237            source_loc,
238            data,
239        }
240    }
241
242    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
243    /// `enum` which specifies the expression variant, so it must be accessed by
244    /// any code matching and recursing on an expression.
245    pub fn expr_kind(&self) -> &ExprKind<T> {
246        &self.expr_kind
247    }
248
249    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
250    pub fn into_expr_kind(self) -> ExprKind<T> {
251        self.expr_kind
252    }
253
254    /// Access the data stored on the `Expr`.
255    pub fn data(&self) -> &T {
256        &self.data
257    }
258
259    /// Access the data stored on the `Expr`, taking ownership and consuming the
260    /// `Expr`.
261    pub fn into_data(self) -> T {
262        self.data
263    }
264
265    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
266    /// data.
267    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
268        (self.expr_kind, self.source_loc, self.data)
269    }
270
271    /// Access the `Loc` stored on the `Expr`.
272    pub fn source_loc(&self) -> Option<&Loc> {
273        self.source_loc.as_ref()
274    }
275
276    /// Return the `Expr`, but with the new `source_loc` (or `None`).
277    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
278        Self { source_loc, ..self }
279    }
280
281    /// Update the data for this `Expr`. A convenient function used by the
282    /// Validator in one place.
283    pub fn set_data(&mut self, data: T) {
284        self.data = data;
285    }
286
287    /// Check whether this expression is an entity reference
288    ///
289    /// This is used for policy scopes, where some syntax is
290    /// required to be an entity reference.
291    pub fn is_ref(&self) -> bool {
292        match &self.expr_kind {
293            ExprKind::Lit(lit) => lit.is_ref(),
294            _ => false,
295        }
296    }
297
298    /// Check whether this expression is a slot.
299    pub fn is_slot(&self) -> bool {
300        matches!(&self.expr_kind, ExprKind::Slot(_))
301    }
302
303    /// Check whether this expression is a set of entity references
304    ///
305    /// This is used for policy scopes, where some syntax is
306    /// required to be an entity reference set.
307    pub fn is_ref_set(&self) -> bool {
308        match &self.expr_kind {
309            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
310            _ => false,
311        }
312    }
313
314    /// Iterate over all sub-expressions in this expression
315    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
316        expr_iterator::ExprIterator::new(self)
317    }
318
319    /// Iterate over all of the slots in this policy AST
320    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
321        self.subexpressions()
322            .filter_map(|exp| match &exp.expr_kind {
323                ExprKind::Slot(slotid) => Some(Slot {
324                    id: *slotid,
325                    loc: exp.source_loc().cloned(),
326                }),
327                _ => None,
328            })
329    }
330
331    /// Determine if the expression is projectable under partial evaluation
332    /// An expression is projectable if it's guaranteed to never error on evaluation
333    /// This is true if the expression is entirely composed of values or unknowns
334    pub fn is_projectable(&self) -> bool {
335        self.subexpressions().all(|e| {
336            matches!(
337                e.expr_kind(),
338                ExprKind::Lit(_)
339                    | ExprKind::Unknown(_)
340                    | ExprKind::Set(_)
341                    | ExprKind::Var(_)
342                    | ExprKind::Record(_)
343            )
344        })
345    }
346
347    /// Try to compute the runtime type of this expression. This operation may
348    /// fail (returning `None`), for example, when asked to get the type of any
349    /// variables, any attributes of entities or records, or an `unknown`
350    /// without an explicitly annotated type.
351    ///
352    /// Also note that this is _not_ typechecking the expression. It does not
353    /// check that the expression actually evaluates to a value (as opposed to
354    /// erroring).
355    ///
356    /// Because of these limitations, this function should only be used to
357    /// obtain a type for use in diagnostics such as error strings.
358    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
359        match &self.expr_kind {
360            ExprKind::Lit(l) => Some(l.type_of()),
361            ExprKind::Var(_) => None,
362            ExprKind::Slot(_) => None,
363            ExprKind::Unknown(u) => u.type_annotation.clone(),
364            ExprKind::If {
365                then_expr,
366                else_expr,
367                ..
368            } => {
369                let type_of_then = then_expr.try_type_of(extensions);
370                let type_of_else = else_expr.try_type_of(extensions);
371                if type_of_then == type_of_else {
372                    type_of_then
373                } else {
374                    None
375                }
376            }
377            ExprKind::And { .. } => Some(Type::Bool),
378            ExprKind::Or { .. } => Some(Type::Bool),
379            ExprKind::UnaryApp {
380                op: UnaryOp::Neg, ..
381            } => Some(Type::Long),
382            ExprKind::UnaryApp {
383                op: UnaryOp::Not, ..
384            } => Some(Type::Bool),
385            ExprKind::UnaryApp {
386                op: UnaryOp::IsEmpty,
387                ..
388            } => Some(Type::Bool),
389            ExprKind::BinaryApp {
390                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
391                ..
392            } => Some(Type::Long),
393            ExprKind::BinaryApp {
394                op:
395                    BinaryOp::Contains
396                    | BinaryOp::ContainsAll
397                    | BinaryOp::ContainsAny
398                    | BinaryOp::Eq
399                    | BinaryOp::In
400                    | BinaryOp::Less
401                    | BinaryOp::LessEq,
402                ..
403            } => Some(Type::Bool),
404            ExprKind::BinaryApp {
405                op: BinaryOp::HasTag,
406                ..
407            } => Some(Type::Bool),
408            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
409                .func(fn_name)
410                .ok()?
411                .return_type()
412                .map(|rty| rty.clone().into()),
413            // We could try to be more complete here, but we can't do all that
414            // much better without evaluating the argument. Even if we know it's
415            // a record `Type::Record` tells us nothing about the type of the
416            // attribute.
417            ExprKind::GetAttr { .. } => None,
418            // similarly to `GetAttr`
419            ExprKind::BinaryApp {
420                op: BinaryOp::GetTag,
421                ..
422            } => None,
423            ExprKind::HasAttr { .. } => Some(Type::Bool),
424            ExprKind::Like { .. } => Some(Type::Bool),
425            ExprKind::Is { .. } => Some(Type::Bool),
426            ExprKind::Set(_) => Some(Type::Set),
427            ExprKind::Record(_) => Some(Type::Record),
428            #[cfg(feature = "tolerant-ast")]
429            ExprKind::Error { .. } => None,
430        }
431    }
432
433    /// Converts an `Expr<V>` to `B::Expr` using the provided builder.
434    ///
435    /// Preserves source location information and recursively transforms each expression node.
436    /// Note: Data may be cloned if the source expression is retained elsewhere.
437    pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
438    where
439        T: Clone,
440    {
441        let builder = B::new().with_maybe_source_loc(self.source_loc());
442        match self.into_expr_kind() {
443            ExprKind::Lit(lit) => builder.val(lit),
444            ExprKind::Var(var) => builder.var(var),
445            ExprKind::Slot(slot) => builder.slot(slot),
446            ExprKind::Unknown(u) => builder.unknown(u),
447            ExprKind::If {
448                test_expr,
449                then_expr,
450                else_expr,
451            } => builder.ite(
452                Arc::unwrap_or_clone(test_expr).into_expr::<B>(),
453                Arc::unwrap_or_clone(then_expr).into_expr::<B>(),
454                Arc::unwrap_or_clone(else_expr).into_expr::<B>(),
455            ),
456            ExprKind::And { left, right } => builder.and(
457                Arc::unwrap_or_clone(left).into_expr::<B>(),
458                Arc::unwrap_or_clone(right).into_expr::<B>(),
459            ),
460            ExprKind::Or { left, right } => builder.or(
461                Arc::unwrap_or_clone(left).into_expr::<B>(),
462                Arc::unwrap_or_clone(right).into_expr::<B>(),
463            ),
464            ExprKind::UnaryApp { op, arg } => {
465                let arg = Arc::unwrap_or_clone(arg).into_expr::<B>();
466                builder.unary_app(op, arg)
467            }
468            ExprKind::BinaryApp { op, arg1, arg2 } => {
469                let arg1 = Arc::unwrap_or_clone(arg1).into_expr::<B>();
470                let arg2 = Arc::unwrap_or_clone(arg2).into_expr::<B>();
471                builder.binary_app(op, arg1, arg2)
472            }
473            ExprKind::ExtensionFunctionApp { fn_name, args } => {
474                let args = Arc::unwrap_or_clone(args)
475                    .into_iter()
476                    .map(|e| e.into_expr::<B>());
477                builder.call_extension_fn(fn_name, args)
478            }
479            ExprKind::GetAttr { expr, attr } => {
480                builder.get_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
481            }
482            ExprKind::HasAttr { expr, attr } => {
483                builder.has_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
484            }
485            ExprKind::Like { expr, pattern } => {
486                builder.like(Arc::unwrap_or_clone(expr).into_expr::<B>(), pattern)
487            }
488            ExprKind::Is { expr, entity_type } => {
489                builder.is_entity_type(Arc::unwrap_or_clone(expr).into_expr::<B>(), entity_type)
490            }
491            ExprKind::Set(set) => builder.set(
492                Arc::unwrap_or_clone(set)
493                    .into_iter()
494                    .map(|e| e.into_expr::<B>()),
495            ),
496            #[expect(
497                clippy::unwrap_used,
498                reason = "`map` is a map, so it will not have duplicate keys, so the `.record()` constructor cannot error"
499            )]
500            ExprKind::Record(map) => builder
501                .record(
502                    Arc::unwrap_or_clone(map)
503                        .into_iter()
504                        .map(|(k, v)| (k, v.into_expr::<B>())),
505                )
506                .unwrap(),
507            #[cfg(feature = "tolerant-ast")]
508            #[expect(
509                clippy::unwrap_used,
510                reason = "error type is Infallible so can never happen"
511            )]
512            ExprKind::Error { .. } => builder
513                .error(ParseErrors::singleton(ToASTError::new(
514                    ToASTErrorKind::ASTErrorNode,
515                    Some(Loc::new(0..1, "AST_ERROR_NODE".into())),
516                )))
517                .unwrap(),
518        }
519    }
520}
521
522#[expect(
523    clippy::should_implement_trait,
524    reason = "the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`"
525)]
526impl Expr {
527    /// Create an `Expr` that's just a single `Literal`.
528    ///
529    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
530    pub fn val(v: impl Into<Literal>) -> Self {
531        ExprBuilder::new().val(v)
532    }
533
534    /// Create an `Expr` that's just a single `Unknown`.
535    pub fn unknown(u: Unknown) -> Self {
536        ExprBuilder::new().unknown(u)
537    }
538
539    /// Create an `Expr` that's just this literal `Var`
540    pub fn var(v: Var) -> Self {
541        ExprBuilder::new().var(v)
542    }
543
544    /// Create an `Expr` that's just this `SlotId`
545    pub fn slot(s: SlotId) -> Self {
546        ExprBuilder::new().slot(s)
547    }
548
549    /// Create a ternary (if-then-else) `Expr`.
550    ///
551    /// `test_expr` must evaluate to a Bool type
552    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
553        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
554    }
555
556    /// Create a ternary (if-then-else) `Expr`.
557    /// Takes `Arc`s instead of owned `Expr`s.
558    /// `test_expr` must evaluate to a Bool type
559    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
560        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
561    }
562
563    /// Create a 'not' expression. `e` must evaluate to Bool type
564    pub fn not(e: Expr) -> Self {
565        ExprBuilder::new().not(e)
566    }
567
568    /// Create a '==' expression
569    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
570        ExprBuilder::new().is_eq(e1, e2)
571    }
572
573    /// Create a '!=' expression
574    pub fn noteq(e1: Expr, e2: Expr) -> Self {
575        ExprBuilder::new().noteq(e1, e2)
576    }
577
578    /// Create an 'and' expression. Arguments must evaluate to Bool type
579    pub fn and(e1: Expr, e2: Expr) -> Self {
580        ExprBuilder::new().and(e1, e2)
581    }
582
583    /// Create an 'or' expression. Arguments must evaluate to Bool type
584    pub fn or(e1: Expr, e2: Expr) -> Self {
585        ExprBuilder::new().or(e1, e2)
586    }
587
588    /// Create a '<' expression. Arguments must evaluate to Long type
589    pub fn less(e1: Expr, e2: Expr) -> Self {
590        ExprBuilder::new().less(e1, e2)
591    }
592
593    /// Create a '<=' expression. Arguments must evaluate to Long type
594    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
595        ExprBuilder::new().lesseq(e1, e2)
596    }
597
598    /// Create a '>' expression. Arguments must evaluate to Long type
599    pub fn greater(e1: Expr, e2: Expr) -> Self {
600        ExprBuilder::new().greater(e1, e2)
601    }
602
603    /// Create a '>=' expression. Arguments must evaluate to Long type
604    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
605        ExprBuilder::new().greatereq(e1, e2)
606    }
607
608    /// Create an 'add' expression. Arguments must evaluate to Long type
609    pub fn add(e1: Expr, e2: Expr) -> Self {
610        ExprBuilder::new().add(e1, e2)
611    }
612
613    /// Create a 'sub' expression. Arguments must evaluate to Long type
614    pub fn sub(e1: Expr, e2: Expr) -> Self {
615        ExprBuilder::new().sub(e1, e2)
616    }
617
618    /// Create a 'mul' expression. Arguments must evaluate to Long type
619    pub fn mul(e1: Expr, e2: Expr) -> Self {
620        ExprBuilder::new().mul(e1, e2)
621    }
622
623    /// Create a 'neg' expression. `e` must evaluate to Long type.
624    pub fn neg(e: Expr) -> Self {
625        ExprBuilder::new().neg(e)
626    }
627
628    /// Create an 'in' expression. First argument must evaluate to Entity type.
629    /// Second argument must evaluate to either Entity type or Set type where
630    /// all set elements have Entity type.
631    pub fn is_in(e1: Expr, e2: Expr) -> Self {
632        ExprBuilder::new().is_in(e1, e2)
633    }
634
635    /// Create a `contains` expression.
636    /// First argument must have Set type.
637    pub fn contains(e1: Expr, e2: Expr) -> Self {
638        ExprBuilder::new().contains(e1, e2)
639    }
640
641    /// Create a `containsAll` expression. Arguments must evaluate to Set type
642    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
643        ExprBuilder::new().contains_all(e1, e2)
644    }
645
646    /// Create a `containsAny` expression. Arguments must evaluate to Set type
647    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
648        ExprBuilder::new().contains_any(e1, e2)
649    }
650
651    /// Create a `isEmpty` expression. Argument must evaluate to Set type
652    pub fn is_empty(e: Expr) -> Self {
653        ExprBuilder::new().is_empty(e)
654    }
655
656    /// Create a `getTag` expression.
657    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
658    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
659        ExprBuilder::new().get_tag(expr, tag)
660    }
661
662    /// Create a `hasTag` expression.
663    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
664    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
665        ExprBuilder::new().has_tag(expr, tag)
666    }
667
668    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
669    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
670        ExprBuilder::new().set(exprs)
671    }
672
673    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
674    pub fn record(
675        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
676    ) -> Result<Self, ExpressionConstructionError> {
677        ExprBuilder::new().record(pairs)
678    }
679
680    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
681    ///
682    /// If you have an iterator of pairs, generally prefer calling
683    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
684    /// potentially for efficiency reasons but also because `Expr::record()`
685    /// will properly handle duplicate keys but your own `.collect()` will not
686    /// (by default).
687    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
688        ExprBuilder::new().record_arc(map)
689    }
690
691    /// Create an `Expr` which calls the extension function with the given
692    /// `Name` on `args`
693    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
694        ExprBuilder::new().call_extension_fn(fn_name, args)
695    }
696
697    /// Create an application `Expr` which applies the given built-in unary
698    /// operator to the given `arg`
699    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
700        ExprBuilder::new().unary_app(op, arg)
701    }
702
703    /// Create an application `Expr` which applies the given built-in binary
704    /// operator to `arg1` and `arg2`
705    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
706        ExprBuilder::new().binary_app(op, arg1, arg2)
707    }
708
709    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
710    ///
711    /// `expr` must evaluate to either Entity or Record type
712    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
713        ExprBuilder::new().get_attr(expr, attr)
714    }
715
716    /// Create an `Expr` which tests for the existence of a given
717    /// attribute on a given `Entity` or record.
718    ///
719    /// `expr` must evaluate to either Entity or Record type
720    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
721        ExprBuilder::new().has_attr(expr, attr)
722    }
723
724    /// Create a 'like' expression.
725    ///
726    /// `expr` must evaluate to a String type
727    pub fn like(expr: Expr, pattern: Pattern) -> Self {
728        ExprBuilder::new().like(expr, pattern)
729    }
730
731    /// Create an `is` expression.
732    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
733        ExprBuilder::new().is_entity_type(expr, entity_type)
734    }
735
736    /// Check if an expression contains any symbolic unknowns
737    pub fn contains_unknown(&self) -> bool {
738        self.subexpressions()
739            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
740    }
741
742    /// Get all unknowns in an expression
743    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
744        self.subexpressions()
745            .filter_map(|subexpr| match subexpr.expr_kind() {
746                ExprKind::Unknown(u) => Some(u),
747                _ => None,
748            })
749    }
750
751    /// Substitute unknowns with concrete values.
752    ///
753    /// Ignores unmapped unknowns.
754    /// Ignores type annotations on unknowns.
755    /// Note that there might be "undiscovered unknowns" in the Expr, which
756    /// this function does not notice if evaluation of this Expr did not
757    /// traverse all entities and attributes during evaluation, leading to
758    /// this function only substituting one unknown at a time.
759    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
760        match self.substitute_general::<UntypedSubstitution>(definitions) {
761            Ok(e) => e,
762            Err(empty) => match empty {},
763        }
764    }
765
766    /// Substitute unknowns with concrete values.
767    ///
768    /// Ignores unmapped unknowns.
769    /// Errors if the substituted value does not match the type annotation on the unknown.
770    /// Note that there might be "undiscovered unknowns" in the Expr, which
771    /// this function does not notice if evaluation of this Expr did not
772    /// traverse all entities and attributes during evaluation, leading to
773    /// this function only substituting one unknown at a time.
774    pub fn substitute_typed(
775        &self,
776        definitions: &HashMap<SmolStr, Value>,
777    ) -> Result<Expr, SubstitutionError> {
778        self.substitute_general::<TypedSubstitution>(definitions)
779    }
780
781    /// Substitute unknowns with values
782    ///
783    /// Generic over the function implementing the substitution to allow for multiple error behaviors
784    fn substitute_general<T: SubstitutionFunction>(
785        &self,
786        definitions: &HashMap<SmolStr, Value>,
787    ) -> Result<Expr, T::Err> {
788        match self.expr_kind() {
789            ExprKind::Lit(_) => Ok(self.clone()),
790            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
791            ExprKind::Var(_) => Ok(self.clone()),
792            ExprKind::Slot(_) => Ok(self.clone()),
793            ExprKind::If {
794                test_expr,
795                then_expr,
796                else_expr,
797            } => Ok(Expr::ite(
798                test_expr.substitute_general::<T>(definitions)?,
799                then_expr.substitute_general::<T>(definitions)?,
800                else_expr.substitute_general::<T>(definitions)?,
801            )),
802            ExprKind::And { left, right } => Ok(Expr::and(
803                left.substitute_general::<T>(definitions)?,
804                right.substitute_general::<T>(definitions)?,
805            )),
806            ExprKind::Or { left, right } => Ok(Expr::or(
807                left.substitute_general::<T>(definitions)?,
808                right.substitute_general::<T>(definitions)?,
809            )),
810            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
811                *op,
812                arg.substitute_general::<T>(definitions)?,
813            )),
814            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
815                *op,
816                arg1.substitute_general::<T>(definitions)?,
817                arg2.substitute_general::<T>(definitions)?,
818            )),
819            ExprKind::ExtensionFunctionApp { fn_name, args } => {
820                let args = args
821                    .iter()
822                    .map(|e| e.substitute_general::<T>(definitions))
823                    .collect::<Result<Vec<Expr>, _>>()?;
824
825                Ok(Expr::call_extension_fn(fn_name.clone(), args))
826            }
827            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
828                expr.substitute_general::<T>(definitions)?,
829                attr.clone(),
830            )),
831            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
832                expr.substitute_general::<T>(definitions)?,
833                attr.clone(),
834            )),
835            ExprKind::Like { expr, pattern } => Ok(Expr::like(
836                expr.substitute_general::<T>(definitions)?,
837                pattern.clone(),
838            )),
839            ExprKind::Set(members) => {
840                let members = members
841                    .iter()
842                    .map(|e| e.substitute_general::<T>(definitions))
843                    .collect::<Result<Vec<_>, _>>()?;
844                Ok(Expr::set(members))
845            }
846            ExprKind::Record(map) => {
847                let map = map
848                    .iter()
849                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
850                    .collect::<Result<BTreeMap<_, _>, _>>()?;
851                #[expect(
852                    clippy::expect_used,
853                    reason = "cannot have a duplicate key because the input was already a BTreeMap"
854                )]
855                Ok(Expr::record(map)
856                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
857            }
858            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
859                expr.substitute_general::<T>(definitions)?,
860                entity_type.clone(),
861            )),
862            #[cfg(feature = "tolerant-ast")]
863            ExprKind::Error { .. } => Ok(self.clone()),
864        }
865    }
866}
867
868/// A trait for customizing the error behavior of substitution
869trait SubstitutionFunction {
870    /// The potential errors this substitution function can return
871    type Err;
872    /// The function for implementing the substitution.
873    ///
874    /// Takes the expression being substituted,
875    /// The substitution from the map (if present)
876    /// and the type annotation from the unknown (if present)
877    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
878}
879
880struct TypedSubstitution {}
881
882impl SubstitutionFunction for TypedSubstitution {
883    type Err = SubstitutionError;
884
885    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
886        match (substitute, &value.type_annotation) {
887            (None, _) => Ok(Expr::unknown(value.clone())),
888            (Some(v), None) => Ok(v.clone().into()),
889            (Some(v), Some(t)) => {
890                if v.type_of() == *t {
891                    Ok(v.clone().into())
892                } else {
893                    Err(SubstitutionError::TypeError {
894                        expected: t.clone(),
895                        actual: v.type_of(),
896                    })
897                }
898            }
899        }
900    }
901}
902
903struct UntypedSubstitution {}
904
905impl SubstitutionFunction for UntypedSubstitution {
906    type Err = std::convert::Infallible;
907
908    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
909        Ok(substitute
910            .map(|v| v.clone().into())
911            .unwrap_or_else(|| Expr::unknown(value.clone())))
912    }
913}
914
915impl<T: Clone> std::fmt::Display for Expr<T> {
916    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
917        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
918        // we just convert to EST and use the EST pretty-printer.
919        // Note that converting AST->EST is lossless and infallible.
920        write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
921    }
922}
923
924impl<T: Clone> BoundedDisplay for Expr<T> {
925    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
926        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
927        // pretty-printer. Note that converting AST->EST is lossless and infallible.
928        BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
929    }
930}
931
932impl std::str::FromStr for Expr {
933    type Err = ParseErrors;
934
935    fn from_str(s: &str) -> Result<Expr, Self::Err> {
936        crate::parser::parse_expr(s)
937    }
938}
939
940/// Enum for errors encountered during substitution
941#[derive(Debug, Clone, Diagnostic, Error)]
942pub enum SubstitutionError {
943    /// The supplied value did not match the type annotation on the unknown.
944    #[error("expected a value of type {expected}, got a value of type {actual}")]
945    TypeError {
946        /// The expected type, ie: the type the unknown was annotated with
947        expected: Type,
948        /// The type of the provided value
949        actual: Type,
950    },
951}
952
953/// Representation of a partial-evaluation Unknown at the AST level
954#[derive(Hash, Debug, Clone, PartialEq, Eq)]
955pub struct Unknown {
956    /// The name of the unknown
957    pub name: SmolStr,
958    /// The type of the values that can be substituted in for the unknown.
959    /// If `None`, we have no type annotation, and thus a value of any type can
960    /// be substituted.
961    pub type_annotation: Option<Type>,
962}
963
964impl Unknown {
965    /// Create a new untyped `Unknown`
966    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
967        Self {
968            name: name.into(),
969            type_annotation: None,
970        }
971    }
972
973    /// Create a new `Unknown` with type annotation. (Only values of the given
974    /// type can be substituted.)
975    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
976        Self {
977            name: name.into(),
978            type_annotation: Some(ty),
979        }
980    }
981}
982
983impl std::fmt::Display for Unknown {
984    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
985        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
986        // to avoid code duplication
987        write!(
988            f,
989            "{}",
990            Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
991        )
992    }
993}
994
995/// Builder for constructing `Expr` objects annotated with some `data`
996/// (possibly taking default value) and optionally a `source_loc`.
997#[derive(Clone, Debug)]
998pub struct ExprBuilder<T> {
999    source_loc: Option<Loc>,
1000    data: T,
1001}
1002
1003impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
1004    type Expr = Expr<T>;
1005
1006    type Data = T;
1007
1008    #[cfg(feature = "tolerant-ast")]
1009    type ErrorType = ParseErrors;
1010
1011    fn loc(&self) -> Option<&Loc> {
1012        self.source_loc.as_ref()
1013    }
1014
1015    fn data(&self) -> &Self::Data {
1016        &self.data
1017    }
1018
1019    fn with_data(data: T) -> Self {
1020        Self {
1021            source_loc: None,
1022            data,
1023        }
1024    }
1025
1026    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
1027        self.source_loc = maybe_source_loc.cloned();
1028        self
1029    }
1030
1031    /// Create an `Expr` that's just a single `Literal`.
1032    ///
1033    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
1034    fn val(self, v: impl Into<Literal>) -> Expr<T> {
1035        self.with_expr_kind(ExprKind::Lit(v.into()))
1036    }
1037
1038    /// Create an `Unknown` `Expr`
1039    fn unknown(self, u: Unknown) -> Expr<T> {
1040        self.with_expr_kind(ExprKind::Unknown(u))
1041    }
1042
1043    /// Create an `Expr` that's just this literal `Var`
1044    fn var(self, v: Var) -> Expr<T> {
1045        self.with_expr_kind(ExprKind::Var(v))
1046    }
1047
1048    /// Create an `Expr` that's just this `SlotId`
1049    fn slot(self, s: SlotId) -> Expr<T> {
1050        self.with_expr_kind(ExprKind::Slot(s))
1051    }
1052
1053    /// Create a ternary (if-then-else) `Expr`.
1054    ///
1055    /// `test_expr` must evaluate to a Bool type
1056    fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
1057        self.with_expr_kind(ExprKind::If {
1058            test_expr: Arc::new(test_expr),
1059            then_expr: Arc::new(then_expr),
1060            else_expr: Arc::new(else_expr),
1061        })
1062    }
1063
1064    /// Create a 'not' expression. `e` must evaluate to Bool type
1065    fn not(self, e: Expr<T>) -> Expr<T> {
1066        self.with_expr_kind(ExprKind::UnaryApp {
1067            op: UnaryOp::Not,
1068            arg: Arc::new(e),
1069        })
1070    }
1071
1072    /// Create a '==' expression
1073    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1074        self.with_expr_kind(ExprKind::BinaryApp {
1075            op: BinaryOp::Eq,
1076            arg1: Arc::new(e1),
1077            arg2: Arc::new(e2),
1078        })
1079    }
1080
1081    /// Create an 'and' expression. Arguments must evaluate to Bool type
1082    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1083        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1084            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1085                ExprKind::Lit(Literal::Bool(*b1 && *b2))
1086            }
1087            _ => ExprKind::And {
1088                left: Arc::new(e1),
1089                right: Arc::new(e2),
1090            },
1091        })
1092    }
1093
1094    /// Create an 'or' expression. Arguments must evaluate to Bool type
1095    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1096        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1097            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1098                ExprKind::Lit(Literal::Bool(*b1 || *b2))
1099            }
1100
1101            _ => ExprKind::Or {
1102                left: Arc::new(e1),
1103                right: Arc::new(e2),
1104            },
1105        })
1106    }
1107
1108    /// Create a '<' expression. Arguments must evaluate to Long type
1109    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1110        self.with_expr_kind(ExprKind::BinaryApp {
1111            op: BinaryOp::Less,
1112            arg1: Arc::new(e1),
1113            arg2: Arc::new(e2),
1114        })
1115    }
1116
1117    /// Create a '<=' expression. Arguments must evaluate to Long type
1118    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1119        self.with_expr_kind(ExprKind::BinaryApp {
1120            op: BinaryOp::LessEq,
1121            arg1: Arc::new(e1),
1122            arg2: Arc::new(e2),
1123        })
1124    }
1125
1126    /// Create an 'add' expression. Arguments must evaluate to Long type
1127    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1128        self.with_expr_kind(ExprKind::BinaryApp {
1129            op: BinaryOp::Add,
1130            arg1: Arc::new(e1),
1131            arg2: Arc::new(e2),
1132        })
1133    }
1134
1135    /// Create a 'sub' expression. Arguments must evaluate to Long type
1136    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1137        self.with_expr_kind(ExprKind::BinaryApp {
1138            op: BinaryOp::Sub,
1139            arg1: Arc::new(e1),
1140            arg2: Arc::new(e2),
1141        })
1142    }
1143
1144    /// Create a 'mul' expression. Arguments must evaluate to Long type
1145    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1146        self.with_expr_kind(ExprKind::BinaryApp {
1147            op: BinaryOp::Mul,
1148            arg1: Arc::new(e1),
1149            arg2: Arc::new(e2),
1150        })
1151    }
1152
1153    /// Create a 'neg' expression. `e` must evaluate to Long type.
1154    fn neg(self, e: Expr<T>) -> Expr<T> {
1155        self.with_expr_kind(ExprKind::UnaryApp {
1156            op: UnaryOp::Neg,
1157            arg: Arc::new(e),
1158        })
1159    }
1160
1161    /// Create an 'in' expression. First argument must evaluate to Entity type.
1162    /// Second argument must evaluate to either Entity type or Set type where
1163    /// all set elements have Entity type.
1164    fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1165        self.with_expr_kind(ExprKind::BinaryApp {
1166            op: BinaryOp::In,
1167            arg1: Arc::new(e1),
1168            arg2: Arc::new(e2),
1169        })
1170    }
1171
1172    /// Create a 'contains' expression.
1173    /// First argument must have Set type.
1174    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1175        self.with_expr_kind(ExprKind::BinaryApp {
1176            op: BinaryOp::Contains,
1177            arg1: Arc::new(e1),
1178            arg2: Arc::new(e2),
1179        })
1180    }
1181
1182    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1183    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1184        self.with_expr_kind(ExprKind::BinaryApp {
1185            op: BinaryOp::ContainsAll,
1186            arg1: Arc::new(e1),
1187            arg2: Arc::new(e2),
1188        })
1189    }
1190
1191    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1192    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1193        self.with_expr_kind(ExprKind::BinaryApp {
1194            op: BinaryOp::ContainsAny,
1195            arg1: Arc::new(e1),
1196            arg2: Arc::new(e2),
1197        })
1198    }
1199
1200    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1201    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1202        self.with_expr_kind(ExprKind::UnaryApp {
1203            op: UnaryOp::IsEmpty,
1204            arg: Arc::new(expr),
1205        })
1206    }
1207
1208    /// Create a 'getTag' expression.
1209    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1210    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1211        self.with_expr_kind(ExprKind::BinaryApp {
1212            op: BinaryOp::GetTag,
1213            arg1: Arc::new(expr),
1214            arg2: Arc::new(tag),
1215        })
1216    }
1217
1218    /// Create a 'hasTag' expression.
1219    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1220    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1221        self.with_expr_kind(ExprKind::BinaryApp {
1222            op: BinaryOp::HasTag,
1223            arg1: Arc::new(expr),
1224            arg2: Arc::new(tag),
1225        })
1226    }
1227
1228    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1229    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1230        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1231    }
1232
1233    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1234    fn record(
1235        self,
1236        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1237    ) -> Result<Expr<T>, ExpressionConstructionError> {
1238        let mut map = BTreeMap::new();
1239        for (k, v) in pairs {
1240            match map.entry(k) {
1241                btree_map::Entry::Occupied(oentry) => {
1242                    return Err(expression_construction_errors::DuplicateKeyError {
1243                        key: oentry.key().clone(),
1244                        context: "in record literal",
1245                    }
1246                    .into());
1247                }
1248                btree_map::Entry::Vacant(ventry) => {
1249                    ventry.insert(v);
1250                }
1251            }
1252        }
1253        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1254    }
1255
1256    /// Create an `Expr` which calls the extension function with the given
1257    /// `Name` on `args`
1258    fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1259        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1260            fn_name,
1261            args: Arc::new(args.into_iter().collect()),
1262        })
1263    }
1264
1265    /// Create an application `Expr` which applies the given built-in unary
1266    /// operator to the given `arg`
1267    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1268        self.with_expr_kind(ExprKind::UnaryApp {
1269            op: op.into(),
1270            arg: Arc::new(arg),
1271        })
1272    }
1273
1274    /// Create an application `Expr` which applies the given built-in binary
1275    /// operator to `arg1` and `arg2`
1276    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1277        self.with_expr_kind(ExprKind::BinaryApp {
1278            op: op.into(),
1279            arg1: Arc::new(arg1),
1280            arg2: Arc::new(arg2),
1281        })
1282    }
1283
1284    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1285    ///
1286    /// `expr` must evaluate to either Entity or Record type
1287    fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1288        self.with_expr_kind(ExprKind::GetAttr {
1289            expr: Arc::new(expr),
1290            attr,
1291        })
1292    }
1293
1294    /// Create an `Expr` which tests for the existence of a given
1295    /// attribute on a given `Entity` or record.
1296    ///
1297    /// `expr` must evaluate to either Entity or Record type
1298    fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1299        self.with_expr_kind(ExprKind::HasAttr {
1300            expr: Arc::new(expr),
1301            attr,
1302        })
1303    }
1304
1305    /// Create a 'like' expression.
1306    ///
1307    /// `expr` must evaluate to a String type
1308    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1309        self.with_expr_kind(ExprKind::Like {
1310            expr: Arc::new(expr),
1311            pattern,
1312        })
1313    }
1314
1315    /// Create an 'is' expression.
1316    fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1317        self.with_expr_kind(ExprKind::Is {
1318            expr: Arc::new(expr),
1319            entity_type,
1320        })
1321    }
1322
1323    /// Don't support AST Error nodes - return the error right back
1324    #[cfg(feature = "tolerant-ast")]
1325    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1326        Err(parse_errors)
1327    }
1328}
1329
1330impl<T> ExprBuilder<T> {
1331    /// Construct an `Expr` containing the `data` and `source_loc` in this
1332    /// `ExprBuilder` and the given `ExprKind`.
1333    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1334        Expr::new(expr_kind, self.source_loc, self.data)
1335    }
1336
1337    /// Create a ternary (if-then-else) `Expr`.
1338    /// Takes `Arc`s instead of owned `Expr`s.
1339    /// `test_expr` must evaluate to a Bool type
1340    pub fn ite_arc(
1341        self,
1342        test_expr: Arc<Expr<T>>,
1343        then_expr: Arc<Expr<T>>,
1344        else_expr: Arc<Expr<T>>,
1345    ) -> Expr<T> {
1346        self.with_expr_kind(ExprKind::If {
1347            test_expr,
1348            then_expr,
1349            else_expr,
1350        })
1351    }
1352
1353    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1354    ///
1355    /// If you have an iterator of pairs, generally prefer calling `.record()`
1356    /// instead of `.collect()`-ing yourself and calling this, potentially for
1357    /// efficiency reasons but also because `.record()` will properly handle
1358    /// duplicate keys but your own `.collect()` will not (by default).
1359    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1360        self.with_expr_kind(ExprKind::Record(map))
1361    }
1362}
1363
1364impl<T: Clone + Default> ExprBuilder<T> {
1365    /// Utility used the validator to get an expression with the same source
1366    /// location as an existing expression. This is done when reconstructing the
1367    /// `Expr` with type information.
1368    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1369        self.with_maybe_source_loc(expr.source_loc.as_ref())
1370    }
1371}
1372
1373/// Errors when constructing an expression
1374//
1375// CAUTION: this type is publicly exported in `cedar-policy`.
1376// Don't make fields `pub`, don't make breaking changes, and use caution
1377// when adding public methods.
1378#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1379pub enum ExpressionConstructionError {
1380    /// The same key occurred two or more times
1381    #[error(transparent)]
1382    #[diagnostic(transparent)]
1383    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1384}
1385
1386/// Error subtypes for [`ExpressionConstructionError`]
1387pub mod expression_construction_errors {
1388    use miette::Diagnostic;
1389    use smol_str::SmolStr;
1390    use thiserror::Error;
1391
1392    /// The same key occurred two or more times
1393    //
1394    // CAUTION: this type is publicly exported in `cedar-policy`.
1395    // Don't make fields `pub`, don't make breaking changes, and use caution
1396    // when adding public methods.
1397    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1398    #[error("duplicate key `{key}` {context}")]
1399    pub struct DuplicateKeyError {
1400        /// The key which occurred two or more times
1401        pub(crate) key: SmolStr,
1402        /// Information about where the duplicate key occurred (e.g., "in record literal")
1403        pub(crate) context: &'static str,
1404    }
1405
1406    impl DuplicateKeyError {
1407        /// Get the key which occurred two or more times
1408        pub fn key(&self) -> &str {
1409            &self.key
1410        }
1411
1412        /// Make a new error with an updated `context` field
1413        pub(crate) fn with_context(self, context: &'static str) -> Self {
1414            Self { context, ..self }
1415        }
1416    }
1417}
1418
1419/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1420/// implementations that ignore any source information or other generic data
1421/// used to annotate the `Expr`.
1422#[derive(Debug, Clone)]
1423pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1424
1425impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1426    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1427    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1428    /// ignore source information and generic data.
1429    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1430        ExprShapeOnly(Cow::Borrowed(e))
1431    }
1432
1433    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1434    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1435    /// ignore source information and generic data.
1436    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1437        ExprShapeOnly(Cow::Owned(e))
1438    }
1439}
1440
1441impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1442    fn eq(&self, other: &Self) -> bool {
1443        self.0.eq_shape(&other.0)
1444    }
1445}
1446
1447impl<T: Clone> Eq for ExprShapeOnly<'_, T> {}
1448
1449impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1450    fn hash<H: Hasher>(&self, state: &mut H) {
1451        self.0.hash_shape(state);
1452    }
1453}
1454
1455impl<T: Clone> PartialOrd for ExprShapeOnly<'_, T> {
1456    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1457        Some(self.cmp(other))
1458    }
1459}
1460
1461impl<T: Clone> Ord for ExprShapeOnly<'_, T> {
1462    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1463        self.0.cmp_shape(&other.0)
1464    }
1465}
1466
1467impl<T> Expr<T> {
1468    /// Return true if this expression (recursively) has the same expression
1469    /// kind as the argument expression. This accounts for the full recursive
1470    /// shape of the expression, but does not consider source information or any
1471    /// generic data annotated on expression. This should behave the same as the
1472    /// default implementation of `Eq` before source information and generic
1473    /// data were added.
1474    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1475        use ExprKind::*;
1476        match (self.expr_kind(), other.expr_kind()) {
1477            (Lit(lit), Lit(lit1)) => lit == lit1,
1478            (Var(v), Var(v1)) => v == v1,
1479            (Slot(s), Slot(s1)) => s == s1,
1480            (
1481                Unknown(self::Unknown {
1482                    name: name1,
1483                    type_annotation: ta_1,
1484                }),
1485                Unknown(self::Unknown {
1486                    name: name2,
1487                    type_annotation: ta_2,
1488                }),
1489            ) => (name1 == name2) && (ta_1 == ta_2),
1490            (
1491                If {
1492                    test_expr,
1493                    then_expr,
1494                    else_expr,
1495                },
1496                If {
1497                    test_expr: test_expr1,
1498                    then_expr: then_expr1,
1499                    else_expr: else_expr1,
1500                },
1501            ) => {
1502                test_expr.eq_shape(test_expr1)
1503                    && then_expr.eq_shape(then_expr1)
1504                    && else_expr.eq_shape(else_expr1)
1505            }
1506            (
1507                And { left, right },
1508                And {
1509                    left: left1,
1510                    right: right1,
1511                },
1512            )
1513            | (
1514                Or { left, right },
1515                Or {
1516                    left: left1,
1517                    right: right1,
1518                },
1519            ) => left.eq_shape(left1) && right.eq_shape(right1),
1520            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1521                op == op1 && arg.eq_shape(arg1)
1522            }
1523            (
1524                BinaryApp { op, arg1, arg2 },
1525                BinaryApp {
1526                    op: op1,
1527                    arg1: arg11,
1528                    arg2: arg21,
1529                },
1530            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1531            (
1532                ExtensionFunctionApp { fn_name, args },
1533                ExtensionFunctionApp {
1534                    fn_name: fn_name1,
1535                    args: args1,
1536                },
1537            ) => {
1538                fn_name == fn_name1
1539                    && args.len() == args1.len()
1540                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1541            }
1542            (
1543                GetAttr { expr, attr },
1544                GetAttr {
1545                    expr: expr1,
1546                    attr: attr1,
1547                },
1548            )
1549            | (
1550                HasAttr { expr, attr },
1551                HasAttr {
1552                    expr: expr1,
1553                    attr: attr1,
1554                },
1555            ) => attr == attr1 && expr.eq_shape(expr1),
1556            (
1557                Like { expr, pattern },
1558                Like {
1559                    expr: expr1,
1560                    pattern: pattern1,
1561                },
1562            ) => pattern == pattern1 && expr.eq_shape(expr1),
1563            (Set(elems), Set(elems1)) => {
1564                elems.len() == elems1.len()
1565                    && elems
1566                        .iter()
1567                        .zip(elems1.iter())
1568                        .all(|(e, e1)| e.eq_shape(e1))
1569            }
1570            (Record(map), Record(map1)) => {
1571                map.len() == map1.len()
1572                    && map
1573                        .iter()
1574                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1575                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1576            }
1577            (
1578                Is { expr, entity_type },
1579                Is {
1580                    expr: expr1,
1581                    entity_type: entity_type1,
1582                },
1583            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1584            _ => false,
1585        }
1586    }
1587
1588    /// Implementation of hashing corresponding to equality as implemented by
1589    /// `eq_shape`. Must satisfy the usual relationship between equality and
1590    /// hashing.
1591    pub fn hash_shape<H>(&self, state: &mut H)
1592    where
1593        H: Hasher,
1594    {
1595        mem::discriminant(self).hash(state);
1596        match self.expr_kind() {
1597            ExprKind::Lit(lit) => lit.hash(state),
1598            ExprKind::Var(v) => v.hash(state),
1599            ExprKind::Slot(s) => s.hash(state),
1600            ExprKind::Unknown(u) => u.hash(state),
1601            ExprKind::If {
1602                test_expr,
1603                then_expr,
1604                else_expr,
1605            } => {
1606                test_expr.hash_shape(state);
1607                then_expr.hash_shape(state);
1608                else_expr.hash_shape(state);
1609            }
1610            ExprKind::And { left, right } => {
1611                left.hash_shape(state);
1612                right.hash_shape(state);
1613            }
1614            ExprKind::Or { left, right } => {
1615                left.hash_shape(state);
1616                right.hash_shape(state);
1617            }
1618            ExprKind::UnaryApp { op, arg } => {
1619                op.hash(state);
1620                arg.hash_shape(state);
1621            }
1622            ExprKind::BinaryApp { op, arg1, arg2 } => {
1623                op.hash(state);
1624                arg1.hash_shape(state);
1625                arg2.hash_shape(state);
1626            }
1627            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1628                fn_name.hash(state);
1629                state.write_usize(args.len());
1630                args.iter().for_each(|a| {
1631                    a.hash_shape(state);
1632                });
1633            }
1634            ExprKind::GetAttr { expr, attr } => {
1635                expr.hash_shape(state);
1636                attr.hash(state);
1637            }
1638            ExprKind::HasAttr { expr, attr } => {
1639                expr.hash_shape(state);
1640                attr.hash(state);
1641            }
1642            ExprKind::Like { expr, pattern } => {
1643                expr.hash_shape(state);
1644                pattern.hash(state);
1645            }
1646            ExprKind::Set(elems) => {
1647                state.write_usize(elems.len());
1648                elems.iter().for_each(|e| {
1649                    e.hash_shape(state);
1650                })
1651            }
1652            ExprKind::Record(map) => {
1653                state.write_usize(map.len());
1654                map.iter().for_each(|(s, a)| {
1655                    s.hash(state);
1656                    a.hash_shape(state);
1657                });
1658            }
1659            ExprKind::Is { expr, entity_type } => {
1660                expr.hash_shape(state);
1661                entity_type.hash(state);
1662            }
1663            #[cfg(feature = "tolerant-ast")]
1664            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1665        }
1666    }
1667
1668    /// Implementation of ordering corresponding to equality as implemented by
1669    /// `eq_shape`. Must satisfy the usual relationship between equality and
1670    /// ordering.
1671    pub fn cmp_shape(&self, other: &Expr<T>) -> std::cmp::Ordering {
1672        // First compare variants for early short-circuiting using discriminant
1673        let self_kind = self.expr_kind();
1674        let other_kind = other.expr_kind();
1675        if std::mem::discriminant(self_kind) != std::mem::discriminant(other_kind) {
1676            return self_kind.variant_order().cmp(&other_kind.variant_order());
1677        }
1678
1679        // Same variants, compare contents
1680        use ExprKind::*;
1681        match (self_kind, other_kind) {
1682            (Lit(lit), Lit(lit1)) => lit.cmp(lit1),
1683            (Var(v), Var(v1)) => v.cmp(v1),
1684            (Slot(s), Slot(s1)) => s.cmp(s1),
1685            (
1686                Unknown(self::Unknown {
1687                    name: name1,
1688                    type_annotation: ta_1,
1689                }),
1690                Unknown(self::Unknown {
1691                    name: name2,
1692                    type_annotation: ta_2,
1693                }),
1694            ) => name1.cmp(name2).then_with(|| ta_1.cmp(ta_2)),
1695            (
1696                If {
1697                    test_expr,
1698                    then_expr,
1699                    else_expr,
1700                },
1701                If {
1702                    test_expr: test_expr1,
1703                    then_expr: then_expr1,
1704                    else_expr: else_expr1,
1705                },
1706            ) => test_expr
1707                .cmp_shape(test_expr1)
1708                .then_with(|| then_expr.cmp_shape(then_expr1))
1709                .then_with(|| else_expr.cmp_shape(else_expr1)),
1710            (
1711                And { left, right },
1712                And {
1713                    left: left1,
1714                    right: right1,
1715                },
1716            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1717            (
1718                Or { left, right },
1719                Or {
1720                    left: left1,
1721                    right: right1,
1722                },
1723            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1724            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1725                op.cmp(op1).then_with(|| arg.cmp_shape(arg1))
1726            }
1727            (
1728                BinaryApp { op, arg1, arg2 },
1729                BinaryApp {
1730                    op: op1,
1731                    arg1: arg11,
1732                    arg2: arg21,
1733                },
1734            ) => op
1735                .cmp(op1)
1736                .then_with(|| arg1.cmp_shape(arg11))
1737                .then_with(|| arg2.cmp_shape(arg21)),
1738            (
1739                ExtensionFunctionApp { fn_name, args },
1740                ExtensionFunctionApp {
1741                    fn_name: fn_name1,
1742                    args: args1,
1743                },
1744            ) => fn_name.cmp(fn_name1).then_with(|| {
1745                args.len().cmp(&args1.len()).then_with(|| {
1746                    for (a, a1) in args.iter().zip(args1.iter()) {
1747                        match a.cmp_shape(a1) {
1748                            std::cmp::Ordering::Equal => continue,
1749                            other => return other,
1750                        }
1751                    }
1752                    std::cmp::Ordering::Equal
1753                })
1754            }),
1755            (
1756                GetAttr { expr, attr },
1757                GetAttr {
1758                    expr: expr1,
1759                    attr: attr1,
1760                },
1761            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1762            (
1763                HasAttr { expr, attr },
1764                HasAttr {
1765                    expr: expr1,
1766                    attr: attr1,
1767                },
1768            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1769            (
1770                Like { expr, pattern },
1771                Like {
1772                    expr: expr1,
1773                    pattern: pattern1,
1774                },
1775            ) => pattern.cmp(pattern1).then_with(|| expr.cmp_shape(expr1)),
1776            (Set(elems), Set(elems1)) => elems.len().cmp(&elems1.len()).then_with(|| {
1777                for (e, e1) in elems.iter().zip(elems1.iter()) {
1778                    match e.cmp_shape(e1) {
1779                        std::cmp::Ordering::Equal => continue,
1780                        other => return other,
1781                    }
1782                }
1783                std::cmp::Ordering::Equal
1784            }),
1785            (Record(map), Record(map1)) => map.len().cmp(&map1.len()).then_with(|| {
1786                for ((a, e), (a1, e1)) in map.iter().zip(map1.iter()) {
1787                    match a.cmp(a1).then_with(|| e.cmp_shape(e1)) {
1788                        std::cmp::Ordering::Equal => continue,
1789                        other => return other,
1790                    }
1791                }
1792                std::cmp::Ordering::Equal
1793            }),
1794            (
1795                Is { expr, entity_type },
1796                Is {
1797                    expr: expr1,
1798                    entity_type: entity_type1,
1799                },
1800            ) => entity_type
1801                .cmp(entity_type1)
1802                .then_with(|| expr.cmp_shape(expr1)),
1803            #[cfg(feature = "tolerant-ast")]
1804            (
1805                Error { error_kind },
1806                Error {
1807                    error_kind: error_kind1,
1808                },
1809            ) => error_kind.cmp(error_kind1),
1810            #[expect(
1811                clippy::unreachable,
1812                reason = "This should never be reached since we compare variants first"
1813            )]
1814            _ => unreachable!(
1815                "Different variants should have been handled by variant_order comparison"
1816            ),
1817        }
1818    }
1819}
1820
1821/// AST variables
1822#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
1823#[serde(rename_all = "camelCase")]
1824#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1825#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1826#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1827pub enum Var {
1828    /// the Principal of the given request
1829    Principal,
1830    /// the Action of the given request
1831    Action,
1832    /// the Resource of the given request
1833    Resource,
1834    /// the Context of the given request
1835    Context,
1836}
1837
1838impl From<PrincipalOrResource> for Var {
1839    fn from(v: PrincipalOrResource) -> Self {
1840        match v {
1841            PrincipalOrResource::Principal => Var::Principal,
1842            PrincipalOrResource::Resource => Var::Resource,
1843        }
1844    }
1845}
1846
1847#[expect(
1848    clippy::fallible_impl_from,
1849    reason = "Tested by `test::all_vars_are_ids`. Never panics"
1850)]
1851impl From<Var> for Id {
1852    fn from(var: Var) -> Self {
1853        #[expect(
1854            clippy::unwrap_used,
1855            reason = "`Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`"
1856        )]
1857        format!("{var}").parse().unwrap()
1858    }
1859}
1860
1861#[expect(
1862    clippy::fallible_impl_from,
1863    reason = "Tested by `test::all_vars_are_ids`. Never panics"
1864)]
1865impl From<Var> for UnreservedId {
1866    fn from(var: Var) -> Self {
1867        #[expect(
1868            clippy::unwrap_used,
1869            reason = "`Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`"
1870        )]
1871        Id::from(var).try_into().unwrap()
1872    }
1873}
1874
1875impl std::fmt::Display for Var {
1876    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1877        match self {
1878            Self::Principal => write!(f, "principal"),
1879            Self::Action => write!(f, "action"),
1880            Self::Resource => write!(f, "resource"),
1881            Self::Context => write!(f, "context"),
1882        }
1883    }
1884}
1885
1886#[cfg(test)]
1887mod test {
1888    use cool_asserts::assert_matches;
1889    use itertools::Itertools;
1890    use smol_str::ToSmolStr;
1891    use std::collections::{hash_map::DefaultHasher, HashSet};
1892
1893    use crate::expr_builder::ExprBuilder as _;
1894
1895    use super::*;
1896
1897    pub fn all_vars() -> impl Iterator<Item = Var> {
1898        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1899    }
1900
1901    // Tests that Var::Into never panics
1902    #[test]
1903    fn all_vars_are_ids() {
1904        for var in all_vars() {
1905            let _id: Id = var.into();
1906            let _id: UnreservedId = var.into();
1907        }
1908    }
1909
1910    #[test]
1911    fn exprs() {
1912        assert_eq!(
1913            Expr::val(33),
1914            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1915        );
1916        assert_eq!(
1917            Expr::val("hello"),
1918            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1919        );
1920        assert_eq!(
1921            Expr::val(EntityUID::with_eid("foo")),
1922            Expr::new(
1923                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1924                None,
1925                ()
1926            )
1927        );
1928        assert_eq!(
1929            Expr::var(Var::Principal),
1930            Expr::new(ExprKind::Var(Var::Principal), None, ())
1931        );
1932        assert_eq!(
1933            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1934            Expr::new(
1935                ExprKind::If {
1936                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1937                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1938                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1939                },
1940                None,
1941                ()
1942            )
1943        );
1944        assert_eq!(
1945            Expr::not(Expr::val(false)),
1946            Expr::new(
1947                ExprKind::UnaryApp {
1948                    op: UnaryOp::Not,
1949                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1950                },
1951                None,
1952                ()
1953            )
1954        );
1955        assert_eq!(
1956            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1957            Expr::new(
1958                ExprKind::GetAttr {
1959                    expr: Arc::new(Expr::new(
1960                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1961                        None,
1962                        ()
1963                    )),
1964                    attr: "some_attr".into()
1965                },
1966                None,
1967                ()
1968            )
1969        );
1970        assert_eq!(
1971            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1972            Expr::new(
1973                ExprKind::HasAttr {
1974                    expr: Arc::new(Expr::new(
1975                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1976                        None,
1977                        ()
1978                    )),
1979                    attr: "some_attr".into()
1980                },
1981                None,
1982                ()
1983            )
1984        );
1985        assert_eq!(
1986            Expr::is_entity_type(
1987                Expr::val(EntityUID::with_eid("foo")),
1988                "Type".parse().unwrap()
1989            ),
1990            Expr::new(
1991                ExprKind::Is {
1992                    expr: Arc::new(Expr::new(
1993                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1994                        None,
1995                        ()
1996                    )),
1997                    entity_type: "Type".parse().unwrap()
1998                },
1999                None,
2000                ()
2001            ),
2002        );
2003    }
2004
2005    #[test]
2006    fn like_display() {
2007        // `\0` escaped form is `\0`.
2008        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
2009        assert_eq!(format!("{e}"), r#""a" like "\0""#);
2010        // `\`'s escaped form is `\\`
2011        let e = Expr::like(
2012            Expr::val("a"),
2013            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
2014        );
2015        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
2016        // `\`'s escaped form is `\\`
2017        let e = Expr::like(
2018            Expr::val("a"),
2019            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
2020        );
2021        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
2022        // literal star's escaped from is `\*`
2023        let e = Expr::like(
2024            Expr::val("a"),
2025            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
2026        );
2027        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
2028    }
2029
2030    #[test]
2031    fn has_display() {
2032        // `\0` escaped form is `\0`.
2033        let e = Expr::has_attr(Expr::val("a"), "\0".into());
2034        assert_eq!(format!("{e}"), r#""a" has "\0""#);
2035        // `\`'s escaped form is `\\`
2036        let e = Expr::has_attr(Expr::val("a"), r"\".into());
2037        assert_eq!(format!("{e}"), r#""a" has "\\""#);
2038    }
2039
2040    #[test]
2041    fn slot_display() {
2042        let e = Expr::slot(SlotId::principal());
2043        assert_eq!(format!("{e}"), "?principal");
2044        let e = Expr::slot(SlotId::resource());
2045        assert_eq!(format!("{e}"), "?resource");
2046        let e = Expr::val(EntityUID::with_eid("eid"));
2047        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
2048    }
2049
2050    #[test]
2051    fn simple_slots() {
2052        let e = Expr::slot(SlotId::principal());
2053        let p = SlotId::principal();
2054        let r = SlotId::resource();
2055        let set: HashSet<SlotId> = HashSet::from_iter([p]);
2056        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2057        let e = Expr::or(
2058            Expr::slot(SlotId::principal()),
2059            Expr::ite(
2060                Expr::val(true),
2061                Expr::slot(SlotId::resource()),
2062                Expr::val(false),
2063            ),
2064        );
2065        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
2066        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2067    }
2068
2069    #[test]
2070    fn unknowns() {
2071        let e = Expr::ite(
2072            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2073            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2074            Expr::unknown(Unknown::new_untyped("c")),
2075        );
2076        let unknowns = e.unknowns().collect_vec();
2077        assert_eq!(unknowns.len(), 3);
2078        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
2079        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
2080        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
2081    }
2082
2083    #[test]
2084    fn is_unknown() {
2085        let e = Expr::ite(
2086            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2087            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2088            Expr::unknown(Unknown::new_untyped("c")),
2089        );
2090        assert!(e.contains_unknown());
2091        let e = Expr::ite(
2092            Expr::not(Expr::val(true)),
2093            Expr::and(Expr::val(1), Expr::val(3)),
2094            Expr::val(1),
2095        );
2096        assert!(!e.contains_unknown());
2097    }
2098
2099    #[test]
2100    fn expr_with_data() {
2101        let e = ExprBuilder::with_data("data").val(1);
2102        assert_eq!(e.into_data(), "data");
2103    }
2104
2105    #[test]
2106    fn expr_shape_only_eq() {
2107        let temp = ExprBuilder::with_data(1).val(1);
2108        let exprs = &[
2109            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
2110            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
2111            (
2112                ExprBuilder::with_data(1).var(Var::Principal),
2113                Expr::var(Var::Principal),
2114            ),
2115            (
2116                ExprBuilder::with_data(1).slot(SlotId::principal()),
2117                Expr::slot(SlotId::principal()),
2118            ),
2119            (
2120                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
2121                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
2122            ),
2123            (
2124                ExprBuilder::with_data(1).not(temp.clone()),
2125                Expr::not(Expr::val(1)),
2126            ),
2127            (
2128                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
2129                Expr::is_eq(Expr::val(1), Expr::val(1)),
2130            ),
2131            (
2132                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
2133                Expr::and(Expr::val(1), Expr::val(1)),
2134            ),
2135            (
2136                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
2137                Expr::or(Expr::val(1), Expr::val(1)),
2138            ),
2139            (
2140                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
2141                Expr::less(Expr::val(1), Expr::val(1)),
2142            ),
2143            (
2144                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
2145                Expr::lesseq(Expr::val(1), Expr::val(1)),
2146            ),
2147            (
2148                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
2149                Expr::greater(Expr::val(1), Expr::val(1)),
2150            ),
2151            (
2152                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
2153                Expr::greatereq(Expr::val(1), Expr::val(1)),
2154            ),
2155            (
2156                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
2157                Expr::add(Expr::val(1), Expr::val(1)),
2158            ),
2159            (
2160                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
2161                Expr::sub(Expr::val(1), Expr::val(1)),
2162            ),
2163            (
2164                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
2165                Expr::mul(Expr::val(1), Expr::val(1)),
2166            ),
2167            (
2168                ExprBuilder::with_data(1).neg(temp.clone()),
2169                Expr::neg(Expr::val(1)),
2170            ),
2171            (
2172                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
2173                Expr::is_in(Expr::val(1), Expr::val(1)),
2174            ),
2175            (
2176                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
2177                Expr::contains(Expr::val(1), Expr::val(1)),
2178            ),
2179            (
2180                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
2181                Expr::contains_all(Expr::val(1), Expr::val(1)),
2182            ),
2183            (
2184                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
2185                Expr::contains_any(Expr::val(1), Expr::val(1)),
2186            ),
2187            (
2188                ExprBuilder::with_data(1).is_empty(temp.clone()),
2189                Expr::is_empty(Expr::val(1)),
2190            ),
2191            (
2192                ExprBuilder::with_data(1).set([temp.clone()]),
2193                Expr::set([Expr::val(1)]),
2194            ),
2195            (
2196                ExprBuilder::with_data(1)
2197                    .record([("foo".into(), temp.clone())])
2198                    .unwrap(),
2199                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
2200            ),
2201            (
2202                ExprBuilder::with_data(1)
2203                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
2204                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
2205            ),
2206            (
2207                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
2208                Expr::get_attr(Expr::val(1), "foo".into()),
2209            ),
2210            (
2211                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
2212                Expr::has_attr(Expr::val(1), "foo".into()),
2213            ),
2214            (
2215                ExprBuilder::with_data(1)
2216                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
2217                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2218            ),
2219            (
2220                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2221                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2222            ),
2223        ];
2224
2225        for (e0, e1) in exprs {
2226            assert!(e0.eq_shape(e0));
2227            assert!(e1.eq_shape(e1));
2228            assert!(e0.eq_shape(e1));
2229            assert!(e1.eq_shape(e0));
2230
2231            let mut hasher0 = DefaultHasher::new();
2232            e0.hash_shape(&mut hasher0);
2233            let hash0 = hasher0.finish();
2234
2235            let mut hasher1 = DefaultHasher::new();
2236            e1.hash_shape(&mut hasher1);
2237            let hash1 = hasher1.finish();
2238
2239            assert_eq!(hash0, hash1);
2240        }
2241    }
2242
2243    #[test]
2244    fn expr_shape_only_not_eq() {
2245        let expr1 = ExprBuilder::with_data(1).val(1);
2246        let expr2 = ExprBuilder::with_data(1).val(2);
2247        assert_ne!(
2248            ExprShapeOnly::new_from_borrowed(&expr1),
2249            ExprShapeOnly::new_from_borrowed(&expr2)
2250        );
2251    }
2252
2253    #[test]
2254    fn expr_shape_only_set_prefix_ne() {
2255        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2256        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2257        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2258
2259        assert_ne!(e1, e2);
2260        assert_ne!(e1, e3);
2261        assert_ne!(e2, e1);
2262        assert_ne!(e2, e3);
2263        assert_ne!(e3, e1);
2264        assert_ne!(e2, e1);
2265    }
2266
2267    #[test]
2268    fn expr_shape_only_ext_fn_arg_prefix_ne() {
2269        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2270            "decimal".parse().unwrap(),
2271            vec![],
2272        ));
2273        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2274            "decimal".parse().unwrap(),
2275            vec![Expr::val("0.0")],
2276        ));
2277        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2278            "decimal".parse().unwrap(),
2279            vec![Expr::val("0.0"), Expr::val("0.0")],
2280        ));
2281
2282        assert_ne!(e1, e2);
2283        assert_ne!(e1, e3);
2284        assert_ne!(e2, e1);
2285        assert_ne!(e2, e3);
2286        assert_ne!(e3, e1);
2287        assert_ne!(e2, e1);
2288    }
2289
2290    #[test]
2291    fn expr_shape_only_record_attr_prefix_ne() {
2292        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2293        let e2 = ExprShapeOnly::new_from_owned(
2294            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2295        );
2296        let e3 = ExprShapeOnly::new_from_owned(
2297            Expr::record([
2298                ("a".to_smolstr(), Expr::val(1)),
2299                ("b".to_smolstr(), Expr::val(2)),
2300            ])
2301            .unwrap(),
2302        );
2303
2304        assert_ne!(e1, e2);
2305        assert_ne!(e1, e3);
2306        assert_ne!(e2, e1);
2307        assert_ne!(e2, e3);
2308        assert_ne!(e3, e1);
2309        assert_ne!(e2, e1);
2310    }
2311
2312    #[test]
2313    fn untyped_subst_present() {
2314        let u = Unknown {
2315            name: "foo".into(),
2316            type_annotation: None,
2317        };
2318        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2319        match r {
2320            Ok(e) => assert_eq!(e, Expr::val(1)),
2321            Err(empty) => match empty {},
2322        }
2323    }
2324
2325    #[test]
2326    fn untyped_subst_present_correct_type() {
2327        let u = Unknown {
2328            name: "foo".into(),
2329            type_annotation: Some(Type::Long),
2330        };
2331        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2332        match r {
2333            Ok(e) => assert_eq!(e, Expr::val(1)),
2334            Err(empty) => match empty {},
2335        }
2336    }
2337
2338    #[test]
2339    fn untyped_subst_present_wrong_type() {
2340        let u = Unknown {
2341            name: "foo".into(),
2342            type_annotation: Some(Type::Bool),
2343        };
2344        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2345        match r {
2346            Ok(e) => assert_eq!(e, Expr::val(1)),
2347            Err(empty) => match empty {},
2348        }
2349    }
2350
2351    #[test]
2352    fn untyped_subst_not_present() {
2353        let u = Unknown {
2354            name: "foo".into(),
2355            type_annotation: Some(Type::Bool),
2356        };
2357        let r = UntypedSubstitution::substitute(&u, None);
2358        match r {
2359            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2360            Err(empty) => match empty {},
2361        }
2362    }
2363
2364    #[test]
2365    fn typed_subst_present() {
2366        let u = Unknown {
2367            name: "foo".into(),
2368            type_annotation: None,
2369        };
2370        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2371        assert_eq!(e, Expr::val(1));
2372    }
2373
2374    #[test]
2375    fn typed_subst_present_correct_type() {
2376        let u = Unknown {
2377            name: "foo".into(),
2378            type_annotation: Some(Type::Long),
2379        };
2380        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2381        assert_eq!(e, Expr::val(1));
2382    }
2383
2384    #[test]
2385    fn typed_subst_present_wrong_type() {
2386        let u = Unknown {
2387            name: "foo".into(),
2388            type_annotation: Some(Type::Bool),
2389        };
2390        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2391        assert_matches!(
2392            r,
2393            SubstitutionError::TypeError {
2394                expected: Type::Bool,
2395                actual: Type::Long,
2396            }
2397        );
2398    }
2399
2400    #[test]
2401    fn typed_subst_not_present() {
2402        let u = Unknown {
2403            name: "foo".into(),
2404            type_annotation: None,
2405        };
2406        let r = TypedSubstitution::substitute(&u, None).unwrap();
2407        assert_eq!(r, Expr::unknown(u));
2408    }
2409}