bool_logic/transforms/
flatten_nested_list.rs

1use crate::ast::{All, Any, Expr};
2use crate::visit_mut::VisitMut;
3use crate::visit_mut::walk_mut_expr_list;
4
5pub struct FlattenNestedList;
6
7impl<T> VisitMut<T> for FlattenNestedList {
8    fn visit_mut_any(&mut self, Any(list): &mut Any<T>) {
9        walk_mut_expr_list(self, list);
10
11        if list.iter().any(Expr::is_any) {
12            let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
13            for item in list.drain(..) {
14                if let Expr::Any(Any(any)) = item {
15                    ans.extend(any);
16                } else {
17                    ans.push(item);
18                }
19            }
20            *list = ans;
21        }
22    }
23
24    fn visit_mut_all(&mut self, All(list): &mut All<T>) {
25        walk_mut_expr_list(self, list);
26
27        if list.iter().any(Expr::is_all) {
28            let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
29            for item in list.drain(..) {
30                if let Expr::All(All(all)) = item {
31                    ans.extend(all);
32                } else {
33                    ans.push(item);
34                }
35            }
36            *list = ans;
37        }
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use crate::ast::{all, any, const_, expr, not, var};
45
46    #[test]
47    fn flatten_nested_any() {
48        let mut x: Expr<u32> = expr(any((any((var(1), var(2))), any((var(3), var(4))))));
49        let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4))));
50
51        FlattenNestedList.visit_mut_expr(&mut x);
52
53        assert_eq!(x.to_string(), expected.to_string());
54    }
55
56    #[test]
57    fn flatten_nested_all() {
58        let mut x: Expr<u32> = expr(all((all((var(1), var(2))), all((var(3), var(4))))));
59        let expected: Expr<u32> = expr(all((var(1), var(2), var(3), var(4))));
60
61        FlattenNestedList.visit_mut_expr(&mut x);
62
63        assert_eq!(x.to_string(), expected.to_string());
64    }
65
66    #[test]
67    fn flatten_mixed_nested() {
68        let mut x: Expr<u32> = expr(any((any((var(1), var(2))), var(3), any((var(4), var(5))))));
69        let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4), var(5))));
70
71        FlattenNestedList.visit_mut_expr(&mut x);
72
73        assert_eq!(x.to_string(), expected.to_string());
74    }
75
76    #[test]
77    fn flatten_deep_nested() {
78        let mut x: Expr<u32> = expr(any((
79            any((any((var(1), var(2))), var(3))),
80            any((var(4), var(5))),
81        )));
82        let expected: Expr<u32> = expr(any((var(1), var(2), var(3), var(4), var(5))));
83
84        FlattenNestedList.visit_mut_expr(&mut x);
85
86        assert_eq!(x.to_string(), expected.to_string());
87    }
88
89    #[test]
90    fn no_flatten_mixed_types() {
91        let mut x: Expr<u32> = expr(any((all((var(1), var(2))), var(3))));
92        let expected: Expr<u32> = expr(any((all((var(1), var(2))), var(3))));
93
94        FlattenNestedList.visit_mut_expr(&mut x);
95
96        assert_eq!(x.to_string(), expected.to_string());
97    }
98
99    #[test]
100    fn flatten_single_nested() {
101        let mut x: Expr<u32> = expr(any(any((var(1), var(2)))));
102        let expected: Expr<u32> = expr(any((var(1), var(2))));
103
104        FlattenNestedList.visit_mut_expr(&mut x);
105
106        assert_eq!(x.to_string(), expected.to_string());
107    }
108
109    #[test]
110    fn flatten_complex_expression() {
111        let mut x: Expr<u32> = expr(all((
112            all((var(1), not(var(2)))),
113            all((const_(true), var(3))),
114            var(4),
115        )));
116        let expected: Expr<u32> = expr(all((var(1), not(var(2)), const_(true), var(3), var(4))));
117
118        FlattenNestedList.visit_mut_expr(&mut x);
119
120        assert_eq!(x.to_string(), expected.to_string());
121    }
122}