1#[allow(dead_code)]
8use crate::error::Result;
9use crate::optimizers::*;
10use crate::schedulers::*;
11use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::time::{Duration, Instant};
17
18#[derive(Debug, Clone)]
20pub struct SelfTuningConfig {
21 pub evaluation_window: usize,
23
24 pub improvement_threshold: f64,
26
27 pub max_switches_per_epoch: usize,
29
30 pub auto_lr_adjustment: bool,
32
33 pub auto_optimizer_selection: bool,
35
36 pub auto_batch_size_tuning: bool,
38
39 pub warmup_steps: usize,
41
42 pub exploration_rate: f64,
44
45 pub exploration_decay: f64,
47
48 pub target_metric: TargetMetric,
50}
51
52impl Default for SelfTuningConfig {
53 fn default() -> Self {
54 Self {
55 evaluation_window: 100,
56 improvement_threshold: 0.01,
57 max_switches_per_epoch: 3,
58 auto_lr_adjustment: true,
59 auto_optimizer_selection: true,
60 auto_batch_size_tuning: false,
61 warmup_steps: 1000,
62 exploration_rate: 0.1,
63 exploration_decay: 0.99,
64 target_metric: TargetMetric::Loss,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum TargetMetric {
72 Loss,
74 Accuracy,
76 ConvergenceTime,
78 Throughput,
80 Custom,
82}
83
84#[derive(Debug, Clone)]
86pub struct PerformanceStats {
87 pub loss: f64,
89
90 pub accuracy: Option<f64>,
92
93 pub gradient_norm: f64,
95
96 pub throughput: f64,
98
99 pub memory_usage: f64,
101
102 pub step_time: Duration,
104
105 pub learning_rate: f64,
107
108 pub optimizer_type: String,
110
111 pub custom_metrics: HashMap<String, f64>,
113}
114
115pub struct SelfTuningOptimizer<A: Float, D: Dimension> {
117 config: SelfTuningConfig,
119
120 current_optimizer: Box<dyn OptimizerTrait<A, D>>,
122
123 optimizer_candidates: Vec<OptimizerCandidate<A, D>>,
125
126 performance_history: VecDeque<PerformanceStats>,
128
129 search_state: HyperparameterSearchState,
131
132 lr_scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
134
135 selection_strategy: OptimizerSelectionStrategy,
137
138 step_count: usize,
140
141 switches_this_epoch: usize,
143
144 best_performance: Option<f64>,
146
147 last_adaptation_time: Instant,
149
150 bandit_state: BanditState,
152}
153
154struct OptimizerCandidate<A: Float, D: Dimension> {
156 name: String,
158
159 factory: Box<dyn Fn() -> Box<dyn OptimizerTrait<A, D>>>,
161
162 performance_history: Vec<f64>,
164
165 usage_count: usize,
167
168 average_performance: f64,
170
171 confidence_interval: (f64, f64),
173}
174
175#[derive(Debug)]
177struct HyperparameterSearchState {
178 learning_rate: f64,
180
181 lr_bounds: (f64, f64),
183
184 batch_size: usize,
186
187 batch_size_bounds: (usize, usize),
189
190 search_iterations: usize,
192
193 best_hyperparameters: HashMap<String, f64>,
195
196 search_algorithm: SearchAlgorithm,
198}
199
200#[derive(Debug)]
202enum SearchAlgorithm {
203 Random {
205 seed: u64,
207 },
208
209 Bayesian {
211 gp_state: GaussianProcessState,
213 },
214
215 Grid {
217 position: Vec<usize>,
219 dimensions: Vec<usize>,
221 },
222
223 SuccessiveHalving {
225 bracket: usize,
227 configurations: Vec<HashMap<String, f64>>,
229 },
230}
231
232#[derive(Debug)]
234struct GaussianProcessState {
235 observed_points: Vec<Vec<f64>>,
237
238 observed_values: Vec<f64>,
240
241 kernel_params: Vec<f64>,
243
244 acquisition_function: AcquisitionFunction,
246}
247
248#[derive(Debug, Clone, Copy)]
250enum AcquisitionFunction {
251 ExpectedImprovement,
252 ProbabilityOfImprovement,
253 UpperConfidenceBound,
254}
255
256#[derive(Debug, Clone)]
258enum OptimizerSelectionStrategy {
259 MultiArmedBandit {
261 algorithm: BanditAlgorithm,
263 },
264
265 PerformanceBased {
267 min_difference: f64,
269 },
270
271 RoundRobin {
273 current_index: usize,
275 },
276
277 MetaLearning {
279 problem_features: Vec<f64>,
281 optimizer_mappings: HashMap<String, f64>,
283 },
284}
285
286#[derive(Debug, Clone, Copy)]
288enum BanditAlgorithm {
289 EpsilonGreedy,
290 UCB1,
291 ThompsonSampling,
292 LinUCB,
293}
294
295#[derive(Debug)]
297struct BanditState {
298 reward_estimates: Vec<f64>,
300
301 confidence_bounds: Vec<f64>,
303
304 selection_counts: Vec<usize>,
306
307 total_selections: usize,
309
310 exploration_param: f64,
312}
313
314pub trait OptimizerTrait<A: Float + ScalarOperand + Debug, D: Dimension>: Send + Sync {
316 fn name(&self) -> &str;
318
319 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()>;
321
322 fn learning_rate(&self) -> A;
324
325 fn set_learning_rate(&mut self, lr: A);
327
328 fn get_state(&self) -> HashMap<String, Vec<u8>>;
330
331 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()>;
333
334 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>>;
336}
337
338impl<
339 A: Float + ScalarOperand + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive,
340 D: Dimension + 'static,
341 > SelfTuningOptimizer<A, D>
342{
343 pub fn new(config: SelfTuningConfig) -> Result<Self> {
345 let mut optimizer_candidates = Vec::new();
346
347 optimizer_candidates.push(OptimizerCandidate {
349 name: "Adam".to_string(),
350 factory: Box::new(|| Box::new(AdamOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.0))),
351 performance_history: Vec::new(),
352 usage_count: 0,
353 average_performance: 0.0,
354 confidence_interval: (0.0, 0.0),
355 });
356
357 optimizer_candidates.push(OptimizerCandidate {
358 name: "SGD".to_string(),
359 factory: Box::new(|| Box::new(SGDOptimizerWrapper::new(0.01, 0.9, 0.0, false))),
360 performance_history: Vec::new(),
361 usage_count: 0,
362 average_performance: 0.0,
363 confidence_interval: (0.0, 0.0),
364 });
365
366 optimizer_candidates.push(OptimizerCandidate {
367 name: "AdamW".to_string(),
368 factory: Box::new(|| {
369 Box::new(AdamWOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.01))
370 }),
371 performance_history: Vec::new(),
372 usage_count: 0,
373 average_performance: 0.0,
374 confidence_interval: (0.0, 0.0),
375 });
376
377 let current_optimizer = (optimizer_candidates[0].factory)();
379
380 let search_state = HyperparameterSearchState {
381 learning_rate: 0.001,
382 lr_bounds: (1e-6, 1.0),
383 batch_size: 32,
384 batch_size_bounds: (8, 512),
385 search_iterations: 0,
386 best_hyperparameters: HashMap::new(),
387 search_algorithm: SearchAlgorithm::Random { seed: 42 },
388 };
389
390 let selection_strategy = OptimizerSelectionStrategy::MultiArmedBandit {
391 algorithm: BanditAlgorithm::UCB1,
392 };
393
394 let bandit_state = BanditState {
395 reward_estimates: vec![0.0; optimizer_candidates.len()],
396 confidence_bounds: vec![1.0; optimizer_candidates.len()],
397 selection_counts: vec![0; optimizer_candidates.len()],
398 total_selections: 0,
399 exploration_param: 2.0,
400 };
401
402 Ok(Self {
403 config,
404 current_optimizer,
405 optimizer_candidates,
406 performance_history: VecDeque::new(),
407 search_state,
408 lr_scheduler: None,
409 selection_strategy,
410 step_count: 0,
411 switches_this_epoch: 0,
412 best_performance: None,
413 last_adaptation_time: Instant::now(),
414 bandit_state,
415 })
416 }
417
418 pub fn add_optimizer_candidate<F>(&mut self, name: String, factory: F)
420 where
421 F: Fn() -> Box<dyn OptimizerTrait<A, D>> + 'static,
422 {
423 self.optimizer_candidates.push(OptimizerCandidate {
424 name,
425 factory: Box::new(factory),
426 performance_history: Vec::new(),
427 usage_count: 0,
428 average_performance: 0.0,
429 confidence_interval: (0.0, 0.0),
430 });
431
432 self.bandit_state.reward_estimates.push(0.0);
434 self.bandit_state.confidence_bounds.push(1.0);
435 self.bandit_state.selection_counts.push(0);
436 }
437
438 pub fn step(
440 &mut self,
441 params: &mut [Array<A, D>],
442 grads: &[Array<A, D>],
443 stats: PerformanceStats,
444 ) -> Result<()> {
445 self.step_count += 1;
446
447 self.performance_history.push_back(stats.clone());
449 if self.performance_history.len() > self.config.evaluation_window {
450 self.performance_history.pop_front();
451 }
452
453 self.current_optimizer.step(params, grads)?;
455
456 if self.step_count > self.config.warmup_steps {
458 self.maybe_adapt_optimizer(&stats)?;
459 self.maybe_adapt_learning_rate(&stats)?;
460 self.maybe_adapt_hyperparameters(&stats)?;
461 }
462
463 let current_performance = self.extract_performance_metric(&stats);
465 if let Some(performance) = current_performance {
466 if self.best_performance.is_none()
467 || self.is_better_performance(
468 performance,
469 self.best_performance.expect("unwrap failed"),
470 )
471 {
472 self.best_performance = Some(performance);
473 }
474 }
475
476 Ok(())
477 }
478
479 fn maybe_adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
481 if !self.config.auto_optimizer_selection {
482 return Ok(());
483 }
484
485 if self.switches_this_epoch >= self.config.max_switches_per_epoch {
486 return Ok(());
487 }
488
489 let should_adapt = self.should_adapt_optimizer(stats);
490
491 if should_adapt {
492 self.adapt_optimizer(stats)?;
493 self.switches_this_epoch += 1;
494 }
495
496 Ok(())
497 }
498
499 fn should_adapt_optimizer(&self, stats: &PerformanceStats) -> bool {
501 if self.performance_history.len() < self.config.evaluation_window / 2 {
502 return false;
503 }
504
505 let recent_performance: Vec<f64> = self
507 .performance_history
508 .iter()
509 .rev()
510 .take(self.config.evaluation_window / 4)
511 .filter_map(|s| self.extract_performance_metric(s))
512 .collect();
513
514 let older_performance: Vec<f64> = self
515 .performance_history
516 .iter()
517 .rev()
518 .skip(self.config.evaluation_window / 4)
519 .take(self.config.evaluation_window / 4)
520 .filter_map(|s| self.extract_performance_metric(s))
521 .collect();
522
523 if recent_performance.is_empty() || older_performance.is_empty() {
524 return false;
525 }
526
527 let recent_avg = recent_performance.iter().sum::<f64>() / recent_performance.len() as f64;
528 let older_avg = older_performance.iter().sum::<f64>() / older_performance.len() as f64;
529
530 match self.config.target_metric {
532 TargetMetric::Loss => {
533 (recent_avg - older_avg).abs() < self.config.improvement_threshold
534 || recent_avg > older_avg
535 }
536 TargetMetric::Accuracy | TargetMetric::Throughput => {
537 (recent_avg - older_avg).abs() < self.config.improvement_threshold
538 || recent_avg < older_avg
539 }
540 _ => false,
541 }
542 }
543
544 fn adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
546 let new_optimizer_idx = match &self.selection_strategy {
547 OptimizerSelectionStrategy::MultiArmedBandit { algorithm } => {
548 self.select_optimizer_bandit(*algorithm)
549 }
550 OptimizerSelectionStrategy::PerformanceBased { .. } => {
551 self.select_optimizer_performance_based()
552 }
553 OptimizerSelectionStrategy::RoundRobin { current_index } => {
554 (*current_index + 1) % self.optimizer_candidates.len()
555 }
556 OptimizerSelectionStrategy::MetaLearning { .. } => {
557 self.select_optimizer_meta_learning(stats)
558 }
559 };
560
561 if new_optimizer_idx < self.optimizer_candidates.len() {
563 let current_lr = self.current_optimizer.learning_rate();
564 let current_state = self.current_optimizer.get_state();
565
566 self.current_optimizer = (self.optimizer_candidates[new_optimizer_idx].factory)();
567 self.current_optimizer.set_learning_rate(current_lr);
568
569 if self.current_optimizer.set_state(current_state).is_err() {
571 }
573
574 self.optimizer_candidates[new_optimizer_idx].usage_count += 1;
576 }
577
578 Ok(())
579 }
580
581 fn select_optimizer_bandit(&mut self, algorithm: BanditAlgorithm) -> usize {
583 match algorithm {
584 BanditAlgorithm::UCB1 => self.select_ucb1(),
585 BanditAlgorithm::EpsilonGreedy => self.select_epsilon_greedy(),
586 BanditAlgorithm::ThompsonSampling => self.select_thompson_sampling(),
587 BanditAlgorithm::LinUCB => self.select_linucb(),
588 }
589 }
590
591 fn select_ucb1(&self) -> usize {
593 if self.bandit_state.total_selections == 0 {
594 return 0;
595 }
596
597 let mut best_score = f64::NEG_INFINITY;
598 let mut best_idx = 0;
599
600 for (i, candidate) in self.optimizer_candidates.iter().enumerate() {
601 let ucb_score = if self.bandit_state.selection_counts[i] == 0 {
602 f64::INFINITY
603 } else {
604 let mean_reward = self.bandit_state.reward_estimates[i];
605 let confidence = (self.bandit_state.exploration_param
606 * (self.bandit_state.total_selections as f64).ln()
607 / self.bandit_state.selection_counts[i] as f64)
608 .sqrt();
609 mean_reward + confidence
610 };
611
612 if ucb_score > best_score {
613 best_score = ucb_score;
614 best_idx = i;
615 }
616 }
617
618 best_idx
619 }
620
621 fn select_epsilon_greedy(&self) -> usize {
623 let mut rng = thread_rng();
624
625 if A::from(rng.random::<f64>()).expect("unwrap failed")
626 < A::from(self.config.exploration_rate).expect("unwrap failed")
627 {
628 rng.gen_range(0..self.optimizer_candidates.len())
630 } else {
631 self.bandit_state
633 .reward_estimates
634 .iter()
635 .enumerate()
636 .max_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
637 .map(|(idx, _)| idx)
638 .unwrap_or(0)
639 }
640 }
641
642 fn select_thompson_sampling(&self) -> usize {
644 let mut rng = thread_rng();
646
647 let mut best_sample = f64::NEG_INFINITY;
648 let mut best_idx = 0;
649
650 for (i, _) in self.optimizer_candidates.iter().enumerate() {
651 let mean = self.bandit_state.reward_estimates[i];
652 let std = self.bandit_state.confidence_bounds[i];
653 let sample = rng.gen_range(mean - std..mean + std);
654
655 if sample > best_sample {
656 best_sample = sample;
657 best_idx = i;
658 }
659 }
660
661 best_idx
662 }
663
664 fn select_linucb(&self) -> usize {
666 self.select_ucb1()
668 }
669
670 fn select_optimizer_performance_based(&self) -> usize {
672 self.optimizer_candidates
673 .iter()
674 .enumerate()
675 .max_by(|a, b| {
676 a.1.average_performance
677 .partial_cmp(&b.1.average_performance)
678 .expect("unwrap failed")
679 })
680 .map(|(idx, _)| idx)
681 .unwrap_or(0)
682 }
683
684 fn select_optimizer_meta_learning(&self, stats: &PerformanceStats) -> usize {
686 0
688 }
689
690 fn maybe_adapt_learning_rate(&mut self, stats: &PerformanceStats) -> Result<()> {
692 if !self.config.auto_lr_adjustment {
693 return Ok(());
694 }
695
696 let current_lr = self
698 .current_optimizer
699 .learning_rate()
700 .to_f64()
701 .expect("unwrap failed");
702 let gradient_norm = stats.gradient_norm;
703
704 let new_lr = if gradient_norm > 10.0 {
705 current_lr * 0.9
707 } else if gradient_norm < 0.1 {
708 current_lr * 1.1
710 } else {
711 current_lr
712 };
713
714 let clamped_lr = new_lr
715 .max(self.search_state.lr_bounds.0)
716 .min(self.search_state.lr_bounds.1);
717
718 if (clamped_lr - current_lr).abs() > current_lr * 0.01 {
719 self.current_optimizer
720 .set_learning_rate(A::from(clamped_lr).expect("unwrap failed"));
721 self.search_state.learning_rate = clamped_lr;
722 }
723
724 Ok(())
725 }
726
727 fn maybe_adapt_hyperparameters(&mut self, stats: &PerformanceStats) -> Result<()> {
729 Ok(())
732 }
733
734 fn extract_performance_metric(&self, stats: &PerformanceStats) -> Option<f64> {
736 match self.config.target_metric {
737 TargetMetric::Loss => Some(stats.loss),
738 TargetMetric::Accuracy => stats.accuracy,
739 TargetMetric::Throughput => Some(stats.throughput),
740 TargetMetric::ConvergenceTime => Some(stats.step_time.as_secs_f64()),
741 TargetMetric::Custom => stats.custom_metrics.values().next().copied(),
742 }
743 }
744
745 fn is_better_performance(&self, new_perf: f64, oldperf: f64) -> bool {
747 match self.config.target_metric {
748 TargetMetric::Loss | TargetMetric::ConvergenceTime => new_perf < oldperf,
749 TargetMetric::Accuracy | TargetMetric::Throughput => new_perf > oldperf,
750 TargetMetric::Custom => new_perf > oldperf, }
752 }
753
754 pub fn reset_epoch(&mut self) {
756 self.switches_this_epoch = 0;
757 }
758
759 pub fn get_optimizer_info(&self) -> OptimizerInfo {
761 OptimizerInfo {
762 name: self.current_optimizer.name().to_string(),
763 learning_rate: self
764 .current_optimizer
765 .learning_rate()
766 .to_f64()
767 .expect("unwrap failed"),
768 step_count: self.step_count,
769 switches_this_epoch: self.switches_this_epoch,
770 performance_window_size: self.performance_history.len(),
771 best_performance: self.best_performance,
772 }
773 }
774
775 pub fn get_statistics(&self) -> SelfTuningStatistics {
777 let optimizer_usage: HashMap<String, usize> = self
778 .optimizer_candidates
779 .iter()
780 .map(|c| (c.name.clone(), c.usage_count))
781 .collect();
782
783 SelfTuningStatistics {
784 total_steps: self.step_count,
785 total_optimizer_switches: self
786 .optimizer_candidates
787 .iter()
788 .map(|c| c.usage_count)
789 .sum(),
790 optimizer_usage,
791 current_learning_rate: self.search_state.learning_rate,
792 average_step_time: self
793 .performance_history
794 .iter()
795 .map(|s| s.step_time.as_secs_f64())
796 .sum::<f64>()
797 / self.performance_history.len().max(1) as f64,
798 exploration_rate: self.config.exploration_rate,
799 }
800 }
801}
802
803#[derive(Debug, Clone)]
805pub struct OptimizerInfo {
806 pub name: String,
807 pub learning_rate: f64,
808 pub step_count: usize,
809 pub switches_this_epoch: usize,
810 pub performance_window_size: usize,
811 pub best_performance: Option<f64>,
812}
813
814#[derive(Debug, Clone)]
816pub struct SelfTuningStatistics {
817 pub total_steps: usize,
818 pub total_optimizer_switches: usize,
819 pub optimizer_usage: HashMap<String, usize>,
820 pub current_learning_rate: f64,
821 pub average_step_time: f64,
822 pub exploration_rate: f64,
823}
824
825struct AdamOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
827 inner: crate::optimizers::Adam<A>,
828 _phantom: std::marker::PhantomData<D>,
829}
830
831impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
832 AdamOptimizerWrapper<A, D>
833{
834 fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
835 Self {
836 inner: crate::optimizers::Adam::new_with_config(
837 A::from(_lr).expect("unwrap failed"),
838 A::from(beta1).expect("unwrap failed"),
839 A::from(beta2).expect("unwrap failed"),
840 A::from(eps).expect("unwrap failed"),
841 A::from(weightdecay).expect("unwrap failed"),
842 ),
843 _phantom: std::marker::PhantomData,
844 }
845 }
846}
847
848impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
849 OptimizerTrait<A, D> for AdamOptimizerWrapper<A, D>
850{
851 fn name(&self) -> &str {
852 "Adam"
853 }
854
855 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
856 if params.len() != grads.len() {
857 return Err(crate::error::OptimError::InvalidParameter(
858 "Mismatched number of parameters and gradients".to_string(),
859 ));
860 }
861
862 for (param, grad) in params.iter_mut().zip(grads.iter()) {
863 let updated = self.inner.step(param, grad)?;
864 *param = updated;
865 }
866 Ok(())
867 }
868
869 fn learning_rate(&self) -> A {
870 self.inner.learning_rate()
871 }
872
873 fn set_learning_rate(&mut self, lr: A) {
874 <crate::optimizers::Adam<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
875 &mut self.inner,
876 lr,
877 );
878 }
879
880 fn get_state(&self) -> HashMap<String, Vec<u8>> {
881 HashMap::new()
882 }
883
884 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
885 Ok(())
886 }
887
888 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
889 Box::new(AdamOptimizerWrapper {
890 inner: self.inner.clone(),
891 _phantom: std::marker::PhantomData,
892 })
893 }
894}
895
896struct SGDOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
897 inner: crate::optimizers::SGD<A>,
898 _phantom: std::marker::PhantomData<D>,
899}
900
901impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
902 SGDOptimizerWrapper<A, D>
903{
904 fn new(_lr: f64, momentum: f64, weightdecay: f64, nesterov: bool) -> Self {
905 Self {
906 inner: crate::optimizers::SGD::new_with_config(
907 A::from(_lr).expect("unwrap failed"),
908 A::from(momentum).expect("unwrap failed"),
909 A::from(weightdecay).expect("unwrap failed"),
910 ),
911 _phantom: std::marker::PhantomData,
912 }
913 }
914}
915
916impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
917 OptimizerTrait<A, D> for SGDOptimizerWrapper<A, D>
918{
919 fn name(&self) -> &str {
920 "SGD"
921 }
922
923 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
924 if params.len() != grads.len() {
925 return Err(crate::error::OptimError::InvalidParameter(
926 "Mismatched number of parameters and gradients".to_string(),
927 ));
928 }
929
930 for (param, grad) in params.iter_mut().zip(grads.iter()) {
931 let updated = self.inner.step(param, grad)?;
932 *param = updated;
933 }
934 Ok(())
935 }
936
937 fn learning_rate(&self) -> A {
938 self.inner.learning_rate()
939 }
940
941 fn set_learning_rate(&mut self, lr: A) {
942 <crate::optimizers::SGD<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
943 &mut self.inner,
944 lr,
945 );
946 }
947
948 fn get_state(&self) -> HashMap<String, Vec<u8>> {
949 HashMap::new()
950 }
951
952 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
953 Ok(())
954 }
955
956 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
957 Box::new(SGDOptimizerWrapper {
958 inner: self.inner.clone(),
959 _phantom: std::marker::PhantomData,
960 })
961 }
962}
963
964struct AdamWOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
965 inner: crate::optimizers::AdamW<A>,
966 _phantom: std::marker::PhantomData<D>,
967}
968
969impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
970 AdamWOptimizerWrapper<A, D>
971{
972 fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
973 Self {
974 inner: crate::optimizers::AdamW::new_with_config(
975 A::from(_lr).expect("unwrap failed"),
976 A::from(beta1).expect("unwrap failed"),
977 A::from(beta2).expect("unwrap failed"),
978 A::from(eps).expect("unwrap failed"),
979 A::from(weightdecay).expect("unwrap failed"),
980 ),
981 _phantom: std::marker::PhantomData,
982 }
983 }
984}
985
986impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
987 OptimizerTrait<A, D> for AdamWOptimizerWrapper<A, D>
988{
989 fn name(&self) -> &str {
990 "AdamW"
991 }
992
993 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
994 if params.len() != grads.len() {
995 return Err(crate::error::OptimError::InvalidParameter(
996 "Mismatched number of parameters and gradients".to_string(),
997 ));
998 }
999
1000 for (param, grad) in params.iter_mut().zip(grads.iter()) {
1001 let updated = self.inner.step(param, grad)?;
1002 *param = updated;
1003 }
1004 Ok(())
1005 }
1006
1007 fn learning_rate(&self) -> A {
1008 self.inner.learning_rate()
1009 }
1010
1011 fn set_learning_rate(&mut self, lr: A) {
1012 <crate::optimizers::AdamW<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
1013 &mut self.inner,
1014 lr,
1015 );
1016 }
1017
1018 fn get_state(&self) -> HashMap<String, Vec<u8>> {
1019 HashMap::new()
1020 }
1021
1022 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
1023 Ok(())
1024 }
1025
1026 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
1027 Box::new(AdamWOptimizerWrapper {
1028 inner: self.inner.clone(),
1029 _phantom: std::marker::PhantomData,
1030 })
1031 }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036 use super::*;
1037 use scirs2_core::ndarray::Array1;
1038 use std::time::Duration;
1039
1040 #[test]
1041 fn test_self_tuning_config_default() {
1042 let config = SelfTuningConfig::default();
1043 assert_eq!(config.evaluation_window, 100);
1044 assert!(config.auto_lr_adjustment);
1045 assert!(config.auto_optimizer_selection);
1046 }
1047
1048 #[test]
1049 fn test_self_tuning_optimizer_creation() {
1050 let config = SelfTuningConfig::default();
1051 let optimizer: Result<SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1>> =
1052 SelfTuningOptimizer::new(config);
1053 assert!(optimizer.is_ok());
1054 }
1055
1056 #[test]
1057 fn test_performance_stats() {
1058 let stats = PerformanceStats {
1059 loss: 0.5,
1060 accuracy: Some(0.9),
1061 gradient_norm: 1.2,
1062 throughput: 100.0,
1063 memory_usage: 1024.0,
1064 step_time: Duration::from_millis(50),
1065 learning_rate: 0.001,
1066 optimizer_type: "Adam".to_string(),
1067 custom_metrics: HashMap::new(),
1068 };
1069
1070 assert_eq!(stats.loss, 0.5);
1071 assert_eq!(stats.accuracy, Some(0.9));
1072 }
1073
1074 #[test]
1075 fn test_optimizer_step() {
1076 let config = SelfTuningConfig::default();
1077 let mut optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1078 SelfTuningOptimizer::new(config).expect("unwrap failed");
1079
1080 let mut params = vec![Array1::zeros(10)];
1081 let grads = vec![Array1::ones(10)];
1082
1083 let stats = PerformanceStats {
1084 loss: 1.0,
1085 accuracy: None,
1086 gradient_norm: 1.0,
1087 throughput: 50.0,
1088 memory_usage: 512.0,
1089 step_time: Duration::from_millis(10),
1090 learning_rate: 0.001,
1091 optimizer_type: "Adam".to_string(),
1092 custom_metrics: HashMap::new(),
1093 };
1094
1095 let result = optimizer.step(&mut params, &grads, stats);
1096 assert!(result.is_ok());
1097
1098 let info = optimizer.get_optimizer_info();
1099 assert_eq!(info.name, "Adam");
1100 assert_eq!(info.step_count, 1);
1101 }
1102
1103 #[test]
1104 fn test_bandit_selection() {
1105 let config = SelfTuningConfig::default();
1106 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1107 SelfTuningOptimizer::new(config).expect("unwrap failed");
1108
1109 let selection = optimizer.select_ucb1();
1110 assert!(selection < optimizer.optimizer_candidates.len());
1111 }
1112
1113 #[test]
1114 fn test_performance_metric_extraction() {
1115 let config = SelfTuningConfig {
1116 target_metric: TargetMetric::Loss,
1117 ..Default::default()
1118 };
1119 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1120 SelfTuningOptimizer::new(config).expect("unwrap failed");
1121
1122 let stats = PerformanceStats {
1123 loss: 0.8,
1124 accuracy: Some(0.85),
1125 gradient_norm: 1.1,
1126 throughput: 75.0,
1127 memory_usage: 800.0,
1128 step_time: Duration::from_millis(20),
1129 learning_rate: 0.001,
1130 optimizer_type: "Adam".to_string(),
1131 custom_metrics: HashMap::new(),
1132 };
1133
1134 let metric = optimizer.extract_performance_metric(&stats);
1135 assert_eq!(metric, Some(0.8));
1136 }
1137
1138 #[test]
1139 fn test_statistics() {
1140 let config = SelfTuningConfig::default();
1141 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1142 SelfTuningOptimizer::new(config).expect("unwrap failed");
1143
1144 let stats = optimizer.get_statistics();
1145 assert_eq!(stats.total_steps, 0);
1146 assert!(stats.optimizer_usage.contains_key("Adam"));
1147 }
1148}