cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#[cfg(feature = "tolerant-ast")]
18use {
19    super::expr_allows_errors::AstExprErrorKind, crate::parser::err::ToASTError,
20    crate::parser::err::ToASTErrorKind,
21};
22
23use crate::{
24    ast::*,
25    expr_builder::{self, ExprBuilder as _},
26    extensions::Extensions,
27    parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34    borrow::Cow,
35    collections::{btree_map, BTreeMap, HashMap},
36    hash::{Hash, Hasher},
37    mem,
38    sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45/// Internal AST for expressions used by the policy evaluator.
46/// This structure is a wrapper around an `ExprKind`, which is the expression
47/// variant this object contains. It also contains source information about
48/// where the expression was written in policy source code, and some generic
49/// data which is stored on each node of the AST.
50/// Cloning is O(1).
51#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54    expr_kind: ExprKind<T>,
55    #[educe(PartialEq(ignore))]
56    #[educe(Hash(ignore))]
57    source_loc: Option<Loc>,
58    data: T,
59}
60
61/// The possible expression variants. This enum should be matched on by code
62/// recursively traversing the AST.
63#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65    /// Literal value
66    Lit(Literal),
67    /// Variable
68    Var(Var),
69    /// Template Slots
70    Slot(SlotId),
71    /// Symbolic Unknown for partial-eval
72    Unknown(Unknown),
73    /// Ternary expression
74    If {
75        /// Condition for the ternary expression. Must evaluate to Bool type
76        test_expr: Arc<Expr<T>>,
77        /// Value if true
78        then_expr: Arc<Expr<T>>,
79        /// Value if false
80        else_expr: Arc<Expr<T>>,
81    },
82    /// Boolean AND
83    And {
84        /// Left operand, which will be eagerly evaluated
85        left: Arc<Expr<T>>,
86        /// Right operand, which may not be evaluated due to short-circuiting
87        right: Arc<Expr<T>>,
88    },
89    /// Boolean OR
90    Or {
91        /// Left operand, which will be eagerly evaluated
92        left: Arc<Expr<T>>,
93        /// Right operand, which may not be evaluated due to short-circuiting
94        right: Arc<Expr<T>>,
95    },
96    /// Application of a built-in unary operator (single parameter)
97    UnaryApp {
98        /// Unary operator to apply
99        op: UnaryOp,
100        /// Argument to apply operator to
101        arg: Arc<Expr<T>>,
102    },
103    /// Application of a built-in binary operator (two parameters)
104    BinaryApp {
105        /// Binary operator to apply
106        op: BinaryOp,
107        /// First arg
108        arg1: Arc<Expr<T>>,
109        /// Second arg
110        arg2: Arc<Expr<T>>,
111    },
112    /// Application of an extension function to n arguments
113    /// INVARIANT (MethodStyleArgs):
114    ///   if op.style is MethodStyle then args _cannot_ be empty.
115    ///     The first element of args refers to the subject of the method call
116    /// Ideally, we find some way to make this non-representable.
117    ExtensionFunctionApp {
118        /// Extension function to apply
119        fn_name: Name,
120        /// Args to apply the function to
121        args: Arc<Vec<Expr<T>>>,
122    },
123    /// Get an attribute of an entity, or a field of a record
124    GetAttr {
125        /// Expression to get an attribute/field of. Must evaluate to either
126        /// Entity or Record type
127        expr: Arc<Expr<T>>,
128        /// Attribute or field to get
129        attr: SmolStr,
130    },
131    /// Does the given `expr` have the given `attr`?
132    HasAttr {
133        /// Expression to test. Must evaluate to either Entity or Record type
134        expr: Arc<Expr<T>>,
135        /// Attribute or field to check for
136        attr: SmolStr,
137    },
138    /// Regex-like string matching similar to IAM's `StringLike` operator.
139    Like {
140        /// Expression to test. Must evaluate to String type
141        expr: Arc<Expr<T>>,
142        /// Pattern to match on; can include the wildcard *, which matches any string.
143        /// To match a literal `*` in the test expression, users can use `\*`.
144        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
145        pattern: Pattern,
146    },
147    /// Entity type test. Does the first argument have the entity type
148    /// specified by the second argument.
149    Is {
150        /// Expression to test. Must evaluate to an Entity.
151        expr: Arc<Expr<T>>,
152        /// The [`EntityType`] used for the type membership test.
153        entity_type: EntityType,
154    },
155    /// Set (whose elements may be arbitrary expressions)
156    //
157    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
158    // that are syntactically unequal, may actually be semantically equal --
159    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
160    // evaluated into `Value`s
161    Set(Arc<Vec<Expr<T>>>),
162    /// Anonymous record (whose elements may be arbitrary expressions)
163    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164    #[cfg(feature = "tolerant-ast")]
165    /// Error expression - allows us to continue parsing even when we have errors
166    Error {
167        /// Type of error that led to the failure
168        error_kind: AstExprErrorKind,
169    },
170}
171
172impl From<Value> for Expr {
173    fn from(v: Value) -> Self {
174        Expr::from(v.value).with_maybe_source_loc(v.loc)
175    }
176}
177
178impl From<ValueKind> for Expr {
179    fn from(v: ValueKind) -> Self {
180        match v {
181            ValueKind::Lit(lit) => Expr::val(lit),
182            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
183            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
184            #[allow(clippy::expect_used)]
185            ValueKind::Record(record) => Expr::record(
186                Arc::unwrap_or_clone(record)
187                    .into_iter()
188                    .map(|(k, v)| (k, Expr::from(v))),
189            )
190            .expect("cannot have duplicate key because the input was already a BTreeMap"),
191            ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
192        }
193    }
194}
195
196impl From<PartialValue> for Expr {
197    fn from(pv: PartialValue) -> Self {
198        match pv {
199            PartialValue::Value(v) => Expr::from(v),
200            PartialValue::Residual(expr) => expr,
201        }
202    }
203}
204
205impl<T> Expr<T> {
206    pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
207        Self {
208            expr_kind,
209            source_loc,
210            data,
211        }
212    }
213
214    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
215    /// `enum` which specifies the expression variant, so it must be accessed by
216    /// any code matching and recursing on an expression.
217    pub fn expr_kind(&self) -> &ExprKind<T> {
218        &self.expr_kind
219    }
220
221    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
222    pub fn into_expr_kind(self) -> ExprKind<T> {
223        self.expr_kind
224    }
225
226    /// Access the data stored on the `Expr`.
227    pub fn data(&self) -> &T {
228        &self.data
229    }
230
231    /// Access the data stored on the `Expr`, taking ownership and consuming the
232    /// `Expr`.
233    pub fn into_data(self) -> T {
234        self.data
235    }
236
237    /// Consume the `Expr`, returning the `ExprKind`, `source_loc`, and stored
238    /// data.
239    pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
240        (self.expr_kind, self.source_loc, self.data)
241    }
242
243    /// Access the `Loc` stored on the `Expr`.
244    pub fn source_loc(&self) -> Option<&Loc> {
245        self.source_loc.as_ref()
246    }
247
248    /// Return the `Expr`, but with the new `source_loc` (or `None`).
249    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
250        Self { source_loc, ..self }
251    }
252
253    /// Update the data for this `Expr`. A convenient function used by the
254    /// Validator in one place.
255    pub fn set_data(&mut self, data: T) {
256        self.data = data;
257    }
258
259    /// Check whether this expression is an entity reference
260    ///
261    /// This is used for policy scopes, where some syntax is
262    /// required to be an entity reference.
263    pub fn is_ref(&self) -> bool {
264        match &self.expr_kind {
265            ExprKind::Lit(lit) => lit.is_ref(),
266            _ => false,
267        }
268    }
269
270    /// Check whether this expression is a slot.
271    pub fn is_slot(&self) -> bool {
272        matches!(&self.expr_kind, ExprKind::Slot(_))
273    }
274
275    /// Check whether this expression is a set of entity references
276    ///
277    /// This is used for policy scopes, where some syntax is
278    /// required to be an entity reference set.
279    pub fn is_ref_set(&self) -> bool {
280        match &self.expr_kind {
281            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
282            _ => false,
283        }
284    }
285
286    /// Iterate over all sub-expressions in this expression
287    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
288        expr_iterator::ExprIterator::new(self)
289    }
290
291    /// Iterate over all of the slots in this policy AST
292    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
293        self.subexpressions()
294            .filter_map(|exp| match &exp.expr_kind {
295                ExprKind::Slot(slotid) => Some(Slot {
296                    id: *slotid,
297                    loc: exp.source_loc().cloned(),
298                }),
299                _ => None,
300            })
301    }
302
303    /// Determine if the expression is projectable under partial evaluation
304    /// An expression is projectable if it's guaranteed to never error on evaluation
305    /// This is true if the expression is entirely composed of values or unknowns
306    pub fn is_projectable(&self) -> bool {
307        self.subexpressions().all(|e| {
308            matches!(
309                e.expr_kind(),
310                ExprKind::Lit(_)
311                    | ExprKind::Unknown(_)
312                    | ExprKind::Set(_)
313                    | ExprKind::Var(_)
314                    | ExprKind::Record(_)
315            )
316        })
317    }
318
319    /// Try to compute the runtime type of this expression. This operation may
320    /// fail (returning `None`), for example, when asked to get the type of any
321    /// variables, any attributes of entities or records, or an `unknown`
322    /// without an explicitly annotated type.
323    ///
324    /// Also note that this is _not_ typechecking the expression. It does not
325    /// check that the expression actually evaluates to a value (as opposed to
326    /// erroring).
327    ///
328    /// Because of these limitations, this function should only be used to
329    /// obtain a type for use in diagnostics such as error strings.
330    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
331        match &self.expr_kind {
332            ExprKind::Lit(l) => Some(l.type_of()),
333            ExprKind::Var(_) => None,
334            ExprKind::Slot(_) => None,
335            ExprKind::Unknown(u) => u.type_annotation.clone(),
336            ExprKind::If {
337                then_expr,
338                else_expr,
339                ..
340            } => {
341                let type_of_then = then_expr.try_type_of(extensions);
342                let type_of_else = else_expr.try_type_of(extensions);
343                if type_of_then == type_of_else {
344                    type_of_then
345                } else {
346                    None
347                }
348            }
349            ExprKind::And { .. } => Some(Type::Bool),
350            ExprKind::Or { .. } => Some(Type::Bool),
351            ExprKind::UnaryApp {
352                op: UnaryOp::Neg, ..
353            } => Some(Type::Long),
354            ExprKind::UnaryApp {
355                op: UnaryOp::Not, ..
356            } => Some(Type::Bool),
357            ExprKind::UnaryApp {
358                op: UnaryOp::IsEmpty,
359                ..
360            } => Some(Type::Bool),
361            ExprKind::BinaryApp {
362                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
363                ..
364            } => Some(Type::Long),
365            ExprKind::BinaryApp {
366                op:
367                    BinaryOp::Contains
368                    | BinaryOp::ContainsAll
369                    | BinaryOp::ContainsAny
370                    | BinaryOp::Eq
371                    | BinaryOp::In
372                    | BinaryOp::Less
373                    | BinaryOp::LessEq,
374                ..
375            } => Some(Type::Bool),
376            ExprKind::BinaryApp {
377                op: BinaryOp::HasTag,
378                ..
379            } => Some(Type::Bool),
380            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
381                .func(fn_name)
382                .ok()?
383                .return_type()
384                .map(|rty| rty.clone().into()),
385            // We could try to be more complete here, but we can't do all that
386            // much better without evaluating the argument. Even if we know it's
387            // a record `Type::Record` tells us nothing about the type of the
388            // attribute.
389            ExprKind::GetAttr { .. } => None,
390            // similarly to `GetAttr`
391            ExprKind::BinaryApp {
392                op: BinaryOp::GetTag,
393                ..
394            } => None,
395            ExprKind::HasAttr { .. } => Some(Type::Bool),
396            ExprKind::Like { .. } => Some(Type::Bool),
397            ExprKind::Is { .. } => Some(Type::Bool),
398            ExprKind::Set(_) => Some(Type::Set),
399            ExprKind::Record(_) => Some(Type::Record),
400            #[cfg(feature = "tolerant-ast")]
401            ExprKind::Error { .. } => None,
402        }
403    }
404
405    /// Converts an `Expr<V>` to `B::Expr` using the provided builder.
406    ///
407    /// Preserves source location information and recursively transforms each expression node.
408    /// Note: Data may be cloned if the source expression is retained elsewhere.
409    pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
410    where
411        T: Clone,
412    {
413        let builder = B::new().with_maybe_source_loc(self.source_loc().cloned().as_ref());
414        match self.into_expr_kind() {
415            ExprKind::Lit(lit) => builder.val(lit),
416            ExprKind::Var(var) => builder.var(var),
417            ExprKind::Slot(slot) => builder.slot(slot),
418            ExprKind::Unknown(u) => builder.unknown(u),
419            ExprKind::If {
420                test_expr,
421                then_expr,
422                else_expr,
423            } => builder.ite(
424                Arc::unwrap_or_clone(test_expr).into_expr::<B>(),
425                Arc::unwrap_or_clone(then_expr).into_expr::<B>(),
426                Arc::unwrap_or_clone(else_expr).into_expr::<B>(),
427            ),
428            ExprKind::And { left, right } => builder.and(
429                Arc::unwrap_or_clone(left).into_expr::<B>(),
430                Arc::unwrap_or_clone(right).into_expr::<B>(),
431            ),
432            ExprKind::Or { left, right } => builder.or(
433                Arc::unwrap_or_clone(left).into_expr::<B>(),
434                Arc::unwrap_or_clone(right).into_expr::<B>(),
435            ),
436            ExprKind::UnaryApp { op, arg } => {
437                let arg = Arc::unwrap_or_clone(arg).into_expr::<B>();
438                builder.unary_app(op, arg)
439            }
440            ExprKind::BinaryApp { op, arg1, arg2 } => {
441                let arg1 = Arc::unwrap_or_clone(arg1).into_expr::<B>();
442                let arg2 = Arc::unwrap_or_clone(arg2).into_expr::<B>();
443                builder.binary_app(op, arg1, arg2)
444            }
445            ExprKind::ExtensionFunctionApp { fn_name, args } => {
446                let args = Arc::unwrap_or_clone(args)
447                    .into_iter()
448                    .map(|e| e.into_expr::<B>());
449                builder.call_extension_fn(fn_name, args)
450            }
451            ExprKind::GetAttr { expr, attr } => {
452                builder.get_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
453            }
454            ExprKind::HasAttr { expr, attr } => {
455                builder.has_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
456            }
457            ExprKind::Like { expr, pattern } => {
458                builder.like(Arc::unwrap_or_clone(expr).into_expr::<B>(), pattern)
459            }
460            ExprKind::Is { expr, entity_type } => {
461                builder.is_entity_type(Arc::unwrap_or_clone(expr).into_expr::<B>(), entity_type)
462            }
463            ExprKind::Set(set) => builder.set(
464                Arc::unwrap_or_clone(set)
465                    .into_iter()
466                    .map(|e| e.into_expr::<B>()),
467            ),
468            // PANIC SAFETY: `map` is a map, so it will not have duplicates keys, so the `record` constructor cannot error.
469            #[allow(clippy::unwrap_used)]
470            ExprKind::Record(map) => builder
471                .record(
472                    Arc::unwrap_or_clone(map)
473                        .into_iter()
474                        .map(|(k, v)| (k, v.into_expr::<B>())),
475                )
476                .unwrap(),
477            #[cfg(feature = "tolerant-ast")]
478            // PANIC SAFETY: error type is Infallible so can never happen
479            #[allow(clippy::unwrap_used)]
480            ExprKind::Error { .. } => builder
481                .error(ParseErrors::singleton(ToASTError::new(
482                    ToASTErrorKind::ASTErrorNode,
483                    Loc::new(0..1, "AST_ERROR_NODE".into()),
484                )))
485                .unwrap(),
486        }
487    }
488}
489
490#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
491#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
492impl Expr {
493    /// Create an `Expr` that's just a single `Literal`.
494    ///
495    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
496    pub fn val(v: impl Into<Literal>) -> Self {
497        ExprBuilder::new().val(v)
498    }
499
500    /// Create an `Expr` that's just a single `Unknown`.
501    pub fn unknown(u: Unknown) -> Self {
502        ExprBuilder::new().unknown(u)
503    }
504
505    /// Create an `Expr` that's just this literal `Var`
506    pub fn var(v: Var) -> Self {
507        ExprBuilder::new().var(v)
508    }
509
510    /// Create an `Expr` that's just this `SlotId`
511    pub fn slot(s: SlotId) -> Self {
512        ExprBuilder::new().slot(s)
513    }
514
515    /// Create a ternary (if-then-else) `Expr`.
516    ///
517    /// `test_expr` must evaluate to a Bool type
518    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
519        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
520    }
521
522    /// Create a ternary (if-then-else) `Expr`.
523    /// Takes `Arc`s instead of owned `Expr`s.
524    /// `test_expr` must evaluate to a Bool type
525    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
526        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
527    }
528
529    /// Create a 'not' expression. `e` must evaluate to Bool type
530    pub fn not(e: Expr) -> Self {
531        ExprBuilder::new().not(e)
532    }
533
534    /// Create a '==' expression
535    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
536        ExprBuilder::new().is_eq(e1, e2)
537    }
538
539    /// Create a '!=' expression
540    pub fn noteq(e1: Expr, e2: Expr) -> Self {
541        ExprBuilder::new().noteq(e1, e2)
542    }
543
544    /// Create an 'and' expression. Arguments must evaluate to Bool type
545    pub fn and(e1: Expr, e2: Expr) -> Self {
546        ExprBuilder::new().and(e1, e2)
547    }
548
549    /// Create an 'or' expression. Arguments must evaluate to Bool type
550    pub fn or(e1: Expr, e2: Expr) -> Self {
551        ExprBuilder::new().or(e1, e2)
552    }
553
554    /// Create a '<' expression. Arguments must evaluate to Long type
555    pub fn less(e1: Expr, e2: Expr) -> Self {
556        ExprBuilder::new().less(e1, e2)
557    }
558
559    /// Create a '<=' expression. Arguments must evaluate to Long type
560    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
561        ExprBuilder::new().lesseq(e1, e2)
562    }
563
564    /// Create a '>' expression. Arguments must evaluate to Long type
565    pub fn greater(e1: Expr, e2: Expr) -> Self {
566        ExprBuilder::new().greater(e1, e2)
567    }
568
569    /// Create a '>=' expression. Arguments must evaluate to Long type
570    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
571        ExprBuilder::new().greatereq(e1, e2)
572    }
573
574    /// Create an 'add' expression. Arguments must evaluate to Long type
575    pub fn add(e1: Expr, e2: Expr) -> Self {
576        ExprBuilder::new().add(e1, e2)
577    }
578
579    /// Create a 'sub' expression. Arguments must evaluate to Long type
580    pub fn sub(e1: Expr, e2: Expr) -> Self {
581        ExprBuilder::new().sub(e1, e2)
582    }
583
584    /// Create a 'mul' expression. Arguments must evaluate to Long type
585    pub fn mul(e1: Expr, e2: Expr) -> Self {
586        ExprBuilder::new().mul(e1, e2)
587    }
588
589    /// Create a 'neg' expression. `e` must evaluate to Long type.
590    pub fn neg(e: Expr) -> Self {
591        ExprBuilder::new().neg(e)
592    }
593
594    /// Create an 'in' expression. First argument must evaluate to Entity type.
595    /// Second argument must evaluate to either Entity type or Set type where
596    /// all set elements have Entity type.
597    pub fn is_in(e1: Expr, e2: Expr) -> Self {
598        ExprBuilder::new().is_in(e1, e2)
599    }
600
601    /// Create a `contains` expression.
602    /// First argument must have Set type.
603    pub fn contains(e1: Expr, e2: Expr) -> Self {
604        ExprBuilder::new().contains(e1, e2)
605    }
606
607    /// Create a `containsAll` expression. Arguments must evaluate to Set type
608    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
609        ExprBuilder::new().contains_all(e1, e2)
610    }
611
612    /// Create a `containsAny` expression. Arguments must evaluate to Set type
613    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
614        ExprBuilder::new().contains_any(e1, e2)
615    }
616
617    /// Create a `isEmpty` expression. Argument must evaluate to Set type
618    pub fn is_empty(e: Expr) -> Self {
619        ExprBuilder::new().is_empty(e)
620    }
621
622    /// Create a `getTag` expression.
623    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
624    pub fn get_tag(expr: Expr, tag: Expr) -> Self {
625        ExprBuilder::new().get_tag(expr, tag)
626    }
627
628    /// Create a `hasTag` expression.
629    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
630    pub fn has_tag(expr: Expr, tag: Expr) -> Self {
631        ExprBuilder::new().has_tag(expr, tag)
632    }
633
634    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
635    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
636        ExprBuilder::new().set(exprs)
637    }
638
639    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
640    pub fn record(
641        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
642    ) -> Result<Self, ExpressionConstructionError> {
643        ExprBuilder::new().record(pairs)
644    }
645
646    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
647    ///
648    /// If you have an iterator of pairs, generally prefer calling
649    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
650    /// potentially for efficiency reasons but also because `Expr::record()`
651    /// will properly handle duplicate keys but your own `.collect()` will not
652    /// (by default).
653    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
654        ExprBuilder::new().record_arc(map)
655    }
656
657    /// Create an `Expr` which calls the extension function with the given
658    /// `Name` on `args`
659    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
660        ExprBuilder::new().call_extension_fn(fn_name, args)
661    }
662
663    /// Create an application `Expr` which applies the given built-in unary
664    /// operator to the given `arg`
665    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
666        ExprBuilder::new().unary_app(op, arg)
667    }
668
669    /// Create an application `Expr` which applies the given built-in binary
670    /// operator to `arg1` and `arg2`
671    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
672        ExprBuilder::new().binary_app(op, arg1, arg2)
673    }
674
675    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
676    ///
677    /// `expr` must evaluate to either Entity or Record type
678    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
679        ExprBuilder::new().get_attr(expr, attr)
680    }
681
682    /// Create an `Expr` which tests for the existence of a given
683    /// attribute on a given `Entity` or record.
684    ///
685    /// `expr` must evaluate to either Entity or Record type
686    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
687        ExprBuilder::new().has_attr(expr, attr)
688    }
689
690    /// Create a 'like' expression.
691    ///
692    /// `expr` must evaluate to a String type
693    pub fn like(expr: Expr, pattern: Pattern) -> Self {
694        ExprBuilder::new().like(expr, pattern)
695    }
696
697    /// Create an `is` expression.
698    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
699        ExprBuilder::new().is_entity_type(expr, entity_type)
700    }
701
702    /// Check if an expression contains any symbolic unknowns
703    pub fn contains_unknown(&self) -> bool {
704        self.subexpressions()
705            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
706    }
707
708    /// Get all unknowns in an expression
709    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
710        self.subexpressions()
711            .filter_map(|subexpr| match subexpr.expr_kind() {
712                ExprKind::Unknown(u) => Some(u),
713                _ => None,
714            })
715    }
716
717    /// Substitute unknowns with concrete values.
718    ///
719    /// Ignores unmapped unknowns.
720    /// Ignores type annotations on unknowns.
721    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
722        match self.substitute_general::<UntypedSubstitution>(definitions) {
723            Ok(e) => e,
724            Err(empty) => match empty {},
725        }
726    }
727
728    /// Substitute unknowns with concrete values.
729    ///
730    /// Ignores unmapped unknowns.
731    /// Errors if the substituted value does not match the type annotation on the unknown.
732    pub fn substitute_typed(
733        &self,
734        definitions: &HashMap<SmolStr, Value>,
735    ) -> Result<Expr, SubstitutionError> {
736        self.substitute_general::<TypedSubstitution>(definitions)
737    }
738
739    /// Substitute unknowns with values
740    ///
741    /// Generic over the function implementing the substitution to allow for multiple error behaviors
742    fn substitute_general<T: SubstitutionFunction>(
743        &self,
744        definitions: &HashMap<SmolStr, Value>,
745    ) -> Result<Expr, T::Err> {
746        match self.expr_kind() {
747            ExprKind::Lit(_) => Ok(self.clone()),
748            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
749            ExprKind::Var(_) => Ok(self.clone()),
750            ExprKind::Slot(_) => Ok(self.clone()),
751            ExprKind::If {
752                test_expr,
753                then_expr,
754                else_expr,
755            } => Ok(Expr::ite(
756                test_expr.substitute_general::<T>(definitions)?,
757                then_expr.substitute_general::<T>(definitions)?,
758                else_expr.substitute_general::<T>(definitions)?,
759            )),
760            ExprKind::And { left, right } => Ok(Expr::and(
761                left.substitute_general::<T>(definitions)?,
762                right.substitute_general::<T>(definitions)?,
763            )),
764            ExprKind::Or { left, right } => Ok(Expr::or(
765                left.substitute_general::<T>(definitions)?,
766                right.substitute_general::<T>(definitions)?,
767            )),
768            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
769                *op,
770                arg.substitute_general::<T>(definitions)?,
771            )),
772            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
773                *op,
774                arg1.substitute_general::<T>(definitions)?,
775                arg2.substitute_general::<T>(definitions)?,
776            )),
777            ExprKind::ExtensionFunctionApp { fn_name, args } => {
778                let args = args
779                    .iter()
780                    .map(|e| e.substitute_general::<T>(definitions))
781                    .collect::<Result<Vec<Expr>, _>>()?;
782
783                Ok(Expr::call_extension_fn(fn_name.clone(), args))
784            }
785            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
786                expr.substitute_general::<T>(definitions)?,
787                attr.clone(),
788            )),
789            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
790                expr.substitute_general::<T>(definitions)?,
791                attr.clone(),
792            )),
793            ExprKind::Like { expr, pattern } => Ok(Expr::like(
794                expr.substitute_general::<T>(definitions)?,
795                pattern.clone(),
796            )),
797            ExprKind::Set(members) => {
798                let members = members
799                    .iter()
800                    .map(|e| e.substitute_general::<T>(definitions))
801                    .collect::<Result<Vec<_>, _>>()?;
802                Ok(Expr::set(members))
803            }
804            ExprKind::Record(map) => {
805                let map = map
806                    .iter()
807                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
808                    .collect::<Result<BTreeMap<_, _>, _>>()?;
809                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
810                #[allow(clippy::expect_used)]
811                Ok(Expr::record(map)
812                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
813            }
814            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
815                expr.substitute_general::<T>(definitions)?,
816                entity_type.clone(),
817            )),
818            #[cfg(feature = "tolerant-ast")]
819            ExprKind::Error { .. } => Ok(self.clone()),
820        }
821    }
822}
823
824/// A trait for customizing the error behavior of substitution
825trait SubstitutionFunction {
826    /// The potential errors this substitution function can return
827    type Err;
828    /// The function for implementing the substitution.
829    ///
830    /// Takes the expression being substituted,
831    /// The substitution from the map (if present)
832    /// and the type annotation from the unknown (if present)
833    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
834}
835
836struct TypedSubstitution {}
837
838impl SubstitutionFunction for TypedSubstitution {
839    type Err = SubstitutionError;
840
841    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
842        match (substitute, &value.type_annotation) {
843            (None, _) => Ok(Expr::unknown(value.clone())),
844            (Some(v), None) => Ok(v.clone().into()),
845            (Some(v), Some(t)) => {
846                if v.type_of() == *t {
847                    Ok(v.clone().into())
848                } else {
849                    Err(SubstitutionError::TypeError {
850                        expected: t.clone(),
851                        actual: v.type_of(),
852                    })
853                }
854            }
855        }
856    }
857}
858
859struct UntypedSubstitution {}
860
861impl SubstitutionFunction for UntypedSubstitution {
862    type Err = std::convert::Infallible;
863
864    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
865        Ok(substitute
866            .map(|v| v.clone().into())
867            .unwrap_or_else(|| Expr::unknown(value.clone())))
868    }
869}
870
871impl<T: Clone> std::fmt::Display for Expr<T> {
872    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
873        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
874        // we just convert to EST and use the EST pretty-printer.
875        // Note that converting AST->EST is lossless and infallible.
876        write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
877    }
878}
879
880impl<T: Clone> BoundedDisplay for Expr<T> {
881    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
882        // Like the `std::fmt::Display` impl, we convert to EST and use the EST
883        // pretty-printer. Note that converting AST->EST is lossless and infallible.
884        BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
885    }
886}
887
888impl std::str::FromStr for Expr {
889    type Err = ParseErrors;
890
891    fn from_str(s: &str) -> Result<Expr, Self::Err> {
892        crate::parser::parse_expr(s)
893    }
894}
895
896/// Enum for errors encountered during substitution
897#[derive(Debug, Clone, Diagnostic, Error)]
898pub enum SubstitutionError {
899    /// The supplied value did not match the type annotation on the unknown.
900    #[error("expected a value of type {expected}, got a value of type {actual}")]
901    TypeError {
902        /// The expected type, ie: the type the unknown was annotated with
903        expected: Type,
904        /// The type of the provided value
905        actual: Type,
906    },
907}
908
909/// Representation of a partial-evaluation Unknown at the AST level
910#[derive(Hash, Debug, Clone, PartialEq, Eq)]
911pub struct Unknown {
912    /// The name of the unknown
913    pub name: SmolStr,
914    /// The type of the values that can be substituted in for the unknown.
915    /// If `None`, we have no type annotation, and thus a value of any type can
916    /// be substituted.
917    pub type_annotation: Option<Type>,
918}
919
920impl Unknown {
921    /// Create a new untyped `Unknown`
922    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
923        Self {
924            name: name.into(),
925            type_annotation: None,
926        }
927    }
928
929    /// Create a new `Unknown` with type annotation. (Only values of the given
930    /// type can be substituted.)
931    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
932        Self {
933            name: name.into(),
934            type_annotation: Some(ty),
935        }
936    }
937}
938
939impl std::fmt::Display for Unknown {
940    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
942        // to avoid code duplication
943        write!(
944            f,
945            "{}",
946            Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
947        )
948    }
949}
950
951/// Builder for constructing `Expr` objects annotated with some `data`
952/// (possibly taking default value) and optionally a `source_loc`.
953#[derive(Clone, Debug)]
954pub struct ExprBuilder<T> {
955    source_loc: Option<Loc>,
956    data: T,
957}
958
959impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
960    type Expr = Expr<T>;
961
962    type Data = T;
963
964    #[cfg(feature = "tolerant-ast")]
965    type ErrorType = ParseErrors;
966
967    fn loc(&self) -> Option<&Loc> {
968        self.source_loc.as_ref()
969    }
970
971    fn data(&self) -> &Self::Data {
972        &self.data
973    }
974
975    fn with_data(data: T) -> Self {
976        Self {
977            source_loc: None,
978            data,
979        }
980    }
981
982    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
983        self.source_loc = maybe_source_loc.cloned();
984        self
985    }
986
987    /// Create an `Expr` that's just a single `Literal`.
988    ///
989    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
990    fn val(self, v: impl Into<Literal>) -> Expr<T> {
991        self.with_expr_kind(ExprKind::Lit(v.into()))
992    }
993
994    /// Create an `Unknown` `Expr`
995    fn unknown(self, u: Unknown) -> Expr<T> {
996        self.with_expr_kind(ExprKind::Unknown(u))
997    }
998
999    /// Create an `Expr` that's just this literal `Var`
1000    fn var(self, v: Var) -> Expr<T> {
1001        self.with_expr_kind(ExprKind::Var(v))
1002    }
1003
1004    /// Create an `Expr` that's just this `SlotId`
1005    fn slot(self, s: SlotId) -> Expr<T> {
1006        self.with_expr_kind(ExprKind::Slot(s))
1007    }
1008
1009    /// Create a ternary (if-then-else) `Expr`.
1010    ///
1011    /// `test_expr` must evaluate to a Bool type
1012    fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
1013        self.with_expr_kind(ExprKind::If {
1014            test_expr: Arc::new(test_expr),
1015            then_expr: Arc::new(then_expr),
1016            else_expr: Arc::new(else_expr),
1017        })
1018    }
1019
1020    /// Create a 'not' expression. `e` must evaluate to Bool type
1021    fn not(self, e: Expr<T>) -> Expr<T> {
1022        self.with_expr_kind(ExprKind::UnaryApp {
1023            op: UnaryOp::Not,
1024            arg: Arc::new(e),
1025        })
1026    }
1027
1028    /// Create a '==' expression
1029    fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1030        self.with_expr_kind(ExprKind::BinaryApp {
1031            op: BinaryOp::Eq,
1032            arg1: Arc::new(e1),
1033            arg2: Arc::new(e2),
1034        })
1035    }
1036
1037    /// Create an 'and' expression. Arguments must evaluate to Bool type
1038    fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1039        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1040            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1041                ExprKind::Lit(Literal::Bool(*b1 && *b2))
1042            }
1043            _ => ExprKind::And {
1044                left: Arc::new(e1),
1045                right: Arc::new(e2),
1046            },
1047        })
1048    }
1049
1050    /// Create an 'or' expression. Arguments must evaluate to Bool type
1051    fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1052        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1053            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1054                ExprKind::Lit(Literal::Bool(*b1 || *b2))
1055            }
1056
1057            _ => ExprKind::Or {
1058                left: Arc::new(e1),
1059                right: Arc::new(e2),
1060            },
1061        })
1062    }
1063
1064    /// Create a '<' expression. Arguments must evaluate to Long type
1065    fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1066        self.with_expr_kind(ExprKind::BinaryApp {
1067            op: BinaryOp::Less,
1068            arg1: Arc::new(e1),
1069            arg2: Arc::new(e2),
1070        })
1071    }
1072
1073    /// Create a '<=' expression. Arguments must evaluate to Long type
1074    fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1075        self.with_expr_kind(ExprKind::BinaryApp {
1076            op: BinaryOp::LessEq,
1077            arg1: Arc::new(e1),
1078            arg2: Arc::new(e2),
1079        })
1080    }
1081
1082    /// Create an 'add' expression. Arguments must evaluate to Long type
1083    fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1084        self.with_expr_kind(ExprKind::BinaryApp {
1085            op: BinaryOp::Add,
1086            arg1: Arc::new(e1),
1087            arg2: Arc::new(e2),
1088        })
1089    }
1090
1091    /// Create a 'sub' expression. Arguments must evaluate to Long type
1092    fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1093        self.with_expr_kind(ExprKind::BinaryApp {
1094            op: BinaryOp::Sub,
1095            arg1: Arc::new(e1),
1096            arg2: Arc::new(e2),
1097        })
1098    }
1099
1100    /// Create a 'mul' expression. Arguments must evaluate to Long type
1101    fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1102        self.with_expr_kind(ExprKind::BinaryApp {
1103            op: BinaryOp::Mul,
1104            arg1: Arc::new(e1),
1105            arg2: Arc::new(e2),
1106        })
1107    }
1108
1109    /// Create a 'neg' expression. `e` must evaluate to Long type.
1110    fn neg(self, e: Expr<T>) -> Expr<T> {
1111        self.with_expr_kind(ExprKind::UnaryApp {
1112            op: UnaryOp::Neg,
1113            arg: Arc::new(e),
1114        })
1115    }
1116
1117    /// Create an 'in' expression. First argument must evaluate to Entity type.
1118    /// Second argument must evaluate to either Entity type or Set type where
1119    /// all set elements have Entity type.
1120    fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1121        self.with_expr_kind(ExprKind::BinaryApp {
1122            op: BinaryOp::In,
1123            arg1: Arc::new(e1),
1124            arg2: Arc::new(e2),
1125        })
1126    }
1127
1128    /// Create a 'contains' expression.
1129    /// First argument must have Set type.
1130    fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1131        self.with_expr_kind(ExprKind::BinaryApp {
1132            op: BinaryOp::Contains,
1133            arg1: Arc::new(e1),
1134            arg2: Arc::new(e2),
1135        })
1136    }
1137
1138    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1139    fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1140        self.with_expr_kind(ExprKind::BinaryApp {
1141            op: BinaryOp::ContainsAll,
1142            arg1: Arc::new(e1),
1143            arg2: Arc::new(e2),
1144        })
1145    }
1146
1147    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1148    fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1149        self.with_expr_kind(ExprKind::BinaryApp {
1150            op: BinaryOp::ContainsAny,
1151            arg1: Arc::new(e1),
1152            arg2: Arc::new(e2),
1153        })
1154    }
1155
1156    /// Create an 'is_empty' expression. Argument must evaluate to Set type
1157    fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1158        self.with_expr_kind(ExprKind::UnaryApp {
1159            op: UnaryOp::IsEmpty,
1160            arg: Arc::new(expr),
1161        })
1162    }
1163
1164    /// Create a 'getTag' expression.
1165    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1166    fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1167        self.with_expr_kind(ExprKind::BinaryApp {
1168            op: BinaryOp::GetTag,
1169            arg1: Arc::new(expr),
1170            arg2: Arc::new(tag),
1171        })
1172    }
1173
1174    /// Create a 'hasTag' expression.
1175    /// `expr` must evaluate to Entity type, `tag` must evaluate to String type.
1176    fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1177        self.with_expr_kind(ExprKind::BinaryApp {
1178            op: BinaryOp::HasTag,
1179            arg1: Arc::new(expr),
1180            arg2: Arc::new(tag),
1181        })
1182    }
1183
1184    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1185    fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1186        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1187    }
1188
1189    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1190    fn record(
1191        self,
1192        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1193    ) -> Result<Expr<T>, ExpressionConstructionError> {
1194        let mut map = BTreeMap::new();
1195        for (k, v) in pairs {
1196            match map.entry(k) {
1197                btree_map::Entry::Occupied(oentry) => {
1198                    return Err(expression_construction_errors::DuplicateKeyError {
1199                        key: oentry.key().clone(),
1200                        context: "in record literal",
1201                    }
1202                    .into());
1203                }
1204                btree_map::Entry::Vacant(ventry) => {
1205                    ventry.insert(v);
1206                }
1207            }
1208        }
1209        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1210    }
1211
1212    /// Create an `Expr` which calls the extension function with the given
1213    /// `Name` on `args`
1214    fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1215        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1216            fn_name,
1217            args: Arc::new(args.into_iter().collect()),
1218        })
1219    }
1220
1221    /// Create an application `Expr` which applies the given built-in unary
1222    /// operator to the given `arg`
1223    fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1224        self.with_expr_kind(ExprKind::UnaryApp {
1225            op: op.into(),
1226            arg: Arc::new(arg),
1227        })
1228    }
1229
1230    /// Create an application `Expr` which applies the given built-in binary
1231    /// operator to `arg1` and `arg2`
1232    fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1233        self.with_expr_kind(ExprKind::BinaryApp {
1234            op: op.into(),
1235            arg1: Arc::new(arg1),
1236            arg2: Arc::new(arg2),
1237        })
1238    }
1239
1240    /// Create an `Expr` which gets a given attribute of a given `Entity` or record.
1241    ///
1242    /// `expr` must evaluate to either Entity or Record type
1243    fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1244        self.with_expr_kind(ExprKind::GetAttr {
1245            expr: Arc::new(expr),
1246            attr,
1247        })
1248    }
1249
1250    /// Create an `Expr` which tests for the existence of a given
1251    /// attribute on a given `Entity` or record.
1252    ///
1253    /// `expr` must evaluate to either Entity or Record type
1254    fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1255        self.with_expr_kind(ExprKind::HasAttr {
1256            expr: Arc::new(expr),
1257            attr,
1258        })
1259    }
1260
1261    /// Create a 'like' expression.
1262    ///
1263    /// `expr` must evaluate to a String type
1264    fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1265        self.with_expr_kind(ExprKind::Like {
1266            expr: Arc::new(expr),
1267            pattern,
1268        })
1269    }
1270
1271    /// Create an 'is' expression.
1272    fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1273        self.with_expr_kind(ExprKind::Is {
1274            expr: Arc::new(expr),
1275            entity_type,
1276        })
1277    }
1278
1279    /// Don't support AST Error nodes - return the error right back
1280    #[cfg(feature = "tolerant-ast")]
1281    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1282        Err(parse_errors)
1283    }
1284}
1285
1286impl<T> ExprBuilder<T> {
1287    /// Construct an `Expr` containing the `data` and `source_loc` in this
1288    /// `ExprBuilder` and the given `ExprKind`.
1289    pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1290        Expr::new(expr_kind, self.source_loc, self.data)
1291    }
1292
1293    /// Create a ternary (if-then-else) `Expr`.
1294    /// Takes `Arc`s instead of owned `Expr`s.
1295    /// `test_expr` must evaluate to a Bool type
1296    pub fn ite_arc(
1297        self,
1298        test_expr: Arc<Expr<T>>,
1299        then_expr: Arc<Expr<T>>,
1300        else_expr: Arc<Expr<T>>,
1301    ) -> Expr<T> {
1302        self.with_expr_kind(ExprKind::If {
1303            test_expr,
1304            then_expr,
1305            else_expr,
1306        })
1307    }
1308
1309    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
1310    ///
1311    /// If you have an iterator of pairs, generally prefer calling `.record()`
1312    /// instead of `.collect()`-ing yourself and calling this, potentially for
1313    /// efficiency reasons but also because `.record()` will properly handle
1314    /// duplicate keys but your own `.collect()` will not (by default).
1315    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1316        self.with_expr_kind(ExprKind::Record(map))
1317    }
1318}
1319
1320impl<T: Clone + Default> ExprBuilder<T> {
1321    /// Utility used the validator to get an expression with the same source
1322    /// location as an existing expression. This is done when reconstructing the
1323    /// `Expr` with type information.
1324    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1325        self.with_maybe_source_loc(expr.source_loc.as_ref())
1326    }
1327}
1328
1329/// Errors when constructing an expression
1330//
1331// CAUTION: this type is publicly exported in `cedar-policy`.
1332// Don't make fields `pub`, don't make breaking changes, and use caution
1333// when adding public methods.
1334#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1335pub enum ExpressionConstructionError {
1336    /// The same key occurred two or more times
1337    #[error(transparent)]
1338    #[diagnostic(transparent)]
1339    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1340}
1341
1342/// Error subtypes for [`ExpressionConstructionError`]
1343pub mod expression_construction_errors {
1344    use miette::Diagnostic;
1345    use smol_str::SmolStr;
1346    use thiserror::Error;
1347
1348    /// The same key occurred two or more times
1349    //
1350    // CAUTION: this type is publicly exported in `cedar-policy`.
1351    // Don't make fields `pub`, don't make breaking changes, and use caution
1352    // when adding public methods.
1353    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1354    #[error("duplicate key `{key}` {context}")]
1355    pub struct DuplicateKeyError {
1356        /// The key which occurred two or more times
1357        pub(crate) key: SmolStr,
1358        /// Information about where the duplicate key occurred (e.g., "in record literal")
1359        pub(crate) context: &'static str,
1360    }
1361
1362    impl DuplicateKeyError {
1363        /// Get the key which occurred two or more times
1364        pub fn key(&self) -> &str {
1365            &self.key
1366        }
1367
1368        /// Make a new error with an updated `context` field
1369        pub(crate) fn with_context(self, context: &'static str) -> Self {
1370            Self { context, ..self }
1371        }
1372    }
1373}
1374
1375/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1376/// implementations that ignore any source information or other generic data
1377/// used to annotate the `Expr`.
1378#[derive(Eq, Debug, Clone)]
1379pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1380
1381impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1382    /// Construct an `ExprShapeOnly` from a borrowed `Expr`. The `Expr` is not
1383    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1384    /// ignore source information and generic data.
1385    pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1386        ExprShapeOnly(Cow::Borrowed(e))
1387    }
1388
1389    /// Construct an `ExprShapeOnly` from an owned `Expr`. The `Expr` is not
1390    /// modified, but any comparisons on the resulting `ExprShapeOnly` will
1391    /// ignore source information and generic data.
1392    pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1393        ExprShapeOnly(Cow::Owned(e))
1394    }
1395}
1396
1397impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1398    fn eq(&self, other: &Self) -> bool {
1399        self.0.eq_shape(&other.0)
1400    }
1401}
1402
1403impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1404    fn hash<H: Hasher>(&self, state: &mut H) {
1405        self.0.hash_shape(state);
1406    }
1407}
1408
1409impl<T> Expr<T> {
1410    /// Return true if this expression (recursively) has the same expression
1411    /// kind as the argument expression. This accounts for the full recursive
1412    /// shape of the expression, but does not consider source information or any
1413    /// generic data annotated on expression. This should behave the same as the
1414    /// default implementation of `Eq` before source information and generic
1415    /// data were added.
1416    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1417        use ExprKind::*;
1418        match (self.expr_kind(), other.expr_kind()) {
1419            (Lit(lit), Lit(lit1)) => lit == lit1,
1420            (Var(v), Var(v1)) => v == v1,
1421            (Slot(s), Slot(s1)) => s == s1,
1422            (
1423                Unknown(self::Unknown {
1424                    name: name1,
1425                    type_annotation: ta_1,
1426                }),
1427                Unknown(self::Unknown {
1428                    name: name2,
1429                    type_annotation: ta_2,
1430                }),
1431            ) => (name1 == name2) && (ta_1 == ta_2),
1432            (
1433                If {
1434                    test_expr,
1435                    then_expr,
1436                    else_expr,
1437                },
1438                If {
1439                    test_expr: test_expr1,
1440                    then_expr: then_expr1,
1441                    else_expr: else_expr1,
1442                },
1443            ) => {
1444                test_expr.eq_shape(test_expr1)
1445                    && then_expr.eq_shape(then_expr1)
1446                    && else_expr.eq_shape(else_expr1)
1447            }
1448            (
1449                And { left, right },
1450                And {
1451                    left: left1,
1452                    right: right1,
1453                },
1454            )
1455            | (
1456                Or { left, right },
1457                Or {
1458                    left: left1,
1459                    right: right1,
1460                },
1461            ) => left.eq_shape(left1) && right.eq_shape(right1),
1462            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1463                op == op1 && arg.eq_shape(arg1)
1464            }
1465            (
1466                BinaryApp { op, arg1, arg2 },
1467                BinaryApp {
1468                    op: op1,
1469                    arg1: arg11,
1470                    arg2: arg21,
1471                },
1472            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1473            (
1474                ExtensionFunctionApp { fn_name, args },
1475                ExtensionFunctionApp {
1476                    fn_name: fn_name1,
1477                    args: args1,
1478                },
1479            ) => {
1480                fn_name == fn_name1
1481                    && args.len() == args1.len()
1482                    && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1483            }
1484            (
1485                GetAttr { expr, attr },
1486                GetAttr {
1487                    expr: expr1,
1488                    attr: attr1,
1489                },
1490            )
1491            | (
1492                HasAttr { expr, attr },
1493                HasAttr {
1494                    expr: expr1,
1495                    attr: attr1,
1496                },
1497            ) => attr == attr1 && expr.eq_shape(expr1),
1498            (
1499                Like { expr, pattern },
1500                Like {
1501                    expr: expr1,
1502                    pattern: pattern1,
1503                },
1504            ) => pattern == pattern1 && expr.eq_shape(expr1),
1505            (Set(elems), Set(elems1)) => {
1506                elems.len() == elems1.len()
1507                    && elems
1508                        .iter()
1509                        .zip(elems1.iter())
1510                        .all(|(e, e1)| e.eq_shape(e1))
1511            }
1512            (Record(map), Record(map1)) => {
1513                map.len() == map1.len()
1514                    && map
1515                        .iter()
1516                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1517                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1518            }
1519            (
1520                Is { expr, entity_type },
1521                Is {
1522                    expr: expr1,
1523                    entity_type: entity_type1,
1524                },
1525            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1526            _ => false,
1527        }
1528    }
1529
1530    /// Implementation of hashing corresponding to equality as implemented by
1531    /// `eq_shape`. Must satisfy the usual relationship between equality and
1532    /// hashing.
1533    pub fn hash_shape<H>(&self, state: &mut H)
1534    where
1535        H: Hasher,
1536    {
1537        mem::discriminant(self).hash(state);
1538        match self.expr_kind() {
1539            ExprKind::Lit(lit) => lit.hash(state),
1540            ExprKind::Var(v) => v.hash(state),
1541            ExprKind::Slot(s) => s.hash(state),
1542            ExprKind::Unknown(u) => u.hash(state),
1543            ExprKind::If {
1544                test_expr,
1545                then_expr,
1546                else_expr,
1547            } => {
1548                test_expr.hash_shape(state);
1549                then_expr.hash_shape(state);
1550                else_expr.hash_shape(state);
1551            }
1552            ExprKind::And { left, right } => {
1553                left.hash_shape(state);
1554                right.hash_shape(state);
1555            }
1556            ExprKind::Or { left, right } => {
1557                left.hash_shape(state);
1558                right.hash_shape(state);
1559            }
1560            ExprKind::UnaryApp { op, arg } => {
1561                op.hash(state);
1562                arg.hash_shape(state);
1563            }
1564            ExprKind::BinaryApp { op, arg1, arg2 } => {
1565                op.hash(state);
1566                arg1.hash_shape(state);
1567                arg2.hash_shape(state);
1568            }
1569            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1570                fn_name.hash(state);
1571                state.write_usize(args.len());
1572                args.iter().for_each(|a| {
1573                    a.hash_shape(state);
1574                });
1575            }
1576            ExprKind::GetAttr { expr, attr } => {
1577                expr.hash_shape(state);
1578                attr.hash(state);
1579            }
1580            ExprKind::HasAttr { expr, attr } => {
1581                expr.hash_shape(state);
1582                attr.hash(state);
1583            }
1584            ExprKind::Like { expr, pattern } => {
1585                expr.hash_shape(state);
1586                pattern.hash(state);
1587            }
1588            ExprKind::Set(elems) => {
1589                state.write_usize(elems.len());
1590                elems.iter().for_each(|e| {
1591                    e.hash_shape(state);
1592                })
1593            }
1594            ExprKind::Record(map) => {
1595                state.write_usize(map.len());
1596                map.iter().for_each(|(s, a)| {
1597                    s.hash(state);
1598                    a.hash_shape(state);
1599                });
1600            }
1601            ExprKind::Is { expr, entity_type } => {
1602                expr.hash_shape(state);
1603                entity_type.hash(state);
1604            }
1605            #[cfg(feature = "tolerant-ast")]
1606            ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1607        }
1608    }
1609}
1610
1611/// AST variables
1612#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1613#[serde(rename_all = "camelCase")]
1614#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1615#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1616#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1617pub enum Var {
1618    /// the Principal of the given request
1619    Principal,
1620    /// the Action of the given request
1621    Action,
1622    /// the Resource of the given request
1623    Resource,
1624    /// the Context of the given request
1625    Context,
1626}
1627
1628impl From<PrincipalOrResource> for Var {
1629    fn from(v: PrincipalOrResource) -> Self {
1630        match v {
1631            PrincipalOrResource::Principal => Var::Principal,
1632            PrincipalOrResource::Resource => Var::Resource,
1633        }
1634    }
1635}
1636
1637// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1638#[allow(clippy::fallible_impl_from)]
1639impl From<Var> for Id {
1640    fn from(var: Var) -> Self {
1641        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1642        #[allow(clippy::unwrap_used)]
1643        format!("{var}").parse().unwrap()
1644    }
1645}
1646
1647// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1648#[allow(clippy::fallible_impl_from)]
1649impl From<Var> for UnreservedId {
1650    fn from(var: Var) -> Self {
1651        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`
1652        #[allow(clippy::unwrap_used)]
1653        Id::from(var).try_into().unwrap()
1654    }
1655}
1656
1657impl std::fmt::Display for Var {
1658    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1659        match self {
1660            Self::Principal => write!(f, "principal"),
1661            Self::Action => write!(f, "action"),
1662            Self::Resource => write!(f, "resource"),
1663            Self::Context => write!(f, "context"),
1664        }
1665    }
1666}
1667
1668#[cfg(test)]
1669mod test {
1670    use cool_asserts::assert_matches;
1671    use itertools::Itertools;
1672    use smol_str::ToSmolStr;
1673    use std::collections::{hash_map::DefaultHasher, HashSet};
1674
1675    use crate::expr_builder::ExprBuilder as _;
1676
1677    use super::*;
1678
1679    pub fn all_vars() -> impl Iterator<Item = Var> {
1680        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1681    }
1682
1683    // Tests that Var::Into never panics
1684    #[test]
1685    fn all_vars_are_ids() {
1686        for var in all_vars() {
1687            let _id: Id = var.into();
1688            let _id: UnreservedId = var.into();
1689        }
1690    }
1691
1692    #[test]
1693    fn exprs() {
1694        assert_eq!(
1695            Expr::val(33),
1696            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1697        );
1698        assert_eq!(
1699            Expr::val("hello"),
1700            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1701        );
1702        assert_eq!(
1703            Expr::val(EntityUID::with_eid("foo")),
1704            Expr::new(
1705                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1706                None,
1707                ()
1708            )
1709        );
1710        assert_eq!(
1711            Expr::var(Var::Principal),
1712            Expr::new(ExprKind::Var(Var::Principal), None, ())
1713        );
1714        assert_eq!(
1715            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1716            Expr::new(
1717                ExprKind::If {
1718                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1719                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1720                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1721                },
1722                None,
1723                ()
1724            )
1725        );
1726        assert_eq!(
1727            Expr::not(Expr::val(false)),
1728            Expr::new(
1729                ExprKind::UnaryApp {
1730                    op: UnaryOp::Not,
1731                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1732                },
1733                None,
1734                ()
1735            )
1736        );
1737        assert_eq!(
1738            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1739            Expr::new(
1740                ExprKind::GetAttr {
1741                    expr: Arc::new(Expr::new(
1742                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1743                        None,
1744                        ()
1745                    )),
1746                    attr: "some_attr".into()
1747                },
1748                None,
1749                ()
1750            )
1751        );
1752        assert_eq!(
1753            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1754            Expr::new(
1755                ExprKind::HasAttr {
1756                    expr: Arc::new(Expr::new(
1757                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1758                        None,
1759                        ()
1760                    )),
1761                    attr: "some_attr".into()
1762                },
1763                None,
1764                ()
1765            )
1766        );
1767        assert_eq!(
1768            Expr::is_entity_type(
1769                Expr::val(EntityUID::with_eid("foo")),
1770                "Type".parse().unwrap()
1771            ),
1772            Expr::new(
1773                ExprKind::Is {
1774                    expr: Arc::new(Expr::new(
1775                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1776                        None,
1777                        ()
1778                    )),
1779                    entity_type: "Type".parse().unwrap()
1780                },
1781                None,
1782                ()
1783            ),
1784        );
1785    }
1786
1787    #[test]
1788    fn like_display() {
1789        // `\0` escaped form is `\0`.
1790        let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
1791        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1792        // `\`'s escaped form is `\\`
1793        let e = Expr::like(
1794            Expr::val("a"),
1795            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
1796        );
1797        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1798        // `\`'s escaped form is `\\`
1799        let e = Expr::like(
1800            Expr::val("a"),
1801            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
1802        );
1803        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1804        // literal star's escaped from is `\*`
1805        let e = Expr::like(
1806            Expr::val("a"),
1807            Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
1808        );
1809        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1810    }
1811
1812    #[test]
1813    fn has_display() {
1814        // `\0` escaped form is `\0`.
1815        let e = Expr::has_attr(Expr::val("a"), "\0".into());
1816        assert_eq!(format!("{e}"), r#""a" has "\0""#);
1817        // `\`'s escaped form is `\\`
1818        let e = Expr::has_attr(Expr::val("a"), r"\".into());
1819        assert_eq!(format!("{e}"), r#""a" has "\\""#);
1820    }
1821
1822    #[test]
1823    fn slot_display() {
1824        let e = Expr::slot(SlotId::principal());
1825        assert_eq!(format!("{e}"), "?principal");
1826        let e = Expr::slot(SlotId::resource());
1827        assert_eq!(format!("{e}"), "?resource");
1828        let e = Expr::val(EntityUID::with_eid("eid"));
1829        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1830    }
1831
1832    #[test]
1833    fn simple_slots() {
1834        let e = Expr::slot(SlotId::principal());
1835        let p = SlotId::principal();
1836        let r = SlotId::resource();
1837        let set: HashSet<SlotId> = HashSet::from_iter([p]);
1838        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1839        let e = Expr::or(
1840            Expr::slot(SlotId::principal()),
1841            Expr::ite(
1842                Expr::val(true),
1843                Expr::slot(SlotId::resource()),
1844                Expr::val(false),
1845            ),
1846        );
1847        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
1848        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1849    }
1850
1851    #[test]
1852    fn unknowns() {
1853        let e = Expr::ite(
1854            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1855            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1856            Expr::unknown(Unknown::new_untyped("c")),
1857        );
1858        let unknowns = e.unknowns().collect_vec();
1859        assert_eq!(unknowns.len(), 3);
1860        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
1861        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
1862        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
1863    }
1864
1865    #[test]
1866    fn is_unknown() {
1867        let e = Expr::ite(
1868            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1869            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1870            Expr::unknown(Unknown::new_untyped("c")),
1871        );
1872        assert!(e.contains_unknown());
1873        let e = Expr::ite(
1874            Expr::not(Expr::val(true)),
1875            Expr::and(Expr::val(1), Expr::val(3)),
1876            Expr::val(1),
1877        );
1878        assert!(!e.contains_unknown());
1879    }
1880
1881    #[test]
1882    fn expr_with_data() {
1883        let e = ExprBuilder::with_data("data").val(1);
1884        assert_eq!(e.into_data(), "data");
1885    }
1886
1887    #[test]
1888    fn expr_shape_only_eq() {
1889        let temp = ExprBuilder::with_data(1).val(1);
1890        let exprs = &[
1891            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1892            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1893            (
1894                ExprBuilder::with_data(1).var(Var::Principal),
1895                Expr::var(Var::Principal),
1896            ),
1897            (
1898                ExprBuilder::with_data(1).slot(SlotId::principal()),
1899                Expr::slot(SlotId::principal()),
1900            ),
1901            (
1902                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1903                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1904            ),
1905            (
1906                ExprBuilder::with_data(1).not(temp.clone()),
1907                Expr::not(Expr::val(1)),
1908            ),
1909            (
1910                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1911                Expr::is_eq(Expr::val(1), Expr::val(1)),
1912            ),
1913            (
1914                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1915                Expr::and(Expr::val(1), Expr::val(1)),
1916            ),
1917            (
1918                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1919                Expr::or(Expr::val(1), Expr::val(1)),
1920            ),
1921            (
1922                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1923                Expr::less(Expr::val(1), Expr::val(1)),
1924            ),
1925            (
1926                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1927                Expr::lesseq(Expr::val(1), Expr::val(1)),
1928            ),
1929            (
1930                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1931                Expr::greater(Expr::val(1), Expr::val(1)),
1932            ),
1933            (
1934                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1935                Expr::greatereq(Expr::val(1), Expr::val(1)),
1936            ),
1937            (
1938                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1939                Expr::add(Expr::val(1), Expr::val(1)),
1940            ),
1941            (
1942                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1943                Expr::sub(Expr::val(1), Expr::val(1)),
1944            ),
1945            (
1946                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
1947                Expr::mul(Expr::val(1), Expr::val(1)),
1948            ),
1949            (
1950                ExprBuilder::with_data(1).neg(temp.clone()),
1951                Expr::neg(Expr::val(1)),
1952            ),
1953            (
1954                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1955                Expr::is_in(Expr::val(1), Expr::val(1)),
1956            ),
1957            (
1958                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1959                Expr::contains(Expr::val(1), Expr::val(1)),
1960            ),
1961            (
1962                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1963                Expr::contains_all(Expr::val(1), Expr::val(1)),
1964            ),
1965            (
1966                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1967                Expr::contains_any(Expr::val(1), Expr::val(1)),
1968            ),
1969            (
1970                ExprBuilder::with_data(1).is_empty(temp.clone()),
1971                Expr::is_empty(Expr::val(1)),
1972            ),
1973            (
1974                ExprBuilder::with_data(1).set([temp.clone()]),
1975                Expr::set([Expr::val(1)]),
1976            ),
1977            (
1978                ExprBuilder::with_data(1)
1979                    .record([("foo".into(), temp.clone())])
1980                    .unwrap(),
1981                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
1982            ),
1983            (
1984                ExprBuilder::with_data(1)
1985                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1986                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1987            ),
1988            (
1989                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1990                Expr::get_attr(Expr::val(1), "foo".into()),
1991            ),
1992            (
1993                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
1994                Expr::has_attr(Expr::val(1), "foo".into()),
1995            ),
1996            (
1997                ExprBuilder::with_data(1)
1998                    .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
1999                Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2000            ),
2001            (
2002                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2003                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2004            ),
2005        ];
2006
2007        for (e0, e1) in exprs {
2008            assert!(e0.eq_shape(e0));
2009            assert!(e1.eq_shape(e1));
2010            assert!(e0.eq_shape(e1));
2011            assert!(e1.eq_shape(e0));
2012
2013            let mut hasher0 = DefaultHasher::new();
2014            e0.hash_shape(&mut hasher0);
2015            let hash0 = hasher0.finish();
2016
2017            let mut hasher1 = DefaultHasher::new();
2018            e1.hash_shape(&mut hasher1);
2019            let hash1 = hasher1.finish();
2020
2021            assert_eq!(hash0, hash1);
2022        }
2023    }
2024
2025    #[test]
2026    fn expr_shape_only_not_eq() {
2027        let expr1 = ExprBuilder::with_data(1).val(1);
2028        let expr2 = ExprBuilder::with_data(1).val(2);
2029        assert_ne!(
2030            ExprShapeOnly::new_from_borrowed(&expr1),
2031            ExprShapeOnly::new_from_borrowed(&expr2)
2032        );
2033    }
2034
2035    #[test]
2036    fn expr_shape_only_set_prefix_ne() {
2037        let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2038        let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2039        let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2040
2041        assert_ne!(e1, e2);
2042        assert_ne!(e1, e3);
2043        assert_ne!(e2, e1);
2044        assert_ne!(e2, e3);
2045        assert_ne!(e3, e1);
2046        assert_ne!(e2, e1);
2047    }
2048
2049    #[test]
2050    fn expr_shape_only_ext_fn_arg_prefix_ne() {
2051        let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2052            "decimal".parse().unwrap(),
2053            vec![],
2054        ));
2055        let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2056            "decimal".parse().unwrap(),
2057            vec![Expr::val("0.0")],
2058        ));
2059        let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2060            "decimal".parse().unwrap(),
2061            vec![Expr::val("0.0"), Expr::val("0.0")],
2062        ));
2063
2064        assert_ne!(e1, e2);
2065        assert_ne!(e1, e3);
2066        assert_ne!(e2, e1);
2067        assert_ne!(e2, e3);
2068        assert_ne!(e3, e1);
2069        assert_ne!(e2, e1);
2070    }
2071
2072    #[test]
2073    fn expr_shape_only_record_attr_prefix_ne() {
2074        let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2075        let e2 = ExprShapeOnly::new_from_owned(
2076            Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2077        );
2078        let e3 = ExprShapeOnly::new_from_owned(
2079            Expr::record([
2080                ("a".to_smolstr(), Expr::val(1)),
2081                ("b".to_smolstr(), Expr::val(2)),
2082            ])
2083            .unwrap(),
2084        );
2085
2086        assert_ne!(e1, e2);
2087        assert_ne!(e1, e3);
2088        assert_ne!(e2, e1);
2089        assert_ne!(e2, e3);
2090        assert_ne!(e3, e1);
2091        assert_ne!(e2, e1);
2092    }
2093
2094    #[test]
2095    fn untyped_subst_present() {
2096        let u = Unknown {
2097            name: "foo".into(),
2098            type_annotation: None,
2099        };
2100        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2101        match r {
2102            Ok(e) => assert_eq!(e, Expr::val(1)),
2103            Err(empty) => match empty {},
2104        }
2105    }
2106
2107    #[test]
2108    fn untyped_subst_present_correct_type() {
2109        let u = Unknown {
2110            name: "foo".into(),
2111            type_annotation: Some(Type::Long),
2112        };
2113        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2114        match r {
2115            Ok(e) => assert_eq!(e, Expr::val(1)),
2116            Err(empty) => match empty {},
2117        }
2118    }
2119
2120    #[test]
2121    fn untyped_subst_present_wrong_type() {
2122        let u = Unknown {
2123            name: "foo".into(),
2124            type_annotation: Some(Type::Bool),
2125        };
2126        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2127        match r {
2128            Ok(e) => assert_eq!(e, Expr::val(1)),
2129            Err(empty) => match empty {},
2130        }
2131    }
2132
2133    #[test]
2134    fn untyped_subst_not_present() {
2135        let u = Unknown {
2136            name: "foo".into(),
2137            type_annotation: Some(Type::Bool),
2138        };
2139        let r = UntypedSubstitution::substitute(&u, None);
2140        match r {
2141            Ok(n) => assert_eq!(n, Expr::unknown(u)),
2142            Err(empty) => match empty {},
2143        }
2144    }
2145
2146    #[test]
2147    fn typed_subst_present() {
2148        let u = Unknown {
2149            name: "foo".into(),
2150            type_annotation: None,
2151        };
2152        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2153        assert_eq!(e, Expr::val(1));
2154    }
2155
2156    #[test]
2157    fn typed_subst_present_correct_type() {
2158        let u = Unknown {
2159            name: "foo".into(),
2160            type_annotation: Some(Type::Long),
2161        };
2162        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2163        assert_eq!(e, Expr::val(1));
2164    }
2165
2166    #[test]
2167    fn typed_subst_present_wrong_type() {
2168        let u = Unknown {
2169            name: "foo".into(),
2170            type_annotation: Some(Type::Bool),
2171        };
2172        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2173        assert_matches!(
2174            r,
2175            SubstitutionError::TypeError {
2176                expected: Type::Bool,
2177                actual: Type::Long,
2178            }
2179        );
2180    }
2181
2182    #[test]
2183    fn typed_subst_not_present() {
2184        let u = Unknown {
2185            name: "foo".into(),
2186            type_annotation: None,
2187        };
2188        let r = TypedSubstitution::substitute(&u, None).unwrap();
2189        assert_eq!(r, Expr::unknown(u));
2190    }
2191}