1use crate::traits::PlotResult;
4use crate::traits::PlotTable;
5use crate::traits::RollResult;
6use rand::Rng;
7use std::collections::HashMap;
8
9use crate::operators::{
10 advantage, difference, disadvantage, divide, equal_to, greater_than, greater_than_or_equal_to,
11 less_than, less_than_or_equal_to, multiply, sum, BinaryOperator,
12};
13use crate::traits::Rollable;
14
15#[derive(Clone, Debug, PartialEq)]
16pub enum Comparison {
17 GreaterThan,
18 GreaterThanOrEqualTo,
19 LessThan,
20 LessThanOrEqualTo,
21 EqualTo,
22}
23
24use Comparison::*;
25
26#[derive(Clone, Debug, PartialEq)]
28pub enum Expression {
29 Die(u32),
30 Constant(i32),
31
32 Sum(Box<Expression>, Box<Expression>),
33 Diff(Box<Expression>, Box<Expression>),
34 Multiply(Box<Expression>, Box<Expression>),
35 Divide(Box<Expression>, Box<Expression>),
36 Advantage(Box<Expression>),
37 Disadvantage(Box<Expression>),
38 Compare(Box<Expression>, Box<Expression>, Comparison),
39}
40
41use Expression::*;
42
43impl Expression {
44 fn get_operation(&self) -> Option<(BinaryOperator, &Box<Expression>, &Box<Expression>)> {
47 match self {
48 Constant(_) => None,
49 Die(_) => None,
50
51 Sum(left, right) => Some((sum, left, right)),
52 Diff(left, right) => Some((difference, left, right)),
53 Multiply(left, right) => Some((multiply, left, right)),
54 Divide(left, right) => Some((divide, left, right)),
55 Advantage(expr) => Some((advantage, expr, expr)),
56 Disadvantage(expr) => Some((disadvantage, expr, expr)),
57 Compare(left, right, comparison) => {
58 let compare = match comparison {
59 GreaterThan => greater_than,
60 GreaterThanOrEqualTo => greater_than_or_equal_to,
61 LessThan => less_than,
62 LessThanOrEqualTo => less_than_or_equal_to,
63 EqualTo => equal_to,
64 };
65 Some((compare, left, right))
66 }
67 }
68 }
69}
70
71impl Rollable for Expression {
72 fn roll(&self) -> RollResult {
74 if let Constant(num) = self {
76 return *num;
77 }
78 if let Die(max) = self {
79 return rand::thread_rng().gen_range(1, *max as RollResult + 1);
80 }
81
82 let (operator, left, right) = self
83 .get_operation()
84 .expect("expression does not represent an operation");
85
86 (operator)(&left.roll(), &right.roll())
87 }
88
89 fn plot(&self) -> PlotResult {
91 if let Constant(num) = self {
93 return PlotResult {
94 total: 1.0,
95 plot: [(*num, 1.0)].iter().cloned().collect(),
96 };
97 }
98 if let Die(num) = self {
99 let total = *num as f32;
100 return PlotResult {
101 total,
102 plot: (1..num + 1)
103 .map(|i| (i as RollResult, 1.0 / total))
104 .collect(),
105 };
106 }
107
108 let (operator, left, right) = self
110 .get_operation()
111 .expect("expression does not represent an operation");
112
113 let left = left.plot();
114 let right = right.plot();
115
116 let mut product: PlotTable = HashMap::new();
117
118 left.plot
119 .iter()
120 .flat_map(|(left_value, left_chance)| {
121 right.plot.iter().map(move |(right_value, right_chance)| {
122 let value = (operator)(left_value, right_value);
123 (value, left_chance * right_chance)
124 })
125 })
126 .for_each(|(value, count)| {
127 *product.entry(value).or_insert(0.0) += count;
128 });
129
130 PlotResult {
131 total: left.total * right.total,
132 plot: product,
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 impl PlotResult {
142 pub fn simplify(&self) -> HashMap<i32, i32> {
144 self.plot
145 .iter()
146 .map(|(value, chance)| {
147 let outcomes = (chance * self.total) as i32;
148 (*value, outcomes)
149 })
150 .collect()
151 }
152 }
153
154 #[test]
155 fn simplify_produces_correct_table() {
156 let plot_result = PlotResult {
157 total: 10.0,
158 plot: [(1, 0.1), (2, 0.2), (3, 0.3), (4, 0.4)]
159 .iter()
160 .cloned()
161 .collect(),
162 };
163 let expected: HashMap<i32, i32> =
164 [(1, 1), (2, 2), (3, 3), (4, 4)].iter().cloned().collect();
165
166 let actual = plot_result.simplify();
167
168 assert_eq!(expected, actual);
169 }
170
171 #[test]
172 fn multiply_produces_correct_plot() {
173 let expression =
174 Expression::Multiply(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
175 let expected: HashMap<i32, i32> = [
176 (1, 1),
177 (2, 2),
178 (3, 2),
179 (4, 3),
180 (6, 2),
181 (8, 2),
182 (9, 1),
183 (12, 2),
184 (16, 1),
185 ]
186 .iter()
187 .cloned()
188 .collect();
189
190 let actual = expression.plot().simplify();
191
192 assert_eq!(expected, actual);
193 }
194
195 #[test]
196 fn divide_produces_correct_plot() {
197 let expression =
198 Expression::Divide(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
199 let expected: HashMap<i32, i32> = [(0, 6), (1, 6), (2, 2), (3, 1), (4, 1)]
216 .iter()
217 .cloned()
218 .collect();
219
220 let actual = expression.plot().simplify();
221
222 assert_eq!(expected, actual);
223 }
224
225 #[test]
226 fn sum_produces_correct_plot() {
227 let expression =
228 Expression::Sum(Box::new(Expression::Die(6)), Box::new(Expression::Die(6)));
229 let expected: HashMap<i32, i32> = [
230 (2, 1),
231 (3, 2),
232 (4, 3),
233 (5, 4),
234 (6, 5),
235 (7, 6),
236 (8, 5),
237 (9, 4),
238 (10, 3),
239 (11, 2),
240 (12, 1),
241 ]
242 .iter()
243 .cloned()
244 .collect();
245
246 let actual = expression.plot().simplify();
247
248 assert_eq!(expected, actual);
249 }
250
251 #[test]
252 fn difference_produces_correct_plot() {
253 let expression =
254 Expression::Diff(Box::new(Expression::Die(4)), Box::new(Expression::Die(4)));
255 let expected: HashMap<i32, i32> =
272 [(-3, 1), (-2, 2), (-1, 3), (0, 4), (1, 3), (2, 2), (3, 1)]
273 .iter()
274 .cloned()
275 .collect();
276
277 let actual = expression.plot().simplify();
278
279 assert_eq!(expected, actual);
280 }
281
282 #[test]
283 fn advantage_produces_correct_plot() {
284 let expression = Expression::Advantage(Box::new(Expression::Die(4)));
285 let expected: HashMap<i32, i32> =
302 [(1, 1), (2, 3), (3, 5), (4, 7)].iter().cloned().collect();
303
304 let actual = expression.plot().simplify();
305
306 assert_eq!(expected, actual);
307 }
308
309 #[test]
310 fn disadvantage_produces_correct_plot() {
311 let expression = Expression::Disadvantage(Box::new(Expression::Die(4)));
312 let expected: HashMap<i32, i32> =
329 [(1, 7), (2, 5), (3, 3), (4, 1)].iter().cloned().collect();
330
331 let actual = expression.plot().simplify();
332
333 assert_eq!(expected, actual);
334 }
335
336 #[test]
337 #[ignore = "not implemented"]
338 fn contest_produces_correct_plot() {
339 let expression = Expression::Compare(
340 Box::new(Expression::Die(2)),
341 Box::new(Expression::Die(3)),
342 Comparison::GreaterThan,
343 );
344
345 let expected: HashMap<i32, i32> = [(-1, 3), (0, 2), (1, 1)].iter().cloned().collect();
352
353 let actual = expression.plot().simplify();
354
355 assert_eq!(expected, actual);
356 }
357
358 #[test]
359 fn compare_produces_correct_plot() {
360 let left = Expression::Die(3);
361 let right = Expression::Die(3);
362 let cases: Vec<(Comparison, &[(i32, i32)])> = vec![
372 (Comparison::GreaterThan, &[(0, 6), (1, 3)]),
373 (Comparison::GreaterThanOrEqualTo, &[(0, 3), (1, 6)]),
374 (Comparison::LessThan, &[(0, 6), (1, 3)]),
375 (Comparison::LessThanOrEqualTo, &[(0, 3), (1, 6)]),
376 (Comparison::EqualTo, &[(0, 6), (1, 3)]),
377 ];
378
379 for (comparison, options) in cases {
380 let expression = Expression::Compare(
381 Box::new(left.clone()),
382 Box::new(right.clone()),
383 comparison.clone(),
384 );
385
386 let expected: HashMap<i32, i32> = options.iter().cloned().collect();
387 let actual = expression.plot().simplify();
388
389 assert_eq!(expected, actual, "{:?}", comparison);
390 }
391 }
392}