meganeura 0.2.0

E-graph optimized neural network training on Blade
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
use crate::graph::{Graph, Node, Op};
use egglog::{Term, TermDag, TermId};
use std::collections::HashMap;
use std::{fmt, time::Instant};

/// Report from the e-graph optimization pass.
pub struct OptimizeReport {
    /// The egglog program text (for external inspection / replay).
    pub egglog_program: String,
    /// Number of e-classes after saturation.
    pub num_eclasses: usize,
    /// Number of e-nodes after saturation.
    pub num_enodes: usize,
    /// Which rewrite rules fired and how many times.
    pub rules_fired: Vec<(String, usize)>,
    /// Graph node count before optimization.
    pub nodes_before: usize,
    /// Graph node count after optimization (excluding Nop).
    pub nodes_after: usize,
    /// Fusions applied: list of (fusion_name, node_index) pairs.
    pub fusions_applied: Vec<(String, u32)>,
    /// Wall-clock time for egglog saturation.
    pub egglog_time: std::time::Duration,
    /// Wall-clock time for graph extraction + fusion rewrites.
    pub extract_time: std::time::Duration,
}

impl fmt::Display for OptimizeReport {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        writeln!(f, "=== Optimization Report ===")?;
        writeln!(
            f,
            "Egglog saturation: {:.1}ms ({} e-classes, {} e-nodes)",
            self.egglog_time.as_secs_f64() * 1000.0,
            self.num_eclasses,
            self.num_enodes,
        )?;
        if !self.rules_fired.is_empty() {
            writeln!(f, "Rules fired:")?;
            for &(ref rule, count) in &self.rules_fired {
                writeln!(f, "  {}  x{}", rule, count)?;
            }
        }
        writeln!(
            f,
            "Graph: {} nodes -> {} active nodes ({} fused away)",
            self.nodes_before,
            self.nodes_after,
            self.nodes_before.saturating_sub(self.nodes_after),
        )?;
        if !self.fusions_applied.is_empty() {
            write!(f, "Fusions:")?;
            for (i, &(ref name, node_idx)) in self.fusions_applied.iter().enumerate() {
                if i > 0 {
                    write!(f, ",")?;
                }
                write!(f, " {} @node{}", name, node_idx)?;
            }
            writeln!(f)?;
        }
        write!(
            f,
            "Extract time: {:.1}ms",
            self.extract_time.as_secs_f64() * 1000.0
        )
    }
}

/// Convert a Graph to an egglog program string, run equality saturation
/// with rewrite rules, and extract the optimized graph back.
pub fn optimize(graph: &Graph) -> Graph {
    let (graph, _report) = optimize_with_report(graph);
    graph
}

/// Like `optimize`, but also returns a detailed report for debugging.
pub fn optimize_with_report(graph: &Graph) -> (Graph, OptimizeReport) {
    let program = graph_to_egglog(graph);
    log::debug!("egglog program:\n{}", program);

    let nodes_before = graph.nodes().len();
    let mut num_eclasses = 0;
    let mut num_enodes = 0;

    let node_count = graph
        .nodes()
        .iter()
        .filter(|n| !matches!(n.op, Op::Nop))
        .count();
    // egglog saturation time grows superlinearly with node count.
    // For the SmolVLA training graph (~750 nodes), saturation takes
    // minutes. Fall back to direct pattern matching for large graphs.
    // TODO: investigate egglog performance or use incremental matching.
    let egglog_start = Instant::now();
    // egglog saturation with shared-parameter graphs (like transformers
    // where the same weight is used by many MatMul nodes) creates large
    // e-classes that make pattern matching slow. 300 nodes handles most
    // small models; larger models fall back to direct pattern matching.
    if node_count > 300 {
        log::debug!(
            "egglog: {} nodes, falling back to pattern matching",
            node_count
        );
        let extract_start = Instant::now();
        let (optimized, fusions_applied) = rebuild_graph_from_extractions(graph, &[]);
        let extract_time = extract_start.elapsed();
        return (
            optimized,
            OptimizeReport {
                egglog_program: program,
                num_eclasses: 0,
                num_enodes: 0,
                rules_fired: fusions_applied.iter().fold(Vec::new(), |mut acc, entry| {
                    let name = &entry.0;
                    if let Some(e) = acc.iter_mut().find(|e: &&mut (String, usize)| e.0 == *name) {
                        e.1 += 1;
                    } else {
                        acc.push((name.clone(), 1));
                    }
                    acc
                }),
                nodes_before,
                nodes_after: 0,
                fusions_applied,
                egglog_time: std::time::Duration::ZERO,
                extract_time,
            },
        );
    }
    let mut egraph = egglog::EGraph::default();
    let egglog_result = egraph.parse_and_run_program(None, &program);
    log::debug!(
        "egglog: saturation took {:.1}ms",
        egglog_start.elapsed().as_secs_f64() * 1000.0
    );
    let egglog_ok;
    let mut extractions: Vec<(TermDag, TermId)> = Vec::new();

    match egglog_result {
        Ok(outputs) => {
            egglog_ok = true;
            for out in &outputs {
                if let egglog::CommandOutput::ExtractBest(ref dag, _cost, term_id) = *out {
                    log::debug!("egglog extracted: {}", dag.to_string(term_id));
                    extractions.push((dag.clone(), term_id));
                }
            }
        }
        Err(e) => {
            log::warn!(
                "egglog optimization failed: {}, returning original graph",
                e
            );
            egglog_ok = false;
        }
    };
    let egglog_time = egglog_start.elapsed();

    if egglog_ok {
        let serialized = egraph.serialize(egglog::SerializeConfig::default());
        num_eclasses = serialized.egraph.class_data.len();
        num_enodes = serialized.egraph.nodes.len();
    }

    let extract_start = Instant::now();
    let (optimized, fusions_applied) = rebuild_graph_from_extractions(graph, &extractions);
    let extract_time = extract_start.elapsed();

    let nodes_after = optimized
        .nodes()
        .iter()
        .filter(|n| !matches!(n.op, Op::Nop))
        .count();

    let mut rules_fired: Vec<(String, usize)> = Vec::new();
    for fusion in &fusions_applied {
        if let Some(entry) = rules_fired.iter_mut().find(|e| e.0 == fusion.0) {
            entry.1 += 1;
        } else {
            rules_fired.push((fusion.0.clone(), 1));
        }
    }

    let report = OptimizeReport {
        egglog_program: program,
        num_eclasses,
        num_enodes,
        rules_fired,
        nodes_before,
        nodes_after,
        fusions_applied,
        egglog_time,
        extract_time,
    };

    (optimized, report)
}

/// Dump the egglog program for a graph (for standalone debugging).
pub fn dump_egglog_program(graph: &Graph) -> String {
    graph_to_egglog(graph)
}

