use crate::neural_architecture_search::{
architecture::*,
config::*,
types::*,
evaluator::*,
history::*,
dataset::*,
monitoring::*,
};
use anyhow::{Result, Context};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
use scirs2_core::random::{Random, Rng};
pub struct NeuralArchitectureSearch {
config: NASConfig,
search_space: ArchitectureSearchSpace,
evaluator: Arc<PerformanceEvaluator>,
strategy: SearchStrategy,
history: Arc<RwLock<SearchHistory>>,
population: Vec<Architecture>,
rng: Random,
}
impl NeuralArchitectureSearch {
pub fn new(
config: NASConfig,
search_space: ArchitectureSearchSpace,
evaluator: Arc<PerformanceEvaluator>,
strategy: SearchStrategy,
) -> Result<Self> {
let mut rng = Random::seed(config.random_seed);
let history = Arc::new(RwLock::new(SearchHistory::default()));
let population = Self::initialize_population(
&search_space,
config.population_size,
&mut rng,
)?;
Ok(Self {
config,
search_space,
evaluator,
strategy,
history,
population,
rng,
})
}
pub async fn search(&mut self) -> Result<Architecture> {
let mut best_architecture = None;
let mut best_performance = f64::NEG_INFINITY;
let mut generations_without_improvement = 0;
for generation in 0..self.config.max_generations {
self.evaluate_population(generation).await?;
if let Some(arch) = self.get_best_architecture()? {
if let Some(perf) = &arch.performance {
if perf.composite_score > best_performance {
best_performance = perf.composite_score;
best_architecture = Some(arch.clone());
generations_without_improvement = 0;
if best_performance >= self.config.target_performance {
break;
}
} else {
generations_without_improvement += 1;
}
}
}
if generations_without_improvement >= self.config.early_stopping_patience {
break;
}
self.evolve_population(generation)?;
self.update_search_statistics(generation, best_performance)?;
}
best_architecture.ok_or_else(|| anyhow::anyhow!("No valid architecture found"))
}
fn initialize_population(
search_space: &ArchitectureSearchSpace,
population_size: usize,
rng: &mut Random,
) -> Result<Vec<Architecture>> {
let mut population = Vec::with_capacity(population_size);
for _ in 0..population_size {
let architecture = Self::generate_random_architecture(search_space, rng)?;
population.push(architecture);
}
Ok(population)
}
fn generate_random_architecture(
search_space: &ArchitectureSearchSpace,
rng: &mut Random,
) -> Result<Architecture> {
let depth = rng.gen_range(search_space.depth_range.0..=search_space.depth_range.1);
let mut layers = Vec::with_capacity(depth);
for _i in 0..depth {
let layer_type = if search_space.layer_types.is_empty() {
return Err(anyhow::anyhow!("No layer types available"));
} else {
let idx = rng.random_range(0..search_space.layer_types.len());
search_space.layer_types[idx].clone()
};
let activation = if search_space.activations.is_empty() {
return Err(anyhow::anyhow!("No activation functions available"));
} else {
let idx = rng.random_range(0..search_space.activations.len());
search_space.activations[idx].clone()
};
let normalization = if search_space.normalizations.is_empty() {
return Err(anyhow::anyhow!("No normalization types available"));
} else {
let idx = rng.random_range(0..search_space.normalizations.len());
search_space.normalizations[idx].clone()
};
let skip_pattern = if search_space.skip_patterns.is_empty() {
return Err(anyhow::anyhow!("No skip patterns available"));
} else {
let idx = rng.random_range(0..search_space.skip_patterns.len());
search_space.skip_patterns[idx].clone()
};
let mut hyperparameters = HashMap::new();
hyperparameters.insert("learning_rate".to_string(), rng.random_range(1e-5, 1e-1));
hyperparameters.insert("dropout_rate".to_string(), rng.gen_range(0.0..0.5));
hyperparameters.insert("weight_decay".to_string(), rng.random_range(1e-6, 1e-2));
let layer_config = LayerConfig {
layer_type,
activation,
normalization,
skip_pattern,
hyperparameters,
};
layers.push(layer_config);
}
let embedding_dim = if search_space.embedding_dims.is_empty() {
return Err(anyhow::anyhow!("No embedding dimensions available"));
} else {
let idx = rng.random_range(0..search_space.embedding_dims.len());
&search_space.embedding_dims[idx]
};
let global_config = GlobalArchConfig {
input_dim: rng.random_range(128..2048),
output_dim: *embedding_dim,
learning_rate: rng.random_range(1e-5, 1e-2),
optimizer: OptimizerType::Adam {
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
},
regularization: RegularizationConfig {
l1_weight: rng.gen_range(0.0..1e-3),
l2_weight: rng.gen_range(0.0..1e-2),
dropout_rate: rng.gen_range(0.0..0.5),
label_smoothing: rng.gen_range(0.0..0.1),
early_stopping_patience: rng.random_range(5..20),
},
training_config: TrainingConfig {
batch_size: {
let batch_sizes = [16, 32, 64, 128, 256];
let idx = rng.random_range(0..batch_sizes.len());
batch_sizes[idx]
},
epochs: rng.random_range(10..100),
validation_split: rng.gen_range(0.1..0.3),
lr_schedule: LRScheduleType::CosineAnnealingLR { t_max: 50 },
loss_function: LossFunction::CosineSimilarity,
},
};
Ok(Architecture::new(layers, global_config))
}
async fn evaluate_population(&mut self, generation: usize) -> Result<()> {
for arch in &mut self.population {
arch.generation = generation;
let performance = PerformanceMetrics {
embedding_quality: self.rng.gen_range(0.0..1.0),
training_loss: self.rng.gen_range(0.01..1.0),
validation_loss: self.rng.gen_range(0.01..1.0),
inference_latency_ms: self.rng.gen_range(1.0..100.0),
model_size_params: arch.estimate_complexity(),
memory_usage_mb: self.rng.gen_range(50.0..500.0),
flops: self.rng.random_range(1_000_000, 100_000_000),
training_time_minutes: self.rng.gen_range(5.0..120.0),
energy_consumption: self.rng.gen_range(10.0..200.0),
task_metrics: HashMap::new(),
};
arch.performance = Some(performance);
}
{
let mut history = self.history
.write()
.expect("history write lock should not be poisoned");
for arch in &self.population {
history.add_architecture(arch.clone());
}
}
Ok(())
}
fn get_best_architecture(&self) -> Result<Option<Architecture>> {
let best = self.population.iter()
.filter_map(|arch| arch.performance.as_ref().map(|perf| (arch, perf)))
.max_by(|(_, perf1), (_, perf2)| {
perf1.composite_score.partial_cmp(&perf2.composite_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(best.map(|(arch, _)| arch.clone()))
}
fn evolve_population(&mut self, generation: usize) -> Result<()> {
match &self.strategy {
SearchStrategy::EvolutionaryAlgorithm => {
Ok(())
}
SearchStrategy::ReinforcementLearning => {
Ok(())
}
SearchStrategy::BayesianOptimization => {
Ok(())
}
SearchStrategy::RandomSearch => {
let num_to_replace = self.population.len() / 4;
for i in 0..num_to_replace {
let new_arch = Self::generate_random_architecture(&self.search_space, &mut self.rng)?;
self.population[i] = new_arch;
}
Ok(())
}
SearchStrategy::GridSearch => {
Ok(())
}
SearchStrategy::GradientBased => {
Ok(())
}
SearchStrategy::Hybrid { .. } => {
Ok(())
}
}
}
fn update_search_statistics(&mut self, generation: usize, best_performance: f64) -> Result<()> {
let mut history = self.history
.write()
.expect("history write lock should not be poisoned");
history.update_best(
self.get_best_architecture()?
.expect("best architecture should exist after evaluation")
.id,
best_performance,
generation,
);
Ok(())
}
pub fn get_search_history(&self) -> Arc<RwLock<SearchHistory>> {
self.history.clone()
}
pub fn get_population(&self) -> &[Architecture] {
&self.population
}
pub fn get_config(&self) -> &NASConfig {
&self.config
}
pub fn get_search_space(&self) -> &ArchitectureSearchSpace {
&self.search_space
}
}