use crate::EvaluateError;
use crate::ast::{DiceExpr, Expr};
use rand::Rng;
use rand::distr::{Uniform, uniform};
use rand::prelude::ThreadRng;
use std::ops::Range;
use thiserror::Error;
use tracing::{debug, trace};
#[derive(Error, Debug)]
pub enum RandomSourceErrors {
#[error(transparent)]
UniformError(#[from] uniform::Error),
}
pub trait RandomSource<Bound> {
fn random_range(&mut self, range: Range<Bound>) -> Result<Bound, RandomSourceErrors>;
}
pub trait StochasticEvaluable<Bound, Out> {
fn evaluate<R: RandomSource<Bound>>(&self, rand_src: &mut R) -> Result<Out, EvaluateError>;
}
pub struct RandCrateSource {
source: ThreadRng,
}
impl RandCrateSource {
pub fn new() -> RandCrateSource {
RandCrateSource {
source: rand::rng(),
}
}
}
impl RandomSource<u32> for RandCrateSource {
#[mutants::skip] fn random_range(&mut self, range: Range<u32>) -> Result<u32, RandomSourceErrors> {
Ok(self.source.sample(Uniform::try_from(range)?))
}
}
pub enum FixedValue {
Minimum,
Maximum,
}
pub struct FixedSource {
pub value: FixedValue,
}
impl RandomSource<u32> for FixedSource {
fn random_range(&mut self, range: Range<u32>) -> Result<u32, RandomSourceErrors> {
match self.value {
FixedValue::Minimum => Ok(range.start),
FixedValue::Maximum => Ok(range.end - 1),
}
}
}
impl StochasticEvaluable<u32, i64> for Expr {
fn evaluate<R: RandomSource<u32>>(&self, rand_src: &mut R) -> Result<i64, EvaluateError> {
match self {
Expr::Number(n) => {
debug!("Evaluation: Number = {}", n);
Ok(*n as i64)
}
Expr::Add(e1, e2) => {
let result1 = e1.evaluate(rand_src)?;
let result2 = e2.evaluate(rand_src)?;
let final_result = result1 + result2;
debug!(
"Evaluation: Add {} + {} = {}",
result1, result2, final_result
);
Ok(final_result)
}
Expr::Subtract(e1, e2) => {
let result1 = e1.evaluate(rand_src)?;
let result2 = e2.evaluate(rand_src)?;
let final_result = result1 - result2;
debug!(
"Evaluation: Subtract {} - {} = {}",
result1, result2, final_result
);
Ok(final_result)
}
Expr::Multiply(e1, e2) => {
let result1 = e1.evaluate(rand_src)?;
let result2 = e2.evaluate(rand_src)?;
let final_result = result1 * result2;
debug!(
"Evaluation: Multiply {} * {} = {}",
result1, result2, final_result
);
Ok(final_result)
}
Expr::Negate(expr) => {
let result = expr.evaluate(rand_src)?;
debug!("Evaluation: Negate {}", result);
Ok(-result)
}
Expr::DiceTotal(dice_expr) => {
let all_results = dice_expr.evaluate(rand_src)?;
trace!("Dices results: {:?}", all_results);
let mut total: i64 = 0;
for result in all_results.results {
total += result as i64;
}
debug!("Evaluation: DiceTotal {}", total);
Ok(total)
}
}
}
}
#[derive(Debug)]
struct AllDiceResults {
pub results: Vec<u32>,
pub number_of_sides: u32,
}
impl StochasticEvaluable<u32, AllDiceResults> for DiceExpr {
fn evaluate<R: RandomSource<u32>>(
&self,
rand_src: &mut R,
) -> Result<AllDiceResults, EvaluateError> {
match self {
DiceExpr::Dice(number_of_dices, number_of_sides) => {
let mut results: Vec<u32> = Vec::new();
for _ in 0..*number_of_dices {
let dice_rolled = rand_src
.random_range(1..*number_of_sides + 1)
.map_err(|err| EvaluateError::OtherErrors(Box::new(err)))?;
results.push(dice_rolled);
}
results.sort();
debug!(
"Dice Evaluation: Dice {}d{}: {:?}",
number_of_dices, number_of_sides, results
);
Ok(AllDiceResults {
results,
number_of_sides: *number_of_sides,
})
}
DiceExpr::ModKeepHighest(expr, to_keep) => {
let mut results = expr.evaluate(rand_src)?;
results.results.sort();
let start = results.results.len() - *to_keep as usize;
let new_results = results.results[start..].to_vec();
debug!(
"Dice Evaluation: Keep Highest({}): {:?}",
to_keep, new_results
);
Ok(AllDiceResults {
results: new_results,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModKeepLowest(expr, to_keep) => {
let mut results = expr.evaluate(rand_src)?;
results.results.sort();
let end = *to_keep as usize;
let new_results = results.results[..end].to_vec();
debug!(
"Dice Evaluation: Keep Lowest({}): {:?}",
to_keep, new_results
);
Ok(AllDiceResults {
results: new_results,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModDropHighest(expr, to_drop) => {
let mut results = expr.evaluate(rand_src)?;
results.results.sort();
let end = results.results.len() - *to_drop as usize;
let new_results = results.results[..end].to_vec();
debug!(
"Dice Evaluation: Drop Highest({}): {:?}",
to_drop, new_results
);
Ok(AllDiceResults {
results: new_results,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModDropLowest(expr, to_drop) => {
let mut results = expr.evaluate(rand_src)?;
results.results.sort();
let start = *to_drop as usize;
let new_result = results.results[start..].to_vec();
debug!(
"Dice Evaluation: Drop Lowest({}): {:?}",
to_drop, new_result
);
Ok(AllDiceResults {
results: new_result,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModReroll(expr, reroll_under) => {
let results = expr.evaluate(rand_src)?;
let rerolled_result = Self::reroll(rand_src, &results, reroll_under)?;
debug!(
"Dice Evaluation: Reroll({}): {:?}",
reroll_under, rerolled_result
);
Ok(AllDiceResults {
results: rerolled_result,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModRerollOnce(expr, reroll_under) => {
let results = expr.evaluate(rand_src)?;
let rerolled_result = Self::reroll_once(rand_src, &results, *reroll_under)?;
debug!(
"Dice Evaluation: Reroll Once({}): {:?}",
reroll_under, rerolled_result
);
Ok(AllDiceResults {
results: rerolled_result,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModMininum(expr, minimum) => {
let results = expr.evaluate(rand_src)?;
let new_result = Self::upgrade_to(&results, *minimum)?;
debug!("Dice Evaluation: Minimum({}): {:?}", minimum, new_result);
Ok(AllDiceResults {
results: new_result,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModMaximum(expr, maximum) => {
let results = expr.evaluate(rand_src)?;
let new_result = Self::downgrade_to(&results, *maximum)?;
debug!("Dice Evaluation: Maximum({}): {:?}", maximum, new_result);
Ok(AllDiceResults {
results: new_result,
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModCountGreaterOrEqual(expr, pivot) => {
let results = expr.evaluate(rand_src)?;
let count = Self::count_greater_or_equal(&results, pivot);
debug!(
"Dice Evaluation: Count Greater or Equal({}): {:?}",
pivot, count
);
Ok(AllDiceResults {
results: vec![count],
number_of_sides: results.number_of_sides,
})
}
DiceExpr::ModCountLowerOrEqual(expr, pivot) => {
let results = expr.evaluate(rand_src)?;
let count = Self::count_lower_or_equal(&results, pivot);
debug!(
"Dice Evaluation: Count Lower or Equal({}): {:?}",
pivot, count
);
Ok(AllDiceResults {
results: vec![count],
number_of_sides: results.number_of_sides,
})
}
}
}
}
impl DiceExpr {
#[mutants::skip] fn reroll<R: RandomSource<u32>>(
rand_src: &mut R,
results: &AllDiceResults,
reroll_under: &u32,
) -> Result<Vec<u32>, EvaluateError> {
if *reroll_under > results.number_of_sides {
return Err(EvaluateError::RerollValueHigherThanDiceSides {
reroll: *reroll_under,
dice_sides: results.number_of_sides,
});
}
let mut rerolled_result = Vec::new();
for result in &results.results {
if result <= reroll_under {
let new_roll = rand_src
.random_range(reroll_under + 1..results.number_of_sides + 1)
.map_err(|err| EvaluateError::OtherErrors(Box::new(err)))?;
rerolled_result.push(new_roll);
} else {
rerolled_result.push(*result);
}
}
Ok(rerolled_result)
}
fn count_lower_or_equal(results: &AllDiceResults, pivot: &u32) -> u32 {
let mut count: u32 = 0;
for result in &results.results {
if result <= pivot {
count += 1;
}
}
count
}
fn count_greater_or_equal(results: &AllDiceResults, pivot: &u32) -> u32 {
let mut count: u32 = 0;
for result in &results.results {
if result >= pivot {
count += 1;
}
}
count
}
fn downgrade_to(results: &AllDiceResults, maximum: u32) -> Result<Vec<u32>, EvaluateError> {
if maximum == 0 {
return Err(EvaluateError::MaximumValueCannotBe0);
}
let mut new_result = Vec::new();
for result in &results.results {
if *result > maximum {
new_result.push(maximum);
} else {
new_result.push(*result);
}
}
Ok(new_result)
}
fn upgrade_to(results: &AllDiceResults, minimum: u32) -> Result<Vec<u32>, EvaluateError> {
if minimum > results.number_of_sides {
return Err(EvaluateError::MinimumValueHigherThanDiceSides {
minimum,
dice_size: results.number_of_sides,
});
}
let mut new_result = Vec::new();
for result in &results.results {
if *result < minimum {
new_result.push(minimum);
} else {
new_result.push(*result);
}
}
Ok(new_result)
}
fn reroll_once<R: RandomSource<u32>>(
rand_src: &mut R,
results: &AllDiceResults,
reroll_under: u32,
) -> Result<Vec<u32>, EvaluateError> {
if reroll_under > results.number_of_sides {
return Err(EvaluateError::RerollValueHigherThanDiceSides {
reroll: reroll_under,
dice_sides: results.number_of_sides,
});
}
let mut rerolled_result = Vec::new();
for result in &results.results {
if *result <= reroll_under {
let new_roll = rand_src
.random_range(1..results.number_of_sides + 1)
.map_err(|err| EvaluateError::OtherErrors(Box::new(err)))?;
rerolled_result.push(new_roll);
} else {
rerolled_result.push(*result);
}
}
Ok(rerolled_result)
}
}
#[cfg(test)]
mod tests {
use crate::ast::Parser;
use crate::ast_evaluation::{
FixedSource, FixedValue, RandomSource, RandomSourceErrors, StochasticEvaluable,
};
use std::ops::Range;
#[test]
fn evaluate_d6() {
let expr = Parser::parse("d6").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(value as i64, result);
}
#[test]
fn evaluate_2d6() {
let expr = Parser::parse("2d6").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(4, result);
}
#[test]
fn evaluate_2d6_plus_1() {
let expr = Parser::parse("2d6+1").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(5, result);
}
#[test]
fn evaluate_2d6_minus_1() {
let expr = Parser::parse("2d6-1").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(3, result);
}
#[test]
fn evaluate_2d6_multiplied_by_3() {
let expr = Parser::parse("2d6*3").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(12, result);
}
#[test]
fn evaluate_order_of_operations() {
let expr = Parser::parse("2d6*3+2").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(14, result);
}
#[test]
fn evaluate_parenthesis() {
let expr = Parser::parse("2d6*(3+2)").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(20, result);
}
#[test]
fn evaluate_negative() {
let expr = Parser::parse("-d6").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(-2, result);
}
#[test]
fn evaluate_addition_with_negative() {
let expr = Parser::parse("3d6 + -1").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(5, result);
}
#[test]
fn evaluate_addition_with_two_dices() {
let expr = Parser::parse("2d6 + 1d4").unwrap();
let value = 2;
let result = expr.evaluate(&mut TestFixedValue::from(value)).unwrap();
assert_eq!(6, result);
}
#[test]
fn evaluate_keep_highest() {
let expr = Parser::parse("2d6k1").unwrap();
let mut values = TestVecValues::new(vec![2, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(4, result);
}
#[test]
fn evaluate_keep_highest_x() {
let expr = Parser::parse("4d6k2").unwrap();
let mut values = TestVecValues::new(vec![1, 2, 4, 6]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(10, result);
}
#[test]
fn evaluate_keep_lowest() {
let expr = Parser::parse("2d6kl1").unwrap();
let mut values = TestVecValues::new(vec![2, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(2, result);
}
#[test]
fn evaluate_keep_lowest_x() {
let expr = Parser::parse("4d6kl2").unwrap();
let mut values = TestVecValues::new(vec![1, 2, 4, 6]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(3, result);
}
#[test]
fn evaluate_drop_highest() {
let expr = Parser::parse("2d6dh1").unwrap();
let mut values = TestVecValues::new(vec![2, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(2, result);
}
#[test]
fn evaluate_drop_highest_x() {
let expr = Parser::parse("4d6dh2").unwrap();
let mut values = TestVecValues::new(vec![1, 2, 4, 6]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(3, result);
}
#[test]
fn evaluate_drop_lowest() {
let expr = Parser::parse("2d6d1").unwrap();
let mut values = TestVecValues::new(vec![2, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(4, result);
}
#[test]
fn evaluate_drop_lowest_x() {
let expr = Parser::parse("4d6d2").unwrap();
let mut values = TestVecValues::new(vec![1, 2, 4, 6]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(10, result);
}
#[test]
fn evaluate_reroll() {
let expr = Parser::parse("10d10r5").unwrap();
let mut values = TestVecValues::from_nb_faces(10);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(80, result);
}
#[test]
fn evaluate_reroll_higher_than_max_roll() {
let expr = Parser::parse("2d6r7").unwrap();
let mut values = TestVecValues::new(vec![1]);
let result = expr.evaluate(&mut values);
assert!(result.is_err());
}
#[test]
fn evaluate_reroll_once() {
let expr = Parser::parse("10d6ro2").unwrap();
let mut values = TestVecValues::from_nb_faces(6);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(39, result);
}
#[test]
fn evaluate_reroll_once_higher_than_max_roll() {
let expr = Parser::parse("2d6ro7").unwrap();
let mut values = TestVecValues::new(vec![1]);
let result = expr.evaluate(&mut values);
assert!(result.is_err());
}
#[test]
fn evaluate_minimum() {
let expr = Parser::parse("2d6mi3").unwrap();
let mut values = TestVecValues::new(vec![1]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(6, result);
}
#[test]
fn evaluate_minimum_higher_than_max_roll() {
let expr = Parser::parse("2d6mi7").unwrap();
let mut values = TestVecValues::new(vec![1]);
let result = expr.evaluate(&mut values);
assert!(result.is_err());
}
#[test]
fn evaluate_maximum() {
let expr = Parser::parse("2d6ma3").unwrap();
let mut values = TestVecValues::new(vec![2, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(5, result);
}
#[test]
fn evaluate_count_greater() {
let expr = Parser::parse("3d6>3").unwrap();
let mut values = TestVecValues::new(vec![1, 3, 4]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(2, result);
}
#[test]
fn evaluate_count_lower() {
let expr = Parser::parse("2d6<3").unwrap();
let mut values = TestVecValues::new(vec![2, 3]);
let result = expr.evaluate(&mut values).unwrap();
assert_eq!(2, result);
}
#[test]
fn evaluate_minimum_double_dice() {
let expr = Parser::parse("2d6").unwrap();
let result = expr
.evaluate(&mut FixedSource {
value: FixedValue::Minimum,
})
.unwrap();
assert_eq!(2, result);
}
#[test]
fn evaluate_minimum_addition() {
let expr = Parser::parse("d8 + 2").unwrap();
let result = expr
.evaluate(&mut FixedSource {
value: FixedValue::Minimum,
})
.unwrap();
assert_eq!(3, result);
}
#[test]
fn evaluate_maximum_double_dice() {
let expr = Parser::parse("2d6").unwrap();
let result = expr
.evaluate(&mut FixedSource {
value: FixedValue::Maximum,
})
.unwrap();
assert_eq!(12, result);
}
#[test]
fn evaluate_maximum_addition() {
let expr = Parser::parse("d8 + 2").unwrap();
let result = expr
.evaluate(&mut FixedSource {
value: FixedValue::Maximum,
})
.unwrap();
assert_eq!(10, result);
}
struct TestFixedValue<T> {
value: T,
}
impl RandomSource<u32> for TestFixedValue<u32> {
fn random_range(&mut self, _: Range<u32>) -> Result<u32, RandomSourceErrors> {
Ok(self.value)
}
}
impl From<u32> for TestFixedValue<u32> {
fn from(value: u32) -> Self {
TestFixedValue { value }
}
}
struct TestVecValues<T> {
values: Vec<T>,
current_index: usize,
}
impl<T> TestVecValues<T> {
pub fn new(values: Vec<T>) -> TestVecValues<T> {
TestVecValues {
values,
current_index: 0,
}
}
}
impl TestVecValues<u32> {
pub fn from_nb_faces(faces: u32) -> TestVecValues<u32> {
let mut values: Vec<u32> = vec![];
for i in 1..faces + 1 {
values.push(i);
}
TestVecValues::new(values)
}
}
impl<T: Copy> TestVecValues<T> {
fn next_value(&mut self) -> T {
let to_return = self.values[self.current_index];
self.current_index += 1;
if self.current_index >= self.values.len() {
self.current_index = 0;
}
to_return
}
}
impl RandomSource<u32> for TestVecValues<u32> {
fn random_range(&mut self, range: Range<u32>) -> Result<u32, RandomSourceErrors> {
let mut val = self.next_value();
while !range.contains(&val) {
val = self.next_value();
}
Ok(val)
}
}
}