use crate::tree::for_;
#[non_exhaustive]
pub enum Price {
Bounded(u64),
Unbounded(u64),
}
impl ::core::ops::Add<Price> for Price {
type Output = Price;
fn add(self, rhs: Price) -> Self::Output {
match (self, rhs) {
(Price::Bounded(lhs), Price::Bounded(rhs)) => Price::Bounded(lhs.saturating_add(rhs)),
(Price::Bounded(lhs), Price::Unbounded(rhs))
| (Price::Unbounded(lhs), Price::Bounded(rhs))
| (Price::Unbounded(lhs), Price::Unbounded(rhs)) => {
Price::Unbounded(lhs.saturating_add(rhs))
}
}
}
}
impl ::core::iter::Sum<Price> for Price {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Price>,
{
iter.fold(Price::Bounded(0), |a, b| a + b)
}
}
pub trait Cost<'a, Ctx> {
type Param;
fn cost(&'a self, param: Self::Param) -> Price;
}
pub fn cost<'a, Ctx, T: Cost<'a, Ctx>>(thing: &'a T, param: T::Param) -> Price {
<T as Cost<'a, Ctx>>::cost(thing, param)
}
use crate::parse::Program;
use crate::stack::postorder;
pub struct AstInterp;
impl<'a> Cost<'a, AstInterp> for Program {
type Param = ();
fn cost(&'a self, _param: Self::Param) -> Price {
use crate::parse::Term;
let mut price = 0u64;
postorder(self, |child, _parent| match child {
Term::Constant(_) => price = price.saturating_add(1),
Term::DiceRoll(count, sides) => {
if *sides > 1 {
price = price.saturating_add(*count as u64);
} else {
price = price.saturating_add(1);
}
}
Term::KeepHigh(_, _) => price = price.saturating_add(1),
Term::KeepLow(_, _) => unimplemented!("cost of keep low on AST interp"),
Term::Explode(_) => unimplemented!("cost of explosion on AST interp"),
Term::Add(_, _) | Term::Subtract(_, _) | Term::UnarySubtract(_) => {
price = price.saturating_add(1)
}
Term::UnaryAdd(_) => (),
});
Price::Bounded(price)
}
}
pub struct StackInterp;
impl<'a> Cost<'a, StackInterp> for Program {
type Param = ();
fn cost(&'a self, _param: Self::Param) -> Price {
use crate::parse::Term;
let mut price = 0u64;
postorder(self, |child, parent| match child {
Term::Constant(_) => price = price.saturating_add(1),
Term::DiceRoll(count, sides) => match parent {
Some(Term::KeepHigh(_, keep_count)) => match (*keep_count == 0, *sides == 1) {
(true, _) => (),
(false, true) => price = price.saturating_add(1),
(false, false) => price += (*count as u64).saturating_mul(2),
},
_ => {
if *sides > 1 {
price = price.saturating_add(*count as u64);
} else {
price = price.saturating_add(1);
}
}
},
Term::KeepHigh(_, _) => (),
Term::KeepLow(_, _) => unimplemented!("keep low cost on old stack vm"),
Term::Explode(_) => unimplemented!("explosion cost"),
Term::Add(_, _) | Term::Subtract(_, _) | Term::UnarySubtract(_) => {
price = price.saturating_add(1)
}
Term::UnaryAdd(_) => (),
});
Price::Bounded(price)
}
}
pub struct MirStack;
impl<'a> Cost<'a, MirStack> for Program {
type Param = ();
fn cost(&'a self, _param: Self::Param) -> Price {
use crate::parse::Term;
let mut price = 0u64;
let mut bounded = true;
postorder(self, |child, parent| match child {
Term::Constant(_) => price = price.saturating_add(1),
Term::DiceRoll(count, _sides) => match parent {
Some(Term::KeepHigh(_, _keep_count) | Term::KeepLow(_, _keep_count)) => {
price = price.saturating_add((*count as u64).saturating_mul(2))
}
_ => price = price.saturating_add(*count as u64),
},
Term::KeepHigh(_, _) | Term::KeepLow(_, _) => (),
Term::Explode(_) => bounded = false,
Term::Add(_, _) | Term::Subtract(_, _) | Term::UnarySubtract(_) => {
price = price.saturating_add(1)
}
Term::UnaryAdd(_) => (),
});
if bounded {
Price::Bounded(price)
} else {
Price::Unbounded(price)
}
}
}
pub mod mbot {
use ::core::marker::PhantomData;
pub struct TextFormatOutput<T> {
_priv: PhantomData<T>,
}
pub struct Default;
pub struct Short;
pub struct Shortest;
pub struct Combined;
}
impl<'a> Cost<'a, mbot::TextFormatOutput<mbot::Default>> for Program {
type Param = ();
fn cost(&'a self, _param: Self::Param) -> Price {
use crate::parse::Term;
let mut price = 0u64;
for_! { (term, _ancestors) in self.postorder() => {
match term {
Term::DiceRoll(count, _sides) => price = price.saturating_add(*count as u64),
_ => price = price.saturating_add(1),
}
}}
Price::Bounded(price)
}
}