bool_logic/transforms/
flatten_nested_list.rs1use crate::ast::{All, Any, Expr};
2use crate::visit_mut::VisitMut;
3use crate::visit_mut::walk_mut_expr_list;
4
5pub struct FlattenNestedList;
6
7impl<T> VisitMut<T> for FlattenNestedList {
8 fn visit_mut_any(&mut self, Any(list): &mut Any<T>) {
9 walk_mut_expr_list(self, list);
10
11 if list.iter().any(Expr::is_any) {
12 let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
13 for item in list.drain(..) {
14 if let Expr::Any(Any(any)) = item {
15 ans.extend(any);
16 } else {
17 ans.push(item);
18 }
19 }
20 *list = ans;
21 }
22 }
23
24 fn visit_mut_all(&mut self, All(list): &mut All<T>) {
25 walk_mut_expr_list(self, list);
26
27 if list.iter().any(Expr::is_all) {
28 let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
29 for item in list.drain(..) {
30 if let Expr::All(All(all)) = item {
31 ans.extend(all);
32 } else {
33 ans.push(item);
34 }
35 }
36 *list = ans;
37 }
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use crate::ast::{all, any, const_, expr, not, var};
45
46 #[test]
47 fn flatten_nested_any() {
48 let mut x: Expr<u32> = expr(any((any((var(1), var(2))), any((var(3), var(4))))));
49 let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4))));
50
51 FlattenNestedList.visit_mut_expr(&mut x);
52
53 assert_eq!(x.to_string(), expected.to_string());
54 }
55
56 #[test]
57 fn flatten_nested_all() {
58 let mut x: Expr<u32> = expr(all((all((var(1), var(2))), all((var(3), var(4))))));
59 let expected: Expr<u32> = expr(all((var(1), var(2), var(3), var(4))));
60
61 FlattenNestedList.visit_mut_expr(&mut x);
62
63 assert_eq!(x.to_string(), expected.to_string());
64 }
65
66 #[test]
67 fn flatten_mixed_nested() {
68 let mut x: Expr<u32> = expr(any((any((var(1), var(2))), var(3), any((var(4), var(5))))));
69 let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4), var(5))));
70
71 FlattenNestedList.visit_mut_expr(&mut x);
72
73 assert_eq!(x.to_string(), expected.to_string());
74 }
75
76 #[test]
77 fn flatten_deep_nested() {
78 let mut x: Expr<u32> = expr(any((
79 any((any((var(1), var(2))), var(3))),
80 any((var(4), var(5))),
81 )));
82 let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4), var(5))));
83
84 FlattenNestedList.visit_mut_expr(&mut x);
85
86 assert_eq!(x.to_string(), expected.to_string());
87 }
88
89 #[test]
90 fn no_flatten_mixed_types() {
91 let mut x: Expr<u32> = expr(any((all((var(1), var(2))), var(3))));
92 let expected: Expr<u32> = expr(any((all((var(1), var(2))), var(3))));
93
94 FlattenNestedList.visit_mut_expr(&mut x);
95
96 assert_eq!(x.to_string(), expected.to_string());
97 }
98
99 #[test]
100 fn flatten_single_nested() {
101 let mut x: Expr<u32> = expr(any(any((var(1), var(2)))));
102 let expected: Expr<u32> = expr(any((var(1), var(2))));
103
104 FlattenNestedList.visit_mut_expr(&mut x);
105
106 assert_eq!(x.to_string(), expected.to_string());
107 }
108
109 #[test]
110 fn flatten_complex_expression() {
111 let mut x: Expr<u32> = expr(all((
112 all((var(1), not(var(2)))),
113 all((const_(true), var(3))),
114 var(4),
115 )));
116 let expected: Expr<u32> = expr(all((var(1), not(var(2)), const_(true), var(3), var(4))));
117
118 FlattenNestedList.visit_mut_expr(&mut x);
119
120 assert_eq!(x.to_string(), expected.to_string());
121 }
122}