cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18#[cfg(feature = "tolerant-ast")]
19use crate::ast::expr_allows_errors::AstExprErrorKind;
20#[cfg(feature = "tolerant-ast")]
21use crate::ast::Infallible;
22use crate::ast::{self, BoundedDisplay, EntityUID, Name};
23use crate::entities::json::{
24    err::EscapeKind, err::JsonDeserializationError, err::JsonDeserializationErrorContext,
25    CedarValueJson, FnAndArgs,
26};
27use crate::expr_builder::ExprBuilder;
28use crate::extensions::Extensions;
29use crate::jsonvalue::JsonValueWithNoDuplicateKeys;
30use crate::parser::cst_to_ast;
31use crate::parser::err::ParseErrors;
32use crate::parser::Node;
33use crate::parser::{cst, Loc};
34use crate::FromNormalizedStr;
35use itertools::Itertools;
36use serde::{de::Visitor, Deserialize, Serialize};
37use serde_with::serde_as;
38use smol_str::{SmolStr, ToSmolStr};
39use std::collections::{btree_map, BTreeMap, HashMap};
40use std::sync::Arc;
41
42/// Serde JSON structure for a Cedar expression in the EST format
43#[derive(Debug, Clone, PartialEq, Serialize)]
44#[serde(untagged)]
45#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
46#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
47pub enum Expr {
48    /// Any Cedar expression other than an extension function call.
49    ExprNoExt(ExprNoExt),
50    /// Extension function call, where the key is the name of an extension
51    /// function or method.
52    ExtFuncCall(ExtFuncCall),
53}
54
55// Manual implementation of `Deserialize` is more efficient than the derived
56// implementation with `serde(untagged)`. In particular, if the key is valid for
57// `ExprNoExt` but there is a deserialization problem within the corresponding
58// value, the derived implementation would backtrack and try to deserialize as
59// `ExtFuncCall` with that key as the extension function name, but this manual
60// implementation instead eagerly errors out, taking advantage of the fact that
61// none of the keys for `ExprNoExt` are valid extension function names.
62//
63// See #1284.
64impl<'de> Deserialize<'de> for Expr {
65    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66    where
67        D: serde::Deserializer<'de>,
68    {
69        struct ExprVisitor;
70        impl<'de> Visitor<'de> for ExprVisitor {
71            type Value = Expr;
72            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73                formatter.write_str("JSON object representing an expression")
74            }
75
76            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
77            where
78                A: serde::de::MapAccess<'de>,
79            {
80                let (k, v): (SmolStr, JsonValueWithNoDuplicateKeys) = match map.next_entry()? {
81                    None => {
82                        return Err(serde::de::Error::custom(
83                            "empty map is not a valid expression",
84                        ))
85                    }
86                    Some((k, v)) => (k, v),
87                };
88                match map.next_key()? {
89                    None => (),
90                    Some(k2) => {
91                        let k2: SmolStr = k2;
92                        return Err(serde::de::Error::custom(format!("JSON object representing an `Expr` should have only one key, but found two keys: `{k}` and `{k2}`")));
93                    }
94                };
95                if cst_to_ast::is_known_extension_func_str(&k) {
96                    // `k` is the name of an extension function or method. We assume that
97                    // no such keys are valid keys for `ExprNoExt`, so we must parse as an
98                    // `ExtFuncCall`.
99                    let obj = serde_json::json!({ k: v });
100                    let extfunccall =
101                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
102                    Ok(Expr::ExtFuncCall(extfunccall))
103                } else {
104                    // not a valid extension function or method, so we expect it
105                    // to work for `ExprNoExt`.
106                    let obj = serde_json::json!({ k: v });
107                    let exprnoext =
108                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
109                    Ok(Expr::ExprNoExt(exprnoext))
110                }
111            }
112        }
113
114        deserializer.deserialize_map(ExprVisitor)
115    }
116}
117
118/// Represent an element of a pattern literal
119#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
120#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
121#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
122pub enum PatternElem {
123    /// The wildcard asterisk
124    Wildcard,
125    /// A string without any wildcards
126    Literal(SmolStr),
127}
128
129impl From<&[PatternElem]> for crate::ast::Pattern {
130    fn from(value: &[PatternElem]) -> Self {
131        let mut elems = Vec::new();
132        for elem in value {
133            match elem {
134                PatternElem::Wildcard => {
135                    elems.push(crate::ast::PatternElem::Wildcard);
136                }
137                PatternElem::Literal(s) => {
138                    elems.extend(s.chars().map(crate::ast::PatternElem::Char));
139                }
140            }
141        }
142        Self::from(elems)
143    }
144}
145
146impl From<crate::ast::PatternElem> for PatternElem {
147    fn from(value: crate::ast::PatternElem) -> Self {
148        match value {
149            crate::ast::PatternElem::Wildcard => Self::Wildcard,
150            crate::ast::PatternElem::Char(c) => Self::Literal(c.to_smolstr()),
151        }
152    }
153}
154
155impl From<crate::ast::Pattern> for Vec<PatternElem> {
156    fn from(value: crate::ast::Pattern) -> Self {
157        value.iter().map(|elem| (*elem).into()).collect()
158    }
159}
160
161/// Serde JSON structure for [any Cedar expression other than an extension
162/// function call] in the EST format
163#[serde_as]
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
165#[serde(deny_unknown_fields)]
166#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
167#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
168pub enum ExprNoExt {
169    /// Literal value (including anything that's legal to express in the
170    /// attribute-value JSON format)
171    Value(CedarValueJson),
172    /// Var
173    Var(ast::Var),
174    /// Template slot
175    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
176    /// `!`
177    #[serde(rename = "!")]
178    Not {
179        /// Argument
180        arg: Arc<Expr>,
181    },
182    /// `-`
183    #[serde(rename = "neg")]
184    Neg {
185        /// Argument
186        arg: Arc<Expr>,
187    },
188    /// `==`
189    #[serde(rename = "==")]
190    Eq {
191        /// Left-hand argument
192        left: Arc<Expr>,
193        /// Right-hand argument
194        right: Arc<Expr>,
195    },
196    /// `!=`
197    #[serde(rename = "!=")]
198    NotEq {
199        /// Left-hand argument
200        left: Arc<Expr>,
201        /// Right-hand argument
202        right: Arc<Expr>,
203    },
204    /// `in`
205    #[serde(rename = "in")]
206    In {
207        /// Left-hand argument
208        left: Arc<Expr>,
209        /// Right-hand argument
210        right: Arc<Expr>,
211    },
212    /// `<`
213    #[serde(rename = "<")]
214    Less {
215        /// Left-hand argument
216        left: Arc<Expr>,
217        /// Right-hand argument
218        right: Arc<Expr>,
219    },
220    /// `<=`
221    #[serde(rename = "<=")]
222    LessEq {
223        /// Left-hand argument
224        left: Arc<Expr>,
225        /// Right-hand argument
226        right: Arc<Expr>,
227    },
228    /// `>`
229    #[serde(rename = ">")]
230    Greater {
231        /// Left-hand argument
232        left: Arc<Expr>,
233        /// Right-hand argument
234        right: Arc<Expr>,
235    },
236    /// `>=`
237    #[serde(rename = ">=")]
238    GreaterEq {
239        /// Left-hand argument
240        left: Arc<Expr>,
241        /// Right-hand argument
242        right: Arc<Expr>,
243    },
244    /// `&&`
245    #[serde(rename = "&&")]
246    And {
247        /// Left-hand argument
248        left: Arc<Expr>,
249        /// Right-hand argument
250        right: Arc<Expr>,
251    },
252    /// `||`
253    #[serde(rename = "||")]
254    Or {
255        /// Left-hand argument
256        left: Arc<Expr>,
257        /// Right-hand argument
258        right: Arc<Expr>,
259    },
260    /// `+`
261    #[serde(rename = "+")]
262    Add {
263        /// Left-hand argument
264        left: Arc<Expr>,
265        /// Right-hand argument
266        right: Arc<Expr>,
267    },
268    /// `-`
269    #[serde(rename = "-")]
270    Sub {
271        /// Left-hand argument
272        left: Arc<Expr>,
273        /// Right-hand argument
274        right: Arc<Expr>,
275    },
276    /// `*`
277    #[serde(rename = "*")]
278    Mul {
279        /// Left-hand argument
280        left: Arc<Expr>,
281        /// Right-hand argument
282        right: Arc<Expr>,
283    },
284    /// `contains()`
285    #[serde(rename = "contains")]
286    Contains {
287        /// Left-hand argument (receiver)
288        left: Arc<Expr>,
289        /// Right-hand argument (inside the `()`)
290        right: Arc<Expr>,
291    },
292    /// `containsAll()`
293    #[serde(rename = "containsAll")]
294    ContainsAll {
295        /// Left-hand argument (receiver)
296        left: Arc<Expr>,
297        /// Right-hand argument (inside the `()`)
298        right: Arc<Expr>,
299    },
300    /// `containsAny()`
301    #[serde(rename = "containsAny")]
302    ContainsAny {
303        /// Left-hand argument (receiver)
304        left: Arc<Expr>,
305        /// Right-hand argument (inside the `()`)
306        right: Arc<Expr>,
307    },
308    /// `isEmpty()`
309    #[serde(rename = "isEmpty")]
310    IsEmpty {
311        /// Argument
312        arg: Arc<Expr>,
313    },
314    /// `getTag()`
315    #[serde(rename = "getTag")]
316    GetTag {
317        /// Left-hand argument (receiver)
318        left: Arc<Expr>,
319        /// Right-hand argument (inside the `()`)
320        right: Arc<Expr>,
321    },
322    /// `hasTag()`
323    #[serde(rename = "hasTag")]
324    HasTag {
325        /// Left-hand argument (receiver)
326        left: Arc<Expr>,
327        /// Right-hand argument (inside the `()`)
328        right: Arc<Expr>,
329    },
330    /// Get-attribute
331    #[serde(rename = ".")]
332    GetAttr {
333        /// Left-hand argument
334        left: Arc<Expr>,
335        /// Attribute name
336        attr: SmolStr,
337    },
338    /// `has`
339    #[serde(rename = "has")]
340    HasAttr {
341        /// Left-hand argument
342        left: Arc<Expr>,
343        /// Attribute name
344        attr: SmolStr,
345    },
346    /// `like`
347    #[serde(rename = "like")]
348    Like {
349        /// Left-hand argument
350        left: Arc<Expr>,
351        /// Pattern
352        pattern: Vec<PatternElem>,
353    },
354    /// `<entity> is <entity_type> in <entity_or_entity_set> `
355    #[serde(rename = "is")]
356    Is {
357        /// Left-hand entity argument
358        left: Arc<Expr>,
359        /// Entity type
360        entity_type: SmolStr,
361        /// Entity or entity set
362        #[serde(skip_serializing_if = "Option::is_none")]
363        #[serde(rename = "in")]
364        in_expr: Option<Arc<Expr>>,
365    },
366    /// Ternary
367    #[serde(rename = "if-then-else")]
368    If {
369        /// Condition
370        #[serde(rename = "if")]
371        cond_expr: Arc<Expr>,
372        /// `then` expression
373        #[serde(rename = "then")]
374        then_expr: Arc<Expr>,
375        /// `else` expression
376        #[serde(rename = "else")]
377        else_expr: Arc<Expr>,
378    },
379    /// Set literal, whose elements may be arbitrary expressions
380    /// (which is why we need this case specifically and can't just
381    /// use Expr::Value)
382    Set(Vec<Expr>),
383    /// Record literal, whose elements may be arbitrary expressions
384    /// (which is why we need this case specifically and can't just
385    /// use Expr::Value)
386    Record(
387        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
388        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
389        BTreeMap<SmolStr, Expr>,
390    ),
391    /// AST Error node - this represents a parsing error in a partially generated AST
392    #[cfg(feature = "tolerant-ast")]
393    Error(AstExprErrorKind),
394}
395
396/// Serde JSON structure for an extension function call in the EST format
397#[serde_as]
398#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
399#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
400#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
401pub struct ExtFuncCall {
402    /// maps the name of the function to a JSON list/array of the arguments.
403    /// Note that for method calls, the method receiver is the first argument.
404    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
405    /// second argument is `b`.
406    ///
407    /// INVARIANT: This map should always have exactly one k-v pair (not more or
408    /// less), but we make it a map in order to get the correct JSON structure
409    /// we want.
410    #[serde(flatten)]
411    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
412    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
413    call: HashMap<SmolStr, Vec<Expr>>,
414}
415
416/// Construct an [`Expr`].
417#[derive(Clone, Debug)]
418pub struct Builder;
419
420impl ExprBuilder for Builder {
421    type Expr = Expr;
422
423    type Data = ();
424    #[cfg(feature = "tolerant-ast")]
425    type ErrorType = Infallible;
426
427    fn with_data(_data: Self::Data) -> Self {
428        Self
429    }
430
431    fn with_maybe_source_loc(self, _: Option<&Loc>) -> Self {
432        self
433    }
434
435    fn loc(&self) -> Option<&Loc> {
436        None
437    }
438
439    fn data(&self) -> &Self::Data {
440        &()
441    }
442
443    /// literal
444    fn val(self, lit: impl Into<ast::Literal>) -> Expr {
445        Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::from_lit(lit.into())))
446    }
447
448    /// principal, action, resource, context
449    fn var(self, var: ast::Var) -> Expr {
450        Expr::ExprNoExt(ExprNoExt::Var(var))
451    }
452
453    /// Template slots
454    fn slot(self, slot: ast::SlotId) -> Expr {
455        Expr::ExprNoExt(ExprNoExt::Slot(slot))
456    }
457
458    /// An extension call with one arg, which is the name of the unknown
459    fn unknown(self, u: ast::Unknown) -> Expr {
460        Expr::ExtFuncCall(ExtFuncCall {
461            call: HashMap::from([("unknown".to_smolstr(), vec![Builder::new().val(u.name)])]),
462        })
463    }
464
465    /// `!`
466    fn not(self, e: Expr) -> Expr {
467        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
468    }
469
470    /// `-`
471    fn neg(self, e: Expr) -> Expr {
472        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
473    }
474
475    /// `==`
476    fn is_eq(self, left: Expr, right: Expr) -> Expr {
477        Expr::ExprNoExt(ExprNoExt::Eq {
478            left: Arc::new(left),
479            right: Arc::new(right),
480        })
481    }
482
483    /// `!=`
484    fn noteq(self, left: Expr, right: Expr) -> Expr {
485        Expr::ExprNoExt(ExprNoExt::NotEq {
486            left: Arc::new(left),
487            right: Arc::new(right),
488        })
489    }
490
491    /// `in`
492    fn is_in(self, left: Expr, right: Expr) -> Expr {
493        Expr::ExprNoExt(ExprNoExt::In {
494            left: Arc::new(left),
495            right: Arc::new(right),
496        })
497    }
498
499    /// `<`
500    fn less(self, left: Expr, right: Expr) -> Expr {
501        Expr::ExprNoExt(ExprNoExt::Less {
502            left: Arc::new(left),
503            right: Arc::new(right),
504        })
505    }
506
507    /// `<=`
508    fn lesseq(self, left: Expr, right: Expr) -> Expr {
509        Expr::ExprNoExt(ExprNoExt::LessEq {
510            left: Arc::new(left),
511            right: Arc::new(right),
512        })
513    }
514
515    /// `>`
516    fn greater(self, left: Expr, right: Expr) -> Expr {
517        Expr::ExprNoExt(ExprNoExt::Greater {
518            left: Arc::new(left),
519            right: Arc::new(right),
520        })
521    }
522
523    /// `>=`
524    fn greatereq(self, left: Expr, right: Expr) -> Expr {
525        Expr::ExprNoExt(ExprNoExt::GreaterEq {
526            left: Arc::new(left),
527            right: Arc::new(right),
528        })
529    }
530
531    /// `&&`
532    fn and(self, left: Expr, right: Expr) -> Expr {
533        Expr::ExprNoExt(ExprNoExt::And {
534            left: Arc::new(left),
535            right: Arc::new(right),
536        })
537    }
538
539    /// `||`
540    fn or(self, left: Expr, right: Expr) -> Expr {
541        Expr::ExprNoExt(ExprNoExt::Or {
542            left: Arc::new(left),
543            right: Arc::new(right),
544        })
545    }
546
547    /// `+`
548    fn add(self, left: Expr, right: Expr) -> Expr {
549        Expr::ExprNoExt(ExprNoExt::Add {
550            left: Arc::new(left),
551            right: Arc::new(right),
552        })
553    }
554
555    /// `-`
556    fn sub(self, left: Expr, right: Expr) -> Expr {
557        Expr::ExprNoExt(ExprNoExt::Sub {
558            left: Arc::new(left),
559            right: Arc::new(right),
560        })
561    }
562
563    /// `*`
564    fn mul(self, left: Expr, right: Expr) -> Expr {
565        Expr::ExprNoExt(ExprNoExt::Mul {
566            left: Arc::new(left),
567            right: Arc::new(right),
568        })
569    }
570
571    /// `left.contains(right)`
572    fn contains(self, left: Expr, right: Expr) -> Expr {
573        Expr::ExprNoExt(ExprNoExt::Contains {
574            left: Arc::new(left),
575            right: Arc::new(right),
576        })
577    }
578
579    /// `left.containsAll(right)`
580    fn contains_all(self, left: Expr, right: Expr) -> Expr {
581        Expr::ExprNoExt(ExprNoExt::ContainsAll {
582            left: Arc::new(left),
583            right: Arc::new(right),
584        })
585    }
586
587    /// `left.containsAny(right)`
588    fn contains_any(self, left: Expr, right: Expr) -> Expr {
589        Expr::ExprNoExt(ExprNoExt::ContainsAny {
590            left: Arc::new(left),
591            right: Arc::new(right),
592        })
593    }
594
595    /// `arg.isEmpty()`
596    fn is_empty(self, expr: Expr) -> Expr {
597        Expr::ExprNoExt(ExprNoExt::IsEmpty {
598            arg: Arc::new(expr),
599        })
600    }
601
602    /// `left.getTag(right)`
603    fn get_tag(self, expr: Expr, tag: Expr) -> Expr {
604        Expr::ExprNoExt(ExprNoExt::GetTag {
605            left: Arc::new(expr),
606            right: Arc::new(tag),
607        })
608    }
609
610    /// `left.hasTag(right)`
611    fn has_tag(self, expr: Expr, tag: Expr) -> Expr {
612        Expr::ExprNoExt(ExprNoExt::HasTag {
613            left: Arc::new(expr),
614            right: Arc::new(tag),
615        })
616    }
617
618    /// `left.attr`
619    fn get_attr(self, expr: Expr, attr: SmolStr) -> Expr {
620        Expr::ExprNoExt(ExprNoExt::GetAttr {
621            left: Arc::new(expr),
622            attr,
623        })
624    }
625
626    /// `left has attr`
627    fn has_attr(self, expr: Expr, attr: SmolStr) -> Expr {
628        Expr::ExprNoExt(ExprNoExt::HasAttr {
629            left: Arc::new(expr),
630            attr,
631        })
632    }
633
634    /// `left like pattern`
635    fn like(self, expr: Expr, pattern: ast::Pattern) -> Expr {
636        Expr::ExprNoExt(ExprNoExt::Like {
637            left: Arc::new(expr),
638            pattern: pattern.into(),
639        })
640    }
641
642    /// `left is entity_type`
643    fn is_entity_type(self, left: Expr, entity_type: ast::EntityType) -> Expr {
644        Expr::ExprNoExt(ExprNoExt::Is {
645            left: Arc::new(left),
646            entity_type: entity_type.to_smolstr(),
647            in_expr: None,
648        })
649    }
650
651    /// `left is entity_type in entity`
652    fn is_in_entity_type(self, left: Expr, entity_type: ast::EntityType, entity: Expr) -> Expr {
653        Expr::ExprNoExt(ExprNoExt::Is {
654            left: Arc::new(left),
655            entity_type: entity_type.to_smolstr(),
656            in_expr: Some(Arc::new(entity)),
657        })
658    }
659
660    /// `if cond_expr then then_expr else else_expr`
661    fn ite(self, cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Expr {
662        Expr::ExprNoExt(ExprNoExt::If {
663            cond_expr: Arc::new(cond_expr),
664            then_expr: Arc::new(then_expr),
665            else_expr: Arc::new(else_expr),
666        })
667    }
668
669    /// e.g. [1+2, !(context has department)]
670    fn set(self, elements: impl IntoIterator<Item = Expr>) -> Expr {
671        Expr::ExprNoExt(ExprNoExt::Set(elements.into_iter().collect()))
672    }
673
674    /// e.g. {foo: 1+2, bar: !(context has department)}
675    fn record(
676        self,
677        map: impl IntoIterator<Item = (SmolStr, Expr)>,
678    ) -> Result<Expr, ast::ExpressionConstructionError> {
679        let mut dedup_map = BTreeMap::new();
680        for (k, v) in map {
681            match dedup_map.entry(k) {
682                btree_map::Entry::Occupied(oentry) => {
683                    return Err(ast::expression_construction_errors::DuplicateKeyError {
684                        key: oentry.key().clone(),
685                        context: "in record literal",
686                    }
687                    .into());
688                }
689                btree_map::Entry::Vacant(ventry) => {
690                    ventry.insert(v);
691                }
692            }
693        }
694        Ok(Expr::ExprNoExt(ExprNoExt::Record(dedup_map)))
695    }
696
697    /// extension function call, including method calls
698    fn call_extension_fn(self, fn_name: ast::Name, args: impl IntoIterator<Item = Expr>) -> Expr {
699        Expr::ExtFuncCall(ExtFuncCall {
700            call: HashMap::from([(fn_name.to_smolstr(), args.into_iter().collect())]),
701        })
702    }
703
704    #[cfg(feature = "tolerant-ast")]
705    fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
706        Ok(Expr::ExprNoExt(ExprNoExt::Error(
707            AstExprErrorKind::InvalidExpr(parse_errors.to_string()),
708        )))
709    }
710}
711
712impl Expr {
713    /// Substitute entity literals
714    pub fn sub_entity_literals(
715        self,
716        mapping: &BTreeMap<EntityUID, EntityUID>,
717    ) -> Result<Self, JsonDeserializationError> {
718        match self {
719            Expr::ExprNoExt(e) => match e {
720                ExprNoExt::Value(v) => Ok(Expr::ExprNoExt(ExprNoExt::Value(
721                    v.sub_entity_literals(mapping)?,
722                ))),
723                v @ ExprNoExt::Var(_) => Ok(Expr::ExprNoExt(v)),
724                s @ ExprNoExt::Slot(_) => Ok(Expr::ExprNoExt(s)),
725                ExprNoExt::Not { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Not {
726                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
727                })),
728                ExprNoExt::Neg { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Neg {
729                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
730                })),
731                ExprNoExt::Eq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Eq {
732                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
733                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
734                })),
735                ExprNoExt::NotEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::NotEq {
736                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
737                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
738                })),
739                ExprNoExt::In { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::In {
740                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
741                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
742                })),
743                ExprNoExt::Less { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Less {
744                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
745                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
746                })),
747                ExprNoExt::LessEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::LessEq {
748                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
749                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
750                })),
751                ExprNoExt::Greater { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Greater {
752                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
753                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
754                })),
755                ExprNoExt::GreaterEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GreaterEq {
756                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
757                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
758                })),
759                ExprNoExt::And { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::And {
760                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
761                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
762                })),
763                ExprNoExt::Or { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Or {
764                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
765                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
766                })),
767                ExprNoExt::Add { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Add {
768                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
769                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
770                })),
771                ExprNoExt::Sub { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Sub {
772                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
773                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
774                })),
775                ExprNoExt::Mul { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Mul {
776                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
777                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
778                })),
779                ExprNoExt::Contains { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Contains {
780                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
781                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
782                })),
783                ExprNoExt::ContainsAll { left, right } => {
784                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAll {
785                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
786                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
787                    }))
788                }
789                ExprNoExt::ContainsAny { left, right } => {
790                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAny {
791                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
792                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
793                    }))
794                }
795                ExprNoExt::IsEmpty { arg } => Ok(Expr::ExprNoExt(ExprNoExt::IsEmpty {
796                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
797                })),
798                ExprNoExt::GetTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GetTag {
799                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
800                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
801                })),
802                ExprNoExt::HasTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::HasTag {
803                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
804                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
805                })),
806                ExprNoExt::GetAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::GetAttr {
807                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
808                    attr,
809                })),
810                ExprNoExt::HasAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::HasAttr {
811                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
812                    attr,
813                })),
814                ExprNoExt::Like { left, pattern } => Ok(Expr::ExprNoExt(ExprNoExt::Like {
815                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
816                    pattern,
817                })),
818                ExprNoExt::Is {
819                    left,
820                    entity_type,
821                    in_expr,
822                } => match in_expr {
823                    Some(in_expr) => Ok(Expr::ExprNoExt(ExprNoExt::Is {
824                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
825                        entity_type,
826                        in_expr: Some(Arc::new(
827                            Arc::unwrap_or_clone(in_expr).sub_entity_literals(mapping)?,
828                        )),
829                    })),
830                    None => Ok(Expr::ExprNoExt(ExprNoExt::Is {
831                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
832                        entity_type,
833                        in_expr: None,
834                    })),
835                },
836                ExprNoExt::If {
837                    cond_expr,
838                    then_expr,
839                    else_expr,
840                } => Ok(Expr::ExprNoExt(ExprNoExt::If {
841                    cond_expr: Arc::new(
842                        Arc::unwrap_or_clone(cond_expr).sub_entity_literals(mapping)?,
843                    ),
844                    then_expr: Arc::new(
845                        Arc::unwrap_or_clone(then_expr).sub_entity_literals(mapping)?,
846                    ),
847                    else_expr: Arc::new(
848                        Arc::unwrap_or_clone(else_expr).sub_entity_literals(mapping)?,
849                    ),
850                })),
851                ExprNoExt::Set(v) => {
852                    let mut new_v = vec![];
853                    for e in v {
854                        new_v.push(e.sub_entity_literals(mapping)?);
855                    }
856                    Ok(Expr::ExprNoExt(ExprNoExt::Set(new_v)))
857                }
858                ExprNoExt::Record(m) => {
859                    let mut new_m = BTreeMap::new();
860                    for (k, v) in m {
861                        new_m.insert(k, v.sub_entity_literals(mapping)?);
862                    }
863                    Ok(Expr::ExprNoExt(ExprNoExt::Record(new_m)))
864                }
865                #[cfg(feature = "tolerant-ast")]
866                ExprNoExt::Error(_) => Err(JsonDeserializationError::ASTErrorNode),
867            },
868            Expr::ExtFuncCall(e_fn_call) => {
869                let mut new_m = HashMap::new();
870                for (k, v) in e_fn_call.call {
871                    let mut new_v = vec![];
872                    for e in v {
873                        new_v.push(e.sub_entity_literals(mapping)?);
874                    }
875                    new_m.insert(k, new_v);
876                }
877                Ok(Expr::ExtFuncCall(ExtFuncCall { call: new_m }))
878            }
879        }
880    }
881}
882
883impl Expr {
884    /// Attempt to convert this `est::Expr` into an `ast::Expr`
885    ///
886    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
887    pub fn try_into_ast(self, id: &ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
888        match self {
889            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
890                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
891                .map(Into::into)
892                .map_err(Into::into),
893            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
894            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
895            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
896                Ok(ast::Expr::not(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
897            }
898            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
899                Ok(ast::Expr::neg(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
900            }
901            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
902                Arc::unwrap_or_clone(left).try_into_ast(id)?,
903                Arc::unwrap_or_clone(right).try_into_ast(id)?,
904            )),
905            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
906                Arc::unwrap_or_clone(left).try_into_ast(id)?,
907                Arc::unwrap_or_clone(right).try_into_ast(id)?,
908            )),
909            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
910                Arc::unwrap_or_clone(left).try_into_ast(id)?,
911                Arc::unwrap_or_clone(right).try_into_ast(id)?,
912            )),
913            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
914                Arc::unwrap_or_clone(left).try_into_ast(id)?,
915                Arc::unwrap_or_clone(right).try_into_ast(id)?,
916            )),
917            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
918                Arc::unwrap_or_clone(left).try_into_ast(id)?,
919                Arc::unwrap_or_clone(right).try_into_ast(id)?,
920            )),
921            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
922                Arc::unwrap_or_clone(left).try_into_ast(id)?,
923                Arc::unwrap_or_clone(right).try_into_ast(id)?,
924            )),
925            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
926                Arc::unwrap_or_clone(left).try_into_ast(id)?,
927                Arc::unwrap_or_clone(right).try_into_ast(id)?,
928            )),
929            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
930                Arc::unwrap_or_clone(left).try_into_ast(id)?,
931                Arc::unwrap_or_clone(right).try_into_ast(id)?,
932            )),
933            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
934                Arc::unwrap_or_clone(left).try_into_ast(id)?,
935                Arc::unwrap_or_clone(right).try_into_ast(id)?,
936            )),
937            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
938                Arc::unwrap_or_clone(left).try_into_ast(id)?,
939                Arc::unwrap_or_clone(right).try_into_ast(id)?,
940            )),
941            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
942                Arc::unwrap_or_clone(left).try_into_ast(id)?,
943                Arc::unwrap_or_clone(right).try_into_ast(id)?,
944            )),
945            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
946                Arc::unwrap_or_clone(left).try_into_ast(id)?,
947                Arc::unwrap_or_clone(right).try_into_ast(id)?,
948            )),
949            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
950                Arc::unwrap_or_clone(left).try_into_ast(id)?,
951                Arc::unwrap_or_clone(right).try_into_ast(id)?,
952            )),
953            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
954                Arc::unwrap_or_clone(left).try_into_ast(id)?,
955                Arc::unwrap_or_clone(right).try_into_ast(id)?,
956            )),
957            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
958                Arc::unwrap_or_clone(left).try_into_ast(id)?,
959                Arc::unwrap_or_clone(right).try_into_ast(id)?,
960            )),
961            Expr::ExprNoExt(ExprNoExt::IsEmpty { arg }) => Ok(ast::Expr::is_empty(
962                Arc::unwrap_or_clone(arg).try_into_ast(id)?,
963            )),
964            Expr::ExprNoExt(ExprNoExt::GetTag { left, right }) => Ok(ast::Expr::get_tag(
965                Arc::unwrap_or_clone(left).try_into_ast(id)?,
966                Arc::unwrap_or_clone(right).try_into_ast(id)?,
967            )),
968            Expr::ExprNoExt(ExprNoExt::HasTag { left, right }) => Ok(ast::Expr::has_tag(
969                Arc::unwrap_or_clone(left).try_into_ast(id)?,
970                Arc::unwrap_or_clone(right).try_into_ast(id)?,
971            )),
972            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => Ok(ast::Expr::get_attr(
973                Arc::unwrap_or_clone(left).try_into_ast(id)?,
974                attr,
975            )),
976            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => Ok(ast::Expr::has_attr(
977                Arc::unwrap_or_clone(left).try_into_ast(id)?,
978                attr,
979            )),
980            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => Ok(ast::Expr::like(
981                Arc::unwrap_or_clone(left).try_into_ast(id)?,
982                crate::ast::Pattern::from(pattern.as_slice()),
983            )),
984            Expr::ExprNoExt(ExprNoExt::Is {
985                left,
986                entity_type,
987                in_expr,
988            }) => ast::EntityType::from_normalized_str(entity_type.as_str())
989                .map_err(FromJsonError::InvalidEntityType)
990                .and_then(|entity_type_name| {
991                    let left: ast::Expr = Arc::unwrap_or_clone(left).try_into_ast(id)?;
992                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
993                    match in_expr {
994                        // The AST doesn't have an `... is ... in ..` node, so
995                        // we represent it as a conjunction of `is` and `in`.
996                        Some(in_expr) => Ok(ast::Expr::and(
997                            is_expr,
998                            ast::Expr::is_in(left, Arc::unwrap_or_clone(in_expr).try_into_ast(id)?),
999                        )),
1000                        None => Ok(is_expr),
1001                    }
1002                }),
1003            Expr::ExprNoExt(ExprNoExt::If {
1004                cond_expr,
1005                then_expr,
1006                else_expr,
1007            }) => Ok(ast::Expr::ite(
1008                Arc::unwrap_or_clone(cond_expr).try_into_ast(id)?,
1009                Arc::unwrap_or_clone(then_expr).try_into_ast(id)?,
1010                Arc::unwrap_or_clone(else_expr).try_into_ast(id)?,
1011            )),
1012            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
1013                elements
1014                    .into_iter()
1015                    .map(|el| el.try_into_ast(id))
1016                    .collect::<Result<Vec<_>, FromJsonError>>()?,
1017            )),
1018            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
1019                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
1020                #[allow(clippy::expect_used)]
1021                Ok(ast::Expr::record(
1022                    map.into_iter()
1023                        .map(|(k, v)| Ok((k, v.try_into_ast(id)?)))
1024                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
1025                )
1026                .expect("can't have duplicate keys here because the input was already a HashMap"))
1027            }
1028            Expr::ExtFuncCall(ExtFuncCall { call }) => {
1029                match call.len() {
1030                    0 => Err(FromJsonError::MissingOperator),
1031                    1 => {
1032                        // PANIC SAFETY checked that `call.len() == 1`
1033                        #[allow(clippy::expect_used)]
1034                        let (fn_name, args) = call
1035                            .into_iter()
1036                            .next()
1037                            .expect("already checked that len was 1");
1038                        let fn_name = Name::from_normalized_str(&fn_name).map_err(|errs| {
1039                            JsonDeserializationError::parse_escape(
1040                                EscapeKind::Extension,
1041                                fn_name,
1042                                errs,
1043                            )
1044                        })?;
1045                        if cst_to_ast::is_known_extension_func_name(&fn_name) {
1046                            Ok(ast::Expr::call_extension_fn(
1047                                fn_name,
1048                                args.into_iter()
1049                                    .map(|arg| arg.try_into_ast(id))
1050                                    .collect::<Result<_, _>>()?,
1051                            ))
1052                        } else {
1053                            Err(FromJsonError::UnknownExtensionFunction(fn_name))
1054                        }
1055                    }
1056                    _ => Err(FromJsonError::MultipleOperators {
1057                        ops: call.into_keys().collect(),
1058                    }),
1059                }
1060            }
1061            #[cfg(feature = "tolerant-ast")]
1062            Expr::ExprNoExt(ExprNoExt::Error(_)) => Err(FromJsonError::ASTErrorNode),
1063        }
1064    }
1065}
1066
1067impl From<ast::Literal> for Expr {
1068    fn from(lit: ast::Literal) -> Expr {
1069        Builder::new().val(lit)
1070    }
1071}
1072
1073impl From<ast::Var> for Expr {
1074    fn from(var: ast::Var) -> Expr {
1075        Builder::new().var(var)
1076    }
1077}
1078
1079impl From<ast::SlotId> for Expr {
1080    fn from(slot: ast::SlotId) -> Expr {
1081        Builder::new().slot(slot)
1082    }
1083}
1084
1085impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
1086    type Error = ParseErrors;
1087    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
1088        e.to_expr::<Builder>()
1089    }
1090}
1091
1092impl std::fmt::Display for Expr {
1093    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1094        match self {
1095            Self::ExprNoExt(e) => write!(f, "{e}"),
1096            Self::ExtFuncCall(e) => write!(f, "{e}"),
1097        }
1098    }
1099}
1100
1101impl BoundedDisplay for Expr {
1102    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1103        match self {
1104            Self::ExprNoExt(e) => BoundedDisplay::fmt(e, f, n),
1105            Self::ExtFuncCall(e) => BoundedDisplay::fmt(e, f, n),
1106        }
1107    }
1108}
1109
1110fn display_cedarvaluejson(
1111    f: &mut impl std::fmt::Write,
1112    v: &CedarValueJson,
1113    n: Option<usize>,
1114) -> std::fmt::Result {
1115    match v {
1116        // Add parentheses around negative numeric literals otherwise
1117        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1118        CedarValueJson::Long(i) if *i < 0 => write!(f, "({i})"),
1119        CedarValueJson::Long(i) => write!(f, "{i}"),
1120        CedarValueJson::Bool(b) => write!(f, "{b}"),
1121        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1122        CedarValueJson::EntityEscape { __entity } => {
1123            match ast::EntityUID::try_from(__entity.clone()) {
1124                Ok(euid) => write!(f, "{euid}"),
1125                Err(e) => write!(f, "(invalid entity uid: {e})"),
1126            }
1127        }
1128        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1129        CedarValueJson::ExtnEscape {
1130            __extn: FnAndArgs::Single { ext_fn, arg },
1131        } => {
1132            // search for the name and callstyle
1133            let style = Extensions::all_available().all_funcs().find_map(|f| {
1134                if &f.name().to_smolstr() == ext_fn {
1135                    Some(f.style())
1136                } else {
1137                    None
1138                }
1139            });
1140            match style {
1141                Some(ast::CallStyle::MethodStyle) => {
1142                    display_cedarvaluejson(f, arg, n)?;
1143                    write!(f, ".{ext_fn}()")?;
1144                    Ok(())
1145                }
1146                Some(ast::CallStyle::FunctionStyle) | None => {
1147                    write!(f, "{ext_fn}(")?;
1148                    display_cedarvaluejson(f, arg, n)?;
1149                    write!(f, ")")?;
1150                    Ok(())
1151                }
1152            }
1153        }
1154        CedarValueJson::ExtnEscape {
1155            __extn: FnAndArgs::Multi { ext_fn, args },
1156        } => {
1157            // search for the name and callstyle
1158            let style = Extensions::all_available().all_funcs().find_map(|f| {
1159                if &f.name().to_smolstr() == ext_fn {
1160                    Some(f.style())
1161                } else {
1162                    None
1163                }
1164            });
1165            match style {
1166                Some(ast::CallStyle::MethodStyle) => {
1167                    // PANIC SAFETY: method-style calls must have more than one argument
1168                    #[allow(clippy::indexing_slicing)]
1169                    display_cedarvaluejson(f, &args[0], n)?;
1170                    write!(f, ".{ext_fn}(")?;
1171                    // PANIC SAFETY: method-style calls must have more than one argument
1172                    #[allow(clippy::indexing_slicing)]
1173                    match &args[1..] {
1174                        [] => {}
1175                        [args @ .., last] => {
1176                            for arg in args {
1177                                display_cedarvaluejson(f, arg, n)?;
1178                                write!(f, ", ")?;
1179                            }
1180                            display_cedarvaluejson(f, last, n)?;
1181                        }
1182                    }
1183                    write!(f, ")")?;
1184                    Ok(())
1185                }
1186                Some(ast::CallStyle::FunctionStyle) | None => {
1187                    write!(f, "{ext_fn}(")?;
1188                    match &args[..] {
1189                        [] => {}
1190                        [args @ .., last] => {
1191                            for arg in args {
1192                                display_cedarvaluejson(f, arg, n)?;
1193                                write!(f, ", ")?;
1194                            }
1195                            display_cedarvaluejson(f, last, n)?;
1196                        }
1197                    }
1198                    write!(f, ")")?;
1199                    Ok(())
1200                }
1201            }
1202        }
1203        CedarValueJson::Set(v) => {
1204            match n {
1205                Some(n) if v.len() > n => {
1206                    // truncate to n elements
1207                    write!(f, "[")?;
1208                    for val in v.iter().take(n) {
1209                        display_cedarvaluejson(f, val, Some(n))?;
1210                        write!(f, ", ")?;
1211                    }
1212                    write!(f, "..]")?;
1213                    Ok(())
1214                }
1215                _ => {
1216                    // no truncation
1217                    write!(f, "[")?;
1218                    for (i, val) in v.iter().enumerate() {
1219                        display_cedarvaluejson(f, val, n)?;
1220                        if i < v.len() - 1 {
1221                            write!(f, ", ")?;
1222                        }
1223                    }
1224                    write!(f, "]")?;
1225                    Ok(())
1226                }
1227            }
1228        }
1229        CedarValueJson::Record(r) => {
1230            match n {
1231                Some(n) if r.len() > n => {
1232                    // truncate to n key-value pairs
1233                    write!(f, "{{")?;
1234                    for (k, v) in r.iter().take(n) {
1235                        write!(f, "\"{}\": ", k.escape_debug())?;
1236                        display_cedarvaluejson(f, v, Some(n))?;
1237                        write!(f, ", ")?;
1238                    }
1239                    write!(f, "..}}")?;
1240                    Ok(())
1241                }
1242                _ => {
1243                    // no truncation
1244                    write!(f, "{{")?;
1245                    for (i, (k, v)) in r.iter().enumerate() {
1246                        write!(f, "\"{}\": ", k.escape_debug())?;
1247                        display_cedarvaluejson(f, v, n)?;
1248                        if i < r.len() - 1 {
1249                            write!(f, ", ")?;
1250                        }
1251                    }
1252                    write!(f, "}}")?;
1253                    Ok(())
1254                }
1255            }
1256        }
1257        CedarValueJson::Null => {
1258            write!(f, "null")?;
1259            Ok(())
1260        }
1261    }
1262}
1263
1264impl std::fmt::Display for ExprNoExt {
1265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1266        BoundedDisplay::fmt_unbounded(self, f)
1267    }
1268}
1269
1270impl BoundedDisplay for ExprNoExt {
1271    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1272        match &self {
1273            ExprNoExt::Value(v) => display_cedarvaluejson(f, v, n),
1274            ExprNoExt::Var(v) => write!(f, "{v}"),
1275            ExprNoExt::Slot(id) => write!(f, "{id}"),
1276            ExprNoExt::Not { arg } => {
1277                write!(f, "!")?;
1278                maybe_with_parens(f, arg, n)
1279            }
1280            ExprNoExt::Neg { arg } => {
1281                // Always add parentheses instead of calling
1282                // `maybe_with_parens`.
1283                // This makes sure that we always get a negation operation back
1284                // (as opposed to e.g., a negative number) when parsing the
1285                // printed form, thus preserving the round-tripping property.
1286                write!(f, "-({arg})")
1287            }
1288            ExprNoExt::Eq { left, right } => {
1289                maybe_with_parens(f, left, n)?;
1290                write!(f, " == ")?;
1291                maybe_with_parens(f, right, n)
1292            }
1293            ExprNoExt::NotEq { left, right } => {
1294                maybe_with_parens(f, left, n)?;
1295                write!(f, " != ")?;
1296                maybe_with_parens(f, right, n)
1297            }
1298            ExprNoExt::In { left, right } => {
1299                maybe_with_parens(f, left, n)?;
1300                write!(f, " in ")?;
1301                maybe_with_parens(f, right, n)
1302            }
1303            ExprNoExt::Less { left, right } => {
1304                maybe_with_parens(f, left, n)?;
1305                write!(f, " < ")?;
1306                maybe_with_parens(f, right, n)
1307            }
1308            ExprNoExt::LessEq { left, right } => {
1309                maybe_with_parens(f, left, n)?;
1310                write!(f, " <= ")?;
1311                maybe_with_parens(f, right, n)
1312            }
1313            ExprNoExt::Greater { left, right } => {
1314                maybe_with_parens(f, left, n)?;
1315                write!(f, " > ")?;
1316                maybe_with_parens(f, right, n)
1317            }
1318            ExprNoExt::GreaterEq { left, right } => {
1319                maybe_with_parens(f, left, n)?;
1320                write!(f, " >= ")?;
1321                maybe_with_parens(f, right, n)
1322            }
1323            ExprNoExt::And { left, right } => {
1324                maybe_with_parens(f, left, n)?;
1325                write!(f, " && ")?;
1326                maybe_with_parens(f, right, n)
1327            }
1328            ExprNoExt::Or { left, right } => {
1329                maybe_with_parens(f, left, n)?;
1330                write!(f, " || ")?;
1331                maybe_with_parens(f, right, n)
1332            }
1333            ExprNoExt::Add { left, right } => {
1334                maybe_with_parens(f, left, n)?;
1335                write!(f, " + ")?;
1336                maybe_with_parens(f, right, n)
1337            }
1338            ExprNoExt::Sub { left, right } => {
1339                maybe_with_parens(f, left, n)?;
1340                write!(f, " - ")?;
1341                maybe_with_parens(f, right, n)
1342            }
1343            ExprNoExt::Mul { left, right } => {
1344                maybe_with_parens(f, left, n)?;
1345                write!(f, " * ")?;
1346                maybe_with_parens(f, right, n)
1347            }
1348            ExprNoExt::Contains { left, right } => {
1349                maybe_with_parens(f, left, n)?;
1350                write!(f, ".contains({right})")
1351            }
1352            ExprNoExt::ContainsAll { left, right } => {
1353                maybe_with_parens(f, left, n)?;
1354                write!(f, ".containsAll({right})")
1355            }
1356            ExprNoExt::ContainsAny { left, right } => {
1357                maybe_with_parens(f, left, n)?;
1358                write!(f, ".containsAny({right})")
1359            }
1360            ExprNoExt::IsEmpty { arg } => {
1361                maybe_with_parens(f, arg, n)?;
1362                write!(f, ".isEmpty()")
1363            }
1364            ExprNoExt::GetTag { left, right } => {
1365                maybe_with_parens(f, left, n)?;
1366                write!(f, ".getTag({right})")
1367            }
1368            ExprNoExt::HasTag { left, right } => {
1369                maybe_with_parens(f, left, n)?;
1370                write!(f, ".hasTag({right})")
1371            }
1372            ExprNoExt::GetAttr { left, attr } => {
1373                maybe_with_parens(f, left, n)?;
1374                write!(f, "[\"{}\"]", attr.escape_debug())
1375            }
1376            ExprNoExt::HasAttr { left, attr } => {
1377                maybe_with_parens(f, left, n)?;
1378                write!(f, " has \"{}\"", attr.escape_debug())
1379            }
1380            ExprNoExt::Like { left, pattern } => {
1381                maybe_with_parens(f, left, n)?;
1382                write!(
1383                    f,
1384                    " like \"{}\"",
1385                    crate::ast::Pattern::from(pattern.as_slice())
1386                )
1387            }
1388            ExprNoExt::Is {
1389                left,
1390                entity_type,
1391                in_expr,
1392            } => {
1393                maybe_with_parens(f, left, n)?;
1394                write!(f, " is {entity_type}")?;
1395                match in_expr {
1396                    Some(in_expr) => {
1397                        write!(f, " in ")?;
1398                        maybe_with_parens(f, in_expr, n)
1399                    }
1400                    None => Ok(()),
1401                }
1402            }
1403            ExprNoExt::If {
1404                cond_expr,
1405                then_expr,
1406                else_expr,
1407            } => {
1408                write!(f, "if ")?;
1409                maybe_with_parens(f, cond_expr, n)?;
1410                write!(f, " then ")?;
1411                maybe_with_parens(f, then_expr, n)?;
1412                write!(f, " else ")?;
1413                maybe_with_parens(f, else_expr, n)
1414            }
1415            ExprNoExt::Set(v) => {
1416                match n {
1417                    Some(n) if v.len() > n => {
1418                        // truncate to n elements
1419                        write!(f, "[")?;
1420                        for element in v.iter().take(n) {
1421                            BoundedDisplay::fmt(element, f, Some(n))?;
1422                            write!(f, ", ")?;
1423                        }
1424                        write!(f, "..]")?;
1425                        Ok(())
1426                    }
1427                    _ => {
1428                        // no truncation
1429                        write!(f, "[")?;
1430                        for (i, element) in v.iter().enumerate() {
1431                            BoundedDisplay::fmt(element, f, n)?;
1432                            if i < v.len() - 1 {
1433                                write!(f, ", ")?;
1434                            }
1435                        }
1436                        write!(f, "]")?;
1437                        Ok(())
1438                    }
1439                }
1440            }
1441            ExprNoExt::Record(m) => {
1442                match n {
1443                    Some(n) if m.len() > n => {
1444                        // truncate to n key-value pairs
1445                        write!(f, "{{")?;
1446                        for (k, v) in m.iter().take(n) {
1447                            write!(f, "\"{}\": ", k.escape_debug())?;
1448                            BoundedDisplay::fmt(v, f, Some(n))?;
1449                            write!(f, ", ")?;
1450                        }
1451                        write!(f, "..}}")?;
1452                        Ok(())
1453                    }
1454                    _ => {
1455                        // no truncation
1456                        write!(f, "{{")?;
1457                        for (i, (k, v)) in m.iter().enumerate() {
1458                            write!(f, "\"{}\": ", k.escape_debug())?;
1459                            BoundedDisplay::fmt(v, f, n)?;
1460                            if i < m.len() - 1 {
1461                                write!(f, ", ")?;
1462                            }
1463                        }
1464                        write!(f, "}}")?;
1465                        Ok(())
1466                    }
1467                }
1468            }
1469            #[cfg(feature = "tolerant-ast")]
1470            ExprNoExt::Error(e) => {
1471                write!(f, "{e}")?;
1472                Ok(())
1473            }
1474        }
1475    }
1476}
1477
1478impl std::fmt::Display for ExtFuncCall {
1479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1480        BoundedDisplay::fmt_unbounded(self, f)
1481    }
1482}
1483
1484impl BoundedDisplay for ExtFuncCall {
1485    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1486        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1487        #[allow(clippy::unreachable)]
1488        let Some((fn_name, args)) = self.call.iter().next() else {
1489            unreachable!("invariant violated: empty ExtFuncCall")
1490        };
1491        // search for the name and callstyle
1492        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1493            if &ext_fn.name().to_smolstr() == fn_name {
1494                Some(ext_fn.style())
1495            } else {
1496                None
1497            }
1498        });
1499        match (style, args.iter().next()) {
1500            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1501                maybe_with_parens(f, receiver, n)?;
1502                write!(f, ".{}({})", fn_name, args.iter().skip(1).join(", "))
1503            }
1504            (_, _) => {
1505                write!(f, "{}({})", fn_name, args.iter().join(", "))
1506            }
1507        }
1508    }
1509}
1510
1511/// returns the `BoundedDisplay` representation of the Expr, adding parens around
1512/// the entire string if necessary.
1513/// E.g., won't add parens for constants or `principal` etc, but will for things
1514/// like `(2 < 5)`.
1515/// When in doubt, add the parens.
1516fn maybe_with_parens(
1517    f: &mut impl std::fmt::Write,
1518    expr: &Expr,
1519    n: Option<usize>,
1520) -> std::fmt::Result {
1521    match expr {
1522        Expr::ExprNoExt(ExprNoExt::Set(_)) |
1523        Expr::ExprNoExt(ExprNoExt::Record(_)) |
1524        Expr::ExprNoExt(ExprNoExt::Value(_)) |
1525        Expr::ExprNoExt(ExprNoExt::Var(_)) |
1526        Expr::ExprNoExt(ExprNoExt::Slot(_)) => BoundedDisplay::fmt(expr, f, n),
1527
1528        // we want parens here because things like parse((!x).y)
1529        // would be printed into !x.y which has a different meaning
1530        Expr::ExprNoExt(ExprNoExt::Not { .. }) |
1531        // we want parens here because things like parse((-x).y)
1532        // would be printed into -x.y which has a different meaning
1533        Expr::ExprNoExt(ExprNoExt::Neg { .. })  |
1534        Expr::ExprNoExt(ExprNoExt::Eq { .. }) |
1535        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) |
1536        Expr::ExprNoExt(ExprNoExt::In { .. }) |
1537        Expr::ExprNoExt(ExprNoExt::Less { .. }) |
1538        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) |
1539        Expr::ExprNoExt(ExprNoExt::Greater { .. }) |
1540        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) |
1541        Expr::ExprNoExt(ExprNoExt::And { .. }) |
1542        Expr::ExprNoExt(ExprNoExt::Or { .. }) |
1543        Expr::ExprNoExt(ExprNoExt::Add { .. }) |
1544        Expr::ExprNoExt(ExprNoExt::Sub { .. }) |
1545        Expr::ExprNoExt(ExprNoExt::Mul { .. }) |
1546        Expr::ExprNoExt(ExprNoExt::Contains { .. }) |
1547        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) |
1548        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) |
1549        Expr::ExprNoExt(ExprNoExt::IsEmpty { .. }) |
1550        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) |
1551        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) |
1552        Expr::ExprNoExt(ExprNoExt::GetTag { .. }) |
1553        Expr::ExprNoExt(ExprNoExt::HasTag { .. }) |
1554        Expr::ExprNoExt(ExprNoExt::Like { .. }) |
1555        Expr::ExprNoExt(ExprNoExt::Is { .. }) |
1556        Expr::ExprNoExt(ExprNoExt::If { .. }) |
1557        Expr::ExtFuncCall { .. } => {
1558            write!(f, "(")?;
1559            BoundedDisplay::fmt(expr, f, n)?;
1560            write!(f, ")")?;
1561            Ok(())
1562        },
1563        #[cfg(feature = "tolerant-ast")]
1564        Expr::ExprNoExt(ExprNoExt::Error { .. }) => {
1565            write!(f, "(")?;
1566            BoundedDisplay::fmt(expr, f, n)?;
1567            write!(f, ")")?;
1568            Ok(())
1569        }
1570    }
1571}
1572
1573#[cfg(test)]
1574// PANIC SAFETY: this is unit test code
1575#[allow(clippy::indexing_slicing)]
1576// PANIC SAFETY: Unit Test Code
1577#[allow(clippy::panic)]
1578mod test {
1579    use crate::parser::{
1580        err::{ParseError, ToASTErrorKind},
1581        parse_expr,
1582    };
1583
1584    use super::*;
1585    use ast::BoundedToString;
1586    use cool_asserts::assert_matches;
1587
1588    #[test]
1589    fn test_invalid_expr_from_cst_name() {
1590        let e = crate::parser::text_to_cst::parse_expr("some_long_str::else").unwrap();
1591        assert_matches!(Expr::try_from(&e), Err(e) => {
1592            assert!(e.len() == 1);
1593            assert_matches!(&e[0],
1594                ParseError::ToAST(to_ast_error) => {
1595                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::ReservedIdentifier(s) => {
1596                        assert_eq!(s.to_string(), "else");
1597                    });
1598                }
1599            );
1600        });
1601    }
1602
1603    #[test]
1604    fn display_and_bounded_display() {
1605        let expr = parse_expr(r#"[100, [3, 4, 5], -20, "foo"]"#)
1606            .unwrap()
1607            .into_expr::<Builder>();
1608        assert_eq!(format!("{expr}"), r#"[100, [3, 4, 5], (-20), "foo"]"#);
1609        assert_eq!(
1610            BoundedToString::to_string(&expr, None),
1611            r#"[100, [3, 4, 5], (-20), "foo"]"#
1612        );
1613        assert_eq!(
1614            BoundedToString::to_string(&expr, Some(4)),
1615            r#"[100, [3, 4, 5], (-20), "foo"]"#
1616        );
1617        assert_eq!(
1618            BoundedToString::to_string(&expr, Some(3)),
1619            r#"[100, [3, 4, 5], (-20), ..]"#
1620        );
1621        assert_eq!(
1622            BoundedToString::to_string(&expr, Some(2)),
1623            r#"[100, [3, 4, ..], ..]"#
1624        );
1625        assert_eq!(BoundedToString::to_string(&expr, Some(1)), r#"[100, ..]"#);
1626        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"[..]"#);
1627
1628        let expr = parse_expr(
1629            r#"{
1630            a: 12,
1631            b: [3, 4, true],
1632            c: -20,
1633            "hello ∞ world": "∂µß≈¥"
1634        }"#,
1635        )
1636        .unwrap()
1637        .into_expr::<Builder>();
1638        assert_eq!(
1639            format!("{expr}"),
1640            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1641        );
1642        assert_eq!(
1643            BoundedToString::to_string(&expr, None),
1644            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1645        );
1646        assert_eq!(
1647            BoundedToString::to_string(&expr, Some(4)),
1648            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1649        );
1650        assert_eq!(
1651            BoundedToString::to_string(&expr, Some(3)),
1652            r#"{"a": 12, "b": [3, 4, true], "c": (-20), ..}"#
1653        );
1654        assert_eq!(
1655            BoundedToString::to_string(&expr, Some(2)),
1656            r#"{"a": 12, "b": [3, 4, ..], ..}"#
1657        );
1658        assert_eq!(
1659            BoundedToString::to_string(&expr, Some(1)),
1660            r#"{"a": 12, ..}"#
1661        );
1662        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"{..}"#);
1663    }
1664}