use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::parallel_ops::*;
use std::fmt::Debug;
use crate::error::Result;
use crate::optimizers::Optimizer;
#[derive(Debug)]
pub struct ParallelOptimizer<O, A, D>
where
O: Optimizer<A, D> + Clone + Send + Sync,
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
base_optimizer: O,
_phantom_a: std::marker::PhantomData<A>,
_phantom_d: std::marker::PhantomData<D>,
}
impl<O, A, D> ParallelOptimizer<O, A, D>
where
O: Optimizer<A, D> + Clone + Send + Sync,
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
{
pub fn new(base_optimizer: O) -> Self {
Self {
base_optimizer,
_phantom_a: std::marker::PhantomData,
_phantom_d: std::marker::PhantomData,
}
}
pub fn step_parallel_groups(
&mut self,
params_list: &[Array<A, D>],
grads_list: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>>
where
Array<A, D>: Clone + Send + Sync,
{
if params_list.len() != grads_list.len() {
return Err(crate::error::OptimError::InvalidConfig(format!(
"Parameter groups ({}) and gradient groups ({}) must have same length",
params_list.len(),
grads_list.len()
)));
}
let results: Vec<Result<Array<A, D>>> = params_list
.par_iter()
.zip(grads_list.par_iter())
.map(|(params, grads)| {
let mut opt_clone = self.base_optimizer.clone();
opt_clone.step(params, grads)
})
.collect();
let mut updated_params = Vec::with_capacity(results.len());
for result in results {
updated_params.push(result?);
}
Ok(updated_params)
}
pub fn inner(&self) -> &O {
&self.base_optimizer
}
pub fn inner_mut(&mut self) -> &mut O {
&mut self.base_optimizer
}
pub fn get_learning_rate(&self) -> A {
self.base_optimizer.get_learning_rate()
}
pub fn set_learning_rate(&mut self, learning_rate: A) {
self.base_optimizer.set_learning_rate(learning_rate);
}
}
pub struct ParallelBatchProcessor {
min_chunk_size: usize,
num_threads: Option<usize>,
}
impl ParallelBatchProcessor {
pub fn new(min_chunk_size: usize) -> Self {
Self {
min_chunk_size,
num_threads: None,
}
}
pub fn with_threads(mut self, num_threads: Option<usize>) -> Self {
self.num_threads = num_threads;
self
}
pub fn should_use_parallel(&self, size: usize) -> bool {
let num_cores = num_cpus::get();
size >= self.min_chunk_size * num_cores
}
pub fn optimal_chunk_size(&self, total_size: usize) -> usize {
let num_cores = self.num_threads.unwrap_or_else(num_cpus::get);
let chunk_size = total_size / num_cores;
chunk_size.max(self.min_chunk_size)
}
}
impl Default for ParallelBatchProcessor {
fn default() -> Self {
Self::new(1024)
}
}
pub fn parallel_step<O, A, D>(
optimizer: &mut O,
params_list: &[Array<A, D>],
grads_list: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>>
where
O: Optimizer<A, D> + Clone + Send + Sync,
A: Float + ScalarOperand + Debug + Send + Sync,
D: Dimension,
Array<A, D>: Clone + Send + Sync,
{
if params_list.len() != grads_list.len() {
return Err(crate::error::OptimError::InvalidConfig(format!(
"Parameter groups ({}) and gradient groups ({}) must have same length",
params_list.len(),
grads_list.len()
)));
}
let results: Vec<Result<Array<A, D>>> = params_list
.par_iter()
.zip(grads_list.par_iter())
.map(|(params, grads)| {
let mut opt_clone = optimizer.clone();
opt_clone.step(params, grads)
})
.collect();
let mut updated_params = Vec::with_capacity(results.len());
for result in results {
updated_params.push(result?);
}
Ok(updated_params)
}
pub fn parallel_step_array1<O, A>(
optimizer: &mut O,
params_list: &[Array1<A>],
grads_list: &[Array1<A>],
) -> Result<Vec<Array1<A>>>
where
O: Optimizer<A, scirs2_core::ndarray::Ix1> + Clone + Send + Sync,
A: Float + ScalarOperand + Debug + Send + Sync,
{
if params_list.len() != grads_list.len() {
return Err(crate::error::OptimError::InvalidConfig(format!(
"Parameter groups ({}) and gradient groups ({}) must have same length",
params_list.len(),
grads_list.len()
)));
}
let results: Vec<Result<Array1<A>>> = params_list
.par_iter()
.zip(grads_list.par_iter())
.map(|(params, grads)| {
let mut opt_clone = optimizer.clone();
opt_clone.step(params, grads)
})
.collect();
let mut updated_params = Vec::with_capacity(results.len());
for result in results {
updated_params.push(result?);
}
Ok(updated_params)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizers::SGD;
use approx::assert_relative_eq;
#[test]
fn test_parallel_optimizer_basic() {
let optimizer = SGD::new(0.1);
let mut parallel_opt = ParallelOptimizer::new(optimizer);
let params_list = vec![
Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
Array1::from_vec(vec![4.0, 5.0, 6.0]),
];
let grads_list = vec![
Array1::from_vec(vec![0.1, 0.2, 0.3]),
Array1::from_vec(vec![0.1, 0.2, 0.3]),
];
let results = parallel_opt
.step_parallel_groups(¶ms_list, &grads_list)
.expect("unwrap failed");
assert_eq!(results.len(), 2);
assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
}
#[test]
fn test_parallel_optimizer_multiple_groups() {
let optimizer = SGD::new(0.01);
let mut parallel_opt = ParallelOptimizer::new(optimizer);
let params_list: Vec<Array1<f32>> =
(0..10).map(|i| Array1::from_elem(100, i as f32)).collect();
let grads_list: Vec<Array1<f32>> = (0..10).map(|_| Array1::from_elem(100, 0.1)).collect();
let results = parallel_opt
.step_parallel_groups(¶ms_list, &grads_list)
.expect("unwrap failed");
assert_eq!(results.len(), 10);
assert_relative_eq!(results[0][0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
}
#[test]
fn test_parallel_step_function() {
let mut optimizer = SGD::new(0.1);
let params_list = vec![
Array1::from_vec(vec![1.0f32, 2.0]),
Array1::from_vec(vec![3.0, 4.0]),
];
let grads_list = vec![
Array1::from_vec(vec![0.1, 0.2]),
Array1::from_vec(vec![0.3, 0.4]),
];
let results =
parallel_step(&mut optimizer, ¶ms_list, &grads_list).expect("unwrap failed");
assert_eq!(results.len(), 2);
assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
assert_relative_eq!(results[1][0], 2.97, epsilon = 1e-6);
}
#[test]
fn test_parallel_batch_processor() {
let processor = ParallelBatchProcessor::new(1024);
assert!(!processor.should_use_parallel(100));
let num_cores = num_cpus::get();
assert!(processor.should_use_parallel(1024 * num_cores * 2));
let chunk_size = processor.optimal_chunk_size(10000);
assert!(chunk_size >= 1024);
}
#[test]
fn test_parallel_batch_processor_threads() {
let processor = ParallelBatchProcessor::new(1024).with_threads(Some(4));
let chunk_size = processor.optimal_chunk_size(10000);
assert!(chunk_size >= 1024);
assert!(chunk_size <= 10000);
}
#[test]
fn test_parallel_optimizer_learning_rate() {
let optimizer = SGD::new(0.1);
let mut parallel_opt: ParallelOptimizer<_, f64, scirs2_core::ndarray::Ix1> =
ParallelOptimizer::new(optimizer);
assert_relative_eq!(parallel_opt.get_learning_rate(), 0.1, epsilon = 1e-6);
parallel_opt.set_learning_rate(0.2);
assert_relative_eq!(parallel_opt.get_learning_rate(), 0.2, epsilon = 1e-6);
}
#[test]
fn test_parallel_step_array1() {
let mut optimizer = SGD::new(0.1);
let params_list = vec![
Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
Array1::from_vec(vec![4.0, 5.0, 6.0]),
];
let grads_list = vec![
Array1::from_vec(vec![0.1, 0.2, 0.3]),
Array1::from_vec(vec![0.1, 0.2, 0.3]),
];
let results =
parallel_step_array1(&mut optimizer, ¶ms_list, &grads_list).expect("unwrap failed");
assert_eq!(results.len(), 2);
assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
}
}