cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use {
19    super::expr_allows_errors::AstExprErrorKind, crate::parser::err::ToASTError,
20    crate::parser::err::ToASTErrorKind,
21};
22
23use crate::{
24    ast::*,
25    expr_builder::{self, ExprBuilder as _},
26    extensions::Extensions,
27    parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34    borrow::Cow,
35    collections::{btree_map, BTreeMap, HashMap},
36    hash::{Hash, Hasher},
37    mem,
38    sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45/// Internal AST for expressions used by the policy evaluator.
46/// This structure is a wrapper around an `ExprKind`, which is the expression
47/// variant this object contains. It also contains source information about
48/// where the expression was written in policy source code, and some generic
49/// data which is stored on each node of the AST.
50/// Cloning is O(1).
51#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54    expr_kind: ExprKind<T>,
55    #[educe(PartialEq(ignore))]
56    #[educe(Hash(ignore))]
57    source_loc: Option<Loc>,
58    data: T,
59}
60
61/// The possible expression variants. This enum should be matched on by code
62/// recursively traversing the AST.
63#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65    /// Literal value
66    Lit(Literal),
67    /// Variable
68    Var(Var),
69    /// Template Slots
70    Slot(SlotId),
71    /// Symbolic Unknown for partial-eval
72    Unknown(Unknown),
73    /// Ternary expression
74    If {
75        /// Condition for the ternary expression. Must evaluate to Bool type
76        test_expr: Arc<Expr<T>>,
77        /// Value if true
78        then_expr: Arc<Expr<T>>,
79        /// Value if false
80        else_expr: Arc<Expr<T>>,
81    },
82    /// Boolean AND
83    And {
84        /// Left operand, which will be eagerly evaluated
85        left: Arc<Expr<T>>,
86        /// Right operand, which may not be evaluated due to short-circuiting
87        right: Arc<Expr<T>>,
88    },
89    /// Boolean OR
90    Or {
91        /// Left operand, which will be eagerly evaluated
92        left: Arc<Expr<T>>,
93        /// Right operand, which may not be evaluated due to short-circuiting
94        right: Arc<Expr<T>>,
95    },
96    /// Application of a built-in unary operator (single parameter)
97    UnaryApp {
98        /// Unary operator to apply
99        op: UnaryOp,
100        /// Argument to apply operator to
101        arg: Arc<Expr<T>>,
102    },
103    /// Application of a built-in binary operator (two parameters)
104    BinaryApp {
105        /// Binary operator to apply
106        op: BinaryOp,
107        /// First arg
108        arg1: Arc<Expr<T>>,
109        /// Second arg
110        arg2: Arc<Expr<T>>,
111    },
112    /// Application of an extension function to n arguments
113    /// INVARIANT (MethodStyleArgs):
114    ///   if op.style is MethodStyle then args _cannot_ be empty.
115    ///     The first element of args refers to the subject of the method call
116    /// Ideally, we find some way to make this non-representable.
117    ExtensionFunctionApp {
118        /// Extension function to apply
119        fn_name: Name,
120        /// Args to apply the function to
121        args: Arc<Vec<Expr<T>>>,
122    },
123    /// Get an attribute of an entity, or a field of a record
124    GetAttr {
125        /// Expression to get an attribute/field of. Must evaluate to either
126        /// Entity or Record type
127        expr: Arc<Expr<T>>,
128        /// Attribute or field to get
129        attr: SmolStr,
130    },
131    /// Does the given `expr` have the given `attr`?
132    HasAttr {
133        /// Expression to test. Must evaluate to either Entity or Record type
134        expr: Arc<Expr<T>>,
135        /// Attribute or field to check for
136        attr: SmolStr,
137    },
138    /// Regex-like string matching similar to IAM's `StringLike` operator.
139    Like {
140        /// Expression to test. Must evaluate to String type
141        expr: Arc<Expr<T>>,
142        /// Pattern to match on; can include the wildcard *, which matches any string.
143        /// To match a literal `*` in the test expression, users can use `\*`.
144        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
145        pattern: Pattern,
146    },
147    /// Entity type test. Does the first argument have the entity type
148    /// specified by the second argument.
149    Is {
150        /// Expression to test. Must evaluate to an Entity.
151        expr: Arc<Expr<T>>,
152        /// The [`EntityType`] used for the type membership test.
153        entity_type: EntityType,
154    },
155    /// Set (whose elements may be arbitrary expressions)
156    //
157    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
158    // that are syntactically unequal, may actually be semantically equal --
159    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
160    // evaluated into `Value`s
161    Set(Arc<Vec<Expr<T>>>),
162    /// Anonymous record (whose elements may be arbitrary expressions)
163    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164    #[cfg(feature = "tolerant-ast")]
165    /// Error expression - allows us to continue parsing even when we have errors
166    Error {
167        /// Type of error that led to the failure
168        error_kind: AstExprErrorKind,
169    },
170}
171
172impl<T> ExprKind<T> {
173    /// Get the variant order (same as derive(Ord) for enums)
174    fn variant_order(&self) -> u8 {
175        match self {
176            ExprKind::Lit(_) => 0,
177            ExprKind::Var(_) => 1,
178            ExprKind::Slot(_) => 2,
179            ExprKind::Unknown(_) => 3,
180            ExprKind::If { .. } => 4,
181            ExprKind::And { .. } => 5,
182            ExprKind::Or { .. } => 6,
183            ExprKind::UnaryApp { .. } => 7,
184            ExprKind::BinaryApp { .. } => 8,
185            ExprKind::ExtensionFunctionApp { .. } => 9,
186            ExprKind::GetAttr { .. } => 10,
187            ExprKind::HasAttr { .. } => 11,
188            ExprKind::Like { .. } => 12,
189            ExprKind::Set(_) => 13,
190            ExprKind::Record(_) => 14,
191            ExprKind::Is { .. } => 15,
192            #[cfg(feature = "tolerant-ast")]
193            ExprKind::Error { .. } => 16,
194        }
195    }
196}
197
198impl From<Value> for Expr {
199    fn from(v: Value) -> Self {
200        Expr::from(v.value).with_maybe_source_loc(v.loc)
201    }
202}
203
204impl From<ValueKind> for Expr {
205    fn from(v: ValueKind) -> Self {
206        match v {
207            ValueKind::Lit(lit) => Expr::val(lit),
208            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
209            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
210            #[allow(clippy::expect_used)]
211            ValueKind::Record(record) => Expr::record(
212                Arc::unwrap_or_clone(record)
213                    .into_iter()
214                    .map(|(k, v)| (k, Expr::from(v))),
215            )
216            .expect("cannot have duplicate key because the input was already a BTreeMap"),
217            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
218        }
219    }
220}
221
222impl From<PartialValue> for Expr {
223    fn from(pv: PartialValue) -> Self {
224        match pv {
225            PartialValue::Value(v) => Expr::from(v),
226            PartialValue::Residual(expr) => expr,
227        }
228    }
229}
230
231impl<T> Expr<T> {
232    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
233        Self {
234            expr_kind,
235            source_loc,
236            data,
237        }
238    }
239
240    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
241    /// `enum` which specifies the expression variant, so it must be accessed by
242    /// any code matching and recursing on an expression.
243    pub fn expr_kind(&self) -> &ExprKind<T> {
244        &self.expr_kind
245    }
246
247    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
248    pub fn into_expr_kind(self) -> ExprKind<T> {
249        self.expr_kind
250    }
251
252    /// Access the data stored on the `Expr`.
253    pub fn data(&self) -> &T {
254        &self.data
255    }
256
257    /// Access the data stored on the `Expr`, taking ownership and consuming the
258    /// `Expr`.
259    pub fn into_data(self) -> T {
260        self.data
261    }
262
263    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
264    /// data.
265    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
266        (self.expr_kind, self.source_loc, self.data)
267    }
268
269    /// Access the `Loc` stored on the `Expr`.
270    pub fn source_loc(&self) -> Option<&Loc> {
271        self.source_loc.as_ref()
272    }
273
274    /// Return the `Expr`, but with the new `source_loc` (or `None`).
275    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
276        Self { source_loc, ..self }
277    }
278
279    /// Update the data for this `Expr`. A convenient function used by the
280    /// Validator in one place.
281    pub fn set_data(&mut self, data: T) {
282        self.data = data;
283    }
284
285    /// Check whether this expression is an entity reference
286    ///
287    /// This is used for policy scopes, where some syntax is
288    /// required to be an entity reference.
289    pub fn is_ref(&self) -> bool {
290        match &self.expr_kind {
291            ExprKind::Lit(lit) => lit.is_ref(),
292            _ => false,
293        }
294    }
295
296    /// Check whether this expression is a slot.
297    pub fn is_slot(&self) -> bool {
298        matches!(&self.expr_kind, ExprKind::Slot(_))
299    }
300
301    /// Check whether this expression is a set of entity references
302    ///
303    /// This is used for policy scopes, where some syntax is
304    /// required to be an entity reference set.
305    pub fn is_ref_set(&self) -> bool {
306        match &self.expr_kind {
307            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
308            _ => false,
309        }
310    }
311
312    /// Iterate over all sub-expressions in this expression
313    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
314        expr_iterator::ExprIterator::new(self)
315    }
316
317    /// Iterate over all of the slots in this policy AST
318    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
319        self.subexpressions()
320            .filter_map(|exp| match &exp.expr_kind {
321                ExprKind::Slot(slotid) => Some(Slot {
322                    id: *slotid,
323                    loc: exp.source_loc().cloned(),
324                }),
325                _ => None,
326            })
327    }
328
329    /// Determine if the expression is projectable under partial evaluation
330    /// An expression is projectable if it's guaranteed to never error on evaluation
331    /// This is true if the expression is entirely composed of values or unknowns
332    pub fn is_projectable(&self) -> bool {
333        self.subexpressions().all(|e| {
334            matches!(
335                e.expr_kind(),
336                ExprKind::Lit(_)
337                    | ExprKind::Unknown(_)
338                    | ExprKind::Set(_)
339                    | ExprKind::Var(_)
340                    | ExprKind::Record(_)
341            )
342        })
343    }
344
345    /// Try to compute the runtime type of this expression. This operation may
346    /// fail (returning `None`), for example, when asked to get the type of any
347    /// variables, any attributes of entities or records, or an `unknown`
348    /// without an explicitly annotated type.
349    ///
350    /// Also note that this is _not_ typechecking the expression. It does not
351    /// check that the expression actually evaluates to a value (as opposed to
352    /// erroring).
353    ///
354    /// Because of these limitations, this function should only be used to
355    /// obtain a type for use in diagnostics such as error strings.
356    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
357        match &self.expr_kind {
358            ExprKind::Lit(l) => Some(l.type_of()),
359            ExprKind::Var(_) => None,
360            ExprKind::Slot(_) => None,
361            ExprKind::Unknown(u) => u.type_annotation.clone(),
362            ExprKind::If {
363                then_expr,
364                else_expr,
365                ..
366            } => {
367                let type_of_then = then_expr.try_type_of(extensions);
368                let type_of_else = else_expr.try_type_of(extensions);
369                if type_of_then == type_of_else {
370                    type_of_then
371                } else {
372                    None
373                }
374            }
375            ExprKind::And { .. } => Some(Type::Bool),
376            ExprKind::Or { .. } => Some(Type::Bool),
377            ExprKind::UnaryApp {
378                op: UnaryOp::Neg, ..
379            } => Some(Type::Long),
380            ExprKind::UnaryApp {
381                op: UnaryOp::Not, ..
382            } => Some(Type::Bool),
383            ExprKind::UnaryApp {
384                op: UnaryOp::IsEmpty,
385                ..
386            } => Some(Type::Bool),
387            ExprKind::BinaryApp {
388                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
389                ..
390            } => Some(Type::Long),
391            ExprKind::BinaryApp {
392                op:
393                    BinaryOp::Contains
394                    | BinaryOp::ContainsAll
395                    | BinaryOp::ContainsAny
396                    | BinaryOp::Eq
397                    | BinaryOp::In
398                    | BinaryOp::Less
399                    | BinaryOp::LessEq,
400                ..
401            } => Some(Type::Bool),
402            ExprKind::BinaryApp {
403                op: BinaryOp::HasTag,
404                ..
405            } => Some(Type::Bool),
406            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
407                .func(fn_name)
408                .ok()?
409                .return_type()
410                .map(|rty| rty.clone().into()),
411            // We could try to be more complete here, but we can't do all that
412            // much better without evaluating the argument. Even if we know it's
413            // a record `Type::Record` tells us nothing about the type of the
414            // attribute.
415            ExprKind::GetAttr { .. } => None,
416            // similarly to `GetAttr`
417            ExprKind::BinaryApp {
418                op: BinaryOp::GetTag,
419                ..
420            } => None,
421            ExprKind::HasAttr { .. } => Some(Type::Bool),
422            ExprKind::Like { .. } => Some(Type::Bool),
423            ExprKind::Is { .. } => Some(Type::Bool),
424            ExprKind::Set(_) => Some(Type::Set),
425            ExprKind::Record(_) => Some(Type::Record),
426            #[cfg(feature = "tolerant-ast")]
427            ExprKind::Error { .. } => None,
428        }
429    }
430
431    /// Converts an `Expr<V>` to `B::Expr` using the provided builder.
432    ///
433    /// Preserves source location information and recursively transforms each expression node.
434    /// Note: Data may be cloned if the source expression is retained elsewhere.
435    pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
436    where
437        T: Clone,
438    {
439        let builder = B::new().with_maybe_source_loc(self.source_loc());
440        match self.into_expr_kind() {
441            ExprKind::Lit(lit) => builder.val(lit),
442            ExprKind::Var(var) => builder.var(var),
443            ExprKind::Slot(slot) => builder.slot(slot),
444            ExprKind::Unknown(u) => builder.unknown(u),
445            ExprKind::If {
446                test_expr,
447                then_expr,
448                else_expr,
449            } => builder.ite(
450                Arc::unwrap_or_clone(test_expr).into_expr::<B>(),
451                Arc::unwrap_or_clone(then_expr).into_expr::<B>(),
452                Arc::unwrap_or_clone(else_expr).into_expr::<B>(),
453            ),
454            ExprKind::And { left, right } => builder.and(
455                Arc::unwrap_or_clone(left).into_expr::<B>(),
456                Arc::unwrap_or_clone(right).into_expr::<B>(),
457            ),
458            ExprKind::Or { left, right } => builder.or(
459                Arc::unwrap_or_clone(left).into_expr::<B>(),
460                Arc::unwrap_or_clone(right).into_expr::<B>(),
461            ),
462            ExprKind::UnaryApp { op, arg } => {
463                let arg = Arc::unwrap_or_clone(arg).into_expr::<B>();
464                builder.unary_app(op, arg)
465            }
466            ExprKind::BinaryApp { op, arg1, arg2 } => {
467                let arg1 = Arc::unwrap_or_clone(arg1).into_expr::<B>();
468                let arg2 = Arc::unwrap_or_clone(arg2).into_expr::<B>();
469                builder.binary_app(op, arg1, arg2)
470            }
471            ExprKind::ExtensionFunctionApp { fn_name, args } => {
472                let args = Arc::unwrap_or_clone(args)
473                    .into_iter()
474                    .map(|e| e.into_expr::<B>());
475                builder.call_extension_fn(fn_name, args)
476            }
477            ExprKind::GetAttr { expr, attr } => {
478                builder.get_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
479            }
480            ExprKind::HasAttr { expr, attr } => {
481                builder.has_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
482            }
483            ExprKind::Like { expr, pattern } => {
484                builder.like(Arc::unwrap_or_clone(expr).into_expr::<B>(), pattern)
485            }
486            ExprKind::Is { expr, entity_type } => {
487                builder.is_entity_type(Arc::unwrap_or_clone(expr).into_expr::<B>(), entity_type)
488            }
489            ExprKind::Set(set) => builder.set(
490                Arc::unwrap_or_clone(set)
491                    .into_iter()
492                    .map(|e| e.into_expr::<B>()),
493            ),
494            // PANIC SAFETY: `map` is a map, so it will not have duplicates keys, so the `record` constructor cannot error.
495            #[allow(clippy::unwrap_used)]
496            ExprKind::Record(map) => builder
497                .record(
498                    Arc::unwrap_or_clone(map)
499                        .into_iter()
500                        .map(|(k, v)| (k, v.into_expr::<B>())),
501                )
502                .unwrap(),
503            #[cfg(feature = "tolerant-ast")]
504            // PANIC SAFETY: error type is Infallible so can never happen
505            #[allow(clippy::unwrap_used)]
506            ExprKind::Error { .. } => builder
507                .error(ParseErrors::singleton(ToASTError::new(
508                    ToASTErrorKind::ASTErrorNode,
509                    Some(Loc::new(0..1, "AST_ERROR_NODE".into())),
510                )))
511                .unwrap(),
512        }
513    }
514}
515
516#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
517#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
518impl Expr {
519    /// Create an `Expr` that's just a single `Literal`.
520    ///
521    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
522    pub fn val(v: impl Into<Literal>) -> Self {
523        ExprBuilder::new().val(v)
524    }
525
526    /// Create an `Expr` that's just a single `Unknown`.
527    pub fn unknown(u: Unknown) -> Self {
528        ExprBuilder::new().unknown(u)
529    }
530
531    /// Create an `Expr` that's just this literal `Var`
532    pub fn var(v: Var) -> Self {
533        ExprBuilder::new().var(v)
534    }
535
536    /// Create an `Expr` that's just this `SlotId`
537    pub fn slot(s: SlotId) -> Self {
538        ExprBuilder::new().slot(s)
539    }
540
541    /// Create a ternary (if-then-else) `Expr`.
542    ///
543    /// `test_expr` must evaluate to a Bool type
544    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
545        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
546    }
547
548    /// Create a ternary (if-then-else) `Expr`.
549    /// Takes `Arc`s instead of owned `Expr`s.
550    /// `test_expr` must evaluate to a Bool type
551    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
552        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
553    }
554
555    /// Create a 'not' expression. `e` must evaluate to Bool type
556    pub fn not(e: Expr) -> Self {
557        ExprBuilder::new().not(e)
558    }
559
560    /// Create a '==' expression
561    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
562        ExprBuilder::new().is_eq(e1, e2)
563    }
564
565    /// Create a '!=' expression
566    pub fn noteq(e1: Expr, e2: Expr) -> Self {
567        ExprBuilder::new().noteq(e1, e2)
568    }
569
570    /// Create an 'and' expression. Arguments must evaluate to Bool type
571    pub fn and(e1: Expr, e2: Expr) -> Self {
572        ExprBuilder::new().and(e1, e2)
573    }
574
575    /// Create an 'or' expression. Arguments must evaluate to Bool type
576    pub fn or(e1: Expr, e2: Expr) -> Self {
577        ExprBuilder::new().or(e1, e2)
578    }
579
580    /// Create a '<' expression. Arguments must evaluate to Long type
581    pub fn less(e1: Expr, e2: Expr) -> Self {
582        ExprBuilder::new().less(e1, e2)
583    }
584
585    /// Create a '<=' expression. Arguments must evaluate to Long type
586    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
587        ExprBuilder::new().lesseq(e1, e2)
588    }
589
590    /// Create a '>' expression. Arguments must evaluate to Long type
591    pub fn greater(e1: Expr, e2: Expr) -> Self {
592        ExprBuilder::new().greater(e1, e2)
593    }
594
595    /// Create a '>=' expression. Arguments must evaluate to Long type
596    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
597        ExprBuilder::new().greatereq(e1, e2)
598    }
599
600    /// Create an 'add' expression. Arguments must evaluate to Long type
601    pub fn add(e1: Expr, e2: Expr) -> Self {
602        ExprBuilder::new().add(e1, e2)
603    }
604
605    /// Create a 'sub' expression. Arguments must evaluate to Long type
606    pub fn sub(e1: Expr, e2: Expr) -> Self {
607        ExprBuilder::new().sub(e1, e2)
608    }
609
610    /// Create a 'mul' expression. Arguments must evaluate to Long type
611    pub fn mul(e1: Expr, e2: Expr) -> Self {
612        ExprBuilder::new().mul(e1, e2)
613    }
614
615    /// Create a 'neg' expression. `e` must evaluate to Long type.
616    pub fn neg(e: Expr) -> Self {
617        ExprBuilder::new().neg(e)
618    }
619
620    /// Create an 'in' expression. First argument must evaluate to Entity type.
621    /// Second argument must evaluate to either Entity type or Set type where
622    /// all set elements have Entity type.
623    pub fn is_in(e1: Expr, e2: Expr) -> Self {
624        ExprBuilder::new().is_in(e1, e2)
625    }
626
627    /// Create a `contains` expression.
628    /// First argument must have Set type.
629    pub fn contains(e1: Expr, e2: Expr) -> Self {
630        ExprBuilder::new().contains(e1, e2)
631    }
632
633    /// Create a `containsAll` expression. Arguments must evaluate to Set type
634    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
635        ExprBuilder::new().contains_all(e1, e2)
636    }
637
638    /// Create a `containsAny` expression. Arguments must evaluate to Set type
639    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
640        ExprBuilder::new().contains_any(e1, e2)
641    }
642
643    /// Create a `isEmpty` expression. Argument must evaluate to Set type
644    pub fn is_empty(e: Expr) -> Self {
645        ExprBuilder::new().is_empty(e)
646    }
647
648    /// Create a `getTag` expression.
649    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
650    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
651        ExprBuilder::new().get_tag(expr, tag)
652    }
653
654    /// Create a `hasTag` expression.
655    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
656    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
657        ExprBuilder::new().has_tag(expr, tag)
658    }
659
660    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
661    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
662        ExprBuilder::new().set(exprs)
663    }
664
665    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
666    pub fn record(
667        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
668    ) -> Result<Self, ExpressionConstructionError> {
669        ExprBuilder::new().record(pairs)
670    }
671
672    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
673    ///
674    /// If you have an iterator of pairs, generally prefer calling
675    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
676    /// potentially for efficiency reasons but also because `Expr::record()`
677    /// will properly handle duplicate keys but your own `.collect()` will not
678    /// (by default).
679    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
680        ExprBuilder::new().record_arc(map)
681    }
682
683    /// Create an `Expr` which calls the extension function with the given
684    /// `Name` on `args`
685    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
686        ExprBuilder::new().call_extension_fn(fn_name, args)
687    }
688
689    /// Create an application `Expr` which applies the given built-in unary
690    /// operator to the given `arg`
691    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
692        ExprBuilder::new().unary_app(op, arg)
693    }
694
695    /// Create an application `Expr` which applies the given built-in binary
696    /// operator to `arg1` and `arg2`
697    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
698        ExprBuilder::new().binary_app(op, arg1, arg2)
699    }
700
701    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
702    ///
703    /// `expr` must evaluate to either Entity or Record type
704    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
705        ExprBuilder::new().get_attr(expr, attr)
706    }
707
708    /// Create an `Expr` which tests for the existence of a given
709    /// attribute on a given `Entity` or record.
710    ///
711    /// `expr` must evaluate to either Entity or Record type
712    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
713        ExprBuilder::new().has_attr(expr, attr)
714    }
715
716    /// Create a 'like' expression.
717    ///
718    /// `expr` must evaluate to a String type
719    pub fn like(expr: Expr, pattern: Pattern) -> Self {
720        ExprBuilder::new().like(expr, pattern)
721    }
722
723    /// Create an `is` expression.
724    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
725        ExprBuilder::new().is_entity_type(expr, entity_type)
726    }
727
728    /// Check if an expression contains any symbolic unknowns
729    pub fn contains_unknown(&self) -> bool {
730        self.subexpressions()
731            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
732    }
733
734    /// Get all unknowns in an expression
735    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
736        self.subexpressions()
737            .filter_map(|subexpr| match subexpr.expr_kind() {
738                ExprKind::Unknown(u) => Some(u),
739                _ => None,
740            })
741    }
742
743    /// Substitute unknowns with concrete values.
744    ///
745    /// Ignores unmapped unknowns.
746    /// Ignores type annotations on unknowns.
747    /// Note that there might be "undiscovered unknowns" in the Expr, which
748    /// this function does not notice if evaluation of this Expr did not
749    /// traverse all entities and attributes during evaluation, leading to
750    /// this function only substituting one unknown at a time.
751    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
752        match self.substitute_general::<UntypedSubstitution>(definitions) {
753            Ok(e) => e,
754            Err(empty) => match empty {},
755        }
756    }
757
758    /// Substitute unknowns with concrete values.
759    ///
760    /// Ignores unmapped unknowns.
761    /// Errors if the substituted value does not match the type annotation on the unknown.
762    /// Note that there might be "undiscovered unknowns" in the Expr, which
763    /// this function does not notice if evaluation of this Expr did not
764    /// traverse all entities and attributes during evaluation, leading to
765    /// this function only substituting one unknown at a time.
766    pub fn substitute_typed(
767        &self,
768        definitions: &HashMap<SmolStr, Value>,
769    ) -> Result<Expr, SubstitutionError> {
770        self.substitute_general::<TypedSubstitution>(definitions)
771    }
772
773    /// Substitute unknowns with values
774    ///
775    /// Generic over the function implementing the substitution to allow for multiple error behaviors
776    fn substitute_general<T: SubstitutionFunction>(
777        &self,
778        definitions: &HashMap<SmolStr, Value>,
779    ) -> Result<Expr, T::Err> {
780        match self.expr_kind() {
781            ExprKind::Lit(_) => Ok(self.clone()),
782            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
783            ExprKind::Var(_) => Ok(self.clone()),
784            ExprKind::Slot(_) => Ok(self.clone()),
785            ExprKind::If {
786                test_expr,
787                then_expr,
788                else_expr,
789            } => Ok(Expr::ite(
790                test_expr.substitute_general::<T>(definitions)?,
791                then_expr.substitute_general::<T>(definitions)?,
792                else_expr.substitute_general::<T>(definitions)?,
793            )),
794            ExprKind::And { left, right } => Ok(Expr::and(
795                left.substitute_general::<T>(definitions)?,
796                right.substitute_general::<T>(definitions)?,
797            )),
798            ExprKind::Or { left, right } => Ok(Expr::or(
799                left.substitute_general::<T>(definitions)?,
800                right.substitute_general::<T>(definitions)?,
801            )),
802            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
803                *op,
804                arg.substitute_general::<T>(definitions)?,
805            )),
806            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
807                *op,
808                arg1.substitute_general::<T>(definitions)?,
809                arg2.substitute_general::<T>(definitions)?,
810            )),
811            ExprKind::ExtensionFunctionApp { fn_name, args } => {
812                let args = args
813                    .iter()
814                    .map(|e| e.substitute_general::<T>(definitions))
815                    .collect::<Result<Vec<Expr>, _>>()?;
816
817                Ok(Expr::call_extension_fn(fn_name.clone(), args))
818            }
819            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
820                expr.substitute_general::<T>(definitions)?,
821                attr.clone(),
822            )),
823            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
824                expr.substitute_general::<T>(definitions)?,
825                attr.clone(),
826            )),
827            ExprKind::Like { expr, pattern } => Ok(Expr::like(
828                expr.substitute_general::<T>(definitions)?,
829                pattern.clone(),
830            )),
831            ExprKind::Set(members) => {
832                let members = members
833                    .iter()
834                    .map(|e| e.substitute_general::<T>(definitions))
835                    .collect::<Result<Vec<_>, _>>()?;
836                Ok(Expr::set(members))
837            }
838            ExprKind::Record(map) => {
839                let map = map
840                    .iter()
841                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
842                    .collect::<Result<BTreeMap<_, _>, _>>()?;
843                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
844                #[allow(clippy::expect_used)]
845                Ok(Expr::record(map)
846                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
847            }
848            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
849                expr.substitute_general::<T>(definitions)?,
850                entity_type.clone(),
851            )),
852            #[cfg(feature = "tolerant-ast")]
853            ExprKind::Error { .. } => Ok(self.clone()),
854        }
855    }
856}
857
858/// A trait for customizing the error behavior of substitution
859trait SubstitutionFunction {
860    /// The potential errors this substitution function can return
861    type Err;
862    /// The function for implementing the substitution.
863    ///
864    /// Takes the expression being substituted,
865    /// The substitution from the map (if present)
866    /// and the type annotation from the unknown (if present)
867    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
868}
869
870struct TypedSubstitution {}
871
872impl SubstitutionFunction for TypedSubstitution {
873    type Err = SubstitutionError;
874
875    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
876        match (substitute, &value.type_annotation) {
877            (None, _) => Ok(Expr::unknown(value.clone())),
878            (Some(v), None) => Ok(v.clone().into()),
879            (Some(v), Some(t)) => {
880                if v.type_of() == *t {
881                    Ok(v.clone().into())
882                } else {
883                    Err(SubstitutionError::TypeError {
884                        expected: t.clone(),
885                        actual: v.type_of(),
886                    })
887                }
888            }
889        }
890    }
891}
892
893struct UntypedSubstitution {}
894
895impl SubstitutionFunction for UntypedSubstitution {
896    type Err = std::convert::Infallible;
897
898    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
899        Ok(substitute
900            .map(|v| v.clone().into())
901            .unwrap_or_else(|| Expr::unknown(value.clone())))
902    }
903}
904
905impl<T: Clone> std::fmt::Display for Expr<T> {
906    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
907        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
908        // we just convert to EST and use the EST pretty-printer.
909        // Note that converting AST->EST is lossless and infallible.
910        write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
911    }
912}
913
914impl<T: Clone> BoundedDisplay for Expr<T> {
915    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
916        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
917        // pretty-printer. Note that converting AST->EST is lossless and infallible.
918        BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
919    }
920}
921
922impl std::str::FromStr for Expr {
923    type Err = ParseErrors;
924
925    fn from_str(s: &str) -> Result<Expr, Self::Err> {
926        crate::parser::parse_expr(s)
927    }
928}
929
930/// Enum for errors encountered during substitution
931#[derive(Debug, Clone, Diagnostic, Error)]
932pub enum SubstitutionError {
933    /// The supplied value did not match the type annotation on the unknown.
934    #[error("expected a value of type {expected}, got a value of type {actual}")]
935    TypeError {
936        /// The expected type, ie: the type the unknown was annotated with
937        expected: Type,
938        /// The type of the provided value
939        actual: Type,
940    },
941}
942
943/// Representation of a partial-evaluation Unknown at the AST level
944#[derive(Hash, Debug, Clone, PartialEq, Eq)]
945pub struct Unknown {
946    /// The name of the unknown
947    pub name: SmolStr,
948    /// The type of the values that can be substituted in for the unknown.
949    /// If `None`, we have no type annotation, and thus a value of any type can
950    /// be substituted.
951    pub type_annotation: Option<Type>,
952}
953
954impl Unknown {
955    /// Create a new untyped `Unknown`
956    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
957        Self {
958            name: name.into(),
959            type_annotation: None,
960        }
961    }
962
963    /// Create a new `Unknown` with type annotation. (Only values of the given
964    /// type can be substituted.)
965    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
966        Self {
967            name: name.into(),
968            type_annotation: Some(ty),
969        }
970    }
971}
972
973impl std::fmt::Display for Unknown {
974    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
975        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
976        // to avoid code duplication
977        write!(
978            f,
979            "{}",
980            Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
981        )
982    }
983}
984
985/// Builder for constructing `Expr` objects annotated with some `data`
986/// (possibly taking default value) and optionally a `source_loc`.
987#[derive(Clone, Debug)]
988pub struct ExprBuilder<T> {
989    source_loc: Option<Loc>,
990    data: T,
991}
992
993impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
994    type Expr = Expr<T>;
995
996    type Data = T;
997
998    #[cfg(feature = "tolerant-ast")]
999    type ErrorType = ParseErrors;
1000
1001    fn loc(&self) -> Option<&Loc> {
1002        self.source_loc.as_ref()
1003    }
1004
1005    fn data(&self) -> &Self::Data {
1006        &self.data
1007    }
1008
1009    fn with_data(data: T) -> Self {
1010        Self {
1011            source_loc: None,
1012            data,
1013        }
1014    }
1015
1016    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
1017        self.source_loc = maybe_source_loc.cloned();
1018        self
1019    }
1020
1021    /// Create an `Expr` that's just a single `Literal`.
1022    ///
1023    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
1024    fn val(self, v: impl Into<Literal>) -> Expr<T> {
1025        self.with_expr_kind(ExprKind::Lit(v.into()))
1026    }
1027
1028    /// Create an `Unknown` `Expr`
1029    fn unknown(self, u: Unknown) -> Expr<T> {
1030        self.with_expr_kind(ExprKind::Unknown(u))
1031    }
1032
1033    /// Create an `Expr` that's just this literal `Var`
1034    fn var(self, v: Var) -> Expr<T> {
1035        self.with_expr_kind(ExprKind::Var(v))
1036    }
1037
1038    /// Create an `Expr` that's just this `SlotId`
1039    fn slot(self, s: SlotId) -> Expr<T> {
1040        self.with_expr_kind(ExprKind::Slot(s))
1041    }
1042
1043    /// Create a ternary (if-then-else) `Expr`.
1044    ///
1045    /// `test_expr` must evaluate to a Bool type
1046    fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
1047        self.with_expr_kind(ExprKind::If {
1048            test_expr: Arc::new(test_expr),
1049            then_expr: Arc::new(then_expr),
1050            else_expr: Arc::new(else_expr),
1051        })
1052    }
1053
1054    /// Create a 'not' expression. `e` must evaluate to Bool type
1055    fn not(self, e: Expr<T>) -> Expr<T> {
1056        self.with_expr_kind(ExprKind::UnaryApp {
1057            op: UnaryOp::Not,
1058            arg: Arc::new(e),
1059        })
1060    }
1061
1062    /// Create a '==' expression
1063    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1064        self.with_expr_kind(ExprKind::BinaryApp {
1065            op: BinaryOp::Eq,
1066            arg1: Arc::new(e1),
1067            arg2: Arc::new(e2),
1068        })
1069    }
1070
1071    /// Create an 'and' expression. Arguments must evaluate to Bool type
1072    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1073        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1074            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1075                ExprKind::Lit(Literal::Bool(*b1 && *b2))
1076            }
1077            _ => ExprKind::And {
1078                left: Arc::new(e1),
1079                right: Arc::new(e2),
1080            },
1081        })
1082    }
1083
1084    /// Create an 'or' expression. Arguments must evaluate to Bool type
1085    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1086        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1087            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1088                ExprKind::Lit(Literal::Bool(*b1 || *b2))
1089            }
1090
1091            _ => ExprKind::Or {
1092                left: Arc::new(e1),
1093                right: Arc::new(e2),
1094            },
1095        })
1096    }
1097
1098    /// Create a '<' expression. Arguments must evaluate to Long type
1099    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1100        self.with_expr_kind(ExprKind::BinaryApp {
1101            op: BinaryOp::Less,
1102            arg1: Arc::new(e1),
1103            arg2: Arc::new(e2),
1104        })
1105    }
1106
1107    /// Create a '<=' expression. Arguments must evaluate to Long type
1108    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1109        self.with_expr_kind(ExprKind::BinaryApp {
1110            op: BinaryOp::LessEq,
1111            arg1: Arc::new(e1),
1112            arg2: Arc::new(e2),
1113        })
1114    }
1115
1116    /// Create an 'add' expression. Arguments must evaluate to Long type
1117    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1118        self.with_expr_kind(ExprKind::BinaryApp {
1119            op: BinaryOp::Add,
1120            arg1: Arc::new(e1),
1121            arg2: Arc::new(e2),
1122        })
1123    }
1124
1125    /// Create a 'sub' expression. Arguments must evaluate to Long type
1126    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1127        self.with_expr_kind(ExprKind::BinaryApp {
1128            op: BinaryOp::Sub,
1129            arg1: Arc::new(e1),
1130            arg2: Arc::new(e2),
1131        })
1132    }
1133
1134    /// Create a 'mul' expression. Arguments must evaluate to Long type
1135    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1136        self.with_expr_kind(ExprKind::BinaryApp {
1137            op: BinaryOp::Mul,
1138            arg1: Arc::new(e1),
1139            arg2: Arc::new(e2),
1140        })
1141    }
1142
1143    /// Create a 'neg' expression. `e` must evaluate to Long type.
1144    fn neg(self, e: Expr<T>) -> Expr<T> {
1145        self.with_expr_kind(ExprKind::UnaryApp {
1146            op: UnaryOp::Neg,
1147            arg: Arc::new(e),
1148        })
1149    }
1150
1151    /// Create an 'in' expression. First argument must evaluate to Entity type.
1152    /// Second argument must evaluate to either Entity type or Set type where
1153    /// all set elements have Entity type.
1154    fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1155        self.with_expr_kind(ExprKind::BinaryApp {
1156            op: BinaryOp::In,
1157            arg1: Arc::new(e1),
1158            arg2: Arc::new(e2),
1159        })
1160    }
1161
1162    /// Create a 'contains' expression.
1163    /// First argument must have Set type.
1164    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1165        self.with_expr_kind(ExprKind::BinaryApp {
1166            op: BinaryOp::Contains,
1167            arg1: Arc::new(e1),
1168            arg2: Arc::new(e2),
1169        })
1170    }
1171
1172    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1173    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1174        self.with_expr_kind(ExprKind::BinaryApp {
1175            op: BinaryOp::ContainsAll,
1176            arg1: Arc::new(e1),
1177            arg2: Arc::new(e2),
1178        })
1179    }
1180
1181    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1182    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1183        self.with_expr_kind(ExprKind::BinaryApp {
1184            op: BinaryOp::ContainsAny,
1185            arg1: Arc::new(e1),
1186            arg2: Arc::new(e2),
1187        })
1188    }
1189
1190    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1191    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1192        self.with_expr_kind(ExprKind::UnaryApp {
1193            op: UnaryOp::IsEmpty,
1194            arg: Arc::new(expr),
1195        })
1196    }
1197
1198    /// Create a 'getTag' expression.
1199    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1200    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1201        self.with_expr_kind(ExprKind::BinaryApp {
1202            op: BinaryOp::GetTag,
1203            arg1: Arc::new(expr),
1204            arg2: Arc::new(tag),
1205        })
1206    }
1207
1208    /// Create a 'hasTag' expression.
1209    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1210    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1211        self.with_expr_kind(ExprKind::BinaryApp {
1212            op: BinaryOp::HasTag,
1213            arg1: Arc::new(expr),
1214            arg2: Arc::new(tag),
1215        })
1216    }
1217
1218    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1219    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1220        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1221    }
1222
1223    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1224    fn record(
1225        self,
1226        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1227    ) -> Result<Expr<T>, ExpressionConstructionError> {
1228        let mut map = BTreeMap::new();
1229        for (k, v) in pairs {
1230            match map.entry(k) {
1231                btree_map::Entry::Occupied(oentry) => {
1232                    return Err(expression_construction_errors::DuplicateKeyError {
1233                        key: oentry.key().clone(),
1234                        context: "in record literal",
1235                    }
1236                    .into());
1237                }
1238                btree_map::Entry::Vacant(ventry) => {
1239                    ventry.insert(v);
1240                }
1241            }
1242        }
1243        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1244    }
1245
1246    /// Create an `Expr` which calls the extension function with the given
1247    /// `Name` on `args`
1248    fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1249        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1250            fn_name,
1251            args: Arc::new(args.into_iter().collect()),
1252        })
1253    }
1254
1255    /// Create an application `Expr` which applies the given built-in unary
1256    /// operator to the given `arg`
1257    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1258        self.with_expr_kind(ExprKind::UnaryApp {
1259            op: op.into(),
1260            arg: Arc::new(arg),
1261        })
1262    }
1263
1264    /// Create an application `Expr` which applies the given built-in binary
1265    /// operator to `arg1` and `arg2`
1266    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1267        self.with_expr_kind(ExprKind::BinaryApp {
1268            op: op.into(),
1269            arg1: Arc::new(arg1),
1270            arg2: Arc::new(arg2),
1271        })
1272    }
1273
1274    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1275    ///
1276    /// `expr` must evaluate to either Entity or Record type
1277    fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1278        self.with_expr_kind(ExprKind::GetAttr {
1279            expr: Arc::new(expr),
1280            attr,
1281        })
1282    }
1283
1284    /// Create an `Expr` which tests for the existence of a given
1285    /// attribute on a given `Entity` or record.
1286    ///
1287    /// `expr` must evaluate to either Entity or Record type
1288    fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1289        self.with_expr_kind(ExprKind::HasAttr {
1290            expr: Arc::new(expr),
1291            attr,
1292        })
1293    }
1294
1295    /// Create a 'like' expression.
1296    ///
1297    /// `expr` must evaluate to a String type
1298    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1299        self.with_expr_kind(ExprKind::Like {
1300            expr: Arc::new(expr),
1301            pattern,
1302        })
1303    }
1304
1305    /// Create an 'is' expression.
1306    fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1307        self.with_expr_kind(ExprKind::Is {
1308            expr: Arc::new(expr),
1309            entity_type,
1310        })
1311    }
1312
1313    /// Don't support AST Error nodes - return the error right back
1314    #[cfg(feature = "tolerant-ast")]
1315    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1316        Err(parse_errors)
1317    }
1318}
1319
1320impl<T> ExprBuilder<T> {
1321    /// Construct an `Expr` containing the `data` and `source_loc` in this
1322    /// `ExprBuilder` and the given `ExprKind`.
1323    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1324        Expr::new(expr_kind, self.source_loc, self.data)
1325    }
1326
1327    /// Create a ternary (if-then-else) `Expr`.
1328    /// Takes `Arc`s instead of owned `Expr`s.
1329    /// `test_expr` must evaluate to a Bool type
1330    pub fn ite_arc(
1331        self,
1332        test_expr: Arc<Expr<T>>,
1333        then_expr: Arc<Expr<T>>,
1334        else_expr: Arc<Expr<T>>,
1335    ) -> Expr<T> {
1336        self.with_expr_kind(ExprKind::If {
1337            test_expr,
1338            then_expr,
1339            else_expr,
1340        })
1341    }
1342
1343    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1344    ///
1345    /// If you have an iterator of pairs, generally prefer calling `.record()`
1346    /// instead of `.collect()`-ing yourself and calling this, potentially for
1347    /// efficiency reasons but also because `.record()` will properly handle
1348    /// duplicate keys but your own `.collect()` will not (by default).
1349    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1350        self.with_expr_kind(ExprKind::Record(map))
1351    }
1352}
1353
1354impl<T: Clone + Default> ExprBuilder<T> {
1355    /// Utility used the validator to get an expression with the same source
1356    /// location as an existing expression. This is done when reconstructing the
1357    /// `Expr` with type information.
1358    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1359        self.with_maybe_source_loc(expr.source_loc.as_ref())
1360    }
1361}
1362
1363/// Errors when constructing an expression
1364//
1365// CAUTION: this type is publicly exported in `cedar-policy`.
1366// Don't make fields `pub`, don't make breaking changes, and use caution
1367// when adding public methods.
1368#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1369pub enum ExpressionConstructionError {
1370    /// The same key occurred two or more times
1371    #[error(transparent)]
1372    #[diagnostic(transparent)]
1373    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1374}
1375
1376/// Error subtypes for [`ExpressionConstructionError`]
1377pub mod expression_construction_errors {
1378    use miette::Diagnostic;
1379    use smol_str::SmolStr;
1380    use thiserror::Error;
1381
1382    /// The same key occurred two or more times
1383    //
1384    // CAUTION: this type is publicly exported in `cedar-policy`.
1385    // Don't make fields `pub`, don't make breaking changes, and use caution
1386    // when adding public methods.
1387    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1388    #[error("duplicate key `{key}` {context}")]
1389    pub struct DuplicateKeyError {
1390        /// The key which occurred two or more times
1391        pub(crate) key: SmolStr,
1392        /// Information about where the duplicate key occurred (e.g., "in record literal")
1393        pub(crate) context: &'static str,
1394    }
1395
1396    impl DuplicateKeyError {
1397        /// Get the key which occurred two or more times
1398        pub fn key(&self) -> &str {
1399            &self.key
1400        }
1401
1402        /// Make a new error with an updated `context` field
1403        pub(crate) fn with_context(self, context: &'static str) -> Self {
1404            Self { context, ..self }
1405        }
1406    }
1407}
1408
1409/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1410/// implementations that ignore any source information or other generic data
1411/// used to annotate the `Expr`.
1412#[derive(Debug, Clone)]
1413pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1414
1415impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1416    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1417    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1418    /// ignore source information and generic data.
1419    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1420        ExprShapeOnly(Cow::Borrowed(e))
1421    }
1422
1423    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1424    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1425    /// ignore source information and generic data.
1426    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1427        ExprShapeOnly(Cow::Owned(e))
1428    }
1429}
1430
1431impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1432    fn eq(&self, other: &Self) -> bool {
1433        self.0.eq_shape(&other.0)
1434    }
1435}
1436
1437impl<T: Clone> Eq for ExprShapeOnly<'_, T> {}
1438
1439impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1440    fn hash<H: Hasher>(&self, state: &mut H) {
1441        self.0.hash_shape(state);
1442    }
1443}
1444
1445impl<T: Clone> PartialOrd for ExprShapeOnly<'_, T> {
1446    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1447        Some(self.cmp(other))
1448    }
1449}
1450
1451impl<T: Clone> Ord for ExprShapeOnly<'_, T> {
1452    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1453        self.0.cmp_shape(&other.0)
1454    }
1455}
1456
1457impl<T> Expr<T> {
1458    /// Return true if this expression (recursively) has the same expression
1459    /// kind as the argument expression. This accounts for the full recursive
1460    /// shape of the expression, but does not consider source information or any
1461    /// generic data annotated on expression. This should behave the same as the
1462    /// default implementation of `Eq` before source information and generic
1463    /// data were added.
1464    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1465        use ExprKind::*;
1466        match (self.expr_kind(), other.expr_kind()) {
1467            (Lit(lit), Lit(lit1)) => lit == lit1,
1468            (Var(v), Var(v1)) => v == v1,
1469            (Slot(s), Slot(s1)) => s == s1,
1470            (
1471                Unknown(self::Unknown {
1472                    name: name1,
1473                    type_annotation: ta_1,
1474                }),
1475                Unknown(self::Unknown {
1476                    name: name2,
1477                    type_annotation: ta_2,
1478                }),
1479            ) => (name1 == name2) && (ta_1 == ta_2),
1480            (
1481                If {
1482                    test_expr,
1483                    then_expr,
1484                    else_expr,
1485                },
1486                If {
1487                    test_expr: test_expr1,
1488                    then_expr: then_expr1,
1489                    else_expr: else_expr1,
1490                },
1491            ) => {
1492                test_expr.eq_shape(test_expr1)
1493                    && then_expr.eq_shape(then_expr1)
1494                    && else_expr.eq_shape(else_expr1)
1495            }
1496            (
1497                And { left, right },
1498                And {
1499                    left: left1,
1500                    right: right1,
1501                },
1502            )
1503            | (
1504                Or { left, right },
1505                Or {
1506                    left: left1,
1507                    right: right1,
1508                },
1509            ) => left.eq_shape(left1) && right.eq_shape(right1),
1510            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1511                op == op1 && arg.eq_shape(arg1)
1512            }
1513            (
1514                BinaryApp { op, arg1, arg2 },
1515                BinaryApp {
1516                    op: op1,
1517                    arg1: arg11,
1518                    arg2: arg21,
1519                },
1520            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1521            (
1522                ExtensionFunctionApp { fn_name, args },
1523                ExtensionFunctionApp {
1524                    fn_name: fn_name1,
1525                    args: args1,
1526                },
1527            ) => {
1528                fn_name == fn_name1
1529                    && args.len() == args1.len()
1530                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1531            }
1532            (
1533                GetAttr { expr, attr },
1534                GetAttr {
1535                    expr: expr1,
1536                    attr: attr1,
1537                },
1538            )
1539            | (
1540                HasAttr { expr, attr },
1541                HasAttr {
1542                    expr: expr1,
1543                    attr: attr1,
1544                },
1545            ) => attr == attr1 && expr.eq_shape(expr1),
1546            (
1547                Like { expr, pattern },
1548                Like {
1549                    expr: expr1,
1550                    pattern: pattern1,
1551                },
1552            ) => pattern == pattern1 && expr.eq_shape(expr1),
1553            (Set(elems), Set(elems1)) => {
1554                elems.len() == elems1.len()
1555                    && elems
1556                        .iter()
1557                        .zip(elems1.iter())
1558                        .all(|(e, e1)| e.eq_shape(e1))
1559            }
1560            (Record(map), Record(map1)) => {
1561                map.len() == map1.len()
1562                    && map
1563                        .iter()
1564                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1565                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1566            }
1567            (
1568                Is { expr, entity_type },
1569                Is {
1570                    expr: expr1,
1571                    entity_type: entity_type1,
1572                },
1573            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1574            _ => false,
1575        }
1576    }
1577
1578    /// Implementation of hashing corresponding to equality as implemented by
1579    /// `eq_shape`. Must satisfy the usual relationship between equality and
1580    /// hashing.
1581    pub fn hash_shape<H>(&self, state: &mut H)
1582    where
1583        H: Hasher,
1584    {
1585        mem::discriminant(self).hash(state);
1586        match self.expr_kind() {
1587            ExprKind::Lit(lit) => lit.hash(state),
1588            ExprKind::Var(v) => v.hash(state),
1589            ExprKind::Slot(s) => s.hash(state),
1590            ExprKind::Unknown(u) => u.hash(state),
1591            ExprKind::If {
1592                test_expr,
1593                then_expr,
1594                else_expr,
1595            } => {
1596                test_expr.hash_shape(state);
1597                then_expr.hash_shape(state);
1598                else_expr.hash_shape(state);
1599            }
1600            ExprKind::And { left, right } => {
1601                left.hash_shape(state);
1602                right.hash_shape(state);
1603            }
1604            ExprKind::Or { left, right } => {
1605                left.hash_shape(state);
1606                right.hash_shape(state);
1607            }
1608            ExprKind::UnaryApp { op, arg } => {
1609                op.hash(state);
1610                arg.hash_shape(state);
1611            }
1612            ExprKind::BinaryApp { op, arg1, arg2 } => {
1613                op.hash(state);
1614                arg1.hash_shape(state);
1615                arg2.hash_shape(state);
1616            }
1617            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1618                fn_name.hash(state);
1619                state.write_usize(args.len());
1620                args.iter().for_each(|a| {
1621                    a.hash_shape(state);
1622                });
1623            }
1624            ExprKind::GetAttr { expr, attr } => {
1625                expr.hash_shape(state);
1626                attr.hash(state);
1627            }
1628            ExprKind::HasAttr { expr, attr } => {
1629                expr.hash_shape(state);
1630                attr.hash(state);
1631            }
1632            ExprKind::Like { expr, pattern } => {
1633                expr.hash_shape(state);
1634                pattern.hash(state);
1635            }
1636            ExprKind::Set(elems) => {
1637                state.write_usize(elems.len());
1638                elems.iter().for_each(|e| {
1639                    e.hash_shape(state);
1640                })
1641            }
1642            ExprKind::Record(map) => {
1643                state.write_usize(map.len());
1644                map.iter().for_each(|(s, a)| {
1645                    s.hash(state);
1646                    a.hash_shape(state);
1647                });
1648            }
1649            ExprKind::Is { expr, entity_type } => {
1650                expr.hash_shape(state);
1651                entity_type.hash(state);
1652            }
1653            #[cfg(feature = "tolerant-ast")]
1654            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1655        }
1656    }
1657
1658    /// Implementation of ordering corresponding to equality as implemented by
1659    /// `eq_shape`. Must satisfy the usual relationship between equality and
1660    /// ordering.
1661    pub fn cmp_shape(&self, other: &Expr<T>) -> std::cmp::Ordering {
1662        // First compare variants for early short-circuiting using discriminant
1663        let self_kind = self.expr_kind();
1664        let other_kind = other.expr_kind();
1665        if std::mem::discriminant(self_kind) != std::mem::discriminant(other_kind) {
1666            return self_kind.variant_order().cmp(&other_kind.variant_order());
1667        }
1668
1669        // Same variants, compare contents
1670        use ExprKind::*;
1671        match (self_kind, other_kind) {
1672            (Lit(lit), Lit(lit1)) => lit.cmp(lit1),
1673            (Var(v), Var(v1)) => v.cmp(v1),
1674            (Slot(s), Slot(s1)) => s.cmp(s1),
1675            (
1676                Unknown(self::Unknown {
1677                    name: name1,
1678                    type_annotation: ta_1,
1679                }),
1680                Unknown(self::Unknown {
1681                    name: name2,
1682                    type_annotation: ta_2,
1683                }),
1684            ) => name1.cmp(name2).then_with(|| ta_1.cmp(ta_2)),
1685            (
1686                If {
1687                    test_expr,
1688                    then_expr,
1689                    else_expr,
1690                },
1691                If {
1692                    test_expr: test_expr1,
1693                    then_expr: then_expr1,
1694                    else_expr: else_expr1,
1695                },
1696            ) => test_expr
1697                .cmp_shape(test_expr1)
1698                .then_with(|| then_expr.cmp_shape(then_expr1))
1699                .then_with(|| else_expr.cmp_shape(else_expr1)),
1700            (
1701                And { left, right },
1702                And {
1703                    left: left1,
1704                    right: right1,
1705                },
1706            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1707            (
1708                Or { left, right },
1709                Or {
1710                    left: left1,
1711                    right: right1,
1712                },
1713            ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1714            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1715                op.cmp(op1).then_with(|| arg.cmp_shape(arg1))
1716            }
1717            (
1718                BinaryApp { op, arg1, arg2 },
1719                BinaryApp {
1720                    op: op1,
1721                    arg1: arg11,
1722                    arg2: arg21,
1723                },
1724            ) => op
1725                .cmp(op1)
1726                .then_with(|| arg1.cmp_shape(arg11))
1727                .then_with(|| arg2.cmp_shape(arg21)),
1728            (
1729                ExtensionFunctionApp { fn_name, args },
1730                ExtensionFunctionApp {
1731                    fn_name: fn_name1,
1732                    args: args1,
1733                },
1734            ) => fn_name.cmp(fn_name1).then_with(|| {
1735                args.len().cmp(&args1.len()).then_with(|| {
1736                    for (a, a1) in args.iter().zip(args1.iter()) {
1737                        match a.cmp_shape(a1) {
1738                            std::cmp::Ordering::Equal => continue,
1739                            other => return other,
1740                        }
1741                    }
1742                    std::cmp::Ordering::Equal
1743                })
1744            }),
1745            (
1746                GetAttr { expr, attr },
1747                GetAttr {
1748                    expr: expr1,
1749                    attr: attr1,
1750                },
1751            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1752            (
1753                HasAttr { expr, attr },
1754                HasAttr {
1755                    expr: expr1,
1756                    attr: attr1,
1757                },
1758            ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1759            (
1760                Like { expr, pattern },
1761                Like {
1762                    expr: expr1,
1763                    pattern: pattern1,
1764                },
1765            ) => pattern.cmp(pattern1).then_with(|| expr.cmp_shape(expr1)),
1766            (Set(elems), Set(elems1)) => elems.len().cmp(&elems1.len()).then_with(|| {
1767                for (e, e1) in elems.iter().zip(elems1.iter()) {
1768                    match e.cmp_shape(e1) {
1769                        std::cmp::Ordering::Equal => continue,
1770                        other => return other,
1771                    }
1772                }
1773                std::cmp::Ordering::Equal
1774            }),
1775            (Record(map), Record(map1)) => map.len().cmp(&map1.len()).then_with(|| {
1776                for ((a, e), (a1, e1)) in map.iter().zip(map1.iter()) {
1777                    match a.cmp(a1).then_with(|| e.cmp_shape(e1)) {
1778                        std::cmp::Ordering::Equal => continue,
1779                        other => return other,
1780                    }
1781                }
1782                std::cmp::Ordering::Equal
1783            }),
1784            (
1785                Is { expr, entity_type },
1786                Is {
1787                    expr: expr1,
1788                    entity_type: entity_type1,
1789                },
1790            ) => entity_type
1791                .cmp(entity_type1)
1792                .then_with(|| expr.cmp_shape(expr1)),
1793            #[cfg(feature = "tolerant-ast")]
1794            (
1795                Error { error_kind },
1796                Error {
1797                    error_kind: error_kind1,
1798                },
1799            ) => error_kind.cmp(error_kind1),
1800            // PANIC SAFETY: This should never be reached since we compare variants first
1801            #[allow(clippy::unreachable)]
1802            _ => unreachable!(
1803                "Different variants should have been handled by variant_order comparison"
1804            ),
1805        }
1806    }
1807}
1808
1809/// AST variables
1810#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
1811#[serde(rename_all = "camelCase")]
1812#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1813#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1814#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1815pub enum Var {
1816    /// the Principal of the given request
1817    Principal,
1818    /// the Action of the given request
1819    Action,
1820    /// the Resource of the given request
1821    Resource,
1822    /// the Context of the given request
1823    Context,
1824}
1825
1826impl From<PrincipalOrResource> for Var {
1827    fn from(v: PrincipalOrResource) -> Self {
1828        match v {
1829            PrincipalOrResource::Principal => Var::Principal,
1830            PrincipalOrResource::Resource => Var::Resource,
1831        }
1832    }
1833}
1834
1835// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1836#[allow(clippy::fallible_impl_from)]
1837impl From<Var> for Id {
1838    fn from(var: Var) -> Self {
1839        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1840        #[allow(clippy::unwrap_used)]
1841        format!("{var}").parse().unwrap()
1842    }
1843}
1844
1845// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1846#[allow(clippy::fallible_impl_from)]
1847impl From<Var> for UnreservedId {
1848    fn from(var: Var) -> Self {
1849        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`
1850        #[allow(clippy::unwrap_used)]
1851        Id::from(var).try_into().unwrap()
1852    }
1853}
1854
1855impl std::fmt::Display for Var {
1856    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1857        match self {
1858            Self::Principal => write!(f, "principal"),
1859            Self::Action => write!(f, "action"),
1860            Self::Resource => write!(f, "resource"),
1861            Self::Context => write!(f, "context"),
1862        }
1863    }
1864}
1865
1866#[cfg(test)]
1867mod test {
1868    use cool_asserts::assert_matches;
1869    use itertools::Itertools;
1870    use smol_str::ToSmolStr;
1871    use std::collections::{hash_map::DefaultHasher, HashSet};
1872
1873    use crate::expr_builder::ExprBuilder as _;
1874
1875    use super::*;
1876
1877    pub fn all_vars() -> impl Iterator<Item = Var> {
1878        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1879    }
1880
1881    // Tests that Var::Into never panics
1882    #[test]
1883    fn all_vars_are_ids() {
1884        for var in all_vars() {
1885            let _id: Id = var.into();
1886            let _id: UnreservedId = var.into();
1887        }
1888    }
1889
1890    #[test]
1891    fn exprs() {
1892        assert_eq!(
1893            Expr::val(33),
1894            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1895        );
1896        assert_eq!(
1897            Expr::val("hello"),
1898            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1899        );
1900        assert_eq!(
1901            Expr::val(EntityUID::with_eid("foo")),
1902            Expr::new(
1903                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1904                None,
1905                ()
1906            )
1907        );
1908        assert_eq!(
1909            Expr::var(Var::Principal),
1910            Expr::new(ExprKind::Var(Var::Principal), None, ())
1911        );
1912        assert_eq!(
1913            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1914            Expr::new(
1915                ExprKind::If {
1916                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1917                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1918                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1919                },
1920                None,
1921                ()
1922            )
1923        );
1924        assert_eq!(
1925            Expr::not(Expr::val(false)),
1926            Expr::new(
1927                ExprKind::UnaryApp {
1928                    op: UnaryOp::Not,
1929                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1930                },
1931                None,
1932                ()
1933            )
1934        );
1935        assert_eq!(
1936            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1937            Expr::new(
1938                ExprKind::GetAttr {
1939                    expr: Arc::new(Expr::new(
1940                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1941                        None,
1942                        ()
1943                    )),
1944                    attr: "some_attr".into()
1945                },
1946                None,
1947                ()
1948            )
1949        );
1950        assert_eq!(
1951            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1952            Expr::new(
1953                ExprKind::HasAttr {
1954                    expr: Arc::new(Expr::new(
1955                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1956                        None,
1957                        ()
1958                    )),
1959                    attr: "some_attr".into()
1960                },
1961                None,
1962                ()
1963            )
1964        );
1965        assert_eq!(
1966            Expr::is_entity_type(
1967                Expr::val(EntityUID::with_eid("foo")),
1968                "Type".parse().unwrap()
1969            ),
1970            Expr::new(
1971                ExprKind::Is {
1972                    expr: Arc::new(Expr::new(
1973                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1974                        None,
1975                        ()
1976                    )),
1977                    entity_type: "Type".parse().unwrap()
1978                },
1979                None,
1980                ()
1981            ),
1982        );
1983    }
1984
1985    #[test]
1986    fn like_display() {
1987        // `\0` escaped form is `\0`.
1988        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
1989        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1990        // `\`'s escaped form is `\\`
1991        let e = Expr::like(
1992            Expr::val("a"),
1993            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
1994        );
1995        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1996        // `\`'s escaped form is `\\`
1997        let e = Expr::like(
1998            Expr::val("a"),
1999            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
2000        );
2001        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
2002        // literal star's escaped from is `\*`
2003        let e = Expr::like(
2004            Expr::val("a"),
2005            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
2006        );
2007        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
2008    }
2009
2010    #[test]
2011    fn has_display() {
2012        // `\0` escaped form is `\0`.
2013        let e = Expr::has_attr(Expr::val("a"), "\0".into());
2014        assert_eq!(format!("{e}"), r#""a" has "\0""#);
2015        // `\`'s escaped form is `\\`
2016        let e = Expr::has_attr(Expr::val("a"), r"\".into());
2017        assert_eq!(format!("{e}"), r#""a" has "\\""#);
2018    }
2019
2020    #[test]
2021    fn slot_display() {
2022        let e = Expr::slot(SlotId::principal());
2023        assert_eq!(format!("{e}"), "?principal");
2024        let e = Expr::slot(SlotId::resource());
2025        assert_eq!(format!("{e}"), "?resource");
2026        let e = Expr::val(EntityUID::with_eid("eid"));
2027        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
2028    }
2029
2030    #[test]
2031    fn simple_slots() {
2032        let e = Expr::slot(SlotId::principal());
2033        let p = SlotId::principal();
2034        let r = SlotId::resource();
2035        let set: HashSet<SlotId> = HashSet::from_iter([p]);
2036        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2037        let e = Expr::or(
2038            Expr::slot(SlotId::principal()),
2039            Expr::ite(
2040                Expr::val(true),
2041                Expr::slot(SlotId::resource()),
2042                Expr::val(false),
2043            ),
2044        );
2045        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
2046        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2047    }
2048
2049    #[test]
2050    fn unknowns() {
2051        let e = Expr::ite(
2052            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2053            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2054            Expr::unknown(Unknown::new_untyped("c")),
2055        );
2056        let unknowns = e.unknowns().collect_vec();
2057        assert_eq!(unknowns.len(), 3);
2058        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
2059        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
2060        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
2061    }
2062
2063    #[test]
2064    fn is_unknown() {
2065        let e = Expr::ite(
2066            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2067            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2068            Expr::unknown(Unknown::new_untyped("c")),
2069        );
2070        assert!(e.contains_unknown());
2071        let e = Expr::ite(
2072            Expr::not(Expr::val(true)),
2073            Expr::and(Expr::val(1), Expr::val(3)),
2074            Expr::val(1),
2075        );
2076        assert!(!e.contains_unknown());
2077    }
2078
2079    #[test]
2080    fn expr_with_data() {
2081        let e = ExprBuilder::with_data("data").val(1);
2082        assert_eq!(e.into_data(), "data");
2083    }
2084
2085    #[test]
2086    fn expr_shape_only_eq() {
2087        let temp = ExprBuilder::with_data(1).val(1);
2088        let exprs = &[
2089            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
2090            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
2091            (
2092                ExprBuilder::with_data(1).var(Var::Principal),
2093                Expr::var(Var::Principal),
2094            ),
2095            (
2096                ExprBuilder::with_data(1).slot(SlotId::principal()),
2097                Expr::slot(SlotId::principal()),
2098            ),
2099            (
2100                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
2101                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
2102            ),
2103            (
2104                ExprBuilder::with_data(1).not(temp.clone()),
2105                Expr::not(Expr::val(1)),
2106            ),
2107            (
2108                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
2109                Expr::is_eq(Expr::val(1), Expr::val(1)),
2110            ),
2111            (
2112                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
2113                Expr::and(Expr::val(1), Expr::val(1)),
2114            ),
2115            (
2116                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
2117                Expr::or(Expr::val(1), Expr::val(1)),
2118            ),
2119            (
2120                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
2121                Expr::less(Expr::val(1), Expr::val(1)),
2122            ),
2123            (
2124                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
2125                Expr::lesseq(Expr::val(1), Expr::val(1)),
2126            ),
2127            (
2128                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
2129                Expr::greater(Expr::val(1), Expr::val(1)),
2130            ),
2131            (
2132                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
2133                Expr::greatereq(Expr::val(1), Expr::val(1)),
2134            ),
2135            (
2136                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
2137                Expr::add(Expr::val(1), Expr::val(1)),
2138            ),
2139            (
2140                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
2141                Expr::sub(Expr::val(1), Expr::val(1)),
2142            ),
2143            (
2144                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
2145                Expr::mul(Expr::val(1), Expr::val(1)),
2146            ),
2147            (
2148                ExprBuilder::with_data(1).neg(temp.clone()),
2149                Expr::neg(Expr::val(1)),
2150            ),
2151            (
2152                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
2153                Expr::is_in(Expr::val(1), Expr::val(1)),
2154            ),
2155            (
2156                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
2157                Expr::contains(Expr::val(1), Expr::val(1)),
2158            ),
2159            (
2160                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
2161                Expr::contains_all(Expr::val(1), Expr::val(1)),
2162            ),
2163            (
2164                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
2165                Expr::contains_any(Expr::val(1), Expr::val(1)),
2166            ),
2167            (
2168                ExprBuilder::with_data(1).is_empty(temp.clone()),
2169                Expr::is_empty(Expr::val(1)),
2170            ),
2171            (
2172                ExprBuilder::with_data(1).set([temp.clone()]),
2173                Expr::set([Expr::val(1)]),
2174            ),
2175            (
2176                ExprBuilder::with_data(1)
2177                    .record([("foo".into(), temp.clone())])
2178                    .unwrap(),
2179                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
2180            ),
2181            (
2182                ExprBuilder::with_data(1)
2183                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
2184                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
2185            ),
2186            (
2187                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
2188                Expr::get_attr(Expr::val(1), "foo".into()),
2189            ),
2190            (
2191                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
2192                Expr::has_attr(Expr::val(1), "foo".into()),
2193            ),
2194            (
2195                ExprBuilder::with_data(1)
2196                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
2197                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2198            ),
2199            (
2200                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2201                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2202            ),
2203        ];
2204
2205        for (e0, e1) in exprs {
2206            assert!(e0.eq_shape(e0));
2207            assert!(e1.eq_shape(e1));
2208            assert!(e0.eq_shape(e1));
2209            assert!(e1.eq_shape(e0));
2210
2211            let mut hasher0 = DefaultHasher::new();
2212            e0.hash_shape(&mut hasher0);
2213            let hash0 = hasher0.finish();
2214
2215            let mut hasher1 = DefaultHasher::new();
2216            e1.hash_shape(&mut hasher1);
2217            let hash1 = hasher1.finish();
2218
2219            assert_eq!(hash0, hash1);
2220        }
2221    }
2222
2223    #[test]
2224    fn expr_shape_only_not_eq() {
2225        let expr1 = ExprBuilder::with_data(1).val(1);
2226        let expr2 = ExprBuilder::with_data(1).val(2);
2227        assert_ne!(
2228            ExprShapeOnly::new_from_borrowed(&expr1),
2229            ExprShapeOnly::new_from_borrowed(&expr2)
2230        );
2231    }
2232
2233    #[test]
2234    fn expr_shape_only_set_prefix_ne() {
2235        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2236        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2237        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2238
2239        assert_ne!(e1, e2);
2240        assert_ne!(e1, e3);
2241        assert_ne!(e2, e1);
2242        assert_ne!(e2, e3);
2243        assert_ne!(e3, e1);
2244        assert_ne!(e2, e1);
2245    }
2246
2247    #[test]
2248    fn expr_shape_only_ext_fn_arg_prefix_ne() {
2249        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2250            "decimal".parse().unwrap(),
2251            vec![],
2252        ));
2253        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2254            "decimal".parse().unwrap(),
2255            vec![Expr::val("0.0")],
2256        ));
2257        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2258            "decimal".parse().unwrap(),
2259            vec![Expr::val("0.0"), Expr::val("0.0")],
2260        ));
2261
2262        assert_ne!(e1, e2);
2263        assert_ne!(e1, e3);
2264        assert_ne!(e2, e1);
2265        assert_ne!(e2, e3);
2266        assert_ne!(e3, e1);
2267        assert_ne!(e2, e1);
2268    }
2269
2270    #[test]
2271    fn expr_shape_only_record_attr_prefix_ne() {
2272        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2273        let e2 = ExprShapeOnly::new_from_owned(
2274            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2275        );
2276        let e3 = ExprShapeOnly::new_from_owned(
2277            Expr::record([
2278                ("a".to_smolstr(), Expr::val(1)),
2279                ("b".to_smolstr(), Expr::val(2)),
2280            ])
2281            .unwrap(),
2282        );
2283
2284        assert_ne!(e1, e2);
2285        assert_ne!(e1, e3);
2286        assert_ne!(e2, e1);
2287        assert_ne!(e2, e3);
2288        assert_ne!(e3, e1);
2289        assert_ne!(e2, e1);
2290    }
2291
2292    #[test]
2293    fn untyped_subst_present() {
2294        let u = Unknown {
2295            name: "foo".into(),
2296            type_annotation: None,
2297        };
2298        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2299        match r {
2300            Ok(e) => assert_eq!(e, Expr::val(1)),
2301            Err(empty) => match empty {},
2302        }
2303    }
2304
2305    #[test]
2306    fn untyped_subst_present_correct_type() {
2307        let u = Unknown {
2308            name: "foo".into(),
2309            type_annotation: Some(Type::Long),
2310        };
2311        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2312        match r {
2313            Ok(e) => assert_eq!(e, Expr::val(1)),
2314            Err(empty) => match empty {},
2315        }
2316    }
2317
2318    #[test]
2319    fn untyped_subst_present_wrong_type() {
2320        let u = Unknown {
2321            name: "foo".into(),
2322            type_annotation: Some(Type::Bool),
2323        };
2324        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2325        match r {
2326            Ok(e) => assert_eq!(e, Expr::val(1)),
2327            Err(empty) => match empty {},
2328        }
2329    }
2330
2331    #[test]
2332    fn untyped_subst_not_present() {
2333        let u = Unknown {
2334            name: "foo".into(),
2335            type_annotation: Some(Type::Bool),
2336        };
2337        let r = UntypedSubstitution::substitute(&u, None);
2338        match r {
2339            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2340            Err(empty) => match empty {},
2341        }
2342    }
2343
2344    #[test]
2345    fn typed_subst_present() {
2346        let u = Unknown {
2347            name: "foo".into(),
2348            type_annotation: None,
2349        };
2350        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2351        assert_eq!(e, Expr::val(1));
2352    }
2353
2354    #[test]
2355    fn typed_subst_present_correct_type() {
2356        let u = Unknown {
2357            name: "foo".into(),
2358            type_annotation: Some(Type::Long),
2359        };
2360        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2361        assert_eq!(e, Expr::val(1));
2362    }
2363
2364    #[test]
2365    fn typed_subst_present_wrong_type() {
2366        let u = Unknown {
2367            name: "foo".into(),
2368            type_annotation: Some(Type::Bool),
2369        };
2370        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2371        assert_matches!(
2372            r,
2373            SubstitutionError::TypeError {
2374                expected: Type::Bool,
2375                actual: Type::Long,
2376            }
2377        );
2378    }
2379
2380    #[test]
2381    fn typed_subst_not_present() {
2382        let u = Unknown {
2383            name: "foo".into(),
2384            type_annotation: None,
2385        };
2386        let r = TypedSubstitution::substitute(&u, None).unwrap();
2387        assert_eq!(r, Expr::unknown(u));
2388    }
2389}