bool_logic/transforms/
flatten_single.rs

1#![allow(clippy::single_match)]
2
3use crate::ast::{All, Any, Expr, Not};
4use crate::visit_mut::VisitMut;
5use crate::visit_mut::walk_mut_expr;
6
7fn take<T>(expr: &mut Expr<T>) -> Expr<T> {
8    std::mem::replace(expr, Expr::Const(false))
9}
10
11pub struct FlattenSingle;
12
13impl<T> VisitMut<T> for FlattenSingle {
14    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
15        walk_mut_expr(self, expr);
16
17        match expr {
18            Expr::Any(Any(any)) => match any.as_mut_slice() {
19                [] => *expr = Expr::Const(false),
20                [sub] => *expr = take(sub),
21                _ => {}
22            },
23            Expr::All(All(all)) => match all.as_mut_slice() {
24                [] => *expr = Expr::Const(true),
25                [sub] => *expr = take(sub),
26                _ => {}
27            },
28            Expr::Not(Not(not)) => match &mut **not {
29                Expr::Not(Not(sub)) => *expr = take(sub),
30                _ => {}
31            },
32            _ => {}
33        }
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use super::*;
40    use crate::ast::{all, any, const_, expr, not, var};
41
42    #[test]
43    fn flatten_empty_any() {
44        let mut x: Expr<u32> = expr(any(()));
45        let expected: Expr<u32> = expr(const_(false));
46
47        FlattenSingle.visit_mut_expr(&mut x);
48
49        assert_eq!(x.to_string(), expected.to_string());
50    }
51
52    #[test]
53    fn flatten_empty_all() {
54        let mut x: Expr<u32> = expr(all(()));
55        let expected: Expr<u32> = expr(const_(true));
56
57        FlattenSingle.visit_mut_expr(&mut x);
58
59        assert_eq!(x.to_string(), expected.to_string());
60    }
61
62    #[test]
63    fn flatten_single_any() {
64        let mut x: Expr<u32> = expr(any((var(1),)));
65        let expected: Expr<u32> = expr(var(1));
66
67        FlattenSingle.visit_mut_expr(&mut x);
68
69        assert_eq!(x.to_string(), expected.to_string());
70    }
71
72    #[test]
73    fn flatten_single_all() {
74        let mut x: Expr<u32> = expr(all((var(1),)));
75        let expected: Expr<u32> = expr(var(1));
76
77        FlattenSingle.visit_mut_expr(&mut x);
78
79        assert_eq!(x.to_string(), expected.to_string());
80    }
81
82    #[test]
83    fn flatten_double_negation() {
84        let mut x: Expr<u32> = expr(not(not(var(1))));
85        let expected: Expr<u32> = expr(var(1));
86
87        FlattenSingle.visit_mut_expr(&mut x);
88
89        assert_eq!(x.to_string(), expected.to_string());
90    }
91
92    #[test]
93    fn no_flatten_multiple_any() {
94        let mut x: Expr<u32> = expr(any((var(1), var(2))));
95        let expected: Expr<u32> = expr(any((var(1), var(2))));
96
97        FlattenSingle.visit_mut_expr(&mut x);
98
99        assert_eq!(x.to_string(), expected.to_string());
100    }
101
102    #[test]
103    fn no_flatten_multiple_all() {
104        let mut x: Expr<u32> = expr(all((var(1), var(2))));
105        let expected: Expr<u32> = expr(all((var(1), var(2))));
106
107        FlattenSingle.visit_mut_expr(&mut x);
108
109        assert_eq!(x.to_string(), expected.to_string());
110    }
111
112    #[test]
113    fn no_flatten_single_negation() {
114        let mut x: Expr<u32> = expr(not(var(1)));
115        let expected: Expr<u32> = expr(not(var(1)));
116
117        FlattenSingle.visit_mut_expr(&mut x);
118
119        assert_eq!(x.to_string(), expected.to_string());
120    }
121
122    #[test]
123    fn flatten_nested_expressions() {
124        let mut x: Expr<u32> = expr(any((all((var(1),)),)));
125        let expected: Expr<u32> = expr(var(1));
126
127        FlattenSingle.visit_mut_expr(&mut x);
128
129        assert_eq!(x.to_string(), expected.to_string());
130    }
131}