use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast};
use scirs2_core::random::prelude::*;
use scirs2_core::{ChiSquared, Distribution as CoreDistribution};
use scirs2_stats::{distributions::Normal, Distribution};
use std::fmt::{Debug, Display};
pub struct NonCentralChiSquared {
df: f64,
nonc: f64,
}
impl NonCentralChiSquared {
pub fn new(df: f64, nonc: f64) -> Result<Self> {
if df <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"df must be positive, got {}",
df
)));
}
if nonc < 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"non-centrality parameter must be non-negative, got {}",
nonc
)));
}
Ok(Self { df, nonc })
}
pub fn sample(&self) -> f64 {
let normal = Normal::new(self.nonc.sqrt(), 1.0)
.expect("noncentral_chisquare: normal distribution should be valid for sqrt(nonc)");
let z = normal.rvs(1).expect("normal sampling failed")[0];
let chi_squared = if self.df > 1.0 {
let mut rng = thread_rng();
ChiSquared::new(self.df - 1.0)
.expect("noncentral_chisquare: chi-squared should be valid for df-1")
.sample(&mut rng)
} else {
0.0
};
z * z + chi_squared
}
}
pub struct NonCentralF {
df1: f64,
df2: f64,
nonc: f64,
}
impl NonCentralF {
pub fn new(df1: f64, df2: f64, nonc: f64) -> Result<Self> {
if df1 <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"df1 must be positive, got {}",
df1
)));
}
if df2 <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"df2 must be positive, got {}",
df2
)));
}
if nonc < 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"non-centrality parameter must be non-negative, got {}",
nonc
)));
}
Ok(Self { df1, df2, nonc })
}
pub fn sample(&self) -> f64 {
let nc_chi2 = NonCentralChiSquared::new(self.df1, self.nonc)
.expect("noncentral_f: NonCentralChiSquared should be valid for df1 and nonc");
let x = nc_chi2.sample();
let mut rng = thread_rng();
let chi2 =
ChiSquared::new(self.df2).expect("noncentral_f: chi-squared should be valid for df2");
let y = chi2.sample(&mut rng);
(x / self.df1) / (y / self.df2)
}
}
pub struct VonMises {
mu: f64,
kappa: f64,
}
impl VonMises {
pub fn new(mu: f64, kappa: f64) -> Result<Self> {
if kappa < 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"kappa must be non-negative, got {}",
kappa
)));
}
Ok(Self { mu, kappa })
}
pub fn sample(&self) -> f64 {
use std::f64::consts::PI;
if self.kappa < 1e-10 {
let mut rng = thread_rng();
let u = rng.random::<f64>();
return self.mu + 2.0 * PI * (u - 0.5);
}
let variance_scaling = if self.kappa > 0.0 {
1.27 / self.kappa.sqrt()
} else {
1.0
};
let normal = Normal::new(self.mu, variance_scaling)
.expect("vonmises: normal distribution should be valid for mu and variance_scaling");
let sample = normal.rvs(1).expect("normal sampling failed")[0];
let mut result = sample;
while result > PI {
result -= 2.0 * PI;
}
while result < -PI {
result += 2.0 * PI;
}
result
}
}
pub struct Maxwell {
scale: f64,
}
impl Maxwell {
pub fn new(scale: f64) -> Result<Self> {
if scale <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"scale must be positive, got {}",
scale
)));
}
Ok(Self { scale })
}
pub fn sample(&self) -> f64 {
let normal = Normal::new(0.0, self.scale)
.expect("maxwell: normal distribution should be valid for mean=0 and scale");
let samples = normal.rvs(3).expect("normal sampling failed");
let x = samples[0];
let y = samples[1];
let z = samples[2];
(x * x + y * y + z * z).sqrt()
}
}
pub struct Wald {
mean: f64,
shape: f64,
}
impl Wald {
pub fn new(mean: f64, shape: f64) -> Result<Self> {
if mean <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"mean must be positive, got {}",
mean
)));
}
if shape <= 0.0 {
return Err(NumRs2Error::InvalidOperation(format!(
"shape must be positive, got {}",
shape
)));
}
Ok(Self { mean, shape })
}
pub fn sample(&self) -> f64 {
let normal =
Normal::new(0.0, 1.0).expect("wald: standard normal N(0,1) should always be valid");
let y = normal.rvs(1).expect("normal sampling failed")[0];
let y_squared = y * y;
let mu = self.mean;
let lambda = self.shape;
let x1 = mu + (mu * mu * y_squared / (2.0 * lambda))
- (mu / (2.0 * lambda))
* (mu * y_squared * 4.0 * lambda / mu + mu * mu * y_squared * y_squared).sqrt();
let mut rng = thread_rng();
let u = rng.random::<f64>();
if u <= mu / (mu + x1) {
x1
} else {
mu * mu / x1
}
}
}
pub fn noncentral_chisquare<T: Float + NumCast + Clone + Debug + Display>(
df: T,
nonc: T,
shape: &[usize],
) -> Result<Array<T>> {
let rng = crate::random::distributions::get_global_random_state()?;
let rng_lock = rng.get_rng()?;
let df_f64 = df
.to_f64()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert df to f64".to_string()))?;
let nonc_f64 = nonc.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert non-centrality parameter to f64".to_string(),
)
})?;
let dist = NonCentralChiSquared::new(df_f64, nonc_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!(
"Failed to create non-central chi-squared distribution: {}",
e
))
})?;
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
for _ in 0..size {
let val_f64 = dist.sample();
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sample to target type".to_string())
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn noncentral_f<T: Float + NumCast + Clone + Debug + Display>(
dfnum: T,
dfden: T,
nonc: T,
shape: &[usize],
) -> Result<Array<T>> {
let rng = crate::random::distributions::get_global_random_state()?;
let rng_lock = rng.get_rng()?;
let dfnum_f64 = dfnum.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert dfnum to f64".to_string())
})?;
let dfden_f64 = dfden.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert dfden to f64".to_string())
})?;
let nonc_f64 = nonc.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert non-centrality parameter to f64".to_string(),
)
})?;
let dist = NonCentralF::new(dfnum_f64, dfden_f64, nonc_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!(
"Failed to create non-central F distribution: {}",
e
))
})?;
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
for _ in 0..size {
let val_f64 = dist.sample();
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sample to target type".to_string())
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn vonmises<T: Float + NumCast + Clone + Debug + Display>(
mu: T,
kappa: T,
shape: &[usize],
) -> Result<Array<T>> {
let rng = crate::random::distributions::get_global_random_state()?;
let rng_lock = rng.get_rng()?;
let mu_f64 = mu
.to_f64()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to convert mu to f64".to_string()))?;
let kappa_f64 = kappa.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert kappa to f64".to_string())
})?;
let dist = VonMises::new(mu_f64, kappa_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create von Mises distribution: {}", e))
})?;
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
for _ in 0..size {
let val_f64 = dist.sample();
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sample to target type".to_string())
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn maxwell<T: Float + NumCast + Clone + Debug + Display>(
scale: T,
shape: &[usize],
) -> Result<Array<T>> {
let rng = crate::random::distributions::get_global_random_state()?;
let rng_lock = rng.get_rng()?;
let scale_f64 = scale.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert scale to f64".to_string())
})?;
let dist = Maxwell::new(scale_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create Maxwell distribution: {}", e))
})?;
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
for _ in 0..size {
let val_f64 = dist.sample();
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sample to target type".to_string())
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn wald<T: Float + NumCast + Clone + Debug + Display>(
mean: T,
scale: T,
shape: &[usize],
) -> Result<Array<T>> {
let rng = crate::random::distributions::get_global_random_state()?;
let rng_lock = rng.get_rng()?;
let mean_f64 = mean.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert mean to f64".to_string())
})?;
let scale_f64 = scale.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert scale to f64".to_string())
})?;
let dist = Wald::new(mean_f64, scale_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create Wald distribution: {}", e))
})?;
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
for _ in 0..size {
let val_f64 = dist.sample();
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sample to target type".to_string())
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vonmises_basic() {
let dist = VonMises::new(0.0, 1.0).expect("test: vonmises distribution should be valid");
for _ in 0..10 {
let sample = dist.sample();
assert!(
(-std::f64::consts::PI..=std::f64::consts::PI).contains(&sample),
"Sample {} out of range",
sample
);
}
}
#[test]
fn test_noncentral_chisquare() {
let arr = noncentral_chisquare(2.0, 1.0, &[10])
.expect("test: noncentral_chisquare should succeed");
assert_eq!(arr.shape(), vec![10]);
for val in arr.to_vec() {
assert!(val > 0.0);
}
}
#[test]
fn test_noncentral_f() {
let arr = noncentral_f(2.0, 3.0, 1.0, &[10]).expect("test: noncentral_f should succeed");
assert_eq!(arr.shape(), vec![10]);
for val in arr.to_vec() {
assert!(val > 0.0);
}
}
#[test]
fn test_vonmises() {
let arr = vonmises(0.0, 1.0, &[10]).expect("test: vonmises should succeed");
assert_eq!(arr.shape(), vec![10]);
assert_eq!(arr.size(), 10);
}
#[test]
fn test_maxwell() {
let arr = maxwell(1.0, &[10]).expect("test: maxwell should succeed");
assert_eq!(arr.shape(), vec![10]);
for val in arr.to_vec() {
assert!(val > 0.0);
}
}
#[test]
fn test_wald() {
let arr = wald(1.0, 1.0, &[10]).expect("test: wald should succeed");
assert_eq!(arr.shape(), vec![10]);
for val in arr.to_vec() {
assert!(val > 0.0);
}
}
}