use crate::error::TorshResult;
use crate::pattern_matching::graph::ComputationGraph;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub mod constant_folding;
pub mod dead_code_elimination;
pub mod pattern_optimization;
pub use constant_folding::{ConstantFoldingPass, ConstantValue, FoldingConfig, FoldingStatistics};
pub use dead_code_elimination::{
DeadCodeEliminationPass, EliminationConfig, EliminationStatistics,
};
pub use pattern_optimization::{
OptimizationConfig, OptimizationStatistics, PatternOptimizationPass,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PassConfig {
pub enable_pattern_optimization: bool,
pub enable_dead_code_elimination: bool,
pub enable_constant_folding: bool,
pub max_iterations: usize,
pub convergence_threshold: f64,
pub verbose: bool,
pub custom_order: Option<Vec<PassType>>,
}
impl Default for PassConfig {
fn default() -> Self {
Self {
enable_pattern_optimization: true,
enable_dead_code_elimination: true,
enable_constant_folding: true,
max_iterations: 10,
convergence_threshold: 1e-6,
verbose: false,
custom_order: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PassType {
PatternOptimization,
DeadCodeElimination,
ConstantFolding,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PassResult {
pub total_time: Duration,
pub pass_times: HashMap<PassType, Duration>,
pub iterations: usize,
pub converged: bool,
pub pattern_stats: Option<OptimizationStatistics>,
pub elimination_stats: Option<EliminationStatistics>,
pub folding_stats: Option<FoldingStatistics>,
pub nodes_before: usize,
pub nodes_after: usize,
pub improvement_score: f64,
}
#[derive(Debug)]
pub struct PassManager {
config: PassConfig,
pattern_pass: Option<PatternOptimizationPass>,
elimination_pass: Option<DeadCodeEliminationPass>,
folding_pass: Option<ConstantFoldingPass>,
}
impl PassManager {
pub fn new() -> Self {
Self::with_config(PassConfig::default())
}
pub fn with_config(config: PassConfig) -> Self {
let pattern_pass = if config.enable_pattern_optimization {
Some(PatternOptimizationPass::new())
} else {
None
};
let elimination_pass = if config.enable_dead_code_elimination {
Some(DeadCodeEliminationPass::new())
} else {
None
};
let folding_pass = if config.enable_constant_folding {
Some(ConstantFoldingPass::new())
} else {
None
};
Self {
config,
pattern_pass,
elimination_pass,
folding_pass,
}
}
pub fn configure_pattern_optimization(&mut self, config: OptimizationConfig) {
if let Some(ref mut pass) = self.pattern_pass {
pass.configure(config);
}
}
pub fn configure_dead_code_elimination(&mut self, config: EliminationConfig) {
if let Some(ref mut pass) = self.elimination_pass {
pass.configure(config);
}
}
pub fn configure_constant_folding(&mut self, config: FoldingConfig) {
if let Some(ref mut pass) = self.folding_pass {
pass.configure(config);
}
}
pub fn run_all(&mut self, graph: &mut ComputationGraph) -> TorshResult<PassResult> {
let start_time = Instant::now();
let nodes_before = graph.nodes.len();
let mut pass_times = HashMap::new();
let mut pattern_stats = None;
let mut elimination_stats = None;
let mut folding_stats = None;
let mut iterations = 0;
let mut converged = false;
let pass_order = self.config.custom_order.clone().unwrap_or_else(|| {
vec![
PassType::PatternOptimization,
PassType::ConstantFolding,
PassType::DeadCodeElimination,
]
});
let mut prev_node_count = graph.nodes.len();
for iteration in 0..self.config.max_iterations {
iterations = iteration + 1;
let mut changed = false;
for &pass_type in &pass_order {
if !self.is_pass_enabled(pass_type) {
continue;
}
let pass_start = Instant::now();
let pass_changed = self.run_pass(
pass_type,
graph,
&mut pattern_stats,
&mut elimination_stats,
&mut folding_stats,
)?;
let pass_duration = pass_start.elapsed();
pass_times
.entry(pass_type)
.and_modify(|d| *d += pass_duration)
.or_insert(pass_duration);
changed |= pass_changed;
if self.config.verbose {
println!(
"Pass {:?} iteration {} completed in {:?}, changed: {}",
pass_type,
iteration + 1,
pass_duration,
pass_changed
);
}
}
let current_node_count = graph.nodes.len();
let relative_change = if prev_node_count > 0 {
(prev_node_count as f64 - current_node_count as f64).abs() / prev_node_count as f64
} else {
0.0
};
if !changed || relative_change < self.config.convergence_threshold {
converged = true;
if self.config.verbose {
println!("Optimization converged after {} iterations", iteration + 1);
}
break;
}
prev_node_count = current_node_count;
}
let total_time = start_time.elapsed();
let nodes_after = graph.nodes.len();
let improvement_score = if nodes_before > 0 {
(nodes_before as f64 - nodes_after as f64) / nodes_before as f64
} else {
0.0
}
.max(0.0)
.min(1.0);
Ok(PassResult {
total_time,
pass_times,
iterations,
converged,
pattern_stats,
elimination_stats,
folding_stats,
nodes_before,
nodes_after,
improvement_score,
})
}
pub fn run_single_pass(
&mut self,
pass_type: PassType,
graph: &mut ComputationGraph,
) -> TorshResult<bool> {
let mut pattern_stats = None;
let mut elimination_stats = None;
let mut folding_stats = None;
self.run_pass(
pass_type,
graph,
&mut pattern_stats,
&mut elimination_stats,
&mut folding_stats,
)
}
fn is_pass_enabled(&self, pass_type: PassType) -> bool {
match pass_type {
PassType::PatternOptimization => self.config.enable_pattern_optimization,
PassType::DeadCodeElimination => self.config.enable_dead_code_elimination,
PassType::ConstantFolding => self.config.enable_constant_folding,
}
}
fn run_pass(
&mut self,
pass_type: PassType,
graph: &mut ComputationGraph,
pattern_stats: &mut Option<OptimizationStatistics>,
elimination_stats: &mut Option<EliminationStatistics>,
folding_stats: &mut Option<FoldingStatistics>,
) -> TorshResult<bool> {
match pass_type {
PassType::PatternOptimization => {
if let Some(ref mut pass) = self.pattern_pass {
let result = pass.optimize(graph)?;
*pattern_stats = Some(pass.get_statistics().clone());
Ok(result.optimizations_applied > 0)
} else {
Ok(false)
}
}
PassType::DeadCodeElimination => {
if let Some(ref mut pass) = self.elimination_pass {
let result = pass.eliminate_dead_code(graph)?;
*elimination_stats = Some(pass.get_statistics().clone());
Ok(result.nodes_removed > 0)
} else {
Ok(false)
}
}
PassType::ConstantFolding => {
if let Some(ref mut pass) = self.folding_pass {
let result = pass.fold_constants(graph)?;
*folding_stats = Some(pass.get_statistics().clone());
Ok(result.nodes_folded > 0)
} else {
Ok(false)
}
}
}
}
pub fn config(&self) -> &PassConfig {
&self.config
}
pub fn set_config(&mut self, config: PassConfig) {
self.config = config;
if !config.enable_pattern_optimization {
self.pattern_pass = None;
} else if self.pattern_pass.is_none() {
self.pattern_pass = Some(PatternOptimizationPass::new());
}
if !config.enable_dead_code_elimination {
self.elimination_pass = None;
} else if self.elimination_pass.is_none() {
self.elimination_pass = Some(DeadCodeEliminationPass::new());
}
if !config.enable_constant_folding {
self.folding_pass = None;
} else if self.folding_pass.is_none() {
self.folding_pass = Some(ConstantFoldingPass::new());
}
}
pub fn reset_statistics(&mut self) {
if let Some(ref mut pass) = self.pattern_pass {
pass.reset_statistics();
}
if let Some(ref mut pass) = self.elimination_pass {
pass.reset_statistics();
}
if let Some(ref mut pass) = self.folding_pass {
pass.reset_statistics();
}
}
}
impl Default for PassManager {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn pattern_only() -> PassManager {
PassManager::with_config(PassConfig {
enable_pattern_optimization: true,
enable_dead_code_elimination: false,
enable_constant_folding: false,
..Default::default()
})
}
pub fn elimination_only() -> PassManager {
PassManager::with_config(PassConfig {
enable_pattern_optimization: false,
enable_dead_code_elimination: true,
enable_constant_folding: false,
..Default::default()
})
}
pub fn folding_only() -> PassManager {
PassManager::with_config(PassConfig {
enable_pattern_optimization: false,
enable_dead_code_elimination: false,
enable_constant_folding: true,
..Default::default()
})
}
pub fn fast_compile() -> PassManager {
PassManager::with_config(PassConfig {
enable_pattern_optimization: true,
enable_dead_code_elimination: true,
enable_constant_folding: false, max_iterations: 3,
convergence_threshold: 1e-3,
verbose: false,
custom_order: Some(vec![
PassType::DeadCodeElimination,
PassType::PatternOptimization,
]),
})
}
pub fn max_optimization() -> PassManager {
PassManager::with_config(PassConfig {
enable_pattern_optimization: true,
enable_dead_code_elimination: true,
enable_constant_folding: true,
max_iterations: 20,
convergence_threshold: 1e-8,
verbose: false,
custom_order: Some(vec![
PassType::ConstantFolding,
PassType::PatternOptimization,
PassType::DeadCodeElimination,
PassType::PatternOptimization, ]),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pattern_matching::graph::*;
#[test]
fn test_pass_manager_creation() {
let manager = PassManager::new();
assert!(manager.config.enable_pattern_optimization);
assert!(manager.config.enable_dead_code_elimination);
assert!(manager.config.enable_constant_folding);
}
#[test]
fn test_pass_config_serialization() {
let config = PassConfig::default();
let serialized = serde_json::to_string(&config).unwrap();
let deserialized: PassConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(
config.enable_pattern_optimization,
deserialized.enable_pattern_optimization
);
assert_eq!(config.max_iterations, deserialized.max_iterations);
}
#[test]
fn test_utility_managers() {
let pattern_only = utils::pattern_only();
assert!(pattern_only.config.enable_pattern_optimization);
assert!(!pattern_only.config.enable_dead_code_elimination);
assert!(!pattern_only.config.enable_constant_folding);
let fast_compile = utils::fast_compile();
assert_eq!(fast_compile.config.max_iterations, 3);
assert!(!fast_compile.config.enable_constant_folding);
let max_opt = utils::max_optimization();
assert_eq!(max_opt.config.max_iterations, 20);
assert!(max_opt.config.enable_constant_folding);
}
#[test]
fn test_pass_manager_with_empty_graph() {
let mut manager = PassManager::new();
let mut graph = ComputationGraph::new();
let result = manager.run_all(&mut graph).unwrap();
assert_eq!(result.nodes_before, 0);
assert_eq!(result.nodes_after, 0);
assert!(result.converged);
}
}