use std::collections::HashMap;
use crate::{
Add, AddressingMode, CanAllocate as _, CanVisitInstructions as _, Div,
DropHighest, DropLowest, Exp, Function, Immediate, Instruction,
InstructionVisitor, Mod, Mul, Neg, ProgramCounter, RegisterIndex, Return,
RollCustomDice, RollRange, RollStandardDice, RollingRecordIndex, Sub,
SumRollingRecord
};
use crate::Optimizer;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StrengthReducer
{
replacements: HashMap<AddressingMode, AddressingMode>,
next_register: RegisterIndex,
next_rolling_record: RollingRecordIndex,
instructions: Vec<Instruction>
}
impl Optimizer<()> for StrengthReducer
{
fn optimize(mut self, mut function: Function) -> Result<Function, ()>
{
let start_register =
function.parameters.len() + function.externals.len();
self.next_register = RegisterIndex(start_register);
loop
{
for instruction in &function.instructions
{
instruction.visit(&mut self).unwrap();
}
if function.instructions == self.instructions
{
return Ok(function)
}
function.register_count = self.next_register.0;
function.rolling_record_count = self.next_rolling_record.0;
function.instructions = self.instructions.clone();
self.replacements.clear();
self.next_register = RegisterIndex(start_register);
self.next_rolling_record = RollingRecordIndex::default();
self.instructions.clear();
}
}
}
impl InstructionVisitor<()> for StrengthReducer
{
fn visit_roll_range(&mut self, range: &RollRange) -> Result<(), ()>
{
if range.start == range.end
{
self.replace(range.dest, range.start);
return Ok(());
}
let dest = self.next_rolling_record();
self.replace(range.dest, dest);
self.emit(RollRange {
dest,
start: self.replacement(range.start),
end: self.replacement(range.end)
});
Ok(())
}
fn visit_roll_standard_dice(
&mut self,
roll: &RollStandardDice
) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(1)) = roll.faces
{
self.replace(roll.dest, roll.count);
return Ok(());
}
if let AddressingMode::Immediate(Immediate(1)) = roll.count
{
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollRange {
dest,
start: Immediate(1).into(),
end: self.replacement(roll.faces)
});
return Ok(());
}
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollStandardDice {
dest,
count: self.replacement(roll.count),
faces: self.replacement(roll.faces)
});
Ok(())
}
fn visit_roll_custom_dice(
&mut self,
roll: &RollCustomDice
) -> Result<(), ()>
{
if roll.distinct_faces() == 1
{
let dest = self.next_register.allocate();
self.replace(roll.dest, dest);
self.emit(Mul {
dest,
op1: self.replacement(roll.count),
op2: Immediate(roll.faces[0]).into()
});
return Ok(());
}
let mut faces = roll.faces.to_vec();
faces.sort();
let mut counter = faces[0];
let mut contiguous = true;
for face in &faces[1..]
{
if *face == counter + 1
{
counter = *face;
}
else
{
contiguous = false;
break;
}
}
if let AddressingMode::Immediate(Immediate(1)) = roll.count
&& contiguous
{
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollRange {
dest,
start: Immediate(faces[0]).into(),
end: Immediate(faces[faces.len() - 1]).into()
});
return Ok(());
}
if contiguous && faces[0] == 1
{
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollStandardDice {
dest,
count: self.replacement(roll.count),
faces: Immediate(faces[faces.len() - 1]).into()
});
return Ok(())
}
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollCustomDice {
dest,
count: self.replacement(roll.count),
faces: roll.faces.clone()
});
Ok(())
}
fn visit_drop_lowest(&mut self, drop: &DropLowest) -> Result<(), ()>
{
match self.replacement(drop.dest)
{
AddressingMode::Register(_) | AddressingMode::Immediate(_) =>
{
let dest = self.next_register();
self.emit(Sub {
dest,
op1: self.replacement(drop.dest),
op2: self.replacement(drop.count)
});
self.replace(drop.dest, dest);
Ok(())
},
AddressingMode::RollingRecord(dest) =>
{
let count = self.replacement(drop.count);
if let AddressingMode::Immediate(Immediate(0)) = count
{
self.replace(drop.dest, dest);
return Ok(());
}
if let Some(pc) = self.find_drop_instruction(|inst| {
match DropLowest::try_from(inst.clone()).ok()
{
Some(drop) => drop.dest == dest,
_ => false
}
})
{
let previous = self.instructions.remove(pc.0);
let previous_count = *previous.sources().last().unwrap();
let sum = self.next_register();
self.emit(Add {
dest: sum,
op1: previous_count,
op2: count
});
self.emit(DropLowest {
dest,
count: sum.into()
});
return Ok(())
}
self.emit(DropLowest {
dest,
count: self.replacement(drop.count)
});
Ok(())
}
}
}
fn visit_drop_highest(&mut self, drop: &DropHighest) -> Result<(), ()>
{
match self.replacement(drop.dest)
{
AddressingMode::Register(_) | AddressingMode::Immediate(_) =>
{
let dest = self.next_register();
self.emit(Sub {
dest,
op1: self.replacement(drop.dest),
op2: self.replacement(drop.count)
});
self.replace(drop.dest, dest);
Ok(())
},
AddressingMode::RollingRecord(dest) =>
{
let count = self.replacement(drop.count);
if let AddressingMode::Immediate(Immediate(0)) = count
{
self.replace(drop.dest, dest);
return Ok(());
}
if let Some(pc) = self.find_drop_instruction(|inst| {
match DropHighest::try_from(inst.clone()).ok()
{
Some(drop) => drop.dest == dest,
_ => false
}
})
{
let previous = self.instructions.remove(pc.0);
let previous_count = *previous.sources().last().unwrap();
let sum = self.next_register();
self.emit(Add {
dest: sum,
op1: previous_count,
op2: count
});
self.emit(DropHighest {
dest,
count: sum.into()
});
return Ok(())
}
self.emit(DropHighest {
dest,
count: self.replacement(drop.count)
});
Ok(())
}
}
}
fn visit_sum_rolling_record(
&mut self,
sum: &SumRollingRecord
) -> Result<(), ()>
{
match self.replacement(sum.src)
{
AddressingMode::Immediate(src) =>
{
self.replace(sum.dest, src);
Ok(())
},
AddressingMode::Register(src) =>
{
self.replace(sum.dest, src);
Ok(())
},
AddressingMode::RollingRecord(_) =>
{
let dest = self.next_register();
self.replace(sum.dest, dest);
self.emit(SumRollingRecord {
dest,
src: self.replacement(sum.src).try_into().unwrap()
});
Ok(())
}
}
}
fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op1
{
self.replace(inst.dest, self.replacement(inst.op2));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, self.replacement(inst.op1));
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Add {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op1
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Neg {
dest,
op: self.replacement(inst.op2)
});
return Ok(())
}
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, self.replacement(inst.op1));
return Ok(())
}
if inst.op1 == inst.op2
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Sub {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op1
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op1
{
self.replace(inst.dest, self.replacement(inst.op2));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op2
{
self.replace(inst.dest, self.replacement(inst.op1));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(-1)) = inst.op1
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Neg {
dest,
op: self.replacement(inst.op2)
});
return Ok(())
}
if let AddressingMode::Immediate(Immediate(-1)) = inst.op2
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Neg {
dest,
op: self.replacement(inst.op1)
});
return Ok(())
}
if let AddressingMode::Immediate(Immediate(2)) = inst.op1
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Add {
dest,
op1: self.replacement(inst.op2),
op2: self.replacement(inst.op2)
});
return Ok(())
}
if let AddressingMode::Immediate(Immediate(2)) = inst.op2
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Add {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op1)
});
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Mul {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op1
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op2
{
self.replace(inst.dest, self.replacement(inst.op1));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(-1)) = inst.op2
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Neg {
dest,
op: self.replacement(inst.op1)
});
return Ok(())
}
if inst.op1 == inst.op2
{
self.replace(inst.dest, Immediate(1));
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Div {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_mod(&mut self, inst: &Mod) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op1
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op2
{
self.replace(inst.dest, Immediate(0));
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Mod {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_exp(&mut self, inst: &Exp) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(0)) = inst.op2
{
self.replace(inst.dest, Immediate(1));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op2
{
self.replace(inst.dest, self.replacement(inst.op1));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(1)) = inst.op1
{
self.replace(inst.dest, Immediate(1));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(2)) = inst.op2
{
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Mul {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op1)
});
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Exp {
dest,
op1: self.replacement(inst.op1),
op2: self.replacement(inst.op2)
});
Ok(())
}
fn visit_neg(&mut self, inst: &Neg) -> Result<(), ()>
{
if let Some((pc, src)) = self.find_neg_instruction(inst.op)
{
assert_eq!(pc.0, self.instructions.len() - 1);
self.next_register.0 -= 1;
self.instructions.remove(pc.0);
self.replace(inst.dest, src);
return Ok(())
}
let dest = self.next_register();
self.replace(inst.dest, dest);
self.emit(Neg {
dest,
op: self.replacement(inst.op)
});
Ok(())
}
fn visit_return(&mut self, inst: &Return) -> Result<(), ()>
{
self.emit(Return {
src: self.replacement(inst.src)
});
Ok(())
}
}
impl StrengthReducer
{
#[inline]
fn next_register(&mut self) -> RegisterIndex
{
self.next_register.allocate()
}
#[inline]
fn next_rolling_record(&mut self) -> RollingRecordIndex
{
self.next_rolling_record.allocate()
}
#[inline]
fn replace(
&mut self,
old: impl Into<AddressingMode>,
new: impl Into<AddressingMode>
)
{
self.replacements.insert(old.into(), new.into());
}
#[inline]
fn replacement(&self, op: impl Into<AddressingMode>) -> AddressingMode
{
let op: AddressingMode = op.into();
*self.replacements.get(&op).unwrap_or(&op)
}
fn find_neg_instruction(
&self,
op: AddressingMode
) -> Option<(ProgramCounter, AddressingMode)>
{
match op
{
AddressingMode::Register(reg) =>
{
self.instructions.iter().enumerate().rev().find_map(
|(pc, inst)| match Neg::try_from(inst.clone()).ok()
{
Some(neg) if neg.dest == reg =>
{
Some((pc.into(), neg.op))
},
_ => None
}
)
},
_ => None
}
}
fn find_drop_instruction(
&self,
filter: impl Fn(&Instruction) -> bool
) -> Option<ProgramCounter>
{
self.instructions.iter().enumerate().rev().find_map(
move |(pc, inst)| match filter(inst)
{
true => Some(pc.into()),
false => None
}
)
}
#[inline]
fn emit(&mut self, inst: impl Into<Instruction>)
{
self.instructions.push(inst.into());
}
}