Skip to main content

optirs_core/adaptive_selection/
mod.rs

1// Adaptive optimization algorithm selection
2//
3// This module provides automatic selection of the most appropriate optimization algorithm
4// based on problem characteristics, performance monitoring, and learned patterns.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Rng};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13/// Types of optimization algorithms available for selection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum OptimizerType {
16    /// Stochastic Gradient Descent
17    SGD,
18    /// SGD with momentum
19    SGDMomentum,
20    /// Adam optimizer
21    Adam,
22    /// AdamW (Adam with decoupled weight decay)
23    AdamW,
24    /// RMSprop optimizer
25    RMSprop,
26    /// AdaGrad optimizer
27    AdaGrad,
28    /// RAdam (Rectified Adam)
29    RAdam,
30    /// Lookahead wrapper
31    Lookahead,
32    /// LAMB (Layer-wise Adaptive Moments)
33    LAMB,
34    /// LARS (Layer-wise Adaptive Rate Scaling)
35    LARS,
36    /// L-BFGS (Limited-memory BFGS)
37    LBFGS,
38    /// SAM (Sharpness-Aware Minimization)
39    SAM,
40}
41
42/// Problem characteristics for optimizer selection
43#[derive(Debug, Clone)]
44pub struct ProblemCharacteristics {
45    /// Dataset size
46    pub dataset_size: usize,
47    /// Input dimensionality
48    pub input_dim: usize,
49    /// Output dimensionality
50    pub output_dim: usize,
51    /// Problem type (classification, regression, etc.)
52    pub problem_type: ProblemType,
53    /// Gradient sparsity (0.0 = dense, 1.0 = very sparse)
54    pub gradient_sparsity: f64,
55    /// Noise level in gradients
56    pub gradient_noise: f64,
57    /// Memory constraints (bytes available)
58    pub memory_budget: usize,
59    /// Computational budget (time constraints)
60    pub time_budget: f64,
61    /// Batch size being used
62    pub batch_size: usize,
63    /// Learning rate range preference
64    pub lr_sensitivity: f64,
65    /// Regularization requirements
66    pub regularization_strength: f64,
67    /// Architecture type (if applicable)
68    pub architecture_type: Option<String>,
69}
70
71/// Types of machine learning problems
72#[derive(Debug, Clone, Copy, PartialEq)]
73pub enum ProblemType {
74    /// Classification task
75    Classification,
76    /// Regression task
77    Regression,
78    /// Unsupervised learning
79    Unsupervised,
80    /// Reinforcement learning
81    ReinforcementLearning,
82    /// Time series forecasting
83    TimeSeries,
84    /// Computer vision
85    ComputerVision,
86    /// Natural language processing
87    NaturalLanguage,
88    /// Recommendation systems
89    Recommendation,
90}
91
92/// Performance metrics for optimizer evaluation
93#[derive(Debug, Clone)]
94pub struct PerformanceMetrics {
95    /// Final loss/error achieved
96    pub final_loss: f64,
97    /// Convergence speed (steps to reach target)
98    pub convergence_steps: usize,
99    /// Training time taken
100    pub training_time: f64,
101    /// Memory usage
102    pub memory_usage: usize,
103    /// Validation performance
104    pub validation_performance: f64,
105    /// Stability (variance in loss)
106    pub stability: f64,
107    /// Generalization (validation - training performance)
108    pub generalization_gap: f64,
109}
110
111/// Selection strategy for adaptive optimization
112#[derive(Debug, Clone)]
113pub enum SelectionStrategy {
114    /// Rule-based selection using expert knowledge
115    RuleBased,
116    /// Learning-based selection using historical data
117    LearningBased,
118    /// Ensemble selection trying multiple optimizers
119    Ensemble {
120        /// Number of optimizers to try
121        num_candidates: usize,
122        /// Number of steps for evaluation
123        evaluation_steps: usize,
124    },
125    /// Bandit-based selection with exploration/exploitation
126    Bandit {
127        /// Exploration parameter
128        epsilon: f64,
129        /// UCB confidence parameter
130        confidence: f64,
131    },
132    /// Meta-learning based selection
133    MetaLearning {
134        /// Feature extractor for problems
135        feature_dim: usize,
136        /// Number of similar problems to consider
137        k_nearest: usize,
138    },
139}
140
141/// Adaptive optimizer selector
142#[derive(Debug)]
143pub struct AdaptiveOptimizerSelector<A: Float> {
144    /// Selection strategy
145    strategy: SelectionStrategy,
146    /// Historical performance data
147    performance_history: HashMap<OptimizerType, Vec<PerformanceMetrics>>,
148    /// Problem-optimizer mapping for learning
149    problem_optimizer_map: Vec<(ProblemCharacteristics, OptimizerType, PerformanceMetrics)>,
150    /// Current problem characteristics
151    current_problem: Option<ProblemCharacteristics>,
152    /// Bandit arm statistics (if using bandit strategy)
153    arm_counts: HashMap<OptimizerType, usize>,
154    arm_rewards: HashMap<OptimizerType, f64>,
155    /// Neural network for learning-based selection
156    selection_network: Option<SelectionNetwork<A>>,
157    /// Available optimizers
158    available_optimizers: Vec<OptimizerType>,
159    /// Performance tracking
160    current_performance: VecDeque<f64>,
161    /// Selection confidence
162    last_confidence: f64,
163}
164
165/// Neural network for optimizer selection
166#[derive(Debug)]
167pub struct SelectionNetwork<A: Float> {
168    /// Input weights (problem features -> hidden)
169    input_weights: Array2<A>,
170    /// Output weights (hidden -> optimizer probabilities)
171    output_weights: Array2<A>,
172    /// Input biases
173    input_bias: Array1<A>,
174    /// Output biases
175    output_bias: Array1<A>,
176    /// Hidden layer size
177    #[allow(dead_code)]
178    hidden_size: usize,
179}
180
181impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
182    SelectionNetwork<A>
183{
184    /// Create a new selection network
185    pub fn new(input_size: usize, hidden_size: usize, num_optimizers: usize) -> Self {
186        let mut rng = thread_rng();
187
188        let input_weights = Array2::from_shape_fn((hidden_size, input_size), |_| {
189            A::from(rng.random::<f64>()).expect("unwrap failed")
190                * A::from(0.1).expect("unwrap failed")
191                - A::from(0.05).expect("unwrap failed")
192        });
193
194        let output_weights = Array2::from_shape_fn((num_optimizers, hidden_size), |_| {
195            A::from(rng.random::<f64>()).expect("unwrap failed")
196                * A::from(0.1).expect("unwrap failed")
197                - A::from(0.05).expect("unwrap failed")
198        });
199
200        let input_bias = Array1::zeros(hidden_size);
201        let output_bias = Array1::zeros(num_optimizers);
202
203        Self {
204            input_weights,
205            output_weights,
206            input_bias,
207            output_bias,
208            hidden_size,
209        }
210    }
211
212    /// Forward pass to get optimizer probabilities
213    pub fn forward(&self, features: &Array1<A>) -> Result<Array1<A>> {
214        // Hidden layer
215        let hidden = self.input_weights.dot(features) + self.input_bias.clone();
216        let hidden_activated = hidden.mapv(|x| {
217            // ReLU activation
218            if x > A::zero() {
219                x
220            } else {
221                A::zero()
222            }
223        });
224
225        // Output layer
226        let output = self.output_weights.dot(&hidden_activated) + &self.output_bias;
227
228        // Softmax activation
229        let max_val = output.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
230        let exp_output = output.mapv(|x| A::exp(x - max_val));
231        let sum_exp = exp_output.sum();
232        let probabilities = exp_output.mapv(|x| x / sum_exp);
233
234        Ok(probabilities)
235    }
236
237    /// Train the network on historical data
238    pub fn train(
239        &mut self,
240        features: &[Array1<A>],
241        optimizer_labels: &[usize],
242        learning_rate: A,
243        epochs: usize,
244    ) -> Result<()> {
245        for _ in 0..epochs {
246            for (feature, &label) in features.iter().zip(optimizer_labels.iter()) {
247                // Forward pass
248                let probabilities = self.forward(feature)?;
249
250                // Compute loss (cross-entropy)
251                let target_prob = probabilities[label];
252                let _loss = -A::ln(target_prob);
253
254                // Backward pass (simplified)
255                let mut output_grad = probabilities;
256                output_grad[label] = output_grad[label] - A::one();
257
258                // Update weights (simplified gradient descent)
259                let hidden = self.input_weights.dot(feature) + self.input_bias.clone();
260                let hidden_activated = hidden.mapv(|x| if x > A::zero() { x } else { A::zero() });
261
262                // Update output weights
263                for i in 0..self.output_weights.nrows() {
264                    for j in 0..self.output_weights.ncols() {
265                        self.output_weights[[i, j]] = self.output_weights[[i, j]]
266                            - learning_rate * output_grad[i] * hidden_activated[j];
267                    }
268                }
269
270                // Update output bias
271                for i in 0..self.output_bias.len() {
272                    self.output_bias[i] = self.output_bias[i] - learning_rate * output_grad[i];
273                }
274            }
275        }
276        Ok(())
277    }
278}
279
280impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
281    AdaptiveOptimizerSelector<A>
282{
283    /// Create a new adaptive optimizer selector
284    pub fn new(strategy: SelectionStrategy) -> Self {
285        let available_optimizers = vec![
286            OptimizerType::SGD,
287            OptimizerType::SGDMomentum,
288            OptimizerType::Adam,
289            OptimizerType::AdamW,
290            OptimizerType::RMSprop,
291            OptimizerType::AdaGrad,
292            OptimizerType::RAdam,
293            OptimizerType::LAMB,
294        ];
295
296        let mut arm_counts = HashMap::new();
297        let mut arm_rewards = HashMap::new();
298        for &optimizer in &available_optimizers {
299            arm_counts.insert(optimizer, 0);
300            arm_rewards.insert(optimizer, 0.0);
301        }
302
303        Self {
304            strategy,
305            performance_history: HashMap::new(),
306            problem_optimizer_map: Vec::new(),
307            current_problem: None,
308            arm_counts,
309            arm_rewards,
310            selection_network: None,
311            available_optimizers,
312            current_performance: VecDeque::new(),
313            last_confidence: 0.0,
314        }
315    }
316
317    /// Set the current problem characteristics
318    pub fn set_problem(&mut self, problem: ProblemCharacteristics) {
319        self.current_problem = Some(problem);
320    }
321
322    /// Select the best optimizer for the current problem
323    pub fn select_optimizer(&mut self) -> Result<OptimizerType> {
324        let problem = self.current_problem.clone().ok_or_else(|| {
325            OptimError::InvalidConfig("No problem characteristics set".to_string())
326        })?;
327
328        match &self.strategy {
329            SelectionStrategy::RuleBased => self.rule_based_selection(&problem),
330            SelectionStrategy::LearningBased => self.learning_based_selection(&problem),
331            SelectionStrategy::Ensemble {
332                num_candidates,
333                evaluation_steps,
334            } => self.ensemble_selection(&problem, *num_candidates, *evaluation_steps),
335            SelectionStrategy::Bandit {
336                epsilon,
337                confidence,
338            } => self.bandit_selection(&problem, *epsilon, *confidence),
339            SelectionStrategy::MetaLearning {
340                feature_dim,
341                k_nearest,
342            } => self.meta_learning_selection(&problem, *feature_dim),
343        }
344    }
345
346    /// Rule-based optimizer selection using expert knowledge
347    fn rule_based_selection(&self, problem: &ProblemCharacteristics) -> Result<OptimizerType> {
348        // Large dataset, use adaptive optimizers
349        if problem.dataset_size > 100000 {
350            match problem.problem_type {
351                ProblemType::ComputerVision => return Ok(OptimizerType::AdamW),
352                ProblemType::NaturalLanguage => return Ok(OptimizerType::AdamW),
353                _ => return Ok(OptimizerType::Adam),
354            }
355        }
356
357        // Small dataset, use SGD with momentum
358        if problem.dataset_size < 1000 {
359            return Ok(OptimizerType::LBFGS);
360        }
361
362        // Sparse gradients
363        if problem.gradient_sparsity > 0.5 {
364            return Ok(OptimizerType::AdaGrad);
365        }
366
367        // Large batch training
368        if problem.batch_size > 256 {
369            return Ok(OptimizerType::LAMB);
370        }
371
372        // Memory constrained
373        if problem.memory_budget < 1_000_000 {
374            return Ok(OptimizerType::SGD);
375        }
376
377        // High noise
378        if problem.gradient_noise > 0.3 {
379            return Ok(OptimizerType::RMSprop);
380        }
381
382        // Default choice
383        Ok(OptimizerType::Adam)
384    }
385
386    /// Learning-based selection using historical performance
387    fn learning_based_selection(
388        &mut self,
389        problem: &ProblemCharacteristics,
390    ) -> Result<OptimizerType> {
391        if self.problem_optimizer_map.is_empty() {
392            // No historical data, fall back to rule-based
393            return self.rule_based_selection(problem);
394        }
395
396        // Find most similar problem in history
397        let mut best_similarity = -1.0;
398        let mut best_optimizer = OptimizerType::Adam;
399
400        for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
401            let similarity = self.compute_problem_similarity(problem, hist_problem);
402
403            // Weight by performance
404            let weighted_similarity = similarity * metrics.validation_performance;
405
406            if weighted_similarity > best_similarity {
407                best_similarity = weighted_similarity;
408                best_optimizer = *optimizer;
409            }
410        }
411
412        self.last_confidence = best_similarity;
413        Ok(best_optimizer)
414    }
415
416    /// Ensemble selection by trying multiple optimizers
417    fn ensemble_selection(
418        &self,
419        problem: &ProblemCharacteristics,
420        num_candidates: usize,
421        _evaluation_steps: usize,
422    ) -> Result<OptimizerType> {
423        // Select top _candidates based on historical performance
424        let mut candidates = self.available_optimizers.clone();
425        candidates.truncate(num_candidates.min(candidates.len()));
426
427        // For simplicity, return the first candidate
428        // In practice, you would evaluate each for evaluation_steps
429        Ok(candidates[0])
430    }
431
432    /// Bandit-based selection with epsilon-greedy strategy
433    fn bandit_selection(
434        &self,
435        problem: &ProblemCharacteristics,
436        epsilon: f64,
437        confidence: f64,
438    ) -> Result<OptimizerType> {
439        let mut rng = thread_rng();
440
441        // Epsilon-greedy exploration
442        if rng.random::<f64>() < epsilon {
443            // Explore: random selection
444            let idx = rng.gen_range(0..self.available_optimizers.len());
445            return Ok(self.available_optimizers[idx]);
446        }
447
448        // Exploit: UCB (Upper Confidence Bound) selection
449        let mut best_ucb = f64::NEG_INFINITY;
450        let mut best_optimizer = OptimizerType::Adam;
451        let total_counts: usize = self.arm_counts.values().sum();
452
453        for &optimizer in &self.available_optimizers {
454            let count = self.arm_counts[&optimizer] as f64;
455            let reward = if count > 0.0 {
456                self.arm_rewards[&optimizer] / count
457            } else {
458                0.0
459            };
460
461            let ucb = if count > 0.0 {
462                reward + confidence * ((total_counts as f64).ln() / count).sqrt()
463            } else {
464                f64::INFINITY // Prefer unvisited arms
465            };
466
467            if ucb > best_ucb {
468                best_ucb = ucb;
469                best_optimizer = optimizer;
470            }
471        }
472
473        Ok(best_optimizer)
474    }
475
476    /// Meta-learning based selection
477    fn meta_learning_selection(
478        &mut self,
479        problem: &ProblemCharacteristics,
480        k_nearest: usize,
481    ) -> Result<OptimizerType> {
482        // Extract features from problem
483        let features = self.extract_problem_features(problem);
484
485        // If we have a trained network, use it
486        if let Some(network) = &self.selection_network {
487            let probabilities = network.forward(&features)?;
488
489            // Select optimizer with highest probability
490            let mut best_prob = A::neg_infinity();
491            let mut best_idx = 0;
492
493            for (i, &prob) in probabilities.iter().enumerate() {
494                if prob > best_prob {
495                    best_prob = prob;
496                    best_idx = i;
497                }
498            }
499
500            if best_idx < self.available_optimizers.len() {
501                return Ok(self.available_optimizers[best_idx]);
502            }
503        }
504
505        // k-NN fallback
506        if self.problem_optimizer_map.len() >= k_nearest {
507            let mut similarities = Vec::new();
508
509            for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
510                let similarity = self.compute_problem_similarity(problem, hist_problem);
511                similarities.push((similarity, *optimizer, metrics.validation_performance));
512            }
513
514            // Sort by similarity
515            similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
516
517            // Take k _nearest and vote
518            let mut votes: HashMap<OptimizerType, f64> = HashMap::new();
519            for (similarity, optimizer, performance) in similarities.iter().take(k_nearest) {
520                let weight = similarity * performance;
521                *votes.entry(*optimizer).or_insert(0.0) += weight;
522            }
523
524            // Return optimizer with highest weighted vote
525            let best_optimizer = votes
526                .iter()
527                .max_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
528                .map(|(optimizer_, _)| *optimizer_)
529                .unwrap_or(OptimizerType::Adam);
530
531            return Ok(best_optimizer);
532        }
533
534        // Fall back to rule-based
535        self.rule_based_selection(problem)
536    }
537
538    /// Update selector with performance feedback
539    pub fn update_performance(
540        &mut self,
541        optimizer: OptimizerType,
542        metrics: PerformanceMetrics,
543    ) -> Result<()> {
544        // Update performance history
545        self.performance_history
546            .entry(optimizer)
547            .or_default()
548            .push(metrics.clone());
549
550        // Update bandit statistics
551        *self.arm_counts.entry(optimizer).or_insert(0) += 1;
552        *self.arm_rewards.entry(optimizer).or_insert(0.0) += metrics.validation_performance;
553
554        // Store problem-optimizer mapping
555        if let Some(problem) = &self.current_problem {
556            self.problem_optimizer_map
557                .push((problem.clone(), optimizer, metrics.clone()));
558        }
559
560        // Update current performance tracking
561        self.current_performance
562            .push_back(metrics.validation_performance);
563        if self.current_performance.len() > 100 {
564            self.current_performance.pop_front();
565        }
566
567        Ok(())
568    }
569
570    /// Train the selection network if using learning-based strategy
571    pub fn train_selection_network(&mut self, learning_rate: A, epochs: usize) -> Result<()> {
572        if self.problem_optimizer_map.is_empty() {
573            return Ok(()); // No data to train on
574        }
575
576        // Extract features and labels
577        let mut features = Vec::new();
578        let mut labels = Vec::new();
579
580        for (problem, optimizer_, metrics) in &self.problem_optimizer_map {
581            let feature_vec = self.extract_problem_features(problem);
582            features.push(feature_vec);
583
584            // Convert optimizer to label
585            if let Some(label) = self
586                .available_optimizers
587                .iter()
588                .position(|&opt| opt == *optimizer_)
589            {
590                labels.push(label);
591            }
592        }
593
594        // Create network if it doesn't exist
595        if self.selection_network.is_none() {
596            let feature_dim = features[0].len();
597            let num_optimizers = self.available_optimizers.len();
598            self.selection_network = Some(SelectionNetwork::new(feature_dim, 32, num_optimizers));
599        }
600
601        // Train the network
602        if let Some(network) = &mut self.selection_network {
603            network.train(&features, &labels, learning_rate, epochs)?;
604        }
605
606        Ok(())
607    }
608
609    /// Compute similarity between two problems
610    fn compute_problem_similarity(
611        &self,
612        problem1: &ProblemCharacteristics,
613        problem2: &ProblemCharacteristics,
614    ) -> f64 {
615        let mut similarity = 0.0;
616        let mut weight_sum = 0.0;
617
618        // Dataset size similarity (log scale)
619        let size_sim = 1.0
620            - ((problem1.dataset_size as f64).ln() - (problem2.dataset_size as f64).ln()).abs()
621                / 10.0;
622        similarity += size_sim.max(0.0) * 0.2;
623        weight_sum += 0.2;
624
625        // Problem type similarity
626        if problem1.problem_type == problem2.problem_type {
627            similarity += 0.3;
628        }
629        weight_sum += 0.3;
630
631        // Batch size similarity
632        let batch_sim = 1.0
633            - ((problem1.batch_size as f64 - problem2.batch_size as f64).abs() / 256.0).min(1.0);
634        similarity += batch_sim * 0.1;
635        weight_sum += 0.1;
636
637        // Gradient characteristics similarity
638        let sparsity_sim = 1.0 - (problem1.gradient_sparsity - problem2.gradient_sparsity).abs();
639        let noise_sim = 1.0 - (problem1.gradient_noise - problem2.gradient_noise).abs();
640        similarity += (sparsity_sim + noise_sim) * 0.2;
641        weight_sum += 0.4;
642
643        similarity / weight_sum
644    }
645
646    /// Extract numerical features from problem characteristics
647    fn extract_problem_features(&self, problem: &ProblemCharacteristics) -> Array1<A> {
648        Array1::from_vec(vec![
649            A::from((problem.dataset_size as f64).ln()).expect("unwrap failed"),
650            A::from((problem.input_dim as f64).ln()).expect("unwrap failed"),
651            A::from((problem.output_dim as f64).ln()).expect("unwrap failed"),
652            A::from(problem.problem_type as u8 as f64).expect("unwrap failed"),
653            A::from(problem.gradient_sparsity).expect("unwrap failed"),
654            A::from(problem.gradient_noise).expect("unwrap failed"),
655            A::from((problem.memory_budget as f64).ln()).expect("unwrap failed"),
656            A::from(problem.time_budget.ln()).expect("unwrap failed"),
657            A::from((problem.batch_size as f64).ln()).expect("unwrap failed"),
658            A::from(problem.lr_sensitivity).expect("unwrap failed"),
659            A::from(problem.regularization_strength).expect("unwrap failed"),
660        ])
661    }
662
663    /// Get performance statistics for an optimizer
664    pub fn get_optimizer_statistics(
665        &self,
666        optimizer: OptimizerType,
667    ) -> Option<OptimizerStatistics> {
668        if let Some(history) = self.performance_history.get(&optimizer) {
669            if history.is_empty() {
670                return None;
671            }
672
673            let performances: Vec<f64> = history.iter().map(|m| m.validation_performance).collect();
674            let mean = performances.iter().sum::<f64>() / performances.len() as f64;
675            let variance = performances.iter().map(|p| (p - mean).powi(2)).sum::<f64>()
676                / performances.len() as f64;
677            let std_dev = variance.sqrt();
678
679            Some(OptimizerStatistics {
680                optimizer,
681                num_trials: history.len(),
682                mean_performance: mean,
683                std_performance: std_dev,
684                best_performance: performances
685                    .iter()
686                    .copied()
687                    .fold(f64::NEG_INFINITY, f64::max),
688                worst_performance: performances.iter().copied().fold(f64::INFINITY, f64::min),
689                success_rate: performances.iter().filter(|&&p| p > 0.7).count() as f64
690                    / performances.len() as f64,
691            })
692        } else {
693            None
694        }
695    }
696
697    /// Get all optimizer statistics
698    pub fn get_all_statistics(&self) -> Vec<OptimizerStatistics> {
699        self.available_optimizers
700            .iter()
701            .filter_map(|&opt| self.get_optimizer_statistics(opt))
702            .collect()
703    }
704
705    /// Get current confidence in selection
706    pub fn get_selection_confidence(&self) -> f64 {
707        self.last_confidence
708    }
709
710    /// Reset selector state
711    pub fn reset(&mut self) {
712        self.performance_history.clear();
713        self.problem_optimizer_map.clear();
714        self.current_problem = None;
715        for count in self.arm_counts.values_mut() {
716            *count = 0;
717        }
718        for reward in self.arm_rewards.values_mut() {
719            *reward = 0.0;
720        }
721        self.current_performance.clear();
722        self.last_confidence = 0.0;
723    }
724}
725
726/// Statistics for an optimizer's performance
727#[derive(Debug, Clone)]
728pub struct OptimizerStatistics {
729    /// Optimizer type
730    pub optimizer: OptimizerType,
731    /// Number of trials
732    pub num_trials: usize,
733    /// Mean performance
734    pub mean_performance: f64,
735    /// Standard deviation of performance
736    pub std_performance: f64,
737    /// Best performance achieved
738    pub best_performance: f64,
739    /// Worst performance achieved
740    pub worst_performance: f64,
741    /// Success rate (performance > threshold)
742    pub success_rate: f64,
743}
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use approx::assert_relative_eq;
749
750    #[test]
751    fn test_problem_characteristics() {
752        let problem = ProblemCharacteristics {
753            dataset_size: 10000,
754            input_dim: 784,
755            output_dim: 10,
756            problem_type: ProblemType::Classification,
757            gradient_sparsity: 0.1,
758            gradient_noise: 0.05,
759            memory_budget: 1_000_000,
760            time_budget: 3600.0,
761            batch_size: 64,
762            lr_sensitivity: 0.5,
763            regularization_strength: 0.01,
764            architecture_type: Some("CNN".to_string()),
765        };
766
767        assert_eq!(problem.dataset_size, 10000);
768        assert_eq!(problem.problem_type, ProblemType::Classification);
769    }
770
771    #[test]
772    fn test_rule_based_selection() {
773        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
774
775        // Large dataset -> Adam/AdamW
776        let large_problem = ProblemCharacteristics {
777            dataset_size: 100001,
778            input_dim: 224,
779            output_dim: 1000,
780            problem_type: ProblemType::ComputerVision,
781            gradient_sparsity: 0.1,
782            gradient_noise: 0.05,
783            memory_budget: 10_000_000,
784            time_budget: 7200.0,
785            batch_size: 32,
786            lr_sensitivity: 0.5,
787            regularization_strength: 0.01,
788            architecture_type: Some("ResNet".to_string()),
789        };
790
791        selector.set_problem(large_problem);
792        let optimizer = selector.select_optimizer().expect("unwrap failed");
793        assert_eq!(optimizer, OptimizerType::AdamW);
794    }
795
796    #[test]
797    fn test_selection_network() {
798        let network = SelectionNetwork::<f64>::new(5, 10, 3);
799        let features = Array1::from_vec(vec![1.0, 0.5, 2.0, 0.8, 1.5]);
800
801        let probabilities = network.forward(&features).expect("unwrap failed");
802        assert_eq!(probabilities.len(), 3);
803
804        // Probabilities should sum to 1
805        let sum: f64 = probabilities.iter().sum();
806        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
807
808        // All probabilities should be non-negative
809        for &prob in probabilities.iter() {
810            assert!(prob >= 0.0);
811        }
812    }
813
814    #[test]
815    fn test_bandit_selection() {
816        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::Bandit {
817            epsilon: 0.1,
818            confidence: 2.0,
819        });
820
821        let problem = ProblemCharacteristics {
822            dataset_size: 1000,
823            input_dim: 10,
824            output_dim: 2,
825            problem_type: ProblemType::Classification,
826            gradient_sparsity: 0.0,
827            gradient_noise: 0.1,
828            memory_budget: 1_000_000,
829            time_budget: 600.0,
830            batch_size: 32,
831            lr_sensitivity: 0.5,
832            regularization_strength: 0.01,
833            architecture_type: None,
834        };
835
836        selector.set_problem(problem);
837
838        // Should select an optimizer (any is valid initially)
839        let optimizer = selector.select_optimizer().expect("unwrap failed");
840        assert!(selector.available_optimizers.contains(&optimizer));
841    }
842
843    #[test]
844    fn test_performance_update() {
845        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
846
847        let metrics = PerformanceMetrics {
848            final_loss: 0.1,
849            convergence_steps: 100,
850            training_time: 60.0,
851            memory_usage: 500_000,
852            validation_performance: 0.95,
853            stability: 0.02,
854            generalization_gap: 0.05,
855        };
856
857        selector
858            .update_performance(OptimizerType::Adam, metrics)
859            .expect("unwrap failed");
860
861        let stats = selector
862            .get_optimizer_statistics(OptimizerType::Adam)
863            .expect("unwrap failed");
864        assert_eq!(stats.num_trials, 1);
865        assert_relative_eq!(stats.mean_performance, 0.95, epsilon = 1e-6);
866    }
867
868    #[test]
869    fn test_problem_similarity() {
870        let selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
871
872        let problem1 = ProblemCharacteristics {
873            dataset_size: 1000,
874            input_dim: 10,
875            output_dim: 2,
876            problem_type: ProblemType::Classification,
877            gradient_sparsity: 0.1,
878            gradient_noise: 0.05,
879            memory_budget: 1_000_000,
880            time_budget: 600.0,
881            batch_size: 32,
882            lr_sensitivity: 0.5,
883            regularization_strength: 0.01,
884            architecture_type: None,
885        };
886
887        let problem2 = problem1.clone();
888        let similarity = selector.compute_problem_similarity(&problem1, &problem2);
889        assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
890
891        let mut problem3 = problem1.clone();
892        problem3.problem_type = ProblemType::Regression;
893        let similarity = selector.compute_problem_similarity(&problem1, &problem3);
894        assert!(similarity < 1.0);
895    }
896}