1use super::{Expr, ExprKind};
18
19#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23 expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32 pub fn new(expr: &'a Expr<T>) -> Self {
34 Self {
35 expression_stack: vec![expr],
36 }
37 }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41 type Item = &'a Expr<T>;
42
43 fn next(&mut self) -> Option<Self::Item> {
44 let next_expr = self.expression_stack.pop()?;
45 match next_expr.expr_kind() {
46 ExprKind::Lit(_) => (),
47 ExprKind::Unknown(_) => (),
48 ExprKind::Slot(_) => (),
49 ExprKind::Var(_) => (),
50 ExprKind::If {
51 test_expr,
52 then_expr,
53 else_expr,
54 } => {
55 self.expression_stack.push(test_expr);
56 self.expression_stack.push(then_expr);
57 self.expression_stack.push(else_expr);
58 }
59 ExprKind::And { left, right } => {
60 self.expression_stack.push(left);
61 self.expression_stack.push(right);
62 }
63 ExprKind::Or { left, right } => {
64 self.expression_stack.push(left);
65 self.expression_stack.push(right);
66 }
67 ExprKind::UnaryApp { arg, .. } => {
68 self.expression_stack.push(arg);
69 }
70 ExprKind::BinaryApp { arg1, arg2, .. } => {
71 self.expression_stack.push(arg1);
72 self.expression_stack.push(arg2);
73 }
74 ExprKind::ExtensionFunctionApp { args, .. } => {
75 for arg in args.as_ref() {
76 self.expression_stack.push(arg);
77 }
78 }
79 ExprKind::GetAttr { expr, attr: _ } => {
80 self.expression_stack.push(expr);
81 }
82 ExprKind::HasAttr { expr, attr: _ } => {
83 self.expression_stack.push(expr);
84 }
85 ExprKind::Like { expr, pattern: _ } => {
86 self.expression_stack.push(expr);
87 }
88 ExprKind::Set(elems) => {
89 self.expression_stack.extend(elems.as_ref());
90 }
91 ExprKind::Record(map) => {
92 self.expression_stack.extend(map.values());
93 }
94 ExprKind::Is { expr, .. } => {
95 self.expression_stack.push(expr);
96 }
97 }
98 Some(next_expr)
99 }
100}
101
102#[cfg(test)]
103mod test {
104 use std::collections::HashSet;
105
106 use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
107
108 #[test]
109 fn literals() {
110 let e = Expr::val(true);
111 let v: HashSet<_> = e.subexpressions().collect();
112
113 assert_eq!(v.len(), 1);
114 assert!(v.contains(&Expr::val(true)));
115 }
116
117 #[test]
118 fn slots() {
119 let e = Expr::slot(SlotId::principal());
120 let v: HashSet<_> = e.subexpressions().collect();
121 assert_eq!(v.len(), 1);
122 assert!(v.contains(&Expr::slot(SlotId::principal())));
123 }
124
125 #[test]
126 fn variables() {
127 let e = Expr::var(Var::Principal);
128 let v: HashSet<_> = e.subexpressions().collect();
129 let s = HashSet::from([&e]);
130 assert_eq!(v, s);
131 }
132
133 #[test]
134 fn ite() {
135 let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
136 let v: HashSet<_> = e.subexpressions().collect();
137 assert_eq!(
138 v,
139 HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
140 );
141 }
142
143 #[test]
144 fn and() {
145 let e = Expr::and(Expr::val(1), Expr::val(false));
148 println!("{:?}", e);
149 let v: HashSet<_> = e.subexpressions().collect();
150 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
151 }
152
153 #[test]
154 fn or() {
155 let e = Expr::or(Expr::val(1), Expr::val(false));
158 let v: HashSet<_> = e.subexpressions().collect();
159 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
160 }
161
162 #[test]
163 fn unary() {
164 let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
165 assert_eq!(
166 e.subexpressions().collect::<HashSet<_>>(),
167 HashSet::from([&e, &Expr::val(false)])
168 );
169 }
170
171 #[test]
172 fn binary() {
173 let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
174 assert_eq!(
175 e.subexpressions().collect::<HashSet<_>>(),
176 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
177 );
178 }
179
180 #[test]
181 fn ext() {
182 let e = Expr::call_extension_fn(
183 "test".parse().unwrap(),
184 vec![Expr::val(false), Expr::val(true)],
185 );
186 assert_eq!(
187 e.subexpressions().collect::<HashSet<_>>(),
188 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
189 );
190 }
191
192 #[test]
193 fn has_attr() {
194 let e = Expr::has_attr(Expr::val(false), "test".into());
195 assert_eq!(
196 e.subexpressions().collect::<HashSet<_>>(),
197 HashSet::from([&e, &Expr::val(false)])
198 );
199 }
200
201 #[test]
202 fn get_attr() {
203 let e = Expr::get_attr(Expr::val(false), "test".into());
204 assert_eq!(
205 e.subexpressions().collect::<HashSet<_>>(),
206 HashSet::from([&e, &Expr::val(false)])
207 );
208 }
209
210 #[test]
211 fn set() {
212 let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
213 assert_eq!(
214 e.subexpressions().collect::<HashSet<_>>(),
215 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
216 );
217 }
218
219 #[test]
220 fn set_duplicates() {
221 let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
222 let v: Vec<_> = e.subexpressions().collect();
223 assert_eq!(v.len(), 3);
224 assert!(v.contains(&&Expr::val(true)));
225 }
226
227 #[test]
228 fn record() {
229 let e = Expr::record(vec![
230 ("test".into(), Expr::val(true)),
231 ("another".into(), Expr::val(false)),
232 ])
233 .unwrap();
234 assert_eq!(
235 e.subexpressions().collect::<HashSet<_>>(),
236 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
237 );
238 }
239
240 #[test]
241 fn is() {
242 let e = Expr::is_entity_type(Expr::val(1), "T".parse().unwrap());
243 assert_eq!(
244 e.subexpressions().collect::<HashSet<_>>(),
245 HashSet::from([&e, &Expr::val(1)])
246 );
247 }
248
249 #[test]
250 fn duplicates() {
251 let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
252 let v: Vec<_> = e.subexpressions().collect();
253 assert_eq!(v.len(), 4);
254 assert!(v.contains(&&e));
255 assert!(v.contains(&&Expr::val(true)));
256 }
257
258 #[test]
259 fn deeply_nested() {
260 let e = Expr::get_attr(
261 Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
262 "attr1".into(),
263 );
264 let set: HashSet<_> = e.subexpressions().collect();
265 assert!(set.contains(&e));
266 assert!(set.contains(&Expr::get_attr(
267 Expr::and(Expr::val(1), Expr::val(0)),
268 "attr2".into()
269 )));
270 assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
271 assert!(set.contains(&Expr::val(1)));
272 assert!(set.contains(&Expr::val(0)));
273 }
274}