use std::collections::HashMap;
use crate::{
Add, AddressingMode, CanAllocate, CanVisitInstructions as _, Div,
DropHighest, DropLowest, Exp, Function, Immediate, InstructionVisitor, Mod,
Mul, Neg, ProgramCounter, RegisterIndex, Return, RollCustomDice, RollRange,
RollStandardDice, RollingRecordIndex, Sub, SumRollingRecord,
ir::Instruction
};
use crate::{Optimizer, add, div, exp, r#mod, mul, neg, sub};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ConstantFolder
{
previous: Option<Vec<Instruction>>,
replacements: HashMap<AddressingMode, AddressingMode>,
pc: ProgramCounter,
next_register: RegisterIndex,
next_rolling_record: RollingRecordIndex,
instructions: Vec<Instruction>
}
impl Optimizer<()> for ConstantFolder
{
fn optimize(mut self, mut function: Function) -> Result<Function, ()>
{
let start_register =
function.parameters.len() + function.externals.len();
self.next_register = RegisterIndex(start_register);
loop
{
self.previous = Some(function.instructions.clone());
for instruction in &function.instructions
{
instruction.visit(&mut self).unwrap();
self.pc.allocate();
}
if function.instructions == self.instructions
{
return Ok(function);
}
function.register_count = self.next_register.0;
function.rolling_record_count = self.next_rolling_record.0;
function.instructions = self.instructions.clone();
self.replacements.clear();
self.pc = ProgramCounter::default();
self.next_register = RegisterIndex(start_register);
self.next_rolling_record = RollingRecordIndex::default();
self.instructions.clear();
}
}
}
impl InstructionVisitor<()> for ConstantFolder
{
fn visit_roll_range(&mut self, range: &RollRange) -> Result<(), ()>
{
if let Some((start, end)) = range.const_ops()
{
match end.saturating_sub(start).saturating_add(1)
{
..=0 =>
{
self.replace(range.dest, Immediate(0));
return Ok(())
},
1 =>
{
self.replace(range.dest, Immediate(start));
return Ok(())
},
_ =>
{}
}
}
let dest = self.next_rolling_record();
self.replace(range.dest, dest);
self.emit(RollRange {
dest,
start: self.replacement(range.start),
end: self.replacement(range.end)
});
Ok(())
}
fn visit_roll_standard_dice(
&mut self,
roll: &RollStandardDice
) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(..=0)) = roll.count
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
if let AddressingMode::Immediate(Immediate(..=0)) = roll.faces
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
if let Some((count, faces)) = roll.const_ops()
{
let drops = self.find_drop_instructions(roll.dest);
let can_fold =
drops.map(|drop| drop.const_ops()).collect::<Vec<_>>();
if can_fold.iter().all(Option::is_some)
{
let dropped = can_fold
.iter()
.map(|can_fold| can_fold.unwrap().0)
.sum::<i32>();
let count = (count.saturating_sub(dropped)).clamp(0, i32::MAX);
if count <= 0
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
else if faces == 1
{
self.replace(roll.dest, Immediate(count));
return Ok(())
}
}
}
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollStandardDice {
dest,
count: self.replacement(roll.count),
faces: self.replacement(roll.faces)
});
Ok(())
}
fn visit_roll_custom_dice(
&mut self,
roll: &RollCustomDice
) -> Result<(), ()>
{
if let AddressingMode::Immediate(Immediate(..=0)) = roll.count
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
if roll.faces.is_empty()
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
if let Some(count) = roll.const_ops()
{
let drops = self.find_drop_instructions(roll.dest);
let can_fold =
drops.map(|drop| drop.const_ops()).collect::<Vec<_>>();
if can_fold.iter().all(Option::is_some)
{
let dropped = can_fold
.iter()
.map(|can_fold| can_fold.unwrap().0)
.sum::<i32>();
let count = (count.saturating_sub(dropped)).clamp(0, i32::MAX);
if count <= 0
{
self.replace(roll.dest, Immediate(0));
return Ok(())
}
if roll.distinct_faces() == 1
{
let face = roll.faces[0];
self.replace(
roll.dest,
Immediate(count.saturating_mul(face))
);
return Ok(())
}
}
}
let dest = self.next_rolling_record();
self.replace(roll.dest, dest);
self.emit(RollCustomDice {
dest,
count: self.replacement(roll.count),
faces: roll.faces.clone()
});
Ok(())
}
fn visit_drop_lowest(&mut self, drop: &DropLowest) -> Result<(), ()>
{
match self.replacement(drop.dest)
{
AddressingMode::Immediate(_) =>
{
Ok(())
},
AddressingMode::RollingRecord(dest) =>
{
self.emit(DropLowest {
dest,
count: self.replacement(drop.count)
});
Ok(())
},
_ => unreachable!()
}
}
fn visit_drop_highest(&mut self, drop: &DropHighest) -> Result<(), ()>
{
match self.replacement(drop.dest)
{
AddressingMode::Immediate(_) =>
{
Ok(())
},
AddressingMode::RollingRecord(dest) =>
{
self.emit(DropHighest {
dest,
count: self.replacement(drop.count)
});
Ok(())
},
_ => unreachable!()
}
}
fn visit_sum_rolling_record(
&mut self,
sum: &SumRollingRecord
) -> Result<(), ()>
{
match self.replacement(sum.src)
{
src @ AddressingMode::Immediate(_) =>
{
self.replace(sum.dest, src);
Ok(())
},
AddressingMode::RollingRecord(src) =>
{
let dest = self.next_register();
self.replace(sum.dest, dest);
self.emit(SumRollingRecord { dest, src });
Ok(())
},
_ => unreachable!()
}
}
fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
add,
|dest, op1, op2| Add { dest, op1, op2 }
)
}
fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
sub,
|dest, op1, op2| Sub { dest, op1, op2 }
)
}
fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
mul,
|dest, op1, op2| Mul { dest, op1, op2 }
)
}
fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
div,
|dest, op1, op2| Div { dest, op1, op2 }
)
}
fn visit_mod(&mut self, inst: &Mod) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
r#mod,
|dest, op1, op2| Mod { dest, op1, op2 }
)
}
fn visit_exp(&mut self, inst: &Exp) -> Result<(), ()>
{
self.maybe_fold_binary_instruction(
*inst,
|| (inst.op1, inst.op2),
exp,
|dest, op1, op2| Exp { dest, op1, op2 }
)
}
fn visit_neg(&mut self, inst: &Neg) -> Result<(), ()>
{
if let Some(op1) = inst.const_ops()
{
self.replace(inst.dest, Immediate(neg(op1)));
}
else
{
let new = self.next_register();
self.replace(inst.dest, new);
self.emit(Neg {
dest: new,
op: self.replacement(inst.op)
});
}
Ok(())
}
fn visit_return(&mut self, inst: &Return) -> Result<(), ()>
{
self.emit(Return {
src: self.replacement(inst.src)
});
Ok(())
}
}
impl ConstantFolder
{
#[inline]
fn next_register(&mut self) -> RegisterIndex
{
self.next_register.allocate()
}
#[inline]
fn next_rolling_record(&mut self) -> RollingRecordIndex
{
self.next_rolling_record.allocate()
}
#[inline]
fn replace(
&mut self,
old: impl Into<AddressingMode>,
new: impl Into<AddressingMode>
)
{
self.replacements.insert(old.into(), new.into());
}
#[inline]
fn replacement(&self, op: impl Into<AddressingMode>) -> AddressingMode
{
let op: AddressingMode = op.into();
*self.replacements.get(&op).unwrap_or(&op)
}
fn find_drop_instructions(
&self,
dest: RollingRecordIndex
) -> impl Iterator<Item = &Instruction>
{
self.previous.as_ref().unwrap()[self.pc.0 + 1..]
.iter()
.filter(move |inst| match inst
{
Instruction::DropLowest(inst) => inst.dest == dest,
Instruction::DropHighest(inst) => inst.dest == dest,
_ => false
})
}
fn maybe_fold_binary_instruction<I>(
&mut self,
inst: impl Into<Instruction>,
extractor: impl FnOnce() -> (AddressingMode, AddressingMode),
folder: impl FnOnce(i32, i32) -> i32,
constructor: impl FnOnce(RegisterIndex, AddressingMode, AddressingMode) -> I
) -> Result<(), ()>
where
I: Into<Instruction>
{
let inst = inst.into();
let old = inst.destination().unwrap();
if let Some((op1, op2)) = inst.const_ops()
{
self.replace(old, Immediate(folder(op1, op2.unwrap())));
}
else
{
let new = self.next_register();
self.replacements.insert(old, new.into());
let (op1, op2) = extractor();
self.emit(constructor(
new,
self.replacement(op1),
self.replacement(op2)
));
}
Ok(())
}
#[inline]
fn emit(&mut self, inst: impl Into<Instruction>)
{
self.instructions.push(inst.into());
}
}
trait GetConstantOperands<T>
{
fn const_ops(&self) -> Option<T>;
}
impl GetConstantOperands<(i32, i32)> for RollRange
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.start, self.end)
{
(
AddressingMode::Immediate(Immediate(start)),
AddressingMode::Immediate(Immediate(end))
) => Some((start, end)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for RollStandardDice
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.count, self.faces)
{
(
AddressingMode::Immediate(Immediate(count)),
AddressingMode::Immediate(Immediate(faces))
) => Some((count, faces)),
_ => None
}
}
}
impl GetConstantOperands<i32> for RollCustomDice
{
fn const_ops(&self) -> Option<i32>
{
match self.count
{
AddressingMode::Immediate(Immediate(count)) => Some(count),
_ => None
}
}
}
impl GetConstantOperands<i32> for DropLowest
{
fn const_ops(&self) -> Option<i32>
{
match self.count
{
AddressingMode::Immediate(Immediate(count)) => Some(count),
_ => None
}
}
}
impl GetConstantOperands<i32> for DropHighest
{
fn const_ops(&self) -> Option<i32>
{
match self.count
{
AddressingMode::Immediate(Immediate(count)) => Some(count),
_ => None
}
}
}
impl GetConstantOperands<()> for SumRollingRecord
{
fn const_ops(&self) -> Option<()> { None }
}
impl GetConstantOperands<(i32, i32)> for Add
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for Sub
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for Mul
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for Div
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for Mod
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<(i32, i32)> for Exp
{
fn const_ops(&self) -> Option<(i32, i32)>
{
match (self.op1, self.op2)
{
(
AddressingMode::Immediate(Immediate(op1)),
AddressingMode::Immediate(Immediate(op2))
) => Some((op1, op2)),
_ => None
}
}
}
impl GetConstantOperands<i32> for Neg
{
fn const_ops(&self) -> Option<i32>
{
match self.op
{
AddressingMode::Immediate(Immediate(op)) => Some(op),
_ => None
}
}
}
impl GetConstantOperands<()> for Return
{
fn const_ops(&self) -> Option<()> { None }
}
impl GetConstantOperands<(i32, Option<i32>)> for Instruction
{
fn const_ops(&self) -> Option<(i32, Option<i32>)>
{
match self
{
Instruction::RollRange(roll) =>
{
roll.const_ops().map(|(start, end)| (start, Some(end)))
},
Instruction::RollStandardDice(roll) =>
{
roll.const_ops().map(|(count, faces)| (count, Some(faces)))
},
Instruction::RollCustomDice(roll) =>
{
roll.const_ops().map(|count| (count, None))
},
Instruction::DropLowest(drop) =>
{
drop.const_ops().map(|count| (count, None))
},
Instruction::DropHighest(drop) =>
{
drop.const_ops().map(|count| (count, None))
},
Instruction::SumRollingRecord(_) => None,
Instruction::Add(add) =>
{
add.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Sub(sub) =>
{
sub.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Mul(mul) =>
{
mul.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Div(div) =>
{
div.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Mod(mod_) =>
{
mod_.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Exp(exp) =>
{
exp.const_ops().map(|(op1, op2)| (op1, Some(op2)))
},
Instruction::Neg(neg) =>
{
neg.const_ops().map(|result| (result, None))
},
Instruction::Return(_) => None
}
}
}