computation_types/
rand.rs

1pub use self::{rand::*, seeded_rand::*};
2
3#[allow(clippy::module_inception)]
4mod rand {
5    use core::fmt;
6    use std::marker::PhantomData;
7
8    use crate::{impl_core_ops, Computation, ComputationFn, NamedArgs, Names};
9
10    #[derive(Clone, Copy, Debug)]
11    pub struct Rand<Dist, T>
12    where
13        Self: Computation,
14    {
15        pub distribution: Dist,
16        ty: PhantomData<T>,
17    }
18
19    impl<Dist, T> Rand<Dist, T>
20    where
21        Self: Computation,
22    {
23        pub fn new(distribution: Dist) -> Self {
24            Self {
25                distribution,
26                ty: PhantomData,
27            }
28        }
29    }
30
31    impl<Dist, T> Computation for Rand<Dist, T>
32    where
33        Dist: Computation,
34    {
35        type Dim = Dist::Dim;
36        type Item = T;
37    }
38
39    impl<Dist, T> ComputationFn for Rand<Dist, T>
40    where
41        Self: Computation,
42        Dist: ComputationFn,
43        Rand<Dist::Filled, T>: Computation,
44    {
45        type Filled = Rand<Dist::Filled, T>;
46
47        fn fill(self, named_args: NamedArgs) -> Self::Filled {
48            Rand {
49                distribution: self.distribution.fill(named_args),
50                ty: self.ty,
51            }
52        }
53
54        fn arg_names(&self) -> Names {
55            self.distribution.arg_names()
56        }
57    }
58
59    impl_core_ops!(Rand<Dist, T>);
60
61    impl<Dist, T> fmt::Display for Rand<Dist, T>
62    where
63        Self: Computation,
64        Dist: fmt::Display,
65    {
66        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67            write!(f, "rand({})", self.distribution)
68        }
69    }
70}
71
72mod seeded_rand {
73    use core::fmt;
74    use std::marker::PhantomData;
75
76    use crate::{impl_core_ops, peano::Zero, Computation, ComputationFn, NamedArgs, Names};
77
78    #[derive(Clone, Copy, Debug)]
79    pub struct SeededRand<Dist, T, R>
80    where
81        Self: Computation,
82    {
83        pub distribution: Dist,
84        ty: PhantomData<T>,
85        pub rng: R,
86    }
87
88    impl<Dist, T, R> SeededRand<Dist, T, R>
89    where
90        Self: Computation,
91    {
92        pub fn new(distribution: Dist, rng: R) -> Self {
93            Self {
94                distribution,
95                ty: PhantomData,
96                rng,
97            }
98        }
99    }
100
101    impl<Dist, T, R> Computation for SeededRand<Dist, T, R>
102    where
103        Dist: Computation,
104        R: Computation,
105    {
106        type Dim = (Dist::Dim, Zero);
107        type Item = (T, R::Item);
108    }
109
110    impl<Dist, T, R> ComputationFn for SeededRand<Dist, T, R>
111    where
112        Self: Computation,
113        Dist: ComputationFn,
114        R: ComputationFn,
115        SeededRand<Dist::Filled, T, R::Filled>: Computation,
116    {
117        type Filled = SeededRand<Dist::Filled, T, R::Filled>;
118
119        fn fill(self, named_args: NamedArgs) -> Self::Filled {
120            let (args_0, args_1) = named_args
121                .partition(&self.rng.arg_names(), &self.distribution.arg_names())
122                .unwrap_or_else(|e| panic!("{}", e,));
123            SeededRand {
124                distribution: self.distribution.fill(args_1),
125                ty: self.ty,
126                rng: self.rng.fill(args_0),
127            }
128        }
129
130        fn arg_names(&self) -> Names {
131            self.rng.arg_names().union(self.distribution.arg_names())
132        }
133    }
134
135    impl_core_ops!(SeededRand<Dist, T, R>);
136
137    impl<Dist, T, R> fmt::Display for SeededRand<Dist, T, R>
138    where
139        Self: Computation,
140        Dist: fmt::Display,
141        R: fmt::Display,
142    {
143        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144            write!(f, "seeded_rand({}, {})", self.distribution, self.rng)
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use proptest::prelude::*;
152    use rand::{distributions::Standard, rngs::StdRng, SeedableRng};
153    use test_strategy::proptest;
154
155    use crate::{
156        rand::{rand::Rand, seeded_rand::SeededRand},
157        run::Matrix,
158        val, val1, val2,
159    };
160
161    #[test]
162    fn rands_should_display() {
163        let dist = val!(Standard);
164        assert_eq!(
165            Rand::<_, i32>::new(dist).to_string(),
166            format!("rand({})", dist)
167        );
168    }
169
170    #[proptest]
171    fn rands_should_display_1d(#[strategy(1_usize..10)] x: usize) {
172        let dist = val1!(std::iter::repeat(Standard).take(x).collect::<Vec<_>>());
173        prop_assert_eq!(
174            Rand::<_, i32>::new(dist.clone()).to_string(),
175            format!("rand({})", dist)
176        );
177    }
178
179    #[proptest]
180    fn rands_should_display_2d(
181        #[strategy(1_usize..10)] x: usize,
182        #[strategy(1_usize..10)] y: usize,
183    ) {
184        let dist = val2!(Matrix::from_vec(
185            (x, y),
186            std::iter::repeat(Standard).take(x * y).collect::<Vec<_>>()
187        )
188        .unwrap());
189        prop_assert_eq!(
190            Rand::<_, i32>::new(dist.clone()).to_string(),
191            format!("rand({})", dist)
192        );
193    }
194
195    #[proptest]
196    fn seededrands_should_display(seed: u64) {
197        let dist = val!(Standard);
198        let rng = val!(StdRng::seed_from_u64(seed));
199        prop_assert_eq!(
200            SeededRand::<_, i32, _>::new(dist, rng.clone()).to_string(),
201            format!("seeded_rand({}, {})", dist, rng)
202        );
203    }
204
205    #[proptest]
206    fn seededrands_should_display_1d(seed: u64, #[strategy(1_usize..10)] x: usize) {
207        let dist = val1!(std::iter::repeat(Standard).take(x).collect::<Vec<_>>());
208        let rng = val!(StdRng::seed_from_u64(seed));
209        prop_assert_eq!(
210            SeededRand::<_, i32, _>::new(dist.clone(), rng.clone()).to_string(),
211            format!("seeded_rand({}, {})", dist, rng)
212        );
213    }
214
215    #[proptest]
216    fn seededrands_should_display_2d(
217        seed: u64,
218        #[strategy(1_usize..10)] x: usize,
219        #[strategy(1_usize..10)] y: usize,
220    ) {
221        let dist = val2!(Matrix::from_vec(
222            (x, y),
223            std::iter::repeat(Standard).take(x * y).collect::<Vec<_>>()
224        )
225        .unwrap());
226        let rng = val!(StdRng::seed_from_u64(seed));
227        prop_assert_eq!(
228            SeededRand::<_, i32, _>::new(dist.clone(), rng.clone()).to_string(),
229            format!("seeded_rand({}, {})", dist, rng)
230        );
231    }
232}