pips/
expression.rs

1//! Expression Module
2
3use crate::traits::PlotResult;
4use crate::traits::PlotTable;
5use crate::traits::RollResult;
6use rand::Rng;
7use std::collections::HashMap;
8
9use crate::operators::{
10    advantage, difference, disadvantage, divide, equal_to, greater_than, greater_than_or_equal_to,
11    less_than, less_than_or_equal_to, multiply, sum, BinaryOperator,
12};
13use crate::traits::Rollable;
14
15#[derive(Clone, Debug, PartialEq)]
16pub enum Comparison {
17    GreaterThan,
18    GreaterThanOrEqualTo,
19    LessThan,
20    LessThanOrEqualTo,
21    EqualTo,
22}
23
24use Comparison::*;
25
26/// Represents a dice roll expression
27#[derive(Clone, Debug, PartialEq)]
28pub enum Expression {
29    Die(u32),
30    Constant(i32),
31
32    Sum(Box<Expression>, Box<Expression>),
33    Diff(Box<Expression>, Box<Expression>),
34    Multiply(Box<Expression>, Box<Expression>),
35    Divide(Box<Expression>, Box<Expression>),
36    Advantage(Box<Expression>),
37    Disadvantage(Box<Expression>),
38    Compare(Box<Expression>, Box<Expression>, Comparison),
39}
40
41use Expression::*;
42
43impl Expression {
44    /// retrieve the operation encapsulated by the given `Expression`,
45    /// represented by a binary operator and left/right expressions
46    fn get_operation(&self) -> Option<(BinaryOperator, &Box<Expression>, &Box<Expression>)> {
47        match self {
48            Constant(_) => None,
49            Die(_) => None,
50
51            Sum(left, right) => Some((sum, left, right)),
52            Diff(left, right) => Some((difference, left, right)),
53            Multiply(left, right) => Some((multiply, left, right)),
54            Divide(left, right) => Some((divide, left, right)),
55            Advantage(expr) => Some((advantage, expr, expr)),
56            Disadvantage(expr) => Some((disadvantage, expr, expr)),
57            Compare(left, right, comparison) => {
58                let compare = match comparison {
59                    GreaterThan => greater_than,
60                    GreaterThanOrEqualTo => greater_than_or_equal_to,
61                    LessThan => less_than,
62                    LessThanOrEqualTo => less_than_or_equal_to,
63                    EqualTo => equal_to,
64                };
65                Some((compare, left, right))
66            }
67        }
68    }
69}
70
71impl Rollable for Expression {
72    /// Get a single value from the roll expression
73    fn roll(&self) -> RollResult {
74        // get the root cases out of the way
75        if let Constant(num) = self {
76            return *num;
77        }
78        if let Die(max) = self {
79            return rand::thread_rng().gen_range(1, *max as RollResult + 1);
80        }
81
82        let (operator, left, right) = self
83            .get_operation()
84            .expect("expression does not represent an operation");
85
86        (operator)(&left.roll(), &right.roll())
87    }
88
89    /// Create a list of all possible outcomes and their possibility
90    fn plot(&self) -> PlotResult {
91        // get the root cases out of the way
92        if let Constant(num) = self {
93            return PlotResult {
94                total: 1.0,
95                plot: [(*num, 1.0)].iter().cloned().collect(),
96            };
97        }
98        if let Die(num) = self {
99            let total = *num as f32;
100            return PlotResult {
101                total,
102                plot: (1..num + 1)
103                    .map(|i| (i as RollResult, 1.0 / total))
104                    .collect(),
105            };
106        }
107
108        // handle the more complicated expressions
109        let (operator, left, right) = self
110            .get_operation()
111            .expect("expression does not represent an operation");
112
113        let left = left.plot();
114        let right = right.plot();
115
116        let mut product: PlotTable = HashMap::new();
117
118        left.plot
119            .iter()
120            .flat_map(|(left_value, left_chance)| {
121                right.plot.iter().map(move |(right_value, right_chance)| {
122                    let value = (operator)(left_value, right_value);
123                    (value, left_chance * right_chance)
124                })
125            })
126            .for_each(|(value, count)| {
127                *product.entry(value).or_insert(0.0) += count;
128            });
129
130        PlotResult {
131            total: left.total * right.total,
132            plot: product,
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    impl PlotResult {
142        /// de-normalize the table of possible outcomes
143        pub fn simplify(&self) -> HashMap<i32, i32> {
144            self.plot
145                .iter()
146                .map(|(value, chance)| {
147                    let outcomes = (chance * self.total) as i32;
148                    (*value, outcomes)
149                })
150                .collect()
151        }
152    }
153
154    #[test]
155    fn simplify_produces_correct_table() {
156        let plot_result = PlotResult {
157            total: 10.0,
158            plot: [(1, 0.1), (2, 0.2), (3, 0.3), (4, 0.4)]
159                .iter()
160                .cloned()
161                .collect(),
162        };
163        let expected: HashMap<i32, i32> =
164            [(1, 1), (2, 2), (3, 3), (4, 4)].iter().cloned().collect();
165
166        let actual = plot_result.simplify();
167
168        assert_eq!(expected, actual);
169    }
170
171    #[test]
172    fn multiply_produces_correct_plot() {
173        let expression =
174            Expression::Multiply(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
175        let expected: HashMap<i32, i32> = [
176            (1, 1),
177            (2, 2),
178            (3, 2),
179            (4, 3),
180            (6, 2),
181            (8, 2),
182            (9, 1),
183            (12, 2),
184            (16, 1),
185        ]
186        .iter()
187        .cloned()
188        .collect();
189
190        let actual = expression.plot().simplify();
191
192        assert_eq!(expected, actual);
193    }
194
195    #[test]
196    fn divide_produces_correct_plot() {
197        let expression =
198            Expression::Divide(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
199        // 1 1 -> 1
200        // 1 2 -> 0
201        // 1 3 -> 0
202        // 1 4 -> 0
203        // 2 1 -> 2
204        // 2 2 -> 1
205        // 2 3 -> 0
206        // 2 4 -> 0
207        // 3 1 -> 3
208        // 3 2 -> 1
209        // 3 3 -> 1
210        // 3 4 -> 0
211        // 4 1 -> 4
212        // 4 2 -> 2
213        // 4 3 -> 1
214        // 4 4 -> 1
215        let expected: HashMap<i32, i32> = [(0, 6), (1, 6), (2, 2), (3, 1), (4, 1)]
216            .iter()
217            .cloned()
218            .collect();
219
220        let actual = expression.plot().simplify();
221
222        assert_eq!(expected, actual);
223    }
224
225    #[test]
226    fn sum_produces_correct_plot() {
227        let expression =
228            Expression::Sum(Box::new(Expression::Die(6)), Box::new(Expression::Die(6)));
229        let expected: HashMap<i32, i32> = [
230            (2, 1),
231            (3, 2),
232            (4, 3),
233            (5, 4),
234            (6, 5),
235            (7, 6),
236            (8, 5),
237            (9, 4),
238            (10, 3),
239            (11, 2),
240            (12, 1),
241        ]
242        .iter()
243        .cloned()
244        .collect();
245
246        let actual = expression.plot().simplify();
247
248        assert_eq!(expected, actual);
249    }
250
251    #[test]
252    fn difference_produces_correct_plot() {
253        let expression =
254            Expression::Diff(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
255        // 1 1 -> 0
256        // 1 2 -> -1
257        // 1 3 -> -2
258        // 1 4 -> -3
259        // 2 1 -> 1
260        // 2 2 -> 0
261        // 2 3 -> -1
262        // 2 4 -> -2
263        // 3 1 -> 2
264        // 3 2 -> 1
265        // 3 3 -> 0
266        // 3 4 -> -1
267        // 4 1 -> 3
268        // 4 2 -> 2
269        // 4 3 -> 1
270        // 4 4 -> 0
271        let expected: HashMap<i32, i32> =
272            [(-3, 1), (-2, 2), (-1, 3), (0, 4), (1, 3), (2, 2), (3, 1)]
273                .iter()
274                .cloned()
275                .collect();
276
277        let actual = expression.plot().simplify();
278
279        assert_eq!(expected, actual);
280    }
281
282    #[test]
283    fn advantage_produces_correct_plot() {
284        let expression = Expression::Advantage(Box::new(Expression::Die(4)));
285        // 1 1 -> 1
286        // 1 2 -> 2
287        // 1 3 -> 3
288        // 1 4 -> 4
289        // 2 1 -> 2
290        // 2 2 -> 2
291        // 2 3 -> 3
292        // 2 4 -> 4
293        // 3 1 -> 3
294        // 3 2 -> 3
295        // 3 3 -> 3
296        // 3 4 -> 4
297        // 4 1 -> 4
298        // 4 2 -> 4
299        // 4 3 -> 4
300        // 4 4 -> 4
301        let expected: HashMap<i32, i32> =
302            [(1, 1), (2, 3), (3, 5), (4, 7)].iter().cloned().collect();
303
304        let actual = expression.plot().simplify();
305
306        assert_eq!(expected, actual);
307    }
308
309    #[test]
310    fn disadvantage_produces_correct_plot() {
311        let expression = Expression::Disadvantage(Box::new(Expression::Die(4)));
312        // 1 1 -> 1
313        // 1 2 -> 1
314        // 1 3 -> 1
315        // 1 4 -> 1
316        // 2 1 -> 1
317        // 2 2 -> 2
318        // 2 3 -> 2
319        // 2 4 -> 2
320        // 3 1 -> 1
321        // 3 2 -> 2
322        // 3 3 -> 3
323        // 3 4 -> 3
324        // 4 1 -> 1
325        // 4 2 -> 2
326        // 4 3 -> 3
327        // 4 4 -> 4
328        let expected: HashMap<i32, i32> =
329            [(1, 7), (2, 5), (3, 3), (4, 1)].iter().cloned().collect();
330
331        let actual = expression.plot().simplify();
332
333        assert_eq!(expected, actual);
334    }
335
336    #[test]
337    #[ignore = "not implemented"]
338    fn contest_produces_correct_plot() {
339        let expression = Expression::Compare(
340            Box::new(Expression::Die(2)),
341            Box::new(Expression::Die(3)),
342            Comparison::GreaterThan,
343        );
344
345        // 1 1 -> 0
346        // 1 2 -> 0
347        // 1 3 -> 0
348        // 2 1 -> 1
349        // 2 2 -> 0
350        // 2 3 -> 0
351        let expected: HashMap<i32, i32> = [(-1, 3), (0, 2), (1, 1)].iter().cloned().collect();
352
353        let actual = expression.plot().simplify();
354
355        assert_eq!(expected, actual);
356    }
357
358    #[test]
359    fn compare_produces_correct_plot() {
360        let left = Expression::Die(3);
361        let right = Expression::Die(3);
362        // 1 1
363        // 1 2
364        // 1 3
365        // 2 1
366        // 2 2
367        // 2 3
368        // 3 1
369        // 3 2
370        // 3 3
371        let cases: Vec<(Comparison, &[(i32, i32)])> = vec![
372            (Comparison::GreaterThan, &[(0, 6), (1, 3)]),
373            (Comparison::GreaterThanOrEqualTo, &[(0, 3), (1, 6)]),
374            (Comparison::LessThan, &[(0, 6), (1, 3)]),
375            (Comparison::LessThanOrEqualTo, &[(0, 3), (1, 6)]),
376            (Comparison::EqualTo, &[(0, 6), (1, 3)]),
377        ];
378
379        for (comparison, options) in cases {
380            let expression = Expression::Compare(
381                Box::new(left.clone()),
382                Box::new(right.clone()),
383                comparison.clone(),
384            );
385
386            let expected: HashMap<i32, i32> = options.iter().cloned().collect();
387            let actual = expression.plot().simplify();
388
389            assert_eq!(expected, actual, "{:?}", comparison);
390        }
391    }
392}