codegen_libc/
transform.rs

1use codegen_cfg::ast::*;
2use codegen_cfg::bool_logic::transform::*;
3use codegen_cfg::bool_logic::visit_mut::*;
4use log::debug;
5use rust_utils::iter::filter_map_collect_vec;
6use rust_utils::iter::map_collect_vec;
7use rust_utils::vec::VecExt;
8
9use std::cmp::Ordering::{self, *};
10use std::mem;
11
12use log::trace;
13
14pub fn simplified_expr(x: impl Into<Expr>) -> Expr {
15    let mut x = x.into();
16
17    debug!("input:                              {x}");
18
19    UnifyTargetFamily.visit_mut_expr(&mut x);
20    trace!("after  UnifyTargetFamily:           {x}");
21
22    for _ in 0..3 {
23        FlattenSingle.visit_mut_expr(&mut x);
24        trace!("after  FlattenSingle:               {x}");
25
26        FlattenNestedList.visit_mut_expr(&mut x);
27        trace!("after  FlattenNestedList:           {x}");
28
29        DedupList.visit_mut_expr(&mut x);
30        trace!("after  DedupList:                   {x}");
31
32        EvalConst.visit_mut_expr(&mut x);
33        trace!("after  EvalConst:                   {x}");
34
35        SimplifyNestedList.visit_mut_expr(&mut x);
36        trace!("after  SimplifyNestedList:          {x}");
37
38        MergeAllOfNotAny.visit_mut_expr(&mut x);
39        trace!("after  MergeAllOfNotAny:            {x}");
40
41        SimplifyAllNotAny.visit_mut_expr(&mut x);
42        trace!("after  SimplifyAllNotAny:           {x}");
43
44        MergeAllOfAny.visit_mut_expr(&mut x);
45        trace!("after  MergeAllOfAny:               {x}");
46
47        ImplyByKey.visit_mut_expr(&mut x);
48        trace!("after  ImplyByKey:                  {x}");
49
50        SuppressTargetFamily.visit_mut_expr(&mut x);
51        trace!("after  SuppressTargetFamily:        {x}");
52
53        EvalConst.visit_mut_expr(&mut x);
54        trace!("after  EvalConst:                   {x}");
55
56        MergePattern.visit_mut_expr(&mut x);
57        trace!("after  MergePattern:                {x}");
58
59        EvalConst.visit_mut_expr(&mut x);
60        trace!("after  EvalConst:                   {x}");
61
62        SimplifyByShortCircuit.visit_mut_expr(&mut x);
63        trace!("after  SimplifyByShortCircuit:      {x}");
64
65        EvalConst.visit_mut_expr(&mut x);
66        trace!("after  EvalConst:                   {x}");
67    }
68
69    SimplifyTargetFamily.visit_mut_expr(&mut x);
70    trace!("after  SimplifyTargetFamily:        {x}");
71
72    SortByPriority.visit_mut_expr(&mut x);
73    trace!("after  SortByPriority:              {x}");
74
75    SortByValue.visit_mut_expr(&mut x);
76    trace!("after  SortByValue:                 {x}");
77
78    debug!("output:                             {x}");
79
80    x
81}
82
83struct SortByPriority;
84
85impl SortByPriority {
86    fn get_priority(x: &Expr) -> Option<u32> {
87        Some(match x {
88            Expr::Not(_) => 103,
89            Expr::Any(_) => 101,
90            Expr::All(_) => 102,
91            Expr::Var(Var(pred)) => match pred.key.as_str() {
92                "target_family" => 1,
93                "target_arch" => 2,
94                "target_vendor" => 3,
95                "target_os" => 4,
96                "target_env" => 5,
97                "target_pointer_width" => 6,
98                _ => 0,
99            },
100            Expr::Const(_) => panic!(),
101        })
102    }
103}
104
105impl VisitMut<Pred> for SortByPriority {
106    fn visit_mut_expr(&mut self, expr: &mut Expr) {
107        if let Some(list) = expr.as_mut_expr_list() {
108            list.sort_by(|lhs, rhs| {
109                let Some(lhs) = Self::get_priority(lhs) else {return Equal};
110                let Some(rhs) = Self::get_priority(rhs) else {return Equal};
111                lhs.cmp(&rhs)
112            })
113        }
114
115        walk_mut_expr(self, expr);
116    }
117}
118
119struct SortByValue;
120
121impl SortByValue {
122    fn cmp_var(lhs: &Expr, rhs: &Expr) -> Ordering {
123        let Expr::Var(Var(lhs)) = lhs else { return Equal };
124        let Expr::Var(Var(rhs)) = rhs else { return Equal };
125
126        let ok = Ord::cmp(lhs.key.as_str(), rhs.key.as_str());
127
128        match (lhs.value.as_deref(), rhs.value.as_deref()) {
129            (None, None) => ok,
130            (Some(lv), Some(rv)) => ok.then_with(|| Ord::cmp(lv, rv)),
131            (None, Some(_)) => Less,
132            (Some(_), None) => Greater,
133        }
134    }
135
136    fn cmp_not(lhs: &Expr, rhs: &Expr) -> Ordering {
137        let Expr::Not(Not(lhs)) = lhs else { return Equal };
138        let Expr::Not(Not(rhs)) = rhs else { return Equal };
139
140        Self::cmp_var(lhs, rhs)
141    }
142}
143
144impl VisitMut<Pred> for SortByValue {
145    fn visit_mut_expr(&mut self, expr: &mut Expr) {
146        if let Some(list) = expr.as_mut_expr_list() {
147            list.sort_by(Self::cmp_var);
148            list.sort_by(Self::cmp_not);
149        }
150
151        walk_mut_expr(self, expr);
152    }
153}
154
155struct UnifyTargetFamily;
156
157impl VisitMut<Pred> for UnifyTargetFamily {
158    fn visit_mut_var(&mut self, Var(pred): &mut Var<Pred>) {
159        if pred.value.is_none() && matches!(pred.key.as_str(), "unix" | "windows" | "wasm") {
160            *pred = key_value("target_family", pred.key.as_str());
161        }
162    }
163}
164
165struct SimplifyTargetFamily;
166
167impl VisitMut<Pred> for SimplifyTargetFamily {
168    fn visit_mut_var(&mut self, Var(pred): &mut Var<Pred>) {
169        if pred.key == "target_family" {
170            if let Some(value) = pred.value.as_deref() {
171                if matches!(value, "unix" | "windows" | "wasm") {
172                    *pred = flag(value);
173                }
174            }
175        }
176    }
177}
178
179struct ImplyByKey;
180
181impl ImplyByKey {
182    const UNIQUE_VALUED_KEYS: &[&'static str] = &[
183        "target_family",
184        "target_arch",
185        "target_vendor",
186        "target_os",
187        "target_env",
188        "target_pointer_width",
189    ];
190
191    fn is_expr_any_pred(any: &[Expr], key: &str) -> bool {
192        any.iter().all(|x| x.as_var().map_or(false, |Var(var)| var.key == key))
193    }
194
195    fn fix(pos_key: &str, pos_any_values: &[&str], expr: &mut Expr) {
196        match expr {
197            Expr::Any(Any(any)) => {
198                any.iter_mut().for_each(|x| Self::fix(pos_key, pos_any_values, x));
199            }
200            Expr::All(All(all)) => {
201                all.iter_mut().for_each(|x| Self::fix(pos_key, pos_any_values, x));
202            }
203            Expr::Not(Not(not)) => {
204                Self::fix(pos_key, pos_any_values, not);
205            }
206            Expr::Var(Var(var)) => {
207                if var.key == pos_key {
208                    let var_value = var.value.as_deref().unwrap();
209                    if pos_any_values.contains(&var_value) {
210                        if pos_any_values.len() == 1 {
211                            *expr = Expr::Const(true)
212                        }
213                    } else {
214                        *expr = Expr::Const(false)
215                    }
216                }
217            }
218            Expr::Const(_) => {}
219        }
220    }
221}
222
223impl VisitMut<Pred> for ImplyByKey {
224    fn visit_mut_all(&mut self, All(all): &mut All<Pred>) {
225        walk_mut_expr_list(self, all);
226
227        let mut i = 0;
228        while i < all.len() {
229            match &all[i] {
230                Expr::Var(Var(pos)) => {
231                    if Self::UNIQUE_VALUED_KEYS.contains(&pos.key.as_str()) {
232                        assert!(pos.value.is_some());
233
234                        let pos = pos.clone();
235                        let pos_key = pos.key.as_str();
236                        let pos_any_values = &[pos.value.as_deref().unwrap()];
237
238                        for (_, x) in all.iter_mut().enumerate().filter(|&(j, _)| j != i) {
239                            Self::fix(pos_key, pos_any_values, x);
240                        }
241                    }
242                }
243                Expr::Any(Any(any)) => {
244                    if let Some(pos_key) = Self::UNIQUE_VALUED_KEYS.iter().find(|k| Self::is_expr_any_pred(any, k)) {
245                        let any = any.clone();
246                        let pos_any_values = map_collect_vec(&any, |x| x.as_var().unwrap().0.value.as_deref().unwrap());
247
248                        for (_, x) in all.iter_mut().enumerate().filter(|&(j, _)| j != i) {
249                            Self::fix(pos_key, &pos_any_values, x);
250                        }
251                    }
252                }
253                _ => {}
254            }
255            i += 1;
256        }
257    }
258}
259
260struct SuppressTargetFamily;
261
262impl SuppressTargetFamily {
263    fn is_target_os_pred(x: &Expr) -> bool {
264        match x {
265            Expr::Var(Var(var)) => var.key == "target_os",
266            _ => false,
267        }
268    }
269
270    fn has_specified_target_os(x: &Expr) -> bool {
271        if Self::is_target_os_pred(x) {
272            return true;
273        }
274
275        if let Expr::Any(Any(any)) = x {
276            return any.iter().all(Self::is_target_os_pred);
277        }
278
279        false
280    }
281
282    #[allow(clippy::match_like_matches_macro)]
283    fn is_suppressed_target_family(pred: &Pred) -> bool {
284        match (pred.key.as_str(), pred.value.as_deref()) {
285            ("target_family", Some("unix")) => true,
286            ("target_family", Some("windows")) => true,
287            _ => false,
288        }
289    }
290}
291
292impl VisitMut<Pred> for SuppressTargetFamily {
293    fn visit_mut_all(&mut self, All(all): &mut All<Pred>) {
294        if all.iter().any(Self::has_specified_target_os) {
295            all.remove_if(|x| match x {
296                Expr::Var(Var(pred)) => Self::is_suppressed_target_family(pred),
297                Expr::Not(Not(not)) => match &**not {
298                    Expr::Var(Var(pred)) => Self::is_suppressed_target_family(pred),
299                    _ => false,
300                },
301                _ => false,
302            })
303        }
304
305        walk_mut_expr_list(self, all)
306    }
307}
308
309struct MergePattern;
310
311impl MergePattern {
312    fn merge(any_list: &mut [Expr]) {
313        let mut pattern_list = filter_map_collect_vec(any_list, |x| {
314            if let Expr::All(All(all)) = x {
315                if let [first, second] = all.as_mut_slice() {
316                    if first.is_any() || first.is_var() {
317                        return Some((first, second));
318                    }
319                }
320            }
321            None
322        });
323
324        if let [head, rest @ ..] = pattern_list.as_mut_slice() {
325            let agg = match head.0 {
326                Expr::Any(Any(any)) => any,
327                Expr::Var(var) => {
328                    *head.0 = expr(any((var.clone(),)));
329                    head.0.as_mut_any().map(|x| &mut x.0).unwrap()
330                }
331                _ => panic!(),
332            };
333
334            for x in rest {
335                let to_agg = if x.1 == head.1 {
336                    &mut *x.0
337                } else if x.0 == head.1 {
338                    &mut *x.1
339                } else {
340                    continue;
341                };
342
343                match mem::replace(to_agg, Expr::Const(false)) {
344                    Expr::Any(Any(any)) => agg.extend(any),
345                    Expr::Var(var) => agg.push(expr(var.clone())),
346                    other => *to_agg = other,
347                }
348            }
349
350            if agg.len() == 1 {
351                *head.0 = agg.pop().unwrap();
352            }
353        }
354    }
355}
356
357impl VisitMut<Pred> for MergePattern {
358    fn visit_mut_any(&mut self, Any(any_list): &mut Any<Pred>) {
359        Self::merge(any_list);
360        Self::merge(&mut any_list[1..]);
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn sort() {
370        let mut expr = expr(all((not(flag("unix")), flag("unix"))));
371        SortByPriority.visit_mut_expr(&mut expr);
372        assert_eq!(expr.to_string(), "all(unix, not(unix))");
373    }
374
375    #[test]
376    fn imply() {
377        {
378            let mut expr = expr(all((target_os("linux"), not(target_os("emscripten")))));
379            ImplyByKey.visit_mut_expr(&mut expr);
380            assert_eq!(expr.to_string(), r#"all(target_os = "linux", not(false))"#)
381        }
382        {
383            let mut expr = expr(all((
384                any((target_os("ios"), target_os("macos"))),     //
385                any((target_os("linux"), target_os("android"))), //
386            )));
387            ImplyByKey.visit_mut_expr(&mut expr);
388            assert_eq!(
389                expr.to_string(),
390                r#"all(any(target_os = "ios", target_os = "macos"), any(false, false))"#
391            );
392        }
393    }
394}