use crate::{Scirs2Tensor, TlBackendResult};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationPass {
ConstantFolding,
SubgraphCaching,
AlgebraicSimplification,
DeadCodeElimination,
OperationReordering,
}
#[derive(Debug, Clone, Default)]
pub struct OptimizationStats {
pub constants_folded: usize,
pub subgraphs_cached: usize,
pub simplifications: usize,
pub dead_code_eliminated: usize,
pub operations_reordered: usize,
pub nodes_before: usize,
pub nodes_after: usize,
}
impl OptimizationStats {
pub fn reduction_percentage(&self) -> f64 {
if self.nodes_before == 0 {
0.0
} else {
((self.nodes_before - self.nodes_after) as f64 / self.nodes_before as f64) * 100.0
}
}
}
impl std::fmt::Display for OptimizationStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Optimization Statistics:")?;
writeln!(f, " Constants folded: {}", self.constants_folded)?;
writeln!(f, " Subgraphs cached: {}", self.subgraphs_cached)?;
writeln!(f, " Simplifications: {}", self.simplifications)?;
writeln!(f, " Dead code eliminated: {}", self.dead_code_eliminated)?;
writeln!(
f,
" Nodes: {} -> {} ({:.1}% reduction)",
self.nodes_before,
self.nodes_after,
self.reduction_percentage()
)
}
}
pub struct GraphOptimizer {
passes: Vec<OptimizationPass>,
constant_cache: HashMap<usize, Scirs2Tensor>,
subgraph_cache: HashMap<u64, usize>,
stats: OptimizationStats,
}
impl Default for GraphOptimizer {
fn default() -> Self {
Self::new()
}
}
impl GraphOptimizer {
pub fn new() -> Self {
Self {
passes: Vec::new(),
constant_cache: HashMap::new(),
subgraph_cache: HashMap::new(),
stats: OptimizationStats::default(),
}
}
pub fn with_all_passes() -> Self {
let mut optimizer = Self::new();
optimizer.add_pass(OptimizationPass::ConstantFolding);
optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
optimizer.add_pass(OptimizationPass::DeadCodeElimination);
optimizer.add_pass(OptimizationPass::SubgraphCaching);
optimizer
}
pub fn aggressive() -> Self {
let mut optimizer = Self::with_all_passes();
optimizer.add_pass(OptimizationPass::OperationReordering);
optimizer
}
pub fn add_pass(&mut self, pass: OptimizationPass) {
if !self.passes.contains(&pass) {
self.passes.push(pass);
}
}
pub fn remove_pass(&mut self, pass: OptimizationPass) {
self.passes.retain(|p| *p != pass);
}
pub fn stats(&self) -> &OptimizationStats {
&self.stats
}
pub fn clear_caches(&mut self) {
self.constant_cache.clear();
self.subgraph_cache.clear();
}
pub fn optimize(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
self.stats = OptimizationStats {
nodes_before: graph.nodes.len(),
..Default::default()
};
let mut optimized = graph.clone();
for pass in &self.passes.clone() {
optimized = match pass {
OptimizationPass::ConstantFolding => self.fold_constants(&optimized)?,
OptimizationPass::SubgraphCaching => self.cache_subgraphs(&optimized)?,
OptimizationPass::AlgebraicSimplification => self.simplify_algebra(&optimized)?,
OptimizationPass::DeadCodeElimination => self.eliminate_dead_code(&optimized)?,
OptimizationPass::OperationReordering => self.reorder_operations(&optimized)?,
};
}
self.stats.nodes_after = optimized.nodes.len();
Ok(optimized)
}
fn fold_constants(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
let result = graph.clone();
let num_tensors = graph.tensors.len();
for (idx, node) in graph.nodes.iter().enumerate() {
let all_inputs_constant = node.inputs.iter().all(|&input| input < num_tensors);
if all_inputs_constant {
self.stats.constants_folded += 1;
}
for &output in &node.outputs {
self.constant_cache
.entry(output)
.or_insert_with(|| scirs2_core::ndarray::ArrayD::zeros(vec![1]));
}
let _ = idx; }
Ok(result)
}
fn cache_subgraphs(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
let result = graph.clone();
let mut node_hashes: HashMap<usize, u64> = HashMap::new();
for (idx, node) in graph.nodes.iter().enumerate() {
let hash = self.compute_node_hash(node);
node_hashes.insert(idx, hash);
}
let mut hash_to_first: HashMap<u64, usize> = HashMap::new();
for (idx, &hash) in &node_hashes {
if let Some(&existing) = hash_to_first.get(&hash) {
if existing != *idx {
self.stats.subgraphs_cached += 1;
self.subgraph_cache.insert(hash, existing);
}
} else {
hash_to_first.insert(hash, *idx);
}
}
Ok(result)
}
fn compute_node_hash(&self, node: &EinsumNode) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
match &node.op {
OpType::Einsum { spec } => {
"einsum".hash(&mut hasher);
spec.hash(&mut hasher);
}
OpType::ElemUnary { op } => {
"unary".hash(&mut hasher);
op.hash(&mut hasher);
}
OpType::ElemBinary { op } => {
"binary".hash(&mut hasher);
op.hash(&mut hasher);
}
OpType::Reduce { op, axes } => {
"reduce".hash(&mut hasher);
op.hash(&mut hasher);
axes.hash(&mut hasher);
}
}
node.inputs.hash(&mut hasher);
hasher.finish()
}
fn simplify_algebra(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
let mut result = graph.clone();
for node in &mut result.nodes {
if self.try_simplify_node(node) {
self.stats.simplifications += 1;
}
}
Ok(result)
}
fn try_simplify_node(&self, node: &mut EinsumNode) -> bool {
match &node.op {
OpType::ElemBinary { op } => {
match op.as_str() {
"add" | "multiply" | "subtract" => {
false
}
_ => false,
}
}
OpType::Einsum { spec } => {
spec == "i->i" || spec == "ij->ij"
}
_ => false,
}
}
fn eliminate_dead_code(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
let mut result = graph.clone();
let mut used_tensors: HashSet<usize> = HashSet::new();
if let Some(last_node) = result.nodes.last() {
for &output in &last_node.outputs {
used_tensors.insert(output);
}
}
for node in result.nodes.iter().rev() {
let outputs_used = node.outputs.iter().any(|o| used_tensors.contains(o));
if outputs_used {
for &input in &node.inputs {
used_tensors.insert(input);
}
}
}
let original_count = result.nodes.len();
result
.nodes
.retain(|n| n.outputs.iter().any(|o| used_tensors.contains(o)));
self.stats.dead_code_eliminated = original_count - result.nodes.len();
Ok(result)
}
fn reorder_operations(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
let result = graph.clone();
Ok(result)
}
}
pub struct GraphOptimizerBuilder {
passes: Vec<OptimizationPass>,
}
impl Default for GraphOptimizerBuilder {
fn default() -> Self {
Self::new()
}
}
impl GraphOptimizerBuilder {
pub fn new() -> Self {
Self { passes: Vec::new() }
}
pub fn with_constant_folding(mut self) -> Self {
self.passes.push(OptimizationPass::ConstantFolding);
self
}
pub fn with_subgraph_caching(mut self) -> Self {
self.passes.push(OptimizationPass::SubgraphCaching);
self
}
pub fn with_algebraic_simplification(mut self) -> Self {
self.passes.push(OptimizationPass::AlgebraicSimplification);
self
}
pub fn with_dead_code_elimination(mut self) -> Self {
self.passes.push(OptimizationPass::DeadCodeElimination);
self
}
pub fn with_operation_reordering(mut self) -> Self {
self.passes.push(OptimizationPass::OperationReordering);
self
}
pub fn build(self) -> GraphOptimizer {
let mut optimizer = GraphOptimizer::new();
for pass in self.passes {
optimizer.add_pass(pass);
}
optimizer
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_simple_graph() -> EinsumGraph {
EinsumGraph {
tensors: vec!["x".to_string(), "y".to_string(), "z".to_string()],
nodes: vec![EinsumNode {
inputs: vec![0, 1],
outputs: vec![2],
op: OpType::ElemBinary {
op: "add".to_string(),
},
metadata: None,
}],
inputs: vec![0, 1],
outputs: vec![2],
tensor_metadata: HashMap::new(),
}
}
fn create_graph_with_dead_code() -> EinsumGraph {
EinsumGraph {
tensors: vec![
"x".to_string(),
"y".to_string(),
"dead".to_string(),
"result".to_string(),
],
nodes: vec![
EinsumNode {
inputs: vec![0],
outputs: vec![2],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
},
EinsumNode {
inputs: vec![1],
outputs: vec![3],
op: OpType::ElemUnary {
op: "sigmoid".to_string(),
},
metadata: None,
},
],
inputs: vec![0, 1],
outputs: vec![3],
tensor_metadata: HashMap::new(),
}
}
#[test]
fn test_optimizer_new() {
let optimizer = GraphOptimizer::new();
assert!(optimizer.passes.is_empty());
}
#[test]
fn test_optimizer_with_all_passes() {
let optimizer = GraphOptimizer::with_all_passes();
assert!(optimizer
.passes
.contains(&OptimizationPass::ConstantFolding));
assert!(optimizer
.passes
.contains(&OptimizationPass::AlgebraicSimplification));
assert!(optimizer
.passes
.contains(&OptimizationPass::DeadCodeElimination));
assert!(optimizer
.passes
.contains(&OptimizationPass::SubgraphCaching));
}
#[test]
fn test_add_remove_pass() {
let mut optimizer = GraphOptimizer::new();
optimizer.add_pass(OptimizationPass::ConstantFolding);
assert!(optimizer
.passes
.contains(&OptimizationPass::ConstantFolding));
optimizer.remove_pass(OptimizationPass::ConstantFolding);
assert!(!optimizer
.passes
.contains(&OptimizationPass::ConstantFolding));
}
#[test]
fn test_optimize_empty_graph() {
let mut optimizer = GraphOptimizer::with_all_passes();
let graph = EinsumGraph {
tensors: vec![],
nodes: vec![],
inputs: vec![],
outputs: vec![],
tensor_metadata: HashMap::new(),
};
let result = optimizer.optimize(&graph).expect("unwrap");
assert!(result.nodes.is_empty());
}
#[test]
fn test_optimize_simple_graph() {
let mut optimizer = GraphOptimizer::with_all_passes();
let graph = create_simple_graph();
let result = optimizer.optimize(&graph).expect("unwrap");
assert_eq!(result.nodes.len(), 1);
}
#[test]
fn test_dead_code_elimination() {
let mut optimizer = GraphOptimizer::new();
optimizer.add_pass(OptimizationPass::DeadCodeElimination);
let graph = create_graph_with_dead_code();
let result = optimizer.optimize(&graph).expect("unwrap");
assert_eq!(optimizer.stats().dead_code_eliminated, 1);
assert_eq!(result.nodes.len(), 1);
}
#[test]
fn test_optimization_stats() {
let mut optimizer = GraphOptimizer::new();
optimizer.add_pass(OptimizationPass::DeadCodeElimination);
let graph = create_graph_with_dead_code();
optimizer.optimize(&graph).expect("unwrap");
let stats = optimizer.stats();
assert_eq!(stats.nodes_before, 2);
assert_eq!(stats.nodes_after, 1);
assert!((stats.reduction_percentage() - 50.0).abs() < 0.1);
}
#[test]
fn test_builder() {
let optimizer = GraphOptimizerBuilder::new()
.with_constant_folding()
.with_dead_code_elimination()
.build();
assert!(optimizer
.passes
.contains(&OptimizationPass::ConstantFolding));
assert!(optimizer
.passes
.contains(&OptimizationPass::DeadCodeElimination));
assert!(!optimizer
.passes
.contains(&OptimizationPass::SubgraphCaching));
}
#[test]
fn test_clear_caches() {
let mut optimizer = GraphOptimizer::new();
optimizer
.constant_cache
.insert(0, scirs2_core::ndarray::ArrayD::zeros(vec![1]));
assert!(!optimizer.constant_cache.is_empty());
optimizer.clear_caches();
assert!(optimizer.constant_cache.is_empty());
}
#[test]
fn test_aggressive_optimizer() {
let optimizer = GraphOptimizer::aggressive();
assert!(optimizer
.passes
.contains(&OptimizationPass::OperationReordering));
}
#[test]
fn test_stats_display() {
let stats = OptimizationStats {
constants_folded: 5,
subgraphs_cached: 3,
simplifications: 2,
dead_code_eliminated: 1,
operations_reordered: 0,
nodes_before: 10,
nodes_after: 7,
};
let display = format!("{}", stats);
assert!(display.contains("Constants folded: 5"));
assert!(display.contains("30.0% reduction"));
}
#[test]
fn test_subgraph_caching() {
let mut optimizer = GraphOptimizer::new();
optimizer.add_pass(OptimizationPass::SubgraphCaching);
let graph = EinsumGraph {
tensors: vec!["x".to_string(), "y1".to_string(), "y2".to_string()],
nodes: vec![
EinsumNode {
inputs: vec![0],
outputs: vec![1],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
},
EinsumNode {
inputs: vec![0],
outputs: vec![2],
op: OpType::ElemUnary {
op: "relu".to_string(),
},
metadata: None,
},
],
inputs: vec![0],
outputs: vec![1, 2],
tensor_metadata: HashMap::new(),
};
let _result = optimizer.optimize(&graph).expect("unwrap");
assert!(optimizer.stats().subgraphs_cached > 0);
}
#[test]
fn test_algebraic_simplification() {
let mut optimizer = GraphOptimizer::new();
optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
let graph = EinsumGraph {
tensors: vec!["x".to_string(), "y".to_string()],
nodes: vec![EinsumNode {
inputs: vec![0],
outputs: vec![1],
op: OpType::Einsum {
spec: "i->i".to_string(),
},
metadata: None,
}],
inputs: vec![0],
outputs: vec![1],
tensor_metadata: HashMap::new(),
};
let _result = optimizer.optimize(&graph).expect("unwrap");
assert!(optimizer.stats().simplifications > 0);
}
}