1use fraction::One;
2
3use super::{
4 dice::Dice,
5 dice_string_parser::{self, DiceBuildingError},
6};
7use core::panic;
8use std::{
9 collections::HashMap,
10 fmt::Display,
11 ops::{Add, Mul},
12};
13pub type Value = i64;
14pub type Prob = fraction::BigFraction;
15pub type AggrValue = fraction::BigFraction;
16type Distribution = Box<dyn Iterator<Item = (Value, Prob)>>;
17pub type DistributionHashMap = HashMap<Value, Prob>;
18
19#[derive(Debug, PartialEq, Eq)]
33pub enum DiceBuilder {
34 Constant(Value),
36 FairDie {
38 min: Value,
40 max: Value,
42 },
43 SumCompound(Vec<DiceBuilder>),
45 ProductCompound(Vec<DiceBuilder>),
47 DivisionCompound(Vec<DiceBuilder>),
49 MaxCompound(Vec<DiceBuilder>),
51 MinCompound(Vec<DiceBuilder>),
53 SampleSumCompound(Vec<DiceBuilder>),
84}
85
86impl DiceBuilder {
87 pub fn from_string(input: &str) -> Result<Self, DiceBuildingError> {
113 dice_string_parser::string_to_factor(input)
114 }
115
116 pub fn build(self) -> Dice {
121 #[cfg(feature = "console_error_panic_hook")]
122 console_error_panic_hook::set_once();
123 Dice::from_builder(self)
124 }
125
126 pub fn build_from_string(input: &str) -> Result<Dice, DiceBuildingError> {
128 let builder = DiceBuilder::from_string(input)?;
129 Ok(builder.build())
130 }
131
132 pub fn reconstruct_string(&self) -> String {
136 match self {
137 DiceBuilder::Constant(i) => i.to_string(),
138 DiceBuilder::FairDie { min, max } => match *min == 1 {
139 true => format!("d{max}"),
140 false => "".to_owned(), },
142 DiceBuilder::SumCompound(v) => v
144 .iter()
145 .map(|f| f.to_string())
146 .collect::<Vec<String>>()
147 .join("+"),
148 DiceBuilder::ProductCompound(v) => v
149 .iter()
150 .map(|f| f.to_string())
151 .collect::<Vec<String>>()
152 .join("*"),
153 DiceBuilder::DivisionCompound(v) => v
154 .iter()
155 .map(|f| f.to_string())
156 .collect::<Vec<String>>()
157 .join("/"),
158 DiceBuilder::SampleSumCompound(v) => v
159 .iter()
160 .map(|f| f.to_string())
161 .collect::<Vec<String>>()
162 .join("x"),
163 DiceBuilder::MaxCompound(v) => format!(
164 "max({})",
165 v.iter()
166 .map(|f| f.to_string())
167 .collect::<Vec<String>>()
168 .join(",")
169 ),
170 DiceBuilder::MinCompound(v) => format!(
171 "min({})",
172 v.iter()
173 .map(|f| f.to_string())
174 .collect::<Vec<String>>()
175 .join(",")
176 ),
177 }
178 }
179
180 fn distribution_hashmap(&self) -> DistributionHashMap {
181 match self {
182 DiceBuilder::Constant(v) => {
183 let mut m = DistributionHashMap::new();
184 m.insert(*v, Prob::one());
185 m
186 }
187 DiceBuilder::FairDie { min, max } => {
188 assert!(max >= min);
189 let min: i64 = *min;
190 let max: i64 = *max;
191 let prob: Prob = Prob::new(1u64, (max - min + 1) as u64);
192 let mut m = DistributionHashMap::new();
193 for v in min..=max {
194 m.insert(v, prob.clone());
195 }
196 m
197 }
198 DiceBuilder::SampleSumCompound(vec) => {
199 let hashmaps = vec
200 .iter()
201 .map(|e| e.distribution_hashmap())
202 .collect::<Vec<DistributionHashMap>>();
203 sample_sum_convolute_hashmaps(&hashmaps)
204 }
205 DiceBuilder::SumCompound(vec)
206 | DiceBuilder::ProductCompound(vec)
207 | DiceBuilder::DivisionCompound(vec)
208 | DiceBuilder::MaxCompound(vec)
209 | DiceBuilder::MinCompound(vec) => {
210 let operation = match self {
211 DiceBuilder::SumCompound(_) => |a, b| a + b,
212 DiceBuilder::ProductCompound(_) => |a, b| a * b,
213 DiceBuilder::MaxCompound(_) => std::cmp::max,
214 DiceBuilder::MinCompound(_) => std::cmp::min,
215 DiceBuilder::DivisionCompound(_) => rounded_div::i64,
216 _ => panic!("unreachable by match"),
217 };
218 let hashmaps = vec
219 .iter()
220 .map(|e| e.distribution_hashmap())
221 .collect::<Vec<DistributionHashMap>>();
222 convolute_hashmaps(&hashmaps, operation)
223 }
224 }
225 }
226
227 pub fn distribution_iter(&self) -> Distribution {
232 let mut distribution_vec = self
233 .distribution_hashmap()
234 .into_iter()
235 .collect::<Vec<(Value, Prob)>>();
236 distribution_vec.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
237 Box::new(distribution_vec.into_iter())
238 }
239}
240
241impl Display for DiceBuilder {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 write! {f, "{}", self.reconstruct_string()}
244 }
245}
246
247fn convolute_hashmaps(
248 hashmaps: &Vec<DistributionHashMap>,
249 operation: fn(Value, Value) -> Value,
250) -> DistributionHashMap {
251 if hashmaps.is_empty() {
252 panic!("cannot convolute hashmaps from a zero element vector");
253 }
254 let mut convoluted_h = hashmaps[0].clone();
255 for h in hashmaps.iter().skip(1) {
256 convoluted_h = convolute_two_hashmaps(&convoluted_h, h, operation);
257 }
258 convoluted_h
259}
260
261fn convolute_two_hashmaps(
262 h1: &DistributionHashMap,
263 h2: &DistributionHashMap,
264 operation: fn(Value, Value) -> Value,
265) -> DistributionHashMap {
266 let mut m = DistributionHashMap::new();
267 for (v1, p1) in h1.iter() {
268 for (v2, p2) in h2.iter() {
269 let v = operation(*v1, *v2);
270 let p = p1 * p2;
271 match m.entry(v) {
272 std::collections::hash_map::Entry::Occupied(mut e) => {
273 *e.get_mut() += p;
274 }
275 std::collections::hash_map::Entry::Vacant(e) => {
276 e.insert(p);
277 }
278 }
279 }
280 }
281 m
282}
283
284fn sample_sum_convolute_hashmaps(hashmaps: &Vec<DistributionHashMap>) -> DistributionHashMap {
285 if hashmaps.is_empty() {
286 panic!("cannot convolute hashmaps from a zero element vector");
287 }
288 let mut convoluted_h = hashmaps[0].clone();
289 for h in hashmaps.iter().skip(1) {
290 convoluted_h = sample_sum_convolute_two_hashmaps(&convoluted_h, h);
291 }
292 convoluted_h
293}
294
295fn sample_sum_convolute_two_hashmaps(
296 count_factor: &DistributionHashMap,
297 sample_factor: &DistributionHashMap,
298) -> DistributionHashMap {
299 let mut total_hashmap = DistributionHashMap::new();
300 for (count, count_p) in count_factor.iter() {
301 let mut count_hashmap: DistributionHashMap = match count.cmp(&0) {
302 std::cmp::Ordering::Less => {
303 let count: usize = (-count) as usize;
304 let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
305 .take(count)
306 .cloned()
307 .collect();
308 convolute_hashmaps(&sample_vec, |a, b| a + b)
309 }
310 std::cmp::Ordering::Equal => {
311 let mut h = DistributionHashMap::new();
312 h.insert(0, Prob::new(1u64, 1u64));
313 h
314 }
315 std::cmp::Ordering::Greater => {
316 let count: usize = *count as usize;
317 let sample_vec: Vec<DistributionHashMap> = std::iter::repeat(sample_factor)
318 .take(count)
319 .cloned()
320 .collect();
321 convolute_hashmaps(&sample_vec, |a, b| a + b)
322 }
323 };
324 count_hashmap.iter_mut().for_each(|e| {
325 *e.1 *= count_p.clone();
326 });
327 merge_hashmaps(&mut total_hashmap, &count_hashmap);
328 }
329 total_hashmap
330}
331
332impl Mul for Box<DiceBuilder> {
333 type Output = Box<DiceBuilder>;
334
335 fn mul(self, rhs: Self) -> Self::Output {
336 Box::new(DiceBuilder::ProductCompound(vec![*self, *rhs]))
337 }
338}
339
340impl Add for Box<DiceBuilder> {
341 type Output = Box<DiceBuilder>;
342
343 fn add(self, rhs: Self) -> Self::Output {
344 Box::new(DiceBuilder::SumCompound(vec![*self, *rhs]))
345 }
346}
347
348pub fn merge_hashmaps(first: &mut DistributionHashMap, second: &DistributionHashMap) {
349 for (k, v) in second.iter() {
350 match first.get_mut(k) {
351 Some(e) => {
352 *e += v;
353 }
354 None => {
355 first.insert(*k, v.clone());
356 }
357 }
358 }
359}