cedar_policy_core/ast/
expr_iterator.rs

1/*
2 * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{Expr, ExprKind};
18
19/// This structure implements the iterator used to traverse subexpressions of an
20/// expression.
21#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23    /// The stack of expressions that need to be visited. To get the next
24    /// expression, the iterator will pop from the stack. If the stack is empty,
25    /// then the iterator is finished. Otherwise, any subexpressions of that
26    /// expression are then pushed onto the stack, and the popped expression is
27    /// returned.
28    expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32    /// Construct an expr iterator
33    pub fn new(expr: &'a Expr<T>) -> Self {
34        Self {
35            expression_stack: vec![expr],
36        }
37    }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41    type Item = &'a Expr<T>;
42
43    fn next(&mut self) -> Option<Self::Item> {
44        let next_expr = self.expression_stack.pop()?;
45        match next_expr.expr_kind() {
46            ExprKind::Lit(_) => (),
47            ExprKind::Unknown { .. } => (),
48            ExprKind::Slot(_) => (),
49            ExprKind::Var(_) => (),
50            ExprKind::If {
51                test_expr,
52                then_expr,
53                else_expr,
54            } => {
55                self.expression_stack.push(test_expr);
56                self.expression_stack.push(then_expr);
57                self.expression_stack.push(else_expr);
58            }
59            ExprKind::And { left, right } => {
60                self.expression_stack.push(left);
61                self.expression_stack.push(right);
62            }
63            ExprKind::Or { left, right } => {
64                self.expression_stack.push(left);
65                self.expression_stack.push(right);
66            }
67            ExprKind::UnaryApp { arg, .. } => {
68                self.expression_stack.push(arg);
69            }
70            ExprKind::BinaryApp { arg1, arg2, .. } => {
71                self.expression_stack.push(arg1);
72                self.expression_stack.push(arg2);
73            }
74            ExprKind::MulByConst { arg, .. } => {
75                self.expression_stack.push(arg);
76            }
77            ExprKind::ExtensionFunctionApp { args, .. } => {
78                for arg in args.as_ref() {
79                    self.expression_stack.push(arg);
80                }
81            }
82            ExprKind::GetAttr { expr, attr: _ } => {
83                self.expression_stack.push(expr);
84            }
85            ExprKind::HasAttr { expr, attr: _ } => {
86                self.expression_stack.push(expr);
87            }
88            ExprKind::Like { expr, pattern: _ } => {
89                self.expression_stack.push(expr);
90            }
91            ExprKind::Set(elems) => {
92                for expr in elems.as_ref() {
93                    self.expression_stack.push(expr);
94                }
95            }
96            ExprKind::Record { pairs } => {
97                for (_, val_expr) in pairs.as_ref() {
98                    self.expression_stack.push(val_expr);
99                }
100            }
101        }
102        Some(next_expr)
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use std::collections::HashSet;
109
110    use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
111
112    #[test]
113    fn literals() {
114        let e = Expr::val(true);
115        let v: HashSet<_> = e.subexpressions().collect();
116
117        assert_eq!(v.len(), 1);
118        assert!(v.contains(&Expr::val(true)));
119    }
120
121    #[test]
122    fn slots() {
123        let e = Expr::slot(SlotId::principal());
124        let v: HashSet<_> = e.subexpressions().collect();
125        assert_eq!(v.len(), 1);
126        assert!(v.contains(&Expr::slot(SlotId::principal())));
127    }
128
129    #[test]
130    fn variables() {
131        let e = Expr::var(Var::Principal);
132        let v: HashSet<_> = e.subexpressions().collect();
133        let s = HashSet::from([&e]);
134        assert_eq!(v, s);
135    }
136
137    #[test]
138    fn ite() {
139        let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
140        let v: HashSet<_> = e.subexpressions().collect();
141        assert_eq!(
142            v,
143            HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
144        );
145    }
146
147    #[test]
148    fn and() {
149        // Using `1 && false` because `true && false` would be simplified to
150        // `false` by `Expr::and`.
151        let e = Expr::and(Expr::val(1), Expr::val(false));
152        println!("{:?}", e);
153        let v: HashSet<_> = e.subexpressions().collect();
154        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
155    }
156
157    #[test]
158    fn or() {
159        // Using `1 || false` because `true || false` would be simplified to
160        // `true` by `Expr::or`.
161        let e = Expr::or(Expr::val(1), Expr::val(false));
162        let v: HashSet<_> = e.subexpressions().collect();
163        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
164    }
165
166    #[test]
167    fn unary() {
168        let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
169        assert_eq!(
170            e.subexpressions().collect::<HashSet<_>>(),
171            HashSet::from([&e, &Expr::val(false)])
172        );
173    }
174
175    #[test]
176    fn binary() {
177        let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
178        assert_eq!(
179            e.subexpressions().collect::<HashSet<_>>(),
180            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
181        );
182    }
183
184    #[test]
185    fn ext() {
186        let e = Expr::call_extension_fn(
187            "test".parse().unwrap(),
188            vec![Expr::val(false), Expr::val(true)],
189        );
190        assert_eq!(
191            e.subexpressions().collect::<HashSet<_>>(),
192            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
193        );
194    }
195
196    #[test]
197    fn has_attr() {
198        let e = Expr::has_attr(Expr::val(false), "test".into());
199        assert_eq!(
200            e.subexpressions().collect::<HashSet<_>>(),
201            HashSet::from([&e, &Expr::val(false)])
202        );
203    }
204
205    #[test]
206    fn get_attr() {
207        let e = Expr::get_attr(Expr::val(false), "test".into());
208        assert_eq!(
209            e.subexpressions().collect::<HashSet<_>>(),
210            HashSet::from([&e, &Expr::val(false)])
211        );
212    }
213
214    #[test]
215    fn set() {
216        let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
217        assert_eq!(
218            e.subexpressions().collect::<HashSet<_>>(),
219            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
220        );
221    }
222
223    #[test]
224    fn set_duplicates() {
225        let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
226        let v: Vec<_> = e.subexpressions().collect();
227        assert_eq!(v.len(), 3);
228        assert!(v.contains(&&Expr::val(true)));
229    }
230
231    #[test]
232    fn record() {
233        let e = Expr::record(vec![
234            ("test".into(), Expr::val(true)),
235            ("another".into(), Expr::val(false)),
236        ]);
237        assert_eq!(
238            e.subexpressions().collect::<HashSet<_>>(),
239            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
240        );
241    }
242
243    #[test]
244    fn duplicates() {
245        let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
246        let v: Vec<_> = e.subexpressions().collect();
247        assert_eq!(v.len(), 4);
248        assert!(v.contains(&&e));
249        assert!(v.contains(&&Expr::val(true)));
250    }
251
252    #[test]
253    fn deeply_nested() {
254        let e = Expr::get_attr(
255            Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
256            "attr1".into(),
257        );
258        let set: HashSet<_> = e.subexpressions().collect();
259        assert!(set.contains(&e));
260        assert!(set.contains(&Expr::get_attr(
261            Expr::and(Expr::val(1), Expr::val(0)),
262            "attr2".into()
263        )));
264        assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
265        assert!(set.contains(&Expr::val(1)));
266        assert!(set.contains(&Expr::val(0)));
267    }
268}