use itertools::Itertools;
use std::collections::HashMap;
use crate::dice::*;
use crate::item_counter::ItemCounter;
#[cfg(test)]
mod tests;
#[derive(Eq, PartialEq, Clone, Hash)]
struct RollResultPossibility {
symbols: ItemCounter<DieSymbol>
}
impl RollResultPossibility {
pub fn new() -> RollResultPossibility {
RollResultPossibility {
symbols: ItemCounter::new()
}
}
pub fn add_symbols(&self, symbols: &[DieSymbol]) -> RollResultPossibility {
let mut symbol_count = self.clone().symbols;
for symbol in symbols {
symbol_count.add(symbol);
}
RollResultPossibility { symbols: symbol_count }
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum RollTargetTypes {
Exactly,
AtLeast,
AtMost
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct RollTarget<'a> {
target_type: RollTargetTypes,
amount: usize,
symbols: &'a [DieSymbol]
}
impl<'a> RollTarget<'a> {
pub fn exactly_n_of(n: usize, symbols: &'a [DieSymbol]) -> RollTarget {
RollTarget {
target_type: RollTargetTypes::Exactly,
amount: n,
symbols
}
}
pub fn at_least_n_of(n: usize, symbols: &'a [DieSymbol]) -> RollTarget {
RollTarget {
target_type: RollTargetTypes::AtLeast,
amount: n,
symbols
}
}
pub fn at_most_n_of(n: usize, symbols: &'a [DieSymbol]) -> RollTarget {
RollTarget {
target_type: RollTargetTypes::AtMost,
amount: n,
symbols
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum RollCollectionTypes {
CollectAll,
TakeHighestN(usize),
TakeLowestN(usize),
RemoveHighestN(usize),
RemoveLowestN(usize)
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct RollCollectionPolicy<'a> {
coll_type: RollCollectionTypes,
symbols: &'a [DieSymbol]
}
impl<'a> RollCollectionPolicy<'a> {
pub fn collect_all(symbols: &'a [DieSymbol]) -> RollCollectionPolicy {
RollCollectionPolicy {
coll_type: RollCollectionTypes::CollectAll,
symbols
}
}
pub fn take_highest_n_of(n:usize, symbols: &'a [DieSymbol]) -> RollCollectionPolicy {
RollCollectionPolicy {
coll_type: RollCollectionTypes::TakeHighestN(n),
symbols
}
}
pub fn take_lowest_n_of(n:usize, symbols: &'a [DieSymbol]) -> RollCollectionPolicy {
RollCollectionPolicy {
coll_type: RollCollectionTypes::TakeLowestN(n),
symbols
}
}
pub fn remove_highest_n_of(n:usize, symbols: &'a [DieSymbol]) -> RollCollectionPolicy {
RollCollectionPolicy {
coll_type: RollCollectionTypes::RemoveHighestN(n),
symbols
}
}
pub fn remove_lowest_n_of(n:usize, symbols: &'a [DieSymbol]) -> RollCollectionPolicy {
RollCollectionPolicy {
coll_type: RollCollectionTypes::RemoveLowestN(n),
symbols
}
}
}
pub struct RollProbabilities {
occurrences: HashMap<RollResultPossibility, usize>,
total: usize
}
impl RollProbabilities {
fn collect_symbols(roll: &[&DieSide], policy: &RollCollectionPolicy) -> Vec<DieSymbol> {
let mut filtered_sides: Vec<Vec<DieSymbol>> =
roll.iter()
.map(|x|
x.symbols().iter()
.filter(|y| policy.symbols.contains(y))
.cloned().collect())
.collect();
filtered_sides.sort_by(|x,y| x.len().cmp(&y.len()));
filtered_sides.reverse();
let sides_len = filtered_sides.len();
match policy.coll_type {
RollCollectionTypes::CollectAll =>
filtered_sides.iter()
.flatten().cloned().collect(),
RollCollectionTypes::TakeHighestN(n) =>
filtered_sides.iter().take(n)
.flatten().cloned().collect(),
RollCollectionTypes::TakeLowestN(n) =>
filtered_sides.iter().skip(sides_len - n)
.flatten().cloned().collect(),
RollCollectionTypes::RemoveHighestN(n) =>
filtered_sides.iter().skip(n)
.flatten().cloned().collect(),
RollCollectionTypes::RemoveLowestN(n) =>
filtered_sides.iter().take(sides_len - n)
.flatten().cloned().collect()
}
}
pub fn new(dice: &[Die], policy: RollCollectionPolicy) -> Result<RollProbabilities, String> {
if dice.len() == 0 {
return Err("must include at least one die".to_string());
}
let mut occur = HashMap::new();
for roll in dice.into_iter()
.map(|x| x.sides())
.multi_cartesian_product() {
let collected = Self::collect_symbols(&roll, &policy);
let new_poss =
RollResultPossibility::new()
.add_symbols(&collected);
if occur.contains_key(&new_poss) {
occur.get_mut(&new_poss).map(|x| *x += 1);
} else {
occur.insert(new_poss, 1);
}
}
let total = occur.values().sum();
Ok(RollProbabilities {
occurrences: occur,
total: total
})
}
pub fn get_odds(&self, target: RollTarget) -> f64 {
if self.total == 0 {
return 0.0;
}
let mut total_occurrences = 0;
for poss in self.occurrences.keys() {
let mut count: usize = 0;
for symbol in target.symbols {
count += poss.symbols.get_count(&symbol);
}
let cond = match target.target_type {
RollTargetTypes::Exactly => count == target.amount,
RollTargetTypes::AtLeast => count >= target.amount,
RollTargetTypes::AtMost => count <= target.amount
};
if cond {
total_occurrences += self.occurrences[poss];
}
}
return (total_occurrences as f64) / (self.total as f64);
}
}