use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumCast, One, Zero};
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct EnhancedParallelConfig {
pub num_threads: Option<usize>,
pub min_chunksize: usize,
pub max_chunks: usize,
pub numa_aware: bool,
pub work_stealing: bool,
}
impl Default for EnhancedParallelConfig {
fn default() -> Self {
Self {
num_threads: None,
min_chunksize: 1000,
max_chunks: num_cpus::get() * 4,
numa_aware: true,
work_stealing: true,
}
}
}
pub struct EnhancedParallelProcessor<F> {
config: EnhancedParallelConfig,
_phantom: std::marker::PhantomData<F>,
}
impl<F> Default for EnhancedParallelProcessor<F>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
fn default() -> Self {
Self::new()
}
}
impl<F> EnhancedParallelProcessor<F>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
pub fn new() -> Self {
Self {
config: EnhancedParallelConfig::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn with_config(config: EnhancedParallelConfig) -> Self {
Self {
config,
_phantom: std::marker::PhantomData,
}
}
pub fn mean_parallel_enhanced(&self, data: &ArrayView1<F>) -> StatsResult<F> {
checkarray_finite(data, "data")?;
if data.is_empty() {
return Err(StatsError::InvalidArgument(
"Data cannot be empty".to_string(),
));
}
let n = data.len();
if n < self.config.min_chunksize {
return Ok(data.mean().expect("Operation failed"));
}
let chunksize = self.calculate_optimal_chunksize(n);
let result = data
.as_slice()
.expect("Operation failed")
.par_chunks(chunksize)
.map(|chunk| {
let sum: F = chunk.iter().copied().sum();
let count = chunk.len();
(sum, count)
})
.reduce(
|| (F::zero(), 0),
|(sum1, count1), (sum2, count2)| {
(sum1 + sum2, count1 + count2)
},
);
let (total_sum, total_count) = result;
Ok(total_sum / F::from(total_count).expect("Failed to convert to float"))
}
pub fn variance_parallel_enhanced(&self, data: &ArrayView1<F>, ddof: usize) -> StatsResult<F> {
checkarray_finite(data, "data")?;
if data.is_empty() {
return Err(StatsError::InvalidArgument(
"Data cannot be empty".to_string(),
));
}
let n = data.len();
if n < self.config.min_chunksize {
let mean = data.mean().expect("Operation failed");
let sum_sq_diff: F = data.iter().map(|&x| (x - mean) * (x - mean)).sum();
return Ok(sum_sq_diff / F::from(n.saturating_sub(ddof)).expect("Operation failed"));
}
let mean = self.mean_parallel_enhanced(data)?;
let chunksize = self.calculate_optimal_chunksize(n);
let result = data
.as_slice()
.expect("Operation failed")
.par_chunks(chunksize)
.map(|chunk| {
let sum_sq_diff: F = chunk.iter().map(|&x| (x - mean) * (x - mean)).sum();
let count = chunk.len();
(sum_sq_diff, count)
})
.reduce(
|| (F::zero(), 0),
|(sum1, count1), (sum2, count2)| {
(sum1 + sum2, count1 + count2)
},
);
let (total_sum_sq_diff, total_count) = result;
let denominator = total_count.saturating_sub(ddof);
if denominator == 0 {
return Err(StatsError::InvalidArgument(
"Insufficient degrees of freedom".to_string(),
));
}
Ok(total_sum_sq_diff / F::from(denominator).expect("Failed to convert to float"))
}
pub fn correlation_matrix_parallel(&self, matrix: &ArrayView2<F>) -> StatsResult<Array2<F>> {
checkarray_finite(matrix, "matrix")?;
let (_n_samples_, n_features) = matrix.dim();
if n_features < 2 {
return Err(StatsError::InvalidArgument(
"At least 2 features required for correlation matrix".to_string(),
));
}
let means = parallel_map_collect(0..n_features, |i| {
let col = matrix.column(i);
self.mean_parallel_enhanced(&col).expect("Operation failed")
});
let mut corr_matrix = Array2::zeros((n_features, n_features));
let pairs: Vec<(usize, usize)> = (0..n_features)
.flat_map(|i| (i..n_features).map(move |j| (i, j)))
.collect();
let correlations = parallel_map_collect(&pairs, |&(i, j)| {
if i == j {
(i, j, F::one())
} else {
let col_i = matrix.column(i);
let col_j = matrix.column(j);
let corr = self
.correlation_coefficient(&col_i, &col_j, means[i], means[j])
.expect("Operation failed");
(i, j, corr)
}
});
for (i, j, corr) in correlations {
corr_matrix[[i, j]] = corr;
if i != j {
corr_matrix[[j, i]] = corr;
}
}
Ok(corr_matrix)
}
pub fn bootstrap_parallel_enhanced(
&self,
data: &ArrayView1<F>,
n_bootstrap: usize,
statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
seed: Option<u64>,
) -> StatsResult<Array1<F>> {
checkarray_finite(data, "data")?;
check_positive(n_bootstrap, "n_bootstrap")?;
let statistic_fn = Arc::new(statistic_fn);
let data_arc = Arc::new(data.to_owned());
let results = parallel_map_collect(0..n_bootstrap, |i| {
use scirs2_core::random::Random;
let mut rng = match seed {
Some(s) => Random::seed(s.wrapping_add(i as u64)),
None => Random::seed(i as u64), };
let n = data_arc.len();
let mut bootstrap_sample = Array1::zeros(n);
for j in 0..n {
let idx = rng.random_range(0..n);
bootstrap_sample[j] = data_arc[idx];
}
statistic_fn(&bootstrap_sample.view())
});
Ok(Array1::from_vec(results))
}
pub fn matrix_operations_parallel(
&self,
matrix: &ArrayView2<F>,
) -> StatsResult<MatrixParallelResult<F>> {
checkarray_finite(matrix, "matrix")?;
let (rows, cols) = matrix.dim();
let row_means = parallel_map_collect(0..rows, |i| {
let row = matrix.row(i);
self.mean_parallel_enhanced(&row).expect("Operation failed")
});
let row_vars = parallel_map_collect(0..rows, |i| {
let row = matrix.row(i);
self.variance_parallel_enhanced(&row, 1)
.expect("Operation failed")
});
let col_means = parallel_map_collect(0..cols, |j| {
let col = matrix.column(j);
self.mean_parallel_enhanced(&col).expect("Operation failed")
});
let col_vars = parallel_map_collect(0..cols, |j| {
let col = matrix.column(j);
self.variance_parallel_enhanced(&col, 1)
.expect("Operation failed")
});
let flattened = matrix.iter().copied().collect::<Array1<F>>();
let overall_mean = self.mean_parallel_enhanced(&flattened.view())?;
let overall_var = self.variance_parallel_enhanced(&flattened.view(), 1)?;
Ok(MatrixParallelResult {
row_means: Array1::from_vec(row_means),
row_vars: Array1::from_vec(row_vars),
col_means: Array1::from_vec(col_means),
col_vars: Array1::from_vec(col_vars),
overall_mean,
overall_var,
shape: (rows, cols),
})
}
pub fn quantiles_parallel(
&self,
data: &ArrayView1<F>,
quantiles: &[F],
) -> StatsResult<Array1<F>> {
checkarray_finite(data, "data")?;
if data.is_empty() {
return Err(StatsError::InvalidArgument(
"Data cannot be empty".to_string(),
));
}
for &q in quantiles {
if q < F::zero() || q > F::one() {
return Err(StatsError::InvalidArgument(
"Quantiles must be in [0, 1]".to_string(),
));
}
}
let mut sorteddata = data.to_owned();
if sorteddata.len() >= self.config.min_chunksize {
sorteddata
.as_slice_mut()
.expect("Operation failed")
.par_sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
} else {
sorteddata
.as_slice_mut()
.expect("Operation failed")
.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
}
let n = sorteddata.len();
let results = quantiles
.iter()
.map(|&q| {
let index = (q * F::from(n - 1).expect("Failed to convert to float"))
.to_f64()
.expect("Operation failed");
let lower = index.floor() as usize;
let upper = index.ceil() as usize;
let weight = F::from(index - index.floor()).expect("Operation failed");
if lower == upper {
sorteddata[lower]
} else {
sorteddata[lower] * (F::one() - weight) + sorteddata[upper] * weight
}
})
.collect::<Vec<F>>();
Ok(Array1::from_vec(results))
}
fn calculate_optimal_chunksize(&self, datalen: usize) -> usize {
let num_threads = self.config.num_threads.unwrap_or_else(num_cpus::get);
let ideal_chunks = num_threads * 2; let chunksize = (datalen / ideal_chunks).max(self.config.min_chunksize);
chunksize.min(datalen)
}
fn correlation_coefficient(
&self,
x: &ArrayView1<F>,
y: &ArrayView1<F>,
mean_x: F,
mean_y: F,
) -> StatsResult<F> {
if x.len() != y.len() {
return Err(StatsError::DimensionMismatch(
"Arrays must have the same length".to_string(),
));
}
let n = x.len();
if n < 2 {
return Ok(F::zero());
}
let chunksize = self.calculate_optimal_chunksize(n);
let result = parallel_map_reduce_indexed(
0..n,
chunksize,
|indices| {
let mut sum_xy = F::zero();
let mut sum_x2 = F::zero();
let mut sum_y2 = F::zero();
for &i in indices {
let dx = x[i] - mean_x;
let dy = y[i] - mean_y;
sum_xy = sum_xy + dx * dy;
sum_x2 = sum_x2 + dx * dx;
sum_y2 = sum_y2 + dy * dy;
}
(sum_xy, sum_x2, sum_y2)
},
|(xy1, x2_1, y2_1), (xy2, x2_2, y2_2)| (xy1 + xy2, x2_1 + x2_2, y2_1 + y2_2),
);
let (sum_xy, sum_x2, sum_y2) = result;
let denom = (sum_x2 * sum_y2).sqrt();
if denom > F::zero() {
Ok(sum_xy / denom)
} else {
Ok(F::zero())
}
}
}
#[derive(Debug, Clone)]
pub struct MatrixParallelResult<F> {
pub row_means: Array1<F>,
pub row_vars: Array1<F>,
pub col_means: Array1<F>,
pub col_vars: Array1<F>,
pub overall_mean: F,
pub overall_var: F,
pub shape: (usize, usize),
}
#[allow(dead_code)]
pub fn mean_parallel_advanced<F>(data: &ArrayView1<F>) -> StatsResult<F>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
let processor = EnhancedParallelProcessor::<F>::new();
processor.mean_parallel_enhanced(data)
}
#[allow(dead_code)]
pub fn variance_parallel_advanced<F>(data: &ArrayView1<F>, ddof: usize) -> StatsResult<F>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
let processor = EnhancedParallelProcessor::<F>::new();
processor.variance_parallel_enhanced(data, ddof)
}
#[allow(dead_code)]
pub fn correlation_matrix_parallel_advanced<F>(matrix: &ArrayView2<F>) -> StatsResult<Array2<F>>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
let processor = EnhancedParallelProcessor::<F>::new();
processor.correlation_matrix_parallel(matrix)
}
#[allow(dead_code)]
pub fn bootstrap_parallel_advanced<F>(
data: &ArrayView1<F>,
n_bootstrap: usize,
statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
seed: Option<u64>,
) -> StatsResult<Array1<F>>
where
F: Float
+ NumCast
+ SimdUnifiedOps
+ Zero
+ One
+ PartialOrd
+ Copy
+ Send
+ Sync
+ std::fmt::Display
+ std::iter::Sum<F>
+ scirs2_core::numeric::FromPrimitive,
{
let processor = EnhancedParallelProcessor::<F>::new();
processor.bootstrap_parallel_enhanced(data, n_bootstrap, statistic_fn, seed)
}