1use std::sync::Arc;
2
3use crate::collections::trees::TreeNode;
4use crate::ops::{Op, op_names};
5use crate::{Node, TreeRewriterRule};
6
7pub struct OpTreeRewriteRule<T> {
8 pub apply: Arc<dyn for<'a> Fn(&'a mut TreeNode<Op<T>>) -> bool>,
9}
10
11impl<T> OpTreeRewriteRule<T> {
12 pub fn new<F>(f: F) -> Self
13 where
14 F: for<'a> Fn(&'a mut TreeNode<Op<T>>) -> bool + 'static,
15 {
16 OpTreeRewriteRule { apply: Arc::new(f) }
17 }
18}
19
20impl TreeRewriterRule<Op<f32>> for OpTreeRewriteRule<f32> {
21 fn apply<'a>(&self, node: &'a mut TreeNode<Op<f32>>) -> bool {
22 (self.apply)(node)
23 }
24}
25
26pub fn all_rewrite_rules() -> Vec<OpTreeRewriteRule<f32>> {
27 let mut rules = Vec::new();
28
29 rules.extend(neutral_add_sub_mul_div());
30 rules.extend(fold_add_sub_mul_div());
31 rules.extend(neg_rules());
32 rules.extend(sum_prod_rules());
33
34 rules
35}
36
37fn is_zero(n: &TreeNode<Op<f32>>) -> bool {
38 match n.value() {
39 Op::Const(_, v) => v.abs() <= std::f32::EPSILON,
40 _ => false,
41 }
42}
43
44fn is_one(n: &TreeNode<Op<f32>>) -> bool {
45 match n.value() {
46 Op::Const(_, v) => (*v - crate::ops::math::ONE).abs() <= std::f32::EPSILON,
47 _ => false,
48 }
49}
50
51fn replace_with_child_idx(node: &mut TreeNode<Op<f32>>, idx: usize) -> bool {
54 if let Some(children) = node.children_mut() {
55 if idx < children.len() {
56 let mut subtree = children.swap_remove(idx);
57 std::mem::swap(node, &mut subtree);
58 return true;
59 }
60 }
61 false
62}
63
64fn replace_with_const(node: &mut TreeNode<Op<f32>>, name: &'static str, v: f32) -> bool {
65 let mut new_leaf = TreeNode::new(Op::Const(name, v));
66 std::mem::swap(node, &mut new_leaf);
67 true
68}
69
70pub fn neutral_add_sub_mul_div() -> Vec<OpTreeRewriteRule<f32>> {
72 vec![
73 OpTreeRewriteRule::new(|n| {
75 match n.value() {
76 Op::Fn(name, _, _) if *name == op_names::ADD => {
77 if let Some(children) = n.children_mut() {
78 if children.len() == 2 {
79 if is_zero(&children[0]) {
80 return replace_with_child_idx(n, 1);
81 }
82 if is_zero(&children[1]) {
83 return replace_with_child_idx(n, 0);
84 }
85 }
86 }
87 }
88 _ => {}
89 }
90
91 false
92 }),
93 OpTreeRewriteRule::new(|n| {
95 match n.value() {
96 Op::Fn(name, _, _) if *name == op_names::SUB => {
97 if let Some(children) = n.children_mut() {
98 if children.len() == 2 && is_zero(&children[1]) {
99 return replace_with_child_idx(n, 0);
100 }
101 }
102 }
103 _ => {}
104 }
105
106 false
107 }),
108 OpTreeRewriteRule::new(|n| {
110 match n.value() {
111 Op::Fn(name, _, _) if *name == op_names::SUB => {
112 if let Some(children) = n.children() {
113 if children.len() == 2 && children[0] == children[1] {
114 return replace_with_const(n, "0", 0.0);
115 }
116 }
117 }
118 _ => {}
119 }
120
121 false
122 }),
123 OpTreeRewriteRule::new(|n| {
125 match n.value() {
126 Op::Fn(name, _, _) if *name == op_names::MUL => {
127 if let Some(children) = n.children_mut() {
128 if children.len() == 2 {
129 if is_one(&children[0]) {
130 return replace_with_child_idx(n, 1);
131 }
132 if is_one(&children[1]) {
133 return replace_with_child_idx(n, 0);
134 }
135 }
136 }
137 }
138 _ => {}
139 }
140
141 false
142 }),
143 OpTreeRewriteRule::new(|n| {
145 match n.value() {
146 Op::Fn(name, _, _) if *name == op_names::MUL => {
147 if let Some(children) = n.children() {
148 if children.len() == 2 && (is_zero(&children[0]) || is_zero(&children[1])) {
149 return replace_with_const(n, "0", 0.0);
150 }
151 }
152 }
153 _ => {}
154 }
155
156 false
157 }),
158 OpTreeRewriteRule::new(|n| {
160 match n.value() {
161 Op::Fn(name, _, _) if *name == op_names::DIV => {
162 if let Some(children) = n.children_mut() {
163 if children.len() == 2 && is_one(&children[1]) {
164 return replace_with_child_idx(n, 0);
165 }
166 }
167 }
168 _ => {}
169 }
170
171 false
172 }),
173 ]
174}
175
176pub fn fold_add_sub_mul_div() -> Vec<OpTreeRewriteRule<f32>> {
177 let fold = |name: &'static str, f: fn(f32, f32) -> f32| {
178 OpTreeRewriteRule::new(move |n| {
179 if let Op::Fn(op_name, _, _) = n.value() {
180 if *op_name == name {
181 if let Some(children) = n.children() {
182 if children.len() == 2 {
183 match (children[0].value(), children[1].value()) {
184 (Op::Const(_, a), Op::Const(_, b)) => {
185 return replace_with_const(n, "c", f(*a, *b));
186 }
187 _ => {}
188 }
189 }
190 }
191 }
192 }
193 false
194 })
195 };
196
197 vec![
198 fold(op_names::ADD, |a, b| a + b),
199 fold(op_names::SUB, |a, b| a - b),
200 fold(op_names::MUL, |a, b| a * b),
201 fold(op_names::DIV, |a, b| a / b),
202 ]
203}
204
205pub fn neg_rules() -> Vec<OpTreeRewriteRule<f32>> {
206 vec![
207 OpTreeRewriteRule::new(|n| {
209 if let Op::Fn(name, _, _) = n.value() {
210 if *name == op_names::NEG {
211 if let Some(children) = n.children() {
212 if children.len() >= 1 {
213 if let Op::Fn(n2, _, _) = children[0].value() {
214 if *n2 == op_names::NEG {
215 if let Some(grand) = children[0].children() {
217 if let Some(_) = grand.get(0) {
218 if let Some(mut cs) = n.take_children() {
221 if cs.len() == 1 {
222 if let Some(mut gs) = cs[0].take_children() {
223 if !gs.is_empty() {
224 let mut only = gs.swap_remove(0);
225 std::mem::swap(n, &mut only);
226 return true;
227 }
228 }
229 }
230 n.add_child(cs.swap_remove(0));
232 }
233 }
234 }
235 }
236 }
237 }
238 }
239 }
240 }
241
242 false
243 }),
244 OpTreeRewriteRule::new(|n| {
246 if let Op::Fn(name, _, _) = n.value() {
247 if *name == op_names::NEG {
248 if let Some(children) = n.children() {
249 if children.len() == 1 {
250 if let Op::Const(_, v) = children[0].value() {
251 return replace_with_const(n, "c", -*v);
252 }
253 }
254 }
255 }
256 }
257
258 false
259 }),
260 ]
261}
262
263pub fn sum_prod_rules() -> Vec<OpTreeRewriteRule<f32>> {
264 vec![
265 OpTreeRewriteRule::new(|n| {
267 if let Op::Fn(name, _, _) = n.value() {
268 if *name == op_names::SUM {
269 if let Some(mut cs) = n.take_children() {
270 let mut kept = Vec::with_capacity(cs.len());
271 let mut dropped = false;
272 while let Some(ch) = cs.pop() {
273 if is_zero(&ch) {
274 dropped = true;
275 } else {
276 kept.push(ch);
277 }
278 }
279 kept.reverse();
280 if kept.is_empty() {
281 return replace_with_const(n, "0", 0.0);
282 }
283 if kept.len() == 1 {
284 let mut only = kept.swap_remove(0);
285 std::mem::swap(n, &mut only);
286 return true;
287 }
288 if dropped {
289 for k in kept {
291 n.add_child(k);
292 }
293 return true;
294 } else {
295 for c in kept {
297 n.add_child(c);
298 }
299 }
300 }
301 }
302 }
303 false
304 }),
305 OpTreeRewriteRule::new(|n| {
307 if let Op::Fn(name, _, _) = n.value() {
308 if *name == op_names::PROD {
309 if let Some(mut cs) = n.take_children() {
310 let mut kept = Vec::with_capacity(cs.len());
311 while let Some(ch) = cs.pop() {
312 if is_zero(&ch) {
313 return replace_with_const(n, "0", 0.0);
314 }
315 if is_one(&ch) {
316 continue;
317 }
318 kept.push(ch);
319 }
320 kept.reverse();
321 if kept.is_empty() {
322 return replace_with_const(n, "1", 1.0);
323 }
324 if kept.len() == 1 {
325 let mut only = kept.swap_remove(0);
326 std::mem::swap(n, &mut only);
327 return true;
328 }
329 if let Some(_) = n.children_mut() {
330 for k in kept {
331 n.add_child(k);
332 }
333 return true;
334 }
335 }
336 }
337 }
338 false
339 }),
340 ]
341}
342
343pub fn apply_rules_once(root: &mut TreeNode<Op<f32>>, rules: &[OpTreeRewriteRule<f32>]) -> usize {
345 let mut count = 0;
346
347 if let Some(children) = root.children_mut() {
348 for child in children.iter_mut() {
349 count += apply_rules_once(child, rules);
350 }
351 }
352
353 #[cfg(feature = "pgm")]
354 {
355 if let Op::PGM(_, _, programs, _) = root.value_mut() {
356 let progs = Arc::make_mut(programs);
357 for p in progs.iter_mut() {
358 count += apply_rules_once(p, rules);
359 }
360 }
361 }
362
363 for rule in rules {
364 if (rule.apply)(root) {
365 count += 1;
366 break;
367 }
368 }
369
370 count
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_apply_rules_once() {
379 let mut root = TreeNode::new(Op::add())
380 .attach(Op::named_constant("x", 1.0))
381 .attach(Op::named_constant("0", 0.0));
382
383 let rules = neutral_add_sub_mul_div();
384
385 let count = apply_rules_once(&mut root, &rules);
386 assert_eq!(count, 1);
387 assert_eq!(
388 match root.value() {
389 Op::Const(_, v) => *v,
390 _ => panic!("Expected constant"),
391 },
392 1.0
393 );
394 }
395
396 #[test]
397 fn test_fold_add_sub_mul_div() {
398 let mut root = TreeNode::new(Op::add())
399 .attach(Op::named_constant("x", 1.0))
400 .attach(Op::named_constant("y", 2.0));
401
402 let rules = fold_add_sub_mul_div();
403
404 let count = apply_rules_once(&mut root, &rules);
405 assert_eq!(count, 1);
406 assert_eq!(
407 match root.value() {
408 Op::Const(_, v) => *v,
409 _ => panic!("Expected constant"),
410 },
411 3.0
412 );
413 }
414
415 #[test]
416 fn test_neg_rules() {
417 let mut root = TreeNode::new(Op::neg())
419 .attach(TreeNode::new(Op::neg()).attach(Op::named_constant("x", 3.0)));
420
421 let rules = neg_rules();
422 let count = apply_rules_once(&mut root, &rules);
423
424 assert_eq!(count, 2);
425 match root.value() {
426 Op::Const(_, v) => assert_eq!(*v, 3.0),
427 _ => panic!("Expected constant"),
428 }
429 }
430
431 #[test]
432 fn test_sum_prod_rules() {
433 let mut root = TreeNode::new(Op::sum())
434 .attach(Op::named_constant("x", 2.0))
435 .attach(Op::named_constant("0", 0.0))
436 .attach(Op::named_constant("y", 3.0));
437 let rules = sum_prod_rules();
438 let count = apply_rules_once(&mut root, &rules);
439
440 assert_eq!(count, 1);
441 assert_eq!(root.children().unwrap().len(), 2);
442 assert_eq!(
443 match root.children().unwrap()[0].value() {
444 Op::Const(_, v) => *v,
445 _ => panic!("Expected constant"),
446 },
447 2.0
448 );
449 assert_eq!(
450 match root.children().unwrap()[1].value() {
451 Op::Const(_, v) => *v,
452 _ => panic!("Expected constant"),
453 },
454 3.0
455 );
456 }
457}