Skip to main content

cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use {
19    super::expr_allows_errors::AstExprErrorKind,
20    crate::parser::err::{ToASTError, ToASTErrorKind},
21};
22
23use crate::{
24    ast::*,
25    expr_builder::{self, ExprBuilder as _},
26    extensions::Extensions,
27    parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34    borrow::Cow,
35    collections::{btree_map, BTreeMap, HashMap},
36    hash::{Hash, Hasher},
37    mem,
38    sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45/// Internal AST for expressions used by the policy evaluator.
46/// This structure is a wrapper around an `ExprKind`, which is the expression
47/// variant this object contains. It also contains source information about
48/// where the expression was written in policy source code, and some generic
49/// data which is stored on each node of the AST.
50/// Cloning is O(1).
51#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54    expr_kind: ExprKind<T>,
55    #[educe(PartialEq(ignore))]
56    #[educe(Hash(ignore))]
57    source_loc: Option<Loc>,
58    data: T,
59}
60
61/// The possible expression variants. This enum should be matched on by code
62/// recursively traversing the AST.
63#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65    /// Literal value
66    Lit(Literal),
67    /// Variable
68    Var(Var),
69    /// Template Slots
70    Slot(SlotId),
71    /// Symbolic Unknown for partial-eval
72    Unknown(Unknown),
73    /// Ternary expression
74    If {
75        /// Condition for the ternary expression. Must evaluate to Bool type
76        test_expr: Arc<Expr<T>>,
77        /// Value if true
78        then_expr: Arc<Expr<T>>,
79        /// Value if false
80        else_expr: Arc<Expr<T>>,
81    },
82    /// Boolean AND
83    And {
84        /// Left operand, which will be eagerly evaluated
85        left: Arc<Expr<T>>,
86        /// Right operand, which may not be evaluated due to short-circuiting
87        right: Arc<Expr<T>>,
88    },
89    /// Boolean OR
90    Or {
91        /// Left operand, which will be eagerly evaluated
92        left: Arc<Expr<T>>,
93        /// Right operand, which may not be evaluated due to short-circuiting
94        right: Arc<Expr<T>>,
95    },
96    /// Application of a built-in unary operator (single parameter)
97    UnaryApp {
98        /// Unary operator to apply
99        op: UnaryOp,
100        /// Argument to apply operator to
101        arg: Arc<Expr<T>>,
102    },
103    /// Application of a built-in binary operator (two parameters)
104    BinaryApp {
105        /// Binary operator to apply
106        op: BinaryOp,
107        /// First arg
108        arg1: Arc<Expr<T>>,
109        /// Second arg
110        arg2: Arc<Expr<T>>,
111    },
112    /// Application of an extension function to n arguments
113    /// INVARIANT (MethodStyleArgs):
114    ///   if op.style is MethodStyle then args _cannot_ be empty.
115    ///     The first element of args refers to the subject of the method call
116    /// Ideally, we find some way to make this non-representable.
117    ExtensionFunctionApp {
118        /// Extension function to apply
119        fn_name: Name,
120        /// Args to apply the function to
121        args: Arc<Vec<Expr<T>>>,
122    },
123    /// Get an attribute of an entity, or a field of a record
124    GetAttr {
125        /// Expression to get an attribute/field of. Must evaluate to either
126        /// Entity or Record type
127        expr: Arc<Expr<T>>,
128        /// Attribute or field to get
129        attr: SmolStr,
130    },
131    /// Does the given `expr` have the given `attr`?
132    HasAttr {
133        /// Expression to test. Must evaluate to either Entity or Record type
134        expr: Arc<Expr<T>>,
135        /// Attribute or field to check for
136        attr: SmolStr,
137    },
138    /// Regex-like string matching similar to IAM's `StringLike` operator.
139    Like {
140        /// Expression to test. Must evaluate to String type
141        expr: Arc<Expr<T>>,
142        /// Pattern to match on; can include the wildcard *, which matches any string.
143        /// To match a literal `*` in the test expression, users can use `\*`.
144        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
145        pattern: Pattern,
146    },
147    /// Entity type test. Does the first argument have the entity type
148    /// specified by the second argument.
149    Is {
150        /// Expression to test. Must evaluate to an Entity.
151        expr: Arc<Expr<T>>,
152        /// The [`EntityType`] used for the type membership test.
153        entity_type: EntityType,
154    },
155    /// Set (whose elements may be arbitrary expressions)
156    //
157    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
158    // that are syntactically unequal, may actually be semantically equal --
159    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
160    // evaluated into `Value`s
161    Set(Arc<Vec<Expr<T>>>),
162    /// Anonymous record (whose elements may be arbitrary expressions)
163    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164    #[cfg(feature = "tolerant-ast")]
165    /// Error expression - allows us to continue parsing even when we have errors
166    Error {
167        /// Type of error that led to the failure
168        error_kind: AstExprErrorKind,
169    },
170}
171
172impl<T> ExprKind<T> {
173    /// Get the variant order (same as derive(Ord) for enums)
174    fn variant_order(&self) -> u8 {
175        match self {
176            ExprKind::Lit(_) => 0,
177            ExprKind::Var(_) => 1,
178            ExprKind::Slot(_) => 2,
179            ExprKind::Unknown(_) => 3,
180            ExprKind::If { .. } => 4,
181            ExprKind::And { .. } => 5,
182            ExprKind::Or { .. } => 6,
183            ExprKind::UnaryApp { .. } => 7,
184            ExprKind::BinaryApp { .. } => 8,
185            ExprKind::ExtensionFunctionApp { .. } => 9,
186            ExprKind::GetAttr { .. } => 10,
187            ExprKind::HasAttr { .. } => 11,
188            ExprKind::Like { .. } => 12,
189            ExprKind::Set(_) => 13,
190            ExprKind::Record(_) => 14,
191            ExprKind::Is { .. } => 15,
192            #[cfg(feature = "tolerant-ast")]
193            ExprKind::Error { .. } => 16,
194        }
195    }
196}
197
198impl From<Value> for Expr {
199    fn from(v: Value) -> Self {
200        Expr::from(v.value).with_maybe_source_loc(v.loc)
201    }
202}
203
204impl From<ValueKind> for Expr {
205    fn from(v: ValueKind) -> Self {
206        match v {
207            ValueKind::Lit(lit) => Expr::val(lit),
208            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
209            #[expect(
210                clippy::expect_used,
211                reason = "cannot have duplicate key because the input was already a BTreeMap"
212            )]
213            ValueKind::Record(record) => Expr::record(
214                Arc::unwrap_or_clone(record)
215                    .into_iter()
216                    .map(|(k, v)| (k, Expr::from(v))),
217            )
218            .expect("cannot have duplicate key because the input was already a BTreeMap"),
219            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
220        }
221    }
222}
223
224impl From<PartialValue> for Expr {
225    fn from(pv: PartialValue) -> Self {
226        match pv {
227            PartialValue::Value(v) => Expr::from(v),
228            PartialValue::Residual(expr) => expr,
229        }
230    }
231}
232
233impl<T> Expr<T> {
234    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
235        Self {
236            expr_kind,
237            source_loc,
238            data,
239        }
240    }
241
242    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
243    /// `enum` which specifies the expression variant, so it must be accessed by
244    /// any code matching and recursing on an expression.
245    pub fn expr_kind(&self) -> &ExprKind<T> {
246        &self.expr_kind
247    }
248
249    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
250    pub fn into_expr_kind(self) -> ExprKind<T> {
251        self.expr_kind
252    }
253
254    /// Access the data stored on the `Expr`.
255    pub fn data(&self) -> &T {
256        &self.data
257    }
258
259    /// Access the data stored on the `Expr`, taking ownership and consuming the
260    /// `Expr`.
261    pub fn into_data(self) -> T {
262        self.data
263    }
264
265    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
266    /// data.
267    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
268        (self.expr_kind, self.source_loc, self.data)
269    }
270
271    /// Access the `Loc` stored on the `Expr`.
272    pub fn source_loc(&self) -> Option<&Loc> {
273        self.source_loc.as_ref()
274    }
275
276    /// Return the `Expr`, but with the new `source_loc` (or `None`).
277    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
278        Self { source_loc, ..self }
279    }
280
281    /// Update the data for this `Expr`. A convenient function used by the
282    /// Validator in one place.
283    pub fn set_data(&mut self, data: T) {
284        self.data = data;
285    }
286
287    /// Check whether this expression is an entity reference
288    ///
289    /// This is used for policy scopes, where some syntax is
290    /// required to be an entity reference.
291    pub fn is_ref(&self) -> bool {
292        match &self.expr_kind {
293            ExprKind::Lit(lit) => lit.is_ref(),
294            _ => false,
295        }
296    }
297
298    /// Check whether this expression is a slot.
299    pub fn is_slot(&self) -> bool {
300        matches!(&self.expr_kind, ExprKind::Slot(_))
301    }
302
303    /// Check whether this expression is a set of entity references
304    ///
305    /// This is used for policy scopes, where some syntax is
306    /// required to be an entity reference set.
307    pub fn is_ref_set(&self) -> bool {
308        match &self.expr_kind {
309            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
310            _ => false,
311        }
312    }
313
314    /// Iterate over all sub-expressions in this expression
315    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
316        expr_iterator::ExprIterator::new(self)
317    }
318
319    /// Iterate over all of the slots in this policy AST
320    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
321        self.subexpressions()
322            .filter_map(|exp| match &exp.expr_kind {
323                ExprKind::Slot(slotid) => Some(Slot {
324                    id: *slotid,
325                    loc: exp.source_loc().cloned(),
326                }),
327                _ => None,
328            })
329    }
330
331    /// Determine if the expression is projectable under partial evaluation
332    /// An expression is projectable if it's guaranteed to never error on evaluation
333    /// This is true if the expression is entirely composed of values or unknowns
334    pub fn is_projectable(&self) -> bool {
335        self.subexpressions().all(|e| {
336            matches!(
337                e.expr_kind(),
338                ExprKind::Lit(_)
339                    | ExprKind::Unknown(_)
340                    | ExprKind::Set(_)
341                    | ExprKind::Var(_)
342                    | ExprKind::Record(_)
343            )
344        })
345    }
346
347    /// Try to compute the runtime type of this expression. This operation may
348    /// fail (returning `None`), for example, when asked to get the type of any
349    /// variables, any attributes of entities or records, or an `unknown`
350    /// without an explicitly annotated type.
351    ///
352    /// Also note that this is _not_ typechecking the expression. It does not
353    /// check that the expression actually evaluates to a value (as opposed to
354    /// erroring).
355    ///
356    /// Because of these limitations, this function should only be used to
357    /// obtain a type for use in diagnostics such as error strings.
358    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
359        match &self.expr_kind {
360            ExprKind::Lit(l) => Some(l.type_of()),
361            ExprKind::Var(_) => None,
362            ExprKind::Slot(_) => None,
363            ExprKind::Unknown(u) => u.type_annotation.clone(),
364            ExprKind::If {
365                then_expr,
366                else_expr,
367                ..
368            } => {
369                let type_of_then = then_expr.try_type_of(extensions);
370                let type_of_else = else_expr.try_type_of(extensions);
371                if type_of_then == type_of_else {
372                    type_of_then
373                } else {
374                    None
375                }
376            }
377            ExprKind::And { .. } => Some(Type::Bool),
378            ExprKind::Or { .. } => Some(Type::Bool),
379            ExprKind::UnaryApp {
380                op: UnaryOp::Neg, ..
381            } => Some(Type::Long),
382            ExprKind::UnaryApp {
383                op: UnaryOp::Not, ..
384            } => Some(Type::Bool),
385            ExprKind::UnaryApp {
386                op: UnaryOp::IsEmpty,
387                ..
388            } => Some(Type::Bool),
389            ExprKind::BinaryApp {
390                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
391                ..
392            } => Some(Type::Long),
393            ExprKind::BinaryApp {
394                op:
395                    BinaryOp::Contains
396                    | BinaryOp::ContainsAll
397                    | BinaryOp::ContainsAny
398                    | BinaryOp::Eq
399                    | BinaryOp::In
400                    | BinaryOp::Less
401                    | BinaryOp::LessEq,
402                ..
403            } => Some(Type::Bool),
404            ExprKind::BinaryApp {
405                op: BinaryOp::HasTag,
406                ..
407            } => Some(Type::Bool),
408            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
409                .func(fn_name)
410                .ok()?
411                .return_type()
412                .map(|rty| rty.clone().into()),
413            // We could try to be more complete here, but we can't do all that
414            // much better without evaluating the argument. Even if we know it's
415            // a record `Type::Record` tells us nothing about the type of the
416            // attribute.
417            ExprKind::GetAttr { .. } => None,
418            // similarly to `GetAttr`
419            ExprKind::BinaryApp {
420                op: BinaryOp::GetTag,
421                ..
422            } => None,
423            ExprKind::HasAttr { .. } => Some(Type::Bool),
424            ExprKind::Like { .. } => Some(Type::Bool),
425            ExprKind::Is { .. } => Some(Type::Bool),
426            ExprKind::Set(_) => Some(Type::Set),
427            ExprKind::Record(_) => Some(Type::Record),
428            #[cfg(feature = "tolerant-ast")]
429            ExprKind::Error { .. } => None,
430        }
431    }
432
433    /// Converts an `Expr<V>` to `B::Expr` using the provided builder.
434    ///
435    /// Preserves source location information and recursively transforms each expression node.
436    /// Note: Data may be cloned if the source expression is retained elsewhere.
437    /// Convert this expression to a `B::Expr`, where `B` is a fallible builder.
438    /// Uses `try_call_extension_fn` for extension function calls.
439    pub fn try_into_expr<B: expr_builder::ExprBuilder>(self) -> Result<B::Expr, B::BuildError>
440    where
441        T: Clone,
442    {
443        let builder = B::new().with_maybe_source_loc(self.source_loc());
444        match self.into_expr_kind() {
445            ExprKind::Lit(lit) => Ok(builder.val(lit)),
446            ExprKind::Var(var) => Ok(builder.var(var)),
447            ExprKind::Slot(slot) => Ok(builder.slot(slot)),
448            ExprKind::Unknown(u) => Ok(builder.unknown(u)),
449            ExprKind::If {
450                test_expr,
451                then_expr,
452                else_expr,
453            } => Ok(builder.ite(
454                Arc::unwrap_or_clone(test_expr).try_into_expr::<B>()?,
455                Arc::unwrap_or_clone(then_expr).try_into_expr::<B>()?,
456                Arc::unwrap_or_clone(else_expr).try_into_expr::<B>()?,
457            )),
458            ExprKind::And { left, right } => Ok(builder.and(
459                Arc::unwrap_or_clone(left).try_into_expr::<B>()?,
460                Arc::unwrap_or_clone(right).try_into_expr::<B>()?,
461            )),
462            ExprKind::Or { left, right } => Ok(builder.or(
463                Arc::unwrap_or_clone(left).try_into_expr::<B>()?,
464                Arc::unwrap_or_clone(right).try_into_expr::<B>()?,
465            )),
466            ExprKind::UnaryApp { op, arg } => {
467                Ok(builder.unary_app(op, Arc::unwrap_or_clone(arg).try_into_expr::<B>()?))
468            }
469            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(builder.binary_app(
470                op,
471                Arc::unwrap_or_clone(arg1).try_into_expr::<B>()?,
472                Arc::unwrap_or_clone(arg2).try_into_expr::<B>()?,
473            )),
474            ExprKind::ExtensionFunctionApp { fn_name, args } => {
475                let args: Vec<_> = Arc::unwrap_or_clone(args)
476                    .into_iter()
477                    .map(|e| e.try_into_expr::<B>())
478                    .collect::<Result<_, _>>()?;
479                builder.call_extension_fn(fn_name, args)
480            }
481            ExprKind::GetAttr { expr, attr } => {
482                Ok(builder.get_attr(Arc::unwrap_or_clone(expr).try_into_expr::<B>()?, attr))
483            }
484            ExprKind::HasAttr { expr, attr } => {
485                Ok(builder.has_attr(Arc::unwrap_or_clone(expr).try_into_expr::<B>()?, attr))
486            }
487            ExprKind::Like { expr, pattern } => {
488                Ok(builder.like(Arc::unwrap_or_clone(expr).try_into_expr::<B>()?, pattern))
489            }
490            ExprKind::Is { expr, entity_type } => Ok(builder.is_entity_type(
491                Arc::unwrap_or_clone(expr).try_into_expr::<B>()?,
492                entity_type,
493            )),
494            ExprKind::Set(set) => Ok(builder.set(
495                Arc::unwrap_or_clone(set)
496                    .into_iter()
497                    .map(|e| e.try_into_expr::<B>())
498                    .collect::<Result<Vec<_>, _>>()?,
499            )),
500            #[expect(
501                clippy::unwrap_used,
502                reason = "`map` is a map, so it will not have duplicate keys, so the `.record()` constructor cannot error"
503            )]
504            ExprKind::Record(map) => Ok(builder
505                .record(
506                    Arc::unwrap_or_clone(map)
507                        .into_iter()
508                        .map(|(k, v)| Ok((k, v.try_into_expr::<B>()?)))
509                        .collect::<Result<Vec<_>, _>>()?,
510                )
511                .unwrap()),
512            #[cfg(feature = "tolerant-ast")]
513            #[expect(
514                clippy::unwrap_used,
515                reason = "error type is Infallible so can never happen"
516            )]
517            ExprKind::Error { .. } => Ok(builder
518                .error(ParseErrors::singleton(ToASTError::new(
519                    ToASTErrorKind::ASTErrorNode,
520                    Some(Loc::new(0..1, "AST_ERROR_NODE".into())),
521                )))
522                .unwrap()), // we could have unwrap_infallible + trait bound  but attributes in where clauses are unstable
523        }
524    }
525
526    /// Convert this expression to a `B::Expr`, where `B` is an infallible builder.
527    pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
528    where
529        T: Clone,
530        B::BuildError: IsInfallible,
531    {
532        self.try_into_expr::<B>().unwrap_infallible()
533    }
534}
535
536#[expect(
537    clippy::should_implement_trait,
538    reason = "the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`"
539)]
540impl Expr {
541    /// Create an `Expr` that's just a single `Literal`.
542    ///
543    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
544    pub fn val(v: impl Into<Literal>) -> Self {
545        ExprBuilder::new().val(v)
546    }
547
548    /// Create an `Expr` that's just a single `Unknown`.
549    pub fn unknown(u: Unknown) -> Self {
550        ExprBuilder::new().unknown(u)
551    }
552
553    /// Create an `Expr` that's just this literal `Var`
554    pub fn var(v: Var) -> Self {
555        ExprBuilder::new().var(v)
556    }
557
558    /// Create an `Expr` that's just this `SlotId`
559    pub fn slot(s: SlotId) -> Self {
560        ExprBuilder::new().slot(s)
561    }
562
563    /// Create a ternary (if-then-else) `Expr`.
564    ///
565    /// `test_expr` must evaluate to a Bool type
566    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
567        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
568    }
569
570    /// Create a ternary (if-then-else) `Expr`.
571    /// Takes `Arc`s instead of owned `Expr`s.
572    /// `test_expr` must evaluate to a Bool type
573    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
574        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
575    }
576
577    /// Create a 'not' expression. `e` must evaluate to Bool type
578    pub fn not(e: Expr) -> Self {
579        ExprBuilder::new().not(e)
580    }
581
582    /// Create a '==' expression
583    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
584        ExprBuilder::new().is_eq(e1, e2)
585    }
586
587    /// Create a '!=' expression
588    pub fn noteq(e1: Expr, e2: Expr) -> Self {
589        ExprBuilder::new().noteq(e1, e2)
590    }
591
592    /// Create an 'and' expression. Arguments must evaluate to Bool type
593    pub fn and(e1: Expr, e2: Expr) -> Self {
594        ExprBuilder::new().and(e1, e2)
595    }
596
597    /// Create an 'or' expression. Arguments must evaluate to Bool type
598    pub fn or(e1: Expr, e2: Expr) -> Self {
599        ExprBuilder::new().or(e1, e2)
600    }
601
602    /// Create a '<' expression. Arguments must evaluate to Long type
603    pub fn less(e1: Expr, e2: Expr) -> Self {
604        ExprBuilder::new().less(e1, e2)
605    }
606
607    /// Create a '<=' expression. Arguments must evaluate to Long type
608    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
609        ExprBuilder::new().lesseq(e1, e2)
610    }
611
612    /// Create a '>' expression. Arguments must evaluate to Long type
613    pub fn greater(e1: Expr, e2: Expr) -> Self {
614        ExprBuilder::new().greater(e1, e2)
615    }
616
617    /// Create a '>=' expression. Arguments must evaluate to Long type
618    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
619        ExprBuilder::new().greatereq(e1, e2)
620    }
621
622    /// Create an 'add' expression. Arguments must evaluate to Long type
623    pub fn add(e1: Expr, e2: Expr) -> Self {
624        ExprBuilder::new().add(e1, e2)
625    }
626
627    /// Create a 'sub' expression. Arguments must evaluate to Long type
628    pub fn sub(e1: Expr, e2: Expr) -> Self {
629        ExprBuilder::new().sub(e1, e2)
630    }
631
632    /// Create a 'mul' expression. Arguments must evaluate to Long type
633    pub fn mul(e1: Expr, e2: Expr) -> Self {
634        ExprBuilder::new().mul(e1, e2)
635    }
636
637    /// Create a 'neg' expression. `e` must evaluate to Long type.
638    pub fn neg(e: Expr) -> Self {
639        ExprBuilder::new().neg(e)
640    }
641
642    /// Create an 'in' expression. First argument must evaluate to Entity type.
643    /// Second argument must evaluate to either Entity type or Set type where
644    /// all set elements have Entity type.
645    pub fn is_in(e1: Expr, e2: Expr) -> Self {
646        ExprBuilder::new().is_in(e1, e2)
647    }
648
649    /// Create a `contains` expression.
650    /// First argument must have Set type.
651    pub fn contains(e1: Expr, e2: Expr) -> Self {
652        ExprBuilder::new().contains(e1, e2)
653    }
654
655    /// Create a `containsAll` expression. Arguments must evaluate to Set type
656    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
657        ExprBuilder::new().contains_all(e1, e2)
658    }
659
660    /// Create a `containsAny` expression. Arguments must evaluate to Set type
661    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
662        ExprBuilder::new().contains_any(e1, e2)
663    }
664
665    /// Create a `isEmpty` expression. Argument must evaluate to Set type
666    pub fn is_empty(e: Expr) -> Self {
667        ExprBuilder::new().is_empty(e)
668    }
669
670    /// Create a `getTag` expression.
671    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
672    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
673        ExprBuilder::new().get_tag(expr, tag)
674    }
675
676    /// Create a `hasTag` expression.
677    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
678    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
679        ExprBuilder::new().has_tag(expr, tag)
680    }
681
682    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
683    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
684        ExprBuilder::new().set(exprs)
685    }
686
687    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
688    pub fn record(
689        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
690    ) -> Result<Self, ExpressionConstructionError> {
691        ExprBuilder::new().record(pairs)
692    }
693
694    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
695    ///
696    /// If you have an iterator of pairs, generally prefer calling
697    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
698    /// potentially for efficiency reasons but also because `Expr::record()`
699    /// will properly handle duplicate keys but your own `.collect()` will not
700    /// (by default).
701    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
702        ExprBuilder::new().record_arc(map)
703    }
704
705    /// Create an `Expr` which calls the extension function with the given
706    /// `Name` on `args`
707    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
708        ExprBuilder::new()
709            .call_extension_fn(fn_name, args)
710            .unwrap_infallible()
711    }
712
713    /// Create an application `Expr` which applies the given built-in unary
714    /// operator to the given `arg`
715    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
716        ExprBuilder::new().unary_app(op, arg)
717    }
718
719    /// Create an application `Expr` which applies the given built-in binary
720    /// operator to `arg1` and `arg2`
721    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
722        ExprBuilder::new().binary_app(op, arg1, arg2)
723    }
724
725    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
726    ///
727    /// `expr` must evaluate to either Entity or Record type
728    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
729        ExprBuilder::new().get_attr(expr, attr)
730    }
731
732    /// Create an `Expr` which tests for the existence of a given
733    /// attribute on a given `Entity` or record.
734    ///
735    /// `expr` must evaluate to either Entity or Record type
736    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
737        ExprBuilder::new().has_attr(expr, attr)
738    }
739
740    /// Create a 'like' expression.
741    ///
742    /// `expr` must evaluate to a String type
743    pub fn like(expr: Expr, pattern: Pattern) -> Self {
744        ExprBuilder::new().like(expr, pattern)
745    }
746
747    /// Create an `is` expression.
748    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
749        ExprBuilder::new().is_entity_type(expr, entity_type)
750    }
751
752    /// Check if an expression contains any symbolic unknowns
753    pub fn contains_unknown(&self) -> bool {
754        self.subexpressions()
755            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
756    }
757
758    /// Get all unknowns in an expression
759    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
760        self.subexpressions()
761            .filter_map(|subexpr| match subexpr.expr_kind() {
762                ExprKind::Unknown(u) => Some(u),
763                _ => None,
764            })
765    }
766
767    /// Substitute unknowns with concrete values.
768    ///
769    /// Ignores unmapped unknowns.
770    /// Ignores type annotations on unknowns.
771    /// Note that there might be "undiscovered unknowns" in the Expr, which
772    /// this function does not notice if evaluation of this Expr did not
773    /// traverse all entities and attributes during evaluation, leading to
774    /// this function only substituting one unknown at a time.
775    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
776        match self.substitute_general::<UntypedSubstitution>(definitions) {
777            Ok(e) => e,
778            Err(empty) => match empty {},
779        }
780    }
781
782    /// Substitute unknowns with concrete values.
783    ///
784    /// Ignores unmapped unknowns.
785    /// Errors if the substituted value does not match the type annotation on the unknown.
786    /// Note that there might be "undiscovered unknowns" in the Expr, which
787    /// this function does not notice if evaluation of this Expr did not
788    /// traverse all entities and attributes during evaluation, leading to
789    /// this function only substituting one unknown at a time.
790    pub fn substitute_typed(
791        &self,
792        definitions: &HashMap<SmolStr, Value>,
793    ) -> Result<Expr, SubstitutionError> {
794        self.substitute_general::<TypedSubstitution>(definitions)
795    }
796
797    /// Substitute unknowns with values
798    ///
799    /// Generic over the function implementing the substitution to allow for multiple error behaviors
800    fn substitute_general<T: SubstitutionFunction>(
801        &self,
802        definitions: &HashMap<SmolStr, Value>,
803    ) -> Result<Expr, T::Err> {
804        match self.expr_kind() {
805            ExprKind::Lit(_) => Ok(self.clone()),
806            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
807            ExprKind::Var(_) => Ok(self.clone()),
808            ExprKind::Slot(_) => Ok(self.clone()),
809            ExprKind::If {
810                test_expr,
811                then_expr,
812                else_expr,
813            } => Ok(Expr::ite(
814                test_expr.substitute_general::<T>(definitions)?,
815                then_expr.substitute_general::<T>(definitions)?,
816                else_expr.substitute_general::<T>(definitions)?,
817            )),
818            ExprKind::And { left, right } => Ok(Expr::and(
819                left.substitute_general::<T>(definitions)?,
820                right.substitute_general::<T>(definitions)?,
821            )),
822            ExprKind::Or { left, right } => Ok(Expr::or(
823                left.substitute_general::<T>(definitions)?,
824                right.substitute_general::<T>(definitions)?,
825            )),
826            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
827                *op,
828                arg.substitute_general::<T>(definitions)?,
829            )),
830            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
831                *op,
832                arg1.substitute_general::<T>(definitions)?,
833                arg2.substitute_general::<T>(definitions)?,
834            )),
835            ExprKind::ExtensionFunctionApp { fn_name, args } => {
836                let args = args
837                    .iter()
838                    .map(|e| e.substitute_general::<T>(definitions))
839                    .collect::<Result<Vec<Expr>, _>>()?;
840
841                Ok(Expr::call_extension_fn(fn_name.clone(), args))
842            }
843            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
844                expr.substitute_general::<T>(definitions)?,
845                attr.clone(),
846            )),
847            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
848                expr.substitute_general::<T>(definitions)?,
849                attr.clone(),
850            )),
851            ExprKind::Like { expr, pattern } => Ok(Expr::like(
852                expr.substitute_general::<T>(definitions)?,
853                pattern.clone(),
854            )),
855            ExprKind::Set(members) => {
856                let members = members
857                    .iter()
858                    .map(|e| e.substitute_general::<T>(definitions))
859                    .collect::<Result<Vec<_>, _>>()?;
860                Ok(Expr::set(members))
861            }
862            ExprKind::Record(map) => {
863                let map = map
864                    .iter()
865                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
866                    .collect::<Result<BTreeMap<_, _>, _>>()?;
867                #[expect(
868                    clippy::expect_used,
869                    reason = "cannot have a duplicate key because the input was already a BTreeMap"
870                )]
871                Ok(Expr::record(map)
872                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
873            }
874            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
875                expr.substitute_general::<T>(definitions)?,
876                entity_type.clone(),
877            )),
878            #[cfg(feature = "tolerant-ast")]
879            ExprKind::Error { .. } => Ok(self.clone()),
880        }
881    }
882}
883
884/// A trait for customizing the error behavior of substitution
885trait SubstitutionFunction {
886    /// The potential errors this substitution function can return
887    type Err;
888    /// The function for implementing the substitution.
889    ///
890    /// Takes the expression being substituted,
891    /// The substitution from the map (if present)
892    /// and the type annotation from the unknown (if present)
893    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
894}
895
896struct TypedSubstitution {}
897
898impl SubstitutionFunction for TypedSubstitution {
899    type Err = SubstitutionError;
900
901    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
902        match (substitute, &value.type_annotation) {
903            (None, _) => Ok(Expr::unknown(value.clone())),
904            (Some(v), None) => Ok(v.clone().into()),
905            (Some(v), Some(t)) => {
906                if v.type_of() == *t {
907                    Ok(v.clone().into())
908                } else {
909                    Err(SubstitutionError::TypeError {
910                        expected: t.clone(),
911                        actual: v.type_of(),
912                    })
913                }
914            }
915        }
916    }
917}
918
919struct UntypedSubstitution {}
920
921impl SubstitutionFunction for UntypedSubstitution {
922    type Err = std::convert::Infallible;
923
924    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
925        Ok(substitute
926            .map(|v| v.clone().into())
927            .unwrap_or_else(|| Expr::unknown(value.clone())))
928    }
929}
930
931impl<T: Clone> std::fmt::Display for Expr<T> {
932    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
933        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
934        // we just convert to EST and use the EST pretty-printer.
935        // Note that converting AST->EST is lossless and infallible.
936        write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
937    }
938}
939
940impl<T: Clone> BoundedDisplay for Expr<T> {
941    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
942        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
943        // pretty-printer. Note that converting AST->EST is lossless and infallible.
944        BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
945    }
946}
947
948impl std::str::FromStr for Expr {
949    type Err = ParseErrors;
950
951    fn from_str(s: &str) -> Result<Expr, Self::Err> {
952        crate::parser::parse_expr(s)
953    }
954}
955
956/// Enum for errors encountered during substitution
957#[derive(Debug, Clone, Diagnostic, Error)]
958pub enum SubstitutionError {
959    /// The supplied value did not match the type annotation on the unknown.
960    #[error("expected a value of type {expected}, got a value of type {actual}")]
961    TypeError {
962        /// The expected type, ie: the type the unknown was annotated with
963        expected: Type,
964        /// The type of the provided value
965        actual: Type,
966    },
967}
968
969/// Representation of a partial-evaluation Unknown at the AST level
970#[derive(Hash, Debug, Clone, PartialEq, Eq)]
971pub struct Unknown {
972    /// The name of the unknown
973    pub name: SmolStr,
974    /// The type of the values that can be substituted in for the unknown.
975    /// If `None`, we have no type annotation, and thus a value of any type can
976    /// be substituted.
977    pub type_annotation: Option<Type>,
978}
979
980impl Unknown {
981    /// Create a new untyped `Unknown`
982    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
983        Self {
984            name: name.into(),
985            type_annotation: None,
986        }
987    }
988
989    /// Create a new `Unknown` with type annotation. (Only values of the given
990    /// type can be substituted.)
991    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
992        Self {
993            name: name.into(),
994            type_annotation: Some(ty),
995        }
996    }
997}
998
999impl std::fmt::Display for Unknown {
1000    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1001        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
1002        // to avoid code duplication
1003        write!(
1004            f,
1005            "{}",
1006            Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
1007        )
1008    }
1009}
1010
1011/// Builder for constructing `Expr` objects annotated with some `data`
1012/// (possibly taking default value) and optionally a `source_loc`.
1013#[derive(Clone, Debug)]
1014pub struct ExprBuilder<T> {
1015    source_loc: Option<Loc>,
1016    data: T,
1017}
1018
1019impl<T: Default + Clone> expr_builder::ExprBuilderInfallibleBuild for ExprBuilder<T> {}
1020
1021impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
1022    type Expr = Expr<T>;
1023
1024    type Data = T;
1025
1026    type BuildError = Infallible;
1027
1028    #[cfg(feature = "tolerant-ast")]
1029    type ErrorType = ParseErrors;
1030
1031    fn loc(&self) -> Option<&Loc> {
1032        self.source_loc.as_ref()
1033    }
1034
1035    fn data(&self) -> &Self::Data {
1036        &self.data
1037    }
1038
1039    fn with_data(data: T) -> Self {
1040        Self {
1041            source_loc: None,
1042            data,
1043        }
1044    }
1045
1046    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
1047        self.source_loc = maybe_source_loc.cloned();
1048        self
1049    }
1050
1051    /// Create an `Expr` that's just a single `Literal`.
1052    ///
1053    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
1054    fn val(self, v: impl Into<Literal>) -> Expr<T> {
1055        self.with_expr_kind(ExprKind::Lit(v.into()))
1056    }
1057
1058    /// Create an `Unknown` `Expr`
1059    fn unknown(self, u: Unknown) -> Expr<T> {
1060        self.with_expr_kind(ExprKind::Unknown(u))
1061    }
1062
1063    /// Create an `Expr` that's just this literal `Var`
1064    fn var(self, v: Var) -> Expr<T> {
1065        self.with_expr_kind(ExprKind::Var(v))
1066    }
1067
1068    /// Create an `Expr` that's just this `SlotId`
1069    fn slot(self, s: SlotId) -> Expr<T> {
1070        self.with_expr_kind(ExprKind::Slot(s))
1071    }
1072
1073    /// Create a ternary (if-then-else) `Expr`.
1074    /// Takes `Arc`s instead of owned `Expr`s.
1075    /// `test_expr` must evaluate to a Bool type
1076    fn ite_arc(
1077        self,
1078        test_expr: Arc<Expr<T>>,
1079        then_expr: Arc<Expr<T>>,
1080        else_expr: Arc<Expr<T>>,
1081    ) -> Expr<T> {
1082        self.with_expr_kind(ExprKind::If {
1083            test_expr,
1084            then_expr,
1085            else_expr,
1086        })
1087    }
1088
1089    /// Create a 'not' expression. `e` must evaluate to Bool type
1090    fn not(self, e: Expr<T>) -> Expr<T> {
1091        self.with_expr_kind(ExprKind::UnaryApp {
1092            op: UnaryOp::Not,
1093            arg: Arc::new(e),
1094        })
1095    }
1096
1097    /// Create a '==' expression
1098    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1099        self.with_expr_kind(ExprKind::BinaryApp {
1100            op: BinaryOp::Eq,
1101            arg1: Arc::new(e1),
1102            arg2: Arc::new(e2),
1103        })
1104    }
1105
1106    /// Create an 'and' expression. Arguments must evaluate to Bool type
1107    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1108        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1109            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1110                ExprKind::Lit(Literal::Bool(*b1 && *b2))
1111            }
1112            _ => ExprKind::And {
1113                left: Arc::new(e1),
1114                right: Arc::new(e2),
1115            },
1116        })
1117    }
1118
1119    /// Create an 'or' expression. Arguments must evaluate to Bool type
1120    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1121        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1122            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1123                ExprKind::Lit(Literal::Bool(*b1 || *b2))
1124            }
1125
1126            _ => ExprKind::Or {
1127                left: Arc::new(e1),
1128                right: Arc::new(e2),
1129            },
1130        })
1131    }
1132
1133    /// Create a '<' expression. Arguments must evaluate to Long type
1134    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1135        self.with_expr_kind(ExprKind::BinaryApp {
1136            op: BinaryOp::Less,
1137            arg1: Arc::new(e1),
1138            arg2: Arc::new(e2),
1139        })
1140    }
1141
1142    /// Create a '<=' expression. Arguments must evaluate to Long type
1143    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1144        self.with_expr_kind(ExprKind::BinaryApp {
1145            op: BinaryOp::LessEq,
1146            arg1: Arc::new(e1),
1147            arg2: Arc::new(e2),
1148        })
1149    }
1150
1151    /// Create an 'add' expression. Arguments must evaluate to Long type
1152    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1153        self.with_expr_kind(ExprKind::BinaryApp {
1154            op: BinaryOp::Add,
1155            arg1: Arc::new(e1),
1156            arg2: Arc::new(e2),
1157        })
1158    }
1159
1160    /// Create a 'sub' expression. Arguments must evaluate to Long type
1161    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1162        self.with_expr_kind(ExprKind::BinaryApp {
1163            op: BinaryOp::Sub,
1164            arg1: Arc::new(e1),
1165            arg2: Arc::new(e2),
1166        })
1167    }
1168
1169    /// Create a 'mul' expression. Arguments must evaluate to Long type
1170    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1171        self.with_expr_kind(ExprKind::BinaryApp {
1172            op: BinaryOp::Mul,
1173            arg1: Arc::new(e1),
1174            arg2: Arc::new(e2),
1175        })
1176    }
1177
1178    /// Create a 'neg' expression. `e` must evaluate to Long type.
1179    fn neg(self, e: Expr<T>) -> Expr<T> {
1180        self.with_expr_kind(ExprKind::UnaryApp {
1181            op: UnaryOp::Neg,
1182            arg: Arc::new(e),
1183        })
1184    }
1185
1186    /// Create an 'in' expression. First argument must evaluate to Entity type.
1187    /// Second argument must evaluate to either Entity type or Set type where
1188    /// all set elements have Entity type.
1189    fn is_in_arc(self, arg1: Arc<Expr<T>>, arg2: Arc<Expr<T>>) -> Expr<T> {
1190        self.with_expr_kind(ExprKind::BinaryApp {
1191            op: BinaryOp::In,
1192            arg1,
1193            arg2,
1194        })
1195    }
1196
1197    /// Create a 'contains' expression.
1198    /// First argument must have Set type.
1199    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1200        self.with_expr_kind(ExprKind::BinaryApp {
1201            op: BinaryOp::Contains,
1202            arg1: Arc::new(e1),
1203            arg2: Arc::new(e2),
1204        })
1205    }
1206
1207    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1208    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1209        self.with_expr_kind(ExprKind::BinaryApp {
1210            op: BinaryOp::ContainsAll,
1211            arg1: Arc::new(e1),
1212            arg2: Arc::new(e2),
1213        })
1214    }
1215
1216    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1217    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1218        self.with_expr_kind(ExprKind::BinaryApp {
1219            op: BinaryOp::ContainsAny,
1220            arg1: Arc::new(e1),
1221            arg2: Arc::new(e2),
1222        })
1223    }
1224
1225    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1226    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1227        self.with_expr_kind(ExprKind::UnaryApp {
1228            op: UnaryOp::IsEmpty,
1229            arg: Arc::new(expr),
1230        })
1231    }
1232
1233    /// Create a 'getTag' expression.
1234    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1235    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1236        self.with_expr_kind(ExprKind::BinaryApp {
1237            op: BinaryOp::GetTag,
1238            arg1: Arc::new(expr),
1239            arg2: Arc::new(tag),
1240        })
1241    }
1242
1243    /// Create a 'hasTag' expression.
1244    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1245    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1246        self.with_expr_kind(ExprKind::BinaryApp {
1247            op: BinaryOp::HasTag,
1248            arg1: Arc::new(expr),
1249            arg2: Arc::new(tag),
1250        })
1251    }
1252
1253    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1254    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1255        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1256    }
1257
1258    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1259    fn record(
1260        self,
1261        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1262    ) -> Result<Expr<T>, ExpressionConstructionError> {
1263        let mut map = BTreeMap::new();
1264        for (k, v) in pairs {
1265            match map.entry(k) {
1266                btree_map::Entry::Occupied(oentry) => {
1267                    return Err(expression_construction_errors::DuplicateKeyError {
1268                        key: oentry.key().clone(),
1269                        context: "in record literal",
1270                    }
1271                    .into());
1272                }
1273                btree_map::Entry::Vacant(ventry) => {
1274                    ventry.insert(v);
1275                }
1276            }
1277        }
1278        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1279    }
1280
1281    /// Create an `Expr` which calls the extension function with the given
1282    /// `Name` on `args`
1283    fn call_extension_fn(
1284        self,
1285        fn_name: Name,
1286        args: impl IntoIterator<Item = Expr<T>>,
1287    ) -> Result<Expr<T>, Infallible> {
1288        Ok(self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1289            fn_name,
1290            args: Arc::new(args.into_iter().collect()),
1291        }))
1292    }
1293
1294    /// Create an application `Expr` which applies the given built-in unary
1295    /// operator to the given `arg`
1296    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1297        self.with_expr_kind(ExprKind::UnaryApp {
1298            op: op.into(),
1299            arg: Arc::new(arg),
1300        })
1301    }
1302
1303    /// Create an application `Expr` which applies the given built-in binary
1304    /// operator to `arg1` and `arg2`
1305    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1306        self.with_expr_kind(ExprKind::BinaryApp {
1307            op: op.into(),
1308            arg1: Arc::new(arg1),
1309            arg2: Arc::new(arg2),
1310        })
1311    }
1312
1313    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1314    ///
1315    /// `expr` must evaluate to either Entity or Record type
1316    fn get_attr_arc(self, expr: Arc<Expr<T>>, attr: SmolStr) -> Expr<T> {
1317        self.with_expr_kind(ExprKind::GetAttr { expr, attr })
1318    }
1319
1320    /// Create an `Expr` which tests for the existence of a given
1321    /// attribute on a given `Entity` or record.
1322    ///
1323    /// `expr` must evaluate to either Entity or Record type
1324    fn has_attr_arc(self, expr: Arc<Expr<T>>, attr: SmolStr) -> Expr<T> {
1325        self.with_expr_kind(ExprKind::HasAttr { expr, attr })
1326    }
1327
1328    /// Create a 'like' expression.
1329    ///
1330    /// `expr` must evaluate to a String type
1331    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1332        self.with_expr_kind(ExprKind::Like {
1333            expr: Arc::new(expr),
1334            pattern,
1335        })
1336    }
1337
1338    /// Create an 'is' expression.
1339    fn is_entity_type_arc(self, expr: Arc<Expr<T>>, entity_type: EntityType) -> Expr<T> {
1340        self.with_expr_kind(ExprKind::Is { expr, entity_type })
1341    }
1342
1343    /// Don't support AST Error nodes - return the error right back
1344    #[cfg(feature = "tolerant-ast")]
1345    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1346        Err(parse_errors)
1347    }
1348}
1349
1350impl<T> ExprBuilder<T> {
1351    /// Construct an `Expr` containing the `data` and `source_loc` in this
1352    /// `ExprBuilder` and the given `ExprKind`.
1353    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1354        Expr::new(expr_kind, self.source_loc, self.data)
1355    }
1356
1357    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1358    ///
1359    /// If you have an iterator of pairs, generally prefer calling `.record()`
1360    /// instead of `.collect()`-ing yourself and calling this, potentially for
1361    /// efficiency reasons but also because `.record()` will properly handle
1362    /// duplicate keys but your own `.collect()` will not (by default).
1363    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1364        self.with_expr_kind(ExprKind::Record(map))
1365    }
1366}
1367
1368impl<T: Clone + Default> ExprBuilder<T> {
1369    /// Utility used the validator to get an expression with the same source
1370    /// location as an existing expression. This is done when reconstructing the
1371    /// `Expr` with type information.
1372    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1373        self.with_maybe_source_loc(expr.source_loc.as_ref())
1374    }
1375}
1376
1377/// Errors when constructing an expression
1378//
1379// CAUTION: this type is publicly exported in `cedar-policy`.
1380// Don't make fields `pub`, don't make breaking changes, and use caution
1381// when adding public methods.
1382#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1383pub enum ExpressionConstructionError {
1384    /// The same key occurred two or more times
1385    #[error(transparent)]
1386    #[diagnostic(transparent)]
1387    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1388}
1389
1390/// Error subtypes for [`ExpressionConstructionError`]
1391pub mod expression_construction_errors {
1392    use miette::Diagnostic;
1393    use smol_str::SmolStr;
1394    use thiserror::Error;
1395
1396    /// The same key occurred two or more times
1397    //
1398    // CAUTION: this type is publicly exported in `cedar-policy`.
1399    // Don't make fields `pub`, don't make breaking changes, and use caution
1400    // when adding public methods.
1401    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1402    #[error("duplicate key `{key}` {context}")]
1403    pub struct DuplicateKeyError {
1404        /// The key which occurred two or more times
1405        pub(crate) key: SmolStr,
1406        /// Information about where the duplicate key occurred (e.g., "in record literal")
1407        pub(crate) context: &'static str,
1408    }
1409
1410    impl DuplicateKeyError {
1411        /// Get the key which occurred two or more times
1412        pub fn key(&self) -> &str {
1413            &self.key
1414        }
1415
1416        /// Make a new error with an updated `context` field
1417        pub(crate) fn with_context(self, context: &'static str) -> Self {
1418            Self { context, ..self }
1419        }
1420    }
1421}
1422
1423/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1424/// implementations that ignore any source information or other generic data
1425/// used to annotate the `Expr`.
1426#[derive(Debug, Clone)]
1427pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1428
1429impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1430    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1431    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1432    /// ignore source information and generic data.
1433    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1434        ExprShapeOnly(Cow::Borrowed(e))
1435    }
1436
1437    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1438    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1439    /// ignore source information and generic data.
1440    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1441        ExprShapeOnly(Cow::Owned(e))
1442    }
1443}
1444
1445impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1446    fn eq(&self, other: &Self) -> bool {
1447        self.0.eq_shape(&other.0)
1448    }
1449}
1450
1451impl<T: Clone> Eq for ExprShapeOnly<'_, T> {}
1452
1453impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1454    fn hash<H: Hasher>(&self, state: &mut H) {
1455        self.0.hash_shape(state);
1456    }
1457}
1458
1459impl<T: Clone> PartialOrd for ExprShapeOnly<'_, T> {
1460    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1461        Some(self.cmp(other))
1462    }
1463}
1464
1465impl<T: Clone> Ord for ExprShapeOnly<'_, T> {
1466    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1467        self.0.cmp_shape(&other.0)
1468    }
1469}
1470
1471impl<T> Expr<T> {
1472    /// Return true if this expression (recursively) has the same expression
1473    /// kind as the argument expression. This accounts for the full recursive
1474    /// shape of the expression, but does not consider source information or any
1475    /// generic data annotated on expression. This should behave the same as the
1476    /// default implementation of `Eq` before source information and generic
1477    /// data were added.
1478    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1479        use ExprKind::*;
1480        match (self.expr_kind(), other.expr_kind()) {
1481            (Lit(lit), Lit(lit1)) => lit == lit1,
1482            (Var(v), Var(v1)) => v == v1,
1483            (Slot(s), Slot(s1)) => s == s1,
1484            (
1485                Unknown(self::Unknown {
1486                    name: name1,
1487                    type_annotation: ta_1,
1488                }),
1489                Unknown(self::Unknown {
1490                    name: name2,
1491                    type_annotation: ta_2,
1492                }),
1493            ) => (name1 == name2) && (ta_1 == ta_2),
1494            (
1495                If {
1496                    test_expr,
1497                    then_expr,
1498                    else_expr,
1499                },
1500                If {
1501                    test_expr: test_expr1,
1502                    then_expr: then_expr1,
1503                    else_expr: else_expr1,
1504                },
1505            ) => {
1506                test_expr.eq_shape(test_expr1)
1507                    && then_expr.eq_shape(then_expr1)
1508                    && else_expr.eq_shape(else_expr1)
1509            }
1510            (
1511                And { left, right },
1512                And {
1513                    left: left1,
1514                    right: right1,
1515                },
1516            )
1517            | (
1518                Or { left, right },
1519                Or {
1520                    left: left1,
1521                    right: right1,
1522                },
1523            ) => left.eq_shape(left1) && right.eq_shape(right1),
1524            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1525                op == op1 && arg.eq_shape(arg1)
1526            }
1527            (
1528                BinaryApp { op, arg1, arg2 },
1529                BinaryApp {
1530                    op: op1,
1531                    arg1: arg11,
1532                    arg2: arg21,
1533                },
1534            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1535            (
1536                ExtensionFunctionApp { fn_name, args },
1537                ExtensionFunctionApp {
1538                    fn_name: fn_name1,
1539                    args: args1,
1540                },
1541            ) => {
1542                fn_name == fn_name1
1543                    && args.len() == args1.len()
1544                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1545            }
1546            (
1547                GetAttr { expr, attr },
1548                GetAttr {
1549                    expr: expr1,
1550                    attr: attr1,
1551                },
1552            )
1553            | (
1554                HasAttr { expr, attr },
1555                HasAttr {
1556                    expr: expr1,
1557                    attr: attr1,
1558                },
1559            ) => attr == attr1 && expr.eq_shape(expr1),
1560            (
1561                Like { expr, pattern },
1562                Like {
1563                    expr: expr1,
1564                    pattern: pattern1,
1565                },
1566            ) => pattern == pattern1 && expr.eq_shape(expr1),
1567            (Set(elems), Set(elems1)) => {
1568                elems.len() == elems1.len()
1569                    && elems
1570                        .iter()
1571                        .zip(elems1.iter())
1572                        .all(|(e, e1)| e.eq_shape(e1))
1573            }
1574            (Record(map), Record(map1)) => {
1575                map.len() == map1.len()
1576                    && map
1577                        .iter()
1578                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1579                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1580            }
1581            (
1582                Is { expr, entity_type },
1583                Is {
1584                    expr: expr1,
1585                    entity_type: entity_type1,
1586                },
1587            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1588            _ => false,
1589        }
1590    }
1591
1592    /// Implementation of hashing corresponding to equality as implemented by
1593    /// `eq_shape`. Must satisfy the usual relationship between equality and
1594    /// hashing.
1595    pub fn hash_shape<H>(&self, state: &mut H)
1596    where
1597        H: Hasher,
1598    {
1599        mem::discriminant(self).hash(state);
1600        match self.expr_kind() {
1601            ExprKind::Lit(lit) => lit.hash(state),
1602            ExprKind::Var(v) => v.hash(state),
1603            ExprKind::Slot(s) => s.hash(state),
1604            ExprKind::Unknown(u) => u.hash(state),
1605            ExprKind::If {
1606                test_expr,
1607                then_expr,
1608                else_expr,
1609            } => {
1610                test_expr.hash_shape(state);
1611                then_expr.hash_shape(state);
1612                else_expr.hash_shape(state);
1613            }
1614            ExprKind::And { left, right } => {
1615                left.hash_shape(state);
1616                right.hash_shape(state);
1617            }
1618            ExprKind::Or { left, right } => {
1619                left.hash_shape(state);
1620                right.hash_shape(state);
1621            }
1622            ExprKind::UnaryApp { op, arg } => {
1623                op.hash(state);
1624                arg.hash_shape(state);
1625            }
1626            ExprKind::BinaryApp { op, arg1, arg2 } => {
1627                op.hash(state);
1628                arg1.hash_shape(state);
1629                arg2.hash_shape(state);
1630            }
1631            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1632                fn_name.hash(state);
1633                state.write_usize(args.len());
1634                args.iter().for_each(|a| {
1635                    a.hash_shape(state);
1636                });
1637            }
1638            ExprKind::GetAttr { expr, attr } => {
1639                expr.hash_shape(state);
1640                attr.hash(state);
1641            }
1642            ExprKind::HasAttr { expr, attr } => {
1643                expr.hash_shape(state);
1644                attr.hash(state);
1645            }
1646            ExprKind::Like { expr, pattern } => {
1647                expr.hash_shape(state);
1648                pattern.hash(state);
1649            }
1650            ExprKind::Set(elems) => {
1651                state.write_usize(elems.len());
1652                elems.iter().for_each(|e| {
1653                    e.hash_shape(state);
1654                })
1655            }
1656            ExprKind::Record(map) => {
1657                state.write_usize(map.len());
1658                map.iter().for_each(|(s, a)| {
1659                    s.hash(state);
1660                    a.hash_shape(state);
1661                });
1662            }
1663            ExprKind::Is { expr, entity_type } => {
1664                expr.hash_shape(state);
1665                entity_type.hash(state);
1666            }
1667            #[cfg(feature = "tolerant-ast")]
1668            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1669        }
1670    }
1671
1672    /// Implementation of ordering corresponding to equality as implemented by
1673    /// `eq_shape`. Must satisfy the usual relationship between equality and
1674    /// ordering.
1675    pub fn cmp_shape(&self, other: &Expr<T>) -> std::cmp::Ordering {
1676        // First compare variants for early short-circuiting using discriminant
1677        let self_kind = self.expr_kind();
1678        let other_kind = other.expr_kind();
1679        if std::mem::discriminant(self_kind) != std::mem::discriminant(other_kind) {
1680            return self_kind.variant_order().cmp(&other_kind.variant_order());
1681        }
1682
1683        // Same variants, compare contents
1684        use ExprKind::*;
1685        match (self_kind, other_kind) {
1686            (Lit(lit), Lit(lit1)) => lit.cmp(lit1),
1687            (Var(v), Var(v1)) => v.cmp(v1),
1688            (Slot(s), Slot(s1)) => s.cmp(s1),
1689            (
1690                Unknown(self::Unknown {
1691                    name: name1,
1692                    type_annotation: ta_1,
1693                }),
1694                Unknown(self::Unknown {
1695                    name: name2,
1696                    type_annotation: ta_2,
1697                }),
1698            ) => name1.cmp(name2).then_with(|| ta_1.cmp(ta_2)),
1699            (
1700                If {
1701                    test_expr,
1702                    then_expr,
1703                    else_expr,
1704                },
1705                If {
1706                    test_expr: test_expr1,
1707                    then_expr: then_expr1,
1708                    else_expr: else_expr1,
1709                },
1710            ) => test_expr
1711                .cmp_shape(test_expr1)
1712                .then_with(|| then_expr.cmp_shape(then_expr1))
1713                .then_with(|| else_expr.cmp_shape(else_expr1)),
1714            (
1715                And { left, right },
1716                And {
1717                    left: left1,
1718                    right: right1,
1719                },
1720            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1721            (
1722                Or { left, right },
1723                Or {
1724                    left: left1,
1725                    right: right1,
1726                },
1727            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1728            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1729                op.cmp(op1).then_with(|| arg.cmp_shape(arg1))
1730            }
1731            (
1732                BinaryApp { op, arg1, arg2 },
1733                BinaryApp {
1734                    op: op1,
1735                    arg1: arg11,
1736                    arg2: arg21,
1737                },
1738            ) => op
1739                .cmp(op1)
1740                .then_with(|| arg1.cmp_shape(arg11))
1741                .then_with(|| arg2.cmp_shape(arg21)),
1742            (
1743                ExtensionFunctionApp { fn_name, args },
1744                ExtensionFunctionApp {
1745                    fn_name: fn_name1,
1746                    args: args1,
1747                },
1748            ) => fn_name.cmp(fn_name1).then_with(|| {
1749                args.len().cmp(&args1.len()).then_with(|| {
1750                    for (a, a1) in args.iter().zip(args1.iter()) {
1751                        match a.cmp_shape(a1) {
1752                            std::cmp::Ordering::Equal => continue,
1753                            other => return other,
1754                        }
1755                    }
1756                    std::cmp::Ordering::Equal
1757                })
1758            }),
1759            (
1760                GetAttr { expr, attr },
1761                GetAttr {
1762                    expr: expr1,
1763                    attr: attr1,
1764                },
1765            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1766            (
1767                HasAttr { expr, attr },
1768                HasAttr {
1769                    expr: expr1,
1770                    attr: attr1,
1771                },
1772            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1773            (
1774                Like { expr, pattern },
1775                Like {
1776                    expr: expr1,
1777                    pattern: pattern1,
1778                },
1779            ) => pattern.cmp(pattern1).then_with(|| expr.cmp_shape(expr1)),
1780            (Set(elems), Set(elems1)) => elems.len().cmp(&elems1.len()).then_with(|| {
1781                for (e, e1) in elems.iter().zip(elems1.iter()) {
1782                    match e.cmp_shape(e1) {
1783                        std::cmp::Ordering::Equal => continue,
1784                        other => return other,
1785                    }
1786                }
1787                std::cmp::Ordering::Equal
1788            }),
1789            (Record(map), Record(map1)) => map.len().cmp(&map1.len()).then_with(|| {
1790                for ((a, e), (a1, e1)) in map.iter().zip(map1.iter()) {
1791                    match a.cmp(a1).then_with(|| e.cmp_shape(e1)) {
1792                        std::cmp::Ordering::Equal => continue,
1793                        other => return other,
1794                    }
1795                }
1796                std::cmp::Ordering::Equal
1797            }),
1798            (
1799                Is { expr, entity_type },
1800                Is {
1801                    expr: expr1,
1802                    entity_type: entity_type1,
1803                },
1804            ) => entity_type
1805                .cmp(entity_type1)
1806                .then_with(|| expr.cmp_shape(expr1)),
1807            #[cfg(feature = "tolerant-ast")]
1808            (
1809                Error { error_kind },
1810                Error {
1811                    error_kind: error_kind1,
1812                },
1813            ) => error_kind.cmp(error_kind1),
1814            #[expect(
1815                clippy::unreachable,
1816                reason = "This should never be reached since we compare variants first"
1817            )]
1818            _ => unreachable!(
1819                "Different variants should have been handled by variant_order comparison"
1820            ),
1821        }
1822    }
1823}
1824
1825/// AST variables
1826#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
1827#[serde(rename_all = "camelCase")]
1828#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1829#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1830#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1831pub enum Var {
1832    /// the Principal of the given request
1833    Principal,
1834    /// the Action of the given request
1835    Action,
1836    /// the Resource of the given request
1837    Resource,
1838    /// the Context of the given request
1839    Context,
1840}
1841
1842impl From<PrincipalOrResource> for Var {
1843    fn from(v: PrincipalOrResource) -> Self {
1844        match v {
1845            PrincipalOrResource::Principal => Var::Principal,
1846            PrincipalOrResource::Resource => Var::Resource,
1847        }
1848    }
1849}
1850
1851#[expect(
1852    clippy::fallible_impl_from,
1853    reason = "Tested by `test::all_vars_are_ids`. Never panics"
1854)]
1855impl From<Var> for Id {
1856    fn from(var: Var) -> Self {
1857        #[expect(
1858            clippy::unwrap_used,
1859            reason = "`Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`"
1860        )]
1861        format!("{var}").parse().unwrap()
1862    }
1863}
1864
1865#[expect(
1866    clippy::fallible_impl_from,
1867    reason = "Tested by `test::all_vars_are_ids`. Never panics"
1868)]
1869impl From<Var> for UnreservedId {
1870    fn from(var: Var) -> Self {
1871        #[expect(
1872            clippy::unwrap_used,
1873            reason = "`Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`"
1874        )]
1875        Id::from(var).try_into().unwrap()
1876    }
1877}
1878
1879impl std::fmt::Display for Var {
1880    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1881        match self {
1882            Self::Principal => write!(f, "principal"),
1883            Self::Action => write!(f, "action"),
1884            Self::Resource => write!(f, "resource"),
1885            Self::Context => write!(f, "context"),
1886        }
1887    }
1888}
1889
1890#[cfg(test)]
1891mod test {
1892    use cool_asserts::assert_matches;
1893    use itertools::Itertools;
1894    use smol_str::ToSmolStr;
1895    use std::collections::{hash_map::DefaultHasher, HashSet};
1896
1897    use crate::expr_builder::ExprBuilder as _;
1898
1899    use super::*;
1900
1901    pub fn all_vars() -> impl Iterator<Item = Var> {
1902        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1903    }
1904
1905    // Tests that Var::Into never panics
1906    #[test]
1907    fn all_vars_are_ids() {
1908        for var in all_vars() {
1909            let _id: Id = var.into();
1910            let _id: UnreservedId = var.into();
1911        }
1912    }
1913
1914    #[test]
1915    fn exprs() {
1916        assert_eq!(
1917            Expr::val(33),
1918            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1919        );
1920        assert_eq!(
1921            Expr::val("hello"),
1922            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1923        );
1924        assert_eq!(
1925            Expr::val(EntityUID::with_eid("foo")),
1926            Expr::new(
1927                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1928                None,
1929                ()
1930            )
1931        );
1932        assert_eq!(
1933            Expr::var(Var::Principal),
1934            Expr::new(ExprKind::Var(Var::Principal), None, ())
1935        );
1936        assert_eq!(
1937            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1938            Expr::new(
1939                ExprKind::If {
1940                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1941                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1942                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1943                },
1944                None,
1945                ()
1946            )
1947        );
1948        assert_eq!(
1949            Expr::not(Expr::val(false)),
1950            Expr::new(
1951                ExprKind::UnaryApp {
1952                    op: UnaryOp::Not,
1953                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1954                },
1955                None,
1956                ()
1957            )
1958        );
1959        assert_eq!(
1960            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1961            Expr::new(
1962                ExprKind::GetAttr {
1963                    expr: Arc::new(Expr::new(
1964                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1965                        None,
1966                        ()
1967                    )),
1968                    attr: "some_attr".into()
1969                },
1970                None,
1971                ()
1972            )
1973        );
1974        assert_eq!(
1975            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1976            Expr::new(
1977                ExprKind::HasAttr {
1978                    expr: Arc::new(Expr::new(
1979                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1980                        None,
1981                        ()
1982                    )),
1983                    attr: "some_attr".into()
1984                },
1985                None,
1986                ()
1987            )
1988        );
1989        assert_eq!(
1990            Expr::is_entity_type(
1991                Expr::val(EntityUID::with_eid("foo")),
1992                "Type".parse().unwrap()
1993            ),
1994            Expr::new(
1995                ExprKind::Is {
1996                    expr: Arc::new(Expr::new(
1997                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1998                        None,
1999                        ()
2000                    )),
2001                    entity_type: "Type".parse().unwrap()
2002                },
2003                None,
2004                ()
2005            ),
2006        );
2007    }
2008
2009    #[test]
2010    fn like_display() {
2011        // `\0` escaped form is `\0`.
2012        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
2013        assert_eq!(format!("{e}"), r#""a" like "\0""#);
2014        // `\`'s escaped form is `\\`
2015        let e = Expr::like(
2016            Expr::val("a"),
2017            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
2018        );
2019        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
2020        // `\`'s escaped form is `\\`
2021        let e = Expr::like(
2022            Expr::val("a"),
2023            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
2024        );
2025        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
2026        // literal star's escaped from is `\*`
2027        let e = Expr::like(
2028            Expr::val("a"),
2029            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
2030        );
2031        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
2032    }
2033
2034    #[test]
2035    fn has_display() {
2036        // `\0` escaped form is `\0`.
2037        let e = Expr::has_attr(Expr::val("a"), "\0".into());
2038        assert_eq!(format!("{e}"), r#""a" has "\0""#);
2039        // `\`'s escaped form is `\\`
2040        let e = Expr::has_attr(Expr::val("a"), r"\".into());
2041        assert_eq!(format!("{e}"), r#""a" has "\\""#);
2042    }
2043
2044    #[test]
2045    fn slot_display() {
2046        let e = Expr::slot(SlotId::principal());
2047        assert_eq!(format!("{e}"), "?principal");
2048        let e = Expr::slot(SlotId::resource());
2049        assert_eq!(format!("{e}"), "?resource");
2050        let e = Expr::val(EntityUID::with_eid("eid"));
2051        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
2052    }
2053
2054    #[test]
2055    fn simple_slots() {
2056        let e = Expr::slot(SlotId::principal());
2057        let p = SlotId::principal();
2058        let r = SlotId::resource();
2059        let set: HashSet<SlotId> = HashSet::from_iter([p]);
2060        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2061        let e = Expr::or(
2062            Expr::slot(SlotId::principal()),
2063            Expr::ite(
2064                Expr::val(true),
2065                Expr::slot(SlotId::resource()),
2066                Expr::val(false),
2067            ),
2068        );
2069        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
2070        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2071    }
2072
2073    #[test]
2074    fn unknowns() {
2075        let e = Expr::ite(
2076            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2077            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2078            Expr::unknown(Unknown::new_untyped("c")),
2079        );
2080        let unknowns = e.unknowns().collect_vec();
2081        assert_eq!(unknowns.len(), 3);
2082        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
2083        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
2084        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
2085    }
2086
2087    #[test]
2088    fn is_unknown() {
2089        let e = Expr::ite(
2090            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2091            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2092            Expr::unknown(Unknown::new_untyped("c")),
2093        );
2094        assert!(e.contains_unknown());
2095        let e = Expr::ite(
2096            Expr::not(Expr::val(true)),
2097            Expr::and(Expr::val(1), Expr::val(3)),
2098            Expr::val(1),
2099        );
2100        assert!(!e.contains_unknown());
2101    }
2102
2103    #[test]
2104    fn expr_with_data() {
2105        let e = ExprBuilder::with_data("data").val(1);
2106        assert_eq!(e.into_data(), "data");
2107    }
2108
2109    #[test]
2110    fn expr_shape_only_eq() {
2111        let temp = ExprBuilder::with_data(1).val(1);
2112        let exprs = &[
2113            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
2114            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
2115            (
2116                ExprBuilder::with_data(1).var(Var::Principal),
2117                Expr::var(Var::Principal),
2118            ),
2119            (
2120                ExprBuilder::with_data(1).slot(SlotId::principal()),
2121                Expr::slot(SlotId::principal()),
2122            ),
2123            (
2124                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
2125                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
2126            ),
2127            (
2128                ExprBuilder::with_data(1).not(temp.clone()),
2129                Expr::not(Expr::val(1)),
2130            ),
2131            (
2132                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
2133                Expr::is_eq(Expr::val(1), Expr::val(1)),
2134            ),
2135            (
2136                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
2137                Expr::and(Expr::val(1), Expr::val(1)),
2138            ),
2139            (
2140                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
2141                Expr::or(Expr::val(1), Expr::val(1)),
2142            ),
2143            (
2144                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
2145                Expr::less(Expr::val(1), Expr::val(1)),
2146            ),
2147            (
2148                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
2149                Expr::lesseq(Expr::val(1), Expr::val(1)),
2150            ),
2151            (
2152                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
2153                Expr::greater(Expr::val(1), Expr::val(1)),
2154            ),
2155            (
2156                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
2157                Expr::greatereq(Expr::val(1), Expr::val(1)),
2158            ),
2159            (
2160                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
2161                Expr::add(Expr::val(1), Expr::val(1)),
2162            ),
2163            (
2164                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
2165                Expr::sub(Expr::val(1), Expr::val(1)),
2166            ),
2167            (
2168                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
2169                Expr::mul(Expr::val(1), Expr::val(1)),
2170            ),
2171            (
2172                ExprBuilder::with_data(1).neg(temp.clone()),
2173                Expr::neg(Expr::val(1)),
2174            ),
2175            (
2176                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
2177                Expr::is_in(Expr::val(1), Expr::val(1)),
2178            ),
2179            (
2180                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
2181                Expr::contains(Expr::val(1), Expr::val(1)),
2182            ),
2183            (
2184                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
2185                Expr::contains_all(Expr::val(1), Expr::val(1)),
2186            ),
2187            (
2188                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
2189                Expr::contains_any(Expr::val(1), Expr::val(1)),
2190            ),
2191            (
2192                ExprBuilder::with_data(1).is_empty(temp.clone()),
2193                Expr::is_empty(Expr::val(1)),
2194            ),
2195            (
2196                ExprBuilder::with_data(1).set([temp.clone()]),
2197                Expr::set([Expr::val(1)]),
2198            ),
2199            (
2200                ExprBuilder::with_data(1)
2201                    .record([("foo".into(), temp.clone())])
2202                    .unwrap(),
2203                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
2204            ),
2205            (
2206                ExprBuilder::with_data(1)
2207                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()])
2208                    .unwrap_infallible(),
2209                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
2210            ),
2211            (
2212                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
2213                Expr::get_attr(Expr::val(1), "foo".into()),
2214            ),
2215            (
2216                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
2217                Expr::has_attr(Expr::val(1), "foo".into()),
2218            ),
2219            (
2220                ExprBuilder::with_data(1)
2221                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
2222                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2223            ),
2224            (
2225                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2226                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2227            ),
2228        ];
2229
2230        for (e0, e1) in exprs {
2231            assert!(e0.eq_shape(e0));
2232            assert!(e1.eq_shape(e1));
2233            assert!(e0.eq_shape(e1));
2234            assert!(e1.eq_shape(e0));
2235
2236            let mut hasher0 = DefaultHasher::new();
2237            e0.hash_shape(&mut hasher0);
2238            let hash0 = hasher0.finish();
2239
2240            let mut hasher1 = DefaultHasher::new();
2241            e1.hash_shape(&mut hasher1);
2242            let hash1 = hasher1.finish();
2243
2244            assert_eq!(hash0, hash1);
2245        }
2246    }
2247
2248    #[test]
2249    fn expr_shape_only_not_eq() {
2250        let expr1 = ExprBuilder::with_data(1).val(1);
2251        let expr2 = ExprBuilder::with_data(1).val(2);
2252        assert_ne!(
2253            ExprShapeOnly::new_from_borrowed(&expr1),
2254            ExprShapeOnly::new_from_borrowed(&expr2)
2255        );
2256    }
2257
2258    #[test]
2259    fn expr_shape_only_set_prefix_ne() {
2260        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2261        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2262        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2263
2264        assert_ne!(e1, e2);
2265        assert_ne!(e1, e3);
2266        assert_ne!(e2, e1);
2267        assert_ne!(e2, e3);
2268        assert_ne!(e3, e1);
2269        assert_ne!(e2, e1);
2270    }
2271
2272    #[test]
2273    fn expr_shape_only_ext_fn_arg_prefix_ne() {
2274        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2275            "decimal".parse().unwrap(),
2276            vec![],
2277        ));
2278        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2279            "decimal".parse().unwrap(),
2280            vec![Expr::val("0.0")],
2281        ));
2282        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2283            "decimal".parse().unwrap(),
2284            vec![Expr::val("0.0"), Expr::val("0.0")],
2285        ));
2286
2287        assert_ne!(e1, e2);
2288        assert_ne!(e1, e3);
2289        assert_ne!(e2, e1);
2290        assert_ne!(e2, e3);
2291        assert_ne!(e3, e1);
2292        assert_ne!(e2, e1);
2293    }
2294
2295    #[test]
2296    fn expr_shape_only_record_attr_prefix_ne() {
2297        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2298        let e2 = ExprShapeOnly::new_from_owned(
2299            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2300        );
2301        let e3 = ExprShapeOnly::new_from_owned(
2302            Expr::record([
2303                ("a".to_smolstr(), Expr::val(1)),
2304                ("b".to_smolstr(), Expr::val(2)),
2305            ])
2306            .unwrap(),
2307        );
2308
2309        assert_ne!(e1, e2);
2310        assert_ne!(e1, e3);
2311        assert_ne!(e2, e1);
2312        assert_ne!(e2, e3);
2313        assert_ne!(e3, e1);
2314        assert_ne!(e2, e1);
2315    }
2316
2317    #[test]
2318    fn untyped_subst_present() {
2319        let u = Unknown {
2320            name: "foo".into(),
2321            type_annotation: None,
2322        };
2323        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2324        match r {
2325            Ok(e) => assert_eq!(e, Expr::val(1)),
2326            Err(empty) => match empty {},
2327        }
2328    }
2329
2330    #[test]
2331    fn untyped_subst_present_correct_type() {
2332        let u = Unknown {
2333            name: "foo".into(),
2334            type_annotation: Some(Type::Long),
2335        };
2336        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2337        match r {
2338            Ok(e) => assert_eq!(e, Expr::val(1)),
2339            Err(empty) => match empty {},
2340        }
2341    }
2342
2343    #[test]
2344    fn untyped_subst_present_wrong_type() {
2345        let u = Unknown {
2346            name: "foo".into(),
2347            type_annotation: Some(Type::Bool),
2348        };
2349        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2350        match r {
2351            Ok(e) => assert_eq!(e, Expr::val(1)),
2352            Err(empty) => match empty {},
2353        }
2354    }
2355
2356    #[test]
2357    fn untyped_subst_not_present() {
2358        let u = Unknown {
2359            name: "foo".into(),
2360            type_annotation: Some(Type::Bool),
2361        };
2362        let r = UntypedSubstitution::substitute(&u, None);
2363        match r {
2364            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2365            Err(empty) => match empty {},
2366        }
2367    }
2368
2369    #[test]
2370    fn typed_subst_present() {
2371        let u = Unknown {
2372            name: "foo".into(),
2373            type_annotation: None,
2374        };
2375        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2376        assert_eq!(e, Expr::val(1));
2377    }
2378
2379    #[test]
2380    fn typed_subst_present_correct_type() {
2381        let u = Unknown {
2382            name: "foo".into(),
2383            type_annotation: Some(Type::Long),
2384        };
2385        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2386        assert_eq!(e, Expr::val(1));
2387    }
2388
2389    #[test]
2390    fn typed_subst_present_wrong_type() {
2391        let u = Unknown {
2392            name: "foo".into(),
2393            type_annotation: Some(Type::Bool),
2394        };
2395        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2396        assert_matches!(
2397            r,
2398            SubstitutionError::TypeError {
2399                expected: Type::Bool,
2400                actual: Type::Long,
2401            }
2402        );
2403    }
2404
2405    #[test]
2406    fn typed_subst_not_present() {
2407        let u = Unknown {
2408            name: "foo".into(),
2409            type_annotation: None,
2410        };
2411        let r = TypedSubstitution::substitute(&u, None).unwrap();
2412        assert_eq!(r, Expr::unknown(u));
2413    }
2414}