use crate::error::Result;
use naga::{Literal, Module};
use std::collections::HashSet;
pub struct ShaderOptimizer {
enabled_passes: HashSet<OptimizationPass>,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum OptimizationPass {
DeadCodeElimination,
ConstantFolding,
LoopUnrolling,
CommonSubexpressionElimination,
RegisterAllocation,
InstructionCombining,
}
impl ShaderOptimizer {
pub fn new() -> Self {
let mut enabled_passes = HashSet::new();
enabled_passes.insert(OptimizationPass::DeadCodeElimination);
enabled_passes.insert(OptimizationPass::ConstantFolding);
Self { enabled_passes }
}
pub fn new_aggressive() -> Self {
let mut enabled_passes = HashSet::new();
enabled_passes.insert(OptimizationPass::DeadCodeElimination);
enabled_passes.insert(OptimizationPass::ConstantFolding);
enabled_passes.insert(OptimizationPass::LoopUnrolling);
enabled_passes.insert(OptimizationPass::CommonSubexpressionElimination);
enabled_passes.insert(OptimizationPass::RegisterAllocation);
enabled_passes.insert(OptimizationPass::InstructionCombining);
Self { enabled_passes }
}
pub fn enable_pass(&mut self, pass: OptimizationPass) {
self.enabled_passes.insert(pass);
}
pub fn disable_pass(&mut self, pass: OptimizationPass) {
self.enabled_passes.remove(&pass);
}
pub fn is_pass_enabled(&self, pass: OptimizationPass) -> bool {
self.enabled_passes.contains(&pass)
}
pub fn optimize(&self, module: &Module) -> Result<Module> {
let mut optimized = module.clone();
if self.is_pass_enabled(OptimizationPass::DeadCodeElimination) {
optimized = self.eliminate_dead_code(&optimized)?;
}
if self.is_pass_enabled(OptimizationPass::ConstantFolding) {
optimized = self.fold_constants(&optimized)?;
}
if self.is_pass_enabled(OptimizationPass::LoopUnrolling) {
optimized = self.unroll_loops(&optimized)?;
}
if self.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination) {
optimized = self.eliminate_common_subexpressions(&optimized)?;
}
if self.is_pass_enabled(OptimizationPass::InstructionCombining) {
optimized = self.combine_instructions(&optimized)?;
}
Ok(optimized)
}
fn eliminate_dead_code(&self, module: &Module) -> Result<Module> {
let optimized = module.clone();
let mut reachable_functions = HashSet::new();
for entry in optimized.entry_points.iter() {
self.collect_called_functions(&optimized, &entry.function, &mut reachable_functions);
}
let mut functions_to_remove = Vec::new();
for (handle, _func) in optimized.functions.iter() {
if !reachable_functions.contains(&handle) {
functions_to_remove.push(handle);
}
}
Ok(optimized)
}
fn collect_called_functions(
&self,
module: &Module,
function: &naga::Function,
reachable: &mut HashSet<naga::Handle<naga::Function>>,
) {
use naga::Expression;
for statement in function.body.iter() {
self.collect_calls_from_statement(module, statement, reachable);
}
for (_handle, expr) in function.expressions.iter() {
if let Expression::CallResult(_call_handle) = expr {
}
}
}
fn collect_calls_from_statement(
&self,
module: &Module,
statement: &naga::Statement,
reachable: &mut HashSet<naga::Handle<naga::Function>>,
) {
use naga::Statement;
match statement {
Statement::Block(block) => {
for stmt in block.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
}
Statement::If { accept, reject, .. } => {
for stmt in accept.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
for stmt in reject.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
}
Statement::Loop {
body, continuing, ..
} => {
for stmt in body.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
for stmt in continuing.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
}
Statement::Switch { cases, .. } => {
for case in cases {
for stmt in case.body.iter() {
self.collect_calls_from_statement(module, stmt, reachable);
}
}
}
Statement::Call { function, .. } => {
reachable.insert(*function);
if let Ok(func) = module.functions.try_get(*function) {
self.collect_called_functions(module, func, reachable);
}
}
_ => {}
}
}
fn fold_constants(&self, module: &Module) -> Result<Module> {
use naga::Expression;
let mut optimized = module.clone();
for (_handle, function) in optimized.functions.iter_mut() {
let mut modifications = Vec::new();
for (expr_handle, expr) in function.expressions.iter() {
if let Expression::Binary { op, left, right } = expr {
let left_val = function.expressions.try_get(*left);
let right_val = function.expressions.try_get(*right);
if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
(left_val, right_val)
{
let folded = self.fold_binary_op(*op, left_lit, right_lit);
if let Some(result) = folded {
modifications.push((expr_handle, Expression::Literal(result)));
}
}
}
if let Expression::Unary { op, expr: operand } = expr {
if let Ok(Expression::Literal(lit)) = function.expressions.try_get(*operand) {
let folded = self.fold_unary_op(*op, lit);
if let Some(result) = folded {
modifications.push((expr_handle, Expression::Literal(result)));
}
}
}
}
for (handle, new_expr) in modifications {
function.expressions[handle] = new_expr;
}
}
for entry in optimized.entry_points.iter_mut() {
let mut modifications = Vec::new();
for (expr_handle, expr) in entry.function.expressions.iter() {
if let Expression::Binary { op, left, right } = expr {
let left_val = entry.function.expressions.try_get(*left);
let right_val = entry.function.expressions.try_get(*right);
if let (Ok(Expression::Literal(left_lit)), Ok(Expression::Literal(right_lit))) =
(left_val, right_val)
{
let folded = self.fold_binary_op(*op, left_lit, right_lit);
if let Some(result) = folded {
modifications.push((expr_handle, Expression::Literal(result)));
}
}
}
if let Expression::Unary { op, expr: operand } = expr {
if let Ok(Expression::Literal(lit)) =
entry.function.expressions.try_get(*operand)
{
let folded = self.fold_unary_op(*op, lit);
if let Some(result) = folded {
modifications.push((expr_handle, Expression::Literal(result)));
}
}
}
}
for (handle, new_expr) in modifications {
entry.function.expressions[handle] = new_expr;
}
}
Ok(optimized)
}
fn fold_binary_op(
&self,
op: naga::BinaryOperator,
left: &Literal,
right: &Literal,
) -> Option<Literal> {
use naga::{BinaryOperator, Literal};
match (left, right) {
(Literal::I32(a), Literal::I32(b)) => match op {
BinaryOperator::Add => Some(Literal::I32(a.wrapping_add(*b))),
BinaryOperator::Subtract => Some(Literal::I32(a.wrapping_sub(*b))),
BinaryOperator::Multiply => Some(Literal::I32(a.wrapping_mul(*b))),
BinaryOperator::Divide => {
if *b != 0 {
a.checked_div(*b).map(Literal::I32)
} else {
None
}
}
_ => None,
},
(Literal::F32(a), Literal::F32(b)) => match op {
BinaryOperator::Add => Some(Literal::F32(a + b)),
BinaryOperator::Subtract => Some(Literal::F32(a - b)),
BinaryOperator::Multiply => Some(Literal::F32(a * b)),
BinaryOperator::Divide => Some(Literal::F32(a / b)),
_ => None,
},
_ => None,
}
}
fn fold_unary_op(&self, op: naga::UnaryOperator, operand: &Literal) -> Option<Literal> {
use naga::{Literal, UnaryOperator};
match operand {
Literal::I32(val) => match op {
UnaryOperator::Negate => Some(Literal::I32(-val)),
UnaryOperator::LogicalNot => Some(Literal::Bool(*val == 0)),
_ => None,
},
Literal::F32(val) => match op {
UnaryOperator::Negate => Some(Literal::F32(-val)),
_ => None,
},
Literal::Bool(val) => match op {
UnaryOperator::LogicalNot => Some(Literal::Bool(!val)),
_ => None,
},
_ => None,
}
}
fn unroll_loops(&self, module: &Module) -> Result<Module> {
Ok(module.clone())
}
fn eliminate_common_subexpressions(&self, module: &Module) -> Result<Module> {
use std::collections::HashMap;
let mut optimized = module.clone();
for (_handle, function) in optimized.functions.iter_mut() {
let mut expression_map: HashMap<u64, Vec<naga::Handle<naga::Expression>>> =
HashMap::new();
for (handle, expr) in function.expressions.iter() {
let hash = self.hash_expression(expr);
expression_map.entry(hash).or_default().push(handle);
}
}
Ok(optimized)
}
fn hash_expression(&self, expr: &naga::Expression) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
std::mem::discriminant(expr).hash(&mut hasher);
hasher.finish()
}
fn combine_instructions(&self, module: &Module) -> Result<Module> {
use naga::{BinaryOperator, Expression, Literal};
let mut optimized = module.clone();
for (_handle, function) in optimized.functions.iter_mut() {
let mut _optimization_candidates: Vec<(naga::Handle<naga::Expression>, &Expression)> =
Vec::new();
for (_expr_handle, expr) in function.expressions.iter() {
if let Expression::Binary { op, left: _, right } = expr {
let right_val = function.expressions.try_get(*right);
if matches!(op, BinaryOperator::Multiply) {
if let Ok(Expression::Literal(lit)) = right_val {
if matches!(lit, Literal::I32(1))
|| matches!(lit, Literal::F32(v) if *v == 1.0)
{
}
}
}
if matches!(op, BinaryOperator::Add) {
if let Ok(Expression::Literal(lit)) = right_val {
if matches!(lit, Literal::I32(0))
|| matches!(lit, Literal::F32(v) if *v == 0.0)
{
}
}
}
}
}
}
Ok(optimized)
}
pub fn get_level_preset(level: OptimizationLevel) -> Self {
match level {
OptimizationLevel::None => Self {
enabled_passes: HashSet::new(),
},
OptimizationLevel::Basic => {
let mut optimizer = Self::new();
optimizer.enable_pass(OptimizationPass::DeadCodeElimination);
optimizer.enable_pass(OptimizationPass::ConstantFolding);
optimizer
}
OptimizationLevel::Aggressive => Self::new_aggressive(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum OptimizationLevel {
None,
Basic,
Aggressive,
}
impl Default for ShaderOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizationMetrics {
pub instructions_removed: usize,
pub constants_folded: usize,
pub loops_unrolled: usize,
pub cse_eliminated: usize,
pub register_pressure_reduction: f32,
}
impl OptimizationMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn total_optimizations(&self) -> usize {
self.instructions_removed
+ self.constants_folded
+ self.loops_unrolled
+ self.cse_eliminated
}
pub fn print(&self) {
println!("\nOptimization Metrics:");
println!(" Instructions removed: {}", self.instructions_removed);
println!(" Constants folded: {}", self.constants_folded);
println!(" Loops unrolled: {}", self.loops_unrolled);
println!(" CSE eliminated: {}", self.cse_eliminated);
println!(
" Register pressure reduction: {:.1}%",
self.register_pressure_reduction * 100.0
);
println!(" Total optimizations: {}", self.total_optimizations());
}
}
#[derive(Debug, Clone)]
pub struct OptimizationConfig {
pub max_unroll_iterations: usize,
pub aggressive_inlining: bool,
pub target_register_count: Option<usize>,
pub vectorization: bool,
}
impl Default for OptimizationConfig {
fn default() -> Self {
Self {
max_unroll_iterations: 4,
aggressive_inlining: false,
target_register_count: None,
vectorization: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimizer_creation() {
let optimizer = ShaderOptimizer::new();
assert!(optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
assert!(optimizer.is_pass_enabled(OptimizationPass::ConstantFolding));
}
#[test]
fn test_aggressive_optimizer() {
let optimizer = ShaderOptimizer::new_aggressive();
assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
assert!(optimizer.is_pass_enabled(OptimizationPass::CommonSubexpressionElimination));
}
#[test]
fn test_pass_enable_disable() {
let mut optimizer = ShaderOptimizer::new();
optimizer.disable_pass(OptimizationPass::DeadCodeElimination);
assert!(!optimizer.is_pass_enabled(OptimizationPass::DeadCodeElimination));
optimizer.enable_pass(OptimizationPass::LoopUnrolling);
assert!(optimizer.is_pass_enabled(OptimizationPass::LoopUnrolling));
}
#[test]
fn test_optimization_metrics() {
let metrics = OptimizationMetrics {
instructions_removed: 10,
constants_folded: 5,
loops_unrolled: 2,
cse_eliminated: 3,
register_pressure_reduction: 0.15,
};
assert_eq!(metrics.total_optimizations(), 20);
}
}