/// Generate egglog program text from a Graph.
///
/// Encodes the FULL graph (forward + backward) into egglog. Every node
/// becomes an expression. Rewrite rules express algebraic simplifications
/// and kernel fusions — egglog discovers which fusions are applicable via
/// equality saturation.
fn graph_to_egglog(graph: &Graph) -> String {
    let mut prog = String::new();

    // Sort and constructors — covers forward, backward, and fused ops
    prog.push_str(
        "\
(datatype Op
  ; --- Leaf nodes ---
  (Input String)
  (Parameter String)
  (Const i64)
  ; --- Forward matmul variants ---
  (MatMul Op Op)
  (MatMulAT Op Op)
  (MatMulBT Op Op)
  ; --- Fused matmul+add (targets for fusion rules) ---
  (FusedMatMulAdd Op Op Op)
  (FusedMatMulATAdd Op Op Op)
  (FusedMatMulBTAdd Op Op Op)
  ; --- Element-wise ---
  (Add Op Op)
  (Mul Op Op)
  (BiasAdd Op Op)
  (Relu Op)
  (Sigmoid Op)
  (Tanh Op)
  (Neg Op)
  (Abs Op)
  (Log Op)
  (Recip Op)
  (ScatterAdd i64 Op Op)
  (Silu Op)
  (Gelu Op)
  (Identity Op)
  ; --- Shape / reduction ---
  (Transpose Op)
  (Softmax Op)
  (LogSoftmax Op)
  (SumAll Op)
  (MeanAll Op)
  (SumRows Op)
  (CrossEntropyLoss Op Op)
  (BceLoss Op Op)
  (Greater Op Op)
  ; --- Transformer forward ---
  (SwiGLU Op Op)
  (SwiGLUConcat Op)
  (RmsNorm Op Op)
  (FusedRmsNormMatMul Op Op Op)
  (Embedding Op Op)
  (RoPE Op)
  (RoPEGrad Op)
  (CausalAttention Op Op Op)
  (SlidingWindowAttention Op Op Op)
  (LayerNorm Op Op Op)
  (FullAttention Op Op Op)
  (CrossAttention Op Op Op)
  (MultiHeadAttn Op Op Op)
  ; --- GroupNorm, Concat, Upsample, Conv2d ops ---
  (GroupNorm Op Op Op)
  (GroupNormSilu Op Op Op)
  (GroupNormGradInput Op Op Op)
  (GroupNormGradWeightBias Op Op)
  (Concat Op Op)
  (SplitA Op)
  (SplitB Op)
  (Upsample2x Op)
  (Upsample2xGrad Op)
  (Conv2d Op Op)
  (Conv2dGradInput Op Op)
  (Conv2dGradWeight Op Op)
  (MaxPool2d Op)
  (GlobalAvgPool Op)
  (GlobalAvgPoolGrad Op)
  ; --- KV cache ops ---
  (CacheWrite Op Op Op)
  (CachedAttention Op Op Op Op)
  ; --- Backward / gradient ops ---
  (SiluGrad Op Op)
  (SwiGLUGradGate Op Op Op)
  (SwiGLUGradUp Op Op)
  (SwiGLUConcatGrad Op Op)
  (RmsNormGradW Op Op Op)
  (RmsNormGradX Op Op Op)
  (LayerNormGradWB Op Op Op)
  (LayerNormGradX Op Op Op)
  (MHAGradQ Op Op Op Op)
  (MHAGradK Op Op Op Op)
  (MHAGradV Op Op Op Op)
)

",
    );

    // Rewrite rules — these are the optimizations egglog discovers
    prog.push_str(
        "\
; --- Algebraic simplifications ---
(rewrite (Neg (Neg ?x)) ?x)
(rewrite (Transpose (Transpose ?x)) ?x)
(rewrite (Relu (Relu ?x)) (Relu ?x))

; --- Kernel fusion: Add(MatMul*(a,b), d) → FusedMatMul*Add(a,b,d) ---
; Both argument orders handled explicitly (no general Add commutativity
; rule, which causes exponential blowup on large graphs).
(rewrite (Add (MatMul ?a ?b) ?d)    (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add ?d (MatMul ?a ?b))    (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add (MatMulAT ?a ?b) ?d)  (FusedMatMulATAdd ?a ?b ?d))
(rewrite (Add ?d (MatMulAT ?a ?b))  (FusedMatMulATAdd ?a ?b ?d))
(rewrite (Add (MatMulBT ?a ?b) ?d)  (FusedMatMulBTAdd ?a ?b ?d))
(rewrite (Add ?d (MatMulBT ?a ?b))  (FusedMatMulBTAdd ?a ?b ?d))

; --- RmsNorm+MatMul fusion ---
(rewrite (MatMul (RmsNorm ?x ?w_norm) ?w_proj) (FusedRmsNormMatMul ?x ?w_norm ?w_proj))

; --- SwiGLU fusion: two matmuls sharing input → single wide matmul ---
; SwiGLU(MatMul(h, w1), MatMul(h, w2)) can use SwiGLUConcat on a
; concatenated [h, w1|w2] matmul. Pattern matcher handles weight
; concatenation since egglog can't create new tensors.
; (documented here; applied by apply_swiglu_concat_fusions)

; --- ONNX decomposed op recognition ---
; PyTorch decomposes compound ops when exporting to ONNX.
; These rules recognize the decomposed patterns and fuse them back
; into our efficient compound kernels.

; Silu: x * sigmoid(x) → Silu(x)
(rewrite (Mul ?x (Sigmoid ?x)) (Silu ?x))
(rewrite (Mul (Sigmoid ?x) ?x) (Silu ?x))

; SwiGLU: silu(gate) * up → SwiGLU(gate, up)
(rewrite (Mul (Silu ?gate) ?up) (SwiGLU ?gate ?up))

",
    );

    // Encode every node (forward AND backward)
    for node in graph.nodes() {
        if matches!(node.op, Op::Nop) {
            continue;
        }
        let expr = node_to_egglog_expr(node);
        prog.push_str(&format!("(let n{} {})\n", node.id, expr));
    }

    // Run equality saturation with a node limit to keep saturation fast.
    // The fusion rules only need one pass — they're not iterative.
    prog.push_str("(run 1)\n\n");

    // Extract all output nodes (after saturation)
    for &out in graph.outputs() {
        if !matches!(graph.node(out).op, Op::Nop) {
            prog.push_str(&format!("(extract n{})\n", out));
        }
    }

    prog
}

