1use crate::{
4 Error,
5 analysis::Closed,
6 symbolic::{ComparisonOp, Constant, Die, ExpressionTree, ExpressionWrapper, Ranker, Symbol},
7};
8use std::{collections::HashMap, ops::Neg};
9
10use itertools::Itertools;
11use num::{ToPrimitive, rational::Ratio};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Distribution {
17 occurrence_by_value: Vec<usize>,
20 offset: isize,
22}
23
24#[derive(Default)]
32pub struct Evaluator {
33 memo: HashMap<Closed, Distribution>,
35 memoize: bool,
36}
37
38impl Evaluator {
39 pub fn new(memoize: bool) -> Self {
41 Self {
42 memoize,
43 ..Default::default()
44 }
45 }
46
47 pub fn eval(&mut self, tree: &Closed) -> Result<Distribution, Error> {
48 if self.memoize {
49 if let Some(dist) = self.memo.get(tree) {
50 return Ok(dist.clone());
51 }
52 }
53 let memo = match tree.inner() {
57 ExpressionTree::Modifier(Constant(constant)) => Distribution::constant(*constant),
58 ExpressionTree::Die(Die(die)) => Distribution::die(*die),
59 ExpressionTree::Symbol(symbol) => {
60 panic!("unbound symbol {symbol} in closed expression")
61 }
63 ExpressionTree::Negated(e) => {
64 let dist = self.eval(e.as_ref())?;
65 -dist
66 }
67 ExpressionTree::Repeated {
68 count,
69 value,
70 ranker,
71 } => self.repeat(tree, count, value, ranker)?,
72 ExpressionTree::Product(a, b) => self.product(a, b)?,
73 ExpressionTree::Floor(a, b) => self.floor(tree, a, b)?,
74 ExpressionTree::Sum(items) => {
75 let distrs: Result<Vec<_>, _> = items.iter().map(|e| self.eval(e)).collect();
76 let distrs = distrs?;
77 distrs.into_iter().sum()
78 }
79 ExpressionTree::Comparison { a, b, op } => self.comparison(a, b, *op)?,
80 ExpressionTree::Binding {
81 symbol,
82 value,
83 tail,
84 } => self.binding(symbol, value, tail)?,
85 };
86 if self.memoize {
87 self.memo.insert(tree.clone(), memo.clone());
88 }
89 Ok(memo)
90 }
91
92 fn product(&mut self, a: &Closed, b: &Closed) -> Result<Distribution, Error> {
93 let a = self.eval(a)?;
94 let b = self.eval(b)?;
95
96 let mut d = Distribution::empty();
97
98 for ((v1, f1), (v2, f2)) in a.occurrences().cartesian_product(b.occurrences()) {
99 d.add_occurrences(v1 * v2, f1 * f2);
100 }
101 Ok(d)
102 }
103
104 fn floor(&mut self, e: &Closed, a: &Closed, b: &Closed) -> Result<Distribution, Error> {
105 let a = self.eval(a)?;
106 let b = self.eval(b)?;
107
108 if *b.probability(0).numer() != 0 {
109 return Err(Error::DivideByZero(e.to_string()));
110 }
111
112 let mut d = Distribution::empty();
113 for ((v1, f1), (v2, f2)) in a.occurrences().cartesian_product(b.occurrences()) {
114 d.add_occurrences(v1 / v2, f1 * f2);
115 }
116 Ok(d)
117 }
118
119 fn repeat(
120 &mut self,
121 expression: &Closed,
122 count: &Closed,
123 value: &Closed,
124 ranker: &Ranker,
125 ) -> Result<Distribution, Error> {
126 let count_dist = self.eval(count)?;
127 let value_dist = self.eval(value)?;
128
129 let mut result = Distribution::empty();
130 if count_dist.min() < 0 {
131 return Err(Error::NegativeCount(expression.to_string()));
132 }
133 if (count_dist.min() as usize) < ranker.min_count() {
134 return Err(Error::KeepTooFew(
135 ranker.min_count(),
136 expression.to_string(),
137 ));
138 }
139
140 #[allow(clippy::ptr_arg)]
143 fn keep_all(v: &mut [isize], _n: usize) -> &[isize] {
144 v
145 }
146 fn keep_highest(v: &mut [isize], n: usize) -> &[isize] {
147 v.sort_by(|v1, v2| v2.cmp(v1));
148 &v[..n]
149 }
150 fn keep_lowest(v: &mut [isize], n: usize) -> &[isize] {
151 v.sort();
152 &v[..n]
153 }
154 let filter = match ranker {
155 Ranker::All => keep_all,
156 Ranker::Highest(_) => keep_highest,
157 Ranker::Lowest(_) => keep_lowest,
158 };
159
160 for (count, count_frequency) in count_dist.occurrences() {
161 let keep_count = ranker.keep(count) as usize;
162 let dice = std::iter::repeat(&value_dist)
164 .map(|d| d.occurrences())
165 .take(count as usize);
166 for value_set in dice.multi_cartesian_product() {
167 let (mut values, frequencies): (Vec<isize>, Vec<usize>) =
168 value_set.into_iter().unzip();
169 let occurrences = frequencies.into_iter().product::<usize>() * count_frequency;
172 let value = filter(&mut values, keep_count).iter().sum();
173 result.add_occurrences(value, occurrences);
174 }
175 }
176 Ok(result)
177 }
178
179 fn comparison(
180 &mut self,
181 a: &Closed,
182 b: &Closed,
183 op: ComparisonOp,
184 ) -> Result<Distribution, Error> {
185 let a = self.eval(a)?;
186 let b = self.eval(b)?;
187
188 let mut dist = Distribution::empty();
189
190 for ((v1, o1), (v2, o2)) in a.occurrences().cartesian_product(b.occurrences()) {
191 let occurrences = o1 * o2;
192 let value = op.compare(v1, v2) as isize;
193 dist.add_occurrences(value, occurrences);
194 }
195 Ok(dist)
196 }
197
198 fn binding(
199 &mut self,
200 symbol: &Symbol,
201 value: &Closed,
202 tail: &Closed,
203 ) -> Result<Distribution, Error> {
204 let value = self.eval(value)?;
205 let mut acc = Distribution::empty();
206 for (value, occ) in value.occurrences() {
207 let tree: Closed = tail.substitute(symbol, value);
208 let table = self.eval(&tree)?;
209 for (v2, o2) in table.occurrences() {
210 acc.add_occurrences(v2, occ * o2);
211 }
212 }
213 Ok(acc)
214 }
215}
216
217impl Distribution {
218 fn die(size: usize) -> Distribution {
221 let mut v = Vec::new();
222 v.resize(size, 1);
223 Distribution {
224 occurrence_by_value: v,
225 offset: 1,
226 }
227 }
228
229 fn constant(value: usize) -> Distribution {
231 Distribution {
232 occurrence_by_value: vec![1],
233 offset: value as isize,
234 }
235 }
236
237 pub fn probability(&self, value: isize) -> Ratio<usize> {
239 let index = value - self.offset;
240 if (0..(self.occurrence_by_value.len() as isize)).contains(&index) {
241 Ratio::new(self.occurrence_by_value[index as usize], self.total())
242 } else {
243 Ratio::new(0, 1)
244 }
245 }
246
247 pub fn probability_f64(&self, value: isize) -> f64 {
248 Ratio::to_f64(&self.probability(value)).expect("should convert probability to f64")
249 }
250
251 pub fn total(&self) -> usize {
254 let v = self.occurrence_by_value.iter().sum();
255 debug_assert_ne!(v, 0);
256 v
257 }
258
259 pub fn occurrences(&self) -> Occurrences {
262 Occurrences {
263 distribution: self,
264 current: self.offset,
265 }
266 }
267
268 pub fn min(&self) -> isize {
270 self.offset
271 }
272
273 pub fn max(&self) -> isize {
275 self.offset + (self.occurrence_by_value.len() as isize) - 1
276 }
277
278 pub fn mean(&self) -> f64 {
280 (self.min()..=self.max())
282 .map(|v| (v as f64) * self.probability_f64(v))
283 .sum()
284 }
285
286 fn clean(&mut self) {
288 let leading_zeros = self
289 .occurrence_by_value
290 .iter()
291 .take_while(|&&f| f == 0)
292 .count();
293 if leading_zeros > 0 {
294 self.occurrence_by_value = self.occurrence_by_value[leading_zeros..].into();
295 self.offset += leading_zeros as isize;
296 }
297 let trailing_zeros = self
298 .occurrence_by_value
299 .iter()
300 .rev()
301 .take_while(|&&f| f == 0)
302 .count();
303 self.occurrence_by_value
304 .truncate(self.occurrence_by_value.len() - trailing_zeros);
305 }
306
307 fn add_occurrences(&mut self, value: isize, occurrences: usize) {
309 if value < self.offset {
310 let diff = (self.offset - value) as usize;
311 let new_len = self.occurrence_by_value.len() + diff;
312 self.occurrence_by_value.resize(new_len, 0);
313 for i in (diff..self.occurrence_by_value.len()).rev() {
315 self.occurrence_by_value.swap(i, i - diff);
316 }
317 self.offset = value;
318 }
319 let index = (value - self.offset) as usize;
320 if index >= self.occurrence_by_value.len() {
321 self.occurrence_by_value.resize(index + 1, 0);
322 }
323 self.occurrence_by_value[index] += occurrences;
324 }
325
326 fn empty() -> Self {
327 Self {
328 occurrence_by_value: vec![],
329 offset: 0,
330 }
331 }
332}
333
334#[derive(Debug, Clone)]
338pub struct Occurrences<'a> {
339 distribution: &'a Distribution,
340 current: isize,
341}
342
343impl Iterator for Occurrences<'_> {
344 type Item = (isize, usize);
345
346 fn next(&mut self) -> Option<Self::Item> {
347 loop {
348 let value = self.current;
349 let index = (value - self.distribution.offset) as usize;
350 if index < self.distribution.occurrence_by_value.len() {
351 self.current += 1;
352 let occ = self.distribution.occurrence_by_value[index];
353 if occ == 0 {
354 continue;
355 } else {
356 break Some((value, occ));
357 }
358 } else {
359 break None;
360 }
361 }
362 }
363}
364
365impl std::ops::Add<&Distribution> for &Distribution {
366 type Output = Distribution;
367
368 fn add(self, rhs: &Distribution) -> Self::Output {
369 let a = self;
370 let b = rhs;
371
372 let mut result = Distribution::empty();
373
374 for ((v1, o1), (v2, o2)) in a.occurrences().cartesian_product(b.occurrences()) {
375 let val = v1 + v2;
376 let occ = o1 * o2;
382 result.add_occurrences(val, occ);
386 }
387
388 debug_assert_eq!(a.total() * b.total(), result.total(), "{result:?}");
389
390 result
391 }
392}
393
394impl std::ops::Add<Distribution> for Distribution {
395 type Output = Distribution;
396
397 fn add(self, rhs: Distribution) -> Self::Output {
398 (&self) + (&rhs)
399 }
400}
401
402impl Neg for &Distribution {
403 type Output = Distribution;
404
405 fn neg(self) -> Self::Output {
406 let magnitude = (self.occurrence_by_value.len() - 1) as isize + self.offset;
408 let occurrence_by_value = self.occurrence_by_value.iter().rev().copied().collect();
409 Distribution {
410 offset: -magnitude,
411 occurrence_by_value,
412 }
413 }
414}
415
416impl Neg for Distribution {
417 type Output = Distribution;
418
419 fn neg(self) -> Self::Output {
420 (&self).neg()
421 }
422}
423
424impl std::iter::Sum for Distribution {
425 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
426 iter.reduce(|a, b| a + b)
427 .unwrap_or_else(|| Distribution::constant(0))
428 }
429}
430
431impl Closed {
432 pub fn distribution(&self) -> Result<Distribution, Error> {
434 let mut eval = Evaluator::default();
435 eval.eval(self)
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use crate::parse::RawExpression;
442
443 use super::*;
444
445 fn distribution_of(s: &str) -> Result<Distribution, Error> {
446 let raw = s.parse::<RawExpression>().unwrap();
447 let closed: Closed = raw.try_into().expect("failed closure");
448 closed.distribution()
449 }
450
451 #[test]
452 fn no_div_zero() {
453 let e = distribution_of("20 / (1d20 - 10)").unwrap_err();
454 assert!(matches!(e, Error::DivideByZero(_)));
455 }
456
457 #[test]
458 fn d20() {
459 let d = distribution_of("d20").unwrap();
460
461 for i in 1..=20isize {
462 assert_eq!(d.probability(i), Ratio::new(1, 20));
463 }
464
465 for i in [-1, -2, -3, 0, 21, 22, 32] {
466 assert_eq!(*d.probability(i).numer(), 0);
467 }
468 }
469
470 #[test]
471 fn d20_plus1() {
472 let d = distribution_of("d20 + 1").unwrap();
473
474 for i in 2..=21isize {
475 assert_eq!(d.probability(i), Ratio::new(1, 20));
476 }
477
478 for i in [-1, -2, -3, 0, 1, 22, 22, 32] {
479 assert_eq!(*d.probability(i).numer(), 0);
480 }
481 }
482
483 #[test]
484 fn two_d4() {
485 let d = distribution_of("2d4").unwrap();
486
487 for (v, p) in [(2, 1), (3, 2), (4, 3), (5, 4), (6, 3), (7, 2), (8, 1)] {
488 assert_eq!(d.probability(v), Ratio::new(p, 16));
489 }
490 }
491
492 #[test]
493 fn advantage_disadvantage() {
494 let a = distribution_of("2d20kh").unwrap();
495 let b = distribution_of("1d20").unwrap();
496 let c = distribution_of("2d20kl").unwrap();
497
498 assert!(a.mean() > b.mean());
499 assert!(b.mean() > c.mean());
500 }
501
502 #[test]
503 fn stat_roll() {
504 let stat = distribution_of("4d6kh3").unwrap();
505 let diff = stat.mean() - 12.25;
506
507 assert!(diff < 0.01, "{}", stat.mean());
508 }
509
510 #[test]
511 fn require_positive_roll_count() {
512 for expr in ["(1d3-2)d4", "(-1)d10"] {
513 let e = distribution_of(expr).unwrap_err();
514 assert!(matches!(e, Error::NegativeCount(_)));
515 }
516 }
517
518 #[test]
519 fn require_dice_to_keep() {
520 for expr in ["2d4kh3", "(1d4)(4)kl2"] {
521 let e = distribution_of(expr).unwrap_err();
522 assert!(matches!(e, Error::KeepTooFew(..)));
523 }
524 }
525
526 #[test]
527 fn negative_modifier() {
528 let d = distribution_of("1d4 + -1").unwrap();
529 for i in 0..3isize {
530 assert_eq!(d.probability(i), Ratio::new(1, 4));
531 }
532 }
533
534 #[test]
535 fn negative_die() {
536 let d = -Distribution::die(4) + Distribution::constant(1);
537 for i in -3..=0isize {
538 assert_eq!(d.probability(i), Ratio::new(1, 4), "{d:?}");
539 }
540 }
541
542 #[test]
543 fn product() {
544 let d = distribution_of("1d4 * 3").unwrap();
545 let ps: Vec<_> = d.occurrences().collect();
546 assert_eq!(&ps, &vec![(3, 1), (6, 1), (9, 1), (12, 1)])
547 }
548
549 #[test]
550 fn never() {
551 distribution_of("0d3").unwrap_err();
552 }
553
554 #[test]
591 fn critical_slap() {
592 let d = distribution_of(
593 r#"
594 [ATK: 1d20] (ATK >= 12) * 1 + (ATK = 20) * 1
595 "#,
596 )
597 .unwrap();
598 let ps: Vec<_> = d.occurrences().collect();
599 assert_eq!(&ps, &vec![(0, 11), (1, 8), (2, 1)])
600 }
601
602 #[test]
603 fn critical_fail() {
604 let d = distribution_of(
605 r#"
606 [ATK: 1d20] (ATK > 1) * (1 + (ATK = 20) * 1)
607 "#,
608 )
609 .unwrap();
610 let ps: Vec<_> = d.occurrences().collect();
611 assert_eq!(&ps, &vec![(0, 1), (1, 18), (2, 1)])
612 }
613
614 #[test]
615 fn even_contest() {
616 let d = distribution_of(
617 r#"
618 (1d20 = 1d20) * 2
619 "#,
620 )
621 .unwrap();
622 let ps: Vec<_> = d.occurrences().collect();
623 assert_eq!(&ps, &vec![(0, 380), (2, 20)])
624 }
625
626 #[test]
627 fn break_even_contest() {
628 let d = distribution_of(
629 r#"
630 (1d20 >= 1d20) * 2
631 "#,
632 )
633 .unwrap();
634 let ps: Vec<_> = d.occurrences().collect();
635 assert_eq!(&ps, &vec![(0, 190), (2, 210)])
637 }
638
639 #[test]
640 fn dagger() {
641 let d = distribution_of("[ATK: 1d20] (ATK > 10) * 1d4").unwrap();
642 let ps: Vec<_> = d.occurrences().collect();
643 assert_eq!(&ps, &vec![(0, 40), (1, 10), (2, 10), (3, 10), (4, 10)])
648 }
649
650 #[test]
651 fn floor_div() {
652 let d = distribution_of("1d4 / 2").unwrap();
653 let ps: Vec<_> = d.occurrences().collect();
654 assert_eq!(&ps, &vec![(0, 1), (1, 2), (2, 1)])
655 }
656}