1use ferray_core::dimension::broadcast::broadcast_shapes;
4use ferray_core::{Array, FerrayError, IxDyn};
5
6use crate::bitgen::BitGenerator;
7use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
8use crate::generator::{
9 Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
10};
11use crate::shape::IntoShape;
12
13pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
20 standard_normal_ziggurat(bg)
21}
22
23pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
29 standard_normal_ziggurat_f32(bg)
30}
31
32impl<B: BitGenerator> Generator<B> {
33 pub fn standard_normal(
42 &mut self,
43 size: impl IntoShape,
44 ) -> Result<Array<f64, IxDyn>, FerrayError> {
45 let shape = size.into_shape()?;
46 let n = shape_size(&shape);
47 let data = generate_vec(self, n, standard_normal_single);
48 vec_to_array_f64(data, &shape)
49 }
50
51 pub fn normal(
57 &mut self,
58 loc: f64,
59 scale: f64,
60 size: impl IntoShape,
61 ) -> Result<Array<f64, IxDyn>, FerrayError> {
62 if scale <= 0.0 {
63 return Err(FerrayError::invalid_value(format!(
64 "scale must be positive, got {scale}"
65 )));
66 }
67 let shape = size.into_shape()?;
68 let n = shape_size(&shape);
69 let data = generate_vec(self, n, |bg| scale.mul_add(standard_normal_single(bg), loc));
70 vec_to_array_f64(data, &shape)
71 }
72
73 pub fn standard_normal_into(&mut self, out: &mut Array<f64, IxDyn>) -> Result<(), FerrayError> {
83 let slice = out.as_slice_mut().ok_or_else(|| {
84 FerrayError::invalid_value("standard_normal_into requires a contiguous out buffer")
85 })?;
86 for v in slice.iter_mut() {
87 *v = standard_normal_single(&mut self.bg);
88 }
89 Ok(())
90 }
91
92 pub fn normal_array(
109 &mut self,
110 loc: &Array<f64, IxDyn>,
111 scale: &Array<f64, IxDyn>,
112 ) -> Result<Array<f64, IxDyn>, FerrayError> {
113 let target = broadcast_shapes(loc.shape(), scale.shape())?;
114 let loc_v = loc.broadcast_to(&target)?;
115 let scale_v = scale.broadcast_to(&target)?;
116 let total: usize = target.iter().product();
117 let mut out: Vec<f64> = Vec::with_capacity(total);
118 for (&l, &s) in loc_v.iter().zip(scale_v.iter()) {
119 if s <= 0.0 {
120 return Err(FerrayError::invalid_value(format!(
121 "scale must be positive, got {s}"
122 )));
123 }
124 out.push(s.mul_add(standard_normal_single(&mut self.bg), l));
125 }
126 Array::<f64, IxDyn>::from_vec(IxDyn::new(&target), out)
127 }
128
129 pub fn standard_normal_f32(
138 &mut self,
139 size: impl IntoShape,
140 ) -> Result<Array<f32, IxDyn>, FerrayError> {
141 let shape = size.into_shape()?;
142 let n = shape_size(&shape);
143 let data = generate_vec_f32(self, n, standard_normal_single_f32);
144 vec_to_array_f32(data, &shape)
145 }
146
147 pub fn normal_f32(
153 &mut self,
154 loc: f32,
155 scale: f32,
156 size: impl IntoShape,
157 ) -> Result<Array<f32, IxDyn>, FerrayError> {
158 if scale <= 0.0 {
159 return Err(FerrayError::invalid_value(format!(
160 "scale must be positive, got {scale}"
161 )));
162 }
163 let shape = size.into_shape()?;
164 let n = shape_size(&shape);
165 let data = generate_vec_f32(self, n, |bg| {
166 scale.mul_add(standard_normal_single_f32(bg), loc)
167 });
168 vec_to_array_f32(data, &shape)
169 }
170
171 pub fn lognormal_f32(
177 &mut self,
178 mean: f32,
179 sigma: f32,
180 size: impl IntoShape,
181 ) -> Result<Array<f32, IxDyn>, FerrayError> {
182 if sigma <= 0.0 {
183 return Err(FerrayError::invalid_value(format!(
184 "sigma must be positive, got {sigma}"
185 )));
186 }
187 let shape = size.into_shape()?;
188 let n = shape_size(&shape);
189 let data = generate_vec_f32(self, n, |bg| {
190 sigma.mul_add(standard_normal_single_f32(bg), mean).exp()
191 });
192 vec_to_array_f32(data, &shape)
193 }
194
195 pub fn lognormal(
202 &mut self,
203 mean: f64,
204 sigma: f64,
205 size: impl IntoShape,
206 ) -> Result<Array<f64, IxDyn>, FerrayError> {
207 if sigma <= 0.0 {
208 return Err(FerrayError::invalid_value(format!(
209 "sigma must be positive, got {sigma}"
210 )));
211 }
212 let shape = size.into_shape()?;
213 let n = shape_size(&shape);
214 let data = generate_vec(self, n, |bg| {
215 sigma.mul_add(standard_normal_single(bg), mean).exp()
216 });
217 vec_to_array_f64(data, &shape)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use crate::default_rng_seeded;
224
225 #[test]
226 fn standard_normal_deterministic() {
227 let mut rng1 = default_rng_seeded(42);
228 let mut rng2 = default_rng_seeded(42);
229 let a = rng1.standard_normal(1000).unwrap();
230 let b = rng2.standard_normal(1000).unwrap();
231 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
232 }
233
234 #[test]
235 fn standard_normal_mean_variance() {
236 let mut rng = default_rng_seeded(42);
237 let n = 100_000;
238 let arr = rng.standard_normal(n).unwrap();
239 let slice = arr.as_slice().unwrap();
240 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
241 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
242 let se = (1.0 / n as f64).sqrt();
243 assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
244 assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
245 }
246
247 #[test]
248 fn normal_mean_variance() {
249 let mut rng = default_rng_seeded(42);
250 let n = 100_000;
251 let loc = 5.0;
252 let scale = 2.0;
253 let arr = rng.normal(loc, scale, n).unwrap();
254 let slice = arr.as_slice().unwrap();
255 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
256 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
257 let se = (scale * scale / n as f64).sqrt();
258 assert!(
259 (mean - loc).abs() < 3.0 * se,
260 "mean {mean} too far from {loc}"
261 );
262 assert!(
263 (var - scale * scale).abs() < 0.2,
264 "variance {var} too far from {}",
265 scale * scale
266 );
267 }
268
269 #[test]
270 fn normal_bad_scale() {
271 let mut rng = default_rng_seeded(42);
272 assert!(rng.normal(0.0, 0.0, 100).is_err());
273 assert!(rng.normal(0.0, -1.0, 100).is_err());
274 }
275
276 #[test]
277 fn standard_normal_into_matches_allocating_version() {
278 use ferray_core::{Array, IxDyn};
279 let mut a = default_rng_seeded(42);
280 let mut b = default_rng_seeded(42);
281 let allocated = a.standard_normal([4, 5]).unwrap();
282 let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[4, 5]), vec![0.0; 20]).unwrap();
283 b.standard_normal_into(&mut buf).unwrap();
284 assert_eq!(allocated.as_slice().unwrap(), buf.as_slice().unwrap());
285 }
286
287 #[test]
288 fn normal_array_broadcast_scalar_x_vector() {
289 use ferray_core::IxDyn;
290 let mut rng = default_rng_seeded(42);
291 let loc =
293 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0, 10.0, -5.0])
294 .unwrap();
295 let scale =
296 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1]), vec![1.0]).unwrap();
297 let out = rng.normal_array(&loc, &scale).unwrap();
298 assert_eq!(out.shape(), &[3]);
299 }
300
301 #[test]
302 fn normal_array_2d_broadcast_means_match_loc() {
303 use ferray_core::IxDyn;
304 let mut rng = default_rng_seeded(7);
308 let loc =
309 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3, 1]), vec![0.0, 5.0, -3.0])
310 .unwrap();
311 let scale = ferray_core::Array::<f64, IxDyn>::from_vec(
312 IxDyn::new(&[1, 4]),
313 vec![1.0, 0.5, 2.0, 0.1],
314 )
315 .unwrap();
316
317 let n_trials = 5_000;
318 let mut row_sums = [0.0_f64; 3];
319 for _ in 0..n_trials {
320 let out = rng.normal_array(&loc, &scale).unwrap();
321 assert_eq!(out.shape(), &[3, 4]);
322 let s = out.as_slice().unwrap();
323 for r in 0..3 {
324 for c in 0..4 {
325 row_sums[r] += s[r * 4 + c];
326 }
327 }
328 }
329 let denom = (n_trials * 4) as f64;
331 let expected = [0.0, 5.0, -3.0];
332 for r in 0..3 {
333 let m = row_sums[r] / denom;
334 assert!(
335 (m - expected[r]).abs() < 0.05,
336 "row {r} mean {m} too far from {}",
337 expected[r]
338 );
339 }
340 }
341
342 #[test]
343 fn normal_array_bad_scale_errors() {
344 use ferray_core::IxDyn;
345 let mut rng = default_rng_seeded(0);
346 let loc =
347 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0, 0.0]).unwrap();
348 let scale =
349 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, -0.5]).unwrap();
350 assert!(rng.normal_array(&loc, &scale).is_err());
351 }
352
353 #[test]
354 fn normal_array_shape_mismatch_errors() {
355 use ferray_core::IxDyn;
356 let mut rng = default_rng_seeded(0);
357 let loc =
358 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0; 3]).unwrap();
359 let scale =
360 ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0; 2]).unwrap();
361 assert!(rng.normal_array(&loc, &scale).is_err());
362 }
363
364 #[test]
365 fn lognormal_positive() {
366 let mut rng = default_rng_seeded(42);
367 let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
368 let slice = arr.as_slice().unwrap();
369 for &v in slice {
370 assert!(v > 0.0, "lognormal produced non-positive value: {v}");
371 }
372 }
373
374 #[test]
375 fn lognormal_mean() {
376 let mut rng = default_rng_seeded(42);
377 let n = 100_000;
378 let mu = 0.0;
379 let sigma = 0.5;
380 let arr = rng.lognormal(mu, sigma, n).unwrap();
381 let slice = arr.as_slice().unwrap();
382 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
383 let expected_mean = (mu + sigma * sigma / 2.0).exp();
385 let expected_var = (sigma * sigma).exp_m1() * 2.0f64.mul_add(mu, sigma * sigma).exp();
386 let se = (expected_var / n as f64).sqrt();
387 assert!(
388 (mean - expected_mean).abs() < 3.0 * se,
389 "lognormal mean {mean} too far from {expected_mean}"
390 );
391 }
392
393 #[test]
394 fn standard_normal_variance() {
395 let mut rng = default_rng_seeded(42);
396 let n = 100_000;
397 let arr = rng.standard_normal(n).unwrap();
398 let s = arr.as_slice().unwrap();
399 let mean: f64 = s.iter().sum::<f64>() / n as f64;
400 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
401 assert!(
403 (var - 1.0).abs() < 0.05,
404 "standard_normal variance {var} too far from 1.0"
405 );
406 }
407
408 #[test]
409 fn normal_mean_and_variance() {
410 let mut rng = default_rng_seeded(42);
411 let n = 100_000;
412 let loc = 5.0;
413 let scale = 2.0;
414 let arr = rng.normal(loc, scale, n).unwrap();
415 let s: Vec<f64> = arr.iter().copied().collect();
416 let mean: f64 = s.iter().sum::<f64>() / n as f64;
417 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
418 assert!(
419 (mean - loc).abs() < 0.05,
420 "normal mean {mean} too far from {loc}"
421 );
422 assert!(
423 (var - scale * scale).abs() < 0.2,
424 "normal variance {var} too far from {}",
425 scale * scale
426 );
427 }
428
429 #[test]
432 fn standard_normal_nd_shape() {
433 let mut rng = crate::default_rng_seeded(42);
434 let arr = rng.standard_normal([3, 4]).unwrap();
435 assert_eq!(arr.shape(), &[3, 4]);
436 }
437
438 #[test]
439 fn normal_nd_shape() {
440 let mut rng = crate::default_rng_seeded(42);
441 let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
442 assert_eq!(arr.shape(), &[2, 3, 4]);
443 }
444
445 #[test]
446 fn lognormal_nd_shape() {
447 let mut rng = crate::default_rng_seeded(42);
448 let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
449 assert_eq!(arr.shape(), &[5, 5]);
450 for &v in arr.iter() {
451 assert!(v > 0.0);
452 }
453 }
454
455 #[test]
460 fn standard_normal_f32_deterministic() {
461 let mut rng1 = default_rng_seeded(42);
462 let mut rng2 = default_rng_seeded(42);
463 let a = rng1.standard_normal_f32(1000).unwrap();
464 let b = rng2.standard_normal_f32(1000).unwrap();
465 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
466 }
467
468 #[test]
469 fn standard_normal_f32_mean_variance() {
470 let mut rng = default_rng_seeded(42);
471 let n = 100_000usize;
472 let arr = rng.standard_normal_f32(n).unwrap();
473 let slice = arr.as_slice().unwrap();
474 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
476 let var: f64 = slice
477 .iter()
478 .map(|&x| {
479 let d = x as f64 - mean;
480 d * d
481 })
482 .sum::<f64>()
483 / n as f64;
484 let se = (1.0 / n as f64).sqrt();
485 assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
486 assert!(
487 (var - 1.0).abs() < 0.05,
488 "f32 variance {var} too far from 1"
489 );
490 }
491
492 #[test]
493 fn standard_normal_f32_nd_shape() {
494 let mut rng = default_rng_seeded(42);
495 let arr = rng.standard_normal_f32([3, 4]).unwrap();
496 assert_eq!(arr.shape(), &[3, 4]);
497 }
498
499 #[test]
500 fn normal_f32_mean() {
501 let mut rng = default_rng_seeded(42);
502 let n = 100_000usize;
503 let loc = 5.0f32;
504 let scale = 2.0f32;
505 let arr = rng.normal_f32(loc, scale, n).unwrap();
506 let slice = arr.as_slice().unwrap();
507 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
508 assert!(
509 (mean - loc as f64).abs() < 0.05,
510 "f32 normal mean {mean} too far from {loc}"
511 );
512 }
513
514 #[test]
515 fn normal_f32_bad_scale() {
516 let mut rng = default_rng_seeded(42);
517 assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
518 assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
519 }
520
521 #[test]
522 fn lognormal_f32_positive() {
523 let mut rng = default_rng_seeded(42);
524 let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
525 for &v in arr.as_slice().unwrap() {
526 assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
527 }
528 }
529
530 #[test]
531 fn lognormal_f32_bad_sigma() {
532 let mut rng = default_rng_seeded(42);
533 assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
534 assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
535 }
536
537 #[test]
540 fn normal_nan_loc_produces_nan_output() {
541 let mut rng = default_rng_seeded(42);
543 let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
544 for &v in arr.as_slice().unwrap() {
545 assert!(v.is_nan(), "expected NaN, got {v}");
546 }
547 }
548
549 #[test]
550 fn normal_inf_scale_produces_inf_output() {
551 let mut rng = default_rng_seeded(42);
553 let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
554 for &v in arr.as_slice().unwrap() {
555 assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
556 }
557 }
558
559 #[test]
560 fn normal_nan_scale_rejected() {
561 let mut rng = default_rng_seeded(42);
564 let _ = rng.normal(0.0, f64::NAN, 5);
568 }
569}