fn node_to_egglog_expr(node: &Node) -> String {
    let i = &node.inputs;
    match node.op {
        Op::Input { ref name } => format!("(Input \"{}\")", name),
        Op::Parameter { ref name } => format!("(Parameter \"{}\")", name),
        Op::Constant { .. } => format!("(Const {})", node.id),
        Op::MatMul => format!("(MatMul n{} n{})", i[0], i[1]),
        Op::MatMulAT => format!("(MatMulAT n{} n{})", i[0], i[1]),
        Op::MatMulBT => format!("(MatMulBT n{} n{})", i[0], i[1]),
        Op::Add => format!("(Add n{} n{})", i[0], i[1]),
        Op::Mul => format!("(Mul n{} n{})", i[0], i[1]),
        Op::BiasAdd => format!("(BiasAdd n{} n{})", i[0], i[1]),
        Op::Relu => format!("(Relu n{})", i[0]),
        Op::Sigmoid => format!("(Sigmoid n{})", i[0]),
        Op::Tanh => format!("(Tanh n{})", i[0]),
        Op::Neg => format!("(Neg n{})", i[0]),
        Op::Abs => format!("(Abs n{})", i[0]),
        Op::Log => format!("(Log n{})", i[0]),
        Op::Recip => format!("(Recip n{})", i[0]),
        Op::ScatterAdd { vocab_size } => {
            format!("(ScatterAdd {} n{} n{})", vocab_size, i[0], i[1])
        }
        Op::Transpose => format!("(Transpose n{})", i[0]),
        Op::Softmax => format!("(Softmax n{})", i[0]),
        Op::LogSoftmax => format!("(LogSoftmax n{})", i[0]),
        Op::SumAll => format!("(SumAll n{})", i[0]),
        Op::MeanAll => format!("(MeanAll n{})", i[0]),
        Op::SumRows => format!("(SumRows n{})", i[0]),
        Op::CrossEntropyLoss => format!("(CrossEntropyLoss n{} n{})", i[0], i[1]),
        Op::BceLoss => format!("(BceLoss n{} n{})", i[0], i[1]),
        Op::Greater => format!("(Greater n{} n{})", i[0], i[1]),
        Op::Silu => format!("(Silu n{})", i[0]),
        Op::SwiGLU => format!("(SwiGLU n{} n{})", i[0], i[1]),
        Op::SwiGLUConcat => format!("(SwiGLUConcat n{})", i[0]),
        Op::Gelu => format!("(Gelu n{})", i[0]),
        Op::RmsNorm { .. } => format!("(RmsNorm n{} n{})", i[0], i[1]),
        Op::Embedding => format!("(Embedding n{} n{})", i[0], i[1]),
        Op::RoPE { .. } => format!("(RoPE n{})", i[0]),
        Op::RoPEGrad { .. } => format!("(RoPEGrad n{})", i[0]),
        Op::CausalAttention { .. } | Op::CausalAttentionRoPE { .. } => {
            format!("(CausalAttention n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::SlidingWindowAttention { .. } => {
            format!("(SlidingWindowAttention n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::LayerNorm { .. } => format!("(LayerNorm n{} n{} n{})", i[0], i[1], i[2]),
        Op::FullAttention { .. } => format!("(FullAttention n{} n{} n{})", i[0], i[1], i[2]),
        Op::CrossAttention { .. } => format!("(CrossAttention n{} n{} n{})", i[0], i[1], i[2]),
        Op::MultiHeadAttn { .. } => format!("(MultiHeadAttn n{} n{} n{})", i[0], i[1], i[2]),
        // Backward ops
        Op::SiluGrad => format!("(SiluGrad n{} n{})", i[0], i[1]),
        Op::SwiGLUGradGate => format!("(SwiGLUGradGate n{} n{} n{})", i[0], i[1], i[2]),
        Op::SwiGLUGradUp => format!("(SwiGLUGradUp n{} n{})", i[0], i[1]),
        Op::SwiGLUConcatGrad => format!("(SwiGLUConcatGrad n{} n{})", i[0], i[1]),
        Op::RmsNormGradW { .. } => format!("(RmsNormGradW n{} n{} n{})", i[0], i[1], i[2]),
        Op::RmsNormGradX { .. } => format!("(RmsNormGradX n{} n{} n{})", i[0], i[1], i[2]),
        Op::LayerNormGradWB { .. } => format!("(LayerNormGradWB n{} n{} n{})", i[0], i[1], i[2]),
        Op::LayerNormGradX { .. } => format!("(LayerNormGradX n{} n{} n{})", i[0], i[1], i[2]),
        Op::MultiHeadAttnGradQ { .. } => {
            format!("(MHAGradQ n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
        }
        Op::MultiHeadAttnGradK { .. } => {
            format!("(MHAGradK n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
        }
        Op::MultiHeadAttnGradV { .. } => {
            format!("(MHAGradV n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
        }
        // Fused ops from a previous optimization pass — encode as-is
        Op::FusedMatMulAdd => {
            format!("(FusedMatMulAdd n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::FusedMatMulATAdd => {
            format!("(FusedMatMulATAdd n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::FusedMatMulBTAdd => {
            format!("(FusedMatMulBTAdd n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::FusedRmsNormMatMul { .. } => {
            format!("(FusedRmsNormMatMul n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::GroupNorm { .. } => format!("(GroupNorm n{} n{} n{})", i[0], i[1], i[2]),
        Op::GroupNormSilu { .. } => format!("(GroupNormSilu n{} n{} n{})", i[0], i[1], i[2]),
        Op::GroupNormGradInput { .. } => {
            format!("(GroupNormGradInput n{} n{} n{})", i[0], i[1], i[2])
        }
        Op::GroupNormGradWeightBias { .. } => {
            format!("(GroupNormGradWeightBias n{} n{})", i[0], i[1])
        }
        Op::Concat { .. } => format!("(Concat n{} n{})", i[0], i[1]),
        Op::SplitA { .. } => format!("(SplitA n{})", i[0]),
        Op::SplitB { .. } => format!("(SplitB n{})", i[0]),
        Op::Upsample2x { .. } => format!("(Upsample2x n{})", i[0]),
        Op::Upsample2xGrad { .. } => format!("(Upsample2xGrad n{})", i[0]),
        Op::Conv2d { .. } => format!("(Conv2d n{} n{})", i[0], i[1]),
        Op::Conv2dGradInput { .. } => format!("(Conv2dGradInput n{} n{})", i[0], i[1]),
        Op::Conv2dGradWeight { .. } => format!("(Conv2dGradWeight n{} n{})", i[0], i[1]),
        Op::MaxPool2d { .. } => format!("(MaxPool2d n{})", i[0]),
        Op::GlobalAvgPool { .. } => format!("(GlobalAvgPool n{})", i[0]),
        Op::GlobalAvgPoolGrad { .. } => format!("(GlobalAvgPoolGrad n{})", i[0]),
        Op::CacheWrite => format!("(CacheWrite n{} n{} n{})", i[0], i[1], i[2]),
        Op::CachedAttention { .. } => {
            format!("(CachedAttention n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
        }
        Op::Nop => unreachable!("Nop nodes are filtered before encoding"),
        Op::Identity => format!("(Identity n{})", i[0]),
    }
}

/// Rebuild the graph using egglog extraction results.
///
/// Each extraction is a `(TermDag, TermId)` from egglog's `(extract ...)`.
/// We walk the term trees, matching them back to original graph nodes.
/// Where egglog chose a fused variant (e.g. FusedMatMulAdd instead of
/// Add(MatMul, x)), we apply the fusion in the graph.
///
/// Falls back to manual pattern matching if no extractions are available
/// (e.g. egglog failed or the graph has no extract commands).
fn rebuild_graph_from_extractions(
    original: &Graph,
    extractions: &[(TermDag, TermId)],
) -> (Graph, Vec<(String, u32)>) {
    let mut graph = clone_graph(original);
    let mut fusions = Vec::new();

    if !extractions.is_empty() {
        // Build a lookup: (op_name, input_node_ids) → graph node id.
        // This lets us match extracted terms back to original graph nodes.
        let mut node_lookup: HashMap<String, Vec<usize>> = HashMap::new();
        for node in graph.nodes() {
            let key = egglog_key(node);
            node_lookup.entry(key).or_default().push(node.id as usize);
        }

        // Walk each extracted term tree looking for fused ops that differ
        // from the original graph. When we find a FusedMatMul*Add that
        // corresponds to an original Add(MatMul*(...), d), apply the fusion.
        for &(ref dag, root) in extractions {
            scan_fusions(dag, root, &graph, &node_lookup, &mut fusions);
        }
    }

    // Apply fusion rules iteratively until fixpoint.
    // Each rule fires on matching patterns, potentially exposing new patterns
    // for subsequent rules (like e-graph saturation, but on the graph IR).
    loop {
        let n = fusions.len();
        apply_matmul_add_fusions(&mut graph, &mut fusions);
        apply_silu_fusions(&mut graph, &mut fusions);
        apply_swiglu_fusions(&mut graph, &mut fusions);
        apply_swiglu_concat_fusions(&mut graph, &mut fusions);
        // RmsNorm+MatMul fusion: saves 24 barriers but the coop variant's
        // 64-thread tree-reduction rsqrt prologue makes the fused kernel ~57%
        // slower than separate RmsNorm + coop MatMul, roughly breaking even.
        // Subgroup ops (subgroupAdd) would fix this but NVIDIA's driver
        // crashes when combining subgroup + cooperative matrix capabilities
        // in the same SPIR-V module: https://github.com/kvark/blade/issues/333
        // apply_rms_norm_matmul_fusions(&mut graph, &mut fusions);
        // apply_rope_attention_fusions(&mut graph, &mut fusions);
        if fusions.len() == n {
            break;
        }
    }
    let active_nodes = graph
        .nodes()
        .iter()
        .filter(|n| !matches!(n.op, Op::Nop))
        .count();
    log::info!(
        "optimizer: {} fusions on {} nodes",
        fusions.len(),
        active_nodes
    );
    for (name, count) in fusions.iter().fold(
        std::collections::BTreeMap::<&str, usize>::new(),
        |mut acc, entry| {
            let name = &entry.0;
            *acc.entry(name.as_str()).or_default() += 1;
            acc
        },
    ) {
        log::info!("  {}x {}", count, name);
    }

    (graph, fusions)
}

/// Generate a lookup key for a graph node (op name + input IDs).
fn egglog_key(node: &Node) -> String {
    let op_name = match node.op {
        Op::Input { ref name } => format!("Input:{}", name),
        Op::Parameter { ref name } => format!("Parameter:{}", name),
        Op::Constant { .. } => format!("Const:{}", node.id),
        _ => format!("{:?}", std::mem::discriminant(&node.op)),
    };
    format!("{}:{:?}", op_name, node.inputs)
}

/// Walk an extracted term tree looking for fused ops.
fn scan_fusions(
    dag: &TermDag,
    term_id: TermId,
    _graph: &Graph,
    _lookup: &HashMap<String, Vec<usize>>,
    _fusions: &mut Vec<(String, u32)>,
) {
    if let Term::App(name, children) = dag.get(term_id).clone() {
        if name.starts_with("FusedMatMul") || name.starts_with("FusedRmsNorm") {
            log::debug!("egglog discovered fusion: {}", name);
        }
        for child in children {
            scan_fusions(dag, child, _graph, _lookup, _fusions);
        }
    }
}

/// Apply Add(MatMul*(a, b), d) → FusedMatMul*Add(a, b, d) fusions.
///
/// This is the concrete graph mutation. It matches the patterns that
/// egglog's rewrite rules express, applying them with the additional
/// single-use constraint (the MatMul must feed only this Add).
fn apply_matmul_add_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::Add) {
            continue;
        }
        let (lhs, rhs) = (node.inputs[0], node.inputs[1]);
        let (mm_id, addend_id) =
            if matches!(graph.node(lhs).op, Op::MatMul | Op::MatMulAT | Op::MatMulBT) {
                (lhs, rhs)
            } else if matches!(graph.node(rhs).op, Op::MatMul | Op::MatMulAT | Op::MatMulBT) {
                (rhs, lhs)
            } else {
                continue;
            };
        let mm_use_count = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&mm_id) && !matches!(n.op, Op::Nop))
            .count();
        if mm_use_count != 1 {
            continue;
        }

        let mm_node = graph.node(mm_id);
        let (a, b) = (mm_node.inputs[0], mm_node.inputs[1]);
        let (fused_op, label) = match mm_node.op {
            Op::MatMul => (Op::FusedMatMulAdd, "MatMul+Add→FusedMatMulAdd"),
            Op::MatMulAT => (Op::FusedMatMulATAdd, "MatMulAT+Add→FusedMatMulATAdd"),
            Op::MatMulBT => (Op::FusedMatMulBTAdd, "MatMulBT+Add→FusedMatMulBTAdd"),
            _ => unreachable!(),
        };
        graph.nodes_mut()[id].op = fused_op;
        graph.nodes_mut()[id].inputs = vec![a, b, addend_id];
        graph.nodes_mut()[mm_id as usize].op = Op::Nop;
        fusions.push((label.to_string(), id as u32));
    }
}

/// Fuse SwiGLU(MatMul(h, w_gate), MatMul(h, w_up)) → SwiGLUConcat(MatMul(h, w_gate_up))
///
/// When both gate and up projections share the same input `h`, merge
/// them into a single wide matmul [hidden, 2*intermediate] followed by
/// SwiGLUConcat. This halves the number of matmul dispatches for the
/// MLP gate+up path. The optimizer creates a new `concat_weight` parameter
/// node so model code can use the naive two-matmul pattern.
fn apply_swiglu_concat_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    use crate::graph::TensorType;
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::SwiGLU) {
            continue;
        }
        let (gate_id, up_id) = (node.inputs[0], node.inputs[1]);
        let gate_node = graph.node(gate_id);
        let up_node = graph.node(up_id);

        // Both must be MatMul
        if !matches!(gate_node.op, Op::MatMul) || !matches!(up_node.op, Op::MatMul) {
            continue;
        }
        // Both must share the same input (first operand)
        if gate_node.inputs[0] != up_node.inputs[0] {
            continue;
        }
        // Both matmuls must be single-use (only feeding this SwiGLU)
        let gate_uses = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&gate_id) && !matches!(n.op, Op::Nop))
            .count();
        let up_uses = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&up_id) && !matches!(n.op, Op::Nop))
            .count();
        if gate_uses != 1 || up_uses != 1 {
            continue;
        }

        let h = gate_node.inputs[0];
        let w_gate = gate_node.inputs[1];
        let w_up = up_node.inputs[1];

        // Create concatenated weight parameter: [in_features, 2 * out_features]
        let gate_shape = &graph.node(w_gate).ty.shape;
        let up_shape = &graph.node(w_up).ty.shape;
        if gate_shape.len() != 2 || up_shape.len() != 2 {
            continue;
        }
        if gate_shape[0] != up_shape[0] || gate_shape[1] != up_shape[1] {
            continue;
        }
        let in_features = gate_shape[0];
        let out_features = gate_shape[1];
        let concat_shape = vec![in_features, 2 * out_features];
        let gate_name = match graph.node(w_gate).op {
            Op::Parameter { ref name } => name.clone(),
            _ => "w_gate".to_string(),
        };
        let up_name = match graph.node(w_up).op {
            Op::Parameter { ref name } => name.clone(),
            _ => "w_up".to_string(),
        };
        let concat_name = format!("{}+{}", gate_name, up_name);

        // Record derivation so runtime can fill this from original params
        graph.derived_params.push(crate::graph::DerivedParam {
            name: concat_name.clone(),
            sources: vec![(gate_name, out_features), (up_name, out_features)],
            rows: in_features,
        });
        let concat_w = graph.add_raw_node(
            Op::Parameter { name: concat_name },
            vec![],
            TensorType::f32(concat_shape.clone()),
        );

        // MatMul(h, concat_w) → [M, 2*out_features]
        let m = graph.node(h).ty.shape[0];
        let wide_mm = graph.add_raw_node(
            Op::MatMul,
            vec![h, concat_w],
            TensorType::f32(vec![m, 2 * out_features]),
        );

        // SwiGLUConcat(wide_mm) → [M, out_features]
        let swiglu_ty = TensorType::f32(vec![m, out_features]);
        graph.nodes_mut()[id].op = Op::SwiGLUConcat;
        graph.nodes_mut()[id].inputs = vec![wide_mm];
        graph.nodes_mut()[id].ty = swiglu_ty;

        // Mark old matmuls as Nop
        graph.nodes_mut()[gate_id as usize].op = Op::Nop;
        graph.nodes_mut()[up_id as usize].op = Op::Nop;

        fusions.push((
            "SwiGLU(MatMul,MatMul)→SwiGLUConcat(MatMul)".to_string(),
            id as u32,
        ));
    }
}

