use super::core::FusedOp;
#[derive(Debug)]
pub struct FusionOpportunity {
pub position: usize,
pub operation: FusedOp,
pub expected_benefit: f64, pub memory_savings: usize, }
pub fn analyze_fusion_opportunities(
operations: &[&str],
tensor_sizes: &[usize],
memory_bandwidth: f64,
compute_throughput: f64,
) -> Vec<FusionOpportunity> {
let mut opportunities = Vec::new();
for (i, window) in operations.windows(2).enumerate() {
if let Some(op) = match window {
["add", "relu"] => Some(FusedOp::ReluAdd),
["mul", "add"] => Some(FusedOp::MulAdd),
["add", "mul"] => Some(FusedOp::AddMul),
["sigmoid", "mul"] => Some(FusedOp::SigmoidMul),
["tanh", "scale"] => Some(FusedOp::TanhScale),
_ => None,
} {
let benefit = calculate_fusion_benefit(
&op,
tensor_sizes.get(i).copied().unwrap_or(0),
memory_bandwidth,
compute_throughput,
);
if benefit > 0.1 {
opportunities.push(FusionOpportunity {
position: i,
operation: op.clone(),
expected_benefit: benefit,
memory_savings: estimate_memory_savings(
&op,
tensor_sizes.get(i).copied().unwrap_or(0),
),
});
}
}
}
for (i, window) in operations.windows(3).enumerate() {
if let Some(op) = match window {
["add", "relu", "mul"] => Some(FusedOp::AddReluMul),
["batch_norm", "relu", "dropout"] => Some(FusedOp::BatchNormFused),
_ => None,
} {
let benefit = calculate_fusion_benefit(
&op,
tensor_sizes.get(i).copied().unwrap_or(0),
memory_bandwidth,
compute_throughput,
);
if benefit > 0.15 {
opportunities.push(FusionOpportunity {
position: i,
operation: op.clone(),
expected_benefit: benefit,
memory_savings: estimate_memory_savings(
&op,
tensor_sizes.get(i).copied().unwrap_or(0),
),
});
}
}
}
opportunities.sort_by(|a, b| {
b.expected_benefit
.partial_cmp(&a.expected_benefit)
.unwrap_or(std::cmp::Ordering::Equal)
});
opportunities
}
fn calculate_fusion_benefit(
op: &FusedOp,
tensor_size: usize,
memory_bandwidth: f64,
compute_throughput: f64,
) -> f64 {
let element_size = 4; let total_bytes = tensor_size * element_size;
let unfused_memory_accesses = match op {
FusedOp::ReluAdd | FusedOp::AddMul | FusedOp::MulAdd => {
6.0
}
FusedOp::SigmoidMul => {
5.0
}
FusedOp::TanhScale => {
4.0
}
FusedOp::AddReluMul => {
8.0
}
FusedOp::BatchNormFused => {
12.0
}
FusedOp::LayerNormFused => {
10.0
}
};
let fused_memory_accesses = match op {
FusedOp::ReluAdd | FusedOp::AddMul | FusedOp::MulAdd => 3.0, FusedOp::SigmoidMul => 3.0, FusedOp::TanhScale => 2.0, FusedOp::AddReluMul => 4.0, FusedOp::BatchNormFused => 6.0, FusedOp::LayerNormFused => 5.0, };
let memory_time_savings =
(unfused_memory_accesses - fused_memory_accesses) * total_bytes as f64 / memory_bandwidth;
let compute_time = total_bytes as f64 / compute_throughput;
if memory_time_savings + compute_time > 0.0 {
memory_time_savings / (memory_time_savings + compute_time)
} else {
0.0
}
}
fn estimate_memory_savings(op: &FusedOp, tensor_size: usize) -> usize {
let element_size = 4;
match op {
FusedOp::ReluAdd | FusedOp::AddMul | FusedOp::MulAdd => {
tensor_size * element_size
}
FusedOp::SigmoidMul => {
tensor_size * element_size
}
FusedOp::AddReluMul => {
tensor_size * element_size * 2
}
FusedOp::BatchNormFused | FusedOp::LayerNormFused => {
tensor_size * element_size * 3
}
_ => tensor_size * element_size,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fusion_benefit_calculation() {
let op = FusedOp::ReluAdd;
let tensor_size = 1000;
let memory_bandwidth = 100_000_000.0; let compute_throughput = 1_000_000_000.0;
let benefit =
calculate_fusion_benefit(&op, tensor_size, memory_bandwidth, compute_throughput);
assert!(benefit > 0.0);
assert!(benefit < 1.0);
}
#[test]
fn test_memory_savings_estimation() {
let op = FusedOp::AddReluMul;
let tensor_size = 1000;
let savings = estimate_memory_savings(&op, tensor_size);
assert_eq!(savings, tensor_size * 4 * 2);
}
#[test]
fn test_fusion_opportunities() {
let operations = ["add", "relu", "mul"];
let tensor_sizes = [1000, 1000, 1000];
let memory_bandwidth = 100_000_000.0;
let compute_throughput = 1_000_000_000.0;
let opportunities = analyze_fusion_opportunities(
&operations,
&tensor_sizes,
memory_bandwidth,
compute_throughput,
);
assert!(!opportunities.is_empty());
assert!(opportunities
.iter()
.any(|opp| matches!(opp.operation, FusedOp::AddReluMul)));
}
}