1use crate::error::Result;
4use naga::{Literal, Module};
5use std::collections::HashSet;
6
7pub struct ShaderOptimizer {
9 enabled_passes: HashSet<OptimizationPass>,
11}
12
13#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
15pub enum OptimizationPass {
16 DeadCodeElimination,
18 ConstantFolding,
20 LoopUnrolling,
22 CommonSubexpressionElimination,
24 RegisterAllocation,
26 InstructionCombining,
28}
29
30impl ShaderOptimizer {
31 pub fn new() -> Self {
33 let mut enabled_passes = HashSet::new();
34 enabled_passes.insert(OptimizationPass::DeadCodeElimination);
35 enabled_passes.insert(OptimizationPass::ConstantFolding);
36
37 Self { enabled_passes }
38 }
39
40 pub fn new_aggressive() -> Self {
42 let mut enabled_passes = HashSet::new();
43 enabled_passes.insert(OptimizationPass::DeadCodeElimination);
44 enabled_passes.insert(OptimizationPass::ConstantFolding);
45 enabled_passes.insert(OptimizationPass::LoopUnrolling);
46 enabled_passes.insert(OptimizationPass::CommonSubexpressionElimination);
47 enabled_passes.insert(OptimizationPass::RegisterAllocation);
48 enabled_passes.insert(OptimizationPass::InstructionCombining);
49
50 Self { enabled_passes }
51 }
52
53 pub fn enable_pass(&mut self, pass: OptimizationPass) {
55 self.enabled_passes.insert(pass);
56 }
57
58 pub fn disable_pass(&mut self, pass: OptimizationPass) {
60 self.enabled_passes.remove(&pass);
61 }
62
63 pub fn is_pass_enabled(&self, pass: OptimizationPass) -> bool {
65 self.enabled_passes.contains(&pass)
66 }
67
68 pub fn optimize(&self, module: &Module) -> Result<Module> {
70 let mut optimized = module.clone();
71
72 if self.is_pass_enabled(OptimizationPass::DeadCodeElimination) {
74 optimized = self.eliminate_dead_code(&optimized)?;
75 }
76
77 if self.is_pass_enabled(OptimizationPass::ConstantFolding) {
78 optimized = self.fold_constants(&optimized)?;
79 }
80
81 if self.is_pass_enabled(OptimizationPass::LoopUnrolling) {
82 optimized = self.unroll_loops(&optimized)?;
83 }
84
85 if self.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination) {
86 optimized = self.eliminate_common_subexpressions(&optimized)?;
87 }
88
89 if self.is_pass_enabled(OptimizationPass::InstructionCombining) {
90 optimized = self.combine_instructions(&optimized)?;
91 }
92
93 Ok(optimized)
94 }
95
96 fn eliminate_dead_code(&self, module: &Module) -> Result<Module> {
98 let optimized = module.clone();
99
100 let mut reachable_functions = HashSet::new();
102
103 for entry in optimized.entry_points.iter() {
105 self.collect_called_functions(&optimized, &entry.function, &mut reachable_functions);
108 }
109
110 let mut functions_to_remove = Vec::new();
112 for (handle, _func) in optimized.functions.iter() {
113 if !reachable_functions.contains(&handle) {
114 functions_to_remove.push(handle);
115 }
116 }
117
118 Ok(optimized)
123 }
124
125 fn collect_called_functions(
127 &self,
128 module: &Module,
129 function: &naga::Function,
130 reachable: &mut HashSet<naga::Handle<naga::Function>>,
131 ) {
132 use naga::Expression;
133
134 for statement in function.body.iter() {
136 self.collect_calls_from_statement(module, statement, reachable);
137 }
138
139 for (_handle, expr) in function.expressions.iter() {
141 if let Expression::CallResult(_call_handle) = expr {
142 }
146 }
147 }
148
149 fn collect_calls_from_statement(
151 &self,
152 module: &Module,
153 statement: &naga::Statement,
154 reachable: &mut HashSet<naga::Handle<naga::Function>>,
155 ) {
156 use naga::Statement;
157
158 match statement {
159 Statement::Block(block) => {
160 for stmt in block.iter() {
161 self.collect_calls_from_statement(module, stmt, reachable);
162 }
163 }
164 Statement::If { accept, reject, .. } => {
165 for stmt in accept.iter() {
166 self.collect_calls_from_statement(module, stmt, reachable);
167 }
168 for stmt in reject.iter() {
169 self.collect_calls_from_statement(module, stmt, reachable);
170 }
171 }
172 Statement::Loop {
173 body, continuing, ..
174 } => {
175 for stmt in body.iter() {
176 self.collect_calls_from_statement(module, stmt, reachable);
177 }
178 for stmt in continuing.iter() {
179 self.collect_calls_from_statement(module, stmt, reachable);
180 }
181 }
182 Statement::Switch { cases, .. } => {
183 for case in cases {
184 for stmt in case.body.iter() {
185 self.collect_calls_from_statement(module, stmt, reachable);
186 }
187 }
188 }
189 Statement::Call { function, .. } => {
190 reachable.insert(*function);
191 if let Ok(func) = module.functions.try_get(*function) {
193 self.collect_called_functions(module, func, reachable);
194 }
195 }
196 _ => {}
197 }
198 }
199
200 fn fold_constants(&self, module: &Module) -> Result<Module> {
202 use naga::Expression;
203
204 let mut optimized = module.clone();
205
206 for (_handle, function) in optimized.functions.iter_mut() {
208 let mut modifications = Vec::new();
210
211 for (expr_handle, expr) in function.expressions.iter() {
212 if let Expression::Binary { op, left, right } = expr {
214 let left_val = function.expressions.try_get(*left);
215 let right_val = function.expressions.try_get(*right);
216
217 if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
219 (left_val, right_val)
220 {
221 let folded = self.fold_binary_op(*op, left_lit, right_lit);
223 if let Some(result) = folded {
224 modifications.push((expr_handle, Expression::Literal(result)));
225 }
226 }
227 }
228
229 if let Expression::Unary { op, expr: operand } = expr {
231 if let Ok(Expression::Literal(lit)) = function.expressions.try_get(*operand) {
232 let folded = self.fold_unary_op(*op, lit);
233 if let Some(result) = folded {
234 modifications.push((expr_handle, Expression::Literal(result)));
235 }
236 }
237 }
238 }
239
240 for (handle, new_expr) in modifications {
242 function.expressions[handle] = new_expr;
243 }
244 }
245
246 for entry in optimized.entry_points.iter_mut() {
248 let mut modifications = Vec::new();
250
251 for (expr_handle, expr) in entry.function.expressions.iter() {
252 if let Expression::Binary { op, left, right } = expr {
253 let left_val = entry.function.expressions.try_get(*left);
254 let right_val = entry.function.expressions.try_get(*right);
255
256 if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
257 (left_val, right_val)
258 {
259 let folded = self.fold_binary_op(*op, left_lit, right_lit);
260 if let Some(result) = folded {
261 modifications.push((expr_handle, Expression::Literal(result)));
262 }
263 }
264 }
265
266 if let Expression::Unary { op, expr: operand } = expr {
268 if let Ok(Expression::Literal(lit)) =
269 entry.function.expressions.try_get(*operand)
270 {
271 let folded = self.fold_unary_op(*op, lit);
272 if let Some(result) = folded {
273 modifications.push((expr_handle, Expression::Literal(result)));
274 }
275 }
276 }
277 }
278
279 for (handle, new_expr) in modifications {
281 entry.function.expressions[handle] = new_expr;
282 }
283 }
284
285 Ok(optimized)
286 }
287
288 fn fold_binary_op(
290 &self,
291 op: naga::BinaryOperator,
292 left: &Literal,
293 right: &Literal,
294 ) -> Option<Literal> {
295 use naga::{BinaryOperator, Literal};
296
297 match (left, right) {
298 (Literal::I32(a), Literal::I32(b)) => match op {
299 BinaryOperator::Add => Some(Literal::I32(a.wrapping_add(*b))),
300 BinaryOperator::Subtract => Some(Literal::I32(a.wrapping_sub(*b))),
301 BinaryOperator::Multiply => Some(Literal::I32(a.wrapping_mul(*b))),
302 BinaryOperator::Divide => {
303 if *b != 0 {
304 a.checked_div(*b).map(Literal::I32)
305 } else {
306 None
307 }
308 }
309 _ => None,
310 },
311 (Literal::F32(a), Literal::F32(b)) => match op {
312 BinaryOperator::Add => Some(Literal::F32(a + b)),
313 BinaryOperator::Subtract => Some(Literal::F32(a - b)),
314 BinaryOperator::Multiply => Some(Literal::F32(a * b)),
315 BinaryOperator::Divide => Some(Literal::F32(a / b)),
316 _ => None,
317 },
318 _ => None,
319 }
320 }
321
322 fn fold_unary_op(&self, op: naga::UnaryOperator, operand: &Literal) -> Option<Literal> {
324 use naga::{Literal, UnaryOperator};
325
326 match operand {
327 Literal::I32(val) => match op {
328 UnaryOperator::Negate => Some(Literal::I32(-val)),
329 UnaryOperator::LogicalNot => Some(Literal::Bool(*val == 0)),
330 _ => None,
331 },
332 Literal::F32(val) => match op {
333 UnaryOperator::Negate => Some(Literal::F32(-val)),
334 _ => None,
335 },
336 Literal::Bool(val) => match op {
337 UnaryOperator::LogicalNot => Some(Literal::Bool(!val)),
338 _ => None,
339 },
340 _ => None,
341 }
342 }
343
344 fn unroll_loops(&self, module: &Module) -> Result<Module> {
346 Ok(module.clone())
357 }
358
359 fn eliminate_common_subexpressions(&self, module: &Module) -> Result<Module> {
361 use std::collections::HashMap;
362
363 let mut optimized = module.clone();
364
365 for (_handle, function) in optimized.functions.iter_mut() {
369 let mut expression_map: HashMap<u64, Vec<naga::Handle<naga::Expression>>> =
370 HashMap::new();
371
372 for (handle, expr) in function.expressions.iter() {
374 let hash = self.hash_expression(expr);
375 expression_map.entry(hash).or_default().push(handle);
376 }
377
378 }
382
383 Ok(optimized)
384 }
385
386 fn hash_expression(&self, expr: &naga::Expression) -> u64 {
388 use std::collections::hash_map::DefaultHasher;
389 use std::hash::{Hash, Hasher};
390
391 let mut hasher = DefaultHasher::new();
392 std::mem::discriminant(expr).hash(&mut hasher);
395 hasher.finish()
396 }
397
398 fn combine_instructions(&self, module: &Module) -> Result<Module> {
400 use naga::{BinaryOperator, Expression, Literal};
401
402 let mut optimized = module.clone();
403
404 for (_handle, function) in optimized.functions.iter_mut() {
408 let mut _optimization_candidates: Vec<(naga::Handle<naga::Expression>, &Expression)> =
410 Vec::new();
411
412 for (_expr_handle, expr) in function.expressions.iter() {
413 if let Expression::Binary { op, left: _, right } = expr {
419 let right_val = function.expressions.try_get(*right);
420
421 if matches!(op, BinaryOperator::Multiply) {
423 if let Ok(Expression::Literal(lit)) = right_val {
424 if matches!(lit, Literal::I32(1))
425 || matches!(lit, Literal::F32(v) if *v == 1.0)
426 {
427 }
430 }
431 }
432
433 if matches!(op, BinaryOperator::Add) {
435 if let Ok(Expression::Literal(lit)) = right_val {
436 if matches!(lit, Literal::I32(0))
437 || matches!(lit, Literal::F32(v) if *v == 0.0)
438 {
439 }
441 }
442 }
443 }
444 }
445
446 }
448
449 Ok(optimized)
450 }
451
452 pub fn get_level_preset(level: OptimizationLevel) -> Self {
454 match level {
455 OptimizationLevel::None => Self {
456 enabled_passes: HashSet::new(),
457 },
458 OptimizationLevel::Basic => {
459 let mut optimizer = Self::new();
460 optimizer.enable_pass(OptimizationPass::DeadCodeElimination);
461 optimizer.enable_pass(OptimizationPass::ConstantFolding);
462 optimizer
463 }
464 OptimizationLevel::Aggressive => Self::new_aggressive(),
465 }
466 }
467}
468
469#[derive(Debug, Clone, Copy)]
471pub enum OptimizationLevel {
472 None,
474 Basic,
476 Aggressive,
478}
479
480impl Default for ShaderOptimizer {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486#[derive(Debug, Clone, Default)]
488pub struct OptimizationMetrics {
489 pub instructions_removed: usize,
491 pub constants_folded: usize,
493 pub loops_unrolled: usize,
495 pub cse_eliminated: usize,
497 pub register_pressure_reduction: f32,
499}
500
501impl OptimizationMetrics {
502 pub fn new() -> Self {
504 Self::default()
505 }
506
507 pub fn total_optimizations(&self) -> usize {
509 self.instructions_removed
510 + self.constants_folded
511 + self.loops_unrolled
512 + self.cse_eliminated
513 }
514
515 pub fn print(&self) {
517 println!("\nOptimization Metrics:");
518 println!(" Instructions removed: {}", self.instructions_removed);
519 println!(" Constants folded: {}", self.constants_folded);
520 println!(" Loops unrolled: {}", self.loops_unrolled);
521 println!(" CSE eliminated: {}", self.cse_eliminated);
522 println!(
523 " Register pressure reduction: {:.1}%",
524 self.register_pressure_reduction * 100.0
525 );
526 println!(" Total optimizations: {}", self.total_optimizations());
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct OptimizationConfig {
533 pub max_unroll_iterations: usize,
535 pub aggressive_inlining: bool,
537 pub target_register_count: Option<usize>,
539 pub vectorization: bool,
541}
542
543impl Default for OptimizationConfig {
544 fn default() -> Self {
545 Self {
546 max_unroll_iterations: 4,
547 aggressive_inlining: false,
548 target_register_count: None,
549 vectorization: true,
550 }
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_optimizer_creation() {
560 let optimizer = ShaderOptimizer::new();
561 assert!(optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
562 assert!(optimizer.is_pass_enabled(OptimizationPass::ConstantFolding));
563 }
564
565 #[test]
566 fn test_aggressive_optimizer() {
567 let optimizer = ShaderOptimizer::new_aggressive();
568 assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
569 assert!(optimizer.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination));
570 }
571
572 #[test]
573 fn test_pass_enable_disable() {
574 let mut optimizer = ShaderOptimizer::new();
575 optimizer.disable_pass(OptimizationPass::DeadCodeElimination);
576 assert!(!optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
577
578 optimizer.enable_pass(OptimizationPass::LoopUnrolling);
579 assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
580 }
581
582 #[test]
583 fn test_optimization_metrics() {
584 let metrics = OptimizationMetrics {
585 instructions_removed: 10,
586 constants_folded: 5,
587 loops_unrolled: 2,
588 cse_eliminated: 3,
589 register_pressure_reduction: 0.15,
590 };
591
592 assert_eq!(metrics.total_optimizations(), 20);
593 }
594}