Skip to main content

optirs_core/optimizer_composition/
mod.rs

1// Optimizer composition framework
2//
3// This module provides compositions of optimizers to create more sophisticated
4// optimization strategies. It includes three main types of compositions:
5//
6// 1. **Sequential**: Apply multiple optimizers in sequence
7// 2. **Parallel**: Apply different optimizers to different parameter groups
8// 3. **Chained**: Wrap an optimizer with another (similar to Lookahead wrapping other optimizers)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// A sequential composition of optimizers
17///
18/// This applies multiple optimizers in sequence to the same parameters.
19/// Each optimizer's output becomes the input to the next optimizer.
20///
21/// # Example
22///
23/// ```
24/// use scirs2_core::ndarray::Array1;
25/// use optirs_core::optimizer_composition::SequentialOptimizer;
26/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
27///
28/// // Create optimizers
29/// let sgd = SGD::new(0.1);
30/// let adam = Adam::new(0.01);
31///
32/// // Combine them sequentially
33/// let mut seq_optimizer = SequentialOptimizer::new(vec![
34///     Box::new(sgd),
35///     Box::new(adam),
36/// ]);
37///
38/// // Use the sequential optimizer
39/// let params = Array1::zeros(5);
40/// let gradients = Array1::ones(5);
41/// let updated_params = seq_optimizer.step(&params, &gradients).expect("unwrap failed");
42/// ```
43pub struct SequentialOptimizer<A, D>
44where
45    A: Float + ScalarOperand + Debug,
46    D: Dimension,
47{
48    /// List of optimizers to apply in sequence
49    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54    A: Float + ScalarOperand + Debug,
55    D: Dimension,
56{
57    /// Create a new sequential optimizer
58    ///
59    /// # Arguments
60    ///
61    /// * `optimizers` - List of optimizers to apply in sequence
62    pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63        Self { optimizers }
64    }
65
66    /// Add an optimizer to the sequence
67    ///
68    /// # Arguments
69    ///
70    /// * `optimizer` - The optimizer to add
71    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72        self.optimizers.push(optimizer);
73    }
74
75    /// Get the number of optimizers in the sequence
76    pub fn num_optimizers(&self) -> usize {
77        self.optimizers.len()
78    }
79
80    /// Get a reference to an optimizer by index
81    ///
82    /// # Arguments
83    ///
84    /// * `index` - The index of the optimizer
85    ///
86    /// # Returns
87    ///
88    /// A reference to the optimizer at the given index, or None if out of bounds
89    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90        if index < self.optimizers.len() {
91            Some(self.optimizers[index].as_ref())
92        } else {
93            None
94        }
95    }
96
97    /// Get a mutable reference to an optimizer by index
98    ///
99    /// # Arguments
100    ///
101    /// * `index` - The index of the optimizer
102    ///
103    /// # Returns
104    ///
105    /// A mutable reference to the optimizer at the given index, or None if out of bounds
106    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107        if index < self.optimizers.len() {
108            Some(self.optimizers[index].as_mut())
109        } else {
110            None
111        }
112    }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117    A: Float + ScalarOperand + Debug,
118    D: Dimension,
119{
120    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121        // Check if we have any optimizers
122        if self.optimizers.is_empty() {
123            return Err(OptimError::InvalidConfig(
124                "SequentialOptimizer has no optimizers".to_string(),
125            ));
126        }
127
128        // Start with the initial parameters
129        let mut current_params = params.clone();
130
131        // Apply each optimizer in sequence
132        for optimizer in &mut self.optimizers {
133            current_params = optimizer.step(&current_params, gradients)?;
134        }
135
136        Ok(current_params)
137    }
138
139    fn get_learning_rate(&self) -> A {
140        // Return the learning rate of the first optimizer, or a default if empty
141        if let Some(optimizer) = self.optimizers.first() {
142            optimizer.get_learning_rate()
143        } else {
144            A::from(0.01).expect("unwrap failed") // Default learning rate
145        }
146    }
147
148    fn set_learning_rate(&mut self, learningrate: A) {
149        // Set the learning _rate for all optimizers
150        for optimizer in &mut self.optimizers {
151            optimizer.set_learning_rate(learningrate);
152        }
153    }
154}
155
156/// A struct for assigning parameters to specific groups for parallel optimization
157pub struct ParameterGroup<A, D>
158where
159    A: Float + ScalarOperand + Debug,
160    D: Dimension,
161{
162    /// The parameters in this group
163    pub params: Array<A, D>,
164    /// The index of the optimizer to use for this group
165    pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170    A: Float + ScalarOperand + Debug,
171    D: Dimension,
172{
173    /// Create a new parameter group
174    ///
175    /// # Arguments
176    ///
177    /// * `params` - The parameters in this group
178    /// * `optimizerindex` - The index of the optimizer to use for this group
179    pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180        Self {
181            params,
182            optimizerindex,
183        }
184    }
185}
186
187/// A parallel composition of optimizers
188///
189/// This applies different optimizers to different parameter groups.
190/// Each group of parameters is updated using its assigned optimizer.
191///
192/// # Example
193///
194/// ```
195/// use scirs2_core::ndarray::Array1;
196/// use optirs_core::optimizer_composition::{ParallelOptimizer, ParameterGroup};
197/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
198///
199/// // Create optimizers
200/// let sgd = SGD::new(0.1);
201/// let adam = Adam::new(0.01);
202///
203/// // Create parameter groups
204/// let params1 = Array1::zeros(3);
205/// let params2 = Array1::zeros(5);
206///
207/// let group1 = ParameterGroup::new(params1, 0); // Use SGD
208/// let group2 = ParameterGroup::new(params2, 1); // Use Adam
209///
210/// // Combine them in parallel
211/// let mut parallel_optimizer = ParallelOptimizer::new(
212///     vec![Box::new(sgd), Box::new(adam)],
213///     vec![group1, group2],
214/// );
215///
216/// // The step method will update all parameter groups using their assigned optimizers
217/// // (In a real use case, you'd provide the corresponding gradients)
218/// ```
219pub struct ParallelOptimizer<A, D>
220where
221    A: Float + ScalarOperand + Debug,
222    D: Dimension,
223{
224    /// List of optimizers to apply to different parameter groups
225    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226    /// Groups of parameters with their assigned optimizer indices
227    parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232    A: Float + ScalarOperand + Debug,
233    D: Dimension,
234{
235    /// Create a new parallel optimizer
236    ///
237    /// # Arguments
238    ///
239    /// * `optimizers` - List of optimizers to use
240    /// * `parameter_groups` - Groups of parameters with their assigned optimizer indices
241    pub fn new(
242        optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243        parameter_groups: Vec<ParameterGroup<A, D>>,
244    ) -> Self {
245        Self {
246            optimizers,
247            parameter_groups,
248        }
249    }
250
251    /// Add an optimizer
252    ///
253    /// # Arguments
254    ///
255    /// * `optimizer` - The optimizer to add
256    ///
257    /// # Returns
258    ///
259    /// The index of the added optimizer
260    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261        let index = self.optimizers.len();
262        self.optimizers.push(optimizer);
263        index
264    }
265
266    /// Add a parameter group
267    ///
268    /// # Arguments
269    ///
270    /// * `params` - The parameters in this group
271    /// * `optimizerindex` - The index of the optimizer to use for this group
272    ///
273    /// # Returns
274    ///
275    /// Result with the index of the added parameter group, or an error if the optimizer index is invalid
276    pub fn add_parameter_group(
277        &mut self,
278        params: Array<A, D>,
279        optimizerindex: usize,
280    ) -> Result<usize> {
281        // Check if the optimizer _index is valid
282        if optimizerindex >= self.optimizers.len() {
283            return Err(OptimError::InvalidConfig(format!(
284                "Invalid optimizer _index: {}. Only {} optimizers available.",
285                optimizerindex,
286                self.optimizers.len()
287            )));
288        }
289
290        let _index = self.parameter_groups.len();
291        self.parameter_groups
292            .push(ParameterGroup::new(params, optimizerindex));
293        Ok(_index)
294    }
295
296    /// Get the number of optimizers
297    pub fn num_optimizers(&self) -> usize {
298        self.optimizers.len()
299    }
300
301    /// Get the number of parameter groups
302    pub fn num_parameter_groups(&self) -> usize {
303        self.parameter_groups.len()
304    }
305
306    /// Get a reference to an optimizer by index
307    ///
308    /// # Arguments
309    ///
310    /// * `index` - The index of the optimizer
311    ///
312    /// # Returns
313    ///
314    /// A reference to the optimizer at the given index, or None if out of bounds
315    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316        if index < self.optimizers.len() {
317            Some(self.optimizers[index].as_ref())
318        } else {
319            None
320        }
321    }
322
323    /// Get a mutable reference to an optimizer by index
324    ///
325    /// # Arguments
326    ///
327    /// * `index` - The index of the optimizer
328    ///
329    /// # Returns
330    ///
331    /// A mutable reference to the optimizer at the given index, or None if out of bounds
332    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333        if index < self.optimizers.len() {
334            Some(self.optimizers[index].as_mut())
335        } else {
336            None
337        }
338    }
339
340    /// Get a reference to a parameter group by index
341    ///
342    /// # Arguments
343    ///
344    /// * `index` - The index of the parameter group
345    ///
346    /// # Returns
347    ///
348    /// A reference to the parameter group at the given index, or None if out of bounds
349    pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350        self.parameter_groups.get(index)
351    }
352
353    /// Get a mutable reference to a parameter group by index
354    ///
355    /// # Arguments
356    ///
357    /// * `index` - The index of the parameter group
358    ///
359    /// # Returns
360    ///
361    /// A mutable reference to the parameter group at the given index, or None if out of bounds
362    pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363        self.parameter_groups.get_mut(index)
364    }
365
366    /// Get all current parameter values as a single array
367    ///
368    /// # Returns
369    ///
370    /// A result containing all parameter values concatenated into a single array
371    pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372        Ok(self
373            .parameter_groups
374            .iter()
375            .map(|group| group.params.clone())
376            .collect())
377    }
378
379    /// Update all parameter groups using their assigned optimizers
380    ///
381    /// # Arguments
382    ///
383    /// * `gradients` - List of gradient arrays corresponding to parameter groups
384    ///
385    /// # Returns
386    ///
387    /// Result with the updated parameter values, or an error
388    pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389        // Check if the number of gradients matches the number of parameter groups
390        if gradients.len() != self.parameter_groups.len() {
391            return Err(OptimError::InvalidConfig(format!(
392                "Number of gradients ({}) does not match number of parameter groups ({})",
393                gradients.len(),
394                self.parameter_groups.len()
395            )));
396        }
397
398        let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400        // Update each parameter group using its assigned optimizer
401        for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402            let optimizerindex = group.optimizerindex;
403
404            // Check if the optimizer index is valid
405            if optimizerindex >= self.optimizers.len() {
406                return Err(OptimError::InvalidConfig(format!(
407                    "Invalid optimizer index: {}. Only {} optimizers available.",
408                    optimizerindex,
409                    self.optimizers.len()
410                )));
411            }
412
413            // Get the optimizer and update the parameters
414            let optimizer = &mut self.optimizers[optimizerindex];
415            let params = &group.params;
416            let gradient = &gradients[i];
417
418            // Update the parameters
419            let updated = optimizer.step(params, gradient)?;
420            group.params = updated.clone();
421            updated_params.push(updated);
422        }
423
424        Ok(updated_params)
425    }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430    A: Float + ScalarOperand + Debug,
431    D: Dimension,
432{
433    fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434        // This implementation is a bit tricky since we have multiple parameter groups
435        // We'll return an error message directing users to use update_all_parameters instead
436        Err(OptimError::InvalidConfig(
437            "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438                .to_string(),
439        ))
440    }
441
442    fn step_list(
443        &mut self,
444        params_list: &[&Array<A, D>],
445        gradients_list: &[&Array<A, D>],
446    ) -> Result<Vec<Array<A, D>>> {
447        // Convert params_list to owned arrays
448        let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450        // Set parameter groups based on the input params
451        self.parameter_groups = params_vec
452            .into_iter()
453            .enumerate()
454            .map(|(i, params)| {
455                // Use the first optimizer for all if there are more params than optimizers
456                let optimizerindex = i.min(self.optimizers.len() - 1);
457                ParameterGroup::new(params, optimizerindex)
458            })
459            .collect();
460
461        // Convert gradients_list to owned arrays
462        let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464        // Update parameter groups using their assigned optimizers
465        self.update_all_parameters(&gradients_vec)
466    }
467
468    fn get_learning_rate(&self) -> A {
469        // Return the learning rate of the first optimizer, or a default if empty
470        if let Some(optimizer) = self.optimizers.first() {
471            optimizer.get_learning_rate()
472        } else {
473            A::from(0.01).expect("unwrap failed") // Default learning rate
474        }
475    }
476
477    fn set_learning_rate(&mut self, learningrate: A) {
478        // Set the learning _rate for all optimizers
479        for optimizer in &mut self.optimizers {
480            optimizer.set_learning_rate(learningrate);
481        }
482    }
483}
484
485/// A chained composition of optimizers
486///
487/// This wraps one optimizer with another, similar to how Lookahead wraps
488/// another optimizer. The inner optimizer is applied first, and then the
489/// outer optimizer is applied to the result.
490///
491/// # Example
492///
493/// ```
494/// use scirs2_core::ndarray::Array1;
495/// use optirs_core::optimizer_composition::ChainedOptimizer;
496/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
497///
498/// // Create optimizers
499/// let inner = SGD::new(0.1);
500/// let outer = Adam::new(0.01);
501///
502/// // Chain them together
503/// let mut chained_optimizer = ChainedOptimizer::new(Box::new(inner), Box::new(outer));
504///
505/// // Use the chained optimizer
506/// let params = Array1::zeros(5);
507/// let gradients = Array1::ones(5);
508/// let updated_params = chained_optimizer.step(&params, &gradients).expect("unwrap failed");
509/// ```
510pub struct ChainedOptimizer<A, D>
511where
512    A: Float + ScalarOperand + Debug,
513    D: Dimension,
514{
515    /// The inner optimizer, applied first
516    inner: Box<dyn Optimizer<A, D>>,
517    /// The outer optimizer, applied to the result of the inner optimizer
518    outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523    A: Float + ScalarOperand + Debug,
524    D: Dimension,
525{
526    /// Create a new chained optimizer
527    ///
528    /// # Arguments
529    ///
530    /// * `inner` - The inner optimizer, applied first
531    /// * `outer` - The outer optimizer, applied to the result of the inner optimizer
532    pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533        Self { inner, outer }
534    }
535
536    /// Get a reference to the inner optimizer
537    pub fn inner(&self) -> &dyn Optimizer<A, D> {
538        self.inner.as_ref()
539    }
540
541    /// Get a mutable reference to the inner optimizer
542    pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543        self.inner.as_mut()
544    }
545
546    /// Get a reference to the outer optimizer
547    pub fn outer(&self) -> &dyn Optimizer<A, D> {
548        self.outer.as_ref()
549    }
550
551    /// Get a mutable reference to the outer optimizer
552    pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553        self.outer.as_mut()
554    }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559    A: Float + ScalarOperand + Debug,
560    D: Dimension,
561{
562    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563        // Apply the inner optimizer first
564        let intermediate_params = self.inner.step(params, gradients)?;
565
566        // Then apply the outer optimizer to the result
567        self.outer.step(&intermediate_params, gradients)
568    }
569
570    fn get_learning_rate(&self) -> A {
571        // Return the learning rate of the inner optimizer
572        self.inner.get_learning_rate()
573    }
574
575    fn set_learning_rate(&mut self, learningrate: A) {
576        // Set the learning _rate for both optimizers
577        self.inner.set_learning_rate(learningrate);
578        self.outer.set_learning_rate(learningrate);
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use crate::optimizers::{Adam, SGD};
586    use approx::assert_abs_diff_eq;
587    use scirs2_core::ndarray::Array1;
588
589    #[test]
590    fn test_sequential_optimizer() {
591        // Create a sequential optimizer with SGD followed by Adam
592        let sgd = SGD::new(0.1);
593        let adam = Adam::new(0.01);
594
595        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
596            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
597
598        // Create test parameters and gradients
599        let params = Array1::zeros(3);
600        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
601
602        // Apply the sequential optimizer
603        let updated_params = seq_optimizer
604            .step(&params, &gradients)
605            .expect("unwrap failed");
606
607        // Verify the result
608        // First SGD updates: params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
609        // Then Adam makes additional updates
610        assert!(updated_params[0] < -0.1);
611        assert!(updated_params[1] < -0.2);
612        assert!(updated_params[2] < -0.3);
613    }
614
615    #[test]
616    fn test_parallel_optimizer() {
617        // Create a parallel optimizer with SGD and Adam
618        let sgd = SGD::new(0.1);
619        let adam = Adam::new(0.01);
620
621        let params1 = Array1::zeros(2);
622        let params2 = Array1::zeros(3);
623
624        let group1 = ParameterGroup::new(params1.clone(), 0); // Use SGD
625        let group2 = ParameterGroup::new(params2.clone(), 1); // Use Adam
626
627        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
628            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
629
630        // Create test gradients
631        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
632        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
633
634        // Update the parameters
635        let updated_params = parallel_optimizer
636            .update_all_parameters(&[gradients1, gradients2])
637            .expect("unwrap failed");
638
639        // Verify the results
640        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
641        assert_abs_diff_eq!(updated_params[0][0], -0.1);
642        assert_abs_diff_eq!(updated_params[0][1], -0.2);
643
644        // Group 2 (Adam): The update will be different due to Adam's adaptive nature
645        // Just verify it's different from the original params
646        assert!(updated_params[1][0] != 0.0);
647        assert!(updated_params[1][1] != 0.0);
648        assert!(updated_params[1][2] != 0.0);
649    }
650
651    #[test]
652    fn test_chained_optimizer() {
653        // Create a chained optimizer with SGD as inner and Adam as outer
654        let inner = SGD::new(0.1);
655        let outer = Adam::new(0.01);
656
657        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
658            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
659
660        // Create test parameters and gradients
661        let params = Array1::zeros(3);
662        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
663
664        // Apply the chained optimizer
665        let updated_params = chained_optimizer
666            .step(&params, &gradients)
667            .expect("unwrap failed");
668
669        // Verify the result
670        // Inner (SGD): params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
671        // Then outer (Adam) applies another update
672        assert!(updated_params[0] < -0.1);
673        assert!(updated_params[1] < -0.2);
674        assert!(updated_params[2] < -0.3);
675    }
676
677    #[test]
678    fn test_sequential_learning_rate() {
679        // Create a sequential optimizer with SGD followed by Adam
680        let sgd = SGD::new(0.1);
681        let adam = Adam::new(0.01);
682
683        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
684            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
685
686        // Test getting the learning rate (should be from the first optimizer)
687        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
688
689        // Test setting the learning rate for all optimizers
690        seq_optimizer.set_learning_rate(0.05);
691
692        // Verify the learning rate has been set for both optimizers
693        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
694        assert_abs_diff_eq!(
695            seq_optimizer
696                .get_optimizer(0)
697                .expect("unwrap failed")
698                .get_learning_rate(),
699            0.05
700        );
701        assert_abs_diff_eq!(
702            seq_optimizer
703                .get_optimizer(1)
704                .expect("unwrap failed")
705                .get_learning_rate(),
706            0.05
707        );
708    }
709
710    #[test]
711    fn test_parallel_optimizer_step_list() {
712        // Create a parallel optimizer with SGD and Adam
713        let sgd = SGD::new(0.1);
714        let adam = Adam::new(0.01);
715
716        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
717            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
718
719        // Create test parameters and gradients
720        let params1 = Array1::zeros(2);
721        let params2 = Array1::zeros(3);
722        let params3 = Array1::zeros(4);
723
724        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
725        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
726        let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
727
728        // Use step_list to update all parameters
729        let params_refs = vec![&params1, &params2, &params3];
730        let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
731
732        let updated_params = parallel_optimizer
733            .step_list(&params_refs, &gradients_refs)
734            .expect("unwrap failed");
735
736        // Verify the results
737        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
738        assert_abs_diff_eq!(updated_params[0][0], -0.1);
739        assert_abs_diff_eq!(updated_params[0][1], -0.2);
740
741        // Group 2 will use SGD since we only have 2 optimizers and index 1 % 2 = 1 (Adam)
742        // Adam: The update will be different than SGD
743        assert!(updated_params[1][0] != -0.3);
744
745        // Group 3 will wrap around to optimize with Adam
746        // Just check that it's been updated from zero
747        assert!(updated_params[2][0] < 0.0);
748    }
749
750    #[test]
751    fn test_chained_optimizer_learning_rate() {
752        // Create a chained optimizer with SGD as inner and Adam as outer
753        let inner = SGD::new(0.1);
754        let outer = Adam::new(0.01);
755
756        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
757            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
758
759        // Test getting the learning rate (should be from the inner optimizer)
760        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
761
762        // Test setting the learning rate for both optimizers
763        chained_optimizer.set_learning_rate(0.05);
764
765        // Verify the learning rate has been set for both optimizers
766        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
767        assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
768        assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
769    }
770}