use super::{
ActivationType, LearnedOptimizationConfig, LearnedOptimizer, MetaOptimizerState,
OptimizationProblem, TrainingTask,
};
use crate::error::OptimizeResult;
use crate::result::OptimizeResults;
use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
use statrs::statistics::Statistics;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AdaptiveNASSystem {
config: LearnedOptimizationConfig,
architecture_population: Vec<OptimizationArchitecture>,
performance_history: HashMap<ArchitectureId, Vec<f64>>,
controller: ArchitectureController,
meta_state: MetaOptimizerState,
architecture_cache: HashMap<String, OptimizationArchitecture>,
search_stats: NASSearchStats,
generation: usize,
}
type ArchitectureId = String;
#[derive(Debug, Clone)]
pub struct OptimizationArchitecture {
pub id: ArchitectureId,
pub layers: Vec<LayerConfig>,
pub connections: Vec<Connection>,
pub activations: Vec<ActivationType>,
pub skip_connections: Vec<SkipConnection>,
pub optimizer_components: Vec<OptimizerComponent>,
pub complexity: f64,
pub performance_metrics: ArchitectureMetrics,
}
#[derive(Debug, Clone)]
pub struct LayerConfig {
pub layer_type: LayerType,
pub units: usize,
pub dropout: f64,
pub normalization: NormalizationType,
pub parameters: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub enum LayerType {
Dense,
Convolution { kernel_size: usize, stride: usize },
Attention { num_heads: usize },
LSTM { hidden_size: usize },
GRU { hidden_size: usize },
Transformer { num_heads: usize, ff_dim: usize },
GraphNN { aggregation: String },
Memory { memory_size: usize },
}
#[derive(Debug, Clone)]
pub enum NormalizationType {
None,
BatchNorm,
LayerNorm,
GroupNorm { groups: usize },
InstanceNorm,
}
#[derive(Debug, Clone)]
pub struct Connection {
pub from: usize,
pub to: usize,
pub weight: f64,
pub connection_type: ConnectionType,
}
#[derive(Debug, Clone)]
pub enum ConnectionType {
Forward,
Residual,
Dense,
Attention,
}
#[derive(Debug, Clone)]
pub struct SkipConnection {
pub source: usize,
pub target: usize,
pub skip_type: SkipType,
}
#[derive(Debug, Clone)]
pub enum SkipType {
Add,
Concat,
Gated { gate_size: usize },
Highway,
}
#[derive(Debug, Clone)]
pub enum OptimizerComponent {
Momentum { decay: f64 },
AdaptiveLR {
adaptation_rate: f64,
min_lr: f64,
max_lr: f64,
},
SecondOrder {
hessian_approximation: HessianApprox,
regularization: f64,
},
TrustRegion {
initial_radius: f64,
max_radius: f64,
shrink_factor: f64,
expand_factor: f64,
},
LineSearch {
method: LineSearchMethod,
max_nit: usize,
},
Regularization {
l1_weight: f64,
l2_weight: f64,
elastic_net_ratio: f64,
},
}
#[derive(Debug, Clone)]
pub enum HessianApprox {
BFGS,
LBFGS { memory_size: usize },
SR1,
DFP,
DiagonalApprox,
}
#[derive(Debug, Clone)]
pub enum LineSearchMethod {
Backtracking,
StrongWolfe,
MoreThuente,
Armijo,
Exact,
}
#[derive(Debug, Clone)]
pub struct ArchitectureMetrics {
pub convergence_rate: f64,
pub success_rate: f64,
pub avg_evaluations: f64,
pub robustness: f64,
pub transfer_score: f64,
pub efficiency: f64,
}
impl Default for ArchitectureMetrics {
fn default() -> Self {
Self {
convergence_rate: 0.0,
success_rate: 0.0,
avg_evaluations: 0.0,
robustness: 0.0,
transfer_score: 0.0,
efficiency: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct ArchitectureController {
lstm_weights: Array3<f64>,
embedding_layer: Array2<f64>,
output_layer: Array2<f64>,
controller_state: Array1<f64>,
vocabulary: ArchitectureVocabulary,
}
#[derive(Debug, Clone)]
pub struct ArchitectureVocabulary {
pub layer_types: HashMap<String, usize>,
pub activations: HashMap<String, usize>,
pub components: HashMap<String, usize>,
pub vocab_size: usize,
}
#[derive(Debug, Clone)]
pub struct NASSearchStats {
pub architectures_evaluated: usize,
pub best_performance: f64,
pub search_efficiency: f64,
pub population_diversity: f64,
pub convergence_indicators: Vec<f64>,
}
impl Default for NASSearchStats {
fn default() -> Self {
Self {
architectures_evaluated: 0,
best_performance: f64::NEG_INFINITY,
search_efficiency: 0.0,
population_diversity: 1.0,
convergence_indicators: Vec::new(),
}
}
}
impl AdaptiveNASSystem {
pub fn new(config: LearnedOptimizationConfig) -> Self {
let vocabulary = ArchitectureVocabulary::new();
let controller = ArchitectureController::new(&vocabulary, config.hidden_size);
let hidden_size = config.hidden_size;
Self {
config,
architecture_population: Vec::new(),
performance_history: HashMap::new(),
controller,
meta_state: MetaOptimizerState {
meta_params: Array1::zeros(100),
network_weights: Array2::zeros((hidden_size, hidden_size)),
performance_history: Vec::new(),
adaptation_stats: super::AdaptationStatistics::default(),
episode: 0,
},
architecture_cache: HashMap::new(),
search_stats: NASSearchStats::default(),
generation: 0,
}
}
pub fn search_architectures(
&mut self,
training_problems: &[OptimizationProblem],
) -> OptimizeResult<Vec<OptimizationArchitecture>> {
if self.architecture_population.is_empty() {
self.initialize_population()?;
}
for generation in 0..self.config.meta_training_episodes {
self.generation = generation;
self.evaluate_population(training_problems)?;
self.update_controller()?;
let new_architectures = self.generate_new_architectures()?;
self.select_next_generation(new_architectures)?;
self.update_search_stats()?;
if self.check_convergence() {
break;
}
}
Ok(self.get_best_architectures())
}
fn initialize_population(&mut self) -> OptimizeResult<()> {
for _ in 0..self.config.batch_size {
let architecture = self.generate_random_architecture()?;
self.architecture_population.push(architecture);
}
Ok(())
}
fn generate_random_architecture(&self) -> OptimizeResult<OptimizationArchitecture> {
let num_layers = 2 + (scirs2_core::random::rng().random_range(0..8)); let mut layers = Vec::new();
let mut connections = Vec::new();
let mut activations = Vec::new();
let mut optimizer_components = Vec::new();
for i in 0..num_layers {
let layer_type = self.sample_layer_type();
let units = 16 + (scirs2_core::random::rng().random_range(0..256));
layers.push(LayerConfig {
layer_type,
units,
dropout: scirs2_core::random::rng().random_range(0.0..0.5),
normalization: self.sample_normalization(),
parameters: HashMap::new(),
});
activations.push(self.sample_activation());
if i > 0 {
connections.push(Connection {
from: i - 1,
to: i,
weight: 1.0,
connection_type: ConnectionType::Forward,
});
if i > 1 && scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
let skip_source = scirs2_core::random::rng().random_range(0..i);
connections.push(Connection {
from: skip_source,
to: i,
weight: 0.5,
connection_type: ConnectionType::Residual,
});
}
}
}
for _ in 0..(1 + scirs2_core::random::rng().random_range(0..4)) {
optimizer_components.push(self.sample_optimizer_component());
}
let id = format!(
"arch_{}",
scirs2_core::random::rng().random_range(0..u64::MAX)
);
Ok(OptimizationArchitecture {
id,
layers,
connections,
activations,
skip_connections: Vec::new(),
optimizer_components,
complexity: 0.0,
performance_metrics: ArchitectureMetrics::default(),
})
}
fn sample_layer_type(&self) -> LayerType {
match scirs2_core::random::rng().random_range(0..8) {
0 => LayerType::Dense,
1 => LayerType::Attention {
num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
},
2 => LayerType::LSTM {
hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
},
3 => LayerType::GRU {
hidden_size: 32 + scirs2_core::random::rng().random_range(0..128),
},
4 => LayerType::Transformer {
num_heads: 2 + scirs2_core::random::rng().random_range(0..6),
ff_dim: 64 + scirs2_core::random::rng().random_range(0..256),
},
5 => LayerType::Memory {
memory_size: 16 + scirs2_core::random::rng().random_range(0..64),
},
6 => LayerType::Convolution {
kernel_size: 1 + scirs2_core::random::rng().random_range(0..5),
stride: 1 + scirs2_core::random::rng().random_range(0..3),
},
_ => LayerType::GraphNN {
aggregation: "mean".to_string(),
},
}
}
fn sample_normalization(&self) -> NormalizationType {
match scirs2_core::random::rng().random_range(0..5) {
0 => NormalizationType::None,
1 => NormalizationType::BatchNorm,
2 => NormalizationType::LayerNorm,
3 => NormalizationType::GroupNorm {
groups: 2 + scirs2_core::random::rng().random_range(0..6),
},
_ => NormalizationType::InstanceNorm,
}
}
fn sample_activation(&self) -> ActivationType {
match scirs2_core::random::rng().random_range(0..5) {
0 => ActivationType::ReLU,
1 => ActivationType::GELU,
2 => ActivationType::Swish,
3 => ActivationType::Tanh,
_ => ActivationType::LeakyReLU,
}
}
fn sample_optimizer_component(&self) -> OptimizerComponent {
match scirs2_core::random::rng().random_range(0..6) {
0 => OptimizerComponent::Momentum {
decay: 0.8 + scirs2_core::random::rng().random_range(0.0..0.19),
},
1 => OptimizerComponent::AdaptiveLR {
adaptation_rate: 0.001 + scirs2_core::random::rng().random_range(0.0..0.009),
min_lr: 1e-8,
max_lr: 1.0,
},
2 => OptimizerComponent::SecondOrder {
hessian_approximation: HessianApprox::LBFGS {
memory_size: 5 + scirs2_core::random::rng().random_range(0..15),
},
regularization: 1e-6 + scirs2_core::random::rng().random_range(0.0..1e-3),
},
3 => OptimizerComponent::TrustRegion {
initial_radius: 0.1 + scirs2_core::random::rng().random_range(0.0..0.9),
max_radius: 10.0,
shrink_factor: 0.25,
expand_factor: 2.0,
},
4 => OptimizerComponent::LineSearch {
method: LineSearchMethod::StrongWolfe,
max_nit: 10 + scirs2_core::random::rng().random_range(0..20),
},
_ => OptimizerComponent::Regularization {
l1_weight: scirs2_core::random::rng().random_range(0.0..0.01),
l2_weight: scirs2_core::random::rng().random_range(0.0..0.01),
elastic_net_ratio: scirs2_core::random::rng().random_range(0.0..1.0),
},
}
}
fn evaluate_population(
&mut self,
training_problems: &[OptimizationProblem],
) -> OptimizeResult<()> {
let scores: Vec<_> = self
.architecture_population
.iter()
.map(|architecture| {
let mut total_score = 0.0;
let mut num_evaluated = 0;
for problem in training_problems.iter().take(5) {
if let Ok(score) = self.evaluate_architecture_on_problem(architecture, problem)
{
total_score += score;
num_evaluated += 1;
}
}
if num_evaluated > 0 {
Some(total_score / num_evaluated as f64)
} else {
None
}
})
.collect();
for (architecture, score) in self.architecture_population.iter_mut().zip(scores.iter()) {
if let Some(avg_score) = score {
architecture.performance_metrics.convergence_rate = *avg_score;
self.performance_history
.entry(architecture.id.clone())
.or_default()
.push(*avg_score);
}
}
Ok(())
}
fn evaluate_architecture_on_problem(
&self,
architecture: &OptimizationArchitecture,
problem: &OptimizationProblem,
) -> OptimizeResult<f64> {
let complexity_penalty = architecture.complexity * 0.01;
let num_components = architecture.optimizer_components.len() as f64;
let num_layers = architecture.layers.len() as f64;
let base_score = (num_components * 0.1 + num_layers * 0.05).min(1.0);
let final_score = base_score - complexity_penalty;
Ok(final_score.max(0.0))
}
fn update_controller(&mut self) -> OptimizeResult<()> {
let mut rewards = Vec::new();
for arch in &self.architecture_population {
rewards.push(arch.performance_metrics.convergence_rate);
}
if rewards.is_empty() {
return Ok(());
}
let baseline = rewards.iter().sum::<f64>() / rewards.len() as f64;
for (i, &reward) in rewards.iter().enumerate() {
let advantage = reward - baseline;
let lstm_len = self.controller.lstm_weights.len();
if i < lstm_len {
let shape = self.controller.lstm_weights.shape();
let dims = (shape[0], shape[1], shape[2]);
for j in 0..dims.1 {
for k in 0..dims.2 {
let learning_rate = self.config.meta_learning_rate;
let idx = (i % lstm_len, j, k);
self.controller.lstm_weights[idx] += learning_rate * advantage * 0.01;
}
}
}
}
Ok(())
}
fn generate_new_architectures(&mut self) -> OptimizeResult<Vec<OptimizationArchitecture>> {
let mut new_architectures = Vec::new();
for _ in 0..self.config.batch_size / 2 {
let architecture = self.controller_generate_architecture()?;
new_architectures.push(architecture);
if !self.architecture_population.is_empty() {
let best_idx = self.get_best_architecture_index();
let mutated = self.mutate_architecture(&self.architecture_population[best_idx])?;
new_architectures.push(mutated);
}
}
Ok(new_architectures)
}
fn controller_generate_architecture(&mut self) -> OptimizeResult<OptimizationArchitecture> {
let mut architecture = self.generate_random_architecture()?;
let controller_influence = self.controller.controller_state.view().mean();
if controller_influence > 0.5 {
if architecture.layers.len() < 10 {
architecture.layers.push(LayerConfig {
layer_type: LayerType::Dense,
units: 64,
dropout: 0.1,
normalization: NormalizationType::LayerNorm,
parameters: HashMap::new(),
});
}
} else {
if architecture.layers.len() > 2 {
architecture.layers.pop();
}
}
Ok(architecture)
}
fn mutate_architecture(
&self,
base_arch: &OptimizationArchitecture,
) -> OptimizeResult<OptimizationArchitecture> {
let mut mutated = base_arch.clone();
mutated.id = format!(
"mutated_{}",
scirs2_core::random::rng().random_range(0..u64::MAX)
);
if scirs2_core::random::rng().random_range(0.0..1.0) < 0.3 {
if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5 && mutated.layers.len() < 12
{
mutated.layers.push(LayerConfig {
layer_type: self.sample_layer_type(),
units: 32 + scirs2_core::random::rng().random_range(0..128),
dropout: scirs2_core::random::rng().random_range(0.0..0.5),
normalization: self.sample_normalization(),
parameters: HashMap::new(),
});
} else if mutated.layers.len() > 2 {
mutated.layers.pop();
}
}
for activation in &mut mutated.activations {
if scirs2_core::random::rng().random_range(0.0..1.0) < 0.2 {
*activation = self.sample_activation();
}
}
if scirs2_core::random::rng().random_range(0.0..1.0) < 0.4 {
if scirs2_core::random::rng().random_range(0.0..1.0) < 0.5
&& mutated.optimizer_components.len() < 6
{
mutated
.optimizer_components
.push(self.sample_optimizer_component());
} else if !mutated.optimizer_components.is_empty() {
let idx =
scirs2_core::random::rng().random_range(0..mutated.optimizer_components.len());
mutated.optimizer_components.remove(idx);
}
}
Ok(mutated)
}
fn select_next_generation(
&mut self,
mut new_architectures: Vec<OptimizationArchitecture>,
) -> OptimizeResult<()> {
self.architecture_population.append(&mut new_architectures);
self.architecture_population.sort_by(|a, b| {
b.performance_metrics
.convergence_rate
.partial_cmp(&a.performance_metrics.convergence_rate)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.architecture_population
.truncate(self.config.batch_size);
Ok(())
}
fn update_search_stats(&mut self) -> OptimizeResult<()> {
self.search_stats.architectures_evaluated += self.architecture_population.len();
if let Some(best_arch) = self.architecture_population.first() {
let best_performance = best_arch.performance_metrics.convergence_rate;
if best_performance > self.search_stats.best_performance {
self.search_stats.best_performance = best_performance;
}
}
let performances: Vec<f64> = self
.architecture_population
.iter()
.map(|a| a.performance_metrics.convergence_rate)
.collect();
if performances.len() > 1 {
let mean = performances.iter().sum::<f64>() / performances.len() as f64;
let variance = performances
.iter()
.map(|&p| (p - mean).powi(2))
.sum::<f64>()
/ performances.len() as f64;
self.search_stats.population_diversity = variance.sqrt();
}
self.search_stats
.convergence_indicators
.push(self.search_stats.best_performance);
Ok(())
}
fn check_convergence(&self) -> bool {
if self.search_stats.convergence_indicators.len() < 10 {
return false;
}
let recent_improvements: Vec<f64> = self
.search_stats
.convergence_indicators
.windows(2)
.map(|w| w[1] - w[0])
.collect();
let avg_improvement =
recent_improvements.iter().sum::<f64>() / recent_improvements.len() as f64;
avg_improvement < 1e-6
}
fn get_best_architectures(&self) -> Vec<OptimizationArchitecture> {
self.architecture_population.clone()
}
fn get_best_architecture_index(&self) -> usize {
self.architecture_population
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.performance_metrics
.convergence_rate
.partial_cmp(&b.performance_metrics.convergence_rate)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn get_search_stats(&self) -> &NASSearchStats {
&self.search_stats
}
pub fn cache_architecture_for_problem(
&mut self,
problem_class: String,
architecture: OptimizationArchitecture,
) {
self.architecture_cache.insert(problem_class, architecture);
}
pub fn get_cached_architecture(
&self,
problem_class: &str,
) -> Option<&OptimizationArchitecture> {
self.architecture_cache.get(problem_class)
}
}
impl ArchitectureController {
pub fn new(vocabulary: &ArchitectureVocabulary, hidden_size: usize) -> Self {
Self {
lstm_weights: Array3::from_shape_fn((4, hidden_size, hidden_size), |_| {
(scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
}),
embedding_layer: Array2::from_shape_fn((hidden_size, vocabulary.vocab_size), |_| {
(scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
}),
output_layer: Array2::from_shape_fn((vocabulary.vocab_size, hidden_size), |_| {
(scirs2_core::random::rng().random_range(0.0..1.0) - 0.5) * 0.1
}),
controller_state: Array1::zeros(hidden_size),
vocabulary: vocabulary.clone(),
}
}
}
impl Default for ArchitectureVocabulary {
fn default() -> Self {
Self::new()
}
}
impl ArchitectureVocabulary {
pub fn new() -> Self {
let mut layer_types = HashMap::new();
layer_types.insert("dense".to_string(), 0);
layer_types.insert("conv".to_string(), 1);
layer_types.insert("attention".to_string(), 2);
layer_types.insert("lstm".to_string(), 3);
layer_types.insert("gru".to_string(), 4);
layer_types.insert("transformer".to_string(), 5);
layer_types.insert("graph".to_string(), 6);
layer_types.insert("memory".to_string(), 7);
let mut activations = HashMap::new();
activations.insert("relu".to_string(), 8);
activations.insert("gelu".to_string(), 9);
activations.insert("swish".to_string(), 10);
activations.insert("tanh".to_string(), 11);
activations.insert("leaky_relu".to_string(), 12);
let mut components = HashMap::new();
components.insert("momentum".to_string(), 13);
components.insert("adaptive_lr".to_string(), 14);
components.insert("second_order".to_string(), 15);
components.insert("trust_region".to_string(), 16);
components.insert("line_search".to_string(), 17);
components.insert("regularization".to_string(), 18);
Self {
layer_types,
activations,
components,
vocab_size: 19,
}
}
}
impl LearnedOptimizer for AdaptiveNASSystem {
fn meta_train(&mut self, training_tasks: &[TrainingTask]) -> OptimizeResult<()> {
let problems: Vec<OptimizationProblem> = training_tasks
.iter()
.map(|task| task.problem.clone())
.collect();
self.search_architectures(&problems)?;
Ok(())
}
fn adapt_to_problem(
&mut self,
problem: &OptimizationProblem,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<()> {
if let Some(cached_arch) = self.get_cached_architecture(&problem.problem_class) {
return Ok(());
}
let specialized_arch = self.generate_random_architecture()?;
self.cache_architecture_for_problem(problem.problem_class.clone(), specialized_arch);
Ok(())
}
fn optimize<F>(
&mut self,
objective: F,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
if self.architecture_population.is_empty() {
self.initialize_population()?;
}
let best_idx = self.get_best_architecture_index();
let best_arch = &self.architecture_population[best_idx];
let mut current_params = initial_params.to_owned();
let mut best_value = objective(initial_params);
let mut iterations = 0;
for iter in 0..1000 {
iterations = iter;
let step_size = self.compute_step_size(best_arch, iter);
let direction = self.compute_search_direction(&objective, ¤t_params, best_arch);
for i in 0..current_params.len() {
current_params[i] -= step_size * direction[i];
}
let current_value = objective(¤t_params.view());
if current_value < best_value {
best_value = current_value;
}
if step_size < 1e-8 {
break;
}
}
Ok(OptimizeResults::<f64> {
x: current_params,
fun: best_value,
success: true,
nit: iterations,
message: format!(
"NAS optimization completed using architecture: {}",
best_arch.id
),
jac: None,
hess: None,
constr: None,
nfev: iterations * best_arch.layers.len(), njev: 0,
nhev: 0,
maxcv: 0,
status: 0,
})
}
fn get_state(&self) -> &MetaOptimizerState {
&self.meta_state
}
fn reset(&mut self) {
self.architecture_population.clear();
self.performance_history.clear();
self.search_stats = NASSearchStats::default();
self.generation = 0;
}
}
impl AdaptiveNASSystem {
fn compute_step_size(&self, architecture: &OptimizationArchitecture, iteration: usize) -> f64 {
let mut step_size = 0.01;
for component in &architecture.optimizer_components {
match component {
OptimizerComponent::AdaptiveLR {
adaptation_rate,
min_lr,
max_lr,
} => {
step_size *= 1.0 + adaptation_rate * (iteration as f64).cos();
step_size = step_size.max(*min_lr).min(*max_lr);
}
OptimizerComponent::TrustRegion { initial_radius, .. } => {
step_size = step_size.min(*initial_radius);
}
_ => {}
}
}
step_size / (1.0 + iteration as f64 * 0.001)
}
fn compute_search_direction<F>(
&self,
objective: &F,
params: &Array1<f64>,
architecture: &OptimizationArchitecture,
) -> Array1<f64>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut direction = Array1::zeros(params.len());
let h = 1e-6;
let f0 = objective(¶ms.view());
for i in 0..params.len() {
let mut params_plus = params.clone();
params_plus[i] += h;
let f_plus = objective(¶ms_plus.view());
direction[i] = (f_plus - f0) / h;
}
for component in &architecture.optimizer_components {
match component {
OptimizerComponent::Momentum { decay } => {
direction *= 1.0 - decay;
}
OptimizerComponent::Regularization {
l1_weight,
l2_weight,
..
} => {
for i in 0..direction.len() {
direction[i] += l1_weight * params[i].signum() + l2_weight * params[i];
}
}
_ => {}
}
}
direction
}
}
#[allow(dead_code)]
pub fn nas_optimize<F>(
objective: F,
initial_params: &ArrayView1<f64>,
config: Option<LearnedOptimizationConfig>,
) -> super::OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let config = config.unwrap_or_default();
let mut nas_system = AdaptiveNASSystem::new(config);
nas_system.optimize(objective, initial_params)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nas_system_creation() {
let config = LearnedOptimizationConfig::default();
let nas_system = AdaptiveNASSystem::new(config);
assert_eq!(nas_system.generation, 0);
assert!(nas_system.architecture_population.is_empty());
}
#[test]
fn test_architecture_generation() {
let config = LearnedOptimizationConfig::default();
let nas_system = AdaptiveNASSystem::new(config);
let architecture = nas_system
.generate_random_architecture()
.expect("Operation failed");
assert!(!architecture.layers.is_empty());
assert!(!architecture.activations.is_empty());
assert!(!architecture.optimizer_components.is_empty());
}
#[test]
fn test_vocabulary_creation() {
let vocab = ArchitectureVocabulary::new();
assert!(vocab.layer_types.contains_key("dense"));
assert!(vocab.activations.contains_key("relu"));
assert!(vocab.components.contains_key("momentum"));
assert_eq!(vocab.vocab_size, 19);
}
#[test]
fn test_architecture_mutation() {
let config = LearnedOptimizationConfig::default();
let nas_system = AdaptiveNASSystem::new(config);
let base_arch = nas_system
.generate_random_architecture()
.expect("Operation failed");
let mutated = nas_system
.mutate_architecture(&base_arch)
.expect("Operation failed");
assert_ne!(base_arch.id, mutated.id);
}
#[test]
fn test_nas_optimization() {
let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let initial = Array1::from(vec![2.0, 2.0]);
let config = LearnedOptimizationConfig {
meta_training_episodes: 5,
inner_steps: 10,
..Default::default()
};
let result =
nas_optimize(objective, &initial.view(), Some(config)).expect("Operation failed");
assert!(result.fun >= 0.0);
assert_eq!(result.x.len(), 2);
assert!(result.success);
}
}
#[allow(dead_code)]
pub fn placeholder() {
}