cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18use crate::ast::{self, EntityUID, InputInteger};
19use crate::entities::json::{
20    err::EscapeKind, err::JsonDeserializationError, err::JsonDeserializationErrorContext,
21    CedarValueJson, FnAndArg, TypeAndId,
22};
23use crate::extensions::Extensions;
24use crate::jsonvalue::JsonValueWithNoDuplicateKeys;
25use crate::parser::cst::{self, Ident};
26use crate::parser::cst_to_ast;
27use crate::parser::err::{ParseErrors, ToASTError, ToASTErrorKind};
28use crate::parser::unescape::to_unescaped_string;
29use crate::parser::util::flatten_tuple_2;
30use crate::parser::{Loc, Node};
31use either::Either;
32use itertools::Itertools;
33use serde::{de::Visitor, Deserialize, Serialize};
34use serde_with::serde_as;
35use smol_str::{SmolStr, ToSmolStr};
36use std::collections::{BTreeMap, HashMap};
37use std::sync::Arc;
38
39/// Serde JSON structure for a Cedar expression in the EST format
40#[derive(Debug, Clone, PartialEq, Serialize)]
41#[serde(untagged)]
42#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
43#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
44pub enum Expr {
45    /// Any Cedar expression other than an extension function call.
46    ExprNoExt(ExprNoExt),
47    /// Extension function call, where the key is the name of an extension
48    /// function or method.
49    ExtFuncCall(ExtFuncCall),
50}
51
52// Manual implementation of `Deserialize` is more efficient than the derived
53// implementation with `serde(untagged)`. In particular, if the key is valid for
54// `ExprNoExt` but there is a deserialization problem within the corresponding
55// value, the derived implementation would backtrack and try to deserialize as
56// `ExtFuncCall` with that key as the extension function name, but this manual
57// implementation instead eagerly errors out, taking advantage of the fact that
58// none of the keys for `ExprNoExt` are valid extension function names.
59//
60// See #1284.
61impl<'de> Deserialize<'de> for Expr {
62    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63    where
64        D: serde::Deserializer<'de>,
65    {
66        struct ExprVisitor;
67        impl<'de> Visitor<'de> for ExprVisitor {
68            type Value = Expr;
69            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70                formatter.write_str("JSON object representing an expression")
71            }
72
73            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
74            where
75                A: serde::de::MapAccess<'de>,
76            {
77                let (k, v): (SmolStr, JsonValueWithNoDuplicateKeys) = match map.next_entry()? {
78                    None => {
79                        return Err(serde::de::Error::custom(
80                            "empty map is not a valid expression",
81                        ))
82                    }
83                    Some((k, v)) => (k, v),
84                };
85                match map.next_key()? {
86                    None => (),
87                    Some(k2) => {
88                        let k2: SmolStr = k2;
89                        return Err(serde::de::Error::custom(format!("JSON object representing an `Expr` should have only one key, but found two keys: `{k}` and `{k2}`")));
90                    }
91                };
92                if cst_to_ast::is_known_extension_func_str(&k) {
93                    // `k` is the name of an extension function or method. We assume that
94                    // no such keys are valid keys for `ExprNoExt`, so we must parse as an
95                    // `ExtFuncCall`.
96                    let obj = serde_json::json!({ k: v });
97                    let extfunccall =
98                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
99                    Ok(Expr::ExtFuncCall(extfunccall))
100                } else {
101                    // not a valid extension function or method, so we expect it
102                    // to work for `ExprNoExt`.
103                    let obj = serde_json::json!({ k: v });
104                    let exprnoext =
105                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
106                    Ok(Expr::ExprNoExt(exprnoext))
107                }
108            }
109        }
110
111        deserializer.deserialize_map(ExprVisitor)
112    }
113}
114
115/// Represent an element of a pattern literal
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
118#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
119pub enum PatternElem {
120    /// The wildcard asterisk
121    Wildcard,
122    /// A string without any wildcards
123    Literal(SmolStr),
124}
125
126impl From<Vec<PatternElem>> for crate::ast::Pattern {
127    fn from(value: Vec<PatternElem>) -> Self {
128        let mut elems = Vec::new();
129        for elem in value {
130            match elem {
131                PatternElem::Wildcard => {
132                    elems.push(crate::ast::PatternElem::Wildcard);
133                }
134                PatternElem::Literal(s) => {
135                    elems.extend(s.chars().map(crate::ast::PatternElem::Char));
136                }
137            }
138        }
139        Self::new(elems)
140    }
141}
142
143impl From<crate::ast::PatternElem> for PatternElem {
144    fn from(value: crate::ast::PatternElem) -> Self {
145        match value {
146            crate::ast::PatternElem::Wildcard => Self::Wildcard,
147            crate::ast::PatternElem::Char(c) => Self::Literal(c.to_smolstr()),
148        }
149    }
150}
151
152impl From<crate::ast::Pattern> for Vec<PatternElem> {
153    fn from(value: crate::ast::Pattern) -> Self {
154        value.iter().map(|elem| (*elem).into()).collect()
155    }
156}
157
158/// Serde JSON structure for [any Cedar expression other than an extension
159/// function call] in the EST format
160#[serde_as]
161#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
162#[serde(deny_unknown_fields)]
163#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
164#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
165pub enum ExprNoExt {
166    /// Literal value (including anything that's legal to express in the
167    /// attribute-value JSON format)
168    Value(CedarValueJson),
169    /// Var
170    Var(ast::Var),
171    /// Template slot
172    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
173    /// `!`
174    #[serde(rename = "!")]
175    Not {
176        /// Argument
177        arg: Arc<Expr>,
178    },
179    /// `-`
180    #[serde(rename = "neg")]
181    Neg {
182        /// Argument
183        arg: Arc<Expr>,
184    },
185    /// `==`
186    #[serde(rename = "==")]
187    Eq {
188        /// Left-hand argument
189        left: Arc<Expr>,
190        /// Right-hand argument
191        right: Arc<Expr>,
192    },
193    /// `!=`
194    #[serde(rename = "!=")]
195    NotEq {
196        /// Left-hand argument
197        left: Arc<Expr>,
198        /// Right-hand argument
199        right: Arc<Expr>,
200    },
201    /// `in`
202    #[serde(rename = "in")]
203    In {
204        /// Left-hand argument
205        left: Arc<Expr>,
206        /// Right-hand argument
207        right: Arc<Expr>,
208    },
209    /// `<`
210    #[serde(rename = "<")]
211    Less {
212        /// Left-hand argument
213        left: Arc<Expr>,
214        /// Right-hand argument
215        right: Arc<Expr>,
216    },
217    /// `<=`
218    #[serde(rename = "<=")]
219    LessEq {
220        /// Left-hand argument
221        left: Arc<Expr>,
222        /// Right-hand argument
223        right: Arc<Expr>,
224    },
225    /// `>`
226    #[serde(rename = ">")]
227    Greater {
228        /// Left-hand argument
229        left: Arc<Expr>,
230        /// Right-hand argument
231        right: Arc<Expr>,
232    },
233    /// `>=`
234    #[serde(rename = ">=")]
235    GreaterEq {
236        /// Left-hand argument
237        left: Arc<Expr>,
238        /// Right-hand argument
239        right: Arc<Expr>,
240    },
241    /// `&&`
242    #[serde(rename = "&&")]
243    And {
244        /// Left-hand argument
245        left: Arc<Expr>,
246        /// Right-hand argument
247        right: Arc<Expr>,
248    },
249    /// `||`
250    #[serde(rename = "||")]
251    Or {
252        /// Left-hand argument
253        left: Arc<Expr>,
254        /// Right-hand argument
255        right: Arc<Expr>,
256    },
257    /// `+`
258    #[serde(rename = "+")]
259    Add {
260        /// Left-hand argument
261        left: Arc<Expr>,
262        /// Right-hand argument
263        right: Arc<Expr>,
264    },
265    /// `-`
266    #[serde(rename = "-")]
267    Sub {
268        /// Left-hand argument
269        left: Arc<Expr>,
270        /// Right-hand argument
271        right: Arc<Expr>,
272    },
273    /// `*`
274    #[serde(rename = "*")]
275    Mul {
276        /// Left-hand argument
277        left: Arc<Expr>,
278        /// Right-hand argument
279        right: Arc<Expr>,
280    },
281    /// `contains()`
282    #[serde(rename = "contains")]
283    Contains {
284        /// Left-hand argument (receiver)
285        left: Arc<Expr>,
286        /// Right-hand argument (inside the `()`)
287        right: Arc<Expr>,
288    },
289    /// `containsAll()`
290    #[serde(rename = "containsAll")]
291    ContainsAll {
292        /// Left-hand argument (receiver)
293        left: Arc<Expr>,
294        /// Right-hand argument (inside the `()`)
295        right: Arc<Expr>,
296    },
297    /// `containsAny()`
298    #[serde(rename = "containsAny")]
299    ContainsAny {
300        /// Left-hand argument (receiver)
301        left: Arc<Expr>,
302        /// Right-hand argument (inside the `()`)
303        right: Arc<Expr>,
304    },
305    /// `getTag()`
306    #[serde(rename = "getTag")]
307    GetTag {
308        /// Left-hand argument (receiver)
309        left: Arc<Expr>,
310        /// Right-hand argument (inside the `()`)
311        right: Arc<Expr>,
312    },
313    /// `hasTag()`
314    #[serde(rename = "hasTag")]
315    HasTag {
316        /// Left-hand argument (receiver)
317        left: Arc<Expr>,
318        /// Right-hand argument (inside the `()`)
319        right: Arc<Expr>,
320    },
321    /// Get-attribute
322    #[serde(rename = ".")]
323    GetAttr {
324        /// Left-hand argument
325        left: Arc<Expr>,
326        /// Attribute name
327        attr: SmolStr,
328    },
329    /// `has`
330    #[serde(rename = "has")]
331    HasAttr {
332        /// Left-hand argument
333        left: Arc<Expr>,
334        /// Attribute name
335        attr: SmolStr,
336    },
337    /// `like`
338    #[serde(rename = "like")]
339    Like {
340        /// Left-hand argument
341        left: Arc<Expr>,
342        /// Pattern
343        pattern: Vec<PatternElem>,
344    },
345    /// `<entity> is <entity_type> in <entity_or_entity_set> `
346    #[serde(rename = "is")]
347    Is {
348        /// Left-hand entity argument
349        left: Arc<Expr>,
350        /// Entity type
351        entity_type: SmolStr,
352        /// Entity or entity set
353        #[serde(skip_serializing_if = "Option::is_none")]
354        #[serde(rename = "in")]
355        in_expr: Option<Arc<Expr>>,
356    },
357    /// Ternary
358    #[serde(rename = "if-then-else")]
359    If {
360        /// Condition
361        #[serde(rename = "if")]
362        cond_expr: Arc<Expr>,
363        /// `then` expression
364        #[serde(rename = "then")]
365        then_expr: Arc<Expr>,
366        /// `else` expression
367        #[serde(rename = "else")]
368        else_expr: Arc<Expr>,
369    },
370    /// Set literal, whose elements may be arbitrary expressions
371    /// (which is why we need this case specifically and can't just
372    /// use Expr::Value)
373    Set(Vec<Expr>),
374    /// Record literal, whose elements may be arbitrary expressions
375    /// (which is why we need this case specifically and can't just
376    /// use Expr::Value)
377    Record(
378        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
379        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
380        HashMap<SmolStr, Expr>,
381    ),
382}
383
384/// Serde JSON structure for an extension function call in the EST format
385#[serde_as]
386#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
387#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
388#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
389pub struct ExtFuncCall {
390    /// maps the name of the function to a JSON list/array of the arguments.
391    /// Note that for method calls, the method receiver is the first argument.
392    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
393    /// second argument is `b`.
394    ///
395    /// INVARIANT: This map should always have exactly one k-v pair (not more or
396    /// less), but we make it a map in order to get the correct JSON structure
397    /// we want.
398    #[serde(flatten)]
399    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
400    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
401    call: HashMap<SmolStr, Vec<Expr>>,
402}
403
404#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
405impl Expr {
406    /// literal
407    pub fn lit(lit: CedarValueJson) -> Self {
408        Expr::ExprNoExt(ExprNoExt::Value(lit))
409    }
410
411    /// principal, action, resource, context
412    pub fn var(var: ast::Var) -> Self {
413        Expr::ExprNoExt(ExprNoExt::Var(var))
414    }
415
416    /// Template slots
417    pub fn slot(slot: ast::SlotId) -> Self {
418        Expr::ExprNoExt(ExprNoExt::Slot(slot))
419    }
420
421    /// An extension call with one arg, which is the name of the unknown
422    pub fn unknown(name: impl Into<SmolStr>) -> Self {
423        Expr::ext_call(
424            "unknown".into(),
425            vec![Expr::lit(CedarValueJson::String(name.into()))],
426        )
427    }
428
429    /// `!`
430    pub fn not(e: Expr) -> Self {
431        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
432    }
433
434    /// `-`
435    pub fn neg(e: Expr) -> Self {
436        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
437    }
438
439    /// `==`
440    pub fn eq(left: Expr, right: Expr) -> Self {
441        Expr::ExprNoExt(ExprNoExt::Eq {
442            left: Arc::new(left),
443            right: Arc::new(right),
444        })
445    }
446
447    /// `!=`
448    pub fn noteq(left: Expr, right: Expr) -> Self {
449        Expr::ExprNoExt(ExprNoExt::NotEq {
450            left: Arc::new(left),
451            right: Arc::new(right),
452        })
453    }
454
455    /// `in`
456    pub fn _in(left: Expr, right: Expr) -> Self {
457        Expr::ExprNoExt(ExprNoExt::In {
458            left: Arc::new(left),
459            right: Arc::new(right),
460        })
461    }
462
463    /// `<`
464    pub fn less(left: Expr, right: Expr) -> Self {
465        Expr::ExprNoExt(ExprNoExt::Less {
466            left: Arc::new(left),
467            right: Arc::new(right),
468        })
469    }
470
471    /// `<=`
472    pub fn lesseq(left: Expr, right: Expr) -> Self {
473        Expr::ExprNoExt(ExprNoExt::LessEq {
474            left: Arc::new(left),
475            right: Arc::new(right),
476        })
477    }
478
479    /// `>`
480    pub fn greater(left: Expr, right: Expr) -> Self {
481        Expr::ExprNoExt(ExprNoExt::Greater {
482            left: Arc::new(left),
483            right: Arc::new(right),
484        })
485    }
486
487    /// `>=`
488    pub fn greatereq(left: Expr, right: Expr) -> Self {
489        Expr::ExprNoExt(ExprNoExt::GreaterEq {
490            left: Arc::new(left),
491            right: Arc::new(right),
492        })
493    }
494
495    /// `&&`
496    pub fn and(left: Expr, right: Expr) -> Self {
497        Expr::ExprNoExt(ExprNoExt::And {
498            left: Arc::new(left),
499            right: Arc::new(right),
500        })
501    }
502
503    /// `||`
504    pub fn or(left: Expr, right: Expr) -> Self {
505        Expr::ExprNoExt(ExprNoExt::Or {
506            left: Arc::new(left),
507            right: Arc::new(right),
508        })
509    }
510
511    /// `+`
512    pub fn add(left: Expr, right: Expr) -> Self {
513        Expr::ExprNoExt(ExprNoExt::Add {
514            left: Arc::new(left),
515            right: Arc::new(right),
516        })
517    }
518
519    /// `-`
520    pub fn sub(left: Expr, right: Expr) -> Self {
521        Expr::ExprNoExt(ExprNoExt::Sub {
522            left: Arc::new(left),
523            right: Arc::new(right),
524        })
525    }
526
527    /// `*`
528    pub fn mul(left: Expr, right: Expr) -> Self {
529        Expr::ExprNoExt(ExprNoExt::Mul {
530            left: Arc::new(left),
531            right: Arc::new(right),
532        })
533    }
534
535    /// `left.contains(right)`
536    pub fn contains(left: Arc<Expr>, right: Expr) -> Self {
537        Expr::ExprNoExt(ExprNoExt::Contains {
538            left,
539            right: Arc::new(right),
540        })
541    }
542
543    /// `left.containsAll(right)`
544    pub fn contains_all(left: Arc<Expr>, right: Expr) -> Self {
545        Expr::ExprNoExt(ExprNoExt::ContainsAll {
546            left,
547            right: Arc::new(right),
548        })
549    }
550
551    /// `left.containsAny(right)`
552    pub fn contains_any(left: Arc<Expr>, right: Expr) -> Self {
553        Expr::ExprNoExt(ExprNoExt::ContainsAny {
554            left,
555            right: Arc::new(right),
556        })
557    }
558
559    /// `left.getTag(right)`
560    pub fn get_tag(left: Arc<Expr>, right: Expr) -> Self {
561        Expr::ExprNoExt(ExprNoExt::GetTag {
562            left,
563            right: Arc::new(right),
564        })
565    }
566
567    /// `left.hasTag(right)`
568    pub fn has_tag(left: Arc<Expr>, right: Expr) -> Self {
569        Expr::ExprNoExt(ExprNoExt::HasTag {
570            left,
571            right: Arc::new(right),
572        })
573    }
574
575    /// `left.attr`
576    pub fn get_attr(left: Expr, attr: SmolStr) -> Self {
577        Expr::ExprNoExt(ExprNoExt::GetAttr {
578            left: Arc::new(left),
579            attr,
580        })
581    }
582
583    /// `left has attr`
584    pub fn has_attr(left: Expr, attr: SmolStr) -> Self {
585        Expr::ExprNoExt(ExprNoExt::HasAttr {
586            left: Arc::new(left),
587            attr,
588        })
589    }
590
591    /// `left like pattern`
592    pub fn like(left: Expr, pattern: impl IntoIterator<Item = PatternElem>) -> Self {
593        Expr::ExprNoExt(ExprNoExt::Like {
594            left: Arc::new(left),
595            pattern: pattern.into_iter().collect(),
596        })
597    }
598
599    /// `left is entity_type`
600    pub fn is_entity_type(left: Expr, entity_type: SmolStr) -> Self {
601        Expr::ExprNoExt(ExprNoExt::Is {
602            left: Arc::new(left),
603            entity_type,
604            in_expr: None,
605        })
606    }
607
608    /// `left is entity_type in entity`
609    pub fn is_entity_type_in(left: Expr, entity_type: SmolStr, entity: Expr) -> Self {
610        Expr::ExprNoExt(ExprNoExt::Is {
611            left: Arc::new(left),
612            entity_type,
613            in_expr: Some(Arc::new(entity)),
614        })
615    }
616
617    /// `if cond_expr then then_expr else else_expr`
618    pub fn ite(cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
619        Expr::ExprNoExt(ExprNoExt::If {
620            cond_expr: Arc::new(cond_expr),
621            then_expr: Arc::new(then_expr),
622            else_expr: Arc::new(else_expr),
623        })
624    }
625
626    /// e.g. [1+2, !(context has department)]
627    pub fn set(elements: Vec<Expr>) -> Self {
628        Expr::ExprNoExt(ExprNoExt::Set(elements))
629    }
630
631    /// e.g. {foo: 1+2, bar: !(context has department)}
632    pub fn record(map: HashMap<SmolStr, Expr>) -> Self {
633        Expr::ExprNoExt(ExprNoExt::Record(map))
634    }
635
636    /// extension function call, including method calls
637    pub fn ext_call(fn_name: SmolStr, args: Vec<Expr>) -> Self {
638        Expr::ExtFuncCall(ExtFuncCall {
639            call: [(fn_name, args)].into_iter().collect(),
640        })
641    }
642
643    /// Consume the `Expr`, producing a string literal if it was a string literal, otherwise returns the literal in the `Err` variant.
644    pub fn into_string_literal(self) -> Result<SmolStr, Self> {
645        match self {
646            Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::String(s))) => Ok(s),
647            _ => Err(self),
648        }
649    }
650
651    /// Substitute entity literals
652    pub fn sub_entity_literals(
653        self,
654        mapping: &BTreeMap<EntityUID, EntityUID>,
655    ) -> Result<Self, JsonDeserializationError> {
656        match self.clone() {
657            Expr::ExprNoExt(e) => match e.clone() {
658                ExprNoExt::Value(v) => Ok(Expr::ExprNoExt(ExprNoExt::Value(
659                    v.sub_entity_literals(mapping)?,
660                ))),
661                ExprNoExt::Var(_) => Ok(self.clone()),
662                ExprNoExt::Slot(_) => Ok(self.clone()),
663                ExprNoExt::Not { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Not {
664                    arg: Arc::new((*arg).clone().sub_entity_literals(mapping)?),
665                })),
666                ExprNoExt::Neg { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Neg {
667                    arg: Arc::new((*arg).clone().sub_entity_literals(mapping)?),
668                })),
669                ExprNoExt::Eq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Eq {
670                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
671                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
672                })),
673                ExprNoExt::NotEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::NotEq {
674                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
675                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
676                })),
677                ExprNoExt::In { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::In {
678                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
679                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
680                })),
681                ExprNoExt::Less { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Less {
682                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
683                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
684                })),
685                ExprNoExt::LessEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::LessEq {
686                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
687                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
688                })),
689                ExprNoExt::Greater { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Greater {
690                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
691                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
692                })),
693                ExprNoExt::GreaterEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GreaterEq {
694                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
695                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
696                })),
697                ExprNoExt::And { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::And {
698                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
699                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
700                })),
701                ExprNoExt::Or { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Or {
702                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
703                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
704                })),
705                ExprNoExt::Add { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Add {
706                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
707                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
708                })),
709                ExprNoExt::Sub { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Sub {
710                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
711                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
712                })),
713                ExprNoExt::Mul { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Mul {
714                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
715                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
716                })),
717                ExprNoExt::Contains { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Contains {
718                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
719                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
720                })),
721                ExprNoExt::ContainsAll { left, right } => {
722                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAll {
723                        left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
724                        right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
725                    }))
726                }
727                ExprNoExt::ContainsAny { left, right } => {
728                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAny {
729                        left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
730                        right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
731                    }))
732                }
733                ExprNoExt::GetTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GetTag {
734                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
735                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
736                })),
737                ExprNoExt::HasTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::HasTag {
738                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
739                    right: Arc::new((*right).clone().sub_entity_literals(mapping)?),
740                })),
741                ExprNoExt::GetAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::GetAttr {
742                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
743                    attr,
744                })),
745                ExprNoExt::HasAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::HasAttr {
746                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
747                    attr,
748                })),
749                ExprNoExt::Like { left, pattern } => Ok(Expr::ExprNoExt(ExprNoExt::Like {
750                    left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
751                    pattern,
752                })),
753                ExprNoExt::Is {
754                    left,
755                    entity_type,
756                    in_expr,
757                } => match in_expr {
758                    Some(in_expr) => Ok(Expr::ExprNoExt(ExprNoExt::Is {
759                        left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
760                        entity_type,
761                        in_expr: Some(Arc::new((*in_expr).clone().sub_entity_literals(mapping)?)),
762                    })),
763                    None => Ok(Expr::ExprNoExt(ExprNoExt::Is {
764                        left: Arc::new((*left).clone().sub_entity_literals(mapping)?),
765                        entity_type,
766                        in_expr: None,
767                    })),
768                },
769                ExprNoExt::If {
770                    cond_expr,
771                    then_expr,
772                    else_expr,
773                } => Ok(Expr::ExprNoExt(ExprNoExt::If {
774                    cond_expr: Arc::new((*cond_expr).clone().sub_entity_literals(mapping)?),
775                    then_expr: Arc::new((*then_expr).clone().sub_entity_literals(mapping)?),
776                    else_expr: Arc::new((*else_expr).clone().sub_entity_literals(mapping)?),
777                })),
778                ExprNoExt::Set(v) => {
779                    let mut new_v = vec![];
780                    for e in v {
781                        new_v.push(e.sub_entity_literals(mapping)?);
782                    }
783                    Ok(Expr::ExprNoExt(ExprNoExt::Set(new_v)))
784                }
785                ExprNoExt::Record(m) => {
786                    let mut new_m = HashMap::new();
787                    for (k, v) in m {
788                        new_m.insert(k, v.sub_entity_literals(mapping)?);
789                    }
790                    Ok(Expr::ExprNoExt(ExprNoExt::Record(new_m)))
791                }
792            },
793            Expr::ExtFuncCall(e_fn_call) => {
794                let mut new_m = HashMap::new();
795                for (k, v) in e_fn_call.call {
796                    let mut new_v = vec![];
797                    for e in v {
798                        new_v.push(e.sub_entity_literals(mapping)?);
799                    }
800                    new_m.insert(k, new_v);
801                }
802                Ok(Expr::ExtFuncCall(ExtFuncCall { call: new_m }))
803            }
804        }
805    }
806}
807
808impl Expr {
809    /// Attempt to convert this `est::Expr` into an `ast::Expr`
810    ///
811    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
812    pub fn try_into_ast(self, id: ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
813        match self {
814            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
815                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
816                .map(Into::into)
817                .map_err(Into::into),
818            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
819            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
820            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
821                Ok(ast::Expr::not((*arg).clone().try_into_ast(id)?))
822            }
823            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
824                Ok(ast::Expr::neg((*arg).clone().try_into_ast(id)?))
825            }
826            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
827                (*left).clone().try_into_ast(id.clone())?,
828                (*right).clone().try_into_ast(id)?,
829            )),
830            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
831                (*left).clone().try_into_ast(id.clone())?,
832                (*right).clone().try_into_ast(id)?,
833            )),
834            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
835                (*left).clone().try_into_ast(id.clone())?,
836                (*right).clone().try_into_ast(id)?,
837            )),
838            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
839                (*left).clone().try_into_ast(id.clone())?,
840                (*right).clone().try_into_ast(id)?,
841            )),
842            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
843                (*left).clone().try_into_ast(id.clone())?,
844                (*right).clone().try_into_ast(id)?,
845            )),
846            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
847                (*left).clone().try_into_ast(id.clone())?,
848                (*right).clone().try_into_ast(id)?,
849            )),
850            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
851                (*left).clone().try_into_ast(id.clone())?,
852                (*right).clone().try_into_ast(id)?,
853            )),
854            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
855                (*left).clone().try_into_ast(id.clone())?,
856                (*right).clone().try_into_ast(id)?,
857            )),
858            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
859                (*left).clone().try_into_ast(id.clone())?,
860                (*right).clone().try_into_ast(id)?,
861            )),
862            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
863                (*left).clone().try_into_ast(id.clone())?,
864                (*right).clone().try_into_ast(id)?,
865            )),
866            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
867                (*left).clone().try_into_ast(id.clone())?,
868                (*right).clone().try_into_ast(id)?,
869            )),
870            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
871                (*left).clone().try_into_ast(id.clone())?,
872                (*right).clone().try_into_ast(id)?,
873            )),
874            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
875                (*left).clone().try_into_ast(id.clone())?,
876                (*right).clone().try_into_ast(id)?,
877            )),
878            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
879                (*left).clone().try_into_ast(id.clone())?,
880                (*right).clone().try_into_ast(id)?,
881            )),
882            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
883                (*left).clone().try_into_ast(id.clone())?,
884                (*right).clone().try_into_ast(id)?,
885            )),
886            Expr::ExprNoExt(ExprNoExt::GetTag { left, right }) => Ok(ast::Expr::get_tag(
887                (*left).clone().try_into_ast(id.clone())?,
888                (*right).clone().try_into_ast(id)?,
889            )),
890            Expr::ExprNoExt(ExprNoExt::HasTag { left, right }) => Ok(ast::Expr::has_tag(
891                (*left).clone().try_into_ast(id.clone())?,
892                (*right).clone().try_into_ast(id)?,
893            )),
894            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => {
895                Ok(ast::Expr::get_attr((*left).clone().try_into_ast(id)?, attr))
896            }
897            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => {
898                Ok(ast::Expr::has_attr((*left).clone().try_into_ast(id)?, attr))
899            }
900            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => Ok(ast::Expr::like(
901                (*left).clone().try_into_ast(id)?,
902                crate::ast::Pattern::from(pattern).iter().cloned(),
903            )),
904            Expr::ExprNoExt(ExprNoExt::Is {
905                left,
906                entity_type,
907                in_expr,
908            }) => ast::EntityType::from_normalized_str(entity_type.as_str())
909                .map_err(FromJsonError::InvalidEntityType)
910                .and_then(|entity_type_name| {
911                    let left: ast::Expr = (*left).clone().try_into_ast(id.clone())?;
912                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
913                    match in_expr {
914                        // The AST doesn't have an `... is ... in ..` node, so
915                        // we represent it as a conjunction of `is` and `in`.
916                        Some(in_expr) => Ok(ast::Expr::and(
917                            is_expr,
918                            ast::Expr::is_in(left, (*in_expr).clone().try_into_ast(id)?),
919                        )),
920                        None => Ok(is_expr),
921                    }
922                }),
923            Expr::ExprNoExt(ExprNoExt::If {
924                cond_expr,
925                then_expr,
926                else_expr,
927            }) => Ok(ast::Expr::ite(
928                (*cond_expr).clone().try_into_ast(id.clone())?,
929                (*then_expr).clone().try_into_ast(id.clone())?,
930                (*else_expr).clone().try_into_ast(id)?,
931            )),
932            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
933                elements
934                    .into_iter()
935                    .map(|el| el.try_into_ast(id.clone()))
936                    .collect::<Result<Vec<_>, FromJsonError>>()?,
937            )),
938            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
939                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
940                #[allow(clippy::expect_used)]
941                Ok(ast::Expr::record(
942                    map.into_iter()
943                        .map(|(k, v)| Ok((k, v.try_into_ast(id.clone())?)))
944                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
945                )
946                .expect("can't have duplicate keys here because the input was already a HashMap"))
947            }
948            Expr::ExtFuncCall(ExtFuncCall { call }) => {
949                match call.len() {
950                    0 => Err(FromJsonError::MissingOperator),
951                    1 => {
952                        // PANIC SAFETY checked that `call.len() == 1`
953                        #[allow(clippy::expect_used)]
954                        let (fn_name, args) = call
955                            .into_iter()
956                            .next()
957                            .expect("already checked that len was 1");
958                        let fn_name: ast::Name = fn_name.parse().map_err(|errs| {
959                            JsonDeserializationError::parse_escape(
960                                EscapeKind::Extension,
961                                fn_name,
962                                errs,
963                            )
964                        })?;
965                        if !cst_to_ast::is_known_extension_func_name(&fn_name) {
966                            return Err(FromJsonError::UnknownExtensionFunction(fn_name.clone()));
967                        }
968                        Ok(ast::Expr::call_extension_fn(
969                            fn_name,
970                            args.into_iter()
971                                .map(|arg| arg.try_into_ast(id.clone()))
972                                .collect::<Result<_, _>>()?,
973                        ))
974                    }
975                    _ => Err(FromJsonError::MultipleOperators {
976                        ops: call.into_keys().collect(),
977                    }),
978                }
979            }
980        }
981    }
982}
983
984impl<T: Clone> From<ast::Expr<T>> for Expr {
985    fn from(expr: ast::Expr<T>) -> Expr {
986        match expr.into_expr_kind() {
987            ast::ExprKind::Lit(lit) => lit.into(),
988            ast::ExprKind::Var(var) => var.into(),
989            ast::ExprKind::Slot(slot) => slot.into(),
990            ast::ExprKind::Unknown(ast::Unknown { name, .. }) => Expr::unknown(name),
991            ast::ExprKind::If {
992                test_expr,
993                then_expr,
994                else_expr,
995            } => Expr::ite(
996                Arc::unwrap_or_clone(test_expr).into(),
997                Arc::unwrap_or_clone(then_expr).into(),
998                Arc::unwrap_or_clone(else_expr).into(),
999            ),
1000            ast::ExprKind::And { left, right } => Expr::and(
1001                Arc::unwrap_or_clone(left).into(),
1002                Arc::unwrap_or_clone(right).into(),
1003            ),
1004            ast::ExprKind::Or { left, right } => Expr::or(
1005                Arc::unwrap_or_clone(left).into(),
1006                Arc::unwrap_or_clone(right).into(),
1007            ),
1008            ast::ExprKind::UnaryApp { op, arg } => {
1009                let arg = Arc::unwrap_or_clone(arg).into();
1010                match op {
1011                    ast::UnaryOp::Not => Expr::not(arg),
1012                    ast::UnaryOp::Neg => Expr::neg(arg),
1013                }
1014            }
1015            ast::ExprKind::BinaryApp { op, arg1, arg2 } => {
1016                let arg1 = Arc::unwrap_or_clone(arg1).into();
1017                let arg2 = Arc::unwrap_or_clone(arg2).into();
1018                match op {
1019                    ast::BinaryOp::Eq => Expr::eq(arg1, arg2),
1020                    ast::BinaryOp::In => Expr::_in(arg1, arg2),
1021                    ast::BinaryOp::Less => Expr::less(arg1, arg2),
1022                    ast::BinaryOp::LessEq => Expr::lesseq(arg1, arg2),
1023                    ast::BinaryOp::Add => Expr::add(arg1, arg2),
1024                    ast::BinaryOp::Sub => Expr::sub(arg1, arg2),
1025                    ast::BinaryOp::Mul => Expr::mul(arg1, arg2),
1026                    ast::BinaryOp::Contains => Expr::contains(Arc::new(arg1), arg2),
1027                    ast::BinaryOp::ContainsAll => Expr::contains_all(Arc::new(arg1), arg2),
1028                    ast::BinaryOp::ContainsAny => Expr::contains_any(Arc::new(arg1), arg2),
1029                    ast::BinaryOp::GetTag => Expr::get_tag(Arc::new(arg1), arg2),
1030                    ast::BinaryOp::HasTag => Expr::has_tag(Arc::new(arg1), arg2),
1031                }
1032            }
1033            ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
1034                let args = Arc::unwrap_or_clone(args)
1035                    .into_iter()
1036                    .map(Into::into)
1037                    .collect();
1038                Expr::ext_call(fn_name.to_string().into(), args)
1039            }
1040            ast::ExprKind::GetAttr { expr, attr } => {
1041                Expr::get_attr(Arc::unwrap_or_clone(expr).into(), attr)
1042            }
1043            ast::ExprKind::HasAttr { expr, attr } => {
1044                Expr::has_attr(Arc::unwrap_or_clone(expr).into(), attr)
1045            }
1046            ast::ExprKind::Like { expr, pattern } => Expr::like(
1047                Arc::unwrap_or_clone(expr).into(),
1048                Vec::<PatternElem>::from(pattern),
1049            ),
1050            ast::ExprKind::Is { expr, entity_type } => Expr::is_entity_type(
1051                Arc::unwrap_or_clone(expr).into(),
1052                entity_type.to_string().into(),
1053            ),
1054            ast::ExprKind::Set(set) => Expr::set(
1055                Arc::unwrap_or_clone(set)
1056                    .into_iter()
1057                    .map(Into::into)
1058                    .collect(),
1059            ),
1060            ast::ExprKind::Record(map) => Expr::record(
1061                Arc::unwrap_or_clone(map)
1062                    .into_iter()
1063                    .map(|(k, v)| (k, v.into()))
1064                    .collect(),
1065            ),
1066        }
1067    }
1068}
1069
1070impl From<ast::Literal> for Expr {
1071    fn from(lit: ast::Literal) -> Expr {
1072        Expr::lit(CedarValueJson::from_lit(lit))
1073    }
1074}
1075
1076impl From<ast::Var> for Expr {
1077    fn from(var: ast::Var) -> Expr {
1078        Expr::var(var)
1079    }
1080}
1081
1082impl From<ast::SlotId> for Expr {
1083    fn from(slot: ast::SlotId) -> Expr {
1084        Expr::slot(slot)
1085    }
1086}
1087
1088impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
1089    type Error = ParseErrors;
1090    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
1091        match &*e.try_as_inner()?.expr {
1092            cst::ExprData::Or(node) => node.try_into(),
1093            cst::ExprData::If(if_node, then_node, else_node) => {
1094                let cond_expr = if_node.try_into()?;
1095                let then_expr = then_node.try_into()?;
1096                let else_expr = else_node.try_into()?;
1097                Ok(Expr::ite(cond_expr, then_expr, else_expr))
1098            }
1099        }
1100    }
1101}
1102
1103impl TryFrom<&Node<Option<cst::Or>>> for Expr {
1104    type Error = ParseErrors;
1105    fn try_from(o: &Node<Option<cst::Or>>) -> Result<Expr, ParseErrors> {
1106        let o_node = o.try_as_inner()?;
1107        let mut expr = (&o_node.initial).try_into()?;
1108        for node in &o_node.extended {
1109            let rhs = node.try_into()?;
1110            expr = Expr::or(expr, rhs);
1111        }
1112        Ok(expr)
1113    }
1114}
1115
1116impl TryFrom<&Node<Option<cst::And>>> for Expr {
1117    type Error = ParseErrors;
1118    fn try_from(a: &Node<Option<cst::And>>) -> Result<Expr, ParseErrors> {
1119        let a_node = a.try_as_inner()?;
1120        let mut expr = (&a_node.initial).try_into()?;
1121        for node in &a_node.extended {
1122            let rhs = node.try_into()?;
1123            expr = Expr::and(expr, rhs);
1124        }
1125        Ok(expr)
1126    }
1127}
1128
1129impl TryFrom<&Node<Option<cst::Relation>>> for Expr {
1130    type Error = ParseErrors;
1131    fn try_from(r: &Node<Option<cst::Relation>>) -> Result<Expr, ParseErrors> {
1132        match r.try_as_inner()? {
1133            cst::Relation::Common { initial, extended } => {
1134                let mut expr = initial.try_into()?;
1135                for (op, node) in extended {
1136                    let rhs = node.try_into()?;
1137                    match op {
1138                        cst::RelOp::Eq => {
1139                            expr = Expr::eq(expr, rhs);
1140                        }
1141                        cst::RelOp::NotEq => {
1142                            expr = Expr::noteq(expr, rhs);
1143                        }
1144                        cst::RelOp::In => {
1145                            expr = Expr::_in(expr, rhs);
1146                        }
1147                        cst::RelOp::Less => {
1148                            expr = Expr::less(expr, rhs);
1149                        }
1150                        cst::RelOp::LessEq => {
1151                            expr = Expr::lesseq(expr, rhs);
1152                        }
1153                        cst::RelOp::Greater => {
1154                            expr = Expr::greater(expr, rhs);
1155                        }
1156                        cst::RelOp::GreaterEq => {
1157                            expr = Expr::greatereq(expr, rhs);
1158                        }
1159                        cst::RelOp::InvalidSingleEq => {
1160                            return Err(ToASTError::new(
1161                                ToASTErrorKind::InvalidSingleEq,
1162                                r.loc.clone(),
1163                            )
1164                            .into());
1165                        }
1166                    }
1167                }
1168                Ok(expr)
1169            }
1170            cst::Relation::Has { target, field } => {
1171                let target_expr = target.try_into()?;
1172                field
1173                    .to_expr_or_special()?
1174                    .into_valid_attr()
1175                    .map(|attr| Expr::has_attr(target_expr, attr))
1176            }
1177            cst::Relation::Like { target, pattern } => {
1178                let target_expr = target.try_into()?;
1179                pattern
1180                    .to_expr_or_special()?
1181                    .into_pattern()
1182                    .map(|pat| Expr::like(target_expr, pat.into_iter().map(PatternElem::from)))
1183            }
1184            cst::Relation::IsIn {
1185                target,
1186                entity_type,
1187                in_entity,
1188            } => {
1189                let target = target.try_into()?;
1190                let type_str = entity_type.try_as_inner()?.to_string().into();
1191                match in_entity {
1192                    Some(in_entity) => Ok(Expr::is_entity_type_in(
1193                        target,
1194                        type_str,
1195                        in_entity.try_into()?,
1196                    )),
1197                    None => Ok(Expr::is_entity_type(target, type_str)),
1198                }
1199            }
1200        }
1201    }
1202}
1203
1204impl TryFrom<&Node<Option<cst::Add>>> for Expr {
1205    type Error = ParseErrors;
1206    fn try_from(a: &Node<Option<cst::Add>>) -> Result<Expr, ParseErrors> {
1207        let a_node = a.try_as_inner()?;
1208        let mut expr = (&a_node.initial).try_into()?;
1209        for (op, node) in &a_node.extended {
1210            let rhs = node.try_into()?;
1211            match op {
1212                cst::AddOp::Plus => {
1213                    expr = Expr::add(expr, rhs);
1214                }
1215                cst::AddOp::Minus => {
1216                    expr = Expr::sub(expr, rhs);
1217                }
1218            }
1219        }
1220        Ok(expr)
1221    }
1222}
1223
1224impl TryFrom<&Node<Option<cst::Mult>>> for Expr {
1225    type Error = ParseErrors;
1226    fn try_from(m: &Node<Option<cst::Mult>>) -> Result<Expr, ParseErrors> {
1227        let m_node = m.try_as_inner()?;
1228        let mut expr = (&m_node.initial).try_into()?;
1229        for (op, node) in &m_node.extended {
1230            let rhs = node.try_into()?;
1231            match op {
1232                cst::MultOp::Times => {
1233                    expr = Expr::mul(expr, rhs);
1234                }
1235                cst::MultOp::Divide => {
1236                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedDivision).into())
1237                }
1238                cst::MultOp::Mod => {
1239                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedModulo).into())
1240                }
1241            }
1242        }
1243        Ok(expr)
1244    }
1245}
1246
1247impl TryFrom<&Node<Option<cst::Unary>>> for Expr {
1248    type Error = ParseErrors;
1249    fn try_from(u: &Node<Option<cst::Unary>>) -> Result<Expr, ParseErrors> {
1250        let u_node = u.try_as_inner()?;
1251
1252        match u_node.op {
1253            Some(cst::NegOp::Bang(num_bangs)) => {
1254                let inner = (&u_node.item).try_into()?;
1255                match num_bangs {
1256                    0 => Ok(inner),
1257                    1 => Ok(Expr::not(inner)),
1258                    2 => Ok(Expr::not(Expr::not(inner))),
1259                    3 => Ok(Expr::not(Expr::not(Expr::not(inner)))),
1260                    4 => Ok(Expr::not(Expr::not(Expr::not(Expr::not(inner))))),
1261                    _ => Err(u
1262                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
1263                        .into()),
1264                }
1265            }
1266            Some(cst::NegOp::Dash(0)) => Ok((&u_node.item).try_into()?),
1267            Some(cst::NegOp::Dash(mut num_dashes)) => {
1268                let inner = match &u_node.item.to_lit() {
1269                    Some(cst::Literal::Num(num)) => {
1270                        match num.cmp(&(InputInteger::MAX as u64 + 1)) {
1271                            std::cmp::Ordering::Less => {
1272                                num_dashes -= 1;
1273                                Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(
1274                                    -(*num as InputInteger),
1275                                )))
1276                            }
1277                            std::cmp::Ordering::Equal => {
1278                                num_dashes -= 1;
1279                                Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(
1280                                    InputInteger::MIN,
1281                                )))
1282                            }
1283                            std::cmp::Ordering::Greater => {
1284                                return Err(u_node
1285                                    .item
1286                                    .to_ast_err(ToASTErrorKind::IntegerLiteralTooLarge(*num))
1287                                    .into());
1288                            }
1289                        }
1290                    }
1291                    _ => (&u_node.item).try_into()?,
1292                };
1293                match num_dashes {
1294                    0 => Ok(inner),
1295                    1 => Ok(Expr::neg(inner)),
1296                    2 => {
1297                        // not safe to collapse `--` to nothing
1298                        Ok(Expr::neg(Expr::neg(inner)))
1299                    }
1300                    3 => Ok(Expr::neg(Expr::neg(Expr::neg(inner)))),
1301                    4 => Ok(Expr::neg(Expr::neg(Expr::neg(Expr::neg(inner))))),
1302                    _ => Err(u
1303                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1304                        .into()),
1305                }
1306            }
1307            Some(cst::NegOp::OverBang) => Err(u
1308                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
1309                .into()),
1310            Some(cst::NegOp::OverDash) => Err(u
1311                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1312                .into()),
1313            None => Ok((&u_node.item).try_into()?),
1314        }
1315    }
1316}
1317
1318/// Convert the given `cst::Primary` into either a (possibly namespaced)
1319/// function name, or an `Expr`.
1320///
1321/// (Upstream, the case where the `Primary` is a function name needs special
1322/// handling, because in that case it is not a valid expression. In all other
1323/// cases a `Primary` can be converted into an `Expr`.)
1324fn interpret_primary(
1325    p: &Node<Option<cst::Primary>>,
1326) -> Result<Either<ast::Name, Expr>, ParseErrors> {
1327    match p.try_as_inner()? {
1328        cst::Primary::Literal(lit) => Ok(Either::Right(lit.try_into()?)),
1329        cst::Primary::Ref(node) => match node.try_as_inner()? {
1330            cst::Ref::Uid {
1331                path,
1332                eid: eid_node,
1333            } => {
1334                let maybe_name = path.to_name().map(ast::EntityType::from);
1335                let maybe_eid = eid_node.as_valid_string();
1336
1337                let (name, eid) = flatten_tuple_2(maybe_name, maybe_eid)?;
1338                match to_unescaped_string(eid) {
1339                    Ok(eid) => Ok(Either::Right(Expr::lit(CedarValueJson::EntityEscape {
1340                        __entity: TypeAndId::from(ast::EntityUID::from_components(
1341                            name,
1342                            ast::Eid::new(eid),
1343                            None,
1344                        )),
1345                    }))),
1346                    Err(unescape_errs) => {
1347                        Err(ParseErrors::new_from_nonempty(unescape_errs.map(|err| {
1348                            {
1349                                crate::parser::err::ParseError::from(
1350                                    eid_node.to_ast_err(ToASTErrorKind::Unescape(err)),
1351                                )
1352                            }
1353                        })))
1354                    }
1355                }
1356            }
1357            r @ cst::Ref::Ref { .. } => Err(node
1358                .to_ast_err(ToASTErrorKind::InvalidEntityLiteral(r.to_string()))
1359                .into()),
1360        },
1361        cst::Primary::Name(node) => {
1362            let name = node.try_as_inner()?;
1363            let base_name = name.name.try_as_inner()?;
1364            match (&name.path[..], base_name) {
1365                (&[], cst::Ident::Principal) => Ok(Either::Right(Expr::var(ast::Var::Principal))),
1366                (&[], cst::Ident::Action) => Ok(Either::Right(Expr::var(ast::Var::Action))),
1367                (&[], cst::Ident::Resource) => Ok(Either::Right(Expr::var(ast::Var::Resource))),
1368                (&[], cst::Ident::Context) => Ok(Either::Right(Expr::var(ast::Var::Context))),
1369                (path, cst::Ident::Ident(id)) => Ok(Either::Left(
1370                    ast::InternalName::new(
1371                        id.parse()?,
1372                        path.iter()
1373                            .map(|node| {
1374                                node.try_as_inner()
1375                                    .map_err(Into::into)
1376                                    .and_then(|id| id.to_string().parse().map_err(Into::into))
1377                            })
1378                            .collect::<Result<Vec<ast::Id>, ParseErrors>>()?,
1379                        Some(node.loc.clone()),
1380                    )
1381                    .try_into()?,
1382                )),
1383                (path, id) => {
1384                    let (l, r, src) = match (path.first(), path.last()) {
1385                        (Some(l), Some(r)) => (
1386                            l.loc.start(),
1387                            r.loc.end() + ident_to_str_len(id),
1388                            Arc::clone(&l.loc.src),
1389                        ),
1390                        (_, _) => (0, 0, Arc::from("")),
1391                    };
1392                    Err(ToASTError::new(
1393                        ToASTErrorKind::ArbitraryVariable(name.to_string().into()),
1394                        Loc::new(l..r, src),
1395                    )
1396                    .into())
1397                }
1398            }
1399        }
1400        cst::Primary::Slot(node) => Ok(Either::Right(Expr::slot(
1401            node.try_as_inner()?
1402                .try_into()
1403                .map_err(|e| node.to_ast_err(e))?,
1404        ))),
1405        cst::Primary::Expr(e) => Ok(Either::Right(e.try_into()?)),
1406        cst::Primary::EList(nodes) => nodes
1407            .iter()
1408            .map(|node| node.try_into())
1409            .collect::<Result<Vec<Expr>, _>>()
1410            .map(Expr::set)
1411            .map(Either::Right),
1412        cst::Primary::RInits(nodes) => nodes
1413            .iter()
1414            .map(|node| {
1415                let cst::RecInit(k, v) = node.try_as_inner()?;
1416                let s = k.to_expr_or_special().and_then(|es| es.into_valid_attr())?;
1417                Ok((s, v.try_into()?))
1418            })
1419            .collect::<Result<HashMap<SmolStr, Expr>, ParseErrors>>()
1420            .map(Expr::record)
1421            .map(Either::Right),
1422    }
1423}
1424
1425impl TryFrom<&Node<Option<cst::Member>>> for Expr {
1426    type Error = ParseErrors;
1427    fn try_from(m: &Node<Option<cst::Member>>) -> Result<Expr, ParseErrors> {
1428        let m_node = m.try_as_inner()?;
1429        let mut item: Either<ast::Name, Expr> = interpret_primary(&m_node.item)?;
1430        for access in &m_node.access {
1431            match access.try_as_inner()? {
1432                cst::MemAccess::Field(node) => {
1433                    let id = node.to_valid_ident()?;
1434                    item = match item {
1435                        Either::Left(name) => {
1436                            return Err(node
1437                                .to_ast_err(ToASTErrorKind::InvalidAccess(name, id.to_smolstr()))
1438                                .into())
1439                        }
1440                        Either::Right(expr) => Either::Right(Expr::get_attr(expr, id.to_smolstr())),
1441                    };
1442                }
1443                cst::MemAccess::Call(args) => {
1444                    // we have item(args).  We hope item is either:
1445                    //   - an `ast::Name`, in which case we have a standard function call
1446                    //   - or an expr of the form `x.y`, in which case y is the method
1447                    //      name and not a field name. In the previous iteration of the
1448                    //      `for` loop we would have made `item` equal to
1449                    //      `Expr::GetAttr(x, y)`. Now we have to undo that to make a
1450                    //      method call instead.
1451                    //   - any other expression: it's an illegal call as the target is a higher order expression
1452                    item = match item {
1453                        Either::Left(name) => Either::Right(Expr::ext_call(
1454                            name.to_string().into(),
1455                            args.iter()
1456                                .map(|node| node.try_into())
1457                                .collect::<Result<Vec<_>, _>>()?,
1458                        )),
1459                        Either::Right(Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr })) => {
1460                            let args = args.iter().map(|node| node.try_into()).collect::<Result<
1461                                Vec<Expr>,
1462                                ParseErrors,
1463                            >>(
1464                            )?;
1465                            let args = args.into_iter();
1466                            match attr.as_str() {
1467                                "contains" => Either::Right(Expr::contains(
1468                                    left,
1469                                    extract_single_argument(args, "contains()", &access.loc)?,
1470                                )),
1471                                "containsAll" => Either::Right(Expr::contains_all(
1472                                    left,
1473                                    extract_single_argument(args, "containsAll()", &access.loc)?,
1474                                )),
1475                                "containsAny" => Either::Right(Expr::contains_any(
1476                                    left,
1477                                    extract_single_argument(args, "containsAny()", &access.loc)?,
1478                                )),
1479                                "getTag" => Either::Right(Expr::get_tag(
1480                                    left,
1481                                    extract_single_argument(args, "getTag()", &access.loc)?,
1482                                )),
1483                                "hasTag" => Either::Right(Expr::has_tag(
1484                                    left,
1485                                    extract_single_argument(args, "hasTag()", &access.loc)?,
1486                                )),
1487                                _ => {
1488                                    // have to add the "receiver" argument as
1489                                    // first in the list for the method call
1490                                    let mut args = args.collect::<Vec<_>>();
1491                                    args.insert(0, Arc::unwrap_or_clone(left));
1492                                    Either::Right(Expr::ext_call(attr, args))
1493                                }
1494                            }
1495                        }
1496                        _ => return Err(access.to_ast_err(ToASTErrorKind::ExpressionCall).into()),
1497                    };
1498                }
1499                cst::MemAccess::Index(node) => {
1500                    let s = Expr::try_from(node)?
1501                        .into_string_literal()
1502                        .map_err(|_| node.to_ast_err(ToASTErrorKind::NonStringIndex))?;
1503                    item = match item {
1504                        Either::Left(name) => {
1505                            return Err(node
1506                                .to_ast_err(ToASTErrorKind::InvalidIndex(name, s))
1507                                .into())
1508                        }
1509                        Either::Right(expr) => Either::Right(Expr::get_attr(expr, s)),
1510                    };
1511                }
1512            }
1513        }
1514        match item {
1515            Either::Left(_) => Err(m.to_ast_err(ToASTErrorKind::MembershipInvariantViolation))?,
1516            Either::Right(expr) => Ok(expr),
1517        }
1518    }
1519}
1520
1521/// Return the single argument in `args` iterator, or return a wrong arity error
1522/// if the iterator has 0 elements or more than 1 element.
1523pub fn extract_single_argument<T>(
1524    args: impl ExactSizeIterator<Item = T>,
1525    fn_name: &'static str,
1526    loc: &Loc,
1527) -> Result<T, ParseErrors> {
1528    let mut iter = args.fuse().peekable();
1529    let first = iter.next();
1530    let second = iter.peek();
1531    match (first, second) {
1532        (None, _) => Err(ParseErrors::singleton(ToASTError::new(
1533            ToASTErrorKind::wrong_arity(fn_name, 1, 0),
1534            loc.clone(),
1535        ))),
1536        (Some(_), Some(_)) => Err(ParseErrors::singleton(ToASTError::new(
1537            ToASTErrorKind::wrong_arity(fn_name, 1, iter.len() + 1),
1538            loc.clone(),
1539        ))),
1540        (Some(first), None) => Ok(first),
1541    }
1542}
1543
1544impl TryFrom<&Node<Option<cst::Literal>>> for Expr {
1545    type Error = ParseErrors;
1546    fn try_from(lit: &Node<Option<cst::Literal>>) -> Result<Expr, ParseErrors> {
1547        match lit.try_as_inner()? {
1548            cst::Literal::True => Ok(Expr::lit(CedarValueJson::Bool(true))),
1549            cst::Literal::False => Ok(Expr::lit(CedarValueJson::Bool(false))),
1550            cst::Literal::Num(n) => Ok(Expr::lit(CedarValueJson::Long(
1551                (*n).try_into()
1552                    .map_err(|_| lit.to_ast_err(ToASTErrorKind::IntegerLiteralTooLarge(*n)))?,
1553            ))),
1554            cst::Literal::Str(node) => match node.try_as_inner()? {
1555                cst::Str::String(s) => match to_unescaped_string(s) {
1556                    Ok(s) => Ok(Expr::lit(CedarValueJson::String(s))),
1557                    Err(errs) => {
1558                        Err(ParseErrors::new_from_nonempty(errs.map(|err| {
1559                            node.to_ast_err(ToASTErrorKind::Unescape(err)).into()
1560                        })))
1561                    }
1562                },
1563                cst::Str::Invalid(invalid_str) => Err(node
1564                    .to_ast_err(ToASTErrorKind::InvalidString(invalid_str.to_string()))
1565                    .into()),
1566            },
1567        }
1568    }
1569}
1570
1571impl TryFrom<&Node<Option<cst::Name>>> for Expr {
1572    type Error = ParseErrors;
1573    fn try_from(name: &Node<Option<cst::Name>>) -> Result<Expr, ParseErrors> {
1574        let name_node = name.try_as_inner()?;
1575        let base_name = name_node.name.try_as_inner()?;
1576        match (&name_node.path[..], base_name) {
1577            (&[], cst::Ident::Principal) => Ok(Expr::var(ast::Var::Principal)),
1578            (&[], cst::Ident::Action) => Ok(Expr::var(ast::Var::Action)),
1579            (&[], cst::Ident::Resource) => Ok(Expr::var(ast::Var::Resource)),
1580            (&[], cst::Ident::Context) => Ok(Expr::var(ast::Var::Context)),
1581            (_, _) => Err(name
1582                .to_ast_err(ToASTErrorKind::ArbitraryVariable(
1583                    name_node.to_string().into(),
1584                ))
1585                .into()),
1586        }
1587    }
1588}
1589
1590/// Get the string length of an `Ident`. Used to print the source location for error messages
1591fn ident_to_str_len(i: &Ident) -> usize {
1592    match i {
1593        Ident::Principal => 9,
1594        Ident::Action => 6,
1595        Ident::Resource => 8,
1596        Ident::Context => 7,
1597        Ident::True => 4,
1598        Ident::False => 5,
1599        Ident::Permit => 6,
1600        Ident::Forbid => 6,
1601        Ident::When => 4,
1602        Ident::Unless => 6,
1603        Ident::In => 2,
1604        Ident::Has => 3,
1605        Ident::Like => 4,
1606        Ident::If => 2,
1607        Ident::Then => 4,
1608        Ident::Else => 4,
1609        Ident::Ident(s) => s.len(),
1610        Ident::Invalid(s) => s.len(),
1611        Ident::Is => 2,
1612    }
1613}
1614
1615impl std::fmt::Display for Expr {
1616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1617        match self {
1618            Self::ExprNoExt(e) => write!(f, "{e}"),
1619            Self::ExtFuncCall(e) => write!(f, "{e}"),
1620        }
1621    }
1622}
1623
1624fn display_cedarvaluejson(f: &mut std::fmt::Formatter<'_>, v: &CedarValueJson) -> std::fmt::Result {
1625    match v {
1626        // Add parentheses around negative numeric literals otherwise
1627        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1628        CedarValueJson::Long(n) if *n < 0 => write!(f, "({n})"),
1629        CedarValueJson::Long(n) => write!(f, "{n}"),
1630        CedarValueJson::Bool(b) => write!(f, "{b}"),
1631        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1632        CedarValueJson::EntityEscape { __entity } => {
1633            match ast::EntityUID::try_from(__entity.clone()) {
1634                Ok(euid) => write!(f, "{euid}"),
1635                Err(e) => write!(f, "(invalid entity uid: {})", e),
1636            }
1637        }
1638        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1639        CedarValueJson::ExtnEscape {
1640            __extn: FnAndArg { ext_fn, arg },
1641        } => {
1642            // search for the name and callstyle
1643            let style = Extensions::all_available().all_funcs().find_map(|f| {
1644                if &f.name().to_string() == ext_fn {
1645                    Some(f.style())
1646                } else {
1647                    None
1648                }
1649            });
1650            match style {
1651                Some(ast::CallStyle::MethodStyle) => {
1652                    display_cedarvaluejson(f, arg)?;
1653                    write!(f, ".{ext_fn}()")?;
1654                    Ok(())
1655                }
1656                Some(ast::CallStyle::FunctionStyle) | None => {
1657                    write!(f, "{ext_fn}(")?;
1658                    display_cedarvaluejson(f, arg)?;
1659                    write!(f, ")")?;
1660                    Ok(())
1661                }
1662            }
1663        }
1664        CedarValueJson::Set(v) => {
1665            write!(f, "[")?;
1666            for (i, val) in v.iter().enumerate() {
1667                display_cedarvaluejson(f, val)?;
1668                if i < (v.len() - 1) {
1669                    write!(f, ", ")?;
1670                }
1671            }
1672            write!(f, "]")?;
1673            Ok(())
1674        }
1675        CedarValueJson::Record(m) => {
1676            write!(f, "{{")?;
1677            for (i, (k, v)) in m.iter().enumerate() {
1678                write!(f, "\"{}\": ", k.escape_debug())?;
1679                display_cedarvaluejson(f, v)?;
1680                if i < (m.len() - 1) {
1681                    write!(f, ", ")?;
1682                }
1683            }
1684            write!(f, "}}")?;
1685            Ok(())
1686        }
1687        CedarValueJson::Null => {
1688            write!(f, "null")?;
1689            Ok(())
1690        }
1691    }
1692}
1693
1694impl std::fmt::Display for ExprNoExt {
1695    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1696        match &self {
1697            ExprNoExt::Value(v) => display_cedarvaluejson(f, v),
1698            ExprNoExt::Var(v) => write!(f, "{v}"),
1699            ExprNoExt::Slot(id) => write!(f, "{id}"),
1700            ExprNoExt::Not { arg } => {
1701                write!(f, "!")?;
1702                maybe_with_parens(f, arg)
1703            }
1704            ExprNoExt::Neg { arg } => {
1705                // Always add parentheses instead of calling
1706                // `maybe_with_parens`.
1707                // This makes sure that we always get a negation operation back
1708                // (as opposed to e.g., a negative number) when parsing the
1709                // printed form, thus preserving the round-tripping property.
1710                write!(f, "-({arg})")
1711            }
1712            ExprNoExt::Eq { left, right } => {
1713                maybe_with_parens(f, left)?;
1714                write!(f, " == ")?;
1715                maybe_with_parens(f, right)
1716            }
1717            ExprNoExt::NotEq { left, right } => {
1718                maybe_with_parens(f, left)?;
1719                write!(f, " != ")?;
1720                maybe_with_parens(f, right)
1721            }
1722            ExprNoExt::In { left, right } => {
1723                maybe_with_parens(f, left)?;
1724                write!(f, " in ")?;
1725                maybe_with_parens(f, right)
1726            }
1727            ExprNoExt::Less { left, right } => {
1728                maybe_with_parens(f, left)?;
1729                write!(f, " < ")?;
1730                maybe_with_parens(f, right)
1731            }
1732            ExprNoExt::LessEq { left, right } => {
1733                maybe_with_parens(f, left)?;
1734                write!(f, " <= ")?;
1735                maybe_with_parens(f, right)
1736            }
1737            ExprNoExt::Greater { left, right } => {
1738                maybe_with_parens(f, left)?;
1739                write!(f, " > ")?;
1740                maybe_with_parens(f, right)
1741            }
1742            ExprNoExt::GreaterEq { left, right } => {
1743                maybe_with_parens(f, left)?;
1744                write!(f, " >= ")?;
1745                maybe_with_parens(f, right)
1746            }
1747            ExprNoExt::And { left, right } => {
1748                maybe_with_parens(f, left)?;
1749                write!(f, " && ")?;
1750                maybe_with_parens(f, right)
1751            }
1752            ExprNoExt::Or { left, right } => {
1753                maybe_with_parens(f, left)?;
1754                write!(f, " || ")?;
1755                maybe_with_parens(f, right)
1756            }
1757            ExprNoExt::Add { left, right } => {
1758                maybe_with_parens(f, left)?;
1759                write!(f, " + ")?;
1760                maybe_with_parens(f, right)
1761            }
1762            ExprNoExt::Sub { left, right } => {
1763                maybe_with_parens(f, left)?;
1764                write!(f, " - ")?;
1765                maybe_with_parens(f, right)
1766            }
1767            ExprNoExt::Mul { left, right } => {
1768                maybe_with_parens(f, left)?;
1769                write!(f, " * ")?;
1770                maybe_with_parens(f, right)
1771            }
1772            ExprNoExt::Contains { left, right } => {
1773                maybe_with_parens(f, left)?;
1774                write!(f, ".contains({right})")
1775            }
1776            ExprNoExt::ContainsAll { left, right } => {
1777                maybe_with_parens(f, left)?;
1778                write!(f, ".containsAll({right})")
1779            }
1780            ExprNoExt::ContainsAny { left, right } => {
1781                maybe_with_parens(f, left)?;
1782                write!(f, ".containsAny({right})")
1783            }
1784            ExprNoExt::GetTag { left, right } => {
1785                maybe_with_parens(f, left)?;
1786                write!(f, ".getTag({right})")
1787            }
1788            ExprNoExt::HasTag { left, right } => {
1789                maybe_with_parens(f, left)?;
1790                write!(f, ".hasTag({right})")
1791            }
1792            ExprNoExt::GetAttr { left, attr } => {
1793                maybe_with_parens(f, left)?;
1794                write!(f, "[\"{}\"]", attr.escape_debug())
1795            }
1796            ExprNoExt::HasAttr { left, attr } => {
1797                maybe_with_parens(f, left)?;
1798                write!(f, " has \"{}\"", attr.escape_debug())
1799            }
1800            ExprNoExt::Like { left, pattern } => {
1801                maybe_with_parens(f, left)?;
1802                write!(
1803                    f,
1804                    " like \"{}\"",
1805                    crate::ast::Pattern::from(pattern.clone())
1806                )
1807            }
1808            ExprNoExt::Is {
1809                left,
1810                entity_type,
1811                in_expr,
1812            } => {
1813                maybe_with_parens(f, left)?;
1814                write!(f, " is {entity_type}")?;
1815                match in_expr {
1816                    Some(in_expr) => {
1817                        write!(f, " in ")?;
1818                        maybe_with_parens(f, in_expr)
1819                    }
1820                    None => Ok(()),
1821                }
1822            }
1823            ExprNoExt::If {
1824                cond_expr,
1825                then_expr,
1826                else_expr,
1827            } => {
1828                write!(f, "if ")?;
1829                maybe_with_parens(f, cond_expr)?;
1830                write!(f, " then ")?;
1831                maybe_with_parens(f, then_expr)?;
1832                write!(f, " else ")?;
1833                maybe_with_parens(f, else_expr)
1834            }
1835            ExprNoExt::Set(v) => write!(f, "[{}]", v.iter().join(", ")),
1836            ExprNoExt::Record(m) => write!(
1837                f,
1838                "{{{}}}",
1839                m.iter()
1840                    .map(|(k, v)| format!("\"{}\": {}", k.escape_debug(), v))
1841                    .join(", ")
1842            ),
1843        }
1844    }
1845}
1846
1847impl std::fmt::Display for ExtFuncCall {
1848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1849        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1850        #[allow(clippy::unreachable)]
1851        let Some((fn_name, args)) = self.call.iter().next() else {
1852            unreachable!("invariant violated: empty ExtFuncCall")
1853        };
1854        // search for the name and callstyle
1855        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1856            if &ext_fn.name().to_string() == fn_name {
1857                Some(ext_fn.style())
1858            } else {
1859                None
1860            }
1861        });
1862        match (style, args.iter().next()) {
1863            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1864                maybe_with_parens(f, receiver)?;
1865                write!(f, ".{}({})", fn_name, args.iter().skip(1).join(", "))
1866            }
1867            (_, _) => {
1868                write!(f, "{}({})", fn_name, args.iter().join(", "))
1869            }
1870        }
1871    }
1872}
1873
1874/// returns the `Display` representation of the Expr, adding parens around
1875/// the entire string if necessary.
1876/// E.g., won't add parens for constants or `principal` etc, but will for things
1877/// like `(2 < 5)`.
1878/// When in doubt, add the parens.
1879fn maybe_with_parens(f: &mut std::fmt::Formatter<'_>, expr: &Expr) -> std::fmt::Result {
1880    match expr {
1881        Expr::ExprNoExt(ExprNoExt::Set(_)) |
1882        Expr::ExprNoExt(ExprNoExt::Record(_)) |
1883        Expr::ExprNoExt(ExprNoExt::Value(_)) |
1884        Expr::ExprNoExt(ExprNoExt::Var(_)) |
1885        Expr::ExprNoExt(ExprNoExt::Slot(_)) => write!(f, "{expr}"),
1886
1887        // we want parens here because things like parse((!x).y)
1888        // would be printed into !x.y which has a different meaning
1889        Expr::ExprNoExt(ExprNoExt::Not { .. }) |
1890        // we want parens here because things like parse((-x).y)
1891        // would be printed into -x.y which has a different meaning
1892        Expr::ExprNoExt(ExprNoExt::Neg { .. })  |
1893        Expr::ExprNoExt(ExprNoExt::Eq { .. }) |
1894        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) |
1895        Expr::ExprNoExt(ExprNoExt::In { .. }) |
1896        Expr::ExprNoExt(ExprNoExt::Less { .. }) |
1897        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) |
1898        Expr::ExprNoExt(ExprNoExt::Greater { .. }) |
1899        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) |
1900        Expr::ExprNoExt(ExprNoExt::And { .. }) |
1901        Expr::ExprNoExt(ExprNoExt::Or { .. }) |
1902        Expr::ExprNoExt(ExprNoExt::Add { .. }) |
1903        Expr::ExprNoExt(ExprNoExt::Sub { .. }) |
1904        Expr::ExprNoExt(ExprNoExt::Mul { .. }) |
1905        Expr::ExprNoExt(ExprNoExt::Contains { .. }) |
1906        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) |
1907        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) |
1908        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) |
1909        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) |
1910        Expr::ExprNoExt(ExprNoExt::GetTag { .. }) |
1911        Expr::ExprNoExt(ExprNoExt::HasTag { .. }) |
1912        Expr::ExprNoExt(ExprNoExt::Like { .. }) |
1913        Expr::ExprNoExt(ExprNoExt::Is { .. }) |
1914        Expr::ExprNoExt(ExprNoExt::If { .. }) |
1915        Expr::ExtFuncCall { .. } => write!(f, "({expr})"),
1916    }
1917}
1918
1919#[cfg(test)]
1920// PANIC SAFETY: this is unit test code
1921#[allow(clippy::indexing_slicing)]
1922// PANIC SAFETY: Unit Test Code
1923#[allow(clippy::panic)]
1924mod test {
1925    use crate::parser::err::ParseError;
1926
1927    use super::*;
1928    use cool_asserts::assert_matches;
1929
1930    #[test]
1931    fn test_invalid_expr_from_cst_name() {
1932        let src = "some_long_str";
1933        let path = vec![Node::with_source_loc(
1934            Some(cst::Ident::Ident(src.into())),
1935            Loc::new(0..12, Arc::from(src)),
1936        )];
1937        let name = Node::with_source_loc(Some(cst::Ident::Else), Loc::new(13..16, Arc::from(src)));
1938        let cst_name = Node::with_source_loc(
1939            Some(cst::Name { path, name }),
1940            Loc::new(0..16, Arc::from(src)),
1941        );
1942
1943        assert_matches!(Expr::try_from(&cst_name), Err(e) => {
1944            assert!(e.len() == 1);
1945            assert_matches!(&e[0],
1946                ParseError::ToAST(to_ast_error) => {
1947                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::ArbitraryVariable(s) => {
1948                        assert_eq!(s, "some_long_str::else");
1949                    });
1950                }
1951            );
1952        });
1953    }
1954}