use crate::error::Result;
use crate::traits::*;
use crate::types::FloatBounds;
use rayon::prelude::*;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: Option<usize>,
pub min_parallel_batch_size: usize,
pub enabled: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: None,
min_parallel_batch_size: 1000,
enabled: true,
}
}
}
pub trait ParallelPredict<X, Output> {
fn predict_parallel(&self, x: &X) -> Result<Output>;
fn predict_parallel_with_config(&self, x: &X, config: &ParallelConfig) -> Result<Output>;
}
pub trait ParallelTransform<X, Output = X> {
fn transform_parallel(&self, x: &X) -> Result<Output>;
fn transform_parallel_with_config(&self, x: &X, config: &ParallelConfig) -> Result<Output>;
}
pub trait ParallelFit<X, Y> {
type Fitted;
fn fit_parallel(self, x: &X, y: &Y) -> Result<Self::Fitted>;
fn fit_parallel_with_config(
self,
x: &X,
y: &Y,
config: &ParallelConfig,
) -> Result<Self::Fitted>;
}
pub trait ParallelCrossValidation<X, Y> {
type Score: FloatBounds;
fn cross_validate_parallel(
&self,
model: impl Fit<X, Y> + Clone + Send + Sync,
x: &X,
y: &Y,
cv_folds: usize,
) -> Result<Vec<Self::Score>>
where
X: Clone + Send + Sync,
Y: Clone + Send + Sync,
<Self as ParallelCrossValidation<X, Y>>::Score: Send;
}
pub trait ParallelEnsemble<X, Y, Output> {
fn fit_ensemble_parallel(
models: Vec<impl Fit<X, Y> + Clone + Send + Sync>,
x: &X,
y: &Y,
) -> Result<Vec<Box<dyn Predict<X, Output>>>>
where
X: Clone + Send + Sync,
Y: Clone + Send + Sync;
fn predict_ensemble_parallel(
models: &[impl Predict<X, Output> + Sync],
x: &X,
) -> Result<Vec<Output>>
where
X: Sync,
Output: Send;
}
pub fn predict_parallel_ndarray<T, M>(
model: &M,
x: &Array2<T>,
config: &ParallelConfig,
) -> Result<Array1<T>>
where
T: FloatBounds + Send + Sync,
M: Predict<Array2<T>, Array1<T>> + Sync,
{
if !config.enabled || x.nrows() < config.min_parallel_batch_size {
return model.predict(x);
}
let chunk_size = (x.nrows() / rayon::current_num_threads()).max(1);
let chunks: Vec<_> = x.axis_chunks_iter(Axis(0), chunk_size).collect();
let results: Result<Vec<_>> = chunks
.into_par_iter()
.map(|chunk| {
let chunk_array = chunk.to_owned();
model.predict(&chunk_array)
})
.collect();
let predictions = results?;
let total_len: usize = predictions.iter().map(|p| p.len()).sum();
let mut result = Array1::zeros(total_len);
let mut offset = 0;
for pred in predictions {
let end = offset + pred.len();
result
.slice_mut(scirs2_core::ndarray::s![offset..end])
.assign(&pred);
offset = end;
}
Ok(result)
}
pub struct ParallelMatrixOps;
impl ParallelMatrixOps {
pub fn matrix_multiply_parallel<T: FloatBounds + Send + Sync>(
a: &Array2<T>,
b: &Array2<T>,
config: &ParallelConfig,
) -> Array2<T> {
let (m, k) = a.dim();
let (k2, n) = b.dim();
assert_eq!(k, k2, "Matrix dimensions must match");
let mut result = Array2::zeros((m, n));
if !config.enabled || m < config.min_parallel_batch_size {
result.assign(&a.dot(b));
return result;
}
result
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
for j in 0..n {
let mut sum = T::zero();
for ki in 0..k {
sum += a[[i, ki]] * b[[ki, j]];
}
row[j] = sum;
}
});
result
}
pub fn elementwise_op_parallel<T, F>(
a: &Array2<T>,
b: &Array2<T>,
op: F,
config: &ParallelConfig,
) -> Array2<T>
where
T: FloatBounds + Send + Sync,
F: Fn(T, T) -> T + Send + Sync,
{
assert_eq!(a.shape(), b.shape());
let mut result = Array2::zeros(a.dim());
if !config.enabled || a.len() < config.min_parallel_batch_size {
result
.iter_mut()
.zip(a.iter())
.zip(b.iter())
.for_each(|((r, &ai), &bi)| *r = op(ai, bi));
} else {
if let (Some(result_slice), Some(a_slice), Some(b_slice)) =
(result.as_slice_mut(), a.as_slice(), b.as_slice())
{
result_slice
.par_iter_mut()
.zip(a_slice.par_iter())
.zip(b_slice.par_iter())
.for_each(|((r, &ai), &bi)| *r = op(ai, bi));
} else {
result
.iter_mut()
.zip(a.iter())
.zip(b.iter())
.for_each(|((r, &ai), &bi)| *r = op(ai, bi));
}
}
result
}
pub fn apply_row_parallel<T, F>(matrix: &Array2<T>, op: F, config: &ParallelConfig) -> Array1<T>
where
T: FloatBounds + Send + Sync,
F: Fn(ArrayView1<T>) -> T + Send + Sync,
{
let mut result = Array1::zeros(matrix.nrows());
if !config.enabled || matrix.nrows() < config.min_parallel_batch_size {
result
.iter_mut()
.zip(matrix.axis_iter(Axis(0)))
.for_each(|(r, row)| *r = op(row));
} else {
if let Some(result_slice) = result.as_slice_mut() {
result_slice.par_iter_mut().enumerate().for_each(|(i, r)| {
let row = matrix.row(i);
*r = op(row);
});
} else {
result
.iter_mut()
.zip(matrix.axis_iter(Axis(0)))
.for_each(|(r, row)| *r = op(row));
}
}
result
}
pub fn apply_column_parallel<T, F>(
matrix: &Array2<T>,
op: F,
config: &ParallelConfig,
) -> Array1<T>
where
T: FloatBounds + Send + Sync,
F: Fn(ArrayView1<T>) -> T + Send + Sync,
{
let mut result = Array1::zeros(matrix.ncols());
if !config.enabled || matrix.ncols() < config.min_parallel_batch_size {
result
.iter_mut()
.zip(matrix.axis_iter(Axis(1)))
.for_each(|(r, col)| *r = op(col));
} else {
if let Some(result_slice) = result.as_slice_mut() {
result_slice.par_iter_mut().enumerate().for_each(|(j, r)| {
let col = matrix.column(j);
*r = op(col);
});
} else {
result
.iter_mut()
.zip(matrix.axis_iter(Axis(1)))
.for_each(|(r, col)| *r = op(col));
}
}
result
}
}
pub struct ParallelCrossValidator<T: FloatBounds> {
config: ParallelConfig,
_phantom: std::marker::PhantomData<T>,
}
impl<T: FloatBounds> ParallelCrossValidator<T> {
pub fn new(config: ParallelConfig) -> Self {
Self {
config,
_phantom: std::marker::PhantomData,
}
}
pub fn k_fold_parallel<X, Y, M, Output>(
&self,
model: M,
x: &X,
y: &Y,
k: usize,
) -> Result<Vec<T>>
where
M: Fit<X, Y> + Clone + Send + Sync,
M::Fitted: Score<X, Y, Float = T>,
X: Clone + Send + Sync,
Y: Clone + Send + Sync,
T: Send,
{
if !self.config.enabled || k < 2 {
return self.k_fold_sequential(model, x, y, k);
}
let fold_indices: Vec<_> = (0..k).collect();
let scores: Result<Vec<_>> = fold_indices
.into_par_iter()
.map(|_fold_idx| {
let model_clone = model.clone();
let fitted = model_clone.fit(x, y)?;
fitted.score(x, y)
})
.collect();
scores
}
fn k_fold_sequential<X, Y, M>(&self, model: M, x: &X, y: &Y, k: usize) -> Result<Vec<T>>
where
M: Fit<X, Y> + Clone,
M::Fitted: Score<X, Y, Float = T>,
{
let mut scores = Vec::with_capacity(k);
for _fold in 0..k {
let model_clone = model.clone();
let fitted = model_clone.fit(x, y)?;
let score = fitted.score(x, y)?;
scores.push(score);
}
Ok(scores)
}
}
pub struct ParallelEnsembleOps;
impl ParallelEnsembleOps {
pub fn train_models_parallel<X, Y, M>(
models: Vec<M>,
x: &X,
y: &Y,
config: &ParallelConfig,
) -> Result<Vec<M::Fitted>>
where
M: Fit<X, Y> + Send,
M::Fitted: Send,
X: Sync,
Y: Sync,
{
if !config.enabled || models.len() < 2 {
return models.into_iter().map(|model| model.fit(x, y)).collect();
}
models
.into_par_iter()
.map(|model| model.fit(x, y))
.collect()
}
pub fn predict_parallel<X, Output, M>(
models: &[M],
x: &X,
config: &ParallelConfig,
) -> Result<Vec<Output>>
where
M: Predict<X, Output> + Sync,
Output: Send,
X: Sync,
{
if !config.enabled || models.len() < 2 {
return models.iter().map(|model| model.predict(x)).collect();
}
models.par_iter().map(|model| model.predict(x)).collect()
}
}
pub mod utils {
use super::*;
pub fn optimal_chunk_size(total_size: usize, min_chunk_size: usize) -> usize {
let num_threads = rayon::current_num_threads();
(total_size / num_threads).max(min_chunk_size)
}
pub fn should_use_parallel(data_size: usize, config: &ParallelConfig) -> bool {
config.enabled && data_size >= config.min_parallel_batch_size
}
pub fn initialize_thread_pool(num_threads: Option<usize>) -> Result<()> {
if let Some(threads) = num_threads {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.map_err(|e| {
crate::error::SklearsError::NumericalError(format!(
"Failed to initialize thread pool: {e}"
))
})?;
}
Ok(())
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array2;
#[test]
fn test_parallel_matrix_multiply() {
let a = Array2::from_shape_vec((100, 50), (0..5000).map(|x| x as f64).collect())
.expect("valid array shape");
let b = Array2::from_shape_vec((50, 30), (0..1500).map(|x| x as f64 + 1.0).collect())
.expect("valid array shape");
let config = ParallelConfig {
enabled: true,
min_parallel_batch_size: 10,
num_threads: None,
};
let result_parallel = ParallelMatrixOps::matrix_multiply_parallel(&a, &b, &config);
let result_sequential = a.dot(&b);
for i in 0..result_parallel.nrows() {
for j in 0..result_parallel.ncols() {
assert_relative_eq!(
result_parallel[[i, j]],
result_sequential[[i, j]],
epsilon = 1e-10
);
}
}
}
#[test]
fn test_parallel_elementwise_ops() {
let a = Array2::from_shape_vec((100, 100), (0..10000).map(|x| x as f64).collect())
.expect("valid array shape");
let b = Array2::from_shape_vec((100, 100), (0..10000).map(|x| x as f64 + 1.0).collect())
.expect("expected valid value");
let config = ParallelConfig {
enabled: true,
min_parallel_batch_size: 100,
num_threads: None,
};
let result_parallel =
ParallelMatrixOps::elementwise_op_parallel(&a, &b, |x, y| x + y, &config);
let result_sequential = &a + &b;
for i in 0..result_parallel.nrows() {
for j in 0..result_parallel.ncols() {
assert_relative_eq!(
result_parallel[[i, j]],
result_sequential[[i, j]],
epsilon = 1e-10
);
}
}
}
#[test]
fn test_optimal_chunk_size() {
let num_threads = rayon::current_num_threads();
let expected = (1000 / num_threads).max(10);
assert_eq!(utils::optimal_chunk_size(1000, 10), expected);
assert_eq!(utils::optimal_chunk_size(100, 50), 50); }
#[test]
fn test_should_use_parallel() {
let config = ParallelConfig::default();
assert!(!utils::should_use_parallel(100, &config)); assert!(utils::should_use_parallel(2000, &config));
let disabled_config = ParallelConfig {
enabled: false,
..Default::default()
};
assert!(!utils::should_use_parallel(2000, &disabled_config)); }
}