/// Fuse Silu(GroupNorm(x, w, b)) → GroupNormSilu(x, w, b)
///
/// Only fuses if the GroupNorm result is used exclusively by this Silu.
/// This is inference-only (backward pass can't differentiate through the fused op).
pub fn apply_group_norm_silu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::Silu) {
            continue;
        }
        let gn_id = node.inputs[0];
        let gn_node = graph.node(gn_id);
        let (num_groups, eps, channels, spatial) = match gn_node.op {
            Op::GroupNorm {
                num_groups,
                eps,
                channels,
                spatial,
            } => (num_groups, eps, channels, spatial),
            _ => continue,
        };
        // Only fuse if GroupNorm has a single consumer
        let gn_use_count = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&gn_id) && !matches!(n.op, Op::Nop))
            .count();
        if gn_use_count != 1 {
            continue;
        }
        let (x, w, b) = (gn_node.inputs[0], gn_node.inputs[1], gn_node.inputs[2]);
        // Rewrite Silu node to GroupNormSilu
        graph.nodes_mut()[id].op = Op::GroupNormSilu {
            num_groups,
            eps,
            channels,
            spatial,
        };
        graph.nodes_mut()[id].inputs = vec![x, w, b];
        // Mark old GroupNorm as Nop
        graph.nodes_mut()[gn_id as usize].op = Op::Nop;
        fusions.push(("GroupNorm+Silu→GroupNormSilu".to_string(), id as u32));
    }
}

