use anyhow::Result;
use std::str::FromStr;
use tensorlogic_compiler::passes::{
optimize_einsum_graph as compiler_optimize_graph, EinsumOptResult,
};
use tensorlogic_ir::EinsumGraph;
use crate::output::{print_info, print_success};
#[derive(Debug, Clone, Default)]
pub struct OptimizationStats {
pub identity_simplifications: usize,
pub merged_einsums: usize,
pub reordered_ops: usize,
pub estimated_speedup: f64,
}
impl From<EinsumOptResult> for OptimizationStats {
fn from(result: EinsumOptResult) -> Self {
Self {
identity_simplifications: result.identity_eliminated,
merged_einsums: result.merged_count,
reordered_ops: result.reordered_count,
estimated_speedup: result.estimated_speedup,
}
}
}
fn optimize_graph_internal(graph: &mut EinsumGraph) -> OptimizationStats {
match compiler_optimize_graph(graph) {
Ok(result) => result.into(),
Err(_) => OptimizationStats::default(),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OptimizationLevel {
None,
Basic,
#[default]
Standard,
Aggressive,
}
impl OptimizationLevel {
pub fn num_passes(&self) -> usize {
match self {
OptimizationLevel::None => 0,
OptimizationLevel::Basic => 1,
OptimizationLevel::Standard => 2,
OptimizationLevel::Aggressive => 10, }
}
pub fn description(&self) -> &'static str {
match self {
OptimizationLevel::None => "No optimizations",
OptimizationLevel::Basic => "Basic (1 pass: DCE + CSE)",
OptimizationLevel::Standard => "Standard (2 passes: DCE + CSE + Identity)",
OptimizationLevel::Aggressive => "Aggressive (until convergence)",
}
}
}
impl FromStr for OptimizationLevel {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"none" | "0" => Ok(OptimizationLevel::None),
"basic" | "1" => Ok(OptimizationLevel::Basic),
"standard" | "2" => Ok(OptimizationLevel::Standard),
"aggressive" | "3" => Ok(OptimizationLevel::Aggressive),
_ => anyhow::bail!("Unknown optimization level: {}", s),
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizationConfig {
pub level: OptimizationLevel,
#[allow(dead_code)]
pub enable_dce: bool,
#[allow(dead_code)]
pub enable_cse: bool,
#[allow(dead_code)]
pub enable_identity: bool,
pub show_stats: bool,
pub verbose: bool,
}
impl Default for OptimizationConfig {
fn default() -> Self {
Self {
level: OptimizationLevel::default(),
enable_dce: true,
enable_cse: true,
enable_identity: true,
show_stats: false,
verbose: false,
}
}
}
pub fn optimize_einsum_graph(
mut graph: EinsumGraph,
config: &OptimizationConfig,
) -> Result<(EinsumGraph, OptimizationStats)> {
if config.level == OptimizationLevel::None {
if config.verbose {
print_info("Skipping optimizations (level: None)");
}
return Ok((graph, OptimizationStats::default()));
}
let num_passes = config.level.num_passes();
let mut total_stats = OptimizationStats::default();
if config.verbose {
print_info(&format!(
"Applying {} ({})",
config.level.description(),
num_passes
));
println!(
" Initial: {} nodes, {} tensors",
graph.nodes.len(),
graph.tensors.len()
);
}
for pass in 0..num_passes {
let before_nodes = graph.nodes.len();
let before_tensors = graph.tensors.len();
let stats = optimize_graph_internal(&mut graph);
if stats.identity_simplifications == 0
&& stats.merged_einsums == 0
&& stats.reordered_ops == 0
{
if config.verbose {
println!(" Converged after {} passes", pass + 1);
}
break;
}
total_stats.identity_simplifications += stats.identity_simplifications;
total_stats.merged_einsums += stats.merged_einsums;
total_stats.reordered_ops += stats.reordered_ops;
if stats.estimated_speedup > 1.0 {
total_stats.estimated_speedup *= stats.estimated_speedup;
}
if config.verbose {
println!(
" Pass {}: {} → {} nodes, {} → {} tensors",
pass + 1,
before_nodes,
graph.nodes.len(),
before_tensors,
graph.tensors.len()
);
}
}
if config.show_stats || config.verbose {
print_optimization_stats(&total_stats);
}
let total_improvements = total_stats.identity_simplifications
+ total_stats.merged_einsums
+ total_stats.reordered_ops;
if total_improvements > 0 {
print_success(&format!(
"Optimization complete: {} identities removed, {} einsums merged, {} reordered",
total_stats.identity_simplifications,
total_stats.merged_einsums,
total_stats.reordered_ops
));
} else if config.verbose {
print_info("No optimizations applied (graph already optimal)");
}
Ok((graph, total_stats))
}
fn print_optimization_stats(stats: &OptimizationStats) {
println!("\nOptimization Statistics:");
println!(
" Identity operations eliminated: {}",
stats.identity_simplifications
);
println!(" Einsum operations merged: {}", stats.merged_einsums);
println!(" Operations reordered: {}", stats.reordered_ops);
let total = stats.identity_simplifications + stats.merged_einsums + stats.reordered_ops;
if total > 0 {
println!(" Total improvements: {}", total);
if stats.estimated_speedup > 1.0 {
println!(" Estimated speedup: {:.2}x", stats.estimated_speedup);
}
}
}
#[allow(dead_code)]
pub fn list_optimization_levels() {
println!("Optimization Levels:");
println!();
for level in &[
OptimizationLevel::None,
OptimizationLevel::Basic,
OptimizationLevel::Standard,
OptimizationLevel::Aggressive,
] {
println!(" {:?}: {}", level, level.description());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimization_level_from_str() {
assert_eq!(
OptimizationLevel::from_str("none").unwrap(),
OptimizationLevel::None
);
assert_eq!(
OptimizationLevel::from_str("basic").unwrap(),
OptimizationLevel::Basic
);
assert_eq!(
OptimizationLevel::from_str("2").unwrap(),
OptimizationLevel::Standard
);
assert!(OptimizationLevel::from_str("invalid").is_err());
}
#[test]
fn test_optimization_level_num_passes() {
assert_eq!(OptimizationLevel::None.num_passes(), 0);
assert_eq!(OptimizationLevel::Basic.num_passes(), 1);
assert_eq!(OptimizationLevel::Standard.num_passes(), 2);
assert_eq!(OptimizationLevel::Aggressive.num_passes(), 10);
}
#[test]
fn test_optimization_config_default() {
let config = OptimizationConfig::default();
assert_eq!(config.level, OptimizationLevel::Standard);
assert!(config.enable_dce);
assert!(config.enable_cse);
assert!(config.enable_identity);
}
}