use crate::parse::{Program, Term};
use crate::tree::Tree;
use ::id_arena::{Arena, Id};
use ::rand::Rng;
#[derive(Debug)]
pub enum InterpError {
OverflowPositive,
OverflowNegative,
}
#[derive(Debug)]
pub struct ProgramOutput {
total: i64,
tree: Tree<(Id<Term>, TermOutput)>,
}
impl ProgramOutput {
pub fn total(&self) -> i64 {
self.total
}
pub fn top(&self) -> &TermOutput {
&self.tree.arena[self.tree.top].1
}
pub fn get(&self, id: Id<(Id<Term>, TermOutput)>) -> &TermOutput {
&self.tree.arena[id].1
}
}
impl ::core::ops::Deref for ProgramOutput {
type Target = Tree<(Id<Term>, TermOutput)>;
fn deref(&self) -> &Self::Target {
&self.tree
}
}
impl ::core::ops::DerefMut for ProgramOutput {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.tree
}
}
pub fn interpret<R: Rng>(
rng: &mut R,
Program {
tree: Tree { arena, top },
}: &Program,
) -> Result<ProgramOutput, InterpError> {
let mut outputs = Arena::new();
let top = interpret_term(rng, arena, &mut outputs, *top)?;
let total = outputs[top].1.total();
Ok(ProgramOutput {
total,
tree: Tree {
arena: outputs,
top,
},
})
}
type Out = (Id<Term>, TermOutput);
#[derive(Clone, Debug)]
pub struct KeepHigh {
total: i64,
keep_count: i64,
roll: Id<Out>,
}
#[non_exhaustive]
#[derive(Clone, Debug, ::derive_more::Unwrap)]
pub enum TermOutput {
Constant(i64),
DiceRoll(i64, Option<Vec<i64>>),
KeepHigh(KeepHigh),
Add(i64, Id<Out>, Id<Out>),
Subtract(i64, Id<Out>, Id<Out>),
UnarySubtract(i64, Id<Out>),
UnaryAdd(i64, Id<Out>),
}
impl TermOutput {
fn total(&self) -> i64 {
*match self {
TermOutput::Constant(total)
| TermOutput::DiceRoll(total, ..)
| TermOutput::KeepHigh(KeepHigh { total, .. })
| TermOutput::Add(total, ..)
| TermOutput::Subtract(total, ..)
| TermOutput::UnarySubtract(total, ..)
| TermOutput::UnaryAdd(total, ..) => total,
}
}
}
fn interpret_term<R: Rng>(
rng: &mut R,
terms: &Arena<Term>,
term_outputs: &mut Arena<(Id<Term>, TermOutput)>,
term: Id<Term>,
) -> Result<Id<(Id<Term>, TermOutput)>, InterpError> {
match terms[term] {
Term::Constant(total) => Ok(term_outputs.alloc((term, TermOutput::Constant(total)))),
Term::DiceRoll(count, sides) => {
if sides == 1 {
Ok(term_outputs.alloc((term, TermOutput::DiceRoll(count, None))))
} else {
let mut total: i64 = 0;
let mut parts = Vec::with_capacity(count as usize);
for _ in 0..count {
let random = rng.gen_range(1..=sides);
total = total
.checked_add(random)
.ok_or(InterpError::OverflowPositive)?;
parts.push(random);
}
Ok(term_outputs.alloc((term, TermOutput::DiceRoll(total, Some(parts)))))
}
}
Term::KeepHigh(roll, count) => {
let roll = interpret_term(rng, terms, term_outputs, roll)?;
match &mut term_outputs[roll].1 {
TermOutput::DiceRoll(total, partials) => match partials {
Some(partials) => {
use ::core_extensions::SliceExt;
partials.sort_unstable_by(|a, b| b.cmp(a));
let total = partials.slice_lossy(0..(count as _), ()).iter().sum();
Ok(term_outputs.alloc((
term,
TermOutput::KeepHigh(KeepHigh {
total,
keep_count: count,
roll,
}),
)))
}
None => {
if count >= 1 {
let total = *total;
Ok(term_outputs.alloc((
term,
TermOutput::KeepHigh(KeepHigh {
total,
keep_count: count,
roll,
}),
)))
} else {
Ok(term_outputs.alloc((
term,
TermOutput::KeepHigh(KeepHigh {
total: 0,
keep_count: count,
roll,
}),
)))
}
}
},
_ => unreachable!("nesting of dice operators is currently not permitted"),
}
}
Term::Add(left, right) => {
let (total, left, right) = interpret_term(&mut *rng, terms, &mut *term_outputs, left)
.and_then(|left| {
interpret_term(&mut *rng, terms, &mut *term_outputs, right).and_then(|right| {
let (left_total, right_total) =
(term_outputs[left].1.total(), term_outputs[right].1.total());
left_total
.checked_add(right_total)
.ok_or_else(|| {
if left_total > 0 || right_total > 0 {
InterpError::OverflowPositive
} else {
InterpError::OverflowNegative
}
})
.map(|total| (total, left, right))
})
})?;
Ok(term_outputs.alloc((term, TermOutput::Add(total, left, right))))
}
Term::Subtract(left, right) => {
let left = interpret_term(&mut *rng, terms, &mut *term_outputs, left)?;
let right = interpret_term(&mut *rng, terms, &mut *term_outputs, right)?;
let (left_total, right_total) =
(term_outputs[left].1.total(), term_outputs[right].1.total());
Ok(term_outputs.alloc((
term,
TermOutput::Subtract(
left_total.checked_sub(right_total).ok_or_else(|| {
if left_total > 0 || right_total < 0 {
InterpError::OverflowPositive
} else {
InterpError::OverflowNegative
}
})?,
left,
right,
),
)))
}
Term::UnarySubtract(term_0) => {
let term_0 = interpret_term(&mut *rng, terms, &mut *term_outputs, term_0)?;
let term_total = term_outputs[term_0].1.total();
Ok(term_outputs.alloc((
term,
TermOutput::UnarySubtract(
term_total
.checked_neg()
.ok_or(InterpError::OverflowNegative)?,
term_0,
),
)))
}
Term::UnaryAdd(term_0) => {
let term_0 = interpret_term(&mut *rng, terms, &mut *term_outputs, term_0)?;
let term_total = term_outputs[term_0].1.total();
Ok(term_outputs.alloc((term, TermOutput::UnaryAdd(term_total, term_0))))
}
ref term => unimplemented!("evaluating term {:?}", term),
}
}
pub mod fmt {
use super::KeepHigh;
use super::Out;
use super::ProgramOutput;
use super::TermOutput;
use crate::parse::Term;
use ::id_arena::{Arena, Id};
fn fmt_default_impl(
buf: &mut String,
current: Id<Out>,
arena: &Arena<Out>,
terms: &Arena<Term>,
) {
let (term, out) = &arena[current];
match out {
TermOutput::Constant(n) => {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*n));
}
TermOutput::DiceRoll(_total, partial_sums) => {
let (count, sides) = terms[*term].clone().unwrap_dice_roll();
let nonzero_dice = match partial_sums.as_deref() {
Some([_, ..]) => true,
None if count != 0 => true,
Some([]) | None => false,
};
if nonzero_dice {
buf.push('(');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push_str(" → ");
match partial_sums.as_deref() {
Some([first, rest @ ..]) => {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*first));
for part in rest {
buf.push_str(" + ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*part));
}
}
Some([]) => unreachable!("groups of zero dice"),
None => match count {
0 => unreachable!("group of zero dice"),
_ => {
buf.push('1');
for _ in 1..count {
buf.push_str(" + 1");
}
}
},
}
buf.push(')');
} else {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
}
}
TermOutput::KeepHigh(KeepHigh {
total: _,
keep_count,
roll,
}) => {
let (term, _) = terms[*term].clone().unwrap_keep_high();
let (count, sides) = terms[term].clone().unwrap_dice_roll();
let (_, partial_sums) = arena[*roll].1.clone().unwrap_dice_roll();
let nonzero_dice = match partial_sums.as_deref() {
Some([_, ..]) => true,
None if count != 0 => true,
Some([]) | None => false,
} && *keep_count != 0;
if nonzero_dice {
buf.push('(');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push('k');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*keep_count));
buf.push_str(" → ");
match partial_sums.as_deref() {
Some([first, rest @ ..]) => {
buf.push_str("**");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*first));
let mut i = 1;
for part in rest {
if i == *keep_count as usize {
buf.push_str("**");
}
buf.push_str(" + ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*part));
i += 1;
}
if i <= *keep_count as usize {
buf.push_str("**");
}
}
Some([]) => unreachable!("groups of zero dice"),
None => match count {
0 => unreachable!("group of zero dice"),
_ => {
buf.push_str("**");
buf.push('1');
let mut i = 1;
for _ in 1..count {
if i == *keep_count as usize {
buf.push_str("**");
}
buf.push_str(" + 1");
i += 1;
}
if i <= *keep_count as usize {
buf.push_str("**");
}
}
},
}
buf.push(')');
} else {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push('k');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*keep_count));
}
}
TermOutput::Add(_total, left, right) => {
fmt_default_impl(&mut *buf, *left, arena, terms);
buf.push_str(" + ");
fmt_default_impl(&mut *buf, *right, arena, terms);
}
TermOutput::Subtract(_total, left, right) => {
fmt_default_impl(&mut *buf, *left, arena, terms);
buf.push_str(" - ");
fmt_default_impl(&mut *buf, *right, arena, terms);
}
TermOutput::UnarySubtract(_total, only) => {
buf.push('-');
fmt_default_impl(buf, *only, arena, terms);
}
TermOutput::UnaryAdd(_total, only) => fmt_default_impl(buf, *only, arena, terms),
}
}
fn fmt_short_impl(buf: &mut String, current: Id<Out>, arena: &Arena<Out>, terms: &Arena<Term>) {
let (term, out) = &arena[current];
match out {
TermOutput::Constant(n) => {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*n));
}
TermOutput::DiceRoll(total, partial_sums) => {
let (count, sides) = terms[*term].clone().unwrap_dice_roll();
let nonzero_dice = match partial_sums.as_deref() {
Some([_, ..]) => true,
None if count != 0 => true,
Some([]) | None => false,
};
if nonzero_dice {
buf.push('(');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push_str(" → ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*total));
buf.push(')');
} else {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
}
}
TermOutput::KeepHigh(KeepHigh {
total,
keep_count,
roll,
}) => {
let (term, _) = terms[*term].clone().unwrap_keep_high();
let (count, sides) = terms[term].clone().unwrap_dice_roll();
let (_, partial_sums) = arena[*roll].1.clone().unwrap_dice_roll();
let nonzero_dice = match partial_sums.as_deref() {
Some([_, ..]) => true,
None if count != 0 => true,
Some([]) | None => false,
} && *keep_count != 0;
if nonzero_dice {
buf.push('(');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push('k');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*keep_count));
buf.push_str(" → ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*total));
buf.push(')');
} else {
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(count));
buf.push('d');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(sides));
buf.push('k');
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(*keep_count));
}
}
TermOutput::Add(_total, left, right) => {
fmt_short_impl(&mut *buf, *left, arena, terms);
buf.push_str(" + ");
fmt_short_impl(&mut *buf, *right, arena, terms);
}
TermOutput::Subtract(_total, left, right) => {
fmt_short_impl(&mut *buf, *left, arena, terms);
buf.push_str(" - ");
fmt_short_impl(&mut *buf, *right, arena, terms);
}
TermOutput::UnarySubtract(_total, only) => {
buf.push('-');
fmt_short_impl(buf, *only, arena, terms);
}
TermOutput::UnaryAdd(_total, only) => fmt_short_impl(buf, *only, arena, terms),
}
}
pub fn mbot_format_default(input: &Arena<Term>, output: &ProgramOutput) -> String {
let mut buf = String::with_capacity(2000);
fmt_default_impl(&mut buf, output.tree.top, &output.tree.arena, input);
buf.push_str(" = ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(output.total));
buf
}
pub fn mbot_format_short(input: &Arena<Term>, output: &ProgramOutput) -> String {
let mut buf = String::with_capacity(2000);
fmt_short_impl(&mut buf, output.tree.top, &output.tree.arena, input);
buf.push_str(" = ");
let mut itoa_buf = itoa::Buffer::new();
buf.push_str(itoa_buf.format(output.total));
buf
}
}