cedar_policy_core/ast/
value.rs

1/*
2 * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use core::fmt;
19use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
20use std::sync::Arc;
21
22use itertools::Either;
23use serde::{Deserialize, Serialize};
24use smol_str::SmolStr;
25use thiserror::Error;
26
27/// This describes all the values which could be the dynamic result of evaluating an `Expr`.
28/// Cloning is O(1).
29#[derive(Debug, Clone, PartialOrd, Ord, Serialize, Deserialize)]
30#[serde(into = "Expr")]
31#[serde(try_from = "Expr")]
32pub enum Value {
33    /// anything that is a Literal can also be the dynamic result of evaluating an `Expr`
34    Lit(Literal),
35    /// Evaluating an `Expr` can result in a first-class set
36    Set(Set),
37    /// Evaluating an `Expr` can result in a first-class anonymous record (keyed on String)
38    Record(Arc<BTreeMap<SmolStr, Value>>),
39    /// Evaluating an `Expr` can result in an extension value
40    ExtensionValue(Arc<ExtensionValueWithArgs>),
41}
42
43#[derive(Debug, Error)]
44/// An error that can be thrown converting an expression to a value
45pub enum NotValue {
46    /// General error for non-values
47    #[error("Not A Value")]
48    NotValue,
49}
50
51impl TryFrom<Expr> for Value {
52    type Error = NotValue;
53
54    fn try_from(value: Expr) -> Result<Self, Self::Error> {
55        match value.into_expr_kind() {
56            ExprKind::Lit(l) => Ok(Value::Lit(l)),
57            ExprKind::Unknown { .. } => Err(NotValue::NotValue),
58            ExprKind::Var(_) => Err(NotValue::NotValue),
59            ExprKind::Slot(_) => Err(NotValue::NotValue),
60            ExprKind::If { .. } => Err(NotValue::NotValue),
61            ExprKind::And { .. } => Err(NotValue::NotValue),
62            ExprKind::Or { .. } => Err(NotValue::NotValue),
63            ExprKind::UnaryApp { .. } => Err(NotValue::NotValue),
64            ExprKind::BinaryApp { .. } => Err(NotValue::NotValue),
65            ExprKind::MulByConst { .. } => Err(NotValue::NotValue),
66            ExprKind::ExtensionFunctionApp { .. } => Err(NotValue::NotValue),
67            ExprKind::GetAttr { .. } => Err(NotValue::NotValue),
68            ExprKind::HasAttr { .. } => Err(NotValue::NotValue),
69            ExprKind::Like { .. } => Err(NotValue::NotValue),
70            ExprKind::Set(members) => members
71                .iter()
72                .map(|e| e.clone().try_into())
73                .collect::<Result<Set, _>>()
74                .map(Value::Set),
75            ExprKind::Record { pairs } => pairs
76                .iter()
77                .map(|(k, v)| v.clone().try_into().map(|v: Value| (k.clone(), v)))
78                .collect::<Result<BTreeMap<SmolStr, Value>, _>>()
79                .map(|m| Value::Record(Arc::new(m))),
80        }
81    }
82}
83
84#[derive(Debug, Clone, PartialEq)]
85/// Intermediate results of partial evaluation
86pub enum PartialValue {
87    /// Fully evaluated values
88    Value(Value),
89    /// Residual expressions containing unknowns
90    /// INVARIANT: A residual _must_ have an unknown contained within
91    Residual(Expr),
92}
93
94impl<V: Into<Value>> From<V> for PartialValue {
95    fn from(into_v: V) -> Self {
96        PartialValue::Value(into_v.into())
97    }
98}
99
100impl From<Expr> for PartialValue {
101    fn from(e: Expr) -> Self {
102        debug_assert!(e.is_unknown());
103        PartialValue::Residual(e)
104    }
105}
106
107impl From<PartialValue> for Expr {
108    fn from(val: PartialValue) -> Self {
109        match val {
110            PartialValue::Value(v) => v.into(),
111            PartialValue::Residual(e) => e,
112        }
113    }
114}
115
116impl TryFrom<PartialValue> for Value {
117    type Error = NotValue;
118
119    fn try_from(value: PartialValue) -> Result<Self, Self::Error> {
120        match value {
121            PartialValue::Value(v) => Ok(v),
122            PartialValue::Residual(e) => e.try_into(),
123        }
124    }
125}
126
127impl fmt::Display for PartialValue {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        match self {
130            PartialValue::Value(v) => write!(f, "{v}"),
131            PartialValue::Residual(r) => write!(f, "{r}"),
132        }
133    }
134}
135
136/// Collect an iterator of either residuals or values into one of the following
137///  a) An iterator over values, if everything evaluated to values
138///  b) An iterator over residuals expressions, if anything only evaluated to a residual
139/// Order is preserved.
140pub fn split<I>(i: I) -> Either<impl Iterator<Item = Value>, impl Iterator<Item = Expr>>
141where
142    I: IntoIterator<Item = PartialValue>,
143{
144    let mut values = vec![];
145    let mut residuals = vec![];
146
147    for item in i.into_iter() {
148        match item {
149            PartialValue::Value(a) => {
150                if residuals.is_empty() {
151                    values.push(a)
152                } else {
153                    residuals.push(a.into())
154                }
155            }
156            PartialValue::Residual(r) => {
157                residuals.push(r);
158            }
159        }
160    }
161
162    if residuals.is_empty() {
163        Either::Left(values.into_iter())
164    } else {
165        let mut exprs: Vec<Expr> = values.into_iter().map(|x| x.into()).collect();
166        exprs.append(&mut residuals);
167        Either::Right(exprs.into_iter())
168    }
169}
170
171/// `Value`'s internal representation of a `Set`
172#[derive(Debug, Clone)]
173pub struct Set {
174    /// the values in the set, stored in a `BTreeSet`
175    pub authoritative: Arc<BTreeSet<Value>>,
176    /// if possible, `HashSet<Literal>` representation of the set.
177    /// (This is possible if all the elements are literals.)
178    /// Some operations are much faster in this case.
179    ///
180    /// INVARIANT (FastRepr)
181    /// we guarantee that if the elements are all
182    /// literals, then this will be `Some`. (This allows us to further
183    /// optimize e.g. equality checks between sets: for instance, we know
184    /// that if one set has `fast` and another does not, the sets can't be
185    /// equal.)
186    pub fast: Option<Arc<HashSet<Literal>>>,
187}
188
189impl Set {
190    /// Get the number of items in the set
191    pub fn len(&self) -> usize {
192        self.authoritative.len()
193    }
194    /// Convenience method to check if a set is empty
195    pub fn is_empty(&self) -> bool {
196        self.len() == 0
197    }
198
199    /// Borrowed iterator
200    pub fn iter(&self) -> impl Iterator<Item = &Value> {
201        self.authoritative.iter()
202    }
203}
204
205impl FromIterator<Value> for Set {
206    fn from_iter<T: IntoIterator<Item = Value>>(iter: T) -> Self {
207        let (literals, non_literals): (BTreeSet<_>, BTreeSet<_>) =
208            iter.into_iter().partition(|v| matches!(v, Value::Lit(_)));
209
210        if non_literals.is_empty() {
211            // INVARIANT (FastRepr)
212            // There are 0 non-literals, so we need to populate `fast`
213            Self {
214                authoritative: Arc::new(literals.clone()), // non_literals is empty, so this drops no items
215                fast: Some(Arc::new(
216                    literals
217                        .into_iter()
218                        .map(|v| match v {
219                            Value::Lit(l) => l,
220                            _ => unreachable!(), // SAFETY: This is unreachable as every item in `literals` matches Value::Lit
221                        })
222                        .collect(),
223                )),
224            }
225        } else {
226            // INVARIANT (FastRepr)
227            // There are non-literals, so we need `fast` should be `None`
228            // We also need to add all the literals back into the set
229            let mut all_items = non_literals;
230            let mut literals = literals;
231            all_items.append(&mut literals);
232            Self {
233                authoritative: Arc::new(all_items),
234                fast: None,
235            }
236        }
237    }
238}
239
240impl Value {
241    /// If the value is a Literal, get a reference to the underlying Literal
242    pub(crate) fn try_as_lit(&self) -> Option<&Literal> {
243        match self {
244            Self::Lit(lit) => Some(lit),
245            _ => None,
246        }
247    }
248}
249
250// Trying to derive `PartialEq` for `Value` fails with a compile error (at
251// least, as of this writing) due to the `Arc<dyn>`, so we write out the
252// implementation manually
253impl PartialEq for Value {
254    fn eq(&self, other: &Value) -> bool {
255        match (self, other) {
256            (Value::Lit(l1), Value::Lit(l2)) => l1 == l2,
257            (
258                Value::Set(Set {
259                    fast: Some(rc1), ..
260                }),
261                Value::Set(Set {
262                    fast: Some(rc2), ..
263                }),
264            ) => rc1 == rc2,
265            (Value::Set(Set { fast: Some(_), .. }), Value::Set(Set { fast: None, .. })) => false, // due to internal invariant documented on `Set`, we know that one set contains a non-literal and the other does not
266            (Value::Set(Set { fast: None, .. }), Value::Set(Set { fast: Some(_), .. })) => false, // due to internal invariant documented on `Set`, we know that one set contains a non-literal and the other does not
267            (
268                Value::Set(Set {
269                    authoritative: a1, ..
270                }),
271                Value::Set(Set {
272                    authoritative: a2, ..
273                }),
274            ) => a1 == a2,
275            (Value::Record(r1), Value::Record(r2)) => r1 == r2,
276            (Value::ExtensionValue(ev1), Value::ExtensionValue(ev2)) => ev1 == ev2,
277            (_, _) => false, // values of different types are not equal
278        }
279    }
280}
281
282impl Eq for Value {}
283
284// PartialEq on Set compares only the `authoritative` version
285impl PartialEq for Set {
286    fn eq(&self, other: &Self) -> bool {
287        self.authoritative.as_ref() == other.authoritative.as_ref()
288    }
289}
290
291impl Eq for Set {}
292
293// PartialOrd on Set compares only the `authoritative` version; note that
294// HashSet doesn't implement PartialOrd
295impl PartialOrd<Set> for Set {
296    fn partial_cmp(&self, other: &Set) -> Option<std::cmp::Ordering> {
297        self.authoritative
298            .as_ref()
299            .partial_cmp(other.authoritative.as_ref())
300    }
301}
302
303// Ord on Set compares only the `authoritative` version; note that HashSet
304// doesn't implement Ord
305impl Ord for Set {
306    fn cmp(&self, other: &Set) -> std::cmp::Ordering {
307        self.authoritative
308            .as_ref()
309            .cmp(other.authoritative.as_ref())
310    }
311}
312
313impl StaticallyTyped for Value {
314    fn type_of(&self) -> Type {
315        match self {
316            Self::Lit(lit) => lit.type_of(),
317            Self::Set(_) => Type::Set,
318            Self::Record(_) => Type::Record,
319            Self::ExtensionValue(ev) => ev.type_of(),
320        }
321    }
322}
323
324impl std::fmt::Display for Value {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        match self {
327            Self::Lit(lit) => write!(f, "{}", lit),
328            Self::Set(Set {
329                fast,
330                authoritative,
331            }) => {
332                let len = fast
333                    .as_ref()
334                    .map(|set| set.len())
335                    .unwrap_or_else(|| authoritative.len());
336                match len {
337                    0 => write!(f, "[]"),
338                    1..=5 => {
339                        write!(f, "[")?;
340                        if let Some(rc) = fast {
341                            for item in rc.as_ref() {
342                                write!(f, "{item}, ")?;
343                            }
344                        } else {
345                            for item in authoritative.as_ref() {
346                                write!(f, "{item}, ")?;
347                            }
348                        }
349                        write!(f, "]")?;
350                        Ok(())
351                    }
352                    n => write!(f, "<set with {} elements>", n),
353                }
354            }
355            Self::Record(record) => write!(f, "<first-class record with {} fields>", record.len()),
356            Self::ExtensionValue(ev) => write!(f, "{}", ev),
357        }
358    }
359}
360
361/// Create a `Value` directly from a `Vec<Value>`, or `Vec<T> where T: Into<Value>`
362/// (so `Vec<i64>`, `Vec<String>`, etc)
363impl<T: Into<Value>> From<Vec<T>> for Value {
364    fn from(v: Vec<T>) -> Self {
365        Self::set(v.into_iter().map(Into::into))
366    }
367}
368
369/// Create a `Value::Record` from a map of `String` to `Value`
370impl<S> From<BTreeMap<S, Value>> for Value
371where
372    S: Into<SmolStr>,
373{
374    fn from(map: BTreeMap<S, Value>) -> Self {
375        Self::Record(Arc::new(
376            map.into_iter().map(|(k, v)| (k.into(), v)).collect(),
377        ))
378    }
379}
380
381/// As above, create a `Value::Record` from a map of `SmolStr` to `Value`.
382/// This implementation provides conversion from `HashMap` while the earlier
383/// implementation provides conversion from `BTreeMap`
384impl<S> From<HashMap<S, Value>> for Value
385where
386    S: Into<SmolStr>,
387{
388    fn from(map: HashMap<S, Value>) -> Self {
389        Self::Record(Arc::new(
390            map.into_iter().map(|(k, v)| (k.into(), v)).collect(),
391        ))
392    }
393}
394
395/// Create a `Value` directly from a `Vec` of `(String, Value)` pairs, which
396/// will be interpreted as (field, value) pairs for a first-class record
397impl From<Vec<(SmolStr, Value)>> for Value {
398    fn from(v: Vec<(SmolStr, Value)>) -> Self {
399        Self::Record(Arc::new(v.into_iter().collect()))
400    }
401}
402
403/// Create a `Value` directly from a `Literal`, or from anything that implements
404/// `Into<Literal>` (so `i64`, `&str`, `EntityUID`, etc)
405impl<T: Into<Literal>> From<T> for Value {
406    fn from(lit: T) -> Self {
407        Self::Lit(lit.into())
408    }
409}
410
411impl Value {
412    /// Create a new empty set
413    pub fn empty_set() -> Self {
414        Self::Set(Set {
415            authoritative: Arc::new(BTreeSet::new()),
416            fast: Some(Arc::new(HashSet::new())),
417        })
418    }
419
420    /// Create a new empty record
421    pub fn empty_record() -> Self {
422        Self::Record(Arc::new(BTreeMap::new()))
423    }
424
425    /// Create a set with the given `Value`s as elements
426    pub fn set(vals: impl IntoIterator<Item = Value>) -> Self {
427        let authoritative: BTreeSet<Value> = vals.into_iter().collect();
428        let fast: Option<HashSet<Literal>> = authoritative
429            .iter()
430            .map(|v| v.try_as_lit().cloned())
431            .collect();
432        if let Some(fast) = fast {
433            Self::Set(Set {
434                authoritative: Arc::new(authoritative),
435                fast: Some(Arc::new(fast)),
436            })
437        } else {
438            Self::Set(Set {
439                authoritative: Arc::new(authoritative),
440                fast: None,
441            })
442        }
443    }
444
445    /// Create a set with the given `Literal`s as elements
446    pub fn set_of_lits(lits: impl IntoIterator<Item = Literal>) -> Self {
447        let fast: HashSet<Literal> = lits.into_iter().collect();
448        let authoritative: BTreeSet<Value> =
449            fast.iter().map(|lit| Value::Lit(lit.clone())).collect();
450        Self::Set(Set {
451            authoritative: Arc::new(authoritative),
452            fast: Some(Arc::new(fast)),
453        })
454    }
455}
456
457#[cfg(test)]
458mod test {
459    use super::*;
460
461    #[test]
462    fn values() {
463        assert_eq!(Value::from(true), Value::Lit(Literal::Bool(true)));
464        assert_eq!(Value::from(false), Value::Lit(Literal::Bool(false)));
465        assert_eq!(Value::from(23), Value::Lit(Literal::Long(23)));
466        assert_eq!(Value::from(-47), Value::Lit(Literal::Long(-47)));
467        assert_eq!(
468            Value::from("hello"),
469            Value::Lit(Literal::String("hello".into()))
470        );
471        assert_eq!(
472            Value::from("hello".to_owned()),
473            Value::Lit(Literal::String("hello".into()))
474        );
475        assert_eq!(
476            Value::from(String::new()),
477            Value::Lit(Literal::String(SmolStr::default()))
478        );
479        assert_eq!(
480            Value::from(""),
481            Value::Lit(Literal::String(SmolStr::default()))
482        );
483        assert_eq!(
484            Value::from(vec![2, -3, 40]),
485            Value::set(vec![Value::from(2), Value::from(-3), Value::from(40)])
486        );
487        assert_eq!(
488            Value::from(vec![Literal::from(false), Literal::from("eggs")]),
489            Value::set(vec!(Value::from(false), Value::from("eggs")))
490        );
491        assert_eq!(
492            Value::set(vec!(Value::from(false), Value::from("eggs"))),
493            Value::set_of_lits(vec!(Literal::from(false), Literal::from("eggs")))
494        );
495
496        let mut rec1: BTreeMap<SmolStr, Value> = BTreeMap::new();
497        rec1.insert("ham".into(), 3.into());
498        rec1.insert("eggs".into(), "hickory".into());
499        assert_eq!(Value::from(rec1.clone()), Value::Record(Arc::new(rec1)));
500
501        let mut rec2: BTreeMap<SmolStr, Value> = BTreeMap::new();
502        rec2.insert("hi".into(), "ham".into());
503        rec2.insert("eggs".into(), "hickory".into());
504        assert_eq!(
505            Value::from(vec![
506                ("hi".into(), "ham".into()),
507                ("eggs".into(), "hickory".into())
508            ]),
509            Value::Record(Arc::new(rec2))
510        );
511
512        assert_eq!(
513            Value::from(EntityUID::with_eid("foo")),
514            Value::Lit(Literal::EntityUID(Arc::new(EntityUID::with_eid("foo"))))
515        );
516    }
517
518    #[test]
519    fn value_types() {
520        assert_eq!(Value::from(false).type_of(), Type::Bool);
521        assert_eq!(Value::from(23).type_of(), Type::Long);
522        assert_eq!(Value::from(-47).type_of(), Type::Long);
523        assert_eq!(Value::from("hello").type_of(), Type::String);
524        assert_eq!(Value::from(vec![2, -3, 40]).type_of(), Type::Set);
525        assert_eq!(Value::empty_set().type_of(), Type::Set);
526        assert_eq!(Value::empty_record().type_of(), Type::Record);
527        assert_eq!(
528            Value::from(vec![("hello".into(), Value::from("ham"))]).type_of(),
529            Type::Record
530        );
531        assert_eq!(
532            Value::from(EntityUID::with_eid("foo")).type_of(),
533            Type::entity_type(
534                Name::parse_unqualified_name("test_entity_type").expect("valid identifier")
535            )
536        );
537    }
538
539    #[test]
540    fn test_set_is_empty_for_empty_set() {
541        let set = Set {
542            authoritative: Arc::new(BTreeSet::new()),
543            fast: Some(Arc::new(HashSet::new())),
544        };
545        assert!(set.is_empty());
546    }
547
548    #[test]
549    fn test_set_is_not_empty_for_set_with_values() {
550        let set = Set {
551            authoritative: Arc::new(BTreeSet::from([Value::from("abc")])),
552            fast: None,
553        };
554        assert!(!set.is_empty());
555    }
556
557    #[test]
558    fn pretty_printer() {
559        assert_eq!(Value::from("abc").to_string(), r#""abc""#);
560        assert_eq!(Value::from("\t").to_string(), r#""\t""#);
561        assert_eq!(Value::from("🐈").to_string(), r#""🐈""#);
562    }
563
564    #[test]
565    fn set_collect() {
566        let v = vec![Value::Lit(1.into())];
567        let s: Set = v.into_iter().collect();
568        assert_eq!(s.len(), 1);
569        let v2 = vec![Value::Set(s)];
570        let s2: Set = v2.into_iter().collect();
571        assert_eq!(s2.len(), 1);
572    }
573
574    #[test]
575    fn split_values() {
576        let vs = [
577            PartialValue::Value(Value::Lit(1.into())),
578            PartialValue::Value(Value::Lit(2.into())),
579        ];
580        match split(vs) {
581            Either::Left(vs) => assert_eq!(
582                vs.collect::<Vec<_>>(),
583                vec![Value::Lit(1.into()), Value::Lit(2.into())]
584            ),
585            Either::Right(_) => panic!("Got residuals"),
586        }
587    }
588
589    #[test]
590    fn split_residuals() {
591        let rs = [
592            PartialValue::Value(Value::Lit(1.into())),
593            PartialValue::Residual(Expr::val(2)),
594            PartialValue::Value(Value::Lit(3.into())),
595            PartialValue::Residual(Expr::val(4)),
596        ];
597        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
598        match split(rs) {
599            Either::Left(_) => panic!("Got values"),
600            Either::Right(rs) => assert_eq!(rs.collect::<Vec<_>>(), expected),
601        }
602    }
603
604    #[test]
605    fn split_residuals2() {
606        let rs = [
607            PartialValue::Value(Value::Lit(1.into())),
608            PartialValue::Value(Value::Lit(2.into())),
609            PartialValue::Residual(Expr::val(3)),
610            PartialValue::Residual(Expr::val(4)),
611        ];
612        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
613        match split(rs) {
614            Either::Left(_) => panic!("Got values"),
615            Either::Right(rs) => assert_eq!(rs.collect::<Vec<_>>(), expected),
616        }
617    }
618
619    #[test]
620    fn split_residuals3() {
621        let rs = [
622            PartialValue::Residual(Expr::val(1)),
623            PartialValue::Residual(Expr::val(2)),
624            PartialValue::Value(Value::Lit(3.into())),
625            PartialValue::Value(Value::Lit(4.into())),
626        ];
627        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
628        match split(rs) {
629            Either::Left(_) => panic!("Got values"),
630            Either::Right(rs) => assert_eq!(rs.collect::<Vec<_>>(), expected),
631        }
632    }
633}