ferray_random/distributions/
normal.rs1use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
7use crate::generator::{
8 Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
9};
10use crate::shape::IntoShape;
11
12pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
19 standard_normal_ziggurat(bg)
20}
21
22pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
28 standard_normal_ziggurat_f32(bg)
29}
30
31impl<B: BitGenerator> Generator<B> {
32 pub fn standard_normal(
41 &mut self,
42 size: impl IntoShape,
43 ) -> Result<Array<f64, IxDyn>, FerrayError> {
44 let shape = size.into_shape()?;
45 let n = shape_size(&shape);
46 let data = generate_vec(self, n, standard_normal_single);
47 vec_to_array_f64(data, &shape)
48 }
49
50 pub fn normal(
56 &mut self,
57 loc: f64,
58 scale: f64,
59 size: impl IntoShape,
60 ) -> Result<Array<f64, IxDyn>, FerrayError> {
61 if scale <= 0.0 {
62 return Err(FerrayError::invalid_value(format!(
63 "scale must be positive, got {scale}"
64 )));
65 }
66 let shape = size.into_shape()?;
67 let n = shape_size(&shape);
68 let data = generate_vec(self, n, |bg| scale.mul_add(standard_normal_single(bg), loc));
69 vec_to_array_f64(data, &shape)
70 }
71
72 pub fn standard_normal_f32(
81 &mut self,
82 size: impl IntoShape,
83 ) -> Result<Array<f32, IxDyn>, FerrayError> {
84 let shape = size.into_shape()?;
85 let n = shape_size(&shape);
86 let data = generate_vec_f32(self, n, standard_normal_single_f32);
87 vec_to_array_f32(data, &shape)
88 }
89
90 pub fn normal_f32(
96 &mut self,
97 loc: f32,
98 scale: f32,
99 size: impl IntoShape,
100 ) -> Result<Array<f32, IxDyn>, FerrayError> {
101 if scale <= 0.0 {
102 return Err(FerrayError::invalid_value(format!(
103 "scale must be positive, got {scale}"
104 )));
105 }
106 let shape = size.into_shape()?;
107 let n = shape_size(&shape);
108 let data = generate_vec_f32(self, n, |bg| {
109 scale.mul_add(standard_normal_single_f32(bg), loc)
110 });
111 vec_to_array_f32(data, &shape)
112 }
113
114 pub fn lognormal_f32(
120 &mut self,
121 mean: f32,
122 sigma: f32,
123 size: impl IntoShape,
124 ) -> Result<Array<f32, IxDyn>, FerrayError> {
125 if sigma <= 0.0 {
126 return Err(FerrayError::invalid_value(format!(
127 "sigma must be positive, got {sigma}"
128 )));
129 }
130 let shape = size.into_shape()?;
131 let n = shape_size(&shape);
132 let data = generate_vec_f32(self, n, |bg| {
133 sigma.mul_add(standard_normal_single_f32(bg), mean).exp()
134 });
135 vec_to_array_f32(data, &shape)
136 }
137
138 pub fn lognormal(
145 &mut self,
146 mean: f64,
147 sigma: f64,
148 size: impl IntoShape,
149 ) -> Result<Array<f64, IxDyn>, FerrayError> {
150 if sigma <= 0.0 {
151 return Err(FerrayError::invalid_value(format!(
152 "sigma must be positive, got {sigma}"
153 )));
154 }
155 let shape = size.into_shape()?;
156 let n = shape_size(&shape);
157 let data = generate_vec(self, n, |bg| {
158 sigma.mul_add(standard_normal_single(bg), mean).exp()
159 });
160 vec_to_array_f64(data, &shape)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use crate::default_rng_seeded;
167
168 #[test]
169 fn standard_normal_deterministic() {
170 let mut rng1 = default_rng_seeded(42);
171 let mut rng2 = default_rng_seeded(42);
172 let a = rng1.standard_normal(1000).unwrap();
173 let b = rng2.standard_normal(1000).unwrap();
174 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
175 }
176
177 #[test]
178 fn standard_normal_mean_variance() {
179 let mut rng = default_rng_seeded(42);
180 let n = 100_000;
181 let arr = rng.standard_normal(n).unwrap();
182 let slice = arr.as_slice().unwrap();
183 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
184 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
185 let se = (1.0 / n as f64).sqrt();
186 assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
187 assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
188 }
189
190 #[test]
191 fn normal_mean_variance() {
192 let mut rng = default_rng_seeded(42);
193 let n = 100_000;
194 let loc = 5.0;
195 let scale = 2.0;
196 let arr = rng.normal(loc, scale, n).unwrap();
197 let slice = arr.as_slice().unwrap();
198 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
199 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
200 let se = (scale * scale / n as f64).sqrt();
201 assert!(
202 (mean - loc).abs() < 3.0 * se,
203 "mean {mean} too far from {loc}"
204 );
205 assert!(
206 (var - scale * scale).abs() < 0.2,
207 "variance {var} too far from {}",
208 scale * scale
209 );
210 }
211
212 #[test]
213 fn normal_bad_scale() {
214 let mut rng = default_rng_seeded(42);
215 assert!(rng.normal(0.0, 0.0, 100).is_err());
216 assert!(rng.normal(0.0, -1.0, 100).is_err());
217 }
218
219 #[test]
220 fn lognormal_positive() {
221 let mut rng = default_rng_seeded(42);
222 let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
223 let slice = arr.as_slice().unwrap();
224 for &v in slice {
225 assert!(v > 0.0, "lognormal produced non-positive value: {v}");
226 }
227 }
228
229 #[test]
230 fn lognormal_mean() {
231 let mut rng = default_rng_seeded(42);
232 let n = 100_000;
233 let mu = 0.0;
234 let sigma = 0.5;
235 let arr = rng.lognormal(mu, sigma, n).unwrap();
236 let slice = arr.as_slice().unwrap();
237 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
238 let expected_mean = (mu + sigma * sigma / 2.0).exp();
240 let expected_var = (sigma * sigma).exp_m1() * 2.0f64.mul_add(mu, sigma * sigma).exp();
241 let se = (expected_var / n as f64).sqrt();
242 assert!(
243 (mean - expected_mean).abs() < 3.0 * se,
244 "lognormal mean {mean} too far from {expected_mean}"
245 );
246 }
247
248 #[test]
249 fn standard_normal_variance() {
250 let mut rng = default_rng_seeded(42);
251 let n = 100_000;
252 let arr = rng.standard_normal(n).unwrap();
253 let s = arr.as_slice().unwrap();
254 let mean: f64 = s.iter().sum::<f64>() / n as f64;
255 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
256 assert!(
258 (var - 1.0).abs() < 0.05,
259 "standard_normal variance {var} too far from 1.0"
260 );
261 }
262
263 #[test]
264 fn normal_mean_and_variance() {
265 let mut rng = default_rng_seeded(42);
266 let n = 100_000;
267 let loc = 5.0;
268 let scale = 2.0;
269 let arr = rng.normal(loc, scale, n).unwrap();
270 let s: Vec<f64> = arr.iter().copied().collect();
271 let mean: f64 = s.iter().sum::<f64>() / n as f64;
272 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
273 assert!(
274 (mean - loc).abs() < 0.05,
275 "normal mean {mean} too far from {loc}"
276 );
277 assert!(
278 (var - scale * scale).abs() < 0.2,
279 "normal variance {var} too far from {}",
280 scale * scale
281 );
282 }
283
284 #[test]
287 fn standard_normal_nd_shape() {
288 let mut rng = crate::default_rng_seeded(42);
289 let arr = rng.standard_normal([3, 4]).unwrap();
290 assert_eq!(arr.shape(), &[3, 4]);
291 }
292
293 #[test]
294 fn normal_nd_shape() {
295 let mut rng = crate::default_rng_seeded(42);
296 let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
297 assert_eq!(arr.shape(), &[2, 3, 4]);
298 }
299
300 #[test]
301 fn lognormal_nd_shape() {
302 let mut rng = crate::default_rng_seeded(42);
303 let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
304 assert_eq!(arr.shape(), &[5, 5]);
305 for &v in arr.iter() {
306 assert!(v > 0.0);
307 }
308 }
309
310 #[test]
315 fn standard_normal_f32_deterministic() {
316 let mut rng1 = default_rng_seeded(42);
317 let mut rng2 = default_rng_seeded(42);
318 let a = rng1.standard_normal_f32(1000).unwrap();
319 let b = rng2.standard_normal_f32(1000).unwrap();
320 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
321 }
322
323 #[test]
324 fn standard_normal_f32_mean_variance() {
325 let mut rng = default_rng_seeded(42);
326 let n = 100_000usize;
327 let arr = rng.standard_normal_f32(n).unwrap();
328 let slice = arr.as_slice().unwrap();
329 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
331 let var: f64 = slice
332 .iter()
333 .map(|&x| {
334 let d = x as f64 - mean;
335 d * d
336 })
337 .sum::<f64>()
338 / n as f64;
339 let se = (1.0 / n as f64).sqrt();
340 assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
341 assert!(
342 (var - 1.0).abs() < 0.05,
343 "f32 variance {var} too far from 1"
344 );
345 }
346
347 #[test]
348 fn standard_normal_f32_nd_shape() {
349 let mut rng = default_rng_seeded(42);
350 let arr = rng.standard_normal_f32([3, 4]).unwrap();
351 assert_eq!(arr.shape(), &[3, 4]);
352 }
353
354 #[test]
355 fn normal_f32_mean() {
356 let mut rng = default_rng_seeded(42);
357 let n = 100_000usize;
358 let loc = 5.0f32;
359 let scale = 2.0f32;
360 let arr = rng.normal_f32(loc, scale, n).unwrap();
361 let slice = arr.as_slice().unwrap();
362 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
363 assert!(
364 (mean - loc as f64).abs() < 0.05,
365 "f32 normal mean {mean} too far from {loc}"
366 );
367 }
368
369 #[test]
370 fn normal_f32_bad_scale() {
371 let mut rng = default_rng_seeded(42);
372 assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
373 assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
374 }
375
376 #[test]
377 fn lognormal_f32_positive() {
378 let mut rng = default_rng_seeded(42);
379 let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
380 for &v in arr.as_slice().unwrap() {
381 assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
382 }
383 }
384
385 #[test]
386 fn lognormal_f32_bad_sigma() {
387 let mut rng = default_rng_seeded(42);
388 assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
389 assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
390 }
391
392 #[test]
395 fn normal_nan_loc_produces_nan_output() {
396 let mut rng = default_rng_seeded(42);
398 let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
399 for &v in arr.as_slice().unwrap() {
400 assert!(v.is_nan(), "expected NaN, got {v}");
401 }
402 }
403
404 #[test]
405 fn normal_inf_scale_produces_inf_output() {
406 let mut rng = default_rng_seeded(42);
408 let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
409 for &v in arr.as_slice().unwrap() {
410 assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
411 }
412 }
413
414 #[test]
415 fn normal_nan_scale_rejected() {
416 let mut rng = default_rng_seeded(42);
419 let _ = rng.normal(0.0, f64::NAN, 5);
423 }
424}