optirs_core/optimizer_composition/mod.rs
1// Optimizer composition framework
2//
3// This module provides compositions of optimizers to create more sophisticated
4// optimization strategies. It includes three main types of compositions:
5//
6// 1. **Sequential**: Apply multiple optimizers in sequence
7// 2. **Parallel**: Apply different optimizers to different parameter groups
8// 3. **Chained**: Wrap an optimizer with another (similar to Lookahead wrapping other optimizers)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// A sequential composition of optimizers
17///
18/// This applies multiple optimizers in sequence to the same parameters.
19/// Each optimizer's output becomes the input to the next optimizer.
20///
21/// # Example
22///
23/// ```
24/// use scirs2_core::ndarray::Array1;
25/// use optirs_core::optimizer_composition::SequentialOptimizer;
26/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
27///
28/// // Create optimizers
29/// let sgd = SGD::new(0.1);
30/// let adam = Adam::new(0.01);
31///
32/// // Combine them sequentially
33/// let mut seq_optimizer = SequentialOptimizer::new(vec![
34/// Box::new(sgd),
35/// Box::new(adam),
36/// ]);
37///
38/// // Use the sequential optimizer
39/// let params = Array1::zeros(5);
40/// let gradients = Array1::ones(5);
41/// let updated_params = seq_optimizer.step(¶ms, &gradients).unwrap();
42/// ```
43pub struct SequentialOptimizer<A, D>
44where
45 A: Float + ScalarOperand + Debug,
46 D: Dimension,
47{
48 /// List of optimizers to apply in sequence
49 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54 A: Float + ScalarOperand + Debug,
55 D: Dimension,
56{
57 /// Create a new sequential optimizer
58 ///
59 /// # Arguments
60 ///
61 /// * `optimizers` - List of optimizers to apply in sequence
62 pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63 Self { optimizers }
64 }
65
66 /// Add an optimizer to the sequence
67 ///
68 /// # Arguments
69 ///
70 /// * `optimizer` - The optimizer to add
71 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72 self.optimizers.push(optimizer);
73 }
74
75 /// Get the number of optimizers in the sequence
76 pub fn num_optimizers(&self) -> usize {
77 self.optimizers.len()
78 }
79
80 /// Get a reference to an optimizer by index
81 ///
82 /// # Arguments
83 ///
84 /// * `index` - The index of the optimizer
85 ///
86 /// # Returns
87 ///
88 /// A reference to the optimizer at the given index, or None if out of bounds
89 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90 if index < self.optimizers.len() {
91 Some(self.optimizers[index].as_ref())
92 } else {
93 None
94 }
95 }
96
97 /// Get a mutable reference to an optimizer by index
98 ///
99 /// # Arguments
100 ///
101 /// * `index` - The index of the optimizer
102 ///
103 /// # Returns
104 ///
105 /// A mutable reference to the optimizer at the given index, or None if out of bounds
106 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107 if index < self.optimizers.len() {
108 Some(self.optimizers[index].as_mut())
109 } else {
110 None
111 }
112 }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117 A: Float + ScalarOperand + Debug,
118 D: Dimension,
119{
120 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121 // Check if we have any optimizers
122 if self.optimizers.is_empty() {
123 return Err(OptimError::InvalidConfig(
124 "SequentialOptimizer has no optimizers".to_string(),
125 ));
126 }
127
128 // Start with the initial parameters
129 let mut current_params = params.clone();
130
131 // Apply each optimizer in sequence
132 for optimizer in &mut self.optimizers {
133 current_params = optimizer.step(¤t_params, gradients)?;
134 }
135
136 Ok(current_params)
137 }
138
139 fn get_learning_rate(&self) -> A {
140 // Return the learning rate of the first optimizer, or a default if empty
141 if let Some(optimizer) = self.optimizers.first() {
142 optimizer.get_learning_rate()
143 } else {
144 A::from(0.01).unwrap() // Default learning rate
145 }
146 }
147
148 fn set_learning_rate(&mut self, learningrate: A) {
149 // Set the learning _rate for all optimizers
150 for optimizer in &mut self.optimizers {
151 optimizer.set_learning_rate(learningrate);
152 }
153 }
154}
155
156/// A struct for assigning parameters to specific groups for parallel optimization
157pub struct ParameterGroup<A, D>
158where
159 A: Float + ScalarOperand + Debug,
160 D: Dimension,
161{
162 /// The parameters in this group
163 pub params: Array<A, D>,
164 /// The index of the optimizer to use for this group
165 pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170 A: Float + ScalarOperand + Debug,
171 D: Dimension,
172{
173 /// Create a new parameter group
174 ///
175 /// # Arguments
176 ///
177 /// * `params` - The parameters in this group
178 /// * `optimizerindex` - The index of the optimizer to use for this group
179 pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180 Self {
181 params,
182 optimizerindex,
183 }
184 }
185}
186
187/// A parallel composition of optimizers
188///
189/// This applies different optimizers to different parameter groups.
190/// Each group of parameters is updated using its assigned optimizer.
191///
192/// # Example
193///
194/// ```
195/// use scirs2_core::ndarray::Array1;
196/// use optirs_core::optimizer_composition::{ParallelOptimizer, ParameterGroup};
197/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
198///
199/// // Create optimizers
200/// let sgd = SGD::new(0.1);
201/// let adam = Adam::new(0.01);
202///
203/// // Create parameter groups
204/// let params1 = Array1::zeros(3);
205/// let params2 = Array1::zeros(5);
206///
207/// let group1 = ParameterGroup::new(params1, 0); // Use SGD
208/// let group2 = ParameterGroup::new(params2, 1); // Use Adam
209///
210/// // Combine them in parallel
211/// let mut parallel_optimizer = ParallelOptimizer::new(
212/// vec![Box::new(sgd), Box::new(adam)],
213/// vec![group1, group2],
214/// );
215///
216/// // The step method will update all parameter groups using their assigned optimizers
217/// // (In a real use case, you'd provide the corresponding gradients)
218/// ```
219pub struct ParallelOptimizer<A, D>
220where
221 A: Float + ScalarOperand + Debug,
222 D: Dimension,
223{
224 /// List of optimizers to apply to different parameter groups
225 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226 /// Groups of parameters with their assigned optimizer indices
227 parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232 A: Float + ScalarOperand + Debug,
233 D: Dimension,
234{
235 /// Create a new parallel optimizer
236 ///
237 /// # Arguments
238 ///
239 /// * `optimizers` - List of optimizers to use
240 /// * `parameter_groups` - Groups of parameters with their assigned optimizer indices
241 pub fn new(
242 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243 parameter_groups: Vec<ParameterGroup<A, D>>,
244 ) -> Self {
245 Self {
246 optimizers,
247 parameter_groups,
248 }
249 }
250
251 /// Add an optimizer
252 ///
253 /// # Arguments
254 ///
255 /// * `optimizer` - The optimizer to add
256 ///
257 /// # Returns
258 ///
259 /// The index of the added optimizer
260 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261 let index = self.optimizers.len();
262 self.optimizers.push(optimizer);
263 index
264 }
265
266 /// Add a parameter group
267 ///
268 /// # Arguments
269 ///
270 /// * `params` - The parameters in this group
271 /// * `optimizerindex` - The index of the optimizer to use for this group
272 ///
273 /// # Returns
274 ///
275 /// Result with the index of the added parameter group, or an error if the optimizer index is invalid
276 pub fn add_parameter_group(
277 &mut self,
278 params: Array<A, D>,
279 optimizerindex: usize,
280 ) -> Result<usize> {
281 // Check if the optimizer _index is valid
282 if optimizerindex >= self.optimizers.len() {
283 return Err(OptimError::InvalidConfig(format!(
284 "Invalid optimizer _index: {}. Only {} optimizers available.",
285 optimizerindex,
286 self.optimizers.len()
287 )));
288 }
289
290 let _index = self.parameter_groups.len();
291 self.parameter_groups
292 .push(ParameterGroup::new(params, optimizerindex));
293 Ok(_index)
294 }
295
296 /// Get the number of optimizers
297 pub fn num_optimizers(&self) -> usize {
298 self.optimizers.len()
299 }
300
301 /// Get the number of parameter groups
302 pub fn num_parameter_groups(&self) -> usize {
303 self.parameter_groups.len()
304 }
305
306 /// Get a reference to an optimizer by index
307 ///
308 /// # Arguments
309 ///
310 /// * `index` - The index of the optimizer
311 ///
312 /// # Returns
313 ///
314 /// A reference to the optimizer at the given index, or None if out of bounds
315 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316 if index < self.optimizers.len() {
317 Some(self.optimizers[index].as_ref())
318 } else {
319 None
320 }
321 }
322
323 /// Get a mutable reference to an optimizer by index
324 ///
325 /// # Arguments
326 ///
327 /// * `index` - The index of the optimizer
328 ///
329 /// # Returns
330 ///
331 /// A mutable reference to the optimizer at the given index, or None if out of bounds
332 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333 if index < self.optimizers.len() {
334 Some(self.optimizers[index].as_mut())
335 } else {
336 None
337 }
338 }
339
340 /// Get a reference to a parameter group by index
341 ///
342 /// # Arguments
343 ///
344 /// * `index` - The index of the parameter group
345 ///
346 /// # Returns
347 ///
348 /// A reference to the parameter group at the given index, or None if out of bounds
349 pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350 self.parameter_groups.get(index)
351 }
352
353 /// Get a mutable reference to a parameter group by index
354 ///
355 /// # Arguments
356 ///
357 /// * `index` - The index of the parameter group
358 ///
359 /// # Returns
360 ///
361 /// A mutable reference to the parameter group at the given index, or None if out of bounds
362 pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363 self.parameter_groups.get_mut(index)
364 }
365
366 /// Get all current parameter values as a single array
367 ///
368 /// # Returns
369 ///
370 /// A result containing all parameter values concatenated into a single array
371 pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372 Ok(self
373 .parameter_groups
374 .iter()
375 .map(|group| group.params.clone())
376 .collect())
377 }
378
379 /// Update all parameter groups using their assigned optimizers
380 ///
381 /// # Arguments
382 ///
383 /// * `gradients` - List of gradient arrays corresponding to parameter groups
384 ///
385 /// # Returns
386 ///
387 /// Result with the updated parameter values, or an error
388 pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389 // Check if the number of gradients matches the number of parameter groups
390 if gradients.len() != self.parameter_groups.len() {
391 return Err(OptimError::InvalidConfig(format!(
392 "Number of gradients ({}) does not match number of parameter groups ({})",
393 gradients.len(),
394 self.parameter_groups.len()
395 )));
396 }
397
398 let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400 // Update each parameter group using its assigned optimizer
401 for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402 let optimizerindex = group.optimizerindex;
403
404 // Check if the optimizer index is valid
405 if optimizerindex >= self.optimizers.len() {
406 return Err(OptimError::InvalidConfig(format!(
407 "Invalid optimizer index: {}. Only {} optimizers available.",
408 optimizerindex,
409 self.optimizers.len()
410 )));
411 }
412
413 // Get the optimizer and update the parameters
414 let optimizer = &mut self.optimizers[optimizerindex];
415 let params = &group.params;
416 let gradient = &gradients[i];
417
418 // Update the parameters
419 let updated = optimizer.step(params, gradient)?;
420 group.params = updated.clone();
421 updated_params.push(updated);
422 }
423
424 Ok(updated_params)
425 }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430 A: Float + ScalarOperand + Debug,
431 D: Dimension,
432{
433 fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434 // This implementation is a bit tricky since we have multiple parameter groups
435 // We'll return an error message directing users to use update_all_parameters instead
436 Err(OptimError::InvalidConfig(
437 "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438 .to_string(),
439 ))
440 }
441
442 fn step_list(
443 &mut self,
444 params_list: &[&Array<A, D>],
445 gradients_list: &[&Array<A, D>],
446 ) -> Result<Vec<Array<A, D>>> {
447 // Convert params_list to owned arrays
448 let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450 // Set parameter groups based on the input params
451 self.parameter_groups = params_vec
452 .into_iter()
453 .enumerate()
454 .map(|(i, params)| {
455 // Use the first optimizer for all if there are more params than optimizers
456 let optimizerindex = i.min(self.optimizers.len() - 1);
457 ParameterGroup::new(params, optimizerindex)
458 })
459 .collect();
460
461 // Convert gradients_list to owned arrays
462 let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464 // Update parameter groups using their assigned optimizers
465 self.update_all_parameters(&gradients_vec)
466 }
467
468 fn get_learning_rate(&self) -> A {
469 // Return the learning rate of the first optimizer, or a default if empty
470 if let Some(optimizer) = self.optimizers.first() {
471 optimizer.get_learning_rate()
472 } else {
473 A::from(0.01).unwrap() // Default learning rate
474 }
475 }
476
477 fn set_learning_rate(&mut self, learningrate: A) {
478 // Set the learning _rate for all optimizers
479 for optimizer in &mut self.optimizers {
480 optimizer.set_learning_rate(learningrate);
481 }
482 }
483}
484
485/// A chained composition of optimizers
486///
487/// This wraps one optimizer with another, similar to how Lookahead wraps
488/// another optimizer. The inner optimizer is applied first, and then the
489/// outer optimizer is applied to the result.
490///
491/// # Example
492///
493/// ```
494/// use scirs2_core::ndarray::Array1;
495/// use optirs_core::optimizer_composition::ChainedOptimizer;
496/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
497///
498/// // Create optimizers
499/// let inner = SGD::new(0.1);
500/// let outer = Adam::new(0.01);
501///
502/// // Chain them together
503/// let mut chained_optimizer = ChainedOptimizer::new(Box::new(inner), Box::new(outer));
504///
505/// // Use the chained optimizer
506/// let params = Array1::zeros(5);
507/// let gradients = Array1::ones(5);
508/// let updated_params = chained_optimizer.step(¶ms, &gradients).unwrap();
509/// ```
510pub struct ChainedOptimizer<A, D>
511where
512 A: Float + ScalarOperand + Debug,
513 D: Dimension,
514{
515 /// The inner optimizer, applied first
516 inner: Box<dyn Optimizer<A, D>>,
517 /// The outer optimizer, applied to the result of the inner optimizer
518 outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523 A: Float + ScalarOperand + Debug,
524 D: Dimension,
525{
526 /// Create a new chained optimizer
527 ///
528 /// # Arguments
529 ///
530 /// * `inner` - The inner optimizer, applied first
531 /// * `outer` - The outer optimizer, applied to the result of the inner optimizer
532 pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533 Self { inner, outer }
534 }
535
536 /// Get a reference to the inner optimizer
537 pub fn inner(&self) -> &dyn Optimizer<A, D> {
538 self.inner.as_ref()
539 }
540
541 /// Get a mutable reference to the inner optimizer
542 pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543 self.inner.as_mut()
544 }
545
546 /// Get a reference to the outer optimizer
547 pub fn outer(&self) -> &dyn Optimizer<A, D> {
548 self.outer.as_ref()
549 }
550
551 /// Get a mutable reference to the outer optimizer
552 pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553 self.outer.as_mut()
554 }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559 A: Float + ScalarOperand + Debug,
560 D: Dimension,
561{
562 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563 // Apply the inner optimizer first
564 let intermediate_params = self.inner.step(params, gradients)?;
565
566 // Then apply the outer optimizer to the result
567 self.outer.step(&intermediate_params, gradients)
568 }
569
570 fn get_learning_rate(&self) -> A {
571 // Return the learning rate of the inner optimizer
572 self.inner.get_learning_rate()
573 }
574
575 fn set_learning_rate(&mut self, learningrate: A) {
576 // Set the learning _rate for both optimizers
577 self.inner.set_learning_rate(learningrate);
578 self.outer.set_learning_rate(learningrate);
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use crate::optimizers::{Adam, SGD};
586 use approx::assert_abs_diff_eq;
587 use scirs2_core::ndarray::Array1;
588
589 #[test]
590 fn test_sequential_optimizer() {
591 // Create a sequential optimizer with SGD followed by Adam
592 let sgd = SGD::new(0.1);
593 let adam = Adam::new(0.01);
594
595 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
596 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
597
598 // Create test parameters and gradients
599 let params = Array1::zeros(3);
600 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
601
602 // Apply the sequential optimizer
603 let updated_params = seq_optimizer.step(¶ms, &gradients).unwrap();
604
605 // Verify the result
606 // First SGD updates: params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
607 // Then Adam makes additional updates
608 assert!(updated_params[0] < -0.1);
609 assert!(updated_params[1] < -0.2);
610 assert!(updated_params[2] < -0.3);
611 }
612
613 #[test]
614 fn test_parallel_optimizer() {
615 // Create a parallel optimizer with SGD and Adam
616 let sgd = SGD::new(0.1);
617 let adam = Adam::new(0.01);
618
619 let params1 = Array1::zeros(2);
620 let params2 = Array1::zeros(3);
621
622 let group1 = ParameterGroup::new(params1.clone(), 0); // Use SGD
623 let group2 = ParameterGroup::new(params2.clone(), 1); // Use Adam
624
625 let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
626 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
627
628 // Create test gradients
629 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
630 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
631
632 // Update the parameters
633 let updated_params = parallel_optimizer
634 .update_all_parameters(&[gradients1, gradients2])
635 .unwrap();
636
637 // Verify the results
638 // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
639 assert_abs_diff_eq!(updated_params[0][0], -0.1);
640 assert_abs_diff_eq!(updated_params[0][1], -0.2);
641
642 // Group 2 (Adam): The update will be different due to Adam's adaptive nature
643 // Just verify it's different from the original params
644 assert!(updated_params[1][0] != 0.0);
645 assert!(updated_params[1][1] != 0.0);
646 assert!(updated_params[1][2] != 0.0);
647 }
648
649 #[test]
650 fn test_chained_optimizer() {
651 // Create a chained optimizer with SGD as inner and Adam as outer
652 let inner = SGD::new(0.1);
653 let outer = Adam::new(0.01);
654
655 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
656 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
657
658 // Create test parameters and gradients
659 let params = Array1::zeros(3);
660 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
661
662 // Apply the chained optimizer
663 let updated_params = chained_optimizer.step(¶ms, &gradients).unwrap();
664
665 // Verify the result
666 // Inner (SGD): params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
667 // Then outer (Adam) applies another update
668 assert!(updated_params[0] < -0.1);
669 assert!(updated_params[1] < -0.2);
670 assert!(updated_params[2] < -0.3);
671 }
672
673 #[test]
674 fn test_sequential_learning_rate() {
675 // Create a sequential optimizer with SGD followed by Adam
676 let sgd = SGD::new(0.1);
677 let adam = Adam::new(0.01);
678
679 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
680 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
681
682 // Test getting the learning rate (should be from the first optimizer)
683 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
684
685 // Test setting the learning rate for all optimizers
686 seq_optimizer.set_learning_rate(0.05);
687
688 // Verify the learning rate has been set for both optimizers
689 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
690 assert_abs_diff_eq!(
691 seq_optimizer.get_optimizer(0).unwrap().get_learning_rate(),
692 0.05
693 );
694 assert_abs_diff_eq!(
695 seq_optimizer.get_optimizer(1).unwrap().get_learning_rate(),
696 0.05
697 );
698 }
699
700 #[test]
701 fn test_parallel_optimizer_step_list() {
702 // Create a parallel optimizer with SGD and Adam
703 let sgd = SGD::new(0.1);
704 let adam = Adam::new(0.01);
705
706 let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
707 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
708
709 // Create test parameters and gradients
710 let params1 = Array1::zeros(2);
711 let params2 = Array1::zeros(3);
712 let params3 = Array1::zeros(4);
713
714 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
715 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
716 let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
717
718 // Use step_list to update all parameters
719 let params_refs = vec![¶ms1, ¶ms2, ¶ms3];
720 let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
721
722 let updated_params = parallel_optimizer
723 .step_list(¶ms_refs, &gradients_refs)
724 .unwrap();
725
726 // Verify the results
727 // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
728 assert_abs_diff_eq!(updated_params[0][0], -0.1);
729 assert_abs_diff_eq!(updated_params[0][1], -0.2);
730
731 // Group 2 will use SGD since we only have 2 optimizers and index 1 % 2 = 1 (Adam)
732 // Adam: The update will be different than SGD
733 assert!(updated_params[1][0] != -0.3);
734
735 // Group 3 will wrap around to optimize with Adam
736 // Just check that it's been updated from zero
737 assert!(updated_params[2][0] < 0.0);
738 }
739
740 #[test]
741 fn test_chained_optimizer_learning_rate() {
742 // Create a chained optimizer with SGD as inner and Adam as outer
743 let inner = SGD::new(0.1);
744 let outer = Adam::new(0.01);
745
746 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
747 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
748
749 // Test getting the learning rate (should be from the inner optimizer)
750 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
751
752 // Test setting the learning rate for both optimizers
753 chained_optimizer.set_learning_rate(0.05);
754
755 // Verify the learning rate has been set for both optimizers
756 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
757 assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
758 assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
759 }
760}