use ferray_core::{Array, FerrayError, IxDyn};
use crate::bitgen::BitGenerator;
use crate::distributions::normal::standard_normal_single;
use crate::generator::{Generator, generate_vec, shape_size, vec_to_array_f64};
use crate::shape::IntoShape;
pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
if alpha < 1.0 {
if alpha <= 0.0 {
return 0.0;
}
loop {
let u = bg.next_f64();
if u > f64::EPSILON {
let x = standard_gamma_ge1(bg, alpha + 1.0);
return x * u.powf(1.0 / alpha);
}
}
} else {
standard_gamma_ge1(bg, alpha)
}
}
fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
let d = alpha - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let x = standard_normal_single(bg);
let v_base = 1.0 + c * x;
if v_base <= 0.0 {
continue;
}
let v = v_base * v_base * v_base;
let u = bg.next_f64();
if u < (0.0331 * (x * x)).mul_add(-(x * x), 1.0) {
return d * v;
}
if u.ln() < (0.5 * x).mul_add(x, d * (1.0 - v + v.ln())) {
return d * v;
}
}
}
impl<B: BitGenerator> Generator<B> {
pub fn standard_gamma(
&mut self,
alpha: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if alpha <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"alpha must be positive, got {alpha}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| standard_gamma_single(bg, alpha));
vec_to_array_f64(data, &shape_vec)
}
pub fn gamma(
&mut self,
alpha: f64,
scale: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if alpha <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"alpha must be positive, got {alpha}"
)));
}
if scale <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"scale must be positive, got {scale}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| scale * standard_gamma_single(bg, alpha));
vec_to_array_f64(data, &shape_vec)
}
pub fn beta(
&mut self,
a: f64,
b: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if a <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"a must be positive, got {a}"
)));
}
if b <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"b must be positive, got {b}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| {
let x = standard_gamma_single(bg, a);
let y = standard_gamma_single(bg, b);
if x + y == 0.0 {
0.5 } else {
x / (x + y)
}
});
vec_to_array_f64(data, &shape_vec)
}
pub fn chisquare(
&mut self,
df: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if df <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"df must be positive, got {df}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
vec_to_array_f64(data, &shape_vec)
}
pub fn f(
&mut self,
dfnum: f64,
dfden: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if dfnum <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"dfnum must be positive, got {dfnum}"
)));
}
if dfden <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"dfden must be positive, got {dfden}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| {
let x1 = standard_gamma_single(bg, dfnum / 2.0);
let x2 = standard_gamma_single(bg, dfden / 2.0);
if x2 == 0.0 {
f64::INFINITY
} else {
(x1 / dfnum) / (x2 / dfden)
}
});
vec_to_array_f64(data, &shape_vec)
}
pub fn student_t(
&mut self,
df: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if df <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"df must be positive, got {df}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| {
let z = standard_normal_single(bg);
let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
z / (chi2 / df).sqrt()
});
vec_to_array_f64(data, &shape_vec)
}
pub fn standard_t(
&mut self,
df: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
self.student_t(df, size)
}
pub fn noncentral_chisquare(
&mut self,
df: f64,
nonc: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if df <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"df must be positive, got {df}"
)));
}
if nonc < 0.0 {
return Err(FerrayError::invalid_value(format!(
"nonc must be non-negative, got {nonc}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| {
let lam = nonc / 2.0;
let n_pois: u64 = if lam == 0.0 {
0
} else {
let l = (-lam).exp();
let mut k: u64 = 0;
let mut p = 1.0;
loop {
k += 1;
p *= bg.next_f64();
if p <= l {
break k - 1;
}
}
};
let total_df = df + 2.0 * (n_pois as f64);
2.0 * standard_gamma_single(bg, total_df / 2.0)
});
vec_to_array_f64(data, &shape_vec)
}
pub fn noncentral_f(
&mut self,
dfnum: f64,
dfden: f64,
nonc: f64,
size: impl IntoShape,
) -> Result<Array<f64, IxDyn>, FerrayError> {
if dfnum <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"dfnum must be positive, got {dfnum}"
)));
}
if dfden <= 0.0 {
return Err(FerrayError::invalid_value(format!(
"dfden must be positive, got {dfden}"
)));
}
if nonc < 0.0 {
return Err(FerrayError::invalid_value(format!(
"nonc must be non-negative, got {nonc}"
)));
}
let shape_vec = size.into_shape()?;
let n = shape_size(&shape_vec);
let data = generate_vec(self, n, |bg| {
let lam = nonc / 2.0;
let n_pois: u64 = if lam == 0.0 {
0
} else {
let l = (-lam).exp();
let mut k: u64 = 0;
let mut p = 1.0;
loop {
k += 1;
p *= bg.next_f64();
if p <= l {
break k - 1;
}
}
};
let total_dfnum = dfnum + 2.0 * (n_pois as f64);
let chi2_num = 2.0 * standard_gamma_single(bg, total_dfnum / 2.0);
let chi2_den = 2.0 * standard_gamma_single(bg, dfden / 2.0);
if chi2_den == 0.0 {
f64::INFINITY
} else {
(chi2_num / dfnum) / (chi2_den / dfden)
}
});
vec_to_array_f64(data, &shape_vec)
}
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
#[test]
fn gamma_positive() {
let mut rng = default_rng_seeded(42);
let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0);
}
}
#[test]
fn gamma_mean_variance() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let shape = 3.0;
let scale = 2.0;
let arr = rng.gamma(shape, scale, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let expected_mean = shape * scale;
let expected_var = shape * scale * scale;
let se = (expected_var / n as f64).sqrt();
assert!(
(mean - expected_mean).abs() < 3.0 * se,
"gamma mean {mean} too far from {expected_mean}"
);
assert!(
(var - expected_var).abs() / expected_var < 0.05,
"gamma variance {var} too far from {expected_var}"
);
}
#[test]
fn gamma_small_shape() {
let mut rng = default_rng_seeded(42);
let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0);
}
}
#[test]
fn beta_in_range() {
let mut rng = default_rng_seeded(42);
let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
}
}
#[test]
fn beta_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let a = 2.0;
let b = 5.0;
let arr = rng.beta(a, b, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let expected_mean = a / (a + b);
let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
let se = (expected_var / n as f64).sqrt();
assert!(
(mean - expected_mean).abs() < 3.0 * se,
"beta mean {mean} too far from {expected_mean}"
);
}
#[test]
fn chisquare_positive() {
let mut rng = default_rng_seeded(42);
let arr = rng.chisquare(5.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0);
}
}
#[test]
fn chisquare_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let df = 10.0;
let arr = rng.chisquare(df, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let expected_var = 2.0 * df;
let se = (expected_var / n as f64).sqrt();
assert!(
(mean - df).abs() < 3.0 * se,
"chisquare mean {mean} too far from {df}"
);
}
#[test]
fn f_positive() {
let mut rng = default_rng_seeded(42);
let arr = rng.f(5.0, 10.0, 10_000).unwrap();
let slice = arr.as_slice().unwrap();
for &v in slice {
assert!(v > 0.0);
}
}
#[test]
fn student_t_symmetric() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let df = 10.0;
let arr = rng.student_t(df, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
}
#[test]
fn standard_gamma_mean() {
let mut rng = default_rng_seeded(42);
let n = 100_000;
let shape = 5.0;
let arr = rng.standard_gamma(shape, n).unwrap();
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / n as f64;
let se = (shape / n as f64).sqrt();
assert!(
(mean - shape).abs() < 3.0 * se,
"standard_gamma mean {mean} too far from {shape}"
);
}
#[test]
fn gamma_bad_params() {
let mut rng = default_rng_seeded(42);
assert!(rng.gamma(0.0, 1.0, 100).is_err());
assert!(rng.gamma(1.0, 0.0, 100).is_err());
assert!(rng.gamma(-1.0, 1.0, 100).is_err());
}
#[test]
fn standard_t_alias_matches_student_t() {
let mut rng_a = default_rng_seeded(7);
let mut rng_b = default_rng_seeded(7);
let a = rng_a.student_t(5.0, 100).unwrap();
let b = rng_b.standard_t(5.0, 100).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn noncentral_chisquare_mean_approx() {
let mut rng = default_rng_seeded(42);
let n = 50_000;
let arr = rng.noncentral_chisquare(5.0, 3.0, n).unwrap();
let s = arr.as_slice().unwrap();
let mean: f64 = s.iter().sum::<f64>() / n as f64;
assert!((mean - 8.0).abs() < 0.5, "noncentral_chisquare mean {mean}");
}
#[test]
fn noncentral_chisquare_zero_lambda_matches_chisquare() {
let mut rng_a = default_rng_seeded(11);
let mut rng_b = default_rng_seeded(11);
let a = rng_a.noncentral_chisquare(4.0, 0.0, 1000).unwrap();
let b = rng_b.chisquare(4.0, 1000).unwrap();
for (x, y) in a.as_slice().unwrap().iter().zip(b.as_slice().unwrap()) {
assert!((x - y).abs() < 1e-12);
}
}
#[test]
fn noncentral_chisquare_bad_params() {
let mut rng = default_rng_seeded(0);
assert!(rng.noncentral_chisquare(0.0, 1.0, 10).is_err());
assert!(rng.noncentral_chisquare(1.0, -1.0, 10).is_err());
}
#[test]
fn noncentral_f_positive() {
let mut rng = default_rng_seeded(100);
let arr = rng.noncentral_f(5.0, 7.0, 2.0, 1000).unwrap();
for &v in arr.as_slice().unwrap() {
assert!(v >= 0.0);
}
}
#[test]
fn noncentral_f_bad_params() {
let mut rng = default_rng_seeded(0);
assert!(rng.noncentral_f(0.0, 1.0, 1.0, 10).is_err());
assert!(rng.noncentral_f(1.0, 0.0, 1.0, 10).is_err());
assert!(rng.noncentral_f(1.0, 1.0, -1.0, 10).is_err());
}
}