1#[allow(dead_code)]
8use crate::error::Result;
9use crate::optimizers::*;
10use crate::schedulers::*;
11use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
12use scirs2_core::numeric::Float;
13use scirs2_core::random::{thread_rng, Rng};
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::time::{Duration, Instant};
17
18#[derive(Debug, Clone)]
20pub struct SelfTuningConfig {
21 pub evaluation_window: usize,
23
24 pub improvement_threshold: f64,
26
27 pub max_switches_per_epoch: usize,
29
30 pub auto_lr_adjustment: bool,
32
33 pub auto_optimizer_selection: bool,
35
36 pub auto_batch_size_tuning: bool,
38
39 pub warmup_steps: usize,
41
42 pub exploration_rate: f64,
44
45 pub exploration_decay: f64,
47
48 pub target_metric: TargetMetric,
50}
51
52impl Default for SelfTuningConfig {
53 fn default() -> Self {
54 Self {
55 evaluation_window: 100,
56 improvement_threshold: 0.01,
57 max_switches_per_epoch: 3,
58 auto_lr_adjustment: true,
59 auto_optimizer_selection: true,
60 auto_batch_size_tuning: false,
61 warmup_steps: 1000,
62 exploration_rate: 0.1,
63 exploration_decay: 0.99,
64 target_metric: TargetMetric::Loss,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum TargetMetric {
72 Loss,
74 Accuracy,
76 ConvergenceTime,
78 Throughput,
80 Custom,
82}
83
84#[derive(Debug, Clone)]
86pub struct PerformanceStats {
87 pub loss: f64,
89
90 pub accuracy: Option<f64>,
92
93 pub gradient_norm: f64,
95
96 pub throughput: f64,
98
99 pub memory_usage: f64,
101
102 pub step_time: Duration,
104
105 pub learning_rate: f64,
107
108 pub optimizer_type: String,
110
111 pub custom_metrics: HashMap<String, f64>,
113}
114
115pub struct SelfTuningOptimizer<A: Float, D: Dimension> {
117 config: SelfTuningConfig,
119
120 current_optimizer: Box<dyn OptimizerTrait<A, D>>,
122
123 optimizer_candidates: Vec<OptimizerCandidate<A, D>>,
125
126 performance_history: VecDeque<PerformanceStats>,
128
129 search_state: HyperparameterSearchState,
131
132 lr_scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
134
135 selection_strategy: OptimizerSelectionStrategy,
137
138 step_count: usize,
140
141 switches_this_epoch: usize,
143
144 best_performance: Option<f64>,
146
147 last_adaptation_time: Instant,
149
150 bandit_state: BanditState,
152}
153
154struct OptimizerCandidate<A: Float, D: Dimension> {
156 name: String,
158
159 factory: Box<dyn Fn() -> Box<dyn OptimizerTrait<A, D>>>,
161
162 performance_history: Vec<f64>,
164
165 usage_count: usize,
167
168 average_performance: f64,
170
171 confidence_interval: (f64, f64),
173}
174
175#[derive(Debug)]
177struct HyperparameterSearchState {
178 learning_rate: f64,
180
181 lr_bounds: (f64, f64),
183
184 batch_size: usize,
186
187 batch_size_bounds: (usize, usize),
189
190 search_iterations: usize,
192
193 best_hyperparameters: HashMap<String, f64>,
195
196 search_algorithm: SearchAlgorithm,
198}
199
200#[derive(Debug)]
202enum SearchAlgorithm {
203 Random {
205 seed: u64,
207 },
208
209 Bayesian {
211 gp_state: GaussianProcessState,
213 },
214
215 Grid {
217 position: Vec<usize>,
219 dimensions: Vec<usize>,
221 },
222
223 SuccessiveHalving {
225 bracket: usize,
227 configurations: Vec<HashMap<String, f64>>,
229 },
230}
231
232#[derive(Debug)]
234struct GaussianProcessState {
235 observed_points: Vec<Vec<f64>>,
237
238 observed_values: Vec<f64>,
240
241 kernel_params: Vec<f64>,
243
244 acquisition_function: AcquisitionFunction,
246}
247
248#[derive(Debug, Clone, Copy)]
250enum AcquisitionFunction {
251 ExpectedImprovement,
252 ProbabilityOfImprovement,
253 UpperConfidenceBound,
254}
255
256#[derive(Debug, Clone)]
258enum OptimizerSelectionStrategy {
259 MultiArmedBandit {
261 algorithm: BanditAlgorithm,
263 },
264
265 PerformanceBased {
267 min_difference: f64,
269 },
270
271 RoundRobin {
273 current_index: usize,
275 },
276
277 MetaLearning {
279 problem_features: Vec<f64>,
281 optimizer_mappings: HashMap<String, f64>,
283 },
284}
285
286#[derive(Debug, Clone, Copy)]
288enum BanditAlgorithm {
289 EpsilonGreedy,
290 UCB1,
291 ThompsonSampling,
292 LinUCB,
293}
294
295#[derive(Debug)]
297struct BanditState {
298 reward_estimates: Vec<f64>,
300
301 confidence_bounds: Vec<f64>,
303
304 selection_counts: Vec<usize>,
306
307 total_selections: usize,
309
310 exploration_param: f64,
312}
313
314pub trait OptimizerTrait<A: Float + ScalarOperand + Debug, D: Dimension>: Send + Sync {
316 fn name(&self) -> &str;
318
319 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()>;
321
322 fn learning_rate(&self) -> A;
324
325 fn set_learning_rate(&mut self, lr: A);
327
328 fn get_state(&self) -> HashMap<String, Vec<u8>>;
330
331 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()>;
333
334 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>>;
336}
337
338impl<
339 A: Float + ScalarOperand + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive,
340 D: Dimension + 'static,
341 > SelfTuningOptimizer<A, D>
342{
343 pub fn new(config: SelfTuningConfig) -> Result<Self> {
345 let mut optimizer_candidates = Vec::new();
346
347 optimizer_candidates.push(OptimizerCandidate {
349 name: "Adam".to_string(),
350 factory: Box::new(|| Box::new(AdamOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.0))),
351 performance_history: Vec::new(),
352 usage_count: 0,
353 average_performance: 0.0,
354 confidence_interval: (0.0, 0.0),
355 });
356
357 optimizer_candidates.push(OptimizerCandidate {
358 name: "SGD".to_string(),
359 factory: Box::new(|| Box::new(SGDOptimizerWrapper::new(0.01, 0.9, 0.0, false))),
360 performance_history: Vec::new(),
361 usage_count: 0,
362 average_performance: 0.0,
363 confidence_interval: (0.0, 0.0),
364 });
365
366 optimizer_candidates.push(OptimizerCandidate {
367 name: "AdamW".to_string(),
368 factory: Box::new(|| {
369 Box::new(AdamWOptimizerWrapper::new(0.001, 0.9, 0.999, 1e-8, 0.01))
370 }),
371 performance_history: Vec::new(),
372 usage_count: 0,
373 average_performance: 0.0,
374 confidence_interval: (0.0, 0.0),
375 });
376
377 let current_optimizer = (optimizer_candidates[0].factory)();
379
380 let search_state = HyperparameterSearchState {
381 learning_rate: 0.001,
382 lr_bounds: (1e-6, 1.0),
383 batch_size: 32,
384 batch_size_bounds: (8, 512),
385 search_iterations: 0,
386 best_hyperparameters: HashMap::new(),
387 search_algorithm: SearchAlgorithm::Random { seed: 42 },
388 };
389
390 let selection_strategy = OptimizerSelectionStrategy::MultiArmedBandit {
391 algorithm: BanditAlgorithm::UCB1,
392 };
393
394 let bandit_state = BanditState {
395 reward_estimates: vec![0.0; optimizer_candidates.len()],
396 confidence_bounds: vec![1.0; optimizer_candidates.len()],
397 selection_counts: vec![0; optimizer_candidates.len()],
398 total_selections: 0,
399 exploration_param: 2.0,
400 };
401
402 Ok(Self {
403 config,
404 current_optimizer,
405 optimizer_candidates,
406 performance_history: VecDeque::new(),
407 search_state,
408 lr_scheduler: None,
409 selection_strategy,
410 step_count: 0,
411 switches_this_epoch: 0,
412 best_performance: None,
413 last_adaptation_time: Instant::now(),
414 bandit_state,
415 })
416 }
417
418 pub fn add_optimizer_candidate<F>(&mut self, name: String, factory: F)
420 where
421 F: Fn() -> Box<dyn OptimizerTrait<A, D>> + 'static,
422 {
423 self.optimizer_candidates.push(OptimizerCandidate {
424 name,
425 factory: Box::new(factory),
426 performance_history: Vec::new(),
427 usage_count: 0,
428 average_performance: 0.0,
429 confidence_interval: (0.0, 0.0),
430 });
431
432 self.bandit_state.reward_estimates.push(0.0);
434 self.bandit_state.confidence_bounds.push(1.0);
435 self.bandit_state.selection_counts.push(0);
436 }
437
438 pub fn step(
440 &mut self,
441 params: &mut [Array<A, D>],
442 grads: &[Array<A, D>],
443 stats: PerformanceStats,
444 ) -> Result<()> {
445 self.step_count += 1;
446
447 self.performance_history.push_back(stats.clone());
449 if self.performance_history.len() > self.config.evaluation_window {
450 self.performance_history.pop_front();
451 }
452
453 self.current_optimizer.step(params, grads)?;
455
456 if self.step_count > self.config.warmup_steps {
458 self.maybe_adapt_optimizer(&stats)?;
459 self.maybe_adapt_learning_rate(&stats)?;
460 self.maybe_adapt_hyperparameters(&stats)?;
461 }
462
463 let current_performance = self.extract_performance_metric(&stats);
465 if let Some(performance) = current_performance {
466 if self.best_performance.is_none()
467 || self.is_better_performance(performance, self.best_performance.unwrap())
468 {
469 self.best_performance = Some(performance);
470 }
471 }
472
473 Ok(())
474 }
475
476 fn maybe_adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
478 if !self.config.auto_optimizer_selection {
479 return Ok(());
480 }
481
482 if self.switches_this_epoch >= self.config.max_switches_per_epoch {
483 return Ok(());
484 }
485
486 let should_adapt = self.should_adapt_optimizer(stats);
487
488 if should_adapt {
489 self.adapt_optimizer(stats)?;
490 self.switches_this_epoch += 1;
491 }
492
493 Ok(())
494 }
495
496 fn should_adapt_optimizer(&self, stats: &PerformanceStats) -> bool {
498 if self.performance_history.len() < self.config.evaluation_window / 2 {
499 return false;
500 }
501
502 let recent_performance: Vec<f64> = self
504 .performance_history
505 .iter()
506 .rev()
507 .take(self.config.evaluation_window / 4)
508 .filter_map(|s| self.extract_performance_metric(s))
509 .collect();
510
511 let older_performance: Vec<f64> = self
512 .performance_history
513 .iter()
514 .rev()
515 .skip(self.config.evaluation_window / 4)
516 .take(self.config.evaluation_window / 4)
517 .filter_map(|s| self.extract_performance_metric(s))
518 .collect();
519
520 if recent_performance.is_empty() || older_performance.is_empty() {
521 return false;
522 }
523
524 let recent_avg = recent_performance.iter().sum::<f64>() / recent_performance.len() as f64;
525 let older_avg = older_performance.iter().sum::<f64>() / older_performance.len() as f64;
526
527 match self.config.target_metric {
529 TargetMetric::Loss => {
530 (recent_avg - older_avg).abs() < self.config.improvement_threshold
531 || recent_avg > older_avg
532 }
533 TargetMetric::Accuracy | TargetMetric::Throughput => {
534 (recent_avg - older_avg).abs() < self.config.improvement_threshold
535 || recent_avg < older_avg
536 }
537 _ => false,
538 }
539 }
540
541 fn adapt_optimizer(&mut self, stats: &PerformanceStats) -> Result<()> {
543 let new_optimizer_idx = match &self.selection_strategy {
544 OptimizerSelectionStrategy::MultiArmedBandit { algorithm } => {
545 self.select_optimizer_bandit(*algorithm)
546 }
547 OptimizerSelectionStrategy::PerformanceBased { .. } => {
548 self.select_optimizer_performance_based()
549 }
550 OptimizerSelectionStrategy::RoundRobin { current_index } => {
551 (*current_index + 1) % self.optimizer_candidates.len()
552 }
553 OptimizerSelectionStrategy::MetaLearning { .. } => {
554 self.select_optimizer_meta_learning(stats)
555 }
556 };
557
558 if new_optimizer_idx < self.optimizer_candidates.len() {
560 let current_lr = self.current_optimizer.learning_rate();
561 let current_state = self.current_optimizer.get_state();
562
563 self.current_optimizer = (self.optimizer_candidates[new_optimizer_idx].factory)();
564 self.current_optimizer.set_learning_rate(current_lr);
565
566 if self.current_optimizer.set_state(current_state).is_err() {
568 }
570
571 self.optimizer_candidates[new_optimizer_idx].usage_count += 1;
573 }
574
575 Ok(())
576 }
577
578 fn select_optimizer_bandit(&mut self, algorithm: BanditAlgorithm) -> usize {
580 match algorithm {
581 BanditAlgorithm::UCB1 => self.select_ucb1(),
582 BanditAlgorithm::EpsilonGreedy => self.select_epsilon_greedy(),
583 BanditAlgorithm::ThompsonSampling => self.select_thompson_sampling(),
584 BanditAlgorithm::LinUCB => self.select_linucb(),
585 }
586 }
587
588 fn select_ucb1(&self) -> usize {
590 if self.bandit_state.total_selections == 0 {
591 return 0;
592 }
593
594 let mut best_score = f64::NEG_INFINITY;
595 let mut best_idx = 0;
596
597 for (i, candidate) in self.optimizer_candidates.iter().enumerate() {
598 let ucb_score = if self.bandit_state.selection_counts[i] == 0 {
599 f64::INFINITY
600 } else {
601 let mean_reward = self.bandit_state.reward_estimates[i];
602 let confidence = (self.bandit_state.exploration_param
603 * (self.bandit_state.total_selections as f64).ln()
604 / self.bandit_state.selection_counts[i] as f64)
605 .sqrt();
606 mean_reward + confidence
607 };
608
609 if ucb_score > best_score {
610 best_score = ucb_score;
611 best_idx = i;
612 }
613 }
614
615 best_idx
616 }
617
618 fn select_epsilon_greedy(&self) -> usize {
620 let mut rng = thread_rng();
621
622 if A::from(rng.random::<f64>()).unwrap() < A::from(self.config.exploration_rate).unwrap() {
623 rng.gen_range(0..self.optimizer_candidates.len())
625 } else {
626 self.bandit_state
628 .reward_estimates
629 .iter()
630 .enumerate()
631 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
632 .map(|(idx, _)| idx)
633 .unwrap_or(0)
634 }
635 }
636
637 fn select_thompson_sampling(&self) -> usize {
639 let mut rng = thread_rng();
641
642 let mut best_sample = f64::NEG_INFINITY;
643 let mut best_idx = 0;
644
645 for (i, _) in self.optimizer_candidates.iter().enumerate() {
646 let mean = self.bandit_state.reward_estimates[i];
647 let std = self.bandit_state.confidence_bounds[i];
648 let sample = rng.gen_range(mean - std..mean + std);
649
650 if sample > best_sample {
651 best_sample = sample;
652 best_idx = i;
653 }
654 }
655
656 best_idx
657 }
658
659 fn select_linucb(&self) -> usize {
661 self.select_ucb1()
663 }
664
665 fn select_optimizer_performance_based(&self) -> usize {
667 self.optimizer_candidates
668 .iter()
669 .enumerate()
670 .max_by(|a, b| {
671 a.1.average_performance
672 .partial_cmp(&b.1.average_performance)
673 .unwrap()
674 })
675 .map(|(idx, _)| idx)
676 .unwrap_or(0)
677 }
678
679 fn select_optimizer_meta_learning(&self, stats: &PerformanceStats) -> usize {
681 0
683 }
684
685 fn maybe_adapt_learning_rate(&mut self, stats: &PerformanceStats) -> Result<()> {
687 if !self.config.auto_lr_adjustment {
688 return Ok(());
689 }
690
691 let current_lr = self.current_optimizer.learning_rate().to_f64().unwrap();
693 let gradient_norm = stats.gradient_norm;
694
695 let new_lr = if gradient_norm > 10.0 {
696 current_lr * 0.9
698 } else if gradient_norm < 0.1 {
699 current_lr * 1.1
701 } else {
702 current_lr
703 };
704
705 let clamped_lr = new_lr
706 .max(self.search_state.lr_bounds.0)
707 .min(self.search_state.lr_bounds.1);
708
709 if (clamped_lr - current_lr).abs() > current_lr * 0.01 {
710 self.current_optimizer
711 .set_learning_rate(A::from(clamped_lr).unwrap());
712 self.search_state.learning_rate = clamped_lr;
713 }
714
715 Ok(())
716 }
717
718 fn maybe_adapt_hyperparameters(&mut self, stats: &PerformanceStats) -> Result<()> {
720 Ok(())
723 }
724
725 fn extract_performance_metric(&self, stats: &PerformanceStats) -> Option<f64> {
727 match self.config.target_metric {
728 TargetMetric::Loss => Some(stats.loss),
729 TargetMetric::Accuracy => stats.accuracy,
730 TargetMetric::Throughput => Some(stats.throughput),
731 TargetMetric::ConvergenceTime => Some(stats.step_time.as_secs_f64()),
732 TargetMetric::Custom => stats.custom_metrics.values().next().copied(),
733 }
734 }
735
736 fn is_better_performance(&self, new_perf: f64, oldperf: f64) -> bool {
738 match self.config.target_metric {
739 TargetMetric::Loss | TargetMetric::ConvergenceTime => new_perf < oldperf,
740 TargetMetric::Accuracy | TargetMetric::Throughput => new_perf > oldperf,
741 TargetMetric::Custom => new_perf > oldperf, }
743 }
744
745 pub fn reset_epoch(&mut self) {
747 self.switches_this_epoch = 0;
748 }
749
750 pub fn get_optimizer_info(&self) -> OptimizerInfo {
752 OptimizerInfo {
753 name: self.current_optimizer.name().to_string(),
754 learning_rate: self.current_optimizer.learning_rate().to_f64().unwrap(),
755 step_count: self.step_count,
756 switches_this_epoch: self.switches_this_epoch,
757 performance_window_size: self.performance_history.len(),
758 best_performance: self.best_performance,
759 }
760 }
761
762 pub fn get_statistics(&self) -> SelfTuningStatistics {
764 let optimizer_usage: HashMap<String, usize> = self
765 .optimizer_candidates
766 .iter()
767 .map(|c| (c.name.clone(), c.usage_count))
768 .collect();
769
770 SelfTuningStatistics {
771 total_steps: self.step_count,
772 total_optimizer_switches: self
773 .optimizer_candidates
774 .iter()
775 .map(|c| c.usage_count)
776 .sum(),
777 optimizer_usage,
778 current_learning_rate: self.search_state.learning_rate,
779 average_step_time: self
780 .performance_history
781 .iter()
782 .map(|s| s.step_time.as_secs_f64())
783 .sum::<f64>()
784 / self.performance_history.len().max(1) as f64,
785 exploration_rate: self.config.exploration_rate,
786 }
787 }
788}
789
790#[derive(Debug, Clone)]
792pub struct OptimizerInfo {
793 pub name: String,
794 pub learning_rate: f64,
795 pub step_count: usize,
796 pub switches_this_epoch: usize,
797 pub performance_window_size: usize,
798 pub best_performance: Option<f64>,
799}
800
801#[derive(Debug, Clone)]
803pub struct SelfTuningStatistics {
804 pub total_steps: usize,
805 pub total_optimizer_switches: usize,
806 pub optimizer_usage: HashMap<String, usize>,
807 pub current_learning_rate: f64,
808 pub average_step_time: f64,
809 pub exploration_rate: f64,
810}
811
812struct AdamOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
814 inner: crate::optimizers::Adam<A>,
815 _phantom: std::marker::PhantomData<D>,
816}
817
818impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
819 AdamOptimizerWrapper<A, D>
820{
821 fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
822 Self {
823 inner: crate::optimizers::Adam::new_with_config(
824 A::from(_lr).unwrap(),
825 A::from(beta1).unwrap(),
826 A::from(beta2).unwrap(),
827 A::from(eps).unwrap(),
828 A::from(weightdecay).unwrap(),
829 ),
830 _phantom: std::marker::PhantomData,
831 }
832 }
833}
834
835impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
836 OptimizerTrait<A, D> for AdamOptimizerWrapper<A, D>
837{
838 fn name(&self) -> &str {
839 "Adam"
840 }
841
842 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
843 if params.len() != grads.len() {
844 return Err(crate::error::OptimError::InvalidParameter(
845 "Mismatched number of parameters and gradients".to_string(),
846 ));
847 }
848
849 for (param, grad) in params.iter_mut().zip(grads.iter()) {
850 let updated = self.inner.step(param, grad)?;
851 *param = updated;
852 }
853 Ok(())
854 }
855
856 fn learning_rate(&self) -> A {
857 self.inner.learning_rate()
858 }
859
860 fn set_learning_rate(&mut self, lr: A) {
861 <crate::optimizers::Adam<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
862 &mut self.inner,
863 lr,
864 );
865 }
866
867 fn get_state(&self) -> HashMap<String, Vec<u8>> {
868 HashMap::new()
869 }
870
871 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
872 Ok(())
873 }
874
875 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
876 Box::new(AdamOptimizerWrapper {
877 inner: self.inner.clone(),
878 _phantom: std::marker::PhantomData,
879 })
880 }
881}
882
883struct SGDOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
884 inner: crate::optimizers::SGD<A>,
885 _phantom: std::marker::PhantomData<D>,
886}
887
888impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
889 SGDOptimizerWrapper<A, D>
890{
891 fn new(_lr: f64, momentum: f64, weightdecay: f64, nesterov: bool) -> Self {
892 Self {
893 inner: crate::optimizers::SGD::new_with_config(
894 A::from(_lr).unwrap(),
895 A::from(momentum).unwrap(),
896 A::from(weightdecay).unwrap(),
897 ),
898 _phantom: std::marker::PhantomData,
899 }
900 }
901}
902
903impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
904 OptimizerTrait<A, D> for SGDOptimizerWrapper<A, D>
905{
906 fn name(&self) -> &str {
907 "SGD"
908 }
909
910 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
911 if params.len() != grads.len() {
912 return Err(crate::error::OptimError::InvalidParameter(
913 "Mismatched number of parameters and gradients".to_string(),
914 ));
915 }
916
917 for (param, grad) in params.iter_mut().zip(grads.iter()) {
918 let updated = self.inner.step(param, grad)?;
919 *param = updated;
920 }
921 Ok(())
922 }
923
924 fn learning_rate(&self) -> A {
925 self.inner.learning_rate()
926 }
927
928 fn set_learning_rate(&mut self, lr: A) {
929 <crate::optimizers::SGD<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
930 &mut self.inner,
931 lr,
932 );
933 }
934
935 fn get_state(&self) -> HashMap<String, Vec<u8>> {
936 HashMap::new()
937 }
938
939 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
940 Ok(())
941 }
942
943 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
944 Box::new(SGDOptimizerWrapper {
945 inner: self.inner.clone(),
946 _phantom: std::marker::PhantomData,
947 })
948 }
949}
950
951struct AdamWOptimizerWrapper<A: Float + ScalarOperand + Debug, D: Dimension> {
952 inner: crate::optimizers::AdamW<A>,
953 _phantom: std::marker::PhantomData<D>,
954}
955
956impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
957 AdamWOptimizerWrapper<A, D>
958{
959 fn new(_lr: f64, beta1: f64, beta2: f64, eps: f64, weightdecay: f64) -> Self {
960 Self {
961 inner: crate::optimizers::AdamW::new_with_config(
962 A::from(_lr).unwrap(),
963 A::from(beta1).unwrap(),
964 A::from(beta2).unwrap(),
965 A::from(eps).unwrap(),
966 A::from(weightdecay).unwrap(),
967 ),
968 _phantom: std::marker::PhantomData,
969 }
970 }
971}
972
973impl<A: Float + ScalarOperand + Debug + Send + Sync + 'static, D: Dimension + 'static>
974 OptimizerTrait<A, D> for AdamWOptimizerWrapper<A, D>
975{
976 fn name(&self) -> &str {
977 "AdamW"
978 }
979
980 fn step(&mut self, params: &mut [Array<A, D>], grads: &[Array<A, D>]) -> Result<()> {
981 if params.len() != grads.len() {
982 return Err(crate::error::OptimError::InvalidParameter(
983 "Mismatched number of parameters and gradients".to_string(),
984 ));
985 }
986
987 for (param, grad) in params.iter_mut().zip(grads.iter()) {
988 let updated = self.inner.step(param, grad)?;
989 *param = updated;
990 }
991 Ok(())
992 }
993
994 fn learning_rate(&self) -> A {
995 self.inner.learning_rate()
996 }
997
998 fn set_learning_rate(&mut self, lr: A) {
999 <crate::optimizers::AdamW<A> as crate::optimizers::Optimizer<A, D>>::set_learning_rate(
1000 &mut self.inner,
1001 lr,
1002 );
1003 }
1004
1005 fn get_state(&self) -> HashMap<String, Vec<u8>> {
1006 HashMap::new()
1007 }
1008
1009 fn set_state(&mut self, state: HashMap<String, Vec<u8>>) -> Result<()> {
1010 Ok(())
1011 }
1012
1013 fn clone_optimizer(&self) -> Box<dyn OptimizerTrait<A, D>> {
1014 Box::new(AdamWOptimizerWrapper {
1015 inner: self.inner.clone(),
1016 _phantom: std::marker::PhantomData,
1017 })
1018 }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024 use scirs2_core::ndarray::Array1;
1025 use std::time::Duration;
1026
1027 #[test]
1028 fn test_self_tuning_config_default() {
1029 let config = SelfTuningConfig::default();
1030 assert_eq!(config.evaluation_window, 100);
1031 assert!(config.auto_lr_adjustment);
1032 assert!(config.auto_optimizer_selection);
1033 }
1034
1035 #[test]
1036 fn test_self_tuning_optimizer_creation() {
1037 let config = SelfTuningConfig::default();
1038 let optimizer: Result<SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1>> =
1039 SelfTuningOptimizer::new(config);
1040 assert!(optimizer.is_ok());
1041 }
1042
1043 #[test]
1044 fn test_performance_stats() {
1045 let stats = PerformanceStats {
1046 loss: 0.5,
1047 accuracy: Some(0.9),
1048 gradient_norm: 1.2,
1049 throughput: 100.0,
1050 memory_usage: 1024.0,
1051 step_time: Duration::from_millis(50),
1052 learning_rate: 0.001,
1053 optimizer_type: "Adam".to_string(),
1054 custom_metrics: HashMap::new(),
1055 };
1056
1057 assert_eq!(stats.loss, 0.5);
1058 assert_eq!(stats.accuracy, Some(0.9));
1059 }
1060
1061 #[test]
1062 fn test_optimizer_step() {
1063 let config = SelfTuningConfig::default();
1064 let mut optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1065 SelfTuningOptimizer::new(config).unwrap();
1066
1067 let mut params = vec![Array1::zeros(10)];
1068 let grads = vec![Array1::ones(10)];
1069
1070 let stats = PerformanceStats {
1071 loss: 1.0,
1072 accuracy: None,
1073 gradient_norm: 1.0,
1074 throughput: 50.0,
1075 memory_usage: 512.0,
1076 step_time: Duration::from_millis(10),
1077 learning_rate: 0.001,
1078 optimizer_type: "Adam".to_string(),
1079 custom_metrics: HashMap::new(),
1080 };
1081
1082 let result = optimizer.step(&mut params, &grads, stats);
1083 assert!(result.is_ok());
1084
1085 let info = optimizer.get_optimizer_info();
1086 assert_eq!(info.name, "Adam");
1087 assert_eq!(info.step_count, 1);
1088 }
1089
1090 #[test]
1091 fn test_bandit_selection() {
1092 let config = SelfTuningConfig::default();
1093 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1094 SelfTuningOptimizer::new(config).unwrap();
1095
1096 let selection = optimizer.select_ucb1();
1097 assert!(selection < optimizer.optimizer_candidates.len());
1098 }
1099
1100 #[test]
1101 fn test_performance_metric_extraction() {
1102 let config = SelfTuningConfig {
1103 target_metric: TargetMetric::Loss,
1104 ..Default::default()
1105 };
1106 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1107 SelfTuningOptimizer::new(config).unwrap();
1108
1109 let stats = PerformanceStats {
1110 loss: 0.8,
1111 accuracy: Some(0.85),
1112 gradient_norm: 1.1,
1113 throughput: 75.0,
1114 memory_usage: 800.0,
1115 step_time: Duration::from_millis(20),
1116 learning_rate: 0.001,
1117 optimizer_type: "Adam".to_string(),
1118 custom_metrics: HashMap::new(),
1119 };
1120
1121 let metric = optimizer.extract_performance_metric(&stats);
1122 assert_eq!(metric, Some(0.8));
1123 }
1124
1125 #[test]
1126 fn test_statistics() {
1127 let config = SelfTuningConfig::default();
1128 let optimizer: SelfTuningOptimizer<f64, scirs2_core::ndarray::Ix1> =
1129 SelfTuningOptimizer::new(config).unwrap();
1130
1131 let stats = optimizer.get_statistics();
1132 assert_eq!(stats.total_steps, 0);
1133 assert!(stats.optimizer_usage.contains_key("Adam"));
1134 }
1135}