dices/
dice_builder.rs

1use fraction::One;
2
3use super::{
4    dice::Dice,
5    dice_string_parser::{self, DiceBuildingError},
6};
7use core::panic;
8use std::{
9    collections::HashMap,
10    fmt::Display,
11    ops::{Add, Mul},
12};
13pub type Value = i64;
14pub type Prob = fraction::BigFraction;
15pub type AggrValue = fraction::BigFraction;
16type Distribution = Box<dyn Iterator<Item = (Value, Prob)>>;
17pub type DistributionHashMap = HashMap<Value, Prob>;
18
19/// A [`DiceBuilder`] tree-like data structure representing the components of a dice formula like `max(2d6+4,d20)`
20///
21/// The tree can be used to calculate a discrete probability distribution. This happens when the `build()` method is called and creates a [`Dice`].
22///
23/// # Examples
24/// ```
25/// use dices::DiceBuilder;
26/// use fraction::ToPrimitive;
27/// let dice_builder = DiceBuilder::from_string("2d6+4").unwrap();
28/// let dice = dice_builder.build();
29/// let mean = dice.mean.to_f64().unwrap();
30/// assert_eq!(mean, 11.0);
31/// ```
32#[derive(Debug, PartialEq, Eq)]
33pub enum DiceBuilder {
34    /// A constant value (i64) that does not
35    Constant(Value),
36    /// A discrete uniform distribution over the integer interval `[min, max]`
37    FairDie {
38        /// minimum value of the die, inclusive
39        min: Value,
40        /// maximum value of the die, inclusive
41        max: Value,
42    },
43    /// the sum of multiple [DiceBuilder] instances, like: d6 + 3 + d20
44    SumCompound(Vec<DiceBuilder>),
45    /// the product of multiple [DiceBuilder] instances, like: d6 * 3 * d20
46    ProductCompound(Vec<DiceBuilder>),
47    /// the division of multiple [DiceBuilder] instances, left-associative, rounded up to integers like: d6 / 2 = d3
48    DivisionCompound(Vec<DiceBuilder>),
49    /// the maximum of multiple [DiceBuilder] instances, like: max(d6,3,d20)
50    MaxCompound(Vec<DiceBuilder>),
51    /// the minimum of multiple [DiceBuilder] instances, like: min(d6,3,d20)
52    MinCompound(Vec<DiceBuilder>),
53    /// SampleSumCompound(vec![a,b]) can be interpreted as follows:
54    /// A [`DiceBuilder`] `b` is sampled `a` times independently of each other.
55    /// It is represented by an x in input strings, e.g. "a x b"
56    /// The operator is left-associative, so a x b x c is (a x b) x c.
57    ///
58    /// # Examples
59    /// throwing 5 six-sided dice:
60    /// ```
61    /// use dices::DiceBuilder::*;
62    /// let five_six_sided_dice = SampleSumCompound(
63    ///     vec![Constant(5),FairDie{min: 1, max: 6}]
64    /// );
65    /// ```
66    ///
67    /// throwing 1, 2 or 3 (randomly determined) six-sided and summing them up:
68    /// ```
69    /// use dices::DiceBuilder::*;
70    /// let dice_1_2_or_3 = SampleSumCompound(
71    ///     vec![FairDie{min: 1, max: 3},FairDie{min: 1, max: 6}]
72    /// );
73    /// ```
74    ///
75    /// for two constants, it is the same as multiplication:
76    /// ```
77    /// use dices::DiceBuilder::*;
78    /// let b1 = SampleSumCompound(vec![Constant(2),Constant(3)]);
79    /// let b2 = ProductCompound(vec![Constant(2),Constant(3)]);
80    /// assert_eq!(b1.build().distribution, b2.build().distribution);
81    ///
82    /// ```
83    SampleSumCompound(Vec<DiceBuilder>),
84}
85
86impl DiceBuilder {
87    /// parses the string into a tree-like structure to create a [`DiceBuilder`]
88    ///
89    /// # Syntax Examples:
90    /// |-----|
91    /// |     |
92    /// 4 six-sided dice: "4d6"
93    ///
94    /// # Examples:
95    /// throwing 3 six-sided dice:
96    /// ```
97    /// use dices::DiceBuilder;
98    /// let builder = DiceBuilder::from_string("3d6");
99    /// let builder_2 = DiceBuilder::from_string("3 d6  ");
100    /// let builder_3 = DiceBuilder::from_string("3xd6"); // explicitly using sample sum
101    /// assert_eq!(builder, builder_2);
102    /// assert_eq!(builder_2, builder_3);
103    /// ```
104    ///
105    /// the minimum and maximum of multiple dice:
106    /// ```
107    /// use dices::DiceBuilder;
108    /// let min_builder = DiceBuilder::from_string("min(d6,d6)");
109    /// let max_builder = DiceBuilder::from_string("max(d6,d6,d20)");
110    /// ```
111    ///
112    pub fn from_string(input: &str) -> Result<Self, DiceBuildingError> {
113        dice_string_parser::string_to_factor(input)
114    }
115
116    /// builds a [`Dice`] from [`self`]
117    ///
118    /// this method calculates the distribution and all distribution paramters on the fly, to create the [`Dice`].
119    /// Depending on the complexity of the `dice_builder` heavy lifting like convoluting probability distributions may take place here.
120    pub fn build(self) -> Dice {
121        #[cfg(feature = "console_error_panic_hook")]
122        console_error_panic_hook::set_once();
123        Dice::from_builder(self)
124    }
125
126    /// shortcut for `DiceBuilder::from_string(input).build()`
127    pub fn build_from_string(input: &str) -> Result<Dice, DiceBuildingError> {
128        let builder = DiceBuilder::from_string(input)?;
129        Ok(builder.build())
130    }
131
132    /// constructs a string from the DiceBuilder that can be used to reconstruct an equivalent DiceBuilder from it.
133    ///
134    /// currently fails to construct a correct string in case dices with a non-1 minimum are present. This is because there is no string notation for dices with a non-1 minimum yet.
135    pub fn reconstruct_string(&self) -> String {
136        match self {
137            DiceBuilder::Constant(i) => i.to_string(),
138            DiceBuilder::FairDie { min, max } => match *min == 1 {
139                true => format!("d{max}"),
140                false => "".to_owned(), // this is currently a weak point where errors can occur
141            },
142            // ugly code right now, too much repetition:
143            DiceBuilder::SumCompound(v) => v
144                .iter()
145                .map(|f| f.to_string())
146                .collect::<Vec<String>>()
147                .join("+"),
148            DiceBuilder::ProductCompound(v) => v
149                .iter()
150                .map(|f| f.to_string())
151                .collect::<Vec<String>>()
152                .join("*"),
153            DiceBuilder::DivisionCompound(v) => v
154                .iter()
155                .map(|f| f.to_string())
156                .collect::<Vec<String>>()
157                .join("/"),
158            DiceBuilder::SampleSumCompound(v) => v
159                .iter()
160                .map(|f| f.to_string())
161                .collect::<Vec<String>>()
162                .join("x"),
163            DiceBuilder::MaxCompound(v) => format!(
164                "max({})",
165                v.iter()
166                    .map(|f| f.to_string())
167                    .collect::<Vec<String>>()
168                    .join(",")
169            ),
170            DiceBuilder::MinCompound(v) => format!(
171                "min({})",
172                v.iter()
173                    .map(|f| f.to_string())
174                    .collect::<Vec<String>>()
175                    .join(",")
176            ),
177        }
178    }
179
180    fn distribution_hashmap(&self) -> DistributionHashMap {
181        match self {
182            DiceBuilder::Constant(v) => {
183                let mut m = DistributionHashMap::new();
184                m.insert(*v, Prob::one());
185                m
186            }
187            DiceBuilder::FairDie { min, max } => {
188                assert!(max >= min);
189                let min: i64 = *min;
190                let max: i64 = *max;
191                let prob: Prob = Prob::new(1u64, (max - min + 1) as u64);
192                let mut m = DistributionHashMap::new();
193                for v in min..=max {
194                    m.insert(v, prob.clone());
195                }
196                m
197            }
198            DiceBuilder::SampleSumCompound(vec) => {
199                let hashmaps = vec
200                    .iter()
201                    .map(|e| e.distribution_hashmap())
202                    .collect::<Vec<DistributionHashMap>>();
203                sample_sum_convolute_hashmaps(&hashmaps)
204            }
205            DiceBuilder::SumCompound(vec)
206            | DiceBuilder::ProductCompound(vec)
207            | DiceBuilder::DivisionCompound(vec)
208            | DiceBuilder::MaxCompound(vec)
209            | DiceBuilder::MinCompound(vec) => {
210                let operation = match self {
211                    DiceBuilder::SumCompound(_) => |a, b| a + b,
212                    DiceBuilder::ProductCompound(_) => |a, b| a * b,
213                    DiceBuilder::MaxCompound(_) => std::cmp::max,
214                    DiceBuilder::MinCompound(_) => std::cmp::min,
215                    DiceBuilder::DivisionCompound(_) => rounded_div::i64,
216                    _ => panic!("unreachable by match"),
217                };
218                let hashmaps = vec
219                    .iter()
220                    .map(|e| e.distribution_hashmap())
221                    .collect::<Vec<DistributionHashMap>>();
222                convolute_hashmaps(&hashmaps, operation)
223            }
224        }
225    }
226
227    /// iterator for the probability mass function (pmf) of the [`DiceBuilder`], with tuples for each value with its probability in ascending order (regarding value)
228    ///
229    /// Calculates the distribution and all distribution paramters.
230    /// Depending on the complexity of [`self`] heavy lifting like convoluting probability distributions may take place here.
231    pub fn distribution_iter(&self) -> Distribution {
232        let mut distribution_vec = self
233            .distribution_hashmap()
234            .into_iter()
235            .collect::<Vec<(Value, Prob)>>();
236        distribution_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
237        Box::new(distribution_vec.into_iter())
238    }
239}
240
241impl Display for DiceBuilder {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        write! {f, "{}", self.reconstruct_string()}
244    }
245}
246
247fn convolute_hashmaps(
248    hashmaps: &Vec<DistributionHashMap>,
249    operation: fn(Value, Value) -> Value,
250) -> DistributionHashMap {
251    if hashmaps.is_empty() {
252        panic!("cannot convolute hashmaps from a zero element vector");
253    }
254    let mut convoluted_h = hashmaps[0].clone();
255    for h in hashmaps.iter().skip(1) {
256        convoluted_h = convolute_two_hashmaps(&convoluted_h, h, operation);
257    }
258    convoluted_h
259}
260
261fn convolute_two_hashmaps(
262    h1: &DistributionHashMap,
263    h2: &DistributionHashMap,
264    operation: fn(Value, Value) -> Value,
265) -> DistributionHashMap {
266    let mut m = DistributionHashMap::new();
267    for (v1, p1) in h1.iter() {
268        for (v2, p2) in h2.iter() {
269            let v = operation(*v1, *v2);
270            let p = p1 * p2;
271            match m.entry(v) {
272                std::collections::hash_map::Entry::Occupied(mut e) => {
273                    *e.get_mut() += p;
274                }
275                std::collections::hash_map::Entry::Vacant(e) => {
276                    e.insert(p);
277                }
278            }
279        }
280    }
281    m
282}
283
284fn sample_sum_convolute_hashmaps(hashmaps: &Vec<DistributionHashMap>) -> DistributionHashMap {
285    if hashmaps.is_empty() {
286        panic!("cannot convolute hashmaps from a zero element vector");
287    }
288    let mut convoluted_h = hashmaps[0].clone();
289    for h in hashmaps.iter().skip(1) {
290        convoluted_h = sample_sum_convolute_two_hashmaps(&convoluted_h, h);
291    }
292    convoluted_h
293}
294
295fn sample_sum_convolute_two_hashmaps(
296    count_factor: &DistributionHashMap,
297    sample_factor: &DistributionHashMap,
298) -> DistributionHashMap {
299    let mut total_hashmap = DistributionHashMap::new();
300    for (count, count_p) in count_factor.iter() {
301        let mut count_hashmap: DistributionHashMap = match count.cmp(&0) {
302            std::cmp::Ordering::Less => {
303                let count: usize = (-count) as usize;
304                let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
305                    .take(count)
306                    .cloned()
307                    .collect();
308                convolute_hashmaps(&sample_vec, |a, b| a + b)
309            }
310            std::cmp::Ordering::Equal => {
311                let mut h = DistributionHashMap::new();
312                h.insert(0, Prob::new(1u64, 1u64));
313                h
314            }
315            std::cmp::Ordering::Greater => {
316                let count: usize = *count as usize;
317                let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
318                    .take(count)
319                    .cloned()
320                    .collect();
321                convolute_hashmaps(&sample_vec, |a, b| a + b)
322            }
323        };
324        count_hashmap.iter_mut().for_each(|e| {
325            *e.1 *= count_p.clone();
326        });
327        merge_hashmaps(&mut total_hashmap, &count_hashmap);
328    }
329    total_hashmap
330}
331
332impl Mul for Box<DiceBuilder> {
333    type Output = Box<DiceBuilder>;
334
335    fn mul(self, rhs: Self) -> Self::Output {
336        Box::new(DiceBuilder::ProductCompound(vec![*self, *rhs]))
337    }
338}
339
340impl Add for Box<DiceBuilder> {
341    type Output = Box<DiceBuilder>;
342
343    fn add(self, rhs: Self) -> Self::Output {
344        Box::new(DiceBuilder::SumCompound(vec![*self, *rhs]))
345    }
346}
347
348pub fn merge_hashmaps(first: &mut DistributionHashMap, second: &DistributionHashMap) {
349    for (k, v) in second.iter() {
350        match first.get_mut(k) {
351            Some(e) => {
352                *e += v;
353            }
354            None => {
355                first.insert(*k, v.clone());
356            }
357        }
358    }
359}