1#![warn(missing_docs)]
2
3use chomsky_rule_engine::{RewriteRegistry, RewriteRule, RuleCategory};
4use chomsky_uir::egraph::{Analysis, EGraph};
5use chomsky_uir::{IKun, Id};
6
7pub struct ConstantFolding;
8
9impl<A: Analysis<IKun>> RewriteRule<A> for ConstantFolding {
10 fn name(&self) -> &str {
11 "constant-folding"
12 }
13
14 fn apply(&self, egraph: &EGraph<IKun, A>) {
15 let mut matches = Vec::new();
16 for entry in egraph.classes.iter() {
17 let (&id, eclass) = entry.pair();
18 for node in &eclass.nodes {
19 if let IKun::Extension(op, args) = node {
20 if args.len() == 2 {
21 let arg1_root = egraph.union_find.find(args[0]);
22 let arg2_root = egraph.union_find.find(args[1]);
23
24 let arg1_const = get_const(egraph, arg1_root);
25 let arg2_const = get_const(egraph, arg2_root);
26
27 if let (Some(v1), Some(v2)) = (arg1_const, arg2_const) {
28 match op.as_str() {
29 "add" => matches.push((id, v1 + v2)),
30 "sub" => matches.push((id, v1 - v2)),
31 "mul" => matches.push((id, v1 * v2)),
32 "div" if v2 != 0 => matches.push((id, v1 / v2)),
33 _ => {}
34 }
35 }
36 }
37 }
38 }
39 }
40
41 for (id, result) in matches {
42 let const_id = egraph.add(IKun::Constant(result));
43 egraph.union(id, const_id);
44 }
45 }
46}
47
48fn get_bool<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id) -> Option<bool> {
49 let root = egraph.union_find.find(id);
50 let eclass = egraph.classes.get(&root)?;
51 for node in &eclass.nodes {
52 if let IKun::BooleanConstant(v) = node {
53 return Some(*v);
54 }
55 }
56 None
57}
58
59pub struct AlgebraicSimplification;
60
61impl<A: Analysis<IKun>> RewriteRule<A> for AlgebraicSimplification {
62 fn name(&self) -> &str {
63 "algebraic-simplification"
64 }
65
66 fn apply(&self, egraph: &EGraph<IKun, A>) {
67 let mut matches = Vec::new();
68 for entry in egraph.classes.iter() {
69 let id = *entry.key();
70 let eclass = entry.value();
71 for node in &eclass.nodes {
72 if let IKun::Extension(op, args) = node {
73 if args.len() == 2 {
74 let arg1_root = egraph.union_find.find(args[0]);
75 let arg2_root = egraph.union_find.find(args[1]);
76
77 match op.as_str() {
78 "add" => {
79 if is_const(egraph, arg2_root, 0) {
81 matches.push((id, args[0]));
82 }
83 else if is_const(egraph, arg1_root, 0) {
85 matches.push((id, args[1]));
86 }
87 }
88 "sub" => {
89 if is_const(egraph, arg2_root, 0) {
91 matches.push((id, args[0]));
92 }
93 else if arg1_root == arg2_root {
95 let zero_id = egraph.add(IKun::Constant(0));
96 matches.push((id, zero_id));
97 }
98 }
99 "mul" => {
100 if is_const(egraph, arg2_root, 1) {
102 matches.push((id, args[0]));
103 }
104 else if is_const(egraph, arg1_root, 1) {
106 matches.push((id, args[1]));
107 }
108 else if is_const(egraph, arg2_root, 0) {
110 matches.push((id, arg2_root));
111 }
112 else if is_const(egraph, arg1_root, 0) {
114 matches.push((id, arg1_root));
115 }
116 }
117 "div" => {
118 if is_const(egraph, arg2_root, 1) {
120 matches.push((id, args[0]));
121 }
122 else if arg1_root == arg2_root {
124 let one_id = egraph.add(IKun::Constant(1));
127 matches.push((id, one_id));
128 }
129 }
130 _ => {}
131 }
132 }
133 }
134 }
135 }
136
137 for (id, target) in matches {
138 egraph.union(id, target);
139 }
140 }
141}
142
143pub struct TrapSimplification;
144
145impl<A: Analysis<IKun>> RewriteRule<A> for TrapSimplification {
146 fn name(&self) -> &str {
147 "trap-simplification"
148 }
149
150 fn apply(&self, egraph: &EGraph<IKun, A>) {
151 let mut matches = Vec::new();
152 for entry in egraph.classes.iter() {
153 let (&id, eclass) = entry.pair();
154 for node in &eclass.nodes {
155 if let IKun::Trap(inner) = node {
156 let inner_root = egraph.union_find.find(*inner);
157
158 if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
160 for inner_node in &inner_eclass.nodes {
161 if let IKun::Trap(_) = inner_node {
162 matches.push((id, *inner));
163 }
164 }
165 }
166 }
167 }
168 }
169
170 for (id, target) in matches {
171 egraph.union(id, target);
172 }
173 }
174}
175
176pub struct UniversalSemanticOptimization;
177
178impl<A: Analysis<IKun>> RewriteRule<A> for UniversalSemanticOptimization {
179 fn name(&self) -> &str {
180 "universal-semantic-optimization"
181 }
182
183 fn apply(&self, egraph: &EGraph<IKun, A>) {
184 let mut matches = Vec::new();
185 for entry in egraph.classes.iter() {
186 let (&id, eclass) = entry.pair();
187 for node in &eclass.nodes {
188 match node {
189 IKun::WithContext(ctx_id, inner_id) => {
191 let ctx_root = egraph.union_find.find(*ctx_id);
192 let inner_root = egraph.union_find.find(*inner_id);
193
194 if let Some(inner_class) = egraph.classes.get(&inner_root) {
195 for inner_node in &inner_class.nodes {
196 if let IKun::WithContext(nested_ctx_id, _nested_inner_id) =
197 inner_node
198 {
199 if egraph.union_find.find(*nested_ctx_id) == ctx_root {
200 matches.push((id, *inner_id));
202 }
203 }
204 }
205 }
206 }
207
208 _ => {}
215 }
216 }
217 }
218
219 for (id, result) in matches {
220 egraph.union(id, result);
221 }
222 }
223}
224
225fn get_const<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id) -> Option<i64> {
226 let root = egraph.union_find.find(id);
227 egraph.classes.get(&root).and_then(|c| {
228 c.nodes.iter().find_map(|n| {
229 if let IKun::Constant(v) = n {
230 Some(*v)
231 } else {
232 None
233 }
234 })
235 })
236}
237
238fn is_const<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id, val: i64) -> bool {
239 get_const(egraph, id) == Some(val)
240}
241
242pub struct StrengthReduction;
243
244impl<A: Analysis<IKun>> RewriteRule<A> for StrengthReduction {
245 fn name(&self) -> &str {
246 "strength-reduction"
247 }
248 fn apply(&self, egraph: &EGraph<IKun, A>) {
249 let mut matches = Vec::new();
250 for entry in egraph.classes.iter() {
251 let (&id, eclass) = entry.pair();
252 for node in &eclass.nodes {
253 if let IKun::Extension(op, args) = node {
254 if args.len() == 2 {
255 if let Some(val) = get_const(egraph, args[1]) {
256 match op.as_str() {
257 "mul" if val > 0 && (val & (val - 1)) == 0 => {
258 let n = (val as f64).log2() as i64;
259 matches.push((id, "shl", args[0], n));
260 }
261 "div" if val > 0 && (val & (val - 1)) == 0 => {
262 let n = (val as f64).log2() as i64;
263 matches.push((id, "shr", args[0], n));
264 }
265 _ => {}
266 }
267 }
268 }
269 }
270 }
271 }
272 for (id, op, arg, n) in matches {
273 let n_id = egraph.add(IKun::Constant(n));
274 let new_id = egraph.add(IKun::Extension(op.to_string(), vec![arg, n_id]));
275 egraph.union(id, new_id);
276 }
277 }
278}
279
280pub struct Peephole;
281
282impl<A: Analysis<IKun>> RewriteRule<A> for Peephole {
283 fn name(&self) -> &str {
284 "peephole"
285 }
286 fn apply(&self, egraph: &EGraph<IKun, A>) {
287 let mut matches = Vec::new();
288 for entry in egraph.classes.iter() {
289 let (&id, eclass) = entry.pair();
290 for node in &eclass.nodes {
291 if let IKun::Extension(op, args) = node {
292 if args.len() == 2 {
293 let arg1_root = egraph.union_find.find(args[0]);
294 let arg2_root = egraph.union_find.find(args[1]);
295 match op.as_str() {
296 "add" if arg1_root == arg2_root => {
297 matches.push((id, "mul", args[0], 2));
298 }
299 _ => {}
300 }
301 }
302 }
303 }
304 }
305 for (id, op, arg, val) in matches {
306 let val_id = egraph.add(IKun::Constant(val));
307 let new_id = egraph.add(IKun::Extension(op.to_string(), vec![arg, val_id]));
308 egraph.union(id, new_id);
309 }
310 }
311}
312
313pub struct MapFusion;
314
315impl<A: Analysis<IKun>> RewriteRule<A> for MapFusion {
316 fn name(&self) -> &str {
317 "map-fusion"
318 }
319
320 fn apply(&self, egraph: &EGraph<IKun, A>) {
321 let mut matches = Vec::new();
322 for entry in egraph.classes.iter() {
323 let (&id, eclass) = entry.pair();
324 for node in &eclass.nodes {
325 if let IKun::Map(f, inner_id) = node {
326 let inner_root = egraph.union_find.find(*inner_id);
327 if let Some(inner_class) = egraph.classes.get(&inner_root) {
328 for inner_node in &inner_class.nodes {
329 if let IKun::Map(g, x) = inner_node {
330 matches.push((id, *f, *g, *x));
331 }
332 }
333 }
334 }
335 }
336 }
337
338 for (id, f, g, x) in matches {
339 let seq_id = egraph.add(IKun::Compose(f, g));
340 let new_map_id = egraph.add(IKun::Map(seq_id, x));
341 egraph.union(id, new_map_id);
342 }
343 }
344}
345
346pub struct FilterFusion;
347
348impl<A: Analysis<IKun>> RewriteRule<A> for FilterFusion {
349 fn name(&self) -> &str {
350 "filter-fusion"
351 }
352
353 fn apply(&self, egraph: &EGraph<IKun, A>) {
354 let mut matches = Vec::new();
355 for entry in egraph.classes.iter() {
356 let (&id, eclass) = entry.pair();
357 for node in &eclass.nodes {
358 if let IKun::Filter(p1, inner_id) = node {
359 let inner_root = egraph.union_find.find(*inner_id);
360 if let Some(inner_class) = egraph.classes.get(&inner_root) {
361 for inner_node in &inner_class.nodes {
362 if let IKun::Filter(p2, x) = inner_node {
363 matches.push((id, *p1, *p2, *x));
364 }
365 }
366 }
367 }
368 }
369 }
370
371 for (id, p1, p2, x) in matches {
372 let combined_p = egraph.add(IKun::Extension("and_predicate".to_string(), vec![p2, p1]));
373 let new_filter_id = egraph.add(IKun::Filter(combined_p, x));
374 egraph.union(id, new_filter_id);
375 }
376 }
377}
378
379pub struct FilterMapFusion;
380
381impl<A: Analysis<IKun>> RewriteRule<A> for FilterMapFusion {
382 fn name(&self) -> &str {
383 "filter-map-fusion"
384 }
385
386 fn apply(&self, egraph: &EGraph<IKun, A>) {
387 let mut matches = Vec::new();
388 for entry in egraph.classes.iter() {
389 let (&id, eclass) = entry.pair();
390 for node in &eclass.nodes {
391 if let IKun::Map(f, inner_id) = node {
392 let inner_root = egraph.union_find.find(*inner_id);
393 if let Some(inner_class) = egraph.classes.get(&inner_root) {
394 for inner_node in &inner_class.nodes {
395 if let IKun::Filter(p, x) = inner_node {
396 matches.push((id, *f, *p, *x));
397 }
398 }
399 }
400 }
401 }
402 }
403
404 for (id, f, p, x) in matches {
405 let fm_node = egraph.add(IKun::Extension("filter_map".to_string(), vec![f, p, x]));
406 egraph.union(id, fm_node);
407 }
408 }
409}
410
411pub struct MapFilterFusion;
412
413impl<A: Analysis<IKun>> RewriteRule<A> for MapFilterFusion {
414 fn name(&self) -> &str {
415 "map-filter-fusion"
416 }
417
418 fn apply(&self, egraph: &EGraph<IKun, A>) {
419 let mut matches = Vec::new();
420 for entry in egraph.classes.iter() {
421 let (&id, eclass) = entry.pair();
422 for node in &eclass.nodes {
423 if let IKun::Filter(p, inner_id) = node {
424 let inner_root = egraph.union_find.find(*inner_id);
425 if let Some(inner_class) = egraph.classes.get(&inner_root) {
426 for inner_node in &inner_class.nodes {
427 if let IKun::Map(f, x) = inner_node {
428 matches.push((id, *p, *f, *x));
429 }
430 }
431 }
432 }
433 }
434 }
435
436 for (id, p, f, x) in matches {
437 let mf_node = egraph.add(IKun::Extension("map_filter".to_string(), vec![p, f, x]));
438 egraph.union(id, mf_node);
439 }
440 }
441}
442
443pub struct MapReduceFusion;
444
445impl<A: Analysis<IKun>> RewriteRule<A> for MapReduceFusion {
446 fn name(&self) -> &str {
447 "map-reduce-fusion"
448 }
449
450 fn apply(&self, egraph: &EGraph<IKun, A>) {
451 let mut matches = Vec::new();
452 for entry in egraph.classes.iter() {
453 let (&id, eclass) = entry.pair();
454 for node in &eclass.nodes {
455 if let IKun::Reduce(g, init, inner_id) = node {
456 let inner_root = egraph.union_find.find(*inner_id);
457 if let Some(inner_class) = egraph.classes.get(&inner_root) {
458 for inner_node in &inner_class.nodes {
459 if let IKun::Map(f, x) = inner_node {
460 matches.push((id, *f, *g, *init, *x));
461 }
462 }
463 }
464 }
465 }
466 }
467
468 for (id, f, g, init, x) in matches {
469 let fused_node = egraph.add(IKun::Extension(
470 "loop_map_reduce".to_string(),
471 vec![f, g, init, x],
472 ));
473 egraph.union(id, fused_node);
474 }
475 }
476}
477
478pub struct FilterReduceFusion;
479
480impl<A: Analysis<IKun>> RewriteRule<A> for FilterReduceFusion {
481 fn name(&self) -> &str {
482 "filter-reduce-fusion"
483 }
484
485 fn apply(&self, egraph: &EGraph<IKun, A>) {
486 let mut matches = Vec::new();
487 for entry in egraph.classes.iter() {
488 let (&id, eclass) = entry.pair();
489 for node in &eclass.nodes {
490 if let IKun::Reduce(f, init, inner_id) = node {
491 let inner_root = egraph.union_find.find(*inner_id);
492 if let Some(inner_class) = egraph.classes.get(&inner_root) {
493 for inner_node in &inner_class.nodes {
494 if let IKun::Filter(p, x) = inner_node {
495 matches.push((id, *p, *f, *init, *x));
496 }
497 }
498 }
499 }
500 }
501 }
502
503 for (id, p, f, init, x) in matches {
504 let fused_node = egraph.add(IKun::Extension(
505 "loop_filter_reduce".to_string(),
506 vec![p, f, init, x],
507 ));
508 egraph.union(id, fused_node);
509 }
510 }
511}
512
513pub struct LayoutTransformation;
514
515impl<A: Analysis<IKun>> RewriteRule<A> for LayoutTransformation {
516 fn name(&self) -> &str {
517 "layout-transformation"
518 }
519 fn apply(&self, egraph: &EGraph<IKun, A>) {
520 let mut matches = Vec::new();
521 for entry in egraph.classes.iter() {
522 let eclass = entry.value();
523 for node in &eclass.nodes {
524 if let IKun::WithContext(ctx_id, body_id) = node {
525 let is_spatial = egraph.classes.get(ctx_id).map_or(false, |c| {
526 c.nodes
527 .iter()
528 .any(|n| matches!(n, IKun::SpatialContext | IKun::GpuContext))
529 });
530
531 if is_spatial {
532 if let Some(body_class) = egraph.classes.get(body_id) {
533 for body_node in &body_class.nodes {
534 if let IKun::Map(f, x) = body_node {
535 matches.push((*body_id, *f, *x));
536 }
537 }
538 }
539 }
540 }
541 }
542 }
543 for (id, f, x) in matches {
544 let soa_map = egraph.add(IKun::SoAMap(f, x));
545 egraph.union(id, soa_map);
546 }
547 }
548}
549
550pub struct LoopTiling;
551
552impl<A: Analysis<IKun>> RewriteRule<A> for LoopTiling {
553 fn name(&self) -> &str {
554 "loop-tiling"
555 }
556 fn apply(&self, egraph: &EGraph<IKun, A>) {
557 let mut matches = Vec::new();
558 for entry in egraph.classes.iter() {
559 let eclass = entry.value();
560 for node in &eclass.nodes {
561 if let IKun::Map(f, x) = node {
562 matches.push((*entry.key(), *f, *x));
563 }
564 }
565 }
566 for (id, f, x) in matches {
567 let tiled_map = egraph.add(IKun::TiledMap(32, f, x));
568 egraph.union(id, tiled_map);
569 }
570 }
571}
572
573pub struct AutoVectorization;
574
575impl<A: Analysis<IKun>> RewriteRule<A> for AutoVectorization {
576 fn name(&self) -> &str {
577 "auto-vectorization"
578 }
579 fn apply(&self, egraph: &EGraph<IKun, A>) {
580 let mut matches = Vec::new();
581 for entry in egraph.classes.iter() {
582 let eclass = entry.value();
583 for node in &eclass.nodes {
584 if let IKun::Map(f, x) = node {
585 matches.push((*entry.key(), *f, *x));
586 }
587 }
588 }
589 for (id, f, x) in matches {
590 let vectorized_map = egraph.add(IKun::VectorizedMap(8, f, x));
591 egraph.union(id, vectorized_map);
592 }
593 }
594}
595
596pub struct GpuSpecialization;
597
598impl<A: Analysis<IKun>> RewriteRule<A> for GpuSpecialization {
599 fn name(&self) -> &str {
600 "gpu-specialization"
601 }
602
603 fn apply(&self, egraph: &EGraph<IKun, A>) {
604 let mut matches = Vec::new();
605 for entry in egraph.classes.iter() {
606 let (&id, eclass) = entry.pair();
607 for node in &eclass.nodes {
608 if let IKun::WithContext(ctx_id, body_id) = node {
609 let is_gpu = egraph.classes.get(ctx_id).map_or(false, |c| {
610 c.nodes.iter().any(|n| matches!(n, IKun::GpuContext))
611 });
612
613 if is_gpu {
614 if let Some(body_class) = egraph.classes.get(body_id) {
615 for body_node in &body_class.nodes {
616 if let IKun::Map(f, x) = body_node {
617 matches.push((id, *f, *x));
618 }
619 }
620 }
621 }
622 }
623 }
624 }
625 }
626}
627
628pub struct CpuSpecialization;
629
630impl<A: Analysis<IKun>> RewriteRule<A> for CpuSpecialization {
631 fn name(&self) -> &str {
632 "cpu-specialization"
633 }
634
635 fn apply(&self, _egraph: &EGraph<IKun, A>) {}
636}
637
638pub struct MapToLoop;
639
640impl<A: Analysis<IKun>> RewriteRule<A> for MapToLoop {
641 fn name(&self) -> &str {
642 "map-to-loop"
643 }
644
645 fn apply(&self, egraph: &EGraph<IKun, A>) {
646 let mut matches = Vec::new();
647 for entry in egraph.classes.iter() {
648 let (&id, eclass) = entry.pair();
649 for node in &eclass.nodes {
650 if let IKun::Map(f, x) = node {
651 matches.push((id, *f, *x));
652 }
653 }
654 }
655
656 for (id, f, x) in matches {
657 let loop_node = egraph.add(IKun::Extension("loop_map".to_string(), vec![f, x]));
658 egraph.union(id, loop_node);
659 }
660
661 let mut filter_matches = Vec::new();
662 for entry in egraph.classes.iter() {
663 let (&id, eclass) = entry.pair();
664 for node in &eclass.nodes {
665 if let IKun::Filter(p, x) = node {
666 filter_matches.push((id, *p, *x));
667 }
668 }
669 }
670
671 for (id, p, x) in filter_matches {
672 let loop_node = egraph.add(IKun::Extension("loop_filter".to_string(), vec![p, x]));
673 egraph.union(id, loop_node);
674 }
675
676 let mut reduce_matches = Vec::new();
677 for entry in egraph.classes.iter() {
678 let (&id, eclass) = entry.pair();
679 for node in &eclass.nodes {
680 if let IKun::Reduce(f, init, x) = node {
681 reduce_matches.push((id, *f, *init, *x));
682 }
683 }
684 }
685
686 for (id, f, init, x) in reduce_matches {
687 let loop_node =
688 egraph.add(IKun::Extension("loop_reduce".to_string(), vec![f, init, x]));
689 egraph.union(id, loop_node);
690 }
691 }
692}
693
694pub fn register_standard_rules<A: Analysis<IKun> + 'static>(registry: &mut RewriteRegistry<A>) {
697 registry.register(RuleCategory::Algebraic, Box::new(ConstantFolding));
699 registry.register(RuleCategory::Algebraic, Box::new(AlgebraicSimplification));
700 registry.register(RuleCategory::Algebraic, Box::new(StrengthReduction));
701 registry.register(RuleCategory::Algebraic, Box::new(Peephole));
702 registry.register(RuleCategory::Algebraic, Box::new(TrapSimplification));
703 registry.register(RuleCategory::Algebraic, Box::new(MetaSimplification));
704 registry.register(RuleCategory::Algebraic, Box::new(ContextSimplification));
705 registry.register(RuleCategory::Algebraic, Box::new(LifeCycleSimplification));
706 registry.register(RuleCategory::Algebraic, Box::new(SafeElimination));
707
708 registry.register(RuleCategory::Architectural, Box::new(MapFusion));
710 registry.register(RuleCategory::Architectural, Box::new(FilterFusion));
711 registry.register(RuleCategory::Architectural, Box::new(FilterMapFusion));
712 registry.register(RuleCategory::Architectural, Box::new(MapFilterFusion));
713 registry.register(RuleCategory::Architectural, Box::new(MapReduceFusion));
714 registry.register(RuleCategory::Architectural, Box::new(FilterReduceFusion));
715 registry.register(RuleCategory::Architectural, Box::new(LayoutTransformation));
716 registry.register(RuleCategory::Architectural, Box::new(LoopTiling));
717 registry.register(RuleCategory::Architectural, Box::new(AutoVectorization));
718
719 registry.register(RuleCategory::Architectural, Box::new(GpuSpecialization));
721 registry.register(RuleCategory::Architectural, Box::new(CpuSpecialization));
722
723 registry.register(RuleCategory::Concretization, Box::new(MapToLoop));
725}
726
727pub struct MetaSimplification;
728
729impl<A: Analysis<IKun>> RewriteRule<A> for MetaSimplification {
730 fn name(&self) -> &str {
731 "meta-simplification"
732 }
733
734 fn apply(&self, egraph: &EGraph<IKun, A>) {
735 let mut matches = Vec::new();
736 for entry in egraph.classes.iter() {
737 let (&id, eclass) = entry.pair();
738 for node in &eclass.nodes {
739 if let IKun::Meta(inner) = node {
740 let inner_root = egraph.union_find.find(*inner);
741
742 if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
744 for inner_node in &inner_eclass.nodes {
745 if let IKun::Meta(_) = inner_node {
746 matches.push((id, *inner));
747 }
748 }
749 }
750
751 if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
753 for inner_node in &inner_eclass.nodes {
754 if let IKun::Constant(_) = inner_node {
755 matches.push((id, *inner));
756 }
757 }
758 }
759 }
760 }
761 }
762
763 for (id, target) in matches {
764 egraph.union(id, target);
765 }
766 }
767}
768
769pub struct ContextSimplification;
770
771impl<A: Analysis<IKun>> RewriteRule<A> for ContextSimplification {
772 fn name(&self) -> &str {
773 "context-simplification"
774 }
775
776 fn apply(&self, egraph: &EGraph<IKun, A>) {
777 let mut matches = Vec::new();
778 for entry in egraph.classes.iter() {
779 let (&id, eclass) = entry.pair();
780 for node in &eclass.nodes {
781 if let IKun::WithContext(ctx, inner) = node {
782 let inner_root = egraph.union_find.find(*inner);
783
784 if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
786 for inner_node in &inner_eclass.nodes {
787 if let IKun::WithContext(inner_ctx, _) = inner_node {
788 if ctx == inner_ctx {
789 matches.push((id, *inner));
790 }
791 }
792 }
793 }
794 }
795 }
796 }
797
798 for (id, target) in matches {
799 egraph.union(id, target);
800 }
801 }
802}
803
804pub struct LifeCycleSimplification;
805
806impl<A: Analysis<IKun>> RewriteRule<A> for LifeCycleSimplification {
807 fn name(&self) -> &str {
808 "lifecycle-simplification"
809 }
810 fn apply(&self, egraph: &EGraph<IKun, A>) {
811 let mut matches = Vec::new();
812 for entry in egraph.classes.iter() {
813 let (&id, eclass) = entry.pair();
814 for node in &eclass.nodes {
815 if let IKun::LifeCycle(setup, cleanup) = node {
816 let setup_root = egraph.union_find.find(*setup);
817
818 if let Some(setup_eclass) = egraph.classes.get(&setup_root) {
820 for setup_node in &setup_eclass.nodes {
821 if let IKun::LifeCycle(s1, c1) = setup_node {
822 matches.push((id, *s1, *c1, *cleanup));
823 }
824 }
825 }
826
827 let cleanup_root = egraph.union_find.find(*cleanup);
830 if let Some(cleanup_eclass) = egraph.classes.get(&cleanup_root) {
831 for cleanup_node in &cleanup_eclass.nodes {
832 if let IKun::Seq(items) = cleanup_node {
833 if items.is_empty() {
834 matches.push((id, *setup, *setup, *setup)); }
836 }
837 }
838 }
839 }
840 }
841 }
842
843 for (id, s1, c1, c2) in matches {
844 if s1 == c1 && c1 == c2 {
845 egraph.union(id, s1);
847 } else {
848 let new_cleanup = egraph.add(IKun::Seq(vec![c1, c2]));
850 let new_lifecycle = egraph.add(IKun::LifeCycle(s1, new_cleanup));
851 egraph.union(id, new_lifecycle);
852 }
853 }
854 }
855}
856
857pub struct SafeElimination;
858
859impl<A: Analysis<IKun>> RewriteRule<A> for SafeElimination {
860 fn name(&self) -> &str {
861 "safe-elimination"
862 }
863 fn apply(&self, egraph: &EGraph<IKun, A>) {
864 let mut matches = Vec::new();
865 for entry in egraph.classes.iter() {
866 let (&id, eclass) = entry.pair();
867 for node in &eclass.nodes {
868 if let IKun::Trap(inner) = node {
869 let inner_root = egraph.union_find.find(*inner);
870 if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
871 for inner_node in &inner_eclass.nodes {
872 if let IKun::WithContext(ctx_id, body_id) = inner_node {
873 let ctx_root = egraph.union_find.find(*ctx_id);
874 if let Some(ctx_eclass) = egraph.classes.get(&ctx_root) {
875 if ctx_eclass
876 .nodes
877 .iter()
878 .any(|n| matches!(n, IKun::SafeContext))
879 {
880 matches.push((id, *body_id));
881 }
882 }
883 }
884 }
885 }
886 }
887 }
888 }
889 for (id, target) in matches {
890 egraph.union(id, target);
891 }
892 }
893}