rpg_dice_roller/evaluate/
expression.rs

1use rand::Rng;
2
3use crate::parse::{Dice, DiceKind, Expression, MathFn1, MathFn2, Modifier, Operator};
4
5use super::{
6    group_rolls::apply_group_modifiers,
7    roll::{GroupRollOutput, RollOutput},
8};
9
10impl Expression {
11    pub fn roll(self, rng: &mut impl Rng) -> RolledExpression {
12        match self {
13            Expression::Value(float) => RolledExpression::Value(float),
14            Expression::DiceFudge1(qty, mods) => {
15                let dice = dice_from_expression(qty.map(|q| *q), DiceKind::Fudge1, mods, rng);
16                RolledExpression::DiceRoll(dice.roll_all_with(rng))
17            }
18            Expression::DiceFudge2(qty, mods) => {
19                let dice = dice_from_expression(qty.map(|q| *q), DiceKind::Fudge2, mods, rng);
20                RolledExpression::DiceRoll(dice.roll_all_with(rng))
21            }
22            Expression::DicePercentile(qty, mods) => {
23                let dice =
24                    dice_from_expression(qty.map(|q| *q), DiceKind::Standard(100), mods, rng);
25                RolledExpression::DiceRoll(dice.roll_all_with(rng))
26            }
27            Expression::DiceStandard(qty, sides, mods) => {
28                // Max sides is i32::MAX
29                let sides = sides.roll(rng).value().round() as i32;
30
31                // Make sure sides aren't negative, not sure it makes sense for
32                // the minimum amount of sides to be 1, but I'd have to make
33                // sure that parsing also doesn't accept a 1 in that case.
34                let sides = sides.max(1);
35
36                let dice =
37                    dice_from_expression(qty.map(|q| *q), DiceKind::Standard(sides), mods, rng);
38                RolledExpression::DiceRoll(dice.roll_all_with(rng))
39            }
40            Expression::Parens(expr) => expr.roll(rng),
41            Expression::Group(expressions, modifiers) => {
42                let rolled_expressions =
43                    expressions.into_iter().map(|expr| expr.roll(rng)).collect();
44                let output = apply_group_modifiers(rolled_expressions, &modifiers);
45                RolledExpression::Group(output)
46            }
47            Expression::Infix(op, lhs, rhs) => {
48                RolledExpression::Infix(op, Box::new(lhs.roll(rng)), Box::new(rhs.roll(rng)))
49            }
50            Expression::Fn1(f, arg) => RolledExpression::Fn1(f, Box::new(arg.roll(rng))),
51            Expression::Fn2(f, arg1, arg2) => {
52                RolledExpression::Fn2(f, Box::new(arg1.roll(rng)), Box::new(arg2.roll(rng)))
53            }
54        }
55    }
56}
57
58#[derive(Debug, Clone, PartialEq)]
59pub enum RolledExpression {
60    DiceRoll(RollOutput),
61    Group(GroupRollOutput),
62    Value(f64),
63    Parens(Box<RolledExpression>),
64    Infix(Operator, Box<RolledExpression>, Box<RolledExpression>),
65    Fn1(MathFn1, Box<RolledExpression>),
66    Fn2(MathFn2, Box<RolledExpression>, Box<RolledExpression>),
67}
68
69impl RolledExpression {
70    pub fn value(&self) -> f64 {
71        match self {
72            RolledExpression::DiceRoll(roll_output) => roll_output.value(),
73            RolledExpression::Group(group_output) => group_output.value(),
74            RolledExpression::Value(float) => *float,
75            RolledExpression::Parens(expr) => expr.value(),
76            RolledExpression::Infix(operator, lhs, rhs) => operator.evaluate(lhs, rhs),
77            RolledExpression::Fn1(f, arg) => f.evaluate(arg),
78            RolledExpression::Fn2(f, arg1, arg2) => f.evaluate(arg1, arg2),
79        }
80    }
81}
82
83fn dice_from_expression(
84    quantity: Option<Expression>,
85    kind: DiceKind,
86    modifiers: Vec<Modifier>,
87    rng: &mut impl Rng,
88) -> Dice {
89    let quantity = match quantity {
90        Some(expression) => expression.roll(rng).value().round() as u32,
91        None => 1,
92    };
93
94    Dice::new(quantity, kind, &modifiers)
95}
96
97impl Operator {
98    pub fn evaluate(&self, lhs: &RolledExpression, rhs: &RolledExpression) -> f64 {
99        match self {
100            Operator::Add => lhs.value() + rhs.value(),
101            Operator::Sub => lhs.value() - rhs.value(),
102            Operator::Mul => lhs.value() * rhs.value(),
103            Operator::Div => lhs.value() / rhs.value(),
104            Operator::Pow => lhs.value() * rhs.value(),
105            Operator::Rem => lhs.value() % rhs.value(),
106        }
107    }
108}
109
110impl MathFn1 {
111    pub fn evaluate(&self, arg: &RolledExpression) -> f64 {
112        let arg = arg.value();
113        match self {
114            MathFn1::Abs => arg.abs(),
115            MathFn1::Floor => arg.floor(),
116            MathFn1::Ceil => arg.ceil(),
117            MathFn1::Round => arg.round(),
118            MathFn1::Sign => arg.signum(),
119            MathFn1::Sqrt => arg.sqrt(),
120            MathFn1::Log => arg.log(std::f64::consts::E),
121            MathFn1::Exp => arg.exp(),
122            MathFn1::Sin => arg.sin(),
123            MathFn1::Cos => arg.cos(),
124            MathFn1::Tan => arg.tan(),
125        }
126    }
127}
128
129impl MathFn2 {
130    pub fn evaluate(&self, arg1: &RolledExpression, arg2: &RolledExpression) -> f64 {
131        let arg1 = arg1.value();
132        let arg2 = arg2.value();
133
134        match self {
135            MathFn2::Min => arg1.min(arg2),
136            MathFn2::Max => arg1.max(arg2),
137            MathFn2::Pow => arg1.powf(arg2),
138        }
139    }
140}
141
142impl std::fmt::Display for Expression {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            Expression::Value(val) => write!(f, "{val}"),
146            Expression::DiceStandard(None, sides, mods) => {
147                write!(f, "d{sides}{}", Modifier::join_all(mods))
148            }
149            Expression::DiceStandard(Some(qty), sides, mods) => {
150                write!(f, "{qty}d{sides}{}", Modifier::join_all(mods))
151            }
152            Expression::DiceFudge1(None, mods) => write!(f, "dF.1{}", Modifier::join_all(mods)),
153            Expression::DiceFudge1(Some(qty), mods) => {
154                write!(f, "{qty}dF.1{}", Modifier::join_all(mods))
155            }
156            Expression::DiceFudge2(None, mods) => write!(f, "dF.2{}", Modifier::join_all(mods)),
157            Expression::DiceFudge2(Some(qty), mods) => {
158                write!(f, "{qty}dF.2{}", Modifier::join_all(mods))
159            }
160            Expression::DicePercentile(None, mods) => write!(f, "d%{}", Modifier::join_all(mods)),
161            Expression::DicePercentile(Some(qty), mods) => {
162                write!(f, "{qty}d%{}", Modifier::join_all(mods))
163            }
164            Expression::Parens(expr) => write!(f, "({expr})"),
165            Expression::Group(expressions, mods) => {
166                let expressions = expressions
167                    .iter()
168                    .map(|e| e.to_string())
169                    .collect::<Vec<_>>()
170                    .join(", ");
171                write!(f, "{{{expressions}}}{}", Modifier::join_all(mods))
172            }
173            Expression::Infix(op, expr1, expr2) => write!(f, "{expr1} {op} {expr2}"),
174            // no parens on the function call because there's always a parens expression following the function call
175            Expression::Fn1(func, arg) => write!(f, "{func}{arg}"),
176            Expression::Fn2(func, arg1, arg2) => write!(f, "{func}({arg1}, {arg2})"),
177        }
178    }
179}
180
181impl std::fmt::Display for RolledExpression {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        match self {
184            RolledExpression::DiceRoll(roll_output) => write!(f, "{roll_output}"),
185            RolledExpression::Group(group_output) => write!(f, "{{{group_output}}}"),
186            RolledExpression::Value(float) => write!(f, "{float}"),
187            RolledExpression::Parens(expr) => write!(f, "({expr})"),
188            RolledExpression::Infix(op, lhs, rhs) => write!(f, "{lhs} {op} {rhs}"),
189            RolledExpression::Fn1(func, arg) => write!(f, "{func}{arg}"),
190            RolledExpression::Fn2(func, arg1, arg2) => write!(f, "{func}({arg1}, {arg2})"),
191        }
192    }
193}
194
195impl std::fmt::Display for Operator {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        let str = match self {
198            Operator::Add => "+",
199            Operator::Sub => "-",
200            Operator::Mul => "*",
201            Operator::Div => "/",
202            Operator::Rem => "%",
203            Operator::Pow => "**",
204        };
205        write!(f, "{str}")
206    }
207}
208
209impl std::fmt::Display for MathFn1 {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        let str = match self {
212            MathFn1::Abs => "abs",
213            MathFn1::Floor => "floor",
214            MathFn1::Ceil => "ceil",
215            MathFn1::Round => "round",
216            MathFn1::Sign => "sign",
217            MathFn1::Sqrt => "sqrt",
218            MathFn1::Log => "log",
219            MathFn1::Exp => "exp",
220            MathFn1::Sin => "sin",
221            MathFn1::Cos => "cos",
222            MathFn1::Tan => "tan",
223        };
224        write!(f, "{str}")
225    }
226}
227
228impl std::fmt::Display for MathFn2 {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        let str = match self {
231            MathFn2::Min => "min",
232            MathFn2::Max => "max",
233            MathFn2::Pow => "pow",
234        };
235        write!(f, "{str}")
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use crate::{parse::Operator, Expression::*};
242    use std::{f64::INFINITY, u64};
243
244    struct MaxValRng;
245
246    impl rand::RngCore for MaxValRng {
247        fn next_u32(&mut self) -> u32 {
248            self.next_u64() as u32
249        }
250
251        fn next_u64(&mut self) -> u64 {
252            u64::MAX
253        }
254
255        fn fill_bytes(&mut self, _dest: &mut [u8]) {
256            unimplemented!()
257        }
258
259        fn try_fill_bytes(&mut self, _dest: &mut [u8]) -> Result<(), rand::Error> {
260            unimplemented!();
261        }
262    }
263
264    #[test]
265    fn test_infinity_sides_dice_max_is_i32_max() {
266        let sides = Value(INFINITY);
267        let expression = DiceStandard(None, Box::new(sides), [].to_vec());
268        let output = expression.roll(&mut MaxValRng);
269        assert_eq!(output.value(), i32::MAX as f64);
270    }
271
272    #[test]
273    fn test_greater_than_i32_max_sides_dice_max_is_i32_max() {
274        let sides = Value(i32::MAX as f64 + 1.0);
275        let expression = DiceStandard(None, Box::new(sides), [].to_vec());
276        let output = expression.roll(&mut MaxValRng);
277        assert_eq!(output.value(), i32::MAX as f64);
278    }
279
280    #[test]
281    fn test_divide_by_zero_max_is_i32_max() {
282        let sides = Infix(Operator::Div, Box::new(Value(3.0)), Box::new(Value(0.0)));
283        let expression = DiceStandard(None, Box::new(sides), [].to_vec());
284        let output = expression.roll(&mut MaxValRng);
285        assert_eq!(output.value(), i32::MAX as f64);
286    }
287}