Skip to main content

roll/
lib.rs

1#![warn(clippy::pedantic)]
2//! A dice roller library for tabletop RPGs.
3//!
4//! Supports standard dice notation (e.g. `2d10+4`), advantage/disadvantage,
5//! keep-highest/lowest (`4d6kh3`), and Monte Carlo / exact probability estimation.
6//!
7//! # Examples
8//!
9//! ```
10//! use roll::{parse_expr, Modifier};
11//!
12//! let expr = parse_expr("2d10+4").unwrap();
13//! assert_eq!(expr.flat_bonus, 4);
14//! assert_eq!(expr.modifier, Modifier::None);
15//! assert_eq!(expr.groups.len(), 1);
16//! assert_eq!(expr.groups[0].count, 2);
17//! assert_eq!(expr.groups[0].sides, 10);
18//! ```
19
20use rand::Rng;
21use std::collections::BTreeMap;
22use std::fmt;
23
24// ── Error type ────────────────────────────────────────────────────────────────
25
26/// Error type for dice expression parsing failures.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum ParseError {
29    InvalidDiceCount(String),
30    InvalidSides(String),
31    NegativeDiceGroup,
32    NoDiceFound,
33    InvalidToken(String),
34}
35
36impl fmt::Display for ParseError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::InvalidDiceCount(s) => write!(f, "invalid dice count: '{s}'"),
40            Self::InvalidSides(s) => write!(f, "invalid sides: '{s}'"),
41            Self::NegativeDiceGroup => write!(f, "negative dice groups are not supported"),
42            Self::NoDiceFound => write!(f, "no dice found in expression"),
43            Self::InvalidToken(s) => write!(f, "invalid token: '{s}'"),
44        }
45    }
46}
47
48impl std::error::Error for ParseError {}
49
50// ── Core types ────────────────────────────────────────────────────────────────
51
52/// Keep rule applied to a dice group after rolling.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum Keep {
55    /// Keep all dice (default).
56    All,
57    /// Keep only the N highest dice.
58    Highest(u32),
59    /// Keep only the N lowest dice.
60    Lowest(u32),
61}
62
63/// Modifier applied to a dice roll (advantage, disadvantage, or none).
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum Modifier {
66    None,
67    Advantage,
68    Disadvantage,
69}
70
71/// A group of identical dice, e.g. `2d10` means 2 ten-sided dice.
72#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct DiceGroup {
74    pub count: u32,
75    pub sides: u32,
76    pub keep: Keep,
77}
78
79/// A parsed dice expression such as `adv 2d10+1d4+3`.
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct DiceExpr {
82    pub modifier: Modifier,
83    pub groups: Vec<DiceGroup>,
84    pub flat_bonus: i64,
85}
86
87impl fmt::Display for DiceExpr {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self.modifier {
90            Modifier::Advantage => write!(f, "adv ")?,
91            Modifier::Disadvantage => write!(f, "dis ")?,
92            Modifier::None => {}
93        }
94        for (i, g) in self.groups.iter().enumerate() {
95            if i > 0 {
96                write!(f, "+")?;
97            }
98            write!(f, "{}d{}", g.count, g.sides)?;
99            match g.keep {
100                Keep::All => {}
101                Keep::Highest(n) => write!(f, "kh{n}")?,
102                Keep::Lowest(n) => write!(f, "kl{n}")?,
103            }
104        }
105        if self.flat_bonus > 0 {
106            write!(f, "+{}", self.flat_bonus)?;
107        } else if self.flat_bonus < 0 {
108            write!(f, "{}", self.flat_bonus)?;
109        }
110        Ok(())
111    }
112}
113
114/// Theoretical statistics for a dice expression (min, max, mean).
115///
116/// Computed analytically; does not account for advantage/disadvantage.
117#[derive(Debug, Clone, PartialEq)]
118pub struct RollStats {
119    pub min: i64,
120    pub max: i64,
121    pub mean: f64,
122}
123
124// ── Parsing ───────────────────────────────────────────────────────────────────
125
126/// Split an expression string into sign-annotated tokens on `+`/`-` boundaries.
127///
128/// `"2d10 + 1d4 - 3"` → `[(1, "2d10"), (1, "1d4"), (-1, "3")]`
129fn split_signed_tokens(s: &str) -> Vec<(i64, &str)> {
130    let mut tokens = Vec::new();
131    let mut sign: i64 = 1;
132    let mut token_start = 0usize;
133
134    for (i, ch) in s.char_indices() {
135        if ch == '+' || ch == '-' {
136            let tok = s[token_start..i].trim();
137            if !tok.is_empty() {
138                tokens.push((sign, tok));
139            }
140            sign = if ch == '-' { -1 } else { 1 };
141            token_start = i + ch.len_utf8();
142        }
143    }
144    let tok = s[token_start..].trim();
145    if !tok.is_empty() {
146        tokens.push((sign, tok));
147    }
148    tokens
149}
150
151/// Parse a single dice token (already lowercased) that may contain a `kh`/`kl` keep suffix.
152fn parse_dice_token(token: &str) -> Result<DiceGroup, ParseError> {
153    let (dice_part, keep) = if let Some(pos) = token.find("kh") {
154        let n: u32 = token[pos + 2..]
155            .parse()
156            .map_err(|_| ParseError::InvalidSides(token[pos + 2..].to_string()))?;
157        (&token[..pos], Keep::Highest(n))
158    } else if let Some(pos) = token.find("kl") {
159        let n: u32 = token[pos + 2..]
160            .parse()
161            .map_err(|_| ParseError::InvalidSides(token[pos + 2..].to_string()))?;
162        (&token[..pos], Keep::Lowest(n))
163    } else {
164        (token, Keep::All)
165    };
166
167    let d_pos = dice_part
168        .find('d')
169        .ok_or_else(|| ParseError::InvalidToken(dice_part.to_string()))?;
170
171    let count_str = &dice_part[..d_pos];
172    let sides_str = &dice_part[d_pos + 1..];
173
174    let count: u32 = if count_str.is_empty() {
175        1
176    } else {
177        count_str
178            .parse()
179            .map_err(|_| ParseError::InvalidDiceCount(count_str.to_string()))?
180    };
181    if count == 0 {
182        return Err(ParseError::InvalidDiceCount(
183            "count must be at least 1".to_string(),
184        ));
185    }
186
187    let sides: u32 = sides_str
188        .parse()
189        .map_err(|_| ParseError::InvalidSides(sides_str.to_string()))?;
190    if sides == 0 {
191        return Err(ParseError::InvalidSides(
192            "sides must be at least 1".to_string(),
193        ));
194    }
195
196    Ok(DiceGroup { count, sides, keep })
197}
198
199/// Parse a dice expression string into a [`DiceExpr`].
200///
201/// Supports expressions like `"2d10+4"`, `"adv d20+5"`, `"dis d20-1"`,
202/// `"2d6+1d4+3"`, and `"4d6kh3"` (keep highest 3 of 4d6).
203///
204/// # Errors
205///
206/// Returns a [`ParseError`] if the expression is malformed.
207pub fn parse_expr(input: &str) -> Result<DiceExpr, ParseError> {
208    let input = input.trim().to_lowercase();
209    if input.is_empty() {
210        return Err(ParseError::NoDiceFound);
211    }
212
213    let (modifier, rest) = if let Some(r) = input.strip_prefix("adv") {
214        (Modifier::Advantage, r.trim_start())
215    } else if let Some(r) = input.strip_prefix("dis") {
216        (Modifier::Disadvantage, r.trim_start())
217    } else {
218        (Modifier::None, input.as_str())
219    };
220
221    let mut groups = Vec::new();
222    let mut flat_bonus: i64 = 0;
223
224    for (sign, token) in split_signed_tokens(rest) {
225        if token.contains('d') {
226            if sign == -1 {
227                return Err(ParseError::NegativeDiceGroup);
228            }
229            groups.push(parse_dice_token(token)?);
230        } else {
231            let val: i64 = token
232                .parse()
233                .map_err(|_| ParseError::InvalidToken(token.to_string()))?;
234            flat_bonus += sign * val;
235        }
236    }
237
238    if groups.is_empty() {
239        return Err(ParseError::NoDiceFound);
240    }
241
242    Ok(DiceExpr {
243        modifier,
244        groups,
245        flat_bonus,
246    })
247}
248
249// ── Rolling ───────────────────────────────────────────────────────────────────
250
251/// Roll the dice once, returning the total and the kept dice per group.
252///
253/// For groups with a [`Keep`] rule, only the kept dice are included in the inner
254/// `Vec`; the total already reflects the keep logic.
255#[must_use]
256pub fn roll_once(expr: &DiceExpr, rng: &mut impl Rng) -> (i64, Vec<Vec<u32>>) {
257    let mut total: i64 = expr.flat_bonus;
258    let mut all_rolls = Vec::new();
259
260    for g in &expr.groups {
261        let mut rolls: Vec<u32> = (0..g.count).map(|_| rng.random_range(1..=g.sides)).collect();
262
263        let kept = match &g.keep {
264            Keep::All => {
265                total += rolls.iter().map(|&r| i64::from(r)).sum::<i64>();
266                rolls
267            }
268            Keep::Highest(n) => {
269                rolls.sort_unstable_by(|a, b| b.cmp(a));
270                let kept: Vec<u32> = rolls.iter().take(*n as usize).copied().collect();
271                total += kept.iter().map(|&r| i64::from(r)).sum::<i64>();
272                kept
273            }
274            Keep::Lowest(n) => {
275                rolls.sort_unstable();
276                let kept: Vec<u32> = rolls.iter().take(*n as usize).copied().collect();
277                total += kept.iter().map(|&r| i64::from(r)).sum::<i64>();
278                kept
279            }
280        };
281
282        all_rolls.push(kept);
283    }
284
285    (total, all_rolls)
286}
287
288/// Roll and return just the final value, applying advantage/disadvantage.
289#[must_use]
290pub fn roll_value(expr: &DiceExpr, rng: &mut impl Rng) -> i64 {
291    match expr.modifier {
292        Modifier::None => roll_once(expr, rng).0,
293        Modifier::Advantage => {
294            let a = roll_once(expr, rng).0;
295            let b = roll_once(expr, rng).0;
296            a.max(b)
297        }
298        Modifier::Disadvantage => {
299            let a = roll_once(expr, rng).0;
300            let b = roll_once(expr, rng).0;
301            a.min(b)
302        }
303    }
304}
305
306/// Roll with detailed output, returning the total and a human-readable breakdown.
307#[must_use]
308pub fn roll_verbose(expr: &DiceExpr, rng: &mut impl Rng) -> (i64, String) {
309    match expr.modifier {
310        Modifier::None => {
311            let (total, rolls) = roll_once(expr, rng);
312            (total, format_rolls(&rolls))
313        }
314        Modifier::Advantage | Modifier::Disadvantage => {
315            let (a, rolls_a) = roll_once(expr, rng);
316            let (b, rolls_b) = roll_once(expr, rng);
317            let total = if expr.modifier == Modifier::Advantage {
318                a.max(b)
319            } else {
320                a.min(b)
321            };
322            (
323                total,
324                format!("{} vs {}", format_rolls(&rolls_a), format_rolls(&rolls_b)),
325            )
326        }
327    }
328}
329
330/// Format roll results as a human-readable string like `[3, 5] + [2]`.
331#[must_use]
332pub fn format_rolls(rolls: &[Vec<u32>]) -> String {
333    rolls
334        .iter()
335        .map(|group| {
336            let inner: Vec<String> = group.iter().map(|r| r.to_string()).collect();
337            format!("[{}]", inner.join(", "))
338        })
339        .collect::<Vec<_>>()
340        .join(" + ")
341}
342
343// ── Statistics ────────────────────────────────────────────────────────────────
344
345/// Compute theoretical min, max, and mean for a [`DiceExpr`].
346///
347/// Ignores advantage/disadvantage (those require simulation to compute exactly).
348/// For keep groups, uses the kept count to compute bounds (not statistically
349/// exact for `kh`/`kl`, but gives useful ballpark figures).
350#[must_use]
351pub fn roll_stats(expr: &DiceExpr) -> RollStats {
352    let mut min = expr.flat_bonus;
353    let mut max = expr.flat_bonus;
354    let mut mean = expr.flat_bonus as f64;
355
356    for g in &expr.groups {
357        let keep_count = match g.keep {
358            Keep::All => g.count,
359            Keep::Highest(n) | Keep::Lowest(n) => n,
360        };
361        min += i64::from(keep_count);
362        max += i64::from(keep_count) * i64::from(g.sides);
363        mean += f64::from(g.sides + 1) / 2.0 * f64::from(keep_count);
364    }
365
366    RollStats { min, max, mean }
367}
368
369// ── Distribution ──────────────────────────────────────────────────────────────
370
371/// Run a Monte Carlo simulation and return the count of each result value.
372#[must_use]
373pub fn compute_distribution(expr: &DiceExpr, sims: u64, rng: &mut impl Rng) -> BTreeMap<i64, u64> {
374    let mut counts = BTreeMap::new();
375    for _ in 0..sims {
376        *counts.entry(roll_value(expr, rng)).or_insert(0) += 1;
377    }
378    counts
379}
380
381/// Render a probability distribution histogram as a string.
382#[must_use]
383#[allow(clippy::cast_precision_loss)]
384pub fn render_distribution(expr: &DiceExpr, counts: &BTreeMap<i64, u64>, sims: u64) -> String {
385    let mut out = format!("Distribution for {expr} ({sims} simulations):\n");
386
387    let (&min_val, &max_val) = match (counts.keys().next(), counts.keys().next_back()) {
388        (Some(lo), Some(hi)) => (lo, hi),
389        _ => return out,
390    };
391
392    let max_count = *counts.values().max().unwrap_or(&1);
393    let label_width = max_val.to_string().len().max(min_val.to_string().len());
394    const MAX_BAR: usize = 40;
395
396    for v in min_val..=max_val {
397        let count = counts.get(&v).copied().unwrap_or(0);
398        let pct = count as f64 / sims as f64 * 100.0;
399        let bar_len = if max_count > 0 {
400            (count as f64 / max_count as f64 * MAX_BAR as f64).round() as usize
401        } else {
402            0
403        };
404        let bar: String = "\u{2588}".repeat(bar_len);
405        out.push_str(&format!(
406            " {:>width$} | {:>5.1}% {}\n",
407            v,
408            pct,
409            bar,
410            width = label_width,
411        ));
412    }
413
414    out
415}
416
417// ── Probability ───────────────────────────────────────────────────────────────
418
419/// Compute the exact probability of rolling at least `target` via polynomial convolution.
420///
421/// Returns `None` when the expression is too complex for analytical computation
422/// (i.e. advantage/disadvantage is active, or any group uses a keep rule).
423#[must_use]
424#[allow(clippy::cast_precision_loss)]
425pub fn exact_probability(expr: &DiceExpr, target: i64) -> Option<f64> {
426    if expr.modifier != Modifier::None {
427        return None;
428    }
429    if expr.groups.iter().any(|g| g.keep != Keep::All) {
430        return None;
431    }
432
433    // Convolve uniform distributions for each individual die.
434    let mut dist: BTreeMap<i64, f64> = BTreeMap::new();
435    dist.insert(0, 1.0);
436
437    for g in &expr.groups {
438        let p = 1.0 / f64::from(g.sides);
439        for _ in 0..g.count {
440            let mut new_dist: BTreeMap<i64, f64> = BTreeMap::new();
441            for (&val, &prob) in &dist {
442                for face in 1..=g.sides {
443                    *new_dist
444                        .entry(val + i64::from(face))
445                        .or_insert(0.0) += prob * p;
446                }
447            }
448            dist = new_dist;
449        }
450    }
451
452    // P(dice_total + flat_bonus >= target)  ≡  P(dice_total >= target - flat_bonus)
453    let adjusted = target - expr.flat_bonus;
454    let prob: f64 = dist.range(adjusted..).map(|(_, &p)| p).sum();
455    Some(prob)
456}
457
458/// Estimate the probability of rolling at least `target` using Monte Carlo simulation.
459#[must_use]
460#[allow(clippy::cast_precision_loss)]
461pub fn estimate_probability(expr: &DiceExpr, target: i64, sims: u64, rng: &mut impl Rng) -> f64 {
462    let hits = (0..sims)
463        .filter(|_| roll_value(expr, rng) >= target)
464        .count();
465    hits as f64 / sims as f64
466}
467
468// ── Tests ─────────────────────────────────────────────────────────────────────
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use rand::SeedableRng;
474    use rand::rngs::StdRng;
475
476    fn seeded_rng() -> StdRng {
477        StdRng::seed_from_u64(42)
478    }
479
480    // ---- parse_expr tests ----
481
482    #[test]
483    fn parse_simple_dice() {
484        let expr = parse_expr("2d10").unwrap();
485        assert_eq!(expr.modifier, Modifier::None);
486        assert_eq!(expr.groups.len(), 1);
487        assert_eq!(expr.groups[0].count, 2);
488        assert_eq!(expr.groups[0].sides, 10);
489        assert_eq!(expr.flat_bonus, 0);
490        assert_eq!(expr.groups[0].keep, Keep::All);
491    }
492
493    #[test]
494    fn parse_single_die_shorthand() {
495        let expr = parse_expr("d20").unwrap();
496        assert_eq!(expr.groups[0].count, 1);
497        assert_eq!(expr.groups[0].sides, 20);
498    }
499
500    #[test]
501    fn parse_with_positive_bonus() {
502        let expr = parse_expr("2d10+4").unwrap();
503        assert_eq!(expr.flat_bonus, 4);
504    }
505
506    #[test]
507    fn parse_with_negative_bonus() {
508        let expr = parse_expr("d20-3").unwrap();
509        assert_eq!(expr.flat_bonus, -3);
510    }
511
512    #[test]
513    fn parse_advantage() {
514        let expr = parse_expr("adv d20+5").unwrap();
515        assert_eq!(expr.modifier, Modifier::Advantage);
516        assert_eq!(expr.groups[0].count, 1);
517        assert_eq!(expr.groups[0].sides, 20);
518        assert_eq!(expr.flat_bonus, 5);
519    }
520
521    #[test]
522    fn parse_disadvantage() {
523        let expr = parse_expr("dis d20-1").unwrap();
524        assert_eq!(expr.modifier, Modifier::Disadvantage);
525        assert_eq!(expr.flat_bonus, -1);
526    }
527
528    #[test]
529    fn parse_multiple_groups() {
530        let expr = parse_expr("2d6+1d4+3").unwrap();
531        assert_eq!(expr.groups.len(), 2);
532        assert_eq!(expr.groups[0].count, 2);
533        assert_eq!(expr.groups[0].sides, 6);
534        assert_eq!(expr.groups[1].count, 1);
535        assert_eq!(expr.groups[1].sides, 4);
536        assert_eq!(expr.flat_bonus, 3);
537    }
538
539    #[test]
540    fn parse_case_insensitive() {
541        let expr = parse_expr("ADV D20+5").unwrap();
542        assert_eq!(expr.modifier, Modifier::Advantage);
543    }
544
545    #[test]
546    fn parse_with_whitespace() {
547        let expr = parse_expr("  2d10 + 4  ").unwrap();
548        assert_eq!(expr.groups[0].count, 2);
549        assert_eq!(expr.flat_bonus, 4);
550    }
551
552    #[test]
553    fn parse_no_dice_error() {
554        assert!(parse_expr("42").is_err());
555    }
556
557    #[test]
558    fn parse_negative_dice_group_error() {
559        assert!(parse_expr("d20-2d6").is_err());
560    }
561
562    #[test]
563    fn parse_invalid_sides_error() {
564        assert!(parse_expr("2dx").is_err());
565    }
566
567    #[test]
568    fn parse_empty_error() {
569        assert!(parse_expr("").is_err());
570    }
571
572    #[test]
573    fn parse_zero_sides_error() {
574        assert_eq!(
575            parse_expr("2d0"),
576            Err(ParseError::InvalidSides(
577                "sides must be at least 1".to_string()
578            ))
579        );
580    }
581
582    #[test]
583    fn parse_zero_count_error() {
584        assert_eq!(
585            parse_expr("0d6"),
586            Err(ParseError::InvalidDiceCount(
587                "count must be at least 1".to_string()
588            ))
589        );
590    }
591
592    #[test]
593    fn parse_keep_highest() {
594        let expr = parse_expr("4d6kh3").unwrap();
595        assert_eq!(expr.groups[0].count, 4);
596        assert_eq!(expr.groups[0].sides, 6);
597        assert_eq!(expr.groups[0].keep, Keep::Highest(3));
598    }
599
600    #[test]
601    fn parse_keep_lowest() {
602        let expr = parse_expr("4d6kl1").unwrap();
603        assert_eq!(expr.groups[0].keep, Keep::Lowest(1));
604    }
605
606    #[test]
607    fn parse_keep_with_bonus() {
608        let expr = parse_expr("4d6kh3+2").unwrap();
609        assert_eq!(expr.groups[0].keep, Keep::Highest(3));
610        assert_eq!(expr.flat_bonus, 2);
611    }
612
613    // ---- Display tests ----
614
615    #[test]
616    fn display_simple() {
617        let expr = parse_expr("2d10+4").unwrap();
618        assert_eq!(expr.to_string(), "2d10+4");
619    }
620
621    #[test]
622    fn display_advantage() {
623        let expr = parse_expr("adv d20+5").unwrap();
624        assert_eq!(expr.to_string(), "adv 1d20+5");
625    }
626
627    #[test]
628    fn display_negative_bonus() {
629        let expr = parse_expr("d20-3").unwrap();
630        assert_eq!(expr.to_string(), "1d20-3");
631    }
632
633    #[test]
634    fn display_no_bonus() {
635        let expr = parse_expr("d20").unwrap();
636        assert_eq!(expr.to_string(), "1d20");
637    }
638
639    #[test]
640    fn display_keep_highest() {
641        let expr = parse_expr("4d6kh3").unwrap();
642        assert_eq!(expr.to_string(), "4d6kh3");
643    }
644
645    // ---- Rolling tests ----
646
647    #[test]
648    fn roll_once_within_bounds() {
649        let expr = parse_expr("2d6").unwrap();
650        let mut rng = seeded_rng();
651        for _ in 0..100 {
652            let (total, rolls) = roll_once(&expr, &mut rng);
653            assert!(total >= 2 && total <= 12);
654            assert_eq!(rolls.len(), 1);
655            assert_eq!(rolls[0].len(), 2);
656            for &r in &rolls[0] {
657                assert!(r >= 1 && r <= 6);
658            }
659        }
660    }
661
662    #[test]
663    fn roll_once_applies_flat_bonus() {
664        let expr = parse_expr("1d6+10").unwrap();
665        let mut rng = seeded_rng();
666        for _ in 0..100 {
667            let (total, _) = roll_once(&expr, &mut rng);
668            assert!(total >= 11 && total <= 16);
669        }
670    }
671
672    #[test]
673    fn roll_once_keep_highest() {
674        let expr = parse_expr("4d6kh3").unwrap();
675        let mut rng = seeded_rng();
676        for _ in 0..100 {
677            let (total, rolls) = roll_once(&expr, &mut rng);
678            // Only 3 dice kept
679            assert_eq!(rolls[0].len(), 3);
680            // Kept dice are sorted descending
681            assert!(rolls[0].windows(2).all(|w| w[0] >= w[1]));
682            // Total equals sum of kept dice
683            let sum: i64 = rolls[0].iter().map(|&r| i64::from(r)).sum();
684            assert_eq!(total, sum);
685            // Each die is within range
686            assert!(total >= 3 && total <= 18);
687        }
688    }
689
690    #[test]
691    fn roll_once_keep_lowest() {
692        let expr = parse_expr("4d6kl1").unwrap();
693        let mut rng = seeded_rng();
694        for _ in 0..100 {
695            let (total, rolls) = roll_once(&expr, &mut rng);
696            assert_eq!(rolls[0].len(), 1);
697            assert!(total >= 1 && total <= 6);
698        }
699    }
700
701    #[test]
702    fn roll_value_deterministic_with_seed() {
703        let expr = parse_expr("d20").unwrap();
704        let mut rng1 = seeded_rng();
705        let mut rng2 = seeded_rng();
706        let a = roll_value(&expr, &mut rng1);
707        let b = roll_value(&expr, &mut rng2);
708        assert_eq!(a, b);
709    }
710
711    #[test]
712    fn roll_value_advantage_takes_higher() {
713        let expr = parse_expr("adv d20").unwrap();
714        let mut rng = seeded_rng();
715        for _ in 0..100 {
716            let adv = roll_value(&expr, &mut rng);
717            assert!(adv >= 1 && adv <= 20);
718        }
719    }
720
721    #[test]
722    fn roll_value_disadvantage_takes_lower() {
723        let expr = parse_expr("dis d20").unwrap();
724        let mut rng = seeded_rng();
725        for _ in 0..100 {
726            let dis = roll_value(&expr, &mut rng);
727            assert!(dis >= 1 && dis <= 20);
728        }
729    }
730
731    #[test]
732    fn advantage_greater_equal_disadvantage() {
733        let adv_expr = parse_expr("adv d20").unwrap();
734        let dis_expr = parse_expr("dis d20").unwrap();
735        let mut rng = seeded_rng();
736        let mut adv_total: i64 = 0;
737        let mut dis_total: i64 = 0;
738        let n = 10_000;
739        for _ in 0..n {
740            adv_total += roll_value(&adv_expr, &mut rng);
741            dis_total += roll_value(&dis_expr, &mut rng);
742        }
743        assert!(adv_total > dis_total);
744    }
745
746    // ---- roll_verbose tests ----
747
748    #[test]
749    fn roll_verbose_includes_rolls() {
750        let expr = parse_expr("2d6").unwrap();
751        let mut rng = seeded_rng();
752        let (_, detail) = roll_verbose(&expr, &mut rng);
753        assert!(detail.starts_with('['));
754        assert!(detail.contains(']'));
755    }
756
757    #[test]
758    fn roll_verbose_advantage_shows_vs() {
759        let expr = parse_expr("adv d20").unwrap();
760        let mut rng = seeded_rng();
761        let (_, detail) = roll_verbose(&expr, &mut rng);
762        assert!(detail.contains("vs"));
763    }
764
765    // ---- format_rolls tests ----
766
767    #[test]
768    fn format_rolls_single_group() {
769        assert_eq!(format_rolls(&[vec![3, 5]]), "[3, 5]");
770    }
771
772    #[test]
773    fn format_rolls_multiple_groups() {
774        assert_eq!(format_rolls(&[vec![3, 5], vec![2]]), "[3, 5] + [2]");
775    }
776
777    // ---- roll_stats tests ----
778
779    #[test]
780    fn roll_stats_d6() {
781        let expr = parse_expr("d6").unwrap();
782        let stats = roll_stats(&expr);
783        assert_eq!(stats.min, 1);
784        assert_eq!(stats.max, 6);
785        assert!((stats.mean - 3.5).abs() < f64::EPSILON);
786    }
787
788    #[test]
789    fn roll_stats_with_bonus() {
790        let expr = parse_expr("2d6+5").unwrap();
791        let stats = roll_stats(&expr);
792        assert_eq!(stats.min, 7);
793        assert_eq!(stats.max, 17);
794        assert!((stats.mean - 12.0).abs() < f64::EPSILON);
795    }
796
797    #[test]
798    fn roll_stats_keep_highest() {
799        // 4d6kh3: keep 3 dice
800        let expr = parse_expr("4d6kh3").unwrap();
801        let stats = roll_stats(&expr);
802        assert_eq!(stats.min, 3);
803        assert_eq!(stats.max, 18);
804    }
805
806    // ---- compute_distribution tests ----
807
808    #[test]
809    fn distribution_d6_has_all_values() {
810        let expr = parse_expr("d6").unwrap();
811        let mut rng = seeded_rng();
812        let counts = compute_distribution(&expr, 100_000, &mut rng);
813        for v in 1..=6 {
814            assert!(counts.contains_key(&v), "missing value {v}");
815        }
816        assert!(!counts.contains_key(&0));
817        assert!(!counts.contains_key(&7));
818    }
819
820    #[test]
821    fn distribution_counts_sum_to_sims() {
822        let expr = parse_expr("2d6+3").unwrap();
823        let mut rng = seeded_rng();
824        let sims = 50_000;
825        let counts = compute_distribution(&expr, sims, &mut rng);
826        let total: u64 = counts.values().sum();
827        assert_eq!(total, sims);
828    }
829
830    // ---- render_distribution tests ----
831
832    #[test]
833    fn render_distribution_contains_all_values() {
834        let expr = parse_expr("d6").unwrap();
835        let mut counts = BTreeMap::new();
836        for v in 1..=6 {
837            counts.insert(v, 1000);
838        }
839        let output = render_distribution(&expr, &counts, 6000);
840        assert!(output.starts_with("Distribution for"));
841        for v in 1..=6 {
842            assert!(output.contains(&format!("{v} |")));
843        }
844    }
845
846    #[test]
847    fn render_distribution_percentages() {
848        let expr = parse_expr("d6").unwrap();
849        let mut counts = BTreeMap::new();
850        counts.insert(1, 500);
851        counts.insert(2, 500);
852        let output = render_distribution(&expr, &counts, 1000);
853        assert!(output.contains("50.0%"));
854    }
855
856    // ---- exact_probability tests ----
857
858    #[test]
859    fn exact_probability_d6_at_least_1_is_100_percent() {
860        let expr = parse_expr("d6").unwrap();
861        let p = exact_probability(&expr, 1).unwrap();
862        assert!((p - 1.0).abs() < f64::EPSILON);
863    }
864
865    #[test]
866    fn exact_probability_d6_at_least_7_is_0_percent() {
867        let expr = parse_expr("d6").unwrap();
868        let p = exact_probability(&expr, 7).unwrap();
869        assert!(p.abs() < f64::EPSILON);
870    }
871
872    #[test]
873    fn exact_probability_d6_at_least_4_is_50_percent() {
874        let expr = parse_expr("d6").unwrap();
875        let p = exact_probability(&expr, 4).unwrap();
876        assert!((p - 0.5).abs() < 1e-10);
877    }
878
879    #[test]
880    fn exact_probability_returns_none_for_advantage() {
881        let expr = parse_expr("adv d20").unwrap();
882        assert!(exact_probability(&expr, 15).is_none());
883    }
884
885    #[test]
886    fn exact_probability_returns_none_for_keep() {
887        let expr = parse_expr("4d6kh3").unwrap();
888        assert!(exact_probability(&expr, 10).is_none());
889    }
890
891    #[test]
892    fn exact_probability_2d6_known_value() {
893        // P(2d6 >= 7) = 21/36 = 7/12
894        let expr = parse_expr("2d6").unwrap();
895        let p = exact_probability(&expr, 7).unwrap();
896        assert!((p - 7.0 / 12.0).abs() < 1e-10);
897    }
898
899    #[test]
900    fn exact_probability_with_flat_bonus() {
901        // P(d6+3 >= 7) = P(d6 >= 4) = 3/6 = 0.5
902        let expr = parse_expr("d6+3").unwrap();
903        let p = exact_probability(&expr, 7).unwrap();
904        assert!((p - 0.5).abs() < 1e-10);
905    }
906
907    #[test]
908    fn exact_probability_sums_to_one() {
909        // Sum of P(2d6 >= k) - P(2d6 >= k+1) across all outcomes should equal 1.
910        // Equivalently, P(2d6 >= 2) should be 1.0.
911        let expr = parse_expr("2d6").unwrap();
912        let p = exact_probability(&expr, 2).unwrap();
913        assert!((p - 1.0).abs() < 1e-10);
914    }
915
916    // ---- ParseError display tests ----
917
918    #[test]
919    fn parse_error_display_no_dice() {
920        assert_eq!(ParseError::NoDiceFound.to_string(), "no dice found in expression");
921    }
922
923    #[test]
924    fn parse_error_display_negative_group() {
925        assert_eq!(
926            ParseError::NegativeDiceGroup.to_string(),
927            "negative dice groups are not supported"
928        );
929    }
930
931    #[test]
932    fn parse_error_display_invalid_token() {
933        assert_eq!(
934            ParseError::InvalidToken("foo".to_string()).to_string(),
935            "invalid token: 'foo'"
936        );
937    }
938
939    #[test]
940    fn parse_error_display_invalid_sides() {
941        assert_eq!(
942            ParseError::InvalidSides("sides must be at least 1".to_string()).to_string(),
943            "invalid sides: 'sides must be at least 1'"
944        );
945    }
946
947    #[test]
948    fn parse_error_display_invalid_count() {
949        assert_eq!(
950            ParseError::InvalidDiceCount("count must be at least 1".to_string()).to_string(),
951            "invalid dice count: 'count must be at least 1'"
952        );
953    }
954
955    // ---- roll_verbose with keep tests ----
956
957    #[test]
958    fn roll_verbose_keep_shows_kept_count() {
959        // 4d6kh3 keeps 3 dice; the detail should contain exactly 3 numbers in brackets
960        let expr = parse_expr("4d6kh3").unwrap();
961        let mut rng = seeded_rng();
962        for _ in 0..20 {
963            let (_, detail) = roll_verbose(&expr, &mut rng);
964            // Detail looks like "[a, b, c]"; split on ',' to count dice
965            let inner = detail.trim_start_matches('[').trim_end_matches(']');
966            assert_eq!(inner.split(',').count(), 3, "expected 3 kept dice, got: {detail}");
967        }
968    }
969
970    // ---- estimate_probability tests ----
971
972    #[test]
973    fn probability_d6_at_least_1_is_100_percent() {
974        let expr = parse_expr("d6").unwrap();
975        let mut rng = seeded_rng();
976        let p = estimate_probability(&expr, 1, 10_000, &mut rng);
977        assert!((p - 1.0).abs() < f64::EPSILON);
978    }
979
980    #[test]
981    fn probability_d6_at_least_7_is_0_percent() {
982        let expr = parse_expr("d6").unwrap();
983        let mut rng = seeded_rng();
984        let p = estimate_probability(&expr, 7, 10_000, &mut rng);
985        assert!(p.abs() < f64::EPSILON);
986    }
987
988    #[test]
989    fn probability_d6_at_least_4_roughly_50_percent() {
990        let expr = parse_expr("d6").unwrap();
991        let mut rng = seeded_rng();
992        let p = estimate_probability(&expr, 4, 100_000, &mut rng);
993        assert!((p - 0.5).abs() < 0.02);
994    }
995}