radiate_gp/ops/
rewrite.rs

1use std::sync::Arc;
2
3use crate::collections::trees::TreeNode;
4use crate::ops::{Op, op_names};
5use crate::{Node, TreeRewriterRule};
6
7pub struct OpTreeRewriteRule<T> {
8    pub apply: Arc<dyn for<'a> Fn(&'a mut TreeNode<Op<T>>) -> bool>,
9}
10
11impl<T> OpTreeRewriteRule<T> {
12    pub fn new<F>(f: F) -> Self
13    where
14        F: for<'a> Fn(&'a mut TreeNode<Op<T>>) -> bool + 'static,
15    {
16        OpTreeRewriteRule { apply: Arc::new(f) }
17    }
18}
19
20impl TreeRewriterRule<Op<f32>> for OpTreeRewriteRule<f32> {
21    fn apply<'a>(&self, node: &'a mut TreeNode<Op<f32>>) -> bool {
22        (self.apply)(node)
23    }
24}
25
26pub fn all_rewrite_rules() -> Vec<OpTreeRewriteRule<f32>> {
27    let mut rules = Vec::new();
28
29    rules.extend(neutral_add_sub_mul_div());
30    rules.extend(fold_add_sub_mul_div());
31    rules.extend(neg_rules());
32    rules.extend(sum_prod_rules());
33
34    rules
35}
36
37fn is_zero(n: &TreeNode<Op<f32>>) -> bool {
38    match n.value() {
39        Op::Const(_, v) => v.abs() <= std::f32::EPSILON,
40        _ => false,
41    }
42}
43
44fn is_one(n: &TreeNode<Op<f32>>) -> bool {
45    match n.value() {
46        Op::Const(_, v) => (*v - crate::ops::math::ONE).abs() <= std::f32::EPSILON,
47        _ => false,
48    }
49}
50
51// Replace current node with one of its children by moving it, no clone.
52// idx must be valid and children must exist. This discards the other child subtree (intended).
53fn replace_with_child_idx(node: &mut TreeNode<Op<f32>>, idx: usize) -> bool {
54    if let Some(children) = node.children_mut() {
55        if idx < children.len() {
56            let mut subtree = children.swap_remove(idx);
57            std::mem::swap(node, &mut subtree);
58            return true;
59        }
60    }
61    false
62}
63
64fn replace_with_const(node: &mut TreeNode<Op<f32>>, name: &'static str, v: f32) -> bool {
65    let mut new_leaf = TreeNode::new(Op::Const(name, v));
66    std::mem::swap(node, &mut new_leaf);
67    true
68}
69
70// Neutral/identity rules (in-place; no subtree clones)
71pub fn neutral_add_sub_mul_div() -> Vec<OpTreeRewriteRule<f32>> {
72    vec![
73        // add(x,0) or add(0,x) -> x
74        OpTreeRewriteRule::new(|n| {
75            match n.value() {
76                Op::Fn(name, _, _) if *name == op_names::ADD => {
77                    if let Some(children) = n.children_mut() {
78                        if children.len() == 2 {
79                            if is_zero(&children[0]) {
80                                return replace_with_child_idx(n, 1);
81                            }
82                            if is_zero(&children[1]) {
83                                return replace_with_child_idx(n, 0);
84                            }
85                        }
86                    }
87                }
88                _ => {}
89            }
90
91            false
92        }),
93        // sub(x,0) -> x
94        OpTreeRewriteRule::new(|n| {
95            match n.value() {
96                Op::Fn(name, _, _) if *name == op_names::SUB => {
97                    if let Some(children) = n.children_mut() {
98                        if children.len() == 2 && is_zero(&children[1]) {
99                            return replace_with_child_idx(n, 0);
100                        }
101                    }
102                }
103                _ => {}
104            }
105
106            false
107        }),
108        // sub(x,x) -> 0
109        OpTreeRewriteRule::new(|n| {
110            match n.value() {
111                Op::Fn(name, _, _) if *name == op_names::SUB => {
112                    if let Some(children) = n.children() {
113                        if children.len() == 2 && children[0] == children[1] {
114                            return replace_with_const(n, "0", 0.0);
115                        }
116                    }
117                }
118                _ => {}
119            }
120
121            false
122        }),
123        // mul(x,1) or mul(1,x) -> x
124        OpTreeRewriteRule::new(|n| {
125            match n.value() {
126                Op::Fn(name, _, _) if *name == op_names::MUL => {
127                    if let Some(children) = n.children_mut() {
128                        if children.len() == 2 {
129                            if is_one(&children[0]) {
130                                return replace_with_child_idx(n, 1);
131                            }
132                            if is_one(&children[1]) {
133                                return replace_with_child_idx(n, 0);
134                            }
135                        }
136                    }
137                }
138                _ => {}
139            }
140
141            false
142        }),
143        // mul(x,0) or mul(0,x) -> 0
144        OpTreeRewriteRule::new(|n| {
145            match n.value() {
146                Op::Fn(name, _, _) if *name == op_names::MUL => {
147                    if let Some(children) = n.children() {
148                        if children.len() == 2 && (is_zero(&children[0]) || is_zero(&children[1])) {
149                            return replace_with_const(n, "0", 0.0);
150                        }
151                    }
152                }
153                _ => {}
154            }
155
156            false
157        }),
158        // div(x,1) -> x
159        OpTreeRewriteRule::new(|n| {
160            match n.value() {
161                Op::Fn(name, _, _) if *name == op_names::DIV => {
162                    if let Some(children) = n.children_mut() {
163                        if children.len() == 2 && is_one(&children[1]) {
164                            return replace_with_child_idx(n, 0);
165                        }
166                    }
167                }
168                _ => {}
169            }
170
171            false
172        }),
173    ]
174}
175
176pub fn fold_add_sub_mul_div() -> Vec<OpTreeRewriteRule<f32>> {
177    let fold = |name: &'static str, f: fn(f32, f32) -> f32| {
178        OpTreeRewriteRule::new(move |n| {
179            if let Op::Fn(op_name, _, _) = n.value() {
180                if *op_name == name {
181                    if let Some(children) = n.children() {
182                        if children.len() == 2 {
183                            match (children[0].value(), children[1].value()) {
184                                (Op::Const(_, a), Op::Const(_, b)) => {
185                                    return replace_with_const(n, "c", f(*a, *b));
186                                }
187                                _ => {}
188                            }
189                        }
190                    }
191                }
192            }
193            false
194        })
195    };
196
197    vec![
198        fold(op_names::ADD, |a, b| a + b),
199        fold(op_names::SUB, |a, b| a - b),
200        fold(op_names::MUL, |a, b| a * b),
201        fold(op_names::DIV, |a, b| a / b),
202    ]
203}
204
205pub fn neg_rules() -> Vec<OpTreeRewriteRule<f32>> {
206    vec![
207        // neg(neg(x)) -> x
208        OpTreeRewriteRule::new(|n| {
209            if let Op::Fn(name, _, _) = n.value() {
210                if *name == op_names::NEG {
211                    if let Some(children) = n.children() {
212                        if children.len() >= 1 {
213                            if let Op::Fn(n2, _, _) = children[0].value() {
214                                if *n2 == op_names::NEG {
215                                    // move the grandchild into place
216                                    if let Some(grand) = children[0].children() {
217                                        if let Some(_) = grand.get(0) {
218                                            // swap with a moved copy of grandchild (avoid clone via take_children):
219                                            // use take from parent:
220                                            if let Some(mut cs) = n.take_children() {
221                                                if cs.len() == 1 {
222                                                    if let Some(mut gs) = cs[0].take_children() {
223                                                        if !gs.is_empty() {
224                                                            let mut only = gs.swap_remove(0);
225                                                            std::mem::swap(n, &mut only);
226                                                            return true;
227                                                        }
228                                                    }
229                                                }
230                                                // restore if failed
231                                                n.add_child(cs.swap_remove(0));
232                                            }
233                                        }
234                                    }
235                                }
236                            }
237                        }
238                    }
239                }
240            }
241
242            false
243        }),
244        // neg(Const) -> Const(-v)
245        OpTreeRewriteRule::new(|n| {
246            if let Op::Fn(name, _, _) = n.value() {
247                if *name == op_names::NEG {
248                    if let Some(children) = n.children() {
249                        if children.len() == 1 {
250                            if let Op::Const(_, v) = children[0].value() {
251                                return replace_with_const(n, "c", -*v);
252                            }
253                        }
254                    }
255                }
256            }
257
258            false
259        }),
260    ]
261}
262
263pub fn sum_prod_rules() -> Vec<OpTreeRewriteRule<f32>> {
264    vec![
265        // sum(...,0,...) -> drop zeros; unwrap; empty->0
266        OpTreeRewriteRule::new(|n| {
267            if let Op::Fn(name, _, _) = n.value() {
268                if *name == op_names::SUM {
269                    if let Some(mut cs) = n.take_children() {
270                        let mut kept = Vec::with_capacity(cs.len());
271                        let mut dropped = false;
272                        while let Some(ch) = cs.pop() {
273                            if is_zero(&ch) {
274                                dropped = true;
275                            } else {
276                                kept.push(ch);
277                            }
278                        }
279                        kept.reverse();
280                        if kept.is_empty() {
281                            return replace_with_const(n, "0", 0.0);
282                        }
283                        if kept.len() == 1 {
284                            let mut only = kept.swap_remove(0);
285                            std::mem::swap(n, &mut only);
286                            return true;
287                        }
288                        if dropped {
289                            // put back pruned children
290                            for k in kept {
291                                n.add_child(k);
292                            }
293                            return true;
294                        } else {
295                            // nothing changed; restore original children
296                            for c in kept {
297                                n.add_child(c);
298                            }
299                        }
300                    }
301                }
302            }
303            false
304        }),
305        // prod: zero short-circuit; drop ones; unwrap; empty->1
306        OpTreeRewriteRule::new(|n| {
307            if let Op::Fn(name, _, _) = n.value() {
308                if *name == op_names::PROD {
309                    if let Some(mut cs) = n.take_children() {
310                        let mut kept = Vec::with_capacity(cs.len());
311                        while let Some(ch) = cs.pop() {
312                            if is_zero(&ch) {
313                                return replace_with_const(n, "0", 0.0);
314                            }
315                            if is_one(&ch) {
316                                continue;
317                            }
318                            kept.push(ch);
319                        }
320                        kept.reverse();
321                        if kept.is_empty() {
322                            return replace_with_const(n, "1", 1.0);
323                        }
324                        if kept.len() == 1 {
325                            let mut only = kept.swap_remove(0);
326                            std::mem::swap(n, &mut only);
327                            return true;
328                        }
329                        if let Some(_) = n.children_mut() {
330                            for k in kept {
331                                n.add_child(k);
332                            }
333                            return true;
334                        }
335                    }
336                }
337            }
338            false
339        }),
340    ]
341}
342
343// Post-order application (in-place). Returns number of rewrites.
344pub fn apply_rules_once(root: &mut TreeNode<Op<f32>>, rules: &[OpTreeRewriteRule<f32>]) -> usize {
345    let mut count = 0;
346
347    if let Some(children) = root.children_mut() {
348        for child in children.iter_mut() {
349            count += apply_rules_once(child, rules);
350        }
351    }
352
353    #[cfg(feature = "pgm")]
354    {
355        if let Op::PGM(_, _, programs, _) = root.value_mut() {
356            let progs = Arc::make_mut(programs);
357            for p in progs.iter_mut() {
358                count += apply_rules_once(p, rules);
359            }
360        }
361    }
362
363    for rule in rules {
364        if (rule.apply)(root) {
365            count += 1;
366            break;
367        }
368    }
369
370    count
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_apply_rules_once() {
379        let mut root = TreeNode::new(Op::add())
380            .attach(Op::named_constant("x", 1.0))
381            .attach(Op::named_constant("0", 0.0));
382
383        let rules = neutral_add_sub_mul_div();
384
385        let count = apply_rules_once(&mut root, &rules);
386        assert_eq!(count, 1);
387        assert_eq!(
388            match root.value() {
389                Op::Const(_, v) => *v,
390                _ => panic!("Expected constant"),
391            },
392            1.0
393        );
394    }
395
396    #[test]
397    fn test_fold_add_sub_mul_div() {
398        let mut root = TreeNode::new(Op::add())
399            .attach(Op::named_constant("x", 1.0))
400            .attach(Op::named_constant("y", 2.0));
401
402        let rules = fold_add_sub_mul_div();
403
404        let count = apply_rules_once(&mut root, &rules);
405        assert_eq!(count, 1);
406        assert_eq!(
407            match root.value() {
408                Op::Const(_, v) => *v,
409                _ => panic!("Expected constant"),
410            },
411            3.0
412        );
413    }
414
415    #[test]
416    fn test_neg_rules() {
417        // Build neg(neg(x))
418        let mut root = TreeNode::new(Op::neg())
419            .attach(TreeNode::new(Op::neg()).attach(Op::named_constant("x", 3.0)));
420
421        let rules = neg_rules();
422        let count = apply_rules_once(&mut root, &rules);
423
424        assert_eq!(count, 2);
425        match root.value() {
426            Op::Const(_, v) => assert_eq!(*v, 3.0),
427            _ => panic!("Expected constant"),
428        }
429    }
430
431    #[test]
432    fn test_sum_prod_rules() {
433        let mut root = TreeNode::new(Op::sum())
434            .attach(Op::named_constant("x", 2.0))
435            .attach(Op::named_constant("0", 0.0))
436            .attach(Op::named_constant("y", 3.0));
437        let rules = sum_prod_rules();
438        let count = apply_rules_once(&mut root, &rules);
439
440        assert_eq!(count, 1);
441        assert_eq!(root.children().unwrap().len(), 2);
442        assert_eq!(
443            match root.children().unwrap()[0].value() {
444                Op::Const(_, v) => *v,
445                _ => panic!("Expected constant"),
446            },
447            2.0
448        );
449        assert_eq!(
450            match root.children().unwrap()[1].value() {
451                Op::Const(_, v) => *v,
452                _ => panic!("Expected constant"),
453            },
454            3.0
455        );
456    }
457}