1use crate::error::{OptimError, Result};
11use crate::optimizers::Optimizer;
12use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
13use scirs2_core::numeric::Float;
14use std::fmt::Debug;
15
16pub struct SequentialOptimizer<A, D>
44where
45 A: Float + ScalarOperand + Debug,
46 D: Dimension,
47{
48 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
50}
51
52impl<A, D> SequentialOptimizer<A, D>
53where
54 A: Float + ScalarOperand + Debug,
55 D: Dimension,
56{
57 pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
63 Self { optimizers }
64 }
65
66 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
72 self.optimizers.push(optimizer);
73 }
74
75 pub fn num_optimizers(&self) -> usize {
77 self.optimizers.len()
78 }
79
80 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
90 if index < self.optimizers.len() {
91 Some(self.optimizers[index].as_ref())
92 } else {
93 None
94 }
95 }
96
97 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
107 if index < self.optimizers.len() {
108 Some(self.optimizers[index].as_mut())
109 } else {
110 None
111 }
112 }
113}
114
115impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
116where
117 A: Float + ScalarOperand + Debug,
118 D: Dimension,
119{
120 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
121 if self.optimizers.is_empty() {
123 return Err(OptimError::InvalidConfig(
124 "SequentialOptimizer has no optimizers".to_string(),
125 ));
126 }
127
128 let mut current_params = params.clone();
130
131 for optimizer in &mut self.optimizers {
133 current_params = optimizer.step(¤t_params, gradients)?;
134 }
135
136 Ok(current_params)
137 }
138
139 fn get_learning_rate(&self) -> A {
140 if let Some(optimizer) = self.optimizers.first() {
142 optimizer.get_learning_rate()
143 } else {
144 A::from(0.01).expect("unwrap failed") }
146 }
147
148 fn set_learning_rate(&mut self, learningrate: A) {
149 for optimizer in &mut self.optimizers {
151 optimizer.set_learning_rate(learningrate);
152 }
153 }
154}
155
156pub struct ParameterGroup<A, D>
158where
159 A: Float + ScalarOperand + Debug,
160 D: Dimension,
161{
162 pub params: Array<A, D>,
164 pub optimizerindex: usize,
166}
167
168impl<A, D> ParameterGroup<A, D>
169where
170 A: Float + ScalarOperand + Debug,
171 D: Dimension,
172{
173 pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
180 Self {
181 params,
182 optimizerindex,
183 }
184 }
185}
186
187pub struct ParallelOptimizer<A, D>
220where
221 A: Float + ScalarOperand + Debug,
222 D: Dimension,
223{
224 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
226 parameter_groups: Vec<ParameterGroup<A, D>>,
228}
229
230impl<A, D> ParallelOptimizer<A, D>
231where
232 A: Float + ScalarOperand + Debug,
233 D: Dimension,
234{
235 pub fn new(
242 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
243 parameter_groups: Vec<ParameterGroup<A, D>>,
244 ) -> Self {
245 Self {
246 optimizers,
247 parameter_groups,
248 }
249 }
250
251 pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
261 let index = self.optimizers.len();
262 self.optimizers.push(optimizer);
263 index
264 }
265
266 pub fn add_parameter_group(
277 &mut self,
278 params: Array<A, D>,
279 optimizerindex: usize,
280 ) -> Result<usize> {
281 if optimizerindex >= self.optimizers.len() {
283 return Err(OptimError::InvalidConfig(format!(
284 "Invalid optimizer _index: {}. Only {} optimizers available.",
285 optimizerindex,
286 self.optimizers.len()
287 )));
288 }
289
290 let _index = self.parameter_groups.len();
291 self.parameter_groups
292 .push(ParameterGroup::new(params, optimizerindex));
293 Ok(_index)
294 }
295
296 pub fn num_optimizers(&self) -> usize {
298 self.optimizers.len()
299 }
300
301 pub fn num_parameter_groups(&self) -> usize {
303 self.parameter_groups.len()
304 }
305
306 pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
316 if index < self.optimizers.len() {
317 Some(self.optimizers[index].as_ref())
318 } else {
319 None
320 }
321 }
322
323 pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
333 if index < self.optimizers.len() {
334 Some(self.optimizers[index].as_mut())
335 } else {
336 None
337 }
338 }
339
340 pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
350 self.parameter_groups.get(index)
351 }
352
353 pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
363 self.parameter_groups.get_mut(index)
364 }
365
366 pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
372 Ok(self
373 .parameter_groups
374 .iter()
375 .map(|group| group.params.clone())
376 .collect())
377 }
378
379 pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
389 if gradients.len() != self.parameter_groups.len() {
391 return Err(OptimError::InvalidConfig(format!(
392 "Number of gradients ({}) does not match number of parameter groups ({})",
393 gradients.len(),
394 self.parameter_groups.len()
395 )));
396 }
397
398 let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
399
400 for (i, group) in self.parameter_groups.iter_mut().enumerate() {
402 let optimizerindex = group.optimizerindex;
403
404 if optimizerindex >= self.optimizers.len() {
406 return Err(OptimError::InvalidConfig(format!(
407 "Invalid optimizer index: {}. Only {} optimizers available.",
408 optimizerindex,
409 self.optimizers.len()
410 )));
411 }
412
413 let optimizer = &mut self.optimizers[optimizerindex];
415 let params = &group.params;
416 let gradient = &gradients[i];
417
418 let updated = optimizer.step(params, gradient)?;
420 group.params = updated.clone();
421 updated_params.push(updated);
422 }
423
424 Ok(updated_params)
425 }
426}
427
428impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
429where
430 A: Float + ScalarOperand + Debug,
431 D: Dimension,
432{
433 fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
434 Err(OptimError::InvalidConfig(
437 "ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
438 .to_string(),
439 ))
440 }
441
442 fn step_list(
443 &mut self,
444 params_list: &[&Array<A, D>],
445 gradients_list: &[&Array<A, D>],
446 ) -> Result<Vec<Array<A, D>>> {
447 let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
449
450 self.parameter_groups = params_vec
452 .into_iter()
453 .enumerate()
454 .map(|(i, params)| {
455 let optimizerindex = i.min(self.optimizers.len() - 1);
457 ParameterGroup::new(params, optimizerindex)
458 })
459 .collect();
460
461 let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
463
464 self.update_all_parameters(&gradients_vec)
466 }
467
468 fn get_learning_rate(&self) -> A {
469 if let Some(optimizer) = self.optimizers.first() {
471 optimizer.get_learning_rate()
472 } else {
473 A::from(0.01).expect("unwrap failed") }
475 }
476
477 fn set_learning_rate(&mut self, learningrate: A) {
478 for optimizer in &mut self.optimizers {
480 optimizer.set_learning_rate(learningrate);
481 }
482 }
483}
484
485pub struct ChainedOptimizer<A, D>
511where
512 A: Float + ScalarOperand + Debug,
513 D: Dimension,
514{
515 inner: Box<dyn Optimizer<A, D>>,
517 outer: Box<dyn Optimizer<A, D>>,
519}
520
521impl<A, D> ChainedOptimizer<A, D>
522where
523 A: Float + ScalarOperand + Debug,
524 D: Dimension,
525{
526 pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
533 Self { inner, outer }
534 }
535
536 pub fn inner(&self) -> &dyn Optimizer<A, D> {
538 self.inner.as_ref()
539 }
540
541 pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
543 self.inner.as_mut()
544 }
545
546 pub fn outer(&self) -> &dyn Optimizer<A, D> {
548 self.outer.as_ref()
549 }
550
551 pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
553 self.outer.as_mut()
554 }
555}
556
557impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
558where
559 A: Float + ScalarOperand + Debug,
560 D: Dimension,
561{
562 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
563 let intermediate_params = self.inner.step(params, gradients)?;
565
566 self.outer.step(&intermediate_params, gradients)
568 }
569
570 fn get_learning_rate(&self) -> A {
571 self.inner.get_learning_rate()
573 }
574
575 fn set_learning_rate(&mut self, learningrate: A) {
576 self.inner.set_learning_rate(learningrate);
578 self.outer.set_learning_rate(learningrate);
579 }
580}
581
582pub struct WeightedOptimizer<A, D>
605where
606 A: Float + ScalarOperand + Debug,
607 D: Dimension,
608{
609 optimizers: Vec<Box<dyn Optimizer<A, D>>>,
611 weights: Vec<A>,
613}
614
615impl<A, D> Default for WeightedOptimizer<A, D>
616where
617 A: Float + ScalarOperand + Debug,
618 D: Dimension,
619{
620 fn default() -> Self {
621 Self::new()
622 }
623}
624
625impl<A, D> WeightedOptimizer<A, D>
626where
627 A: Float + ScalarOperand + Debug,
628 D: Dimension,
629{
630 pub fn new() -> Self {
632 Self {
633 optimizers: Vec::new(),
634 weights: Vec::new(),
635 }
636 }
637
638 pub fn add_optimizer(mut self, opt: Box<dyn Optimizer<A, D>>, weight: A) -> Self {
645 self.optimizers.push(opt);
646 self.weights.push(weight);
647 self
648 }
649
650 pub fn with_optimizers(mut self, opts: Vec<(Box<dyn Optimizer<A, D>>, A)>) -> Self {
656 for (opt, weight) in opts {
657 self.optimizers.push(opt);
658 self.weights.push(weight);
659 }
660 self
661 }
662
663 pub fn normalize_weights(&mut self) {
665 let sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
666 if sum > A::zero() {
667 for w in &mut self.weights {
668 *w = *w / sum;
669 }
670 }
671 }
672
673 pub fn num_optimizers(&self) -> usize {
675 self.optimizers.len()
676 }
677
678 pub fn weights(&self) -> &[A] {
680 &self.weights
681 }
682}
683
684impl<A, D> Optimizer<A, D> for WeightedOptimizer<A, D>
685where
686 A: Float + ScalarOperand + Debug,
687 D: Dimension,
688{
689 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
690 if self.optimizers.is_empty() {
691 return Err(OptimError::InvalidConfig(
692 "WeightedOptimizer has no optimizers".to_string(),
693 ));
694 }
695
696 let weight_sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
698 if weight_sum <= A::zero() {
699 return Err(OptimError::InvalidConfig(
700 "WeightedOptimizer weight sum must be positive".to_string(),
701 ));
702 }
703
704 let mut result: Option<Array<A, D>> = None;
706
707 for (optimizer, &weight) in self.optimizers.iter_mut().zip(self.weights.iter()) {
708 let updated = optimizer.step(params, gradients)?;
709 let normalized_weight = weight / weight_sum;
710
711 match result {
712 None => {
713 result = Some(updated * normalized_weight);
714 }
715 Some(ref mut acc) => {
716 acc.zip_mut_with(&updated, |a, &b| {
717 *a = *a + b * normalized_weight;
718 });
719 }
720 }
721 }
722
723 result.ok_or_else(|| {
724 OptimError::InvalidConfig("WeightedOptimizer produced no result".to_string())
725 })
726 }
727
728 fn get_learning_rate(&self) -> A {
729 if let Some(optimizer) = self.optimizers.first() {
730 optimizer.get_learning_rate()
731 } else {
732 A::from(0.01).expect("failed to convert default learning rate")
733 }
734 }
735
736 fn set_learning_rate(&mut self, learning_rate: A) {
737 for optimizer in &mut self.optimizers {
738 optimizer.set_learning_rate(learning_rate);
739 }
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use crate::optimizers::{Adam, SGD};
747 use approx::assert_abs_diff_eq;
748 use scirs2_core::ndarray::Array1;
749
750 #[test]
751 fn test_sequential_optimizer() {
752 let sgd = SGD::new(0.1);
754 let adam = Adam::new(0.01);
755
756 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
757 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
758
759 let params = Array1::zeros(3);
761 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
762
763 let updated_params = seq_optimizer
765 .step(¶ms, &gradients)
766 .expect("unwrap failed");
767
768 assert!(updated_params[0] < -0.1);
772 assert!(updated_params[1] < -0.2);
773 assert!(updated_params[2] < -0.3);
774 }
775
776 #[test]
777 fn test_parallel_optimizer() {
778 let sgd = SGD::new(0.1);
780 let adam = Adam::new(0.01);
781
782 let params1 = Array1::zeros(2);
783 let params2 = Array1::zeros(3);
784
785 let group1 = ParameterGroup::new(params1.clone(), 0); let group2 = ParameterGroup::new(params2.clone(), 1); let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
789 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
790
791 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
793 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
794
795 let updated_params = parallel_optimizer
797 .update_all_parameters(&[gradients1, gradients2])
798 .expect("unwrap failed");
799
800 assert_abs_diff_eq!(updated_params[0][0], -0.1);
803 assert_abs_diff_eq!(updated_params[0][1], -0.2);
804
805 assert!(updated_params[1][0] != 0.0);
808 assert!(updated_params[1][1] != 0.0);
809 assert!(updated_params[1][2] != 0.0);
810 }
811
812 #[test]
813 fn test_chained_optimizer() {
814 let inner = SGD::new(0.1);
816 let outer = Adam::new(0.01);
817
818 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
819 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
820
821 let params = Array1::zeros(3);
823 let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
824
825 let updated_params = chained_optimizer
827 .step(¶ms, &gradients)
828 .expect("unwrap failed");
829
830 assert!(updated_params[0] < -0.1);
834 assert!(updated_params[1] < -0.2);
835 assert!(updated_params[2] < -0.3);
836 }
837
838 #[test]
839 fn test_sequential_learning_rate() {
840 let sgd = SGD::new(0.1);
842 let adam = Adam::new(0.01);
843
844 let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
845 SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
846
847 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
849
850 seq_optimizer.set_learning_rate(0.05);
852
853 assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
855 assert_abs_diff_eq!(
856 seq_optimizer
857 .get_optimizer(0)
858 .expect("unwrap failed")
859 .get_learning_rate(),
860 0.05
861 );
862 assert_abs_diff_eq!(
863 seq_optimizer
864 .get_optimizer(1)
865 .expect("unwrap failed")
866 .get_learning_rate(),
867 0.05
868 );
869 }
870
871 #[test]
872 fn test_parallel_optimizer_step_list() {
873 let sgd = SGD::new(0.1);
875 let adam = Adam::new(0.01);
876
877 let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
878 ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
879
880 let params1 = Array1::zeros(2);
882 let params2 = Array1::zeros(3);
883 let params3 = Array1::zeros(4);
884
885 let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
886 let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
887 let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
888
889 let params_refs = vec![¶ms1, ¶ms2, ¶ms3];
891 let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
892
893 let updated_params = parallel_optimizer
894 .step_list(¶ms_refs, &gradients_refs)
895 .expect("unwrap failed");
896
897 assert_abs_diff_eq!(updated_params[0][0], -0.1);
900 assert_abs_diff_eq!(updated_params[0][1], -0.2);
901
902 assert!(updated_params[1][0] != -0.3);
905
906 assert!(updated_params[2][0] < 0.0);
909 }
910
911 #[test]
912 fn test_chained_optimizer_learning_rate() {
913 let inner = SGD::new(0.1);
915 let outer = Adam::new(0.01);
916
917 let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
918 ChainedOptimizer::new(Box::new(inner), Box::new(outer));
919
920 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
922
923 chained_optimizer.set_learning_rate(0.05);
925
926 assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
928 assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
929 assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
930 }
931
932 #[test]
933 fn test_weighted_optimizer_basic() {
934 let sgd1 = SGD::new(0.1);
936 let sgd2 = SGD::new(0.2);
937
938 let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
939 WeightedOptimizer::new()
940 .add_optimizer(Box::new(sgd1), 0.5)
941 .add_optimizer(Box::new(sgd2), 0.5);
942
943 let params = Array1::zeros(3);
944 let gradients = Array1::ones(3);
945
946 let updated = weighted.step(¶ms, &gradients).expect("step failed");
947
948 assert_abs_diff_eq!(updated[0], -0.15, epsilon = 1e-10);
952 assert_abs_diff_eq!(updated[1], -0.15, epsilon = 1e-10);
953 assert_abs_diff_eq!(updated[2], -0.15, epsilon = 1e-10);
954 }
955
956 #[test]
957 fn test_weighted_optimizer_unequal_weights() {
958 let sgd1 = SGD::new(0.1);
959 let sgd2 = SGD::new(0.2);
960
961 let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
962 WeightedOptimizer::new()
963 .add_optimizer(Box::new(sgd1), 3.0)
964 .add_optimizer(Box::new(sgd2), 1.0);
965
966 let params = Array1::zeros(2);
967 let gradients = Array1::ones(2);
968
969 let updated = weighted.step(¶ms, &gradients).expect("step failed");
970
971 assert_abs_diff_eq!(updated[0], -0.125, epsilon = 1e-10);
975 }
976
977 #[test]
978 fn test_weighted_optimizer_empty() {
979 let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
980 WeightedOptimizer::new();
981
982 let params = Array1::zeros(3);
983 let gradients = Array1::ones(3);
984
985 let result = weighted.step(¶ms, &gradients);
986 assert!(result.is_err());
987 }
988
989 #[test]
990 fn test_weighted_optimizer_normalize_weights() {
991 let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
992 WeightedOptimizer::new()
993 .add_optimizer(Box::new(SGD::new(0.1)), 2.0)
994 .add_optimizer(Box::new(SGD::new(0.2)), 8.0);
995
996 weighted.normalize_weights();
997
998 assert_abs_diff_eq!(weighted.weights()[0], 0.2, epsilon = 1e-10);
999 assert_abs_diff_eq!(weighted.weights()[1], 0.8, epsilon = 1e-10);
1000 }
1001
1002 #[test]
1003 fn test_weighted_optimizer_learning_rate() {
1004 let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
1005 WeightedOptimizer::new()
1006 .add_optimizer(Box::new(SGD::new(0.1)), 1.0)
1007 .add_optimizer(Box::new(Adam::new(0.01)), 1.0);
1008
1009 assert_abs_diff_eq!(weighted.get_learning_rate(), 0.1);
1011
1012 weighted.set_learning_rate(0.05);
1014 assert_abs_diff_eq!(weighted.get_learning_rate(), 0.05);
1015 }
1016
1017 #[test]
1018 fn test_weighted_optimizer_with_optimizers() {
1019 let opts: Vec<(Box<dyn Optimizer<f64, scirs2_core::ndarray::Ix1>>, f64)> = vec![
1020 (Box::new(SGD::new(0.1)), 1.0),
1021 (Box::new(SGD::new(0.2)), 1.0),
1022 ];
1023
1024 let weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
1025 WeightedOptimizer::new().with_optimizers(opts);
1026
1027 assert_eq!(weighted.num_optimizers(), 2);
1028 assert_abs_diff_eq!(weighted.weights()[0], 1.0);
1029 assert_abs_diff_eq!(weighted.weights()[1], 1.0);
1030 }
1031}