optirs_core/adaptive_selection/
mod.rs

1// Adaptive optimization algorithm selection
2//
3// This module provides automatic selection of the most appropriate optimization algorithm
4// based on problem characteristics, performance monitoring, and learned patterns.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Rng};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13/// Types of optimization algorithms available for selection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum OptimizerType {
16    /// Stochastic Gradient Descent
17    SGD,
18    /// SGD with momentum
19    SGDMomentum,
20    /// Adam optimizer
21    Adam,
22    /// AdamW (Adam with decoupled weight decay)
23    AdamW,
24    /// RMSprop optimizer
25    RMSprop,
26    /// AdaGrad optimizer
27    AdaGrad,
28    /// RAdam (Rectified Adam)
29    RAdam,
30    /// Lookahead wrapper
31    Lookahead,
32    /// LAMB (Layer-wise Adaptive Moments)
33    LAMB,
34    /// LARS (Layer-wise Adaptive Rate Scaling)
35    LARS,
36    /// L-BFGS (Limited-memory BFGS)
37    LBFGS,
38    /// SAM (Sharpness-Aware Minimization)
39    SAM,
40}
41
42/// Problem characteristics for optimizer selection
43#[derive(Debug, Clone)]
44pub struct ProblemCharacteristics {
45    /// Dataset size
46    pub dataset_size: usize,
47    /// Input dimensionality
48    pub input_dim: usize,
49    /// Output dimensionality
50    pub output_dim: usize,
51    /// Problem type (classification, regression, etc.)
52    pub problem_type: ProblemType,
53    /// Gradient sparsity (0.0 = dense, 1.0 = very sparse)
54    pub gradient_sparsity: f64,
55    /// Noise level in gradients
56    pub gradient_noise: f64,
57    /// Memory constraints (bytes available)
58    pub memory_budget: usize,
59    /// Computational budget (time constraints)
60    pub time_budget: f64,
61    /// Batch size being used
62    pub batch_size: usize,
63    /// Learning rate range preference
64    pub lr_sensitivity: f64,
65    /// Regularization requirements
66    pub regularization_strength: f64,
67    /// Architecture type (if applicable)
68    pub architecture_type: Option<String>,
69}
70
71/// Types of machine learning problems
72#[derive(Debug, Clone, Copy, PartialEq)]
73pub enum ProblemType {
74    /// Classification task
75    Classification,
76    /// Regression task
77    Regression,
78    /// Unsupervised learning
79    Unsupervised,
80    /// Reinforcement learning
81    ReinforcementLearning,
82    /// Time series forecasting
83    TimeSeries,
84    /// Computer vision
85    ComputerVision,
86    /// Natural language processing
87    NaturalLanguage,
88    /// Recommendation systems
89    Recommendation,
90}
91
92/// Performance metrics for optimizer evaluation
93#[derive(Debug, Clone)]
94pub struct PerformanceMetrics {
95    /// Final loss/error achieved
96    pub final_loss: f64,
97    /// Convergence speed (steps to reach target)
98    pub convergence_steps: usize,
99    /// Training time taken
100    pub training_time: f64,
101    /// Memory usage
102    pub memory_usage: usize,
103    /// Validation performance
104    pub validation_performance: f64,
105    /// Stability (variance in loss)
106    pub stability: f64,
107    /// Generalization (validation - training performance)
108    pub generalization_gap: f64,
109}
110
111/// Selection strategy for adaptive optimization
112#[derive(Debug, Clone)]
113pub enum SelectionStrategy {
114    /// Rule-based selection using expert knowledge
115    RuleBased,
116    /// Learning-based selection using historical data
117    LearningBased,
118    /// Ensemble selection trying multiple optimizers
119    Ensemble {
120        /// Number of optimizers to try
121        num_candidates: usize,
122        /// Number of steps for evaluation
123        evaluation_steps: usize,
124    },
125    /// Bandit-based selection with exploration/exploitation
126    Bandit {
127        /// Exploration parameter
128        epsilon: f64,
129        /// UCB confidence parameter
130        confidence: f64,
131    },
132    /// Meta-learning based selection
133    MetaLearning {
134        /// Feature extractor for problems
135        feature_dim: usize,
136        /// Number of similar problems to consider
137        k_nearest: usize,
138    },
139}
140
141/// Adaptive optimizer selector
142#[derive(Debug)]
143pub struct AdaptiveOptimizerSelector<A: Float> {
144    /// Selection strategy
145    strategy: SelectionStrategy,
146    /// Historical performance data
147    performance_history: HashMap<OptimizerType, Vec<PerformanceMetrics>>,
148    /// Problem-optimizer mapping for learning
149    problem_optimizer_map: Vec<(ProblemCharacteristics, OptimizerType, PerformanceMetrics)>,
150    /// Current problem characteristics
151    current_problem: Option<ProblemCharacteristics>,
152    /// Bandit arm statistics (if using bandit strategy)
153    arm_counts: HashMap<OptimizerType, usize>,
154    arm_rewards: HashMap<OptimizerType, f64>,
155    /// Neural network for learning-based selection
156    selection_network: Option<SelectionNetwork<A>>,
157    /// Available optimizers
158    available_optimizers: Vec<OptimizerType>,
159    /// Performance tracking
160    current_performance: VecDeque<f64>,
161    /// Selection confidence
162    last_confidence: f64,
163}
164
165/// Neural network for optimizer selection
166#[derive(Debug)]
167pub struct SelectionNetwork<A: Float> {
168    /// Input weights (problem features -> hidden)
169    input_weights: Array2<A>,
170    /// Output weights (hidden -> optimizer probabilities)
171    output_weights: Array2<A>,
172    /// Input biases
173    input_bias: Array1<A>,
174    /// Output biases
175    output_bias: Array1<A>,
176    /// Hidden layer size
177    #[allow(dead_code)]
178    hidden_size: usize,
179}
180
181impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
182    SelectionNetwork<A>
183{
184    /// Create a new selection network
185    pub fn new(input_size: usize, hidden_size: usize, num_optimizers: usize) -> Self {
186        let mut rng = thread_rng();
187
188        let input_weights = Array2::from_shape_fn((hidden_size, input_size), |_| {
189            A::from(rng.random::<f64>()).unwrap() * A::from(0.1).unwrap() - A::from(0.05).unwrap()
190        });
191
192        let output_weights = Array2::from_shape_fn((num_optimizers, hidden_size), |_| {
193            A::from(rng.random::<f64>()).unwrap() * A::from(0.1).unwrap() - A::from(0.05).unwrap()
194        });
195
196        let input_bias = Array1::zeros(hidden_size);
197        let output_bias = Array1::zeros(num_optimizers);
198
199        Self {
200            input_weights,
201            output_weights,
202            input_bias,
203            output_bias,
204            hidden_size,
205        }
206    }
207
208    /// Forward pass to get optimizer probabilities
209    pub fn forward(&self, features: &Array1<A>) -> Result<Array1<A>> {
210        // Hidden layer
211        let hidden = self.input_weights.dot(features) + self.input_bias.clone();
212        let hidden_activated = hidden.mapv(|x| {
213            // ReLU activation
214            if x > A::zero() {
215                x
216            } else {
217                A::zero()
218            }
219        });
220
221        // Output layer
222        let output = self.output_weights.dot(&hidden_activated) + &self.output_bias;
223
224        // Softmax activation
225        let max_val = output.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
226        let exp_output = output.mapv(|x| A::exp(x - max_val));
227        let sum_exp = exp_output.sum();
228        let probabilities = exp_output.mapv(|x| x / sum_exp);
229
230        Ok(probabilities)
231    }
232
233    /// Train the network on historical data
234    pub fn train(
235        &mut self,
236        features: &[Array1<A>],
237        optimizer_labels: &[usize],
238        learning_rate: A,
239        epochs: usize,
240    ) -> Result<()> {
241        for _ in 0..epochs {
242            for (feature, &label) in features.iter().zip(optimizer_labels.iter()) {
243                // Forward pass
244                let probabilities = self.forward(feature)?;
245
246                // Compute loss (cross-entropy)
247                let target_prob = probabilities[label];
248                let _loss = -A::ln(target_prob);
249
250                // Backward pass (simplified)
251                let mut output_grad = probabilities;
252                output_grad[label] = output_grad[label] - A::one();
253
254                // Update weights (simplified gradient descent)
255                let hidden = self.input_weights.dot(feature) + self.input_bias.clone();
256                let hidden_activated = hidden.mapv(|x| if x > A::zero() { x } else { A::zero() });
257
258                // Update output weights
259                for i in 0..self.output_weights.nrows() {
260                    for j in 0..self.output_weights.ncols() {
261                        self.output_weights[[i, j]] = self.output_weights[[i, j]]
262                            - learning_rate * output_grad[i] * hidden_activated[j];
263                    }
264                }
265
266                // Update output bias
267                for i in 0..self.output_bias.len() {
268                    self.output_bias[i] = self.output_bias[i] - learning_rate * output_grad[i];
269                }
270            }
271        }
272        Ok(())
273    }
274}
275
276impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
277    AdaptiveOptimizerSelector<A>
278{
279    /// Create a new adaptive optimizer selector
280    pub fn new(strategy: SelectionStrategy) -> Self {
281        let available_optimizers = vec![
282            OptimizerType::SGD,
283            OptimizerType::SGDMomentum,
284            OptimizerType::Adam,
285            OptimizerType::AdamW,
286            OptimizerType::RMSprop,
287            OptimizerType::AdaGrad,
288            OptimizerType::RAdam,
289            OptimizerType::LAMB,
290        ];
291
292        let mut arm_counts = HashMap::new();
293        let mut arm_rewards = HashMap::new();
294        for &optimizer in &available_optimizers {
295            arm_counts.insert(optimizer, 0);
296            arm_rewards.insert(optimizer, 0.0);
297        }
298
299        Self {
300            strategy,
301            performance_history: HashMap::new(),
302            problem_optimizer_map: Vec::new(),
303            current_problem: None,
304            arm_counts,
305            arm_rewards,
306            selection_network: None,
307            available_optimizers,
308            current_performance: VecDeque::new(),
309            last_confidence: 0.0,
310        }
311    }
312
313    /// Set the current problem characteristics
314    pub fn set_problem(&mut self, problem: ProblemCharacteristics) {
315        self.current_problem = Some(problem);
316    }
317
318    /// Select the best optimizer for the current problem
319    pub fn select_optimizer(&mut self) -> Result<OptimizerType> {
320        let problem = self.current_problem.clone().ok_or_else(|| {
321            OptimError::InvalidConfig("No problem characteristics set".to_string())
322        })?;
323
324        match &self.strategy {
325            SelectionStrategy::RuleBased => self.rule_based_selection(&problem),
326            SelectionStrategy::LearningBased => self.learning_based_selection(&problem),
327            SelectionStrategy::Ensemble {
328                num_candidates,
329                evaluation_steps,
330            } => self.ensemble_selection(&problem, *num_candidates, *evaluation_steps),
331            SelectionStrategy::Bandit {
332                epsilon,
333                confidence,
334            } => self.bandit_selection(&problem, *epsilon, *confidence),
335            SelectionStrategy::MetaLearning {
336                feature_dim,
337                k_nearest,
338            } => self.meta_learning_selection(&problem, *feature_dim),
339        }
340    }
341
342    /// Rule-based optimizer selection using expert knowledge
343    fn rule_based_selection(&self, problem: &ProblemCharacteristics) -> Result<OptimizerType> {
344        // Large dataset, use adaptive optimizers
345        if problem.dataset_size > 100000 {
346            match problem.problem_type {
347                ProblemType::ComputerVision => return Ok(OptimizerType::AdamW),
348                ProblemType::NaturalLanguage => return Ok(OptimizerType::AdamW),
349                _ => return Ok(OptimizerType::Adam),
350            }
351        }
352
353        // Small dataset, use SGD with momentum
354        if problem.dataset_size < 1000 {
355            return Ok(OptimizerType::LBFGS);
356        }
357
358        // Sparse gradients
359        if problem.gradient_sparsity > 0.5 {
360            return Ok(OptimizerType::AdaGrad);
361        }
362
363        // Large batch training
364        if problem.batch_size > 256 {
365            return Ok(OptimizerType::LAMB);
366        }
367
368        // Memory constrained
369        if problem.memory_budget < 1_000_000 {
370            return Ok(OptimizerType::SGD);
371        }
372
373        // High noise
374        if problem.gradient_noise > 0.3 {
375            return Ok(OptimizerType::RMSprop);
376        }
377
378        // Default choice
379        Ok(OptimizerType::Adam)
380    }
381
382    /// Learning-based selection using historical performance
383    fn learning_based_selection(
384        &mut self,
385        problem: &ProblemCharacteristics,
386    ) -> Result<OptimizerType> {
387        if self.problem_optimizer_map.is_empty() {
388            // No historical data, fall back to rule-based
389            return self.rule_based_selection(problem);
390        }
391
392        // Find most similar problem in history
393        let mut best_similarity = -1.0;
394        let mut best_optimizer = OptimizerType::Adam;
395
396        for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
397            let similarity = self.compute_problem_similarity(problem, hist_problem);
398
399            // Weight by performance
400            let weighted_similarity = similarity * metrics.validation_performance;
401
402            if weighted_similarity > best_similarity {
403                best_similarity = weighted_similarity;
404                best_optimizer = *optimizer;
405            }
406        }
407
408        self.last_confidence = best_similarity;
409        Ok(best_optimizer)
410    }
411
412    /// Ensemble selection by trying multiple optimizers
413    fn ensemble_selection(
414        &self,
415        problem: &ProblemCharacteristics,
416        num_candidates: usize,
417        _evaluation_steps: usize,
418    ) -> Result<OptimizerType> {
419        // Select top _candidates based on historical performance
420        let mut candidates = self.available_optimizers.clone();
421        candidates.truncate(num_candidates.min(candidates.len()));
422
423        // For simplicity, return the first candidate
424        // In practice, you would evaluate each for evaluation_steps
425        Ok(candidates[0])
426    }
427
428    /// Bandit-based selection with epsilon-greedy strategy
429    fn bandit_selection(
430        &self,
431        problem: &ProblemCharacteristics,
432        epsilon: f64,
433        confidence: f64,
434    ) -> Result<OptimizerType> {
435        let mut rng = thread_rng();
436
437        // Epsilon-greedy exploration
438        if rng.random::<f64>() < epsilon {
439            // Explore: random selection
440            let idx = rng.gen_range(0..self.available_optimizers.len());
441            return Ok(self.available_optimizers[idx]);
442        }
443
444        // Exploit: UCB (Upper Confidence Bound) selection
445        let mut best_ucb = f64::NEG_INFINITY;
446        let mut best_optimizer = OptimizerType::Adam;
447        let total_counts: usize = self.arm_counts.values().sum();
448
449        for &optimizer in &self.available_optimizers {
450            let count = self.arm_counts[&optimizer] as f64;
451            let reward = if count > 0.0 {
452                self.arm_rewards[&optimizer] / count
453            } else {
454                0.0
455            };
456
457            let ucb = if count > 0.0 {
458                reward + confidence * ((total_counts as f64).ln() / count).sqrt()
459            } else {
460                f64::INFINITY // Prefer unvisited arms
461            };
462
463            if ucb > best_ucb {
464                best_ucb = ucb;
465                best_optimizer = optimizer;
466            }
467        }
468
469        Ok(best_optimizer)
470    }
471
472    /// Meta-learning based selection
473    fn meta_learning_selection(
474        &mut self,
475        problem: &ProblemCharacteristics,
476        k_nearest: usize,
477    ) -> Result<OptimizerType> {
478        // Extract features from problem
479        let features = self.extract_problem_features(problem);
480
481        // If we have a trained network, use it
482        if let Some(network) = &self.selection_network {
483            let probabilities = network.forward(&features)?;
484
485            // Select optimizer with highest probability
486            let mut best_prob = A::neg_infinity();
487            let mut best_idx = 0;
488
489            for (i, &prob) in probabilities.iter().enumerate() {
490                if prob > best_prob {
491                    best_prob = prob;
492                    best_idx = i;
493                }
494            }
495
496            if best_idx < self.available_optimizers.len() {
497                return Ok(self.available_optimizers[best_idx]);
498            }
499        }
500
501        // k-NN fallback
502        if self.problem_optimizer_map.len() >= k_nearest {
503            let mut similarities = Vec::new();
504
505            for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
506                let similarity = self.compute_problem_similarity(problem, hist_problem);
507                similarities.push((similarity, *optimizer, metrics.validation_performance));
508            }
509
510            // Sort by similarity
511            similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
512
513            // Take k _nearest and vote
514            let mut votes: HashMap<OptimizerType, f64> = HashMap::new();
515            for (similarity, optimizer, performance) in similarities.iter().take(k_nearest) {
516                let weight = similarity * performance;
517                *votes.entry(*optimizer).or_insert(0.0) += weight;
518            }
519
520            // Return optimizer with highest weighted vote
521            let best_optimizer = votes
522                .iter()
523                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
524                .map(|(optimizer_, _)| *optimizer_)
525                .unwrap_or(OptimizerType::Adam);
526
527            return Ok(best_optimizer);
528        }
529
530        // Fall back to rule-based
531        self.rule_based_selection(problem)
532    }
533
534    /// Update selector with performance feedback
535    pub fn update_performance(
536        &mut self,
537        optimizer: OptimizerType,
538        metrics: PerformanceMetrics,
539    ) -> Result<()> {
540        // Update performance history
541        self.performance_history
542            .entry(optimizer)
543            .or_default()
544            .push(metrics.clone());
545
546        // Update bandit statistics
547        *self.arm_counts.entry(optimizer).or_insert(0) += 1;
548        *self.arm_rewards.entry(optimizer).or_insert(0.0) += metrics.validation_performance;
549
550        // Store problem-optimizer mapping
551        if let Some(problem) = &self.current_problem {
552            self.problem_optimizer_map
553                .push((problem.clone(), optimizer, metrics.clone()));
554        }
555
556        // Update current performance tracking
557        self.current_performance
558            .push_back(metrics.validation_performance);
559        if self.current_performance.len() > 100 {
560            self.current_performance.pop_front();
561        }
562
563        Ok(())
564    }
565
566    /// Train the selection network if using learning-based strategy
567    pub fn train_selection_network(&mut self, learning_rate: A, epochs: usize) -> Result<()> {
568        if self.problem_optimizer_map.is_empty() {
569            return Ok(()); // No data to train on
570        }
571
572        // Extract features and labels
573        let mut features = Vec::new();
574        let mut labels = Vec::new();
575
576        for (problem, optimizer_, metrics) in &self.problem_optimizer_map {
577            let feature_vec = self.extract_problem_features(problem);
578            features.push(feature_vec);
579
580            // Convert optimizer to label
581            if let Some(label) = self
582                .available_optimizers
583                .iter()
584                .position(|&opt| opt == *optimizer_)
585            {
586                labels.push(label);
587            }
588        }
589
590        // Create network if it doesn't exist
591        if self.selection_network.is_none() {
592            let feature_dim = features[0].len();
593            let num_optimizers = self.available_optimizers.len();
594            self.selection_network = Some(SelectionNetwork::new(feature_dim, 32, num_optimizers));
595        }
596
597        // Train the network
598        if let Some(network) = &mut self.selection_network {
599            network.train(&features, &labels, learning_rate, epochs)?;
600        }
601
602        Ok(())
603    }
604
605    /// Compute similarity between two problems
606    fn compute_problem_similarity(
607        &self,
608        problem1: &ProblemCharacteristics,
609        problem2: &ProblemCharacteristics,
610    ) -> f64 {
611        let mut similarity = 0.0;
612        let mut weight_sum = 0.0;
613
614        // Dataset size similarity (log scale)
615        let size_sim = 1.0
616            - ((problem1.dataset_size as f64).ln() - (problem2.dataset_size as f64).ln()).abs()
617                / 10.0;
618        similarity += size_sim.max(0.0) * 0.2;
619        weight_sum += 0.2;
620
621        // Problem type similarity
622        if problem1.problem_type == problem2.problem_type {
623            similarity += 0.3;
624        }
625        weight_sum += 0.3;
626
627        // Batch size similarity
628        let batch_sim = 1.0
629            - ((problem1.batch_size as f64 - problem2.batch_size as f64).abs() / 256.0).min(1.0);
630        similarity += batch_sim * 0.1;
631        weight_sum += 0.1;
632
633        // Gradient characteristics similarity
634        let sparsity_sim = 1.0 - (problem1.gradient_sparsity - problem2.gradient_sparsity).abs();
635        let noise_sim = 1.0 - (problem1.gradient_noise - problem2.gradient_noise).abs();
636        similarity += (sparsity_sim + noise_sim) * 0.2;
637        weight_sum += 0.4;
638
639        similarity / weight_sum
640    }
641
642    /// Extract numerical features from problem characteristics
643    fn extract_problem_features(&self, problem: &ProblemCharacteristics) -> Array1<A> {
644        Array1::from_vec(vec![
645            A::from((problem.dataset_size as f64).ln()).unwrap(),
646            A::from((problem.input_dim as f64).ln()).unwrap(),
647            A::from((problem.output_dim as f64).ln()).unwrap(),
648            A::from(problem.problem_type as u8 as f64).unwrap(),
649            A::from(problem.gradient_sparsity).unwrap(),
650            A::from(problem.gradient_noise).unwrap(),
651            A::from((problem.memory_budget as f64).ln()).unwrap(),
652            A::from(problem.time_budget.ln()).unwrap(),
653            A::from((problem.batch_size as f64).ln()).unwrap(),
654            A::from(problem.lr_sensitivity).unwrap(),
655            A::from(problem.regularization_strength).unwrap(),
656        ])
657    }
658
659    /// Get performance statistics for an optimizer
660    pub fn get_optimizer_statistics(
661        &self,
662        optimizer: OptimizerType,
663    ) -> Option<OptimizerStatistics> {
664        if let Some(history) = self.performance_history.get(&optimizer) {
665            if history.is_empty() {
666                return None;
667            }
668
669            let performances: Vec<f64> = history.iter().map(|m| m.validation_performance).collect();
670            let mean = performances.iter().sum::<f64>() / performances.len() as f64;
671            let variance = performances.iter().map(|p| (p - mean).powi(2)).sum::<f64>()
672                / performances.len() as f64;
673            let std_dev = variance.sqrt();
674
675            Some(OptimizerStatistics {
676                optimizer,
677                num_trials: history.len(),
678                mean_performance: mean,
679                std_performance: std_dev,
680                best_performance: performances
681                    .iter()
682                    .copied()
683                    .fold(f64::NEG_INFINITY, f64::max),
684                worst_performance: performances.iter().copied().fold(f64::INFINITY, f64::min),
685                success_rate: performances.iter().filter(|&&p| p > 0.7).count() as f64
686                    / performances.len() as f64,
687            })
688        } else {
689            None
690        }
691    }
692
693    /// Get all optimizer statistics
694    pub fn get_all_statistics(&self) -> Vec<OptimizerStatistics> {
695        self.available_optimizers
696            .iter()
697            .filter_map(|&opt| self.get_optimizer_statistics(opt))
698            .collect()
699    }
700
701    /// Get current confidence in selection
702    pub fn get_selection_confidence(&self) -> f64 {
703        self.last_confidence
704    }
705
706    /// Reset selector state
707    pub fn reset(&mut self) {
708        self.performance_history.clear();
709        self.problem_optimizer_map.clear();
710        self.current_problem = None;
711        for count in self.arm_counts.values_mut() {
712            *count = 0;
713        }
714        for reward in self.arm_rewards.values_mut() {
715            *reward = 0.0;
716        }
717        self.current_performance.clear();
718        self.last_confidence = 0.0;
719    }
720}
721
722/// Statistics for an optimizer's performance
723#[derive(Debug, Clone)]
724pub struct OptimizerStatistics {
725    /// Optimizer type
726    pub optimizer: OptimizerType,
727    /// Number of trials
728    pub num_trials: usize,
729    /// Mean performance
730    pub mean_performance: f64,
731    /// Standard deviation of performance
732    pub std_performance: f64,
733    /// Best performance achieved
734    pub best_performance: f64,
735    /// Worst performance achieved
736    pub worst_performance: f64,
737    /// Success rate (performance > threshold)
738    pub success_rate: f64,
739}
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744    use approx::assert_relative_eq;
745
746    #[test]
747    fn test_problem_characteristics() {
748        let problem = ProblemCharacteristics {
749            dataset_size: 10000,
750            input_dim: 784,
751            output_dim: 10,
752            problem_type: ProblemType::Classification,
753            gradient_sparsity: 0.1,
754            gradient_noise: 0.05,
755            memory_budget: 1_000_000,
756            time_budget: 3600.0,
757            batch_size: 64,
758            lr_sensitivity: 0.5,
759            regularization_strength: 0.01,
760            architecture_type: Some("CNN".to_string()),
761        };
762
763        assert_eq!(problem.dataset_size, 10000);
764        assert_eq!(problem.problem_type, ProblemType::Classification);
765    }
766
767    #[test]
768    fn test_rule_based_selection() {
769        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
770
771        // Large dataset -> Adam/AdamW
772        let large_problem = ProblemCharacteristics {
773            dataset_size: 100001,
774            input_dim: 224,
775            output_dim: 1000,
776            problem_type: ProblemType::ComputerVision,
777            gradient_sparsity: 0.1,
778            gradient_noise: 0.05,
779            memory_budget: 10_000_000,
780            time_budget: 7200.0,
781            batch_size: 32,
782            lr_sensitivity: 0.5,
783            regularization_strength: 0.01,
784            architecture_type: Some("ResNet".to_string()),
785        };
786
787        selector.set_problem(large_problem);
788        let optimizer = selector.select_optimizer().unwrap();
789        assert_eq!(optimizer, OptimizerType::AdamW);
790    }
791
792    #[test]
793    fn test_selection_network() {
794        let network = SelectionNetwork::<f64>::new(5, 10, 3);
795        let features = Array1::from_vec(vec![1.0, 0.5, 2.0, 0.8, 1.5]);
796
797        let probabilities = network.forward(&features).unwrap();
798        assert_eq!(probabilities.len(), 3);
799
800        // Probabilities should sum to 1
801        let sum: f64 = probabilities.iter().sum();
802        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
803
804        // All probabilities should be non-negative
805        for &prob in probabilities.iter() {
806            assert!(prob >= 0.0);
807        }
808    }
809
810    #[test]
811    fn test_bandit_selection() {
812        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::Bandit {
813            epsilon: 0.1,
814            confidence: 2.0,
815        });
816
817        let problem = ProblemCharacteristics {
818            dataset_size: 1000,
819            input_dim: 10,
820            output_dim: 2,
821            problem_type: ProblemType::Classification,
822            gradient_sparsity: 0.0,
823            gradient_noise: 0.1,
824            memory_budget: 1_000_000,
825            time_budget: 600.0,
826            batch_size: 32,
827            lr_sensitivity: 0.5,
828            regularization_strength: 0.01,
829            architecture_type: None,
830        };
831
832        selector.set_problem(problem);
833
834        // Should select an optimizer (any is valid initially)
835        let optimizer = selector.select_optimizer().unwrap();
836        assert!(selector.available_optimizers.contains(&optimizer));
837    }
838
839    #[test]
840    fn test_performance_update() {
841        let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
842
843        let metrics = PerformanceMetrics {
844            final_loss: 0.1,
845            convergence_steps: 100,
846            training_time: 60.0,
847            memory_usage: 500_000,
848            validation_performance: 0.95,
849            stability: 0.02,
850            generalization_gap: 0.05,
851        };
852
853        selector
854            .update_performance(OptimizerType::Adam, metrics)
855            .unwrap();
856
857        let stats = selector
858            .get_optimizer_statistics(OptimizerType::Adam)
859            .unwrap();
860        assert_eq!(stats.num_trials, 1);
861        assert_relative_eq!(stats.mean_performance, 0.95, epsilon = 1e-6);
862    }
863
864    #[test]
865    fn test_problem_similarity() {
866        let selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
867
868        let problem1 = ProblemCharacteristics {
869            dataset_size: 1000,
870            input_dim: 10,
871            output_dim: 2,
872            problem_type: ProblemType::Classification,
873            gradient_sparsity: 0.1,
874            gradient_noise: 0.05,
875            memory_budget: 1_000_000,
876            time_budget: 600.0,
877            batch_size: 32,
878            lr_sensitivity: 0.5,
879            regularization_strength: 0.01,
880            architecture_type: None,
881        };
882
883        let problem2 = problem1.clone();
884        let similarity = selector.compute_problem_similarity(&problem1, &problem2);
885        assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
886
887        let mut problem3 = problem1.clone();
888        problem3.problem_type = ProblemType::Regression;
889        let similarity = selector.compute_problem_similarity(&problem1, &problem3);
890        assert!(similarity < 1.0);
891    }
892}