use crate::adaptive_selection::{OptimizerType, ProblemCharacteristics};
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub enum DomainStrategy {
ComputerVision {
resolution_adaptive: bool,
batch_norm_tuning: bool,
augmentation_aware: bool,
},
NaturalLanguage {
sequence_adaptive: bool,
attention_optimized: bool,
vocab_aware: bool,
},
RecommendationSystems {
collaborative_filtering: bool,
matrix_factorization: bool,
cold_start_aware: bool,
},
TimeSeries {
temporal_aware: bool,
seasonality_adaptive: bool,
multi_step: bool,
},
ReinforcementLearning {
policy_gradient: bool,
value_function: bool,
exploration_aware: bool,
},
ScientificComputing {
stability_focused: bool,
precision_critical: bool,
sparse_optimized: bool,
},
}
#[derive(Debug, Clone)]
pub struct DomainConfig<A: Float> {
pub base_learning_rate: A,
pub recommended_batch_sizes: Vec<usize>,
pub gradient_clip_values: Vec<A>,
pub regularization_range: (A, A),
pub optimizer_ranking: Vec<OptimizerType>,
pub domain_params: HashMap<String, A>,
}
#[derive(Debug)]
pub struct DomainSpecificSelector<A: Float> {
strategy: DomainStrategy,
config: DomainConfig<A>,
domain_performance: HashMap<String, Vec<DomainPerformanceMetrics<A>>>,
transfer_knowledge: Vec<CrossDomainKnowledge<A>>,
currentcontext: Option<OptimizationContext<A>>,
}
#[derive(Debug, Clone)]
pub struct DomainPerformanceMetrics<A: Float> {
pub validation_accuracy: A,
pub domain_specific_score: A,
pub stability_score: A,
pub convergence_epochs: usize,
pub resource_efficiency: A,
pub transfer_score: A,
}
#[derive(Debug, Clone)]
pub struct CrossDomainKnowledge<A: Float> {
pub source_domain: String,
pub target_domain: String,
pub transferable_params: HashMap<String, A>,
pub transfer_score: A,
pub successful_strategy: OptimizerType,
}
#[derive(Debug, Clone)]
pub struct OptimizationContext<A: Float> {
pub problem_chars: ProblemCharacteristics,
pub resource_constraints: ResourceConstraints<A>,
pub training_config: TrainingConfiguration<A>,
pub domain_metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct ResourceConstraints<A: Float> {
pub max_memory: usize,
pub max_time: A,
pub gpu_available: bool,
pub distributed_capable: bool,
pub energy_efficient: bool,
}
#[derive(Debug, Clone)]
pub struct TrainingConfiguration<A: Float> {
pub max_epochs: usize,
pub early_stopping_patience: usize,
pub validation_frequency: usize,
pub lr_schedule_type: LearningRateScheduleType,
pub regularization_approach: RegularizationApproach<A>,
}
#[derive(Debug, Clone)]
pub enum LearningRateScheduleType {
Constant,
ExponentialDecay {
decay_rate: f64,
},
CosineAnnealing {
t_max: usize,
},
ReduceOnPlateau {
patience: usize,
factor: f64,
},
OneCycle {
max_lr: f64,
},
}
#[derive(Debug, Clone)]
pub enum RegularizationApproach<A: Float> {
L2Only {
weight: A,
},
L1Only {
weight: A,
},
ElasticNet {
l1_weight: A,
l2_weight: A,
},
Dropout {
dropout_rate: A,
},
Combined {
l2_weight: A,
dropout_rate: A,
additional_techniques: Vec<String>,
},
}
impl<A: Float + ScalarOperand + Debug + std::iter::Sum + Send + Sync> DomainSpecificSelector<A> {
pub fn new(strategy: DomainStrategy) -> Self {
let config = Self::default_config_for_strategy(&strategy);
Self {
strategy,
config,
domain_performance: HashMap::new(),
transfer_knowledge: Vec::new(),
currentcontext: None,
}
}
pub fn setcontext(&mut self, context: OptimizationContext<A>) {
self.currentcontext = Some(context);
}
pub fn select_optimal_config(&mut self) -> Result<DomainOptimizationConfig<A>> {
let context = self
.currentcontext
.as_ref()
.ok_or_else(|| OptimError::InvalidConfig("No optimization context set".to_string()))?;
match &self.strategy {
DomainStrategy::ComputerVision {
resolution_adaptive,
batch_norm_tuning,
augmentation_aware,
} => self.optimize_computer_vision(
context,
*resolution_adaptive,
*batch_norm_tuning,
*augmentation_aware,
),
DomainStrategy::NaturalLanguage {
sequence_adaptive,
attention_optimized,
vocab_aware,
} => self.optimize_natural_language(
context,
*sequence_adaptive,
*attention_optimized,
*vocab_aware,
),
DomainStrategy::RecommendationSystems {
collaborative_filtering,
matrix_factorization,
cold_start_aware,
} => self.optimize_recommendation_systems(
context,
*collaborative_filtering,
*matrix_factorization,
*cold_start_aware,
),
DomainStrategy::TimeSeries {
temporal_aware,
seasonality_adaptive,
multi_step,
} => self.optimize_time_series(
context,
*temporal_aware,
*seasonality_adaptive,
*multi_step,
),
DomainStrategy::ReinforcementLearning {
policy_gradient,
value_function,
exploration_aware,
} => self.optimize_reinforcement_learning(
context,
*policy_gradient,
*value_function,
*exploration_aware,
),
DomainStrategy::ScientificComputing {
stability_focused,
precision_critical,
sparse_optimized,
} => self.optimize_scientific_computing(
context,
*stability_focused,
*precision_critical,
*sparse_optimized,
),
}
}
fn optimize_computer_vision(
&self,
context: &OptimizationContext<A>,
resolution_adaptive: bool,
batch_norm_tuning: bool,
augmentation_aware: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if resolution_adaptive {
let resolution_factor = self.estimate_resolution_factor(&context.problem_chars);
config.learning_rate =
self.config.base_learning_rate * A::from(resolution_factor).expect("unwrap failed");
if context.problem_chars.input_dim > 512 * 512 {
config.learning_rate = config.learning_rate * A::from(0.5).expect("unwrap failed");
}
}
if batch_norm_tuning {
config.optimizer_type = OptimizerType::AdamW; config.specialized_params.insert(
"batch_norm_momentum".to_string(),
A::from(0.99).expect("unwrap failed"),
);
config.specialized_params.insert(
"batch_norm_eps".to_string(),
A::from(1e-5).expect("unwrap failed"),
);
}
if augmentation_aware {
config.regularization_strength =
config.regularization_strength * A::from(1.5).expect("unwrap failed");
config.specialized_params.insert(
"mixup_alpha".to_string(),
A::from(0.2).expect("unwrap failed"),
);
config.specialized_params.insert(
"cutmix_alpha".to_string(),
A::from(1.0).expect("unwrap failed"),
);
}
config.batch_size = self.select_cv_batch_size(&context.resource_constraints);
config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
config.lr_schedule = LearningRateScheduleType::CosineAnnealing {
t_max: context.training_config.max_epochs,
};
Ok(config)
}
fn optimize_natural_language(
&self,
context: &OptimizationContext<A>,
sequence_adaptive: bool,
attention_optimized: bool,
vocab_aware: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if sequence_adaptive {
let seq_length = context.problem_chars.input_dim;
if seq_length > 512 {
config.learning_rate =
self.config.base_learning_rate * A::from(0.7).expect("unwrap failed");
config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed"));
} else {
config.learning_rate = self.config.base_learning_rate;
config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
}
}
if attention_optimized {
config.optimizer_type = OptimizerType::AdamW; config.specialized_params.insert(
"attention_dropout".to_string(),
A::from(0.1).expect("unwrap failed"),
);
config.specialized_params.insert(
"attention_head_dim".to_string(),
A::from(64.0).expect("unwrap failed"),
);
config.specialized_params.insert(
"layer_decay_rate".to_string(),
A::from(0.95).expect("unwrap failed"),
);
}
if vocab_aware {
let vocab_size = context.problem_chars.output_dim;
if vocab_size > 30000 {
config.specialized_params.insert(
"tie_embeddings".to_string(),
A::from(1.0).expect("unwrap failed"),
);
config.specialized_params.insert(
"embedding_dropout".to_string(),
A::from(0.1).expect("unwrap failed"),
);
}
}
config.batch_size = self.select_nlp_batch_size(&context.resource_constraints);
config.lr_schedule = LearningRateScheduleType::OneCycle {
max_lr: config.learning_rate.to_f64().expect("unwrap failed"),
};
config.specialized_params.insert(
"warmup_steps".to_string(),
A::from(1000.0).expect("unwrap failed"),
);
Ok(config)
}
fn optimize_recommendation_systems(
&self,
context: &OptimizationContext<A>,
collaborative_filtering: bool,
matrix_factorization: bool,
cold_start_aware: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if collaborative_filtering {
config.optimizer_type = OptimizerType::Adam; config.regularization_strength = A::from(0.01).expect("unwrap failed"); config.specialized_params.insert(
"negative_sampling_rate".to_string(),
A::from(5.0).expect("unwrap failed"),
);
}
if matrix_factorization {
config.learning_rate = A::from(0.01).expect("unwrap failed"); config.specialized_params.insert(
"embedding_dim".to_string(),
A::from(128.0).expect("unwrap failed"),
);
config.specialized_params.insert(
"factorization_rank".to_string(),
A::from(50.0).expect("unwrap failed"),
);
}
if cold_start_aware {
config.specialized_params.insert(
"content_weight".to_string(),
A::from(0.3).expect("unwrap failed"),
);
config.specialized_params.insert(
"popularity_bias".to_string(),
A::from(0.1).expect("unwrap failed"),
);
}
config.batch_size = self.select_recsys_batch_size(&context.resource_constraints);
config.gradient_clip_norm = Some(A::from(5.0).expect("unwrap failed"));
Ok(config)
}
fn optimize_time_series(
&self,
context: &OptimizationContext<A>,
temporal_aware: bool,
seasonality_adaptive: bool,
multi_step: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if temporal_aware {
config.optimizer_type = OptimizerType::RMSprop; config.learning_rate = A::from(0.001).expect("unwrap failed"); config.specialized_params.insert(
"sequence_length".to_string(),
A::from(context.problem_chars.input_dim as f64).expect("unwrap failed"),
);
}
if seasonality_adaptive {
config.specialized_params.insert(
"seasonal_periods".to_string(),
A::from(24.0).expect("unwrap failed"),
); config.specialized_params.insert(
"trend_strength".to_string(),
A::from(0.1).expect("unwrap failed"),
);
}
if multi_step {
config.specialized_params.insert(
"prediction_horizon".to_string(),
A::from(12.0).expect("unwrap failed"),
);
config.specialized_params.insert(
"multi_step_loss_weight".to_string(),
A::from(0.8).expect("unwrap failed"),
);
}
config.batch_size = 32; config.gradient_clip_norm = Some(A::from(1.0).expect("unwrap failed"));
config.lr_schedule = LearningRateScheduleType::ReduceOnPlateau {
patience: 10,
factor: 0.5,
};
Ok(config)
}
fn optimize_reinforcement_learning(
&self,
context: &OptimizationContext<A>,
policy_gradient: bool,
value_function: bool,
exploration_aware: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if policy_gradient {
config.optimizer_type = OptimizerType::Adam;
config.learning_rate = A::from(3e-4).expect("unwrap failed"); config.specialized_params.insert(
"entropy_coeff".to_string(),
A::from(0.01).expect("unwrap failed"),
);
}
if value_function {
config.specialized_params.insert(
"value_loss_coeff".to_string(),
A::from(0.5).expect("unwrap failed"),
);
config.specialized_params.insert(
"huber_loss_delta".to_string(),
A::from(1.0).expect("unwrap failed"),
);
}
if exploration_aware {
config.specialized_params.insert(
"epsilon_start".to_string(),
A::from(1.0).expect("unwrap failed"),
);
config.specialized_params.insert(
"epsilon_end".to_string(),
A::from(0.1).expect("unwrap failed"),
);
config.specialized_params.insert(
"epsilon_decay".to_string(),
A::from(0.995).expect("unwrap failed"),
);
}
config.batch_size = 64; config.gradient_clip_norm = Some(A::from(0.5).expect("unwrap failed")); config.lr_schedule = LearningRateScheduleType::Constant;
Ok(config)
}
fn optimize_scientific_computing(
&self,
context: &OptimizationContext<A>,
stability_focused: bool,
precision_critical: bool,
sparse_optimized: bool,
) -> Result<DomainOptimizationConfig<A>> {
let mut config = DomainOptimizationConfig::default();
if stability_focused {
config.optimizer_type = OptimizerType::LBFGS; config.learning_rate = A::from(0.1).expect("unwrap failed"); config.specialized_params.insert(
"line_search_tolerance".to_string(),
A::from(1e-6).expect("unwrap failed"),
);
}
if precision_critical {
config.specialized_params.insert(
"convergence_tolerance".to_string(),
A::from(1e-8).expect("unwrap failed"),
);
config.specialized_params.insert(
"max_iterations".to_string(),
A::from(1000.0).expect("unwrap failed"),
);
}
if sparse_optimized {
config.optimizer_type = OptimizerType::Adam;
config.specialized_params.insert(
"sparsity_threshold".to_string(),
A::from(1e-6).expect("unwrap failed"),
);
}
config.batch_size = context.problem_chars.dataset_size.min(1024); config.gradient_clip_norm = None; config.lr_schedule = LearningRateScheduleType::Constant;
Ok(config)
}
pub fn update_domain_performance(
&mut self,
domain: String,
metrics: DomainPerformanceMetrics<A>,
) {
self.domain_performance
.entry(domain)
.or_default()
.push(metrics);
}
pub fn record_transfer_knowledge(&mut self, knowledge: CrossDomainKnowledge<A>) {
self.transfer_knowledge.push(knowledge);
}
pub fn get_domain_recommendations(&self, domain: &str) -> Vec<DomainRecommendation<A>> {
let mut recommendations = Vec::new();
if let Some(history) = self.domain_performance.get(domain) {
if !history.is_empty() {
let avg_performance = history.iter().map(|m| m.validation_accuracy).sum::<A>()
/ A::from(history.len()).expect("unwrap failed");
recommendations.push(DomainRecommendation {
recommendation_type: RecommendationType::PerformanceBaseline,
description: format!(
"Historical average performance: {:.4}",
avg_performance.to_f64().expect("unwrap failed")
),
confidence: A::from(0.8).expect("unwrap failed"),
action: "Consider this as baseline for improvements".to_string(),
});
}
}
for knowledge in &self.transfer_knowledge {
if knowledge.target_domain == domain {
recommendations.push(DomainRecommendation {
recommendation_type: RecommendationType::TransferLearning,
description: format!(
"Transfer from {} domain with {:.2} effectiveness",
knowledge.source_domain,
knowledge.transfer_score.to_f64().expect("unwrap failed")
),
confidence: knowledge.transfer_score,
action: format!("Use {:?} optimizer", knowledge.successful_strategy),
});
}
}
recommendations
}
fn estimate_resolution_factor(&self, problem_chars: &ProblemCharacteristics) -> f64 {
let resolution = problem_chars.input_dim as f64;
if resolution > 1_000_000.0 {
0.5
} else if resolution > 250_000.0 {
0.7
} else if resolution > 50_000.0 {
0.9
} else {
1.0
}
}
fn select_cv_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
if constraints.max_memory > 16_000_000_000 {
128
} else if constraints.max_memory > 8_000_000_000 {
64
} else {
32
}
}
fn select_nlp_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
if constraints.max_memory > 32_000_000_000 {
64
} else if constraints.max_memory > 16_000_000_000 {
32
} else {
16
}
}
fn select_recsys_batch_size(&self, constraints: &ResourceConstraints<A>) -> usize {
if constraints.max_memory > 8_000_000_000 {
512
} else {
256
}
}
fn default_config_for_strategy(strategy: &DomainStrategy) -> DomainConfig<A> {
match strategy {
DomainStrategy::ComputerVision { .. } => DomainConfig {
base_learning_rate: A::from(0.001).expect("unwrap failed"),
recommended_batch_sizes: vec![32, 64, 128],
gradient_clip_values: vec![
A::from(1.0).expect("unwrap failed"),
A::from(2.0).expect("unwrap failed"),
],
regularization_range: (
A::from(1e-5).expect("unwrap failed"),
A::from(1e-2).expect("unwrap failed"),
),
optimizer_ranking: vec![
OptimizerType::AdamW,
OptimizerType::SGDMomentum,
OptimizerType::Adam,
],
domain_params: HashMap::new(),
},
DomainStrategy::NaturalLanguage { .. } => DomainConfig {
base_learning_rate: A::from(2e-5).expect("unwrap failed"),
recommended_batch_sizes: vec![16, 32, 64],
gradient_clip_values: vec![
A::from(0.5).expect("unwrap failed"),
A::from(1.0).expect("unwrap failed"),
],
regularization_range: (
A::from(1e-4).expect("unwrap failed"),
A::from(1e-1).expect("unwrap failed"),
),
optimizer_ranking: vec![OptimizerType::AdamW, OptimizerType::Adam],
domain_params: HashMap::new(),
},
DomainStrategy::RecommendationSystems { .. } => DomainConfig {
base_learning_rate: A::from(0.01).expect("unwrap failed"),
recommended_batch_sizes: vec![128, 256, 512],
gradient_clip_values: vec![
A::from(5.0).expect("unwrap failed"),
A::from(10.0).expect("unwrap failed"),
],
regularization_range: (
A::from(1e-3).expect("unwrap failed"),
A::from(1e-1).expect("unwrap failed"),
),
optimizer_ranking: vec![OptimizerType::Adam, OptimizerType::AdaGrad],
domain_params: HashMap::new(),
},
DomainStrategy::TimeSeries { .. } => DomainConfig {
base_learning_rate: A::from(0.001).expect("unwrap failed"),
recommended_batch_sizes: vec![16, 32, 64],
gradient_clip_values: vec![A::from(1.0).expect("unwrap failed")],
regularization_range: (
A::from(1e-4).expect("unwrap failed"),
A::from(1e-2).expect("unwrap failed"),
),
optimizer_ranking: vec![OptimizerType::RMSprop, OptimizerType::Adam],
domain_params: HashMap::new(),
},
DomainStrategy::ReinforcementLearning { .. } => DomainConfig {
base_learning_rate: A::from(3e-4).expect("unwrap failed"),
recommended_batch_sizes: vec![32, 64, 128],
gradient_clip_values: vec![A::from(0.5).expect("unwrap failed")],
regularization_range: (
A::from(1e-4).expect("unwrap failed"),
A::from(1e-2).expect("unwrap failed"),
),
optimizer_ranking: vec![OptimizerType::Adam],
domain_params: HashMap::new(),
},
DomainStrategy::ScientificComputing { .. } => DomainConfig {
base_learning_rate: A::from(0.1).expect("unwrap failed"),
recommended_batch_sizes: vec![64, 128, 256, 512],
gradient_clip_values: vec![],
regularization_range: (
A::from(1e-6).expect("unwrap failed"),
A::from(1e-3).expect("unwrap failed"),
),
optimizer_ranking: vec![OptimizerType::LBFGS, OptimizerType::Adam],
domain_params: HashMap::new(),
},
}
}
}
#[derive(Debug, Clone)]
pub struct DomainOptimizationConfig<A: Float> {
pub optimizer_type: OptimizerType,
pub learning_rate: A,
pub batch_size: usize,
pub gradient_clip_norm: Option<A>,
pub regularization_strength: A,
pub lr_schedule: LearningRateScheduleType,
pub specialized_params: HashMap<String, A>,
}
impl<A: Float + Send + Sync> Default for DomainOptimizationConfig<A> {
fn default() -> Self {
Self {
optimizer_type: OptimizerType::Adam,
learning_rate: A::from(0.001).expect("unwrap failed"),
batch_size: 32,
gradient_clip_norm: Some(A::from(1.0).expect("unwrap failed")),
regularization_strength: A::from(1e-4).expect("unwrap failed"),
lr_schedule: LearningRateScheduleType::Constant,
specialized_params: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct DomainRecommendation<A: Float> {
pub recommendation_type: RecommendationType,
pub description: String,
pub confidence: A,
pub action: String,
}
#[derive(Debug, Clone)]
pub enum RecommendationType {
PerformanceBaseline,
TransferLearning,
HyperparameterTuning,
ArchitectureChange,
ResourceOptimization,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adaptive_selection::ProblemType;
#[test]
fn test_domain_specific_selector_creation() {
let strategy = DomainStrategy::ComputerVision {
resolution_adaptive: true,
batch_norm_tuning: true,
augmentation_aware: true,
};
let selector = DomainSpecificSelector::<f64>::new(strategy);
assert_eq!(selector.config.optimizer_ranking[0], OptimizerType::AdamW);
}
#[test]
fn test_computer_vision_optimization() {
let strategy = DomainStrategy::ComputerVision {
resolution_adaptive: true,
batch_norm_tuning: true,
augmentation_aware: true,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let context = OptimizationContext {
problem_chars: ProblemCharacteristics {
dataset_size: 50000,
input_dim: 224 * 224 * 3, output_dim: 1000,
problem_type: ProblemType::ComputerVision,
gradient_sparsity: 0.1,
gradient_noise: 0.05,
memory_budget: 8_000_000_000,
time_budget: 3600.0,
batch_size: 64,
lr_sensitivity: 0.5,
regularization_strength: 0.01,
architecture_type: Some("ResNet".to_string()),
},
resource_constraints: ResourceConstraints {
max_memory: 17_000_000_000, max_time: 7200.0,
gpu_available: true,
distributed_capable: false,
energy_efficient: false,
},
training_config: TrainingConfiguration {
max_epochs: 100,
early_stopping_patience: 10,
validation_frequency: 1,
lr_schedule_type: LearningRateScheduleType::CosineAnnealing { t_max: 100 },
regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
},
domain_metadata: HashMap::new(),
};
selector.setcontext(context);
let config = selector.select_optimal_config().expect("unwrap failed");
assert_eq!(config.optimizer_type, OptimizerType::AdamW);
assert_eq!(config.batch_size, 128); assert!(config.gradient_clip_norm.is_some());
}
#[test]
fn test_natural_language_optimization() {
let strategy = DomainStrategy::NaturalLanguage {
sequence_adaptive: true,
attention_optimized: true,
vocab_aware: true,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let context = OptimizationContext {
problem_chars: ProblemCharacteristics {
dataset_size: 100000,
input_dim: 512, output_dim: 50000, problem_type: ProblemType::NaturalLanguage,
gradient_sparsity: 0.2,
gradient_noise: 0.1,
memory_budget: 32_000_000_000,
time_budget: 7200.0,
batch_size: 32,
lr_sensitivity: 0.8,
regularization_strength: 0.1,
architecture_type: Some("Transformer".to_string()),
},
resource_constraints: ResourceConstraints {
max_memory: 32_000_000_000,
max_time: 10800.0,
gpu_available: true,
distributed_capable: true,
energy_efficient: false,
},
training_config: TrainingConfiguration {
max_epochs: 50,
early_stopping_patience: 5,
validation_frequency: 1,
lr_schedule_type: LearningRateScheduleType::OneCycle { max_lr: 2e-5 },
regularization_approach: RegularizationApproach::Dropout { dropout_rate: 0.1 },
},
domain_metadata: HashMap::new(),
};
selector.setcontext(context);
let config = selector.select_optimal_config().expect("unwrap failed");
assert_eq!(config.optimizer_type, OptimizerType::AdamW);
assert!(config.specialized_params.contains_key("warmup_steps"));
assert!(config.specialized_params.contains_key("tie_embeddings")); }
#[test]
fn test_time_series_optimization() {
let strategy = DomainStrategy::TimeSeries {
temporal_aware: true,
seasonality_adaptive: true,
multi_step: true,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let context = OptimizationContext {
problem_chars: ProblemCharacteristics {
dataset_size: 10000,
input_dim: 168, output_dim: 24, problem_type: ProblemType::TimeSeries,
gradient_sparsity: 0.05,
gradient_noise: 0.2,
memory_budget: 4_000_000_000,
time_budget: 1800.0,
batch_size: 32,
lr_sensitivity: 0.7,
regularization_strength: 0.01,
architecture_type: Some("LSTM".to_string()),
},
resource_constraints: ResourceConstraints {
max_memory: 8_000_000_000,
max_time: 3600.0,
gpu_available: true,
distributed_capable: false,
energy_efficient: true,
},
training_config: TrainingConfiguration {
max_epochs: 200,
early_stopping_patience: 20,
validation_frequency: 5,
lr_schedule_type: LearningRateScheduleType::ReduceOnPlateau {
patience: 10,
factor: 0.5,
},
regularization_approach: RegularizationApproach::L2Only { weight: 1e-4 },
},
domain_metadata: HashMap::new(),
};
selector.setcontext(context);
let config = selector.select_optimal_config().expect("unwrap failed");
assert_eq!(config.optimizer_type, OptimizerType::RMSprop);
assert_eq!(config.batch_size, 32);
assert!(config.specialized_params.contains_key("seasonal_periods"));
assert!(config.specialized_params.contains_key("prediction_horizon"));
}
#[test]
fn test_performance_tracking() {
let strategy = DomainStrategy::ComputerVision {
resolution_adaptive: true,
batch_norm_tuning: false,
augmentation_aware: false,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let metrics = DomainPerformanceMetrics {
validation_accuracy: 0.95,
domain_specific_score: 0.92,
stability_score: 0.88,
convergence_epochs: 50,
resource_efficiency: 0.85,
transfer_score: 0.7,
};
selector.update_domain_performance("computer_vision".to_string(), metrics);
let recommendations = selector.get_domain_recommendations("computer_vision");
assert!(!recommendations.is_empty());
assert!(recommendations[0].description.contains("0.95"));
}
#[test]
fn test_cross_domain_transfer() {
let strategy = DomainStrategy::ComputerVision {
resolution_adaptive: true,
batch_norm_tuning: true,
augmentation_aware: true,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let transfer_knowledge = CrossDomainKnowledge {
source_domain: "natural_language".to_string(),
target_domain: "computer_vision".to_string(),
transferable_params: HashMap::from([
("learning_rate".to_string(), 0.001),
("weight_decay".to_string(), 0.01),
]),
transfer_score: 0.8,
successful_strategy: OptimizerType::AdamW,
};
selector.record_transfer_knowledge(transfer_knowledge);
let recommendations = selector.get_domain_recommendations("computer_vision");
assert!(recommendations
.iter()
.any(|r| matches!(r.recommendation_type, RecommendationType::TransferLearning)));
}
#[test]
fn test_scientific_computing_optimization() {
let strategy = DomainStrategy::ScientificComputing {
stability_focused: true,
precision_critical: true,
sparse_optimized: false,
};
let mut selector = DomainSpecificSelector::<f64>::new(strategy);
let context = OptimizationContext {
problem_chars: ProblemCharacteristics {
dataset_size: 1000,
input_dim: 100,
output_dim: 1,
problem_type: ProblemType::Regression,
gradient_sparsity: 0.01,
gradient_noise: 0.01,
memory_budget: 16_000_000_000,
time_budget: 7200.0,
batch_size: 100,
lr_sensitivity: 0.3,
regularization_strength: 1e-6,
architecture_type: Some("MLP".to_string()),
},
resource_constraints: ResourceConstraints {
max_memory: 16_000_000_000,
max_time: 7200.0,
gpu_available: false,
distributed_capable: false,
energy_efficient: false,
},
training_config: TrainingConfiguration {
max_epochs: 1000,
early_stopping_patience: 100,
validation_frequency: 10,
lr_schedule_type: LearningRateScheduleType::Constant,
regularization_approach: RegularizationApproach::L2Only { weight: 1e-6 },
},
domain_metadata: HashMap::new(),
};
selector.setcontext(context);
let config = selector.select_optimal_config().expect("unwrap failed");
assert_eq!(config.optimizer_type, OptimizerType::LBFGS);
assert!(config.gradient_clip_norm.is_none()); assert!(config
.specialized_params
.contains_key("convergence_tolerance"));
}
}