use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::uniform::SampleUniform;
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::simd_ops::{AutoOptimizer, SimdUnifiedOps};
pub fn box_muller_simd<F>(
n: usize,
mean: F,
std_dev: F,
seed: Option<u64>,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps + SampleUniform,
{
if std_dev <= F::zero() {
return Err(StatsError::invalid_argument(
"Standard deviation must be positive",
));
}
let n_pairs = (n + 1) / 2;
let n_total = n_pairs * 2;
let mut rng = {
let s = seed.unwrap_or_else(|| {
use scirs2_core::random::{Rng, RngExt};
scirs2_core::random::thread_rng().random()
});
scirs2_core::random::seeded_rng(s)
};
let optimizer = AutoOptimizer::new();
let u1: Array1<F> = Array1::from_shape_fn(n_pairs, |_| {
F::from(rng.gen_range(F::epsilon().to_f64().unwrap_or(1e-10)..1.0))
.unwrap_or_else(|| F::one())
});
let u2: Array1<F> = Array1::from_shape_fn(n_pairs, |_| {
F::from(rng.gen_range(0.0..1.0)).unwrap_or_else(|| F::zero())
});
if optimizer.should_use_simd(n_pairs) {
let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap_or_else(|| F::one());
let ln_u1 = F::simd_ln(&u1.view());
let minus_two = F::from(-2.0).unwrap_or_else(|| F::one());
let minus_two_array = Array1::from_elem(n_pairs, minus_two);
let neg_two_ln_u1 = F::simd_mul(&minus_two_array.view(), &ln_u1.view());
let sqrt_term = F::simd_sqrt(&neg_two_ln_u1.view());
let two_pi_array = Array1::from_elem(n_pairs, two_pi);
let two_pi_u2 = F::simd_mul(&two_pi_array.view(), &u2.view());
let cos_term = F::simd_cos(&two_pi_u2.view());
let sin_term = F::simd_sin(&two_pi_u2.view());
let z0 = F::simd_mul(&sqrt_term.view(), &cos_term.view());
let z1 = F::simd_mul(&sqrt_term.view(), &sin_term.view());
let std_dev_array = Array1::from_elem(n_pairs, std_dev);
let mean_array = Array1::from_elem(n_pairs, mean);
let z0_scaled = F::simd_fma(&z0.view(), &std_dev_array.view(), &mean_array.view());
let z1_scaled = F::simd_fma(&z1.view(), &std_dev_array.view(), &mean_array.view());
let mut result = Array1::zeros(n_total);
for i in 0..n_pairs {
result[2 * i] = z0_scaled[i];
if 2 * i + 1 < n_total {
result[2 * i + 1] = z1_scaled[i];
}
}
Ok(result.slice(scirs2_core::ndarray::s![..n]).to_owned())
} else {
let mut result = Array1::zeros(n_total);
let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap_or_else(|| F::one());
let two = F::from(2.0).unwrap_or_else(|| F::one());
for i in 0..n_pairs {
let r = (-two * u1[i].ln()).sqrt();
let theta = two_pi * u2[i];
result[2 * i] = mean + std_dev * r * theta.cos();
if 2 * i + 1 < n_total {
result[2 * i + 1] = mean + std_dev * r * theta.sin();
}
}
Ok(result.slice(scirs2_core::ndarray::s![..n]).to_owned())
}
}
pub fn inverse_transform_simd<F, InvCDF>(
n: usize,
inverse_cdf: InvCDF,
seed: Option<u64>,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps + SampleUniform,
InvCDF: Fn(F) -> F,
{
let mut rng = {
let s = seed.unwrap_or_else(|| {
use scirs2_core::random::{Rng, RngExt};
scirs2_core::random::thread_rng().random()
});
scirs2_core::random::seeded_rng(s)
};
let u: Array1<F> = Array1::from_shape_fn(n, |_| {
F::from(rng.gen_range(0.0..1.0)).unwrap_or_else(|| F::zero())
});
let result = u.mapv(|ui| inverse_cdf(ui));
Ok(result)
}
pub fn exponential_simd<F>(n: usize, rate: F, seed: Option<u64>) -> StatsResult<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps + SampleUniform,
{
if rate <= F::zero() {
return Err(StatsError::invalid_argument("Rate must be positive"));
}
let mut rng = {
let s = seed.unwrap_or_else(|| {
use scirs2_core::random::{Rng, RngExt};
scirs2_core::random::thread_rng().random()
});
scirs2_core::random::seeded_rng(s)
};
let optimizer = AutoOptimizer::new();
let u: Array1<F> = Array1::from_shape_fn(n, |_| {
F::from(rng.gen_range(F::epsilon().to_f64().unwrap_or(1e-10)..1.0))
.unwrap_or_else(|| F::one())
});
if optimizer.should_use_simd(n) {
let ln_u = F::simd_ln(&u.view());
let minus_one = F::from(-1.0).unwrap_or_else(|| F::one());
let minus_one_array = Array1::from_elem(n, minus_one);
let neg_ln_u = F::simd_mul(&minus_one_array.view(), &ln_u.view());
let inv_rate = F::one() / rate;
let inv_rate_array = Array1::from_elem(n, inv_rate);
let result = F::simd_mul(&neg_ln_u.view(), &inv_rate_array.view());
Ok(result)
} else {
Ok(u.mapv(|ui| -ui.ln() / rate))
}
}
pub fn bootstrap_simd<F>(
data: &ArrayView1<F>,
n_samples: usize,
seed: Option<u64>,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
if data.is_empty() {
return Err(StatsError::invalid_argument("Data array cannot be empty"));
}
let n_data = data.len();
let mut rng = {
let s = seed.unwrap_or_else(|| {
use scirs2_core::random::{Rng, RngExt};
scirs2_core::random::thread_rng().random()
});
scirs2_core::random::seeded_rng(s)
};
let mut result = Array1::zeros(n_samples);
for i in 0..n_samples {
let idx = rng.gen_range(0..n_data);
result[i] = data[idx];
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_box_muller_simd_basic() {
let samples = box_muller_simd(1000, 0.0, 1.0, Some(42)).expect("Sampling failed");
assert_eq!(samples.len(), 1000);
let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
assert_abs_diff_eq!(mean, 0.0, epsilon = 0.1);
let variance: f64 =
samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
assert_abs_diff_eq!(variance.sqrt(), 1.0, epsilon = 0.1);
}
#[test]
fn test_exponential_simd_basic() {
let rate = 2.0;
let samples = exponential_simd(1000, rate, Some(42)).expect("Sampling failed");
assert_eq!(samples.len(), 1000);
let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
let expected_mean = 1.0 / rate;
assert_abs_diff_eq!(mean, expected_mean, epsilon = 0.1);
}
#[test]
fn test_bootstrap_simd_basic() {
use scirs2_core::ndarray::array;
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
let samples = bootstrap_simd(&data.view(), 100, Some(42)).expect("Bootstrap failed");
assert_eq!(samples.len(), 100);
for &sample in samples.iter() {
assert!(data.iter().any(|&x| (x - sample).abs() < 1e-10));
}
}
}