1use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
18use scirs2_core::numeric::Float;
19use scirs2_core::parallel_ops::*;
20use std::fmt::Debug;
21
22use crate::error::Result;
23use crate::optimizers::Optimizer;
24
25#[derive(Debug)]
58pub struct ParallelOptimizer<O, A, D>
59where
60 O: Optimizer<A, D> + Clone + Send + Sync,
61 A: Float + ScalarOperand + Debug + Send + Sync,
62 D: Dimension,
63{
64 base_optimizer: O,
65 _phantom_a: std::marker::PhantomData<A>,
66 _phantom_d: std::marker::PhantomData<D>,
67}
68
69impl<O, A, D> ParallelOptimizer<O, A, D>
70where
71 O: Optimizer<A, D> + Clone + Send + Sync,
72 A: Float + ScalarOperand + Debug + Send + Sync,
73 D: Dimension,
74{
75 pub fn new(base_optimizer: O) -> Self {
81 Self {
82 base_optimizer,
83 _phantom_a: std::marker::PhantomData,
84 _phantom_d: std::marker::PhantomData,
85 }
86 }
87
88 pub fn step_parallel_groups(
102 &mut self,
103 params_list: &[Array<A, D>],
104 grads_list: &[Array<A, D>],
105 ) -> Result<Vec<Array<A, D>>>
106 where
107 Array<A, D>: Clone + Send + Sync,
108 {
109 if params_list.len() != grads_list.len() {
110 return Err(crate::error::OptimError::InvalidConfig(format!(
111 "Parameter groups ({}) and gradient groups ({}) must have same length",
112 params_list.len(),
113 grads_list.len()
114 )));
115 }
116
117 let results: Vec<Result<Array<A, D>>> = params_list
119 .par_iter()
120 .zip(grads_list.par_iter())
121 .map(|(params, grads)| {
122 let mut opt_clone = self.base_optimizer.clone();
123 opt_clone.step(params, grads)
124 })
125 .collect();
126
127 let mut updated_params = Vec::with_capacity(results.len());
129 for result in results {
130 updated_params.push(result?);
131 }
132
133 Ok(updated_params)
134 }
135
136 pub fn inner(&self) -> &O {
138 &self.base_optimizer
139 }
140
141 pub fn inner_mut(&mut self) -> &mut O {
143 &mut self.base_optimizer
144 }
145
146 pub fn get_learning_rate(&self) -> A {
148 self.base_optimizer.get_learning_rate()
149 }
150
151 pub fn set_learning_rate(&mut self, learning_rate: A) {
153 self.base_optimizer.set_learning_rate(learning_rate);
154 }
155}
156
157pub struct ParallelBatchProcessor {
162 min_chunk_size: usize,
164 num_threads: Option<usize>,
166}
167
168impl ParallelBatchProcessor {
169 pub fn new(min_chunk_size: usize) -> Self {
175 Self {
176 min_chunk_size,
177 num_threads: None,
178 }
179 }
180
181 pub fn with_threads(mut self, num_threads: Option<usize>) -> Self {
187 self.num_threads = num_threads;
188 self
189 }
190
191 pub fn should_use_parallel(&self, size: usize) -> bool {
201 let num_cores = num_cpus::get();
202 size >= self.min_chunk_size * num_cores
203 }
204
205 pub fn optimal_chunk_size(&self, total_size: usize) -> usize {
215 let num_cores = self.num_threads.unwrap_or_else(num_cpus::get);
216 let chunk_size = total_size / num_cores;
217 chunk_size.max(self.min_chunk_size)
218 }
219}
220
221impl Default for ParallelBatchProcessor {
222 fn default() -> Self {
223 Self::new(1024)
224 }
225}
226
227pub fn parallel_step<O, A, D>(
242 optimizer: &mut O,
243 params_list: &[Array<A, D>],
244 grads_list: &[Array<A, D>],
245) -> Result<Vec<Array<A, D>>>
246where
247 O: Optimizer<A, D> + Clone + Send + Sync,
248 A: Float + ScalarOperand + Debug + Send + Sync,
249 D: Dimension,
250 Array<A, D>: Clone + Send + Sync,
251{
252 if params_list.len() != grads_list.len() {
253 return Err(crate::error::OptimError::InvalidConfig(format!(
254 "Parameter groups ({}) and gradient groups ({}) must have same length",
255 params_list.len(),
256 grads_list.len()
257 )));
258 }
259
260 let results: Vec<Result<Array<A, D>>> = params_list
261 .par_iter()
262 .zip(grads_list.par_iter())
263 .map(|(params, grads)| {
264 let mut opt_clone = optimizer.clone();
265 opt_clone.step(params, grads)
266 })
267 .collect();
268
269 let mut updated_params = Vec::with_capacity(results.len());
270 for result in results {
271 updated_params.push(result?);
272 }
273
274 Ok(updated_params)
275}
276
277pub fn parallel_step_array1<O, A>(
279 optimizer: &mut O,
280 params_list: &[Array1<A>],
281 grads_list: &[Array1<A>],
282) -> Result<Vec<Array1<A>>>
283where
284 O: Optimizer<A, scirs2_core::ndarray::Ix1> + Clone + Send + Sync,
285 A: Float + ScalarOperand + Debug + Send + Sync,
286{
287 if params_list.len() != grads_list.len() {
288 return Err(crate::error::OptimError::InvalidConfig(format!(
289 "Parameter groups ({}) and gradient groups ({}) must have same length",
290 params_list.len(),
291 grads_list.len()
292 )));
293 }
294
295 let results: Vec<Result<Array1<A>>> = params_list
296 .par_iter()
297 .zip(grads_list.par_iter())
298 .map(|(params, grads)| {
299 let mut opt_clone = optimizer.clone();
300 opt_clone.step(params, grads)
301 })
302 .collect();
303
304 let mut updated_params = Vec::with_capacity(results.len());
305 for result in results {
306 updated_params.push(result?);
307 }
308
309 Ok(updated_params)
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::optimizers::SGD;
316 use approx::assert_relative_eq;
317
318 #[test]
319 fn test_parallel_optimizer_basic() {
320 let optimizer = SGD::new(0.1);
321 let mut parallel_opt = ParallelOptimizer::new(optimizer);
322
323 let params_list = vec![
324 Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
325 Array1::from_vec(vec![4.0, 5.0, 6.0]),
326 ];
327 let grads_list = vec![
328 Array1::from_vec(vec![0.1, 0.2, 0.3]),
329 Array1::from_vec(vec![0.1, 0.2, 0.3]),
330 ];
331
332 let results = parallel_opt
333 .step_parallel_groups(¶ms_list, &grads_list)
334 .unwrap();
335
336 assert_eq!(results.len(), 2);
337 assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
338 assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
339 }
340
341 #[test]
342 fn test_parallel_optimizer_multiple_groups() {
343 let optimizer = SGD::new(0.01);
344 let mut parallel_opt = ParallelOptimizer::new(optimizer);
345
346 let params_list: Vec<Array1<f32>> =
348 (0..10).map(|i| Array1::from_elem(100, i as f32)).collect();
349 let grads_list: Vec<Array1<f32>> = (0..10).map(|_| Array1::from_elem(100, 0.1)).collect();
350
351 let results = parallel_opt
352 .step_parallel_groups(¶ms_list, &grads_list)
353 .unwrap();
354
355 assert_eq!(results.len(), 10);
356 assert_relative_eq!(results[0][0], 0.0 - 0.01 * 0.1, epsilon = 1e-6);
358 }
359
360 #[test]
361 fn test_parallel_step_function() {
362 let mut optimizer = SGD::new(0.1);
363
364 let params_list = vec![
365 Array1::from_vec(vec![1.0f32, 2.0]),
366 Array1::from_vec(vec![3.0, 4.0]),
367 ];
368 let grads_list = vec![
369 Array1::from_vec(vec![0.1, 0.2]),
370 Array1::from_vec(vec![0.3, 0.4]),
371 ];
372
373 let results = parallel_step(&mut optimizer, ¶ms_list, &grads_list).unwrap();
374
375 assert_eq!(results.len(), 2);
376 assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
377 assert_relative_eq!(results[1][0], 2.97, epsilon = 1e-6);
378 }
379
380 #[test]
381 fn test_parallel_batch_processor() {
382 let processor = ParallelBatchProcessor::new(1024);
383
384 assert!(!processor.should_use_parallel(100));
386
387 let num_cores = num_cpus::get();
389 assert!(processor.should_use_parallel(1024 * num_cores * 2));
390
391 let chunk_size = processor.optimal_chunk_size(10000);
393 assert!(chunk_size >= 1024);
394 }
395
396 #[test]
397 fn test_parallel_batch_processor_threads() {
398 let processor = ParallelBatchProcessor::new(1024).with_threads(Some(4));
399
400 let chunk_size = processor.optimal_chunk_size(10000);
401 assert!(chunk_size >= 1024);
403 assert!(chunk_size <= 10000);
404 }
405
406 #[test]
407 fn test_parallel_optimizer_learning_rate() {
408 let optimizer = SGD::new(0.1);
409 let mut parallel_opt: ParallelOptimizer<_, f64, scirs2_core::ndarray::Ix1> =
410 ParallelOptimizer::new(optimizer);
411
412 assert_relative_eq!(parallel_opt.get_learning_rate(), 0.1, epsilon = 1e-6);
413
414 parallel_opt.set_learning_rate(0.2);
415 assert_relative_eq!(parallel_opt.get_learning_rate(), 0.2, epsilon = 1e-6);
416 }
417
418 #[test]
419 fn test_parallel_step_array1() {
420 let mut optimizer = SGD::new(0.1);
421
422 let params_list = vec![
423 Array1::from_vec(vec![1.0f32, 2.0, 3.0]),
424 Array1::from_vec(vec![4.0, 5.0, 6.0]),
425 ];
426 let grads_list = vec![
427 Array1::from_vec(vec![0.1, 0.2, 0.3]),
428 Array1::from_vec(vec![0.1, 0.2, 0.3]),
429 ];
430
431 let results = parallel_step_array1(&mut optimizer, ¶ms_list, &grads_list).unwrap();
432
433 assert_eq!(results.len(), 2);
434 assert_relative_eq!(results[0][0], 0.99, epsilon = 1e-6);
435 assert_relative_eq!(results[1][0], 3.99, epsilon = 1e-6);
436 }
437}