Skip to main content

optirs_core/optimizer_composition/
mod.rs

1// Optimizer composition framework
2//
3// This module provides compositions of optimizers to create more sophisticated
4// optimization strategies. It includes three main types of compositions:
5//
6// 1. **Sequential**: Apply multiple optimizers in sequence
7// 2. **Parallel**: Apply different optimizers to different parameter groups
8// 3. **Chained**: Wrap an optimizer with another (similar to Lookahead wrapping other optimizers)
9
10use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16/// A sequential composition of optimizers
17///
18/// This applies multiple optimizers in sequence to the same parameters.
19/// Each optimizer's output becomes the input to the next optimizer.
20///
21/// # Example
22///
23/// ```
24/// use scirs2_core::ndarray::Array1;
25/// use optirs_core::optimizer_composition::SequentialOptimizer;
26/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
27///
28/// // Create optimizers
29/// let sgd = SGD::new(0.1);
30/// let adam = Adam::new(0.01);
31///
32/// // Combine them sequentially
33/// let mut seq_optimizer = SequentialOptimizer::new(vec![
34///     Box::new(sgd),
35///     Box::new(adam),
36/// ]);
37///
38/// // Use the sequential optimizer
39/// let params = Array1::zeros(5);
40/// let gradients = Array1::ones(5);
41/// let updated_params = seq_optimizer.step(&params, &gradients).expect("unwrap failed");
42/// ```
43pub struct SequentialOptimizer<A, D>
44where
45    A: Float + ScalarOperand + Debug,
46    D: Dimension,
47{
48    /// List of optimizers to apply in sequence
49    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54    A: Float + ScalarOperand + Debug,
55    D: Dimension,
56{
57    /// Create a new sequential optimizer
58    ///
59    /// # Arguments
60    ///
61    /// * `optimizers` - List of optimizers to apply in sequence
62    pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63        Self { optimizers }
64    }
65
66    /// Add an optimizer to the sequence
67    ///
68    /// # Arguments
69    ///
70    /// * `optimizer` - The optimizer to add
71    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72        self.optimizers.push(optimizer);
73    }
74
75    /// Get the number of optimizers in the sequence
76    pub fn num_optimizers(&self) -> usize {
77        self.optimizers.len()
78    }
79
80    /// Get a reference to an optimizer by index
81    ///
82    /// # Arguments
83    ///
84    /// * `index` - The index of the optimizer
85    ///
86    /// # Returns
87    ///
88    /// A reference to the optimizer at the given index, or None if out of bounds
89    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90        if index < self.optimizers.len() {
91            Some(self.optimizers[index].as_ref())
92        } else {
93            None
94        }
95    }
96
97    /// Get a mutable reference to an optimizer by index
98    ///
99    /// # Arguments
100    ///
101    /// * `index` - The index of the optimizer
102    ///
103    /// # Returns
104    ///
105    /// A mutable reference to the optimizer at the given index, or None if out of bounds
106    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107        if index < self.optimizers.len() {
108            Some(self.optimizers[index].as_mut())
109        } else {
110            None
111        }
112    }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117    A: Float + ScalarOperand + Debug,
118    D: Dimension,
119{
120    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121        // Check if we have any optimizers
122        if self.optimizers.is_empty() {
123            return Err(OptimError::InvalidConfig(
124                "SequentialOptimizer has no optimizers".to_string(),
125            ));
126        }
127
128        // Start with the initial parameters
129        let mut current_params = params.clone();
130
131        // Apply each optimizer in sequence
132        for optimizer in &mut self.optimizers {
133            current_params = optimizer.step(&current_params, gradients)?;
134        }
135
136        Ok(current_params)
137    }
138
139    fn get_learning_rate(&self) -> A {
140        // Return the learning rate of the first optimizer, or a default if empty
141        if let Some(optimizer) = self.optimizers.first() {
142            optimizer.get_learning_rate()
143        } else {
144            A::from(0.01).expect("unwrap failed") // Default learning rate
145        }
146    }
147
148    fn set_learning_rate(&mut self, learningrate: A) {
149        // Set the learning _rate for all optimizers
150        for optimizer in &mut self.optimizers {
151            optimizer.set_learning_rate(learningrate);
152        }
153    }
154}
155
156/// A struct for assigning parameters to specific groups for parallel optimization
157pub struct ParameterGroup<A, D>
158where
159    A: Float + ScalarOperand + Debug,
160    D: Dimension,
161{
162    /// The parameters in this group
163    pub params: Array<A, D>,
164    /// The index of the optimizer to use for this group
165    pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170    A: Float + ScalarOperand + Debug,
171    D: Dimension,
172{
173    /// Create a new parameter group
174    ///
175    /// # Arguments
176    ///
177    /// * `params` - The parameters in this group
178    /// * `optimizerindex` - The index of the optimizer to use for this group
179    pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180        Self {
181            params,
182            optimizerindex,
183        }
184    }
185}
186
187/// A parallel composition of optimizers
188///
189/// This applies different optimizers to different parameter groups.
190/// Each group of parameters is updated using its assigned optimizer.
191///
192/// # Example
193///
194/// ```
195/// use scirs2_core::ndarray::Array1;
196/// use optirs_core::optimizer_composition::{ParallelOptimizer, ParameterGroup};
197/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
198///
199/// // Create optimizers
200/// let sgd = SGD::new(0.1);
201/// let adam = Adam::new(0.01);
202///
203/// // Create parameter groups
204/// let params1 = Array1::zeros(3);
205/// let params2 = Array1::zeros(5);
206///
207/// let group1 = ParameterGroup::new(params1, 0); // Use SGD
208/// let group2 = ParameterGroup::new(params2, 1); // Use Adam
209///
210/// // Combine them in parallel
211/// let mut parallel_optimizer = ParallelOptimizer::new(
212///     vec![Box::new(sgd), Box::new(adam)],
213///     vec![group1, group2],
214/// );
215///
216/// // The step method will update all parameter groups using their assigned optimizers
217/// // (In a real use case, you'd provide the corresponding gradients)
218/// ```
219pub struct ParallelOptimizer<A, D>
220where
221    A: Float + ScalarOperand + Debug,
222    D: Dimension,
223{
224    /// List of optimizers to apply to different parameter groups
225    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226    /// Groups of parameters with their assigned optimizer indices
227    parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232    A: Float + ScalarOperand + Debug,
233    D: Dimension,
234{
235    /// Create a new parallel optimizer
236    ///
237    /// # Arguments
238    ///
239    /// * `optimizers` - List of optimizers to use
240    /// * `parameter_groups` - Groups of parameters with their assigned optimizer indices
241    pub fn new(
242        optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243        parameter_groups: Vec<ParameterGroup<A, D>>,
244    ) -> Self {
245        Self {
246            optimizers,
247            parameter_groups,
248        }
249    }
250
251    /// Add an optimizer
252    ///
253    /// # Arguments
254    ///
255    /// * `optimizer` - The optimizer to add
256    ///
257    /// # Returns
258    ///
259    /// The index of the added optimizer
260    pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261        let index = self.optimizers.len();
262        self.optimizers.push(optimizer);
263        index
264    }
265
266    /// Add a parameter group
267    ///
268    /// # Arguments
269    ///
270    /// * `params` - The parameters in this group
271    /// * `optimizerindex` - The index of the optimizer to use for this group
272    ///
273    /// # Returns
274    ///
275    /// Result with the index of the added parameter group, or an error if the optimizer index is invalid
276    pub fn add_parameter_group(
277        &mut self,
278        params: Array<A, D>,
279        optimizerindex: usize,
280    ) -> Result<usize> {
281        // Check if the optimizer _index is valid
282        if optimizerindex >= self.optimizers.len() {
283            return Err(OptimError::InvalidConfig(format!(
284                "Invalid optimizer _index: {}. Only {} optimizers available.",
285                optimizerindex,
286                self.optimizers.len()
287            )));
288        }
289
290        let _index = self.parameter_groups.len();
291        self.parameter_groups
292            .push(ParameterGroup::new(params, optimizerindex));
293        Ok(_index)
294    }
295
296    /// Get the number of optimizers
297    pub fn num_optimizers(&self) -> usize {
298        self.optimizers.len()
299    }
300
301    /// Get the number of parameter groups
302    pub fn num_parameter_groups(&self) -> usize {
303        self.parameter_groups.len()
304    }
305
306    /// Get a reference to an optimizer by index
307    ///
308    /// # Arguments
309    ///
310    /// * `index` - The index of the optimizer
311    ///
312    /// # Returns
313    ///
314    /// A reference to the optimizer at the given index, or None if out of bounds
315    pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316        if index < self.optimizers.len() {
317            Some(self.optimizers[index].as_ref())
318        } else {
319            None
320        }
321    }
322
323    /// Get a mutable reference to an optimizer by index
324    ///
325    /// # Arguments
326    ///
327    /// * `index` - The index of the optimizer
328    ///
329    /// # Returns
330    ///
331    /// A mutable reference to the optimizer at the given index, or None if out of bounds
332    pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333        if index < self.optimizers.len() {
334            Some(self.optimizers[index].as_mut())
335        } else {
336            None
337        }
338    }
339
340    /// Get a reference to a parameter group by index
341    ///
342    /// # Arguments
343    ///
344    /// * `index` - The index of the parameter group
345    ///
346    /// # Returns
347    ///
348    /// A reference to the parameter group at the given index, or None if out of bounds
349    pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350        self.parameter_groups.get(index)
351    }
352
353    /// Get a mutable reference to a parameter group by index
354    ///
355    /// # Arguments
356    ///
357    /// * `index` - The index of the parameter group
358    ///
359    /// # Returns
360    ///
361    /// A mutable reference to the parameter group at the given index, or None if out of bounds
362    pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363        self.parameter_groups.get_mut(index)
364    }
365
366    /// Get all current parameter values as a single array
367    ///
368    /// # Returns
369    ///
370    /// A result containing all parameter values concatenated into a single array
371    pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372        Ok(self
373            .parameter_groups
374            .iter()
375            .map(|group| group.params.clone())
376            .collect())
377    }
378
379    /// Update all parameter groups using their assigned optimizers
380    ///
381    /// # Arguments
382    ///
383    /// * `gradients` - List of gradient arrays corresponding to parameter groups
384    ///
385    /// # Returns
386    ///
387    /// Result with the updated parameter values, or an error
388    pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389        // Check if the number of gradients matches the number of parameter groups
390        if gradients.len() != self.parameter_groups.len() {
391            return Err(OptimError::InvalidConfig(format!(
392                "Number of gradients ({}) does not match number of parameter groups ({})",
393                gradients.len(),
394                self.parameter_groups.len()
395            )));
396        }
397
398        let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400        // Update each parameter group using its assigned optimizer
401        for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402            let optimizerindex = group.optimizerindex;
403
404            // Check if the optimizer index is valid
405            if optimizerindex >= self.optimizers.len() {
406                return Err(OptimError::InvalidConfig(format!(
407                    "Invalid optimizer index: {}. Only {} optimizers available.",
408                    optimizerindex,
409                    self.optimizers.len()
410                )));
411            }
412
413            // Get the optimizer and update the parameters
414            let optimizer = &mut self.optimizers[optimizerindex];
415            let params = &group.params;
416            let gradient = &gradients[i];
417
418            // Update the parameters
419            let updated = optimizer.step(params, gradient)?;
420            group.params = updated.clone();
421            updated_params.push(updated);
422        }
423
424        Ok(updated_params)
425    }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430    A: Float + ScalarOperand + Debug,
431    D: Dimension,
432{
433    fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434        // This implementation is a bit tricky since we have multiple parameter groups
435        // We'll return an error message directing users to use update_all_parameters instead
436        Err(OptimError::InvalidConfig(
437            "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438                .to_string(),
439        ))
440    }
441
442    fn step_list(
443        &mut self,
444        params_list: &[&Array<A, D>],
445        gradients_list: &[&Array<A, D>],
446    ) -> Result<Vec<Array<A, D>>> {
447        // Convert params_list to owned arrays
448        let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450        // Set parameter groups based on the input params
451        self.parameter_groups = params_vec
452            .into_iter()
453            .enumerate()
454            .map(|(i, params)| {
455                // Use the first optimizer for all if there are more params than optimizers
456                let optimizerindex = i.min(self.optimizers.len() - 1);
457                ParameterGroup::new(params, optimizerindex)
458            })
459            .collect();
460
461        // Convert gradients_list to owned arrays
462        let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464        // Update parameter groups using their assigned optimizers
465        self.update_all_parameters(&gradients_vec)
466    }
467
468    fn get_learning_rate(&self) -> A {
469        // Return the learning rate of the first optimizer, or a default if empty
470        if let Some(optimizer) = self.optimizers.first() {
471            optimizer.get_learning_rate()
472        } else {
473            A::from(0.01).expect("unwrap failed") // Default learning rate
474        }
475    }
476
477    fn set_learning_rate(&mut self, learningrate: A) {
478        // Set the learning _rate for all optimizers
479        for optimizer in &mut self.optimizers {
480            optimizer.set_learning_rate(learningrate);
481        }
482    }
483}
484
485/// A chained composition of optimizers
486///
487/// This wraps one optimizer with another, similar to how Lookahead wraps
488/// another optimizer. The inner optimizer is applied first, and then the
489/// outer optimizer is applied to the result.
490///
491/// # Example
492///
493/// ```
494/// use scirs2_core::ndarray::Array1;
495/// use optirs_core::optimizer_composition::ChainedOptimizer;
496/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
497///
498/// // Create optimizers
499/// let inner = SGD::new(0.1);
500/// let outer = Adam::new(0.01);
501///
502/// // Chain them together
503/// let mut chained_optimizer = ChainedOptimizer::new(Box::new(inner), Box::new(outer));
504///
505/// // Use the chained optimizer
506/// let params = Array1::zeros(5);
507/// let gradients = Array1::ones(5);
508/// let updated_params = chained_optimizer.step(&params, &gradients).expect("unwrap failed");
509/// ```
510pub struct ChainedOptimizer<A, D>
511where
512    A: Float + ScalarOperand + Debug,
513    D: Dimension,
514{
515    /// The inner optimizer, applied first
516    inner: Box<dyn Optimizer<A, D>>,
517    /// The outer optimizer, applied to the result of the inner optimizer
518    outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523    A: Float + ScalarOperand + Debug,
524    D: Dimension,
525{
526    /// Create a new chained optimizer
527    ///
528    /// # Arguments
529    ///
530    /// * `inner` - The inner optimizer, applied first
531    /// * `outer` - The outer optimizer, applied to the result of the inner optimizer
532    pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533        Self { inner, outer }
534    }
535
536    /// Get a reference to the inner optimizer
537    pub fn inner(&self) -> &dyn Optimizer<A, D> {
538        self.inner.as_ref()
539    }
540
541    /// Get a mutable reference to the inner optimizer
542    pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543        self.inner.as_mut()
544    }
545
546    /// Get a reference to the outer optimizer
547    pub fn outer(&self) -> &dyn Optimizer<A, D> {
548        self.outer.as_ref()
549    }
550
551    /// Get a mutable reference to the outer optimizer
552    pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553        self.outer.as_mut()
554    }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559    A: Float + ScalarOperand + Debug,
560    D: Dimension,
561{
562    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563        // Apply the inner optimizer first
564        let intermediate_params = self.inner.step(params, gradients)?;
565
566        // Then apply the outer optimizer to the result
567        self.outer.step(&intermediate_params, gradients)
568    }
569
570    fn get_learning_rate(&self) -> A {
571        // Return the learning rate of the inner optimizer
572        self.inner.get_learning_rate()
573    }
574
575    fn set_learning_rate(&mut self, learningrate: A) {
576        // Set the learning _rate for both optimizers
577        self.inner.set_learning_rate(learningrate);
578        self.outer.set_learning_rate(learningrate);
579    }
580}
581
582/// A weighted composition of optimizers
583///
584/// Runs all optimizers on the same parameters/gradients and returns the
585/// weighted average of their outputs. This allows blending the behavior
586/// of multiple optimization strategies.
587///
588/// # Example
589///
590/// ```
591/// use scirs2_core::ndarray::Array1;
592/// use optirs_core::optimizer_composition::WeightedOptimizer;
593/// use optirs_core::optimizers::{SGD, Adam, Optimizer};
594///
595/// // Create a weighted combination of SGD and Adam
596/// let mut weighted = WeightedOptimizer::new()
597///     .add_optimizer(Box::new(SGD::new(0.1)), 0.7)
598///     .add_optimizer(Box::new(Adam::new(0.01)), 0.3);
599///
600/// let params = Array1::zeros(3);
601/// let gradients = Array1::ones(3);
602/// let updated = weighted.step(&params, &gradients).expect("step failed");
603/// ```
604pub struct WeightedOptimizer<A, D>
605where
606    A: Float + ScalarOperand + Debug,
607    D: Dimension,
608{
609    /// The optimizers and their associated weights
610    optimizers: Vec<Box<dyn Optimizer<A, D>>>,
611    /// The weight for each optimizer
612    weights: Vec<A>,
613}
614
615impl<A, D> Default for WeightedOptimizer<A, D>
616where
617    A: Float + ScalarOperand + Debug,
618    D: Dimension,
619{
620    fn default() -> Self {
621        Self::new()
622    }
623}
624
625impl<A, D> WeightedOptimizer<A, D>
626where
627    A: Float + ScalarOperand + Debug,
628    D: Dimension,
629{
630    /// Create a new empty weighted optimizer
631    pub fn new() -> Self {
632        Self {
633            optimizers: Vec::new(),
634            weights: Vec::new(),
635        }
636    }
637
638    /// Add an optimizer with a given weight (builder pattern)
639    ///
640    /// # Arguments
641    ///
642    /// * `opt` - The optimizer to add
643    /// * `weight` - The weight for this optimizer
644    pub fn add_optimizer(mut self, opt: Box<dyn Optimizer<A, D>>, weight: A) -> Self {
645        self.optimizers.push(opt);
646        self.weights.push(weight);
647        self
648    }
649
650    /// Add multiple optimizers at once (builder pattern)
651    ///
652    /// # Arguments
653    ///
654    /// * `opts` - A vector of (optimizer, weight) pairs
655    pub fn with_optimizers(mut self, opts: Vec<(Box<dyn Optimizer<A, D>>, A)>) -> Self {
656        for (opt, weight) in opts {
657            self.optimizers.push(opt);
658            self.weights.push(weight);
659        }
660        self
661    }
662
663    /// Normalize weights so they sum to 1
664    pub fn normalize_weights(&mut self) {
665        let sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
666        if sum > A::zero() {
667            for w in &mut self.weights {
668                *w = *w / sum;
669            }
670        }
671    }
672
673    /// Get the number of optimizers
674    pub fn num_optimizers(&self) -> usize {
675        self.optimizers.len()
676    }
677
678    /// Get the current weights
679    pub fn weights(&self) -> &[A] {
680        &self.weights
681    }
682}
683
684impl<A, D> Optimizer<A, D> for WeightedOptimizer<A, D>
685where
686    A: Float + ScalarOperand + Debug,
687    D: Dimension,
688{
689    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
690        if self.optimizers.is_empty() {
691            return Err(OptimError::InvalidConfig(
692                "WeightedOptimizer has no optimizers".to_string(),
693            ));
694        }
695
696        // Compute the weight sum for normalization
697        let weight_sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
698        if weight_sum <= A::zero() {
699            return Err(OptimError::InvalidConfig(
700                "WeightedOptimizer weight sum must be positive".to_string(),
701            ));
702        }
703
704        // Run each optimizer and accumulate the weighted result
705        let mut result: Option<Array<A, D>> = None;
706
707        for (optimizer, &weight) in self.optimizers.iter_mut().zip(self.weights.iter()) {
708            let updated = optimizer.step(params, gradients)?;
709            let normalized_weight = weight / weight_sum;
710
711            match result {
712                None => {
713                    result = Some(updated * normalized_weight);
714                }
715                Some(ref mut acc) => {
716                    acc.zip_mut_with(&updated, |a, &b| {
717                        *a = *a + b * normalized_weight;
718                    });
719                }
720            }
721        }
722
723        result.ok_or_else(|| {
724            OptimError::InvalidConfig("WeightedOptimizer produced no result".to_string())
725        })
726    }
727
728    fn get_learning_rate(&self) -> A {
729        if let Some(optimizer) = self.optimizers.first() {
730            optimizer.get_learning_rate()
731        } else {
732            A::from(0.01).expect("failed to convert default learning rate")
733        }
734    }
735
736    fn set_learning_rate(&mut self, learning_rate: A) {
737        for optimizer in &mut self.optimizers {
738            optimizer.set_learning_rate(learning_rate);
739        }
740    }
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use crate::optimizers::{Adam, SGD};
747    use approx::assert_abs_diff_eq;
748    use scirs2_core::ndarray::Array1;
749
750    #[test]
751    fn test_sequential_optimizer() {
752        // Create a sequential optimizer with SGD followed by Adam
753        let sgd = SGD::new(0.1);
754        let adam = Adam::new(0.01);
755
756        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
757            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
758
759        // Create test parameters and gradients
760        let params = Array1::zeros(3);
761        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
762
763        // Apply the sequential optimizer
764        let updated_params = seq_optimizer
765            .step(&params, &gradients)
766            .expect("unwrap failed");
767
768        // Verify the result
769        // First SGD updates: params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
770        // Then Adam makes additional updates
771        assert!(updated_params[0] < -0.1);
772        assert!(updated_params[1] < -0.2);
773        assert!(updated_params[2] < -0.3);
774    }
775
776    #[test]
777    fn test_parallel_optimizer() {
778        // Create a parallel optimizer with SGD and Adam
779        let sgd = SGD::new(0.1);
780        let adam = Adam::new(0.01);
781
782        let params1 = Array1::zeros(2);
783        let params2 = Array1::zeros(3);
784
785        let group1 = ParameterGroup::new(params1.clone(), 0); // Use SGD
786        let group2 = ParameterGroup::new(params2.clone(), 1); // Use Adam
787
788        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
789            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
790
791        // Create test gradients
792        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
793        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
794
795        // Update the parameters
796        let updated_params = parallel_optimizer
797            .update_all_parameters(&[gradients1, gradients2])
798            .expect("unwrap failed");
799
800        // Verify the results
801        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
802        assert_abs_diff_eq!(updated_params[0][0], -0.1);
803        assert_abs_diff_eq!(updated_params[0][1], -0.2);
804
805        // Group 2 (Adam): The update will be different due to Adam's adaptive nature
806        // Just verify it's different from the original params
807        assert!(updated_params[1][0] != 0.0);
808        assert!(updated_params[1][1] != 0.0);
809        assert!(updated_params[1][2] != 0.0);
810    }
811
812    #[test]
813    fn test_chained_optimizer() {
814        // Create a chained optimizer with SGD as inner and Adam as outer
815        let inner = SGD::new(0.1);
816        let outer = Adam::new(0.01);
817
818        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
819            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
820
821        // Create test parameters and gradients
822        let params = Array1::zeros(3);
823        let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
824
825        // Apply the chained optimizer
826        let updated_params = chained_optimizer
827            .step(&params, &gradients)
828            .expect("unwrap failed");
829
830        // Verify the result
831        // Inner (SGD): params - 0.1 * gradients = [0, 0, 0] - 0.1 * [1, 2, 3] = [-0.1, -0.2, -0.3]
832        // Then outer (Adam) applies another update
833        assert!(updated_params[0] < -0.1);
834        assert!(updated_params[1] < -0.2);
835        assert!(updated_params[2] < -0.3);
836    }
837
838    #[test]
839    fn test_sequential_learning_rate() {
840        // Create a sequential optimizer with SGD followed by Adam
841        let sgd = SGD::new(0.1);
842        let adam = Adam::new(0.01);
843
844        let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
845            SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
846
847        // Test getting the learning rate (should be from the first optimizer)
848        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
849
850        // Test setting the learning rate for all optimizers
851        seq_optimizer.set_learning_rate(0.05);
852
853        // Verify the learning rate has been set for both optimizers
854        assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
855        assert_abs_diff_eq!(
856            seq_optimizer
857                .get_optimizer(0)
858                .expect("unwrap failed")
859                .get_learning_rate(),
860            0.05
861        );
862        assert_abs_diff_eq!(
863            seq_optimizer
864                .get_optimizer(1)
865                .expect("unwrap failed")
866                .get_learning_rate(),
867            0.05
868        );
869    }
870
871    #[test]
872    fn test_parallel_optimizer_step_list() {
873        // Create a parallel optimizer with SGD and Adam
874        let sgd = SGD::new(0.1);
875        let adam = Adam::new(0.01);
876
877        let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
878            ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
879
880        // Create test parameters and gradients
881        let params1 = Array1::zeros(2);
882        let params2 = Array1::zeros(3);
883        let params3 = Array1::zeros(4);
884
885        let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
886        let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
887        let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
888
889        // Use step_list to update all parameters
890        let params_refs = vec![&params1, &params2, &params3];
891        let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
892
893        let updated_params = parallel_optimizer
894            .step_list(&params_refs, &gradients_refs)
895            .expect("unwrap failed");
896
897        // Verify the results
898        // Group 1 (SGD): params - 0.1 * gradients = [0, 0] - 0.1 * [1, 2] = [-0.1, -0.2]
899        assert_abs_diff_eq!(updated_params[0][0], -0.1);
900        assert_abs_diff_eq!(updated_params[0][1], -0.2);
901
902        // Group 2 will use SGD since we only have 2 optimizers and index 1 % 2 = 1 (Adam)
903        // Adam: The update will be different than SGD
904        assert!(updated_params[1][0] != -0.3);
905
906        // Group 3 will wrap around to optimize with Adam
907        // Just check that it's been updated from zero
908        assert!(updated_params[2][0] < 0.0);
909    }
910
911    #[test]
912    fn test_chained_optimizer_learning_rate() {
913        // Create a chained optimizer with SGD as inner and Adam as outer
914        let inner = SGD::new(0.1);
915        let outer = Adam::new(0.01);
916
917        let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
918            ChainedOptimizer::new(Box::new(inner), Box::new(outer));
919
920        // Test getting the learning rate (should be from the inner optimizer)
921        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
922
923        // Test setting the learning rate for both optimizers
924        chained_optimizer.set_learning_rate(0.05);
925
926        // Verify the learning rate has been set for both optimizers
927        assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
928        assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
929        assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
930    }
931
932    #[test]
933    fn test_weighted_optimizer_basic() {
934        // Create two SGD optimizers with different learning rates
935        let sgd1 = SGD::new(0.1);
936        let sgd2 = SGD::new(0.2);
937
938        let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
939            WeightedOptimizer::new()
940                .add_optimizer(Box::new(sgd1), 0.5)
941                .add_optimizer(Box::new(sgd2), 0.5);
942
943        let params = Array1::zeros(3);
944        let gradients = Array1::ones(3);
945
946        let updated = weighted.step(&params, &gradients).expect("step failed");
947
948        // SGD1: params - 0.1 * grads = [-0.1, -0.1, -0.1]
949        // SGD2: params - 0.2 * grads = [-0.2, -0.2, -0.2]
950        // Weighted avg (0.5 each): 0.5*(-0.1) + 0.5*(-0.2) = -0.15
951        assert_abs_diff_eq!(updated[0], -0.15, epsilon = 1e-10);
952        assert_abs_diff_eq!(updated[1], -0.15, epsilon = 1e-10);
953        assert_abs_diff_eq!(updated[2], -0.15, epsilon = 1e-10);
954    }
955
956    #[test]
957    fn test_weighted_optimizer_unequal_weights() {
958        let sgd1 = SGD::new(0.1);
959        let sgd2 = SGD::new(0.2);
960
961        let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
962            WeightedOptimizer::new()
963                .add_optimizer(Box::new(sgd1), 3.0)
964                .add_optimizer(Box::new(sgd2), 1.0);
965
966        let params = Array1::zeros(2);
967        let gradients = Array1::ones(2);
968
969        let updated = weighted.step(&params, &gradients).expect("step failed");
970
971        // SGD1: [-0.1, -0.1], SGD2: [-0.2, -0.2]
972        // Weights normalized: 3/4=0.75, 1/4=0.25
973        // Result: 0.75*(-0.1) + 0.25*(-0.2) = -0.075 - 0.05 = -0.125
974        assert_abs_diff_eq!(updated[0], -0.125, epsilon = 1e-10);
975    }
976
977    #[test]
978    fn test_weighted_optimizer_empty() {
979        let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
980            WeightedOptimizer::new();
981
982        let params = Array1::zeros(3);
983        let gradients = Array1::ones(3);
984
985        let result = weighted.step(&params, &gradients);
986        assert!(result.is_err());
987    }
988
989    #[test]
990    fn test_weighted_optimizer_normalize_weights() {
991        let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
992            WeightedOptimizer::new()
993                .add_optimizer(Box::new(SGD::new(0.1)), 2.0)
994                .add_optimizer(Box::new(SGD::new(0.2)), 8.0);
995
996        weighted.normalize_weights();
997
998        assert_abs_diff_eq!(weighted.weights()[0], 0.2, epsilon = 1e-10);
999        assert_abs_diff_eq!(weighted.weights()[1], 0.8, epsilon = 1e-10);
1000    }
1001
1002    #[test]
1003    fn test_weighted_optimizer_learning_rate() {
1004        let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
1005            WeightedOptimizer::new()
1006                .add_optimizer(Box::new(SGD::new(0.1)), 1.0)
1007                .add_optimizer(Box::new(Adam::new(0.01)), 1.0);
1008
1009        // Learning rate comes from the first optimizer
1010        assert_abs_diff_eq!(weighted.get_learning_rate(), 0.1);
1011
1012        // Setting learning rate applies to all
1013        weighted.set_learning_rate(0.05);
1014        assert_abs_diff_eq!(weighted.get_learning_rate(), 0.05);
1015    }
1016
1017    #[test]
1018    fn test_weighted_optimizer_with_optimizers() {
1019        let opts: Vec<(Box<dyn Optimizer<f64, scirs2_core::ndarray::Ix1>>, f64)> = vec![
1020            (Box::new(SGD::new(0.1)), 1.0),
1021            (Box::new(SGD::new(0.2)), 1.0),
1022        ];
1023
1024        let weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
1025            WeightedOptimizer::new().with_optimizers(opts);
1026
1027        assert_eq!(weighted.num_optimizers(), 2);
1028        assert_abs_diff_eq!(weighted.weights()[0], 1.0);
1029        assert_abs_diff_eq!(weighted.weights()[1], 1.0);
1030    }
1031}