use std::collections::{HashMap, hash_map::Entry};
use crate::{
Add, AddressingMode, CanAllocate, CanVisitInstructions as _, Div,
DropHighest, DropLowest, Exp, Function, Instruction, InstructionVisitor,
Mod, Mul, Neg, RegisterIndex, Return, RollCustomDice, RollRange,
RollStandardDice, Sub, SumRollingRecord
};
use super::Optimizer;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CommonSubexpressionEliminator
{
originals: HashMap<Instruction, RegisterIndex>,
replacements: HashMap<AddressingMode, AddressingMode>,
next_register: RegisterIndex,
instructions: Vec<Instruction>
}
impl Optimizer<()> for CommonSubexpressionEliminator
{
fn optimize(mut self, mut function: Function) -> Result<Function, ()>
{
let start_register =
function.parameters.len() + function.externals.len();
loop
{
self.originals.clear();
self.replacements.clear();
self.next_register = RegisterIndex(start_register);
self.instructions.clear();
for instruction in &function.instructions
{
instruction.visit(&mut self).unwrap();
}
if function.instructions == self.instructions
{
return Ok(function);
}
function.register_count = self.next_register.0;
function.instructions = self.instructions.clone();
}
}
}
impl InstructionVisitor<()> for CommonSubexpressionEliminator
{
fn visit_roll_range(&mut self, range: &crate::RollRange) -> Result<(), ()>
{
self.emit(RollRange {
dest: range.dest,
start: self.replacement(range.start),
end: self.replacement(range.end)
});
Ok(())
}
fn visit_roll_standard_dice(
&mut self,
roll: &RollStandardDice
) -> Result<(), ()>
{
self.emit(RollStandardDice {
dest: roll.dest,
count: self.replacement(roll.count),
faces: self.replacement(roll.faces)
});
Ok(())
}
fn visit_roll_custom_dice(
&mut self,
roll: &RollCustomDice
) -> Result<(), ()>
{
self.emit(RollCustomDice {
dest: roll.dest,
count: self.replacement(roll.count),
faces: roll.faces.clone()
});
Ok(())
}
fn visit_drop_lowest(&mut self, drop: &DropLowest) -> Result<(), ()>
{
self.emit(DropLowest {
dest: drop.dest,
count: self.replacement(drop.count)
});
Ok(())
}
fn visit_drop_highest(&mut self, drop: &DropHighest) -> Result<(), ()>
{
self.emit(DropHighest {
dest: drop.dest,
count: self.replacement(drop.count)
});
Ok(())
}
fn visit_sum_rolling_record(
&mut self,
sum: &SumRollingRecord
) -> Result<(), ()>
{
if let (dest, true) = self.canonicalize(*sum)
{
self.emit(SumRollingRecord { dest, src: sum.src });
}
Ok(())
}
fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Add { dest, op1, op2 }
)
}
fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Sub { dest, op1, op2 }
)
}
fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Mul { dest, op1, op2 }
)
}
fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Div { dest, op1, op2 }
)
}
fn visit_mod(&mut self, inst: &Mod) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Mod { dest, op1, op2 }
)
}
fn visit_exp(&mut self, inst: &Exp) -> Result<(), ()>
{
self.canonicalize_binary_op(
*inst,
|| (inst.op1, inst.op2),
|dest, op1, op2| Exp { dest, op1, op2 }
)
}
fn visit_neg(&mut self, inst: &Neg) -> Result<(), ()>
{
if let (dest, true) = self.canonicalize(*inst)
{
self.emit(Neg {
dest,
op: self.replacement(inst.op)
});
}
Ok(())
}
fn visit_return(&mut self, inst: &Return) -> Result<(), ()>
{
self.emit(Return {
src: self.replacement(inst.src)
});
Ok(())
}
}
impl CommonSubexpressionEliminator
{
#[inline]
fn replacement(&self, op: impl Into<AddressingMode>) -> AddressingMode
{
let op: AddressingMode = op.into();
*self.replacements.get(&op).unwrap_or(&op)
}
fn canonicalize(
&mut self,
inst: impl Into<Instruction>
) -> (RegisterIndex, bool)
{
let inst = inst.into();
match self.originals.entry(inst.scrub())
{
Entry::Occupied(entry) =>
{
let dest = *entry.get();
self.replacements
.insert(inst.destination().unwrap(), dest.into());
(dest, false)
},
Entry::Vacant(entry) =>
{
let dest = self.next_register.allocate();
self.replacements
.insert(inst.destination().unwrap(), dest.into());
(*entry.insert(dest), true)
}
}
}
fn canonicalize_binary_op<I>(
&mut self,
inst: impl Into<Instruction>,
extractor: impl Fn() -> (AddressingMode, AddressingMode),
constructor: impl Fn(RegisterIndex, AddressingMode, AddressingMode) -> I
) -> Result<(), ()>
where
I: Into<Instruction>
{
if let (dest, true) = self.canonicalize(inst)
{
let (op1, op2) = extractor();
self.instructions.push(
constructor(
dest,
self.replacements.get(&op1).copied().unwrap_or(op1),
self.replacements.get(&op2).copied().unwrap_or(op2)
)
.into()
);
}
Ok(())
}
#[inline]
fn emit(&mut self, inst: impl Into<Instruction>)
{
self.instructions.push(inst.into());
}
}
trait CanScrub
{
fn scrub(&self) -> Self;
}
impl CanScrub for RollRange
{
fn scrub(&self) -> Self
{
RollRange {
dest: self.dest,
start: self.start,
end: self.end
}
}
}
impl CanScrub for RollStandardDice
{
fn scrub(&self) -> Self
{
RollStandardDice {
dest: self.dest,
count: self.count,
faces: self.faces
}
}
}
impl CanScrub for RollCustomDice
{
fn scrub(&self) -> Self
{
RollCustomDice {
dest: self.dest,
count: self.count,
faces: self.faces.clone()
}
}
}
impl CanScrub for DropLowest
{
fn scrub(&self) -> Self
{
DropLowest {
dest: self.dest,
count: self.count
}
}
}
impl CanScrub for DropHighest
{
fn scrub(&self) -> Self
{
DropHighest {
dest: self.dest,
count: self.count
}
}
}
impl CanScrub for SumRollingRecord
{
fn scrub(&self) -> Self
{
SumRollingRecord {
dest: SCRUBBED_DEST,
src: self.src
}
}
}
impl CanScrub for Add
{
fn scrub(&self) -> Self
{
Add {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Sub
{
fn scrub(&self) -> Self
{
Sub {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Mul
{
fn scrub(&self) -> Self
{
Mul {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Div
{
fn scrub(&self) -> Self
{
Div {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Mod
{
fn scrub(&self) -> Self
{
Mod {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Exp
{
fn scrub(&self) -> Self
{
Exp {
dest: SCRUBBED_DEST,
op1: self.op1,
op2: self.op2
}
}
}
impl CanScrub for Neg
{
fn scrub(&self) -> Self
{
Neg {
dest: SCRUBBED_DEST,
op: self.op
}
}
}
impl CanScrub for Return
{
fn scrub(&self) -> Self { Return { src: self.src } }
}
impl CanScrub for Instruction
{
fn scrub(&self) -> Self
{
match self
{
Instruction::RollRange(inst) => inst.scrub().into(),
Instruction::RollStandardDice(inst) => inst.scrub().into(),
Instruction::RollCustomDice(inst) => inst.scrub().into(),
Instruction::DropLowest(inst) => inst.scrub().into(),
Instruction::DropHighest(inst) => inst.scrub().into(),
Instruction::SumRollingRecord(inst) => inst.scrub().into(),
Instruction::Add(inst) => inst.scrub().into(),
Instruction::Sub(inst) => inst.scrub().into(),
Instruction::Mul(inst) => inst.scrub().into(),
Instruction::Div(inst) => inst.scrub().into(),
Instruction::Mod(inst) => inst.scrub().into(),
Instruction::Exp(inst) => inst.scrub().into(),
Instruction::Neg(inst) => inst.scrub().into(),
Instruction::Return(inst) => inst.scrub().into()
}
}
}
const SCRUBBED_DEST: RegisterIndex = RegisterIndex(usize::MAX);