/// Fuse MatMul(RmsNorm(x, w_norm, eps), w_proj) → FusedRmsNormMatMul(x, w_norm, w_proj, eps)
///
/// Only fuses if the RmsNorm result is used exclusively by this MatMul.
#[allow(dead_code)]
fn apply_rms_norm_matmul_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    use crate::graph::TensorType;
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::MatMul) {
            continue;
        }
        let (norm_id, w_proj_id) = (node.inputs[0], node.inputs[1]);
        let norm_node = graph.node(norm_id);
        let eps = match norm_node.op {
            Op::RmsNorm { eps } => eps,
            _ => continue,
        };
        // RmsNorm must be single-use (only feeding this MatMul)
        let norm_use_count = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&norm_id) && !matches!(n.op, Op::Nop))
            .count();
        if norm_use_count != 1 {
            continue;
        }

        let x = norm_node.inputs[0];
        let w_norm = norm_node.inputs[1];
        let x_shape = &graph.node(x).ty.shape;
        let w_proj_shape = &graph.node(w_proj_id).ty.shape;
        let m = x_shape[0];
        let n = w_proj_shape[1];

        // Rewrite the MatMul node to FusedRmsNormMatMul
        graph.nodes_mut()[id].op = Op::FusedRmsNormMatMul { eps };
        graph.nodes_mut()[id].inputs = vec![x, w_norm, w_proj_id];
        graph.nodes_mut()[id].ty = TensorType::f32(vec![m, n]);
        // Mark old RmsNorm as Nop
        graph.nodes_mut()[norm_id as usize].op = Op::Nop;

        fusions.push(("RmsNorm+MatMul→FusedRmsNormMatMul".to_string(), id as u32));
    }
}

