use scirs2_core::random::*; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use trustformers_core::errors::{invalid_input, Result, TrustformersError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NASConfig {
pub strategy: SearchStrategy,
pub search_space: SearchSpace,
pub objectives: Vec<OptimizationObjective>,
pub max_evaluations: usize,
pub population_size: usize,
pub generations: usize,
pub patience: usize,
pub hardware_constraints: Option<HardwareConstraints>,
pub progressive_search: bool,
pub seed: Option<u64>,
}
impl Default for NASConfig {
fn default() -> Self {
Self {
strategy: SearchStrategy::Evolutionary,
search_space: SearchSpace::default(),
objectives: vec![OptimizationObjective::Accuracy { weight: 1.0 }],
max_evaluations: 1000,
population_size: 50,
generations: 20,
patience: 5,
hardware_constraints: None,
progressive_search: true,
seed: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchStrategy {
Random,
Evolutionary,
ReinforcementLearning,
DARTS,
Progressive,
BayesianOptimization,
NSGA2,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchSpace {
pub dimensions: HashMap<String, DimensionRange>,
pub choices: HashMap<String, Vec<String>>,
pub constraints: Vec<ArchitectureConstraint>,
}
impl Default for SearchSpace {
fn default() -> Self {
Self::transformer_space()
}
}
impl SearchSpace {
pub fn transformer_space() -> Self {
let mut dimensions = HashMap::new();
let mut choices = HashMap::new();
dimensions.insert("num_layers".to_string(), DimensionRange::new(6, 24, 1));
dimensions.insert(
"hidden_size".to_string(),
DimensionRange::new(512, 4096, 64),
);
dimensions.insert("num_heads".to_string(), DimensionRange::new(8, 32, 4));
dimensions.insert(
"intermediate_size".to_string(),
DimensionRange::new(2048, 16384, 256),
);
dimensions.insert(
"max_position_embeddings".to_string(),
DimensionRange::new(512, 8192, 512),
);
choices.insert(
"activation".to_string(),
vec![
"gelu".to_string(),
"relu".to_string(),
"swish".to_string(),
"silu".to_string(),
"gelu_new".to_string(),
],
);
choices.insert(
"attention_type".to_string(),
vec![
"standard".to_string(),
"grouped_query".to_string(),
"multi_query".to_string(),
"sparse".to_string(),
"sliding_window".to_string(),
],
);
choices.insert(
"normalization".to_string(),
vec![
"layer_norm".to_string(),
"rms_norm".to_string(),
"group_norm".to_string(),
],
);
choices.insert(
"position_encoding".to_string(),
vec![
"absolute".to_string(),
"relative".to_string(),
"rotary".to_string(),
"alibi".to_string(),
],
);
Self {
dimensions,
choices,
constraints: vec![
ArchitectureConstraint::DivisibilityConstraint {
dimension: "hidden_size".to_string(),
divisor: "num_heads".to_string(),
},
ArchitectureConstraint::RatioConstraint {
numerator: "intermediate_size".to_string(),
denominator: "hidden_size".to_string(),
min_ratio: 2.0,
max_ratio: 8.0,
},
],
}
}
pub fn vision_transformer_space() -> Self {
let mut dimensions = HashMap::new();
let mut choices = HashMap::new();
dimensions.insert("num_layers".to_string(), DimensionRange::new(6, 24, 1));
dimensions.insert(
"hidden_size".to_string(),
DimensionRange::new(384, 1536, 64),
);
dimensions.insert("num_heads".to_string(), DimensionRange::new(6, 24, 2));
dimensions.insert("patch_size".to_string(), DimensionRange::new(8, 32, 4));
dimensions.insert("image_size".to_string(), DimensionRange::new(224, 512, 32));
choices.insert(
"pooling".to_string(),
vec![
"cls_token".to_string(),
"gap".to_string(),
"map".to_string(),
],
);
Self {
dimensions,
choices,
constraints: vec![
ArchitectureConstraint::DivisibilityConstraint {
dimension: "hidden_size".to_string(),
divisor: "num_heads".to_string(),
},
ArchitectureConstraint::DivisibilityConstraint {
dimension: "image_size".to_string(),
divisor: "patch_size".to_string(),
},
],
}
}
pub fn validate_architecture(&self, architecture: &Architecture) -> Result<()> {
for constraint in &self.constraints {
constraint.validate(architecture)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DimensionRange {
pub min: i32,
pub max: i32,
pub step: i32,
}
impl DimensionRange {
pub fn new(min: i32, max: i32, step: i32) -> Self {
Self { min, max, step }
}
#[allow(deprecated)]
pub fn sample(&self, rng: &mut impl Rng) -> i32 {
let steps = (self.max - self.min) / self.step + 1;
let step_idx = rng.random_range(0..steps);
self.min + step_idx * self.step
}
pub fn validate(&self, value: i32) -> bool {
value >= self.min && value <= self.max && (value - self.min) % self.step == 0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ArchitectureConstraint {
DivisibilityConstraint { dimension: String, divisor: String },
RatioConstraint {
numerator: String,
denominator: String,
min_ratio: f32,
max_ratio: f32,
},
ParameterConstraint {
min_params: Option<usize>,
max_params: Option<usize>,
},
CustomConstraint { name: String, description: String },
}
impl ArchitectureConstraint {
fn validate(&self, architecture: &Architecture) -> Result<()> {
match self {
ArchitectureConstraint::DivisibilityConstraint { dimension, divisor } => {
let dim_val = architecture
.dimensions
.get(dimension)
.ok_or_else(|| invalid_input(format!("Missing dimension: {}", dimension)))?;
let div_val = architecture
.dimensions
.get(divisor)
.ok_or_else(|| invalid_input(format!("Missing divisor: {}", divisor)))?;
if dim_val % div_val != 0 {
return Err(invalid_input(format!(
"{} ({}) must be divisible by {} ({})",
dimension, dim_val, divisor, div_val
)));
}
},
ArchitectureConstraint::RatioConstraint {
numerator,
denominator,
min_ratio,
max_ratio,
} => {
let num_val = *architecture
.dimensions
.get(numerator)
.ok_or_else(|| invalid_input(format!("Missing numerator: {}", numerator)))?
as f32;
let den_val =
*architecture.dimensions.get(denominator).ok_or_else(|| {
invalid_input(format!("Missing denominator: {}", denominator))
})? as f32;
let ratio = num_val / den_val;
if ratio < *min_ratio || ratio > *max_ratio {
return Err(invalid_input(format!(
"Ratio {} / {} ({:.2}) must be between {:.2} and {:.2}",
numerator, denominator, ratio, min_ratio, max_ratio
)));
}
},
ArchitectureConstraint::ParameterConstraint {
min_params,
max_params,
} => {
let params = architecture.estimate_parameters();
if let Some(min) = min_params {
if params < *min {
return Err(invalid_input(format!(
"Architecture has {} parameters, minimum required: {}",
params, min
)));
}
}
if let Some(max) = max_params {
if params > *max {
return Err(invalid_input(format!(
"Architecture has {} parameters, maximum allowed: {}",
params, max
)));
}
}
},
ArchitectureConstraint::CustomConstraint { name, .. } => {
return Err(invalid_input(format!(
"Custom constraint '{}' not implemented",
name
)));
},
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OptimizationObjective {
Accuracy { weight: f32 },
Latency { weight: f32 },
Memory { weight: f32 },
Energy { weight: f32 },
ModelSize { weight: f32 },
Efficiency { weight: f32 },
Custom { name: String, weight: f32 },
}
impl OptimizationObjective {
pub fn weight(&self) -> f32 {
match self {
OptimizationObjective::Accuracy { weight }
| OptimizationObjective::Latency { weight }
| OptimizationObjective::Memory { weight }
| OptimizationObjective::Energy { weight }
| OptimizationObjective::ModelSize { weight }
| OptimizationObjective::Efficiency { weight }
| OptimizationObjective::Custom { weight, .. } => *weight,
}
}
pub fn name(&self) -> &str {
match self {
OptimizationObjective::Accuracy { .. } => "accuracy",
OptimizationObjective::Latency { .. } => "latency",
OptimizationObjective::Memory { .. } => "memory",
OptimizationObjective::Energy { .. } => "energy",
OptimizationObjective::ModelSize { .. } => "model_size",
OptimizationObjective::Efficiency { .. } => "efficiency",
OptimizationObjective::Custom { name, .. } => name,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareConstraints {
pub max_memory_gb: Option<f32>,
pub max_latency_ms: Option<f32>,
pub platform: HardwarePlatform,
pub max_energy_mj: Option<f32>,
pub min_throughput: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum HardwarePlatform {
CPU,
GPU {
memory_gb: f32,
},
TPU,
Mobile,
Edge,
Custom {
name: String,
specs: HashMap<String, f32>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Architecture {
pub dimensions: HashMap<String, i32>,
pub choices: HashMap<String, String>,
pub metadata: ArchitectureMetadata,
}
impl Default for Architecture {
fn default() -> Self {
Self::new()
}
}
impl Architecture {
pub fn new() -> Self {
Self {
dimensions: HashMap::new(),
choices: HashMap::new(),
metadata: ArchitectureMetadata::default(),
}
}
pub fn estimate_parameters(&self) -> usize {
let hidden_size = *self.dimensions.get("hidden_size").unwrap_or(&768) as f64;
let num_layers = *self.dimensions.get("num_layers").unwrap_or(&12) as f64;
let vocab_size = *self.dimensions.get("vocab_size").unwrap_or(&32000) as f64;
let intermediate_size =
*self.dimensions.get("intermediate_size").unwrap_or(&(hidden_size as i32 * 4)) as f64;
let embedding_params = vocab_size * hidden_size;
let attention_params = num_layers * (4.0 * hidden_size * hidden_size);
let ffn_params = num_layers * (2.0 * hidden_size * intermediate_size);
let norm_params = num_layers * 2.0 * hidden_size;
(embedding_params + attention_params + ffn_params + norm_params) as usize
}
pub fn estimate_memory_mb(&self) -> f32 {
let params = self.estimate_parameters() as f32;
(params * 4.0 * 1.5) / (1024.0 * 1024.0)
}
pub fn estimate_latency(&self) -> f32 {
let num_layers = *self.dimensions.get("num_layers").unwrap_or(&12) as f32;
let hidden_size = *self.dimensions.get("hidden_size").unwrap_or(&768) as f32;
num_layers * hidden_size.powf(1.5) / 1000000.0
}
#[allow(deprecated)]
pub fn random(search_space: &SearchSpace, rng: &mut impl Rng) -> Self {
let mut architecture = Architecture::new();
for (name, range) in &search_space.dimensions {
architecture.dimensions.insert(name.clone(), range.sample(rng));
}
for (name, options) in &search_space.choices {
if !options.is_empty() {
let choice = options[rng.random_range(0..options.len())].clone();
architecture.choices.insert(name.clone(), choice);
}
}
architecture
}
#[allow(deprecated)]
pub fn mutate(&mut self, search_space: &SearchSpace, mutation_rate: f32, rng: &mut impl Rng) {
for (name, value) in &mut self.dimensions {
if rng.random::<f32>() < mutation_rate {
if let Some(range) = search_space.dimensions.get(name) {
*value = range.sample(rng);
}
}
}
for (name, value) in &mut self.choices {
if rng.random::<f32>() < mutation_rate {
if let Some(options) = search_space.choices.get(name) {
if !options.is_empty() {
*value = options[rng.random_range(0..options.len())].clone();
}
}
}
}
self.metadata.generation += 1;
}
#[allow(deprecated)]
pub fn crossover(&self, other: &Architecture, rng: &mut impl Rng) -> Architecture {
let mut child = Architecture::new();
for name in self.dimensions.keys() {
let value = if rng.random::<f32>() < 0.5 {
self.dimensions[name]
} else {
other.dimensions.get(name).copied().unwrap_or(self.dimensions[name])
};
child.dimensions.insert(name.clone(), value);
}
for name in self.choices.keys() {
let value = if rng.random::<f32>() < 0.5 {
self.choices[name].clone()
} else {
other.choices.get(name).cloned().unwrap_or_else(|| self.choices[name].clone())
};
child.choices.insert(name.clone(), value);
}
child.metadata.generation =
std::cmp::max(self.metadata.generation, other.metadata.generation) + 1;
child
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureMetadata {
pub id: String,
pub generation: u32,
pub parents: Vec<String>,
pub created_at: std::time::SystemTime,
}
impl Default for ArchitectureMetadata {
fn default() -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
generation: 0,
parents: Vec::new(),
created_at: std::time::SystemTime::now(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureEvaluation {
pub architecture: Architecture,
pub metrics: HashMap<String, f32>,
pub fitness: f32,
pub evaluation_time: std::time::Duration,
pub info: HashMap<String, String>,
}
impl ArchitectureEvaluation {
pub fn new(architecture: Architecture) -> Self {
Self {
architecture,
metrics: HashMap::new(),
fitness: 0.0,
evaluation_time: std::time::Duration::from_secs(0),
info: HashMap::new(),
}
}
}
pub struct NeuralArchitectureSearcher {
config: NASConfig,
search_space: SearchSpace,
population: Vec<ArchitectureEvaluation>,
best_architecture: Option<ArchitectureEvaluation>,
evaluation_history: Vec<ArchitectureEvaluation>,
rng: StdRng,
}
impl NeuralArchitectureSearcher {
pub fn new(config: NASConfig) -> Result<Self> {
let rng = if let Some(seed) = config.seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::seed_from_u64(random::<u64>())
};
Ok(Self {
search_space: config.search_space.clone(),
config,
population: Vec::new(),
best_architecture: None,
evaluation_history: Vec::new(),
rng,
})
}
pub fn search(&mut self) -> Result<ArchitectureEvaluation> {
match self.config.strategy {
SearchStrategy::Random => self.random_search(),
SearchStrategy::Evolutionary => self.evolutionary_search(),
SearchStrategy::ReinforcementLearning => self.rl_search(),
SearchStrategy::DARTS => self.darts_search(),
SearchStrategy::Progressive => self.progressive_search(),
SearchStrategy::BayesianOptimization => self.bayesian_search(),
SearchStrategy::NSGA2 => self.nsga2_search(),
}
}
fn random_search(&mut self) -> Result<ArchitectureEvaluation> {
for i in 0..self.config.max_evaluations {
let architecture = Architecture::random(&self.search_space, &mut self.rng);
let evaluation = self.evaluate_architecture(architecture)?;
self.update_best(&evaluation);
self.evaluation_history.push(evaluation);
if i % 100 == 0 {
if let Some(ref best) = self.best_architecture {
println!(
"Random search iteration {}, best fitness: {:.4}",
i, best.fitness
);
}
}
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
#[allow(deprecated)]
fn evolutionary_search(&mut self) -> Result<ArchitectureEvaluation> {
self.initialize_population()?;
for generation in 0..self.config.generations {
let parents = self.select_parents();
let mut offspring = Vec::new();
for _ in 0..self.config.population_size / 2 {
let parent1_idx = self.rng.random_range(0..parents.len());
let parent2_idx = self.rng.random_range(0..parents.len());
let parent1 = &parents[parent1_idx];
let parent2 = &parents[parent2_idx];
let mut child1 =
parent1.architecture.crossover(&parent2.architecture, &mut self.rng);
let mut child2 =
parent2.architecture.crossover(&parent1.architecture, &mut self.rng);
child1.mutate(&self.search_space, 0.1, &mut self.rng);
child2.mutate(&self.search_space, 0.1, &mut self.rng);
offspring.push(self.evaluate_architecture(child1)?);
offspring.push(self.evaluate_architecture(child2)?);
}
self.environmental_selection(offspring)?;
println!(
"Generation {}, best fitness: {:.4}",
generation,
self.best_architecture.as_ref().map_or(0.0, |a| a.fitness)
);
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
fn rl_search(&mut self) -> Result<ArchitectureEvaluation> {
for i in 0..self.config.max_evaluations {
let architecture = Architecture::random(&self.search_space, &mut self.rng);
let evaluation = self.evaluate_architecture(architecture)?;
self.update_best(&evaluation);
self.evaluation_history.push(evaluation);
if i % 100 == 0 {
println!(
"RL search iteration {}, best fitness: {:.4}",
i,
self.best_architecture.as_ref().map_or(0.0, |a| a.fitness)
);
}
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
fn darts_search(&mut self) -> Result<ArchitectureEvaluation> {
for i in 0..self.config.max_evaluations {
let architecture = Architecture::random(&self.search_space, &mut self.rng);
let evaluation = self.evaluate_architecture(architecture)?;
self.update_best(&evaluation);
self.evaluation_history.push(evaluation);
if i % 100 == 0 {
println!(
"DARTS iteration {}, best fitness: {:.4}",
i,
self.best_architecture.as_ref().map_or(0.0, |a| a.fitness)
);
}
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
fn progressive_search(&mut self) -> Result<ArchitectureEvaluation> {
let complexity_stages = 5;
let evaluations_per_stage = self.config.max_evaluations / complexity_stages;
for stage in 0..complexity_stages {
let complexity_factor = (stage + 1) as f32 / complexity_stages as f32;
for i in 0..evaluations_per_stage {
let mut architecture = Architecture::random(&self.search_space, &mut self.rng);
for (name, value) in &mut architecture.dimensions {
if let Some(range) = self.search_space.dimensions.get(name) {
let scaled =
range.min + (((*value - range.min) as f32 * complexity_factor) as i32);
*value = scaled.clamp(range.min, range.max);
}
}
let evaluation = self.evaluate_architecture(architecture)?;
self.update_best(&evaluation);
self.evaluation_history.push(evaluation);
if i % 50 == 0 {
println!(
"Progressive search stage {}, iteration {}, best fitness: {:.4}",
stage,
i,
self.best_architecture.as_ref().map_or(0.0, |a| a.fitness)
);
}
}
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
fn bayesian_search(&mut self) -> Result<ArchitectureEvaluation> {
for i in 0..self.config.max_evaluations {
let architecture = if i < 10 {
Architecture::random(&self.search_space, &mut self.rng)
} else {
let best = self.best_architecture.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"No best architecture available for guidance".to_string(),
)
})?;
let mut arch = best.architecture.clone();
arch.mutate(&self.search_space, 0.2, &mut self.rng);
arch
};
let evaluation = self.evaluate_architecture(architecture)?;
self.update_best(&evaluation);
self.evaluation_history.push(evaluation);
if i % 100 == 0 {
println!(
"Bayesian search iteration {}, best fitness: {:.4}",
i,
self.best_architecture.as_ref().map_or(0.0, |a| a.fitness)
);
}
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
#[allow(deprecated)]
fn nsga2_search(&mut self) -> Result<ArchitectureEvaluation> {
self.initialize_population()?;
for generation in 0..self.config.generations {
let parents = self.select_parents();
let mut offspring = Vec::new();
for _ in 0..self.config.population_size {
let parent1_idx = self.rng.random_range(0..parents.len());
let parent2_idx = self.rng.random_range(0..parents.len());
let parent1 = &parents[parent1_idx];
let parent2 = &parents[parent2_idx];
let mut child =
parent1.architecture.crossover(&parent2.architecture, &mut self.rng);
child.mutate(&self.search_space, 0.1, &mut self.rng);
offspring.push(self.evaluate_architecture(child)?);
}
self.nsga2_selection(offspring)?;
println!(
"NSGA-II generation {}, population size: {}",
generation,
self.population.len()
);
}
self.best_architecture.clone().ok_or_else(|| {
TrustformersError::invalid_config("No architecture found during search".to_string())
})
}
fn initialize_population(&mut self) -> Result<()> {
self.population.clear();
for _ in 0..self.config.population_size {
let architecture = Architecture::random(&self.search_space, &mut self.rng);
let evaluation = self.evaluate_architecture(architecture)?;
self.population.push(evaluation);
}
if let Some(best) = self
.population
.iter()
.max_by(|a, b| a.fitness.partial_cmp(&b.fitness).expect("operation failed"))
{
self.best_architecture = Some(best.clone());
}
Ok(())
}
#[allow(deprecated)]
fn select_parents(&mut self) -> Vec<ArchitectureEvaluation> {
let tournament_size = 3;
let mut parents = Vec::new();
for _ in 0..self.config.population_size {
let mut tournament = Vec::new();
for _ in 0..tournament_size {
let idx = self.rng.random_range(0..self.population.len());
tournament.push(self.population[idx].clone());
}
let winner = tournament
.into_iter()
.max_by(|a, b| {
a.fitness.partial_cmp(&b.fitness).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or_else(|| {
self.population[0].clone()
});
parents.push(winner);
}
parents
}
fn environmental_selection(&mut self, offspring: Vec<ArchitectureEvaluation>) -> Result<()> {
let mut combined = self.population.clone();
combined.extend(offspring);
combined
.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap_or(std::cmp::Ordering::Equal));
self.population = combined.into_iter().take(self.config.population_size).collect();
if let Some(best) = self.population.first() {
let should_update = self
.best_architecture
.as_ref()
.is_none_or(|current| best.fitness > current.fitness);
if should_update {
self.best_architecture = Some(best.clone());
}
}
Ok(())
}
fn nsga2_selection(&mut self, offspring: Vec<ArchitectureEvaluation>) -> Result<()> {
self.environmental_selection(offspring)
}
fn evaluate_architecture(&self, architecture: Architecture) -> Result<ArchitectureEvaluation> {
let start_time = std::time::Instant::now();
self.search_space.validate_architecture(&architecture)?;
let mut evaluation = ArchitectureEvaluation::new(architecture);
for objective in &self.config.objectives {
let (metric_name, metric_value) = match objective {
OptimizationObjective::Accuracy { .. } => {
let complexity =
evaluation.architecture.estimate_parameters() as f32 / 1000000.0;
let accuracy =
0.85 + (complexity / 100.0).min(0.1) - (complexity / 1000.0).max(0.0);
("accuracy", accuracy.clamp(0.0, 1.0))
},
OptimizationObjective::Latency { .. } => {
let latency = evaluation.architecture.estimate_latency();
("latency", 1.0 / (1.0 + latency)) },
OptimizationObjective::Memory { .. } => {
let memory = evaluation.architecture.estimate_memory_mb();
("memory", 1.0 / (1.0 + memory / 1000.0)) },
OptimizationObjective::ModelSize { .. } => {
let params = evaluation.architecture.estimate_parameters() as f32;
("model_size", 1.0 / (1.0 + params / 1000000.0)) },
OptimizationObjective::Efficiency { .. } => {
let params = evaluation.architecture.estimate_parameters() as f32;
let latency = evaluation.architecture.estimate_latency();
("efficiency", 1.0 / (1.0 + params / 1000000.0 + latency))
},
OptimizationObjective::Energy { .. } => {
let energy = evaluation.architecture.estimate_latency() * 0.5; ("energy", 1.0 / (1.0 + energy))
},
OptimizationObjective::Custom { name, .. } => {
(name.as_str(), 0.5) },
};
evaluation.metrics.insert(metric_name.to_string(), metric_value);
}
evaluation.fitness = self
.config
.objectives
.iter()
.map(|obj| {
let metric_value = evaluation.metrics.get(obj.name()).unwrap_or(&0.0);
obj.weight() * metric_value
})
.sum();
evaluation.evaluation_time = start_time.elapsed();
Ok(evaluation)
}
fn update_best(&mut self, evaluation: &ArchitectureEvaluation) {
let should_update = self
.best_architecture
.as_ref()
.is_none_or(|current| evaluation.fitness > current.fitness);
if should_update {
self.best_architecture = Some(evaluation.clone());
}
}
pub fn best_architecture(&self) -> Option<&ArchitectureEvaluation> {
self.best_architecture.as_ref()
}
pub fn evaluation_history(&self) -> &[ArchitectureEvaluation] {
&self.evaluation_history
}
pub fn get_statistics(&self) -> SearchStatistics {
let mut stats = SearchStatistics::default();
if !self.evaluation_history.is_empty() {
let fitnesses: Vec<f32> = self.evaluation_history.iter().map(|e| e.fitness).collect();
stats.num_evaluations = fitnesses.len();
stats.best_fitness = fitnesses.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
stats.average_fitness = fitnesses.iter().sum::<f32>() / fitnesses.len() as f32;
stats.fitness_std = {
let variance =
fitnesses.iter().map(|f| (f - stats.average_fitness).powi(2)).sum::<f32>()
/ fitnesses.len() as f32;
variance.sqrt()
};
}
stats
}
}
#[derive(Debug, Clone, Default)]
pub struct SearchStatistics {
pub num_evaluations: usize,
pub best_fitness: f32,
pub average_fitness: f32,
pub fitness_std: f32,
pub convergence_generation: Option<usize>,
}
impl fmt::Display for SearchStatistics {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SearchStatistics {{ evaluations: {}, best: {:.4}, avg: {:.4}, std: {:.4} }}",
self.num_evaluations, self.best_fitness, self.average_fitness, self.fitness_std
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nas_config_default() {
let config = NASConfig::default();
assert_eq!(config.max_evaluations, 1000);
assert_eq!(config.population_size, 50);
assert!(matches!(config.strategy, SearchStrategy::Evolutionary));
}
#[test]
fn test_transformer_search_space() {
let space = SearchSpace::transformer_space();
assert!(space.dimensions.contains_key("num_layers"));
assert!(space.dimensions.contains_key("hidden_size"));
assert!(space.choices.contains_key("activation"));
}
#[test]
fn test_architecture_random_generation() {
let space = SearchSpace::transformer_space();
let mut rng = StdRng::seed_from_u64(42);
let arch = Architecture::random(&space, &mut rng);
assert!(!arch.dimensions.is_empty());
assert!(!arch.choices.is_empty());
}
#[test]
fn test_architecture_parameter_estimation() {
let mut arch = Architecture::new();
arch.dimensions.insert("hidden_size".to_string(), 768);
arch.dimensions.insert("num_layers".to_string(), 12);
arch.dimensions.insert("vocab_size".to_string(), 32000);
let params = arch.estimate_parameters();
assert!(params > 100_000_000); }
#[test]
fn test_architecture_constraint_validation() {
let space = SearchSpace::transformer_space();
let mut arch = Architecture::new();
arch.dimensions.insert("hidden_size".to_string(), 768);
arch.dimensions.insert("num_heads".to_string(), 12);
arch.dimensions.insert("intermediate_size".to_string(), 3072);
assert!(space.validate_architecture(&arch).is_ok());
arch.dimensions.insert("hidden_size".to_string(), 777); assert!(space.validate_architecture(&arch).is_err());
}
#[test]
fn test_architecture_mutation() {
let space = SearchSpace::transformer_space();
let mut rng = StdRng::seed_from_u64(42);
let mut arch = Architecture::random(&space, &mut rng);
let original = arch.clone();
arch.mutate(&space, 1.0, &mut rng);
let mut differences = 0;
for (key, value) in &arch.dimensions {
if original.dimensions.get(key) != Some(value) {
differences += 1;
}
}
assert!(differences > 0);
}
#[test]
fn test_neural_architecture_searcher_creation() {
let config = NASConfig::default();
let searcher = NeuralArchitectureSearcher::new(config);
assert!(searcher.is_ok());
}
#[test]
fn test_dimension_range() {
let range = DimensionRange::new(1, 10, 2);
assert!(range.validate(1));
assert!(range.validate(3));
assert!(range.validate(9));
assert!(!range.validate(2));
assert!(!range.validate(11));
let mut rng = StdRng::seed_from_u64(42);
let sample = range.sample(&mut rng);
assert!(range.validate(sample));
}
#[test]
fn test_optimization_objectives() {
let obj1 = OptimizationObjective::Accuracy { weight: 0.7 };
let obj2 = OptimizationObjective::Latency { weight: 0.3 };
assert_eq!(obj1.weight(), 0.7);
assert_eq!(obj2.weight(), 0.3);
assert_eq!(obj1.name(), "accuracy");
assert_eq!(obj2.name(), "latency");
}
#[test]
fn test_architecture_crossover() {
let space = SearchSpace::transformer_space();
let mut rng = StdRng::seed_from_u64(42);
let parent1 = Architecture::random(&space, &mut rng);
let parent2 = Architecture::random(&space, &mut rng);
let child = parent1.crossover(&parent2, &mut rng);
assert_eq!(child.dimensions.len(), parent1.dimensions.len());
assert_eq!(child.choices.len(), parent1.choices.len());
assert_eq!(
child.metadata.generation,
std::cmp::max(parent1.metadata.generation, parent2.metadata.generation) + 1
);
}
}