1use super::{Expr, ExprKind};
18
19#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23 expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32 pub fn new(expr: &'a Expr<T>) -> Self {
34 Self {
35 expression_stack: vec![expr],
36 }
37 }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41 type Item = &'a Expr<T>;
42
43 fn next(&mut self) -> Option<Self::Item> {
44 let next_expr = self.expression_stack.pop()?;
45 match next_expr.expr_kind() {
46 ExprKind::Lit(_) => (),
47 ExprKind::Unknown { .. } => (),
48 ExprKind::Slot(_) => (),
49 ExprKind::Var(_) => (),
50 ExprKind::If {
51 test_expr,
52 then_expr,
53 else_expr,
54 } => {
55 self.expression_stack.push(test_expr);
56 self.expression_stack.push(then_expr);
57 self.expression_stack.push(else_expr);
58 }
59 ExprKind::And { left, right } => {
60 self.expression_stack.push(left);
61 self.expression_stack.push(right);
62 }
63 ExprKind::Or { left, right } => {
64 self.expression_stack.push(left);
65 self.expression_stack.push(right);
66 }
67 ExprKind::UnaryApp { arg, .. } => {
68 self.expression_stack.push(arg);
69 }
70 ExprKind::BinaryApp { arg1, arg2, .. } => {
71 self.expression_stack.push(arg1);
72 self.expression_stack.push(arg2);
73 }
74 ExprKind::MulByConst { arg, .. } => {
75 self.expression_stack.push(arg);
76 }
77 ExprKind::ExtensionFunctionApp { args, .. } => {
78 for arg in args.as_ref() {
79 self.expression_stack.push(arg);
80 }
81 }
82 ExprKind::GetAttr { expr, attr: _ } => {
83 self.expression_stack.push(expr);
84 }
85 ExprKind::HasAttr { expr, attr: _ } => {
86 self.expression_stack.push(expr);
87 }
88 ExprKind::Like { expr, pattern: _ } => {
89 self.expression_stack.push(expr);
90 }
91 ExprKind::Set(elems) => {
92 for expr in elems.as_ref() {
93 self.expression_stack.push(expr);
94 }
95 }
96 ExprKind::Record { pairs } => {
97 for (_, val_expr) in pairs.as_ref() {
98 self.expression_stack.push(val_expr);
99 }
100 }
101 }
102 Some(next_expr)
103 }
104}
105
106#[cfg(test)]
107mod test {
108 use std::collections::HashSet;
109
110 use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
111
112 #[test]
113 fn literals() {
114 let e = Expr::val(true);
115 let v: HashSet<_> = e.subexpressions().collect();
116
117 assert_eq!(v.len(), 1);
118 assert!(v.contains(&Expr::val(true)));
119 }
120
121 #[test]
122 fn slots() {
123 let e = Expr::slot(SlotId::principal());
124 let v: HashSet<_> = e.subexpressions().collect();
125 assert_eq!(v.len(), 1);
126 assert!(v.contains(&Expr::slot(SlotId::principal())));
127 }
128
129 #[test]
130 fn variables() {
131 let e = Expr::var(Var::Principal);
132 let v: HashSet<_> = e.subexpressions().collect();
133 let s = HashSet::from([&e]);
134 assert_eq!(v, s);
135 }
136
137 #[test]
138 fn ite() {
139 let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
140 let v: HashSet<_> = e.subexpressions().collect();
141 assert_eq!(
142 v,
143 HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
144 );
145 }
146
147 #[test]
148 fn and() {
149 let e = Expr::and(Expr::val(1), Expr::val(false));
152 println!("{:?}", e);
153 let v: HashSet<_> = e.subexpressions().collect();
154 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
155 }
156
157 #[test]
158 fn or() {
159 let e = Expr::or(Expr::val(1), Expr::val(false));
162 let v: HashSet<_> = e.subexpressions().collect();
163 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
164 }
165
166 #[test]
167 fn unary() {
168 let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
169 assert_eq!(
170 e.subexpressions().collect::<HashSet<_>>(),
171 HashSet::from([&e, &Expr::val(false)])
172 );
173 }
174
175 #[test]
176 fn binary() {
177 let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
178 assert_eq!(
179 e.subexpressions().collect::<HashSet<_>>(),
180 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
181 );
182 }
183
184 #[test]
185 fn ext() {
186 let e = Expr::call_extension_fn(
187 "test".parse().unwrap(),
188 vec![Expr::val(false), Expr::val(true)],
189 );
190 assert_eq!(
191 e.subexpressions().collect::<HashSet<_>>(),
192 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
193 );
194 }
195
196 #[test]
197 fn has_attr() {
198 let e = Expr::has_attr(Expr::val(false), "test".into());
199 assert_eq!(
200 e.subexpressions().collect::<HashSet<_>>(),
201 HashSet::from([&e, &Expr::val(false)])
202 );
203 }
204
205 #[test]
206 fn get_attr() {
207 let e = Expr::get_attr(Expr::val(false), "test".into());
208 assert_eq!(
209 e.subexpressions().collect::<HashSet<_>>(),
210 HashSet::from([&e, &Expr::val(false)])
211 );
212 }
213
214 #[test]
215 fn set() {
216 let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
217 assert_eq!(
218 e.subexpressions().collect::<HashSet<_>>(),
219 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
220 );
221 }
222
223 #[test]
224 fn set_duplicates() {
225 let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
226 let v: Vec<_> = e.subexpressions().collect();
227 assert_eq!(v.len(), 3);
228 assert!(v.contains(&&Expr::val(true)));
229 }
230
231 #[test]
232 fn record() {
233 let e = Expr::record(vec![
234 ("test".into(), Expr::val(true)),
235 ("another".into(), Expr::val(false)),
236 ]);
237 assert_eq!(
238 e.subexpressions().collect::<HashSet<_>>(),
239 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
240 );
241 }
242
243 #[test]
244 fn duplicates() {
245 let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
246 let v: Vec<_> = e.subexpressions().collect();
247 assert_eq!(v.len(), 4);
248 assert!(v.contains(&&e));
249 assert!(v.contains(&&Expr::val(true)));
250 }
251
252 #[test]
253 fn deeply_nested() {
254 let e = Expr::get_attr(
255 Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
256 "attr1".into(),
257 );
258 let set: HashSet<_> = e.subexpressions().collect();
259 assert!(set.contains(&e));
260 assert!(set.contains(&Expr::get_attr(
261 Expr::and(Expr::val(1), Expr::val(0)),
262 "attr2".into()
263 )));
264 assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
265 assert!(set.contains(&Expr::val(1)));
266 assert!(set.contains(&Expr::val(0)));
267 }
268}