/// Fuse CausalAttention(RoPE(Q), RoPE(K), V) → CausalAttentionRoPE(Q, K, V)
///
/// When both Q and K inputs to CausalAttention are single-use RoPE nodes
/// with the same theta, replace with CausalAttentionRoPE which applies
/// RoPE inside the attention kernel's dot product. Eliminates 2 dispatches
/// + 1 barrier group per attention layer.
#[allow(dead_code)]
fn apply_rope_attention_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        let (num_heads, num_kv_heads, head_dim) = match node.op {
            Op::CausalAttention {
                num_heads,
                num_kv_heads,
                head_dim,
            } => (num_heads, num_kv_heads, head_dim),
            _ => continue,
        };

        let q_id = node.inputs[0];
        let k_id = node.inputs[1];
        let v_id = node.inputs[2];

        // Both Q and K must be RoPE nodes
        let q_node = graph.node(q_id);
        let k_node = graph.node(k_id);
        let (q_theta, q_raw) = match q_node.op {
            Op::RoPE { theta, .. } => (theta, q_node.inputs[0]),
            _ => continue,
        };
        let (k_theta, k_raw) = match k_node.op {
            Op::RoPE { theta, .. } => (theta, k_node.inputs[0]),
            _ => continue,
        };

        // Same theta
        if q_theta != k_theta {
            continue;
        }

        // Both RoPE nodes must be single-use (only feeding this attention)
        let q_uses = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&q_id) && !matches!(n.op, Op::Nop))
            .count();
        let k_uses = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&k_id) && !matches!(n.op, Op::Nop))
            .count();
        if q_uses != 1 || k_uses != 1 {
            continue;
        }

        // Replace CausalAttention with CausalAttentionRoPE using un-rotated Q, K
        graph.nodes_mut()[id].op = Op::CausalAttentionRoPE {
            num_heads,
            num_kv_heads,
            head_dim,
            rope_theta: q_theta,
        };
        graph.nodes_mut()[id].inputs = vec![q_raw, k_raw, v_id];
        // Mark old RoPE nodes as Nop
        graph.nodes_mut()[q_id as usize].op = Op::Nop;
        graph.nodes_mut()[k_id as usize].op = Op::Nop;

        fusions.push((
            "CausalAttn(RoPE,RoPE)→CausalAttnRoPE".to_string(),
            id as u32,
        ));
    }
}

/// Fuse Mul(x, Sigmoid(x)) → Silu(x) via direct pattern matching.
///
/// PyTorch decomposes Silu to x * sigmoid(x) in ONNX exports.
fn apply_silu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::Mul) {
            continue;
        }
        let (a_id, b_id) = (node.inputs[0], node.inputs[1]);
        // Check: Mul(x, Sigmoid(x)) or Mul(Sigmoid(x), x)
        let (x, sig_id) = if matches!(graph.node(b_id).op, Op::Sigmoid)
            && graph.node(b_id).inputs[0] == a_id
        {
            (a_id, b_id)
        } else if matches!(graph.node(a_id).op, Op::Sigmoid) && graph.node(a_id).inputs[0] == b_id {
            (b_id, a_id)
        } else {
            continue;
        };
        // Only fuse if Sigmoid has a single consumer (this Mul)
        let sig_use_count = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&sig_id) && !matches!(n.op, Op::Nop))
            .count();
        if sig_use_count != 1 {
            continue;
        }
        graph.nodes_mut()[id].op = Op::Silu;
        graph.nodes_mut()[id].inputs = vec![x];
        graph.nodes_mut()[sig_id as usize].op = Op::Nop;
        fusions.push(("Mul+Sigmoid→Silu".to_string(), id as u32));
    }
}

/// Fuse Mul(Silu(gate), up) → SwiGLU(gate, up) via direct pattern matching.
fn apply_swiglu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
    let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
    for &id in &node_ids {
        let node = &graph.nodes()[id];
        if !matches!(node.op, Op::Mul) {
            continue;
        }
        let (a_id, b_id) = (node.inputs[0], node.inputs[1]);
        // Check: Mul(Silu(gate), up)
        let (gate, up, silu_id) = if matches!(graph.node(a_id).op, Op::Silu) {
            (graph.node(a_id).inputs[0], b_id, a_id)
        } else if matches!(graph.node(b_id).op, Op::Silu) {
            (graph.node(b_id).inputs[0], a_id, b_id)
        } else {
            continue;
        };
        // Only fuse if Silu has a single consumer
        let silu_use_count = graph
            .nodes()
            .iter()
            .filter(|n| n.inputs.contains(&silu_id) && !matches!(n.op, Op::Nop))
            .count();
        if silu_use_count != 1 {
            continue;
        }
        graph.nodes_mut()[id].op = Op::SwiGLU;
        graph.nodes_mut()[id].inputs = vec![gate, up];
        graph.nodes_mut()[silu_id as usize].op = Op::Nop;
        fusions.push(("Silu+Mul→SwiGLU".to_string(), id as u32));
    }
}

