use std::{
cmp::{max, min},
collections::{BTreeSet, HashMap}
};
use crate::{
Add, AddressingMode, CanVisitInstructions, DependencyAnalyzer, Div,
DropHighest, DropLowest, Exp, Instruction, InstructionVisitor, Mod, Mul,
Neg, Return, RollCustomDice, RollRange, RollStandardDice, Sub,
SumRollingRecord
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConstantCommuter<'inst>
{
instructions: &'inst [Instruction],
analyzer: DependencyAnalyzer<'inst>,
groups: HashMap<Instruction, usize>,
next_group: usize,
replacements: HashMap<Instruction, Instruction>
}
impl<'inst> ConstantCommuter<'inst>
{
pub fn commute(instructions: &'inst [Instruction]) -> Vec<Instruction>
{
let mut commuter = Self {
instructions,
analyzer: DependencyAnalyzer::analyze(instructions),
groups: HashMap::new(),
next_group: 0,
replacements: HashMap::new()
};
for instruction in instructions
{
instruction.visit(&mut commuter).unwrap();
}
let groups = commuter.groups.values().copied().collect::<BTreeSet<_>>();
for group in groups
{
let representative = commuter
.instructions
.iter()
.find(|inst| commuter.groups.get(inst) == Some(&group))
.unwrap();
match representative
{
Instruction::DropLowest(_) => commuter
.rewrite_commutative_group(group, |dest, srcs| {
DropLowest {
dest: dest.try_into().unwrap(),
count: srcs[0]
}
}),
Instruction::DropHighest(_) => commuter
.rewrite_commutative_group(group, |dest, srcs| {
DropHighest {
dest: dest.try_into().unwrap(),
count: srcs[0]
}
}),
Instruction::Add(_) =>
{
commuter.rewrite_commutative_group(group, |dest, srcs| {
Add {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
}
})
},
Instruction::Sub(_) => commuter
.rewrite_nearly_commutative_group(
group,
|dest, srcs| Add {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
},
|dest, srcs| Sub {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
}
),
Instruction::Mul(_) =>
{
commuter.rewrite_commutative_group(group, |dest, srcs| {
Mul {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
}
})
},
Instruction::Div(_) => commuter
.rewrite_nearly_commutative_group(
group,
|dest, srcs| Mul {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
},
|dest, srcs| Div {
dest: dest.try_into().unwrap(),
op1: srcs[0],
op2: srcs[1]
}
),
_ =>
{}
}
}
instructions
.iter()
.map(|inst| commuter.replacements.get(inst).unwrap_or(inst))
.cloned()
.collect()
}
}
impl InstructionVisitor<()> for ConstantCommuter<'_>
{
fn visit_roll_range(&mut self, _inst: &RollRange) -> Result<(), ()>
{
Ok(())
}
fn visit_roll_standard_dice(
&mut self,
_inst: &RollStandardDice
) -> Result<(), ()>
{
Ok(())
}
fn visit_roll_custom_dice(
&mut self,
_inst: &RollCustomDice
) -> Result<(), ()>
{
Ok(())
}
fn visit_drop_lowest(&mut self, _inst: &DropLowest) -> Result<(), ()>
{
Ok(())
}
fn visit_drop_highest(&mut self, _inst: &DropHighest) -> Result<(), ()>
{
Ok(())
}
fn visit_sum_rolling_record(
&mut self,
_inst: &SumRollingRecord
) -> Result<(), ()>
{
Ok(())
}
fn visit_add(&mut self, inst: &Add) -> Result<(), ()>
{
self.organize(*inst, |inst| matches!(inst, Instruction::Add(_)))
}
fn visit_sub(&mut self, inst: &Sub) -> Result<(), ()>
{
self.organize(*inst, |inst| matches!(inst, Instruction::Sub(_)))
}
fn visit_mul(&mut self, inst: &Mul) -> Result<(), ()>
{
self.organize(*inst, |inst| matches!(inst, Instruction::Mul(_)))
}
fn visit_div(&mut self, inst: &Div) -> Result<(), ()>
{
self.organize(*inst, |inst| matches!(inst, Instruction::Div(_)))
}
fn visit_mod(&mut self, _inst: &Mod) -> Result<(), ()> { Ok(()) }
fn visit_exp(&mut self, _inst: &Exp) -> Result<(), ()> { Ok(()) }
fn visit_neg(&mut self, _inst: &Neg) -> Result<(), ()> { Ok(()) }
fn visit_return(&mut self, _inst: &Return) -> Result<(), ()> { Ok(()) }
}
impl ConstantCommuter<'_>
{
fn organize(
&mut self,
inst: impl Into<Instruction>,
filter: impl Fn(&Instruction) -> bool
) -> Result<(), ()>
{
let inst = inst.into();
let group = self.group_id(&inst);
for reader in self
.analyzer
.readers()
.get(&inst.destination().unwrap())
.unwrap()
.clone()
{
let reader = &self.instructions[reader.0];
if filter(reader)
{
let reader_group = self.group_id(&reader.clone());
self.merge_group_ids(group, reader_group);
}
}
Ok(())
}
fn group_id(&mut self, inst: &Instruction) -> usize
{
match self.groups.get(inst)
{
Some(index) => *index,
None =>
{
let index = self.next_group;
self.groups.insert(inst.clone(), index);
self.next_group += 1;
index
}
}
}
fn merge_group_ids(&mut self, first: usize, second: usize)
{
if first != second
{
let min = min(first, second);
let max = max(first, second);
self.groups.iter_mut().for_each(|(_, group)| {
if *group == max
{
*group = min;
}
});
assert!(self.groups.values().all(|&group| group != max));
}
}
fn group(&self, group: usize) -> Vec<Instruction>
{
let mut instructions = self
.groups
.iter()
.filter_map(|(inst, g)| if *g == group { Some(inst) } else { None })
.map(|inst| self.replacements.get(inst).unwrap_or(inst).clone())
.collect::<Vec<_>>();
instructions.sort_by_key(Instruction::destination);
instructions
}
fn rewrite_commutative_group<I>(
&mut self,
group: usize,
constructor: impl Fn(AddressingMode, &[AddressingMode]) -> I
) where
I: Into<Instruction>
{
let instructions = self.group(group);
let mut ops = instructions
.iter()
.flat_map(Instruction::sources)
.collect::<Vec<_>>();
ops.sort();
ops.reverse();
let arity = ops.len() / instructions.len();
instructions.iter().for_each(|inst| {
let ops =
(0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
let new_inst =
constructor(inst.destination().unwrap(), &ops).into();
self.replacements.insert((*inst).clone(), new_inst);
});
}
fn rewrite_nearly_commutative_group<C, F>(
&mut self,
group: usize,
commutative_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> C,
final_constructor: impl Fn(AddressingMode, &[AddressingMode]) -> F
) where
C: Into<Instruction>,
F: Into<Instruction>
{
let mut instructions = self.group(group);
if instructions.len() > 1
{
let mut targets = instructions
.iter()
.flat_map(Instruction::destination)
.collect::<Vec<_>>();
let mut ops = instructions
.iter()
.flat_map(Instruction::sources)
.collect::<Vec<_>>();
let first_op = ops.remove(0);
let len = ops.len();
ops.swap(len - 2, len - 1);
let commutative_result = ops.pop().unwrap();
ops.sort();
let first_inst = instructions.remove(0);
let terminal = final_constructor(
targets.pop().unwrap(),
&[first_op, commutative_result]
)
.into();
ops.reverse();
targets.reverse();
let arity = ops.len() / instructions.len();
let mut previous = first_inst.clone();
instructions.iter().for_each(|inst| {
let dest = targets.pop().unwrap();
let ops =
(0..arity).map(|_| ops.pop().unwrap()).collect::<Vec<_>>();
let new_inst = commutative_constructor(dest, &ops).into();
self.replacements.insert(previous.clone(), new_inst);
previous = inst.clone();
});
self.replacements.insert(previous, terminal);
}
}
}