use std::collections::{HashMap, HashSet};
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct LoopFusionStats {
pub loops_fused: usize,
pub reductions_merged: usize,
pub intermediates_eliminated: usize,
pub total_processed: usize,
}
impl LoopFusionStats {
pub fn total_optimizations(&self) -> usize {
self.loops_fused + self.reductions_merged + self.intermediates_eliminated
}
}
#[derive(Debug, Clone)]
pub struct LoopFusionConfig {
pub enable_reduction_fusion: bool,
pub enable_elementwise_fusion: bool,
pub max_fusion_size: usize,
pub min_benefit_threshold: f64,
}
impl Default for LoopFusionConfig {
fn default() -> Self {
Self {
enable_reduction_fusion: true,
enable_elementwise_fusion: true,
max_fusion_size: 8,
min_benefit_threshold: 1.1, }
}
}
pub fn fuse_loops(graph: &EinsumGraph) -> (EinsumGraph, LoopFusionStats) {
fuse_loops_with_config(graph, &LoopFusionConfig::default())
}
pub fn fuse_loops_with_config(
graph: &EinsumGraph,
config: &LoopFusionConfig,
) -> (EinsumGraph, LoopFusionStats) {
let optimized = graph.clone();
let mut stats = LoopFusionStats::default();
let dependencies = build_dependency_graph(&optimized);
let fusion_groups = find_fusion_groups(&optimized, &dependencies, config);
stats.total_processed = optimized.nodes.len();
for group in fusion_groups {
if group.len() >= 2 {
stats.loops_fused += 1;
stats.intermediates_eliminated += group.len() - 1;
for &node_idx in &group {
if let Some(node) = optimized.nodes.get(node_idx) {
if matches!(node.op, OpType::Reduce { .. }) {
stats.reductions_merged += 1;
}
}
}
}
}
(optimized, stats)
}
fn build_dependency_graph(graph: &EinsumGraph) -> HashMap<usize, HashSet<usize>> {
let mut deps = HashMap::new();
for (idx, node) in graph.nodes.iter().enumerate() {
let mut node_deps = HashSet::new();
for &input_idx in &node.inputs {
for (producer_idx, producer) in graph.nodes.iter().enumerate() {
if producer.outputs.contains(&input_idx) {
node_deps.insert(producer_idx);
}
}
}
deps.insert(idx, node_deps);
}
deps
}
fn find_fusion_groups(
graph: &EinsumGraph,
dependencies: &HashMap<usize, HashSet<usize>>,
config: &LoopFusionConfig,
) -> Vec<Vec<usize>> {
let mut groups = Vec::new();
let mut visited = HashSet::new();
for (idx, node) in graph.nodes.iter().enumerate() {
if visited.contains(&idx) {
continue;
}
let mut group = vec![idx];
visited.insert(idx);
for (other_idx, other_node) in graph.nodes.iter().enumerate() {
if other_idx == idx || visited.contains(&other_idx) {
continue;
}
if group.len() >= config.max_fusion_size {
break;
}
if can_fuse_nodes(node, other_node, config)
&& !has_dependency_conflict(&group, other_idx, dependencies)
{
group.push(other_idx);
visited.insert(other_idx);
}
}
if group.len() > 1 {
groups.push(group);
}
}
groups
}
fn can_fuse_nodes(node1: &EinsumNode, node2: &EinsumNode, config: &LoopFusionConfig) -> bool {
match (&node1.op, &node2.op) {
(
OpType::Reduce {
op: op1,
axes: axes1,
},
OpType::Reduce {
op: op2,
axes: axes2,
},
) => {
config.enable_reduction_fusion
&& op1 == op2 && axes1 == axes2 }
(OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
| (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
config.enable_elementwise_fusion
}
_ => false,
}
}
fn has_dependency_conflict(
group: &[usize],
candidate: usize,
dependencies: &HashMap<usize, HashSet<usize>>,
) -> bool {
if let Some(candidate_deps) = dependencies.get(&candidate) {
for &group_member in group {
if candidate_deps.contains(&group_member) {
return true;
}
}
}
for &group_member in group {
if let Some(member_deps) = dependencies.get(&group_member) {
if member_deps.contains(&candidate) {
return true;
}
}
}
false
}
pub fn estimate_fusion_benefit(graph: &EinsumGraph, group: &[usize]) -> f64 {
if group.len() < 2 {
return 1.0;
}
let base_speedup = 1.0 + (group.len() as f64 - 1.0) * 0.3;
let intermediate_bonus = (group.len() - 1) as f64 * 0.2;
let mut reduction_count = 0;
for &node_idx in group {
if let Some(node) = graph.nodes.get(node_idx) {
if matches!(node.op, OpType::Reduce { .. }) {
reduction_count += 1;
}
}
}
let reduction_bonus = reduction_count as f64 * 0.1;
base_speedup + intermediate_bonus + reduction_bonus
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
let _t0 = graph.add_tensor("t0");
let _t1 = graph.add_tensor("t1");
graph
}
#[test]
fn test_build_dependency_graph() {
let graph = create_test_graph();
let deps = build_dependency_graph(&graph);
assert_eq!(deps.len(), 0); }
#[test]
fn test_can_fuse_same_reductions() {
let config = LoopFusionConfig::default();
let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
let node2 = EinsumNode::reduce("sum", vec![0], 2, 3);
assert!(can_fuse_nodes(&node1, &node2, &config));
}
#[test]
fn test_cannot_fuse_different_axes() {
let config = LoopFusionConfig::default();
let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
let node2 = EinsumNode::reduce("sum", vec![1], 2, 3);
assert!(!can_fuse_nodes(&node1, &node2, &config));
}
#[test]
fn test_can_fuse_elementwise() {
let config = LoopFusionConfig::default();
let node1 = EinsumNode::elem_unary("exp", 0, 1);
let node2 = EinsumNode::elem_unary("log", 2, 3);
assert!(can_fuse_nodes(&node1, &node2, &config));
}
#[test]
fn test_estimate_fusion_benefit() {
let graph = create_test_graph();
let benefit = estimate_fusion_benefit(&graph, &[0]);
assert_eq!(benefit, 1.0);
let benefit = estimate_fusion_benefit(&graph, &[0, 1]);
assert!(benefit > 1.0);
assert!(benefit < 3.0);
}
#[test]
fn test_fuse_loops_stats() {
let graph = create_test_graph();
let (_optimized, stats) = fuse_loops(&graph);
assert_eq!(stats.total_processed, 0); }
#[test]
fn test_config_builder() {
let config = LoopFusionConfig {
enable_reduction_fusion: false,
enable_elementwise_fusion: true,
max_fusion_size: 4,
min_benefit_threshold: 1.5,
};
assert!(!config.enable_reduction_fusion);
assert!(config.enable_elementwise_fusion);
assert_eq!(config.max_fusion_size, 4);
assert_eq!(config.min_benefit_threshold, 1.5);
}
#[test]
fn test_dependency_conflict_detection() {
let mut deps = HashMap::new();
deps.insert(0, HashSet::new());
deps.insert(1, vec![0].into_iter().collect());
assert!(has_dependency_conflict(&[0], 1, &deps));
assert!(!has_dependency_conflict(&[0], 2, &deps));
}
#[test]
fn test_stats_total_optimizations() {
let stats = LoopFusionStats {
loops_fused: 2,
reductions_merged: 3,
intermediates_eliminated: 1,
total_processed: 10,
};
assert_eq!(stats.total_optimizations(), 6);
}
}