Skip to main content

chomsky_rules/
lib.rs

1#![warn(missing_docs)]
2
3use chomsky_rule_engine::{RewriteRegistry, RewriteRule, RuleCategory};
4use chomsky_uir::egraph::{Analysis, EGraph};
5use chomsky_uir::{IKun, Id};
6
7pub struct ConstantFolding;
8
9impl<A: Analysis<IKun>> RewriteRule<A> for ConstantFolding {
10    fn name(&self) -> &str {
11        "constant-folding"
12    }
13
14    fn apply(&self, egraph: &EGraph<IKun, A>) {
15        let mut matches = Vec::new();
16        for entry in egraph.classes.iter() {
17            let (&id, eclass) = entry.pair();
18            for node in &eclass.nodes {
19                if let IKun::Extension(op, args) = node {
20                    if args.len() == 2 {
21                        let arg1_root = egraph.union_find.find(args[0]);
22                        let arg2_root = egraph.union_find.find(args[1]);
23
24                        let arg1_const = get_const(egraph, arg1_root);
25                        let arg2_const = get_const(egraph, arg2_root);
26
27                        if let (Some(v1), Some(v2)) = (arg1_const, arg2_const) {
28                            match op.as_str() {
29                                "add" => matches.push((id, v1 + v2)),
30                                "sub" => matches.push((id, v1 - v2)),
31                                "mul" => matches.push((id, v1 * v2)),
32                                "div" if v2 != 0 => matches.push((id, v1 / v2)),
33                                _ => {}
34                            }
35                        }
36                    }
37                }
38            }
39        }
40
41        for (id, result) in matches {
42            let const_id = egraph.add(IKun::Constant(result));
43            egraph.union(id, const_id);
44        }
45    }
46}
47
48fn get_bool<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id) -> Option<bool> {
49    let root = egraph.union_find.find(id);
50    let eclass = egraph.classes.get(&root)?;
51    for node in &eclass.nodes {
52        if let IKun::BooleanConstant(v) = node {
53            return Some(*v);
54        }
55    }
56    None
57}
58
59pub struct AlgebraicSimplification;
60
61impl<A: Analysis<IKun>> RewriteRule<A> for AlgebraicSimplification {
62    fn name(&self) -> &str {
63        "algebraic-simplification"
64    }
65
66    fn apply(&self, egraph: &EGraph<IKun, A>) {
67        let mut matches = Vec::new();
68        for entry in egraph.classes.iter() {
69            let id = *entry.key();
70            let eclass = entry.value();
71            for node in &eclass.nodes {
72                if let IKun::Extension(op, args) = node {
73                    if args.len() == 2 {
74                        let arg1_root = egraph.union_find.find(args[0]);
75                        let arg2_root = egraph.union_find.find(args[1]);
76
77                        match op.as_str() {
78                            "add" => {
79                                // x + 0 = x
80                                if is_const(egraph, arg2_root, 0) {
81                                    matches.push((id, args[0]));
82                                }
83                                // 0 + x = x
84                                else if is_const(egraph, arg1_root, 0) {
85                                    matches.push((id, args[1]));
86                                }
87                            }
88                            "sub" => {
89                                // x - 0 = x
90                                if is_const(egraph, arg2_root, 0) {
91                                    matches.push((id, args[0]));
92                                }
93                                // x - x = 0
94                                else if arg1_root == arg2_root {
95                                    let zero_id = egraph.add(IKun::Constant(0));
96                                    matches.push((id, zero_id));
97                                }
98                            }
99                            "mul" => {
100                                // x * 1 = x
101                                if is_const(egraph, arg2_root, 1) {
102                                    matches.push((id, args[0]));
103                                }
104                                // 1 * x = x
105                                else if is_const(egraph, arg1_root, 1) {
106                                    matches.push((id, args[1]));
107                                }
108                                // x * 0 = 0
109                                else if is_const(egraph, arg2_root, 0) {
110                                    matches.push((id, arg2_root));
111                                }
112                                // 0 * x = 0
113                                else if is_const(egraph, arg1_root, 0) {
114                                    matches.push((id, arg1_root));
115                                }
116                            }
117                            "div" => {
118                                // x / 1 = x
119                                if is_const(egraph, arg2_root, 1) {
120                                    matches.push((id, args[0]));
121                                }
122                                // x / x = 1 (if x != 0)
123                                else if arg1_root == arg2_root {
124                                    // In a real compiler we'd check if x can be 0.
125                                    // For simplicity in this demo, we assume safety or handle it via analysis.
126                                    let one_id = egraph.add(IKun::Constant(1));
127                                    matches.push((id, one_id));
128                                }
129                            }
130                            _ => {}
131                        }
132                    }
133                }
134            }
135        }
136
137        for (id, target) in matches {
138            egraph.union(id, target);
139        }
140    }
141}
142
143pub struct TrapSimplification;
144
145impl<A: Analysis<IKun>> RewriteRule<A> for TrapSimplification {
146    fn name(&self) -> &str {
147        "trap-simplification"
148    }
149
150    fn apply(&self, egraph: &EGraph<IKun, A>) {
151        let mut matches = Vec::new();
152        for entry in egraph.classes.iter() {
153            let (&id, eclass) = entry.pair();
154            for node in &eclass.nodes {
155                if let IKun::Trap(inner) = node {
156                    let inner_root = egraph.union_find.find(*inner);
157
158                    // Trap(Trap(x)) -> Trap(x)
159                    if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
160                        for inner_node in &inner_eclass.nodes {
161                            if let IKun::Trap(_) = inner_node {
162                                matches.push((id, *inner));
163                            }
164                        }
165                    }
166                }
167            }
168        }
169
170        for (id, target) in matches {
171            egraph.union(id, target);
172        }
173    }
174}
175
176pub struct UniversalSemanticOptimization;
177
178impl<A: Analysis<IKun>> RewriteRule<A> for UniversalSemanticOptimization {
179    fn name(&self) -> &str {
180        "universal-semantic-optimization"
181    }
182
183    fn apply(&self, egraph: &EGraph<IKun, A>) {
184        let mut matches = Vec::new();
185        for entry in egraph.classes.iter() {
186            let (&id, eclass) = entry.pair();
187            for node in &eclass.nodes {
188                match node {
189                    // 1. Redundant Context elimination: WithContext(ctx, WithContext(ctx, x)) -> WithContext(ctx, x)
190                    IKun::WithContext(ctx_id, inner_id) => {
191                        let ctx_root = egraph.union_find.find(*ctx_id);
192                        let inner_root = egraph.union_find.find(*inner_id);
193
194                        if let Some(inner_class) = egraph.classes.get(&inner_root) {
195                            for inner_node in &inner_class.nodes {
196                                if let IKun::WithContext(nested_ctx_id, _nested_inner_id) =
197                                    inner_node
198                                {
199                                    if egraph.union_find.find(*nested_ctx_id) == ctx_root {
200                                        // Found redundant context
201                                        matches.push((id, *inner_id));
202                                    }
203                                }
204                            }
205                        }
206                    }
207
208                    // 2. Comptime evaluation: WithContext(ComptimeContext, x) -> x (once evaluated)
209                    // In a real system, this would trigger the actual comptime engine.
210                    // Here we just model the fact that it simplifies if evaluated.
211
212                    // 3. Defer normalization: WithContext(DeferContext, List(a, b, Defer(c)))
213                    // This is more complex and depends on how we lower defer.
214                    _ => {}
215                }
216            }
217        }
218
219        for (id, result) in matches {
220            egraph.union(id, result);
221        }
222    }
223}
224
225fn get_const<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id) -> Option<i64> {
226    let root = egraph.union_find.find(id);
227    egraph.classes.get(&root).and_then(|c| {
228        c.nodes.iter().find_map(|n| {
229            if let IKun::Constant(v) = n {
230                Some(*v)
231            } else {
232                None
233            }
234        })
235    })
236}
237
238fn is_const<A: Analysis<IKun>>(egraph: &EGraph<IKun, A>, id: Id, val: i64) -> bool {
239    get_const(egraph, id) == Some(val)
240}
241
242pub struct StrengthReduction;
243
244impl<A: Analysis<IKun>> RewriteRule<A> for StrengthReduction {
245    fn name(&self) -> &str {
246        "strength-reduction"
247    }
248    fn apply(&self, egraph: &EGraph<IKun, A>) {
249        let mut matches = Vec::new();
250        for entry in egraph.classes.iter() {
251            let (&id, eclass) = entry.pair();
252            for node in &eclass.nodes {
253                if let IKun::Extension(op, args) = node {
254                    if args.len() == 2 {
255                        if let Some(val) = get_const(egraph, args[1]) {
256                            match op.as_str() {
257                                "mul" if val > 0 && (val & (val - 1)) == 0 => {
258                                    let n = (val as f64).log2() as i64;
259                                    matches.push((id, "shl", args[0], n));
260                                }
261                                "div" if val > 0 && (val & (val - 1)) == 0 => {
262                                    let n = (val as f64).log2() as i64;
263                                    matches.push((id, "shr", args[0], n));
264                                }
265                                _ => {}
266                            }
267                        }
268                    }
269                }
270            }
271        }
272        for (id, op, arg, n) in matches {
273            let n_id = egraph.add(IKun::Constant(n));
274            let new_id = egraph.add(IKun::Extension(op.to_string(), vec![arg, n_id]));
275            egraph.union(id, new_id);
276        }
277    }
278}
279
280pub struct Peephole;
281
282impl<A: Analysis<IKun>> RewriteRule<A> for Peephole {
283    fn name(&self) -> &str {
284        "peephole"
285    }
286    fn apply(&self, egraph: &EGraph<IKun, A>) {
287        let mut matches = Vec::new();
288        for entry in egraph.classes.iter() {
289            let (&id, eclass) = entry.pair();
290            for node in &eclass.nodes {
291                if let IKun::Extension(op, args) = node {
292                    if args.len() == 2 {
293                        let arg1_root = egraph.union_find.find(args[0]);
294                        let arg2_root = egraph.union_find.find(args[1]);
295                        match op.as_str() {
296                            "add" if arg1_root == arg2_root => {
297                                matches.push((id, "mul", args[0], 2));
298                            }
299                            _ => {}
300                        }
301                    }
302                }
303            }
304        }
305        for (id, op, arg, val) in matches {
306            let val_id = egraph.add(IKun::Constant(val));
307            let new_id = egraph.add(IKun::Extension(op.to_string(), vec![arg, val_id]));
308            egraph.union(id, new_id);
309        }
310    }
311}
312
313pub struct MapFusion;
314
315impl<A: Analysis<IKun>> RewriteRule<A> for MapFusion {
316    fn name(&self) -> &str {
317        "map-fusion"
318    }
319
320    fn apply(&self, egraph: &EGraph<IKun, A>) {
321        let mut matches = Vec::new();
322        for entry in egraph.classes.iter() {
323            let (&id, eclass) = entry.pair();
324            for node in &eclass.nodes {
325                if let IKun::Map(f, inner_id) = node {
326                    let inner_root = egraph.union_find.find(*inner_id);
327                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
328                        for inner_node in &inner_class.nodes {
329                            if let IKun::Map(g, x) = inner_node {
330                                matches.push((id, *f, *g, *x));
331                            }
332                        }
333                    }
334                }
335            }
336        }
337
338        for (id, f, g, x) in matches {
339            let seq_id = egraph.add(IKun::Compose(f, g));
340            let new_map_id = egraph.add(IKun::Map(seq_id, x));
341            egraph.union(id, new_map_id);
342        }
343    }
344}
345
346pub struct FilterFusion;
347
348impl<A: Analysis<IKun>> RewriteRule<A> for FilterFusion {
349    fn name(&self) -> &str {
350        "filter-fusion"
351    }
352
353    fn apply(&self, egraph: &EGraph<IKun, A>) {
354        let mut matches = Vec::new();
355        for entry in egraph.classes.iter() {
356            let (&id, eclass) = entry.pair();
357            for node in &eclass.nodes {
358                if let IKun::Filter(p1, inner_id) = node {
359                    let inner_root = egraph.union_find.find(*inner_id);
360                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
361                        for inner_node in &inner_class.nodes {
362                            if let IKun::Filter(p2, x) = inner_node {
363                                matches.push((id, *p1, *p2, *x));
364                            }
365                        }
366                    }
367                }
368            }
369        }
370
371        for (id, p1, p2, x) in matches {
372            let combined_p = egraph.add(IKun::Extension("and_predicate".to_string(), vec![p2, p1]));
373            let new_filter_id = egraph.add(IKun::Filter(combined_p, x));
374            egraph.union(id, new_filter_id);
375        }
376    }
377}
378
379pub struct FilterMapFusion;
380
381impl<A: Analysis<IKun>> RewriteRule<A> for FilterMapFusion {
382    fn name(&self) -> &str {
383        "filter-map-fusion"
384    }
385
386    fn apply(&self, egraph: &EGraph<IKun, A>) {
387        let mut matches = Vec::new();
388        for entry in egraph.classes.iter() {
389            let (&id, eclass) = entry.pair();
390            for node in &eclass.nodes {
391                if let IKun::Map(f, inner_id) = node {
392                    let inner_root = egraph.union_find.find(*inner_id);
393                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
394                        for inner_node in &inner_class.nodes {
395                            if let IKun::Filter(p, x) = inner_node {
396                                matches.push((id, *f, *p, *x));
397                            }
398                        }
399                    }
400                }
401            }
402        }
403
404        for (id, f, p, x) in matches {
405            let fm_node = egraph.add(IKun::Extension("filter_map".to_string(), vec![f, p, x]));
406            egraph.union(id, fm_node);
407        }
408    }
409}
410
411pub struct MapFilterFusion;
412
413impl<A: Analysis<IKun>> RewriteRule<A> for MapFilterFusion {
414    fn name(&self) -> &str {
415        "map-filter-fusion"
416    }
417
418    fn apply(&self, egraph: &EGraph<IKun, A>) {
419        let mut matches = Vec::new();
420        for entry in egraph.classes.iter() {
421            let (&id, eclass) = entry.pair();
422            for node in &eclass.nodes {
423                if let IKun::Filter(p, inner_id) = node {
424                    let inner_root = egraph.union_find.find(*inner_id);
425                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
426                        for inner_node in &inner_class.nodes {
427                            if let IKun::Map(f, x) = inner_node {
428                                matches.push((id, *p, *f, *x));
429                            }
430                        }
431                    }
432                }
433            }
434        }
435
436        for (id, p, f, x) in matches {
437            let mf_node = egraph.add(IKun::Extension("map_filter".to_string(), vec![p, f, x]));
438            egraph.union(id, mf_node);
439        }
440    }
441}
442
443pub struct MapReduceFusion;
444
445impl<A: Analysis<IKun>> RewriteRule<A> for MapReduceFusion {
446    fn name(&self) -> &str {
447        "map-reduce-fusion"
448    }
449
450    fn apply(&self, egraph: &EGraph<IKun, A>) {
451        let mut matches = Vec::new();
452        for entry in egraph.classes.iter() {
453            let (&id, eclass) = entry.pair();
454            for node in &eclass.nodes {
455                if let IKun::Reduce(g, init, inner_id) = node {
456                    let inner_root = egraph.union_find.find(*inner_id);
457                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
458                        for inner_node in &inner_class.nodes {
459                            if let IKun::Map(f, x) = inner_node {
460                                matches.push((id, *f, *g, *init, *x));
461                            }
462                        }
463                    }
464                }
465            }
466        }
467
468        for (id, f, g, init, x) in matches {
469            let fused_node = egraph.add(IKun::Extension(
470                "loop_map_reduce".to_string(),
471                vec![f, g, init, x],
472            ));
473            egraph.union(id, fused_node);
474        }
475    }
476}
477
478pub struct FilterReduceFusion;
479
480impl<A: Analysis<IKun>> RewriteRule<A> for FilterReduceFusion {
481    fn name(&self) -> &str {
482        "filter-reduce-fusion"
483    }
484
485    fn apply(&self, egraph: &EGraph<IKun, A>) {
486        let mut matches = Vec::new();
487        for entry in egraph.classes.iter() {
488            let (&id, eclass) = entry.pair();
489            for node in &eclass.nodes {
490                if let IKun::Reduce(f, init, inner_id) = node {
491                    let inner_root = egraph.union_find.find(*inner_id);
492                    if let Some(inner_class) = egraph.classes.get(&inner_root) {
493                        for inner_node in &inner_class.nodes {
494                            if let IKun::Filter(p, x) = inner_node {
495                                matches.push((id, *p, *f, *init, *x));
496                            }
497                        }
498                    }
499                }
500            }
501        }
502
503        for (id, p, f, init, x) in matches {
504            let fused_node = egraph.add(IKun::Extension(
505                "loop_filter_reduce".to_string(),
506                vec![p, f, init, x],
507            ));
508            egraph.union(id, fused_node);
509        }
510    }
511}
512
513pub struct LayoutTransformation;
514
515impl<A: Analysis<IKun>> RewriteRule<A> for LayoutTransformation {
516    fn name(&self) -> &str {
517        "layout-transformation"
518    }
519    fn apply(&self, egraph: &EGraph<IKun, A>) {
520        let mut matches = Vec::new();
521        for entry in egraph.classes.iter() {
522            let eclass = entry.value();
523            for node in &eclass.nodes {
524                if let IKun::WithContext(ctx_id, body_id) = node {
525                    let is_spatial = egraph.classes.get(ctx_id).map_or(false, |c| {
526                        c.nodes
527                            .iter()
528                            .any(|n| matches!(n, IKun::SpatialContext | IKun::GpuContext))
529                    });
530
531                    if is_spatial {
532                        if let Some(body_class) = egraph.classes.get(body_id) {
533                            for body_node in &body_class.nodes {
534                                if let IKun::Map(f, x) = body_node {
535                                    matches.push((*body_id, *f, *x));
536                                }
537                            }
538                        }
539                    }
540                }
541            }
542        }
543        for (id, f, x) in matches {
544            let soa_map = egraph.add(IKun::SoAMap(f, x));
545            egraph.union(id, soa_map);
546        }
547    }
548}
549
550pub struct LoopTiling;
551
552impl<A: Analysis<IKun>> RewriteRule<A> for LoopTiling {
553    fn name(&self) -> &str {
554        "loop-tiling"
555    }
556    fn apply(&self, egraph: &EGraph<IKun, A>) {
557        let mut matches = Vec::new();
558        for entry in egraph.classes.iter() {
559            let eclass = entry.value();
560            for node in &eclass.nodes {
561                if let IKun::Map(f, x) = node {
562                    matches.push((*entry.key(), *f, *x));
563                }
564            }
565        }
566        for (id, f, x) in matches {
567            let tiled_map = egraph.add(IKun::TiledMap(32, f, x));
568            egraph.union(id, tiled_map);
569        }
570    }
571}
572
573pub struct AutoVectorization;
574
575impl<A: Analysis<IKun>> RewriteRule<A> for AutoVectorization {
576    fn name(&self) -> &str {
577        "auto-vectorization"
578    }
579    fn apply(&self, egraph: &EGraph<IKun, A>) {
580        let mut matches = Vec::new();
581        for entry in egraph.classes.iter() {
582            let eclass = entry.value();
583            for node in &eclass.nodes {
584                if let IKun::Map(f, x) = node {
585                    matches.push((*entry.key(), *f, *x));
586                }
587            }
588        }
589        for (id, f, x) in matches {
590            let vectorized_map = egraph.add(IKun::VectorizedMap(8, f, x));
591            egraph.union(id, vectorized_map);
592        }
593    }
594}
595
596pub struct GpuSpecialization;
597
598impl<A: Analysis<IKun>> RewriteRule<A> for GpuSpecialization {
599    fn name(&self) -> &str {
600        "gpu-specialization"
601    }
602
603    fn apply(&self, egraph: &EGraph<IKun, A>) {
604        let mut matches = Vec::new();
605        for entry in egraph.classes.iter() {
606            let (&id, eclass) = entry.pair();
607            for node in &eclass.nodes {
608                if let IKun::WithContext(ctx_id, body_id) = node {
609                    let is_gpu = egraph.classes.get(ctx_id).map_or(false, |c| {
610                        c.nodes.iter().any(|n| matches!(n, IKun::GpuContext))
611                    });
612
613                    if is_gpu {
614                        if let Some(body_class) = egraph.classes.get(body_id) {
615                            for body_node in &body_class.nodes {
616                                if let IKun::Map(f, x) = body_node {
617                                    matches.push((id, *f, *x));
618                                }
619                            }
620                        }
621                    }
622                }
623            }
624        }
625    }
626}
627
628pub struct CpuSpecialization;
629
630impl<A: Analysis<IKun>> RewriteRule<A> for CpuSpecialization {
631    fn name(&self) -> &str {
632        "cpu-specialization"
633    }
634
635    fn apply(&self, _egraph: &EGraph<IKun, A>) {}
636}
637
638pub struct MapToLoop;
639
640impl<A: Analysis<IKun>> RewriteRule<A> for MapToLoop {
641    fn name(&self) -> &str {
642        "map-to-loop"
643    }
644
645    fn apply(&self, egraph: &EGraph<IKun, A>) {
646        let mut matches = Vec::new();
647        for entry in egraph.classes.iter() {
648            let (&id, eclass) = entry.pair();
649            for node in &eclass.nodes {
650                if let IKun::Map(f, x) = node {
651                    matches.push((id, *f, *x));
652                }
653            }
654        }
655
656        for (id, f, x) in matches {
657            let loop_node = egraph.add(IKun::Extension("loop_map".to_string(), vec![f, x]));
658            egraph.union(id, loop_node);
659        }
660
661        let mut filter_matches = Vec::new();
662        for entry in egraph.classes.iter() {
663            let (&id, eclass) = entry.pair();
664            for node in &eclass.nodes {
665                if let IKun::Filter(p, x) = node {
666                    filter_matches.push((id, *p, *x));
667                }
668            }
669        }
670
671        for (id, p, x) in filter_matches {
672            let loop_node = egraph.add(IKun::Extension("loop_filter".to_string(), vec![p, x]));
673            egraph.union(id, loop_node);
674        }
675
676        let mut reduce_matches = Vec::new();
677        for entry in egraph.classes.iter() {
678            let (&id, eclass) = entry.pair();
679            for node in &eclass.nodes {
680                if let IKun::Reduce(f, init, x) = node {
681                    reduce_matches.push((id, *f, *init, *x));
682                }
683            }
684        }
685
686        for (id, f, init, x) in reduce_matches {
687            let loop_node =
688                egraph.add(IKun::Extension("loop_reduce".to_string(), vec![f, init, x]));
689            egraph.union(id, loop_node);
690        }
691    }
692}
693
694/// 注册所有标准优化规则。这些规则涵盖了通用的代数优化、
695/// 函数式融合以及编程语言(PL)相关的通用语义优化。
696pub fn register_standard_rules<A: Analysis<IKun> + 'static>(registry: &mut RewriteRegistry<A>) {
697    // 1. 代数规则 (Algebraic Rules)
698    registry.register(RuleCategory::Algebraic, Box::new(ConstantFolding));
699    registry.register(RuleCategory::Algebraic, Box::new(AlgebraicSimplification));
700    registry.register(RuleCategory::Algebraic, Box::new(StrengthReduction));
701    registry.register(RuleCategory::Algebraic, Box::new(Peephole));
702    registry.register(RuleCategory::Algebraic, Box::new(TrapSimplification));
703    registry.register(RuleCategory::Algebraic, Box::new(MetaSimplification));
704    registry.register(RuleCategory::Algebraic, Box::new(ContextSimplification));
705    registry.register(RuleCategory::Algebraic, Box::new(LifeCycleSimplification));
706    registry.register(RuleCategory::Algebraic, Box::new(SafeElimination));
707
708    // 2. 架构与融合规则 (Architectural & Fusion Rules)
709    registry.register(RuleCategory::Architectural, Box::new(MapFusion));
710    registry.register(RuleCategory::Architectural, Box::new(FilterFusion));
711    registry.register(RuleCategory::Architectural, Box::new(FilterMapFusion));
712    registry.register(RuleCategory::Architectural, Box::new(MapFilterFusion));
713    registry.register(RuleCategory::Architectural, Box::new(MapReduceFusion));
714    registry.register(RuleCategory::Architectural, Box::new(FilterReduceFusion));
715    registry.register(RuleCategory::Architectural, Box::new(LayoutTransformation));
716    registry.register(RuleCategory::Architectural, Box::new(LoopTiling));
717    registry.register(RuleCategory::Architectural, Box::new(AutoVectorization));
718
719    // 3. 特化规则 (Specialization Rules)
720    registry.register(RuleCategory::Architectural, Box::new(GpuSpecialization));
721    registry.register(RuleCategory::Architectural, Box::new(CpuSpecialization));
722
723    // 4. 具体化规则 (Concretization Rules)
724    registry.register(RuleCategory::Concretization, Box::new(MapToLoop));
725}
726
727pub struct MetaSimplification;
728
729impl<A: Analysis<IKun>> RewriteRule<A> for MetaSimplification {
730    fn name(&self) -> &str {
731        "meta-simplification"
732    }
733
734    fn apply(&self, egraph: &EGraph<IKun, A>) {
735        let mut matches = Vec::new();
736        for entry in egraph.classes.iter() {
737            let (&id, eclass) = entry.pair();
738            for node in &eclass.nodes {
739                if let IKun::Meta(inner) = node {
740                    let inner_root = egraph.union_find.find(*inner);
741
742                    // 1. Meta(Meta(x)) -> Meta(x)
743                    if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
744                        for inner_node in &inner_eclass.nodes {
745                            if let IKun::Meta(_) = inner_node {
746                                matches.push((id, *inner));
747                            }
748                        }
749                    }
750
751                    // 2. Meta(Constant(x)) -> Constant(x)
752                    if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
753                        for inner_node in &inner_eclass.nodes {
754                            if let IKun::Constant(_) = inner_node {
755                                matches.push((id, *inner));
756                            }
757                        }
758                    }
759                }
760            }
761        }
762
763        for (id, target) in matches {
764            egraph.union(id, target);
765        }
766    }
767}
768
769pub struct ContextSimplification;
770
771impl<A: Analysis<IKun>> RewriteRule<A> for ContextSimplification {
772    fn name(&self) -> &str {
773        "context-simplification"
774    }
775
776    fn apply(&self, egraph: &EGraph<IKun, A>) {
777        let mut matches = Vec::new();
778        for entry in egraph.classes.iter() {
779            let (&id, eclass) = entry.pair();
780            for node in &eclass.nodes {
781                if let IKun::WithContext(ctx, inner) = node {
782                    let inner_root = egraph.union_find.find(*inner);
783
784                    // 1. Redundant Context: WithContext(ctx, WithContext(ctx, x)) -> WithContext(ctx, x)
785                    if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
786                        for inner_node in &inner_eclass.nodes {
787                            if let IKun::WithContext(inner_ctx, _) = inner_node {
788                                if ctx == inner_ctx {
789                                    matches.push((id, *inner));
790                                }
791                            }
792                        }
793                    }
794                }
795            }
796        }
797
798        for (id, target) in matches {
799            egraph.union(id, target);
800        }
801    }
802}
803
804pub struct LifeCycleSimplification;
805
806impl<A: Analysis<IKun>> RewriteRule<A> for LifeCycleSimplification {
807    fn name(&self) -> &str {
808        "lifecycle-simplification"
809    }
810    fn apply(&self, egraph: &EGraph<IKun, A>) {
811        let mut matches = Vec::new();
812        for entry in egraph.classes.iter() {
813            let (&id, eclass) = entry.pair();
814            for node in &eclass.nodes {
815                if let IKun::LifeCycle(setup, cleanup) = node {
816                    let setup_root = egraph.union_find.find(*setup);
817
818                    // 1. LifeCycle(LifeCycle(s1, c1), c2) -> LifeCycle(s1, Seq([c1, c2]))
819                    if let Some(setup_eclass) = egraph.classes.get(&setup_root) {
820                        for setup_node in &setup_eclass.nodes {
821                            if let IKun::LifeCycle(s1, c1) = setup_node {
822                                matches.push((id, *s1, *c1, *cleanup));
823                            }
824                        }
825                    }
826
827                    // 2. LifeCycle(Constant(x), c) -> Constant(x) if we can prove c is not needed or side-effect free.
828                    // For now, let's just do a simple version: LifeCycle(x, EmptySeq) -> x
829                    let cleanup_root = egraph.union_find.find(*cleanup);
830                    if let Some(cleanup_eclass) = egraph.classes.get(&cleanup_root) {
831                        for cleanup_node in &cleanup_eclass.nodes {
832                            if let IKun::Seq(items) = cleanup_node {
833                                if items.is_empty() {
834                                    matches.push((id, *setup, *setup, *setup)); // dummy target to signal simplification
835                                }
836                            }
837                        }
838                    }
839                }
840            }
841        }
842
843        for (id, s1, c1, c2) in matches {
844            if s1 == c1 && c1 == c2 {
845                // Simplification case 2
846                egraph.union(id, s1);
847            } else {
848                // Simplification case 1
849                let new_cleanup = egraph.add(IKun::Seq(vec![c1, c2]));
850                let new_lifecycle = egraph.add(IKun::LifeCycle(s1, new_cleanup));
851                egraph.union(id, new_lifecycle);
852            }
853        }
854    }
855}
856
857pub struct SafeElimination;
858
859impl<A: Analysis<IKun>> RewriteRule<A> for SafeElimination {
860    fn name(&self) -> &str {
861        "safe-elimination"
862    }
863    fn apply(&self, egraph: &EGraph<IKun, A>) {
864        let mut matches = Vec::new();
865        for entry in egraph.classes.iter() {
866            let (&id, eclass) = entry.pair();
867            for node in &eclass.nodes {
868                if let IKun::Trap(inner) = node {
869                    let inner_root = egraph.union_find.find(*inner);
870                    if let Some(inner_eclass) = egraph.classes.get(&inner_root) {
871                        for inner_node in &inner_eclass.nodes {
872                            if let IKun::WithContext(ctx_id, body_id) = inner_node {
873                                let ctx_root = egraph.union_find.find(*ctx_id);
874                                if let Some(ctx_eclass) = egraph.classes.get(&ctx_root) {
875                                    if ctx_eclass
876                                        .nodes
877                                        .iter()
878                                        .any(|n| matches!(n, IKun::SafeContext))
879                                    {
880                                        matches.push((id, *body_id));
881                                    }
882                                }
883                            }
884                        }
885                    }
886                }
887            }
888        }
889        for (id, target) in matches {
890            egraph.union(id, target);
891        }
892    }
893}