fn clone_graph(graph: &Graph) -> Graph {
    let mut new_graph = Graph::new();
    for node in graph.nodes() {
        new_graph.add_raw_node(node.op.clone(), node.inputs.clone(), node.ty.clone());
    }
    new_graph.set_outputs(graph.outputs().to_vec());
    new_graph.derived_params = graph.derived_params.clone();
    new_graph
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_no_fusion_cooperative_matrix() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 784]);
        let w = g.parameter("w", &[784, 128]);
        let mm = g.matmul(x, w);
        let h = g.relu(mm);
        g.set_outputs(vec![h]);

        let opt = optimize(&g);
        let output_id = opt.outputs()[0];
        let output_node = opt.node(output_id);
        assert!(
            matches!(output_node.op, Op::Relu),
            "expected Relu (no fusion), got {:?}",
            output_node.op
        );
    }

    #[test]
    fn test_optimize_report() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 784]);
        let w1 = g.parameter("w1", &[784, 128]);
        let mm1 = g.matmul(x, w1);
        let h1 = g.relu(mm1);
        let w2 = g.parameter("w2", &[128, 10]);
        let mm2 = g.matmul(h1, w2);
        let h2 = g.relu(mm2);
        g.set_outputs(vec![h2]);

        let (_opt, report) = optimize_with_report(&g);
        assert!(report.fusions_applied.is_empty());
        let display = format!("{}", report);
        assert!(display.contains("Optimization Report"));
    }

    #[test]
    fn test_egglog_roundtrip() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 10]);
        let w = g.parameter("w", &[10, 5]);
        let y = g.matmul(x, w);
        g.set_outputs(vec![y]);

        let program = graph_to_egglog(&g);
        assert!(program.contains("(MatMul"));
        assert!(program.contains("(Input \"x\")"));

        let mut egraph = egglog::EGraph::default();
        egraph.parse_and_run_program(None, &program).unwrap();
    }

    /// Verify egglog extraction returns fused terms via TermDag.
    #[test]
    fn test_egglog_extract_returns_fused() {
        let mut egraph = egglog::EGraph::default();
        let outputs = egraph
            .parse_and_run_program(
                None,
                r#"
(datatype Op
  (MatMul Op Op)
  (MatMulBT Op Op)
  (Add Op Op)
  (FusedMatMulAdd Op Op Op)
  (FusedMatMulBTAdd Op Op Op)
  (Input String)
  (Parameter String)
)
(rewrite (Add (MatMul ?a ?b) ?d) (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add (MatMulBT ?a ?b) ?d) (FusedMatMulBTAdd ?a ?b ?d))
(rewrite (Add ?x ?y) (Add ?y ?x))

(let n0 (Input "x"))
(let n1 (Parameter "w"))
(let n2 (MatMul n0 n1))
(let n3 (Input "bias"))
(let n4 (Add n2 n3))
(run 10)
(extract n4)
"#,
            )
            .unwrap();
        // Find the ExtractBest output
        let mut found_fused = false;
        for out in &outputs {
            if let egglog::CommandOutput::ExtractBest(dag, _cost, term_id) = out {
                let s = dag.to_string(*term_id);
                eprintln!("egglog extracted: {}", s);
                assert!(
                    s.contains("FusedMatMulAdd"),
                    "expected FusedMatMulAdd, got: {}",
                    s
                );
                // Verify the term tree structure
                match dag.get(*term_id) {
                    Term::App(name, _children) => {
                        assert_eq!(name, "FusedMatMulAdd");
                    }
                    other => panic!("expected App, got {:?}", other),
                }
                found_fused = true;
            }
        }
        assert!(found_fused, "no ExtractBest output found");
    }

    #[test]
    fn test_optimize_preserves_graph() {
        let mut g = Graph::new();
        let a = g.input("a", &[4, 8]);
        let b = g.input("b", &[4, 8]);
        let sum = g.add(a, b);
        let neg = g.neg(sum);
        g.set_outputs(vec![neg]);

        let opt = optimize(&g);
        assert_eq!(opt.nodes().len(), g.nodes().len());
        let out = opt.node(opt.outputs()[0]);
        assert!(matches!(out.op, Op::Neg));
    }

    #[test]
    fn test_dump_egglog_program() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let w = g.parameter("w", &[8, 4]);
        let y = g.matmul(x, w);
        let _h = g.relu(y);
        g.set_outputs(vec![y]);

        let program = dump_egglog_program(&g);
        assert!(program.contains("(datatype Op"));
        assert!(program.contains("(extract n"));
    }

    #[test]
    fn test_egglog_all_ops() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let w = g.parameter("w", &[8, 4]);
        let _c = g.constant(vec![0.0; 32], &[4, 8]);
        let mm = g.matmul(x, w);
        let _a = g.add(mm, mm);
        let _m = g.mul(mm, mm);
        let b = g.parameter("b", &[4]);
        let _ba = g.bias_add(mm, b);
        let _r = g.relu(mm);
        let _s = g.sigmoid(mm);
        let _n = g.neg(mm);
        let _t = g.transpose(mm);
        let _sm = g.softmax(mm);
        let _lsm = g.log_softmax(mm);
        let sa = g.sum_all(mm);
        let _ma = g.mean_all(mm);
        let _gt = g.greater(mm, mm);
        let _cel = g.cross_entropy_loss(mm, mm);
        g.set_outputs(vec![sa]);

        let program = graph_to_egglog(&g);
        let mut egraph = egglog::EGraph::default();
        egraph.parse_and_run_program(None, &program).unwrap();
    }

    #[test]
    fn test_clone_graph_preserves_structure() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let w = g.parameter("w", &[8, 4]);
        let y = g.matmul(x, w);
        g.set_outputs(vec![y]);

        let cloned = clone_graph(&g);
        assert_eq!(cloned.nodes().len(), g.nodes().len());
        assert_eq!(cloned.outputs(), g.outputs());
        for (a, b) in cloned.nodes().iter().zip(g.nodes().iter()) {
            assert_eq!(a.id, b.id);
            assert_eq!(a.inputs, b.inputs);
            assert_eq!(a.ty.shape, b.ty.shape);
        }
    }

    #[test]
    fn test_matmul_stays_as_matmul() {
        let mut g = Graph::new();
        let x = g.input("x", &[2, 1024]);
        let w = g.parameter("w", &[1024, 64]);
        let y = g.matmul(x, w);
        g.set_outputs(vec![y]);

        let opt = optimize(&g);
        let output_id = opt.outputs()[0];
        assert!(
            matches!(opt.node(output_id).op, Op::MatMul),
            "expected MatMul, got {:?}",
            opt.node(output_id).op
        );
    }

    /// Measure egglog saturation time vs graph size.
    #[test]
    fn test_egglog_scalability() {
        for n in [10, 50, 100, 200, 350] {
            let mut prog = String::from(
                "(datatype Op
  (MatMul Op Op) (MatMulAT Op Op) (MatMulBT Op Op)
  (Add Op Op) (Input String) (Parameter String)
  (FusedMatMulAdd Op Op Op) (FusedMatMulATAdd Op Op Op) (FusedMatMulBTAdd Op Op Op)
)\n",
            );
            prog.push_str("(rewrite (Add (MatMul ?a ?b) ?d) (FusedMatMulAdd ?a ?b ?d))\n");
            prog.push_str("(rewrite (Add ?d (MatMul ?a ?b)) (FusedMatMulAdd ?a ?b ?d))\n");
            prog.push_str("(rewrite (Add (MatMulAT ?a ?b) ?d) (FusedMatMulATAdd ?a ?b ?d))\n");
            prog.push_str("(rewrite (Add ?d (MatMulAT ?a ?b)) (FusedMatMulATAdd ?a ?b ?d))\n");
            prog.push_str("(rewrite (Add (MatMulBT ?a ?b) ?d) (FusedMatMulBTAdd ?a ?b ?d))\n");
            prog.push_str("(rewrite (Add ?d (MatMulBT ?a ?b)) (FusedMatMulBTAdd ?a ?b ?d))\n");

            prog.push_str("(let n0 (Input \"x\"))\n(let n1 (Parameter \"w\"))\n");
            for i in 1..n {
                let prev = (i - 1) * 2 + 2;
                match i % 3 {
                    0 => prog.push_str(&format!("(let n{} (MatMulAT n{} n1))\n", i * 2, prev - 1)),
                    1 => prog.push_str(&format!("(let n{} (MatMulBT n{} n1))\n", i * 2, prev - 1)),
                    _ => prog.push_str(&format!("(let n{} (MatMul n{} n1))\n", i * 2, prev - 1)),
                }
                prog.push_str(&format!(
                    "(let n{} (Add n{} n{}))\n",
                    i * 2 + 1,
                    i * 2,
                    prev - 1
                ));
            }
            prog.push_str("(run 1)\n");
            let last = (n - 1) * 2 + 1;
            prog.push_str(&format!("(extract n{})\n", last));

            let t0 = Instant::now();
            let mut egraph = egglog::EGraph::default();
            egraph.parse_and_run_program(None, &prog).unwrap();
            let elapsed = t0.elapsed();
            eprintln!(
                "egglog scalability: n={:>4} nodes -> {:>8.1}ms",
                n * 2,
                elapsed.as_secs_f64() * 1000.0
            );
        }
    }

    /// E-graph discovers MatMul+Add → FusedMatMulAdd.
    #[test]
    fn test_egglog_discovers_matmul_add_fusion() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let w = g.parameter("w", &[8, 4]);
        let b = g.input("bias", &[4, 4]);
        let mm = g.matmul(x, w);
        let out = g.add(mm, b);
        g.set_outputs(vec![out]);

        let (opt, report) = optimize_with_report(&g);
        let output_node = opt.node(opt.outputs()[0]);
        assert!(
            matches!(output_node.op, Op::FusedMatMulAdd),
            "expected FusedMatMulAdd, got {:?}",
            output_node.op
        );
        assert!(!report.fusions_applied.is_empty());
    }

    /// SwiGLU(MatMul, MatMul) → SwiGLUConcat(MatMul) fusion.
    #[test]
    fn test_swiglu_concat_fusion() {
        let mut g = Graph::new();
        let h = g.input("h", &[50, 720]);
        let w_gate = g.parameter("w_gate", &[720, 2048]);
        let w_up = g.parameter("w_up", &[720, 2048]);
        let gate = g.matmul(h, w_gate);
        let up = g.matmul(h, w_up);
        let out = g.swiglu(gate, up);
        g.set_outputs(vec![out]);

        let (opt, report) = optimize_with_report(&g);
        let output_node = opt.node(opt.outputs()[0]);
        assert!(
            matches!(output_node.op, Op::SwiGLUConcat),
            "expected SwiGLUConcat, got {:?}",
            output_node.op
        );
        assert!(
            report
                .fusions_applied
                .iter()
                .any(|(name, _)| name.contains("SwiGLU")),
            "no SwiGLU fusion in report: {:?}",
            report.fusions_applied
        );
        // The fused matmul should have shape [50, 4096] (2*2048)
        let mm_id = output_node.inputs[0];
        let mm_node = opt.node(mm_id);
        assert!(matches!(mm_node.op, Op::MatMul));
        assert_eq!(mm_node.ty.shape, vec![50, 4096]);
    }

    /// Backward ops are encoded into egglog (not skipped).
    #[test]
    fn test_egglog_encodes_backward_ops() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let w = g.parameter("w", &[8, 4]);
        let at = g.add_raw_node(
            Op::MatMulAT,
            vec![x, x],
            crate::graph::TensorType::f32(vec![8, 8]),
        );
        let bt = g.add_raw_node(
            Op::MatMulBT,
            vec![x, w],
            crate::graph::TensorType::f32(vec![4, 8]),
        );
        g.set_outputs(vec![at, bt]);

        let program = graph_to_egglog(&g);
        assert!(program.contains("MatMulAT"), "MatMulAT not encoded");
        assert!(program.contains("MatMulBT"), "MatMulBT not encoded");

        let mut egraph = egglog::EGraph::default();
        egraph
            .parse_and_run_program(None, &program)
            .expect("egglog failed with backward ops");
    }

    /// E-graph discovers MatMulBT+Add → FusedMatMulBTAdd on backward ops.
    #[test]
    fn test_egglog_discovers_backward_bt_add_fusion() {
        let mut g = Graph::new();
        let grad = g.input("grad", &[4, 8]);
        let w = g.parameter("w", &[4, 8]);
        let prev = g.input("prev_grad", &[4, 4]);
        let bt = g.add_raw_node(
            Op::MatMulBT,
            vec![grad, w],
            crate::graph::TensorType::f32(vec![4, 4]),
        );
        let out = g.add(bt, prev);
        g.set_outputs(vec![out]);

        let (opt, report) = optimize_with_report(&g);
        let output_node = opt.node(opt.outputs()[0]);
        assert!(
            matches!(output_node.op, Op::FusedMatMulBTAdd),
            "expected FusedMatMulBTAdd, got {:?}",
            output_node.op
        );
        assert!(
            report
                .fusions_applied
                .iter()
                .any(|(name, _)| name.contains("BT")),
            "no BT fusion in report"
        );
    }

    /// E-graph recognizes x * sigmoid(x) → Silu(x).
    #[test]
    fn test_silu_fusion() {
        let mut g = Graph::new();
        let x = g.input("x", &[4, 8]);
        let sig = g.sigmoid(x);
        let out = g.mul(x, sig);
        g.set_outputs(vec![out]);

        let (opt, report) = optimize_with_report(&g);
        // The output should now be Silu
        let has_silu = opt.nodes().iter().any(|n| matches!(n.op, Op::Silu));
        assert!(
            has_silu,
            "expected Silu fusion, got nodes: {:?}",
            opt.nodes()
                .iter()
                .map(|n| format!("{:?}", n.op))
                .collect::<Vec<_>>()
        );
        assert!(
            !report.fusions_applied.is_empty() || has_silu,
            "no Silu fusion detected"
        );
    }

    /// Pattern matcher recognizes Silu+Mul → SwiGLU.
    #[test]
    fn test_swiglu_from_decomposed() {
        let mut g = Graph::new();
        let gate = g.input("gate", &[4, 8]);
        let up = g.input("up", &[4, 8]);
        // Decomposed SwiGLU: silu(gate) * up
        let sig = g.sigmoid(gate);
        let silu = g.mul(gate, sig);
        let out = g.mul(silu, up);
        g.set_outputs(vec![out]);

        let (opt, _report) = optimize_with_report(&g);
        let has_swiglu = opt.nodes().iter().any(|n| matches!(n.op, Op::SwiGLU));
        assert!(
            has_swiglu,
            "expected SwiGLU fusion from decomposed silu*up, got nodes: {:?}",
            opt.nodes()
                .iter()
                .map(|n| format!("{:?}", n.op))
                .collect::<Vec<_>>()
        );
    }

    /// MaxPool2d and GlobalAvgPool survive e-graph optimization unchanged.
    #[test]
    fn test_pool_ops_roundtrip() {
        let mut g = Graph::new();
        let x = g.input("x", &[1 * 64 * 8 * 8]);
        let pool = g.max_pool_2d(x, 1, 64, 8, 8, 2, 2, 2, 0);
        let gap = g.global_avg_pool(pool, 1, 64, 16);
        g.set_outputs(vec![gap]);

        let (opt, _report) = optimize_with_report(&g);
        let has_maxpool = opt
            .nodes()
            .iter()
            .any(|n| matches!(n.op, Op::MaxPool2d { .. }));
        let has_gap = opt
            .nodes()
            .iter()
            .any(|n| matches!(n.op, Op::GlobalAvgPool { .. }));
        assert!(has_maxpool, "MaxPool2d should survive optimization");
        assert!(has_gap, "GlobalAvgPool should survive optimization");
    }
}