bool_logic/transforms/
simplify_by_short_circuit.rs

1use std::ops::Not as _;
2
3use crate::ast::{All, Any, Expr, Not, Var};
4use crate::visit_mut::VisitMut;
5
6fn find_vars<T: Eq + Clone>(list: &mut [Expr<T>], marker: bool) -> Vec<Var<T>> {
7    let mut ans: Vec<Var<T>> = Vec::new();
8    for x in list {
9        if let Expr::Var(var) = x {
10            if ans.contains(var) {
11                *x = Expr::Const(marker);
12            } else {
13                ans.push(var.clone());
14            }
15        }
16    }
17    ans
18}
19
20fn replace_vars<T: Eq>(x: &mut Expr<T>, vars: &[Var<T>], marker: bool) {
21    match x {
22        Expr::Any(Any(any)) => any.iter_mut().for_each(|x| replace_vars(x, vars, marker)),
23        Expr::All(All(all)) => all.iter_mut().for_each(|x| replace_vars(x, vars, marker)),
24        Expr::Not(Not(not)) => replace_vars(not, vars, marker),
25        Expr::Var(var) => {
26            if vars.contains(var) {
27                *x = Expr::Const(marker);
28            }
29        }
30        Expr::Const(_) => {}
31    }
32}
33
34pub struct SimplifyByShortCircuit;
35
36impl<T: Eq + Clone> VisitMut<T> for SimplifyByShortCircuit {
37    fn visit_mut_any(&mut self, Any(any): &mut Any<T>) {
38        let marker = false;
39        let vars = find_vars(any, marker);
40        for x in any.iter_mut().filter(|x| x.is_var().not()) {
41            replace_vars(x, &vars, marker);
42        }
43    }
44
45    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
46        let marker = true;
47        let vars = find_vars(all, marker);
48        for x in all.iter_mut().filter(|x| x.is_var().not()) {
49            replace_vars(x, &vars, marker);
50        }
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::ast::{all, any, const_, expr, var};
58
59    #[test]
60    fn test_simplify_any_with_duplicates() {
61        let mut expr_val = expr(any((
62            var(1),
63            var(2),
64            var(1), // duplicate
65            all((var(3), var(1))),
66        )));
67
68        SimplifyByShortCircuit.visit_mut_expr(&mut expr_val);
69
70        let any_val = expr_val.as_any().unwrap();
71        assert_eq!(any_val.0[2], expr(const_(false))); // duplicate replaced
72        // Check that var(1) in nested All is also replaced
73        let all_val = any_val.0[3].as_all().unwrap();
74        assert_eq!(all_val.0[1], expr(const_(false)));
75    }
76
77    #[test]
78    fn test_simplify_all_with_duplicates() {
79        let mut expr_val = expr(all((
80            var(1),
81            var(2),
82            var(1), // duplicate
83            any((var(3), var(1))),
84        )));
85
86        SimplifyByShortCircuit.visit_mut_expr(&mut expr_val);
87
88        let all_val = expr_val.as_all().unwrap();
89        assert_eq!(all_val.0[2], expr(const_(true))); // duplicate replaced
90        // Check that var(1) in nested Any is also replaced
91        let any_val = all_val.0[3].as_any().unwrap();
92        assert_eq!(any_val.0[1], expr(const_(true)));
93    }
94
95    #[test]
96    fn test_no_duplicates() {
97        let mut expr_val = expr(any((var(1), var(2), var(3))));
98        let original = expr_val.clone();
99
100        SimplifyByShortCircuit.visit_mut_expr(&mut expr_val);
101
102        assert_eq!(expr_val.to_string(), original.to_string());
103    }
104}