use crate::graph::{Graph, Node, Op};
use egglog::{Term, TermDag, TermId};
use std::collections::HashMap;
use std::{fmt, time::Instant};
pub struct OptimizeReport {
pub egglog_program: String,
pub num_eclasses: usize,
pub num_enodes: usize,
pub rules_fired: Vec<(String, usize)>,
pub nodes_before: usize,
pub nodes_after: usize,
pub fusions_applied: Vec<(String, u32)>,
pub egglog_time: std::time::Duration,
pub extract_time: std::time::Duration,
}
impl fmt::Display for OptimizeReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "=== Optimization Report ===")?;
writeln!(
f,
"Egglog saturation: {:.1}ms ({} e-classes, {} e-nodes)",
self.egglog_time.as_secs_f64() * 1000.0,
self.num_eclasses,
self.num_enodes,
)?;
if !self.rules_fired.is_empty() {
writeln!(f, "Rules fired:")?;
for &(ref rule, count) in &self.rules_fired {
writeln!(f, " {} x{}", rule, count)?;
}
}
writeln!(
f,
"Graph: {} nodes -> {} active nodes ({} fused away)",
self.nodes_before,
self.nodes_after,
self.nodes_before.saturating_sub(self.nodes_after),
)?;
if !self.fusions_applied.is_empty() {
write!(f, "Fusions:")?;
for (i, &(ref name, node_idx)) in self.fusions_applied.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, " {} @node{}", name, node_idx)?;
}
writeln!(f)?;
}
write!(
f,
"Extract time: {:.1}ms",
self.extract_time.as_secs_f64() * 1000.0
)
}
}
pub fn optimize(graph: &Graph) -> Graph {
let (graph, _report) = optimize_with_report(graph);
graph
}
pub fn optimize_with_report(graph: &Graph) -> (Graph, OptimizeReport) {
let program = graph_to_egglog(graph);
log::debug!("egglog program:\n{}", program);
let nodes_before = graph.nodes().len();
let mut num_eclasses = 0;
let mut num_enodes = 0;
let node_count = graph
.nodes()
.iter()
.filter(|n| !matches!(n.op, Op::Nop))
.count();
let egglog_start = Instant::now();
if node_count > 300 {
log::debug!(
"egglog: {} nodes, falling back to pattern matching",
node_count
);
let extract_start = Instant::now();
let (optimized, fusions_applied) = rebuild_graph_from_extractions(graph, &[]);
let extract_time = extract_start.elapsed();
return (
optimized,
OptimizeReport {
egglog_program: program,
num_eclasses: 0,
num_enodes: 0,
rules_fired: fusions_applied.iter().fold(Vec::new(), |mut acc, entry| {
let name = &entry.0;
if let Some(e) = acc.iter_mut().find(|e: &&mut (String, usize)| e.0 == *name) {
e.1 += 1;
} else {
acc.push((name.clone(), 1));
}
acc
}),
nodes_before,
nodes_after: 0,
fusions_applied,
egglog_time: std::time::Duration::ZERO,
extract_time,
},
);
}
let mut egraph = egglog::EGraph::default();
let egglog_result = egraph.parse_and_run_program(None, &program);
log::debug!(
"egglog: saturation took {:.1}ms",
egglog_start.elapsed().as_secs_f64() * 1000.0
);
let egglog_ok;
let mut extractions: Vec<(TermDag, TermId)> = Vec::new();
match egglog_result {
Ok(outputs) => {
egglog_ok = true;
for out in &outputs {
if let egglog::CommandOutput::ExtractBest(ref dag, _cost, term_id) = *out {
log::debug!("egglog extracted: {}", dag.to_string(term_id));
extractions.push((dag.clone(), term_id));
}
}
}
Err(e) => {
log::warn!(
"egglog optimization failed: {}, returning original graph",
e
);
egglog_ok = false;
}
};
let egglog_time = egglog_start.elapsed();
if egglog_ok {
let serialized = egraph.serialize(egglog::SerializeConfig::default());
num_eclasses = serialized.egraph.class_data.len();
num_enodes = serialized.egraph.nodes.len();
}
let extract_start = Instant::now();
let (optimized, fusions_applied) = rebuild_graph_from_extractions(graph, &extractions);
let extract_time = extract_start.elapsed();
let nodes_after = optimized
.nodes()
.iter()
.filter(|n| !matches!(n.op, Op::Nop))
.count();
let mut rules_fired: Vec<(String, usize)> = Vec::new();
for fusion in &fusions_applied {
if let Some(entry) = rules_fired.iter_mut().find(|e| e.0 == fusion.0) {
entry.1 += 1;
} else {
rules_fired.push((fusion.0.clone(), 1));
}
}
let report = OptimizeReport {
egglog_program: program,
num_eclasses,
num_enodes,
rules_fired,
nodes_before,
nodes_after,
fusions_applied,
egglog_time,
extract_time,
};
(optimized, report)
}
pub fn dump_egglog_program(graph: &Graph) -> String {
graph_to_egglog(graph)
}
fn graph_to_egglog(graph: &Graph) -> String {
let mut prog = String::new();
prog.push_str(
"\
(datatype Op
; --- Leaf nodes ---
(Input String)
(Parameter String)
(Const i64)
; --- Forward matmul variants ---
(MatMul Op Op)
(MatMulAT Op Op)
(MatMulBT Op Op)
; --- Fused matmul+add (targets for fusion rules) ---
(FusedMatMulAdd Op Op Op)
(FusedMatMulATAdd Op Op Op)
(FusedMatMulBTAdd Op Op Op)
; --- Element-wise ---
(Add Op Op)
(Mul Op Op)
(BiasAdd Op Op)
(Relu Op)
(Sigmoid Op)
(Tanh Op)
(Neg Op)
(Abs Op)
(Log Op)
(Recip Op)
(ScatterAdd i64 Op Op)
(Silu Op)
(Gelu Op)
(Identity Op)
; --- Shape / reduction ---
(Transpose Op)
(Softmax Op)
(LogSoftmax Op)
(SumAll Op)
(MeanAll Op)
(SumRows Op)
(CrossEntropyLoss Op Op)
(BceLoss Op Op)
(Greater Op Op)
; --- Transformer forward ---
(SwiGLU Op Op)
(SwiGLUConcat Op)
(RmsNorm Op Op)
(FusedRmsNormMatMul Op Op Op)
(Embedding Op Op)
(RoPE Op)
(RoPEGrad Op)
(CausalAttention Op Op Op)
(SlidingWindowAttention Op Op Op)
(LayerNorm Op Op Op)
(FullAttention Op Op Op)
(CrossAttention Op Op Op)
(MultiHeadAttn Op Op Op)
; --- GroupNorm, Concat, Upsample, Conv2d ops ---
(GroupNorm Op Op Op)
(GroupNormSilu Op Op Op)
(GroupNormGradInput Op Op Op)
(GroupNormGradWeightBias Op Op)
(Concat Op Op)
(SplitA Op)
(SplitB Op)
(Upsample2x Op)
(Upsample2xGrad Op)
(Conv2d Op Op)
(Conv2dGradInput Op Op)
(Conv2dGradWeight Op Op)
(MaxPool2d Op)
(GlobalAvgPool Op)
(GlobalAvgPoolGrad Op)
; --- KV cache ops ---
(CacheWrite Op Op Op)
(CachedAttention Op Op Op Op)
; --- Backward / gradient ops ---
(SiluGrad Op Op)
(SwiGLUGradGate Op Op Op)
(SwiGLUGradUp Op Op)
(SwiGLUConcatGrad Op Op)
(RmsNormGradW Op Op Op)
(RmsNormGradX Op Op Op)
(LayerNormGradWB Op Op Op)
(LayerNormGradX Op Op Op)
(MHAGradQ Op Op Op Op)
(MHAGradK Op Op Op Op)
(MHAGradV Op Op Op Op)
)
",
);
prog.push_str(
"\
; --- Algebraic simplifications ---
(rewrite (Neg (Neg ?x)) ?x)
(rewrite (Transpose (Transpose ?x)) ?x)
(rewrite (Relu (Relu ?x)) (Relu ?x))
; --- Kernel fusion: Add(MatMul*(a,b), d) → FusedMatMul*Add(a,b,d) ---
; Both argument orders handled explicitly (no general Add commutativity
; rule, which causes exponential blowup on large graphs).
(rewrite (Add (MatMul ?a ?b) ?d) (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add ?d (MatMul ?a ?b)) (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add (MatMulAT ?a ?b) ?d) (FusedMatMulATAdd ?a ?b ?d))
(rewrite (Add ?d (MatMulAT ?a ?b)) (FusedMatMulATAdd ?a ?b ?d))
(rewrite (Add (MatMulBT ?a ?b) ?d) (FusedMatMulBTAdd ?a ?b ?d))
(rewrite (Add ?d (MatMulBT ?a ?b)) (FusedMatMulBTAdd ?a ?b ?d))
; --- RmsNorm+MatMul fusion ---
(rewrite (MatMul (RmsNorm ?x ?w_norm) ?w_proj) (FusedRmsNormMatMul ?x ?w_norm ?w_proj))
; --- SwiGLU fusion: two matmuls sharing input → single wide matmul ---
; SwiGLU(MatMul(h, w1), MatMul(h, w2)) can use SwiGLUConcat on a
; concatenated [h, w1|w2] matmul. Pattern matcher handles weight
; concatenation since egglog can't create new tensors.
; (documented here; applied by apply_swiglu_concat_fusions)
; --- ONNX decomposed op recognition ---
; PyTorch decomposes compound ops when exporting to ONNX.
; These rules recognize the decomposed patterns and fuse them back
; into our efficient compound kernels.
; Silu: x * sigmoid(x) → Silu(x)
(rewrite (Mul ?x (Sigmoid ?x)) (Silu ?x))
(rewrite (Mul (Sigmoid ?x) ?x) (Silu ?x))
; SwiGLU: silu(gate) * up → SwiGLU(gate, up)
(rewrite (Mul (Silu ?gate) ?up) (SwiGLU ?gate ?up))
",
);
for node in graph.nodes() {
if matches!(node.op, Op::Nop) {
continue;
}
let expr = node_to_egglog_expr(node);
prog.push_str(&format!("(let n{} {})\n", node.id, expr));
}
prog.push_str("(run 1)\n\n");
for &out in graph.outputs() {
if !matches!(graph.node(out).op, Op::Nop) {
prog.push_str(&format!("(extract n{})\n", out));
}
}
prog
}
fn node_to_egglog_expr(node: &Node) -> String {
let i = &node.inputs;
match node.op {
Op::Input { ref name } => format!("(Input \"{}\")", name),
Op::Parameter { ref name } => format!("(Parameter \"{}\")", name),
Op::Constant { .. } => format!("(Const {})", node.id),
Op::MatMul => format!("(MatMul n{} n{})", i[0], i[1]),
Op::MatMulAT => format!("(MatMulAT n{} n{})", i[0], i[1]),
Op::MatMulBT => format!("(MatMulBT n{} n{})", i[0], i[1]),
Op::Add => format!("(Add n{} n{})", i[0], i[1]),
Op::Mul => format!("(Mul n{} n{})", i[0], i[1]),
Op::BiasAdd => format!("(BiasAdd n{} n{})", i[0], i[1]),
Op::Relu => format!("(Relu n{})", i[0]),
Op::Sigmoid => format!("(Sigmoid n{})", i[0]),
Op::Tanh => format!("(Tanh n{})", i[0]),
Op::Neg => format!("(Neg n{})", i[0]),
Op::Abs => format!("(Abs n{})", i[0]),
Op::Log => format!("(Log n{})", i[0]),
Op::Recip => format!("(Recip n{})", i[0]),
Op::ScatterAdd { vocab_size } => {
format!("(ScatterAdd {} n{} n{})", vocab_size, i[0], i[1])
}
Op::Transpose => format!("(Transpose n{})", i[0]),
Op::Softmax => format!("(Softmax n{})", i[0]),
Op::LogSoftmax => format!("(LogSoftmax n{})", i[0]),
Op::SumAll => format!("(SumAll n{})", i[0]),
Op::MeanAll => format!("(MeanAll n{})", i[0]),
Op::SumRows => format!("(SumRows n{})", i[0]),
Op::CrossEntropyLoss => format!("(CrossEntropyLoss n{} n{})", i[0], i[1]),
Op::BceLoss => format!("(BceLoss n{} n{})", i[0], i[1]),
Op::Greater => format!("(Greater n{} n{})", i[0], i[1]),
Op::Silu => format!("(Silu n{})", i[0]),
Op::SwiGLU => format!("(SwiGLU n{} n{})", i[0], i[1]),
Op::SwiGLUConcat => format!("(SwiGLUConcat n{})", i[0]),
Op::Gelu => format!("(Gelu n{})", i[0]),
Op::RmsNorm { .. } => format!("(RmsNorm n{} n{})", i[0], i[1]),
Op::Embedding => format!("(Embedding n{} n{})", i[0], i[1]),
Op::RoPE { .. } => format!("(RoPE n{})", i[0]),
Op::RoPEGrad { .. } => format!("(RoPEGrad n{})", i[0]),
Op::CausalAttention { .. } | Op::CausalAttentionRoPE { .. } => {
format!("(CausalAttention n{} n{} n{})", i[0], i[1], i[2])
}
Op::SlidingWindowAttention { .. } => {
format!("(SlidingWindowAttention n{} n{} n{})", i[0], i[1], i[2])
}
Op::LayerNorm { .. } => format!("(LayerNorm n{} n{} n{})", i[0], i[1], i[2]),
Op::FullAttention { .. } => format!("(FullAttention n{} n{} n{})", i[0], i[1], i[2]),
Op::CrossAttention { .. } => format!("(CrossAttention n{} n{} n{})", i[0], i[1], i[2]),
Op::MultiHeadAttn { .. } => format!("(MultiHeadAttn n{} n{} n{})", i[0], i[1], i[2]),
Op::SiluGrad => format!("(SiluGrad n{} n{})", i[0], i[1]),
Op::SwiGLUGradGate => format!("(SwiGLUGradGate n{} n{} n{})", i[0], i[1], i[2]),
Op::SwiGLUGradUp => format!("(SwiGLUGradUp n{} n{})", i[0], i[1]),
Op::SwiGLUConcatGrad => format!("(SwiGLUConcatGrad n{} n{})", i[0], i[1]),
Op::RmsNormGradW { .. } => format!("(RmsNormGradW n{} n{} n{})", i[0], i[1], i[2]),
Op::RmsNormGradX { .. } => format!("(RmsNormGradX n{} n{} n{})", i[0], i[1], i[2]),
Op::LayerNormGradWB { .. } => format!("(LayerNormGradWB n{} n{} n{})", i[0], i[1], i[2]),
Op::LayerNormGradX { .. } => format!("(LayerNormGradX n{} n{} n{})", i[0], i[1], i[2]),
Op::MultiHeadAttnGradQ { .. } => {
format!("(MHAGradQ n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
}
Op::MultiHeadAttnGradK { .. } => {
format!("(MHAGradK n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
}
Op::MultiHeadAttnGradV { .. } => {
format!("(MHAGradV n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
}
Op::FusedMatMulAdd => {
format!("(FusedMatMulAdd n{} n{} n{})", i[0], i[1], i[2])
}
Op::FusedMatMulATAdd => {
format!("(FusedMatMulATAdd n{} n{} n{})", i[0], i[1], i[2])
}
Op::FusedMatMulBTAdd => {
format!("(FusedMatMulBTAdd n{} n{} n{})", i[0], i[1], i[2])
}
Op::FusedRmsNormMatMul { .. } => {
format!("(FusedRmsNormMatMul n{} n{} n{})", i[0], i[1], i[2])
}
Op::GroupNorm { .. } => format!("(GroupNorm n{} n{} n{})", i[0], i[1], i[2]),
Op::GroupNormSilu { .. } => format!("(GroupNormSilu n{} n{} n{})", i[0], i[1], i[2]),
Op::GroupNormGradInput { .. } => {
format!("(GroupNormGradInput n{} n{} n{})", i[0], i[1], i[2])
}
Op::GroupNormGradWeightBias { .. } => {
format!("(GroupNormGradWeightBias n{} n{})", i[0], i[1])
}
Op::Concat { .. } => format!("(Concat n{} n{})", i[0], i[1]),
Op::SplitA { .. } => format!("(SplitA n{})", i[0]),
Op::SplitB { .. } => format!("(SplitB n{})", i[0]),
Op::Upsample2x { .. } => format!("(Upsample2x n{})", i[0]),
Op::Upsample2xGrad { .. } => format!("(Upsample2xGrad n{})", i[0]),
Op::Conv2d { .. } => format!("(Conv2d n{} n{})", i[0], i[1]),
Op::Conv2dGradInput { .. } => format!("(Conv2dGradInput n{} n{})", i[0], i[1]),
Op::Conv2dGradWeight { .. } => format!("(Conv2dGradWeight n{} n{})", i[0], i[1]),
Op::MaxPool2d { .. } => format!("(MaxPool2d n{})", i[0]),
Op::GlobalAvgPool { .. } => format!("(GlobalAvgPool n{})", i[0]),
Op::GlobalAvgPoolGrad { .. } => format!("(GlobalAvgPoolGrad n{})", i[0]),
Op::CacheWrite => format!("(CacheWrite n{} n{} n{})", i[0], i[1], i[2]),
Op::CachedAttention { .. } => {
format!("(CachedAttention n{} n{} n{} n{})", i[0], i[1], i[2], i[3])
}
Op::Nop => unreachable!("Nop nodes are filtered before encoding"),
Op::Identity => format!("(Identity n{})", i[0]),
}
}
fn rebuild_graph_from_extractions(
original: &Graph,
extractions: &[(TermDag, TermId)],
) -> (Graph, Vec<(String, u32)>) {
let mut graph = clone_graph(original);
let mut fusions = Vec::new();
if !extractions.is_empty() {
let mut node_lookup: HashMap<String, Vec<usize>> = HashMap::new();
for node in graph.nodes() {
let key = egglog_key(node);
node_lookup.entry(key).or_default().push(node.id as usize);
}
for &(ref dag, root) in extractions {
scan_fusions(dag, root, &graph, &node_lookup, &mut fusions);
}
}
loop {
let n = fusions.len();
apply_matmul_add_fusions(&mut graph, &mut fusions);
apply_silu_fusions(&mut graph, &mut fusions);
apply_swiglu_fusions(&mut graph, &mut fusions);
apply_swiglu_concat_fusions(&mut graph, &mut fusions);
if fusions.len() == n {
break;
}
}
let active_nodes = graph
.nodes()
.iter()
.filter(|n| !matches!(n.op, Op::Nop))
.count();
log::info!(
"optimizer: {} fusions on {} nodes",
fusions.len(),
active_nodes
);
for (name, count) in fusions.iter().fold(
std::collections::BTreeMap::<&str, usize>::new(),
|mut acc, entry| {
let name = &entry.0;
*acc.entry(name.as_str()).or_default() += 1;
acc
},
) {
log::info!(" {}x {}", count, name);
}
(graph, fusions)
}
fn egglog_key(node: &Node) -> String {
let op_name = match node.op {
Op::Input { ref name } => format!("Input:{}", name),
Op::Parameter { ref name } => format!("Parameter:{}", name),
Op::Constant { .. } => format!("Const:{}", node.id),
_ => format!("{:?}", std::mem::discriminant(&node.op)),
};
format!("{}:{:?}", op_name, node.inputs)
}
fn scan_fusions(
dag: &TermDag,
term_id: TermId,
_graph: &Graph,
_lookup: &HashMap<String, Vec<usize>>,
_fusions: &mut Vec<(String, u32)>,
) {
if let Term::App(name, children) = dag.get(term_id).clone() {
if name.starts_with("FusedMatMul") || name.starts_with("FusedRmsNorm") {
log::debug!("egglog discovered fusion: {}", name);
}
for child in children {
scan_fusions(dag, child, _graph, _lookup, _fusions);
}
}
}
fn apply_matmul_add_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::Add) {
continue;
}
let (lhs, rhs) = (node.inputs[0], node.inputs[1]);
let (mm_id, addend_id) =
if matches!(graph.node(lhs).op, Op::MatMul | Op::MatMulAT | Op::MatMulBT) {
(lhs, rhs)
} else if matches!(graph.node(rhs).op, Op::MatMul | Op::MatMulAT | Op::MatMulBT) {
(rhs, lhs)
} else {
continue;
};
let mm_use_count = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&mm_id) && !matches!(n.op, Op::Nop))
.count();
if mm_use_count != 1 {
continue;
}
let mm_node = graph.node(mm_id);
let (a, b) = (mm_node.inputs[0], mm_node.inputs[1]);
let (fused_op, label) = match mm_node.op {
Op::MatMul => (Op::FusedMatMulAdd, "MatMul+Add→FusedMatMulAdd"),
Op::MatMulAT => (Op::FusedMatMulATAdd, "MatMulAT+Add→FusedMatMulATAdd"),
Op::MatMulBT => (Op::FusedMatMulBTAdd, "MatMulBT+Add→FusedMatMulBTAdd"),
_ => unreachable!(),
};
graph.nodes_mut()[id].op = fused_op;
graph.nodes_mut()[id].inputs = vec![a, b, addend_id];
graph.nodes_mut()[mm_id as usize].op = Op::Nop;
fusions.push((label.to_string(), id as u32));
}
}
fn apply_swiglu_concat_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
use crate::graph::TensorType;
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::SwiGLU) {
continue;
}
let (gate_id, up_id) = (node.inputs[0], node.inputs[1]);
let gate_node = graph.node(gate_id);
let up_node = graph.node(up_id);
if !matches!(gate_node.op, Op::MatMul) || !matches!(up_node.op, Op::MatMul) {
continue;
}
if gate_node.inputs[0] != up_node.inputs[0] {
continue;
}
let gate_uses = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&gate_id) && !matches!(n.op, Op::Nop))
.count();
let up_uses = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&up_id) && !matches!(n.op, Op::Nop))
.count();
if gate_uses != 1 || up_uses != 1 {
continue;
}
let h = gate_node.inputs[0];
let w_gate = gate_node.inputs[1];
let w_up = up_node.inputs[1];
let gate_shape = &graph.node(w_gate).ty.shape;
let up_shape = &graph.node(w_up).ty.shape;
if gate_shape.len() != 2 || up_shape.len() != 2 {
continue;
}
if gate_shape[0] != up_shape[0] || gate_shape[1] != up_shape[1] {
continue;
}
let in_features = gate_shape[0];
let out_features = gate_shape[1];
let concat_shape = vec![in_features, 2 * out_features];
let gate_name = match graph.node(w_gate).op {
Op::Parameter { ref name } => name.clone(),
_ => "w_gate".to_string(),
};
let up_name = match graph.node(w_up).op {
Op::Parameter { ref name } => name.clone(),
_ => "w_up".to_string(),
};
let concat_name = format!("{}+{}", gate_name, up_name);
graph.derived_params.push(crate::graph::DerivedParam {
name: concat_name.clone(),
sources: vec![(gate_name, out_features), (up_name, out_features)],
rows: in_features,
});
let concat_w = graph.add_raw_node(
Op::Parameter { name: concat_name },
vec![],
TensorType::f32(concat_shape.clone()),
);
let m = graph.node(h).ty.shape[0];
let wide_mm = graph.add_raw_node(
Op::MatMul,
vec![h, concat_w],
TensorType::f32(vec![m, 2 * out_features]),
);
let swiglu_ty = TensorType::f32(vec![m, out_features]);
graph.nodes_mut()[id].op = Op::SwiGLUConcat;
graph.nodes_mut()[id].inputs = vec![wide_mm];
graph.nodes_mut()[id].ty = swiglu_ty;
graph.nodes_mut()[gate_id as usize].op = Op::Nop;
graph.nodes_mut()[up_id as usize].op = Op::Nop;
fusions.push((
"SwiGLU(MatMul,MatMul)→SwiGLUConcat(MatMul)".to_string(),
id as u32,
));
}
}
pub fn apply_group_norm_silu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::Silu) {
continue;
}
let gn_id = node.inputs[0];
let gn_node = graph.node(gn_id);
let (num_groups, eps, channels, spatial) = match gn_node.op {
Op::GroupNorm {
num_groups,
eps,
channels,
spatial,
} => (num_groups, eps, channels, spatial),
_ => continue,
};
let gn_use_count = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&gn_id) && !matches!(n.op, Op::Nop))
.count();
if gn_use_count != 1 {
continue;
}
let (x, w, b) = (gn_node.inputs[0], gn_node.inputs[1], gn_node.inputs[2]);
graph.nodes_mut()[id].op = Op::GroupNormSilu {
num_groups,
eps,
channels,
spatial,
};
graph.nodes_mut()[id].inputs = vec![x, w, b];
graph.nodes_mut()[gn_id as usize].op = Op::Nop;
fusions.push(("GroupNorm+Silu→GroupNormSilu".to_string(), id as u32));
}
}
#[allow(dead_code)]
fn apply_rms_norm_matmul_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
use crate::graph::TensorType;
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::MatMul) {
continue;
}
let (norm_id, w_proj_id) = (node.inputs[0], node.inputs[1]);
let norm_node = graph.node(norm_id);
let eps = match norm_node.op {
Op::RmsNorm { eps } => eps,
_ => continue,
};
let norm_use_count = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&norm_id) && !matches!(n.op, Op::Nop))
.count();
if norm_use_count != 1 {
continue;
}
let x = norm_node.inputs[0];
let w_norm = norm_node.inputs[1];
let x_shape = &graph.node(x).ty.shape;
let w_proj_shape = &graph.node(w_proj_id).ty.shape;
let m = x_shape[0];
let n = w_proj_shape[1];
graph.nodes_mut()[id].op = Op::FusedRmsNormMatMul { eps };
graph.nodes_mut()[id].inputs = vec![x, w_norm, w_proj_id];
graph.nodes_mut()[id].ty = TensorType::f32(vec![m, n]);
graph.nodes_mut()[norm_id as usize].op = Op::Nop;
fusions.push(("RmsNorm+MatMul→FusedRmsNormMatMul".to_string(), id as u32));
}
}
#[allow(dead_code)]
fn apply_rope_attention_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
let (num_heads, num_kv_heads, head_dim) = match node.op {
Op::CausalAttention {
num_heads,
num_kv_heads,
head_dim,
} => (num_heads, num_kv_heads, head_dim),
_ => continue,
};
let q_id = node.inputs[0];
let k_id = node.inputs[1];
let v_id = node.inputs[2];
let q_node = graph.node(q_id);
let k_node = graph.node(k_id);
let (q_theta, q_raw) = match q_node.op {
Op::RoPE { theta, .. } => (theta, q_node.inputs[0]),
_ => continue,
};
let (k_theta, k_raw) = match k_node.op {
Op::RoPE { theta, .. } => (theta, k_node.inputs[0]),
_ => continue,
};
if q_theta != k_theta {
continue;
}
let q_uses = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&q_id) && !matches!(n.op, Op::Nop))
.count();
let k_uses = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&k_id) && !matches!(n.op, Op::Nop))
.count();
if q_uses != 1 || k_uses != 1 {
continue;
}
graph.nodes_mut()[id].op = Op::CausalAttentionRoPE {
num_heads,
num_kv_heads,
head_dim,
rope_theta: q_theta,
};
graph.nodes_mut()[id].inputs = vec![q_raw, k_raw, v_id];
graph.nodes_mut()[q_id as usize].op = Op::Nop;
graph.nodes_mut()[k_id as usize].op = Op::Nop;
fusions.push((
"CausalAttn(RoPE,RoPE)→CausalAttnRoPE".to_string(),
id as u32,
));
}
}
fn apply_silu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::Mul) {
continue;
}
let (a_id, b_id) = (node.inputs[0], node.inputs[1]);
let (x, sig_id) = if matches!(graph.node(b_id).op, Op::Sigmoid)
&& graph.node(b_id).inputs[0] == a_id
{
(a_id, b_id)
} else if matches!(graph.node(a_id).op, Op::Sigmoid) && graph.node(a_id).inputs[0] == b_id {
(b_id, a_id)
} else {
continue;
};
let sig_use_count = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&sig_id) && !matches!(n.op, Op::Nop))
.count();
if sig_use_count != 1 {
continue;
}
graph.nodes_mut()[id].op = Op::Silu;
graph.nodes_mut()[id].inputs = vec![x];
graph.nodes_mut()[sig_id as usize].op = Op::Nop;
fusions.push(("Mul+Sigmoid→Silu".to_string(), id as u32));
}
}
fn apply_swiglu_fusions(graph: &mut Graph, fusions: &mut Vec<(String, u32)>) {
let node_ids: Vec<usize> = (0..graph.nodes().len()).collect();
for &id in &node_ids {
let node = &graph.nodes()[id];
if !matches!(node.op, Op::Mul) {
continue;
}
let (a_id, b_id) = (node.inputs[0], node.inputs[1]);
let (gate, up, silu_id) = if matches!(graph.node(a_id).op, Op::Silu) {
(graph.node(a_id).inputs[0], b_id, a_id)
} else if matches!(graph.node(b_id).op, Op::Silu) {
(graph.node(b_id).inputs[0], a_id, b_id)
} else {
continue;
};
let silu_use_count = graph
.nodes()
.iter()
.filter(|n| n.inputs.contains(&silu_id) && !matches!(n.op, Op::Nop))
.count();
if silu_use_count != 1 {
continue;
}
graph.nodes_mut()[id].op = Op::SwiGLU;
graph.nodes_mut()[id].inputs = vec![gate, up];
graph.nodes_mut()[silu_id as usize].op = Op::Nop;
fusions.push(("Silu+Mul→SwiGLU".to_string(), id as u32));
}
}
fn clone_graph(graph: &Graph) -> Graph {
let mut new_graph = Graph::new();
for node in graph.nodes() {
new_graph.add_raw_node(node.op.clone(), node.inputs.clone(), node.ty.clone());
}
new_graph.set_outputs(graph.outputs().to_vec());
new_graph.derived_params = graph.derived_params.clone();
new_graph
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_fusion_cooperative_matrix() {
let mut g = Graph::new();
let x = g.input("x", &[4, 784]);
let w = g.parameter("w", &[784, 128]);
let mm = g.matmul(x, w);
let h = g.relu(mm);
g.set_outputs(vec![h]);
let opt = optimize(&g);
let output_id = opt.outputs()[0];
let output_node = opt.node(output_id);
assert!(
matches!(output_node.op, Op::Relu),
"expected Relu (no fusion), got {:?}",
output_node.op
);
}
#[test]
fn test_optimize_report() {
let mut g = Graph::new();
let x = g.input("x", &[4, 784]);
let w1 = g.parameter("w1", &[784, 128]);
let mm1 = g.matmul(x, w1);
let h1 = g.relu(mm1);
let w2 = g.parameter("w2", &[128, 10]);
let mm2 = g.matmul(h1, w2);
let h2 = g.relu(mm2);
g.set_outputs(vec![h2]);
let (_opt, report) = optimize_with_report(&g);
assert!(report.fusions_applied.is_empty());
let display = format!("{}", report);
assert!(display.contains("Optimization Report"));
}
#[test]
fn test_egglog_roundtrip() {
let mut g = Graph::new();
let x = g.input("x", &[4, 10]);
let w = g.parameter("w", &[10, 5]);
let y = g.matmul(x, w);
g.set_outputs(vec![y]);
let program = graph_to_egglog(&g);
assert!(program.contains("(MatMul"));
assert!(program.contains("(Input \"x\")"));
let mut egraph = egglog::EGraph::default();
egraph.parse_and_run_program(None, &program).unwrap();
}
#[test]
fn test_egglog_extract_returns_fused() {
let mut egraph = egglog::EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Op
(MatMul Op Op)
(MatMulBT Op Op)
(Add Op Op)
(FusedMatMulAdd Op Op Op)
(FusedMatMulBTAdd Op Op Op)
(Input String)
(Parameter String)
)
(rewrite (Add (MatMul ?a ?b) ?d) (FusedMatMulAdd ?a ?b ?d))
(rewrite (Add (MatMulBT ?a ?b) ?d) (FusedMatMulBTAdd ?a ?b ?d))
(rewrite (Add ?x ?y) (Add ?y ?x))
(let n0 (Input "x"))
(let n1 (Parameter "w"))
(let n2 (MatMul n0 n1))
(let n3 (Input "bias"))
(let n4 (Add n2 n3))
(run 10)
(extract n4)
"#,
)
.unwrap();
let mut found_fused = false;
for out in &outputs {
if let egglog::CommandOutput::ExtractBest(dag, _cost, term_id) = out {
let s = dag.to_string(*term_id);
eprintln!("egglog extracted: {}", s);
assert!(
s.contains("FusedMatMulAdd"),
"expected FusedMatMulAdd, got: {}",
s
);
match dag.get(*term_id) {
Term::App(name, _children) => {
assert_eq!(name, "FusedMatMulAdd");
}
other => panic!("expected App, got {:?}", other),
}
found_fused = true;
}
}
assert!(found_fused, "no ExtractBest output found");
}
#[test]
fn test_optimize_preserves_graph() {
let mut g = Graph::new();
let a = g.input("a", &[4, 8]);
let b = g.input("b", &[4, 8]);
let sum = g.add(a, b);
let neg = g.neg(sum);
g.set_outputs(vec![neg]);
let opt = optimize(&g);
assert_eq!(opt.nodes().len(), g.nodes().len());
let out = opt.node(opt.outputs()[0]);
assert!(matches!(out.op, Op::Neg));
}
#[test]
fn test_dump_egglog_program() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let w = g.parameter("w", &[8, 4]);
let y = g.matmul(x, w);
let _h = g.relu(y);
g.set_outputs(vec![y]);
let program = dump_egglog_program(&g);
assert!(program.contains("(datatype Op"));
assert!(program.contains("(extract n"));
}
#[test]
fn test_egglog_all_ops() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let w = g.parameter("w", &[8, 4]);
let _c = g.constant(vec![0.0; 32], &[4, 8]);
let mm = g.matmul(x, w);
let _a = g.add(mm, mm);
let _m = g.mul(mm, mm);
let b = g.parameter("b", &[4]);
let _ba = g.bias_add(mm, b);
let _r = g.relu(mm);
let _s = g.sigmoid(mm);
let _n = g.neg(mm);
let _t = g.transpose(mm);
let _sm = g.softmax(mm);
let _lsm = g.log_softmax(mm);
let sa = g.sum_all(mm);
let _ma = g.mean_all(mm);
let _gt = g.greater(mm, mm);
let _cel = g.cross_entropy_loss(mm, mm);
g.set_outputs(vec![sa]);
let program = graph_to_egglog(&g);
let mut egraph = egglog::EGraph::default();
egraph.parse_and_run_program(None, &program).unwrap();
}
#[test]
fn test_clone_graph_preserves_structure() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let w = g.parameter("w", &[8, 4]);
let y = g.matmul(x, w);
g.set_outputs(vec![y]);
let cloned = clone_graph(&g);
assert_eq!(cloned.nodes().len(), g.nodes().len());
assert_eq!(cloned.outputs(), g.outputs());
for (a, b) in cloned.nodes().iter().zip(g.nodes().iter()) {
assert_eq!(a.id, b.id);
assert_eq!(a.inputs, b.inputs);
assert_eq!(a.ty.shape, b.ty.shape);
}
}
#[test]
fn test_matmul_stays_as_matmul() {
let mut g = Graph::new();
let x = g.input("x", &[2, 1024]);
let w = g.parameter("w", &[1024, 64]);
let y = g.matmul(x, w);
g.set_outputs(vec![y]);
let opt = optimize(&g);
let output_id = opt.outputs()[0];
assert!(
matches!(opt.node(output_id).op, Op::MatMul),
"expected MatMul, got {:?}",
opt.node(output_id).op
);
}
#[test]
fn test_egglog_scalability() {
for n in [10, 50, 100, 200, 350] {
let mut prog = String::from(
"(datatype Op
(MatMul Op Op) (MatMulAT Op Op) (MatMulBT Op Op)
(Add Op Op) (Input String) (Parameter String)
(FusedMatMulAdd Op Op Op) (FusedMatMulATAdd Op Op Op) (FusedMatMulBTAdd Op Op Op)
)\n",
);
prog.push_str("(rewrite (Add (MatMul ?a ?b) ?d) (FusedMatMulAdd ?a ?b ?d))\n");
prog.push_str("(rewrite (Add ?d (MatMul ?a ?b)) (FusedMatMulAdd ?a ?b ?d))\n");
prog.push_str("(rewrite (Add (MatMulAT ?a ?b) ?d) (FusedMatMulATAdd ?a ?b ?d))\n");
prog.push_str("(rewrite (Add ?d (MatMulAT ?a ?b)) (FusedMatMulATAdd ?a ?b ?d))\n");
prog.push_str("(rewrite (Add (MatMulBT ?a ?b) ?d) (FusedMatMulBTAdd ?a ?b ?d))\n");
prog.push_str("(rewrite (Add ?d (MatMulBT ?a ?b)) (FusedMatMulBTAdd ?a ?b ?d))\n");
prog.push_str("(let n0 (Input \"x\"))\n(let n1 (Parameter \"w\"))\n");
for i in 1..n {
let prev = (i - 1) * 2 + 2;
match i % 3 {
0 => prog.push_str(&format!("(let n{} (MatMulAT n{} n1))\n", i * 2, prev - 1)),
1 => prog.push_str(&format!("(let n{} (MatMulBT n{} n1))\n", i * 2, prev - 1)),
_ => prog.push_str(&format!("(let n{} (MatMul n{} n1))\n", i * 2, prev - 1)),
}
prog.push_str(&format!(
"(let n{} (Add n{} n{}))\n",
i * 2 + 1,
i * 2,
prev - 1
));
}
prog.push_str("(run 1)\n");
let last = (n - 1) * 2 + 1;
prog.push_str(&format!("(extract n{})\n", last));
let t0 = Instant::now();
let mut egraph = egglog::EGraph::default();
egraph.parse_and_run_program(None, &prog).unwrap();
let elapsed = t0.elapsed();
eprintln!(
"egglog scalability: n={:>4} nodes -> {:>8.1}ms",
n * 2,
elapsed.as_secs_f64() * 1000.0
);
}
}
#[test]
fn test_egglog_discovers_matmul_add_fusion() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let w = g.parameter("w", &[8, 4]);
let b = g.input("bias", &[4, 4]);
let mm = g.matmul(x, w);
let out = g.add(mm, b);
g.set_outputs(vec![out]);
let (opt, report) = optimize_with_report(&g);
let output_node = opt.node(opt.outputs()[0]);
assert!(
matches!(output_node.op, Op::FusedMatMulAdd),
"expected FusedMatMulAdd, got {:?}",
output_node.op
);
assert!(!report.fusions_applied.is_empty());
}
#[test]
fn test_swiglu_concat_fusion() {
let mut g = Graph::new();
let h = g.input("h", &[50, 720]);
let w_gate = g.parameter("w_gate", &[720, 2048]);
let w_up = g.parameter("w_up", &[720, 2048]);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let out = g.swiglu(gate, up);
g.set_outputs(vec![out]);
let (opt, report) = optimize_with_report(&g);
let output_node = opt.node(opt.outputs()[0]);
assert!(
matches!(output_node.op, Op::SwiGLUConcat),
"expected SwiGLUConcat, got {:?}",
output_node.op
);
assert!(
report
.fusions_applied
.iter()
.any(|(name, _)| name.contains("SwiGLU")),
"no SwiGLU fusion in report: {:?}",
report.fusions_applied
);
let mm_id = output_node.inputs[0];
let mm_node = opt.node(mm_id);
assert!(matches!(mm_node.op, Op::MatMul));
assert_eq!(mm_node.ty.shape, vec![50, 4096]);
}
#[test]
fn test_egglog_encodes_backward_ops() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let w = g.parameter("w", &[8, 4]);
let at = g.add_raw_node(
Op::MatMulAT,
vec![x, x],
crate::graph::TensorType::f32(vec![8, 8]),
);
let bt = g.add_raw_node(
Op::MatMulBT,
vec![x, w],
crate::graph::TensorType::f32(vec![4, 8]),
);
g.set_outputs(vec![at, bt]);
let program = graph_to_egglog(&g);
assert!(program.contains("MatMulAT"), "MatMulAT not encoded");
assert!(program.contains("MatMulBT"), "MatMulBT not encoded");
let mut egraph = egglog::EGraph::default();
egraph
.parse_and_run_program(None, &program)
.expect("egglog failed with backward ops");
}
#[test]
fn test_egglog_discovers_backward_bt_add_fusion() {
let mut g = Graph::new();
let grad = g.input("grad", &[4, 8]);
let w = g.parameter("w", &[4, 8]);
let prev = g.input("prev_grad", &[4, 4]);
let bt = g.add_raw_node(
Op::MatMulBT,
vec![grad, w],
crate::graph::TensorType::f32(vec![4, 4]),
);
let out = g.add(bt, prev);
g.set_outputs(vec![out]);
let (opt, report) = optimize_with_report(&g);
let output_node = opt.node(opt.outputs()[0]);
assert!(
matches!(output_node.op, Op::FusedMatMulBTAdd),
"expected FusedMatMulBTAdd, got {:?}",
output_node.op
);
assert!(
report
.fusions_applied
.iter()
.any(|(name, _)| name.contains("BT")),
"no BT fusion in report"
);
}
#[test]
fn test_silu_fusion() {
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let sig = g.sigmoid(x);
let out = g.mul(x, sig);
g.set_outputs(vec![out]);
let (opt, report) = optimize_with_report(&g);
let has_silu = opt.nodes().iter().any(|n| matches!(n.op, Op::Silu));
assert!(
has_silu,
"expected Silu fusion, got nodes: {:?}",
opt.nodes()
.iter()
.map(|n| format!("{:?}", n.op))
.collect::<Vec<_>>()
);
assert!(
!report.fusions_applied.is_empty() || has_silu,
"no Silu fusion detected"
);
}
#[test]
fn test_swiglu_from_decomposed() {
let mut g = Graph::new();
let gate = g.input("gate", &[4, 8]);
let up = g.input("up", &[4, 8]);
let sig = g.sigmoid(gate);
let silu = g.mul(gate, sig);
let out = g.mul(silu, up);
g.set_outputs(vec![out]);
let (opt, _report) = optimize_with_report(&g);
let has_swiglu = opt.nodes().iter().any(|n| matches!(n.op, Op::SwiGLU));
assert!(
has_swiglu,
"expected SwiGLU fusion from decomposed silu*up, got nodes: {:?}",
opt.nodes()
.iter()
.map(|n| format!("{:?}", n.op))
.collect::<Vec<_>>()
);
}
#[test]
fn test_pool_ops_roundtrip() {
let mut g = Graph::new();
let x = g.input("x", &[1 * 64 * 8 * 8]);
let pool = g.max_pool_2d(x, 1, 64, 8, 8, 2, 2, 2, 0);
let gap = g.global_avg_pool(pool, 1, 64, 16);
g.set_outputs(vec![gap]);
let (opt, _report) = optimize_with_report(&g);
let has_maxpool = opt
.nodes()
.iter()
.any(|n| matches!(n.op, Op::MaxPool2d { .. }));
let has_gap = opt
.nodes()
.iter()
.any(|n| matches!(n.op, Op::GlobalAvgPool { .. }));
assert!(has_maxpool, "MaxPool2d should survive optimization");
assert!(has_gap, "GlobalAvgPool should survive optimization");
}
}