cedar_policy_core/ast/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::{
18    ast::*,
19    extensions::Extensions,
20    parser::{err::ParseErrors, Loc},
21};
22use miette::Diagnostic;
23use serde::{Deserialize, Serialize};
24use smol_str::SmolStr;
25use std::{
26    collections::{btree_map, BTreeMap, HashMap},
27    hash::{Hash, Hasher},
28    mem,
29    sync::Arc,
30};
31use thiserror::Error;
32
33#[cfg(feature = "wasm")]
34extern crate tsify;
35
36/// Internal AST for expressions used by the policy evaluator.
37/// This structure is a wrapper around an `ExprKind`, which is the expression
38/// variant this object contains. It also contains source information about
39/// where the expression was written in policy source code, and some generic
40/// data which is stored on each node of the AST.
41/// Cloning is O(1).
42#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
43pub struct Expr<T = ()> {
44    expr_kind: ExprKind<T>,
45    source_loc: Option<Loc>,
46    data: T,
47}
48
49/// The possible expression variants. This enum should be matched on by code
50/// recursively traversing the AST.
51#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
52pub enum ExprKind<T = ()> {
53    /// Literal value
54    Lit(Literal),
55    /// Variable
56    Var(Var),
57    /// Template Slots
58    Slot(SlotId),
59    /// Symbolic Unknown for partial-eval
60    Unknown(Unknown),
61    /// Ternary expression
62    If {
63        /// Condition for the ternary expression. Must evaluate to Bool type
64        test_expr: Arc<Expr<T>>,
65        /// Value if true
66        then_expr: Arc<Expr<T>>,
67        /// Value if false
68        else_expr: Arc<Expr<T>>,
69    },
70    /// Boolean AND
71    And {
72        /// Left operand, which will be eagerly evaluated
73        left: Arc<Expr<T>>,
74        /// Right operand, which may not be evaluated due to short-circuiting
75        right: Arc<Expr<T>>,
76    },
77    /// Boolean OR
78    Or {
79        /// Left operand, which will be eagerly evaluated
80        left: Arc<Expr<T>>,
81        /// Right operand, which may not be evaluated due to short-circuiting
82        right: Arc<Expr<T>>,
83    },
84    /// Application of a built-in unary operator (single parameter)
85    UnaryApp {
86        /// Unary operator to apply
87        op: UnaryOp,
88        /// Argument to apply operator to
89        arg: Arc<Expr<T>>,
90    },
91    /// Application of a built-in binary operator (two parameters)
92    BinaryApp {
93        /// Binary operator to apply
94        op: BinaryOp,
95        /// First arg
96        arg1: Arc<Expr<T>>,
97        /// Second arg
98        arg2: Arc<Expr<T>>,
99    },
100    /// Application of an extension function to n arguments
101    /// INVARIANT (MethodStyleArgs):
102    ///   if op.style is MethodStyle then args _cannot_ be empty.
103    ///     The first element of args refers to the subject of the method call
104    /// Ideally, we find some way to make this non-representable.
105    ExtensionFunctionApp {
106        /// Extension function to apply
107        fn_name: Name,
108        /// Args to apply the function to
109        args: Arc<Vec<Expr<T>>>,
110    },
111    /// Get an attribute of an entity, or a field of a record
112    GetAttr {
113        /// Expression to get an attribute/field of. Must evaluate to either
114        /// Entity or Record type
115        expr: Arc<Expr<T>>,
116        /// Attribute or field to get
117        attr: SmolStr,
118    },
119    /// Does the given `expr` have the given `attr`?
120    HasAttr {
121        /// Expression to test. Must evaluate to either Entity or Record type
122        expr: Arc<Expr<T>>,
123        /// Attribute or field to check for
124        attr: SmolStr,
125    },
126    /// Regex-like string matching similar to IAM's `StringLike` operator.
127    Like {
128        /// Expression to test. Must evaluate to String type
129        expr: Arc<Expr<T>>,
130        /// Pattern to match on; can include the wildcard *, which matches any string.
131        /// To match a literal `*` in the test expression, users can use `\*`.
132        /// Be careful the backslash in `\*` must not be another escape sequence. For instance, `\\*` matches a backslash plus an arbitrary string.
133        pattern: Pattern,
134    },
135    /// Entity type test. Does the first argument have the entity type
136    /// specified by the second argument.
137    Is {
138        /// Expression to test. Must evaluate to an Entity.
139        expr: Arc<Expr<T>>,
140        /// The [`EntityType`] used for the type membership test.
141        entity_type: EntityType,
142    },
143    /// Set (whose elements may be arbitrary expressions)
144    //
145    // This is backed by `Vec` (and not e.g. `HashSet`), because two `Expr`s
146    // that are syntactically unequal, may actually be semantically equal --
147    // i.e., we can't do the dedup of duplicates until all of the `Expr`s are
148    // evaluated into `Value`s
149    Set(Arc<Vec<Expr<T>>>),
150    /// Anonymous record (whose elements may be arbitrary expressions)
151    Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
152}
153
154impl From<Value> for Expr {
155    fn from(v: Value) -> Self {
156        Expr::from(v.value).with_maybe_source_loc(v.loc)
157    }
158}
159
160impl From<ValueKind> for Expr {
161    fn from(v: ValueKind) -> Self {
162        match v {
163            ValueKind::Lit(lit) => Expr::val(lit),
164            ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
165            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
166            #[allow(clippy::expect_used)]
167            ValueKind::Record(record) => Expr::record(
168                Arc::unwrap_or_clone(record)
169                    .into_iter()
170                    .map(|(k, v)| (k, Expr::from(v))),
171            )
172            .expect("cannot have duplicate key because the input was already a BTreeMap"),
173            ValueKind::ExtensionValue(ev) => Expr::from(ev.as_ref().clone()),
174        }
175    }
176}
177
178impl From<PartialValue> for Expr {
179    fn from(pv: PartialValue) -> Self {
180        match pv {
181            PartialValue::Value(v) => Expr::from(v),
182            PartialValue::Residual(expr) => expr,
183        }
184    }
185}
186
187impl<T> Expr<T> {
188    fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
189        Self {
190            expr_kind,
191            source_loc,
192            data,
193        }
194    }
195
196    /// Access the inner `ExprKind` for this `Expr`. The `ExprKind` is the
197    /// `enum` which specifies the expression variant, so it must be accessed by
198    /// any code matching and recursing on an expression.
199    pub fn expr_kind(&self) -> &ExprKind<T> {
200        &self.expr_kind
201    }
202
203    /// Access the inner `ExprKind`, taking ownership and consuming the `Expr`.
204    pub fn into_expr_kind(self) -> ExprKind<T> {
205        self.expr_kind
206    }
207
208    /// Access the data stored on the `Expr`.
209    pub fn data(&self) -> &T {
210        &self.data
211    }
212
213    /// Access the data stored on the `Expr`, taking ownership and consuming the
214    /// `Expr`.
215    pub fn into_data(self) -> T {
216        self.data
217    }
218
219    /// Access the `Loc` stored on the `Expr`.
220    pub fn source_loc(&self) -> Option<&Loc> {
221        self.source_loc.as_ref()
222    }
223
224    /// Return the `Expr`, but with the new `source_loc` (or `None`).
225    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
226        Self { source_loc, ..self }
227    }
228
229    /// Update the data for this `Expr`. A convenient function used by the
230    /// Validator in one place.
231    pub fn set_data(&mut self, data: T) {
232        self.data = data;
233    }
234
235    /// Check whether this expression is an entity reference
236    ///
237    /// This is used for policy scopes, where some syntax is
238    /// required to be an entity reference.
239    pub fn is_ref(&self) -> bool {
240        match &self.expr_kind {
241            ExprKind::Lit(lit) => lit.is_ref(),
242            _ => false,
243        }
244    }
245
246    /// Check whether this expression is a slot.
247    pub fn is_slot(&self) -> bool {
248        matches!(&self.expr_kind, ExprKind::Slot(_))
249    }
250
251    /// Check whether this expression is a set of entity references
252    ///
253    /// This is used for policy scopes, where some syntax is
254    /// required to be an entity reference set.
255    pub fn is_ref_set(&self) -> bool {
256        match &self.expr_kind {
257            ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
258            _ => false,
259        }
260    }
261
262    /// Iterate over all sub-expressions in this expression
263    pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
264        expr_iterator::ExprIterator::new(self)
265    }
266
267    /// Iterate over all of the slots in this policy AST
268    pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
269        self.subexpressions()
270            .filter_map(|exp| match &exp.expr_kind {
271                ExprKind::Slot(slotid) => Some(Slot {
272                    id: *slotid,
273                    loc: exp.source_loc().cloned(),
274                }),
275                _ => None,
276            })
277    }
278
279    /// Determine if the expression is projectable under partial evaluation
280    /// An expression is projectable if it's guaranteed to never error on evaluation
281    /// This is true if the expression is entirely composed of values or unknowns
282    pub fn is_projectable(&self) -> bool {
283        self.subexpressions().all(|e| {
284            matches!(
285                e.expr_kind(),
286                ExprKind::Lit(_)
287                    | ExprKind::Unknown(_)
288                    | ExprKind::Set(_)
289                    | ExprKind::Var(_)
290                    | ExprKind::Record(_)
291            )
292        })
293    }
294
295    /// Try to compute the runtime type of this expression. This operation may
296    /// fail (returning `None`), for example, when asked to get the type of any
297    /// variables, any attributes of entities or records, or an `unknown`
298    /// without an explicitly annotated type.
299    ///
300    /// Also note that this is _not_ typechecking the expression. It does not
301    /// check that the expression actually evaluates to a value (as opposed to
302    /// erroring).
303    ///
304    /// Because of these limitations, this function should only be used to
305    /// obtain a type for use in diagnostics such as error strings.
306    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
307        match &self.expr_kind {
308            ExprKind::Lit(l) => Some(l.type_of()),
309            ExprKind::Var(_) => None,
310            ExprKind::Slot(_) => None,
311            ExprKind::Unknown(u) => u.type_annotation.clone(),
312            ExprKind::If {
313                then_expr,
314                else_expr,
315                ..
316            } => {
317                let type_of_then = then_expr.try_type_of(extensions);
318                let type_of_else = else_expr.try_type_of(extensions);
319                if type_of_then == type_of_else {
320                    type_of_then
321                } else {
322                    None
323                }
324            }
325            ExprKind::And { .. } => Some(Type::Bool),
326            ExprKind::Or { .. } => Some(Type::Bool),
327            ExprKind::UnaryApp {
328                op: UnaryOp::Neg, ..
329            } => Some(Type::Long),
330            ExprKind::UnaryApp {
331                op: UnaryOp::Not, ..
332            } => Some(Type::Bool),
333            ExprKind::BinaryApp {
334                op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
335                ..
336            } => Some(Type::Long),
337            ExprKind::BinaryApp {
338                op:
339                    BinaryOp::Contains
340                    | BinaryOp::ContainsAll
341                    | BinaryOp::ContainsAny
342                    | BinaryOp::Eq
343                    | BinaryOp::In
344                    | BinaryOp::Less
345                    | BinaryOp::LessEq,
346                ..
347            } => Some(Type::Bool),
348            ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
349                .func(fn_name)
350                .ok()?
351                .return_type()
352                .map(|rty| rty.clone().into()),
353            // We could try to be more complete here, but we can't do all that
354            // much better without evaluating the argument. Even if we know it's
355            // a record `Type::Record` tells us nothing about the type of the
356            // attribute.
357            ExprKind::GetAttr { .. } => None,
358            ExprKind::HasAttr { .. } => Some(Type::Bool),
359            ExprKind::Like { .. } => Some(Type::Bool),
360            ExprKind::Is { .. } => Some(Type::Bool),
361            ExprKind::Set(_) => Some(Type::Set),
362            ExprKind::Record(_) => Some(Type::Record),
363        }
364    }
365}
366
367#[allow(dead_code)] // some constructors are currently unused, or used only in tests, but provided for completeness
368#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
369impl Expr {
370    /// Create an `Expr` that's just a single `Literal`.
371    ///
372    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
373    pub fn val(v: impl Into<Literal>) -> Self {
374        ExprBuilder::new().val(v)
375    }
376
377    /// Create an `Expr` that's just a single `Unknown`.
378    pub fn unknown(u: Unknown) -> Self {
379        ExprBuilder::new().unknown(u)
380    }
381
382    /// Create an `Expr` that's just this literal `Var`
383    pub fn var(v: Var) -> Self {
384        ExprBuilder::new().var(v)
385    }
386
387    /// Create an `Expr` that's just this `SlotId`
388    pub fn slot(s: SlotId) -> Self {
389        ExprBuilder::new().slot(s)
390    }
391
392    /// Create a ternary (if-then-else) `Expr`.
393    ///
394    /// `test_expr` must evaluate to a Bool type
395    pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
396        ExprBuilder::new().ite(test_expr, then_expr, else_expr)
397    }
398
399    /// Create a ternary (if-then-else) `Expr`.
400    /// Takes `Arc`s instead of owned `Expr`s.
401    /// `test_expr` must evaluate to a Bool type
402    pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
403        ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
404    }
405
406    /// Create a 'not' expression. `e` must evaluate to Bool type
407    pub fn not(e: Expr) -> Self {
408        ExprBuilder::new().not(e)
409    }
410
411    /// Create a '==' expression
412    pub fn is_eq(e1: Expr, e2: Expr) -> Self {
413        ExprBuilder::new().is_eq(e1, e2)
414    }
415
416    /// Create a '!=' expression
417    pub fn noteq(e1: Expr, e2: Expr) -> Self {
418        ExprBuilder::new().noteq(e1, e2)
419    }
420
421    /// Create an 'and' expression. Arguments must evaluate to Bool type
422    pub fn and(e1: Expr, e2: Expr) -> Self {
423        ExprBuilder::new().and(e1, e2)
424    }
425
426    /// Create an 'or' expression. Arguments must evaluate to Bool type
427    pub fn or(e1: Expr, e2: Expr) -> Self {
428        ExprBuilder::new().or(e1, e2)
429    }
430
431    /// Create a '<' expression. Arguments must evaluate to Long type
432    pub fn less(e1: Expr, e2: Expr) -> Self {
433        ExprBuilder::new().less(e1, e2)
434    }
435
436    /// Create a '<=' expression. Arguments must evaluate to Long type
437    pub fn lesseq(e1: Expr, e2: Expr) -> Self {
438        ExprBuilder::new().lesseq(e1, e2)
439    }
440
441    /// Create a '>' expression. Arguments must evaluate to Long type
442    pub fn greater(e1: Expr, e2: Expr) -> Self {
443        ExprBuilder::new().greater(e1, e2)
444    }
445
446    /// Create a '>=' expression. Arguments must evaluate to Long type
447    pub fn greatereq(e1: Expr, e2: Expr) -> Self {
448        ExprBuilder::new().greatereq(e1, e2)
449    }
450
451    /// Create an 'add' expression. Arguments must evaluate to Long type
452    pub fn add(e1: Expr, e2: Expr) -> Self {
453        ExprBuilder::new().add(e1, e2)
454    }
455
456    /// Create a 'sub' expression. Arguments must evaluate to Long type
457    pub fn sub(e1: Expr, e2: Expr) -> Self {
458        ExprBuilder::new().sub(e1, e2)
459    }
460
461    /// Create a 'mul' expression. Arguments must evaluate to Long type
462    pub fn mul(e1: Expr, e2: Expr) -> Self {
463        ExprBuilder::new().mul(e1, e2)
464    }
465
466    /// Create a 'neg' expression. `e` must evaluate to Long type.
467    pub fn neg(e: Expr) -> Self {
468        ExprBuilder::new().neg(e)
469    }
470
471    /// Create an 'in' expression. First argument must evaluate to Entity type.
472    /// Second argument must evaluate to either Entity type or Set type where
473    /// all set elements have Entity type.
474    pub fn is_in(e1: Expr, e2: Expr) -> Self {
475        ExprBuilder::new().is_in(e1, e2)
476    }
477
478    /// Create a 'contains' expression.
479    /// First argument must have Set type.
480    pub fn contains(e1: Expr, e2: Expr) -> Self {
481        ExprBuilder::new().contains(e1, e2)
482    }
483
484    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
485    pub fn contains_all(e1: Expr, e2: Expr) -> Self {
486        ExprBuilder::new().contains_all(e1, e2)
487    }
488
489    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
490    pub fn contains_any(e1: Expr, e2: Expr) -> Self {
491        ExprBuilder::new().contains_any(e1, e2)
492    }
493
494    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
495    pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
496        ExprBuilder::new().set(exprs)
497    }
498
499    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
500    pub fn record(
501        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
502    ) -> Result<Self, ExpressionConstructionError> {
503        ExprBuilder::new().record(pairs)
504    }
505
506    /// Create an `Expr` which evaluates to a Record with the given key-value mapping.
507    ///
508    /// If you have an iterator of pairs, generally prefer calling
509    /// `Expr::record()` instead of `.collect()`-ing yourself and calling this,
510    /// potentially for efficiency reasons but also because `Expr::record()`
511    /// will properly handle duplicate keys but your own `.collect()` will not
512    /// (by default).
513    pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
514        ExprBuilder::new().record_arc(map)
515    }
516
517    /// Create an `Expr` which calls the extension function with the given
518    /// `Name` on `args`
519    pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
520        ExprBuilder::new().call_extension_fn(fn_name, args)
521    }
522
523    /// Create an application `Expr` which applies the given built-in unary
524    /// operator to the given `arg`
525    pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
526        ExprBuilder::new().unary_app(op, arg)
527    }
528
529    /// Create an application `Expr` which applies the given built-in binary
530    /// operator to `arg1` and `arg2`
531    pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
532        ExprBuilder::new().binary_app(op, arg1, arg2)
533    }
534
535    /// Create an `Expr` which gets the attribute of some `Entity` or the field
536    /// of some record.
537    ///
538    /// `expr` must evaluate to either Entity or Record type
539    pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
540        ExprBuilder::new().get_attr(expr, attr)
541    }
542
543    /// Create an `Expr` which tests for the existence of a given
544    /// attribute on a given `Entity`, or field on a given record.
545    ///
546    /// `expr` must evaluate to either Entity or Record type
547    pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
548        ExprBuilder::new().has_attr(expr, attr)
549    }
550
551    /// Create a 'like' expression.
552    ///
553    /// `expr` must evaluate to a String type
554    pub fn like(expr: Expr, pattern: impl IntoIterator<Item = PatternElem>) -> Self {
555        ExprBuilder::new().like(expr, pattern)
556    }
557
558    /// Create an `is` expression.
559    pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
560        ExprBuilder::new().is_entity_type(expr, entity_type)
561    }
562
563    /// Check if an expression contains any symbolic unknowns
564    pub fn contains_unknown(&self) -> bool {
565        self.subexpressions()
566            .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
567    }
568
569    /// Get all unknowns in an expression
570    pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
571        self.subexpressions()
572            .filter_map(|subexpr| match subexpr.expr_kind() {
573                ExprKind::Unknown(u) => Some(u),
574                _ => None,
575            })
576    }
577
578    /// Substitute unknowns with concrete values.
579    ///
580    /// Ignores unmapped unknowns.
581    /// Ignores type annotations on unknowns.
582    pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
583        match self.substitute_general::<UntypedSubstitution>(definitions) {
584            Ok(e) => e,
585            Err(empty) => match empty {},
586        }
587    }
588
589    /// Substitute unknowns with concrete values.
590    ///
591    /// Ignores unmapped unknowns.
592    /// Errors if the substituted value does not match the type annotation on the unknown.
593    pub fn substitute_typed(
594        &self,
595        definitions: &HashMap<SmolStr, Value>,
596    ) -> Result<Expr, SubstitutionError> {
597        self.substitute_general::<TypedSubstitution>(definitions)
598    }
599
600    /// Substitute unknowns with values
601    ///
602    /// Generic over the function implementing the substitution to allow for multiple error behaviors
603    fn substitute_general<T: SubstitutionFunction>(
604        &self,
605        definitions: &HashMap<SmolStr, Value>,
606    ) -> Result<Expr, T::Err> {
607        match self.expr_kind() {
608            ExprKind::Lit(_) => Ok(self.clone()),
609            ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
610            ExprKind::Var(_) => Ok(self.clone()),
611            ExprKind::Slot(_) => Ok(self.clone()),
612            ExprKind::If {
613                test_expr,
614                then_expr,
615                else_expr,
616            } => Ok(Expr::ite(
617                test_expr.substitute_general::<T>(definitions)?,
618                then_expr.substitute_general::<T>(definitions)?,
619                else_expr.substitute_general::<T>(definitions)?,
620            )),
621            ExprKind::And { left, right } => Ok(Expr::and(
622                left.substitute_general::<T>(definitions)?,
623                right.substitute_general::<T>(definitions)?,
624            )),
625            ExprKind::Or { left, right } => Ok(Expr::or(
626                left.substitute_general::<T>(definitions)?,
627                right.substitute_general::<T>(definitions)?,
628            )),
629            ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
630                *op,
631                arg.substitute_general::<T>(definitions)?,
632            )),
633            ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
634                *op,
635                arg1.substitute_general::<T>(definitions)?,
636                arg2.substitute_general::<T>(definitions)?,
637            )),
638            ExprKind::ExtensionFunctionApp { fn_name, args } => {
639                let args = args
640                    .iter()
641                    .map(|e| e.substitute_general::<T>(definitions))
642                    .collect::<Result<Vec<Expr>, _>>()?;
643
644                Ok(Expr::call_extension_fn(fn_name.clone(), args))
645            }
646            ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
647                expr.substitute_general::<T>(definitions)?,
648                attr.clone(),
649            )),
650            ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
651                expr.substitute_general::<T>(definitions)?,
652                attr.clone(),
653            )),
654            ExprKind::Like { expr, pattern } => Ok(Expr::like(
655                expr.substitute_general::<T>(definitions)?,
656                pattern.iter().cloned(),
657            )),
658            ExprKind::Set(members) => {
659                let members = members
660                    .iter()
661                    .map(|e| e.substitute_general::<T>(definitions))
662                    .collect::<Result<Vec<_>, _>>()?;
663                Ok(Expr::set(members))
664            }
665            ExprKind::Record(map) => {
666                let map = map
667                    .iter()
668                    .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
669                    .collect::<Result<BTreeMap<_, _>, _>>()?;
670                // PANIC SAFETY: cannot have a duplicate key because the input was already a BTreeMap
671                #[allow(clippy::expect_used)]
672                Ok(Expr::record(map)
673                    .expect("cannot have a duplicate key because the input was already a BTreeMap"))
674            }
675            ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
676                expr.substitute_general::<T>(definitions)?,
677                entity_type.clone(),
678            )),
679        }
680    }
681}
682
683/// A trait for customizing the error behavior of substitution
684trait SubstitutionFunction {
685    /// The potential errors this substitution function can return
686    type Err;
687    /// The function for implementing the substitution.
688    ///
689    /// Takes the expression being substituted,
690    /// The substitution from the map (if present)
691    /// and the type annotation from the unknown (if present)
692    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
693}
694
695struct TypedSubstitution {}
696
697impl SubstitutionFunction for TypedSubstitution {
698    type Err = SubstitutionError;
699
700    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
701        match (substitute, &value.type_annotation) {
702            (None, _) => Ok(Expr::unknown(value.clone())),
703            (Some(v), None) => Ok(v.clone().into()),
704            (Some(v), Some(t)) => {
705                if v.type_of() == *t {
706                    Ok(v.clone().into())
707                } else {
708                    Err(SubstitutionError::TypeError {
709                        expected: t.clone(),
710                        actual: v.type_of(),
711                    })
712                }
713            }
714        }
715    }
716}
717
718struct UntypedSubstitution {}
719
720impl SubstitutionFunction for UntypedSubstitution {
721    type Err = std::convert::Infallible;
722
723    fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
724        Ok(substitute
725            .map(|v| v.clone().into())
726            .unwrap_or_else(|| Expr::unknown(value.clone())))
727    }
728}
729
730impl std::fmt::Display for Expr {
731    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732        // To avoid code duplication between pretty-printers for AST Expr and EST Expr,
733        // we just convert to EST and use the EST pretty-printer.
734        // Note that converting AST->EST is lossless and infallible.
735        write!(f, "{}", crate::est::Expr::from(self.clone()))
736    }
737}
738
739impl std::str::FromStr for Expr {
740    type Err = ParseErrors;
741
742    fn from_str(s: &str) -> Result<Expr, Self::Err> {
743        crate::parser::parse_expr(s)
744    }
745}
746
747/// Enum for errors encountered during substitution
748#[derive(Debug, Clone, Diagnostic, Error)]
749pub enum SubstitutionError {
750    /// The supplied value did not match the type annotation on the unknown.
751    #[error("expected a value of type {expected}, got a value of type {actual}")]
752    TypeError {
753        /// The expected type, ie: the type the unknown was annotated with
754        expected: Type,
755        /// The type of the provided value
756        actual: Type,
757    },
758}
759
760/// Representation of a partial-evaluation Unknown at the AST level
761#[derive(Serialize, Deserialize, Hash, Debug, Clone, PartialEq, Eq)]
762pub struct Unknown {
763    /// The name of the unknown
764    pub name: SmolStr,
765    /// The type of the values that can be substituted in for the unknown.
766    /// If `None`, we have no type annotation, and thus a value of any type can
767    /// be substituted.
768    pub type_annotation: Option<Type>,
769}
770
771impl Unknown {
772    /// Create a new untyped `Unknown`
773    pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
774        Self {
775            name: name.into(),
776            type_annotation: None,
777        }
778    }
779
780    /// Create a new `Unknown` with type annotation. (Only values of the given
781    /// type can be substituted.)
782    pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
783        Self {
784            name: name.into(),
785            type_annotation: Some(ty),
786        }
787    }
788}
789
790impl std::fmt::Display for Unknown {
791    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
792        // Like the Display impl for Expr, we delegate to the EST pretty-printer,
793        // to avoid code duplication
794        write!(f, "{}", crate::est::Expr::from(Expr::unknown(self.clone())))
795    }
796}
797
798/// Builder for constructing `Expr` objects annotated with some `data`
799/// (possibly taking default value) and optionally a `source_loc`.
800#[derive(Debug)]
801pub struct ExprBuilder<T> {
802    source_loc: Option<Loc>,
803    data: T,
804}
805
806impl<T> ExprBuilder<T>
807where
808    T: Default,
809{
810    /// Construct a new `ExprBuilder` where the data used for an expression
811    /// takes a default value.
812    pub fn new() -> Self {
813        Self {
814            source_loc: None,
815            data: T::default(),
816        }
817    }
818
819    /// Create a '!=' expression.
820    /// Defined only for `T: Default` because the caller would otherwise need to
821    /// provide a `data` for the intermediate `not` Expr node.
822    pub fn noteq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
823        match &self.source_loc {
824            Some(source_loc) => ExprBuilder::new().with_source_loc(source_loc.clone()),
825            None => ExprBuilder::new(),
826        }
827        .not(self.with_expr_kind(ExprKind::BinaryApp {
828            op: BinaryOp::Eq,
829            arg1: Arc::new(e1),
830            arg2: Arc::new(e2),
831        }))
832    }
833}
834
835impl<T: Default> Default for ExprBuilder<T> {
836    fn default() -> Self {
837        Self::new()
838    }
839}
840
841impl<T> ExprBuilder<T> {
842    /// Construct a new `ExprBuild` where the specified data will be stored on
843    /// the `Expr`. This constructor does not populate the `source_loc` field,
844    /// so `with_source_loc` should be called if constructing an `Expr` where
845    /// the source location is known.
846    pub fn with_data(data: T) -> Self {
847        Self {
848            source_loc: None,
849            data,
850        }
851    }
852
853    /// Update the `ExprBuilder` to build an expression with some known location
854    /// in policy source code.
855    pub fn with_source_loc(self, source_loc: Loc) -> Self {
856        self.with_maybe_source_loc(Some(source_loc))
857    }
858
859    /// Utility used the validator to get an expression with the same source
860    /// location as an existing expression. This is done when reconstructing the
861    /// `Expr` with type information.
862    pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
863        self.with_maybe_source_loc(expr.source_loc.clone())
864    }
865
866    /// internally used to update `.source_loc` to the given `Some` or `None`
867    fn with_maybe_source_loc(mut self, maybe_source_loc: Option<Loc>) -> Self {
868        self.source_loc = maybe_source_loc;
869        self
870    }
871
872    /// Internally used by the following methods to construct an `Expr`
873    /// containing the `data` and `source_loc` in this `ExprBuilder` with some
874    /// inner `ExprKind`.
875    fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
876        Expr::new(expr_kind, self.source_loc, self.data)
877    }
878
879    /// Create an `Expr` that's just a single `Literal`.
880    ///
881    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
882    pub fn val(self, v: impl Into<Literal>) -> Expr<T> {
883        self.with_expr_kind(ExprKind::Lit(v.into()))
884    }
885
886    /// Create an `Unknown` `Expr`
887    pub fn unknown(self, u: Unknown) -> Expr<T> {
888        self.with_expr_kind(ExprKind::Unknown(u))
889    }
890
891    /// Create an `Expr` that's just this literal `Var`
892    pub fn var(self, v: Var) -> Expr<T> {
893        self.with_expr_kind(ExprKind::Var(v))
894    }
895
896    /// Create an `Expr` that's just this `SlotId`
897    pub fn slot(self, s: SlotId) -> Expr<T> {
898        self.with_expr_kind(ExprKind::Slot(s))
899    }
900
901    /// Create a ternary (if-then-else) `Expr`.
902    ///
903    /// `test_expr` must evaluate to a Bool type
904    pub fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
905        self.with_expr_kind(ExprKind::If {
906            test_expr: Arc::new(test_expr),
907            then_expr: Arc::new(then_expr),
908            else_expr: Arc::new(else_expr),
909        })
910    }
911
912    /// Create a ternary (if-then-else) `Expr`.
913    /// Takes `Arc`s instead of owned `Expr`s.
914    /// `test_expr` must evaluate to a Bool type
915    pub fn ite_arc(
916        self,
917        test_expr: Arc<Expr<T>>,
918        then_expr: Arc<Expr<T>>,
919        else_expr: Arc<Expr<T>>,
920    ) -> Expr<T> {
921        self.with_expr_kind(ExprKind::If {
922            test_expr,
923            then_expr,
924            else_expr,
925        })
926    }
927
928    /// Create a 'not' expression. `e` must evaluate to Bool type
929    pub fn not(self, e: Expr<T>) -> Expr<T> {
930        self.with_expr_kind(ExprKind::UnaryApp {
931            op: UnaryOp::Not,
932            arg: Arc::new(e),
933        })
934    }
935
936    /// Create a '==' expression
937    pub fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
938        self.with_expr_kind(ExprKind::BinaryApp {
939            op: BinaryOp::Eq,
940            arg1: Arc::new(e1),
941            arg2: Arc::new(e2),
942        })
943    }
944
945    /// Create an 'and' expression. Arguments must evaluate to Bool type
946    pub fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
947        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
948            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
949                ExprKind::Lit(Literal::Bool(*b1 && *b2))
950            }
951            _ => ExprKind::And {
952                left: Arc::new(e1),
953                right: Arc::new(e2),
954            },
955        })
956    }
957
958    /// Create an 'or' expression. Arguments must evaluate to Bool type
959    pub fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
960        self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
961            (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
962                ExprKind::Lit(Literal::Bool(*b1 || *b2))
963            }
964
965            _ => ExprKind::Or {
966                left: Arc::new(e1),
967                right: Arc::new(e2),
968            },
969        })
970    }
971
972    /// Create a '<' expression. Arguments must evaluate to Long type
973    pub fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
974        self.with_expr_kind(ExprKind::BinaryApp {
975            op: BinaryOp::Less,
976            arg1: Arc::new(e1),
977            arg2: Arc::new(e2),
978        })
979    }
980
981    /// Create a '<=' expression. Arguments must evaluate to Long type
982    pub fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
983        self.with_expr_kind(ExprKind::BinaryApp {
984            op: BinaryOp::LessEq,
985            arg1: Arc::new(e1),
986            arg2: Arc::new(e2),
987        })
988    }
989
990    /// Create an 'add' expression. Arguments must evaluate to Long type
991    pub fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
992        self.with_expr_kind(ExprKind::BinaryApp {
993            op: BinaryOp::Add,
994            arg1: Arc::new(e1),
995            arg2: Arc::new(e2),
996        })
997    }
998
999    /// Create a 'sub' expression. Arguments must evaluate to Long type
1000    pub fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1001        self.with_expr_kind(ExprKind::BinaryApp {
1002            op: BinaryOp::Sub,
1003            arg1: Arc::new(e1),
1004            arg2: Arc::new(e2),
1005        })
1006    }
1007
1008    /// Create a 'mul' expression. Arguments must evaluate to Long type
1009    pub fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1010        self.with_expr_kind(ExprKind::BinaryApp {
1011            op: BinaryOp::Mul,
1012            arg1: Arc::new(e1),
1013            arg2: Arc::new(e2),
1014        })
1015    }
1016
1017    /// Create a 'neg' expression. `e` must evaluate to Long type.
1018    pub fn neg(self, e: Expr<T>) -> Expr<T> {
1019        self.with_expr_kind(ExprKind::UnaryApp {
1020            op: UnaryOp::Neg,
1021            arg: Arc::new(e),
1022        })
1023    }
1024
1025    /// Create an 'in' expression. First argument must evaluate to Entity type.
1026    /// Second argument must evaluate to either Entity type or Set type where
1027    /// all set elements have Entity type.
1028    pub fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1029        self.with_expr_kind(ExprKind::BinaryApp {
1030            op: BinaryOp::In,
1031            arg1: Arc::new(e1),
1032            arg2: Arc::new(e2),
1033        })
1034    }
1035
1036    /// Create a 'contains' expression.
1037    /// First argument must have Set type.
1038    pub fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1039        self.with_expr_kind(ExprKind::BinaryApp {
1040            op: BinaryOp::Contains,
1041            arg1: Arc::new(e1),
1042            arg2: Arc::new(e2),
1043        })
1044    }
1045
1046    /// Create a 'contains_all' expression. Arguments must evaluate to Set type
1047    pub fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1048        self.with_expr_kind(ExprKind::BinaryApp {
1049            op: BinaryOp::ContainsAll,
1050            arg1: Arc::new(e1),
1051            arg2: Arc::new(e2),
1052        })
1053    }
1054
1055    /// Create an 'contains_any' expression. Arguments must evaluate to Set type
1056    pub fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1057        self.with_expr_kind(ExprKind::BinaryApp {
1058            op: BinaryOp::ContainsAny,
1059            arg1: Arc::new(e1),
1060            arg2: Arc::new(e2),
1061        })
1062    }
1063
1064    /// Create an `Expr` which evaluates to a Set of the given `Expr`s
1065    pub fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1066        self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1067    }
1068
1069    /// Create an `Expr` which evaluates to a Record with the given (key, value) pairs.
1070    pub fn record(
1071        self,
1072        pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1073    ) -> Result<Expr<T>, ExpressionConstructionError> {
1074        let mut map = BTreeMap::new();
1075        for (k, v) in pairs {
1076            match map.entry(k) {
1077                btree_map::Entry::Occupied(oentry) => {
1078                    return Err(expression_construction_errors::DuplicateKeyError {
1079                        key: oentry.key().clone(),
1080                        context: "in record literal",
1081                    }
1082                    .into());
1083                }
1084                btree_map::Entry::Vacant(ventry) => {
1085                    ventry.insert(v);
1086                }
1087            }
1088        }
1089        Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1090    }
1091
1092    /// Create an `Expr` which evalutes to a Record with the given key-value mapping.
1093    ///
1094    /// If you have an iterator of pairs, generally prefer calling `.record()`
1095    /// instead of `.collect()`-ing yourself and calling this, potentially for
1096    /// efficiency reasons but also because `.record()` will properly handle
1097    /// duplicate keys but your own `.collect()` will not (by default).
1098    pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1099        self.with_expr_kind(ExprKind::Record(map))
1100    }
1101
1102    /// Create an `Expr` which calls the extension function with the given
1103    /// `Name` on `args`
1104    pub fn call_extension_fn(
1105        self,
1106        fn_name: Name,
1107        args: impl IntoIterator<Item = Expr<T>>,
1108    ) -> Expr<T> {
1109        self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1110            fn_name,
1111            args: Arc::new(args.into_iter().collect()),
1112        })
1113    }
1114
1115    /// Create an application `Expr` which applies the given built-in unary
1116    /// operator to the given `arg`
1117    pub fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1118        self.with_expr_kind(ExprKind::UnaryApp {
1119            op: op.into(),
1120            arg: Arc::new(arg),
1121        })
1122    }
1123
1124    /// Create an application `Expr` which applies the given built-in binary
1125    /// operator to `arg1` and `arg2`
1126    pub fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1127        self.with_expr_kind(ExprKind::BinaryApp {
1128            op: op.into(),
1129            arg1: Arc::new(arg1),
1130            arg2: Arc::new(arg2),
1131        })
1132    }
1133
1134    /// Create an `Expr` which gets the attribute of some `Entity` or the field
1135    /// of some record.
1136    ///
1137    /// `expr` must evaluate to either Entity or Record type
1138    pub fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1139        self.with_expr_kind(ExprKind::GetAttr {
1140            expr: Arc::new(expr),
1141            attr,
1142        })
1143    }
1144
1145    /// Create an `Expr` which tests for the existence of a given
1146    /// attribute on a given `Entity`, or field on a given record.
1147    ///
1148    /// `expr` must evaluate to either Entity or Record type
1149    pub fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1150        self.with_expr_kind(ExprKind::HasAttr {
1151            expr: Arc::new(expr),
1152            attr,
1153        })
1154    }
1155
1156    /// Create a 'like' expression.
1157    ///
1158    /// `expr` must evaluate to a String type
1159    pub fn like(self, expr: Expr<T>, pattern: impl IntoIterator<Item = PatternElem>) -> Expr<T> {
1160        self.with_expr_kind(ExprKind::Like {
1161            expr: Arc::new(expr),
1162            pattern: Pattern::new(pattern),
1163        })
1164    }
1165
1166    /// Create an 'is' expression.
1167    pub fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1168        self.with_expr_kind(ExprKind::Is {
1169            expr: Arc::new(expr),
1170            entity_type,
1171        })
1172    }
1173}
1174
1175impl<T: Clone> ExprBuilder<T> {
1176    /// Create an `and` expression that may have more than two subexpressions (A && B && C)
1177    /// or may have only one subexpression, in which case no `&&` is performed at all.
1178    /// Arguments must evaluate to Bool type.
1179    ///
1180    /// This may create multiple AST `&&` nodes. If it does, all the nodes will have the same
1181    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1182    /// with another call to `with_source_loc()`.
1183    pub fn and_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1184        others.into_iter().fold(first, |acc, next| {
1185            Self::with_data(self.data.clone())
1186                .with_maybe_source_loc(self.source_loc.clone())
1187                .and(acc, next)
1188        })
1189    }
1190
1191    /// Create an `or` expression that may have more than two subexpressions (A || B || C)
1192    /// or may have only one subexpression, in which case no `||` is performed at all.
1193    /// Arguments must evaluate to Bool type.
1194    ///
1195    /// This may create multiple AST `||` nodes. If it does, all the nodes will have the same
1196    /// source location and the same `T` data (taken from this builder) unless overridden, e.g.,
1197    /// with another call to `with_source_loc()`.
1198    pub fn or_nary(self, first: Expr<T>, others: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1199        others.into_iter().fold(first, |acc, next| {
1200            Self::with_data(self.data.clone())
1201                .with_maybe_source_loc(self.source_loc.clone())
1202                .or(acc, next)
1203        })
1204    }
1205
1206    /// Create a '>' expression. Arguments must evaluate to Long type
1207    pub fn greater(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1208        // e1 > e2 is defined as !(e1 <= e2)
1209        let leq = Self::with_data(self.data.clone())
1210            .with_maybe_source_loc(self.source_loc.clone())
1211            .lesseq(e1, e2);
1212        self.not(leq)
1213    }
1214
1215    /// Create a '>=' expression. Arguments must evaluate to Long type
1216    pub fn greatereq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1217        // e1 >= e2 is defined as !(e1 < e2)
1218        let leq = Self::with_data(self.data.clone())
1219            .with_maybe_source_loc(self.source_loc.clone())
1220            .less(e1, e2);
1221        self.not(leq)
1222    }
1223}
1224
1225/// Errors when constructing an expression
1226//
1227// CAUTION: this type is publicly exported in `cedar-policy`.
1228// Don't make fields `pub`, don't make breaking changes, and use caution
1229// when adding public methods.
1230#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1231pub enum ExpressionConstructionError {
1232    /// The same key occurred two or more times
1233    #[error(transparent)]
1234    #[diagnostic(transparent)]
1235    DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1236}
1237
1238/// Error subtypes for [`ExpressionConstructionError`]
1239pub mod expression_construction_errors {
1240    use miette::Diagnostic;
1241    use smol_str::SmolStr;
1242    use thiserror::Error;
1243
1244    /// The same key occurred two or more times
1245    //
1246    // CAUTION: this type is publicly exported in `cedar-policy`.
1247    // Don't make fields `pub`, don't make breaking changes, and use caution
1248    // when adding public methods.
1249    #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1250    #[error("duplicate key `{key}` {context}")]
1251    pub struct DuplicateKeyError {
1252        /// The key which occurred two or more times
1253        pub(crate) key: SmolStr,
1254        /// Information about where the duplicate key occurred (e.g., "in record literal")
1255        pub(crate) context: &'static str,
1256    }
1257
1258    impl DuplicateKeyError {
1259        /// Get the key which occurred two or more times
1260        pub fn key(&self) -> &str {
1261            &self.key
1262        }
1263
1264        /// Make a new error with an updated `context` field
1265        pub(crate) fn with_context(self, context: &'static str) -> Self {
1266            Self { context, ..self }
1267        }
1268    }
1269}
1270
1271/// A new type wrapper around `Expr` that provides `Eq` and `Hash`
1272/// implementations that ignore any source information or other generic data
1273/// used to annotate the `Expr`.
1274#[derive(Eq, Debug, Clone)]
1275pub struct ExprShapeOnly<'a, T = ()>(&'a Expr<T>);
1276
1277impl<'a, T> ExprShapeOnly<'a, T> {
1278    /// Construct an `ExprShapeOnly` from an `Expr`. The `Expr` is not modified,
1279    /// but any comparisons on the resulting `ExprShapeOnly` will ignore source
1280    /// information and generic data.
1281    pub fn new(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1282        ExprShapeOnly(e)
1283    }
1284}
1285
1286impl<'a, T> PartialEq for ExprShapeOnly<'a, T> {
1287    fn eq(&self, other: &Self) -> bool {
1288        self.0.eq_shape(other.0)
1289    }
1290}
1291
1292impl<'a, T> Hash for ExprShapeOnly<'a, T> {
1293    fn hash<H: Hasher>(&self, state: &mut H) {
1294        self.0.hash_shape(state);
1295    }
1296}
1297
1298impl<T> Expr<T> {
1299    /// Return true if this expression (recursively) has the same expression
1300    /// kind as the argument expression. This accounts for the full recursive
1301    /// shape of the expression, but does not consider source information or any
1302    /// generic data annotated on expression. This should behave the same as the
1303    /// default implementation of `Eq` before source information and generic
1304    /// data were added.
1305    pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1306        use ExprKind::*;
1307        match (self.expr_kind(), other.expr_kind()) {
1308            (Lit(lit), Lit(lit1)) => lit == lit1,
1309            (Var(v), Var(v1)) => v == v1,
1310            (Slot(s), Slot(s1)) => s == s1,
1311            (
1312                Unknown(self::Unknown {
1313                    name: name1,
1314                    type_annotation: ta_1,
1315                }),
1316                Unknown(self::Unknown {
1317                    name: name2,
1318                    type_annotation: ta_2,
1319                }),
1320            ) => (name1 == name2) && (ta_1 == ta_2),
1321            (
1322                If {
1323                    test_expr,
1324                    then_expr,
1325                    else_expr,
1326                },
1327                If {
1328                    test_expr: test_expr1,
1329                    then_expr: then_expr1,
1330                    else_expr: else_expr1,
1331                },
1332            ) => {
1333                test_expr.eq_shape(test_expr1)
1334                    && then_expr.eq_shape(then_expr1)
1335                    && else_expr.eq_shape(else_expr1)
1336            }
1337            (
1338                And { left, right },
1339                And {
1340                    left: left1,
1341                    right: right1,
1342                },
1343            )
1344            | (
1345                Or { left, right },
1346                Or {
1347                    left: left1,
1348                    right: right1,
1349                },
1350            ) => left.eq_shape(left1) && right.eq_shape(right1),
1351            (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1352                op == op1 && arg.eq_shape(arg1)
1353            }
1354            (
1355                BinaryApp { op, arg1, arg2 },
1356                BinaryApp {
1357                    op: op1,
1358                    arg1: arg11,
1359                    arg2: arg21,
1360                },
1361            ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1362            (
1363                ExtensionFunctionApp { fn_name, args },
1364                ExtensionFunctionApp {
1365                    fn_name: fn_name1,
1366                    args: args1,
1367                },
1368            ) => fn_name == fn_name1 && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1)),
1369            (
1370                GetAttr { expr, attr },
1371                GetAttr {
1372                    expr: expr1,
1373                    attr: attr1,
1374                },
1375            )
1376            | (
1377                HasAttr { expr, attr },
1378                HasAttr {
1379                    expr: expr1,
1380                    attr: attr1,
1381                },
1382            ) => attr == attr1 && expr.eq_shape(expr1),
1383            (
1384                Like { expr, pattern },
1385                Like {
1386                    expr: expr1,
1387                    pattern: pattern1,
1388                },
1389            ) => pattern == pattern1 && expr.eq_shape(expr1),
1390            (Set(elems), Set(elems1)) => elems
1391                .iter()
1392                .zip(elems1.iter())
1393                .all(|(e, e1)| e.eq_shape(e1)),
1394            (Record(map), Record(map1)) => {
1395                map.len() == map1.len()
1396                    && map
1397                        .iter()
1398                        .zip(map1.iter()) // relying on BTreeMap producing an iterator sorted by key
1399                        .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1400            }
1401            (
1402                Is { expr, entity_type },
1403                Is {
1404                    expr: expr1,
1405                    entity_type: entity_type1,
1406                },
1407            ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1408            _ => false,
1409        }
1410    }
1411
1412    /// Implementation of hashing corresponding to equality as implemented by
1413    /// `eq_shape`. Must satisfy the usual relationship between equality and
1414    /// hashing.
1415    pub fn hash_shape<H>(&self, state: &mut H)
1416    where
1417        H: Hasher,
1418    {
1419        mem::discriminant(self).hash(state);
1420        match self.expr_kind() {
1421            ExprKind::Lit(lit) => lit.hash(state),
1422            ExprKind::Var(v) => v.hash(state),
1423            ExprKind::Slot(s) => s.hash(state),
1424            ExprKind::Unknown(u) => u.hash(state),
1425            ExprKind::If {
1426                test_expr,
1427                then_expr,
1428                else_expr,
1429            } => {
1430                test_expr.hash_shape(state);
1431                then_expr.hash_shape(state);
1432                else_expr.hash_shape(state);
1433            }
1434            ExprKind::And { left, right } => {
1435                left.hash_shape(state);
1436                right.hash_shape(state);
1437            }
1438            ExprKind::Or { left, right } => {
1439                left.hash_shape(state);
1440                right.hash_shape(state);
1441            }
1442            ExprKind::UnaryApp { op, arg } => {
1443                op.hash(state);
1444                arg.hash_shape(state);
1445            }
1446            ExprKind::BinaryApp { op, arg1, arg2 } => {
1447                op.hash(state);
1448                arg1.hash_shape(state);
1449                arg2.hash_shape(state);
1450            }
1451            ExprKind::ExtensionFunctionApp { fn_name, args } => {
1452                fn_name.hash(state);
1453                state.write_usize(args.len());
1454                args.iter().for_each(|a| {
1455                    a.hash_shape(state);
1456                });
1457            }
1458            ExprKind::GetAttr { expr, attr } => {
1459                expr.hash_shape(state);
1460                attr.hash(state);
1461            }
1462            ExprKind::HasAttr { expr, attr } => {
1463                expr.hash_shape(state);
1464                attr.hash(state);
1465            }
1466            ExprKind::Like { expr, pattern } => {
1467                expr.hash_shape(state);
1468                pattern.hash(state);
1469            }
1470            ExprKind::Set(elems) => {
1471                state.write_usize(elems.len());
1472                elems.iter().for_each(|e| {
1473                    e.hash_shape(state);
1474                })
1475            }
1476            ExprKind::Record(map) => {
1477                state.write_usize(map.len());
1478                map.iter().for_each(|(s, a)| {
1479                    s.hash(state);
1480                    a.hash_shape(state);
1481                });
1482            }
1483            ExprKind::Is { expr, entity_type } => {
1484                expr.hash_shape(state);
1485                entity_type.hash(state);
1486            }
1487        }
1488    }
1489}
1490
1491/// AST variables
1492#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
1493#[serde(rename_all = "camelCase")]
1494#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1495#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1496#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1497pub enum Var {
1498    /// the Principal of the given request
1499    Principal,
1500    /// the Action of the given request
1501    Action,
1502    /// the Resource of the given request
1503    Resource,
1504    /// the Context of the given request
1505    Context,
1506}
1507
1508#[cfg(test)]
1509pub mod var_generator {
1510    use super::Var;
1511    #[cfg(test)]
1512    pub fn all_vars() -> impl Iterator<Item = Var> {
1513        [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1514    }
1515}
1516// by default, Coverlay does not track coverage for lines after a line
1517// containing #[cfg(test)].
1518// we use the following sentinel to "turn back on" coverage tracking for
1519// remaining lines of this file, until the next #[cfg(test)]
1520// GRCOV_BEGIN_COVERAGE
1521
1522impl From<PrincipalOrResource> for Var {
1523    fn from(v: PrincipalOrResource) -> Self {
1524        match v {
1525            PrincipalOrResource::Principal => Var::Principal,
1526            PrincipalOrResource::Resource => Var::Resource,
1527        }
1528    }
1529}
1530
1531// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1532#[allow(clippy::fallible_impl_from)]
1533impl From<Var> for Id {
1534    fn from(var: Var) -> Self {
1535        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `Id`. Tested by `test::all_vars_are_ids`
1536        #[allow(clippy::unwrap_used)]
1537        format!("{var}").parse().unwrap()
1538    }
1539}
1540
1541// PANIC SAFETY Tested by `test::all_vars_are_ids`. Never panics.
1542#[allow(clippy::fallible_impl_from)]
1543impl From<Var> for UnreservedId {
1544    fn from(var: Var) -> Self {
1545        // PANIC SAFETY: `Var` is a simple enum and all vars are formatted as valid `UnreservedId`. Tested by `test::all_vars_are_ids`
1546        #[allow(clippy::unwrap_used)]
1547        Id::from(var).try_into().unwrap()
1548    }
1549}
1550
1551impl std::fmt::Display for Var {
1552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1553        match self {
1554            Self::Principal => write!(f, "principal"),
1555            Self::Action => write!(f, "action"),
1556            Self::Resource => write!(f, "resource"),
1557            Self::Context => write!(f, "context"),
1558        }
1559    }
1560}
1561
1562#[cfg(test)]
1563mod test {
1564    use cool_asserts::assert_matches;
1565    use itertools::Itertools;
1566    use std::collections::{hash_map::DefaultHasher, HashSet};
1567
1568    use super::{var_generator::all_vars, *};
1569
1570    // Tests that Var::Into never panics
1571    #[test]
1572    fn all_vars_are_ids() {
1573        for var in all_vars() {
1574            let _id: Id = var.into();
1575            let _id: UnreservedId = var.into();
1576        }
1577    }
1578
1579    #[test]
1580    fn exprs() {
1581        assert_eq!(
1582            Expr::val(33),
1583            Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1584        );
1585        assert_eq!(
1586            Expr::val("hello"),
1587            Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1588        );
1589        assert_eq!(
1590            Expr::val(EntityUID::with_eid("foo")),
1591            Expr::new(
1592                ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1593                None,
1594                ()
1595            )
1596        );
1597        assert_eq!(
1598            Expr::var(Var::Principal),
1599            Expr::new(ExprKind::Var(Var::Principal), None, ())
1600        );
1601        assert_eq!(
1602            Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1603            Expr::new(
1604                ExprKind::If {
1605                    test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1606                    then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1607                    else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1608                },
1609                None,
1610                ()
1611            )
1612        );
1613        assert_eq!(
1614            Expr::not(Expr::val(false)),
1615            Expr::new(
1616                ExprKind::UnaryApp {
1617                    op: UnaryOp::Not,
1618                    arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1619                },
1620                None,
1621                ()
1622            )
1623        );
1624        assert_eq!(
1625            Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1626            Expr::new(
1627                ExprKind::GetAttr {
1628                    expr: Arc::new(Expr::new(
1629                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1630                        None,
1631                        ()
1632                    )),
1633                    attr: "some_attr".into()
1634                },
1635                None,
1636                ()
1637            )
1638        );
1639        assert_eq!(
1640            Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1641            Expr::new(
1642                ExprKind::HasAttr {
1643                    expr: Arc::new(Expr::new(
1644                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1645                        None,
1646                        ()
1647                    )),
1648                    attr: "some_attr".into()
1649                },
1650                None,
1651                ()
1652            )
1653        );
1654        assert_eq!(
1655            Expr::is_entity_type(
1656                Expr::val(EntityUID::with_eid("foo")),
1657                "Type".parse().unwrap()
1658            ),
1659            Expr::new(
1660                ExprKind::Is {
1661                    expr: Arc::new(Expr::new(
1662                        ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1663                        None,
1664                        ()
1665                    )),
1666                    entity_type: "Type".parse().unwrap()
1667                },
1668                None,
1669                ()
1670            ),
1671        );
1672    }
1673
1674    #[test]
1675    fn like_display() {
1676        // `\0` escaped form is `\0`.
1677        let e = Expr::like(Expr::val("a"), vec![PatternElem::Char('\0')]);
1678        assert_eq!(format!("{e}"), r#""a" like "\0""#);
1679        // `\`'s escaped form is `\\`
1680        let e = Expr::like(
1681            Expr::val("a"),
1682            vec![PatternElem::Char('\\'), PatternElem::Char('0')],
1683        );
1684        assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1685        // `\`'s escaped form is `\\`
1686        let e = Expr::like(
1687            Expr::val("a"),
1688            vec![PatternElem::Char('\\'), PatternElem::Wildcard],
1689        );
1690        assert_eq!(format!("{e}"), r#""a" like "\\*""#);
1691        // literal star's escaped from is `\*`
1692        let e = Expr::like(
1693            Expr::val("a"),
1694            vec![PatternElem::Char('\\'), PatternElem::Char('*')],
1695        );
1696        assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
1697    }
1698
1699    #[test]
1700    fn has_display() {
1701        // `\0` escaped form is `\0`.
1702        let e = Expr::has_attr(Expr::val("a"), "\0".into());
1703        assert_eq!(format!("{e}"), r#""a" has "\0""#);
1704        // `\`'s escaped form is `\\`
1705        let e = Expr::has_attr(Expr::val("a"), r"\".into());
1706        assert_eq!(format!("{e}"), r#""a" has "\\""#);
1707    }
1708
1709    #[test]
1710    fn slot_display() {
1711        let e = Expr::slot(SlotId::principal());
1712        assert_eq!(format!("{e}"), "?principal");
1713        let e = Expr::slot(SlotId::resource());
1714        assert_eq!(format!("{e}"), "?resource");
1715        let e = Expr::val(EntityUID::with_eid("eid"));
1716        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
1717    }
1718
1719    #[test]
1720    fn simple_slots() {
1721        let e = Expr::slot(SlotId::principal());
1722        let p = SlotId::principal();
1723        let r = SlotId::resource();
1724        let set: HashSet<SlotId> = HashSet::from_iter([p]);
1725        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1726        let e = Expr::or(
1727            Expr::slot(SlotId::principal()),
1728            Expr::ite(
1729                Expr::val(true),
1730                Expr::slot(SlotId::resource()),
1731                Expr::val(false),
1732            ),
1733        );
1734        let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
1735        assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
1736    }
1737
1738    #[test]
1739    fn unknowns() {
1740        let e = Expr::ite(
1741            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1742            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1743            Expr::unknown(Unknown::new_untyped("c")),
1744        );
1745        let unknowns = e.unknowns().collect_vec();
1746        assert_eq!(unknowns.len(), 3);
1747        assert!(unknowns.contains(&&Unknown::new_untyped("a")));
1748        assert!(unknowns.contains(&&Unknown::new_untyped("b")));
1749        assert!(unknowns.contains(&&Unknown::new_untyped("c")));
1750    }
1751
1752    #[test]
1753    fn is_unknown() {
1754        let e = Expr::ite(
1755            Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
1756            Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
1757            Expr::unknown(Unknown::new_untyped("c")),
1758        );
1759        assert!(e.contains_unknown());
1760        let e = Expr::ite(
1761            Expr::not(Expr::val(true)),
1762            Expr::and(Expr::val(1), Expr::val(3)),
1763            Expr::val(1),
1764        );
1765        assert!(!e.contains_unknown());
1766    }
1767
1768    #[test]
1769    fn expr_with_data() {
1770        let e = ExprBuilder::with_data("data").val(1);
1771        assert_eq!(e.into_data(), "data");
1772    }
1773
1774    #[test]
1775    fn expr_shape_only_eq() {
1776        let temp = ExprBuilder::with_data(1).val(1);
1777        let exprs = &[
1778            (ExprBuilder::with_data(1).val(33), Expr::val(33)),
1779            (ExprBuilder::with_data(1).val(true), Expr::val(true)),
1780            (
1781                ExprBuilder::with_data(1).var(Var::Principal),
1782                Expr::var(Var::Principal),
1783            ),
1784            (
1785                ExprBuilder::with_data(1).slot(SlotId::principal()),
1786                Expr::slot(SlotId::principal()),
1787            ),
1788            (
1789                ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
1790                Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
1791            ),
1792            (
1793                ExprBuilder::with_data(1).not(temp.clone()),
1794                Expr::not(Expr::val(1)),
1795            ),
1796            (
1797                ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
1798                Expr::is_eq(Expr::val(1), Expr::val(1)),
1799            ),
1800            (
1801                ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
1802                Expr::and(Expr::val(1), Expr::val(1)),
1803            ),
1804            (
1805                ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
1806                Expr::or(Expr::val(1), Expr::val(1)),
1807            ),
1808            (
1809                ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
1810                Expr::less(Expr::val(1), Expr::val(1)),
1811            ),
1812            (
1813                ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
1814                Expr::lesseq(Expr::val(1), Expr::val(1)),
1815            ),
1816            (
1817                ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
1818                Expr::greater(Expr::val(1), Expr::val(1)),
1819            ),
1820            (
1821                ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
1822                Expr::greatereq(Expr::val(1), Expr::val(1)),
1823            ),
1824            (
1825                ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
1826                Expr::add(Expr::val(1), Expr::val(1)),
1827            ),
1828            (
1829                ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
1830                Expr::sub(Expr::val(1), Expr::val(1)),
1831            ),
1832            (
1833                ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
1834                Expr::mul(Expr::val(1), Expr::val(1)),
1835            ),
1836            (
1837                ExprBuilder::with_data(1).neg(temp.clone()),
1838                Expr::neg(Expr::val(1)),
1839            ),
1840            (
1841                ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
1842                Expr::is_in(Expr::val(1), Expr::val(1)),
1843            ),
1844            (
1845                ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
1846                Expr::contains(Expr::val(1), Expr::val(1)),
1847            ),
1848            (
1849                ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
1850                Expr::contains_all(Expr::val(1), Expr::val(1)),
1851            ),
1852            (
1853                ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
1854                Expr::contains_any(Expr::val(1), Expr::val(1)),
1855            ),
1856            (
1857                ExprBuilder::with_data(1).set([temp.clone()]),
1858                Expr::set([Expr::val(1)]),
1859            ),
1860            (
1861                ExprBuilder::with_data(1)
1862                    .record([("foo".into(), temp.clone())])
1863                    .unwrap(),
1864                Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
1865            ),
1866            (
1867                ExprBuilder::with_data(1)
1868                    .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
1869                Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
1870            ),
1871            (
1872                ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
1873                Expr::get_attr(Expr::val(1), "foo".into()),
1874            ),
1875            (
1876                ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
1877                Expr::has_attr(Expr::val(1), "foo".into()),
1878            ),
1879            (
1880                ExprBuilder::with_data(1).like(temp.clone(), vec![PatternElem::Wildcard]),
1881                Expr::like(Expr::val(1), vec![PatternElem::Wildcard]),
1882            ),
1883            (
1884                ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
1885                Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
1886            ),
1887        ];
1888
1889        for (e0, e1) in exprs {
1890            assert!(e0.eq_shape(e0));
1891            assert!(e1.eq_shape(e1));
1892            assert!(e0.eq_shape(e1));
1893            assert!(e1.eq_shape(e0));
1894
1895            let mut hasher0 = DefaultHasher::new();
1896            e0.hash_shape(&mut hasher0);
1897            let hash0 = hasher0.finish();
1898
1899            let mut hasher1 = DefaultHasher::new();
1900            e1.hash_shape(&mut hasher1);
1901            let hash1 = hasher1.finish();
1902
1903            assert_eq!(hash0, hash1);
1904        }
1905    }
1906
1907    #[test]
1908    fn expr_shape_only_not_eq() {
1909        let expr1 = ExprBuilder::with_data(1).val(1);
1910        let expr2 = ExprBuilder::with_data(1).val(2);
1911        assert_ne!(ExprShapeOnly::new(&expr1), ExprShapeOnly::new(&expr2));
1912    }
1913
1914    #[test]
1915    fn untyped_subst_present() {
1916        let u = Unknown {
1917            name: "foo".into(),
1918            type_annotation: None,
1919        };
1920        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1921        match r {
1922            Ok(e) => assert_eq!(e, Expr::val(1)),
1923            Err(empty) => match empty {},
1924        }
1925    }
1926
1927    #[test]
1928    fn untyped_subst_present_correct_type() {
1929        let u = Unknown {
1930            name: "foo".into(),
1931            type_annotation: Some(Type::Long),
1932        };
1933        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1934        match r {
1935            Ok(e) => assert_eq!(e, Expr::val(1)),
1936            Err(empty) => match empty {},
1937        }
1938    }
1939
1940    #[test]
1941    fn untyped_subst_present_wrong_type() {
1942        let u = Unknown {
1943            name: "foo".into(),
1944            type_annotation: Some(Type::Bool),
1945        };
1946        let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
1947        match r {
1948            Ok(e) => assert_eq!(e, Expr::val(1)),
1949            Err(empty) => match empty {},
1950        }
1951    }
1952
1953    #[test]
1954    fn untyped_subst_not_present() {
1955        let u = Unknown {
1956            name: "foo".into(),
1957            type_annotation: Some(Type::Bool),
1958        };
1959        let r = UntypedSubstitution::substitute(&u, None);
1960        match r {
1961            Ok(n) => assert_eq!(n, Expr::unknown(u)),
1962            Err(empty) => match empty {},
1963        }
1964    }
1965
1966    #[test]
1967    fn typed_subst_present() {
1968        let u = Unknown {
1969            name: "foo".into(),
1970            type_annotation: None,
1971        };
1972        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
1973        assert_eq!(e, Expr::val(1));
1974    }
1975
1976    #[test]
1977    fn typed_subst_present_correct_type() {
1978        let u = Unknown {
1979            name: "foo".into(),
1980            type_annotation: Some(Type::Long),
1981        };
1982        let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
1983        assert_eq!(e, Expr::val(1));
1984    }
1985
1986    #[test]
1987    fn typed_subst_present_wrong_type() {
1988        let u = Unknown {
1989            name: "foo".into(),
1990            type_annotation: Some(Type::Bool),
1991        };
1992        let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
1993        assert_matches!(
1994            r,
1995            SubstitutionError::TypeError {
1996                expected: Type::Bool,
1997                actual: Type::Long,
1998            }
1999        );
2000    }
2001
2002    #[test]
2003    fn typed_subst_not_present() {
2004        let u = Unknown {
2005            name: "foo".into(),
2006            type_annotation: None,
2007        };
2008        let r = TypedSubstitution::substitute(&u, None).unwrap();
2009        assert_eq!(r, Expr::unknown(u));
2010    }
2011}