use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use tensorlogic_ir::{EinsumGraph, OpType};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub usize);
#[derive(Error, Debug, Clone, PartialEq)]
pub enum FusionError {
#[error("Fusion would create a cycle in the graph")]
WouldCreateCycle,
#[error("Incompatible operations for fusion: {0:?} and {1:?}")]
IncompatibleOps(OpType, OpType),
#[error("Fusion exceeds resource limits: {0}")]
ResourceLimitExceeded(String),
#[error("Invalid fusion pattern")]
InvalidPattern,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FusionPattern {
MatMulBias,
MatMulActivation,
BiasActivation,
BatchNormReLU,
ConvBNReLU,
ElementwiseChain,
ReduceElementwise,
ParallelReductions,
BroadcastElementwise,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FusionStrategy {
Conservative,
Aggressive,
Balanced,
MemoryAware,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FusionCandidate {
pub nodes: Vec<NodeId>,
pub pattern: FusionPattern,
pub benefit_score: f64,
pub memory_savings: usize,
pub compute_savings: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub strategy: FusionStrategy,
pub max_fusion_size: usize,
pub enable_patterns: bool,
pub enable_vertical: bool,
pub enable_horizontal: bool,
pub enable_loop_fusion: bool,
pub memory_bandwidth_threshold: Option<f64>,
pub min_benefit_score: f64,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
strategy: FusionStrategy::Balanced,
max_fusion_size: 8,
enable_patterns: true,
enable_vertical: true,
enable_horizontal: true,
enable_loop_fusion: true,
memory_bandwidth_threshold: None,
min_benefit_score: 0.1,
}
}
}
impl FusionConfig {
pub fn aggressive() -> Self {
Self {
strategy: FusionStrategy::Aggressive,
max_fusion_size: 16,
min_benefit_score: 0.0,
..Default::default()
}
}
pub fn conservative() -> Self {
Self {
strategy: FusionStrategy::Conservative,
max_fusion_size: 4,
enable_horizontal: false,
enable_loop_fusion: false,
min_benefit_score: 0.3,
..Default::default()
}
}
pub fn memory_aware() -> Self {
Self {
strategy: FusionStrategy::MemoryAware,
memory_bandwidth_threshold: Some(100e9), ..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct FusionCostModel {
pub memory_access_cost: f64,
pub compute_cost: f64,
pub kernel_launch_cost: f64,
pub memory_bandwidth: f64,
}
impl Default for FusionCostModel {
fn default() -> Self {
Self {
memory_access_cost: 1.0,
compute_cost: 0.1,
kernel_launch_cost: 10.0,
memory_bandwidth: 100e9, }
}
}
impl FusionCostModel {
pub fn cost_separate(&self, num_ops: usize, data_size: usize) -> f64 {
let memory_cost = self.memory_access_cost * data_size as f64 * num_ops as f64;
let launch_cost = self.kernel_launch_cost * num_ops as f64;
memory_cost + launch_cost
}
pub fn cost_fused(&self, num_ops: usize, data_size: usize) -> f64 {
let memory_cost = self.memory_access_cost * data_size as f64 * 2.0;
let launch_cost = self.kernel_launch_cost;
let compute_overhead = self.compute_cost * num_ops as f64; memory_cost + launch_cost + compute_overhead
}
pub fn fusion_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
let separate_cost = self.cost_separate(num_ops, data_size);
let fused_cost = self.cost_fused(num_ops, data_size);
(separate_cost - fused_cost) / separate_cost
}
}
pub struct FusionOptimizer {
config: FusionConfig,
cost_model: FusionCostModel,
candidates: Vec<FusionCandidate>,
}
impl FusionOptimizer {
pub fn new(config: FusionConfig) -> Self {
Self {
config,
cost_model: FusionCostModel::default(),
candidates: Vec::new(),
}
}
pub fn with_cost_model(config: FusionConfig, cost_model: FusionCostModel) -> Self {
Self {
config,
cost_model,
candidates: Vec::new(),
}
}
pub fn analyze(&mut self, graph: &EinsumGraph) -> Vec<FusionCandidate> {
self.candidates.clear();
if self.config.enable_patterns {
self.find_pattern_fusions(graph);
}
if self.config.enable_vertical {
self.find_vertical_fusions(graph);
}
if self.config.enable_horizontal {
self.find_horizontal_fusions(graph);
}
self.candidates.sort_by(|a, b| {
b.benefit_score
.partial_cmp(&a.benefit_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.candidates.clone()
}
fn find_pattern_fusions(&mut self, graph: &EinsumGraph) {
for node_id in 0..graph.nodes.len() {
let node_id = NodeId(node_id);
let node = &graph.nodes[node_id.0];
if matches!(node.op, OpType::Einsum { .. }) {
let consumers = self.find_consumers(graph, node_id);
for consumer in consumers {
let consumer_node = &graph.nodes[consumer.0];
if matches!(
consumer_node.op,
OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
) {
let benefit = self.estimate_pattern_benefit(2, 1024);
if benefit >= self.config.min_benefit_score {
self.candidates.push(FusionCandidate {
nodes: vec![node_id, consumer],
pattern: FusionPattern::MatMulActivation,
benefit_score: benefit,
memory_savings: 1024 * 4, compute_savings: 0.0,
});
}
}
}
}
}
}
fn find_vertical_fusions(&mut self, graph: &EinsumGraph) {
for node_id in 0..graph.nodes.len() {
let node_id = NodeId(node_id);
let consumers = self.find_consumers(graph, node_id);
if consumers.len() == 1 {
let consumer = consumers[0];
if self.can_fuse_vertically(graph, node_id, consumer) {
let benefit = self.cost_model.fusion_benefit(2, 1024);
if benefit >= self.config.min_benefit_score {
self.candidates.push(FusionCandidate {
nodes: vec![node_id, consumer],
pattern: FusionPattern::ElementwiseChain,
benefit_score: benefit,
memory_savings: 1024 * 4,
compute_savings: 0.0,
});
}
}
}
}
}
fn find_horizontal_fusions(&mut self, graph: &EinsumGraph) {
let _independent_groups: Vec<Vec<NodeId>> = Vec::new();
let mut depth_groups: HashMap<usize, Vec<NodeId>> = HashMap::new();
for node_id in 0..graph.nodes.len() {
let depth = self.compute_depth(graph, NodeId(node_id));
depth_groups.entry(depth).or_default().push(NodeId(node_id));
}
for (_, nodes) in depth_groups {
if nodes.len() >= 2 {
for i in 0..nodes.len() {
for j in i + 1..nodes.len() {
if self.are_independent(graph, nodes[i], nodes[j])
&& self.have_similar_ops(graph, nodes[i], nodes[j])
{
let benefit = self.cost_model.fusion_benefit(2, 512);
if benefit >= self.config.min_benefit_score {
self.candidates.push(FusionCandidate {
nodes: vec![nodes[i], nodes[j]],
pattern: FusionPattern::ParallelReductions,
benefit_score: benefit * 0.8, memory_savings: 512 * 4,
compute_savings: 0.0,
});
}
}
}
}
}
}
}
fn can_fuse_vertically(
&self,
_graph: &EinsumGraph,
_producer: NodeId,
_consumer: NodeId,
) -> bool {
true }
fn are_independent(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
let a_deps = self.get_all_dependencies(graph, a);
let b_deps = self.get_all_dependencies(graph, b);
!a_deps.contains(&b) && !b_deps.contains(&a)
}
fn have_similar_ops(&self, graph: &EinsumGraph, a: NodeId, b: NodeId) -> bool {
let op_a = &graph.nodes[a.0].op;
let op_b = &graph.nodes[b.0].op;
std::mem::discriminant(op_a) == std::mem::discriminant(op_b)
}
fn find_consumers(&self, graph: &EinsumGraph, producer: NodeId) -> Vec<NodeId> {
let mut consumers = Vec::new();
for (i, node) in graph.nodes.iter().enumerate() {
if node.inputs.iter().any(|&n| NodeId(n) == producer) {
consumers.push(NodeId(i));
}
}
consumers
}
fn get_all_dependencies(&self, graph: &EinsumGraph, node_id: NodeId) -> HashSet<NodeId> {
let mut deps = HashSet::new();
let mut to_visit = vec![node_id];
while let Some(current) = to_visit.pop() {
if deps.contains(¤t) {
continue;
}
deps.insert(current);
let node = &graph.nodes[current.0];
for &input in &node.inputs {
to_visit.push(NodeId(input));
}
}
deps
}
#[allow(clippy::only_used_in_recursion)]
fn compute_depth(&self, graph: &EinsumGraph, node_id: NodeId) -> usize {
let node = &graph.nodes[node_id.0];
if node.inputs.is_empty() {
0
} else {
1 + node
.inputs
.iter()
.map(|&input| self.compute_depth(graph, NodeId(input)))
.max()
.unwrap_or(0)
}
}
fn estimate_pattern_benefit(&self, num_ops: usize, data_size: usize) -> f64 {
match self.config.strategy {
FusionStrategy::Aggressive => self.cost_model.fusion_benefit(num_ops, data_size) * 1.2,
FusionStrategy::Conservative => {
self.cost_model.fusion_benefit(num_ops, data_size) * 0.8
}
FusionStrategy::Balanced => self.cost_model.fusion_benefit(num_ops, data_size),
FusionStrategy::MemoryAware => {
let base_benefit = self.cost_model.fusion_benefit(num_ops, data_size);
base_benefit * 1.5
}
}
}
pub fn apply_fusions(
&self,
graph: &EinsumGraph,
_candidates: &[FusionCandidate],
) -> Result<EinsumGraph, FusionError> {
Ok(graph.clone())
}
pub fn stats(&self) -> FusionStats {
let total_candidates = self.candidates.len();
let total_memory_savings: usize = self.candidates.iter().map(|c| c.memory_savings).sum();
let avg_benefit_score = if total_candidates > 0 {
self.candidates.iter().map(|c| c.benefit_score).sum::<f64>() / total_candidates as f64
} else {
0.0
};
let mut pattern_counts = HashMap::new();
for candidate in &self.candidates {
*pattern_counts.entry(candidate.pattern).or_insert(0) += 1;
}
FusionStats {
total_candidates,
total_memory_savings,
avg_benefit_score,
pattern_distribution: pattern_counts,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionStats {
pub total_candidates: usize,
pub total_memory_savings: usize,
pub avg_benefit_score: f64,
pub pattern_distribution: HashMap<FusionPattern, usize>,
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::EinsumNode;
fn create_test_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
graph.nodes.push(EinsumNode {
op: OpType::Einsum {
spec: "ij,jk->ik".to_string(),
},
inputs: vec![],
outputs: vec![0],
metadata: Default::default(),
});
graph.nodes.push(EinsumNode {
op: OpType::ElemUnary {
op: "relu".to_string(),
},
inputs: vec![0],
outputs: vec![1],
metadata: Default::default(),
});
graph
}
#[test]
fn test_fusion_config() {
let config = FusionConfig::aggressive();
assert_eq!(config.strategy, FusionStrategy::Aggressive);
assert!(config.max_fusion_size >= FusionConfig::default().max_fusion_size);
let config = FusionConfig::conservative();
assert_eq!(config.strategy, FusionStrategy::Conservative);
}
#[test]
fn test_cost_model() {
let model = FusionCostModel::default();
let benefit = model.fusion_benefit(3, 1024);
assert!(benefit > 0.0);
assert!(benefit < 1.0);
let benefit_more = model.fusion_benefit(5, 1024);
assert!(benefit_more > benefit);
}
#[test]
fn test_fusion_optimizer_creation() {
let config = FusionConfig::default();
let optimizer = FusionOptimizer::new(config);
assert_eq!(optimizer.candidates.len(), 0);
}
#[test]
fn test_fusion_analysis() {
let graph = create_test_graph();
let config = FusionConfig {
min_benefit_score: 0.0,
..FusionConfig::default()
};
let mut optimizer = FusionOptimizer::new(config);
let candidates = optimizer.analyze(&graph);
assert!(!candidates.is_empty());
}
#[test]
fn test_consumer_finding() {
let graph = create_test_graph();
let optimizer = FusionOptimizer::new(FusionConfig::default());
let consumers = optimizer.find_consumers(&graph, NodeId(0));
assert_eq!(consumers.len(), 1);
assert_eq!(consumers[0], NodeId(1));
}
#[test]
fn test_depth_computation() {
let graph = create_test_graph();
let optimizer = FusionOptimizer::new(FusionConfig::default());
assert_eq!(optimizer.compute_depth(&graph, NodeId(0)), 0);
assert_eq!(optimizer.compute_depth(&graph, NodeId(1)), 1);
}
#[test]
fn test_independence_check() {
let mut graph = create_test_graph();
graph.nodes.push(EinsumNode {
op: OpType::ElemUnary {
op: "tanh".to_string(),
},
inputs: vec![],
outputs: vec![2],
metadata: Default::default(),
});
let optimizer = FusionOptimizer::new(FusionConfig::default());
assert!(!optimizer.are_independent(&graph, NodeId(0), NodeId(1)));
assert!(optimizer.are_independent(&graph, NodeId(0), NodeId(2)));
}
#[test]
fn test_fusion_stats() {
let graph = create_test_graph();
let config = FusionConfig {
min_benefit_score: 0.0,
..FusionConfig::default()
};
let mut optimizer = FusionOptimizer::new(config);
optimizer.analyze(&graph);
let stats = optimizer.stats();
assert!(stats.total_candidates > 0);
assert!(stats.avg_benefit_score >= 0.0);
}
#[test]
fn test_similar_ops_check() {
let graph = create_test_graph();
let optimizer = FusionOptimizer::new(FusionConfig::default());
assert!(optimizer.have_similar_ops(&graph, NodeId(0), NodeId(0)));
assert!(!optimizer.have_similar_ops(&graph, NodeId(0), NodeId(1)));
}
}