1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Rng};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum OptimizerType {
16 SGD,
18 SGDMomentum,
20 Adam,
22 AdamW,
24 RMSprop,
26 AdaGrad,
28 RAdam,
30 Lookahead,
32 LAMB,
34 LARS,
36 LBFGS,
38 SAM,
40}
41
42#[derive(Debug, Clone)]
44pub struct ProblemCharacteristics {
45 pub dataset_size: usize,
47 pub input_dim: usize,
49 pub output_dim: usize,
51 pub problem_type: ProblemType,
53 pub gradient_sparsity: f64,
55 pub gradient_noise: f64,
57 pub memory_budget: usize,
59 pub time_budget: f64,
61 pub batch_size: usize,
63 pub lr_sensitivity: f64,
65 pub regularization_strength: f64,
67 pub architecture_type: Option<String>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq)]
73pub enum ProblemType {
74 Classification,
76 Regression,
78 Unsupervised,
80 ReinforcementLearning,
82 TimeSeries,
84 ComputerVision,
86 NaturalLanguage,
88 Recommendation,
90}
91
92#[derive(Debug, Clone)]
94pub struct PerformanceMetrics {
95 pub final_loss: f64,
97 pub convergence_steps: usize,
99 pub training_time: f64,
101 pub memory_usage: usize,
103 pub validation_performance: f64,
105 pub stability: f64,
107 pub generalization_gap: f64,
109}
110
111#[derive(Debug, Clone)]
113pub enum SelectionStrategy {
114 RuleBased,
116 LearningBased,
118 Ensemble {
120 num_candidates: usize,
122 evaluation_steps: usize,
124 },
125 Bandit {
127 epsilon: f64,
129 confidence: f64,
131 },
132 MetaLearning {
134 feature_dim: usize,
136 k_nearest: usize,
138 },
139}
140
141#[derive(Debug)]
143pub struct AdaptiveOptimizerSelector<A: Float> {
144 strategy: SelectionStrategy,
146 performance_history: HashMap<OptimizerType, Vec<PerformanceMetrics>>,
148 problem_optimizer_map: Vec<(ProblemCharacteristics, OptimizerType, PerformanceMetrics)>,
150 current_problem: Option<ProblemCharacteristics>,
152 arm_counts: HashMap<OptimizerType, usize>,
154 arm_rewards: HashMap<OptimizerType, f64>,
155 selection_network: Option<SelectionNetwork<A>>,
157 available_optimizers: Vec<OptimizerType>,
159 current_performance: VecDeque<f64>,
161 last_confidence: f64,
163}
164
165#[derive(Debug)]
167pub struct SelectionNetwork<A: Float> {
168 input_weights: Array2<A>,
170 output_weights: Array2<A>,
172 input_bias: Array1<A>,
174 output_bias: Array1<A>,
176 #[allow(dead_code)]
178 hidden_size: usize,
179}
180
181impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
182 SelectionNetwork<A>
183{
184 pub fn new(input_size: usize, hidden_size: usize, num_optimizers: usize) -> Self {
186 let mut rng = thread_rng();
187
188 let input_weights = Array2::from_shape_fn((hidden_size, input_size), |_| {
189 A::from(rng.random::<f64>()).unwrap() * A::from(0.1).unwrap() - A::from(0.05).unwrap()
190 });
191
192 let output_weights = Array2::from_shape_fn((num_optimizers, hidden_size), |_| {
193 A::from(rng.random::<f64>()).unwrap() * A::from(0.1).unwrap() - A::from(0.05).unwrap()
194 });
195
196 let input_bias = Array1::zeros(hidden_size);
197 let output_bias = Array1::zeros(num_optimizers);
198
199 Self {
200 input_weights,
201 output_weights,
202 input_bias,
203 output_bias,
204 hidden_size,
205 }
206 }
207
208 pub fn forward(&self, features: &Array1<A>) -> Result<Array1<A>> {
210 let hidden = self.input_weights.dot(features) + self.input_bias.clone();
212 let hidden_activated = hidden.mapv(|x| {
213 if x > A::zero() {
215 x
216 } else {
217 A::zero()
218 }
219 });
220
221 let output = self.output_weights.dot(&hidden_activated) + &self.output_bias;
223
224 let max_val = output.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
226 let exp_output = output.mapv(|x| A::exp(x - max_val));
227 let sum_exp = exp_output.sum();
228 let probabilities = exp_output.mapv(|x| x / sum_exp);
229
230 Ok(probabilities)
231 }
232
233 pub fn train(
235 &mut self,
236 features: &[Array1<A>],
237 optimizer_labels: &[usize],
238 learning_rate: A,
239 epochs: usize,
240 ) -> Result<()> {
241 for _ in 0..epochs {
242 for (feature, &label) in features.iter().zip(optimizer_labels.iter()) {
243 let probabilities = self.forward(feature)?;
245
246 let target_prob = probabilities[label];
248 let _loss = -A::ln(target_prob);
249
250 let mut output_grad = probabilities;
252 output_grad[label] = output_grad[label] - A::one();
253
254 let hidden = self.input_weights.dot(feature) + self.input_bias.clone();
256 let hidden_activated = hidden.mapv(|x| if x > A::zero() { x } else { A::zero() });
257
258 for i in 0..self.output_weights.nrows() {
260 for j in 0..self.output_weights.ncols() {
261 self.output_weights[[i, j]] = self.output_weights[[i, j]]
262 - learning_rate * output_grad[i] * hidden_activated[j];
263 }
264 }
265
266 for i in 0..self.output_bias.len() {
268 self.output_bias[i] = self.output_bias[i] - learning_rate * output_grad[i];
269 }
270 }
271 }
272 Ok(())
273 }
274}
275
276impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
277 AdaptiveOptimizerSelector<A>
278{
279 pub fn new(strategy: SelectionStrategy) -> Self {
281 let available_optimizers = vec![
282 OptimizerType::SGD,
283 OptimizerType::SGDMomentum,
284 OptimizerType::Adam,
285 OptimizerType::AdamW,
286 OptimizerType::RMSprop,
287 OptimizerType::AdaGrad,
288 OptimizerType::RAdam,
289 OptimizerType::LAMB,
290 ];
291
292 let mut arm_counts = HashMap::new();
293 let mut arm_rewards = HashMap::new();
294 for &optimizer in &available_optimizers {
295 arm_counts.insert(optimizer, 0);
296 arm_rewards.insert(optimizer, 0.0);
297 }
298
299 Self {
300 strategy,
301 performance_history: HashMap::new(),
302 problem_optimizer_map: Vec::new(),
303 current_problem: None,
304 arm_counts,
305 arm_rewards,
306 selection_network: None,
307 available_optimizers,
308 current_performance: VecDeque::new(),
309 last_confidence: 0.0,
310 }
311 }
312
313 pub fn set_problem(&mut self, problem: ProblemCharacteristics) {
315 self.current_problem = Some(problem);
316 }
317
318 pub fn select_optimizer(&mut self) -> Result<OptimizerType> {
320 let problem = self.current_problem.clone().ok_or_else(|| {
321 OptimError::InvalidConfig("No problem characteristics set".to_string())
322 })?;
323
324 match &self.strategy {
325 SelectionStrategy::RuleBased => self.rule_based_selection(&problem),
326 SelectionStrategy::LearningBased => self.learning_based_selection(&problem),
327 SelectionStrategy::Ensemble {
328 num_candidates,
329 evaluation_steps,
330 } => self.ensemble_selection(&problem, *num_candidates, *evaluation_steps),
331 SelectionStrategy::Bandit {
332 epsilon,
333 confidence,
334 } => self.bandit_selection(&problem, *epsilon, *confidence),
335 SelectionStrategy::MetaLearning {
336 feature_dim,
337 k_nearest,
338 } => self.meta_learning_selection(&problem, *feature_dim),
339 }
340 }
341
342 fn rule_based_selection(&self, problem: &ProblemCharacteristics) -> Result<OptimizerType> {
344 if problem.dataset_size > 100000 {
346 match problem.problem_type {
347 ProblemType::ComputerVision => return Ok(OptimizerType::AdamW),
348 ProblemType::NaturalLanguage => return Ok(OptimizerType::AdamW),
349 _ => return Ok(OptimizerType::Adam),
350 }
351 }
352
353 if problem.dataset_size < 1000 {
355 return Ok(OptimizerType::LBFGS);
356 }
357
358 if problem.gradient_sparsity > 0.5 {
360 return Ok(OptimizerType::AdaGrad);
361 }
362
363 if problem.batch_size > 256 {
365 return Ok(OptimizerType::LAMB);
366 }
367
368 if problem.memory_budget < 1_000_000 {
370 return Ok(OptimizerType::SGD);
371 }
372
373 if problem.gradient_noise > 0.3 {
375 return Ok(OptimizerType::RMSprop);
376 }
377
378 Ok(OptimizerType::Adam)
380 }
381
382 fn learning_based_selection(
384 &mut self,
385 problem: &ProblemCharacteristics,
386 ) -> Result<OptimizerType> {
387 if self.problem_optimizer_map.is_empty() {
388 return self.rule_based_selection(problem);
390 }
391
392 let mut best_similarity = -1.0;
394 let mut best_optimizer = OptimizerType::Adam;
395
396 for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
397 let similarity = self.compute_problem_similarity(problem, hist_problem);
398
399 let weighted_similarity = similarity * metrics.validation_performance;
401
402 if weighted_similarity > best_similarity {
403 best_similarity = weighted_similarity;
404 best_optimizer = *optimizer;
405 }
406 }
407
408 self.last_confidence = best_similarity;
409 Ok(best_optimizer)
410 }
411
412 fn ensemble_selection(
414 &self,
415 problem: &ProblemCharacteristics,
416 num_candidates: usize,
417 _evaluation_steps: usize,
418 ) -> Result<OptimizerType> {
419 let mut candidates = self.available_optimizers.clone();
421 candidates.truncate(num_candidates.min(candidates.len()));
422
423 Ok(candidates[0])
426 }
427
428 fn bandit_selection(
430 &self,
431 problem: &ProblemCharacteristics,
432 epsilon: f64,
433 confidence: f64,
434 ) -> Result<OptimizerType> {
435 let mut rng = thread_rng();
436
437 if rng.random::<f64>() < epsilon {
439 let idx = rng.gen_range(0..self.available_optimizers.len());
441 return Ok(self.available_optimizers[idx]);
442 }
443
444 let mut best_ucb = f64::NEG_INFINITY;
446 let mut best_optimizer = OptimizerType::Adam;
447 let total_counts: usize = self.arm_counts.values().sum();
448
449 for &optimizer in &self.available_optimizers {
450 let count = self.arm_counts[&optimizer] as f64;
451 let reward = if count > 0.0 {
452 self.arm_rewards[&optimizer] / count
453 } else {
454 0.0
455 };
456
457 let ucb = if count > 0.0 {
458 reward + confidence * ((total_counts as f64).ln() / count).sqrt()
459 } else {
460 f64::INFINITY };
462
463 if ucb > best_ucb {
464 best_ucb = ucb;
465 best_optimizer = optimizer;
466 }
467 }
468
469 Ok(best_optimizer)
470 }
471
472 fn meta_learning_selection(
474 &mut self,
475 problem: &ProblemCharacteristics,
476 k_nearest: usize,
477 ) -> Result<OptimizerType> {
478 let features = self.extract_problem_features(problem);
480
481 if let Some(network) = &self.selection_network {
483 let probabilities = network.forward(&features)?;
484
485 let mut best_prob = A::neg_infinity();
487 let mut best_idx = 0;
488
489 for (i, &prob) in probabilities.iter().enumerate() {
490 if prob > best_prob {
491 best_prob = prob;
492 best_idx = i;
493 }
494 }
495
496 if best_idx < self.available_optimizers.len() {
497 return Ok(self.available_optimizers[best_idx]);
498 }
499 }
500
501 if self.problem_optimizer_map.len() >= k_nearest {
503 let mut similarities = Vec::new();
504
505 for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
506 let similarity = self.compute_problem_similarity(problem, hist_problem);
507 similarities.push((similarity, *optimizer, metrics.validation_performance));
508 }
509
510 similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
512
513 let mut votes: HashMap<OptimizerType, f64> = HashMap::new();
515 for (similarity, optimizer, performance) in similarities.iter().take(k_nearest) {
516 let weight = similarity * performance;
517 *votes.entry(*optimizer).or_insert(0.0) += weight;
518 }
519
520 let best_optimizer = votes
522 .iter()
523 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
524 .map(|(optimizer_, _)| *optimizer_)
525 .unwrap_or(OptimizerType::Adam);
526
527 return Ok(best_optimizer);
528 }
529
530 self.rule_based_selection(problem)
532 }
533
534 pub fn update_performance(
536 &mut self,
537 optimizer: OptimizerType,
538 metrics: PerformanceMetrics,
539 ) -> Result<()> {
540 self.performance_history
542 .entry(optimizer)
543 .or_default()
544 .push(metrics.clone());
545
546 *self.arm_counts.entry(optimizer).or_insert(0) += 1;
548 *self.arm_rewards.entry(optimizer).or_insert(0.0) += metrics.validation_performance;
549
550 if let Some(problem) = &self.current_problem {
552 self.problem_optimizer_map
553 .push((problem.clone(), optimizer, metrics.clone()));
554 }
555
556 self.current_performance
558 .push_back(metrics.validation_performance);
559 if self.current_performance.len() > 100 {
560 self.current_performance.pop_front();
561 }
562
563 Ok(())
564 }
565
566 pub fn train_selection_network(&mut self, learning_rate: A, epochs: usize) -> Result<()> {
568 if self.problem_optimizer_map.is_empty() {
569 return Ok(()); }
571
572 let mut features = Vec::new();
574 let mut labels = Vec::new();
575
576 for (problem, optimizer_, metrics) in &self.problem_optimizer_map {
577 let feature_vec = self.extract_problem_features(problem);
578 features.push(feature_vec);
579
580 if let Some(label) = self
582 .available_optimizers
583 .iter()
584 .position(|&opt| opt == *optimizer_)
585 {
586 labels.push(label);
587 }
588 }
589
590 if self.selection_network.is_none() {
592 let feature_dim = features[0].len();
593 let num_optimizers = self.available_optimizers.len();
594 self.selection_network = Some(SelectionNetwork::new(feature_dim, 32, num_optimizers));
595 }
596
597 if let Some(network) = &mut self.selection_network {
599 network.train(&features, &labels, learning_rate, epochs)?;
600 }
601
602 Ok(())
603 }
604
605 fn compute_problem_similarity(
607 &self,
608 problem1: &ProblemCharacteristics,
609 problem2: &ProblemCharacteristics,
610 ) -> f64 {
611 let mut similarity = 0.0;
612 let mut weight_sum = 0.0;
613
614 let size_sim = 1.0
616 - ((problem1.dataset_size as f64).ln() - (problem2.dataset_size as f64).ln()).abs()
617 / 10.0;
618 similarity += size_sim.max(0.0) * 0.2;
619 weight_sum += 0.2;
620
621 if problem1.problem_type == problem2.problem_type {
623 similarity += 0.3;
624 }
625 weight_sum += 0.3;
626
627 let batch_sim = 1.0
629 - ((problem1.batch_size as f64 - problem2.batch_size as f64).abs() / 256.0).min(1.0);
630 similarity += batch_sim * 0.1;
631 weight_sum += 0.1;
632
633 let sparsity_sim = 1.0 - (problem1.gradient_sparsity - problem2.gradient_sparsity).abs();
635 let noise_sim = 1.0 - (problem1.gradient_noise - problem2.gradient_noise).abs();
636 similarity += (sparsity_sim + noise_sim) * 0.2;
637 weight_sum += 0.4;
638
639 similarity / weight_sum
640 }
641
642 fn extract_problem_features(&self, problem: &ProblemCharacteristics) -> Array1<A> {
644 Array1::from_vec(vec![
645 A::from((problem.dataset_size as f64).ln()).unwrap(),
646 A::from((problem.input_dim as f64).ln()).unwrap(),
647 A::from((problem.output_dim as f64).ln()).unwrap(),
648 A::from(problem.problem_type as u8 as f64).unwrap(),
649 A::from(problem.gradient_sparsity).unwrap(),
650 A::from(problem.gradient_noise).unwrap(),
651 A::from((problem.memory_budget as f64).ln()).unwrap(),
652 A::from(problem.time_budget.ln()).unwrap(),
653 A::from((problem.batch_size as f64).ln()).unwrap(),
654 A::from(problem.lr_sensitivity).unwrap(),
655 A::from(problem.regularization_strength).unwrap(),
656 ])
657 }
658
659 pub fn get_optimizer_statistics(
661 &self,
662 optimizer: OptimizerType,
663 ) -> Option<OptimizerStatistics> {
664 if let Some(history) = self.performance_history.get(&optimizer) {
665 if history.is_empty() {
666 return None;
667 }
668
669 let performances: Vec<f64> = history.iter().map(|m| m.validation_performance).collect();
670 let mean = performances.iter().sum::<f64>() / performances.len() as f64;
671 let variance = performances.iter().map(|p| (p - mean).powi(2)).sum::<f64>()
672 / performances.len() as f64;
673 let std_dev = variance.sqrt();
674
675 Some(OptimizerStatistics {
676 optimizer,
677 num_trials: history.len(),
678 mean_performance: mean,
679 std_performance: std_dev,
680 best_performance: performances
681 .iter()
682 .copied()
683 .fold(f64::NEG_INFINITY, f64::max),
684 worst_performance: performances.iter().copied().fold(f64::INFINITY, f64::min),
685 success_rate: performances.iter().filter(|&&p| p > 0.7).count() as f64
686 / performances.len() as f64,
687 })
688 } else {
689 None
690 }
691 }
692
693 pub fn get_all_statistics(&self) -> Vec<OptimizerStatistics> {
695 self.available_optimizers
696 .iter()
697 .filter_map(|&opt| self.get_optimizer_statistics(opt))
698 .collect()
699 }
700
701 pub fn get_selection_confidence(&self) -> f64 {
703 self.last_confidence
704 }
705
706 pub fn reset(&mut self) {
708 self.performance_history.clear();
709 self.problem_optimizer_map.clear();
710 self.current_problem = None;
711 for count in self.arm_counts.values_mut() {
712 *count = 0;
713 }
714 for reward in self.arm_rewards.values_mut() {
715 *reward = 0.0;
716 }
717 self.current_performance.clear();
718 self.last_confidence = 0.0;
719 }
720}
721
722#[derive(Debug, Clone)]
724pub struct OptimizerStatistics {
725 pub optimizer: OptimizerType,
727 pub num_trials: usize,
729 pub mean_performance: f64,
731 pub std_performance: f64,
733 pub best_performance: f64,
735 pub worst_performance: f64,
737 pub success_rate: f64,
739}
740
741#[cfg(test)]
742mod tests {
743 use super::*;
744 use approx::assert_relative_eq;
745
746 #[test]
747 fn test_problem_characteristics() {
748 let problem = ProblemCharacteristics {
749 dataset_size: 10000,
750 input_dim: 784,
751 output_dim: 10,
752 problem_type: ProblemType::Classification,
753 gradient_sparsity: 0.1,
754 gradient_noise: 0.05,
755 memory_budget: 1_000_000,
756 time_budget: 3600.0,
757 batch_size: 64,
758 lr_sensitivity: 0.5,
759 regularization_strength: 0.01,
760 architecture_type: Some("CNN".to_string()),
761 };
762
763 assert_eq!(problem.dataset_size, 10000);
764 assert_eq!(problem.problem_type, ProblemType::Classification);
765 }
766
767 #[test]
768 fn test_rule_based_selection() {
769 let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
770
771 let large_problem = ProblemCharacteristics {
773 dataset_size: 100001,
774 input_dim: 224,
775 output_dim: 1000,
776 problem_type: ProblemType::ComputerVision,
777 gradient_sparsity: 0.1,
778 gradient_noise: 0.05,
779 memory_budget: 10_000_000,
780 time_budget: 7200.0,
781 batch_size: 32,
782 lr_sensitivity: 0.5,
783 regularization_strength: 0.01,
784 architecture_type: Some("ResNet".to_string()),
785 };
786
787 selector.set_problem(large_problem);
788 let optimizer = selector.select_optimizer().unwrap();
789 assert_eq!(optimizer, OptimizerType::AdamW);
790 }
791
792 #[test]
793 fn test_selection_network() {
794 let network = SelectionNetwork::<f64>::new(5, 10, 3);
795 let features = Array1::from_vec(vec![1.0, 0.5, 2.0, 0.8, 1.5]);
796
797 let probabilities = network.forward(&features).unwrap();
798 assert_eq!(probabilities.len(), 3);
799
800 let sum: f64 = probabilities.iter().sum();
802 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
803
804 for &prob in probabilities.iter() {
806 assert!(prob >= 0.0);
807 }
808 }
809
810 #[test]
811 fn test_bandit_selection() {
812 let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::Bandit {
813 epsilon: 0.1,
814 confidence: 2.0,
815 });
816
817 let problem = ProblemCharacteristics {
818 dataset_size: 1000,
819 input_dim: 10,
820 output_dim: 2,
821 problem_type: ProblemType::Classification,
822 gradient_sparsity: 0.0,
823 gradient_noise: 0.1,
824 memory_budget: 1_000_000,
825 time_budget: 600.0,
826 batch_size: 32,
827 lr_sensitivity: 0.5,
828 regularization_strength: 0.01,
829 architecture_type: None,
830 };
831
832 selector.set_problem(problem);
833
834 let optimizer = selector.select_optimizer().unwrap();
836 assert!(selector.available_optimizers.contains(&optimizer));
837 }
838
839 #[test]
840 fn test_performance_update() {
841 let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
842
843 let metrics = PerformanceMetrics {
844 final_loss: 0.1,
845 convergence_steps: 100,
846 training_time: 60.0,
847 memory_usage: 500_000,
848 validation_performance: 0.95,
849 stability: 0.02,
850 generalization_gap: 0.05,
851 };
852
853 selector
854 .update_performance(OptimizerType::Adam, metrics)
855 .unwrap();
856
857 let stats = selector
858 .get_optimizer_statistics(OptimizerType::Adam)
859 .unwrap();
860 assert_eq!(stats.num_trials, 1);
861 assert_relative_eq!(stats.mean_performance, 0.95, epsilon = 1e-6);
862 }
863
864 #[test]
865 fn test_problem_similarity() {
866 let selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
867
868 let problem1 = ProblemCharacteristics {
869 dataset_size: 1000,
870 input_dim: 10,
871 output_dim: 2,
872 problem_type: ProblemType::Classification,
873 gradient_sparsity: 0.1,
874 gradient_noise: 0.05,
875 memory_budget: 1_000_000,
876 time_budget: 600.0,
877 batch_size: 32,
878 lr_sensitivity: 0.5,
879 regularization_strength: 0.01,
880 architecture_type: None,
881 };
882
883 let problem2 = problem1.clone();
884 let similarity = selector.compute_problem_similarity(&problem1, &problem2);
885 assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
886
887 let mut problem3 = problem1.clone();
888 problem3.problem_type = ProblemType::Regression;
889 let similarity = selector.compute_problem_similarity(&problem1, &problem3);
890 assert!(similarity < 1.0);
891 }
892}