use fraction::One;
use super::{
dice::Dice,
dice_string_parser::{self, DiceBuildingError},
};
use core::panic;
use std::{
collections::HashMap,
fmt::Display,
ops::{Add, Mul},
};
pub type Value = i64;
pub type Prob = fraction::BigFraction;
pub type AggrValue = fraction::BigFraction;
type Distribution = Box<dyn Iterator<Item = (Value, Prob)>>;
pub type DistributionHashMap = HashMap<Value, Prob>;
#[derive(Debug, PartialEq, Eq)]
pub enum DiceBuilder {
Constant(Value),
FairDie {
min: Value,
max: Value,
},
SumCompound(Vec<DiceBuilder>),
ProductCompound(Vec<DiceBuilder>),
DivisionCompound(Vec<DiceBuilder>),
MaxCompound(Vec<DiceBuilder>),
MinCompound(Vec<DiceBuilder>),
SampleSumCompound(Vec<DiceBuilder>),
}
impl DiceBuilder {
pub fn from_string(input: &str) -> Result<Self, DiceBuildingError> {
dice_string_parser::string_to_factor(input)
}
pub fn build(self) -> Dice {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
Dice::from_builder(self)
}
pub fn build_from_string(input: &str) -> Result<Dice, DiceBuildingError> {
let builder = DiceBuilder::from_string(input)?;
Ok(builder.build())
}
pub fn reconstruct_string(&self) -> String {
match self {
DiceBuilder::Constant(i) => i.to_string(),
DiceBuilder::FairDie { min, max } => match *min == 1 {
true => format!("d{max}"),
false => "".to_owned(), },
DiceBuilder::SumCompound(v) => v
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join("+"),
DiceBuilder::ProductCompound(v) => v
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join("*"),
DiceBuilder::DivisionCompound(v) => v
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join("/"),
DiceBuilder::SampleSumCompound(v) => v
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join("x"),
DiceBuilder::MaxCompound(v) => format!(
"max({})",
v.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join(",")
),
DiceBuilder::MinCompound(v) => format!(
"min({})",
v.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join(",")
),
}
}
fn distribution_hashmap(&self) -> DistributionHashMap {
match self {
DiceBuilder::Constant(v) => {
let mut m = DistributionHashMap::new();
m.insert(*v, Prob::one());
m
}
DiceBuilder::FairDie { min, max } => {
assert!(max >= min);
let min: i64 = *min;
let max: i64 = *max;
let prob: Prob = Prob::new(1u64, (max - min + 1) as u64);
let mut m = DistributionHashMap::new();
for v in min..=max {
m.insert(v, prob.clone());
}
m
}
DiceBuilder::SampleSumCompound(vec) => {
let hashmaps = vec
.iter()
.map(|e| e.distribution_hashmap())
.collect::<Vec<DistributionHashMap>>();
sample_sum_convolute_hashmaps(&hashmaps)
}
DiceBuilder::SumCompound(vec)
| DiceBuilder::ProductCompound(vec)
| DiceBuilder::DivisionCompound(vec)
| DiceBuilder::MaxCompound(vec)
| DiceBuilder::MinCompound(vec) => {
let operation = match self {
DiceBuilder::SumCompound(_) => |a, b| a + b,
DiceBuilder::ProductCompound(_) => |a, b| a * b,
DiceBuilder::MaxCompound(_) => std::cmp::max,
DiceBuilder::MinCompound(_) => std::cmp::min,
DiceBuilder::DivisionCompound(_) => rounded_div::i64,
_ => panic!("unreachable by match"),
};
let hashmaps = vec
.iter()
.map(|e| e.distribution_hashmap())
.collect::<Vec<DistributionHashMap>>();
convolute_hashmaps(&hashmaps, operation)
}
}
}
pub fn distribution_iter(&self) -> Distribution {
let mut distribution_vec = self
.distribution_hashmap()
.into_iter()
.collect::<Vec<(Value, Prob)>>();
distribution_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
Box::new(distribution_vec.into_iter())
}
}
impl Display for DiceBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write! {f, "{}", self.reconstruct_string()}
}
}
fn convolute_hashmaps(
hashmaps: &Vec<DistributionHashMap>,
operation: fn(Value, Value) -> Value,
) -> DistributionHashMap {
if hashmaps.is_empty() {
panic!("cannot convolute hashmaps from a zero element vector");
}
let mut convoluted_h = hashmaps[0].clone();
for h in hashmaps.iter().skip(1) {
convoluted_h = convolute_two_hashmaps(&convoluted_h, h, operation);
}
convoluted_h
}
fn convolute_two_hashmaps(
h1: &DistributionHashMap,
h2: &DistributionHashMap,
operation: fn(Value, Value) -> Value,
) -> DistributionHashMap {
let mut m = DistributionHashMap::new();
for (v1, p1) in h1.iter() {
for (v2, p2) in h2.iter() {
let v = operation(*v1, *v2);
let p = p1 * p2;
match m.entry(v) {
std::collections::hash_map::Entry::Occupied(mut e) => {
*e.get_mut() += p;
}
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(p);
}
}
}
}
m
}
fn sample_sum_convolute_hashmaps(hashmaps: &Vec<DistributionHashMap>) -> DistributionHashMap {
if hashmaps.is_empty() {
panic!("cannot convolute hashmaps from a zero element vector");
}
let mut convoluted_h = hashmaps[0].clone();
for h in hashmaps.iter().skip(1) {
convoluted_h = sample_sum_convolute_two_hashmaps(&convoluted_h, h);
}
convoluted_h
}
fn sample_sum_convolute_two_hashmaps(
count_factor: &DistributionHashMap,
sample_factor: &DistributionHashMap,
) -> DistributionHashMap {
let mut total_hashmap = DistributionHashMap::new();
for (count, count_p) in count_factor.iter() {
let mut count_hashmap: DistributionHashMap = match count.cmp(&0) {
std::cmp::Ordering::Less => {
let count: usize = (-count) as usize;
let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
.take(count)
.cloned()
.collect();
convolute_hashmaps(&sample_vec, |a, b| a + b)
}
std::cmp::Ordering::Equal => {
let mut h = DistributionHashMap::new();
h.insert(0, Prob::new(1u64, 1u64));
h
}
std::cmp::Ordering::Greater => {
let count: usize = *count as usize;
let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
.take(count)
.cloned()
.collect();
convolute_hashmaps(&sample_vec, |a, b| a + b)
}
};
count_hashmap.iter_mut().for_each(|e| {
*e.1 *= count_p.clone();
});
merge_hashmaps(&mut total_hashmap, &count_hashmap);
}
total_hashmap
}
impl Mul for Box<DiceBuilder> {
type Output = Box<DiceBuilder>;
fn mul(self, rhs: Self) -> Self::Output {
Box::new(DiceBuilder::ProductCompound(vec![*self, *rhs]))
}
}
impl Add for Box<DiceBuilder> {
type Output = Box<DiceBuilder>;
fn add(self, rhs: Self) -> Self::Output {
Box::new(DiceBuilder::SumCompound(vec![*self, *rhs]))
}
}
pub fn merge_hashmaps(first: &mut DistributionHashMap, second: &DistributionHashMap) {
for (k, v) in second.iter() {
match first.get_mut(k) {
Some(e) => {
*e += v;
}
None => {
first.insert(*k, v.clone());
}
}
}
}