bool_logic/transforms/
eval_const.rs

1use std::ops::Not as _;
2
3use crate::ast::{All, Any, Expr, Not};
4use crate::utils::remove_if;
5use crate::visit_mut::VisitMut;
6use crate::visit_mut::walk_mut_expr;
7
8pub struct EvalConst;
9
10impl EvalConst {
11    fn eval_any<T>(any: &mut Vec<Expr<T>>) -> Option<bool> {
12        remove_if(any, Expr::is_const_false);
13
14        if any.is_empty() {
15            return Some(false);
16        }
17
18        if any.iter().any(Expr::is_const_true) {
19            return Some(true);
20        }
21
22        None
23    }
24
25    fn eval_all<T>(all: &mut Vec<Expr<T>>) -> Option<bool> {
26        remove_if(all, Expr::is_const_true);
27
28        if all.is_empty() {
29            return Some(true);
30        }
31
32        if all.iter().any(Expr::is_const_false) {
33            return Some(false);
34        }
35
36        None
37    }
38
39    fn eval_not<T>(not: &Expr<T>) -> Option<bool> {
40        if let Expr::Const(val) = not {
41            return Some(val.not());
42        }
43        None
44    }
45}
46
47impl<T> VisitMut<T> for EvalConst {
48    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
49        walk_mut_expr(self, expr);
50
51        match expr {
52            Expr::Any(Any(any)) => {
53                if let Some(val) = Self::eval_any(any) {
54                    *expr = Expr::Const(val);
55                }
56            }
57            Expr::All(All(all)) => {
58                if let Some(val) = Self::eval_all(all) {
59                    *expr = Expr::Const(val);
60                }
61            }
62            Expr::Not(Not(not)) => {
63                if let Some(val) = Self::eval_not(not) {
64                    *expr = Expr::Const(val);
65                }
66            }
67            _ => {}
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::ast::{all, any, const_, expr, not, var};
76
77    #[test]
78    fn eval_any_with_false() {
79        let mut x: Expr<u32> = expr(any((const_(false), var(1))));
80        let expected: Expr<u32> = expr(any((var(1),)));
81
82        EvalConst.visit_mut_expr(&mut x);
83
84        assert_eq!(x.to_string(), expected.to_string());
85    }
86
87    #[test]
88    fn eval_any_with_true() {
89        let mut x: Expr<u32> = expr(any((var(1), const_(true), var(2))));
90        let expected: Expr<u32> = expr(const_(true));
91
92        EvalConst.visit_mut_expr(&mut x);
93
94        assert_eq!(x.to_string(), expected.to_string());
95    }
96
97    #[test]
98    fn eval_any_all_false() {
99        let mut x: Expr<u32> = expr(any((const_(false), const_(false))));
100        let expected: Expr<u32> = expr(const_(false));
101
102        EvalConst.visit_mut_expr(&mut x);
103
104        assert_eq!(x.to_string(), expected.to_string());
105    }
106
107    #[test]
108    fn eval_all_with_true() {
109        let mut x: Expr<u32> = expr(all((const_(true), var(1))));
110        let expected: Expr<u32> = expr(all((var(1),)));
111
112        EvalConst.visit_mut_expr(&mut x);
113
114        assert_eq!(x.to_string(), expected.to_string());
115    }
116
117    #[test]
118    fn eval_all_with_false() {
119        let mut x: Expr<u32> = expr(all((var(1), const_(false), var(2))));
120        let expected: Expr<u32> = expr(const_(false));
121
122        EvalConst.visit_mut_expr(&mut x);
123
124        assert_eq!(x.to_string(), expected.to_string());
125    }
126
127    #[test]
128    fn eval_all_all_true() {
129        let mut x: Expr<u32> = expr(all((const_(true), const_(true))));
130        let expected: Expr<u32> = expr(const_(true));
131
132        EvalConst.visit_mut_expr(&mut x);
133
134        assert_eq!(x.to_string(), expected.to_string());
135    }
136
137    #[test]
138    fn eval_not_true() {
139        let mut x: Expr<u32> = expr(not(const_(true)));
140        let expected: Expr<u32> = expr(const_(false));
141
142        EvalConst.visit_mut_expr(&mut x);
143
144        assert_eq!(x.to_string(), expected.to_string());
145    }
146
147    #[test]
148    fn eval_not_false() {
149        let mut x: Expr<u32> = expr(not(const_(false)));
150        let expected: Expr<u32> = expr(const_(true));
151
152        EvalConst.visit_mut_expr(&mut x);
153
154        assert_eq!(x.to_string(), expected.to_string());
155    }
156
157    #[test]
158    fn no_eval_not_variable() {
159        let mut x: Expr<u32> = expr(not(var(1)));
160        let expected: Expr<u32> = expr(not(var(1)));
161
162        EvalConst.visit_mut_expr(&mut x);
163
164        assert_eq!(x.to_string(), expected.to_string());
165    }
166
167    #[test]
168    fn eval_nested_expressions() {
169        let mut x: Expr<u32> = expr(any((all((const_(true), var(1))), const_(false))));
170        let expected: Expr<u32> = expr(any((all((var(1),)),)));
171
172        EvalConst.visit_mut_expr(&mut x);
173
174        assert_eq!(x.to_string(), expected.to_string());
175    }
176
177    #[test]
178    fn eval_complex_nested() {
179        let mut x: Expr<u32> = expr(not(all((const_(false), var(1)))));
180        let expected: Expr<u32> = expr(const_(true));
181
182        EvalConst.visit_mut_expr(&mut x);
183
184        assert_eq!(x.to_string(), expected.to_string());
185    }
186
187    #[test]
188    fn no_change_for_variables_only() {
189        let mut x: Expr<u32> = expr(any((var(1), var(2))));
190        let expected: Expr<u32> = expr(any((var(1), var(2))));
191
192        EvalConst.visit_mut_expr(&mut x);
193
194        assert_eq!(x.to_string(), expected.to_string());
195    }
196}