mod coalescence;
mod commutation;
mod cse;
mod dce;
mod folding;
mod reduction;
use std::ops::BitOr;
use coalescence::RegisterCoalescer;
use commutation::ConstantCommuter;
use cse::CommonSubexpressionEliminator;
use dce::DeadCodeEliminator;
use folding::ConstantFolder;
use reduction::StrengthReducer;
use crate::Function;
pub trait Optimizer<E>
{
fn optimize(self, function: Function) -> Result<Function, E>;
}
#[cfg_attr(doc, aquamarine::aquamarine)]
pub struct StandardOptimizer(Passes);
impl StandardOptimizer
{
#[inline]
pub fn new(passes: Passes) -> Self { Self(passes) }
}
impl Optimizer<()> for StandardOptimizer
{
fn optimize(self, function: Function) -> Result<Function, ()>
{
if self.0.is_empty()
{
return Ok(function)
}
let mut previous = function;
loop
{
let mut optimized = previous.clone();
if self.0.contains(Pass::CommonSubexpressionElimination)
{
optimized = CommonSubexpressionEliminator::default()
.optimize(optimized)
.unwrap();
}
if self.0.contains(Pass::ConstantCommuting)
{
optimized.instructions =
ConstantCommuter::commute(&optimized.instructions);
}
if self.0.contains(Pass::ConstantFolding)
{
optimized =
ConstantFolder::default().optimize(optimized).unwrap();
}
if self.0.contains(Pass::StrengthReduction)
{
optimized =
StrengthReducer::default().optimize(optimized).unwrap();
}
if self.0.contains(Pass::DeadCodeElimination)
{
optimized =
DeadCodeEliminator::default().optimize(optimized).unwrap();
}
if optimized == previous
{
if self.0.contains(Pass::RegisterCoalescing)
{
optimized = RegisterCoalescer::default()
.optimize(optimized)
.unwrap();
}
return Ok(optimized)
}
previous = optimized;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Pass
{
CommonSubexpressionElimination = 0x01,
ConstantCommuting = 0x02,
ConstantFolding = 0x04,
StrengthReduction = 0x08,
DeadCodeElimination = 0x10,
RegisterCoalescing = 0x80
}
impl BitOr<Pass> for Pass
{
type Output = Passes;
fn bitor(self, rhs: Pass) -> Self::Output { Passes(self as u8 | rhs as u8) }
}
impl BitOr<Passes> for Pass
{
type Output = Passes;
fn bitor(self, rhs: Passes) -> Self::Output { Passes(self as u8 | rhs.0) }
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Passes(u8);
impl Passes
{
#[inline]
pub fn all() -> Self { Self(0xFF) }
#[inline]
pub fn is_empty(&self) -> bool { self.0 == 0 }
#[inline]
pub fn contains(&self, pass: Pass) -> bool { self.0 & (pass as u8) != 0 }
}
impl std::ops::BitOr<Pass> for Passes
{
type Output = Self;
fn bitor(self, rhs: Pass) -> Self::Output { Self(self.0 | rhs as u8) }
}
impl std::ops::BitOr<Passes> for Passes
{
type Output = Self;
fn bitor(self, rhs: Passes) -> Self::Output { Self(self.0 | rhs.0) }
}
impl std::ops::BitOrAssign<Pass> for Passes
{
fn bitor_assign(&mut self, rhs: Pass) { self.0 |= rhs as u8; }
}
impl From<Pass> for Passes
{
fn from(pass: Pass) -> Self { Self(pass as u8) }
}
#[cfg(test)]
mod tests
{
use pretty_assertions::assert_eq;
use crate::{
Pass, Passes,
support::{compile_valid, optimize}
};
#[test]
fn test_passes()
{
let passes = Passes::all();
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(passes.contains(Pass::ConstantCommuting));
assert!(passes.contains(Pass::ConstantFolding));
assert!(passes.contains(Pass::StrengthReduction));
assert!(passes.contains(Pass::DeadCodeElimination));
assert!(passes.contains(Pass::RegisterCoalescing));
let passes = Passes::from(Pass::CommonSubexpressionElimination);
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(!passes.contains(Pass::ConstantCommuting));
assert!(!passes.contains(Pass::ConstantFolding));
assert!(!passes.contains(Pass::StrengthReduction));
assert!(!passes.contains(Pass::DeadCodeElimination));
assert!(!passes.contains(Pass::RegisterCoalescing));
let passes =
Pass::CommonSubexpressionElimination | Pass::ConstantCommuting;
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(passes.contains(Pass::ConstantCommuting));
assert!(!passes.contains(Pass::ConstantFolding));
assert!(!passes.contains(Pass::StrengthReduction));
assert!(!passes.contains(Pass::DeadCodeElimination));
assert!(!passes.contains(Pass::RegisterCoalescing));
let passes = Passes::from(Pass::ConstantCommuting)
| Pass::CommonSubexpressionElimination;
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(passes.contains(Pass::ConstantCommuting));
assert!(!passes.contains(Pass::ConstantFolding));
assert!(!passes.contains(Pass::StrengthReduction));
assert!(!passes.contains(Pass::DeadCodeElimination));
assert!(!passes.contains(Pass::RegisterCoalescing));
let passes = Pass::CommonSubexpressionElimination
| Passes::from(Pass::ConstantCommuting);
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(passes.contains(Pass::ConstantCommuting));
assert!(!passes.contains(Pass::ConstantFolding));
assert!(!passes.contains(Pass::StrengthReduction));
assert!(!passes.contains(Pass::DeadCodeElimination));
assert!(!passes.contains(Pass::RegisterCoalescing));
let passes = Passes::default()
| Pass::CommonSubexpressionElimination
| Pass::ConstantCommuting
| Pass::ConstantFolding
| Pass::StrengthReduction
| Pass::DeadCodeElimination
| Pass::RegisterCoalescing;
assert!(passes.contains(Pass::CommonSubexpressionElimination));
assert!(passes.contains(Pass::ConstantCommuting));
assert!(passes.contains(Pass::ConstantFolding));
assert!(passes.contains(Pass::StrengthReduction));
assert!(passes.contains(Pass::DeadCodeElimination));
assert!(passes.contains(Pass::RegisterCoalescing));
}
#[test]
fn test_no_passes()
{
let passes = Passes::default();
assert!(passes.is_empty());
let function = compile_valid("x: 3D6 + {x}");
assert_eq!(function, optimize(function.clone(), Passes::default()));
}
}