cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18use crate::ast;
19use crate::ast::InputInteger;
20use crate::entities::json::{
21    err::EscapeKind, err::JsonDeserializationError, err::JsonDeserializationErrorContext,
22    CedarValueJson, FnAndArg, TypeAndId,
23};
24use crate::extensions::Extensions;
25use crate::parser::cst::{self, Ident};
26use crate::parser::err::{ParseErrors, ToASTError, ToASTErrorKind};
27use crate::parser::unescape::to_unescaped_string;
28use crate::parser::util::flatten_tuple_2;
29use crate::parser::{Loc, Node};
30use either::Either;
31use itertools::Itertools;
32use serde::{Deserialize, Serialize};
33use serde_with::serde_as;
34use smol_str::{SmolStr, ToSmolStr};
35use std::collections::HashMap;
36use std::sync::Arc;
37
38/// Serde JSON structure for a Cedar expression in the EST format
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40#[serde(untagged)]
41#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
42#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
43pub enum Expr {
44    /// Any Cedar expression other than an extension function call.
45    /// We try to match this first, see docs on #[serde(untagged)].
46    ExprNoExt(ExprNoExt),
47    /// If that didn't match (because the key is not one of the keys defined in
48    /// `ExprNoExt`), we assume we have an extension function call, where the
49    /// key is the name of an extension function or method.
50    ExtFuncCall(ExtFuncCall),
51}
52
53/// Represent an element of a pattern literal
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
56#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
57pub enum PatternElem {
58    /// The wildcard asterisk
59    Wildcard,
60    /// A string without any wildcards
61    Literal(SmolStr),
62}
63
64impl From<Vec<PatternElem>> for crate::ast::Pattern {
65    fn from(value: Vec<PatternElem>) -> Self {
66        let mut elems = Vec::new();
67        for elem in value {
68            match elem {
69                PatternElem::Wildcard => {
70                    elems.push(crate::ast::PatternElem::Wildcard);
71                }
72                PatternElem::Literal(s) => {
73                    elems.extend(s.chars().map(crate::ast::PatternElem::Char));
74                }
75            }
76        }
77        Self::new(elems)
78    }
79}
80
81impl From<crate::ast::PatternElem> for PatternElem {
82    fn from(value: crate::ast::PatternElem) -> Self {
83        match value {
84            crate::ast::PatternElem::Wildcard => Self::Wildcard,
85            crate::ast::PatternElem::Char(c) => Self::Literal(c.to_smolstr()),
86        }
87    }
88}
89
90impl From<crate::ast::Pattern> for Vec<PatternElem> {
91    fn from(value: crate::ast::Pattern) -> Self {
92        value.iter().map(|elem| (*elem).into()).collect()
93    }
94}
95
96/// Serde JSON structure for [any Cedar expression other than an extension
97/// function call] in the EST format
98#[serde_as]
99#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
100#[serde(deny_unknown_fields)]
101#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
102#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
103pub enum ExprNoExt {
104    /// Literal value (including anything that's legal to express in the
105    /// attribute-value JSON format)
106    Value(CedarValueJson),
107    /// Var
108    Var(ast::Var),
109    /// Template slot
110    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
111    /// `!`
112    #[serde(rename = "!")]
113    Not {
114        /// Argument
115        arg: Arc<Expr>,
116    },
117    /// `-`
118    #[serde(rename = "neg")]
119    Neg {
120        /// Argument
121        arg: Arc<Expr>,
122    },
123    /// `==`
124    #[serde(rename = "==")]
125    Eq {
126        /// Left-hand argument
127        left: Arc<Expr>,
128        /// Right-hand argument
129        right: Arc<Expr>,
130    },
131    /// `!=`
132    #[serde(rename = "!=")]
133    NotEq {
134        /// Left-hand argument
135        left: Arc<Expr>,
136        /// Right-hand argument
137        right: Arc<Expr>,
138    },
139    /// `in`
140    #[serde(rename = "in")]
141    In {
142        /// Left-hand argument
143        left: Arc<Expr>,
144        /// Right-hand argument
145        right: Arc<Expr>,
146    },
147    /// `<`
148    #[serde(rename = "<")]
149    Less {
150        /// Left-hand argument
151        left: Arc<Expr>,
152        /// Right-hand argument
153        right: Arc<Expr>,
154    },
155    /// `<=`
156    #[serde(rename = "<=")]
157    LessEq {
158        /// Left-hand argument
159        left: Arc<Expr>,
160        /// Right-hand argument
161        right: Arc<Expr>,
162    },
163    /// `>`
164    #[serde(rename = ">")]
165    Greater {
166        /// Left-hand argument
167        left: Arc<Expr>,
168        /// Right-hand argument
169        right: Arc<Expr>,
170    },
171    /// `>=`
172    #[serde(rename = ">=")]
173    GreaterEq {
174        /// Left-hand argument
175        left: Arc<Expr>,
176        /// Right-hand argument
177        right: Arc<Expr>,
178    },
179    /// `&&`
180    #[serde(rename = "&&")]
181    And {
182        /// Left-hand argument
183        left: Arc<Expr>,
184        /// Right-hand argument
185        right: Arc<Expr>,
186    },
187    /// `||`
188    #[serde(rename = "||")]
189    Or {
190        /// Left-hand argument
191        left: Arc<Expr>,
192        /// Right-hand argument
193        right: Arc<Expr>,
194    },
195    /// `+`
196    #[serde(rename = "+")]
197    Add {
198        /// Left-hand argument
199        left: Arc<Expr>,
200        /// Right-hand argument
201        right: Arc<Expr>,
202    },
203    /// `-`
204    #[serde(rename = "-")]
205    Sub {
206        /// Left-hand argument
207        left: Arc<Expr>,
208        /// Right-hand argument
209        right: Arc<Expr>,
210    },
211    /// `*`
212    #[serde(rename = "*")]
213    Mul {
214        /// Left-hand argument
215        left: Arc<Expr>,
216        /// Right-hand argument
217        right: Arc<Expr>,
218    },
219    /// `contains()`
220    #[serde(rename = "contains")]
221    Contains {
222        /// Left-hand argument (receiver)
223        left: Arc<Expr>,
224        /// Right-hand argument (inside the `()`)
225        right: Arc<Expr>,
226    },
227    /// `containsAll()`
228    #[serde(rename = "containsAll")]
229    ContainsAll {
230        /// Left-hand argument (receiver)
231        left: Arc<Expr>,
232        /// Right-hand argument (inside the `()`)
233        right: Arc<Expr>,
234    },
235    /// `containsAny()`
236    #[serde(rename = "containsAny")]
237    ContainsAny {
238        /// Left-hand argument (receiver)
239        left: Arc<Expr>,
240        /// Right-hand argument (inside the `()`)
241        right: Arc<Expr>,
242    },
243    /// Get-attribute
244    #[serde(rename = ".")]
245    GetAttr {
246        /// Left-hand argument
247        left: Arc<Expr>,
248        /// Attribute name
249        attr: SmolStr,
250    },
251    /// `has`
252    #[serde(rename = "has")]
253    HasAttr {
254        /// Left-hand argument
255        left: Arc<Expr>,
256        /// Attribute name
257        attr: SmolStr,
258    },
259    /// `like`
260    #[serde(rename = "like")]
261    Like {
262        /// Left-hand argument
263        left: Arc<Expr>,
264        /// Pattern
265        pattern: Vec<PatternElem>,
266    },
267    /// `<entity> is <entity_type> in <entity_or_entity_set> `
268    #[serde(rename = "is")]
269    Is {
270        /// Left-hand entity argument
271        left: Arc<Expr>,
272        /// Entity type
273        entity_type: SmolStr,
274        /// Entity or entity set
275        #[serde(skip_serializing_if = "Option::is_none")]
276        #[serde(rename = "in")]
277        in_expr: Option<Arc<Expr>>,
278    },
279    /// Ternary
280    #[serde(rename = "if-then-else")]
281    If {
282        /// Condition
283        #[serde(rename = "if")]
284        cond_expr: Arc<Expr>,
285        /// `then` expression
286        #[serde(rename = "then")]
287        then_expr: Arc<Expr>,
288        /// `else` expression
289        #[serde(rename = "else")]
290        else_expr: Arc<Expr>,
291    },
292    /// Set literal, whose elements may be arbitrary expressions
293    /// (which is why we need this case specifically and can't just
294    /// use Expr::Value)
295    Set(Vec<Expr>),
296    /// Record literal, whose elements may be arbitrary expressions
297    /// (which is why we need this case specifically and can't just
298    /// use Expr::Value)
299    Record(
300        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
301        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
302        HashMap<SmolStr, Expr>,
303    ),
304}
305
306/// Serde JSON structure for an extension function call in the EST format
307#[serde_as]
308#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
309#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
310#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
311pub struct ExtFuncCall {
312    /// maps the name of the function to a JSON list/array of the arguments.
313    /// Note that for method calls, the method receiver is the first argument.
314    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
315    /// second argument is `b`.
316    ///
317    /// INVARIANT: This map should always have exactly one k-v pair (not more or
318    /// less), but we make it a map in order to get the correct JSON structure
319    /// we want.
320    #[serde(flatten)]
321    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
322    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
323    call: HashMap<SmolStr, Vec<Expr>>,
324}
325
326#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
327impl Expr {
328    /// literal
329    pub fn lit(lit: CedarValueJson) -> Self {
330        Expr::ExprNoExt(ExprNoExt::Value(lit))
331    }
332
333    /// principal, action, resource, context
334    pub fn var(var: ast::Var) -> Self {
335        Expr::ExprNoExt(ExprNoExt::Var(var))
336    }
337
338    /// Template slots
339    pub fn slot(slot: ast::SlotId) -> Self {
340        Expr::ExprNoExt(ExprNoExt::Slot(slot))
341    }
342
343    /// An extension call with one arg, which is the name of the unknown
344    pub fn unknown(name: impl Into<SmolStr>) -> Self {
345        Expr::ext_call(
346            "unknown".into(),
347            vec![Expr::lit(CedarValueJson::String(name.into()))],
348        )
349    }
350
351    /// `!`
352    pub fn not(e: Expr) -> Self {
353        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
354    }
355
356    /// `-`
357    pub fn neg(e: Expr) -> Self {
358        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
359    }
360
361    /// `==`
362    pub fn eq(left: Expr, right: Expr) -> Self {
363        Expr::ExprNoExt(ExprNoExt::Eq {
364            left: Arc::new(left),
365            right: Arc::new(right),
366        })
367    }
368
369    /// `!=`
370    pub fn noteq(left: Expr, right: Expr) -> Self {
371        Expr::ExprNoExt(ExprNoExt::NotEq {
372            left: Arc::new(left),
373            right: Arc::new(right),
374        })
375    }
376
377    /// `in`
378    pub fn _in(left: Expr, right: Expr) -> Self {
379        Expr::ExprNoExt(ExprNoExt::In {
380            left: Arc::new(left),
381            right: Arc::new(right),
382        })
383    }
384
385    /// `<`
386    pub fn less(left: Expr, right: Expr) -> Self {
387        Expr::ExprNoExt(ExprNoExt::Less {
388            left: Arc::new(left),
389            right: Arc::new(right),
390        })
391    }
392
393    /// `<=`
394    pub fn lesseq(left: Expr, right: Expr) -> Self {
395        Expr::ExprNoExt(ExprNoExt::LessEq {
396            left: Arc::new(left),
397            right: Arc::new(right),
398        })
399    }
400
401    /// `>`
402    pub fn greater(left: Expr, right: Expr) -> Self {
403        Expr::ExprNoExt(ExprNoExt::Greater {
404            left: Arc::new(left),
405            right: Arc::new(right),
406        })
407    }
408
409    /// `>=`
410    pub fn greatereq(left: Expr, right: Expr) -> Self {
411        Expr::ExprNoExt(ExprNoExt::GreaterEq {
412            left: Arc::new(left),
413            right: Arc::new(right),
414        })
415    }
416
417    /// `&&`
418    pub fn and(left: Expr, right: Expr) -> Self {
419        Expr::ExprNoExt(ExprNoExt::And {
420            left: Arc::new(left),
421            right: Arc::new(right),
422        })
423    }
424
425    /// `||`
426    pub fn or(left: Expr, right: Expr) -> Self {
427        Expr::ExprNoExt(ExprNoExt::Or {
428            left: Arc::new(left),
429            right: Arc::new(right),
430        })
431    }
432
433    /// `+`
434    pub fn add(left: Expr, right: Expr) -> Self {
435        Expr::ExprNoExt(ExprNoExt::Add {
436            left: Arc::new(left),
437            right: Arc::new(right),
438        })
439    }
440
441    /// `-`
442    pub fn sub(left: Expr, right: Expr) -> Self {
443        Expr::ExprNoExt(ExprNoExt::Sub {
444            left: Arc::new(left),
445            right: Arc::new(right),
446        })
447    }
448
449    /// `*`
450    pub fn mul(left: Expr, right: Expr) -> Self {
451        Expr::ExprNoExt(ExprNoExt::Mul {
452            left: Arc::new(left),
453            right: Arc::new(right),
454        })
455    }
456
457    /// `left.contains(right)`
458    pub fn contains(left: Arc<Expr>, right: Expr) -> Self {
459        Expr::ExprNoExt(ExprNoExt::Contains {
460            left,
461            right: Arc::new(right),
462        })
463    }
464
465    /// `left.containsAll(right)`
466    pub fn contains_all(left: Arc<Expr>, right: Expr) -> Self {
467        Expr::ExprNoExt(ExprNoExt::ContainsAll {
468            left,
469            right: Arc::new(right),
470        })
471    }
472
473    /// `left.containsAny(right)`
474    pub fn contains_any(left: Arc<Expr>, right: Expr) -> Self {
475        Expr::ExprNoExt(ExprNoExt::ContainsAny {
476            left,
477            right: Arc::new(right),
478        })
479    }
480
481    /// `left.attr`
482    pub fn get_attr(left: Expr, attr: SmolStr) -> Self {
483        Expr::ExprNoExt(ExprNoExt::GetAttr {
484            left: Arc::new(left),
485            attr,
486        })
487    }
488
489    /// `left has attr`
490    pub fn has_attr(left: Expr, attr: SmolStr) -> Self {
491        Expr::ExprNoExt(ExprNoExt::HasAttr {
492            left: Arc::new(left),
493            attr,
494        })
495    }
496
497    /// `left like pattern`
498    pub fn like(left: Expr, pattern: impl IntoIterator<Item = PatternElem>) -> Self {
499        Expr::ExprNoExt(ExprNoExt::Like {
500            left: Arc::new(left),
501            pattern: pattern.into_iter().collect(),
502        })
503    }
504
505    /// `left is entity_type`
506    pub fn is_entity_type(left: Expr, entity_type: SmolStr) -> Self {
507        Expr::ExprNoExt(ExprNoExt::Is {
508            left: Arc::new(left),
509            entity_type,
510            in_expr: None,
511        })
512    }
513
514    /// `left is entity_type in entity`
515    pub fn is_entity_type_in(left: Expr, entity_type: SmolStr, entity: Expr) -> Self {
516        Expr::ExprNoExt(ExprNoExt::Is {
517            left: Arc::new(left),
518            entity_type,
519            in_expr: Some(Arc::new(entity)),
520        })
521    }
522
523    /// `if cond_expr then then_expr else else_expr`
524    pub fn ite(cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
525        Expr::ExprNoExt(ExprNoExt::If {
526            cond_expr: Arc::new(cond_expr),
527            then_expr: Arc::new(then_expr),
528            else_expr: Arc::new(else_expr),
529        })
530    }
531
532    /// e.g. [1+2, !(context has department)]
533    pub fn set(elements: Vec<Expr>) -> Self {
534        Expr::ExprNoExt(ExprNoExt::Set(elements))
535    }
536
537    /// e.g. {foo: 1+2, bar: !(context has department)}
538    pub fn record(map: HashMap<SmolStr, Expr>) -> Self {
539        Expr::ExprNoExt(ExprNoExt::Record(map))
540    }
541
542    /// extension function call, including method calls
543    pub fn ext_call(fn_name: SmolStr, args: Vec<Expr>) -> Self {
544        Expr::ExtFuncCall(ExtFuncCall {
545            call: [(fn_name, args)].into_iter().collect(),
546        })
547    }
548
549    /// Consume the `Expr`, producing a string literal if it was a string literal, otherwise returns the literal in the `Err` variant.
550    pub fn into_string_literal(self) -> Result<SmolStr, Self> {
551        match self {
552            Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::String(s))) => Ok(s),
553            _ => Err(self),
554        }
555    }
556}
557
558impl Expr {
559    /// Attempt to convert this `est::Expr` into an `ast::Expr`
560    ///
561    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
562    pub fn try_into_ast(self, id: ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
563        match self {
564            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
565                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
566                .map(Into::into)
567                .map_err(Into::into),
568            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
569            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
570            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
571                Ok(ast::Expr::not((*arg).clone().try_into_ast(id)?))
572            }
573            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
574                Ok(ast::Expr::neg((*arg).clone().try_into_ast(id)?))
575            }
576            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
577                (*left).clone().try_into_ast(id.clone())?,
578                (*right).clone().try_into_ast(id)?,
579            )),
580            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
581                (*left).clone().try_into_ast(id.clone())?,
582                (*right).clone().try_into_ast(id)?,
583            )),
584            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
585                (*left).clone().try_into_ast(id.clone())?,
586                (*right).clone().try_into_ast(id)?,
587            )),
588            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
589                (*left).clone().try_into_ast(id.clone())?,
590                (*right).clone().try_into_ast(id)?,
591            )),
592            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
593                (*left).clone().try_into_ast(id.clone())?,
594                (*right).clone().try_into_ast(id)?,
595            )),
596            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
597                (*left).clone().try_into_ast(id.clone())?,
598                (*right).clone().try_into_ast(id)?,
599            )),
600            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
601                (*left).clone().try_into_ast(id.clone())?,
602                (*right).clone().try_into_ast(id)?,
603            )),
604            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
605                (*left).clone().try_into_ast(id.clone())?,
606                (*right).clone().try_into_ast(id)?,
607            )),
608            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
609                (*left).clone().try_into_ast(id.clone())?,
610                (*right).clone().try_into_ast(id)?,
611            )),
612            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
613                (*left).clone().try_into_ast(id.clone())?,
614                (*right).clone().try_into_ast(id)?,
615            )),
616            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
617                (*left).clone().try_into_ast(id.clone())?,
618                (*right).clone().try_into_ast(id)?,
619            )),
620            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
621                (*left).clone().try_into_ast(id.clone())?,
622                (*right).clone().try_into_ast(id)?,
623            )),
624            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
625                (*left).clone().try_into_ast(id.clone())?,
626                (*right).clone().try_into_ast(id)?,
627            )),
628            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
629                (*left).clone().try_into_ast(id.clone())?,
630                (*right).clone().try_into_ast(id)?,
631            )),
632            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
633                (*left).clone().try_into_ast(id.clone())?,
634                (*right).clone().try_into_ast(id)?,
635            )),
636            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => {
637                Ok(ast::Expr::get_attr((*left).clone().try_into_ast(id)?, attr))
638            }
639            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => {
640                Ok(ast::Expr::has_attr((*left).clone().try_into_ast(id)?, attr))
641            }
642            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => Ok(ast::Expr::like(
643                (*left).clone().try_into_ast(id)?,
644                crate::ast::Pattern::from(pattern).iter().cloned(),
645            )),
646            Expr::ExprNoExt(ExprNoExt::Is {
647                left,
648                entity_type,
649                in_expr,
650            }) => ast::EntityType::from_normalized_str(entity_type.as_str())
651                .map_err(FromJsonError::InvalidEntityType)
652                .and_then(|entity_type_name| {
653                    let left: ast::Expr = (*left).clone().try_into_ast(id.clone())?;
654                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
655                    match in_expr {
656                        // The AST doesn't have an `... is ... in ..` node, so
657                        // we represent it as a conjunction of `is` and `in`.
658                        Some(in_expr) => Ok(ast::Expr::and(
659                            is_expr,
660                            ast::Expr::is_in(left, (*in_expr).clone().try_into_ast(id)?),
661                        )),
662                        None => Ok(is_expr),
663                    }
664                }),
665            Expr::ExprNoExt(ExprNoExt::If {
666                cond_expr,
667                then_expr,
668                else_expr,
669            }) => Ok(ast::Expr::ite(
670                (*cond_expr).clone().try_into_ast(id.clone())?,
671                (*then_expr).clone().try_into_ast(id.clone())?,
672                (*else_expr).clone().try_into_ast(id)?,
673            )),
674            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
675                elements
676                    .into_iter()
677                    .map(|el| el.try_into_ast(id.clone()))
678                    .collect::<Result<Vec<_>, FromJsonError>>()?,
679            )),
680            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
681                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
682                #[allow(clippy::expect_used)]
683                Ok(ast::Expr::record(
684                    map.into_iter()
685                        .map(|(k, v)| Ok((k, v.try_into_ast(id.clone())?)))
686                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
687                )
688                .expect("can't have duplicate keys here because the input was already a HashMap"))
689            }
690            Expr::ExtFuncCall(ExtFuncCall { call }) => {
691                match call.len() {
692                    0 => Err(FromJsonError::MissingOperator),
693                    1 => {
694                        // PANIC SAFETY checked that `call.len() == 1`
695                        #[allow(clippy::expect_used)]
696                        let (fn_name, args) = call
697                            .into_iter()
698                            .next()
699                            .expect("already checked that len was 1");
700                        let fn_name: ast::Name = fn_name.parse().map_err(|errs| {
701                            JsonDeserializationError::parse_escape(
702                                EscapeKind::Extension,
703                                fn_name,
704                                errs,
705                            )
706                        })?;
707                        if !fn_name.is_known_extension_func_name() {
708                            return Err(FromJsonError::UnknownExtensionFunction(fn_name.clone()));
709                        }
710                        Ok(ast::Expr::call_extension_fn(
711                            fn_name,
712                            args.into_iter()
713                                .map(|arg| arg.try_into_ast(id.clone()))
714                                .collect::<Result<_, _>>()?,
715                        ))
716                    }
717                    _ => Err(FromJsonError::MultipleOperators {
718                        ops: call.into_keys().collect(),
719                    }),
720                }
721            }
722        }
723    }
724}
725
726impl From<ast::Expr> for Expr {
727    fn from(expr: ast::Expr) -> Expr {
728        match expr.into_expr_kind() {
729            ast::ExprKind::Lit(lit) => lit.into(),
730            ast::ExprKind::Var(var) => var.into(),
731            ast::ExprKind::Slot(slot) => slot.into(),
732            ast::ExprKind::Unknown(ast::Unknown { name, .. }) => Expr::unknown(name),
733            ast::ExprKind::If {
734                test_expr,
735                then_expr,
736                else_expr,
737            } => Expr::ite(
738                Arc::unwrap_or_clone(test_expr).into(),
739                Arc::unwrap_or_clone(then_expr).into(),
740                Arc::unwrap_or_clone(else_expr).into(),
741            ),
742            ast::ExprKind::And { left, right } => Expr::and(
743                Arc::unwrap_or_clone(left).into(),
744                Arc::unwrap_or_clone(right).into(),
745            ),
746            ast::ExprKind::Or { left, right } => Expr::or(
747                Arc::unwrap_or_clone(left).into(),
748                Arc::unwrap_or_clone(right).into(),
749            ),
750            ast::ExprKind::UnaryApp { op, arg } => {
751                let arg = Arc::unwrap_or_clone(arg).into();
752                match op {
753                    ast::UnaryOp::Not => Expr::not(arg),
754                    ast::UnaryOp::Neg => Expr::neg(arg),
755                }
756            }
757            ast::ExprKind::BinaryApp { op, arg1, arg2 } => {
758                let arg1 = Arc::unwrap_or_clone(arg1).into();
759                let arg2 = Arc::unwrap_or_clone(arg2).into();
760                match op {
761                    ast::BinaryOp::Eq => Expr::eq(arg1, arg2),
762                    ast::BinaryOp::In => Expr::_in(arg1, arg2),
763                    ast::BinaryOp::Less => Expr::less(arg1, arg2),
764                    ast::BinaryOp::LessEq => Expr::lesseq(arg1, arg2),
765                    ast::BinaryOp::Add => Expr::add(arg1, arg2),
766                    ast::BinaryOp::Sub => Expr::sub(arg1, arg2),
767                    ast::BinaryOp::Mul => Expr::mul(arg1, arg2),
768                    ast::BinaryOp::Contains => Expr::contains(Arc::new(arg1), arg2),
769                    ast::BinaryOp::ContainsAll => Expr::contains_all(Arc::new(arg1), arg2),
770                    ast::BinaryOp::ContainsAny => Expr::contains_any(Arc::new(arg1), arg2),
771                }
772            }
773            ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
774                let args = Arc::unwrap_or_clone(args)
775                    .into_iter()
776                    .map(Into::into)
777                    .collect();
778                Expr::ext_call(fn_name.to_string().into(), args)
779            }
780            ast::ExprKind::GetAttr { expr, attr } => {
781                Expr::get_attr(Arc::unwrap_or_clone(expr).into(), attr)
782            }
783            ast::ExprKind::HasAttr { expr, attr } => {
784                Expr::has_attr(Arc::unwrap_or_clone(expr).into(), attr)
785            }
786            ast::ExprKind::Like { expr, pattern } => Expr::like(
787                Arc::unwrap_or_clone(expr).into(),
788                Vec::<PatternElem>::from(pattern),
789            ),
790            ast::ExprKind::Is { expr, entity_type } => Expr::is_entity_type(
791                Arc::unwrap_or_clone(expr).into(),
792                entity_type.to_string().into(),
793            ),
794            ast::ExprKind::Set(set) => Expr::set(
795                Arc::unwrap_or_clone(set)
796                    .into_iter()
797                    .map(Into::into)
798                    .collect(),
799            ),
800            ast::ExprKind::Record(map) => Expr::record(
801                Arc::unwrap_or_clone(map)
802                    .into_iter()
803                    .map(|(k, v)| (k, v.into()))
804                    .collect(),
805            ),
806        }
807    }
808}
809
810impl From<ast::Literal> for Expr {
811    fn from(lit: ast::Literal) -> Expr {
812        Expr::lit(CedarValueJson::from_lit(lit))
813    }
814}
815
816impl From<ast::Var> for Expr {
817    fn from(var: ast::Var) -> Expr {
818        Expr::var(var)
819    }
820}
821
822impl From<ast::SlotId> for Expr {
823    fn from(slot: ast::SlotId) -> Expr {
824        Expr::slot(slot)
825    }
826}
827
828impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
829    type Error = ParseErrors;
830    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
831        match &*e.try_as_inner()?.expr {
832            cst::ExprData::Or(node) => node.try_into(),
833            cst::ExprData::If(if_node, then_node, else_node) => {
834                let cond_expr = if_node.try_into()?;
835                let then_expr = then_node.try_into()?;
836                let else_expr = else_node.try_into()?;
837                Ok(Expr::ite(cond_expr, then_expr, else_expr))
838            }
839        }
840    }
841}
842
843impl TryFrom<&Node<Option<cst::Or>>> for Expr {
844    type Error = ParseErrors;
845    fn try_from(o: &Node<Option<cst::Or>>) -> Result<Expr, ParseErrors> {
846        let o_node = o.try_as_inner()?;
847        let mut expr = (&o_node.initial).try_into()?;
848        for node in &o_node.extended {
849            let rhs = node.try_into()?;
850            expr = Expr::or(expr, rhs);
851        }
852        Ok(expr)
853    }
854}
855
856impl TryFrom<&Node<Option<cst::And>>> for Expr {
857    type Error = ParseErrors;
858    fn try_from(a: &Node<Option<cst::And>>) -> Result<Expr, ParseErrors> {
859        let a_node = a.try_as_inner()?;
860        let mut expr = (&a_node.initial).try_into()?;
861        for node in &a_node.extended {
862            let rhs = node.try_into()?;
863            expr = Expr::and(expr, rhs);
864        }
865        Ok(expr)
866    }
867}
868
869impl TryFrom<&Node<Option<cst::Relation>>> for Expr {
870    type Error = ParseErrors;
871    fn try_from(r: &Node<Option<cst::Relation>>) -> Result<Expr, ParseErrors> {
872        match r.try_as_inner()? {
873            cst::Relation::Common { initial, extended } => {
874                let mut expr = initial.try_into()?;
875                for (op, node) in extended {
876                    let rhs = node.try_into()?;
877                    match op {
878                        cst::RelOp::Eq => {
879                            expr = Expr::eq(expr, rhs);
880                        }
881                        cst::RelOp::NotEq => {
882                            expr = Expr::noteq(expr, rhs);
883                        }
884                        cst::RelOp::In => {
885                            expr = Expr::_in(expr, rhs);
886                        }
887                        cst::RelOp::Less => {
888                            expr = Expr::less(expr, rhs);
889                        }
890                        cst::RelOp::LessEq => {
891                            expr = Expr::lesseq(expr, rhs);
892                        }
893                        cst::RelOp::Greater => {
894                            expr = Expr::greater(expr, rhs);
895                        }
896                        cst::RelOp::GreaterEq => {
897                            expr = Expr::greatereq(expr, rhs);
898                        }
899                        cst::RelOp::InvalidSingleEq => {
900                            return Err(ToASTError::new(
901                                ToASTErrorKind::InvalidSingleEq,
902                                r.loc.clone(),
903                            )
904                            .into());
905                        }
906                    }
907                }
908                Ok(expr)
909            }
910            cst::Relation::Has { target, field } => {
911                let target_expr = target.try_into()?;
912                field
913                    .to_expr_or_special()?
914                    .into_valid_attr()
915                    .map(|attr| Expr::has_attr(target_expr, attr))
916            }
917            cst::Relation::Like { target, pattern } => {
918                let target_expr = target.try_into()?;
919                pattern
920                    .to_expr_or_special()?
921                    .into_pattern()
922                    .map(|pat| Expr::like(target_expr, pat.into_iter().map(PatternElem::from)))
923            }
924            cst::Relation::IsIn {
925                target,
926                entity_type,
927                in_entity,
928            } => {
929                let target = target.try_into()?;
930                let type_str = entity_type.try_as_inner()?.to_string().into();
931                match in_entity {
932                    Some(in_entity) => Ok(Expr::is_entity_type_in(
933                        target,
934                        type_str,
935                        in_entity.try_into()?,
936                    )),
937                    None => Ok(Expr::is_entity_type(target, type_str)),
938                }
939            }
940        }
941    }
942}
943
944impl TryFrom<&Node<Option<cst::Add>>> for Expr {
945    type Error = ParseErrors;
946    fn try_from(a: &Node<Option<cst::Add>>) -> Result<Expr, ParseErrors> {
947        let a_node = a.try_as_inner()?;
948        let mut expr = (&a_node.initial).try_into()?;
949        for (op, node) in &a_node.extended {
950            let rhs = node.try_into()?;
951            match op {
952                cst::AddOp::Plus => {
953                    expr = Expr::add(expr, rhs);
954                }
955                cst::AddOp::Minus => {
956                    expr = Expr::sub(expr, rhs);
957                }
958            }
959        }
960        Ok(expr)
961    }
962}
963
964impl TryFrom<&Node<Option<cst::Mult>>> for Expr {
965    type Error = ParseErrors;
966    fn try_from(m: &Node<Option<cst::Mult>>) -> Result<Expr, ParseErrors> {
967        let m_node = m.try_as_inner()?;
968        let mut expr = (&m_node.initial).try_into()?;
969        for (op, node) in &m_node.extended {
970            let rhs = node.try_into()?;
971            match op {
972                cst::MultOp::Times => {
973                    expr = Expr::mul(expr, rhs);
974                }
975                cst::MultOp::Divide => {
976                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedDivision).into())
977                }
978                cst::MultOp::Mod => {
979                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedModulo).into())
980                }
981            }
982        }
983        Ok(expr)
984    }
985}
986
987impl TryFrom<&Node<Option<cst::Unary>>> for Expr {
988    type Error = ParseErrors;
989    fn try_from(u: &Node<Option<cst::Unary>>) -> Result<Expr, ParseErrors> {
990        let u_node = u.try_as_inner()?;
991
992        match u_node.op {
993            Some(cst::NegOp::Bang(num_bangs)) => {
994                let inner = (&u_node.item).try_into()?;
995                match num_bangs {
996                    0 => Ok(inner),
997                    1 => Ok(Expr::not(inner)),
998                    2 => Ok(Expr::not(Expr::not(inner))),
999                    3 => Ok(Expr::not(Expr::not(Expr::not(inner)))),
1000                    4 => Ok(Expr::not(Expr::not(Expr::not(Expr::not(inner))))),
1001                    _ => Err(u
1002                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
1003                        .into()),
1004                }
1005            }
1006            Some(cst::NegOp::Dash(0)) => Ok((&u_node.item).try_into()?),
1007            Some(cst::NegOp::Dash(mut num_dashes)) => {
1008                let inner = match &u_node.item.to_lit() {
1009                    Some(cst::Literal::Num(num)) => {
1010                        match num.cmp(&(InputInteger::MAX as u64 + 1)) {
1011                            std::cmp::Ordering::Less => {
1012                                num_dashes -= 1;
1013                                Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(
1014                                    -(*num as InputInteger),
1015                                )))
1016                            }
1017                            std::cmp::Ordering::Equal => {
1018                                num_dashes -= 1;
1019                                Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(
1020                                    InputInteger::MIN,
1021                                )))
1022                            }
1023                            std::cmp::Ordering::Greater => {
1024                                return Err(u_node
1025                                    .item
1026                                    .to_ast_err(ToASTErrorKind::IntegerLiteralTooLarge(*num))
1027                                    .into());
1028                            }
1029                        }
1030                    }
1031                    _ => (&u_node.item).try_into()?,
1032                };
1033                match num_dashes {
1034                    0 => Ok(inner),
1035                    1 => Ok(Expr::neg(inner)),
1036                    2 => {
1037                        // not safe to collapse `--` to nothing
1038                        Ok(Expr::neg(Expr::neg(inner)))
1039                    }
1040                    3 => Ok(Expr::neg(Expr::neg(Expr::neg(inner)))),
1041                    4 => Ok(Expr::neg(Expr::neg(Expr::neg(Expr::neg(inner))))),
1042                    _ => Err(u
1043                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1044                        .into()),
1045                }
1046            }
1047            Some(cst::NegOp::OverBang) => Err(u
1048                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
1049                .into()),
1050            Some(cst::NegOp::OverDash) => Err(u
1051                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1052                .into()),
1053            None => Ok((&u_node.item).try_into()?),
1054        }
1055    }
1056}
1057
1058/// Convert the given `cst::Primary` into either a (possibly namespaced)
1059/// function name, or an `Expr`.
1060///
1061/// (Upstream, the case where the `Primary` is a function name needs special
1062/// handling, because in that case it is not a valid expression. In all other
1063/// cases a `Primary` can be converted into an `Expr`.)
1064fn interpret_primary(
1065    p: &Node<Option<cst::Primary>>,
1066) -> Result<Either<ast::Name, Expr>, ParseErrors> {
1067    match p.try_as_inner()? {
1068        cst::Primary::Literal(lit) => Ok(Either::Right(lit.try_into()?)),
1069        cst::Primary::Ref(node) => match node.try_as_inner()? {
1070            cst::Ref::Uid {
1071                path,
1072                eid: eid_node,
1073            } => {
1074                let maybe_name = path.to_name().map(ast::EntityType::from);
1075                let maybe_eid = eid_node.as_valid_string();
1076
1077                let (name, eid) = flatten_tuple_2(maybe_name, maybe_eid)?;
1078                match to_unescaped_string(eid) {
1079                    Ok(eid) => Ok(Either::Right(Expr::lit(CedarValueJson::EntityEscape {
1080                        __entity: TypeAndId::from(ast::EntityUID::from_components(
1081                            name,
1082                            ast::Eid::new(eid),
1083                            None,
1084                        )),
1085                    }))),
1086                    Err(unescape_errs) => {
1087                        Err(ParseErrors::new_from_nonempty(unescape_errs.map(|err| {
1088                            {
1089                                crate::parser::err::ParseError::from(
1090                                    eid_node.to_ast_err(ToASTErrorKind::Unescape(err)),
1091                                )
1092                            }
1093                        })))
1094                    }
1095                }
1096            }
1097            r @ cst::Ref::Ref { .. } => Err(node
1098                .to_ast_err(ToASTErrorKind::InvalidEntityLiteral(r.to_string()))
1099                .into()),
1100        },
1101        cst::Primary::Name(node) => {
1102            let name = node.try_as_inner()?;
1103            let base_name = name.name.try_as_inner()?;
1104            match (&name.path[..], base_name) {
1105                (&[], cst::Ident::Principal) => Ok(Either::Right(Expr::var(ast::Var::Principal))),
1106                (&[], cst::Ident::Action) => Ok(Either::Right(Expr::var(ast::Var::Action))),
1107                (&[], cst::Ident::Resource) => Ok(Either::Right(Expr::var(ast::Var::Resource))),
1108                (&[], cst::Ident::Context) => Ok(Either::Right(Expr::var(ast::Var::Context))),
1109                (path, cst::Ident::Ident(id)) => Ok(Either::Left(
1110                    ast::InternalName::new(
1111                        id.parse()?,
1112                        path.iter()
1113                            .map(|node| {
1114                                node.try_as_inner()
1115                                    .map_err(Into::into)
1116                                    .and_then(|id| id.to_string().parse().map_err(Into::into))
1117                            })
1118                            .collect::<Result<Vec<ast::Id>, ParseErrors>>()?,
1119                        Some(node.loc.clone()),
1120                    )
1121                    .try_into()?,
1122                )),
1123                (path, id) => {
1124                    let (l, r, src) = match (path.first(), path.last()) {
1125                        (Some(l), Some(r)) => (
1126                            l.loc.start(),
1127                            r.loc.end() + ident_to_str_len(id),
1128                            Arc::clone(&l.loc.src),
1129                        ),
1130                        (_, _) => (0, 0, Arc::from("")),
1131                    };
1132                    Err(ToASTError::new(
1133                        ToASTErrorKind::ArbitraryVariable(name.to_string().into()),
1134                        Loc::new(l..r, src),
1135                    )
1136                    .into())
1137                }
1138            }
1139        }
1140        cst::Primary::Slot(node) => Ok(Either::Right(Expr::slot(
1141            node.try_as_inner()?
1142                .try_into()
1143                .map_err(|e| node.to_ast_err(e))?,
1144        ))),
1145        cst::Primary::Expr(e) => Ok(Either::Right(e.try_into()?)),
1146        cst::Primary::EList(nodes) => nodes
1147            .iter()
1148            .map(|node| node.try_into())
1149            .collect::<Result<Vec<Expr>, _>>()
1150            .map(Expr::set)
1151            .map(Either::Right),
1152        cst::Primary::RInits(nodes) => nodes
1153            .iter()
1154            .map(|node| {
1155                let cst::RecInit(k, v) = node.try_as_inner()?;
1156                let s = k.to_expr_or_special().and_then(|es| es.into_valid_attr())?;
1157                Ok((s, v.try_into()?))
1158            })
1159            .collect::<Result<HashMap<SmolStr, Expr>, ParseErrors>>()
1160            .map(Expr::record)
1161            .map(Either::Right),
1162    }
1163}
1164
1165impl TryFrom<&Node<Option<cst::Member>>> for Expr {
1166    type Error = ParseErrors;
1167    fn try_from(m: &Node<Option<cst::Member>>) -> Result<Expr, ParseErrors> {
1168        let m_node = m.try_as_inner()?;
1169        let mut item: Either<ast::Name, Expr> = interpret_primary(&m_node.item)?;
1170        for access in &m_node.access {
1171            match access.try_as_inner()? {
1172                cst::MemAccess::Field(node) => {
1173                    let id = node.to_valid_ident()?;
1174                    item = match item {
1175                        Either::Left(name) => {
1176                            return Err(node
1177                                .to_ast_err(ToASTErrorKind::InvalidAccess(name, id.to_smolstr()))
1178                                .into())
1179                        }
1180                        Either::Right(expr) => Either::Right(Expr::get_attr(expr, id.to_smolstr())),
1181                    };
1182                }
1183                cst::MemAccess::Call(args) => {
1184                    // we have item(args).  We hope item is either:
1185                    //   - an `ast::Name`, in which case we have a standard function call
1186                    //   - or an expr of the form `x.y`, in which case y is the method
1187                    //      name and not a field name. In the previous iteration of the
1188                    //      `for` loop we would have made `item` equal to
1189                    //      `Expr::GetAttr(x, y)`. Now we have to undo that to make a
1190                    //      method call instead.
1191                    //   - any other expression: it's an illegal call as the target is a higher order expression
1192                    item = match item {
1193                        Either::Left(name) => Either::Right(Expr::ext_call(
1194                            name.to_string().into(),
1195                            args.iter()
1196                                .map(|node| node.try_into())
1197                                .collect::<Result<Vec<_>, _>>()?,
1198                        )),
1199                        Either::Right(Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr })) => {
1200                            let args = args.iter().map(|node| node.try_into()).collect::<Result<
1201                                Vec<Expr>,
1202                                ParseErrors,
1203                            >>(
1204                            )?;
1205                            let args = args.into_iter();
1206                            match attr.as_str() {
1207                                "contains" => Either::Right(Expr::contains(
1208                                    left,
1209                                    extract_single_argument(args, "contains()", &access.loc)?,
1210                                )),
1211                                "containsAll" => Either::Right(Expr::contains_all(
1212                                    left,
1213                                    extract_single_argument(args, "containsAll()", &access.loc)?,
1214                                )),
1215                                "containsAny" => Either::Right(Expr::contains_any(
1216                                    left,
1217                                    extract_single_argument(args, "containsAny()", &access.loc)?,
1218                                )),
1219                                _ => {
1220                                    // have to add the "receiver" argument as
1221                                    // first in the list for the method call
1222                                    let mut args = args.collect::<Vec<_>>();
1223                                    args.insert(0, Arc::unwrap_or_clone(left));
1224                                    Either::Right(Expr::ext_call(attr, args))
1225                                }
1226                            }
1227                        }
1228                        _ => return Err(access.to_ast_err(ToASTErrorKind::ExpressionCall).into()),
1229                    };
1230                }
1231                cst::MemAccess::Index(node) => {
1232                    let s = Expr::try_from(node)?
1233                        .into_string_literal()
1234                        .map_err(|_| node.to_ast_err(ToASTErrorKind::NonStringIndex))?;
1235                    item = match item {
1236                        Either::Left(name) => {
1237                            return Err(node
1238                                .to_ast_err(ToASTErrorKind::InvalidIndex(name, s))
1239                                .into())
1240                        }
1241                        Either::Right(expr) => Either::Right(Expr::get_attr(expr, s)),
1242                    };
1243                }
1244            }
1245        }
1246        match item {
1247            Either::Left(_) => Err(m.to_ast_err(ToASTErrorKind::MembershipInvariantViolation))?,
1248            Either::Right(expr) => Ok(expr),
1249        }
1250    }
1251}
1252
1253/// Return the single argument in `args` iterator, or return a wrong arity error
1254/// if the iterator has 0 elements or more than 1 element.
1255pub fn extract_single_argument<T>(
1256    args: impl ExactSizeIterator<Item = T>,
1257    fn_name: &'static str,
1258    loc: &Loc,
1259) -> Result<T, ParseErrors> {
1260    let mut iter = args.fuse().peekable();
1261    let first = iter.next();
1262    let second = iter.peek();
1263    match (first, second) {
1264        (None, _) => Err(ParseErrors::singleton(ToASTError::new(
1265            ToASTErrorKind::wrong_arity(fn_name, 1, 0),
1266            loc.clone(),
1267        ))),
1268        (Some(_), Some(_)) => Err(ParseErrors::singleton(ToASTError::new(
1269            ToASTErrorKind::wrong_arity(fn_name, 1, iter.len() + 1),
1270            loc.clone(),
1271        ))),
1272        (Some(first), None) => Ok(first),
1273    }
1274}
1275
1276impl TryFrom<&Node<Option<cst::Literal>>> for Expr {
1277    type Error = ParseErrors;
1278    fn try_from(lit: &Node<Option<cst::Literal>>) -> Result<Expr, ParseErrors> {
1279        match lit.try_as_inner()? {
1280            cst::Literal::True => Ok(Expr::lit(CedarValueJson::Bool(true))),
1281            cst::Literal::False => Ok(Expr::lit(CedarValueJson::Bool(false))),
1282            cst::Literal::Num(n) => Ok(Expr::lit(CedarValueJson::Long(
1283                (*n).try_into()
1284                    .map_err(|_| lit.to_ast_err(ToASTErrorKind::IntegerLiteralTooLarge(*n)))?,
1285            ))),
1286            cst::Literal::Str(node) => match node.try_as_inner()? {
1287                cst::Str::String(s) => match to_unescaped_string(s) {
1288                    Ok(s) => Ok(Expr::lit(CedarValueJson::String(s))),
1289                    Err(errs) => {
1290                        Err(ParseErrors::new_from_nonempty(errs.map(|err| {
1291                            node.to_ast_err(ToASTErrorKind::Unescape(err)).into()
1292                        })))
1293                    }
1294                },
1295                cst::Str::Invalid(invalid_str) => Err(node
1296                    .to_ast_err(ToASTErrorKind::InvalidString(invalid_str.to_string()))
1297                    .into()),
1298            },
1299        }
1300    }
1301}
1302
1303impl TryFrom<&Node<Option<cst::Name>>> for Expr {
1304    type Error = ParseErrors;
1305    fn try_from(name: &Node<Option<cst::Name>>) -> Result<Expr, ParseErrors> {
1306        let name_node = name.try_as_inner()?;
1307        let base_name = name_node.name.try_as_inner()?;
1308        match (&name_node.path[..], base_name) {
1309            (&[], cst::Ident::Principal) => Ok(Expr::var(ast::Var::Principal)),
1310            (&[], cst::Ident::Action) => Ok(Expr::var(ast::Var::Action)),
1311            (&[], cst::Ident::Resource) => Ok(Expr::var(ast::Var::Resource)),
1312            (&[], cst::Ident::Context) => Ok(Expr::var(ast::Var::Context)),
1313            (_, _) => Err(name
1314                .to_ast_err(ToASTErrorKind::ArbitraryVariable(
1315                    name_node.to_string().into(),
1316                ))
1317                .into()),
1318        }
1319    }
1320}
1321
1322/// Get the string length of an `Ident`. Used to print the source location for error messages
1323fn ident_to_str_len(i: &Ident) -> usize {
1324    match i {
1325        Ident::Principal => 9,
1326        Ident::Action => 6,
1327        Ident::Resource => 8,
1328        Ident::Context => 7,
1329        Ident::True => 4,
1330        Ident::False => 5,
1331        Ident::Permit => 6,
1332        Ident::Forbid => 6,
1333        Ident::When => 4,
1334        Ident::Unless => 6,
1335        Ident::In => 2,
1336        Ident::Has => 3,
1337        Ident::Like => 4,
1338        Ident::If => 2,
1339        Ident::Then => 4,
1340        Ident::Else => 4,
1341        Ident::Ident(s) => s.len(),
1342        Ident::Invalid(s) => s.len(),
1343        Ident::Is => 2,
1344    }
1345}
1346
1347impl std::fmt::Display for Expr {
1348    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1349        match self {
1350            Self::ExprNoExt(e) => write!(f, "{e}"),
1351            Self::ExtFuncCall(e) => write!(f, "{e}"),
1352        }
1353    }
1354}
1355
1356fn display_cedarvaluejson(f: &mut std::fmt::Formatter<'_>, v: &CedarValueJson) -> std::fmt::Result {
1357    match v {
1358        // Add parentheses around negative numeric literals otherwise
1359        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1360        CedarValueJson::Long(n) if *n < 0 => write!(f, "({n})"),
1361        CedarValueJson::Long(n) => write!(f, "{n}"),
1362        CedarValueJson::Bool(b) => write!(f, "{b}"),
1363        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1364        CedarValueJson::EntityEscape { __entity } => {
1365            match ast::EntityUID::try_from(__entity.clone()) {
1366                Ok(euid) => write!(f, "{euid}"),
1367                Err(e) => write!(f, "(invalid entity uid: {})", e),
1368            }
1369        }
1370        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1371        CedarValueJson::ExtnEscape {
1372            __extn: FnAndArg { ext_fn, arg },
1373        } => {
1374            // search for the name and callstyle
1375            let style = Extensions::all_available().all_funcs().find_map(|f| {
1376                if &f.name().to_string() == ext_fn {
1377                    Some(f.style())
1378                } else {
1379                    None
1380                }
1381            });
1382            match style {
1383                Some(ast::CallStyle::MethodStyle) => {
1384                    display_cedarvaluejson(f, arg)?;
1385                    write!(f, ".{ext_fn}()")?;
1386                    Ok(())
1387                }
1388                Some(ast::CallStyle::FunctionStyle) | None => {
1389                    write!(f, "{ext_fn}(")?;
1390                    display_cedarvaluejson(f, arg)?;
1391                    write!(f, ")")?;
1392                    Ok(())
1393                }
1394            }
1395        }
1396        CedarValueJson::Set(v) => {
1397            write!(f, "[")?;
1398            for (i, val) in v.iter().enumerate() {
1399                display_cedarvaluejson(f, val)?;
1400                if i < (v.len() - 1) {
1401                    write!(f, ", ")?;
1402                }
1403            }
1404            write!(f, "]")?;
1405            Ok(())
1406        }
1407        CedarValueJson::Record(m) => {
1408            write!(f, "{{")?;
1409            for (i, (k, v)) in m.iter().enumerate() {
1410                write!(f, "\"{}\": ", k.escape_debug())?;
1411                display_cedarvaluejson(f, v)?;
1412                if i < (m.len() - 1) {
1413                    write!(f, ", ")?;
1414                }
1415            }
1416            write!(f, "}}")?;
1417            Ok(())
1418        }
1419        CedarValueJson::Null => {
1420            write!(f, "null")?;
1421            Ok(())
1422        }
1423    }
1424}
1425
1426impl std::fmt::Display for ExprNoExt {
1427    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1428        match &self {
1429            ExprNoExt::Value(v) => display_cedarvaluejson(f, v),
1430            ExprNoExt::Var(v) => write!(f, "{v}"),
1431            ExprNoExt::Slot(id) => write!(f, "{id}"),
1432            ExprNoExt::Not { arg } => {
1433                write!(f, "!")?;
1434                maybe_with_parens(f, arg)
1435            }
1436            ExprNoExt::Neg { arg } => {
1437                // Always add parentheses instead of calling
1438                // `maybe_with_parens`.
1439                // This makes sure that we always get a negation operation back
1440                // (as opposed to e.g., a negative number) when parsing the
1441                // printed form, thus preserving the round-tripping property.
1442                write!(f, "-({arg})")
1443            }
1444            ExprNoExt::Eq { left, right } => {
1445                maybe_with_parens(f, left)?;
1446                write!(f, " == ")?;
1447                maybe_with_parens(f, right)
1448            }
1449            ExprNoExt::NotEq { left, right } => {
1450                maybe_with_parens(f, left)?;
1451                write!(f, " != ")?;
1452                maybe_with_parens(f, right)
1453            }
1454            ExprNoExt::In { left, right } => {
1455                maybe_with_parens(f, left)?;
1456                write!(f, " in ")?;
1457                maybe_with_parens(f, right)
1458            }
1459            ExprNoExt::Less { left, right } => {
1460                maybe_with_parens(f, left)?;
1461                write!(f, " < ")?;
1462                maybe_with_parens(f, right)
1463            }
1464            ExprNoExt::LessEq { left, right } => {
1465                maybe_with_parens(f, left)?;
1466                write!(f, " <= ")?;
1467                maybe_with_parens(f, right)
1468            }
1469            ExprNoExt::Greater { left, right } => {
1470                maybe_with_parens(f, left)?;
1471                write!(f, " > ")?;
1472                maybe_with_parens(f, right)
1473            }
1474            ExprNoExt::GreaterEq { left, right } => {
1475                maybe_with_parens(f, left)?;
1476                write!(f, " >= ")?;
1477                maybe_with_parens(f, right)
1478            }
1479            ExprNoExt::And { left, right } => {
1480                maybe_with_parens(f, left)?;
1481                write!(f, " && ")?;
1482                maybe_with_parens(f, right)
1483            }
1484            ExprNoExt::Or { left, right } => {
1485                maybe_with_parens(f, left)?;
1486                write!(f, " || ")?;
1487                maybe_with_parens(f, right)
1488            }
1489            ExprNoExt::Add { left, right } => {
1490                maybe_with_parens(f, left)?;
1491                write!(f, " + ")?;
1492                maybe_with_parens(f, right)
1493            }
1494            ExprNoExt::Sub { left, right } => {
1495                maybe_with_parens(f, left)?;
1496                write!(f, " - ")?;
1497                maybe_with_parens(f, right)
1498            }
1499            ExprNoExt::Mul { left, right } => {
1500                maybe_with_parens(f, left)?;
1501                write!(f, " * ")?;
1502                maybe_with_parens(f, right)
1503            }
1504            ExprNoExt::Contains { left, right } => {
1505                maybe_with_parens(f, left)?;
1506                write!(f, ".contains({right})")
1507            }
1508            ExprNoExt::ContainsAll { left, right } => {
1509                maybe_with_parens(f, left)?;
1510                write!(f, ".containsAll({right})")
1511            }
1512            ExprNoExt::ContainsAny { left, right } => {
1513                maybe_with_parens(f, left)?;
1514                write!(f, ".containsAny({right})")
1515            }
1516            ExprNoExt::GetAttr { left, attr } => {
1517                maybe_with_parens(f, left)?;
1518                write!(f, "[\"{}\"]", attr.escape_debug())
1519            }
1520            ExprNoExt::HasAttr { left, attr } => {
1521                maybe_with_parens(f, left)?;
1522                write!(f, " has \"{}\"", attr.escape_debug())
1523            }
1524            ExprNoExt::Like { left, pattern } => {
1525                maybe_with_parens(f, left)?;
1526                write!(
1527                    f,
1528                    " like \"{}\"",
1529                    crate::ast::Pattern::from(pattern.clone())
1530                )
1531            }
1532            ExprNoExt::Is {
1533                left,
1534                entity_type,
1535                in_expr,
1536            } => {
1537                maybe_with_parens(f, left)?;
1538                write!(f, " is {entity_type}")?;
1539                match in_expr {
1540                    Some(in_expr) => {
1541                        write!(f, " in ")?;
1542                        maybe_with_parens(f, in_expr)
1543                    }
1544                    None => Ok(()),
1545                }
1546            }
1547            ExprNoExt::If {
1548                cond_expr,
1549                then_expr,
1550                else_expr,
1551            } => {
1552                write!(f, "if ")?;
1553                maybe_with_parens(f, cond_expr)?;
1554                write!(f, " then ")?;
1555                maybe_with_parens(f, then_expr)?;
1556                write!(f, " else ")?;
1557                maybe_with_parens(f, else_expr)
1558            }
1559            ExprNoExt::Set(v) => write!(f, "[{}]", v.iter().join(", ")),
1560            ExprNoExt::Record(m) => write!(
1561                f,
1562                "{{{}}}",
1563                m.iter()
1564                    .map(|(k, v)| format!("\"{}\": {}", k.escape_debug(), v))
1565                    .join(", ")
1566            ),
1567        }
1568    }
1569}
1570
1571impl std::fmt::Display for ExtFuncCall {
1572    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1573        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1574        #[allow(clippy::unreachable)]
1575        let Some((fn_name, args)) = self.call.iter().next() else {
1576            unreachable!("invariant violated: empty ExtFuncCall")
1577        };
1578        // search for the name and callstyle
1579        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1580            if &ext_fn.name().to_string() == fn_name {
1581                Some(ext_fn.style())
1582            } else {
1583                None
1584            }
1585        });
1586        match (style, args.iter().next()) {
1587            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1588                maybe_with_parens(f, receiver)?;
1589                write!(f, ".{}({})", fn_name, args.iter().skip(1).join(", "))
1590            }
1591            (_, _) => {
1592                write!(f, "{}({})", fn_name, args.iter().join(", "))
1593            }
1594        }
1595    }
1596}
1597
1598/// returns the `Display` representation of the Expr, adding parens around
1599/// the entire string if necessary.
1600/// E.g., won't add parens for constants or `principal` etc, but will for things
1601/// like `(2 < 5)`.
1602/// When in doubt, add the parens.
1603fn maybe_with_parens(f: &mut std::fmt::Formatter<'_>, expr: &Expr) -> std::fmt::Result {
1604    match expr {
1605        Expr::ExprNoExt(ExprNoExt::Set(_)) |
1606        Expr::ExprNoExt(ExprNoExt::Record(_)) |
1607        Expr::ExprNoExt(ExprNoExt::Value(_)) |
1608        Expr::ExprNoExt(ExprNoExt::Var(_)) |
1609        Expr::ExprNoExt(ExprNoExt::Slot(_)) => write!(f, "{expr}"),
1610
1611        // we want parens here because things like parse((!x).y)
1612        // would be printed into !x.y which has a different meaning
1613        Expr::ExprNoExt(ExprNoExt::Not { .. }) |
1614        // we want parens here because things like parse((-x).y)
1615        // would be printed into -x.y which has a different meaning
1616        Expr::ExprNoExt(ExprNoExt::Neg { .. })  |
1617        Expr::ExprNoExt(ExprNoExt::Eq { .. }) |
1618        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) |
1619        Expr::ExprNoExt(ExprNoExt::In { .. }) |
1620        Expr::ExprNoExt(ExprNoExt::Less { .. }) |
1621        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) |
1622        Expr::ExprNoExt(ExprNoExt::Greater { .. }) |
1623        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) |
1624        Expr::ExprNoExt(ExprNoExt::And { .. }) |
1625        Expr::ExprNoExt(ExprNoExt::Or { .. }) |
1626        Expr::ExprNoExt(ExprNoExt::Add { .. }) |
1627        Expr::ExprNoExt(ExprNoExt::Sub { .. }) |
1628        Expr::ExprNoExt(ExprNoExt::Mul { .. }) |
1629        Expr::ExprNoExt(ExprNoExt::Contains { .. }) |
1630        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) |
1631        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) |
1632        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) |
1633        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) |
1634        Expr::ExprNoExt(ExprNoExt::Like { .. }) |
1635        Expr::ExprNoExt(ExprNoExt::Is { .. }) |
1636        Expr::ExprNoExt(ExprNoExt::If { .. }) |
1637        Expr::ExtFuncCall { .. } => write!(f, "({expr})"),
1638    }
1639}
1640
1641#[cfg(test)]
1642// PANIC SAFETY: this is unit test code
1643#[allow(clippy::indexing_slicing)]
1644// PANIC SAFETY: Unit Test Code
1645#[allow(clippy::panic)]
1646mod test {
1647    use crate::parser::err::ParseError;
1648
1649    use super::*;
1650    use cool_asserts::assert_matches;
1651
1652    #[test]
1653    fn test_invalid_expr_from_cst_name() {
1654        let src = "some_long_str";
1655        let path = vec![Node::with_source_loc(
1656            Some(cst::Ident::Ident(src.into())),
1657            Loc::new(0..12, Arc::from(src)),
1658        )];
1659        let name = Node::with_source_loc(Some(cst::Ident::Else), Loc::new(13..16, Arc::from(src)));
1660        let cst_name = Node::with_source_loc(
1661            Some(cst::Name { path, name }),
1662            Loc::new(0..16, Arc::from(src)),
1663        );
1664
1665        assert_matches!(Expr::try_from(&cst_name), Err(e) => {
1666            assert!(e.len() == 1);
1667            assert_matches!(&e[0],
1668                ParseError::ToAST(to_ast_error) => {
1669                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::ArbitraryVariable(s) => {
1670                        assert_eq!(s, "some_long_str::else");
1671                    });
1672                }
1673            );
1674        });
1675    }
1676}