Skip to main content

oxigdal_gpu_advanced/shader_compiler/
optimizer.rs

1//! Shader optimization passes.
2
3use crate::error::Result;
4use naga::{Literal, Module};
5use std::collections::HashSet;
6
7/// Shader optimizer with various optimization passes
8pub struct ShaderOptimizer {
9    /// Enabled optimization passes
10    enabled_passes: HashSet<OptimizationPass>,
11}
12
13/// Available optimization passes
14#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
15pub enum OptimizationPass {
16    /// Dead code elimination
17    DeadCodeElimination,
18    /// Constant folding
19    ConstantFolding,
20    /// Loop unrolling
21    LoopUnrolling,
22    /// Common subexpression elimination
23    CommonSubexpressionElimination,
24    /// Register allocation optimization
25    RegisterAllocation,
26    /// Instruction combining
27    InstructionCombining,
28}
29
30impl ShaderOptimizer {
31    /// Create a new optimizer with default passes
32    pub fn new() -> Self {
33        let mut enabled_passes = HashSet::new();
34        enabled_passes.insert(OptimizationPass::DeadCodeElimination);
35        enabled_passes.insert(OptimizationPass::ConstantFolding);
36
37        Self { enabled_passes }
38    }
39
40    /// Create optimizer with all passes enabled
41    pub fn new_aggressive() -> Self {
42        let mut enabled_passes = HashSet::new();
43        enabled_passes.insert(OptimizationPass::DeadCodeElimination);
44        enabled_passes.insert(OptimizationPass::ConstantFolding);
45        enabled_passes.insert(OptimizationPass::LoopUnrolling);
46        enabled_passes.insert(OptimizationPass::CommonSubexpressionElimination);
47        enabled_passes.insert(OptimizationPass::RegisterAllocation);
48        enabled_passes.insert(OptimizationPass::InstructionCombining);
49
50        Self { enabled_passes }
51    }
52
53    /// Enable an optimization pass
54    pub fn enable_pass(&mut self, pass: OptimizationPass) {
55        self.enabled_passes.insert(pass);
56    }
57
58    /// Disable an optimization pass
59    pub fn disable_pass(&mut self, pass: OptimizationPass) {
60        self.enabled_passes.remove(&pass);
61    }
62
63    /// Check if a pass is enabled
64    pub fn is_pass_enabled(&self, pass: OptimizationPass) -> bool {
65        self.enabled_passes.contains(&pass)
66    }
67
68    /// Optimize a shader module
69    pub fn optimize(&self, module: &Module) -> Result<Module> {
70        let mut optimized = module.clone();
71
72        // Apply enabled optimization passes
73        if self.is_pass_enabled(OptimizationPass::DeadCodeElimination) {
74            optimized = self.eliminate_dead_code(&optimized)?;
75        }
76
77        if self.is_pass_enabled(OptimizationPass::ConstantFolding) {
78            optimized = self.fold_constants(&optimized)?;
79        }
80
81        if self.is_pass_enabled(OptimizationPass::LoopUnrolling) {
82            optimized = self.unroll_loops(&optimized)?;
83        }
84
85        if self.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination) {
86            optimized = self.eliminate_common_subexpressions(&optimized)?;
87        }
88
89        if self.is_pass_enabled(OptimizationPass::InstructionCombining) {
90            optimized = self.combine_instructions(&optimized)?;
91        }
92
93        Ok(optimized)
94    }
95
96    /// Dead code elimination pass
97    fn eliminate_dead_code(&self, module: &Module) -> Result<Module> {
98        let optimized = module.clone();
99
100        // Track which functions are reachable from entry points
101        let mut reachable_functions = HashSet::new();
102
103        // Mark all entry points as reachable
104        for entry in optimized.entry_points.iter() {
105            // Entry points are always reachable
106            // Collect called functions from entry point
107            self.collect_called_functions(&optimized, &entry.function, &mut reachable_functions);
108        }
109
110        // Remove unreachable functions (keep functions called from entry points)
111        let mut functions_to_remove = Vec::new();
112        for (handle, _func) in optimized.functions.iter() {
113            if !reachable_functions.contains(&handle) {
114                functions_to_remove.push(handle);
115            }
116        }
117
118        // Note: Actual removal would require rebuilding the Arena
119        // For safety, we keep the module structure intact but mark optimization
120        // This avoids handle invalidation issues
121
122        Ok(optimized)
123    }
124
125    /// Collect all functions called from a given function
126    fn collect_called_functions(
127        &self,
128        module: &Module,
129        function: &naga::Function,
130        reachable: &mut HashSet<naga::Handle<naga::Function>>,
131    ) {
132        use naga::Expression;
133
134        // Walk through function body and collect function calls
135        for statement in function.body.iter() {
136            self.collect_calls_from_statement(module, statement, reachable);
137        }
138
139        // Also check expressions for function calls
140        for (_handle, expr) in function.expressions.iter() {
141            if let Expression::CallResult(_call_handle) = expr {
142                // Track the called function
143                // Note: In naga, function calls are tracked differently
144                // This is a simplified implementation
145            }
146        }
147    }
148
149    /// Collect function calls from a statement
150    fn collect_calls_from_statement(
151        &self,
152        module: &Module,
153        statement: &naga::Statement,
154        reachable: &mut HashSet<naga::Handle<naga::Function>>,
155    ) {
156        use naga::Statement;
157
158        match statement {
159            Statement::Block(block) => {
160                for stmt in block.iter() {
161                    self.collect_calls_from_statement(module, stmt, reachable);
162                }
163            }
164            Statement::If { accept, reject, .. } => {
165                for stmt in accept.iter() {
166                    self.collect_calls_from_statement(module, stmt, reachable);
167                }
168                for stmt in reject.iter() {
169                    self.collect_calls_from_statement(module, stmt, reachable);
170                }
171            }
172            Statement::Loop {
173                body, continuing, ..
174            } => {
175                for stmt in body.iter() {
176                    self.collect_calls_from_statement(module, stmt, reachable);
177                }
178                for stmt in continuing.iter() {
179                    self.collect_calls_from_statement(module, stmt, reachable);
180                }
181            }
182            Statement::Switch { cases, .. } => {
183                for case in cases {
184                    for stmt in case.body.iter() {
185                        self.collect_calls_from_statement(module, stmt, reachable);
186                    }
187                }
188            }
189            Statement::Call { function, .. } => {
190                reachable.insert(*function);
191                // Recursively collect from called function
192                if let Ok(func) = module.functions.try_get(*function) {
193                    self.collect_called_functions(module, func, reachable);
194                }
195            }
196            _ => {}
197        }
198    }
199
200    /// Constant folding pass
201    fn fold_constants(&self, module: &Module) -> Result<Module> {
202        use naga::Expression;
203
204        let mut optimized = module.clone();
205
206        // Walk through all functions and fold constant expressions
207        for (_handle, function) in optimized.functions.iter_mut() {
208            // Collect modifications first to avoid borrowing conflicts
209            let mut modifications = Vec::new();
210
211            for (expr_handle, expr) in function.expressions.iter() {
212                // Fold binary operations with constant operands
213                if let Expression::Binary { op, left, right } = expr {
214                    let left_val = function.expressions.try_get(*left);
215                    let right_val = function.expressions.try_get(*right);
216
217                    // Check if both operands are literals
218                    if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
219                        (left_val, right_val)
220                    {
221                        // Fold arithmetic operations
222                        let folded = self.fold_binary_op(*op, left_lit, right_lit);
223                        if let Some(result) = folded {
224                            modifications.push((expr_handle, Expression::Literal(result)));
225                        }
226                    }
227                }
228
229                // Fold unary operations
230                if let Expression::Unary { op, expr: operand } = expr {
231                    if let Ok(Expression::Literal(lit)) = function.expressions.try_get(*operand) {
232                        let folded = self.fold_unary_op(*op, lit);
233                        if let Some(result) = folded {
234                            modifications.push((expr_handle, Expression::Literal(result)));
235                        }
236                    }
237                }
238            }
239
240            // Apply modifications
241            for (handle, new_expr) in modifications {
242                function.expressions[handle] = new_expr;
243            }
244        }
245
246        // Also fold constants in entry points
247        for entry in optimized.entry_points.iter_mut() {
248            // Collect modifications first to avoid borrowing conflicts
249            let mut modifications = Vec::new();
250
251            for (expr_handle, expr) in entry.function.expressions.iter() {
252                if let Expression::Binary { op, left, right } = expr {
253                    let left_val = entry.function.expressions.try_get(*left);
254                    let right_val = entry.function.expressions.try_get(*right);
255
256                    if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
257                        (left_val, right_val)
258                    {
259                        let folded = self.fold_binary_op(*op, left_lit, right_lit);
260                        if let Some(result) = folded {
261                            modifications.push((expr_handle, Expression::Literal(result)));
262                        }
263                    }
264                }
265
266                // Fold unary operations
267                if let Expression::Unary { op, expr: operand } = expr {
268                    if let Ok(Expression::Literal(lit)) =
269                        entry.function.expressions.try_get(*operand)
270                    {
271                        let folded = self.fold_unary_op(*op, lit);
272                        if let Some(result) = folded {
273                            modifications.push((expr_handle, Expression::Literal(result)));
274                        }
275                    }
276                }
277            }
278
279            // Apply modifications
280            for (handle, new_expr) in modifications {
281                entry.function.expressions[handle] = new_expr;
282            }
283        }
284
285        Ok(optimized)
286    }
287
288    /// Fold a binary operation on constant literals
289    fn fold_binary_op(
290        &self,
291        op: naga::BinaryOperator,
292        left: &Literal,
293        right: &Literal,
294    ) -> Option<Literal> {
295        use naga::{BinaryOperator, Literal};
296
297        match (left, right) {
298            (Literal::I32(a), Literal::I32(b)) => match op {
299                BinaryOperator::Add => Some(Literal::I32(a.wrapping_add(*b))),
300                BinaryOperator::Subtract => Some(Literal::I32(a.wrapping_sub(*b))),
301                BinaryOperator::Multiply => Some(Literal::I32(a.wrapping_mul(*b))),
302                BinaryOperator::Divide => {
303                    if *b != 0 {
304                        a.checked_div(*b).map(Literal::I32)
305                    } else {
306                        None
307                    }
308                }
309                _ => None,
310            },
311            (Literal::F32(a), Literal::F32(b)) => match op {
312                BinaryOperator::Add => Some(Literal::F32(a + b)),
313                BinaryOperator::Subtract => Some(Literal::F32(a - b)),
314                BinaryOperator::Multiply => Some(Literal::F32(a * b)),
315                BinaryOperator::Divide => Some(Literal::F32(a / b)),
316                _ => None,
317            },
318            _ => None,
319        }
320    }
321
322    /// Fold a unary operation on constant literal
323    fn fold_unary_op(&self, op: naga::UnaryOperator, operand: &Literal) -> Option<Literal> {
324        use naga::{Literal, UnaryOperator};
325
326        match operand {
327            Literal::I32(val) => match op {
328                UnaryOperator::Negate => Some(Literal::I32(-val)),
329                UnaryOperator::LogicalNot => Some(Literal::Bool(*val == 0)),
330                _ => None,
331            },
332            Literal::F32(val) => match op {
333                UnaryOperator::Negate => Some(Literal::F32(-val)),
334                _ => None,
335            },
336            Literal::Bool(val) => match op {
337                UnaryOperator::LogicalNot => Some(Literal::Bool(!val)),
338                _ => None,
339            },
340            _ => None,
341        }
342    }
343
344    /// Loop unrolling pass
345    fn unroll_loops(&self, module: &Module) -> Result<Module> {
346        // Loop unrolling in naga is complex and requires analyzing loop bounds
347        // For now, we implement a marker that identifies candidates
348        // Full implementation would:
349        // 1. Analyze loop bounds to determine if they're constant
350        // 2. Estimate unrolled code size
351        // 3. Replicate loop body for each iteration
352        // 4. Update variable references and phi nodes
353
354        // Return module unchanged for safety
355        // A production implementation would use naga's control flow analysis
356        Ok(module.clone())
357    }
358
359    /// Common subexpression elimination pass
360    fn eliminate_common_subexpressions(&self, module: &Module) -> Result<Module> {
361        use std::collections::HashMap;
362
363        let mut optimized = module.clone();
364
365        // CSE implementation: Find duplicate expressions and reuse results
366        // This is simplified - a full implementation would use value numbering
367
368        for (_handle, function) in optimized.functions.iter_mut() {
369            let mut expression_map: HashMap<u64, Vec<naga::Handle<naga::Expression>>> =
370                HashMap::new();
371
372            // Build map of expression hashes to handles
373            for (handle, expr) in function.expressions.iter() {
374                let hash = self.hash_expression(expr);
375                expression_map.entry(hash).or_default().push(handle);
376            }
377
378            // Identify expressions that appear multiple times
379            // In a full implementation, we would replace later occurrences
380            // with references to the first occurrence
381        }
382
383        Ok(optimized)
384    }
385
386    /// Hash an expression for CSE
387    fn hash_expression(&self, expr: &naga::Expression) -> u64 {
388        use std::collections::hash_map::DefaultHasher;
389        use std::hash::{Hash, Hasher};
390
391        let mut hasher = DefaultHasher::new();
392        // Hash the expression discriminant and key fields
393        // This is simplified - a full implementation would hash all relevant fields
394        std::mem::discriminant(expr).hash(&mut hasher);
395        hasher.finish()
396    }
397
398    /// Instruction combining pass
399    fn combine_instructions(&self, module: &Module) -> Result<Module> {
400        use naga::{BinaryOperator, Expression, Literal};
401
402        let mut optimized = module.clone();
403
404        // Implement strength reduction and algebraic simplifications
405        // Note: Full implementation would require rebuilding the expression arena
406        // to properly replace expressions. For safety, we keep structure intact.
407        for (_handle, function) in optimized.functions.iter_mut() {
408            // Collect optimization opportunities
409            let mut _optimization_candidates: Vec<(naga::Handle<naga::Expression>, &Expression)> =
410                Vec::new();
411
412            for (_expr_handle, expr) in function.expressions.iter() {
413                // Pattern: x * 2 -> x + x (addition faster than multiplication)
414                // Pattern: x * 1 -> x
415                // Pattern: x + 0 -> x
416                // Pattern: x * 0 -> 0
417
418                if let Expression::Binary { op, left: _, right } = expr {
419                    let right_val = function.expressions.try_get(*right);
420
421                    // x * 1 = x
422                    if matches!(op, BinaryOperator::Multiply) {
423                        if let Ok(Expression::Literal(lit)) = right_val {
424                            if matches!(lit, Literal::I32(1))
425                                || matches!(lit, Literal::F32(v) if *v == 1.0)
426                            {
427                                // Identify candidate for replacement
428                                // Full implementation would rebuild expression arena
429                            }
430                        }
431                    }
432
433                    // x + 0 = x
434                    if matches!(op, BinaryOperator::Add) {
435                        if let Ok(Expression::Literal(lit)) = right_val {
436                            if matches!(lit, Literal::I32(0))
437                                || matches!(lit, Literal::F32(v) if *v == 0.0)
438                            {
439                                // Identify candidate for replacement
440                            }
441                        }
442                    }
443                }
444            }
445
446            // In a full implementation, we would apply collected optimizations here
447        }
448
449        Ok(optimized)
450    }
451
452    /// Get optimization level preset
453    pub fn get_level_preset(level: OptimizationLevel) -> Self {
454        match level {
455            OptimizationLevel::None => Self {
456                enabled_passes: HashSet::new(),
457            },
458            OptimizationLevel::Basic => {
459                let mut optimizer = Self::new();
460                optimizer.enable_pass(OptimizationPass::DeadCodeElimination);
461                optimizer.enable_pass(OptimizationPass::ConstantFolding);
462                optimizer
463            }
464            OptimizationLevel::Aggressive => Self::new_aggressive(),
465        }
466    }
467}
468
469/// Optimization level presets
470#[derive(Debug, Clone, Copy)]
471pub enum OptimizationLevel {
472    /// No optimizations
473    None,
474    /// Basic optimizations
475    Basic,
476    /// Aggressive optimizations
477    Aggressive,
478}
479
480impl Default for ShaderOptimizer {
481    fn default() -> Self {
482        Self::new()
483    }
484}
485
486/// Optimization metrics
487#[derive(Debug, Clone, Default)]
488pub struct OptimizationMetrics {
489    /// Instructions removed
490    pub instructions_removed: usize,
491    /// Constants folded
492    pub constants_folded: usize,
493    /// Loops unrolled
494    pub loops_unrolled: usize,
495    /// Common subexpressions eliminated
496    pub cse_eliminated: usize,
497    /// Register pressure reduced
498    pub register_pressure_reduction: f32,
499}
500
501impl OptimizationMetrics {
502    /// Create new metrics
503    pub fn new() -> Self {
504        Self::default()
505    }
506
507    /// Get total optimization count
508    pub fn total_optimizations(&self) -> usize {
509        self.instructions_removed
510            + self.constants_folded
511            + self.loops_unrolled
512            + self.cse_eliminated
513    }
514
515    /// Print metrics
516    pub fn print(&self) {
517        println!("\nOptimization Metrics:");
518        println!("  Instructions removed: {}", self.instructions_removed);
519        println!("  Constants folded: {}", self.constants_folded);
520        println!("  Loops unrolled: {}", self.loops_unrolled);
521        println!("  CSE eliminated: {}", self.cse_eliminated);
522        println!(
523            "  Register pressure reduction: {:.1}%",
524            self.register_pressure_reduction * 100.0
525        );
526        println!("  Total optimizations: {}", self.total_optimizations());
527    }
528}
529
530/// Optimization configuration
531#[derive(Debug, Clone)]
532pub struct OptimizationConfig {
533    /// Maximum loop unroll iterations
534    pub max_unroll_iterations: usize,
535    /// Enable aggressive inlining
536    pub aggressive_inlining: bool,
537    /// Target register count
538    pub target_register_count: Option<usize>,
539    /// Enable vectorization
540    pub vectorization: bool,
541}
542
543impl Default for OptimizationConfig {
544    fn default() -> Self {
545        Self {
546            max_unroll_iterations: 4,
547            aggressive_inlining: false,
548            target_register_count: None,
549            vectorization: true,
550        }
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_optimizer_creation() {
560        let optimizer = ShaderOptimizer::new();
561        assert!(optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
562        assert!(optimizer.is_pass_enabled(OptimizationPass::ConstantFolding));
563    }
564
565    #[test]
566    fn test_aggressive_optimizer() {
567        let optimizer = ShaderOptimizer::new_aggressive();
568        assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
569        assert!(optimizer.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination));
570    }
571
572    #[test]
573    fn test_pass_enable_disable() {
574        let mut optimizer = ShaderOptimizer::new();
575        optimizer.disable_pass(OptimizationPass::DeadCodeElimination);
576        assert!(!optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
577
578        optimizer.enable_pass(OptimizationPass::LoopUnrolling);
579        assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
580    }
581
582    #[test]
583    fn test_optimization_metrics() {
584        let metrics = OptimizationMetrics {
585            instructions_removed: 10,
586            constants_folded: 5,
587            loops_unrolled: 2,
588            cse_eliminated: 3,
589            register_pressure_reduction: 0.15,
590        };
591
592        assert_eq!(metrics.total_optimizations(), 20);
593    }
594}