cedar_policy_core/ast/
expr_iterator.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{Expr, ExprKind};
18
19/// This structure implements the iterator used to traverse subexpressions of an
20/// expression.
21#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23    /// The stack of expressions that need to be visited. To get the next
24    /// expression, the iterator will pop from the stack. If the stack is empty,
25    /// then the iterator is finished. Otherwise, any subexpressions of that
26    /// expression are then pushed onto the stack, and the popped expression is
27    /// returned.
28    expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32    /// Construct an expr iterator
33    pub fn new(expr: &'a Expr<T>) -> Self {
34        Self {
35            expression_stack: vec![expr],
36        }
37    }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41    type Item = &'a Expr<T>;
42
43    fn next(&mut self) -> Option<Self::Item> {
44        let next_expr = self.expression_stack.pop()?;
45        match next_expr.expr_kind() {
46            ExprKind::Lit(_) => (),
47            ExprKind::Unknown(_) => (),
48            ExprKind::Slot(_) => (),
49            ExprKind::Var(_) => (),
50            ExprKind::If {
51                test_expr,
52                then_expr,
53                else_expr,
54            } => {
55                self.expression_stack.push(test_expr);
56                self.expression_stack.push(then_expr);
57                self.expression_stack.push(else_expr);
58            }
59            ExprKind::And { left, right } => {
60                self.expression_stack.push(left);
61                self.expression_stack.push(right);
62            }
63            ExprKind::Or { left, right } => {
64                self.expression_stack.push(left);
65                self.expression_stack.push(right);
66            }
67            ExprKind::UnaryApp { arg, .. } => {
68                self.expression_stack.push(arg);
69            }
70            ExprKind::BinaryApp { arg1, arg2, .. } => {
71                self.expression_stack.push(arg1);
72                self.expression_stack.push(arg2);
73            }
74            ExprKind::ExtensionFunctionApp { args, .. } => {
75                for arg in args.as_ref() {
76                    self.expression_stack.push(arg);
77                }
78            }
79            ExprKind::GetAttr { expr, attr: _ } => {
80                self.expression_stack.push(expr);
81            }
82            ExprKind::HasAttr { expr, attr: _ } => {
83                self.expression_stack.push(expr);
84            }
85            ExprKind::Like { expr, pattern: _ } => {
86                self.expression_stack.push(expr);
87            }
88            ExprKind::Set(elems) => {
89                self.expression_stack.extend(elems.as_ref());
90            }
91            ExprKind::Record(map) => {
92                self.expression_stack.extend(map.values());
93            }
94            ExprKind::Is { expr, .. } => {
95                self.expression_stack.push(expr);
96            }
97        }
98        Some(next_expr)
99    }
100}
101
102#[cfg(test)]
103mod test {
104    use std::collections::HashSet;
105
106    use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
107
108    #[test]
109    fn literals() {
110        let e = Expr::val(true);
111        let v: HashSet<_> = e.subexpressions().collect();
112
113        assert_eq!(v.len(), 1);
114        assert!(v.contains(&Expr::val(true)));
115    }
116
117    #[test]
118    fn slots() {
119        let e = Expr::slot(SlotId::principal());
120        let v: HashSet<_> = e.subexpressions().collect();
121        assert_eq!(v.len(), 1);
122        assert!(v.contains(&Expr::slot(SlotId::principal())));
123    }
124
125    #[test]
126    fn variables() {
127        let e = Expr::var(Var::Principal);
128        let v: HashSet<_> = e.subexpressions().collect();
129        let s = HashSet::from([&e]);
130        assert_eq!(v, s);
131    }
132
133    #[test]
134    fn ite() {
135        let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
136        let v: HashSet<_> = e.subexpressions().collect();
137        assert_eq!(
138            v,
139            HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
140        );
141    }
142
143    #[test]
144    fn and() {
145        // Using `1 && false` because `true && false` would be simplified to
146        // `false` by `Expr::and`.
147        let e = Expr::and(Expr::val(1), Expr::val(false));
148        println!("{:?}", e);
149        let v: HashSet<_> = e.subexpressions().collect();
150        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
151    }
152
153    #[test]
154    fn or() {
155        // Using `1 || false` because `true || false` would be simplified to
156        // `true` by `Expr::or`.
157        let e = Expr::or(Expr::val(1), Expr::val(false));
158        let v: HashSet<_> = e.subexpressions().collect();
159        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
160    }
161
162    #[test]
163    fn unary() {
164        let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
165        assert_eq!(
166            e.subexpressions().collect::<HashSet<_>>(),
167            HashSet::from([&e, &Expr::val(false)])
168        );
169    }
170
171    #[test]
172    fn binary() {
173        let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
174        assert_eq!(
175            e.subexpressions().collect::<HashSet<_>>(),
176            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
177        );
178    }
179
180    #[test]
181    fn ext() {
182        let e = Expr::call_extension_fn(
183            "test".parse().unwrap(),
184            vec![Expr::val(false), Expr::val(true)],
185        );
186        assert_eq!(
187            e.subexpressions().collect::<HashSet<_>>(),
188            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
189        );
190    }
191
192    #[test]
193    fn has_attr() {
194        let e = Expr::has_attr(Expr::val(false), "test".into());
195        assert_eq!(
196            e.subexpressions().collect::<HashSet<_>>(),
197            HashSet::from([&e, &Expr::val(false)])
198        );
199    }
200
201    #[test]
202    fn get_attr() {
203        let e = Expr::get_attr(Expr::val(false), "test".into());
204        assert_eq!(
205            e.subexpressions().collect::<HashSet<_>>(),
206            HashSet::from([&e, &Expr::val(false)])
207        );
208    }
209
210    #[test]
211    fn set() {
212        let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
213        assert_eq!(
214            e.subexpressions().collect::<HashSet<_>>(),
215            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
216        );
217    }
218
219    #[test]
220    fn set_duplicates() {
221        let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
222        let v: Vec<_> = e.subexpressions().collect();
223        assert_eq!(v.len(), 3);
224        assert!(v.contains(&&Expr::val(true)));
225    }
226
227    #[test]
228    fn record() {
229        let e = Expr::record(vec![
230            ("test".into(), Expr::val(true)),
231            ("another".into(), Expr::val(false)),
232        ])
233        .unwrap();
234        assert_eq!(
235            e.subexpressions().collect::<HashSet<_>>(),
236            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
237        );
238    }
239
240    #[test]
241    fn is() {
242        let e = Expr::is_entity_type(Expr::val(1), "T".parse().unwrap());
243        assert_eq!(
244            e.subexpressions().collect::<HashSet<_>>(),
245            HashSet::from([&e, &Expr::val(1)])
246        );
247    }
248
249    #[test]
250    fn duplicates() {
251        let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
252        let v: Vec<_> = e.subexpressions().collect();
253        assert_eq!(v.len(), 4);
254        assert!(v.contains(&&e));
255        assert!(v.contains(&&Expr::val(true)));
256    }
257
258    #[test]
259    fn deeply_nested() {
260        let e = Expr::get_attr(
261            Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
262            "attr1".into(),
263        );
264        let set: HashSet<_> = e.subexpressions().collect();
265        assert!(set.contains(&e));
266        assert!(set.contains(&Expr::get_attr(
267            Expr::and(Expr::val(1), Expr::val(0)),
268            "attr2".into()
269        )));
270        assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
271        assert!(set.contains(&Expr::val(1)));
272        assert!(set.contains(&Expr::val(0)));
273    }
274}