use super::Tensor;
use crate::error::RusTorchResult;
type ParallelResult<T> = RusTorchResult<T>;
use num_traits::Float;
use rayon::prelude::*;
pub trait ParallelOp<T: Float + Send + Sync + Clone + 'static> {
const MIN_PARALLEL_SIZE: usize = 1000;
fn should_parallelize(&self, size: usize) -> bool {
size >= Self::MIN_PARALLEL_SIZE
}
}
pub trait BatchParallelOp<T: Float + Send + Sync + Clone + 'static>: ParallelOp<T> {
fn batch_elementwise_op<F>(&self, other: &Tensor<T>, op: F) -> ParallelResult<Tensor<T>>
where
F: Fn(T, T) -> T + Send + Sync;
fn batch_scalar_op<F>(&self, scalar: T, op: F) -> Tensor<T>
where
F: Fn(T, T) -> T + Send + Sync;
fn batch_normalize(&self, epsilon: T) -> Tensor<T>;
}
pub trait MatrixParallelOp<T: Float + Send + Sync + Clone + 'static>: ParallelOp<T> {
fn batch_matmul(&self, other: &Tensor<T>) -> ParallelResult<Tensor<T>>;
fn batch_conv2d(
&self,
kernel: &Tensor<T>,
stride: usize,
padding: usize,
) -> ParallelResult<Tensor<T>>;
}
pub trait ReductionParallelOp<T: Float + Send + Sync + Clone + 'static>: ParallelOp<T> {
fn parallel_reduce<F, R>(&self, dim: usize, init: R, op: F) -> ParallelResult<Tensor<T>>
where
F: Fn(R, T) -> R + Send + Sync + Clone,
R: Send + Sync + Clone + Into<T>;
fn parallel_sum(&self, dim: usize) -> ParallelResult<Tensor<T>> {
self.parallel_reduce(dim, T::zero(), |acc, x| acc + x)
}
fn parallel_mean(&self, dim: usize) -> ParallelResult<Tensor<T>>;
}
pub trait SimdParallelOp: ParallelOp<f32> {
fn simd_parallel_add(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>>;
fn simd_parallel_matmul(&self, other: &Tensor<f32>) -> ParallelResult<Tensor<f32>>;
fn simd_parallel_scalar_mul(&self, scalar: f32) -> Tensor<f32>;
}
#[derive(Debug, Clone, Copy)]
pub enum ParallelStrategy {
Auto,
ForceParallel,
ForceSequential,
}
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub strategy: ParallelStrategy,
pub chunk_size: usize,
pub num_threads: Option<usize>,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
strategy: ParallelStrategy::Auto,
chunk_size: 1024,
num_threads: None,
}
}
}
pub struct ParallelContext {
pub config: ParallelConfig,
}
impl ParallelContext {
pub fn new(config: ParallelConfig) -> Self {
Self { config }
}
pub fn default() -> Self {
Self::new(ParallelConfig::default())
}
pub fn should_parallelize(&self, size: usize) -> bool {
match self.config.strategy {
ParallelStrategy::Auto => size >= 1000,
ParallelStrategy::ForceParallel => true,
ParallelStrategy::ForceSequential => false,
}
}
}
pub mod parallel_utils {
use super::*;
pub fn safe_parallel_index(
total_size: usize,
chunk_size: usize,
chunk_idx: usize,
) -> (usize, usize) {
let start = chunk_idx * chunk_size;
let end = std::cmp::min(start + chunk_size, total_size);
(start, end)
}
pub fn parallel_chunks<T, F, R>(data: &[T], chunk_size: usize, process_chunk: F) -> Vec<R>
where
T: Send + Sync,
F: Fn(&[T]) -> R + Send + Sync,
R: Send,
{
data.par_chunks(chunk_size).map(process_chunk).collect()
}
pub fn parallel_batch_process<T, F, R>(batch_size: usize, process_batch: F) -> Vec<R>
where
F: Fn(usize) -> R + Send + Sync,
R: Send,
{
(0..batch_size).into_par_iter().map(process_batch).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config() {
let config = ParallelConfig::default();
assert_eq!(config.chunk_size, 1024);
assert!(matches!(config.strategy, ParallelStrategy::Auto));
}
#[test]
fn test_parallel_context() {
let ctx = ParallelContext::default();
assert!(ctx.should_parallelize(2000));
assert!(!ctx.should_parallelize(500));
}
#[test]
fn test_parallel_utils() {
let (start, end) = parallel_utils::safe_parallel_index(1000, 100, 5);
assert_eq!(start, 500);
assert_eq!(end, 600);
let (start, end) = parallel_utils::safe_parallel_index(1000, 100, 9);
assert_eq!(start, 900);
assert_eq!(end, 1000);
}
}