Skip to main content

ronn_graph/
optimizer.rs

1use crate::error::{OptimizationError, Result};
2use crate::passes::*;
3use ronn_core::ModelGraph;
4use std::collections::HashMap;
5use tracing::{debug, info};
6
7/// Optimization levels similar to compiler optimization levels
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum OptimizationLevel {
10    /// No optimizations
11    O0,
12    /// Basic optimizations (constant folding, dead code elimination)
13    O1,
14    /// Standard optimizations (O1 + node fusion, layout optimization)
15    O2,
16    /// Aggressive optimizations (O2 + provider-specific passes)
17    O3,
18}
19
20/// Main optimizer that applies optimization passes to a graph
21pub struct Optimizer {
22    pass_manager: PassManager,
23    level: OptimizationLevel,
24}
25
26impl Optimizer {
27    /// Create a new optimizer with the specified optimization level
28    pub fn new(level: OptimizationLevel) -> Self {
29        let mut pass_manager = PassManager::new();
30
31        // Register passes based on optimization level
32        match level {
33            OptimizationLevel::O0 => {
34                // No optimizations
35            }
36            OptimizationLevel::O1 => {
37                pass_manager.add_pass(Box::new(ConstantFoldingPass));
38                pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
39            }
40            OptimizationLevel::O2 => {
41                pass_manager.add_pass(Box::new(ConstantFoldingPass));
42                pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
43                pass_manager.add_pass(Box::new(NodeFusionPass));
44                pass_manager.add_pass(Box::new(LayoutOptimizationPass));
45            }
46            OptimizationLevel::O3 => {
47                pass_manager.add_pass(Box::new(ConstantFoldingPass));
48                pass_manager.add_pass(Box::new(DeadCodeEliminationPass));
49                pass_manager.add_pass(Box::new(NodeFusionPass));
50                pass_manager.add_pass(Box::new(LayoutOptimizationPass));
51                pass_manager.add_pass(Box::new(CpuOptimizationPass));
52                pass_manager.add_pass(Box::new(GpuOptimizationPass));
53            }
54        }
55
56        Self {
57            pass_manager,
58            level,
59        }
60    }
61
62    /// Optimize a model graph
63    pub fn optimize(&self, graph: &mut ModelGraph) -> Result<OptimizationStats> {
64        info!("Starting optimization with level {:?}", self.level);
65        self.pass_manager.run(graph)
66    }
67
68    /// Get the number of registered passes
69    pub fn pass_count(&self) -> usize {
70        self.pass_manager.pass_count()
71    }
72
73    /// Get the optimization level
74    pub fn level(&self) -> OptimizationLevel {
75        self.level
76    }
77}
78
79/// Manages and executes optimization passes
80pub struct PassManager {
81    passes: Vec<Box<dyn OptimizationPass>>,
82}
83
84impl PassManager {
85    /// Create a new pass manager
86    pub fn new() -> Self {
87        Self { passes: Vec::new() }
88    }
89
90    /// Add an optimization pass
91    pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
92        self.passes.push(pass);
93    }
94
95    /// Run all passes on the graph
96    pub fn run(&self, graph: &mut ModelGraph) -> Result<OptimizationStats> {
97        let mut stats = OptimizationStats::new();
98        let mut modified = true;
99        let mut iteration = 0;
100        const MAX_ITERATIONS: usize = 10;
101
102        // Run passes iteratively until no more changes or max iterations reached
103        while modified && iteration < MAX_ITERATIONS {
104            modified = false;
105            iteration += 1;
106
107            debug!("Optimization iteration {}", iteration);
108
109            for pass in &self.passes {
110                let pass_name = pass.name();
111                debug!("Running pass: {}", pass_name);
112
113                let result = pass.run(graph).map_err(|e| OptimizationError::PassFailed {
114                    pass_name: pass_name.to_string(),
115                    reason: e.to_string(),
116                })?;
117
118                if result.nodes_removed > 0 || result.nodes_fused > 0 || result.nodes_modified > 0 {
119                    modified = true;
120                    stats.merge(result.clone());
121                }
122
123                info!(
124                    "Pass {} completed: {} nodes removed, {} fused, {} modified",
125                    pass_name, result.nodes_removed, result.nodes_fused, result.nodes_modified
126                );
127            }
128        }
129
130        stats.iterations = iteration;
131        info!(
132            "Optimization completed after {} iterations: {} total changes",
133            stats.iterations,
134            stats.total_changes()
135        );
136
137        Ok(stats)
138    }
139
140    /// Get the number of passes
141    pub fn pass_count(&self) -> usize {
142        self.passes.len()
143    }
144}
145
146impl Default for PassManager {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152/// Statistics from running optimization passes
153#[derive(Debug, Clone, Default)]
154pub struct OptimizationStats {
155    pub nodes_removed: usize,
156    pub nodes_fused: usize,
157    pub nodes_modified: usize,
158    pub iterations: usize,
159    pub pass_stats: HashMap<String, PassStats>,
160}
161
162impl OptimizationStats {
163    pub fn new() -> Self {
164        Self::default()
165    }
166
167    pub fn merge(&mut self, other: PassStats) {
168        self.nodes_removed += other.nodes_removed;
169        self.nodes_fused += other.nodes_fused;
170        self.nodes_modified += other.nodes_modified;
171    }
172
173    pub fn total_changes(&self) -> usize {
174        self.nodes_removed + self.nodes_fused + self.nodes_modified
175    }
176}
177
178/// Statistics from a single pass execution
179#[derive(Debug, Clone, Default)]
180pub struct PassStats {
181    pub nodes_removed: usize,
182    pub nodes_fused: usize,
183    pub nodes_modified: usize,
184}