use crate::pattern_matching::graph::{ComputationGraph, GraphNode};
use crate::pattern_matching::matcher::{MatchingConfig, PatternMatch, PatternMatcher};
use crate::pattern_matching::patterns::{CommonPatterns, PatternCollection};
use crate::{QuantConfig, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use torsh_core::TorshError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationConfig {
pub aggressive: bool,
pub max_iterations: usize,
pub enable_fusion: bool,
pub enable_elimination: bool,
pub preserve_debug_info: bool,
pub pattern_priorities: HashMap<String, i32>,
}
impl Default for OptimizationConfig {
fn default() -> Self {
Self {
aggressive: false,
max_iterations: 5,
enable_fusion: true,
enable_elimination: true,
preserve_debug_info: false,
pattern_priorities: HashMap::new(),
}
}
}
#[derive(Debug)]
pub struct PatternOptimizationPass {
matcher: PatternMatcher,
config: OptimizationConfig,
stats: OptimizationStatistics,
}
impl PatternOptimizationPass {
pub fn new() -> Self {
Self {
matcher: PatternMatcher::new(),
config: OptimizationConfig::default(),
stats: OptimizationStatistics::default(),
}
}
pub fn with_config(config: OptimizationConfig) -> Self {
Self {
matcher: PatternMatcher::new(),
config,
stats: OptimizationStatistics::default(),
}
}
pub fn with_patterns(patterns: PatternCollection) -> Self {
Self {
matcher: PatternMatcher::from_collection(patterns),
config: OptimizationConfig::default(),
stats: OptimizationStatistics::default(),
}
}
pub fn set_aggressive(&mut self, aggressive: bool) {
self.config.aggressive = aggressive;
}
pub fn matcher(&self) -> &PatternMatcher {
&self.matcher
}
pub fn matcher_mut(&mut self) -> &mut PatternMatcher {
&mut self.matcher
}
pub fn get_statistics(&self) -> &OptimizationStatistics {
&self.stats
}
pub fn reset_statistics(&mut self) {
self.stats = OptimizationStatistics::default();
}
pub fn optimize(&mut self, graph: &mut ComputationGraph) -> TorshResult<OptimizationResult> {
let initial_node_count = graph.nodes.len();
let mut total_optimizations = 0;
let mut iteration = 0;
self.stats.optimization_runs += 1;
while iteration < self.config.max_iterations {
let matches = self.matcher.find_matches(graph)?;
if matches.is_empty() {
break; }
let selected_matches = self.select_optimization_matches(matches);
if selected_matches.is_empty() {
break; }
let iteration_optimizations = selected_matches.len();
for pattern_match in selected_matches {
self.apply_pattern_optimization(graph, &pattern_match)?;
self.stats.patterns_applied += 1;
*self
.stats
.pattern_counts
.entry(pattern_match.pattern_name.clone())
.or_insert(0) += 1;
}
total_optimizations += iteration_optimizations;
iteration += 1;
if !self.config.aggressive && iteration_optimizations == 0 {
break;
}
}
let final_node_count = graph.nodes.len();
self.stats.nodes_eliminated += initial_node_count.saturating_sub(final_node_count);
Ok(OptimizationResult {
optimizations_applied: total_optimizations,
nodes_eliminated: initial_node_count.saturating_sub(final_node_count),
iterations: iteration,
final_node_count,
success: true,
details: self.create_optimization_details(),
})
}
fn select_optimization_matches(&self, matches: Vec<PatternMatch>) -> Vec<PatternMatch> {
let mut selected = Vec::new();
let mut used_nodes = HashSet::new();
let filtered_matches: Vec<PatternMatch> = matches
.into_iter()
.filter(|m| self.should_apply_pattern(m))
.collect();
let mut prioritized_matches = filtered_matches;
prioritized_matches.sort_by(|a, b| {
let a_priority = self
.config
.pattern_priorities
.get(&a.pattern_name)
.unwrap_or(&0);
let b_priority = self
.config
.pattern_priorities
.get(&b.pattern_name)
.unwrap_or(&0);
if a_priority != b_priority {
return b_priority.cmp(a_priority);
}
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
for pattern_match in prioritized_matches {
let has_overlap = pattern_match
.matched_node_ids
.iter()
.any(|node_id| used_nodes.contains(node_id));
if !has_overlap {
for node_id in pattern_match.matched_node_ids.iter() {
used_nodes.insert(node_id.clone());
}
selected.push(pattern_match);
}
}
selected
}
fn should_apply_pattern(&self, pattern_match: &PatternMatch) -> bool {
if self.is_fusion_pattern(&pattern_match.pattern_name) && !self.config.enable_fusion {
return false;
}
if self.is_elimination_pattern(&pattern_match.pattern_name)
&& !self.config.enable_elimination
{
return false;
}
true
}
fn is_fusion_pattern(&self, pattern_name: &str) -> bool {
matches!(
pattern_name,
"conv_bn"
| "conv_bn_relu"
| "conv_relu"
| "linear_relu"
| "add_relu"
| "mul_add"
| "matmul_add"
)
}
fn is_elimination_pattern(&self, pattern_name: &str) -> bool {
matches!(
pattern_name,
"quant_dequant"
| "quant_dequant_elimination"
| "transpose_transpose"
| "squeeze_unsqueeze"
| "reshape_reshape"
)
}
fn apply_pattern_optimization(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
match pattern_match.pattern_name.as_str() {
"quant_dequant" | "quant_dequant_elimination" => {
self.apply_quant_dequant_elimination(graph, pattern_match)
}
"transpose_transpose" => self.apply_transpose_elimination(graph, pattern_match),
"squeeze_unsqueeze" => self.apply_squeeze_unsqueeze_elimination(graph, pattern_match),
"reshape_reshape" => self.apply_reshape_elimination(graph, pattern_match),
"conv_bn" | "conv_bn_relu" | "conv_relu" | "linear_relu" | "add_relu" | "mul_add"
| "matmul_add" => self.apply_fusion_optimization(graph, pattern_match),
_ => {
self.stats.unknown_patterns += 1;
Ok(())
}
}
}
fn apply_quant_dequant_elimination(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
if pattern_match.matched_node_ids.len() != 2 {
return Err(TorshError::InvalidArgument(
"Quant-dequant elimination requires exactly 2 nodes".to_string(),
));
}
let quant_node_id = &pattern_match.matched_node_ids[0];
let dequant_node_id = &pattern_match.matched_node_ids[1];
let quant_inputs = graph
.get_node(quant_node_id)
.ok_or_else(|| TorshError::InvalidArgument("Quantize node not found".to_string()))?
.inputs
.clone();
let dequant_outputs = graph
.get_node(dequant_node_id)
.ok_or_else(|| TorshError::InvalidArgument("Dequantize node not found".to_string()))?
.outputs
.clone();
for input_id in &quant_inputs {
for output_id in &dequant_outputs {
graph.connect_nodes(input_id, output_id)?;
}
}
graph.remove_node(quant_node_id);
graph.remove_node(dequant_node_id);
self.stats.eliminations_applied += 1;
Ok(())
}
fn apply_transpose_elimination(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
if pattern_match.matched_node_ids.len() != 2 {
return Err(TorshError::InvalidArgument(
"Transpose elimination requires exactly 2 nodes".to_string(),
));
}
let first_transpose = &pattern_match.matched_node_ids[0];
let second_transpose = &pattern_match.matched_node_ids[1];
let inputs = graph.get_node(first_transpose).unwrap().inputs.clone();
let outputs = graph.get_node(second_transpose).unwrap().outputs.clone();
for input_id in &inputs {
for output_id in &outputs {
graph.connect_nodes(input_id, output_id)?;
}
}
graph.remove_node(first_transpose);
graph.remove_node(second_transpose);
self.stats.eliminations_applied += 1;
Ok(())
}
fn apply_squeeze_unsqueeze_elimination(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
if pattern_match.matched_node_ids.len() != 2 {
return Err(TorshError::InvalidArgument(
"Squeeze-unsqueeze elimination requires exactly 2 nodes".to_string(),
));
}
let squeeze_node = &pattern_match.matched_node_ids[0];
let unsqueeze_node = &pattern_match.matched_node_ids[1];
let inputs = graph.get_node(squeeze_node).unwrap().inputs.clone();
let outputs = graph.get_node(unsqueeze_node).unwrap().outputs.clone();
for input_id in &inputs {
for output_id in &outputs {
graph.connect_nodes(input_id, output_id)?;
}
}
graph.remove_node(squeeze_node);
graph.remove_node(unsqueeze_node);
self.stats.eliminations_applied += 1;
Ok(())
}
fn apply_reshape_elimination(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
if pattern_match.matched_node_ids.len() != 2 {
return Err(TorshError::InvalidArgument(
"Reshape elimination requires exactly 2 nodes".to_string(),
));
}
let first_reshape = &pattern_match.matched_node_ids[0];
let second_reshape = &pattern_match.matched_node_ids[1];
let inputs = graph.get_node(first_reshape).unwrap().inputs.clone();
let outputs = graph.get_node(second_reshape).unwrap().outputs.clone();
for input_id in &inputs {
for output_id in &outputs {
graph.connect_nodes(input_id, output_id)?;
}
}
graph.remove_node(first_reshape);
graph.remove_node(second_reshape);
self.stats.eliminations_applied += 1;
Ok(())
}
fn apply_fusion_optimization(
&mut self,
graph: &mut ComputationGraph,
pattern_match: &PatternMatch,
) -> TorshResult<()> {
if pattern_match.matched_node_ids.is_empty() {
return Ok(());
}
let first_node_id = &pattern_match.matched_node_ids[0];
let last_node_id = pattern_match.matched_node_ids.last().unwrap();
let first_inputs = graph
.get_node(first_node_id)
.ok_or_else(|| TorshError::InvalidArgument("First node not found".to_string()))?
.inputs
.clone();
let last_outputs = graph
.get_node(last_node_id)
.ok_or_else(|| TorshError::InvalidArgument("Last node not found".to_string()))?
.outputs
.clone();
let fused_node_id = format!(
"fused_{}_{}",
pattern_match.pattern_name, self.stats.fusions_applied
);
let mut fused_node = GraphNode::new(
fused_node_id.clone(),
format!("fused_{}", pattern_match.pattern_name),
);
for input_id in &first_inputs {
fused_node.add_input(input_id.clone());
}
for output_id in &last_outputs {
fused_node.add_output(output_id.clone());
}
if let Some(qconfig) = &pattern_match.qconfig {
fused_node.set_attribute(
"quantization_scheme".to_string(),
format!("{:?}", qconfig.scheme),
);
fused_node.set_attribute(
"quantization_dtype".to_string(),
format!("{:?}", qconfig.dtype),
);
}
if self.config.preserve_debug_info {
let original_nodes: Vec<String> = pattern_match
.matched_node_ids
.iter()
.map(|id| {
format!(
"{}:{}",
graph
.get_node(id)
.map(|n| &n.op_type)
.unwrap_or(&"unknown".to_string()),
id
)
})
.collect();
fused_node.set_attribute("original_nodes".to_string(), original_nodes.join(","));
fused_node.set_attribute(
"fusion_confidence".to_string(),
pattern_match.confidence.to_string(),
);
}
for node_id in &pattern_match.matched_node_ids {
graph.remove_node(node_id);
}
graph.add_node(fused_node);
for input_id in &first_inputs {
graph.connect_nodes(input_id, &fused_node_id)?;
}
for output_id in &last_outputs {
graph.connect_nodes(&fused_node_id, output_id)?;
}
self.stats.fusions_applied += 1;
Ok(())
}
fn create_optimization_details(&self) -> HashMap<String, String> {
let mut details = HashMap::new();
details.insert(
"fusions_applied".to_string(),
self.stats.fusions_applied.to_string(),
);
details.insert(
"eliminations_applied".to_string(),
self.stats.eliminations_applied.to_string(),
);
details.insert(
"patterns_applied".to_string(),
self.stats.patterns_applied.to_string(),
);
details.insert(
"unknown_patterns".to_string(),
self.stats.unknown_patterns.to_string(),
);
for (pattern, count) in &self.stats.pattern_counts {
details.insert(format!("pattern_{}", pattern), count.to_string());
}
details
}
}
impl Default for PatternOptimizationPass {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OptimizationStatistics {
pub optimization_runs: usize,
pub patterns_applied: usize,
pub fusions_applied: usize,
pub eliminations_applied: usize,
pub nodes_eliminated: usize,
pub unknown_patterns: usize,
pub pattern_counts: HashMap<String, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationResult {
pub optimizations_applied: usize,
pub nodes_eliminated: usize,
pub iterations: usize,
pub final_node_count: usize,
pub success: bool,
pub details: HashMap<String, String>,
}
pub fn create_fusion_pass() -> PatternOptimizationPass {
let config = OptimizationConfig {
aggressive: false,
max_iterations: 3,
enable_fusion: true,
enable_elimination: false,
preserve_debug_info: false,
pattern_priorities: HashMap::new(),
};
PatternOptimizationPass::with_patterns(PatternCollection::fusion_only()).with_config(config)
}
pub fn create_elimination_pass() -> PatternOptimizationPass {
let config = OptimizationConfig {
aggressive: true,
max_iterations: 10,
enable_fusion: false,
enable_elimination: true,
preserve_debug_info: false,
pattern_priorities: HashMap::new(),
};
PatternOptimizationPass::with_patterns(PatternCollection::elimination_only())
.with_config(config)
}
pub fn create_aggressive_pass() -> PatternOptimizationPass {
let config = OptimizationConfig {
aggressive: true,
max_iterations: 15,
enable_fusion: true,
enable_elimination: true,
preserve_debug_info: false,
pattern_priorities: HashMap::new(),
};
PatternOptimizationPass::with_config(config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pattern_matching::graph::{create_branching_graph, create_linear_graph};
#[test]
fn test_optimization_pass_creation() {
let pass = PatternOptimizationPass::new();
assert!(!pass.config.aggressive);
assert_eq!(pass.config.max_iterations, 5);
let aggressive_pass = create_aggressive_pass();
assert!(aggressive_pass.config.aggressive);
assert_eq!(aggressive_pass.config.max_iterations, 15);
}
#[test]
fn test_fusion_optimization() {
let mut pass = PatternOptimizationPass::new();
let mut graph = create_linear_graph(&["conv2d", "relu"]);
let initial_count = graph.nodes.len();
let result = pass.optimize(&mut graph).unwrap();
assert!(result.success);
assert!(graph.nodes.len() <= initial_count);
}
#[test]
fn test_pattern_selection() {
let pass = PatternOptimizationPass::new();
assert!(pass.is_fusion_pattern("conv_relu"));
assert!(pass.is_fusion_pattern("linear_relu"));
assert!(!pass.is_fusion_pattern("quant_dequant"));
assert!(pass.is_elimination_pattern("quant_dequant"));
assert!(pass.is_elimination_pattern("transpose_transpose"));
assert!(!pass.is_elimination_pattern("conv_relu"));
}
#[test]
fn test_optimization_config() {
let config = OptimizationConfig {
aggressive: true,
enable_fusion: false,
enable_elimination: true,
..Default::default()
};
let pass = PatternOptimizationPass::with_config(config);
assert!(pass.config.aggressive);
assert!(!pass.config.enable_fusion);
assert!(pass.config.enable_elimination);
}
#[test]
fn test_specialized_passes() {
let fusion_pass = create_fusion_pass();
assert!(fusion_pass.config.enable_fusion);
assert!(!fusion_pass.config.enable_elimination);
let elimination_pass = create_elimination_pass();
assert!(!elimination_pass.config.enable_fusion);
assert!(elimination_pass.config.enable_elimination);
}
#[test]
fn test_statistics_tracking() {
let mut pass = PatternOptimizationPass::new();
let stats = pass.get_statistics();
assert_eq!(stats.patterns_applied, 0);
assert_eq!(stats.fusions_applied, 0);
assert_eq!(stats.eliminations_applied, 0);
let mut graph = create_linear_graph(&["conv2d", "relu"]);
let _result = pass.optimize(&mut graph).unwrap();
let updated_stats = pass.get_statistics();
assert!(updated_stats.optimization_runs > 0);
}
}