optirs_core/
parallel_optimizer.rs

1//! Parallel optimizer operations using scirs2_core
2//!
3//! This module provides parallel processing capabilities for optimizers,
4//! enabling efficient multi-core utilization for large-scale optimization.
5//!
6//! # Features
7//!
8//! - Parallel parameter group processing
9//! - Parallel batch updates
10//! - Automatic work distribution across CPU cores
11//! - Zero-copy parameter handling
12//!
13//! # Performance
14//!
15//! Expected speedup: 4-8x on multi-core systems for multiple parameter groups
16
17use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
18use scirs2_core::numeric::Float;
19use scirs2_core::parallel_ops::*;
20use std::fmt::Debug;
21
22use crate::error::Result;
23use crate::optimizers::Optimizer;
24
25/// Parallel optimizer wrapper for processing multiple parameter groups
26///
27/// This wrapper enables parallel processing of multiple parameter groups,
28/// providing significant speedup on multi-core systems.
29///
30/// # Examples
31///
32/// ```
33/// use scirs2_core::ndarray::Array1;
34/// use optirs_core::optimizers::{SGD, Optimizer};
35/// use optirs_core::parallel_optimizer::ParallelOptimizer;
36///
37/// // Create base optimizer
38/// let optimizer = SGD::new(0.01);
39///
40/// // Wrap in parallel optimizer
41/// let mut parallel_opt = ParallelOptimizer::new(optimizer);
42///
43/// // Process multiple parameter groups in parallel
44/// let params_list = vec![
45///     Array1::zeros(1000),
46///     Array1::zeros(2000),
47///     Array1::zeros(1500),
48/// ];
49/// let grads_list = vec![
50///     Array1::from_elem(1000, 0.1),
51///     Array1::from_elem(2000, 0.1),
52///     Array1::from_elem(1500, 0.1),
53/// ];
54///
55/// let updated = parallel_opt.step_parallel_groups(&params_list, &grads_list).unwrap();
56/// ```
57#[derive(Debug)]
58pub struct ParallelOptimizer<O, A, D>
59where
60    O: Optimizer<A, D> + Clone + Send + Sync,
61    A: Float + ScalarOperand + Debug + Send + Sync,
62    D: Dimension,
63{
64    base_optimizer: O,
65    _phantom_a: std::marker::PhantomData<A>,
66    _phantom_d: std::marker::PhantomData<D>,
67}
68
69impl<O, A, D> ParallelOptimizer<O, A, D>
70where
71    O: Optimizer<A, D> + Clone + Send + Sync,
72    A: Float + ScalarOperand + Debug + Send + Sync,
73    D: Dimension,
74{
75    /// Creates a new parallel optimizer wrapper
76    ///
77    /// # Arguments
78    ///
79    /// * `base_optimizer` - The base optimizer to parallelize
80    pub fn new(base_optimizer: O) -> Self {
81        Self {
82            base_optimizer,
83            _phantom_a: std::marker::PhantomData,
84            _phantom_d: std::marker::PhantomData,
85        }
86    }
87
88    /// Process multiple parameter groups in parallel
89    ///
90    /// This method distributes parameter groups across available CPU cores
91    /// for parallel processing.
92    ///
93    /// # Arguments
94    ///
95    /// * `params_list` - List of parameter arrays
96    /// * `grads_list` - List of gradient arrays
97    ///
98    /// # Returns
99    ///
100    /// Updated parameter arrays processed in parallel
101    pub fn step_parallel_groups(
102        &mut self,
103        params_list: &[Array<A, D>],
104        grads_list: &[Array<A, D>],
105    ) -> Result<Vec<Array<A, D>>>
106    where
107        Array<A, D>: Clone + Send + Sync,
108    {
109        if params_list.len() != grads_list.len() {
110            return Err(crate::error::OptimError::InvalidConfig(format!(
111                "Parameter groups ({}) and gradient groups ({}) must have same length",
112                params_list.len(),
113                grads_list.len()
114            )));
115        }
116
117        // Use parallel iterator from scirs2_core
118        let results: Vec<Result<Array<A, D>>> = params_list
119            .par_iter()
120            .zip(grads_list.par_iter())
121            .map(|(params, grads)| {
122                let mut opt_clone = self.base_optimizer.clone();
123                opt_clone.step(params, grads)
124            })
125            .collect();
126
127        // Collect results and handle errors
128        let mut updated_params = Vec::with_capacity(results.len());
129        for result in results {
130            updated_params.push(result?);
131        }
132
133        Ok(updated_params)
134    }
135
136    /// Get the underlying optimizer
137    pub fn inner(&self) -> &O {
138        &self.base_optimizer
139    }
140
141    /// Get mutable reference to underlying optimizer
142    pub fn inner_mut(&mut self) -> &mut O {
143        &mut self.base_optimizer
144    }
145
146    /// Get the current learning rate from the base optimizer
147    pub fn get_learning_rate(&self) -> A {
148        self.base_optimizer.get_learning_rate()
149    }
150
151    /// Set the learning rate on the base optimizer
152    pub fn set_learning_rate(&mut self, learning_rate: A) {
153        self.base_optimizer.set_learning_rate(learning_rate);
154    }
155}
156
157/// Parallel batch processor for large parameter arrays
158///
159/// This processor splits large parameter arrays into chunks and processes
160/// them in parallel, providing speedup for very large models.
161pub struct ParallelBatchProcessor {
162    /// Minimum chunk size for parallel processing
163    min_chunk_size: usize,
164    /// Number of threads to use (None = automatic)
165    num_threads: Option<usize>,
166}
167
168impl ParallelBatchProcessor {
169    /// Creates a new parallel batch processor
170    ///
171    /// # Arguments
172    ///
173    /// * `min_chunk_size` - Minimum size of each chunk (default: 1024)
174    pub fn new(min_chunk_size: usize) -> Self {
175        Self {
176            min_chunk_size,
177            num_threads: None,
178        }
179    }
180
181    /// Set the number of threads to use
182    ///
183    /// # Arguments
184    ///
185    /// * `num_threads` - Number of threads (None for automatic)
186    pub fn with_threads(mut self, num_threads: Option<usize>) -> Self {
187        self.num_threads = num_threads;
188        self
189    }
190
191    /// Determine if parallel processing should be used
192    ///
193    /// # Arguments
194    ///
195    /// * `size` - Size of the parameter array
196    ///
197    /// # Returns
198    ///
199    /// True if parallel processing would be beneficial
200    pub fn should_use_parallel(&self, size: usize) -> bool {
201        let num_cores = num_cpus::get();
202        size >= self.min_chunk_size * num_cores
203    }
204
205    /// Get optimal chunk size for parallel processing
206    ///
207    /// # Arguments
208    ///
209    /// * `total_size` - Total size of the array
210    ///
211    /// # Returns
212    ///
213    /// Optimal chunk size for parallel processing
214    pub fn optimal_chunk_size(&self, total_size: usize) -> usize {
215        let num_cores = self.num_threads.unwrap_or_else(num_cpus::get);
216        let chunk_size = total_size / num_cores;
217        chunk_size.max(self.min_chunk_size)
218    }
219}
220
221impl Default for ParallelBatchProcessor {
222    fn default() -> Self {
223        Self::new(1024)
224    }
225}
226
227/// Helper function to process parameter groups in parallel
228///
229/// This is a convenience function for one-off parallel processing without
230/// creating a ParallelOptimizer instance.
231///
232/// # Arguments
233///
234/// * `optimizer` - The optimizer to use (will be cloned for each group)
235/// * `params_list` - List of parameter arrays
236/// * `grads_list` - List of gradient arrays
237///
238/// # Returns
239///
240/// Updated parameter arrays processed in parallel
241pub fn parallel_step<O, A, D>(
242    optimizer: &mut O,
243    params_list: &[Array<A, D>],
244    grads_list: &[Array<A, D>],
245) -> Result<Vec<Array<A, D>>>
246where
247    O: Optimizer<A, D> + Clone + Send + Sync,
248    A: Float + ScalarOperand + Debug + Send + Sync,
249    D: Dimension,
250    Array<A, D>: Clone + Send + Sync,
251{
252    if params_list.len() != grads_list.len() {
253        return Err(crate::error::OptimError::InvalidConfig(format!(
254            "Parameter groups ({}) and gradient groups ({}) must have same length",
255            params_list.len(),
256            grads_list.len()
257        )));
258    }
259
260    let results: Vec<Result<Array<A, D>>> = params_list
261        .par_iter()
262        .zip(grads_list.par_iter())
263        .map(|(params, grads)| {
264            let mut opt_clone = optimizer.clone();
265            opt_clone.step(params, grads)
266        })
267        .collect();
268
269    let mut updated_params = Vec::with_capacity(results.len());
270    for result in results {
271        updated_params.push(result?);
272    }
273
274    Ok(updated_params)
275}
276
277/// Parallel processing for Array1 specifically (optimized path)
278pub fn parallel_step_array1<O, A>(
279    optimizer: &mut O,
280    params_list: &[Array1<A>],
281    grads_list: &[Array1<A>],
282) -> Result<Vec<Array1<A>>>
283where
284    O: Optimizer<A, scirs2_core::ndarray::Ix1> + Clone + Send + Sync,
285    A: Float + ScalarOperand + Debug + Send + Sync,
286{
287    if params_list.len() != grads_list.len() {
288        return Err(crate::error::OptimError::InvalidConfig(format!(
289            "Parameter groups ({}) and gradient groups ({}) must have same length",
290            params_list.len(),
291            grads_list.len()
292        )));
293    }
294
295    let results: Vec<Result<Array1<A>>> = params_list
296        .par_iter()
297        .zip(grads_list.par_iter())
298        .map(|(params, grads)| {
299            let mut opt_clone = optimizer.clone();
300            opt_clone.step(params, grads)
301        })
302        .collect();
303
304    let mut updated_params = Vec::with_capacity(results.len());
305    for result in results {
306        updated_params.push(result?);
307    }
308
309    Ok(updated_params)
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::optimizers::SGD;
316    use approx::assert_relative_eq;
317
318    #[test]
319    fn test_parallel_optimizer_basic() {
320        let optimizer = SGD::new(0.1);
321        let mut parallel_opt = ParallelOptimizer::new(optimizer);
322
323        let params_list = vec![
324            Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
325            Array1::from_vec(vec![4.0, 5.0, 6.0]),
326        ];
327        let grads_list = vec![
328            Array1::from_vec(vec![0.1, 0.2, 0.3]),
329            Array1::from_vec(vec![0.1, 0.2, 0.3]),
330        ];
331
332        let results = parallel_opt
333            .step_parallel_groups(&params_list, &grads_list)
334            .unwrap();
335
336        assert_eq!(results.len(), 2);
337        assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
338        assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
339    }
340
341    #[test]
342    fn test_parallel_optimizer_multiple_groups() {
343        let optimizer = SGD::new(0.01);
344        let mut parallel_opt = ParallelOptimizer::new(optimizer);
345
346        // Create 10 parameter groups
347        let params_list: Vec<Array1<f32>> =
348            (0..10).map(|i| Array1::from_elem(100, i as f32)).collect();
349        let grads_list: Vec<Array1<f32>> = (0..10).map(|_| Array1::from_elem(100, 0.1)).collect();
350
351        let results = parallel_opt
352            .step_parallel_groups(&params_list, &grads_list)
353            .unwrap();
354
355        assert_eq!(results.len(), 10);
356        // Verify first group was updated correctly
357        assert_relative_eq!(results[0][0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
358    }
359
360    #[test]
361    fn test_parallel_step_function() {
362        let mut optimizer = SGD::new(0.1);
363
364        let params_list = vec![
365            Array1::from_vec(vec![1.0f32, 2.0]),
366            Array1::from_vec(vec![3.0, 4.0]),
367        ];
368        let grads_list = vec![
369            Array1::from_vec(vec![0.1, 0.2]),
370            Array1::from_vec(vec![0.3, 0.4]),
371        ];
372
373        let results = parallel_step(&mut optimizer, &params_list, &grads_list).unwrap();
374
375        assert_eq!(results.len(), 2);
376        assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
377        assert_relative_eq!(results[1][0], 2.97, epsilon = 1e-6);
378    }
379
380    #[test]
381    fn test_parallel_batch_processor() {
382        let processor = ParallelBatchProcessor::new(1024);
383
384        // Small array - should not use parallel
385        assert!(!processor.should_use_parallel(100));
386
387        // Large array - should use parallel
388        let num_cores = num_cpus::get();
389        assert!(processor.should_use_parallel(1024 * num_cores * 2));
390
391        // Test optimal chunk size calculation
392        let chunk_size = processor.optimal_chunk_size(10000);
393        assert!(chunk_size >= 1024);
394    }
395
396    #[test]
397    fn test_parallel_batch_processor_threads() {
398        let processor = ParallelBatchProcessor::new(1024).with_threads(Some(4));
399
400        let chunk_size = processor.optimal_chunk_size(10000);
401        // With 4 threads, chunk size should be around 10000/4 = 2500
402        assert!(chunk_size >= 1024);
403        assert!(chunk_size <= 10000);
404    }
405
406    #[test]
407    fn test_parallel_optimizer_learning_rate() {
408        let optimizer = SGD::new(0.1);
409        let mut parallel_opt: ParallelOptimizer<_, f64, scirs2_core::ndarray::Ix1> =
410            ParallelOptimizer::new(optimizer);
411
412        assert_relative_eq!(parallel_opt.get_learning_rate(), 0.1, epsilon = 1e-6);
413
414        parallel_opt.set_learning_rate(0.2);
415        assert_relative_eq!(parallel_opt.get_learning_rate(), 0.2, epsilon = 1e-6);
416    }
417
418    #[test]
419    fn test_parallel_step_array1() {
420        let mut optimizer = SGD::new(0.1);
421
422        let params_list = vec![
423            Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
424            Array1::from_vec(vec![4.0, 5.0, 6.0]),
425        ];
426        let grads_list = vec![
427            Array1::from_vec(vec![0.1, 0.2, 0.3]),
428            Array1::from_vec(vec![0.1, 0.2, 0.3]),
429        ];
430
431        let results = parallel_step_array1(&mut optimizer, &params_list, &grads_list).unwrap();
432
433        assert_eq!(results.len(), 2);
434        assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
435        assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
436    }
437}