use crate::traits::PlotResult;
use crate::traits::PlotTable;
use crate::traits::RollResult;
use rand::Rng;
use std::collections::HashMap;
use crate::operators::{
advantage, difference, disadvantage, divide, equal_to, greater_than, greater_than_or_equal_to,
less_than, less_than_or_equal_to, multiply, sum, BinaryOperator,
};
use crate::traits::Rollable;
#[derive(Clone, Debug, PartialEq)]
pub enum Comparison {
GreaterThan,
GreaterThanOrEqualTo,
LessThan,
LessThanOrEqualTo,
EqualTo,
}
use Comparison::*;
#[derive(Clone, Debug, PartialEq)]
pub enum Expression {
Die(u32),
Constant(i32),
Sum(Box<Expression>, Box<Expression>),
Diff(Box<Expression>, Box<Expression>),
Multiply(Box<Expression>, Box<Expression>),
Divide(Box<Expression>, Box<Expression>),
Advantage(Box<Expression>),
Disadvantage(Box<Expression>),
Compare(Box<Expression>, Box<Expression>, Comparison),
}
use Expression::*;
impl Expression {
fn get_operation(&self) -> Option<(BinaryOperator, &Box<Expression>, &Box<Expression>)> {
match self {
Constant(_) => None,
Die(_) => None,
Sum(left, right) => Some((sum, left, right)),
Diff(left, right) => Some((difference, left, right)),
Multiply(left, right) => Some((multiply, left, right)),
Divide(left, right) => Some((divide, left, right)),
Advantage(expr) => Some((advantage, expr, expr)),
Disadvantage(expr) => Some((disadvantage, expr, expr)),
Compare(left, right, comparison) => {
let compare = match comparison {
GreaterThan => greater_than,
GreaterThanOrEqualTo => greater_than_or_equal_to,
LessThan => less_than,
LessThanOrEqualTo => less_than_or_equal_to,
EqualTo => equal_to,
};
Some((compare, left, right))
}
}
}
}
impl Rollable for Expression {
fn roll(&self) -> RollResult {
if let Constant(num) = self {
return *num;
}
if let Die(max) = self {
return rand::thread_rng().gen_range(1, *max as RollResult + 1);
}
let (operator, left, right) = self
.get_operation()
.expect("expression does not represent an operation");
(operator)(&left.roll(), &right.roll())
}
fn plot(&self) -> PlotResult {
if let Constant(num) = self {
return PlotResult {
total: 1.0,
plot: [(*num, 1.0)].iter().cloned().collect(),
};
}
if let Die(num) = self {
let total = *num as f32;
return PlotResult {
total,
plot: (1..num + 1)
.map(|i| (i as RollResult, 1.0 / total))
.collect(),
};
}
let (operator, left, right) = self
.get_operation()
.expect("expression does not represent an operation");
let left = left.plot();
let right = right.plot();
let mut product: PlotTable = HashMap::new();
left.plot
.iter()
.flat_map(|(left_value, left_chance)| {
right.plot.iter().map(move |(right_value, right_chance)| {
let value = (operator)(left_value, right_value);
(value, left_chance * right_chance)
})
})
.for_each(|(value, count)| {
*product.entry(value).or_insert(0.0) += count;
});
PlotResult {
total: left.total * right.total,
plot: product,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
impl PlotResult {
pub fn simplify(&self) -> HashMap<i32, i32> {
self.plot
.iter()
.map(|(value, chance)| {
let outcomes = (chance * self.total) as i32;
(*value, outcomes)
})
.collect()
}
}
#[test]
fn simplify_produces_correct_table() {
let plot_result = PlotResult {
total: 10.0,
plot: [(1, 0.1), (2, 0.2), (3, 0.3), (4, 0.4)]
.iter()
.cloned()
.collect(),
};
let expected: HashMap<i32, i32> =
[(1, 1), (2, 2), (3, 3), (4, 4)].iter().cloned().collect();
let actual = plot_result.simplify();
assert_eq!(expected, actual);
}
#[test]
fn multiply_produces_correct_plot() {
let expression =
Expression::Multiply(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
let expected: HashMap<i32, i32> = [
(1, 1),
(2, 2),
(3, 2),
(4, 3),
(6, 2),
(8, 2),
(9, 1),
(12, 2),
(16, 1),
]
.iter()
.cloned()
.collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn divide_produces_correct_plot() {
let expression =
Expression::Divide(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
let expected: HashMap<i32, i32> = [(0, 6), (1, 6), (2, 2), (3, 1), (4, 1)]
.iter()
.cloned()
.collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn sum_produces_correct_plot() {
let expression =
Expression::Sum(Box::new(Expression::Die(6)), Box::new(Expression::Die(6)));
let expected: HashMap<i32, i32> = [
(2, 1),
(3, 2),
(4, 3),
(5, 4),
(6, 5),
(7, 6),
(8, 5),
(9, 4),
(10, 3),
(11, 2),
(12, 1),
]
.iter()
.cloned()
.collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn difference_produces_correct_plot() {
let expression =
Expression::Diff(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
let expected: HashMap<i32, i32> =
[(-3, 1), (-2, 2), (-1, 3), (0, 4), (1, 3), (2, 2), (3, 1)]
.iter()
.cloned()
.collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn advantage_produces_correct_plot() {
let expression = Expression::Advantage(Box::new(Expression::Die(4)));
let expected: HashMap<i32, i32> =
[(1, 1), (2, 3), (3, 5), (4, 7)].iter().cloned().collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn disadvantage_produces_correct_plot() {
let expression = Expression::Disadvantage(Box::new(Expression::Die(4)));
let expected: HashMap<i32, i32> =
[(1, 7), (2, 5), (3, 3), (4, 1)].iter().cloned().collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
#[ignore = "not implemented"]
fn contest_produces_correct_plot() {
let expression = Expression::Compare(
Box::new(Expression::Die(2)),
Box::new(Expression::Die(3)),
Comparison::GreaterThan,
);
let expected: HashMap<i32, i32> = [(-1, 3), (0, 2), (1, 1)].iter().cloned().collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual);
}
#[test]
fn compare_produces_correct_plot() {
let left = Expression::Die(3);
let right = Expression::Die(3);
let cases: Vec<(Comparison, &[(i32, i32)])> = vec![
(Comparison::GreaterThan, &[(0, 6), (1, 3)]),
(Comparison::GreaterThanOrEqualTo, &[(0, 3), (1, 6)]),
(Comparison::LessThan, &[(0, 6), (1, 3)]),
(Comparison::LessThanOrEqualTo, &[(0, 3), (1, 6)]),
(Comparison::EqualTo, &[(0, 6), (1, 3)]),
];
for (comparison, options) in cases {
let expression = Expression::Compare(
Box::new(left.clone()),
Box::new(right.clone()),
comparison.clone(),
);
let expected: HashMap<i32, i32> = options.iter().cloned().collect();
let actual = expression.plot().simplify();
assert_eq!(expected, actual, "{:?}", comparison);
}
}
}