ebi/
choose_randomly.rs

1use anyhow::{Context, Result, anyhow};
2use fraction::{GenericFraction, Ratio, Sign};
3use num::BigUint;
4use num_bigint::RandBigInt;
5use rand::Rng;
6
7use crate::{
8    exact::{MaybeExact, is_exact_globally},
9    fraction_enum::FractionEnum,
10    fraction_exact::FractionExact,
11    fraction_f64::FractionF64,
12    traits::Zero,
13};
14
15pub trait ChooseRandomly {
16    type Cache;
17    /**
18     * Return a random index from 0 (inclusive) to the length of the list (exclusive).
19     * The likelihood of each index to be returned is proportional to the value of the fraction at that index.
20     *
21     * The fractions do not need to sum to 1.
22     */
23    fn choose_randomly(fractions: &Vec<Self>) -> Result<usize>
24    where
25        Self: Sized;
26
27    fn choose_randomly_create_cache<'a>(
28        fractions: impl Iterator<Item = &'a Self>,
29    ) -> Result<Self::Cache>
30    where
31        Self: Sized,
32        Self: 'a;
33
34    fn choose_randomly_cached(cache: &Self::Cache) -> usize
35    where
36        Self: Sized;
37}
38
39#[cfg(any(
40    all(
41        not(feature = "exactarithmetic"),
42        not(feature = "approximatearithmetic")
43    ),
44    all(feature = "exactarithmetic", feature = "approximatearithmetic")
45))]
46pub type FractionRandomCache = FractionRandomCacheEnum;
47
48#[cfg(all(not(feature = "exactarithmetic"), feature = "approximatearithmetic"))]
49pub type FractionRandomCache = super::fraction_f64::FractionRandomCacheF64;
50
51#[cfg(all(feature = "exactarithmetic", not(feature = "approximatearithmetic")))]
52pub type FractionRandomCache = super::fraction_exact::FractionRandomCacheExact;
53
54pub enum FractionRandomCacheEnum {
55    Exact(Vec<fraction::BigFraction>, BigUint),
56    Approx(Vec<f64>),
57}
58
59impl ChooseRandomly for FractionEnum {
60    type Cache = FractionRandomCacheEnum;
61
62    fn choose_randomly(fractions: &Vec<FractionEnum>) -> Result<usize> {
63        if fractions.is_empty() {
64            return Err(anyhow!("cannot take an element of an empty list"));
65        }
66
67        //normalise the probabilities
68        let mut probabilities: Vec<FractionEnum> = fractions.iter().cloned().collect();
69        let sum = probabilities
70            .iter()
71            .fold(FractionEnum::zero(), |x, y| &x + y);
72        if sum == FractionEnum::CannotCombineExactAndApprox {
73            return Err(anyhow!("cannot combine exact and approximate arithmetic"));
74        }
75        probabilities.retain_mut(|v| {
76            *v /= &sum;
77            true
78        });
79
80        //select a random value
81        let mut rng = rand::thread_rng();
82        let rand_val = if sum.is_exact() {
83            //strategy: the highest denominator determines how much precision we need
84            let temp_zero = BigUint::zero();
85            let max_denom = probabilities
86                .iter()
87                .map(|f| {
88                    if let FractionEnum::Exact(e) = f {
89                        e.denom().unwrap()
90                    } else {
91                        &temp_zero
92                    }
93                })
94                .max()
95                .unwrap();
96            //Generate a random value with the number of bits of the highest denominator. Repeat until this value is <= the max denominator.
97            let mut rand_val = rng.gen_biguint(max_denom.bits());
98            while &rand_val > max_denom {
99                rand_val = rng.gen_biguint(max_denom.bits());
100            }
101            //create the fraction from the random nominator and the max denominator
102            FractionEnum::try_from((rand_val, max_denom.clone())).unwrap()
103        } else {
104            FractionEnum::Approx(rng.gen_range(0.0..=1.0))
105        };
106
107        let mut cum_prob = FractionEnum::zero();
108        for (index, value) in probabilities.iter().enumerate() {
109            cum_prob += value;
110            if rand_val < cum_prob {
111                return Ok(index);
112            }
113        }
114        Ok(probabilities.len() - 1)
115    }
116
117    fn choose_randomly_create_cache<'a>(
118        mut fractions: impl Iterator<Item = &'a Self>,
119    ) -> Result<FractionRandomCache>
120    where
121        Self: Sized,
122        Self: 'a,
123    {
124        if is_exact_globally() {
125            //exact mode
126            if let Some(first) = fractions.next() {
127                let mut cumulative_probabilities = vec![
128                    first
129                        .extract_exact()
130                        .with_context(|| "cannot combine exact and approximate arithmetic")?
131                        .clone(),
132                ];
133                let mut highest_denom = first.extract_exact()?.denom().unwrap();
134
135                while let Some(fraction) = fractions.next() {
136                    highest_denom = highest_denom.max(fraction.extract_exact()?.denom().unwrap());
137
138                    let mut x = fraction
139                        .extract_exact()
140                        .with_context(|| "cannot combine exact and approximate arithmetic")?
141                        .clone();
142                    x += cumulative_probabilities.last().unwrap();
143                    cumulative_probabilities.push(x);
144                }
145                let highest_denom = highest_denom.clone();
146
147                Ok(FractionRandomCacheEnum::Exact(
148                    cumulative_probabilities,
149                    highest_denom,
150                ))
151            } else {
152                Err(anyhow!("cannot take an element of an empty list"))
153            }
154        } else {
155            //approximate mode
156            if let Some(first) = fractions.next() {
157                let mut cumulative_probabilities = vec![
158                    first
159                        .extract_approx()
160                        .with_context(|| "cannot combine exact and approximate arithmetic")?,
161                ];
162
163                while let Some(fraction) = fractions.next() {
164                    cumulative_probabilities.push(
165                        fraction
166                            .extract_approx()
167                            .with_context(|| "cannot combine exact and approximate arithmetic")?
168                            + cumulative_probabilities.last().unwrap(),
169                    );
170                }
171
172                Ok(FractionRandomCacheEnum::Approx(cumulative_probabilities))
173            } else {
174                Err(anyhow!("cannot take an element of an empty list"))
175            }
176        }
177    }
178
179    fn choose_randomly_cached(cache: &FractionRandomCache) -> usize
180    where
181        Self: Sized,
182    {
183        match cache {
184            FractionRandomCacheEnum::Exact(cumulative_probabilities, highest_denom) => {
185                //select a random value
186                let mut rng = rand::thread_rng();
187                let rand_val = {
188                    //strategy: the highest denominator determines how much precision we need
189
190                    //Generate a random value with the number of bits of the highest denominator. Repeat until this value is <= the max denominator.
191                    let mut rand_val = rng.gen_biguint(highest_denom.bits());
192                    while &rand_val > highest_denom {
193                        rand_val = rng.gen_biguint(highest_denom.bits());
194                    }
195                    //create the fraction from the random nominator and the max denominator
196                    GenericFraction::Rational(
197                        Sign::Plus,
198                        Ratio::new(rand_val, highest_denom.clone()),
199                    )
200                };
201
202                match cumulative_probabilities.binary_search(&rand_val) {
203                    Ok(index) | Err(index) => index,
204                }
205            }
206            FractionRandomCacheEnum::Approx(cumulative_probabilities) => {
207                //select a random value
208                let mut rng = rand::thread_rng();
209                let rand_val = rng.gen_range(0.0..=*cumulative_probabilities.last().unwrap());
210
211                match cumulative_probabilities.binary_search_by(|probe| probe.total_cmp(&rand_val))
212                {
213                    Ok(index) | Err(index) => index,
214                }
215            }
216        }
217    }
218}
219
220pub struct FractionRandomCacheExact {
221    cumulative_probabilities: Vec<FractionExact>,
222    highest_denom: BigUint,
223}
224
225impl ChooseRandomly for FractionExact {
226    type Cache = FractionRandomCacheExact;
227
228    fn choose_randomly(fractions: &Vec<FractionExact>) -> Result<usize> {
229        if fractions.is_empty() {
230            return Err(anyhow!("cannot take an element of an empty list"));
231        }
232
233        //normalise the probabilities
234        let mut probabilities: Vec<FractionExact> = fractions.iter().cloned().collect();
235        let sum = probabilities
236            .iter()
237            .fold(FractionExact::zero(), |x, y| &x + y);
238        probabilities.retain_mut(|v| {
239            *v /= &sum;
240            true
241        });
242
243        //select a random value
244        let mut rng = rand::thread_rng();
245        let rand_val = {
246            //strategy: the highest denominator determines how much precision we need
247            let max_denom = probabilities
248                .iter()
249                .map(|f| match f {
250                    FractionExact(e) => e.denom().unwrap(),
251                })
252                .max()
253                .unwrap();
254            //Generate a random value with the number of bits of the highest denominator. Repeat until this value is <= the max denominator.
255            let mut rand_val = rng.gen_biguint(max_denom.bits());
256            while &rand_val > max_denom {
257                rand_val = rng.gen_biguint(max_denom.bits());
258            }
259            //create the fraction from the random nominator and the max denominator
260            FractionExact::try_from((rand_val, max_denom.clone())).unwrap()
261        };
262
263        let mut cum_prob = FractionExact::zero();
264        for (index, value) in probabilities.iter().enumerate() {
265            cum_prob += value;
266            if rand_val < cum_prob {
267                return Ok(index);
268            }
269        }
270        Ok(probabilities.len() - 1)
271    }
272
273    fn choose_randomly_create_cache<'a>(
274        mut fractions: impl Iterator<Item = &'a Self>,
275    ) -> Result<FractionRandomCacheExact>
276    where
277        Self: Sized,
278        Self: 'a,
279    {
280        if let Some(first) = fractions.next() {
281            let mut cumulative_probabilities = vec![first.clone()];
282            let mut highest_denom = first.0.denom().unwrap();
283
284            while let Some(fraction) = fractions.next() {
285                highest_denom = highest_denom.max(fraction.0.denom().unwrap());
286
287                cumulative_probabilities.push(fraction + cumulative_probabilities.last().unwrap());
288            }
289            let highest_denom = highest_denom.clone();
290
291            Ok(FractionRandomCacheExact {
292                cumulative_probabilities,
293                highest_denom,
294            })
295        } else {
296            Err(anyhow!("cannot take an element of an empty list"))
297        }
298    }
299
300    fn choose_randomly_cached(cache: &FractionRandomCacheExact) -> usize
301    where
302        Self: Sized,
303    {
304        //select a random value
305        let mut rng = rand::thread_rng();
306        let rand_val = {
307            //strategy: the highest denominator determines how much precision we need
308
309            //Generate a random value with the number of bits of the highest denominator. Repeat until this value is <= the max denominator.
310            let mut rand_val = rng.gen_biguint(cache.highest_denom.bits());
311            while rand_val > cache.highest_denom {
312                rand_val = rng.gen_biguint(cache.highest_denom.bits());
313            }
314            //create the fraction from the random nominator and the max denominator
315            FractionExact::try_from((rand_val, cache.highest_denom.clone())).unwrap()
316        };
317
318        match cache.cumulative_probabilities.binary_search(&rand_val) {
319            Ok(index) | Err(index) => index,
320        }
321    }
322}
323
324pub struct FractionRandomCacheF64 {
325    cumulative_probabilities: Vec<FractionF64>,
326}
327
328impl ChooseRandomly for FractionF64 {
329    type Cache = FractionRandomCacheF64;
330
331    fn choose_randomly(fractions: &Vec<FractionF64>) -> Result<usize> {
332        if fractions.is_empty() {
333            return Err(anyhow!("cannot take an element of an empty list"));
334        }
335
336        //normalise the probabilities
337        let mut probabilities: Vec<FractionF64> = fractions.iter().cloned().collect();
338        let sum = probabilities
339            .iter()
340            .fold(FractionF64::zero(), |x, y| &x + y);
341        probabilities.retain_mut(|v| {
342            *v /= &sum;
343            true
344        });
345
346        //select a random value
347        let mut rng = rand::thread_rng();
348        let rand_val = FractionF64(rng.gen_range(0.0..=1.0));
349
350        let mut cum_prob = FractionF64::zero();
351        for (index, value) in probabilities.iter().enumerate() {
352            cum_prob += value;
353            if rand_val < cum_prob {
354                return Ok(index);
355            }
356        }
357        Ok(probabilities.len() - 1)
358    }
359
360    fn choose_randomly_create_cache<'a>(
361        mut fractions: impl Iterator<Item = &'a Self>,
362    ) -> Result<FractionRandomCacheF64>
363    where
364        Self: Sized,
365        Self: 'a,
366    {
367        if let Some(first) = fractions.next() {
368            let mut cumulative_probabilities = vec![*first];
369
370            while let Some(fraction) = fractions.next() {
371                cumulative_probabilities.push(fraction + cumulative_probabilities.last().unwrap());
372            }
373
374            Ok(FractionRandomCacheF64 {
375                cumulative_probabilities,
376            })
377        } else {
378            Err(anyhow!("cannot take an element of an empty list"))
379        }
380    }
381
382    fn choose_randomly_cached(cache: &FractionRandomCacheF64) -> usize
383    where
384        Self: Sized,
385    {
386        //select a random value
387        let mut rng = rand::thread_rng();
388        let rand_val = FractionF64::from(
389            rng.gen_range(
390                0.0..=cache
391                    .cumulative_probabilities
392                    .last()
393                    .unwrap()
394                    .extract_approx()
395                    .unwrap(),
396            ),
397        );
398
399        match cache.cumulative_probabilities.binary_search(&rand_val) {
400            Ok(index) | Err(index) => index,
401        }
402    }
403}