use crate::error::{StatsError, StatsResult};
use crate::error_standardization::ErrorMessages;
use crate::{mean, quantile, var, QuantileInterpolation};
use scirs2_core::ndarray::{s, Array1, ArrayBase, ArrayView1, Data, Ix1, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::parallel_ops::{num_threads, par_chunks, parallel_map, ParallelIterator};
const PARALLEL_THRESHOLD: usize = 10_000;
#[allow(dead_code)]
pub fn mean_parallel<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<F>
where
F: Float
+ NumCast
+ Send
+ Sync
+ std::iter::Sum<F>
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps,
D: Data<Elem = F> + Sync,
{
if x.is_empty() {
return Err(ErrorMessages::empty_array("x"));
}
let n = x.len();
if n < PARALLEL_THRESHOLD {
return mean(&x.view());
}
let chunksize = (n / num_threads()).max(1000);
let sum: F = if let Some(slice) = x.as_slice() {
par_chunks(slice, chunksize)
.map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
.reduce(|| F::zero(), |a, b| a + b)
} else {
x.iter().fold(F::zero(), |acc, &val| acc + val)
};
Ok(sum / F::from(n).expect("Failed to convert to float"))
}
#[allow(dead_code)]
pub fn variance_parallel<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
where
F: Float
+ NumCast
+ Send
+ Sync
+ std::iter::Sum<F>
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps,
D: Data<Elem = F> + Sync,
{
let n = x.len();
if n <= ddof {
return Err(StatsError::invalid_argument(
"Not enough data points for the given degrees of freedom",
));
}
if n < PARALLEL_THRESHOLD {
return var(&x.view(), ddof, None);
}
let mean_val = mean_parallel(x)?;
let chunksize = (n / num_threads()).max(1000);
let sum_sq_dev: F = par_chunks(x.as_slice().expect("Operation failed"), chunksize)
.map(|chunk| {
chunk
.iter()
.map(|&val| {
let dev = val - mean_val;
dev * dev
})
.fold(F::zero(), |acc, val| acc + val)
})
.reduce(|| F::zero(), |a, b| a + b);
Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
}
#[allow(dead_code)]
pub fn quantiles_parallel<F, D>(
x: &ArrayBase<D, Ix1>,
quantiles: &[F],
method: QuantileInterpolation,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + Send + Sync + std::fmt::Display,
D: Data<Elem = F> + Sync,
{
if x.is_empty() {
return Err(StatsError::invalid_argument(
"Cannot compute quantiles of empty array",
));
}
for &q in quantiles {
if q < F::zero() || q > F::one() {
return Err(StatsError::domain("Quantiles must be between 0 and 1"));
}
}
let n = x.len();
if n < PARALLEL_THRESHOLD || quantiles.len() < 4 {
let mut results = Array1::zeros(quantiles.len());
for (i, &q) in quantiles.iter().enumerate() {
results[i] = quantile(&x.view(), q, method)?;
}
return Ok(results);
}
let mut sorted = x.to_owned();
sorted
.as_slice_mut()
.expect("Operation failed")
.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
let results: Vec<F> = parallel_map(quantiles, |&q| {
let pos = q * F::from(n - 1).expect("Failed to convert to float");
let idx = pos.floor();
let frac = pos - idx;
let idx_usize: usize = NumCast::from(idx).expect("Operation failed");
if frac == F::zero() {
sorted[idx_usize]
} else {
let lower = sorted[idx_usize];
let upper = sorted[idx_usize + 1];
lower + frac * (upper - lower)
}
});
Ok(Array1::from_vec(results))
}
#[allow(dead_code)]
pub fn row_statistics_parallel<F, D, S>(
data: &ArrayBase<D, Ix2>,
stat_fn: S,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + Send + Sync,
D: Data<Elem = F> + Sync,
S: Fn(&ArrayView1<F>) -> StatsResult<F> + Send + Sync + std::fmt::Display,
{
let nrows = data.nrows();
if nrows < PARALLEL_THRESHOLD / data.ncols() {
let mut results = Vec::with_capacity(nrows);
for i in 0..nrows {
results.push(stat_fn(&data.slice(s![i, ..]).view())?);
}
return Ok(Array1::from_vec(results));
}
let row_indices: Vec<usize> = (0..nrows).collect();
let results: Result<Vec<F>, StatsError> =
parallel_map(&row_indices, |&i| stat_fn(&data.slice(s![i, ..]).view()))
.into_iter()
.collect();
Ok(Array1::from_vec(results?))
}
#[allow(dead_code)]
pub fn corrcoef_parallel<F, D>(
data: &ArrayBase<D, Ix2>,
) -> StatsResult<scirs2_core::ndarray::Array2<F>>
where
F: Float
+ NumCast
+ Send
+ Sync
+ std::iter::Sum<F>
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps
+ 'static,
D: Data<Elem = F> + Sync,
{
use crate::pearson_r;
let n_vars = data.ncols();
if n_vars * n_vars < PARALLEL_THRESHOLD {
return crate::corrcoef(&data.view(), "pearson");
}
let mut corr_matrix = scirs2_core::ndarray::Array2::zeros((n_vars, n_vars));
let pairs: Vec<(usize, usize)> = (0..n_vars)
.flat_map(|i| ((i + 1)..n_vars).map(move |j| (i, j)))
.collect();
let correlations: Vec<((usize, usize), F)> = parallel_map(&pairs, |&(i, j)| {
let var_i = data.slice(s![.., i]);
let var_j = data.slice(s![.., j]);
let corr = pearson_r(&var_i, &var_j)?;
Ok(((i, j), corr))
})
.into_iter()
.collect::<StatsResult<Vec<_>>>()?;
for i in 0..n_vars {
corr_matrix[(i, i)] = F::one();
}
for ((i, j), corr) in correlations {
corr_matrix[(i, j)] = corr;
corr_matrix[(j, i)] = corr; }
Ok(corr_matrix)
}
#[allow(dead_code)]
pub fn bootstrap_parallel<F, S>(
data: &Array1<F>,
n_samples_: usize,
statistic: S,
seed: Option<u64>,
) -> StatsResult<Array1<F>>
where
F: Float + NumCast + Send + Sync,
S: Fn(&ArrayBase<scirs2_core::ndarray::ViewRepr<&F>, Ix1>) -> StatsResult<F>
+ Send
+ Sync
+ std::fmt::Display,
{
use crate::sampling::bootstrap;
if n_samples_ < PARALLEL_THRESHOLD / data.len() {
let samples = bootstrap(&data.view(), n_samples_, seed)?;
let mut results = Array1::zeros(n_samples_);
for (i, sample) in samples.outer_iter().enumerate() {
results[i] = statistic(&sample)?;
}
return Ok(results);
}
let base_seed = seed.unwrap_or(42);
let seeds: Vec<u64> = (0..n_samples_)
.map(|i| base_seed.wrapping_add(i as u64))
.collect();
let results: Vec<F> = parallel_map(&seeds, |&seed| {
let sample = bootstrap(&data.view(), 1, Some(seed))?;
statistic(&sample.slice(s![0, ..]))
})
.into_iter()
.collect::<StatsResult<Vec<_>>>()?;
Ok(Array1::from_vec(results))
}