1use crate::error::{OptimizationError, Result};
2use crate::passes::*;
3use ronn_core::ModelGraph;
4use std::collections::HashMap;
5use tracing::{debug, info};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum OptimizationLevel {
10 O0,
12 O1,
14 O2,
16 O3,
18}
19
20pub struct Optimizer {
22 pass_manager: PassManager,
23 level: OptimizationLevel,
24}
25
26impl Optimizer {
27 pub fn new(level: OptimizationLevel) -> Self {
29 let mut pass_manager = PassManager::new();
30
31 match level {
33 OptimizationLevel::O0 => {
34 }
36 OptimizationLevel::O1 => {
37 pass_manager.add_pass(Box::new(ConstantFoldingPass));
38 pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
39 }
40 OptimizationLevel::O2 => {
41 pass_manager.add_pass(Box::new(ConstantFoldingPass));
42 pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
43 pass_manager.add_pass(Box::new(NodeFusionPass));
44 pass_manager.add_pass(Box::new(LayoutOptimizationPass));
45 }
46 OptimizationLevel::O3 => {
47 pass_manager.add_pass(Box::new(ConstantFoldingPass));
48 pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
49 pass_manager.add_pass(Box::new(NodeFusionPass));
50 pass_manager.add_pass(Box::new(LayoutOptimizationPass));
51 pass_manager.add_pass(Box::new(CpuOptimizationPass));
52 pass_manager.add_pass(Box::new(GpuOptimizationPass));
53 }
54 }
55
56 Self {
57 pass_manager,
58 level,
59 }
60 }
61
62 pub fn optimize(&self, graph: &mut ModelGraph) -> Result<OptimizationStats> {
64 info!("Starting optimization with level {:?}", self.level);
65 self.pass_manager.run(graph)
66 }
67
68 pub fn pass_count(&self) -> usize {
70 self.pass_manager.pass_count()
71 }
72
73 pub fn level(&self) -> OptimizationLevel {
75 self.level
76 }
77}
78
79pub struct PassManager {
81 passes: Vec<Box<dyn OptimizationPass>>,
82}
83
84impl PassManager {
85 pub fn new() -> Self {
87 Self { passes: Vec::new() }
88 }
89
90 pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
92 self.passes.push(pass);
93 }
94
95 pub fn run(&self, graph: &mut ModelGraph) -> Result<OptimizationStats> {
97 let mut stats = OptimizationStats::new();
98 let mut modified = true;
99 let mut iteration = 0;
100 const MAX_ITERATIONS: usize = 10;
101
102 while modified && iteration < MAX_ITERATIONS {
104 modified = false;
105 iteration += 1;
106
107 debug!("Optimization iteration {}", iteration);
108
109 for pass in &self.passes {
110 let pass_name = pass.name();
111 debug!("Running pass: {}", pass_name);
112
113 let result = pass.run(graph).map_err(|e| OptimizationError::PassFailed {
114 pass_name: pass_name.to_string(),
115 reason: e.to_string(),
116 })?;
117
118 if result.nodes_removed > 0 || result.nodes_fused > 0 || result.nodes_modified > 0 {
119 modified = true;
120 stats.merge(result.clone());
121 }
122
123 info!(
124 "Pass {} completed: {} nodes removed, {} fused, {} modified",
125 pass_name, result.nodes_removed, result.nodes_fused, result.nodes_modified
126 );
127 }
128 }
129
130 stats.iterations = iteration;
131 info!(
132 "Optimization completed after {} iterations: {} total changes",
133 stats.iterations,
134 stats.total_changes()
135 );
136
137 Ok(stats)
138 }
139
140 pub fn pass_count(&self) -> usize {
142 self.passes.len()
143 }
144}
145
146impl Default for PassManager {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct OptimizationStats {
155 pub nodes_removed: usize,
156 pub nodes_fused: usize,
157 pub nodes_modified: usize,
158 pub iterations: usize,
159 pub pass_stats: HashMap<String, PassStats>,
160}
161
162impl OptimizationStats {
163 pub fn new() -> Self {
164 Self::default()
165 }
166
167 pub fn merge(&mut self, other: PassStats) {
168 self.nodes_removed += other.nodes_removed;
169 self.nodes_fused += other.nodes_fused;
170 self.nodes_modified += other.nodes_modified;
171 }
172
173 pub fn total_changes(&self) -> usize {
174 self.nodes_removed + self.nodes_fused + self.nodes_modified
175 }
176}
177
178#[derive(Debug, Clone, Default)]
180pub struct PassStats {
181 pub nodes_removed: usize,
182 pub nodes_fused: usize,
183 pub nodes_modified: usize,
184}