bool_logic/transforms/
merge_all_of_not_any.rs

1use std::ops::Not as _;
2
3use crate::ast::{All, Expr, Not, Var};
4use crate::utils::{drain_filter, remove_if};
5use crate::visit_mut::{VisitMut, walk_mut_expr_list};
6
7fn as_mut_not_any<T>(expr: &mut Expr<T>) -> Option<&mut Vec<Expr<T>>> {
8    expr.as_mut_not_any().map(|x| &mut x.0)
9}
10
11fn unwrap_expr_not_var<T>(expr: Expr<T>) -> Var<T> {
12    if let Expr::Not(Not(not)) = expr {
13        if let Expr::Var(var) = *not {
14            return var;
15        }
16    }
17    panic!()
18}
19
20pub struct MergeAllOfNotAny;
21
22impl<T> VisitMut<T> for MergeAllOfNotAny {
23    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
24        walk_mut_expr_list(self, all);
25
26        let mut not_any_list: Vec<_> = all.iter_mut().filter_map(as_mut_not_any).collect();
27
28        if let [first, rest @ ..] = not_any_list.as_mut_slice() {
29            if rest.is_empty().not() {
30                for x in rest {
31                    first.append(x);
32                }
33                remove_if(all, Expr::is_empty_not_any);
34            }
35
36            {
37                let not_var_list: Vec<_> = drain_filter(all, |x| x.is_expr_not_var()).collect();
38                let not_any = all.iter_mut().find_map(as_mut_not_any).unwrap();
39
40                for not_var in not_var_list {
41                    let var = unwrap_expr_not_var(not_var);
42                    not_any.push(Expr::Var(var));
43                }
44            }
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::ast::{all, any, expr, not, var};
53
54    #[test]
55    fn merge_multiple_not_any() {
56        let mut x: Expr<u32> = expr(all((
57            not(any((var(1), var(2)))),
58            not(any((var(3), var(4)))),
59            var(5),
60        )));
61        let expected: Expr<u32> = expr(all((not(any((var(1), var(2), var(3), var(4)))), var(5))));
62
63        MergeAllOfNotAny.visit_mut_expr(&mut x);
64
65        assert_eq!(x.to_string(), expected.to_string());
66    }
67
68    #[test]
69    fn merge_not_var_into_not_any() {
70        let mut x: Expr<u32> = expr(all((not(any((var(1), var(2)))), not(var(3)), var(4))));
71        let expected: Expr<u32> = expr(all((not(any((var(1), var(2), var(3)))), var(4))));
72
73        MergeAllOfNotAny.visit_mut_expr(&mut x);
74
75        assert_eq!(x.to_string(), expected.to_string());
76    }
77
78    #[test]
79    fn merge_multiple_not_any_and_not_var() {
80        let mut x: Expr<u32> = expr(all((
81            not(any((var(1),))),
82            not(any((var(2),))),
83            not(var(3)),
84            not(var(4)),
85            var(5),
86        )));
87        let expected: Expr<u32> = expr(all((not(any((var(1), var(2), var(3), var(4)))), var(5))));
88
89        MergeAllOfNotAny.visit_mut_expr(&mut x);
90
91        assert_eq!(x.to_string(), expected.to_string());
92    }
93
94    #[test]
95    fn single_not_any_with_not_var() {
96        let mut x: Expr<u32> = expr(all((not(any((var(1),))), not(var(2)), var(3))));
97        let expected: Expr<u32> = expr(all((not(any((var(1), var(2)))), var(3))));
98
99        MergeAllOfNotAny.visit_mut_expr(&mut x);
100
101        assert_eq!(x.to_string(), expected.to_string());
102    }
103
104    #[test]
105    fn no_merge_when_no_not_any() {
106        let mut x: Expr<u32> = expr(all((var(1), var(2), not(var(3)))));
107        let expected: Expr<u32> = expr(all((var(1), var(2), not(var(3)))));
108
109        MergeAllOfNotAny.visit_mut_expr(&mut x);
110
111        assert_eq!(x.to_string(), expected.to_string());
112    }
113
114    #[test]
115    fn no_merge_when_single_not_any_no_not_var() {
116        let mut x: Expr<u32> = expr(all((not(any((var(1), var(2)))), var(3))));
117        let expected: Expr<u32> = expr(all((not(any((var(1), var(2)))), var(3))));
118
119        MergeAllOfNotAny.visit_mut_expr(&mut x);
120
121        assert_eq!(x.to_string(), expected.to_string());
122    }
123
124    #[test]
125    fn merge_empty_not_any() {
126        let mut x: Expr<u32> = expr(all((not(any((var(1),))), not(any(())), var(2))));
127        let expected: Expr<u32> = expr(all((not(any((var(1),))), var(2))));
128
129        MergeAllOfNotAny.visit_mut_expr(&mut x);
130
131        assert_eq!(x.to_string(), expected.to_string());
132    }
133
134    #[test]
135    fn complex_merge_scenario() {
136        let mut x: Expr<u32> = expr(all((
137            not(any((var(1), var(2)))),
138            var(3),
139            not(any((var(4),))),
140            not(var(5)),
141            not(any((var(6), var(7)))),
142            not(var(8)),
143        )));
144        let expected: Expr<u32> = expr(all((
145            not(any((
146                var(1),
147                var(2),
148                var(4),
149                var(6),
150                var(7),
151                var(5),
152                var(8),
153            ))),
154            var(3),
155        )));
156
157        MergeAllOfNotAny.visit_mut_expr(&mut x);
158
159        assert_eq!(x.to_string(), expected.to_string());
160    }
161}