use crate::parse::{Program, Term};
use crate::tree::Tree;
use ::id_arena::Arena;
use ::rand::Rng;
macro_rules! assert_dice_roll_terminal {
() => {{
const _: $crate::parse::Term = $crate::parse::Term::DiceRoll(0i64, 0i64);
}};
}
pub fn postorder<F>(program: &Program, mut visit: F)
where
F: FnMut(&Term, Option<&Term>),
{
let Program {
tree: Tree { top, arena },
} = program;
fn postorder_term<F>(term: &Term, parent: Option<&Term>, arena: &Arena<Term>, visit: &mut F)
where
F: FnMut(&Term, Option<&Term>),
{
use Term::*;
assert_dice_roll_terminal!();
match term {
node @ Constant(_) => visit(node, parent),
node @ DiceRoll(_, _) => visit(node, parent),
KeepHigh(roll, _) | KeepLow(roll, _) => {
visit(&arena[*roll], Some(term));
visit(term, parent);
}
Explode(roll) => {
visit(&arena[*roll], Some(term));
visit(term, parent);
}
Add(left, right) | Subtract(left, right) => {
postorder_term(&arena[*left], Some(term), arena, &mut *visit);
postorder_term(&arena[*right], Some(term), arena, &mut *visit);
visit(term, parent);
}
UnaryAdd(only) | UnarySubtract(only) => {
postorder_term(&arena[*only], Some(term), arena, &mut *visit);
visit(term, parent);
}
}
}
postorder_term(&arena[*top], None, arena, &mut visit);
}
pub struct StackProgram(Vec<Instruction>);
impl StackProgram {
pub fn len_bytes(&self) -> usize {
self.0.len() * ::core::mem::size_of::<Instruction>()
}
}
#[derive(Copy, Clone)]
enum Instruction {
Value(i64),
DiceRoll(i64, i64),
DiceRollKeepHigh {
count: i64,
sides: i64,
keep_count: i64,
},
Add,
Subtract,
UnarySubtract,
NoOp,
}
pub fn compile(program: &Program) -> StackProgram {
let mut instructions = Vec::with_capacity(program.arena.len());
postorder(program, |term, parent| {
use Term::*;
let next = match term {
Constant(value) => Instruction::Value(*value),
DiceRoll(count, sides) => match parent {
Some(KeepHigh(_, keep_count)) => Instruction::DiceRollKeepHigh {
count: *count,
sides: *sides,
keep_count: *keep_count,
},
_ => Instruction::DiceRoll(*count, *sides),
},
KeepHigh(_, _) => Instruction::NoOp,
KeepLow(_, _) => unimplemented!("keep low on old stack machine"),
Explode(_) => unimplemented!("explosion on old stack machine"),
Add(_, _) => Instruction::Add,
Subtract(_, _) => Instruction::Subtract,
UnaryAdd(_) => Instruction::NoOp,
UnarySubtract(_) => Instruction::UnarySubtract,
};
instructions.push(next);
});
StackProgram(instructions)
}
pub struct Machine {
stack: Vec<i64>,
}
#[derive(Debug)]
pub enum Overflow {
Positive,
Negative,
}
impl Machine {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self { stack: Vec::new() }
}
fn exec_with<R: Rng>(&mut self, rng: &mut R, instruction: Instruction) -> Result<(), Overflow> {
use ::core::cmp;
use Instruction::*;
match instruction {
Value(value) => self.stack.push(value),
DiceRoll(count, sides) => {
if sides == 1 {
self.stack.push(count);
} else {
let mut total: i64 = 0;
for _ in 0..count {
let random = rng.gen_range(1..=sides);
total = total.checked_add(random).ok_or(Overflow::Positive)?;
}
self.stack.push(total);
}
}
DiceRollKeepHigh {
count,
sides,
keep_count,
} => {
if keep_count == 0 {
self.stack.push(0);
} else if sides == 1 {
self.stack.push(cmp::min(count, keep_count));
} else {
let mut partials = Vec::<i64>::with_capacity(cmp::min(count, keep_count) as _);
for _ in 0..count {
let random = rng.gen_range(1..=sides);
partials.push(random);
partials.sort_unstable_by(|a, b| b.cmp(a));
partials.truncate(keep_count as _);
}
let total: i64 = partials
.iter()
.try_fold(0i64, |a, &x| a.checked_add(x))
.ok_or(Overflow::Positive)?;
self.stack.push(total);
}
}
Add => {
let (left, right) = (self.stack.pop().unwrap(), self.stack.pop().unwrap());
self.stack.push(match left.checked_add(right) {
Some(x) => x,
None => {
if left > 0 || right > 0 {
return Err(Overflow::Positive);
} else {
return Err(Overflow::Negative);
}
}
});
}
Subtract => {
let (right, left) = (self.stack.pop().unwrap(), self.stack.pop().unwrap());
self.stack.push(match left.checked_sub(right) {
Some(x) => x,
None => {
if left > 0 || right < 0 {
return Err(Overflow::Positive);
} else {
return Err(Overflow::Negative);
}
}
});
}
UnarySubtract => {
let only = self.stack.pop().unwrap();
self.stack
.push(only.checked_neg().ok_or(Overflow::Positive)?);
}
NoOp => (),
}
Ok(())
}
fn run_with<R: Rng>(&mut self, rng: &mut R, program: &StackProgram) -> Result<(), Overflow> {
for instruction in program.0.iter() {
self.exec_with(&mut *rng, *instruction)?;
}
Ok(())
}
pub fn eval_with<R: Rng>(
&mut self,
rng: &mut R,
program: &StackProgram,
) -> Result<i64, Overflow> {
match self.run_with(rng, program) {
Ok(()) => Ok(self.stack.pop().unwrap()),
Err(overflow) => {
self.stack.clear();
Err(overflow)
}
}
}
}