use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Rng};
use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OptimizerType {
SGD,
SGDMomentum,
Adam,
AdamW,
RMSprop,
AdaGrad,
RAdam,
Lookahead,
LAMB,
LARS,
LBFGS,
SAM,
}
#[derive(Debug, Clone)]
pub struct ProblemCharacteristics {
pub dataset_size: usize,
pub input_dim: usize,
pub output_dim: usize,
pub problem_type: ProblemType,
pub gradient_sparsity: f64,
pub gradient_noise: f64,
pub memory_budget: usize,
pub time_budget: f64,
pub batch_size: usize,
pub lr_sensitivity: f64,
pub regularization_strength: f64,
pub architecture_type: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ProblemType {
Classification,
Regression,
Unsupervised,
ReinforcementLearning,
TimeSeries,
ComputerVision,
NaturalLanguage,
Recommendation,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub final_loss: f64,
pub convergence_steps: usize,
pub training_time: f64,
pub memory_usage: usize,
pub validation_performance: f64,
pub stability: f64,
pub generalization_gap: f64,
}
#[derive(Debug, Clone)]
pub enum SelectionStrategy {
RuleBased,
LearningBased,
Ensemble {
num_candidates: usize,
evaluation_steps: usize,
},
Bandit {
epsilon: f64,
confidence: f64,
},
MetaLearning {
feature_dim: usize,
k_nearest: usize,
},
}
#[derive(Debug)]
pub struct AdaptiveOptimizerSelector<A: Float> {
strategy: SelectionStrategy,
performance_history: HashMap<OptimizerType, Vec<PerformanceMetrics>>,
problem_optimizer_map: Vec<(ProblemCharacteristics, OptimizerType, PerformanceMetrics)>,
current_problem: Option<ProblemCharacteristics>,
arm_counts: HashMap<OptimizerType, usize>,
arm_rewards: HashMap<OptimizerType, f64>,
selection_network: Option<SelectionNetwork<A>>,
available_optimizers: Vec<OptimizerType>,
current_performance: VecDeque<f64>,
last_confidence: f64,
}
#[derive(Debug)]
pub struct SelectionNetwork<A: Float> {
input_weights: Array2<A>,
output_weights: Array2<A>,
input_bias: Array1<A>,
output_bias: Array1<A>,
#[allow(dead_code)]
hidden_size: usize,
}
impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
SelectionNetwork<A>
{
pub fn new(input_size: usize, hidden_size: usize, num_optimizers: usize) -> Self {
let mut rng = thread_rng();
let input_weights = Array2::from_shape_fn((hidden_size, input_size), |_| {
A::from(rng.random::<f64>()).expect("unwrap failed")
* A::from(0.1).expect("unwrap failed")
- A::from(0.05).expect("unwrap failed")
});
let output_weights = Array2::from_shape_fn((num_optimizers, hidden_size), |_| {
A::from(rng.random::<f64>()).expect("unwrap failed")
* A::from(0.1).expect("unwrap failed")
- A::from(0.05).expect("unwrap failed")
});
let input_bias = Array1::zeros(hidden_size);
let output_bias = Array1::zeros(num_optimizers);
Self {
input_weights,
output_weights,
input_bias,
output_bias,
hidden_size,
}
}
pub fn forward(&self, features: &Array1<A>) -> Result<Array1<A>> {
let hidden = self.input_weights.dot(features) + self.input_bias.clone();
let hidden_activated = hidden.mapv(|x| {
if x > A::zero() {
x
} else {
A::zero()
}
});
let output = self.output_weights.dot(&hidden_activated) + &self.output_bias;
let max_val = output.iter().fold(A::neg_infinity(), |a, &b| A::max(a, b));
let exp_output = output.mapv(|x| A::exp(x - max_val));
let sum_exp = exp_output.sum();
let probabilities = exp_output.mapv(|x| x / sum_exp);
Ok(probabilities)
}
pub fn train(
&mut self,
features: &[Array1<A>],
optimizer_labels: &[usize],
learning_rate: A,
epochs: usize,
) -> Result<()> {
for _ in 0..epochs {
for (feature, &label) in features.iter().zip(optimizer_labels.iter()) {
let probabilities = self.forward(feature)?;
let target_prob = probabilities[label];
let _loss = -A::ln(target_prob);
let mut output_grad = probabilities;
output_grad[label] = output_grad[label] - A::one();
let hidden = self.input_weights.dot(feature) + self.input_bias.clone();
let hidden_activated = hidden.mapv(|x| if x > A::zero() { x } else { A::zero() });
for i in 0..self.output_weights.nrows() {
for j in 0..self.output_weights.ncols() {
self.output_weights[[i, j]] = self.output_weights[[i, j]]
- learning_rate * output_grad[i] * hidden_activated[j];
}
}
for i in 0..self.output_bias.len() {
self.output_bias[i] = self.output_bias[i] - learning_rate * output_grad[i];
}
}
}
Ok(())
}
}
impl<A: Float + ScalarOperand + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
AdaptiveOptimizerSelector<A>
{
pub fn new(strategy: SelectionStrategy) -> Self {
let available_optimizers = vec![
OptimizerType::SGD,
OptimizerType::SGDMomentum,
OptimizerType::Adam,
OptimizerType::AdamW,
OptimizerType::RMSprop,
OptimizerType::AdaGrad,
OptimizerType::RAdam,
OptimizerType::LAMB,
];
let mut arm_counts = HashMap::new();
let mut arm_rewards = HashMap::new();
for &optimizer in &available_optimizers {
arm_counts.insert(optimizer, 0);
arm_rewards.insert(optimizer, 0.0);
}
Self {
strategy,
performance_history: HashMap::new(),
problem_optimizer_map: Vec::new(),
current_problem: None,
arm_counts,
arm_rewards,
selection_network: None,
available_optimizers,
current_performance: VecDeque::new(),
last_confidence: 0.0,
}
}
pub fn set_problem(&mut self, problem: ProblemCharacteristics) {
self.current_problem = Some(problem);
}
pub fn select_optimizer(&mut self) -> Result<OptimizerType> {
let problem = self.current_problem.clone().ok_or_else(|| {
OptimError::InvalidConfig("No problem characteristics set".to_string())
})?;
match &self.strategy {
SelectionStrategy::RuleBased => self.rule_based_selection(&problem),
SelectionStrategy::LearningBased => self.learning_based_selection(&problem),
SelectionStrategy::Ensemble {
num_candidates,
evaluation_steps,
} => self.ensemble_selection(&problem, *num_candidates, *evaluation_steps),
SelectionStrategy::Bandit {
epsilon,
confidence,
} => self.bandit_selection(&problem, *epsilon, *confidence),
SelectionStrategy::MetaLearning {
feature_dim,
k_nearest,
} => self.meta_learning_selection(&problem, *feature_dim),
}
}
fn rule_based_selection(&self, problem: &ProblemCharacteristics) -> Result<OptimizerType> {
if problem.dataset_size > 100000 {
match problem.problem_type {
ProblemType::ComputerVision => return Ok(OptimizerType::AdamW),
ProblemType::NaturalLanguage => return Ok(OptimizerType::AdamW),
_ => return Ok(OptimizerType::Adam),
}
}
if problem.dataset_size < 1000 {
return Ok(OptimizerType::LBFGS);
}
if problem.gradient_sparsity > 0.5 {
return Ok(OptimizerType::AdaGrad);
}
if problem.batch_size > 256 {
return Ok(OptimizerType::LAMB);
}
if problem.memory_budget < 1_000_000 {
return Ok(OptimizerType::SGD);
}
if problem.gradient_noise > 0.3 {
return Ok(OptimizerType::RMSprop);
}
Ok(OptimizerType::Adam)
}
fn learning_based_selection(
&mut self,
problem: &ProblemCharacteristics,
) -> Result<OptimizerType> {
if self.problem_optimizer_map.is_empty() {
return self.rule_based_selection(problem);
}
let mut best_similarity = -1.0;
let mut best_optimizer = OptimizerType::Adam;
for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
let similarity = self.compute_problem_similarity(problem, hist_problem);
let weighted_similarity = similarity * metrics.validation_performance;
if weighted_similarity > best_similarity {
best_similarity = weighted_similarity;
best_optimizer = *optimizer;
}
}
self.last_confidence = best_similarity;
Ok(best_optimizer)
}
fn ensemble_selection(
&self,
problem: &ProblemCharacteristics,
num_candidates: usize,
_evaluation_steps: usize,
) -> Result<OptimizerType> {
let mut candidates = self.available_optimizers.clone();
candidates.truncate(num_candidates.min(candidates.len()));
Ok(candidates[0])
}
fn bandit_selection(
&self,
problem: &ProblemCharacteristics,
epsilon: f64,
confidence: f64,
) -> Result<OptimizerType> {
let mut rng = thread_rng();
if rng.random::<f64>() < epsilon {
let idx = rng.gen_range(0..self.available_optimizers.len());
return Ok(self.available_optimizers[idx]);
}
let mut best_ucb = f64::NEG_INFINITY;
let mut best_optimizer = OptimizerType::Adam;
let total_counts: usize = self.arm_counts.values().sum();
for &optimizer in &self.available_optimizers {
let count = self.arm_counts[&optimizer] as f64;
let reward = if count > 0.0 {
self.arm_rewards[&optimizer] / count
} else {
0.0
};
let ucb = if count > 0.0 {
reward + confidence * ((total_counts as f64).ln() / count).sqrt()
} else {
f64::INFINITY };
if ucb > best_ucb {
best_ucb = ucb;
best_optimizer = optimizer;
}
}
Ok(best_optimizer)
}
fn meta_learning_selection(
&mut self,
problem: &ProblemCharacteristics,
k_nearest: usize,
) -> Result<OptimizerType> {
let features = self.extract_problem_features(problem);
if let Some(network) = &self.selection_network {
let probabilities = network.forward(&features)?;
let mut best_prob = A::neg_infinity();
let mut best_idx = 0;
for (i, &prob) in probabilities.iter().enumerate() {
if prob > best_prob {
best_prob = prob;
best_idx = i;
}
}
if best_idx < self.available_optimizers.len() {
return Ok(self.available_optimizers[best_idx]);
}
}
if self.problem_optimizer_map.len() >= k_nearest {
let mut similarities = Vec::new();
for (hist_problem, optimizer, metrics) in &self.problem_optimizer_map {
let similarity = self.compute_problem_similarity(problem, hist_problem);
similarities.push((similarity, *optimizer, metrics.validation_performance));
}
similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
let mut votes: HashMap<OptimizerType, f64> = HashMap::new();
for (similarity, optimizer, performance) in similarities.iter().take(k_nearest) {
let weight = similarity * performance;
*votes.entry(*optimizer).or_insert(0.0) += weight;
}
let best_optimizer = votes
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
.map(|(optimizer_, _)| *optimizer_)
.unwrap_or(OptimizerType::Adam);
return Ok(best_optimizer);
}
self.rule_based_selection(problem)
}
pub fn update_performance(
&mut self,
optimizer: OptimizerType,
metrics: PerformanceMetrics,
) -> Result<()> {
self.performance_history
.entry(optimizer)
.or_default()
.push(metrics.clone());
*self.arm_counts.entry(optimizer).or_insert(0) += 1;
*self.arm_rewards.entry(optimizer).or_insert(0.0) += metrics.validation_performance;
if let Some(problem) = &self.current_problem {
self.problem_optimizer_map
.push((problem.clone(), optimizer, metrics.clone()));
}
self.current_performance
.push_back(metrics.validation_performance);
if self.current_performance.len() > 100 {
self.current_performance.pop_front();
}
Ok(())
}
pub fn train_selection_network(&mut self, learning_rate: A, epochs: usize) -> Result<()> {
if self.problem_optimizer_map.is_empty() {
return Ok(()); }
let mut features = Vec::new();
let mut labels = Vec::new();
for (problem, optimizer_, metrics) in &self.problem_optimizer_map {
let feature_vec = self.extract_problem_features(problem);
features.push(feature_vec);
if let Some(label) = self
.available_optimizers
.iter()
.position(|&opt| opt == *optimizer_)
{
labels.push(label);
}
}
if self.selection_network.is_none() {
let feature_dim = features[0].len();
let num_optimizers = self.available_optimizers.len();
self.selection_network = Some(SelectionNetwork::new(feature_dim, 32, num_optimizers));
}
if let Some(network) = &mut self.selection_network {
network.train(&features, &labels, learning_rate, epochs)?;
}
Ok(())
}
fn compute_problem_similarity(
&self,
problem1: &ProblemCharacteristics,
problem2: &ProblemCharacteristics,
) -> f64 {
let mut similarity = 0.0;
let mut weight_sum = 0.0;
let size_sim = 1.0
- ((problem1.dataset_size as f64).ln() - (problem2.dataset_size as f64).ln()).abs()
/ 10.0;
similarity += size_sim.max(0.0) * 0.2;
weight_sum += 0.2;
if problem1.problem_type == problem2.problem_type {
similarity += 0.3;
}
weight_sum += 0.3;
let batch_sim = 1.0
- ((problem1.batch_size as f64 - problem2.batch_size as f64).abs() / 256.0).min(1.0);
similarity += batch_sim * 0.1;
weight_sum += 0.1;
let sparsity_sim = 1.0 - (problem1.gradient_sparsity - problem2.gradient_sparsity).abs();
let noise_sim = 1.0 - (problem1.gradient_noise - problem2.gradient_noise).abs();
similarity += (sparsity_sim + noise_sim) * 0.2;
weight_sum += 0.4;
similarity / weight_sum
}
fn extract_problem_features(&self, problem: &ProblemCharacteristics) -> Array1<A> {
Array1::from_vec(vec![
A::from((problem.dataset_size as f64).ln()).expect("unwrap failed"),
A::from((problem.input_dim as f64).ln()).expect("unwrap failed"),
A::from((problem.output_dim as f64).ln()).expect("unwrap failed"),
A::from(problem.problem_type as u8 as f64).expect("unwrap failed"),
A::from(problem.gradient_sparsity).expect("unwrap failed"),
A::from(problem.gradient_noise).expect("unwrap failed"),
A::from((problem.memory_budget as f64).ln()).expect("unwrap failed"),
A::from(problem.time_budget.ln()).expect("unwrap failed"),
A::from((problem.batch_size as f64).ln()).expect("unwrap failed"),
A::from(problem.lr_sensitivity).expect("unwrap failed"),
A::from(problem.regularization_strength).expect("unwrap failed"),
])
}
pub fn get_optimizer_statistics(
&self,
optimizer: OptimizerType,
) -> Option<OptimizerStatistics> {
if let Some(history) = self.performance_history.get(&optimizer) {
if history.is_empty() {
return None;
}
let performances: Vec<f64> = history.iter().map(|m| m.validation_performance).collect();
let mean = performances.iter().sum::<f64>() / performances.len() as f64;
let variance = performances.iter().map(|p| (p - mean).powi(2)).sum::<f64>()
/ performances.len() as f64;
let std_dev = variance.sqrt();
Some(OptimizerStatistics {
optimizer,
num_trials: history.len(),
mean_performance: mean,
std_performance: std_dev,
best_performance: performances
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max),
worst_performance: performances.iter().copied().fold(f64::INFINITY, f64::min),
success_rate: performances.iter().filter(|&&p| p > 0.7).count() as f64
/ performances.len() as f64,
})
} else {
None
}
}
pub fn get_all_statistics(&self) -> Vec<OptimizerStatistics> {
self.available_optimizers
.iter()
.filter_map(|&opt| self.get_optimizer_statistics(opt))
.collect()
}
pub fn get_selection_confidence(&self) -> f64 {
self.last_confidence
}
pub fn reset(&mut self) {
self.performance_history.clear();
self.problem_optimizer_map.clear();
self.current_problem = None;
for count in self.arm_counts.values_mut() {
*count = 0;
}
for reward in self.arm_rewards.values_mut() {
*reward = 0.0;
}
self.current_performance.clear();
self.last_confidence = 0.0;
}
}
#[derive(Debug, Clone)]
pub struct OptimizerStatistics {
pub optimizer: OptimizerType,
pub num_trials: usize,
pub mean_performance: f64,
pub std_performance: f64,
pub best_performance: f64,
pub worst_performance: f64,
pub success_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_problem_characteristics() {
let problem = ProblemCharacteristics {
dataset_size: 10000,
input_dim: 784,
output_dim: 10,
problem_type: ProblemType::Classification,
gradient_sparsity: 0.1,
gradient_noise: 0.05,
memory_budget: 1_000_000,
time_budget: 3600.0,
batch_size: 64,
lr_sensitivity: 0.5,
regularization_strength: 0.01,
architecture_type: Some("CNN".to_string()),
};
assert_eq!(problem.dataset_size, 10000);
assert_eq!(problem.problem_type, ProblemType::Classification);
}
#[test]
fn test_rule_based_selection() {
let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
let large_problem = ProblemCharacteristics {
dataset_size: 100001,
input_dim: 224,
output_dim: 1000,
problem_type: ProblemType::ComputerVision,
gradient_sparsity: 0.1,
gradient_noise: 0.05,
memory_budget: 10_000_000,
time_budget: 7200.0,
batch_size: 32,
lr_sensitivity: 0.5,
regularization_strength: 0.01,
architecture_type: Some("ResNet".to_string()),
};
selector.set_problem(large_problem);
let optimizer = selector.select_optimizer().expect("unwrap failed");
assert_eq!(optimizer, OptimizerType::AdamW);
}
#[test]
fn test_selection_network() {
let network = SelectionNetwork::<f64>::new(5, 10, 3);
let features = Array1::from_vec(vec![1.0, 0.5, 2.0, 0.8, 1.5]);
let probabilities = network.forward(&features).expect("unwrap failed");
assert_eq!(probabilities.len(), 3);
let sum: f64 = probabilities.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
for &prob in probabilities.iter() {
assert!(prob >= 0.0);
}
}
#[test]
fn test_bandit_selection() {
let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::Bandit {
epsilon: 0.1,
confidence: 2.0,
});
let problem = ProblemCharacteristics {
dataset_size: 1000,
input_dim: 10,
output_dim: 2,
problem_type: ProblemType::Classification,
gradient_sparsity: 0.0,
gradient_noise: 0.1,
memory_budget: 1_000_000,
time_budget: 600.0,
batch_size: 32,
lr_sensitivity: 0.5,
regularization_strength: 0.01,
architecture_type: None,
};
selector.set_problem(problem);
let optimizer = selector.select_optimizer().expect("unwrap failed");
assert!(selector.available_optimizers.contains(&optimizer));
}
#[test]
fn test_performance_update() {
let mut selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
let metrics = PerformanceMetrics {
final_loss: 0.1,
convergence_steps: 100,
training_time: 60.0,
memory_usage: 500_000,
validation_performance: 0.95,
stability: 0.02,
generalization_gap: 0.05,
};
selector
.update_performance(OptimizerType::Adam, metrics)
.expect("unwrap failed");
let stats = selector
.get_optimizer_statistics(OptimizerType::Adam)
.expect("unwrap failed");
assert_eq!(stats.num_trials, 1);
assert_relative_eq!(stats.mean_performance, 0.95, epsilon = 1e-6);
}
#[test]
fn test_problem_similarity() {
let selector = AdaptiveOptimizerSelector::<f64>::new(SelectionStrategy::RuleBased);
let problem1 = ProblemCharacteristics {
dataset_size: 1000,
input_dim: 10,
output_dim: 2,
problem_type: ProblemType::Classification,
gradient_sparsity: 0.1,
gradient_noise: 0.05,
memory_budget: 1_000_000,
time_budget: 600.0,
batch_size: 32,
lr_sensitivity: 0.5,
regularization_strength: 0.01,
architecture_type: None,
};
let problem2 = problem1.clone();
let similarity = selector.compute_problem_similarity(&problem1, &problem2);
assert_relative_eq!(similarity, 1.0, epsilon = 1e-6);
let mut problem3 = problem1.clone();
problem3.problem_type = ProblemType::Regression;
let similarity = selector.compute_problem_similarity(&problem1, &problem3);
assert!(similarity < 1.0);
}
}