1use crate::ast::{All, Any, Expr};
2use crate::visit_mut::VisitMut;
3use crate::visit_mut::walk_mut_expr_list;
4
5fn contains_cross_same<T: Eq>(lhs: &[T], rhs: &[T]) -> bool {
6 lhs.iter().any(|x| rhs.contains(x))
7}
8
9pub struct SimplifyNestedList;
10
11impl<T> VisitMut<T> for SimplifyNestedList
12where
13 T: Eq,
14{
15 fn visit_mut_any(&mut self, Any(any): &mut Any<T>) {
17 walk_mut_expr_list(self, any);
18
19 let mut i = 0;
20 while i < any.len() {
21 if let Expr::All(All(all)) = &any[i] {
22 if contains_cross_same(all, any) {
23 any.remove(i);
24 continue;
25 }
26 }
27
28 i += 1;
29 }
30 }
31
32 fn visit_mut_all(&mut self, All(all): &mut All<T>) {
34 walk_mut_expr_list(self, all);
35
36 let mut i = 0;
37 while i < all.len() {
38 if let Expr::Any(Any(any)) = &all[i] {
39 if contains_cross_same(any, all) {
40 all.remove(i);
41 continue;
42 }
43 }
44
45 i += 1;
46 }
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use crate::ast::{all, any, expr, not, var};
54
55 #[test]
56 fn simplify_any_with_nested_all_containing_same_element() {
57 let mut x: Expr<u32> = expr(any((var(1), all((var(1), var(2))), var(3))));
59 let expected: Expr<u32> = expr(any((var(1), var(3))));
60
61 SimplifyNestedList.visit_mut_expr(&mut x);
62
63 assert_eq!(x.to_string(), expected.to_string());
64 }
65
66 #[test]
67 fn simplify_all_with_nested_any_containing_same_element() {
68 let mut x: Expr<u32> = expr(all((var(1), any((var(1), var(2))), var(3))));
70 let expected: Expr<u32> = expr(all((var(1), var(3))));
71
72 SimplifyNestedList.visit_mut_expr(&mut x);
73
74 assert_eq!(x.to_string(), expected.to_string());
75 }
76
77 #[test]
78 fn no_simplify_any_with_nested_all_no_common_elements() {
79 let mut x: Expr<u32> = expr(any((var(1), all((var(2), var(4))), var(3))));
81 let expected: Expr<u32> = expr(any((var(1), all((var(2), var(4))), var(3))));
82
83 SimplifyNestedList.visit_mut_expr(&mut x);
84
85 assert_eq!(x.to_string(), expected.to_string());
86 }
87
88 #[test]
89 fn no_simplify_all_with_nested_any_no_common_elements() {
90 let mut x: Expr<u32> = expr(all((var(1), any((var(2), var(4))), var(3))));
92 let expected: Expr<u32> = expr(all((var(1), any((var(2), var(4))), var(3))));
93
94 SimplifyNestedList.visit_mut_expr(&mut x);
95
96 assert_eq!(x.to_string(), expected.to_string());
97 }
98
99 #[test]
100 fn simplify_multiple_nested_expressions_in_any() {
101 let mut x: Expr<u32> = expr(any((
103 var(1),
104 all((var(1), var(2))),
105 all((var(1), var(4))),
106 var(3),
107 )));
108 let expected: Expr<u32> = expr(any((var(1), var(3))));
109
110 SimplifyNestedList.visit_mut_expr(&mut x);
111
112 assert_eq!(x.to_string(), expected.to_string());
113 }
114
115 #[test]
116 fn simplify_multiple_nested_expressions_in_all() {
117 let mut x: Expr<u32> = expr(all((
119 var(1),
120 any((var(1), var(2))),
121 any((var(1), var(4))),
122 var(3),
123 )));
124 let expected: Expr<u32> = expr(all((var(1), var(3))));
125
126 SimplifyNestedList.visit_mut_expr(&mut x);
127
128 assert_eq!(x.to_string(), expected.to_string());
129 }
130
131 #[test]
132 fn simplify_nested_with_mixed_content() {
133 let mut x: Expr<u32> = expr(any((var(1), all((var(1), var(2))), not(var(3)))));
135 let expected: Expr<u32> = expr(any((var(1), not(var(3)))));
136
137 SimplifyNestedList.visit_mut_expr(&mut x);
138
139 assert_eq!(x.to_string(), expected.to_string());
140 }
141
142 #[test]
143 fn simplify_deeply_nested_expressions() {
144 let mut x: Expr<u32> = expr(any((var(1), all((var(1), any((var(2), var(3))))))));
146 let expected: Expr<u32> = expr(any((var(1),)));
147
148 SimplifyNestedList.visit_mut_expr(&mut x);
149
150 assert_eq!(x.to_string(), expected.to_string());
151 }
152
153 #[test]
154 fn no_change_when_no_simplification_possible() {
155 let mut x: Expr<u32> = expr(any((all((var(1), var(2))), all((var(3), var(4))))));
157 let expected: Expr<u32> = expr(any((all((var(1), var(2))), all((var(3), var(4))))));
158
159 SimplifyNestedList.visit_mut_expr(&mut x);
160
161 assert_eq!(x.to_string(), expected.to_string());
162 }
163
164 #[test]
165 fn contains_cross_same_function_test() {
166 assert!(contains_cross_same(&[1, 2, 3], &[3, 4, 5]));
167 assert!(contains_cross_same(&[1, 2], &[2, 3]));
168 assert!(!contains_cross_same(&[1, 2], &[3, 4]));
169 assert!(!contains_cross_same(&[], &[1, 2]));
170 assert!(!contains_cross_same(&[1, 2], &[]));
171 }
172}