use std::fmt;
use crate::error::{BlasError, BlasResult};
use crate::types::Transpose;
#[derive(Debug, Clone, PartialEq)]
pub enum GemmOp {
Gemm,
Scale {
alpha: f64,
},
Add,
Transpose,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NodeInput {
External {
name: String,
},
NodeOutput {
node_id: usize,
},
}
#[derive(Debug, Clone)]
pub struct GemmNode {
pub id: usize,
pub op: GemmOp,
pub m: u32,
pub n: u32,
pub k: u32,
pub transpose_a: Transpose,
pub transpose_b: Transpose,
pub inputs: Vec<NodeInput>,
}
#[derive(Debug, Clone)]
pub struct GemmGraph {
nodes: Vec<GemmNode>,
output_node: Option<usize>,
}
impl GemmGraph {
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
output_node: None,
}
}
pub fn add_gemm(
&mut self,
m: u32,
n: u32,
k: u32,
input_a: NodeInput,
input_b: NodeInput,
) -> usize {
let id = self.nodes.len();
self.nodes.push(GemmNode {
id,
op: GemmOp::Gemm,
m,
n,
k,
transpose_a: Transpose::NoTrans,
transpose_b: Transpose::NoTrans,
inputs: vec![input_a, input_b],
});
id
}
pub fn add_scale(&mut self, node_id: usize, alpha: f64) -> BlasResult<usize> {
let src = self
.nodes
.get(node_id)
.ok_or_else(|| BlasError::InvalidArgument(format!("node {node_id} not found")))?;
let (m, n) = (src.m, src.n);
let id = self.nodes.len();
self.nodes.push(GemmNode {
id,
op: GemmOp::Scale { alpha },
m,
n,
k: 0,
transpose_a: Transpose::NoTrans,
transpose_b: Transpose::NoTrans,
inputs: vec![NodeInput::NodeOutput { node_id }],
});
Ok(id)
}
pub fn add_add(&mut self, a_id: usize, b_id: usize) -> BlasResult<usize> {
let a = self
.nodes
.get(a_id)
.ok_or_else(|| BlasError::InvalidArgument(format!("node {a_id} not found")))?;
let b = self
.nodes
.get(b_id)
.ok_or_else(|| BlasError::InvalidArgument(format!("node {b_id} not found")))?;
if a.m != b.m || a.n != b.n {
return Err(BlasError::DimensionMismatch(format!(
"add nodes {a_id} ({}x{}) and {b_id} ({}x{}) differ",
a.m, a.n, b.m, b.n,
)));
}
let (m, n) = (a.m, a.n);
let id = self.nodes.len();
self.nodes.push(GemmNode {
id,
op: GemmOp::Add,
m,
n,
k: 0,
transpose_a: Transpose::NoTrans,
transpose_b: Transpose::NoTrans,
inputs: vec![
NodeInput::NodeOutput { node_id: a_id },
NodeInput::NodeOutput { node_id: b_id },
],
});
Ok(id)
}
pub fn set_output(&mut self, node_id: usize) -> BlasResult<()> {
if node_id >= self.nodes.len() {
return Err(BlasError::InvalidArgument(format!(
"node {node_id} does not exist"
)));
}
self.output_node = Some(node_id);
Ok(())
}
#[must_use]
pub fn output_node(&self) -> Option<usize> {
self.output_node
}
#[must_use]
pub fn nodes(&self) -> &[GemmNode] {
&self.nodes
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn chain(matrices: &[(u32, u32)]) -> BlasResult<Self> {
if matrices.len() < 2 {
return Err(BlasError::InvalidArgument(
"chain requires at least 2 matrices".into(),
));
}
for i in 0..matrices.len() - 1 {
if matrices[i].1 != matrices[i + 1].0 {
return Err(BlasError::DimensionMismatch(format!(
"matrix {} cols ({}) != matrix {} rows ({})",
i,
matrices[i].1,
i + 1,
matrices[i + 1].0,
)));
}
}
let mut graph = Self::new();
let (m0, k0) = (matrices[0].0, matrices[0].1);
let n0 = matrices[1].1;
let prev = graph.add_gemm(
m0,
n0,
k0,
NodeInput::External {
name: format!("M{}", 0),
},
NodeInput::External {
name: format!("M{}", 1),
},
);
let mut last = prev;
for (i, mat) in matrices.iter().enumerate().skip(2) {
let prev_node = &graph.nodes[last];
let m = prev_node.m;
let k = prev_node.n;
let n = mat.1;
last = graph.add_gemm(
m,
n,
k,
NodeInput::NodeOutput { node_id: last },
NodeInput::External {
name: format!("M{i}"),
},
);
}
graph.set_output(last)?;
Ok(graph)
}
fn consumers_of(&self, node_id: usize) -> Vec<usize> {
self.nodes
.iter()
.filter(|n| {
n.inputs.iter().any(|inp| match inp {
NodeInput::NodeOutput { node_id: id } => *id == node_id,
_ => false,
})
})
.map(|n| n.id)
.collect()
}
#[must_use]
pub fn fan_out(&self, node_id: usize) -> usize {
self.consumers_of(node_id).len()
}
}
impl Default for GemmGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionStrategy {
EpilogueFusion,
SharedMemoryReuse,
StreamPipelining,
None,
}
#[derive(Debug, Clone)]
pub struct FusiblePair {
pub producer: usize,
pub consumer: usize,
pub strategy: FusionStrategy,
}
#[derive(Debug, Clone)]
pub struct FusionStage {
pub stage_index: u32,
pub nodes: Vec<usize>,
pub can_overlap_with_next: bool,
pub intermediate_bytes: usize,
}
#[derive(Debug, Clone)]
pub struct FusionPlan {
pub stages: Vec<FusionStage>,
pub fused_pairs: Vec<FusiblePair>,
pub total_gemm_calls: u32,
pub original_gemm_calls: u32,
}
impl fmt::Display for FusionPlan {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"FusionPlan: {} stages, {} -> {} GEMM calls ({} fused pairs)",
self.stages.len(),
self.original_gemm_calls,
self.total_gemm_calls,
self.fused_pairs.len(),
)?;
for stage in &self.stages {
write!(
f,
" stage {}: nodes {:?}, {} intermediate bytes",
stage.stage_index, stage.nodes, stage.intermediate_bytes,
)?;
if stage.can_overlap_with_next {
write!(f, " [overlap]")?;
}
writeln!(f)?;
}
for pair in &self.fused_pairs {
writeln!(
f,
" fused: {} -> {} ({:?})",
pair.producer, pair.consumer, pair.strategy,
)?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FusionType {
GemmBiasActivation,
GemmLayerNorm,
ConsecutiveGemm,
GemmScale,
}
impl fmt::Display for FusionType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GemmBiasActivation => write!(f, "GemmBiasActivation"),
Self::GemmLayerNorm => write!(f, "GemmLayerNorm"),
Self::ConsecutiveGemm => write!(f, "ConsecutiveGemm"),
Self::GemmScale => write!(f, "GemmScale"),
}
}
}
#[derive(Debug, Clone)]
pub struct FusionOpportunity {
pub node_ids: Vec<usize>,
pub estimated_speedup: f32,
pub fusion_type: FusionType,
}
#[derive(Debug, Clone)]
pub enum FusedOp {
FusedGemmEpilogue {
gemm: GemmNode,
epilogue: GemmNode,
},
Standalone {
node_id: usize,
},
}
#[derive(Debug, Clone)]
pub struct FusedKernelPlan {
pub operations: Vec<FusedOp>,
pub memory_saved: usize,
pub compute_overhead: f32,
}
pub struct FusionPass;
impl FusionPass {
pub fn analyze(graph: &GemmGraph) -> BlasResult<FusionPlan> {
if graph.is_empty() {
return Err(BlasError::InvalidArgument(
"cannot analyse an empty graph".into(),
));
}
let fusible = Self::find_fusible_pairs(graph);
let stages = Self::build_stages(graph, &fusible);
let original_gemm_calls = graph
.nodes()
.iter()
.filter(|n| n.op == GemmOp::Gemm)
.count() as u32;
let fused_count = fusible
.iter()
.filter(|p| {
matches!(
p.strategy,
FusionStrategy::EpilogueFusion | FusionStrategy::SharedMemoryReuse
)
})
.count() as u32;
let total_gemm_calls = original_gemm_calls.saturating_sub(fused_count);
Ok(FusionPlan {
stages,
fused_pairs: fusible,
total_gemm_calls,
original_gemm_calls,
})
}
pub fn find_fusible_pairs(graph: &GemmGraph) -> Vec<FusiblePair> {
let mut pairs = Vec::new();
for node in graph.nodes() {
if node.op != GemmOp::Gemm {
continue;
}
let consumers = graph.consumers_of(node.id);
if consumers.len() != 1 {
continue;
}
let consumer_id = consumers[0];
let consumer = &graph.nodes()[consumer_id];
if consumer.op != GemmOp::Gemm {
continue;
}
let strategy = Self::select_strategy(node, consumer);
pairs.push(FusiblePair {
producer: node.id,
consumer: consumer_id,
strategy,
});
}
pairs
}
fn select_strategy(producer: &GemmNode, consumer: &GemmNode) -> FusionStrategy {
if producer.m == consumer.m && producer.n == consumer.k {
return FusionStrategy::EpilogueFusion;
}
let small_threshold = 256;
if producer.m <= small_threshold
&& producer.n <= small_threshold
&& consumer.m <= small_threshold
&& consumer.n <= small_threshold
{
return FusionStrategy::SharedMemoryReuse;
}
FusionStrategy::StreamPipelining
}
pub fn estimate_intermediate_memory(graph: &GemmGraph) -> usize {
let output = graph.output_node();
graph
.nodes()
.iter()
.filter(|n| Some(n.id) != output)
.map(|n| n.m as usize * n.n as usize * 4) .sum()
}
pub fn estimate_fused_memory(graph: &GemmGraph, plan: &FusionPlan) -> usize {
let output = graph.output_node();
let eliminated: std::collections::HashSet<usize> = plan
.fused_pairs
.iter()
.filter(|p| {
matches!(
p.strategy,
FusionStrategy::EpilogueFusion | FusionStrategy::SharedMemoryReuse
)
})
.map(|p| p.producer)
.collect();
graph
.nodes()
.iter()
.filter(|n| Some(n.id) != output && !eliminated.contains(&n.id))
.map(|n| n.m as usize * n.n as usize * 4)
.sum()
}
pub fn memory_savings(graph: &GemmGraph, plan: &FusionPlan) -> f64 {
let original = Self::estimate_intermediate_memory(graph);
if original == 0 {
return 0.0;
}
let fused = Self::estimate_fused_memory(graph, plan);
1.0 - (fused as f64 / original as f64)
}
pub fn analyze_opportunities(
graph: &GemmGraph,
min_speedup_estimate: f32,
) -> Vec<FusionOpportunity> {
let mut opps = Vec::new();
for node in graph.nodes() {
if node.op != GemmOp::Gemm {
continue;
}
let consumers = graph.consumers_of(node.id);
if consumers.len() != 1 {
continue; }
let consumer_id = consumers[0];
let consumer = &graph.nodes()[consumer_id];
let (fusion_type, node_ids) = match &consumer.op {
GemmOp::Add => {
(FusionType::GemmBiasActivation, vec![node.id, consumer_id])
}
GemmOp::Scale { .. } => (FusionType::GemmScale, vec![node.id, consumer_id]),
GemmOp::Gemm => (FusionType::ConsecutiveGemm, vec![node.id, consumer_id]),
_ => continue,
};
let speedup = Self::estimate_speedup(
&fusion_type,
node.m as usize,
node.n as usize,
node.k as usize,
);
if speedup >= min_speedup_estimate {
opps.push(FusionOpportunity {
node_ids,
estimated_speedup: speedup,
fusion_type,
});
}
}
opps
}
#[must_use]
pub fn estimate_speedup(fusion_type: &FusionType, m: usize, n: usize, k: usize) -> f32 {
match fusion_type {
FusionType::GemmBiasActivation => {
let mn = (m * n) as f32;
let k_f = k as f32;
if mn > k_f * 4.0 { 1.8 } else { 1.1 }
}
FusionType::GemmLayerNorm => 1.3,
FusionType::ConsecutiveGemm => 1.05,
FusionType::GemmScale => 1.4,
}
}
pub fn apply(graph: &GemmGraph, opp: &FusionOpportunity) -> BlasResult<FusedKernelPlan> {
for &nid in &opp.node_ids {
if nid >= graph.len() {
return Err(BlasError::InvalidArgument(format!(
"fusion opportunity references non-existent node {nid}"
)));
}
}
let fused_set: std::collections::HashSet<usize> = opp.node_ids.iter().copied().collect();
let mut operations = Vec::new();
if opp.node_ids.len() == 2 {
let producer_id = opp.node_ids[0];
let consumer_id = opp.node_ids[1];
let producer = graph.nodes()[producer_id].clone();
let consumer = graph.nodes()[consumer_id].clone();
operations.push(FusedOp::FusedGemmEpilogue {
gemm: producer,
epilogue: consumer,
});
for node in graph.nodes() {
if !fused_set.contains(&node.id) {
operations.push(FusedOp::Standalone { node_id: node.id });
}
}
} else {
for node in graph.nodes() {
operations.push(FusedOp::Standalone { node_id: node.id });
}
}
let memory_saved: usize = if opp.node_ids.len() >= 2 {
let producer = &graph.nodes()[opp.node_ids[0]];
producer.m as usize * producer.n as usize * 4
} else {
0
};
let compute_overhead = match opp.fusion_type {
FusionType::GemmBiasActivation => 0.02,
FusionType::GemmLayerNorm => 0.05,
FusionType::ConsecutiveGemm => 0.01,
FusionType::GemmScale => 0.01,
};
Ok(FusedKernelPlan {
operations,
memory_saved,
compute_overhead,
})
}
fn build_stages(graph: &GemmGraph, fusible: &[FusiblePair]) -> Vec<FusionStage> {
let n = graph.len();
if n == 0 {
return Vec::new();
}
let mut level = vec![0u32; n];
for node in graph.nodes() {
for inp in &node.inputs {
if let NodeInput::NodeOutput { node_id } = inp {
let candidate = level[*node_id] + 1;
if candidate > level[node.id] {
level[node.id] = candidate;
}
}
}
}
let max_level = level.iter().copied().max().unwrap_or(0);
let stream_pipeline_stages: std::collections::HashSet<u32> = fusible
.iter()
.filter(|p| p.strategy == FusionStrategy::StreamPipelining)
.filter_map(|p| level.get(p.producer).copied())
.collect();
let output = graph.output_node();
let mut stages = Vec::new();
for lv in 0..=max_level {
let nodes_in_level: Vec<usize> = (0..n).filter(|&i| level[i] == lv).collect();
let intermediate_bytes: usize = nodes_in_level
.iter()
.filter(|&&nid| Some(nid) != output)
.map(|&nid| {
let nd = &graph.nodes()[nid];
nd.m as usize * nd.n as usize * 4
})
.sum();
let can_overlap = stream_pipeline_stages.contains(&lv);
stages.push(FusionStage {
stage_index: lv,
nodes: nodes_in_level,
can_overlap_with_next: can_overlap,
intermediate_bytes,
});
}
stages
}
}
pub fn optimal_chain_order(dimensions: &[(u32, u32)]) -> BlasResult<Vec<(usize, usize)>> {
let n = dimensions.len();
if n < 2 {
return Err(BlasError::InvalidArgument(
"optimal_chain_order requires at least 2 matrices".into(),
));
}
for i in 0..n - 1 {
if dimensions[i].1 != dimensions[i + 1].0 {
return Err(BlasError::DimensionMismatch(format!(
"matrix {} cols ({}) != matrix {} rows ({})",
i,
dimensions[i].1,
i + 1,
dimensions[i + 1].0,
)));
}
}
let mut p = Vec::with_capacity(n + 1);
p.push(dimensions[0].0 as u64);
for d in dimensions {
p.push(d.1 as u64);
}
let mut cost = vec![vec![0u64; n]; n];
let mut split = vec![vec![0usize; n]; n];
for chain_len in 1..n {
for i in 0..n - chain_len {
let j = i + chain_len;
cost[i][j] = u64::MAX;
for k in i..j {
let q = cost[i][k] + cost[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
if q < cost[i][j] {
cost[i][j] = q;
split[i][j] = k;
}
}
}
}
let mut order = Vec::new();
reconstruct_order(&split, 0, n - 1, &mut order);
Ok(order)
}
fn reconstruct_order(split: &[Vec<usize>], i: usize, j: usize, order: &mut Vec<(usize, usize)>) {
if i == j {
return;
}
let k = split[i][j];
reconstruct_order(split, i, k, order);
reconstruct_order(split, k + 1, j, order);
order.push((i, j));
}
pub fn estimate_chain_flops(
dimensions: &[(u32, u32)],
order: &[(usize, usize)],
) -> BlasResult<f64> {
let n = dimensions.len();
if n < 2 {
return Err(BlasError::InvalidArgument(
"need at least 2 matrices".into(),
));
}
let mut p = Vec::with_capacity(n + 1);
p.push(dimensions[0].0 as f64);
for d in dimensions {
p.push(d.1 as f64);
}
let mut result_rows: Vec<f64> = dimensions.iter().map(|d| d.0 as f64).collect();
let mut result_cols: Vec<f64> = dimensions.iter().map(|d| d.1 as f64).collect();
let mut total = 0.0f64;
for &(i, j) in order {
if i == j {
continue;
}
let rows = result_rows[i];
let cols = result_cols[j];
let inner = p[split_point_for(i, j, dimensions)];
total += 2.0 * rows * cols * inner;
result_rows[j] = rows;
result_cols[i] = cols;
}
Ok(total)
}
fn split_point_for(i: usize, j: usize, dimensions: &[(u32, u32)]) -> usize {
if j == i + 1 {
return i + 1; }
let n = dimensions.len();
let mut p = Vec::with_capacity(n + 1);
p.push(dimensions[0].0 as u64);
for d in dimensions {
p.push(d.1 as u64);
}
let mut cost = vec![vec![0u64; n]; n];
let mut split = vec![vec![0usize; n]; n];
for chain_len in 1..n {
for ii in 0..n - chain_len {
let jj = ii + chain_len;
cost[ii][jj] = u64::MAX;
for k in ii..jj {
let q = cost[ii][k] + cost[k + 1][jj] + p[ii] * p[k + 1] * p[jj + 1];
if q < cost[ii][jj] {
cost[ii][jj] = q;
split[ii][jj] = k;
}
}
}
}
split[i][j] + 1 }
pub fn minimum_chain_flops(dimensions: &[(u32, u32)]) -> BlasResult<u64> {
let n = dimensions.len();
if n < 2 {
return Err(BlasError::InvalidArgument(
"need at least 2 matrices".into(),
));
}
for i in 0..n - 1 {
if dimensions[i].1 != dimensions[i + 1].0 {
return Err(BlasError::DimensionMismatch(format!(
"matrix {} cols ({}) != matrix {} rows ({})",
i,
dimensions[i].1,
i + 1,
dimensions[i + 1].0,
)));
}
}
let mut p = Vec::with_capacity(n + 1);
p.push(dimensions[0].0 as u64);
for d in dimensions {
p.push(d.1 as u64);
}
let mut cost = vec![vec![0u64; n]; n];
for chain_len in 1..n {
for i in 0..n - chain_len {
let j = i + chain_len;
cost[i][j] = u64::MAX;
for k in i..j {
let q = cost[i][k] + cost[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
if q < cost[i][j] {
cost[i][j] = q;
}
}
}
}
Ok(cost[0][n - 1])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_graph() {
let g = GemmGraph::new();
assert!(g.is_empty());
assert_eq!(g.len(), 0);
assert_eq!(g.output_node(), None);
}
#[test]
fn single_gemm_node() {
let mut g = GemmGraph::new();
let id = g.add_gemm(
128,
256,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
assert_eq!(id, 0);
assert_eq!(g.len(), 1);
assert_eq!(g.nodes()[0].m, 128);
assert_eq!(g.nodes()[0].n, 256);
assert_eq!(g.nodes()[0].k, 64);
assert!(matches!(g.nodes()[0].op, GemmOp::Gemm));
}
#[test]
fn chain_two_matrices() {
let g = GemmGraph::chain(&[(10, 30), (30, 5)]);
assert!(g.is_ok());
let g = g.ok().filter(|_| true).unwrap_or_default();
assert_eq!(g.len(), 1);
assert_eq!(g.output_node(), Some(0));
let n = &g.nodes()[0];
assert_eq!((n.m, n.n, n.k), (10, 5, 30));
}
#[test]
fn chain_three_matrices() {
let g = GemmGraph::chain(&[(10, 30), (30, 5), (5, 60)]);
assert!(g.is_ok());
let g = g.ok().filter(|_| true).unwrap_or_default();
assert_eq!(g.len(), 2);
assert_eq!(g.output_node(), Some(1));
assert_eq!(
(g.nodes()[0].m, g.nodes()[0].n, g.nodes()[0].k),
(10, 5, 30)
);
assert_eq!(
(g.nodes()[1].m, g.nodes()[1].n, g.nodes()[1].k),
(10, 60, 5)
);
}
#[test]
fn chain_dimension_mismatch() {
let r = GemmGraph::chain(&[(10, 30), (20, 5)]); assert!(r.is_err());
}
#[test]
fn chain_too_few_matrices() {
let r = GemmGraph::chain(&[(10, 30)]);
assert!(r.is_err());
}
#[test]
fn add_scale_node() {
let mut g = GemmGraph::new();
let id0 = g.add_gemm(
64,
64,
32,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let id1 = g.add_scale(id0, 2.5);
assert!(id1.is_ok());
let id1 = id1.unwrap_or(0);
assert_eq!(g.nodes()[id1].m, 64);
assert_eq!(g.nodes()[id1].n, 64);
assert!(
matches!(g.nodes()[id1].op, GemmOp::Scale { alpha } if (alpha - 2.5).abs() < 1e-12)
);
}
#[test]
fn set_output_invalid() {
let mut g = GemmGraph::new();
assert!(g.set_output(0).is_err());
}
#[test]
fn optimal_order_three_matrices() {
let dims = [(10, 30), (30, 5), (5, 60)];
let order = optimal_chain_order(&dims);
assert!(order.is_ok());
let cost = minimum_chain_flops(&dims);
assert!(cost.is_ok());
assert_eq!(cost.unwrap_or(0), 4500);
}
#[test]
fn optimal_order_four_matrices() {
let dims = [(40, 20), (20, 30), (30, 10), (10, 30)];
let cost = minimum_chain_flops(&dims);
assert!(cost.is_ok());
let c = cost.unwrap_or(0);
assert!(c < 48000, "optimal cost {c} should be < 48000");
assert_eq!(c, 26000);
}
#[test]
fn optimal_order_dimension_mismatch() {
let r = optimal_chain_order(&[(10, 30), (20, 5)]);
assert!(r.is_err());
}
#[test]
fn optimal_order_too_few() {
let r = optimal_chain_order(&[(10, 30)]);
assert!(r.is_err());
}
#[test]
fn fusible_pairs_linear_chain() {
let g = GemmGraph::chain(&[(64, 64), (64, 64), (64, 64)]);
let g = g.ok().filter(|_| true).unwrap_or_default();
let pairs = FusionPass::find_fusible_pairs(&g);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].producer, 0);
assert_eq!(pairs[0].consumer, 1);
}
#[test]
fn fusible_pairs_no_fusion_fanout() {
let mut g = GemmGraph::new();
let n0 = g.add_gemm(
64,
64,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let _n1 = g.add_gemm(
64,
32,
64,
NodeInput::NodeOutput { node_id: n0 },
NodeInput::External { name: "C".into() },
);
let _n2 = g.add_gemm(
64,
16,
64,
NodeInput::NodeOutput { node_id: n0 },
NodeInput::External { name: "D".into() },
);
let pairs = FusionPass::find_fusible_pairs(&g);
assert!(pairs.is_empty());
}
#[test]
fn intermediate_memory_chain() {
let g = GemmGraph::chain(&[(64, 64), (64, 64), (64, 64)]);
let g = g.ok().filter(|_| true).unwrap_or_default();
let mem = FusionPass::estimate_intermediate_memory(&g);
assert_eq!(mem, 64 * 64 * 4);
}
#[test]
fn memory_savings_with_fusion() {
let g = GemmGraph::chain(&[(64, 64), (64, 64), (64, 64)]);
let g = g.ok().filter(|_| true).unwrap_or_default();
let plan = FusionPass::analyze(&g);
assert!(plan.is_ok());
let plan = plan.unwrap_or_else(|_| FusionPlan {
stages: Vec::new(),
fused_pairs: Vec::new(),
total_gemm_calls: 0,
original_gemm_calls: 0,
});
let savings = FusionPass::memory_savings(&g, &plan);
assert!(savings > 0.0, "expected positive savings, got {savings}");
}
#[test]
fn stages_linear_chain() {
let g = GemmGraph::chain(&[(128, 64), (64, 32), (32, 256)]);
let g = g.ok().filter(|_| true).unwrap_or_default();
let plan = FusionPass::analyze(&g);
assert!(plan.is_ok());
let plan = plan.unwrap_or_else(|_| FusionPlan {
stages: Vec::new(),
fused_pairs: Vec::new(),
total_gemm_calls: 0,
original_gemm_calls: 0,
});
assert_eq!(plan.stages.len(), 2);
assert_eq!(plan.stages[0].nodes, vec![0]);
assert_eq!(plan.stages[1].nodes, vec![1]);
assert_eq!(plan.original_gemm_calls, 2);
}
#[test]
fn display_fusion_plan() {
let g = GemmGraph::chain(&[(64, 64), (64, 64), (64, 64)]);
let g = g.ok().filter(|_| true).unwrap_or_default();
let plan = FusionPass::analyze(&g);
assert!(plan.is_ok());
let plan = plan.unwrap_or_else(|_| FusionPlan {
stages: Vec::new(),
fused_pairs: Vec::new(),
total_gemm_calls: 0,
original_gemm_calls: 0,
});
let display = format!("{plan}");
assert!(display.contains("FusionPlan"));
assert!(display.contains("stage"));
assert!(display.contains("GEMM calls"));
}
#[test]
fn analyze_empty_graph_errors() {
let g = GemmGraph::new();
assert!(FusionPass::analyze(&g).is_err());
}
#[test]
fn chain_flops_two_matrices() {
let dims = [(10, 30), (30, 5)];
let order = optimal_chain_order(&dims);
assert!(order.is_ok());
let order = order.unwrap_or_default();
let flops = estimate_chain_flops(&dims, &order);
assert!(flops.is_ok());
let f = flops.unwrap_or(0.0);
assert!((f - 3000.0).abs() < 1e-6, "expected 3000, got {f}");
}
#[test]
fn test_fusion_pass_identifies_gemm_bias_relu() {
let mut g = GemmGraph::new();
let gemm_id = g.add_gemm(
128,
256,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let bias_gemm_id = g.add_gemm(
128,
256,
1,
NodeInput::External {
name: "bias".into(),
},
NodeInput::External {
name: "ones".into(),
},
);
let add_id = g.add_add(gemm_id, bias_gemm_id).unwrap_or(gemm_id);
g.set_output(add_id).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(
opps.iter()
.any(|o| o.fusion_type == FusionType::GemmBiasActivation),
"expected GemmBiasActivation opportunity, got: {opps:?}"
);
}
#[test]
fn test_fusion_pass_identifies_gemm_scale() {
let mut g = GemmGraph::new();
let gemm_id = g.add_gemm(
128,
256,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let scale_id = g.add_scale(gemm_id, 0.5).unwrap_or(gemm_id);
g.set_output(scale_id).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(!opps.is_empty(), "expected at least one opportunity");
assert!(
opps.iter().any(|o| o.fusion_type == FusionType::GemmScale),
"expected GemmScale opportunity"
);
}
#[test]
fn test_fusion_pass_identifies_consecutive_gemm() {
let g = GemmGraph::chain(&[(64, 64), (64, 64), (64, 64)]).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(!opps.is_empty(), "expected ConsecutiveGemm opportunity");
assert!(
opps.iter()
.any(|o| o.fusion_type == FusionType::ConsecutiveGemm),
"expected at least one ConsecutiveGemm opportunity"
);
}
#[test]
fn test_fusion_pass_no_opportunities_for_standalone_gemm() {
let mut g = GemmGraph::new();
let id = g.add_gemm(
64,
64,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
g.set_output(id).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(
opps.is_empty(),
"expected no opportunities for standalone GEMM"
);
}
#[test]
fn test_fusion_speedup_large_mn_is_higher_than_small() {
let large = FusionPass::estimate_speedup(&FusionType::GemmBiasActivation, 1024, 1024, 64);
let small = FusionPass::estimate_speedup(&FusionType::GemmBiasActivation, 16, 16, 1024);
assert!(
large > small,
"large-MN speedup {large} should exceed small-MN speedup {small}"
);
}
#[test]
fn test_fused_plan_memory_savings_nonzero() {
let mut g = GemmGraph::new();
let gemm_id = g.add_gemm(
128,
256,
64,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let scale_id = g.add_scale(gemm_id, 2.0).unwrap_or(gemm_id);
g.set_output(scale_id).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(!opps.is_empty());
let plan = FusionPass::apply(&g, &opps[0]);
assert!(plan.is_ok(), "apply should succeed: {:?}", plan.err());
let plan = plan.unwrap_or_else(|_| FusedKernelPlan {
operations: vec![],
memory_saved: 0,
compute_overhead: 0.0,
});
assert_eq!(plan.memory_saved, 128 * 256 * 4);
}
#[test]
fn test_fusion_type_display() {
assert_eq!(
FusionType::GemmBiasActivation.to_string(),
"GemmBiasActivation"
);
assert_eq!(FusionType::GemmLayerNorm.to_string(), "GemmLayerNorm");
assert_eq!(FusionType::ConsecutiveGemm.to_string(), "ConsecutiveGemm");
assert_eq!(FusionType::GemmScale.to_string(), "GemmScale");
}
#[test]
fn test_apply_invalid_node_id_errors() {
let g = GemmGraph::new();
let opp = FusionOpportunity {
node_ids: vec![99],
estimated_speedup: 1.5,
fusion_type: FusionType::GemmScale,
};
assert!(FusionPass::apply(&g, &opp).is_err());
}
#[test]
fn test_apply_produces_fused_gemm_epilogue() {
let mut g = GemmGraph::new();
let gemm_id = g.add_gemm(
64,
64,
32,
NodeInput::External { name: "A".into() },
NodeInput::External { name: "B".into() },
);
let scale_id = g.add_scale(gemm_id, 1.5).unwrap_or(gemm_id);
g.set_output(scale_id).unwrap_or_default();
let opps = FusionPass::analyze_opportunities(&g, 0.0);
assert!(!opps.is_empty());
let plan = FusionPass::apply(&g, &opps[0]).unwrap_or_else(|_| FusedKernelPlan {
operations: vec![],
memory_saved: 0,
compute_overhead: 0.0,
});
assert!(
matches!(
plan.operations.first(),
Some(FusedOp::FusedGemmEpilogue { .. })
),
"expected FusedGemmEpilogue as first operation"
);
}
#[test]
fn test_gemm_layernorm_speedup() {
let speedup = FusionPass::estimate_speedup(&FusionType::GemmLayerNorm, 512, 512, 256);
assert!((speedup - 1.3).abs() < 1e-6, "expected 1.3, got {speedup}");
}
}