optirs_core/optimizer_composition/mod.rs
1// Optimizer composition framework
2//
3// This module provides compositions of optimizers to create more sophisticated
4// optimization strategies. It includes three main types of compositions:
5//
6// 1. **Sequential**: Apply multiple optimizers in sequence
7// 2. **Parallel**: Apply different optimizers to different parameter groups
8// 3. **Chained**: Wrap an optimizer with another (similar to Lookahead wrapping other optimizers)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// A sequential composition of optimizers
17///
18/// This applies multiple optimizers in sequence to the same parameters.
19/// Each optimizer's output becomes the input to the next optimizer.
20///
21/// # Example
22///
23/// ```
24/// use scirs2_core::ndarray::Array1;
25/// use optirs_core::optimizer_composition::SequentialOptimizer;
26/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
27///
28/// // Create optimizers
29/// let sgd = SGD::new(0.1);
30/// let adam = Adam::new(0.01);
31///
32/// // Combine them sequentially
33/// let mut seq_optimizer = SequentialOptimizer::new(vec![
34/// Box::new(sgd),
35/// Box::new(adam),
36/// ]);
37///
38/// // Use the sequential optimizer
39/// let params = Array1::zeros(5);
40/// let gradients = Array1::ones(5);
41/// let updated_params = seq_optimizer.step(¶ms, &gradients).expect("unwrap failed");
42/// ```
43pub struct SequentialOptimizer<A, D>
44where
45 A: Float + ScalarOperand + Debug,
46 D: Dimension,
47{
48 /// List of optimizers to apply in sequence
49 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54 A: Float + ScalarOperand + Debug,
55 D: Dimension,
56{
57 /// Create a new sequential optimizer
58 ///
59 /// # Arguments
60 ///
61 /// * `optimizers` - List of optimizers to apply in sequence
62 pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63 Self { optimizers }
64 }
65
66 /// Add an optimizer to the sequence
67 ///
68 /// # Arguments
69 ///
70 /// * `optimizer` - The optimizer to add
71 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72 self.optimizers.push(optimizer);
73 }
74
75 /// Get the number of optimizers in the sequence
76 pub fn num_optimizers(&self) -> usize {
77 self.optimizers.len()
78 }
79
80 /// Get a reference to an optimizer by index
81 ///
82 /// # Arguments
83 ///
84 /// * `index` - The index of the optimizer
85 ///
86 /// # Returns
87 ///
88 /// A reference to the optimizer at the given index, or None if out of bounds
89 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90 if index < self.optimizers.len() {
91 Some(self.optimizers[index].as_ref())
92 } else {
93 None
94 }
95 }
96
97 /// Get a mutable reference to an optimizer by index
98 ///
99 /// # Arguments
100 ///
101 /// * `index` - The index of the optimizer
102 ///
103 /// # Returns
104 ///
105 /// A mutable reference to the optimizer at the given index, or None if out of bounds
106 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107 if index < self.optimizers.len() {
108 Some(self.optimizers[index].as_mut())
109 } else {
110 None
111 }
112 }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117 A: Float + ScalarOperand + Debug,
118 D: Dimension,
119{
120 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121 // Check if we have any optimizers
122 if self.optimizers.is_empty() {
123 return Err(OptimError::InvalidConfig(
124 "SequentialOptimizer has no optimizers".to_string(),
125 ));
126 }
127
128 // Start with the initial parameters
129 let mut current_params = params.clone();
130
131 // Apply each optimizer in sequence
132 for optimizer in &mut self.optimizers {
133 current_params = optimizer.step(¤t_params, gradients)?;
134 }
135
136 Ok(current_params)
137 }
138
139 fn get_learning_rate(&self) -> A {
140 // Return the learning rate of the first optimizer, or a default if empty
141 if let Some(optimizer) = self.optimizers.first() {
142 optimizer.get_learning_rate()
143 } else {
144 A::from(0.01).expect("unwrap failed") // Default learning rate
145 }
146 }
147
148 fn set_learning_rate(&mut self, learningrate: A) {
149 // Set the learning _rate for all optimizers
150 for optimizer in &mut self.optimizers {
151 optimizer.set_learning_rate(learningrate);
152 }
153 }
154}
155
156/// A struct for assigning parameters to specific groups for parallel optimization
157pub struct ParameterGroup<A, D>
158where
159 A: Float + ScalarOperand + Debug,
160 D: Dimension,
161{
162 /// The parameters in this group
163 pub params: Array<A, D>,
164 /// The index of the optimizer to use for this group
165 pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170 A: Float + ScalarOperand + Debug,
171 D: Dimension,
172{
173 /// Create a new parameter group
174 ///
175 /// # Arguments
176 ///
177 /// * `params` - The parameters in this group
178 /// * `optimizerindex` - The index of the optimizer to use for this group
179 pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180 Self {
181 params,
182 optimizerindex,
183 }
184 }
185}
186
187/// A parallel composition of optimizers
188///
189/// This applies different optimizers to different parameter groups.
190/// Each group of parameters is updated using its assigned optimizer.
191///
192/// # Example
193///
194/// ```
195/// use scirs2_core::ndarray::Array1;
196/// use optirs_core::optimizer_composition::{ParallelOptimizer, ParameterGroup};
197/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
198///
199/// // Create optimizers
200/// let sgd = SGD::new(0.1);
201/// let adam = Adam::new(0.01);
202///
203/// // Create parameter groups
204/// let params1 = Array1::zeros(3);
205/// let params2 = Array1::zeros(5);
206///
207/// let group1 = ParameterGroup::new(params1, 0); // Use SGD
208/// let group2 = ParameterGroup::new(params2, 1); // Use Adam
209///
210/// // Combine them in parallel
211/// let mut parallel_optimizer = ParallelOptimizer::new(
212/// vec![Box::new(sgd), Box::new(adam)],
213/// vec![group1, group2],
214/// );
215///
216/// // The step method will update all parameter groups using their assigned optimizers
217/// // (In a real use case, you'd provide the corresponding gradients)
218/// ```
219pub struct ParallelOptimizer<A, D>
220where
221 A: Float + ScalarOperand + Debug,
222 D: Dimension,
223{
224 /// List of optimizers to apply to different parameter groups
225 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226 /// Groups of parameters with their assigned optimizer indices
227 parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232 A: Float + ScalarOperand + Debug,
233 D: Dimension,
234{
235 /// Create a new parallel optimizer
236 ///
237 /// # Arguments
238 ///
239 /// * `optimizers` - List of optimizers to use
240 /// * `parameter_groups` - Groups of parameters with their assigned optimizer indices
241 pub fn new(
242 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243 parameter_groups: Vec<ParameterGroup<A, D>>,
244 ) -> Self {
245 Self {
246 optimizers,
247 parameter_groups,
248 }
249 }
250
251 /// Add an optimizer
252 ///
253 /// # Arguments
254 ///
255 /// * `optimizer` - The optimizer to add
256 ///
257 /// # Returns
258 ///
259 /// The index of the added optimizer
260 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261 let index = self.optimizers.len();
262 self.optimizers.push(optimizer);
263 index
264 }
265
266 /// Add a parameter group
267 ///
268 /// # Arguments
269 ///
270 /// * `params` - The parameters in this group
271 /// * `optimizerindex` - The index of the optimizer to use for this group
272 ///
273 /// # Returns
274 ///
275 /// Result with the index of the added parameter group, or an error if the optimizer index is invalid
276 pub fn add_parameter_group(
277 &mut self,
278 params: Array<A, D>,
279 optimizerindex: usize,
280 ) -> Result<usize> {
281 // Check if the optimizer _index is valid
282 if optimizerindex >= self.optimizers.len() {
283 return Err(OptimError::InvalidConfig(format!(
284 "Invalid optimizer _index: {}. Only {} optimizers available.",
285 optimizerindex,
286 self.optimizers.len()
287 )));
288 }
289
290 let _index = self.parameter_groups.len();
291 self.parameter_groups
292 .push(ParameterGroup::new(params, optimizerindex));
293 Ok(_index)
294 }
295
296 /// Get the number of optimizers
297 pub fn num_optimizers(&self) -> usize {
298 self.optimizers.len()
299 }
300
301 /// Get the number of parameter groups
302 pub fn num_parameter_groups(&self) -> usize {
303 self.parameter_groups.len()
304 }
305
306 /// Get a reference to an optimizer by index
307 ///
308 /// # Arguments
309 ///
310 /// * `index` - The index of the optimizer
311 ///
312 /// # Returns
313 ///
314 /// A reference to the optimizer at the given index, or None if out of bounds
315 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316 if index < self.optimizers.len() {
317 Some(self.optimizers[index].as_ref())
318 } else {
319 None
320 }
321 }
322
323 /// Get a mutable reference to an optimizer by index
324 ///
325 /// # Arguments
326 ///
327 /// * `index` - The index of the optimizer
328 ///
329 /// # Returns
330 ///
331 /// A mutable reference to the optimizer at the given index, or None if out of bounds
332 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333 if index < self.optimizers.len() {
334 Some(self.optimizers[index].as_mut())
335 } else {
336 None
337 }
338 }
339
340 /// Get a reference to a parameter group by index
341 ///
342 /// # Arguments
343 ///
344 /// * `index` - The index of the parameter group
345 ///
346 /// # Returns
347 ///
348 /// A reference to the parameter group at the given index, or None if out of bounds
349 pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350 self.parameter_groups.get(index)
351 }
352
353 /// Get a mutable reference to a parameter group by index
354 ///
355 /// # Arguments
356 ///
357 /// * `index` - The index of the parameter group
358 ///
359 /// # Returns
360 ///
361 /// A mutable reference to the parameter group at the given index, or None if out of bounds
362 pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363 self.parameter_groups.get_mut(index)
364 }
365
366 /// Get all current parameter values as a single array
367 ///
368 /// # Returns
369 ///
370 /// A result containing all parameter values concatenated into a single array
371 pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372 Ok(self
373 .parameter_groups
374 .iter()
375 .map(|group| group.params.clone())
376 .collect())
377 }
378
379 /// Update all parameter groups using their assigned optimizers
380 ///
381 /// # Arguments
382 ///
383 /// * `gradients` - List of gradient arrays corresponding to parameter groups
384 ///
385 /// # Returns
386 ///
387 /// Result with the updated parameter values, or an error
388 pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389 // Check if the number of gradients matches the number of parameter groups
390 if gradients.len() != self.parameter_groups.len() {
391 return Err(OptimError::InvalidConfig(format!(
392 "Number of gradients ({}) does not match number of parameter groups ({})",
393 gradients.len(),
394 self.parameter_groups.len()
395 )));
396 }
397
398 let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400 // Update each parameter group using its assigned optimizer
401 for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402 let optimizerindex = group.optimizerindex;
403
404 // Check if the optimizer index is valid
405 if optimizerindex >= self.optimizers.len() {
406 return Err(OptimError::InvalidConfig(format!(
407 "Invalid optimizer index: {}. Only {} optimizers available.",
408 optimizerindex,
409 self.optimizers.len()
410 )));
411 }
412
413 // Get the optimizer and update the parameters
414 let optimizer = &mut self.optimizers[optimizerindex];
415 let params = &group.params;
416 let gradient = &gradients[i];
417
418 // Update the parameters
419 let updated = optimizer.step(params, gradient)?;
420 group.params = updated.clone();
421 updated_params.push(updated);
422 }
423
424 Ok(updated_params)
425 }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430 A: Float + ScalarOperand + Debug,
431 D: Dimension,
432{
433 fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434 // This implementation is a bit tricky since we have multiple parameter groups
435 // We'll return an error message directing users to use update_all_parameters instead
436 Err(OptimError::InvalidConfig(
437 "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438 .to_string(),
439 ))
440 }
441
442 fn step_list(
443 &mut self,
444 params_list: &[&Array<A, D>],
445 gradients_list: &[&Array<A, D>],
446 ) -> Result<Vec<Array<A, D>>> {
447 // Convert params_list to owned arrays
448 let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450 // Set parameter groups based on the input params
451 self.parameter_groups = params_vec
452 .into_iter()
453 .enumerate()
454 .map(|(i, params)| {
455 // Use the first optimizer for all if there are more params than optimizers
456 let optimizerindex = i.min(self.optimizers.len() - 1);
457 ParameterGroup::new(params, optimizerindex)
458 })
459 .collect();
460
461 // Convert gradients_list to owned arrays
462 let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464 // Update parameter groups using their assigned optimizers
465 self.update_all_parameters(&gradients_vec)
466 }
467
468 fn get_learning_rate(&self) -> A {
469 // Return the learning rate of the first optimizer, or a default if empty
470 if let Some(optimizer) = self.optimizers.first() {
471 optimizer.get_learning_rate()
472 } else {
473 A::from(0.01).expect("unwrap failed") // Default learning rate
474 }
475 }
476
477 fn set_learning_rate(&mut self, learningrate: A) {
478 // Set the learning _rate for all optimizers
479 for optimizer in &mut self.optimizers {
480 optimizer.set_learning_rate(learningrate);
481 }
482 }
483}
484
485/// A chained composition of optimizers
486///
487/// This wraps one optimizer with another, similar to how Lookahead wraps
488/// another optimizer. The inner optimizer is applied first, and then the
489/// outer optimizer is applied to the result.
490///
491/// # Example
492///
493/// ```
494/// use scirs2_core::ndarray::Array1;
495/// use optirs_core::optimizer_composition::ChainedOptimizer;
496/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
497///
498/// // Create optimizers
499/// let inner = SGD::new(0.1);
500/// let outer = Adam::new(0.01);
501///
502/// // Chain them together
503/// let mut chained_optimizer = ChainedOptimizer::new(Box::new(inner), Box::new(outer));
504///
505/// // Use the chained optimizer
506/// let params = Array1::zeros(5);
507/// let gradients = Array1::ones(5);
508/// let updated_params = chained_optimizer.step(¶ms, &gradients).expect("unwrap failed");
509/// ```
510pub struct ChainedOptimizer<A, D>
511where
512 A: Float + ScalarOperand + Debug,
513 D: Dimension,
514{
515 /// The inner optimizer, applied first
516 inner: Box<dyn Optimizer<A, D>>,
517 /// The outer optimizer, applied to the result of the inner optimizer
518 outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523 A: Float + ScalarOperand + Debug,
524 D: Dimension,
525{
526 /// Create a new chained optimizer
527 ///
528 /// # Arguments
529 ///
530 /// * `inner` - The inner optimizer, applied first
531 /// * `outer` - The outer optimizer, applied to the result of the inner optimizer
532 pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533 Self { inner, outer }
534 }
535
536 /// Get a reference to the inner optimizer
537 pub fn inner(&self) -> &dyn Optimizer<A, D> {
538 self.inner.as_ref()
539 }
540
541 /// Get a mutable reference to the inner optimizer
542 pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543 self.inner.as_mut()
544 }
545
546 /// Get a reference to the outer optimizer
547 pub fn outer(&self) -> &dyn Optimizer<A, D> {
548 self.outer.as_ref()
549 }
550
551 /// Get a mutable reference to the outer optimizer
552 pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553 self.outer.as_mut()
554 }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559 A: Float + ScalarOperand + Debug,
560 D: Dimension,
561{
562 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563 // Apply the inner optimizer first
564 let intermediate_params = self.inner.step(params, gradients)?;
565
566 // Then apply the outer optimizer to the result
567 self.outer.step(&intermediate_params, gradients)
568 }
569
570 fn get_learning_rate(&self) -> A {
571 // Return the learning rate of the inner optimizer
572 self.inner.get_learning_rate()
573 }
574
575 fn set_learning_rate(&mut self, learningrate: A) {
576 // Set the learning _rate for both optimizers
577 self.inner.set_learning_rate(learningrate);
578 self.outer.set_learning_rate(learningrate);
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use crate::optimizers::{Adam, SGD};
586 use approx::assert_abs_diff_eq;
587 use scirs2_core::ndarray::Array1;
588
589 #[test]
590 fn test_sequential_optimizer() {
591 // Create a sequential optimizer with SGD followed by Adam
592 let sgd = SGD::new(0.1);
593 let adam = Adam::new(0.01);
594
595 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
596 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
597
598 // Create test parameters and gradients
599 let params = Array1::zeros(3);
600 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
601
602 // Apply the sequential optimizer
603 let updated_params = seq_optimizer
604 .step(¶ms, &gradients)
605 .expect("unwrap failed");
606
607 // Verify the result
608 // First SGD updates: params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
609 // Then Adam makes additional updates
610 assert!(updated_params[0] < -0.1);
611 assert!(updated_params[1] < -0.2);
612 assert!(updated_params[2] < -0.3);
613 }
614
615 #[test]
616 fn test_parallel_optimizer() {
617 // Create a parallel optimizer with SGD and Adam
618 let sgd = SGD::new(0.1);
619 let adam = Adam::new(0.01);
620
621 let params1 = Array1::zeros(2);
622 let params2 = Array1::zeros(3);
623
624 let group1 = ParameterGroup::new(params1.clone(), 0); // Use SGD
625 let group2 = ParameterGroup::new(params2.clone(), 1); // Use Adam
626
627 let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
628 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
629
630 // Create test gradients
631 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
632 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
633
634 // Update the parameters
635 let updated_params = parallel_optimizer
636 .update_all_parameters(&[gradients1, gradients2])
637 .expect("unwrap failed");
638
639 // Verify the results
640 // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
641 assert_abs_diff_eq!(updated_params[0][0], -0.1);
642 assert_abs_diff_eq!(updated_params[0][1], -0.2);
643
644 // Group 2 (Adam): The update will be different due to Adam's adaptive nature
645 // Just verify it's different from the original params
646 assert!(updated_params[1][0] != 0.0);
647 assert!(updated_params[1][1] != 0.0);
648 assert!(updated_params[1][2] != 0.0);
649 }
650
651 #[test]
652 fn test_chained_optimizer() {
653 // Create a chained optimizer with SGD as inner and Adam as outer
654 let inner = SGD::new(0.1);
655 let outer = Adam::new(0.01);
656
657 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
658 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
659
660 // Create test parameters and gradients
661 let params = Array1::zeros(3);
662 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
663
664 // Apply the chained optimizer
665 let updated_params = chained_optimizer
666 .step(¶ms, &gradients)
667 .expect("unwrap failed");
668
669 // Verify the result
670 // Inner (SGD): params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
671 // Then outer (Adam) applies another update
672 assert!(updated_params[0] < -0.1);
673 assert!(updated_params[1] < -0.2);
674 assert!(updated_params[2] < -0.3);
675 }
676
677 #[test]
678 fn test_sequential_learning_rate() {
679 // Create a sequential optimizer with SGD followed by Adam
680 let sgd = SGD::new(0.1);
681 let adam = Adam::new(0.01);
682
683 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
684 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
685
686 // Test getting the learning rate (should be from the first optimizer)
687 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
688
689 // Test setting the learning rate for all optimizers
690 seq_optimizer.set_learning_rate(0.05);
691
692 // Verify the learning rate has been set for both optimizers
693 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
694 assert_abs_diff_eq!(
695 seq_optimizer
696 .get_optimizer(0)
697 .expect("unwrap failed")
698 .get_learning_rate(),
699 0.05
700 );
701 assert_abs_diff_eq!(
702 seq_optimizer
703 .get_optimizer(1)
704 .expect("unwrap failed")
705 .get_learning_rate(),
706 0.05
707 );
708 }
709
710 #[test]
711 fn test_parallel_optimizer_step_list() {
712 // Create a parallel optimizer with SGD and Adam
713 let sgd = SGD::new(0.1);
714 let adam = Adam::new(0.01);
715
716 let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
717 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
718
719 // Create test parameters and gradients
720 let params1 = Array1::zeros(2);
721 let params2 = Array1::zeros(3);
722 let params3 = Array1::zeros(4);
723
724 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
725 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
726 let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
727
728 // Use step_list to update all parameters
729 let params_refs = vec![¶ms1, ¶ms2, ¶ms3];
730 let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
731
732 let updated_params = parallel_optimizer
733 .step_list(¶ms_refs, &gradients_refs)
734 .expect("unwrap failed");
735
736 // Verify the results
737 // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
738 assert_abs_diff_eq!(updated_params[0][0], -0.1);
739 assert_abs_diff_eq!(updated_params[0][1], -0.2);
740
741 // Group 2 will use SGD since we only have 2 optimizers and index 1 % 2 = 1 (Adam)
742 // Adam: The update will be different than SGD
743 assert!(updated_params[1][0] != -0.3);
744
745 // Group 3 will wrap around to optimize with Adam
746 // Just check that it's been updated from zero
747 assert!(updated_params[2][0] < 0.0);
748 }
749
750 #[test]
751 fn test_chained_optimizer_learning_rate() {
752 // Create a chained optimizer with SGD as inner and Adam as outer
753 let inner = SGD::new(0.1);
754 let outer = Adam::new(0.01);
755
756 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
757 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
758
759 // Test getting the learning rate (should be from the inner optimizer)
760 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
761
762 // Test setting the learning rate for both optimizers
763 chained_optimizer.set_learning_rate(0.05);
764
765 // Verify the learning rate has been set for both optimizers
766 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
767 assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
768 assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
769 }
770}