Skip to main content

roll/
lib.rs

1#![warn(clippy::pedantic)]
2//! A dice roller library for tabletop RPGs.
3//!
4//! Supports standard dice notation (e.g. `2d10+4`), advantage/disadvantage,
5//! keep-highest/lowest (`4d6kh3`), and Monte Carlo / exact probability estimation.
6//!
7//! # Examples
8//!
9//! ```
10//! use roll::{parse_expr, Modifier};
11//!
12//! let expr = parse_expr("2d10+4").unwrap();
13//! assert_eq!(expr.flat_bonus, 4);
14//! assert_eq!(expr.modifier, Modifier::None);
15//! assert_eq!(expr.groups.len(), 1);
16//! assert_eq!(expr.groups[0].count, 2);
17//! assert_eq!(expr.groups[0].sides, 10);
18//! ```
19
20use rand::Rng;
21use std::collections::BTreeMap;
22use std::fmt;
23
24// ── Error type ────────────────────────────────────────────────────────────────
25
26/// Error type for dice expression parsing failures.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum ParseError {
29    InvalidDiceCount(String),
30    InvalidSides(String),
31    NegativeDiceGroup,
32    NoDiceFound,
33    InvalidToken(String),
34}
35
36impl fmt::Display for ParseError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::InvalidDiceCount(s) => write!(f, "invalid dice count: '{s}'"),
40            Self::InvalidSides(s) => write!(f, "invalid sides: '{s}'"),
41            Self::NegativeDiceGroup => write!(f, "negative dice groups are not supported"),
42            Self::NoDiceFound => write!(f, "no dice found in expression"),
43            Self::InvalidToken(s) => write!(f, "invalid token: '{s}'"),
44        }
45    }
46}
47
48impl std::error::Error for ParseError {}
49
50// ── Core types ────────────────────────────────────────────────────────────────
51
52/// Keep rule applied to a dice group after rolling.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum Keep {
55    /// Keep all dice (default).
56    All,
57    /// Keep only the N highest dice.
58    Highest(u32),
59    /// Keep only the N lowest dice.
60    Lowest(u32),
61}
62
63/// Modifier applied to a dice roll (advantage, disadvantage, or none).
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum Modifier {
66    None,
67    Advantage,
68    Disadvantage,
69}
70
71/// A group of identical dice, e.g. `2d10` means 2 ten-sided dice.
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct DiceGroup {
74    pub count: u32,
75    pub sides: u32,
76    pub keep: Keep,
77}
78
79/// A parsed dice expression such as `adv 2d10+1d4+3`.
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct DiceExpr {
82    pub modifier: Modifier,
83    pub groups: Vec<DiceGroup>,
84    pub flat_bonus: i64,
85}
86
87impl fmt::Display for DiceExpr {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self.modifier {
90            Modifier::Advantage => write!(f, "adv ")?,
91            Modifier::Disadvantage => write!(f, "dis ")?,
92            Modifier::None => {}
93        }
94        for (i, g) in self.groups.iter().enumerate() {
95            if i > 0 {
96                write!(f, "+")?;
97            }
98            write!(f, "{}d{}", g.count, g.sides)?;
99            match g.keep {
100                Keep::All => {}
101                Keep::Highest(n) => write!(f, "kh{n}")?,
102                Keep::Lowest(n) => write!(f, "kl{n}")?,
103            }
104        }
105        if self.flat_bonus > 0 {
106            write!(f, "+{}", self.flat_bonus)?;
107        } else if self.flat_bonus < 0 {
108            write!(f, "{}", self.flat_bonus)?;
109        }
110        Ok(())
111    }
112}
113
114/// Theoretical statistics for a dice expression (min, max, mean).
115///
116/// Computed analytically; does not account for advantage/disadvantage.
117#[derive(Debug, Clone, PartialEq)]
118pub struct RollStats {
119    pub min: i64,
120    pub max: i64,
121    pub mean: f64,
122}
123
124// ── Parsing ───────────────────────────────────────────────────────────────────
125
126/// Split an expression string into sign-annotated tokens on `+`/`-` boundaries.
127///
128/// `"2d10 + 1d4 - 3"` → `[(1, "2d10"), (1, "1d4"), (-1, "3")]`
129fn split_signed_tokens(s: &str) -> Vec<(i64, &str)> {
130    let mut tokens = Vec::new();
131    let mut sign: i64 = 1;
132    let mut token_start = 0usize;
133
134    for (i, ch) in s.char_indices() {
135        if ch == '+' || ch == '-' {
136            let tok = s[token_start..i].trim();
137            if !tok.is_empty() {
138                tokens.push((sign, tok));
139            }
140            sign = if ch == '-' { -1 } else { 1 };
141            token_start = i + ch.len_utf8();
142        }
143    }
144    let tok = s[token_start..].trim();
145    if !tok.is_empty() {
146        tokens.push((sign, tok));
147    }
148    tokens
149}
150
151/// Parse a single dice token (already lowercased) that may contain a `kh`/`kl` keep suffix.
152fn parse_dice_token(token: &str) -> Result<DiceGroup, ParseError> {
153    let (dice_part, keep) = if let Some(pos) = token.find("kh") {
154        let n: u32 = token[pos + 2..]
155            .parse()
156            .map_err(|_| ParseError::InvalidSides(token[pos + 2..].to_string()))?;
157        (&token[..pos], Keep::Highest(n))
158    } else if let Some(pos) = token.find("kl") {
159        let n: u32 = token[pos + 2..]
160            .parse()
161            .map_err(|_| ParseError::InvalidSides(token[pos + 2..].to_string()))?;
162        (&token[..pos], Keep::Lowest(n))
163    } else {
164        (token, Keep::All)
165    };
166
167    let d_pos = dice_part
168        .find('d')
169        .ok_or_else(|| ParseError::InvalidToken(dice_part.to_string()))?;
170
171    let count_str = &dice_part[..d_pos];
172    let sides_str = &dice_part[d_pos + 1..];
173
174    let count: u32 = if count_str.is_empty() {
175        1
176    } else {
177        count_str
178            .parse()
179            .map_err(|_| ParseError::InvalidDiceCount(count_str.to_string()))?
180    };
181    if count == 0 {
182        return Err(ParseError::InvalidDiceCount(
183            "count must be at least 1".to_string(),
184        ));
185    }
186
187    let sides: u32 = sides_str
188        .parse()
189        .map_err(|_| ParseError::InvalidSides(sides_str.to_string()))?;
190    if sides == 0 {
191        return Err(ParseError::InvalidSides(
192            "sides must be at least 1".to_string(),
193        ));
194    }
195
196    Ok(DiceGroup { count, sides, keep })
197}
198
199/// Parse a dice expression string into a [`DiceExpr`].
200///
201/// Supports expressions like `"2d10+4"`, `"adv d20+5"`, `"dis d20-1"`,
202/// `"2d6+1d4+3"`, and `"4d6kh3"` (keep highest 3 of 4d6).
203///
204/// # Errors
205///
206/// Returns a [`ParseError`] if the expression is malformed.
207pub fn parse_expr(input: &str) -> Result<DiceExpr, ParseError> {
208    let input = input.trim().to_lowercase();
209    if input.is_empty() {
210        return Err(ParseError::NoDiceFound);
211    }
212
213    let (modifier, rest) = if let Some(r) = input.strip_prefix("adv") {
214        (Modifier::Advantage, r.trim_start())
215    } else if let Some(r) = input.strip_prefix("dis") {
216        (Modifier::Disadvantage, r.trim_start())
217    } else {
218        (Modifier::None, input.as_str())
219    };
220
221    let mut groups = Vec::new();
222    let mut flat_bonus: i64 = 0;
223
224    for (sign, token) in split_signed_tokens(rest) {
225        if token.contains('d') {
226            if sign == -1 {
227                return Err(ParseError::NegativeDiceGroup);
228            }
229            groups.push(parse_dice_token(token)?);
230        } else {
231            let val: i64 = token
232                .parse()
233                .map_err(|_| ParseError::InvalidToken(token.to_string()))?;
234            flat_bonus += sign * val;
235        }
236    }
237
238    if groups.is_empty() {
239        return Err(ParseError::NoDiceFound);
240    }
241
242    Ok(DiceExpr {
243        modifier,
244        groups,
245        flat_bonus,
246    })
247}
248
249// ── Rolling ───────────────────────────────────────────────────────────────────
250
251/// Roll the dice once, returning the total and the kept dice per group.
252///
253/// For groups with a [`Keep`] rule, only the kept dice are included in the inner
254/// `Vec`; the total already reflects the keep logic.
255#[must_use]
256pub fn roll_once(expr: &DiceExpr, rng: &mut impl Rng) -> (i64, Vec<Vec<u32>>) {
257    let mut total: i64 = expr.flat_bonus;
258    let mut all_rolls = Vec::new();
259
260    for g in &expr.groups {
261        let mut rolls: Vec<u32> = (0..g.count)
262            .map(|_| rng.random_range(1..=g.sides))
263            .collect();
264
265        let kept = match &g.keep {
266            Keep::All => {
267                total += rolls.iter().map(|&r| i64::from(r)).sum::<i64>();
268                rolls
269            }
270            Keep::Highest(n) => {
271                rolls.sort_unstable_by(|a, b| b.cmp(a));
272                let kept: Vec<u32> = rolls.iter().take(*n as usize).copied().collect();
273                total += kept.iter().map(|&r| i64::from(r)).sum::<i64>();
274                kept
275            }
276            Keep::Lowest(n) => {
277                rolls.sort_unstable();
278                let kept: Vec<u32> = rolls.iter().take(*n as usize).copied().collect();
279                total += kept.iter().map(|&r| i64::from(r)).sum::<i64>();
280                kept
281            }
282        };
283
284        all_rolls.push(kept);
285    }
286
287    (total, all_rolls)
288}
289
290/// Roll and return just the final value, applying advantage/disadvantage.
291#[must_use]
292pub fn roll_value(expr: &DiceExpr, rng: &mut impl Rng) -> i64 {
293    match expr.modifier {
294        Modifier::None => roll_once(expr, rng).0,
295        Modifier::Advantage => {
296            let a = roll_once(expr, rng).0;
297            let b = roll_once(expr, rng).0;
298            a.max(b)
299        }
300        Modifier::Disadvantage => {
301            let a = roll_once(expr, rng).0;
302            let b = roll_once(expr, rng).0;
303            a.min(b)
304        }
305    }
306}
307
308/// Roll with detailed output, returning the total and a human-readable breakdown.
309#[must_use]
310pub fn roll_verbose(expr: &DiceExpr, rng: &mut impl Rng) -> (i64, String) {
311    match expr.modifier {
312        Modifier::None => {
313            let (total, rolls) = roll_once(expr, rng);
314            (total, format_rolls(&rolls))
315        }
316        Modifier::Advantage | Modifier::Disadvantage => {
317            let (a, rolls_a) = roll_once(expr, rng);
318            let (b, rolls_b) = roll_once(expr, rng);
319            let total = if expr.modifier == Modifier::Advantage {
320                a.max(b)
321            } else {
322                a.min(b)
323            };
324            (
325                total,
326                format!("{} vs {}", format_rolls(&rolls_a), format_rolls(&rolls_b)),
327            )
328        }
329    }
330}
331
332/// Format roll results as a human-readable string like `[3, 5] + [2]`.
333#[must_use]
334pub fn format_rolls(rolls: &[Vec<u32>]) -> String {
335    rolls
336        .iter()
337        .map(|group| {
338            let inner: Vec<String> = group.iter().map(|r| r.to_string()).collect();
339            format!("[{}]", inner.join(", "))
340        })
341        .collect::<Vec<_>>()
342        .join(" + ")
343}
344
345// ── Statistics ────────────────────────────────────────────────────────────────
346
347/// Compute theoretical min, max, and mean for a [`DiceExpr`].
348///
349/// Ignores advantage/disadvantage (those require simulation to compute exactly).
350/// For keep groups, uses the kept count to compute bounds (not statistically
351/// exact for `kh`/`kl`, but gives useful ballpark figures).
352#[must_use]
353pub fn roll_stats(expr: &DiceExpr) -> RollStats {
354    let mut min = expr.flat_bonus;
355    let mut max = expr.flat_bonus;
356    let mut mean = expr.flat_bonus as f64;
357
358    for g in &expr.groups {
359        let keep_count = match g.keep {
360            Keep::All => g.count,
361            Keep::Highest(n) | Keep::Lowest(n) => n,
362        };
363        min += i64::from(keep_count);
364        max += i64::from(keep_count) * i64::from(g.sides);
365        mean += f64::from(g.sides + 1) / 2.0 * f64::from(keep_count);
366    }
367
368    RollStats { min, max, mean }
369}
370
371// ── Distribution ──────────────────────────────────────────────────────────────
372
373/// Run a Monte Carlo simulation and return the count of each result value.
374#[must_use]
375pub fn compute_distribution(expr: &DiceExpr, sims: u64, rng: &mut impl Rng) -> BTreeMap<i64, u64> {
376    let mut counts = BTreeMap::new();
377    for _ in 0..sims {
378        *counts.entry(roll_value(expr, rng)).or_insert(0) += 1;
379    }
380    counts
381}
382
383/// Render a probability distribution histogram as a string.
384#[must_use]
385#[allow(clippy::cast_precision_loss)]
386pub fn render_distribution(expr: &DiceExpr, counts: &BTreeMap<i64, u64>, sims: u64) -> String {
387    let mut out = format!("Distribution for {expr} ({sims} simulations):\n");
388
389    let (&min_val, &max_val) = match (counts.keys().next(), counts.keys().next_back()) {
390        (Some(lo), Some(hi)) => (lo, hi),
391        _ => return out,
392    };
393
394    let max_count = *counts.values().max().unwrap_or(&1);
395    let label_width = max_val.to_string().len().max(min_val.to_string().len());
396    const MAX_BAR: usize = 40;
397
398    for v in min_val..=max_val {
399        let count = counts.get(&v).copied().unwrap_or(0);
400        let pct = count as f64 / sims as f64 * 100.0;
401        let bar_len = if max_count > 0 {
402            (count as f64 / max_count as f64 * MAX_BAR as f64).round() as usize
403        } else {
404            0
405        };
406        let bar: String = "\u{2588}".repeat(bar_len);
407        out.push_str(&format!(
408            " {:>width$} | {:>5.1}% {}\n",
409            v,
410            pct,
411            bar,
412            width = label_width,
413        ));
414    }
415
416    out
417}
418
419// ── Probability ───────────────────────────────────────────────────────────────
420
421/// Compute the exact probability of rolling at least `target` via polynomial convolution.
422///
423/// Returns `None` when the expression is too complex for analytical computation
424/// (i.e. advantage/disadvantage is active, or any group uses a keep rule).
425#[must_use]
426#[allow(clippy::cast_precision_loss)]
427pub fn exact_probability(expr: &DiceExpr, target: i64) -> Option<f64> {
428    if expr.modifier != Modifier::None {
429        return None;
430    }
431    if expr.groups.iter().any(|g| g.keep != Keep::All) {
432        return None;
433    }
434
435    // Convolve uniform distributions for each individual die.
436    let mut dist: BTreeMap<i64, f64> = BTreeMap::new();
437    dist.insert(0, 1.0);
438
439    for g in &expr.groups {
440        let p = 1.0 / f64::from(g.sides);
441        for _ in 0..g.count {
442            let mut new_dist: BTreeMap<i64, f64> = BTreeMap::new();
443            for (&val, &prob) in &dist {
444                for face in 1..=g.sides {
445                    *new_dist.entry(val + i64::from(face)).or_insert(0.0) += prob * p;
446                }
447            }
448            dist = new_dist;
449        }
450    }
451
452    // P(dice_total + flat_bonus >= target)  ≡  P(dice_total >= target - flat_bonus)
453    let adjusted = target - expr.flat_bonus;
454    let prob: f64 = dist.range(adjusted..).map(|(_, &p)| p).sum();
455    Some(prob)
456}
457
458/// Estimate the probability of rolling at least `target` using Monte Carlo simulation.
459#[must_use]
460#[allow(clippy::cast_precision_loss)]
461pub fn estimate_probability(expr: &DiceExpr, target: i64, sims: u64, rng: &mut impl Rng) -> f64 {
462    let hits = (0..sims)
463        .filter(|_| roll_value(expr, rng) >= target)
464        .count();
465    hits as f64 / sims as f64
466}
467
468// ── Tests ─────────────────────────────────────────────────────────────────────
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use rand::SeedableRng;
474    use rand::rngs::StdRng;
475
476    fn seeded_rng() -> StdRng {
477        StdRng::seed_from_u64(42)
478    }
479
480    // ---- parse_expr tests ----
481
482    #[test]
483    fn parse_simple_dice() {
484        let expr = parse_expr("2d10").unwrap();
485        assert_eq!(expr.modifier, Modifier::None);
486        assert_eq!(expr.groups.len(), 1);
487        assert_eq!(expr.groups[0].count, 2);
488        assert_eq!(expr.groups[0].sides, 10);
489        assert_eq!(expr.flat_bonus, 0);
490        assert_eq!(expr.groups[0].keep, Keep::All);
491    }
492
493    #[test]
494    fn parse_single_die_shorthand() {
495        let expr = parse_expr("d20").unwrap();
496        assert_eq!(expr.groups[0].count, 1);
497        assert_eq!(expr.groups[0].sides, 20);
498    }
499
500    #[test]
501    fn parse_with_positive_bonus() {
502        let expr = parse_expr("2d10+4").unwrap();
503        assert_eq!(expr.flat_bonus, 4);
504    }
505
506    #[test]
507    fn parse_with_negative_bonus() {
508        let expr = parse_expr("d20-3").unwrap();
509        assert_eq!(expr.flat_bonus, -3);
510    }
511
512    #[test]
513    fn parse_advantage() {
514        let expr = parse_expr("adv d20+5").unwrap();
515        assert_eq!(expr.modifier, Modifier::Advantage);
516        assert_eq!(expr.groups[0].count, 1);
517        assert_eq!(expr.groups[0].sides, 20);
518        assert_eq!(expr.flat_bonus, 5);
519    }
520
521    #[test]
522    fn parse_disadvantage() {
523        let expr = parse_expr("dis d20-1").unwrap();
524        assert_eq!(expr.modifier, Modifier::Disadvantage);
525        assert_eq!(expr.flat_bonus, -1);
526    }
527
528    #[test]
529    fn parse_multiple_groups() {
530        let expr = parse_expr("2d6+1d4+3").unwrap();
531        assert_eq!(expr.groups.len(), 2);
532        assert_eq!(expr.groups[0].count, 2);
533        assert_eq!(expr.groups[0].sides, 6);
534        assert_eq!(expr.groups[1].count, 1);
535        assert_eq!(expr.groups[1].sides, 4);
536        assert_eq!(expr.flat_bonus, 3);
537    }
538
539    #[test]
540    fn parse_case_insensitive() {
541        let expr = parse_expr("ADV D20+5").unwrap();
542        assert_eq!(expr.modifier, Modifier::Advantage);
543    }
544
545    #[test]
546    fn parse_with_whitespace() {
547        let expr = parse_expr("  2d10 + 4  ").unwrap();
548        assert_eq!(expr.groups[0].count, 2);
549        assert_eq!(expr.flat_bonus, 4);
550    }
551
552    #[test]
553    fn parse_no_dice_error() {
554        assert!(parse_expr("42").is_err());
555    }
556
557    #[test]
558    fn parse_negative_dice_group_error() {
559        assert!(parse_expr("d20-2d6").is_err());
560    }
561
562    #[test]
563    fn parse_invalid_sides_error() {
564        assert!(parse_expr("2dx").is_err());
565    }
566
567    #[test]
568    fn parse_empty_error() {
569        assert!(parse_expr("").is_err());
570    }
571
572    #[test]
573    fn parse_zero_sides_error() {
574        assert_eq!(
575            parse_expr("2d0"),
576            Err(ParseError::InvalidSides(
577                "sides must be at least 1".to_string()
578            ))
579        );
580    }
581
582    #[test]
583    fn parse_zero_count_error() {
584        assert_eq!(
585            parse_expr("0d6"),
586            Err(ParseError::InvalidDiceCount(
587                "count must be at least 1".to_string()
588            ))
589        );
590    }
591
592    #[test]
593    fn parse_keep_highest() {
594        let expr = parse_expr("4d6kh3").unwrap();
595        assert_eq!(expr.groups[0].count, 4);
596        assert_eq!(expr.groups[0].sides, 6);
597        assert_eq!(expr.groups[0].keep, Keep::Highest(3));
598    }
599
600    #[test]
601    fn parse_keep_lowest() {
602        let expr = parse_expr("4d6kl1").unwrap();
603        assert_eq!(expr.groups[0].keep, Keep::Lowest(1));
604    }
605
606    #[test]
607    fn parse_keep_with_bonus() {
608        let expr = parse_expr("4d6kh3+2").unwrap();
609        assert_eq!(expr.groups[0].keep, Keep::Highest(3));
610        assert_eq!(expr.flat_bonus, 2);
611    }
612
613    // ---- Display tests ----
614
615    #[test]
616    fn display_simple() {
617        let expr = parse_expr("2d10+4").unwrap();
618        assert_eq!(expr.to_string(), "2d10+4");
619    }
620
621    #[test]
622    fn display_advantage() {
623        let expr = parse_expr("adv d20+5").unwrap();
624        assert_eq!(expr.to_string(), "adv 1d20+5");
625    }
626
627    #[test]
628    fn display_negative_bonus() {
629        let expr = parse_expr("d20-3").unwrap();
630        assert_eq!(expr.to_string(), "1d20-3");
631    }
632
633    #[test]
634    fn display_no_bonus() {
635        let expr = parse_expr("d20").unwrap();
636        assert_eq!(expr.to_string(), "1d20");
637    }
638
639    #[test]
640    fn display_keep_highest() {
641        let expr = parse_expr("4d6kh3").unwrap();
642        assert_eq!(expr.to_string(), "4d6kh3");
643    }
644
645    // ---- Rolling tests ----
646
647    #[test]
648    fn roll_once_within_bounds() {
649        let expr = parse_expr("2d6").unwrap();
650        let mut rng = seeded_rng();
651        for _ in 0..100 {
652            let (total, rolls) = roll_once(&expr, &mut rng);
653            assert!(total >= 2 && total <= 12);
654            assert_eq!(rolls.len(), 1);
655            assert_eq!(rolls[0].len(), 2);
656            for &r in &rolls[0] {
657                assert!(r >= 1 && r <= 6);
658            }
659        }
660    }
661
662    #[test]
663    fn roll_once_applies_flat_bonus() {
664        let expr = parse_expr("1d6+10").unwrap();
665        let mut rng = seeded_rng();
666        for _ in 0..100 {
667            let (total, _) = roll_once(&expr, &mut rng);
668            assert!(total >= 11 && total <= 16);
669        }
670    }
671
672    #[test]
673    fn roll_once_keep_highest() {
674        let expr = parse_expr("4d6kh3").unwrap();
675        let mut rng = seeded_rng();
676        for _ in 0..100 {
677            let (total, rolls) = roll_once(&expr, &mut rng);
678            // Only 3 dice kept
679            assert_eq!(rolls[0].len(), 3);
680            // Kept dice are sorted descending
681            assert!(rolls[0].windows(2).all(|w| w[0] >= w[1]));
682            // Total equals sum of kept dice
683            let sum: i64 = rolls[0].iter().map(|&r| i64::from(r)).sum();
684            assert_eq!(total, sum);
685            // Each die is within range
686            assert!(total >= 3 && total <= 18);
687        }
688    }
689
690    #[test]
691    fn roll_once_keep_lowest() {
692        let expr = parse_expr("4d6kl1").unwrap();
693        let mut rng = seeded_rng();
694        for _ in 0..100 {
695            let (total, rolls) = roll_once(&expr, &mut rng);
696            assert_eq!(rolls[0].len(), 1);
697            assert!(total >= 1 && total <= 6);
698        }
699    }
700
701    #[test]
702    fn roll_value_deterministic_with_seed() {
703        let expr = parse_expr("d20").unwrap();
704        let mut rng1 = seeded_rng();
705        let mut rng2 = seeded_rng();
706        let a = roll_value(&expr, &mut rng1);
707        let b = roll_value(&expr, &mut rng2);
708        assert_eq!(a, b);
709    }
710
711    #[test]
712    fn roll_value_advantage_takes_higher() {
713        let expr = parse_expr("adv d20").unwrap();
714        let mut rng = seeded_rng();
715        for _ in 0..100 {
716            let adv = roll_value(&expr, &mut rng);
717            assert!(adv >= 1 && adv <= 20);
718        }
719    }
720
721    #[test]
722    fn roll_value_disadvantage_takes_lower() {
723        let expr = parse_expr("dis d20").unwrap();
724        let mut rng = seeded_rng();
725        for _ in 0..100 {
726            let dis = roll_value(&expr, &mut rng);
727            assert!(dis >= 1 && dis <= 20);
728        }
729    }
730
731    #[test]
732    fn advantage_greater_equal_disadvantage() {
733        let adv_expr = parse_expr("adv d20").unwrap();
734        let dis_expr = parse_expr("dis d20").unwrap();
735        let mut rng = seeded_rng();
736        let mut adv_total: i64 = 0;
737        let mut dis_total: i64 = 0;
738        let n = 10_000;
739        for _ in 0..n {
740            adv_total += roll_value(&adv_expr, &mut rng);
741            dis_total += roll_value(&dis_expr, &mut rng);
742        }
743        assert!(adv_total > dis_total);
744    }
745
746    // ---- roll_verbose tests ----
747
748    #[test]
749    fn roll_verbose_includes_rolls() {
750        let expr = parse_expr("2d6").unwrap();
751        let mut rng = seeded_rng();
752        let (_, detail) = roll_verbose(&expr, &mut rng);
753        assert!(detail.starts_with('['));
754        assert!(detail.contains(']'));
755    }
756
757    #[test]
758    fn roll_verbose_advantage_shows_vs() {
759        let expr = parse_expr("adv d20").unwrap();
760        let mut rng = seeded_rng();
761        let (_, detail) = roll_verbose(&expr, &mut rng);
762        assert!(detail.contains("vs"));
763    }
764
765    // ---- format_rolls tests ----
766
767    #[test]
768    fn format_rolls_single_group() {
769        assert_eq!(format_rolls(&[vec![3, 5]]), "[3, 5]");
770    }
771
772    #[test]
773    fn format_rolls_multiple_groups() {
774        assert_eq!(format_rolls(&[vec![3, 5], vec![2]]), "[3, 5] + [2]");
775    }
776
777    // ---- roll_stats tests ----
778
779    #[test]
780    fn roll_stats_d6() {
781        let expr = parse_expr("d6").unwrap();
782        let stats = roll_stats(&expr);
783        assert_eq!(stats.min, 1);
784        assert_eq!(stats.max, 6);
785        assert!((stats.mean - 3.5).abs() < f64::EPSILON);
786    }
787
788    #[test]
789    fn roll_stats_with_bonus() {
790        let expr = parse_expr("2d6+5").unwrap();
791        let stats = roll_stats(&expr);
792        assert_eq!(stats.min, 7);
793        assert_eq!(stats.max, 17);
794        assert!((stats.mean - 12.0).abs() < f64::EPSILON);
795    }
796
797    #[test]
798    fn roll_stats_keep_highest() {
799        // 4d6kh3: keep 3 dice
800        let expr = parse_expr("4d6kh3").unwrap();
801        let stats = roll_stats(&expr);
802        assert_eq!(stats.min, 3);
803        assert_eq!(stats.max, 18);
804    }
805
806    // ---- compute_distribution tests ----
807
808    #[test]
809    fn distribution_d6_has_all_values() {
810        let expr = parse_expr("d6").unwrap();
811        let mut rng = seeded_rng();
812        let counts = compute_distribution(&expr, 100_000, &mut rng);
813        for v in 1..=6 {
814            assert!(counts.contains_key(&v), "missing value {v}");
815        }
816        assert!(!counts.contains_key(&0));
817        assert!(!counts.contains_key(&7));
818    }
819
820    #[test]
821    fn distribution_counts_sum_to_sims() {
822        let expr = parse_expr("2d6+3").unwrap();
823        let mut rng = seeded_rng();
824        let sims = 50_000;
825        let counts = compute_distribution(&expr, sims, &mut rng);
826        let total: u64 = counts.values().sum();
827        assert_eq!(total, sims);
828    }
829
830    // ---- render_distribution tests ----
831
832    #[test]
833    fn render_distribution_contains_all_values() {
834        let expr = parse_expr("d6").unwrap();
835        let mut counts = BTreeMap::new();
836        for v in 1..=6 {
837            counts.insert(v, 1000);
838        }
839        let output = render_distribution(&expr, &counts, 6000);
840        assert!(output.starts_with("Distribution for"));
841        for v in 1..=6 {
842            assert!(output.contains(&format!("{v} |")));
843        }
844    }
845
846    #[test]
847    fn render_distribution_percentages() {
848        let expr = parse_expr("d6").unwrap();
849        let mut counts = BTreeMap::new();
850        counts.insert(1, 500);
851        counts.insert(2, 500);
852        let output = render_distribution(&expr, &counts, 1000);
853        assert!(output.contains("50.0%"));
854    }
855
856    // ---- exact_probability tests ----
857
858    #[test]
859    fn exact_probability_d6_at_least_1_is_100_percent() {
860        let expr = parse_expr("d6").unwrap();
861        let p = exact_probability(&expr, 1).unwrap();
862        assert!((p - 1.0).abs() < f64::EPSILON);
863    }
864
865    #[test]
866    fn exact_probability_d6_at_least_7_is_0_percent() {
867        let expr = parse_expr("d6").unwrap();
868        let p = exact_probability(&expr, 7).unwrap();
869        assert!(p.abs() < f64::EPSILON);
870    }
871
872    #[test]
873    fn exact_probability_d6_at_least_4_is_50_percent() {
874        let expr = parse_expr("d6").unwrap();
875        let p = exact_probability(&expr, 4).unwrap();
876        assert!((p - 0.5).abs() < 1e-10);
877    }
878
879    #[test]
880    fn exact_probability_returns_none_for_advantage() {
881        let expr = parse_expr("adv d20").unwrap();
882        assert!(exact_probability(&expr, 15).is_none());
883    }
884
885    #[test]
886    fn exact_probability_returns_none_for_keep() {
887        let expr = parse_expr("4d6kh3").unwrap();
888        assert!(exact_probability(&expr, 10).is_none());
889    }
890
891    #[test]
892    fn exact_probability_2d6_known_value() {
893        // P(2d6 >= 7) = 21/36 = 7/12
894        let expr = parse_expr("2d6").unwrap();
895        let p = exact_probability(&expr, 7).unwrap();
896        assert!((p - 7.0 / 12.0).abs() < 1e-10);
897    }
898
899    #[test]
900    fn exact_probability_with_flat_bonus() {
901        // P(d6+3 >= 7) = P(d6 >= 4) = 3/6 = 0.5
902        let expr = parse_expr("d6+3").unwrap();
903        let p = exact_probability(&expr, 7).unwrap();
904        assert!((p - 0.5).abs() < 1e-10);
905    }
906
907    #[test]
908    fn exact_probability_sums_to_one() {
909        // Sum of P(2d6 >= k) - P(2d6 >= k+1) across all outcomes should equal 1.
910        // Equivalently, P(2d6 >= 2) should be 1.0.
911        let expr = parse_expr("2d6").unwrap();
912        let p = exact_probability(&expr, 2).unwrap();
913        assert!((p - 1.0).abs() < 1e-10);
914    }
915
916    // ---- ParseError display tests ----
917
918    #[test]
919    fn parse_error_display_no_dice() {
920        assert_eq!(
921            ParseError::NoDiceFound.to_string(),
922            "no dice found in expression"
923        );
924    }
925
926    #[test]
927    fn parse_error_display_negative_group() {
928        assert_eq!(
929            ParseError::NegativeDiceGroup.to_string(),
930            "negative dice groups are not supported"
931        );
932    }
933
934    #[test]
935    fn parse_error_display_invalid_token() {
936        assert_eq!(
937            ParseError::InvalidToken("foo".to_string()).to_string(),
938            "invalid token: 'foo'"
939        );
940    }
941
942    #[test]
943    fn parse_error_display_invalid_sides() {
944        assert_eq!(
945            ParseError::InvalidSides("sides must be at least 1".to_string()).to_string(),
946            "invalid sides: 'sides must be at least 1'"
947        );
948    }
949
950    #[test]
951    fn parse_error_display_invalid_count() {
952        assert_eq!(
953            ParseError::InvalidDiceCount("count must be at least 1".to_string()).to_string(),
954            "invalid dice count: 'count must be at least 1'"
955        );
956    }
957
958    // ---- roll_verbose with keep tests ----
959
960    #[test]
961    fn roll_verbose_keep_shows_kept_count() {
962        // 4d6kh3 keeps 3 dice; the detail should contain exactly 3 numbers in brackets
963        let expr = parse_expr("4d6kh3").unwrap();
964        let mut rng = seeded_rng();
965        for _ in 0..20 {
966            let (_, detail) = roll_verbose(&expr, &mut rng);
967            // Detail looks like "[a, b, c]"; split on ',' to count dice
968            let inner = detail.trim_start_matches('[').trim_end_matches(']');
969            assert_eq!(
970                inner.split(',').count(),
971                3,
972                "expected 3 kept dice, got: {detail}"
973            );
974        }
975    }
976
977    // ---- estimate_probability tests ----
978
979    #[test]
980    fn probability_d6_at_least_1_is_100_percent() {
981        let expr = parse_expr("d6").unwrap();
982        let mut rng = seeded_rng();
983        let p = estimate_probability(&expr, 1, 10_000, &mut rng);
984        assert!((p - 1.0).abs() < f64::EPSILON);
985    }
986
987    #[test]
988    fn probability_d6_at_least_7_is_0_percent() {
989        let expr = parse_expr("d6").unwrap();
990        let mut rng = seeded_rng();
991        let p = estimate_probability(&expr, 7, 10_000, &mut rng);
992        assert!(p.abs() < f64::EPSILON);
993    }
994
995    #[test]
996    fn probability_d6_at_least_4_roughly_50_percent() {
997        let expr = parse_expr("d6").unwrap();
998        let mut rng = seeded_rng();
999        let p = estimate_probability(&expr, 4, 100_000, &mut rng);
1000        assert!((p - 0.5).abs() < 0.02);
1001    }
1002}