cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18#[cfg(feature = "tolerant-ast")]
19use crate::ast::expr_allows_errors::AstExprErrorKind;
20#[cfg(feature = "tolerant-ast")]
21use crate::ast::Infallible;
22use crate::ast::{self, BoundedDisplay, EntityUID};
23use crate::entities::json::{
24    err::EscapeKind, err::JsonDeserializationError, err::JsonDeserializationErrorContext,
25    CedarValueJson, FnAndArg,
26};
27use crate::expr_builder::ExprBuilder;
28use crate::extensions::Extensions;
29use crate::jsonvalue::JsonValueWithNoDuplicateKeys;
30use crate::parser::cst_to_ast;
31use crate::parser::err::ParseErrors;
32use crate::parser::Node;
33use crate::parser::{cst, Loc};
34use itertools::Itertools;
35use serde::{de::Visitor, Deserialize, Serialize};
36use serde_with::serde_as;
37use smol_str::{SmolStr, ToSmolStr};
38use std::collections::{btree_map, BTreeMap, HashMap};
39use std::sync::Arc;
40
41/// Serde JSON structure for a Cedar expression in the EST format
42#[derive(Debug, Clone, PartialEq, Serialize)]
43#[serde(untagged)]
44#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
45#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
46pub enum Expr {
47    /// Any Cedar expression other than an extension function call.
48    ExprNoExt(ExprNoExt),
49    /// Extension function call, where the key is the name of an extension
50    /// function or method.
51    ExtFuncCall(ExtFuncCall),
52}
53
54// Manual implementation of `Deserialize` is more efficient than the derived
55// implementation with `serde(untagged)`. In particular, if the key is valid for
56// `ExprNoExt` but there is a deserialization problem within the corresponding
57// value, the derived implementation would backtrack and try to deserialize as
58// `ExtFuncCall` with that key as the extension function name, but this manual
59// implementation instead eagerly errors out, taking advantage of the fact that
60// none of the keys for `ExprNoExt` are valid extension function names.
61//
62// See #1284.
63impl<'de> Deserialize<'de> for Expr {
64    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65    where
66        D: serde::Deserializer<'de>,
67    {
68        struct ExprVisitor;
69        impl<'de> Visitor<'de> for ExprVisitor {
70            type Value = Expr;
71            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72                formatter.write_str("JSON object representing an expression")
73            }
74
75            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
76            where
77                A: serde::de::MapAccess<'de>,
78            {
79                let (k, v): (SmolStr, JsonValueWithNoDuplicateKeys) = match map.next_entry()? {
80                    None => {
81                        return Err(serde::de::Error::custom(
82                            "empty map is not a valid expression",
83                        ))
84                    }
85                    Some((k, v)) => (k, v),
86                };
87                match map.next_key()? {
88                    None => (),
89                    Some(k2) => {
90                        let k2: SmolStr = k2;
91                        return Err(serde::de::Error::custom(format!("JSON object representing an `Expr` should have only one key, but found two keys: `{k}` and `{k2}`")));
92                    }
93                };
94                if cst_to_ast::is_known_extension_func_str(&k) {
95                    // `k` is the name of an extension function or method. We assume that
96                    // no such keys are valid keys for `ExprNoExt`, so we must parse as an
97                    // `ExtFuncCall`.
98                    let obj = serde_json::json!({ k: v });
99                    let extfunccall =
100                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
101                    Ok(Expr::ExtFuncCall(extfunccall))
102                } else {
103                    // not a valid extension function or method, so we expect it
104                    // to work for `ExprNoExt`.
105                    let obj = serde_json::json!({ k: v });
106                    let exprnoext =
107                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
108                    Ok(Expr::ExprNoExt(exprnoext))
109                }
110            }
111        }
112
113        deserializer.deserialize_map(ExprVisitor)
114    }
115}
116
117/// Represent an element of a pattern literal
118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
119#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
120#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
121pub enum PatternElem {
122    /// The wildcard asterisk
123    Wildcard,
124    /// A string without any wildcards
125    Literal(SmolStr),
126}
127
128impl From<&[PatternElem]> for crate::ast::Pattern {
129    fn from(value: &[PatternElem]) -> Self {
130        let mut elems = Vec::new();
131        for elem in value {
132            match elem {
133                PatternElem::Wildcard => {
134                    elems.push(crate::ast::PatternElem::Wildcard);
135                }
136                PatternElem::Literal(s) => {
137                    elems.extend(s.chars().map(crate::ast::PatternElem::Char));
138                }
139            }
140        }
141        Self::from(elems)
142    }
143}
144
145impl From<crate::ast::PatternElem> for PatternElem {
146    fn from(value: crate::ast::PatternElem) -> Self {
147        match value {
148            crate::ast::PatternElem::Wildcard => Self::Wildcard,
149            crate::ast::PatternElem::Char(c) => Self::Literal(c.to_smolstr()),
150        }
151    }
152}
153
154impl From<crate::ast::Pattern> for Vec<PatternElem> {
155    fn from(value: crate::ast::Pattern) -> Self {
156        value.iter().map(|elem| (*elem).into()).collect()
157    }
158}
159
160/// Serde JSON structure for [any Cedar expression other than an extension
161/// function call] in the EST format
162#[serde_as]
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164#[serde(deny_unknown_fields)]
165#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
166#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
167pub enum ExprNoExt {
168    /// Literal value (including anything that's legal to express in the
169    /// attribute-value JSON format)
170    Value(CedarValueJson),
171    /// Var
172    Var(ast::Var),
173    /// Template slot
174    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
175    /// `!`
176    #[serde(rename = "!")]
177    Not {
178        /// Argument
179        arg: Arc<Expr>,
180    },
181    /// `-`
182    #[serde(rename = "neg")]
183    Neg {
184        /// Argument
185        arg: Arc<Expr>,
186    },
187    /// `==`
188    #[serde(rename = "==")]
189    Eq {
190        /// Left-hand argument
191        left: Arc<Expr>,
192        /// Right-hand argument
193        right: Arc<Expr>,
194    },
195    /// `!=`
196    #[serde(rename = "!=")]
197    NotEq {
198        /// Left-hand argument
199        left: Arc<Expr>,
200        /// Right-hand argument
201        right: Arc<Expr>,
202    },
203    /// `in`
204    #[serde(rename = "in")]
205    In {
206        /// Left-hand argument
207        left: Arc<Expr>,
208        /// Right-hand argument
209        right: Arc<Expr>,
210    },
211    /// `<`
212    #[serde(rename = "<")]
213    Less {
214        /// Left-hand argument
215        left: Arc<Expr>,
216        /// Right-hand argument
217        right: Arc<Expr>,
218    },
219    /// `<=`
220    #[serde(rename = "<=")]
221    LessEq {
222        /// Left-hand argument
223        left: Arc<Expr>,
224        /// Right-hand argument
225        right: Arc<Expr>,
226    },
227    /// `>`
228    #[serde(rename = ">")]
229    Greater {
230        /// Left-hand argument
231        left: Arc<Expr>,
232        /// Right-hand argument
233        right: Arc<Expr>,
234    },
235    /// `>=`
236    #[serde(rename = ">=")]
237    GreaterEq {
238        /// Left-hand argument
239        left: Arc<Expr>,
240        /// Right-hand argument
241        right: Arc<Expr>,
242    },
243    /// `&&`
244    #[serde(rename = "&&")]
245    And {
246        /// Left-hand argument
247        left: Arc<Expr>,
248        /// Right-hand argument
249        right: Arc<Expr>,
250    },
251    /// `||`
252    #[serde(rename = "||")]
253    Or {
254        /// Left-hand argument
255        left: Arc<Expr>,
256        /// Right-hand argument
257        right: Arc<Expr>,
258    },
259    /// `+`
260    #[serde(rename = "+")]
261    Add {
262        /// Left-hand argument
263        left: Arc<Expr>,
264        /// Right-hand argument
265        right: Arc<Expr>,
266    },
267    /// `-`
268    #[serde(rename = "-")]
269    Sub {
270        /// Left-hand argument
271        left: Arc<Expr>,
272        /// Right-hand argument
273        right: Arc<Expr>,
274    },
275    /// `*`
276    #[serde(rename = "*")]
277    Mul {
278        /// Left-hand argument
279        left: Arc<Expr>,
280        /// Right-hand argument
281        right: Arc<Expr>,
282    },
283    /// `contains()`
284    #[serde(rename = "contains")]
285    Contains {
286        /// Left-hand argument (receiver)
287        left: Arc<Expr>,
288        /// Right-hand argument (inside the `()`)
289        right: Arc<Expr>,
290    },
291    /// `containsAll()`
292    #[serde(rename = "containsAll")]
293    ContainsAll {
294        /// Left-hand argument (receiver)
295        left: Arc<Expr>,
296        /// Right-hand argument (inside the `()`)
297        right: Arc<Expr>,
298    },
299    /// `containsAny()`
300    #[serde(rename = "containsAny")]
301    ContainsAny {
302        /// Left-hand argument (receiver)
303        left: Arc<Expr>,
304        /// Right-hand argument (inside the `()`)
305        right: Arc<Expr>,
306    },
307    /// `isEmpty()`
308    #[serde(rename = "isEmpty")]
309    IsEmpty {
310        /// Argument
311        arg: Arc<Expr>,
312    },
313    /// `getTag()`
314    #[serde(rename = "getTag")]
315    GetTag {
316        /// Left-hand argument (receiver)
317        left: Arc<Expr>,
318        /// Right-hand argument (inside the `()`)
319        right: Arc<Expr>,
320    },
321    /// `hasTag()`
322    #[serde(rename = "hasTag")]
323    HasTag {
324        /// Left-hand argument (receiver)
325        left: Arc<Expr>,
326        /// Right-hand argument (inside the `()`)
327        right: Arc<Expr>,
328    },
329    /// Get-attribute
330    #[serde(rename = ".")]
331    GetAttr {
332        /// Left-hand argument
333        left: Arc<Expr>,
334        /// Attribute name
335        attr: SmolStr,
336    },
337    /// `has`
338    #[serde(rename = "has")]
339    HasAttr {
340        /// Left-hand argument
341        left: Arc<Expr>,
342        /// Attribute name
343        attr: SmolStr,
344    },
345    /// `like`
346    #[serde(rename = "like")]
347    Like {
348        /// Left-hand argument
349        left: Arc<Expr>,
350        /// Pattern
351        pattern: Vec<PatternElem>,
352    },
353    /// `<entity> is <entity_type> in <entity_or_entity_set> `
354    #[serde(rename = "is")]
355    Is {
356        /// Left-hand entity argument
357        left: Arc<Expr>,
358        /// Entity type
359        entity_type: SmolStr,
360        /// Entity or entity set
361        #[serde(skip_serializing_if = "Option::is_none")]
362        #[serde(rename = "in")]
363        in_expr: Option<Arc<Expr>>,
364    },
365    /// Ternary
366    #[serde(rename = "if-then-else")]
367    If {
368        /// Condition
369        #[serde(rename = "if")]
370        cond_expr: Arc<Expr>,
371        /// `then` expression
372        #[serde(rename = "then")]
373        then_expr: Arc<Expr>,
374        /// `else` expression
375        #[serde(rename = "else")]
376        else_expr: Arc<Expr>,
377    },
378    /// Set literal, whose elements may be arbitrary expressions
379    /// (which is why we need this case specifically and can't just
380    /// use Expr::Value)
381    Set(Vec<Expr>),
382    /// Record literal, whose elements may be arbitrary expressions
383    /// (which is why we need this case specifically and can't just
384    /// use Expr::Value)
385    Record(
386        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
387        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
388        BTreeMap<SmolStr, Expr>,
389    ),
390    /// AST Error node - this represents a parsing error in a partially generated AST
391    #[cfg(feature = "tolerant-ast")]
392    Error(AstExprErrorKind),
393}
394
395/// Serde JSON structure for an extension function call in the EST format
396#[serde_as]
397#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
398#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
399#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
400pub struct ExtFuncCall {
401    /// maps the name of the function to a JSON list/array of the arguments.
402    /// Note that for method calls, the method receiver is the first argument.
403    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
404    /// second argument is `b`.
405    ///
406    /// INVARIANT: This map should always have exactly one k-v pair (not more or
407    /// less), but we make it a map in order to get the correct JSON structure
408    /// we want.
409    #[serde(flatten)]
410    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
411    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
412    call: HashMap<SmolStr, Vec<Expr>>,
413}
414
415/// Construct an [`Expr`].
416#[derive(Clone, Debug)]
417pub struct Builder;
418
419impl ExprBuilder for Builder {
420    type Expr = Expr;
421
422    type Data = ();
423    #[cfg(feature = "tolerant-ast")]
424    type ErrorType = Infallible;
425
426    fn with_data(_data: Self::Data) -> Self {
427        Self
428    }
429
430    fn with_maybe_source_loc(self, _: Option<&Loc>) -> Self {
431        self
432    }
433
434    fn loc(&self) -> Option<&Loc> {
435        None
436    }
437
438    fn data(&self) -> &Self::Data {
439        &()
440    }
441
442    /// literal
443    fn val(self, lit: impl Into<ast::Literal>) -> Expr {
444        Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::from_lit(lit.into())))
445    }
446
447    /// principal, action, resource, context
448    fn var(self, var: ast::Var) -> Expr {
449        Expr::ExprNoExt(ExprNoExt::Var(var))
450    }
451
452    /// Template slots
453    fn slot(self, slot: ast::SlotId) -> Expr {
454        Expr::ExprNoExt(ExprNoExt::Slot(slot))
455    }
456
457    /// An extension call with one arg, which is the name of the unknown
458    fn unknown(self, u: ast::Unknown) -> Expr {
459        Expr::ExtFuncCall(ExtFuncCall {
460            call: HashMap::from([("unknown".to_smolstr(), vec![Builder::new().val(u.name)])]),
461        })
462    }
463
464    /// `!`
465    fn not(self, e: Expr) -> Expr {
466        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
467    }
468
469    /// `-`
470    fn neg(self, e: Expr) -> Expr {
471        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
472    }
473
474    /// `==`
475    fn is_eq(self, left: Expr, right: Expr) -> Expr {
476        Expr::ExprNoExt(ExprNoExt::Eq {
477            left: Arc::new(left),
478            right: Arc::new(right),
479        })
480    }
481
482    /// `!=`
483    fn noteq(self, left: Expr, right: Expr) -> Expr {
484        Expr::ExprNoExt(ExprNoExt::NotEq {
485            left: Arc::new(left),
486            right: Arc::new(right),
487        })
488    }
489
490    /// `in`
491    fn is_in(self, left: Expr, right: Expr) -> Expr {
492        Expr::ExprNoExt(ExprNoExt::In {
493            left: Arc::new(left),
494            right: Arc::new(right),
495        })
496    }
497
498    /// `<`
499    fn less(self, left: Expr, right: Expr) -> Expr {
500        Expr::ExprNoExt(ExprNoExt::Less {
501            left: Arc::new(left),
502            right: Arc::new(right),
503        })
504    }
505
506    /// `<=`
507    fn lesseq(self, left: Expr, right: Expr) -> Expr {
508        Expr::ExprNoExt(ExprNoExt::LessEq {
509            left: Arc::new(left),
510            right: Arc::new(right),
511        })
512    }
513
514    /// `>`
515    fn greater(self, left: Expr, right: Expr) -> Expr {
516        Expr::ExprNoExt(ExprNoExt::Greater {
517            left: Arc::new(left),
518            right: Arc::new(right),
519        })
520    }
521
522    /// `>=`
523    fn greatereq(self, left: Expr, right: Expr) -> Expr {
524        Expr::ExprNoExt(ExprNoExt::GreaterEq {
525            left: Arc::new(left),
526            right: Arc::new(right),
527        })
528    }
529
530    /// `&&`
531    fn and(self, left: Expr, right: Expr) -> Expr {
532        Expr::ExprNoExt(ExprNoExt::And {
533            left: Arc::new(left),
534            right: Arc::new(right),
535        })
536    }
537
538    /// `||`
539    fn or(self, left: Expr, right: Expr) -> Expr {
540        Expr::ExprNoExt(ExprNoExt::Or {
541            left: Arc::new(left),
542            right: Arc::new(right),
543        })
544    }
545
546    /// `+`
547    fn add(self, left: Expr, right: Expr) -> Expr {
548        Expr::ExprNoExt(ExprNoExt::Add {
549            left: Arc::new(left),
550            right: Arc::new(right),
551        })
552    }
553
554    /// `-`
555    fn sub(self, left: Expr, right: Expr) -> Expr {
556        Expr::ExprNoExt(ExprNoExt::Sub {
557            left: Arc::new(left),
558            right: Arc::new(right),
559        })
560    }
561
562    /// `*`
563    fn mul(self, left: Expr, right: Expr) -> Expr {
564        Expr::ExprNoExt(ExprNoExt::Mul {
565            left: Arc::new(left),
566            right: Arc::new(right),
567        })
568    }
569
570    /// `left.contains(right)`
571    fn contains(self, left: Expr, right: Expr) -> Expr {
572        Expr::ExprNoExt(ExprNoExt::Contains {
573            left: Arc::new(left),
574            right: Arc::new(right),
575        })
576    }
577
578    /// `left.containsAll(right)`
579    fn contains_all(self, left: Expr, right: Expr) -> Expr {
580        Expr::ExprNoExt(ExprNoExt::ContainsAll {
581            left: Arc::new(left),
582            right: Arc::new(right),
583        })
584    }
585
586    /// `left.containsAny(right)`
587    fn contains_any(self, left: Expr, right: Expr) -> Expr {
588        Expr::ExprNoExt(ExprNoExt::ContainsAny {
589            left: Arc::new(left),
590            right: Arc::new(right),
591        })
592    }
593
594    /// `arg.isEmpty()`
595    fn is_empty(self, expr: Expr) -> Expr {
596        Expr::ExprNoExt(ExprNoExt::IsEmpty {
597            arg: Arc::new(expr),
598        })
599    }
600
601    /// `left.getTag(right)`
602    fn get_tag(self, expr: Expr, tag: Expr) -> Expr {
603        Expr::ExprNoExt(ExprNoExt::GetTag {
604            left: Arc::new(expr),
605            right: Arc::new(tag),
606        })
607    }
608
609    /// `left.hasTag(right)`
610    fn has_tag(self, expr: Expr, tag: Expr) -> Expr {
611        Expr::ExprNoExt(ExprNoExt::HasTag {
612            left: Arc::new(expr),
613            right: Arc::new(tag),
614        })
615    }
616
617    /// `left.attr`
618    fn get_attr(self, expr: Expr, attr: SmolStr) -> Expr {
619        Expr::ExprNoExt(ExprNoExt::GetAttr {
620            left: Arc::new(expr),
621            attr,
622        })
623    }
624
625    /// `left has attr`
626    fn has_attr(self, expr: Expr, attr: SmolStr) -> Expr {
627        Expr::ExprNoExt(ExprNoExt::HasAttr {
628            left: Arc::new(expr),
629            attr,
630        })
631    }
632
633    /// `left like pattern`
634    fn like(self, expr: Expr, pattern: ast::Pattern) -> Expr {
635        Expr::ExprNoExt(ExprNoExt::Like {
636            left: Arc::new(expr),
637            pattern: pattern.into(),
638        })
639    }
640
641    /// `left is entity_type`
642    fn is_entity_type(self, left: Expr, entity_type: ast::EntityType) -> Expr {
643        Expr::ExprNoExt(ExprNoExt::Is {
644            left: Arc::new(left),
645            entity_type: entity_type.to_smolstr(),
646            in_expr: None,
647        })
648    }
649
650    /// `left is entity_type in entity`
651    fn is_in_entity_type(self, left: Expr, entity_type: ast::EntityType, entity: Expr) -> Expr {
652        Expr::ExprNoExt(ExprNoExt::Is {
653            left: Arc::new(left),
654            entity_type: entity_type.to_smolstr(),
655            in_expr: Some(Arc::new(entity)),
656        })
657    }
658
659    /// `if cond_expr then then_expr else else_expr`
660    fn ite(self, cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Expr {
661        Expr::ExprNoExt(ExprNoExt::If {
662            cond_expr: Arc::new(cond_expr),
663            then_expr: Arc::new(then_expr),
664            else_expr: Arc::new(else_expr),
665        })
666    }
667
668    /// e.g. [1+2, !(context has department)]
669    fn set(self, elements: impl IntoIterator<Item = Expr>) -> Expr {
670        Expr::ExprNoExt(ExprNoExt::Set(elements.into_iter().collect()))
671    }
672
673    /// e.g. {foo: 1+2, bar: !(context has department)}
674    fn record(
675        self,
676        map: impl IntoIterator<Item = (SmolStr, Expr)>,
677    ) -> Result<Expr, ast::ExpressionConstructionError> {
678        let mut dedup_map = BTreeMap::new();
679        for (k, v) in map {
680            match dedup_map.entry(k) {
681                btree_map::Entry::Occupied(oentry) => {
682                    return Err(ast::expression_construction_errors::DuplicateKeyError {
683                        key: oentry.key().clone(),
684                        context: "in record literal",
685                    }
686                    .into());
687                }
688                btree_map::Entry::Vacant(ventry) => {
689                    ventry.insert(v);
690                }
691            }
692        }
693        Ok(Expr::ExprNoExt(ExprNoExt::Record(dedup_map)))
694    }
695
696    /// extension function call, including method calls
697    fn call_extension_fn(self, fn_name: ast::Name, args: impl IntoIterator<Item = Expr>) -> Expr {
698        Expr::ExtFuncCall(ExtFuncCall {
699            call: HashMap::from([(fn_name.to_smolstr(), args.into_iter().collect())]),
700        })
701    }
702
703    #[cfg(feature = "tolerant-ast")]
704    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
705        Ok(Expr::ExprNoExt(ExprNoExt::Error(
706            AstExprErrorKind::InvalidExpr(parse_errors.to_string()),
707        )))
708    }
709}
710
711impl Expr {
712    /// Substitute entity literals
713    pub fn sub_entity_literals(
714        self,
715        mapping: &BTreeMap<EntityUID, EntityUID>,
716    ) -> Result<Self, JsonDeserializationError> {
717        match self.clone() {
718            Expr::ExprNoExt(e) => match e {
719                ExprNoExt::Value(v) => Ok(Expr::ExprNoExt(ExprNoExt::Value(
720                    v.sub_entity_literals(mapping)?,
721                ))),
722                ExprNoExt::Var(_) => Ok(self),
723                ExprNoExt::Slot(_) => Ok(self),
724                ExprNoExt::Not { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Not {
725                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
726                })),
727                ExprNoExt::Neg { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Neg {
728                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
729                })),
730                ExprNoExt::Eq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Eq {
731                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
732                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
733                })),
734                ExprNoExt::NotEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::NotEq {
735                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
736                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
737                })),
738                ExprNoExt::In { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::In {
739                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
740                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
741                })),
742                ExprNoExt::Less { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Less {
743                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
744                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
745                })),
746                ExprNoExt::LessEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::LessEq {
747                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
748                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
749                })),
750                ExprNoExt::Greater { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Greater {
751                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
752                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
753                })),
754                ExprNoExt::GreaterEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GreaterEq {
755                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
756                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
757                })),
758                ExprNoExt::And { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::And {
759                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
760                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
761                })),
762                ExprNoExt::Or { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Or {
763                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
764                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
765                })),
766                ExprNoExt::Add { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Add {
767                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
768                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
769                })),
770                ExprNoExt::Sub { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Sub {
771                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
772                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
773                })),
774                ExprNoExt::Mul { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Mul {
775                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
776                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
777                })),
778                ExprNoExt::Contains { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Contains {
779                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
780                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
781                })),
782                ExprNoExt::ContainsAll { left, right } => {
783                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAll {
784                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
785                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
786                    }))
787                }
788                ExprNoExt::ContainsAny { left, right } => {
789                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAny {
790                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
791                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
792                    }))
793                }
794                ExprNoExt::IsEmpty { arg } => Ok(Expr::ExprNoExt(ExprNoExt::IsEmpty {
795                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
796                })),
797                ExprNoExt::GetTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GetTag {
798                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
799                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
800                })),
801                ExprNoExt::HasTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::HasTag {
802                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
803                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
804                })),
805                ExprNoExt::GetAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::GetAttr {
806                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
807                    attr,
808                })),
809                ExprNoExt::HasAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::HasAttr {
810                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
811                    attr,
812                })),
813                ExprNoExt::Like { left, pattern } => Ok(Expr::ExprNoExt(ExprNoExt::Like {
814                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
815                    pattern,
816                })),
817                ExprNoExt::Is {
818                    left,
819                    entity_type,
820                    in_expr,
821                } => match in_expr {
822                    Some(in_expr) => Ok(Expr::ExprNoExt(ExprNoExt::Is {
823                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
824                        entity_type,
825                        in_expr: Some(Arc::new(
826                            Arc::unwrap_or_clone(in_expr).sub_entity_literals(mapping)?,
827                        )),
828                    })),
829                    None => Ok(Expr::ExprNoExt(ExprNoExt::Is {
830                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
831                        entity_type,
832                        in_expr: None,
833                    })),
834                },
835                ExprNoExt::If {
836                    cond_expr,
837                    then_expr,
838                    else_expr,
839                } => Ok(Expr::ExprNoExt(ExprNoExt::If {
840                    cond_expr: Arc::new(
841                        Arc::unwrap_or_clone(cond_expr).sub_entity_literals(mapping)?,
842                    ),
843                    then_expr: Arc::new(
844                        Arc::unwrap_or_clone(then_expr).sub_entity_literals(mapping)?,
845                    ),
846                    else_expr: Arc::new(
847                        Arc::unwrap_or_clone(else_expr).sub_entity_literals(mapping)?,
848                    ),
849                })),
850                ExprNoExt::Set(v) => {
851                    let mut new_v = vec![];
852                    for e in v {
853                        new_v.push(e.sub_entity_literals(mapping)?);
854                    }
855                    Ok(Expr::ExprNoExt(ExprNoExt::Set(new_v)))
856                }
857                ExprNoExt::Record(m) => {
858                    let mut new_m = BTreeMap::new();
859                    for (k, v) in m {
860                        new_m.insert(k, v.sub_entity_literals(mapping)?);
861                    }
862                    Ok(Expr::ExprNoExt(ExprNoExt::Record(new_m)))
863                }
864                #[cfg(feature = "tolerant-ast")]
865                ExprNoExt::Error(_) => Err(JsonDeserializationError::ASTErrorNode),
866            },
867            Expr::ExtFuncCall(e_fn_call) => {
868                let mut new_m = HashMap::new();
869                for (k, v) in e_fn_call.call {
870                    let mut new_v = vec![];
871                    for e in v {
872                        new_v.push(e.sub_entity_literals(mapping)?);
873                    }
874                    new_m.insert(k, new_v);
875                }
876                Ok(Expr::ExtFuncCall(ExtFuncCall { call: new_m }))
877            }
878        }
879    }
880}
881
882impl Expr {
883    /// Attempt to convert this `est::Expr` into an `ast::Expr`
884    ///
885    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
886    pub fn try_into_ast(self, id: ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
887        match self {
888            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
889                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
890                .map(Into::into)
891                .map_err(Into::into),
892            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
893            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
894            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
895                Ok(ast::Expr::not(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
896            }
897            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
898                Ok(ast::Expr::neg(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
899            }
900            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
901                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
902                Arc::unwrap_or_clone(right).try_into_ast(id)?,
903            )),
904            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
905                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
906                Arc::unwrap_or_clone(right).try_into_ast(id)?,
907            )),
908            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
909                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
910                Arc::unwrap_or_clone(right).try_into_ast(id)?,
911            )),
912            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
913                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
914                Arc::unwrap_or_clone(right).try_into_ast(id)?,
915            )),
916            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
917                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
918                Arc::unwrap_or_clone(right).try_into_ast(id)?,
919            )),
920            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
921                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
922                Arc::unwrap_or_clone(right).try_into_ast(id)?,
923            )),
924            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
925                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
926                Arc::unwrap_or_clone(right).try_into_ast(id)?,
927            )),
928            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
929                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
930                Arc::unwrap_or_clone(right).try_into_ast(id)?,
931            )),
932            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
933                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
934                Arc::unwrap_or_clone(right).try_into_ast(id)?,
935            )),
936            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
937                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
938                Arc::unwrap_or_clone(right).try_into_ast(id)?,
939            )),
940            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
941                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
942                Arc::unwrap_or_clone(right).try_into_ast(id)?,
943            )),
944            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
945                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
946                Arc::unwrap_or_clone(right).try_into_ast(id)?,
947            )),
948            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
949                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
950                Arc::unwrap_or_clone(right).try_into_ast(id)?,
951            )),
952            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
953                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
954                Arc::unwrap_or_clone(right).try_into_ast(id)?,
955            )),
956            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
957                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
958                Arc::unwrap_or_clone(right).try_into_ast(id)?,
959            )),
960            Expr::ExprNoExt(ExprNoExt::IsEmpty { arg }) => Ok(ast::Expr::is_empty(
961                Arc::unwrap_or_clone(arg).try_into_ast(id)?,
962            )),
963            Expr::ExprNoExt(ExprNoExt::GetTag { left, right }) => Ok(ast::Expr::get_tag(
964                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
965                Arc::unwrap_or_clone(right).try_into_ast(id)?,
966            )),
967            Expr::ExprNoExt(ExprNoExt::HasTag { left, right }) => Ok(ast::Expr::has_tag(
968                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
969                Arc::unwrap_or_clone(right).try_into_ast(id)?,
970            )),
971            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => Ok(ast::Expr::get_attr(
972                Arc::unwrap_or_clone(left).try_into_ast(id)?,
973                attr,
974            )),
975            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => Ok(ast::Expr::has_attr(
976                Arc::unwrap_or_clone(left).try_into_ast(id)?,
977                attr,
978            )),
979            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => Ok(ast::Expr::like(
980                Arc::unwrap_or_clone(left).try_into_ast(id)?,
981                crate::ast::Pattern::from(pattern.as_slice()),
982            )),
983            Expr::ExprNoExt(ExprNoExt::Is {
984                left,
985                entity_type,
986                in_expr,
987            }) => ast::EntityType::from_normalized_str(entity_type.as_str())
988                .map_err(FromJsonError::InvalidEntityType)
989                .and_then(|entity_type_name| {
990                    let left: ast::Expr = Arc::unwrap_or_clone(left).try_into_ast(id.clone())?;
991                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
992                    match in_expr {
993                        // The AST doesn't have an `... is ... in ..` node, so
994                        // we represent it as a conjunction of `is` and `in`.
995                        Some(in_expr) => Ok(ast::Expr::and(
996                            is_expr,
997                            ast::Expr::is_in(left, Arc::unwrap_or_clone(in_expr).try_into_ast(id)?),
998                        )),
999                        None => Ok(is_expr),
1000                    }
1001                }),
1002            Expr::ExprNoExt(ExprNoExt::If {
1003                cond_expr,
1004                then_expr,
1005                else_expr,
1006            }) => Ok(ast::Expr::ite(
1007                Arc::unwrap_or_clone(cond_expr).try_into_ast(id.clone())?,
1008                Arc::unwrap_or_clone(then_expr).try_into_ast(id.clone())?,
1009                Arc::unwrap_or_clone(else_expr).try_into_ast(id)?,
1010            )),
1011            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
1012                elements
1013                    .into_iter()
1014                    .map(|el| el.try_into_ast(id.clone()))
1015                    .collect::<Result<Vec<_>, FromJsonError>>()?,
1016            )),
1017            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
1018                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
1019                #[allow(clippy::expect_used)]
1020                Ok(ast::Expr::record(
1021                    map.into_iter()
1022                        .map(|(k, v)| Ok((k, v.try_into_ast(id.clone())?)))
1023                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
1024                )
1025                .expect("can't have duplicate keys here because the input was already a HashMap"))
1026            }
1027            Expr::ExtFuncCall(ExtFuncCall { call }) => {
1028                match call.len() {
1029                    0 => Err(FromJsonError::MissingOperator),
1030                    1 => {
1031                        // PANIC SAFETY checked that `call.len() == 1`
1032                        #[allow(clippy::expect_used)]
1033                        let (fn_name, args) = call
1034                            .into_iter()
1035                            .next()
1036                            .expect("already checked that len was 1");
1037                        let fn_name: ast::Name = fn_name.parse().map_err(|errs| {
1038                            JsonDeserializationError::parse_escape(
1039                                EscapeKind::Extension,
1040                                fn_name,
1041                                errs,
1042                            )
1043                        })?;
1044                        if !cst_to_ast::is_known_extension_func_name(&fn_name) {
1045                            return Err(FromJsonError::UnknownExtensionFunction(fn_name));
1046                        }
1047                        Ok(ast::Expr::call_extension_fn(
1048                            fn_name,
1049                            args.into_iter()
1050                                .map(|arg| arg.try_into_ast(id.clone()))
1051                                .collect::<Result<_, _>>()?,
1052                        ))
1053                    }
1054                    _ => Err(FromJsonError::MultipleOperators {
1055                        ops: call.into_keys().collect(),
1056                    }),
1057                }
1058            }
1059            #[cfg(feature = "tolerant-ast")]
1060            Expr::ExprNoExt(ExprNoExt::Error(_)) => Err(FromJsonError::ASTErrorNode),
1061        }
1062    }
1063}
1064
1065impl From<ast::Literal> for Expr {
1066    fn from(lit: ast::Literal) -> Expr {
1067        Builder::new().val(lit)
1068    }
1069}
1070
1071impl From<ast::Var> for Expr {
1072    fn from(var: ast::Var) -> Expr {
1073        Builder::new().var(var)
1074    }
1075}
1076
1077impl From<ast::SlotId> for Expr {
1078    fn from(slot: ast::SlotId) -> Expr {
1079        Builder::new().slot(slot)
1080    }
1081}
1082
1083impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
1084    type Error = ParseErrors;
1085    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
1086        e.to_expr::<Builder>()
1087    }
1088}
1089
1090impl std::fmt::Display for Expr {
1091    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1092        match self {
1093            Self::ExprNoExt(e) => write!(f, "{e}"),
1094            Self::ExtFuncCall(e) => write!(f, "{e}"),
1095        }
1096    }
1097}
1098
1099impl BoundedDisplay for Expr {
1100    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1101        match self {
1102            Self::ExprNoExt(e) => BoundedDisplay::fmt(e, f, n),
1103            Self::ExtFuncCall(e) => BoundedDisplay::fmt(e, f, n),
1104        }
1105    }
1106}
1107
1108fn display_cedarvaluejson(
1109    f: &mut impl std::fmt::Write,
1110    v: &CedarValueJson,
1111    n: Option<usize>,
1112) -> std::fmt::Result {
1113    match v {
1114        // Add parentheses around negative numeric literals otherwise
1115        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1116        CedarValueJson::Long(i) if *i < 0 => write!(f, "({i})"),
1117        CedarValueJson::Long(i) => write!(f, "{i}"),
1118        CedarValueJson::Bool(b) => write!(f, "{b}"),
1119        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1120        CedarValueJson::EntityEscape { __entity } => {
1121            match ast::EntityUID::try_from(__entity.clone()) {
1122                Ok(euid) => write!(f, "{euid}"),
1123                Err(e) => write!(f, "(invalid entity uid: {})", e),
1124            }
1125        }
1126        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1127        CedarValueJson::ExtnEscape {
1128            __extn: FnAndArg { ext_fn, arg },
1129        } => {
1130            // search for the name and callstyle
1131            let style = Extensions::all_available().all_funcs().find_map(|f| {
1132                if &f.name().to_smolstr() == ext_fn {
1133                    Some(f.style())
1134                } else {
1135                    None
1136                }
1137            });
1138            match style {
1139                Some(ast::CallStyle::MethodStyle) => {
1140                    display_cedarvaluejson(f, arg, n)?;
1141                    write!(f, ".{ext_fn}()")?;
1142                    Ok(())
1143                }
1144                Some(ast::CallStyle::FunctionStyle) | None => {
1145                    write!(f, "{ext_fn}(")?;
1146                    display_cedarvaluejson(f, arg, n)?;
1147                    write!(f, ")")?;
1148                    Ok(())
1149                }
1150            }
1151        }
1152        CedarValueJson::Set(v) => {
1153            match n {
1154                Some(n) if v.len() > n => {
1155                    // truncate to n elements
1156                    write!(f, "[")?;
1157                    for val in v.iter().take(n) {
1158                        display_cedarvaluejson(f, val, Some(n))?;
1159                        write!(f, ", ")?;
1160                    }
1161                    write!(f, "..]")?;
1162                    Ok(())
1163                }
1164                _ => {
1165                    // no truncation
1166                    write!(f, "[")?;
1167                    for (i, val) in v.iter().enumerate() {
1168                        display_cedarvaluejson(f, val, n)?;
1169                        if i < v.len() - 1 {
1170                            write!(f, ", ")?;
1171                        }
1172                    }
1173                    write!(f, "]")?;
1174                    Ok(())
1175                }
1176            }
1177        }
1178        CedarValueJson::Record(r) => {
1179            match n {
1180                Some(n) if r.len() > n => {
1181                    // truncate to n key-value pairs
1182                    write!(f, "{{")?;
1183                    for (k, v) in r.iter().take(n) {
1184                        write!(f, "\"{}\": ", k.escape_debug())?;
1185                        display_cedarvaluejson(f, v, Some(n))?;
1186                        write!(f, ", ")?;
1187                    }
1188                    write!(f, "..}}")?;
1189                    Ok(())
1190                }
1191                _ => {
1192                    // no truncation
1193                    write!(f, "{{")?;
1194                    for (i, (k, v)) in r.iter().enumerate() {
1195                        write!(f, "\"{}\": ", k.escape_debug())?;
1196                        display_cedarvaluejson(f, v, n)?;
1197                        if i < r.len() - 1 {
1198                            write!(f, ", ")?;
1199                        }
1200                    }
1201                    write!(f, "}}")?;
1202                    Ok(())
1203                }
1204            }
1205        }
1206        CedarValueJson::Null => {
1207            write!(f, "null")?;
1208            Ok(())
1209        }
1210    }
1211}
1212
1213impl std::fmt::Display for ExprNoExt {
1214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1215        BoundedDisplay::fmt_unbounded(self, f)
1216    }
1217}
1218
1219impl BoundedDisplay for ExprNoExt {
1220    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1221        match &self {
1222            ExprNoExt::Value(v) => display_cedarvaluejson(f, v, n),
1223            ExprNoExt::Var(v) => write!(f, "{v}"),
1224            ExprNoExt::Slot(id) => write!(f, "{id}"),
1225            ExprNoExt::Not { arg } => {
1226                write!(f, "!")?;
1227                maybe_with_parens(f, arg, n)
1228            }
1229            ExprNoExt::Neg { arg } => {
1230                // Always add parentheses instead of calling
1231                // `maybe_with_parens`.
1232                // This makes sure that we always get a negation operation back
1233                // (as opposed to e.g., a negative number) when parsing the
1234                // printed form, thus preserving the round-tripping property.
1235                write!(f, "-({arg})")
1236            }
1237            ExprNoExt::Eq { left, right } => {
1238                maybe_with_parens(f, left, n)?;
1239                write!(f, " == ")?;
1240                maybe_with_parens(f, right, n)
1241            }
1242            ExprNoExt::NotEq { left, right } => {
1243                maybe_with_parens(f, left, n)?;
1244                write!(f, " != ")?;
1245                maybe_with_parens(f, right, n)
1246            }
1247            ExprNoExt::In { left, right } => {
1248                maybe_with_parens(f, left, n)?;
1249                write!(f, " in ")?;
1250                maybe_with_parens(f, right, n)
1251            }
1252            ExprNoExt::Less { left, right } => {
1253                maybe_with_parens(f, left, n)?;
1254                write!(f, " < ")?;
1255                maybe_with_parens(f, right, n)
1256            }
1257            ExprNoExt::LessEq { left, right } => {
1258                maybe_with_parens(f, left, n)?;
1259                write!(f, " <= ")?;
1260                maybe_with_parens(f, right, n)
1261            }
1262            ExprNoExt::Greater { left, right } => {
1263                maybe_with_parens(f, left, n)?;
1264                write!(f, " > ")?;
1265                maybe_with_parens(f, right, n)
1266            }
1267            ExprNoExt::GreaterEq { left, right } => {
1268                maybe_with_parens(f, left, n)?;
1269                write!(f, " >= ")?;
1270                maybe_with_parens(f, right, n)
1271            }
1272            ExprNoExt::And { left, right } => {
1273                maybe_with_parens(f, left, n)?;
1274                write!(f, " && ")?;
1275                maybe_with_parens(f, right, n)
1276            }
1277            ExprNoExt::Or { left, right } => {
1278                maybe_with_parens(f, left, n)?;
1279                write!(f, " || ")?;
1280                maybe_with_parens(f, right, n)
1281            }
1282            ExprNoExt::Add { left, right } => {
1283                maybe_with_parens(f, left, n)?;
1284                write!(f, " + ")?;
1285                maybe_with_parens(f, right, n)
1286            }
1287            ExprNoExt::Sub { left, right } => {
1288                maybe_with_parens(f, left, n)?;
1289                write!(f, " - ")?;
1290                maybe_with_parens(f, right, n)
1291            }
1292            ExprNoExt::Mul { left, right } => {
1293                maybe_with_parens(f, left, n)?;
1294                write!(f, " * ")?;
1295                maybe_with_parens(f, right, n)
1296            }
1297            ExprNoExt::Contains { left, right } => {
1298                maybe_with_parens(f, left, n)?;
1299                write!(f, ".contains({right})")
1300            }
1301            ExprNoExt::ContainsAll { left, right } => {
1302                maybe_with_parens(f, left, n)?;
1303                write!(f, ".containsAll({right})")
1304            }
1305            ExprNoExt::ContainsAny { left, right } => {
1306                maybe_with_parens(f, left, n)?;
1307                write!(f, ".containsAny({right})")
1308            }
1309            ExprNoExt::IsEmpty { arg } => {
1310                maybe_with_parens(f, arg, n)?;
1311                write!(f, ".isEmpty()")
1312            }
1313            ExprNoExt::GetTag { left, right } => {
1314                maybe_with_parens(f, left, n)?;
1315                write!(f, ".getTag({right})")
1316            }
1317            ExprNoExt::HasTag { left, right } => {
1318                maybe_with_parens(f, left, n)?;
1319                write!(f, ".hasTag({right})")
1320            }
1321            ExprNoExt::GetAttr { left, attr } => {
1322                maybe_with_parens(f, left, n)?;
1323                write!(f, "[\"{}\"]", attr.escape_debug())
1324            }
1325            ExprNoExt::HasAttr { left, attr } => {
1326                maybe_with_parens(f, left, n)?;
1327                write!(f, " has \"{}\"", attr.escape_debug())
1328            }
1329            ExprNoExt::Like { left, pattern } => {
1330                maybe_with_parens(f, left, n)?;
1331                write!(
1332                    f,
1333                    " like \"{}\"",
1334                    crate::ast::Pattern::from(pattern.as_slice())
1335                )
1336            }
1337            ExprNoExt::Is {
1338                left,
1339                entity_type,
1340                in_expr,
1341            } => {
1342                maybe_with_parens(f, left, n)?;
1343                write!(f, " is {entity_type}")?;
1344                match in_expr {
1345                    Some(in_expr) => {
1346                        write!(f, " in ")?;
1347                        maybe_with_parens(f, in_expr, n)
1348                    }
1349                    None => Ok(()),
1350                }
1351            }
1352            ExprNoExt::If {
1353                cond_expr,
1354                then_expr,
1355                else_expr,
1356            } => {
1357                write!(f, "if ")?;
1358                maybe_with_parens(f, cond_expr, n)?;
1359                write!(f, " then ")?;
1360                maybe_with_parens(f, then_expr, n)?;
1361                write!(f, " else ")?;
1362                maybe_with_parens(f, else_expr, n)
1363            }
1364            ExprNoExt::Set(v) => {
1365                match n {
1366                    Some(n) if v.len() > n => {
1367                        // truncate to n elements
1368                        write!(f, "[")?;
1369                        for element in v.iter().take(n) {
1370                            BoundedDisplay::fmt(element, f, Some(n))?;
1371                            write!(f, ", ")?;
1372                        }
1373                        write!(f, "..]")?;
1374                        Ok(())
1375                    }
1376                    _ => {
1377                        // no truncation
1378                        write!(f, "[")?;
1379                        for (i, element) in v.iter().enumerate() {
1380                            BoundedDisplay::fmt(element, f, n)?;
1381                            if i < v.len() - 1 {
1382                                write!(f, ", ")?;
1383                            }
1384                        }
1385                        write!(f, "]")?;
1386                        Ok(())
1387                    }
1388                }
1389            }
1390            ExprNoExt::Record(m) => {
1391                match n {
1392                    Some(n) if m.len() > n => {
1393                        // truncate to n key-value pairs
1394                        write!(f, "{{")?;
1395                        for (k, v) in m.iter().take(n) {
1396                            write!(f, "\"{}\": ", k.escape_debug())?;
1397                            BoundedDisplay::fmt(v, f, Some(n))?;
1398                            write!(f, ", ")?;
1399                        }
1400                        write!(f, "..}}")?;
1401                        Ok(())
1402                    }
1403                    _ => {
1404                        // no truncation
1405                        write!(f, "{{")?;
1406                        for (i, (k, v)) in m.iter().enumerate() {
1407                            write!(f, "\"{}\": ", k.escape_debug())?;
1408                            BoundedDisplay::fmt(v, f, n)?;
1409                            if i < m.len() - 1 {
1410                                write!(f, ", ")?;
1411                            }
1412                        }
1413                        write!(f, "}}")?;
1414                        Ok(())
1415                    }
1416                }
1417            }
1418            #[cfg(feature = "tolerant-ast")]
1419            ExprNoExt::Error(e) => {
1420                write!(f, "{e}")?;
1421                Ok(())
1422            }
1423        }
1424    }
1425}
1426
1427impl std::fmt::Display for ExtFuncCall {
1428    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1429        BoundedDisplay::fmt_unbounded(self, f)
1430    }
1431}
1432
1433impl BoundedDisplay for ExtFuncCall {
1434    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1435        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1436        #[allow(clippy::unreachable)]
1437        let Some((fn_name, args)) = self.call.iter().next() else {
1438            unreachable!("invariant violated: empty ExtFuncCall")
1439        };
1440        // search for the name and callstyle
1441        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1442            if &ext_fn.name().to_smolstr() == fn_name {
1443                Some(ext_fn.style())
1444            } else {
1445                None
1446            }
1447        });
1448        match (style, args.iter().next()) {
1449            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1450                maybe_with_parens(f, receiver, n)?;
1451                write!(f, ".{}({})", fn_name, args.iter().skip(1).join(", "))
1452            }
1453            (_, _) => {
1454                write!(f, "{}({})", fn_name, args.iter().join(", "))
1455            }
1456        }
1457    }
1458}
1459
1460/// returns the `BoundedDisplay` representation of the Expr, adding parens around
1461/// the entire string if necessary.
1462/// E.g., won't add parens for constants or `principal` etc, but will for things
1463/// like `(2 < 5)`.
1464/// When in doubt, add the parens.
1465fn maybe_with_parens(
1466    f: &mut impl std::fmt::Write,
1467    expr: &Expr,
1468    n: Option<usize>,
1469) -> std::fmt::Result {
1470    match expr {
1471        Expr::ExprNoExt(ExprNoExt::Set(_)) |
1472        Expr::ExprNoExt(ExprNoExt::Record(_)) |
1473        Expr::ExprNoExt(ExprNoExt::Value(_)) |
1474        Expr::ExprNoExt(ExprNoExt::Var(_)) |
1475        Expr::ExprNoExt(ExprNoExt::Slot(_)) => BoundedDisplay::fmt(expr, f, n),
1476
1477        // we want parens here because things like parse((!x).y)
1478        // would be printed into !x.y which has a different meaning
1479        Expr::ExprNoExt(ExprNoExt::Not { .. }) |
1480        // we want parens here because things like parse((-x).y)
1481        // would be printed into -x.y which has a different meaning
1482        Expr::ExprNoExt(ExprNoExt::Neg { .. })  |
1483        Expr::ExprNoExt(ExprNoExt::Eq { .. }) |
1484        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) |
1485        Expr::ExprNoExt(ExprNoExt::In { .. }) |
1486        Expr::ExprNoExt(ExprNoExt::Less { .. }) |
1487        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) |
1488        Expr::ExprNoExt(ExprNoExt::Greater { .. }) |
1489        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) |
1490        Expr::ExprNoExt(ExprNoExt::And { .. }) |
1491        Expr::ExprNoExt(ExprNoExt::Or { .. }) |
1492        Expr::ExprNoExt(ExprNoExt::Add { .. }) |
1493        Expr::ExprNoExt(ExprNoExt::Sub { .. }) |
1494        Expr::ExprNoExt(ExprNoExt::Mul { .. }) |
1495        Expr::ExprNoExt(ExprNoExt::Contains { .. }) |
1496        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) |
1497        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) |
1498        Expr::ExprNoExt(ExprNoExt::IsEmpty { .. }) |
1499        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) |
1500        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) |
1501        Expr::ExprNoExt(ExprNoExt::GetTag { .. }) |
1502        Expr::ExprNoExt(ExprNoExt::HasTag { .. }) |
1503        Expr::ExprNoExt(ExprNoExt::Like { .. }) |
1504        Expr::ExprNoExt(ExprNoExt::Is { .. }) |
1505        Expr::ExprNoExt(ExprNoExt::If { .. }) |
1506        Expr::ExtFuncCall { .. } => {
1507            write!(f, "(")?;
1508            BoundedDisplay::fmt(expr, f, n)?;
1509            write!(f, ")")?;
1510            Ok(())
1511        },
1512        #[cfg(feature = "tolerant-ast")]
1513        Expr::ExprNoExt(ExprNoExt::Error { .. }) => {
1514            write!(f, "(")?;
1515            BoundedDisplay::fmt(expr, f, n)?;
1516            write!(f, ")")?;
1517            Ok(())
1518        }
1519    }
1520}
1521
1522#[cfg(test)]
1523// PANIC SAFETY: this is unit test code
1524#[allow(clippy::indexing_slicing)]
1525// PANIC SAFETY: Unit Test Code
1526#[allow(clippy::panic)]
1527mod test {
1528    use crate::parser::{
1529        err::{ParseError, ToASTErrorKind},
1530        parse_expr,
1531    };
1532
1533    use super::*;
1534    use ast::BoundedToString;
1535    use cool_asserts::assert_matches;
1536
1537    #[test]
1538    fn test_invalid_expr_from_cst_name() {
1539        let e = crate::parser::text_to_cst::parse_expr("some_long_str::else").unwrap();
1540        assert_matches!(Expr::try_from(&e), Err(e) => {
1541            assert!(e.len() == 1);
1542            assert_matches!(&e[0],
1543                ParseError::ToAST(to_ast_error) => {
1544                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::ReservedIdentifier(s) => {
1545                        assert_eq!(s.to_string(), "else");
1546                    });
1547                }
1548            );
1549        });
1550    }
1551
1552    #[test]
1553    fn display_and_bounded_display() {
1554        let expr = parse_expr(r#"[100, [3, 4, 5], -20, "foo"]"#)
1555            .unwrap()
1556            .into_expr::<Builder>();
1557        assert_eq!(format!("{expr}"), r#"[100, [3, 4, 5], (-20), "foo"]"#);
1558        assert_eq!(
1559            BoundedToString::to_string(&expr, None),
1560            r#"[100, [3, 4, 5], (-20), "foo"]"#
1561        );
1562        assert_eq!(
1563            BoundedToString::to_string(&expr, Some(4)),
1564            r#"[100, [3, 4, 5], (-20), "foo"]"#
1565        );
1566        assert_eq!(
1567            BoundedToString::to_string(&expr, Some(3)),
1568            r#"[100, [3, 4, 5], (-20), ..]"#
1569        );
1570        assert_eq!(
1571            BoundedToString::to_string(&expr, Some(2)),
1572            r#"[100, [3, 4, ..], ..]"#
1573        );
1574        assert_eq!(BoundedToString::to_string(&expr, Some(1)), r#"[100, ..]"#);
1575        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"[..]"#);
1576
1577        let expr = parse_expr(
1578            r#"{
1579            a: 12,
1580            b: [3, 4, true],
1581            c: -20,
1582            "hello ∞ world": "∂µß≈¥"
1583        }"#,
1584        )
1585        .unwrap()
1586        .into_expr::<Builder>();
1587        assert_eq!(
1588            format!("{expr}"),
1589            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1590        );
1591        assert_eq!(
1592            BoundedToString::to_string(&expr, None),
1593            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1594        );
1595        assert_eq!(
1596            BoundedToString::to_string(&expr, Some(4)),
1597            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1598        );
1599        assert_eq!(
1600            BoundedToString::to_string(&expr, Some(3)),
1601            r#"{"a": 12, "b": [3, 4, true], "c": (-20), ..}"#
1602        );
1603        assert_eq!(
1604            BoundedToString::to_string(&expr, Some(2)),
1605            r#"{"a": 12, "b": [3, 4, ..], ..}"#
1606        );
1607        assert_eq!(
1608            BoundedToString::to_string(&expr, Some(1)),
1609            r#"{"a": 12, ..}"#
1610        );
1611        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"{..}"#);
1612    }
1613}