use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionPattern {
UnaryUnary,
BinaryBinary,
UnaryBinary,
BinaryUnary,
}
#[derive(Debug, Clone)]
pub struct FusionOpportunity {
pub node_indices: Vec<usize>,
pub pattern: FusionPattern,
pub description: String,
}
pub fn analyze_fusion_opportunities(graph: &EinsumGraph) -> Vec<FusionOpportunity> {
let mut opportunities = Vec::new();
for i in 0..(graph.nodes.len().saturating_sub(1)) {
let node1 = &graph.nodes[i];
let node2 = &graph.nodes[i + 1];
if let Some(output1) = node1.outputs.first() {
if node2.inputs.contains(output1) {
if let Some(opportunity) = check_fusion_pair(node1, node2, i, i + 1) {
opportunities.push(opportunity);
}
}
}
}
opportunities
}
fn check_fusion_pair(
node1: &EinsumNode,
node2: &EinsumNode,
idx1: usize,
idx2: usize,
) -> Option<FusionOpportunity> {
match (&node1.op, &node2.op) {
(OpType::ElemUnary { op: op1 }, OpType::ElemUnary { op: op2 }) => Some(FusionOpportunity {
node_indices: vec![idx1, idx2],
pattern: FusionPattern::UnaryUnary,
description: format!("Fuse {} → {}", op1, op2),
}),
(OpType::ElemBinary { op: op1 }, OpType::ElemUnary { op: op2 }) => {
Some(FusionOpportunity {
node_indices: vec![idx1, idx2],
pattern: FusionPattern::BinaryUnary,
description: format!("Fuse {} → {}", op1, op2),
})
}
(OpType::ElemUnary { op: op1 }, OpType::ElemBinary { op: op2 }) => {
Some(FusionOpportunity {
node_indices: vec![idx1, idx2],
pattern: FusionPattern::UnaryBinary,
description: format!("Fuse {} → {}", op1, op2),
})
}
_ => None,
}
}
#[derive(Debug, Clone)]
pub struct FusionStats {
pub total_opportunities: usize,
pub by_pattern: std::collections::HashMap<String, usize>,
pub estimated_speedup: f64,
}
impl FusionStats {
pub fn from_opportunities(opportunities: &[FusionOpportunity]) -> Self {
let mut by_pattern = std::collections::HashMap::new();
for opp in opportunities {
let pattern_name = format!("{:?}", opp.pattern);
*by_pattern.entry(pattern_name).or_insert(0) += 1;
}
let estimated_speedup = 1.0 + (opportunities.len() as f64 * 0.2);
FusionStats {
total_opportunities: opportunities.len(),
by_pattern,
estimated_speedup,
}
}
}
impl std::fmt::Display for FusionStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Fusion Analysis:")?;
writeln!(f, " Total opportunities: {}", self.total_opportunities)?;
writeln!(f, " By pattern:")?;
for (pattern, count) in &self.by_pattern {
writeln!(f, " {}: {}", pattern, count)?;
}
write!(f, " Estimated speedup: {:.2}x", self.estimated_speedup)
}
}
pub fn suggest_fusions(graph: &EinsumGraph) -> (Vec<FusionOpportunity>, FusionStats) {
let opportunities = analyze_fusion_opportunities(graph);
let stats = FusionStats::from_opportunities(&opportunities);
(opportunities, stats)
}
#[cfg(all(test, feature = "integration-tests"))]
mod tests {
use super::*;
use std::collections::HashMap;
use tensorlogic_compiler::compile_to_einsum;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_analyze_unary_unary_fusion() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let y = TLExpr::pred("y", vec![Term::var("i")]);
let expr = TLExpr::mul(TLExpr::add(x.clone(), y.clone()), x);
let graph = compile_to_einsum(&expr).expect("unwrap");
let _opportunities = analyze_fusion_opportunities(&graph);
}
#[test]
fn test_analyze_binary_unary_fusion() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let y = TLExpr::pred("y", vec![Term::var("i")]);
let sum = TLExpr::add(x, y);
let zero = TLExpr::constant(0.0);
let expr = TLExpr::gt(sum, zero);
let graph = compile_to_einsum(&expr).expect("unwrap");
let _opportunities = analyze_fusion_opportunities(&graph);
}
#[test]
fn test_fusion_stats() {
let opportunities = vec![
FusionOpportunity {
node_indices: vec![0, 1],
pattern: FusionPattern::UnaryUnary,
description: "relu → sigmoid".to_string(),
},
FusionOpportunity {
node_indices: vec![2, 3],
pattern: FusionPattern::BinaryUnary,
description: "add → relu".to_string(),
},
FusionOpportunity {
node_indices: vec![4, 5],
pattern: FusionPattern::UnaryUnary,
description: "sigmoid → relu".to_string(),
},
];
let stats = FusionStats::from_opportunities(&opportunities);
assert_eq!(stats.total_opportunities, 3);
assert_eq!(stats.by_pattern.len(), 2); assert!(stats.estimated_speedup > 1.0);
}
#[test]
fn test_suggest_fusions() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let y = TLExpr::pred("y", vec![Term::var("i")]);
let expr = TLExpr::mul(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let (opportunities, stats) = suggest_fusions(&graph);
assert_eq!(opportunities.len(), stats.total_opportunities);
}
#[test]
fn test_empty_graph_fusion() {
let graph = EinsumGraph {
tensors: vec![],
nodes: vec![],
inputs: vec![],
outputs: vec![],
tensor_metadata: HashMap::new(),
};
let opportunities = analyze_fusion_opportunities(&graph);
assert_eq!(opportunities.len(), 0);
}
#[test]
fn test_single_node_no_fusion() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let graph = compile_to_einsum(&x).expect("unwrap");
let opportunities = analyze_fusion_opportunities(&graph);
assert_eq!(opportunities.len(), 0);
}
}