bool_logic/transforms/
simplify_nested_list.rs

1use crate::ast::{All, Any, Expr};
2use crate::visit_mut::VisitMut;
3use crate::visit_mut::walk_mut_expr_list;
4
5fn contains_cross_same<T: Eq>(lhs: &[T], rhs: &[T]) -> bool {
6    lhs.iter().any(|x| rhs.contains(x))
7}
8
9pub struct SimplifyNestedList;
10
11impl<T> VisitMut<T> for SimplifyNestedList
12where
13    T: Eq,
14{
15    /// `any(x0, all(x0, x1), x2) => any(x0, x2)`
16    fn visit_mut_any(&mut self, Any(any): &mut Any<T>) {
17        walk_mut_expr_list(self, any);
18
19        let mut i = 0;
20        while i < any.len() {
21            if let Expr::All(All(all)) = &any[i] {
22                if contains_cross_same(all, any) {
23                    any.remove(i);
24                    continue;
25                }
26            }
27
28            i += 1;
29        }
30    }
31
32    /// `all(x0, any(x0, x1), x2) => all(x0, x2)`
33    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
34        walk_mut_expr_list(self, all);
35
36        let mut i = 0;
37        while i < all.len() {
38            if let Expr::Any(Any(any)) = &all[i] {
39                if contains_cross_same(any, all) {
40                    all.remove(i);
41                    continue;
42                }
43            }
44
45            i += 1;
46        }
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use crate::ast::{all, any, expr, not, var};
54
55    #[test]
56    fn simplify_any_with_nested_all_containing_same_element() {
57        // any(x0, all(x0, x1), x2) => any(x0, x2)
58        let mut x: Expr<u32> = expr(any((var(1), all((var(1), var(2))), var(3))));
59        let expected: Expr<u32> = expr(any((var(1), var(3))));
60
61        SimplifyNestedList.visit_mut_expr(&mut x);
62
63        assert_eq!(x.to_string(), expected.to_string());
64    }
65
66    #[test]
67    fn simplify_all_with_nested_any_containing_same_element() {
68        // all(x0, any(x0, x1), x2) => all(x0, x2)
69        let mut x: Expr<u32> = expr(all((var(1), any((var(1), var(2))), var(3))));
70        let expected: Expr<u32> = expr(all((var(1), var(3))));
71
72        SimplifyNestedList.visit_mut_expr(&mut x);
73
74        assert_eq!(x.to_string(), expected.to_string());
75    }
76
77    #[test]
78    fn no_simplify_any_with_nested_all_no_common_elements() {
79        // any(x0, all(x1, x2), x3) should not be simplified
80        let mut x: Expr<u32> = expr(any((var(1), all((var(2), var(4))), var(3))));
81        let expected: Expr<u32> = expr(any((var(1), all((var(2), var(4))), var(3))));
82
83        SimplifyNestedList.visit_mut_expr(&mut x);
84
85        assert_eq!(x.to_string(), expected.to_string());
86    }
87
88    #[test]
89    fn no_simplify_all_with_nested_any_no_common_elements() {
90        // all(x0, any(x1, x2), x3) should not be simplified
91        let mut x: Expr<u32> = expr(all((var(1), any((var(2), var(4))), var(3))));
92        let expected: Expr<u32> = expr(all((var(1), any((var(2), var(4))), var(3))));
93
94        SimplifyNestedList.visit_mut_expr(&mut x);
95
96        assert_eq!(x.to_string(), expected.to_string());
97    }
98
99    #[test]
100    fn simplify_multiple_nested_expressions_in_any() {
101        // any(x0, all(x0, x1), all(x0, x2), x3) => any(x0, x3)
102        let mut x: Expr<u32> = expr(any((
103            var(1),
104            all((var(1), var(2))),
105            all((var(1), var(4))),
106            var(3),
107        )));
108        let expected: Expr<u32> = expr(any((var(1), var(3))));
109
110        SimplifyNestedList.visit_mut_expr(&mut x);
111
112        assert_eq!(x.to_string(), expected.to_string());
113    }
114
115    #[test]
116    fn simplify_multiple_nested_expressions_in_all() {
117        // all(x0, any(x0, x1), any(x0, x2), x3) => all(x0, x3)
118        let mut x: Expr<u32> = expr(all((
119            var(1),
120            any((var(1), var(2))),
121            any((var(1), var(4))),
122            var(3),
123        )));
124        let expected: Expr<u32> = expr(all((var(1), var(3))));
125
126        SimplifyNestedList.visit_mut_expr(&mut x);
127
128        assert_eq!(x.to_string(), expected.to_string());
129    }
130
131    #[test]
132    fn simplify_nested_with_mixed_content() {
133        // any(x0, all(x0, x1), not(x2)) => any(x0, not(x2))
134        let mut x: Expr<u32> = expr(any((var(1), all((var(1), var(2))), not(var(3)))));
135        let expected: Expr<u32> = expr(any((var(1), not(var(3)))));
136
137        SimplifyNestedList.visit_mut_expr(&mut x);
138
139        assert_eq!(x.to_string(), expected.to_string());
140    }
141
142    #[test]
143    fn simplify_deeply_nested_expressions() {
144        // any(x0, all(x0, any(x1, x2))) => any(x0)
145        let mut x: Expr<u32> = expr(any((var(1), all((var(1), any((var(2), var(3))))))));
146        let expected: Expr<u32> = expr(any((var(1),)));
147
148        SimplifyNestedList.visit_mut_expr(&mut x);
149
150        assert_eq!(x.to_string(), expected.to_string());
151    }
152
153    #[test]
154    fn no_change_when_no_simplification_possible() {
155        // any(all(x0, x1), all(x2, x3)) should remain unchanged
156        let mut x: Expr<u32> = expr(any((all((var(1), var(2))), all((var(3), var(4))))));
157        let expected: Expr<u32> = expr(any((all((var(1), var(2))), all((var(3), var(4))))));
158
159        SimplifyNestedList.visit_mut_expr(&mut x);
160
161        assert_eq!(x.to_string(), expected.to_string());
162    }
163
164    #[test]
165    fn contains_cross_same_function_test() {
166        assert!(contains_cross_same(&[1, 2, 3], &[3, 4, 5]));
167        assert!(contains_cross_same(&[1, 2], &[2, 3]));
168        assert!(!contains_cross_same(&[1, 2], &[3, 4]));
169        assert!(!contains_cross_same(&[], &[1, 2]));
170        assert!(!contains_cross_same(&[1, 2], &[]));
171    }
172}