bool_logic/
transform.rs

1use crate::ast::{self, All, Any, Expr, Not, Var};
2use crate::utils::*;
3use crate::visit_mut::*;
4
5use std::ops::Not as _;
6use std::slice;
7
8use replace_with::replace_with_or_abort as replace_with;
9use rust_utils::default::default;
10use rust_utils::iter::{filter_map_collect, map_collect_vec};
11use rust_utils::slice::SliceExt;
12use rust_utils::vec::VecExt;
13
14fn unwrap_not<T>(expr: Expr<T>) -> Expr<T> {
15    if let Expr::Not(Not(not)) = expr {
16        *not
17    } else {
18        panic!()
19    }
20}
21
22pub struct FlattenSingle;
23
24impl<T> VisitMut<T> for FlattenSingle {
25    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
26        walk_mut_expr(self, expr);
27
28        match expr {
29            Expr::Any(Any(any)) => {
30                if any.is_empty() {
31                    *expr = Expr::Const(false);
32                } else if any.len() == 1 {
33                    *expr = any.pop().unwrap();
34                }
35            }
36            Expr::All(All(all)) => {
37                if all.is_empty() {
38                    *expr = Expr::Const(true);
39                } else if all.len() == 1 {
40                    *expr = all.pop().unwrap();
41                }
42            }
43            Expr::Not(Not(not_expr)) => {
44                if not_expr.is_not() {
45                    replace_with(expr, |expr| unwrap_not(unwrap_not(expr)))
46                }
47            }
48            _ => {}
49        };
50
51        walk_mut_expr(self, expr)
52    }
53}
54
55pub struct FlattenNestedList;
56
57impl FlattenNestedList {
58    fn flatten_any<T>(list: &mut Vec<Expr<T>>) {
59        if list.iter().all(|x| x.is_any().not()) {
60            return;
61        }
62
63        let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
64        for expr in list.drain(..) {
65            if let Expr::Any(Any(any)) = expr {
66                ans.extend(any);
67            } else {
68                ans.push(expr);
69            }
70        }
71        *list = ans;
72    }
73
74    fn flatten_all<T>(list: &mut Vec<Expr<T>>) {
75        if list.iter().all(|x| x.is_all().not()) {
76            return;
77        }
78
79        let mut ans: Vec<Expr<T>> = Vec::with_capacity(list.len());
80        for expr in list.drain(..) {
81            if let Expr::All(All(all)) = expr {
82                ans.extend(all);
83            } else {
84                ans.push(expr);
85            }
86        }
87        *list = ans;
88    }
89}
90
91impl<T> VisitMut<T> for FlattenNestedList {
92    fn visit_mut_any(&mut self, Any(list): &mut Any<T>) {
93        Self::flatten_any(list);
94        walk_mut_expr_list(self, list);
95    }
96
97    fn visit_mut_all(&mut self, All(list): &mut All<T>) {
98        Self::flatten_all(list);
99        walk_mut_expr_list(self, list);
100    }
101}
102
103pub struct DedupList;
104
105impl<T> VisitMut<T> for DedupList
106where
107    T: Eq,
108{
109    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
110        if let Some(list) = expr.as_mut_expr_list() {
111            let mut i = 0;
112            while i < list.len() {
113                let mut j = i + 1;
114                while j < list.len() {
115                    if list[i] == list[j] {
116                        list.remove(j);
117                    } else {
118                        j += 1;
119                    }
120                }
121                i += 1;
122            }
123        }
124        walk_mut_expr(self, expr);
125    }
126}
127
128pub struct EvalConst;
129
130impl EvalConst {
131    fn eval_any<T>(any: &mut Vec<Expr<T>>) -> Option<bool> {
132        any.remove_if(|expr| expr.is_const_false());
133
134        if any.is_empty() {
135            return Some(false);
136        }
137
138        if any.iter().any(|expr| expr.is_const_true()) {
139            return Some(true);
140        }
141
142        None
143    }
144
145    fn eval_all<T>(all: &mut Vec<Expr<T>>) -> Option<bool> {
146        all.remove_if(|expr| expr.is_const_true());
147
148        if all.is_empty() {
149            return Some(true);
150        }
151
152        if all.iter().any(|expr| expr.is_const_false()) {
153            return Some(false);
154        }
155
156        None
157    }
158
159    fn eval_not<T>(not: &Expr<T>) -> Option<bool> {
160        if let Expr::Const(val) = not {
161            return Some(val.not());
162        }
163        None
164    }
165}
166
167impl<T> VisitMut<T> for EvalConst {
168    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
169        walk_mut_expr(self, expr);
170
171        match expr {
172            Expr::Any(Any(any)) => {
173                if let Some(val) = Self::eval_any(any) {
174                    *expr = Expr::Const(val);
175                }
176            }
177            Expr::All(All(all)) => {
178                if let Some(val) = Self::eval_all(all) {
179                    *expr = Expr::Const(val);
180                }
181            }
182            Expr::Not(Not(not)) => {
183                if let Some(val) = Self::eval_not(not) {
184                    *expr = Expr::Const(val);
185                }
186            }
187            _ => {}
188        }
189    }
190}
191
192pub struct SimplifyNestedList;
193
194impl SimplifyNestedList {
195    fn contains_cross_same<T: Eq>(lhs: &[T], rhs: &[T]) -> bool {
196        lhs.iter().any(|x| rhs.contains(x))
197    }
198}
199
200impl<T> VisitMut<T> for SimplifyNestedList
201where
202    T: Eq,
203{
204    /// `any(x0, all(x0, x1), x2) => any(x0, x2)`
205    fn visit_mut_any(&mut self, Any(any): &mut Any<T>) {
206        let mut i = 0;
207        while i < any.len() {
208            if let Expr::All(All(all)) = &any[i] {
209                if Self::contains_cross_same(all, any) {
210                    any.remove(i);
211                    continue;
212                }
213            }
214
215            i += 1;
216        }
217
218        walk_mut_expr_list(self, any);
219    }
220
221    /// `all(x0, any(x0, x1), x2) => all(x0, x2)`
222    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
223        let mut i = 0;
224        while i < all.len() {
225            if let Expr::Any(Any(any)) = &all[i] {
226                if Self::contains_cross_same(any, all) {
227                    all.remove(i);
228                    continue;
229                }
230            }
231
232            i += 1;
233        }
234
235        walk_mut_expr_list(self, all);
236    }
237}
238
239pub struct SimplifyAllNotAny;
240
241impl SimplifyAllNotAny {
242    /// Simplify `all(not(any(...)), any(...))`
243    fn counteract<T: Eq>(neg: &[Expr<T>], pos: &mut Vec<Expr<T>>) {
244        let mut i = 0;
245        while i < pos.len() {
246            if neg.contains(&pos[i]) {
247                pos.remove(i);
248            } else {
249                i += 1;
250            }
251        }
252    }
253}
254
255impl<T> VisitMut<T> for SimplifyAllNotAny
256where
257    T: Eq,
258{
259    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
260        if let [Expr::Not(Not(not)), Expr::Any(Any(pos))] = all.as_mut_slice() {
261            let neg = match not.as_mut_any() {
262                Some(Any(neg)) => neg,
263                None => slice::from_mut(&mut **not),
264            };
265            Self::counteract(neg, pos);
266        } else if let [Expr::Any(Any(pos)), Expr::Not(Not(not))] = all.as_mut_slice() {
267            let neg = match not.as_mut_any() {
268                Some(Any(neg)) => neg,
269                None => slice::from_mut(&mut **not),
270            };
271            Self::counteract(neg, pos);
272        }
273
274        walk_mut_expr_list(self, all);
275    }
276}
277
278pub struct FlattenByDeMorgan;
279
280impl<T> VisitMut<T> for FlattenByDeMorgan {
281    fn visit_mut_expr(&mut self, expr: &mut Expr<T>) {
282        if let Expr::Not(Not(not)) = expr {
283            match &mut **not {
284                Expr::Any(Any(any)) => {
285                    let list = map_collect_vec(any.drain(..), |expr| ast::expr(ast::not(expr)));
286                    *expr = ast::expr(ast::all(list));
287                }
288                Expr::All(All(all)) => {
289                    let list = map_collect_vec(all.drain(..), |expr| ast::expr(ast::not(expr)));
290                    *expr = ast::expr(ast::any(list));
291                }
292                _ => {}
293            }
294        }
295
296        walk_mut_expr(self, expr)
297    }
298}
299
300pub struct MergeAllOfNotAny;
301
302impl MergeAllOfNotAny {
303    fn as_mut_not_any<T>(expr: &mut Expr<T>) -> Option<&mut Vec<Expr<T>>> {
304        expr.as_mut_not_any().map(|x| &mut x.0)
305    }
306
307    fn unwrap_expr_not_var<T>(expr: Expr<T>) -> Var<T> {
308        if let Expr::Not(Not(not)) = expr {
309            if let Expr::Var(var) = *not {
310                return var;
311            }
312        }
313        panic!()
314    }
315}
316
317impl<T> VisitMut<T> for MergeAllOfNotAny {
318    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
319        let mut not_any_list: Vec<_> = filter_map_collect(&mut *all, Self::as_mut_not_any);
320
321        if let [first, rest @ ..] = not_any_list.as_mut_slice() {
322            if rest.is_empty().not() {
323                rest.iter_mut().for_each(|x| first.append(x));
324                all.remove_if(|x| x.is_empty_not_any())
325            }
326
327            {
328                let not_var_list: Vec<_> = drain_filter(all, |x| x.is_expr_not_var()).collect();
329                let not_any = all.iter_mut().find_map(Self::as_mut_not_any).unwrap();
330
331                for not_var in not_var_list {
332                    let var = Self::unwrap_expr_not_var(not_var);
333                    not_any.push(ast::expr(var));
334                }
335            }
336        }
337    }
338}
339
340pub struct MergeAllOfAny;
341
342impl MergeAllOfAny {
343    fn is_subset_of<T: Eq>(lhs: &[Expr<T>], rhs: &[Expr<T>]) -> bool {
344        lhs.iter().all(|x| rhs.contains(x))
345    }
346}
347
348impl<T: Eq> VisitMut<T> for MergeAllOfAny {
349    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
350        walk_mut_expr_list(self, all);
351
352        let mut any_list: Vec<_> = filter_map_collect(&mut *all, |x| Expr::as_mut_any(x).map(|x| &mut x.0));
353
354        for i in 0..any_list.len() {
355            for j in 0..any_list.len() {
356                if let Some((lhs, rhs)) = any_list.get2_mut(i, j) {
357                    if Self::is_subset_of(lhs, rhs) {
358                        rhs.clear();
359                        rhs.push(Expr::Const(true));
360                    }
361                }
362            }
363        }
364    }
365}
366
367pub struct SimplifyByShortCircuit;
368
369impl SimplifyByShortCircuit {
370    fn find_vars<T: Eq + Clone>(list: &mut [Expr<T>], marker: bool) -> Vec<Var<T>> {
371        let mut ans: Vec<Var<T>> = default();
372        for x in list {
373            if let Expr::Var(var) = x {
374                if ans.contains(var) {
375                    *x = Expr::Const(marker);
376                } else {
377                    ans.push(var.clone())
378                }
379            }
380        }
381        ans
382    }
383
384    fn replace_vars<T: Eq>(x: &mut Expr<T>, vars: &[Var<T>], marker: bool) {
385        match x {
386            Expr::Any(Any(any)) => any.iter_mut().for_each(|x| Self::replace_vars(x, vars, marker)),
387            Expr::All(All(all)) => all.iter_mut().for_each(|x| Self::replace_vars(x, vars, marker)),
388            Expr::Not(Not(not)) => Self::replace_vars(not, vars, marker),
389            Expr::Var(var) => {
390                if vars.contains(var) {
391                    *x = Expr::Const(marker);
392                }
393            }
394            Expr::Const(_) => {}
395        }
396    }
397}
398
399impl<T: Eq + Clone> VisitMut<T> for SimplifyByShortCircuit {
400    fn visit_mut_any(&mut self, Any(any): &mut Any<T>) {
401        let marker = false;
402        let vars = Self::find_vars(any, marker);
403        for x in any.iter_mut().filter(|x| x.is_var().not()) {
404            Self::replace_vars(x, &vars, marker);
405        }
406
407        walk_mut_expr_list(self, any)
408    }
409
410    fn visit_mut_all(&mut self, All(all): &mut All<T>) {
411        let marker = true;
412        let vars = Self::find_vars(all, marker);
413        for x in all.iter_mut().filter(|x| x.is_var().not()) {
414            Self::replace_vars(x, &vars, marker);
415        }
416
417        walk_mut_expr_list(self, all)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    use crate::ast::*;
426
427    #[test]
428    fn eval_const() {
429        let mut cfg: Expr<u32> = expr(not(not(any(()))));
430        EvalConst.visit_mut_expr(&mut cfg);
431        assert_eq!(cfg.to_string(), "false");
432    }
433}