#![cfg(feature = "scirs")]
use approx::assert_relative_eq;
use numrs2::array::Array;
use numrs2::interop::scirs_compat::*;
use numrs2::random::advanced_distributions::{maxwell, vonmises};
use numrs2::random::distributions::{multivariate_normal_with_rotation, set_seed};
use numrs2::random::distributions_enhanced::truncated_normal;
use std::f64::consts::PI;
#[test]
fn test_noncentral_chisquare() {
set_seed(12345);
let df = 5.0f64;
let nonc = 2.0f64;
let samples = noncentral_chisquare(df, nonc, &[1000]).unwrap();
assert_eq!(samples.shape(), vec![1000]);
for val in samples.to_vec() {
assert!(val > 0.0);
}
let data = samples.to_vec();
let mean = data.iter().sum::<f64>() / data.len() as f64;
let expected_mean = df + nonc;
assert_relative_eq!(mean, expected_mean, epsilon = 0.5);
assert!(noncentral_chisquare(-1.0, 2.0, &[10]).is_err());
assert!(noncentral_chisquare(2.0, -1.0, &[10]).is_err());
}
#[test]
fn test_noncentral_f() {
set_seed(12345);
let dfnum = 5.0f64;
let dfden = 10.0f64;
let nonc = 2.0f64;
let samples = noncentral_f(dfnum, dfden, nonc, &[100]).unwrap();
assert_eq!(samples.shape(), vec![100]);
for val in samples.to_vec() {
assert!(val > 0.0);
}
assert!(noncentral_f(-1.0, 10.0, 2.0, &[10]).is_err());
assert!(noncentral_f(5.0, -1.0, 2.0, &[10]).is_err());
assert!(noncentral_f(5.0, 10.0, -1.0, &[10]).is_err());
}
#[test]
fn test_vonmises() {
set_seed(12345);
let mu = 0.0f64;
let kappa = 2.0f64;
let samples = vonmises(mu, kappa, &[500]).unwrap();
assert_eq!(samples.shape(), vec![500]);
for val in samples.to_vec() {
assert!((-PI..=PI).contains(&val));
}
let data = samples.to_vec();
let mean_direction = data.to_vec();
let mean_sin = mean_direction.iter().map(|&x| x.sin()).sum::<f64>() / data.len() as f64;
let mean_cos = mean_direction.iter().map(|&x| x.cos()).sum::<f64>() / data.len() as f64;
let circular_mean = mean_sin.atan2(mean_cos);
assert_relative_eq!(circular_mean, mu, epsilon = 0.2);
assert!(vonmises(0.0, -1.0, &[10]).is_err());
}
#[test]
fn test_maxwell() {
set_seed(12345);
let scale = 1.0f64;
let samples = maxwell(scale, &[500]).unwrap();
assert_eq!(samples.shape(), vec![500]);
for val in samples.to_vec() {
assert!(val > 0.0);
}
let data = samples.to_vec();
let mean = data.iter().sum::<f64>() / data.len() as f64;
let expected_mean = 2.0 * scale * (2.0 / PI).sqrt();
assert_relative_eq!(mean, expected_mean, epsilon = 0.1);
assert!(maxwell(-1.0, &[10]).is_err());
assert!(maxwell(0.0, &[10]).is_err());
}
#[test]
fn test_truncated_normal() {
set_seed(12345);
let mean = 0.0f64;
let std = 1.0f64;
let low = -2.0f64;
let high = 2.0f64;
let samples = truncated_normal(mean, std, low, high, &[500]).unwrap();
assert_eq!(samples.shape(), vec![500]);
for val in samples.to_vec() {
assert!(val >= low && val <= high);
}
assert!(truncated_normal(0.0, -1.0, -2.0, 2.0, &[10]).is_err());
assert!(truncated_normal(0.0, 1.0, 2.0, -2.0, &[10]).is_err());
assert!(truncated_normal(0.0, 1.0, 2.0, 2.0, &[10]).is_err());
}
#[test]
fn test_multivariate_normal_with_rotation() {
set_seed(12345);
let mean = vec![0.0f64, 0.0f64];
let cov_data = vec![1.0, 0.5, 0.5, 1.0];
let cov = Array::from_vec(cov_data).reshape(&[2, 2]);
let samples = multivariate_normal_with_rotation(&mean, &cov, Some(&[100]), None).unwrap();
assert_eq!(samples.shape(), vec![100, 2]);
use std::f64::consts::FRAC_1_SQRT_2;
let rot_data = vec![FRAC_1_SQRT_2, FRAC_1_SQRT_2, -FRAC_1_SQRT_2, FRAC_1_SQRT_2]; let rotation = Array::from_vec(rot_data).reshape(&[2, 2]);
let samples_rotated =
multivariate_normal_with_rotation(&mean, &cov, Some(&[100]), Some(&rotation)).unwrap();
assert_eq!(samples_rotated.shape(), vec![100, 2]);
assert!(multivariate_normal_with_rotation(&[], &cov, Some(&[10]), None).is_err());
let bad_cov = Array::from_vec(vec![1.0, 0.5, 0.5]).reshape(&[3, 1]);
assert!(multivariate_normal_with_rotation(&mean, &bad_cov, Some(&[10]), None).is_err());
}
#[test]
fn test_distribution_shapes() {
set_seed(12345);
let samples_1d = noncentral_chisquare(5.0f64, 2.0f64, &[10]).unwrap();
assert_eq!(samples_1d.shape(), vec![10]);
let samples_2d = noncentral_chisquare(5.0f64, 2.0f64, &[5, 4]).unwrap();
assert_eq!(samples_2d.shape(), vec![5, 4]);
let samples_3d = noncentral_chisquare(5.0f64, 2.0f64, &[2, 3, 4]).unwrap();
assert_eq!(samples_3d.shape(), vec![2, 3, 4]);
}
#[test]
#[ignore = "Seeding behavior changed during SciRS2 migration - requires seeding implementation fix"]
fn test_seed_repeatability() {
let test_seed = 987654321u64;
use numrs2::random::distributions::uniform;
let get_uniform_sample = || {
set_seed(test_seed);
uniform(0.0f64, 1.0f64, &[3]).unwrap().to_vec()
};
let sample1 = get_uniform_sample();
let sample2 = get_uniform_sample();
assert_eq!(
sample1, sample2,
"Uniform distribution should be reproducible with same seed"
);
set_seed(test_seed + 1);
let sample3 = uniform(0.0f64, 1.0f64, &[3]).unwrap().to_vec();
assert_ne!(
sample1, sample3,
"Different seeds should produce different results"
);
#[allow(unreachable_code)]
if false
{
let get_maxwell_sample = || {
set_seed(test_seed);
maxwell(1.0f64, &[2]).unwrap().to_vec()
};
let maxwell1 = get_maxwell_sample();
let maxwell2 = get_maxwell_sample();
for (m1, m2) in maxwell1.iter().zip(maxwell2.iter()) {
assert!(
(m1 - m2).abs() < 1e-14,
"Maxwell distribution should be reproducible: {} vs {}",
m1,
m2
);
}
}
#[allow(unreachable_code)]
if false
{
let get_truncnorm_sample = || {
set_seed(test_seed);
truncated_normal(0.0f64, 1.0f64, -1.0f64, 1.0f64, &[2])
.unwrap()
.to_vec()
};
let trunc1 = get_truncnorm_sample();
let trunc2 = get_truncnorm_sample();
for (t1, t2) in trunc1.iter().zip(trunc2.iter()) {
assert!(
(t1 - t2).abs() < 1e-13,
"Truncated normal should be reproducible: {} vs {}",
t1,
t2
);
}
}
}
#[test]
fn test_type_conversions() {
let samples_f64 = maxwell(1.0f64, &[10]).unwrap();
assert_eq!(samples_f64.shape(), vec![10]);
let samples_f32 = maxwell(1.0f32, &[10]).unwrap();
assert_eq!(samples_f32.shape(), vec![10]);
for val in samples_f64.to_vec() {
assert!(val > 0.0);
}
for val in samples_f32.to_vec() {
assert!(val > 0.0);
}
}