use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Array1};
use scirs2_stats::{
distributions::{norm, poisson},
random, sampling,
};
#[test]
#[allow(dead_code)]
fn test_sampling_distribution_consistency() {
let normal = norm(0.0f64, 1.0).expect("Test: operation failed");
let samples1 = sampling::sample_distribution(&normal, 100).expect("Test: operation failed");
let samples2 = normal.rvs(100).expect("Test: operation failed");
assert!(samples1 != samples2.clone());
let samples_seeded1 = random::randn(20, Some(42)).expect("Test: operation failed");
let samples_seeded2 = random::randn(20, Some(42)).expect("Test: operation failed");
for i in 0..samples_seeded1.len() {
assert_eq!(samples_seeded1[i], samples_seeded2[i]);
}
}
#[test]
#[allow(dead_code)]
fn test_bootstrap_sample_properties() {
let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
let samples = sampling::bootstrap(&data.view(), 100, Some(42)).expect("Test: operation failed");
assert_eq!(samples.shape(), &[100, 5]);
for i in 0..samples.shape()[0] {
let bootstrap_sample = samples.slice(scirs2_core::ndarray::s![i, ..]);
assert_eq!(bootstrap_sample.len(), data.len());
for &value in bootstrap_sample.iter() {
let is_from_original = data.iter().any(|&x| {
let diff = (x - value).abs();
diff < 0.0000001
});
assert!(is_from_original);
}
}
}
#[test]
#[allow(dead_code)]
fn test_permutation_properties() {
let data = array![10, 20, 30, 40, 50];
let perm = sampling::permutation(&data.view(), Some(42)).expect("Test: operation failed");
assert_eq!(perm.len(), data.len());
for &value in data.iter() {
assert!(perm.iter().any(|&x| x == value));
}
for &value in data.iter() {
assert_eq!(perm.iter().filter(|&&x| x == value).count(), 1);
}
}
#[test]
#[allow(dead_code)]
fn test_statistical_properties() {
let uniform_samples =
random::uniform(0.0, 1.0, 10000, Some(42)).expect("Test: operation failed");
let mean = custom_mean(&uniform_samples);
let std = custom_std(&uniform_samples, 0);
assert_relative_eq!(mean, 0.5, epsilon = 0.02);
assert_relative_eq!(std, 0.2887, epsilon = 0.02);
let normal_samples = random::randn(10000, Some(42)).expect("Test: operation failed");
let mean = custom_mean(&normal_samples);
let std = custom_std(&normal_samples, 0);
assert_relative_eq!(mean, 0.0, epsilon = 0.05);
assert_relative_eq!(std, 1.0, epsilon = 0.05);
let poisson_dist = poisson(3.0f64, 0.0).expect("Test: operation failed");
let samples = poisson_dist.rvs(10000).expect("Test: operation failed");
let samples_array = Array1::from(samples);
let mean = custom_mean(&samples_array);
assert_relative_eq!(mean, 3.0, epsilon = 0.1);
}
#[allow(dead_code)]
fn custom_mean(arr: &Array1<f64>) -> f64 {
if arr.is_empty() {
return 0.0;
}
let sum: f64 = arr.iter().sum();
sum / arr.len() as f64
}
#[allow(dead_code)]
fn custom_std(arr: &Array1<f64>, ddof: usize) -> f64 {
if arr.len() <= ddof {
return 0.0;
}
let mean = custom_mean(arr);
let sum_sq: f64 = arr.iter().map(|&x| (x - mean) * (x - mean)).sum();
let denominator = (arr.len() - ddof) as f64;
(sum_sq / denominator).sqrt()
}