cedar_policy_core/ast/
restricted_expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{
18    EntityUID, Expr, ExprKind, ExpressionConstructionError, Literal, Name, PartialValue, Type,
19    Unknown, Value, ValueKind,
20};
21use crate::entities::json::err::JsonSerializationError;
22use crate::extensions::Extensions;
23use crate::parser::err::ParseErrors;
24use crate::parser::{self, Loc};
25use miette::Diagnostic;
26use smol_str::{SmolStr, ToSmolStr};
27use std::hash::{Hash, Hasher};
28use std::ops::Deref;
29use std::sync::Arc;
30use thiserror::Error;
31
32/// A few places in Core use these "restricted expressions" (for lack of a
33/// better term) which are in some sense the minimal subset of `Expr` required
34/// to express all possible `Value`s.
35///
36/// Specifically, "restricted" expressions are
37/// defined as expressions containing only the following:
38///   - bool, int, and string literals
39///   - literal EntityUIDs such as User::"alice"
40///   - extension function calls, where the arguments must be other things
41///     on this list
42///   - set and record literals, where the values must be other things on
43///     this list
44///
45/// That means the following are not allowed in "restricted" expressions:
46///   - `principal`, `action`, `resource`, `context`
47///   - builtin operators and functions, including `.`, `in`, `has`, `like`,
48///     `.contains()`
49///   - if-then-else expressions
50///
51/// These restrictions represent the expressions that are allowed to appear as
52/// attribute values in `Slice` and `Context`.
53#[derive(Hash, Debug, Clone, PartialEq, Eq)]
54pub struct RestrictedExpr(Expr);
55
56impl RestrictedExpr {
57    /// Create a new `RestrictedExpr` from an `Expr`.
58    ///
59    /// This function is "safe" in the sense that it will verify that the
60    /// provided `expr` does indeed qualify as a "restricted" expression,
61    /// returning an error if not.
62    ///
63    /// Note this check requires recursively walking the AST. For a version of
64    /// this function that doesn't perform this check, see `new_unchecked()`
65    /// below.
66    pub fn new(expr: Expr) -> Result<Self, RestrictedExpressionError> {
67        is_restricted(&expr)?;
68        Ok(Self(expr))
69    }
70
71    /// Create a new `RestrictedExpr` from an `Expr`, where the caller is
72    /// responsible for ensuring that the `Expr` is a valid "restricted
73    /// expression". If it is not, internal invariants will be violated, which
74    /// may lead to other errors later, panics, or even incorrect results.
75    ///
76    /// For a "safer" version of this function that returns an error for invalid
77    /// inputs, see `new()` above.
78    pub fn new_unchecked(expr: Expr) -> Self {
79        // in debug builds, this does the check anyway, panicking if it fails
80        if cfg!(debug_assertions) {
81            // PANIC SAFETY: We're in debug mode and panicking intentionally
82            #[allow(clippy::unwrap_used)]
83            Self::new(expr).unwrap()
84        } else {
85            Self(expr)
86        }
87    }
88
89    /// Return the `RestrictedExpr`, but with the new `source_loc` (or `None`).
90    pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
91        Self(self.0.with_maybe_source_loc(source_loc))
92    }
93
94    /// Create a `RestrictedExpr` that's just a single `Literal`.
95    ///
96    /// Note that you can pass this a `Literal`, an `Integer`, a `String`, etc.
97    pub fn val(v: impl Into<Literal>) -> Self {
98        // All literals are valid restricted-exprs
99        Self::new_unchecked(Expr::val(v))
100    }
101
102    /// Create a `RestrictedExpr` that's just a single `Unknown`.
103    pub fn unknown(u: Unknown) -> Self {
104        // All unknowns are valid restricted-exprs
105        Self::new_unchecked(Expr::unknown(u))
106    }
107
108    /// Create a `RestrictedExpr` which evaluates to a Set of the given `RestrictedExpr`s
109    pub fn set(exprs: impl IntoIterator<Item = RestrictedExpr>) -> Self {
110        // Set expressions are valid restricted-exprs if their elements are; and
111        // we know the elements are because we require `RestrictedExpr`s in the
112        // parameter
113        Self::new_unchecked(Expr::set(exprs.into_iter().map(Into::into)))
114    }
115
116    /// Create a `RestrictedExpr` which evaluates to a Record with the given
117    /// (key, value) pairs.
118    ///
119    /// Throws an error if any key occurs two or more times.
120    pub fn record(
121        pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
122    ) -> Result<Self, ExpressionConstructionError> {
123        // Record expressions are valid restricted-exprs if their elements are;
124        // and we know the elements are because we require `RestrictedExpr`s in
125        // the parameter
126        Ok(Self::new_unchecked(Expr::record(
127            pairs.into_iter().map(|(k, v)| (k, v.into())),
128        )?))
129    }
130
131    /// Create a `RestrictedExpr` which calls the given extension function
132    pub fn call_extension_fn(
133        function_name: Name,
134        args: impl IntoIterator<Item = RestrictedExpr>,
135    ) -> Self {
136        // Extension-function calls are valid restricted-exprs if their
137        // arguments are; and we know the arguments are because we require
138        // `RestrictedExpr`s in the parameter
139        Self::new_unchecked(Expr::call_extension_fn(
140            function_name,
141            args.into_iter().map(Into::into).collect(),
142        ))
143    }
144
145    /// Write a RestrictedExpr in "natural JSON" format.
146    ///
147    /// Used to output the context as a map from Strings to JSON Values
148    pub fn to_natural_json(&self) -> Result<serde_json::Value, JsonSerializationError> {
149        self.as_borrowed().to_natural_json()
150    }
151
152    /// Get the `bool` value of this `RestrictedExpr` if it's a boolean, or
153    /// `None` if it is not a boolean
154    pub fn as_bool(&self) -> Option<bool> {
155        // the only way a `RestrictedExpr` can be a boolean is if it's a literal
156        match self.expr_kind() {
157            ExprKind::Lit(Literal::Bool(b)) => Some(*b),
158            _ => None,
159        }
160    }
161
162    /// Get the `i64` value of this `RestrictedExpr` if it's a long, or `None`
163    /// if it is not a long
164    pub fn as_long(&self) -> Option<i64> {
165        // the only way a `RestrictedExpr` can be a long is if it's a literal
166        match self.expr_kind() {
167            ExprKind::Lit(Literal::Long(i)) => Some(*i),
168            _ => None,
169        }
170    }
171
172    /// Get the `SmolStr` value of this `RestrictedExpr` if it's a string, or
173    /// `None` if it is not a string
174    pub fn as_string(&self) -> Option<&SmolStr> {
175        // the only way a `RestrictedExpr` can be a string is if it's a literal
176        match self.expr_kind() {
177            ExprKind::Lit(Literal::String(s)) => Some(s),
178            _ => None,
179        }
180    }
181
182    /// Get the `EntityUID` value of this `RestrictedExpr` if it's an entity
183    /// reference, or `None` if it is not an entity reference
184    pub fn as_euid(&self) -> Option<&EntityUID> {
185        // the only way a `RestrictedExpr` can be an entity reference is if it's
186        // a literal
187        match self.expr_kind() {
188            ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
189            _ => None,
190        }
191    }
192
193    /// Get `Unknown` value of this `RestrictedExpr` if it's an `Unknown`, or
194    /// `None` if it is not an `Unknown`
195    pub fn as_unknown(&self) -> Option<&Unknown> {
196        match self.expr_kind() {
197            ExprKind::Unknown(u) => Some(u),
198            _ => None,
199        }
200    }
201
202    /// Iterate over the elements of the set if this `RestrictedExpr` is a set,
203    /// or `None` if it is not a set
204    pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
205        match self.expr_kind() {
206            ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), // since the RestrictedExpr invariant holds for the input set, it will hold for each element as well
207            _ => None,
208        }
209    }
210
211    /// Iterate over the (key, value) pairs of the record if this
212    /// `RestrictedExpr` is a record, or `None` if it is not a record
213    pub fn as_record_pairs(
214        &self,
215    ) -> Option<impl Iterator<Item = (&SmolStr, BorrowedRestrictedExpr<'_>)>> {
216        match self.expr_kind() {
217            ExprKind::Record(map) => Some(
218                map.iter()
219                    .map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
220            ), // since the RestrictedExpr invariant holds for the input record, it will hold for each attr value as well
221            _ => None,
222        }
223    }
224
225    /// Get the name and args of the called extension function if this
226    /// `RestrictedExpr` is an extension function call, or `None` if it is not
227    /// an extension function call
228    pub fn as_extn_fn_call(
229        &self,
230    ) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
231        match self.expr_kind() {
232            ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
233                fn_name,
234                args.iter().map(BorrowedRestrictedExpr::new_unchecked),
235            )), // since the RestrictedExpr invariant holds for the input call, it will hold for each argument as well
236            _ => None,
237        }
238    }
239}
240
241impl From<Value> for RestrictedExpr {
242    fn from(value: Value) -> RestrictedExpr {
243        RestrictedExpr::from(value.value).with_maybe_source_loc(value.loc)
244    }
245}
246
247impl From<ValueKind> for RestrictedExpr {
248    fn from(value: ValueKind) -> RestrictedExpr {
249        match value {
250            ValueKind::Lit(lit) => RestrictedExpr::val(lit),
251            ValueKind::Set(set) => {
252                RestrictedExpr::set(set.iter().map(|val| RestrictedExpr::from(val.clone())))
253            }
254            // PANIC SAFETY: cannot have duplicate key because the input was already a BTreeMap
255            #[allow(clippy::expect_used)]
256            ValueKind::Record(record) => RestrictedExpr::record(
257                Arc::unwrap_or_clone(record)
258                    .into_iter()
259                    .map(|(k, v)| (k, RestrictedExpr::from(v))),
260            )
261            .expect("can't have duplicate keys, because the input `map` was already a BTreeMap"),
262            ValueKind::ExtensionValue(ev) => {
263                let ev = Arc::unwrap_or_clone(ev);
264                ev.into()
265            }
266        }
267    }
268}
269
270impl TryFrom<PartialValue> for RestrictedExpr {
271    type Error = PartialValueToRestrictedExprError;
272    fn try_from(pvalue: PartialValue) -> Result<RestrictedExpr, PartialValueToRestrictedExprError> {
273        match pvalue {
274            PartialValue::Value(v) => Ok(RestrictedExpr::from(v)),
275            PartialValue::Residual(expr) => match RestrictedExpr::new(expr) {
276                Ok(e) => Ok(e),
277                Err(RestrictedExpressionError::InvalidRestrictedExpression(
278                    restricted_expr_errors::InvalidRestrictedExpressionError { expr, .. },
279                )) => Err(PartialValueToRestrictedExprError::NontrivialResidual {
280                    residual: Box::new(expr),
281                }),
282            },
283        }
284    }
285}
286
287/// Errors when converting `PartialValue` to `RestrictedExpr`
288#[derive(Debug, PartialEq, Eq, Diagnostic, Error)]
289pub enum PartialValueToRestrictedExprError {
290    /// The `PartialValue` contains a nontrivial residual that isn't a valid `RestrictedExpr`
291    #[error("residual is not a valid restricted expression: `{residual}`")]
292    NontrivialResidual {
293        /// Residual that isn't a valid `RestrictedExpr`
294        residual: Box<Expr>,
295    },
296}
297
298impl std::str::FromStr for RestrictedExpr {
299    type Err = RestrictedExpressionParseError;
300
301    fn from_str(s: &str) -> Result<RestrictedExpr, Self::Err> {
302        parser::parse_restrictedexpr(s)
303    }
304}
305
306/// While `RestrictedExpr` wraps an _owned_ `Expr`, `BorrowedRestrictedExpr`
307/// wraps a _borrowed_ `Expr`, with the same invariants.
308///
309/// We derive `Copy` for this type because it's just a single reference, and
310/// `&T` is `Copy` for all `T`.
311#[derive(Hash, Debug, Clone, PartialEq, Eq, Copy)]
312pub struct BorrowedRestrictedExpr<'a>(&'a Expr);
313
314impl<'a> BorrowedRestrictedExpr<'a> {
315    /// Create a new `BorrowedRestrictedExpr` from an `&Expr`.
316    ///
317    /// This function is "safe" in the sense that it will verify that the
318    /// provided `expr` does indeed qualify as a "restricted" expression,
319    /// returning an error if not.
320    ///
321    /// Note this check requires recursively walking the AST. For a version of
322    /// this function that doesn't perform this check, see `new_unchecked()`
323    /// below.
324    pub fn new(expr: &'a Expr) -> Result<Self, RestrictedExpressionError> {
325        is_restricted(expr)?;
326        Ok(Self(expr))
327    }
328
329    /// Create a new `BorrowedRestrictedExpr` from an `&Expr`, where the caller
330    /// is responsible for ensuring that the `Expr` is a valid "restricted
331    /// expression". If it is not, internal invariants will be violated, which
332    /// may lead to other errors later, panics, or even incorrect results.
333    ///
334    /// For a "safer" version of this function that returns an error for invalid
335    /// inputs, see `new()` above.
336    pub fn new_unchecked(expr: &'a Expr) -> Self {
337        // in debug builds, this does the check anyway, panicking if it fails
338        if cfg!(debug_assertions) {
339            // PANIC SAFETY: We're in debug mode and panicking intentionally
340            #[allow(clippy::unwrap_used)]
341            Self::new(expr).unwrap()
342        } else {
343            Self(expr)
344        }
345    }
346
347    /// Write a BorrowedRestrictedExpr in "natural JSON" format.
348    ///
349    /// Used to output the context as a map from Strings to JSON Values
350    pub fn to_natural_json(self) -> Result<serde_json::Value, JsonSerializationError> {
351        Ok(serde_json::to_value(
352            crate::entities::json::CedarValueJson::from_expr(self)?,
353        )?)
354    }
355
356    /// Convert `BorrowedRestrictedExpr` to `RestrictedExpr`.
357    /// This has approximately the cost of cloning the `Expr`.
358    pub fn to_owned(self) -> RestrictedExpr {
359        RestrictedExpr::new_unchecked(self.0.clone())
360    }
361
362    /// Get the `bool` value of this `RestrictedExpr` if it's a boolean, or
363    /// `None` if it is not a boolean
364    pub fn as_bool(&self) -> Option<bool> {
365        // the only way a `RestrictedExpr` can be a boolean is if it's a literal
366        match self.expr_kind() {
367            ExprKind::Lit(Literal::Bool(b)) => Some(*b),
368            _ => None,
369        }
370    }
371
372    /// Get the `i64` value of this `RestrictedExpr` if it's a long, or `None`
373    /// if it is not a long
374    pub fn as_long(&self) -> Option<i64> {
375        // the only way a `RestrictedExpr` can be a long is if it's a literal
376        match self.expr_kind() {
377            ExprKind::Lit(Literal::Long(i)) => Some(*i),
378            _ => None,
379        }
380    }
381
382    /// Get the `SmolStr` value of this `RestrictedExpr` if it's a string, or
383    /// `None` if it is not a string
384    pub fn as_string(&self) -> Option<&SmolStr> {
385        // the only way a `RestrictedExpr` can be a string is if it's a literal
386        match self.expr_kind() {
387            ExprKind::Lit(Literal::String(s)) => Some(s),
388            _ => None,
389        }
390    }
391
392    /// Get the `EntityUID` value of this `RestrictedExpr` if it's an entity
393    /// reference, or `None` if it is not an entity reference
394    pub fn as_euid(&self) -> Option<&EntityUID> {
395        // the only way a `RestrictedExpr` can be an entity reference is if it's
396        // a literal
397        match self.expr_kind() {
398            ExprKind::Lit(Literal::EntityUID(e)) => Some(e),
399            _ => None,
400        }
401    }
402
403    /// Get `Unknown` value of this `RestrictedExpr` if it's an `Unknown`, or
404    /// `None` if it is not an `Unknown`
405    pub fn as_unknown(&self) -> Option<&Unknown> {
406        match self.expr_kind() {
407            ExprKind::Unknown(u) => Some(u),
408            _ => None,
409        }
410    }
411
412    /// Iterate over the elements of the set if this `RestrictedExpr` is a set,
413    /// or `None` if it is not a set
414    pub fn as_set_elements(&self) -> Option<impl Iterator<Item = BorrowedRestrictedExpr<'_>>> {
415        match self.expr_kind() {
416            ExprKind::Set(set) => Some(set.iter().map(BorrowedRestrictedExpr::new_unchecked)), // since the RestrictedExpr invariant holds for the input set, it will hold for each element as well
417            _ => None,
418        }
419    }
420
421    /// Iterate over the (key, value) pairs of the record if this
422    /// `RestrictedExpr` is a record, or `None` if it is not a record
423    pub fn as_record_pairs(
424        &self,
425    ) -> Option<impl Iterator<Item = (&'_ SmolStr, BorrowedRestrictedExpr<'_>)>> {
426        match self.expr_kind() {
427            ExprKind::Record(map) => Some(
428                map.iter()
429                    .map(|(k, v)| (k, BorrowedRestrictedExpr::new_unchecked(v))),
430            ), // since the RestrictedExpr invariant holds for the input record, it will hold for each attr value as well
431            _ => None,
432        }
433    }
434
435    /// Get the name and args of the called extension function if this
436    /// `RestrictedExpr` is an extension function call, or `None` if it is not
437    /// an extension function call
438    pub fn as_extn_fn_call(
439        &self,
440    ) -> Option<(&Name, impl Iterator<Item = BorrowedRestrictedExpr<'_>>)> {
441        match self.expr_kind() {
442            ExprKind::ExtensionFunctionApp { fn_name, args } => Some((
443                fn_name,
444                args.iter().map(BorrowedRestrictedExpr::new_unchecked),
445            )), // since the RestrictedExpr invariant holds for the input call, it will hold for each argument as well
446            _ => None,
447        }
448    }
449
450    /// Try to compute the runtime type of this expression. See
451    /// [`Expr::try_type_of`] for exactly what this computes.
452    ///
453    /// On a restricted expression, there are fewer cases where we might fail to
454    /// compute the type, but there are still `unknown`s and extension function
455    /// calls which may cause this function to return `None` .
456    pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
457        self.0.try_type_of(extensions)
458    }
459}
460
461/// Helper function: does the given `Expr` qualify as a "restricted" expression.
462///
463/// Returns `Ok(())` if yes, or a `RestrictedExpressionError` if no.
464fn is_restricted(expr: &Expr) -> Result<(), RestrictedExpressionError> {
465    match expr.expr_kind() {
466        ExprKind::Lit(_) => Ok(()),
467        ExprKind::Unknown(_) => Ok(()),
468        ExprKind::Var(_) => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
469            feature: "variables".into(),
470            expr: expr.clone(),
471        }
472        .into()),
473        ExprKind::Slot(_) => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
474            feature: "template slots".into(),
475            expr: expr.clone(),
476        }
477        .into()),
478        ExprKind::If { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
479            feature: "if-then-else".into(),
480            expr: expr.clone(),
481        }
482        .into()),
483        ExprKind::And { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
484            feature: "&&".into(),
485            expr: expr.clone(),
486        }
487        .into()),
488        ExprKind::Or { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
489            feature: "||".into(),
490            expr: expr.clone(),
491        }
492        .into()),
493        ExprKind::UnaryApp { op, .. } => {
494            Err(restricted_expr_errors::InvalidRestrictedExpressionError {
495                feature: op.to_smolstr(),
496                expr: expr.clone(),
497            }
498            .into())
499        }
500        ExprKind::BinaryApp { op, .. } => {
501            Err(restricted_expr_errors::InvalidRestrictedExpressionError {
502                feature: op.to_smolstr(),
503                expr: expr.clone(),
504            }
505            .into())
506        }
507        ExprKind::GetAttr { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
508            feature: "attribute accesses".into(),
509            expr: expr.clone(),
510        }
511        .into()),
512        ExprKind::HasAttr { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
513            feature: "'has'".into(),
514            expr: expr.clone(),
515        }
516        .into()),
517        ExprKind::Like { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
518            feature: "'like'".into(),
519            expr: expr.clone(),
520        }
521        .into()),
522        ExprKind::Is { .. } => Err(restricted_expr_errors::InvalidRestrictedExpressionError {
523            feature: "'is'".into(),
524            expr: expr.clone(),
525        }
526        .into()),
527        ExprKind::ExtensionFunctionApp { args, .. } => args.iter().try_for_each(is_restricted),
528        ExprKind::Set(exprs) => exprs.iter().try_for_each(is_restricted),
529        ExprKind::Record(map) => map.values().try_for_each(is_restricted),
530        #[cfg(feature = "tolerant-ast")]
531        ExprKind::Error { .. } => Ok(()),
532    }
533}
534
535// converting into Expr is always safe; restricted exprs are always valid Exprs
536impl From<RestrictedExpr> for Expr {
537    fn from(r: RestrictedExpr) -> Expr {
538        r.0
539    }
540}
541
542impl AsRef<Expr> for RestrictedExpr {
543    fn as_ref(&self) -> &Expr {
544        &self.0
545    }
546}
547
548impl Deref for RestrictedExpr {
549    type Target = Expr;
550    fn deref(&self) -> &Expr {
551        self.as_ref()
552    }
553}
554
555impl std::fmt::Display for RestrictedExpr {
556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        write!(f, "{}", &self.0)
558    }
559}
560
561// converting into Expr is always safe; restricted exprs are always valid Exprs
562impl<'a> From<BorrowedRestrictedExpr<'a>> for &'a Expr {
563    fn from(r: BorrowedRestrictedExpr<'a>) -> &'a Expr {
564        r.0
565    }
566}
567
568impl<'a> AsRef<Expr> for BorrowedRestrictedExpr<'a> {
569    fn as_ref(&self) -> &'a Expr {
570        self.0
571    }
572}
573
574impl RestrictedExpr {
575    /// Turn an `&RestrictedExpr` into a `BorrowedRestrictedExpr`
576    pub fn as_borrowed(&self) -> BorrowedRestrictedExpr<'_> {
577        BorrowedRestrictedExpr::new_unchecked(self.as_ref())
578    }
579}
580
581impl<'a> Deref for BorrowedRestrictedExpr<'a> {
582    type Target = Expr;
583    fn deref(&self) -> &'a Expr {
584        self.0
585    }
586}
587
588impl std::fmt::Display for BorrowedRestrictedExpr<'_> {
589    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590        write!(f, "{}", &self.0)
591    }
592}
593
594/// Like `ExprShapeOnly`, but for restricted expressions.
595///
596/// A newtype wrapper around (borrowed) restricted expressions that provides
597/// `Eq` and `Hash` implementations that ignore any source information or other
598/// generic data used to annotate the expression.
599#[derive(Eq, Debug, Clone)]
600pub struct RestrictedExprShapeOnly<'a>(BorrowedRestrictedExpr<'a>);
601
602impl<'a> RestrictedExprShapeOnly<'a> {
603    /// Construct a `RestrictedExprShapeOnly` from a `BorrowedRestrictedExpr`.
604    /// The `BorrowedRestrictedExpr` is not modified, but any comparisons on the
605    /// resulting `RestrictedExprShapeOnly` will ignore source information and
606    /// generic data.
607    pub fn new(e: BorrowedRestrictedExpr<'a>) -> RestrictedExprShapeOnly<'a> {
608        RestrictedExprShapeOnly(e)
609    }
610}
611
612impl PartialEq for RestrictedExprShapeOnly<'_> {
613    fn eq(&self, other: &Self) -> bool {
614        self.0.eq_shape(&other.0)
615    }
616}
617
618impl Hash for RestrictedExprShapeOnly<'_> {
619    fn hash<H: Hasher>(&self, state: &mut H) {
620        self.0.hash_shape(state);
621    }
622}
623
624/// Error when constructing a restricted expression from unrestricted
625/// expression
626//
627// CAUTION: this type is publicly exported in `cedar-policy`.
628// Don't make fields `pub`, don't make breaking changes, and use caution
629// when adding public methods.
630#[derive(Debug, Clone, PartialEq, Eq, Error, Diagnostic)]
631pub enum RestrictedExpressionError {
632    /// An expression was expected to be a "restricted" expression, but contained
633    /// a feature that is not allowed in restricted expressions.
634    #[error(transparent)]
635    #[diagnostic(transparent)]
636    InvalidRestrictedExpression(#[from] restricted_expr_errors::InvalidRestrictedExpressionError),
637}
638
639/// Error subtypes for [`RestrictedExpressionError`]
640pub mod restricted_expr_errors {
641    use super::Expr;
642    use crate::impl_diagnostic_from_method_on_field;
643    use miette::Diagnostic;
644    use smol_str::SmolStr;
645    use thiserror::Error;
646
647    /// An expression was expected to be a "restricted" expression, but contained
648    /// a feature that is not allowed in restricted expressions.
649    //
650    // CAUTION: this type is publicly exported in `cedar-policy`.
651    // Don't make fields `pub`, don't make breaking changes, and use caution
652    // when adding public methods.
653    #[derive(Debug, Clone, PartialEq, Eq, Error)]
654    #[error("not allowed to use {feature} in a restricted expression: `{expr}`")]
655    pub struct InvalidRestrictedExpressionError {
656        /// String description of what disallowed feature appeared in the expression
657        pub(crate) feature: SmolStr,
658        /// the (sub-)expression that uses the disallowed feature. This may be a
659        /// sub-expression of a larger expression.
660        pub(crate) expr: Expr,
661    }
662
663    // custom impl of `Diagnostic`: take source location from the `expr` field's `.source_loc()` method
664    impl Diagnostic for InvalidRestrictedExpressionError {
665        impl_diagnostic_from_method_on_field!(expr, source_loc);
666    }
667}
668
669/// Errors possible from `RestrictedExpr::from_str()`
670//
671// This is NOT a publicly exported error type.
672#[derive(Debug, Clone, PartialEq, Eq, Diagnostic, Error)]
673pub enum RestrictedExpressionParseError {
674    /// Failed to parse the expression
675    #[error(transparent)]
676    #[diagnostic(transparent)]
677    Parse(#[from] ParseErrors),
678    /// Parsed successfully as an expression, but failed to construct a
679    /// restricted expression, for the reason indicated in the underlying error
680    #[error(transparent)]
681    #[diagnostic(transparent)]
682    InvalidRestrictedExpression(#[from] RestrictedExpressionError),
683}
684
685#[cfg(test)]
686mod test {
687    use super::*;
688    use crate::ast::expression_construction_errors;
689    use crate::parser::err::{ParseError, ToASTError, ToASTErrorKind};
690    use crate::parser::Loc;
691    use std::str::FromStr;
692    use std::sync::Arc;
693
694    #[test]
695    fn duplicate_key() {
696        // duplicate key is an error when mapped to values of different types
697        assert_eq!(
698            RestrictedExpr::record([
699                ("foo".into(), RestrictedExpr::val(37),),
700                ("foo".into(), RestrictedExpr::val("hello"),),
701            ]),
702            Err(expression_construction_errors::DuplicateKeyError {
703                key: "foo".into(),
704                context: "in record literal",
705            }
706            .into())
707        );
708
709        // duplicate key is an error when mapped to different values of same type
710        assert_eq!(
711            RestrictedExpr::record([
712                ("foo".into(), RestrictedExpr::val(37),),
713                ("foo".into(), RestrictedExpr::val(101),),
714            ]),
715            Err(expression_construction_errors::DuplicateKeyError {
716                key: "foo".into(),
717                context: "in record literal",
718            }
719            .into())
720        );
721
722        // duplicate key is an error when mapped to the same value multiple times
723        assert_eq!(
724            RestrictedExpr::record([
725                ("foo".into(), RestrictedExpr::val(37),),
726                ("foo".into(), RestrictedExpr::val(37),),
727            ]),
728            Err(expression_construction_errors::DuplicateKeyError {
729                key: "foo".into(),
730                context: "in record literal",
731            }
732            .into())
733        );
734
735        // duplicate key is an error even when other keys appear in between
736        assert_eq!(
737            RestrictedExpr::record([
738                ("bar".into(), RestrictedExpr::val(-3),),
739                ("foo".into(), RestrictedExpr::val(37),),
740                ("spam".into(), RestrictedExpr::val("eggs"),),
741                ("foo".into(), RestrictedExpr::val(37),),
742                ("eggs".into(), RestrictedExpr::val("spam"),),
743            ]),
744            Err(expression_construction_errors::DuplicateKeyError {
745                key: "foo".into(),
746                context: "in record literal",
747            }
748            .into())
749        );
750
751        // duplicate key is also an error when parsing from string
752        let str = r#"{ foo: 37, bar: "hi", foo: 101 }"#;
753        assert_eq!(
754            RestrictedExpr::from_str(str),
755            Err(RestrictedExpressionParseError::Parse(
756                ParseErrors::singleton(ParseError::ToAST(ToASTError::new(
757                    ToASTErrorKind::ExpressionConstructionError(
758                        expression_construction_errors::DuplicateKeyError {
759                            key: "foo".into(),
760                            context: "in record literal",
761                        }
762                        .into()
763                    ),
764                    Some(Loc::new(0..32, Arc::from(str)))
765                )))
766            )),
767        )
768    }
769}