use crate::error::{StatsError, StatsResult};
use crate::random;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::Float;
use scirs2_core::random::prelude::*;
use scirs2_core::random::SeedableRng;
pub trait SampleableDistribution<T> {
fn rvs(&self, size: usize) -> StatsResult<Vec<T>>;
}
#[allow(dead_code)]
pub fn sample_distribution<T, D>(dist: &D, size: usize) -> StatsResult<Array1<T>>
where
T: Float + std::iter::Sum<T> + std::ops::Div<Output = T>,
D: SampleableDistribution<T>,
{
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
let samples = dist.rvs(size)?;
Ok(Array1::from_vec(samples))
}
#[allow(dead_code)]
pub fn bootstrap<T>(
x: &ArrayView1<T>,
n_resamples: usize,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
{
random::bootstrap_sample(x, n_resamples, seed)
}
#[allow(dead_code)]
pub fn permutation<T>(x: &ArrayView1<T>, seed: Option<u64>) -> StatsResult<Array1<T>>
where
T: Copy,
{
random::permutation(x, seed)
}
#[allow(dead_code)]
pub fn stratified_sample<T, G>(
x: &ArrayView1<T>,
groups: &ArrayView1<G>,
size: usize,
seed: Option<u64>,
) -> StatsResult<Array1<usize>>
where
T: Copy,
G: Copy + Eq + std::hash::Hash,
{
if x.len() != groups.len() {
return Err(StatsError::DimensionMismatch(
"Input array and group array must have the same length".to_string(),
));
}
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
let mut unique_groups = std::collections::HashSet::new();
for &g in groups.iter() {
unique_groups.insert(g);
}
let n_groups = unique_groups.len();
let mut group_indices = std::collections::HashMap::new();
for (i, &g) in groups.iter().enumerate() {
group_indices.entry(g).or_insert_with(Vec::new).push(i);
}
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let mut result = Vec::with_capacity(n_groups * size);
for (_, indices) in group_indices.iter() {
if indices.len() < size {
return Err(StatsError::InvalidArgument(format!(
"Group size {} is smaller than requested sample size {}",
indices.len(),
size
)));
}
let mut indices_copy = indices.clone();
for i in 0..size {
let j = rng.random_range(i..indices_copy.len());
indices_copy.swap(i, j);
result.push(indices_copy[i]);
}
}
Ok(Array1::from_vec(result))
}
#[allow(dead_code)]
pub fn stratified_bootstrap<T, G>(
x: &ArrayView1<T>,
groups: &ArrayView1<G>,
n_resamples: usize,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
G: Copy + Eq + std::hash::Hash,
{
if x.len() != groups.len() {
return Err(StatsError::DimensionMismatch(
"Input array and group array must have the same length".to_string(),
));
}
if n_resamples == 0 {
return Err(StatsError::InvalidArgument(
"Number of _resamples must be positive".to_string(),
));
}
let mut group_indices = std::collections::HashMap::new();
for (i, &g) in groups.iter().enumerate() {
group_indices.entry(g).or_insert_with(Vec::new).push(i);
}
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let mut samples = Array2::zeros((n_resamples, x.len()));
for resample_idx in 0..n_resamples {
let mut sample_idx = 0;
for (_, indices) in group_indices.iter() {
for _ in 0..indices.len() {
let random_idx = rng.random_range(0..indices.len());
let selected_idx = indices[random_idx];
samples[[resample_idx, sample_idx]] = x[selected_idx];
sample_idx += 1;
}
}
}
Ok(samples)
}
#[allow(dead_code)]
pub fn block_bootstrap<T>(
x: &ArrayView1<T>,
blocksize: usize,
n_resamples: usize,
circular: bool,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
{
if x.is_empty() {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if blocksize == 0 {
return Err(StatsError::InvalidArgument(
"Block size must be positive".to_string(),
));
}
if blocksize > x.len() {
return Err(StatsError::InvalidArgument(
"Block size cannot exceed array length".to_string(),
));
}
if n_resamples == 0 {
return Err(StatsError::InvalidArgument(
"Number of _resamples must be positive".to_string(),
));
}
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let data_len = x.len();
let max_start_pos = if circular {
data_len
} else {
data_len - blocksize + 1
};
let mut samples = Array2::zeros((n_resamples, data_len));
for resample_idx in 0..n_resamples {
let mut sample_pos = 0;
while sample_pos < data_len {
let start_pos = rng.random_range(0..max_start_pos);
for block_offset in 0..blocksize {
if sample_pos >= data_len {
break;
}
let data_idx = if circular {
(start_pos + block_offset) % data_len
} else {
start_pos + block_offset
};
samples[[resample_idx, sample_pos]] = x[data_idx];
sample_pos += 1;
}
}
}
Ok(samples)
}
#[allow(dead_code)]
pub fn moving_block_bootstrap<T>(
x: &ArrayView1<T>,
blocksize: usize,
n_resamples: usize,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
{
if x.is_empty() {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if blocksize == 0 || blocksize > x.len() {
return Err(StatsError::InvalidArgument(
"Block size must be positive and not exceed array length".to_string(),
));
}
let mut blocks = Vec::new();
for i in 0..=(x.len() - blocksize) {
let mut block = Vec::with_capacity(blocksize);
for j in i..(i + blocksize) {
block.push(x[j]);
}
blocks.push(block);
}
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let data_len = x.len();
let n_blocks_needed = data_len.div_ceil(blocksize); let mut samples = Array2::zeros((n_resamples, data_len));
for resample_idx in 0..n_resamples {
let mut sample_pos = 0;
for _ in 0..n_blocks_needed {
if sample_pos >= data_len {
break;
}
let block_idx = rng.random_range(0..blocks.len());
let selected_block = &blocks[block_idx];
for &value in selected_block {
if sample_pos >= data_len {
break;
}
samples[[resample_idx, sample_pos]] = value;
sample_pos += 1;
}
}
}
Ok(samples)
}
#[allow(dead_code)]
pub fn stationary_bootstrap<T>(
x: &ArrayView1<T>,
p: f64,
n_resamples: usize,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
{
if x.is_empty() {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if p <= 0.0 || p >= 1.0 {
return Err(StatsError::InvalidArgument(
"Probability parameter p must be between 0 and 1".to_string(),
));
}
if n_resamples == 0 {
return Err(StatsError::InvalidArgument(
"Number of _resamples must be positive".to_string(),
));
}
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let data_len = x.len();
let mut samples = Array2::zeros((n_resamples, data_len));
for resample_idx in 0..n_resamples {
let mut sample_pos = 0;
while sample_pos < data_len {
let start_pos = rng.random_range(0..data_len);
let mut current_pos = start_pos;
loop {
samples[[resample_idx, sample_pos]] = x[current_pos];
sample_pos += 1;
if sample_pos >= data_len {
break;
}
let u: f64 = rng.random();
if u < p {
break; }
current_pos = (current_pos + 1) % data_len;
}
}
}
Ok(samples)
}
#[allow(dead_code)]
pub fn double_bootstrap<T, F>(
x: &ArrayView1<T>,
statistic: F,
n_resamples1: usize,
n_resamples2: usize,
seed: Option<u64>,
) -> StatsResult<(f64, Array1<f64>, f64)>
where
T: Copy + scirs2_core::numeric::Zero,
F: Fn(&ArrayView1<T>) -> StatsResult<f64> + Copy,
{
if x.is_empty() {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if n_resamples1 == 0 || n_resamples2 == 0 {
return Err(StatsError::InvalidArgument(
"Number of resamples must be positive".to_string(),
));
}
let original_stat = statistic(x)?;
let first_level_samples = bootstrap(x, n_resamples1, seed)?;
let mut first_level_stats = Array1::zeros(n_resamples1);
let mut rng = match seed {
Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value + 1),
None => {
let mut rng = scirs2_core::random::thread_rng();
let seed = rng.random::<u64>();
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
}
};
let mut bias_estimates = Array1::zeros(n_resamples1);
for i in 0..n_resamples1 {
let first_sample = first_level_samples.row(i);
let first_stat = statistic(&first_sample)?;
first_level_stats[i] = first_stat;
let second_seed = rng.random::<u64>();
let second_level_samples = bootstrap(&first_sample, n_resamples2, Some(second_seed))?;
let mut second_level_stats = Array1::zeros(n_resamples2);
for j in 0..n_resamples2 {
let second_sample = second_level_samples.row(j);
second_level_stats[j] = statistic(&second_sample)?;
}
let second_level_mean = second_level_stats.mean().expect("Operation failed");
bias_estimates[i] = second_level_mean - first_stat;
}
let overall_bias = bias_estimates.mean().expect("Operation failed");
let _first_level_mean = first_level_stats.mean().expect("Operation failed");
let bias_corrected = original_stat - overall_bias;
Ok((bias_corrected, first_level_stats, overall_bias))
}
#[allow(dead_code)]
pub fn bootstrap_confidence_intervals<T, F>(
x: &ArrayView1<T>,
statistic: F,
n_resamples: usize,
confidence_level: f64,
seed: Option<u64>,
) -> StatsResult<((f64, f64), (f64, f64), (f64, f64))>
where
T: Copy + scirs2_core::numeric::Zero,
F: Fn(&ArrayView1<T>) -> StatsResult<f64> + Copy,
{
if confidence_level <= 0.0 || confidence_level >= 1.0 {
return Err(StatsError::InvalidArgument(
"Confidence _level must be between 0 and 1".to_string(),
));
}
let original_stat = statistic(x)?;
let bootstrap_samples = bootstrap(x, n_resamples, seed)?;
let mut bootstrap_stats = Array1::zeros(n_resamples);
for i in 0..n_resamples {
let sample = bootstrap_samples.row(i);
bootstrap_stats[i] = statistic(&sample)?;
}
let mut sorted_stats = bootstrap_stats.to_vec();
sorted_stats.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
let alpha = 1.0 - confidence_level;
let n = sorted_stats.len() as f64;
let lower_idx = ((alpha / 2.0) * n) as usize;
let upper_idx = ((1.0 - alpha / 2.0) * n) as usize;
let percentile_ci = (
sorted_stats[lower_idx.min(n_resamples - 1)],
sorted_stats[upper_idx.min(n_resamples - 1)],
);
let below_original = sorted_stats.iter().filter(|&&x| x < original_stat).count() as f64;
let z0 = if below_original > 0.0 && below_original < n {
let p = below_original / n;
if p > 0.5 {
(2.0 * std::f64::consts::PI * p).sqrt()
} else {
-(2.0 * std::f64::consts::PI * (1.0 - p)).sqrt()
}
} else {
0.0
};
let mut jackknife_stats = Vec::with_capacity(x.len());
for i in 0..x.len() {
let mut jackknife_sample = Vec::with_capacity(x.len() - 1);
for j in 0..x.len() {
if i != j {
jackknife_sample.push(x[j]);
}
}
let jk_array = Array1::from_vec(jackknife_sample);
jackknife_stats.push(statistic(&jk_array.view())?);
}
let jk_mean = jackknife_stats.iter().sum::<f64>() / jackknife_stats.len() as f64;
let mut numerator = 0.0;
let mut denominator = 0.0;
for &jk_stat in &jackknife_stats {
let diff = jk_mean - jk_stat;
numerator += diff.powi(3);
denominator += diff.powi(2);
}
let acceleration = if denominator > 0.0 {
numerator / (6.0 * denominator.powf(1.5))
} else {
0.0
};
let z_alpha_2 = 1.96 * alpha / 2.0; let z_1_alpha_2 = -z_alpha_2;
let alpha1 = normal_cdf(z0 + (z0 + z_alpha_2) / (1.0 - acceleration * (z0 + z_alpha_2)));
let alpha2 = normal_cdf(z0 + (z0 + z_1_alpha_2) / (1.0 - acceleration * (z0 + z_1_alpha_2)));
let bca_lower_idx = (alpha1 * n) as usize;
let bca_upper_idx = (alpha2 * n) as usize;
let bc_ci = (
sorted_stats[bca_lower_idx.min(n_resamples - 1)],
sorted_stats[bca_upper_idx.min(n_resamples - 1)],
);
let bca_ci = (
sorted_stats[bca_lower_idx.min(n_resamples - 1)],
sorted_stats[bca_upper_idx.min(n_resamples - 1)],
);
Ok((percentile_ci, bc_ci, bca_ci))
}
#[allow(dead_code)]
fn normal_cdf(x: f64) -> f64 {
0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
}
#[allow(dead_code)]
fn erf(x: f64) -> f64 {
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}