Skip to main content

cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18use crate::ast::InputInteger;
19use crate::entities::{
20    CedarValueJson, EscapeKind, FnAndArg, JsonDeserializationError,
21    JsonDeserializationErrorContext, TypeAndId,
22};
23use crate::extensions::Extensions;
24use crate::parser::cst::{self, Ident};
25use crate::parser::err::{ParseErrors, ToASTError, ToASTErrorKind};
26use crate::parser::{unescape, Loc, Node};
27use crate::{ast, FromNormalizedStr};
28use either::Either;
29use itertools::Itertools;
30use serde::{Deserialize, Serialize};
31use serde_with::serde_as;
32use smol_str::{SmolStr, ToSmolStr};
33use std::collections::HashMap;
34use std::sync::Arc;
35
36/// Serde JSON structure for a Cedar expression in the EST format
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38#[serde(untagged)]
39#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
40#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
41pub enum Expr {
42    /// Any Cedar expression other than an extension function call.
43    /// We try to match this first, see docs on #[serde(untagged)].
44    ExprNoExt(ExprNoExt),
45    /// If that didn't match (because the key is not one of the keys defined in
46    /// `ExprNoExt`), we assume we have an extension function call, where the
47    /// key is the name of an extension function or method.
48    ExtFuncCall(ExtFuncCall),
49}
50
51/// Serde JSON structure for [any Cedar expression other than an extension
52/// function call] in the EST format
53#[serde_as]
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55#[serde(deny_unknown_fields)]
56#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
57#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
58pub enum ExprNoExt {
59    /// Literal value (including anything that's legal to express in the
60    /// attribute-value JSON format)
61    Value(CedarValueJson),
62    /// Var
63    Var(ast::Var),
64    /// Template slot
65    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
66    /// Unknown (for partial evaluation)
67    Unknown {
68        /// Name of the unknown
69        #[cfg_attr(feature = "wasm", tsify(type = "string"))]
70        name: SmolStr,
71    },
72    /// `!`
73    #[serde(rename = "!")]
74    Not {
75        /// Argument
76        arg: Arc<Expr>,
77    },
78    /// `-`
79    #[serde(rename = "neg")]
80    Neg {
81        /// Argument
82        arg: Arc<Expr>,
83    },
84    /// `==`
85    #[serde(rename = "==")]
86    Eq {
87        /// Left-hand argument
88        left: Arc<Expr>,
89        /// Right-hand argument
90        right: Arc<Expr>,
91    },
92    /// `!=`
93    #[serde(rename = "!=")]
94    NotEq {
95        /// Left-hand argument
96        left: Arc<Expr>,
97        /// Right-hand argument
98        right: Arc<Expr>,
99    },
100    /// `in`
101    #[serde(rename = "in")]
102    In {
103        /// Left-hand argument
104        left: Arc<Expr>,
105        /// Right-hand argument
106        right: Arc<Expr>,
107    },
108    /// `<`
109    #[serde(rename = "<")]
110    Less {
111        /// Left-hand argument
112        left: Arc<Expr>,
113        /// Right-hand argument
114        right: Arc<Expr>,
115    },
116    /// `<=`
117    #[serde(rename = "<=")]
118    LessEq {
119        /// Left-hand argument
120        left: Arc<Expr>,
121        /// Right-hand argument
122        right: Arc<Expr>,
123    },
124    /// `>`
125    #[serde(rename = ">")]
126    Greater {
127        /// Left-hand argument
128        left: Arc<Expr>,
129        /// Right-hand argument
130        right: Arc<Expr>,
131    },
132    /// `>=`
133    #[serde(rename = ">=")]
134    GreaterEq {
135        /// Left-hand argument
136        left: Arc<Expr>,
137        /// Right-hand argument
138        right: Arc<Expr>,
139    },
140    /// `&&`
141    #[serde(rename = "&&")]
142    And {
143        /// Left-hand argument
144        left: Arc<Expr>,
145        /// Right-hand argument
146        right: Arc<Expr>,
147    },
148    /// `||`
149    #[serde(rename = "||")]
150    Or {
151        /// Left-hand argument
152        left: Arc<Expr>,
153        /// Right-hand argument
154        right: Arc<Expr>,
155    },
156    /// `+`
157    #[serde(rename = "+")]
158    Add {
159        /// Left-hand argument
160        left: Arc<Expr>,
161        /// Right-hand argument
162        right: Arc<Expr>,
163    },
164    /// `-`
165    #[serde(rename = "-")]
166    Sub {
167        /// Left-hand argument
168        left: Arc<Expr>,
169        /// Right-hand argument
170        right: Arc<Expr>,
171    },
172    /// `*`
173    #[serde(rename = "*")]
174    Mul {
175        /// Left-hand argument
176        left: Arc<Expr>,
177        /// Right-hand argument
178        right: Arc<Expr>,
179    },
180    /// `contains()`
181    #[serde(rename = "contains")]
182    Contains {
183        /// Left-hand argument (receiver)
184        left: Arc<Expr>,
185        /// Right-hand argument (inside the `()`)
186        right: Arc<Expr>,
187    },
188    /// `containsAll()`
189    #[serde(rename = "containsAll")]
190    ContainsAll {
191        /// Left-hand argument (receiver)
192        left: Arc<Expr>,
193        /// Right-hand argument (inside the `()`)
194        right: Arc<Expr>,
195    },
196    /// `containsAny()`
197    #[serde(rename = "containsAny")]
198    ContainsAny {
199        /// Left-hand argument (receiver)
200        left: Arc<Expr>,
201        /// Right-hand argument (inside the `()`)
202        right: Arc<Expr>,
203    },
204    /// Get-attribute
205    #[serde(rename = ".")]
206    GetAttr {
207        /// Left-hand argument
208        left: Arc<Expr>,
209        /// Attribute name
210        attr: SmolStr,
211    },
212    /// `has`
213    #[serde(rename = "has")]
214    HasAttr {
215        /// Left-hand argument
216        left: Arc<Expr>,
217        /// Attribute name
218        attr: SmolStr,
219    },
220    /// `like`
221    #[serde(rename = "like")]
222    Like {
223        /// Left-hand argument
224        left: Arc<Expr>,
225        /// Pattern
226        pattern: SmolStr,
227    },
228    /// `<entity> is <entity_type> in <entity_or_entity_set> `
229    #[serde(rename = "is")]
230    Is {
231        /// Left-hand entity argument
232        left: Arc<Expr>,
233        /// Entity type
234        entity_type: SmolStr,
235        /// Entity or entity set
236        #[serde(skip_serializing_if = "Option::is_none")]
237        #[serde(rename = "in")]
238        in_expr: Option<Arc<Expr>>,
239    },
240    /// Ternary
241    #[serde(rename = "if-then-else")]
242    If {
243        /// Condition
244        #[serde(rename = "if")]
245        cond_expr: Arc<Expr>,
246        /// `then` expression
247        #[serde(rename = "then")]
248        then_expr: Arc<Expr>,
249        /// `else` expression
250        #[serde(rename = "else")]
251        else_expr: Arc<Expr>,
252    },
253    /// Set literal, whose elements may be arbitrary expressions
254    /// (which is why we need this case specifically and can't just
255    /// use Expr::Value)
256    Set(Vec<Expr>),
257    /// Record literal, whose elements may be arbitrary expressions
258    /// (which is why we need this case specifically and can't just
259    /// use Expr::Value)
260    Record(
261        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
262        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
263        HashMap<SmolStr, Expr>,
264    ),
265}
266
267/// Serde JSON structure for an extension function call in the EST format
268#[serde_as]
269#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
270#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
271#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
272pub struct ExtFuncCall {
273    /// maps the name of the function to a JSON list/array of the arguments.
274    /// Note that for method calls, the method receiver is the first argument.
275    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
276    /// second argument is `b`.
277    ///
278    /// INVARIANT: This map should always have exactly one k-v pair (not more or
279    /// less), but we make it a map in order to get the correct JSON structure
280    /// we want.
281    #[serde(flatten)]
282    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
283    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
284    call: HashMap<SmolStr, Vec<Expr>>,
285}
286
287#[allow(clippy::should_implement_trait)] // the names of arithmetic constructors alias with those of certain trait methods such as `add` of `std::ops::Add`
288impl Expr {
289    /// literal
290    pub fn lit(lit: CedarValueJson) -> Self {
291        Expr::ExprNoExt(ExprNoExt::Value(lit))
292    }
293
294    /// principal, action, resource, context
295    pub fn var(var: ast::Var) -> Self {
296        Expr::ExprNoExt(ExprNoExt::Var(var))
297    }
298
299    /// Template slots
300    pub fn slot(slot: ast::SlotId) -> Self {
301        Expr::ExprNoExt(ExprNoExt::Slot(slot))
302    }
303
304    /// Partial-evaluation unknowns
305    pub fn unknown(name: impl Into<SmolStr>) -> Self {
306        Expr::ExprNoExt(ExprNoExt::Unknown { name: name.into() })
307    }
308
309    /// `!`
310    pub fn not(e: Expr) -> Self {
311        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
312    }
313
314    /// `-`
315    pub fn neg(e: Expr) -> Self {
316        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
317    }
318
319    /// `==`
320    pub fn eq(left: Expr, right: Expr) -> Self {
321        Expr::ExprNoExt(ExprNoExt::Eq {
322            left: Arc::new(left),
323            right: Arc::new(right),
324        })
325    }
326
327    /// `!=`
328    pub fn noteq(left: Expr, right: Expr) -> Self {
329        Expr::ExprNoExt(ExprNoExt::NotEq {
330            left: Arc::new(left),
331            right: Arc::new(right),
332        })
333    }
334
335    /// `in`
336    pub fn _in(left: Expr, right: Expr) -> Self {
337        Expr::ExprNoExt(ExprNoExt::In {
338            left: Arc::new(left),
339            right: Arc::new(right),
340        })
341    }
342
343    /// `<`
344    pub fn less(left: Expr, right: Expr) -> Self {
345        Expr::ExprNoExt(ExprNoExt::Less {
346            left: Arc::new(left),
347            right: Arc::new(right),
348        })
349    }
350
351    /// `<=`
352    pub fn lesseq(left: Expr, right: Expr) -> Self {
353        Expr::ExprNoExt(ExprNoExt::LessEq {
354            left: Arc::new(left),
355            right: Arc::new(right),
356        })
357    }
358
359    /// `>`
360    pub fn greater(left: Expr, right: Expr) -> Self {
361        Expr::ExprNoExt(ExprNoExt::Greater {
362            left: Arc::new(left),
363            right: Arc::new(right),
364        })
365    }
366
367    /// `>=`
368    pub fn greatereq(left: Expr, right: Expr) -> Self {
369        Expr::ExprNoExt(ExprNoExt::GreaterEq {
370            left: Arc::new(left),
371            right: Arc::new(right),
372        })
373    }
374
375    /// `&&`
376    pub fn and(left: Expr, right: Expr) -> Self {
377        Expr::ExprNoExt(ExprNoExt::And {
378            left: Arc::new(left),
379            right: Arc::new(right),
380        })
381    }
382
383    /// `||`
384    pub fn or(left: Expr, right: Expr) -> Self {
385        Expr::ExprNoExt(ExprNoExt::Or {
386            left: Arc::new(left),
387            right: Arc::new(right),
388        })
389    }
390
391    /// `+`
392    pub fn add(left: Expr, right: Expr) -> Self {
393        Expr::ExprNoExt(ExprNoExt::Add {
394            left: Arc::new(left),
395            right: Arc::new(right),
396        })
397    }
398
399    /// `-`
400    pub fn sub(left: Expr, right: Expr) -> Self {
401        Expr::ExprNoExt(ExprNoExt::Sub {
402            left: Arc::new(left),
403            right: Arc::new(right),
404        })
405    }
406
407    /// `*`
408    pub fn mul(left: Expr, right: Expr) -> Self {
409        Expr::ExprNoExt(ExprNoExt::Mul {
410            left: Arc::new(left),
411            right: Arc::new(right),
412        })
413    }
414
415    /// `left.contains(right)`
416    pub fn contains(left: Arc<Expr>, right: Expr) -> Self {
417        Expr::ExprNoExt(ExprNoExt::Contains {
418            left,
419            right: Arc::new(right),
420        })
421    }
422
423    /// `left.containsAll(right)`
424    pub fn contains_all(left: Arc<Expr>, right: Expr) -> Self {
425        Expr::ExprNoExt(ExprNoExt::ContainsAll {
426            left,
427            right: Arc::new(right),
428        })
429    }
430
431    /// `left.containsAny(right)`
432    pub fn contains_any(left: Arc<Expr>, right: Expr) -> Self {
433        Expr::ExprNoExt(ExprNoExt::ContainsAny {
434            left,
435            right: Arc::new(right),
436        })
437    }
438
439    /// `left.attr`
440    pub fn get_attr(left: Expr, attr: SmolStr) -> Self {
441        Expr::ExprNoExt(ExprNoExt::GetAttr {
442            left: Arc::new(left),
443            attr,
444        })
445    }
446
447    /// `left has attr`
448    pub fn has_attr(left: Expr, attr: SmolStr) -> Self {
449        Expr::ExprNoExt(ExprNoExt::HasAttr {
450            left: Arc::new(left),
451            attr,
452        })
453    }
454
455    /// `left like pattern`
456    pub fn like(left: Expr, pattern: SmolStr) -> Self {
457        Expr::ExprNoExt(ExprNoExt::Like {
458            left: Arc::new(left),
459            pattern,
460        })
461    }
462
463    /// `left is entity_type`
464    pub fn is_entity_type(left: Expr, entity_type: SmolStr) -> Self {
465        Expr::ExprNoExt(ExprNoExt::Is {
466            left: Arc::new(left),
467            entity_type,
468            in_expr: None,
469        })
470    }
471
472    /// `left is entity_type in entity`
473    pub fn is_entity_type_in(left: Expr, entity_type: SmolStr, entity: Expr) -> Self {
474        Expr::ExprNoExt(ExprNoExt::Is {
475            left: Arc::new(left),
476            entity_type,
477            in_expr: Some(Arc::new(entity)),
478        })
479    }
480
481    /// `if cond_expr then then_expr else else_expr`
482    pub fn ite(cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
483        Expr::ExprNoExt(ExprNoExt::If {
484            cond_expr: Arc::new(cond_expr),
485            then_expr: Arc::new(then_expr),
486            else_expr: Arc::new(else_expr),
487        })
488    }
489
490    /// e.g. [1+2, !(context has department)]
491    pub fn set(elements: Vec<Expr>) -> Self {
492        Expr::ExprNoExt(ExprNoExt::Set(elements))
493    }
494
495    /// e.g. {foo: 1+2, bar: !(context has department)}
496    pub fn record(map: HashMap<SmolStr, Expr>) -> Self {
497        Expr::ExprNoExt(ExprNoExt::Record(map))
498    }
499
500    /// extension function call, including method calls
501    pub fn ext_call(fn_name: SmolStr, args: Vec<Expr>) -> Self {
502        Expr::ExtFuncCall(ExtFuncCall {
503            call: [(fn_name, args)].into_iter().collect(),
504        })
505    }
506
507    /// Consume the `Expr`, producing a string literal if it was a string literal, otherwise returns the literal in the `Err` variant.
508    pub fn into_string_literal(self) -> Result<SmolStr, Self> {
509        match self {
510            Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::String(s))) => Ok(s),
511            _ => Err(self),
512        }
513    }
514}
515
516impl Expr {
517    /// Attempt to convert this `est::Expr` into an `ast::Expr`
518    ///
519    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
520    pub fn try_into_ast(self, id: ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
521        match self {
522            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
523                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
524                .map(Into::into)
525                .map_err(Into::into),
526            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
527            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
528            Expr::ExprNoExt(ExprNoExt::Unknown { name }) => {
529                Ok(ast::Expr::unknown(ast::Unknown::new_untyped(name)))
530            }
531            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
532                Ok(ast::Expr::not((*arg).clone().try_into_ast(id)?))
533            }
534            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
535                Ok(ast::Expr::neg((*arg).clone().try_into_ast(id)?))
536            }
537            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
538                (*left).clone().try_into_ast(id.clone())?,
539                (*right).clone().try_into_ast(id)?,
540            )),
541            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
542                (*left).clone().try_into_ast(id.clone())?,
543                (*right).clone().try_into_ast(id)?,
544            )),
545            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
546                (*left).clone().try_into_ast(id.clone())?,
547                (*right).clone().try_into_ast(id)?,
548            )),
549            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
550                (*left).clone().try_into_ast(id.clone())?,
551                (*right).clone().try_into_ast(id)?,
552            )),
553            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
554                (*left).clone().try_into_ast(id.clone())?,
555                (*right).clone().try_into_ast(id)?,
556            )),
557            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
558                (*left).clone().try_into_ast(id.clone())?,
559                (*right).clone().try_into_ast(id)?,
560            )),
561            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
562                (*left).clone().try_into_ast(id.clone())?,
563                (*right).clone().try_into_ast(id)?,
564            )),
565            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
566                (*left).clone().try_into_ast(id.clone())?,
567                (*right).clone().try_into_ast(id)?,
568            )),
569            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
570                (*left).clone().try_into_ast(id.clone())?,
571                (*right).clone().try_into_ast(id)?,
572            )),
573            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
574                (*left).clone().try_into_ast(id.clone())?,
575                (*right).clone().try_into_ast(id)?,
576            )),
577            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
578                (*left).clone().try_into_ast(id.clone())?,
579                (*right).clone().try_into_ast(id)?,
580            )),
581            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
582                (*left).clone().try_into_ast(id.clone())?,
583                (*right).clone().try_into_ast(id)?,
584            )),
585            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
586                (*left).clone().try_into_ast(id.clone())?,
587                (*right).clone().try_into_ast(id)?,
588            )),
589            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
590                (*left).clone().try_into_ast(id.clone())?,
591                (*right).clone().try_into_ast(id)?,
592            )),
593            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
594                (*left).clone().try_into_ast(id.clone())?,
595                (*right).clone().try_into_ast(id)?,
596            )),
597            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => {
598                Ok(ast::Expr::get_attr((*left).clone().try_into_ast(id)?, attr))
599            }
600            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => {
601                Ok(ast::Expr::has_attr((*left).clone().try_into_ast(id)?, attr))
602            }
603            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => {
604                match unescape::to_pattern(&pattern) {
605                    Ok(pattern) => Ok(ast::Expr::like((*left).clone().try_into_ast(id)?, pattern)),
606                    Err(errs) => Err(FromJsonError::UnescapeError(errs.into())),
607                }
608            }
609            Expr::ExprNoExt(ExprNoExt::Is {
610                left,
611                entity_type,
612                in_expr,
613            }) => ast::Name::from_normalized_str(entity_type.as_str())
614                .map_err(FromJsonError::InvalidEntityType)
615                .and_then(|entity_type_name| {
616                    let left: ast::Expr = (*left).clone().try_into_ast(id.clone())?;
617                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
618                    match in_expr {
619                        // The AST doesn't have an `... is ... in ..` node, so
620                        // we represent it as a conjunction of `is` and `in`.
621                        Some(in_expr) => Ok(ast::Expr::and(
622                            is_expr,
623                            ast::Expr::is_in(left, (*in_expr).clone().try_into_ast(id)?),
624                        )),
625                        None => Ok(is_expr),
626                    }
627                }),
628            Expr::ExprNoExt(ExprNoExt::If {
629                cond_expr,
630                then_expr,
631                else_expr,
632            }) => Ok(ast::Expr::ite(
633                (*cond_expr).clone().try_into_ast(id.clone())?,
634                (*then_expr).clone().try_into_ast(id.clone())?,
635                (*else_expr).clone().try_into_ast(id)?,
636            )),
637            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
638                elements
639                    .into_iter()
640                    .map(|el| el.try_into_ast(id.clone()))
641                    .collect::<Result<Vec<_>, FromJsonError>>()?,
642            )),
643            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
644                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
645                #[allow(clippy::expect_used)]
646                Ok(ast::Expr::record(
647                    map.into_iter()
648                        .map(|(k, v)| Ok((k, v.try_into_ast(id.clone())?)))
649                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
650                )
651                .expect("can't have duplicate keys here because the input was already a HashMap"))
652            }
653            Expr::ExtFuncCall(ExtFuncCall { call }) => {
654                match call.len() {
655                    0 => Err(FromJsonError::MissingOperator),
656                    1 => {
657                        // PANIC SAFETY checked that `call.len() == 1`
658                        #[allow(clippy::expect_used)]
659                        let (fn_name, args) = call
660                            .into_iter()
661                            .next()
662                            .expect("already checked that len was 1");
663                        let fn_name = fn_name.parse().map_err(|errs| {
664                            JsonDeserializationError::ParseEscape {
665                                kind: EscapeKind::Extension,
666                                value: fn_name.to_string(),
667                                errs,
668                            }
669                        })?;
670                        Ok(ast::Expr::call_extension_fn(
671                            fn_name,
672                            args.into_iter()
673                                .map(|arg| arg.try_into_ast(id.clone()))
674                                .collect::<Result<_, _>>()?,
675                        ))
676                    }
677                    _ => Err(FromJsonError::MultipleOperators {
678                        ops: call.into_keys().collect(),
679                    }),
680                }
681            }
682        }
683    }
684}
685
686impl From<ast::Expr> for Expr {
687    fn from(expr: ast::Expr) -> Expr {
688        match expr.into_expr_kind() {
689            ast::ExprKind::Lit(lit) => lit.into(),
690            ast::ExprKind::Var(var) => var.into(),
691            ast::ExprKind::Slot(slot) => slot.into(),
692            ast::ExprKind::Unknown(ast::Unknown { name, .. }) => Expr::unknown(name),
693            ast::ExprKind::If {
694                test_expr,
695                then_expr,
696                else_expr,
697            } => Expr::ite(
698                Arc::unwrap_or_clone(test_expr).into(),
699                Arc::unwrap_or_clone(then_expr).into(),
700                Arc::unwrap_or_clone(else_expr).into(),
701            ),
702            ast::ExprKind::And { left, right } => Expr::and(
703                Arc::unwrap_or_clone(left).into(),
704                Arc::unwrap_or_clone(right).into(),
705            ),
706            ast::ExprKind::Or { left, right } => Expr::or(
707                Arc::unwrap_or_clone(left).into(),
708                Arc::unwrap_or_clone(right).into(),
709            ),
710            ast::ExprKind::UnaryApp { op, arg } => {
711                let arg = Arc::unwrap_or_clone(arg).into();
712                match op {
713                    ast::UnaryOp::Not => Expr::not(arg),
714                    ast::UnaryOp::Neg => Expr::neg(arg),
715                }
716            }
717            ast::ExprKind::BinaryApp { op, arg1, arg2 } => {
718                let arg1 = Arc::unwrap_or_clone(arg1).into();
719                let arg2 = Arc::unwrap_or_clone(arg2).into();
720                match op {
721                    ast::BinaryOp::Eq => Expr::eq(arg1, arg2),
722                    ast::BinaryOp::In => Expr::_in(arg1, arg2),
723                    ast::BinaryOp::Less => Expr::less(arg1, arg2),
724                    ast::BinaryOp::LessEq => Expr::lesseq(arg1, arg2),
725                    ast::BinaryOp::Add => Expr::add(arg1, arg2),
726                    ast::BinaryOp::Sub => Expr::sub(arg1, arg2),
727                    ast::BinaryOp::Mul => Expr::mul(arg1, arg2),
728                    ast::BinaryOp::Contains => Expr::contains(Arc::new(arg1), arg2),
729                    ast::BinaryOp::ContainsAll => Expr::contains_all(Arc::new(arg1), arg2),
730                    ast::BinaryOp::ContainsAny => Expr::contains_any(Arc::new(arg1), arg2),
731                }
732            }
733            ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
734                let args = Arc::unwrap_or_clone(args)
735                    .into_iter()
736                    .map(Into::into)
737                    .collect();
738                Expr::ext_call(fn_name.to_string().into(), args)
739            }
740            ast::ExprKind::GetAttr { expr, attr } => {
741                Expr::get_attr(Arc::unwrap_or_clone(expr).into(), attr)
742            }
743            ast::ExprKind::HasAttr { expr, attr } => {
744                Expr::has_attr(Arc::unwrap_or_clone(expr).into(), attr)
745            }
746            ast::ExprKind::Like { expr, pattern } => Expr::like(
747                Arc::unwrap_or_clone(expr).into(),
748                pattern.to_string().into(),
749            ),
750            ast::ExprKind::Is { expr, entity_type } => Expr::is_entity_type(
751                Arc::unwrap_or_clone(expr).into(),
752                entity_type.to_string().into(),
753            ),
754            ast::ExprKind::Set(set) => Expr::set(
755                Arc::unwrap_or_clone(set)
756                    .into_iter()
757                    .map(Into::into)
758                    .collect(),
759            ),
760            ast::ExprKind::Record(map) => Expr::record(
761                Arc::unwrap_or_clone(map)
762                    .into_iter()
763                    .map(|(k, v)| (k, v.into()))
764                    .collect(),
765            ),
766        }
767    }
768}
769
770impl From<ast::Literal> for Expr {
771    fn from(lit: ast::Literal) -> Expr {
772        Expr::lit(CedarValueJson::from_lit(lit))
773    }
774}
775
776impl From<ast::Var> for Expr {
777    fn from(var: ast::Var) -> Expr {
778        Expr::var(var)
779    }
780}
781
782impl From<ast::SlotId> for Expr {
783    fn from(slot: ast::SlotId) -> Expr {
784        Expr::slot(slot)
785    }
786}
787
788impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
789    type Error = ParseErrors;
790    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
791        match &*e.ok_or_missing()?.expr {
792            cst::ExprData::Or(node) => node.try_into(),
793            cst::ExprData::If(if_node, then_node, else_node) => {
794                let cond_expr = if_node.try_into()?;
795                let then_expr = then_node.try_into()?;
796                let else_expr = else_node.try_into()?;
797                Ok(Expr::ite(cond_expr, then_expr, else_expr))
798            }
799        }
800    }
801}
802
803impl TryFrom<&Node<Option<cst::Or>>> for Expr {
804    type Error = ParseErrors;
805    fn try_from(o: &Node<Option<cst::Or>>) -> Result<Expr, ParseErrors> {
806        let o_node = o.ok_or_missing()?;
807        let mut expr = (&o_node.initial).try_into()?;
808        for node in &o_node.extended {
809            let rhs = node.try_into()?;
810            expr = Expr::or(expr, rhs);
811        }
812        Ok(expr)
813    }
814}
815
816impl TryFrom<&Node<Option<cst::And>>> for Expr {
817    type Error = ParseErrors;
818    fn try_from(a: &Node<Option<cst::And>>) -> Result<Expr, ParseErrors> {
819        let a_node = a.ok_or_missing()?;
820        let mut expr = (&a_node.initial).try_into()?;
821        for node in &a_node.extended {
822            let rhs = node.try_into()?;
823            expr = Expr::and(expr, rhs);
824        }
825        Ok(expr)
826    }
827}
828
829impl TryFrom<&Node<Option<cst::Relation>>> for Expr {
830    type Error = ParseErrors;
831    fn try_from(r: &Node<Option<cst::Relation>>) -> Result<Expr, ParseErrors> {
832        match r.ok_or_missing()? {
833            cst::Relation::Common { initial, extended } => {
834                let mut expr = initial.try_into()?;
835                for (op, node) in extended {
836                    let rhs = node.try_into()?;
837                    match op {
838                        cst::RelOp::Eq => {
839                            expr = Expr::eq(expr, rhs);
840                        }
841                        cst::RelOp::NotEq => {
842                            expr = Expr::noteq(expr, rhs);
843                        }
844                        cst::RelOp::In => {
845                            expr = Expr::_in(expr, rhs);
846                        }
847                        cst::RelOp::Less => {
848                            expr = Expr::less(expr, rhs);
849                        }
850                        cst::RelOp::LessEq => {
851                            expr = Expr::lesseq(expr, rhs);
852                        }
853                        cst::RelOp::Greater => {
854                            expr = Expr::greater(expr, rhs);
855                        }
856                        cst::RelOp::GreaterEq => {
857                            expr = Expr::greatereq(expr, rhs);
858                        }
859                        cst::RelOp::InvalidSingleEq => {
860                            return Err(ToASTError::new(
861                                ToASTErrorKind::InvalidSingleEq,
862                                r.loc.clone(),
863                            )
864                            .into());
865                        }
866                    }
867                }
868                Ok(expr)
869            }
870            cst::Relation::Has { target, field } => {
871                let target_expr = target.try_into()?;
872                let mut errs = ParseErrors::new();
873                if let Some(field_expr) = field.to_expr_or_special(&mut errs) {
874                    if let Some(attr) = field_expr.into_valid_attr(&mut errs) {
875                        return Ok(Expr::has_attr(target_expr, attr));
876                    }
877                }
878                if errs.is_empty() {
879                    Err(field
880                        .to_ast_err(ToASTErrorKind::InvalidAttribute(
881                            field.ok_or_missing()?.to_smolstr(),
882                        ))
883                        .into())
884                } else {
885                    Err(errs)
886                }
887            }
888            cst::Relation::Like { target, pattern } => {
889                let target_expr = target.try_into()?;
890                let pat_expr: Expr = pattern.try_into()?;
891                let pat_str = pat_expr.into_string_literal().map_err(|e| {
892                    pattern.to_ast_err(ToASTErrorKind::InvalidPattern(
893                        serde_json::to_string(&e).unwrap_or_else(|_| "<malformed est>".to_string()),
894                    ))
895                })?;
896                Ok(Expr::like(target_expr, pat_str))
897            }
898            cst::Relation::IsIn {
899                target,
900                entity_type,
901                in_entity,
902            } => {
903                let target = target.try_into()?;
904                let type_str = entity_type.ok_or_missing()?.to_string().into();
905                match in_entity {
906                    Some(in_entity) => Ok(Expr::is_entity_type_in(
907                        target,
908                        type_str,
909                        in_entity.try_into()?,
910                    )),
911                    None => Ok(Expr::is_entity_type(target, type_str)),
912                }
913            }
914        }
915    }
916}
917
918impl TryFrom<&Node<Option<cst::Add>>> for Expr {
919    type Error = ParseErrors;
920    fn try_from(a: &Node<Option<cst::Add>>) -> Result<Expr, ParseErrors> {
921        let a_node = a.ok_or_missing()?;
922        let mut expr = (&a_node.initial).try_into()?;
923        for (op, node) in &a_node.extended {
924            let rhs = node.try_into()?;
925            match op {
926                cst::AddOp::Plus => {
927                    expr = Expr::add(expr, rhs);
928                }
929                cst::AddOp::Minus => {
930                    expr = Expr::sub(expr, rhs);
931                }
932            }
933        }
934        Ok(expr)
935    }
936}
937
938impl TryFrom<&Node<Option<cst::Mult>>> for Expr {
939    type Error = ParseErrors;
940    fn try_from(m: &Node<Option<cst::Mult>>) -> Result<Expr, ParseErrors> {
941        let m_node = m.ok_or_missing()?;
942        let mut expr = (&m_node.initial).try_into()?;
943        for (op, node) in &m_node.extended {
944            let rhs = node.try_into()?;
945            match op {
946                cst::MultOp::Times => {
947                    expr = Expr::mul(expr, rhs);
948                }
949                cst::MultOp::Divide => {
950                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedDivision).into())
951                }
952                cst::MultOp::Mod => {
953                    return Err(node.to_ast_err(ToASTErrorKind::UnsupportedModulo).into())
954                }
955            }
956        }
957        Ok(expr)
958    }
959}
960
961impl TryFrom<&Node<Option<cst::Unary>>> for Expr {
962    type Error = ParseErrors;
963    fn try_from(u: &Node<Option<cst::Unary>>) -> Result<Expr, ParseErrors> {
964        let u_node = u.ok_or_missing()?;
965
966        match u_node.op {
967            Some(cst::NegOp::Bang(num_bangs)) => {
968                let inner = (&u_node.item).try_into()?;
969                match num_bangs {
970                    0 => Ok(inner),
971                    1 => Ok(Expr::not(inner)),
972                    2 => Ok(Expr::not(Expr::not(inner))),
973                    3 => Ok(Expr::not(Expr::not(Expr::not(inner)))),
974                    4 => Ok(Expr::not(Expr::not(Expr::not(Expr::not(inner))))),
975                    _ => Err(u
976                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
977                        .into()),
978                }
979            }
980            Some(cst::NegOp::Dash(0)) => Ok((&u_node.item).try_into()?),
981            Some(cst::NegOp::Dash(mut num_dashes)) => {
982                let inner = match &u_node.item.to_lit() {
983                    Some(cst::Literal::Num(num))
984                        if num
985                            .checked_sub(1)
986                            .map(|y| y == InputInteger::MAX as u64)
987                            .unwrap_or(false) =>
988                    {
989                        num_dashes -= 1;
990                        Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(InputInteger::MIN)))
991                    }
992                    _ => (&u_node.item).try_into()?,
993                };
994                let inner = match inner {
995                    Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::Long(n)))
996                        if n != InputInteger::MIN =>
997                    {
998                        // collapse the negated literal into a single negative literal.
999                        // Important for multiplication-by-constant to allow multiplication by negative constants.
1000                        num_dashes -= 1;
1001                        Expr::lit(CedarValueJson::Long(-n))
1002                    }
1003                    _ => inner,
1004                };
1005                match num_dashes {
1006                    0 => Ok(inner),
1007                    1 => Ok(Expr::neg(inner)),
1008                    2 => {
1009                        // not safe to collapse `--` to nothing
1010                        Ok(Expr::neg(Expr::neg(inner)))
1011                    }
1012                    3 => Ok(Expr::neg(Expr::neg(Expr::neg(inner)))),
1013                    4 => Ok(Expr::neg(Expr::neg(Expr::neg(Expr::neg(inner))))),
1014                    _ => Err(u
1015                        .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1016                        .into()),
1017                }
1018            }
1019            Some(cst::NegOp::OverBang) => Err(u
1020                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Not))
1021                .into()),
1022            Some(cst::NegOp::OverDash) => Err(u
1023                .to_ast_err(ToASTErrorKind::UnaryOpLimit(ast::UnaryOp::Neg))
1024                .into()),
1025            None => Ok((&u_node.item).try_into()?),
1026        }
1027    }
1028}
1029
1030/// Convert the given `cst::Primary` into either a (possibly namespaced)
1031/// function name, or an `Expr`.
1032///
1033/// (Upstream, the case where the `Primary` is a function name needs special
1034/// handling, because in that case it is not a valid expression. In all other
1035/// cases a `Primary` can be converted into an `Expr`.)
1036fn interpret_primary(
1037    p: &Node<Option<cst::Primary>>,
1038) -> Result<Either<ast::Name, Expr>, ParseErrors> {
1039    match p.ok_or_missing()? {
1040        cst::Primary::Literal(lit) => Ok(Either::Right(lit.try_into()?)),
1041        cst::Primary::Ref(node) => match node.ok_or_missing()? {
1042            cst::Ref::Uid { path, eid } => {
1043                let mut errs = ParseErrors::new();
1044                let maybe_name = path.to_name(&mut errs);
1045                let maybe_eid = eid.as_valid_string(&mut errs);
1046
1047                match (maybe_name, maybe_eid) {
1048                    (Some(name), Some(eid)) => {
1049                        Ok(Either::Right(Expr::lit(CedarValueJson::EntityEscape {
1050                            __entity: TypeAndId::from(ast::EntityUID::from_components(
1051                                name,
1052                                ast::Eid::new(eid.clone()),
1053                                None,
1054                            )),
1055                        })))
1056                    }
1057                    _ => Err(errs),
1058                }
1059            }
1060            cst::Ref::Ref { .. } => Err(node
1061                .to_ast_err(ToASTErrorKind::UnsupportedEntityLiterals)
1062                .into()),
1063        },
1064        cst::Primary::Name(node) => {
1065            let name = node.ok_or_missing()?;
1066            let base_name = name.name.ok_or_missing()?;
1067            match (&name.path[..], base_name) {
1068                (&[], cst::Ident::Principal) => Ok(Either::Right(Expr::var(ast::Var::Principal))),
1069                (&[], cst::Ident::Action) => Ok(Either::Right(Expr::var(ast::Var::Action))),
1070                (&[], cst::Ident::Resource) => Ok(Either::Right(Expr::var(ast::Var::Resource))),
1071                (&[], cst::Ident::Context) => Ok(Either::Right(Expr::var(ast::Var::Context))),
1072                (path, cst::Ident::Ident(id)) => Ok(Either::Left(ast::Name::new(
1073                    id.parse()?,
1074                    path.iter()
1075                        .map(|node| {
1076                            node.ok_or_missing()
1077                                .map_err(Into::into)
1078                                .and_then(|id| id.to_string().parse().map_err(Into::into))
1079                        })
1080                        .collect::<Result<Vec<ast::Id>, ParseErrors>>()?,
1081                    Some(node.loc.clone()),
1082                ))),
1083                (path, id) => {
1084                    let (l, r, src) = match (path.first(), path.last()) {
1085                        (Some(l), Some(r)) => (
1086                            l.loc.start(),
1087                            r.loc.end() + ident_to_str_len(id),
1088                            Arc::clone(&l.loc.src),
1089                        ),
1090                        (_, _) => (0, 0, Arc::from("")),
1091                    };
1092                    Err(ToASTError::new(
1093                        ToASTErrorKind::InvalidExpression(cst::Name {
1094                            path: path.to_vec(),
1095                            name: Node::with_source_loc(Some(id.clone()), node.loc.span(l..r)),
1096                        }),
1097                        Loc::new(l..r, src),
1098                    )
1099                    .into())
1100                }
1101            }
1102        }
1103        cst::Primary::Slot(node) => Ok(Either::Right(Expr::slot(
1104            node.ok_or_missing()?
1105                .try_into()
1106                .map_err(|e| node.to_ast_err(e))?,
1107        ))),
1108        cst::Primary::Expr(e) => Ok(Either::Right(e.try_into()?)),
1109        cst::Primary::EList(nodes) => nodes
1110            .iter()
1111            .map(|node| node.try_into())
1112            .collect::<Result<Vec<Expr>, _>>()
1113            .map(Expr::set)
1114            .map(Either::Right),
1115        cst::Primary::RInits(nodes) => nodes
1116            .iter()
1117            .map(|node| {
1118                let cst::RecInit(k, v) = node.ok_or_missing()?;
1119                let mut errs = ParseErrors::new();
1120                let s = k
1121                    .to_expr_or_special(&mut errs)
1122                    .and_then(|es| es.into_valid_attr(&mut errs));
1123                if !errs.is_empty() {
1124                    Err(errs)
1125                } else {
1126                    match s {
1127                        Some(s) => Ok((s, v.try_into()?)),
1128                        None => Err(node.to_ast_err(ToASTErrorKind::MissingNodeData).into()),
1129                    }
1130                }
1131            })
1132            .collect::<Result<HashMap<SmolStr, Expr>, ParseErrors>>()
1133            .map(Expr::record)
1134            .map(Either::Right),
1135    }
1136}
1137
1138impl TryFrom<&Node<Option<cst::Member>>> for Expr {
1139    type Error = ParseErrors;
1140    fn try_from(m: &Node<Option<cst::Member>>) -> Result<Expr, ParseErrors> {
1141        let m_node = m.ok_or_missing()?;
1142        let mut item: Either<ast::Name, Expr> = interpret_primary(&m_node.item)?;
1143        for access in &m_node.access {
1144            match access.ok_or_missing()? {
1145                cst::MemAccess::Field(node) => {
1146                    let mut errs = ParseErrors::new();
1147                    let field = node.to_valid_ident(&mut errs);
1148                    // rule out invalid identifiers (`Ident::Invalid` and reserved Ids)
1149                    if !errs.is_empty() {
1150                        return Err(errs);
1151                    }
1152                    match field {
1153                        Some(id) => {
1154                            item = match item {
1155                                Either::Left(name) => {
1156                                    return Err(node
1157                                        .to_ast_err(ToASTErrorKind::InvalidAccess(
1158                                            name,
1159                                            id.to_smolstr(),
1160                                        ))
1161                                        .into())
1162                                }
1163                                Either::Right(expr) => {
1164                                    Either::Right(Expr::get_attr(expr, id.to_smolstr()))
1165                                }
1166                            };
1167                        }
1168                        None => {
1169                            return Err(node
1170                                .to_ast_err(ToASTErrorKind::InvalidIdentifier(
1171                                    node.ok_or_missing()?.to_string(),
1172                                ))
1173                                .into())
1174                        }
1175                    }
1176                }
1177                cst::MemAccess::Call(args) => {
1178                    // we have item(args).  We hope item is either:
1179                    //   - an `ast::Name`, in which case we have a standard function call
1180                    //   - or an expr of the form `x.y`, in which case y is the method
1181                    //      name and not a field name. In the previous iteration of the
1182                    //      `for` loop we would have made `item` equal to
1183                    //      `Expr::GetAttr(x, y)`. Now we have to undo that to make a
1184                    //      method call instead.
1185                    //   - any other expression: it's an illegal call as the target is a higher order expression
1186                    item = match item {
1187                        Either::Left(name) => Either::Right(Expr::ext_call(
1188                            name.to_string().into(),
1189                            args.iter()
1190                                .map(|node| node.try_into())
1191                                .collect::<Result<Vec<_>, _>>()?,
1192                        )),
1193                        Either::Right(Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr })) => {
1194                            let args = args.iter().map(|node| node.try_into()).collect::<Result<
1195                                Vec<Expr>,
1196                                ParseErrors,
1197                            >>(
1198                            )?;
1199                            let args = args.into_iter();
1200                            match attr.as_str() {
1201                                "contains" => Either::Right(Expr::contains(
1202                                    left,
1203                                    extract_single_argument(args, "contains()", &access.loc)?,
1204                                )),
1205                                "containsAll" => Either::Right(Expr::contains_all(
1206                                    left,
1207                                    extract_single_argument(args, "containsAll()", &access.loc)?,
1208                                )),
1209                                "containsAny" => Either::Right(Expr::contains_any(
1210                                    left,
1211                                    extract_single_argument(args, "containsAny()", &access.loc)?,
1212                                )),
1213                                _ => {
1214                                    // have to add the "receiver" argument as
1215                                    // first in the list for the method call
1216                                    let mut args = args.collect::<Vec<_>>();
1217                                    args.insert(0, Arc::unwrap_or_clone(left));
1218                                    Either::Right(Expr::ext_call(attr, args))
1219                                }
1220                            }
1221                        }
1222                        _ => return Err(access.to_ast_err(ToASTErrorKind::ExpressionCall).into()),
1223                    };
1224                }
1225                cst::MemAccess::Index(node) => {
1226                    let s = Expr::try_from(node)?
1227                        .into_string_literal()
1228                        .map_err(|_| node.to_ast_err(ToASTErrorKind::NonStringIndex))?;
1229                    item = match item {
1230                        Either::Left(name) => {
1231                            return Err(node
1232                                .to_ast_err(ToASTErrorKind::InvalidIndex(name, s))
1233                                .into())
1234                        }
1235                        Either::Right(expr) => Either::Right(Expr::get_attr(expr, s)),
1236                    };
1237                }
1238            }
1239        }
1240        match item {
1241            Either::Left(_) => Err(m.to_ast_err(ToASTErrorKind::MembershipInvariantViolation))?,
1242            Either::Right(expr) => Ok(expr),
1243        }
1244    }
1245}
1246
1247/// Return the single argument in `args` iterator, or return a wrong arity error
1248/// if the iterator has 0 elements or more than 1 element.
1249pub fn extract_single_argument<T>(
1250    args: impl ExactSizeIterator<Item = T>,
1251    fn_name: &'static str,
1252    loc: &Loc,
1253) -> Result<T, ToASTError> {
1254    let mut iter = args.fuse().peekable();
1255    let first = iter.next();
1256    let second = iter.peek();
1257    match (first, second) {
1258        (None, _) => Err(ToASTError::new(
1259            ToASTErrorKind::wrong_arity(fn_name, 1, 0),
1260            loc.clone(),
1261        )),
1262        (Some(_), Some(_)) => Err(ToASTError::new(
1263            ToASTErrorKind::wrong_arity(fn_name, 1, iter.len() + 1),
1264            loc.clone(),
1265        )),
1266        (Some(first), None) => Ok(first),
1267    }
1268}
1269
1270impl TryFrom<&Node<Option<cst::Literal>>> for Expr {
1271    type Error = ParseErrors;
1272    fn try_from(lit: &Node<Option<cst::Literal>>) -> Result<Expr, ParseErrors> {
1273        match lit.ok_or_missing()? {
1274            cst::Literal::True => Ok(Expr::lit(CedarValueJson::Bool(true))),
1275            cst::Literal::False => Ok(Expr::lit(CedarValueJson::Bool(false))),
1276            cst::Literal::Num(n) => Ok(Expr::lit(CedarValueJson::Long(
1277                (*n).try_into()
1278                    .map_err(|_| lit.to_ast_err(ToASTErrorKind::IntegerLiteralTooLarge(*n)))?,
1279            ))),
1280            cst::Literal::Str(node) => match node.ok_or_missing()? {
1281                cst::Str::String(s) => Ok(Expr::lit(CedarValueJson::String(s.clone()))),
1282                cst::Str::Invalid(invalid_str) => Err(node
1283                    .to_ast_err(ToASTErrorKind::InvalidString(invalid_str.to_string()))
1284                    .into()),
1285            },
1286        }
1287    }
1288}
1289
1290impl TryFrom<&Node<Option<cst::Name>>> for Expr {
1291    type Error = ParseErrors;
1292    fn try_from(name: &Node<Option<cst::Name>>) -> Result<Expr, ParseErrors> {
1293        let name_node = name.ok_or_missing()?;
1294        let base_name = name_node.name.ok_or_missing()?;
1295        match (&name_node.path[..], base_name) {
1296            (&[], cst::Ident::Principal) => Ok(Expr::var(ast::Var::Principal)),
1297            (&[], cst::Ident::Action) => Ok(Expr::var(ast::Var::Action)),
1298            (&[], cst::Ident::Resource) => Ok(Expr::var(ast::Var::Resource)),
1299            (&[], cst::Ident::Context) => Ok(Expr::var(ast::Var::Context)),
1300            (path, id) => {
1301                let loc = match (path.first(), path.last()) {
1302                    (Some(lnode), Some(rnode)) => {
1303                        let l = lnode.loc.start();
1304                        let r = rnode.loc.end() + ident_to_str_len(id);
1305                        Loc::new(l..r, Arc::clone(&lnode.loc.src))
1306                    }
1307                    (_, _) => Loc::new(0, Arc::from("")),
1308                };
1309                Err(name
1310                    .to_ast_err(ToASTErrorKind::InvalidExpression(cst::Name {
1311                        path: path.to_vec(),
1312                        name: Node::with_source_loc(Some(id.clone()), loc),
1313                    }))
1314                    .into())
1315            }
1316        }
1317    }
1318}
1319
1320/// Get the string length of an `Ident`. Used to print the source location for error messages
1321fn ident_to_str_len(i: &Ident) -> usize {
1322    match i {
1323        Ident::Principal => 9,
1324        Ident::Action => 6,
1325        Ident::Resource => 8,
1326        Ident::Context => 7,
1327        Ident::True => 4,
1328        Ident::False => 5,
1329        Ident::Permit => 6,
1330        Ident::Forbid => 6,
1331        Ident::When => 4,
1332        Ident::Unless => 6,
1333        Ident::In => 2,
1334        Ident::Has => 3,
1335        Ident::Like => 4,
1336        Ident::If => 2,
1337        Ident::Then => 4,
1338        Ident::Else => 4,
1339        Ident::Ident(s) => s.len(),
1340        Ident::Invalid(s) => s.len(),
1341        Ident::Is => 2,
1342    }
1343}
1344
1345impl std::fmt::Display for Expr {
1346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1347        match self {
1348            Self::ExprNoExt(e) => write!(f, "{e}"),
1349            Self::ExtFuncCall(e) => write!(f, "{e}"),
1350        }
1351    }
1352}
1353
1354fn display_cedarvaluejson(f: &mut std::fmt::Formatter<'_>, v: &CedarValueJson) -> std::fmt::Result {
1355    match v {
1356        // Add parentheses around negative numeric literals otherwise
1357        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1358        CedarValueJson::Long(n) if *n < 0 => write!(f, "({n})"),
1359        CedarValueJson::Long(n) => write!(f, "{n}"),
1360        CedarValueJson::Bool(b) => write!(f, "{b}"),
1361        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1362        CedarValueJson::EntityEscape { __entity } => {
1363            match ast::EntityUID::try_from(__entity.clone()) {
1364                Ok(euid) => write!(f, "{euid}"),
1365                Err(e) => write!(f, "(invalid entity uid: {})", e),
1366            }
1367        }
1368        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1369        CedarValueJson::ExtnEscape {
1370            __extn: FnAndArg { ext_fn, arg },
1371        } => {
1372            // search for the name and callstyle
1373            let style = Extensions::all_available().all_funcs().find_map(|f| {
1374                if &f.name().to_string() == ext_fn {
1375                    Some(f.style())
1376                } else {
1377                    None
1378                }
1379            });
1380            match style {
1381                Some(ast::CallStyle::MethodStyle) => {
1382                    display_cedarvaluejson(f, arg)?;
1383                    write!(f, ".{ext_fn}()")?;
1384                    Ok(())
1385                }
1386                Some(ast::CallStyle::FunctionStyle) | None => {
1387                    write!(f, "{ext_fn}(")?;
1388                    display_cedarvaluejson(f, arg)?;
1389                    write!(f, ")")?;
1390                    Ok(())
1391                }
1392            }
1393        }
1394        CedarValueJson::Set(v) => {
1395            write!(f, "[")?;
1396            for (i, val) in v.iter().enumerate() {
1397                display_cedarvaluejson(f, val)?;
1398                if i < (v.len() - 1) {
1399                    write!(f, ", ")?;
1400                }
1401            }
1402            write!(f, "]")?;
1403            Ok(())
1404        }
1405        CedarValueJson::Record(m) => {
1406            write!(f, "{{")?;
1407            for (i, (k, v)) in m.iter().enumerate() {
1408                write!(f, "\"{}\": ", k.escape_debug())?;
1409                display_cedarvaluejson(f, v)?;
1410                if i < (m.len() - 1) {
1411                    write!(f, ", ")?;
1412                }
1413            }
1414            write!(f, "}}")?;
1415            Ok(())
1416        }
1417        CedarValueJson::Null => {
1418            write!(f, "null")?;
1419            Ok(())
1420        }
1421    }
1422}
1423
1424impl std::fmt::Display for ExprNoExt {
1425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1426        match &self {
1427            ExprNoExt::Value(v) => display_cedarvaluejson(f, v),
1428            ExprNoExt::Var(v) => write!(f, "{v}"),
1429            ExprNoExt::Slot(id) => write!(f, "{id}"),
1430            ExprNoExt::Unknown { name } => write!(f, "unknown(\"{}\")", name.escape_debug()),
1431            ExprNoExt::Not { arg } => write!(f, "!{}", maybe_with_parens(arg)),
1432            ExprNoExt::Neg { arg } => {
1433                // Always add parentheses instead of calling
1434                // `maybe_with_parens`.
1435                // This makes sure that we always get a negation operation back
1436                // (as opposed to e.g., a negative number) when parsing the
1437                // printed form, thus preserving the round-tripping property.
1438                write!(f, "-({arg})")
1439            }
1440            ExprNoExt::Eq { left, right } => write!(
1441                f,
1442                "{} == {}",
1443                maybe_with_parens(left),
1444                maybe_with_parens(right)
1445            ),
1446            ExprNoExt::NotEq { left, right } => write!(
1447                f,
1448                "{} != {}",
1449                maybe_with_parens(left),
1450                maybe_with_parens(right)
1451            ),
1452            ExprNoExt::In { left, right } => write!(
1453                f,
1454                "{} in {}",
1455                maybe_with_parens(left),
1456                maybe_with_parens(right)
1457            ),
1458            ExprNoExt::Less { left, right } => write!(
1459                f,
1460                "{} < {}",
1461                maybe_with_parens(left),
1462                maybe_with_parens(right)
1463            ),
1464            ExprNoExt::LessEq { left, right } => write!(
1465                f,
1466                "{} <= {}",
1467                maybe_with_parens(left),
1468                maybe_with_parens(right)
1469            ),
1470            ExprNoExt::Greater { left, right } => write!(
1471                f,
1472                "{} > {}",
1473                maybe_with_parens(left),
1474                maybe_with_parens(right)
1475            ),
1476            ExprNoExt::GreaterEq { left, right } => write!(
1477                f,
1478                "{} >= {}",
1479                maybe_with_parens(left),
1480                maybe_with_parens(right)
1481            ),
1482            ExprNoExt::And { left, right } => write!(
1483                f,
1484                "{} && {}",
1485                maybe_with_parens(left),
1486                maybe_with_parens(right)
1487            ),
1488            ExprNoExt::Or { left, right } => write!(
1489                f,
1490                "{} || {}",
1491                maybe_with_parens(left),
1492                maybe_with_parens(right)
1493            ),
1494            ExprNoExt::Add { left, right } => write!(
1495                f,
1496                "{} + {}",
1497                maybe_with_parens(left),
1498                maybe_with_parens(right)
1499            ),
1500            ExprNoExt::Sub { left, right } => write!(
1501                f,
1502                "{} - {}",
1503                maybe_with_parens(left),
1504                maybe_with_parens(right)
1505            ),
1506            ExprNoExt::Mul { left, right } => write!(
1507                f,
1508                "{} * {}",
1509                maybe_with_parens(left),
1510                maybe_with_parens(right)
1511            ),
1512            ExprNoExt::Contains { left, right } => {
1513                write!(f, "{}.contains({right})", maybe_with_parens(left))
1514            }
1515            ExprNoExt::ContainsAll { left, right } => {
1516                write!(f, "{}.containsAll({right})", maybe_with_parens(left))
1517            }
1518            ExprNoExt::ContainsAny { left, right } => {
1519                write!(f, "{}.containsAny({right})", maybe_with_parens(left))
1520            }
1521            ExprNoExt::GetAttr { left, attr } => write!(
1522                f,
1523                "{}[\"{}\"]",
1524                maybe_with_parens(left),
1525                attr.escape_debug()
1526            ),
1527            ExprNoExt::HasAttr { left, attr } => write!(
1528                f,
1529                "{} has \"{}\"",
1530                maybe_with_parens(left),
1531                attr.escape_debug()
1532            ),
1533            ExprNoExt::Like { left, pattern } => {
1534                write!(f, "{} like \"{}\"", maybe_with_parens(left), pattern) // intentionally not using .escape_debug() for pattern
1535            }
1536            ExprNoExt::Is {
1537                left,
1538                entity_type,
1539                in_expr,
1540            } => {
1541                write!(f, "{} is {}", maybe_with_parens(left), entity_type)?;
1542                match in_expr {
1543                    Some(in_expr) => write!(f, " in {}", maybe_with_parens(in_expr)),
1544                    None => Ok(()),
1545                }
1546            }
1547            ExprNoExt::If {
1548                cond_expr,
1549                then_expr,
1550                else_expr,
1551            } => write!(
1552                f,
1553                "if {} then {} else {}",
1554                maybe_with_parens(cond_expr),
1555                maybe_with_parens(then_expr),
1556                maybe_with_parens(else_expr)
1557            ),
1558            ExprNoExt::Set(v) => write!(f, "[{}]", v.iter().join(", ")),
1559            ExprNoExt::Record(m) => write!(
1560                f,
1561                "{{{}}}",
1562                m.iter()
1563                    .map(|(k, v)| format!("\"{}\": {}", k.escape_debug(), v))
1564                    .join(", ")
1565            ),
1566        }
1567    }
1568}
1569
1570impl std::fmt::Display for ExtFuncCall {
1571    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1572        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1573        #[allow(clippy::unreachable)]
1574        let Some((fn_name, args)) = self.call.iter().next() else {
1575            unreachable!("invariant violated: empty ExtFuncCall")
1576        };
1577        // search for the name and callstyle
1578        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1579            if &ext_fn.name().to_string() == fn_name {
1580                Some(ext_fn.style())
1581            } else {
1582                None
1583            }
1584        });
1585        match (style, args.iter().next()) {
1586            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1587                write!(
1588                    f,
1589                    "{}.{}({})",
1590                    maybe_with_parens(receiver),
1591                    fn_name,
1592                    args.iter().skip(1).join(", ")
1593                )
1594            }
1595            (_, _) => {
1596                write!(f, "{}({})", fn_name, args.iter().join(", "))
1597            }
1598        }
1599    }
1600}
1601
1602/// returns the `Display` representation of the Expr, adding parens around
1603/// the entire string if necessary.
1604/// E.g., won't add parens for constants or `principal` etc, but will for things
1605/// like `(2 < 5)`.
1606/// When in doubt, add the parens.
1607fn maybe_with_parens(expr: &Expr) -> String {
1608    match expr {
1609        Expr::ExprNoExt(ExprNoExt::Value(_)) => expr.to_string(),
1610        Expr::ExprNoExt(ExprNoExt::Var(_)) => expr.to_string(),
1611        Expr::ExprNoExt(ExprNoExt::Slot(_)) => expr.to_string(),
1612        Expr::ExprNoExt(ExprNoExt::Unknown { .. }) => expr.to_string(),
1613        Expr::ExprNoExt(ExprNoExt::Not { .. }) => {
1614            // we want parens here because things like parse((!x).y)
1615            // would be printed into !x.y which has a different meaning
1616            format!("({expr})")
1617        }
1618        Expr::ExprNoExt(ExprNoExt::Neg { .. }) => {
1619            // we want parens here because things like parse((-x).y)
1620            // would be printed into -x.y which has a different meaning
1621            format!("({expr})")
1622        }
1623        Expr::ExprNoExt(ExprNoExt::Eq { .. }) => format!("({expr})"),
1624        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) => format!("({expr})"),
1625        Expr::ExprNoExt(ExprNoExt::In { .. }) => format!("({expr})"),
1626        Expr::ExprNoExt(ExprNoExt::Less { .. }) => format!("({expr})"),
1627        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) => format!("({expr})"),
1628        Expr::ExprNoExt(ExprNoExt::Greater { .. }) => format!("({expr})"),
1629        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) => format!("({expr})"),
1630        Expr::ExprNoExt(ExprNoExt::And { .. }) => format!("({expr})"),
1631        Expr::ExprNoExt(ExprNoExt::Or { .. }) => format!("({expr})"),
1632        Expr::ExprNoExt(ExprNoExt::Add { .. }) => format!("({expr})"),
1633        Expr::ExprNoExt(ExprNoExt::Sub { .. }) => format!("({expr})"),
1634        Expr::ExprNoExt(ExprNoExt::Mul { .. }) => format!("({expr})"),
1635        Expr::ExprNoExt(ExprNoExt::Contains { .. }) => format!("({expr})"),
1636        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) => format!("({expr})"),
1637        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) => format!("({expr})"),
1638        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) => format!("({expr})"),
1639        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) => format!("({expr})"),
1640        Expr::ExprNoExt(ExprNoExt::Like { .. }) => format!("({expr})"),
1641        Expr::ExprNoExt(ExprNoExt::Is { .. }) => format!("({expr})"),
1642        Expr::ExprNoExt(ExprNoExt::If { .. }) => format!("({expr})"),
1643        Expr::ExprNoExt(ExprNoExt::Set(_)) => expr.to_string(),
1644        Expr::ExprNoExt(ExprNoExt::Record(_)) => expr.to_string(),
1645        Expr::ExtFuncCall { .. } => format!("({expr})"),
1646    }
1647}
1648
1649#[cfg(test)]
1650// PANIC SAFETY: this is unit test code
1651#[allow(clippy::indexing_slicing)]
1652// PANIC SAFETY: Unit Test Code
1653#[allow(clippy::panic)]
1654mod test {
1655    use crate::parser::err::ParseError;
1656
1657    use super::*;
1658    use cool_asserts::assert_matches;
1659
1660    #[test]
1661    fn test_invalid_expr_from_cst_name() {
1662        let src = "some_long_str";
1663        let path = vec![Node::with_source_loc(
1664            Some(cst::Ident::Ident(src.into())),
1665            Loc::new(0..12, Arc::from(src)),
1666        )];
1667        let name = Node::with_source_loc(Some(cst::Ident::Else), Loc::new(13..16, Arc::from(src)));
1668        let cst_name = Node::with_source_loc(
1669            Some(cst::Name { path, name }),
1670            Loc::new(0..16, Arc::from(src)),
1671        );
1672
1673        assert_matches!(Expr::try_from(&cst_name), Err(e) => {
1674            assert!(e.len() == 1);
1675            assert_matches!(&e[0],
1676                ParseError::ToAST(to_ast_error) => {
1677                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::InvalidExpression(e) => {
1678                        println!("{e:?}");
1679                        assert_eq!(e.name.loc.end(), 16);
1680                    });
1681                }
1682            );
1683        });
1684    }
1685}