nice_dice/
discrete.rs

1//! Probability computation via discrete (integral) math and combinatorics.
2
3use crate::{
4    Error,
5    analysis::Closed,
6    symbolic::{ComparisonOp, Constant, Die, ExpressionTree, ExpressionWrapper, Ranker, Symbol},
7};
8use std::{collections::HashMap, ops::Neg};
9
10use itertools::Itertools;
11use num::{ToPrimitive, rational::Ratio};
12
13/// A computed distribution for a bounded dice expression.
14/// ("bounded": does not support exploding dice.)
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Distribution {
17    /// We track probabilities of each value using integers;
18    /// all of these have an implied denominator of occurrence_by_value.sum().
19    occurrence_by_value: Vec<usize>,
20    /// Index i in occurrence_by_value represents the number of occurrences of (i+offset).
21    offset: isize,
22}
23
24/// An evaluator: evaluates distributions for a closed expression.
25///
26/// Evaluators provide memoization for sub-expressions.
27/// It may be useful to re-use an Evaluator when experimenting with new expressions,
28/// trading (some) memory for (some) processing.
29///
30// TODO: Benchmark with and without memoization.
31#[derive(Default)]
32pub struct Evaluator {
33    /// Memoization table.
34    memo: HashMap<Closed, Distribution>,
35    memoize: bool,
36}
37
38impl Evaluator {
39    /// Create a new Evaluator, with or without memoization enabled.
40    pub fn new(memoize: bool) -> Self {
41        Self {
42            memoize,
43            ..Default::default()
44        }
45    }
46
47    pub fn eval(&mut self, tree: &Closed) -> Result<Distribution, Error> {
48        if self.memoize {
49            if let Some(dist) = self.memo.get(tree) {
50                return Ok(dist.clone());
51            }
52        }
53        // We begin with native-stack recursion.
54
55        // Need to evaluate.
56        let memo = match tree.inner() {
57            ExpressionTree::Modifier(Constant(constant)) => Distribution::constant(*constant),
58            ExpressionTree::Die(Die(die)) => Distribution::die(*die),
59            ExpressionTree::Symbol(symbol) => {
60                panic!("unbound symbol {symbol} in closed expression")
61                // return Err(Error::UnboundSymbols([symbol].into()))
62            }
63            ExpressionTree::Negated(e) => {
64                let dist = self.eval(e.as_ref())?;
65                -dist
66            }
67            ExpressionTree::Repeated {
68                count,
69                value,
70                ranker,
71            } => self.repeat(tree, count, value, ranker)?,
72            ExpressionTree::Product(a, b) => self.product(a, b)?,
73            ExpressionTree::Floor(a, b) => self.floor(tree, a, b)?,
74            ExpressionTree::Sum(items) => {
75                let distrs: Result<Vec<_>, _> = items.iter().map(|e| self.eval(e)).collect();
76                let distrs = distrs?;
77                distrs.into_iter().sum()
78            }
79            ExpressionTree::Comparison { a, b, op } => self.comparison(a, b, *op)?,
80            ExpressionTree::Binding {
81                symbol,
82                value,
83                tail,
84            } => self.binding(symbol, value, tail)?,
85        };
86        if self.memoize {
87            self.memo.insert(tree.clone(), memo.clone());
88        }
89        Ok(memo)
90    }
91
92    fn product(&mut self, a: &Closed, b: &Closed) -> Result<Distribution, Error> {
93        let a = self.eval(a)?;
94        let b = self.eval(b)?;
95
96        let mut d = Distribution::empty();
97
98        for ((v1, f1), (v2, f2)) in a.occurrences().cartesian_product(b.occurrences()) {
99            d.add_occurrences(v1 * v2, f1 * f2);
100        }
101        Ok(d)
102    }
103
104    fn floor(&mut self, e: &Closed, a: &Closed, b: &Closed) -> Result<Distribution, Error> {
105        let a = self.eval(a)?;
106        let b = self.eval(b)?;
107
108        if *b.probability(0).numer() != 0 {
109            return Err(Error::DivideByZero(e.to_string()));
110        }
111
112        let mut d = Distribution::empty();
113        for ((v1, f1), (v2, f2)) in a.occurrences().cartesian_product(b.occurrences()) {
114            d.add_occurrences(v1 / v2, f1 * f2);
115        }
116        Ok(d)
117    }
118
119    fn repeat(
120        &mut self,
121        expression: &Closed,
122        count: &Closed,
123        value: &Closed,
124        ranker: &Ranker,
125    ) -> Result<Distribution, Error> {
126        let count_dist = self.eval(count)?;
127        let value_dist = self.eval(value)?;
128
129        let mut result = Distribution::empty();
130        if count_dist.min() < 0 {
131            return Err(Error::NegativeCount(expression.to_string()));
132        }
133        if (count_dist.min() as usize) < ranker.min_count() {
134            return Err(Error::KeepTooFew(
135                ranker.min_count(),
136                expression.to_string(),
137            ));
138        }
139
140        // We have to have the same type signature for each of these,
141        // and we want to truncate in the other cases.
142        #[allow(clippy::ptr_arg)]
143        fn keep_all(v: &mut [isize], _n: usize) -> &[isize] {
144            v
145        }
146        fn keep_highest(v: &mut [isize], n: usize) -> &[isize] {
147            v.sort_by(|v1, v2| v2.cmp(v1));
148            &v[..n]
149        }
150        fn keep_lowest(v: &mut [isize], n: usize) -> &[isize] {
151            v.sort();
152            &v[..n]
153        }
154        let filter = match ranker {
155            Ranker::All => keep_all,
156            Ranker::Highest(_) => keep_highest,
157            Ranker::Lowest(_) => keep_lowest,
158        };
159
160        for (count, count_frequency) in count_dist.occurrences() {
161            let keep_count = ranker.keep(count) as usize;
162            // Assuming this count happens this often...
163            let dice = std::iter::repeat(&value_dist)
164                .map(|d| d.occurrences())
165                .take(count as usize);
166            for value_set in dice.multi_cartesian_product() {
167                let (mut values, frequencies): (Vec<isize>, Vec<usize>) =
168                    value_set.into_iter().unzip();
169                // We have to compute the overall frquency including the dice we dropped;
170                // in other universes (other combinations), we'd keep them.
171                let occurrences = frequencies.into_iter().product::<usize>() * count_frequency;
172                let value = filter(&mut values, keep_count).iter().sum();
173                result.add_occurrences(value, occurrences);
174            }
175        }
176        Ok(result)
177    }
178
179    fn comparison(
180        &mut self,
181        a: &Closed,
182        b: &Closed,
183        op: ComparisonOp,
184    ) -> Result<Distribution, Error> {
185        let a = self.eval(a)?;
186        let b = self.eval(b)?;
187
188        let mut dist = Distribution::empty();
189
190        for ((v1, o1), (v2, o2)) in a.occurrences().cartesian_product(b.occurrences()) {
191            let occurrences = o1 * o2;
192            let value = op.compare(v1, v2) as isize;
193            dist.add_occurrences(value, occurrences);
194        }
195        Ok(dist)
196    }
197
198    fn binding(
199        &mut self,
200        symbol: &Symbol,
201        value: &Closed,
202        tail: &Closed,
203    ) -> Result<Distribution, Error> {
204        let value = self.eval(value)?;
205        let mut acc = Distribution::empty();
206        for (value, occ) in value.occurrences() {
207            let tree: Closed = tail.substitute(symbol, value);
208            let table = self.eval(&tree)?;
209            for (v2, o2) in table.occurrences() {
210                acc.add_occurrences(v2, occ * o2);
211            }
212        }
213        Ok(acc)
214    }
215}
216
217impl Distribution {
218    /// Generate a uniform distribution on the closed interval `[1, size]`;
219    /// i.e. the distribution for rolling a die with the given number of faces.
220    fn die(size: usize) -> Distribution {
221        let mut v = Vec::new();
222        v.resize(size, 1);
223        Distribution {
224            occurrence_by_value: v,
225            offset: 1,
226        }
227    }
228
229    /// Generate a "modifier" distribution, which has probability 1 of producing the given value.
230    fn constant(value: usize) -> Distribution {
231        Distribution {
232            occurrence_by_value: vec![1],
233            offset: value as isize,
234        }
235    }
236
237    /// Give the probability of this value occurring in this distribution.
238    pub fn probability(&self, value: isize) -> Ratio<usize> {
239        let index = value - self.offset;
240        if (0..(self.occurrence_by_value.len() as isize)).contains(&index) {
241            Ratio::new(self.occurrence_by_value[index as usize], self.total())
242        } else {
243            Ratio::new(0, 1)
244        }
245    }
246
247    pub fn probability_f64(&self, value: isize) -> f64 {
248        Ratio::to_f64(&self.probability(value)).expect("should convert probability to f64")
249    }
250
251    /// Report the total number of occurrences in this expression, i.e. the number of possible
252    /// rolls (rather than the number of distinct values).
253    pub fn total(&self) -> usize {
254        let v = self.occurrence_by_value.iter().sum();
255        debug_assert_ne!(v, 0);
256        v
257    }
258
259    /// Iterator over (value, occurrences) tuples in this distribution.
260    /// Reports values with nonzero occurrence in ascending order of value.
261    pub fn occurrences(&self) -> Occurrences {
262        Occurrences {
263            distribution: self,
264            current: self.offset,
265        }
266    }
267
268    /// The minimum value with nonzero occurrence in this distribution.
269    pub fn min(&self) -> isize {
270        self.offset
271    }
272
273    /// The minimum value with nonzero occurrence in this distribution (note: inclusive)
274    pub fn max(&self) -> isize {
275        self.offset + (self.occurrence_by_value.len() as isize) - 1
276    }
277
278    /// The average value (expected value) from this distribution.
279    pub fn mean(&self) -> f64 {
280        // This might be a hefty sum, so keep each term in the f64 range, and sum f64s.
281        (self.min()..=self.max())
282            .map(|v| (v as f64) * self.probability_f64(v))
283            .sum()
284    }
285
286    /// Clean up the distribution by removing extraneous zero-valued entries.
287    fn clean(&mut self) {
288        let leading_zeros = self
289            .occurrence_by_value
290            .iter()
291            .take_while(|&&f| f == 0)
292            .count();
293        if leading_zeros > 0 {
294            self.occurrence_by_value = self.occurrence_by_value[leading_zeros..].into();
295            self.offset += leading_zeros as isize;
296        }
297        let trailing_zeros = self
298            .occurrence_by_value
299            .iter()
300            .rev()
301            .take_while(|&&f| f == 0)
302            .count();
303        self.occurrence_by_value
304            .truncate(self.occurrence_by_value.len() - trailing_zeros);
305    }
306
307    /// Add the given occurrences to the values table.
308    fn add_occurrences(&mut self, value: isize, occurrences: usize) {
309        if value < self.offset {
310            let diff = (self.offset - value) as usize;
311            let new_len = self.occurrence_by_value.len() + diff;
312            self.occurrence_by_value.resize(new_len, 0);
313            // Swap "upwards", starting from the newly long end
314            for i in (diff..self.occurrence_by_value.len()).rev() {
315                self.occurrence_by_value.swap(i, i - diff);
316            }
317            self.offset = value;
318        }
319        let index = (value - self.offset) as usize;
320        if index >= self.occurrence_by_value.len() {
321            self.occurrence_by_value.resize(index + 1, 0);
322        }
323        self.occurrence_by_value[index] += occurrences;
324    }
325
326    fn empty() -> Self {
327        Self {
328            occurrence_by_value: vec![],
329            offset: 0,
330        }
331    }
332}
333
334/// An iterator over the occurrences in a distribution.
335///
336/// Implemented explicitly for its Clone implementation.
337#[derive(Debug, Clone)]
338pub struct Occurrences<'a> {
339    distribution: &'a Distribution,
340    current: isize,
341}
342
343impl Iterator for Occurrences<'_> {
344    type Item = (isize, usize);
345
346    fn next(&mut self) -> Option<Self::Item> {
347        loop {
348            let value = self.current;
349            let index = (value - self.distribution.offset) as usize;
350            if index < self.distribution.occurrence_by_value.len() {
351                self.current += 1;
352                let occ = self.distribution.occurrence_by_value[index];
353                if occ == 0 {
354                    continue;
355                } else {
356                    break Some((value, occ));
357                }
358            } else {
359                break None;
360            }
361        }
362    }
363}
364
365impl std::ops::Add<&Distribution> for &Distribution {
366    type Output = Distribution;
367
368    fn add(self, rhs: &Distribution) -> Self::Output {
369        let a = self;
370        let b = rhs;
371
372        let mut result = Distribution::empty();
373
374        for ((v1, o1), (v2, o2)) in a.occurrences().cartesian_product(b.occurrences()) {
375            let val = v1 + v2;
376            // aocc and bocc each represent the numerator of a fraction, aocc/atotal and
377            // bocc/btotal. That fraction is the probability that the given value will turn up
378            // on a roll.
379            //
380            // The events are independent, so we can combine the probabilities by summing them.
381            let occ = o1 * o2;
382            // This represents _only one way_ to get this value: this roll from A, this roll
383            // from B.
384            // Accumulate from different rolls:
385            result.add_occurrences(val, occ);
386        }
387
388        debug_assert_eq!(a.total() * b.total(), result.total(), "{result:?}");
389
390        result
391    }
392}
393
394impl std::ops::Add<Distribution> for Distribution {
395    type Output = Distribution;
396
397    fn add(self, rhs: Distribution) -> Self::Output {
398        (&self) + (&rhs)
399    }
400}
401
402impl Neg for &Distribution {
403    type Output = Distribution;
404
405    fn neg(self) -> Self::Output {
406        // The largest magnitude entry has
407        let magnitude = (self.occurrence_by_value.len() - 1) as isize + self.offset;
408        let occurrence_by_value = self.occurrence_by_value.iter().rev().copied().collect();
409        Distribution {
410            offset: -magnitude,
411            occurrence_by_value,
412        }
413    }
414}
415
416impl Neg for Distribution {
417    type Output = Distribution;
418
419    fn neg(self) -> Self::Output {
420        (&self).neg()
421    }
422}
423
424impl std::iter::Sum for Distribution {
425    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
426        iter.reduce(|a, b| a + b)
427            .unwrap_or_else(|| Distribution::constant(0))
428    }
429}
430
431impl Closed {
432    /// Retrieve the distribution for the expression.
433    pub fn distribution(&self) -> Result<Distribution, Error> {
434        let mut eval = Evaluator::default();
435        eval.eval(self)
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use crate::parse::RawExpression;
442
443    use super::*;
444
445    fn distribution_of(s: &str) -> Result<Distribution, Error> {
446        let raw = s.parse::<RawExpression>().unwrap();
447        let closed: Closed = raw.try_into().expect("failed closure");
448        closed.distribution()
449    }
450
451    #[test]
452    fn no_div_zero() {
453        let e = distribution_of("20 / (1d20 - 10)").unwrap_err();
454        assert!(matches!(e, Error::DivideByZero(_)));
455    }
456
457    #[test]
458    fn d20() {
459        let d = distribution_of("d20").unwrap();
460
461        for i in 1..=20isize {
462            assert_eq!(d.probability(i), Ratio::new(1, 20));
463        }
464
465        for i in [-1, -2, -3, 0, 21, 22, 32] {
466            assert_eq!(*d.probability(i).numer(), 0);
467        }
468    }
469
470    #[test]
471    fn d20_plus1() {
472        let d = distribution_of("d20 + 1").unwrap();
473
474        for i in 2..=21isize {
475            assert_eq!(d.probability(i), Ratio::new(1, 20));
476        }
477
478        for i in [-1, -2, -3, 0, 1, 22, 22, 32] {
479            assert_eq!(*d.probability(i).numer(), 0);
480        }
481    }
482
483    #[test]
484    fn two_d4() {
485        let d = distribution_of("2d4").unwrap();
486
487        for (v, p) in [(2, 1), (3, 2), (4, 3), (5, 4), (6, 3), (7, 2), (8, 1)] {
488            assert_eq!(d.probability(v), Ratio::new(p, 16));
489        }
490    }
491
492    #[test]
493    fn advantage_disadvantage() {
494        let a = distribution_of("2d20kh").unwrap();
495        let b = distribution_of("1d20").unwrap();
496        let c = distribution_of("2d20kl").unwrap();
497
498        assert!(a.mean() > b.mean());
499        assert!(b.mean() > c.mean());
500    }
501
502    #[test]
503    fn stat_roll() {
504        let stat = distribution_of("4d6kh3").unwrap();
505        let diff = stat.mean() - 12.25;
506
507        assert!(diff < 0.01, "{}", stat.mean());
508    }
509
510    #[test]
511    fn require_positive_roll_count() {
512        for expr in ["(1d3-2)d4", "(-1)d10"] {
513            let e = distribution_of(expr).unwrap_err();
514            assert!(matches!(e, Error::NegativeCount(_)));
515        }
516    }
517
518    #[test]
519    fn require_dice_to_keep() {
520        for expr in ["2d4kh3", "(1d4)(4)kl2"] {
521            let e = distribution_of(expr).unwrap_err();
522            assert!(matches!(e, Error::KeepTooFew(..)));
523        }
524    }
525
526    #[test]
527    fn negative_modifier() {
528        let d = distribution_of("1d4 + -1").unwrap();
529        for i in 0..3isize {
530            assert_eq!(d.probability(i), Ratio::new(1, 4));
531        }
532    }
533
534    #[test]
535    fn negative_die() {
536        let d = -Distribution::die(4) + Distribution::constant(1);
537        for i in -3..=0isize {
538            assert_eq!(d.probability(i), Ratio::new(1, 4), "{d:?}");
539        }
540    }
541
542    #[test]
543    fn product() {
544        let d = distribution_of("1d4 * 3").unwrap();
545        let ps: Vec<_> = d.occurrences().collect();
546        assert_eq!(&ps, &vec![(3, 1), (6, 1), (9, 1), (12, 1)])
547    }
548
549    #[test]
550    fn never() {
551        distribution_of("0d3").unwrap_err();
552    }
553
554    //#[test]
555    //fn compare_constant() {
556    //    let d = distribution_of("1d20").unwrap();
557    //
558    //    {
559    //        // 10 > 1d20 : 9 times
560    //        let d = ComparisonOp::Gt.compare(10, &d);
561    //        assert_eq!(d, 9);
562    //    }
563    //    {
564    //        let d = ComparisonOp::Ge.compare(10, &d);
565    //        assert_eq!(d, 10);
566    //    }
567    //    {
568    //        let d = ComparisonOp::Eq.compare(20, &d);
569    //        assert_eq!(d, 1);
570    //    }
571    //    {
572    //        let d = ComparisonOp::Eq.compare(21, &d);
573    //        assert_eq!(d, 0);
574    //    }
575    //    {
576    //        // 0 <= 1d20 always
577    //        let d = ComparisonOp::Le.compare(00, &d);
578    //        assert_eq!(d, 20);
579    //    }
580    //    {
581    //        let d = ComparisonOp::Le.compare(18, &d);
582    //        assert_eq!(d, 3);
583    //    }
584    //    {
585    //        let d = ComparisonOp::Lt.compare(3, &d);
586    //        assert_eq!(d, 17);
587    //    }
588    //}
589
590    #[test]
591    fn critical_slap() {
592        let d = distribution_of(
593            r#"
594        [ATK: 1d20] (ATK >= 12) * 1 + (ATK = 20) * 1
595    "#,
596        )
597        .unwrap();
598        let ps: Vec<_> = d.occurrences().collect();
599        assert_eq!(&ps, &vec![(0, 11), (1, 8), (2, 1)])
600    }
601
602    #[test]
603    fn critical_fail() {
604        let d = distribution_of(
605            r#"
606        [ATK: 1d20] (ATK > 1) * (1 + (ATK = 20) * 1)
607        "#,
608        )
609        .unwrap();
610        let ps: Vec<_> = d.occurrences().collect();
611        assert_eq!(&ps, &vec![(0, 1), (1, 18), (2, 1)])
612    }
613
614    #[test]
615    fn even_contest() {
616        let d = distribution_of(
617            r#"
618            (1d20 = 1d20) * 2
619        "#,
620        )
621        .unwrap();
622        let ps: Vec<_> = d.occurrences().collect();
623        assert_eq!(&ps, &vec![(0, 380), (2, 20)])
624    }
625
626    #[test]
627    fn break_even_contest() {
628        let d = distribution_of(
629            r#"
630            (1d20 >= 1d20) * 2
631        "#,
632        )
633        .unwrap();
634        let ps: Vec<_> = d.occurrences().collect();
635        // >= is slightly biased towards the aggressor, "meets or exceeds"
636        assert_eq!(&ps, &vec![(0, 190), (2, 210)])
637    }
638
639    #[test]
640    fn dagger() {
641        let d = distribution_of("[ATK: 1d20] (ATK > 10) * 1d4").unwrap();
642        let ps: Vec<_> = d.occurrences().collect();
643        // In 10/20 cases, we pick the first branch.
644        // In 10/20 cases, we pick the second branch.
645        // In the second branch, we get each value 1/4 of the time.
646
647        assert_eq!(&ps, &vec![(0, 40), (1, 10), (2, 10), (3, 10), (4, 10)])
648    }
649
650    #[test]
651    fn floor_div() {
652        let d = distribution_of("1d4 / 2").unwrap();
653        let ps: Vec<_> = d.occurrences().collect();
654        assert_eq!(&ps, &vec![(0, 1), (1, 2), (2, 1)])
655    }
656}