rand_functors/strategies/
enumerator.rs1use alloc::vec::Vec;
2
3use rand::distr::uniform::SampleUniform;
4use rand::distr::StandardUniform;
5use rand::prelude::*;
6
7use crate::{
8 FlattenableRandomStrategy, Inner, RandomStrategy, RandomVariable, RandomVariableRange,
9};
10
11#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
24pub struct Enumerator;
25
26impl RandomStrategy for Enumerator {
27 type Functor<I: Inner> = Vec<I>;
28
29 #[inline]
30 fn fmap<A: Inner, B: Inner, F: Fn(A) -> B>(f: Self::Functor<A>, func: F) -> Self::Functor<B> {
31 f.into_iter().map(func).collect()
32 }
33
34 #[inline]
35 fn fmap_rand<A: Inner, B: Inner, R: RandomVariable, F: Fn(A, R) -> B>(
36 f: Self::Functor<A>,
37 _: &mut impl Rng,
38 func: F,
39 ) -> Self::Functor<B>
40 where
41 StandardUniform: Distribution<R>,
42 {
43 f.into_iter()
44 .flat_map(|a| R::sample_space().map(move |r| (a.clone(), r)))
45 .map(|(a, r)| func(a, r))
46 .collect()
47 }
48
49 #[inline]
50 fn fmap_rand_range<A: Inner, B: Inner, R: RandomVariable + SampleUniform, F: Fn(A, R) -> B>(
51 f: Self::Functor<A>,
52 range: impl RandomVariableRange<R>,
53 _: &mut impl Rng,
54 func: F,
55 ) -> Self::Functor<B>
56 where
57 StandardUniform: Distribution<R>,
58 {
59 f.into_iter()
60 .flat_map(|a| range.sample_space().map(move |r| (a.clone(), r)))
61 .map(|(a, r)| func(a, r))
62 .collect()
63 }
64}
65
66impl FlattenableRandomStrategy for Enumerator {
67 #[inline]
68 fn fmap_flat<A: Inner, B: Inner, F: FnMut(A) -> Self::Functor<B>>(
69 f: Self::Functor<A>,
70 func: F,
71 ) -> Self::Functor<B> {
72 let children = f.into_iter().map(func).collect::<Self::Functor<_>>();
73 let Some(length_lcm) = children.iter().fold(None, |lcm, functor| {
74 if let Some(lcm) = lcm {
75 Some(num::integer::lcm(lcm, functor.len()))
76 } else {
77 Some(functor.len())
78 }
79 }) else {
80 return Self::Functor::new();
81 };
82 children
83 .into_iter()
84 .flat_map(|functor| {
85 let scaling = length_lcm / functor.len();
86 core::iter::repeat_n(functor, scaling)
87 })
88 .flatten()
89 .collect()
90 }
91}