optirs_core/optimizer_composition/
mod.rs

1// Optimizer composition framework
2//
3// This module provides compositions of optimizers to create more sophisticated
4// optimization strategies. It includes three main types of compositions:
5//
6// 1. **Sequential**: Apply multiple optimizers in sequence
7// 2. **Parallel**: Apply different optimizers to different parameter groups
8// 3. **Chained**: Wrap an optimizer with another (similar to Lookahead wrapping other optimizers)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// A sequential composition of optimizers
17///
18/// This applies multiple optimizers in sequence to the same parameters.
19/// Each optimizer's output becomes the input to the next optimizer.
20///
21/// # Example
22///
23/// ```
24/// use scirs2_core::ndarray::Array1;
25/// use optirs_core::optimizer_composition::SequentialOptimizer;
26/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
27///
28/// // Create optimizers
29/// let sgd = SGD::new(0.1);
30/// let adam = Adam::new(0.01);
31///
32/// // Combine them sequentially
33/// let mut seq_optimizer = SequentialOptimizer::new(vec![
34///     Box::new(sgd),
35///     Box::new(adam),
36/// ]);
37///
38/// // Use the sequential optimizer
39/// let params = Array1::zeros(5);
40/// let gradients = Array1::ones(5);
41/// let updated_params = seq_optimizer.step(&params, &gradients).unwrap();
42/// ```
43pub struct SequentialOptimizer<A, D>
44where
45    A: Float + ScalarOperand + Debug,
46    D: Dimension,
47{
48    /// List of optimizers to apply in sequence
49    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54    A: Float + ScalarOperand + Debug,
55    D: Dimension,
56{
57    /// Create a new sequential optimizer
58    ///
59    /// # Arguments
60    ///
61    /// * `optimizers` - List of optimizers to apply in sequence
62    pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63        Self { optimizers }
64    }
65
66    /// Add an optimizer to the sequence
67    ///
68    /// # Arguments
69    ///
70    /// * `optimizer` - The optimizer to add
71    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72        self.optimizers.push(optimizer);
73    }
74
75    /// Get the number of optimizers in the sequence
76    pub fn num_optimizers(&self) -> usize {
77        self.optimizers.len()
78    }
79
80    /// Get a reference to an optimizer by index
81    ///
82    /// # Arguments
83    ///
84    /// * `index` - The index of the optimizer
85    ///
86    /// # Returns
87    ///
88    /// A reference to the optimizer at the given index, or None if out of bounds
89    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90        if index < self.optimizers.len() {
91            Some(self.optimizers[index].as_ref())
92        } else {
93            None
94        }
95    }
96
97    /// Get a mutable reference to an optimizer by index
98    ///
99    /// # Arguments
100    ///
101    /// * `index` - The index of the optimizer
102    ///
103    /// # Returns
104    ///
105    /// A mutable reference to the optimizer at the given index, or None if out of bounds
106    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107        if index < self.optimizers.len() {
108            Some(self.optimizers[index].as_mut())
109        } else {
110            None
111        }
112    }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117    A: Float + ScalarOperand + Debug,
118    D: Dimension,
119{
120    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121        // Check if we have any optimizers
122        if self.optimizers.is_empty() {
123            return Err(OptimError::InvalidConfig(
124                "SequentialOptimizer has no optimizers".to_string(),
125            ));
126        }
127
128        // Start with the initial parameters
129        let mut current_params = params.clone();
130
131        // Apply each optimizer in sequence
132        for optimizer in &mut self.optimizers {
133            current_params = optimizer.step(&current_params, gradients)?;
134        }
135
136        Ok(current_params)
137    }
138
139    fn get_learning_rate(&self) -> A {
140        // Return the learning rate of the first optimizer, or a default if empty
141        if let Some(optimizer) = self.optimizers.first() {
142            optimizer.get_learning_rate()
143        } else {
144            A::from(0.01).unwrap() // Default learning rate
145        }
146    }
147
148    fn set_learning_rate(&mut self, learningrate: A) {
149        // Set the learning _rate for all optimizers
150        for optimizer in &mut self.optimizers {
151            optimizer.set_learning_rate(learningrate);
152        }
153    }
154}
155
156/// A struct for assigning parameters to specific groups for parallel optimization
157pub struct ParameterGroup<A, D>
158where
159    A: Float + ScalarOperand + Debug,
160    D: Dimension,
161{
162    /// The parameters in this group
163    pub params: Array<A, D>,
164    /// The index of the optimizer to use for this group
165    pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170    A: Float + ScalarOperand + Debug,
171    D: Dimension,
172{
173    /// Create a new parameter group
174    ///
175    /// # Arguments
176    ///
177    /// * `params` - The parameters in this group
178    /// * `optimizerindex` - The index of the optimizer to use for this group
179    pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180        Self {
181            params,
182            optimizerindex,
183        }
184    }
185}
186
187/// A parallel composition of optimizers
188///
189/// This applies different optimizers to different parameter groups.
190/// Each group of parameters is updated using its assigned optimizer.
191///
192/// # Example
193///
194/// ```
195/// use scirs2_core::ndarray::Array1;
196/// use optirs_core::optimizer_composition::{ParallelOptimizer, ParameterGroup};
197/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
198///
199/// // Create optimizers
200/// let sgd = SGD::new(0.1);
201/// let adam = Adam::new(0.01);
202///
203/// // Create parameter groups
204/// let params1 = Array1::zeros(3);
205/// let params2 = Array1::zeros(5);
206///
207/// let group1 = ParameterGroup::new(params1, 0); // Use SGD
208/// let group2 = ParameterGroup::new(params2, 1); // Use Adam
209///
210/// // Combine them in parallel
211/// let mut parallel_optimizer = ParallelOptimizer::new(
212///     vec![Box::new(sgd), Box::new(adam)],
213///     vec![group1, group2],
214/// );
215///
216/// // The step method will update all parameter groups using their assigned optimizers
217/// // (In a real use case, you'd provide the corresponding gradients)
218/// ```
219pub struct ParallelOptimizer<A, D>
220where
221    A: Float + ScalarOperand + Debug,
222    D: Dimension,
223{
224    /// List of optimizers to apply to different parameter groups
225    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226    /// Groups of parameters with their assigned optimizer indices
227    parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232    A: Float + ScalarOperand + Debug,
233    D: Dimension,
234{
235    /// Create a new parallel optimizer
236    ///
237    /// # Arguments
238    ///
239    /// * `optimizers` - List of optimizers to use
240    /// * `parameter_groups` - Groups of parameters with their assigned optimizer indices
241    pub fn new(
242        optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243        parameter_groups: Vec<ParameterGroup<A, D>>,
244    ) -> Self {
245        Self {
246            optimizers,
247            parameter_groups,
248        }
249    }
250
251    /// Add an optimizer
252    ///
253    /// # Arguments
254    ///
255    /// * `optimizer` - The optimizer to add
256    ///
257    /// # Returns
258    ///
259    /// The index of the added optimizer
260    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261        let index = self.optimizers.len();
262        self.optimizers.push(optimizer);
263        index
264    }
265
266    /// Add a parameter group
267    ///
268    /// # Arguments
269    ///
270    /// * `params` - The parameters in this group
271    /// * `optimizerindex` - The index of the optimizer to use for this group
272    ///
273    /// # Returns
274    ///
275    /// Result with the index of the added parameter group, or an error if the optimizer index is invalid
276    pub fn add_parameter_group(
277        &mut self,
278        params: Array<A, D>,
279        optimizerindex: usize,
280    ) -> Result<usize> {
281        // Check if the optimizer _index is valid
282        if optimizerindex >= self.optimizers.len() {
283            return Err(OptimError::InvalidConfig(format!(
284                "Invalid optimizer _index: {}. Only {} optimizers available.",
285                optimizerindex,
286                self.optimizers.len()
287            )));
288        }
289
290        let _index = self.parameter_groups.len();
291        self.parameter_groups
292            .push(ParameterGroup::new(params, optimizerindex));
293        Ok(_index)
294    }
295
296    /// Get the number of optimizers
297    pub fn num_optimizers(&self) -> usize {
298        self.optimizers.len()
299    }
300
301    /// Get the number of parameter groups
302    pub fn num_parameter_groups(&self) -> usize {
303        self.parameter_groups.len()
304    }
305
306    /// Get a reference to an optimizer by index
307    ///
308    /// # Arguments
309    ///
310    /// * `index` - The index of the optimizer
311    ///
312    /// # Returns
313    ///
314    /// A reference to the optimizer at the given index, or None if out of bounds
315    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316        if index < self.optimizers.len() {
317            Some(self.optimizers[index].as_ref())
318        } else {
319            None
320        }
321    }
322
323    /// Get a mutable reference to an optimizer by index
324    ///
325    /// # Arguments
326    ///
327    /// * `index` - The index of the optimizer
328    ///
329    /// # Returns
330    ///
331    /// A mutable reference to the optimizer at the given index, or None if out of bounds
332    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333        if index < self.optimizers.len() {
334            Some(self.optimizers[index].as_mut())
335        } else {
336            None
337        }
338    }
339
340    /// Get a reference to a parameter group by index
341    ///
342    /// # Arguments
343    ///
344    /// * `index` - The index of the parameter group
345    ///
346    /// # Returns
347    ///
348    /// A reference to the parameter group at the given index, or None if out of bounds
349    pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350        self.parameter_groups.get(index)
351    }
352
353    /// Get a mutable reference to a parameter group by index
354    ///
355    /// # Arguments
356    ///
357    /// * `index` - The index of the parameter group
358    ///
359    /// # Returns
360    ///
361    /// A mutable reference to the parameter group at the given index, or None if out of bounds
362    pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363        self.parameter_groups.get_mut(index)
364    }
365
366    /// Get all current parameter values as a single array
367    ///
368    /// # Returns
369    ///
370    /// A result containing all parameter values concatenated into a single array
371    pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372        Ok(self
373            .parameter_groups
374            .iter()
375            .map(|group| group.params.clone())
376            .collect())
377    }
378
379    /// Update all parameter groups using their assigned optimizers
380    ///
381    /// # Arguments
382    ///
383    /// * `gradients` - List of gradient arrays corresponding to parameter groups
384    ///
385    /// # Returns
386    ///
387    /// Result with the updated parameter values, or an error
388    pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389        // Check if the number of gradients matches the number of parameter groups
390        if gradients.len() != self.parameter_groups.len() {
391            return Err(OptimError::InvalidConfig(format!(
392                "Number of gradients ({}) does not match number of parameter groups ({})",
393                gradients.len(),
394                self.parameter_groups.len()
395            )));
396        }
397
398        let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400        // Update each parameter group using its assigned optimizer
401        for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402            let optimizerindex = group.optimizerindex;
403
404            // Check if the optimizer index is valid
405            if optimizerindex >= self.optimizers.len() {
406                return Err(OptimError::InvalidConfig(format!(
407                    "Invalid optimizer index: {}. Only {} optimizers available.",
408                    optimizerindex,
409                    self.optimizers.len()
410                )));
411            }
412
413            // Get the optimizer and update the parameters
414            let optimizer = &mut self.optimizers[optimizerindex];
415            let params = &group.params;
416            let gradient = &gradients[i];
417
418            // Update the parameters
419            let updated = optimizer.step(params, gradient)?;
420            group.params = updated.clone();
421            updated_params.push(updated);
422        }
423
424        Ok(updated_params)
425    }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430    A: Float + ScalarOperand + Debug,
431    D: Dimension,
432{
433    fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434        // This implementation is a bit tricky since we have multiple parameter groups
435        // We'll return an error message directing users to use update_all_parameters instead
436        Err(OptimError::InvalidConfig(
437            "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438                .to_string(),
439        ))
440    }
441
442    fn step_list(
443        &mut self,
444        params_list: &[&Array<A, D>],
445        gradients_list: &[&Array<A, D>],
446    ) -> Result<Vec<Array<A, D>>> {
447        // Convert params_list to owned arrays
448        let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450        // Set parameter groups based on the input params
451        self.parameter_groups = params_vec
452            .into_iter()
453            .enumerate()
454            .map(|(i, params)| {
455                // Use the first optimizer for all if there are more params than optimizers
456                let optimizerindex = i.min(self.optimizers.len() - 1);
457                ParameterGroup::new(params, optimizerindex)
458            })
459            .collect();
460
461        // Convert gradients_list to owned arrays
462        let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464        // Update parameter groups using their assigned optimizers
465        self.update_all_parameters(&gradients_vec)
466    }
467
468    fn get_learning_rate(&self) -> A {
469        // Return the learning rate of the first optimizer, or a default if empty
470        if let Some(optimizer) = self.optimizers.first() {
471            optimizer.get_learning_rate()
472        } else {
473            A::from(0.01).unwrap() // Default learning rate
474        }
475    }
476
477    fn set_learning_rate(&mut self, learningrate: A) {
478        // Set the learning _rate for all optimizers
479        for optimizer in &mut self.optimizers {
480            optimizer.set_learning_rate(learningrate);
481        }
482    }
483}
484
485/// A chained composition of optimizers
486///
487/// This wraps one optimizer with another, similar to how Lookahead wraps
488/// another optimizer. The inner optimizer is applied first, and then the
489/// outer optimizer is applied to the result.
490///
491/// # Example
492///
493/// ```
494/// use scirs2_core::ndarray::Array1;
495/// use optirs_core::optimizer_composition::ChainedOptimizer;
496/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
497///
498/// // Create optimizers
499/// let inner = SGD::new(0.1);
500/// let outer = Adam::new(0.01);
501///
502/// // Chain them together
503/// let mut chained_optimizer = ChainedOptimizer::new(Box::new(inner), Box::new(outer));
504///
505/// // Use the chained optimizer
506/// let params = Array1::zeros(5);
507/// let gradients = Array1::ones(5);
508/// let updated_params = chained_optimizer.step(&params, &gradients).unwrap();
509/// ```
510pub struct ChainedOptimizer<A, D>
511where
512    A: Float + ScalarOperand + Debug,
513    D: Dimension,
514{
515    /// The inner optimizer, applied first
516    inner: Box<dyn Optimizer<A, D>>,
517    /// The outer optimizer, applied to the result of the inner optimizer
518    outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523    A: Float + ScalarOperand + Debug,
524    D: Dimension,
525{
526    /// Create a new chained optimizer
527    ///
528    /// # Arguments
529    ///
530    /// * `inner` - The inner optimizer, applied first
531    /// * `outer` - The outer optimizer, applied to the result of the inner optimizer
532    pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533        Self { inner, outer }
534    }
535
536    /// Get a reference to the inner optimizer
537    pub fn inner(&self) -> &dyn Optimizer<A, D> {
538        self.inner.as_ref()
539    }
540
541    /// Get a mutable reference to the inner optimizer
542    pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543        self.inner.as_mut()
544    }
545
546    /// Get a reference to the outer optimizer
547    pub fn outer(&self) -> &dyn Optimizer<A, D> {
548        self.outer.as_ref()
549    }
550
551    /// Get a mutable reference to the outer optimizer
552    pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553        self.outer.as_mut()
554    }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559    A: Float + ScalarOperand + Debug,
560    D: Dimension,
561{
562    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563        // Apply the inner optimizer first
564        let intermediate_params = self.inner.step(params, gradients)?;
565
566        // Then apply the outer optimizer to the result
567        self.outer.step(&intermediate_params, gradients)
568    }
569
570    fn get_learning_rate(&self) -> A {
571        // Return the learning rate of the inner optimizer
572        self.inner.get_learning_rate()
573    }
574
575    fn set_learning_rate(&mut self, learningrate: A) {
576        // Set the learning _rate for both optimizers
577        self.inner.set_learning_rate(learningrate);
578        self.outer.set_learning_rate(learningrate);
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use crate::optimizers::{Adam, SGD};
586    use approx::assert_abs_diff_eq;
587    use scirs2_core::ndarray::Array1;
588
589    #[test]
590    fn test_sequential_optimizer() {
591        // Create a sequential optimizer with SGD followed by Adam
592        let sgd = SGD::new(0.1);
593        let adam = Adam::new(0.01);
594
595        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
596            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
597
598        // Create test parameters and gradients
599        let params = Array1::zeros(3);
600        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
601
602        // Apply the sequential optimizer
603        let updated_params = seq_optimizer.step(&params, &gradients).unwrap();
604
605        // Verify the result
606        // First SGD updates: params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
607        // Then Adam makes additional updates
608        assert!(updated_params[0] < -0.1);
609        assert!(updated_params[1] < -0.2);
610        assert!(updated_params[2] < -0.3);
611    }
612
613    #[test]
614    fn test_parallel_optimizer() {
615        // Create a parallel optimizer with SGD and Adam
616        let sgd = SGD::new(0.1);
617        let adam = Adam::new(0.01);
618
619        let params1 = Array1::zeros(2);
620        let params2 = Array1::zeros(3);
621
622        let group1 = ParameterGroup::new(params1.clone(), 0); // Use SGD
623        let group2 = ParameterGroup::new(params2.clone(), 1); // Use Adam
624
625        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
626            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
627
628        // Create test gradients
629        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
630        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
631
632        // Update the parameters
633        let updated_params = parallel_optimizer
634            .update_all_parameters(&[gradients1, gradients2])
635            .unwrap();
636
637        // Verify the results
638        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
639        assert_abs_diff_eq!(updated_params[0][0], -0.1);
640        assert_abs_diff_eq!(updated_params[0][1], -0.2);
641
642        // Group 2 (Adam): The update will be different due to Adam's adaptive nature
643        // Just verify it's different from the original params
644        assert!(updated_params[1][0] != 0.0);
645        assert!(updated_params[1][1] != 0.0);
646        assert!(updated_params[1][2] != 0.0);
647    }
648
649    #[test]
650    fn test_chained_optimizer() {
651        // Create a chained optimizer with SGD as inner and Adam as outer
652        let inner = SGD::new(0.1);
653        let outer = Adam::new(0.01);
654
655        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
656            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
657
658        // Create test parameters and gradients
659        let params = Array1::zeros(3);
660        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
661
662        // Apply the chained optimizer
663        let updated_params = chained_optimizer.step(&params, &gradients).unwrap();
664
665        // Verify the result
666        // Inner (SGD): params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
667        // Then outer (Adam) applies another update
668        assert!(updated_params[0] < -0.1);
669        assert!(updated_params[1] < -0.2);
670        assert!(updated_params[2] < -0.3);
671    }
672
673    #[test]
674    fn test_sequential_learning_rate() {
675        // Create a sequential optimizer with SGD followed by Adam
676        let sgd = SGD::new(0.1);
677        let adam = Adam::new(0.01);
678
679        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
680            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
681
682        // Test getting the learning rate (should be from the first optimizer)
683        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
684
685        // Test setting the learning rate for all optimizers
686        seq_optimizer.set_learning_rate(0.05);
687
688        // Verify the learning rate has been set for both optimizers
689        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
690        assert_abs_diff_eq!(
691            seq_optimizer.get_optimizer(0).unwrap().get_learning_rate(),
692            0.05
693        );
694        assert_abs_diff_eq!(
695            seq_optimizer.get_optimizer(1).unwrap().get_learning_rate(),
696            0.05
697        );
698    }
699
700    #[test]
701    fn test_parallel_optimizer_step_list() {
702        // Create a parallel optimizer with SGD and Adam
703        let sgd = SGD::new(0.1);
704        let adam = Adam::new(0.01);
705
706        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
707            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
708
709        // Create test parameters and gradients
710        let params1 = Array1::zeros(2);
711        let params2 = Array1::zeros(3);
712        let params3 = Array1::zeros(4);
713
714        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
715        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
716        let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
717
718        // Use step_list to update all parameters
719        let params_refs = vec![&params1, &params2, &params3];
720        let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
721
722        let updated_params = parallel_optimizer
723            .step_list(&params_refs, &gradients_refs)
724            .unwrap();
725
726        // Verify the results
727        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
728        assert_abs_diff_eq!(updated_params[0][0], -0.1);
729        assert_abs_diff_eq!(updated_params[0][1], -0.2);
730
731        // Group 2 will use SGD since we only have 2 optimizers and index 1 % 2 = 1 (Adam)
732        // Adam: The update will be different than SGD
733        assert!(updated_params[1][0] != -0.3);
734
735        // Group 3 will wrap around to optimize with Adam
736        // Just check that it's been updated from zero
737        assert!(updated_params[2][0] < 0.0);
738    }
739
740    #[test]
741    fn test_chained_optimizer_learning_rate() {
742        // Create a chained optimizer with SGD as inner and Adam as outer
743        let inner = SGD::new(0.1);
744        let outer = Adam::new(0.01);
745
746        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
747            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
748
749        // Test getting the learning rate (should be from the inner optimizer)
750        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
751
752        // Test setting the learning rate for both optimizers
753        chained_optimizer.set_learning_rate(0.05);
754
755        // Verify the learning rate has been set for both optimizers
756        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
757        assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
758        assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
759    }
760}