cedar_policy_core/ast/
partial_value.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{Expr, Unknown, Value};
18use crate::parser::Loc;
19use itertools::Either;
20use miette::Diagnostic;
21use thiserror::Error;
22
23/// Intermediate results of partial evaluation
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum PartialValue {
26    /// Fully evaluated values
27    Value(Value),
28    /// Residual expressions containing unknowns
29    /// INVARIANT: A residual _must_ have an unknown contained within
30    Residual(Expr),
31}
32
33impl PartialValue {
34    /// Create a new `PartialValue` consisting of just this single `Unknown`
35    pub fn unknown(u: Unknown) -> Self {
36        Self::Residual(Expr::unknown(u))
37    }
38
39    /// Return the `PartialValue`, but with the given `Loc` (or `None`)
40    pub fn with_maybe_source_loc(self, loc: Option<Loc>) -> Self {
41        match self {
42            Self::Value(v) => Self::Value(v.with_maybe_source_loc(loc)),
43            Self::Residual(e) => Self::Residual(e.with_maybe_source_loc(loc)),
44        }
45    }
46}
47
48impl<V: Into<Value>> From<V> for PartialValue {
49    fn from(into_v: V) -> Self {
50        PartialValue::Value(into_v.into())
51    }
52}
53
54impl From<Expr> for PartialValue {
55    fn from(e: Expr) -> Self {
56        debug_assert!(e.contains_unknown());
57        PartialValue::Residual(e)
58    }
59}
60
61/// Errors encountered when converting `PartialValue` to `Value`
62// CAUTION: this type is publicly exported in `cedar-policy`.
63#[derive(Debug, PartialEq, Diagnostic, Error)]
64pub enum PartialValueToValueError {
65    /// The `PartialValue` is a residual, i.e., contains an unknown
66    #[diagnostic(transparent)]
67    #[error(transparent)]
68    ContainsUnknown(#[from] ContainsUnknown),
69}
70
71/// The `PartialValue` is a residual, i.e., contains an unknown
72// CAUTION: this type is publicly exported in `cedar-policy`.
73// Don't make fields `pub`, don't make breaking changes, and use caution
74// when adding public methods.
75#[derive(Debug, PartialEq, Diagnostic, Error)]
76#[error("value contains a residual expression: `{residual}`")]
77pub struct ContainsUnknown {
78    /// Residual expression which contains an unknown
79    residual: Expr,
80}
81
82impl TryFrom<PartialValue> for Value {
83    type Error = PartialValueToValueError;
84
85    fn try_from(value: PartialValue) -> Result<Self, Self::Error> {
86        match value {
87            PartialValue::Value(v) => Ok(v),
88            PartialValue::Residual(e) => Err(ContainsUnknown { residual: e }.into()),
89        }
90    }
91}
92
93impl std::fmt::Display for PartialValue {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            PartialValue::Value(v) => write!(f, "{v}"),
97            PartialValue::Residual(r) => write!(f, "{r}"),
98        }
99    }
100}
101
102/// Collect an iterator of either residuals or values into one of the following
103///  a) An iterator over values, if everything evaluated to values
104///  b) An iterator over residuals expressions, if anything only evaluated to a residual
105/// Order is preserved.
106pub fn split<I>(i: I) -> Either<impl Iterator<Item = Value>, impl Iterator<Item = Expr>>
107where
108    I: IntoIterator<Item = PartialValue>,
109{
110    let mut values = vec![];
111    let mut residuals = vec![];
112
113    for item in i.into_iter() {
114        match item {
115            PartialValue::Value(a) => {
116                if residuals.is_empty() {
117                    values.push(a)
118                } else {
119                    residuals.push(a.into())
120                }
121            }
122            PartialValue::Residual(r) => {
123                residuals.push(r);
124            }
125        }
126    }
127
128    if residuals.is_empty() {
129        Either::Left(values.into_iter())
130    } else {
131        let mut exprs: Vec<Expr> = values.into_iter().map(|x| x.into()).collect();
132        exprs.append(&mut residuals);
133        Either::Right(exprs.into_iter())
134    }
135}
136
137// PANIC SAFETY: Unit Test Code
138#[allow(clippy::panic)]
139#[cfg(test)]
140mod test {
141    use super::*;
142
143    #[test]
144    fn split_values() {
145        let vs = [
146            PartialValue::Value(Value::from(1)),
147            PartialValue::Value(Value::from(2)),
148        ];
149        match split(vs) {
150            Either::Right(_) => panic!("expected values, got residuals"),
151            Either::Left(vs) => {
152                assert_eq!(vs.collect::<Vec<_>>(), vec![Value::from(1), Value::from(2)])
153            }
154        };
155    }
156
157    #[test]
158    fn split_residuals() {
159        let rs = [
160            PartialValue::Value(Value::from(1)),
161            PartialValue::Residual(Expr::val(2)),
162            PartialValue::Value(Value::from(3)),
163            PartialValue::Residual(Expr::val(4)),
164        ];
165        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
166        match split(rs) {
167            Either::Left(_) => panic!("expected residuals, got values"),
168            Either::Right(rs) => {
169                assert_eq!(rs.collect::<Vec<_>>(), expected);
170            }
171        };
172    }
173
174    #[test]
175    fn split_residuals2() {
176        let rs = [
177            PartialValue::Value(Value::from(1)),
178            PartialValue::Value(Value::from(2)),
179            PartialValue::Residual(Expr::val(3)),
180            PartialValue::Residual(Expr::val(4)),
181        ];
182        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
183        match split(rs) {
184            Either::Left(_) => panic!("expected residuals, got values"),
185            Either::Right(rs) => {
186                assert_eq!(rs.collect::<Vec<_>>(), expected);
187            }
188        };
189    }
190
191    #[test]
192    fn split_residuals3() {
193        let rs = [
194            PartialValue::Residual(Expr::val(1)),
195            PartialValue::Residual(Expr::val(2)),
196            PartialValue::Value(Value::from(3)),
197            PartialValue::Value(Value::from(4)),
198        ];
199        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
200        match split(rs) {
201            Either::Left(_) => panic!("expected residuals, got values"),
202            Either::Right(rs) => {
203                assert_eq!(rs.collect::<Vec<_>>(), expected);
204            }
205        };
206    }
207}