1pub use self::{rand::*, seeded_rand::*};
2
3#[allow(clippy::module_inception)]
4mod rand {
5 use core::fmt;
6 use std::marker::PhantomData;
7
8 use crate::{impl_core_ops, Computation, ComputationFn, NamedArgs, Names};
9
10 #[derive(Clone, Copy, Debug)]
11 pub struct Rand<Dist, T>
12 where
13 Self: Computation,
14 {
15 pub distribution: Dist,
16 ty: PhantomData<T>,
17 }
18
19 impl<Dist, T> Rand<Dist, T>
20 where
21 Self: Computation,
22 {
23 pub fn new(distribution: Dist) -> Self {
24 Self {
25 distribution,
26 ty: PhantomData,
27 }
28 }
29 }
30
31 impl<Dist, T> Computation for Rand<Dist, T>
32 where
33 Dist: Computation,
34 {
35 type Dim = Dist::Dim;
36 type Item = T;
37 }
38
39 impl<Dist, T> ComputationFn for Rand<Dist, T>
40 where
41 Self: Computation,
42 Dist: ComputationFn,
43 Rand<Dist::Filled, T>: Computation,
44 {
45 type Filled = Rand<Dist::Filled, T>;
46
47 fn fill(self, named_args: NamedArgs) -> Self::Filled {
48 Rand {
49 distribution: self.distribution.fill(named_args),
50 ty: self.ty,
51 }
52 }
53
54 fn arg_names(&self) -> Names {
55 self.distribution.arg_names()
56 }
57 }
58
59 impl_core_ops!(Rand<Dist, T>);
60
61 impl<Dist, T> fmt::Display for Rand<Dist, T>
62 where
63 Self: Computation,
64 Dist: fmt::Display,
65 {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 write!(f, "rand({})", self.distribution)
68 }
69 }
70}
71
72mod seeded_rand {
73 use core::fmt;
74 use std::marker::PhantomData;
75
76 use crate::{impl_core_ops, peano::Zero, Computation, ComputationFn, NamedArgs, Names};
77
78 #[derive(Clone, Copy, Debug)]
79 pub struct SeededRand<Dist, T, R>
80 where
81 Self: Computation,
82 {
83 pub distribution: Dist,
84 ty: PhantomData<T>,
85 pub rng: R,
86 }
87
88 impl<Dist, T, R> SeededRand<Dist, T, R>
89 where
90 Self: Computation,
91 {
92 pub fn new(distribution: Dist, rng: R) -> Self {
93 Self {
94 distribution,
95 ty: PhantomData,
96 rng,
97 }
98 }
99 }
100
101 impl<Dist, T, R> Computation for SeededRand<Dist, T, R>
102 where
103 Dist: Computation,
104 R: Computation,
105 {
106 type Dim = (Dist::Dim, Zero);
107 type Item = (T, R::Item);
108 }
109
110 impl<Dist, T, R> ComputationFn for SeededRand<Dist, T, R>
111 where
112 Self: Computation,
113 Dist: ComputationFn,
114 R: ComputationFn,
115 SeededRand<Dist::Filled, T, R::Filled>: Computation,
116 {
117 type Filled = SeededRand<Dist::Filled, T, R::Filled>;
118
119 fn fill(self, named_args: NamedArgs) -> Self::Filled {
120 let (args_0, args_1) = named_args
121 .partition(&self.rng.arg_names(), &self.distribution.arg_names())
122 .unwrap_or_else(|e| panic!("{}", e,));
123 SeededRand {
124 distribution: self.distribution.fill(args_1),
125 ty: self.ty,
126 rng: self.rng.fill(args_0),
127 }
128 }
129
130 fn arg_names(&self) -> Names {
131 self.rng.arg_names().union(self.distribution.arg_names())
132 }
133 }
134
135 impl_core_ops!(SeededRand<Dist, T, R>);
136
137 impl<Dist, T, R> fmt::Display for SeededRand<Dist, T, R>
138 where
139 Self: Computation,
140 Dist: fmt::Display,
141 R: fmt::Display,
142 {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 write!(f, "seeded_rand({}, {})", self.distribution, self.rng)
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use proptest::prelude::*;
152 use rand::{distributions::Standard, rngs::StdRng, SeedableRng};
153 use test_strategy::proptest;
154
155 use crate::{
156 rand::{rand::Rand, seeded_rand::SeededRand},
157 run::Matrix,
158 val, val1, val2,
159 };
160
161 #[test]
162 fn rands_should_display() {
163 let dist = val!(Standard);
164 assert_eq!(
165 Rand::<_, i32>::new(dist).to_string(),
166 format!("rand({})", dist)
167 );
168 }
169
170 #[proptest]
171 fn rands_should_display_1d(#[strategy(1_usize..10)] x: usize) {
172 let dist = val1!(std::iter::repeat(Standard).take(x).collect::<Vec<_>>());
173 prop_assert_eq!(
174 Rand::<_, i32>::new(dist.clone()).to_string(),
175 format!("rand({})", dist)
176 );
177 }
178
179 #[proptest]
180 fn rands_should_display_2d(
181 #[strategy(1_usize..10)] x: usize,
182 #[strategy(1_usize..10)] y: usize,
183 ) {
184 let dist = val2!(Matrix::from_vec(
185 (x, y),
186 std::iter::repeat(Standard).take(x * y).collect::<Vec<_>>()
187 )
188 .unwrap());
189 prop_assert_eq!(
190 Rand::<_, i32>::new(dist.clone()).to_string(),
191 format!("rand({})", dist)
192 );
193 }
194
195 #[proptest]
196 fn seededrands_should_display(seed: u64) {
197 let dist = val!(Standard);
198 let rng = val!(StdRng::seed_from_u64(seed));
199 prop_assert_eq!(
200 SeededRand::<_, i32, _>::new(dist, rng.clone()).to_string(),
201 format!("seeded_rand({}, {})", dist, rng)
202 );
203 }
204
205 #[proptest]
206 fn seededrands_should_display_1d(seed: u64, #[strategy(1_usize..10)] x: usize) {
207 let dist = val1!(std::iter::repeat(Standard).take(x).collect::<Vec<_>>());
208 let rng = val!(StdRng::seed_from_u64(seed));
209 prop_assert_eq!(
210 SeededRand::<_, i32, _>::new(dist.clone(), rng.clone()).to_string(),
211 format!("seeded_rand({}, {})", dist, rng)
212 );
213 }
214
215 #[proptest]
216 fn seededrands_should_display_2d(
217 seed: u64,
218 #[strategy(1_usize..10)] x: usize,
219 #[strategy(1_usize..10)] y: usize,
220 ) {
221 let dist = val2!(Matrix::from_vec(
222 (x, y),
223 std::iter::repeat(Standard).take(x * y).collect::<Vec<_>>()
224 )
225 .unwrap());
226 let rng = val!(StdRng::seed_from_u64(seed));
227 prop_assert_eq!(
228 SeededRand::<_, i32, _>::new(dist.clone(), rng.clone()).to_string(),
229 format!("seeded_rand({}, {})", dist, rng)
230 );
231 }
232}