use crate::neural_architecture_search::types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Architecture {
pub id: Uuid,
pub layers: Vec<LayerConfig>,
pub global_config: GlobalArchConfig,
pub performance: Option<PerformanceMetrics>,
pub generation: usize,
pub parents: Vec<Uuid>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerConfig {
pub layer_type: LayerType,
pub activation: ActivationType,
pub normalization: NormalizationType,
pub skip_pattern: SkipPattern,
pub hyperparameters: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalArchConfig {
pub input_dim: usize,
pub output_dim: usize,
pub learning_rate: f64,
pub optimizer: OptimizerType,
pub regularization: RegularizationConfig,
pub training_config: TrainingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum OptimizerType {
Adam { beta1: f64, beta2: f64, eps: f64 },
AdamW { beta1: f64, beta2: f64, eps: f64, weight_decay: f64 },
SGD { momentum: f64 },
RMSprop { alpha: f64, eps: f64 },
Lion { beta1: f64, beta2: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegularizationConfig {
pub l1_weight: f64,
pub l2_weight: f64,
pub dropout_rate: f64,
pub label_smoothing: f64,
pub early_stopping_patience: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub batch_size: usize,
pub epochs: usize,
pub validation_split: f64,
pub lr_schedule: LRScheduleType,
pub loss_function: LossFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum LRScheduleType {
Constant,
StepLR { step_size: usize, gamma: f64 },
ExponentialLR { gamma: f64 },
CosineAnnealingLR { t_max: usize },
ReduceLROnPlateau { factor: f64, patience: usize },
WarmupCosine { warmup_epochs: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum LossFunction {
MSE,
CosineSimilarity,
TripletLoss { margin: f64 },
ContrastiveLoss { margin: f64 },
InfoNCE { temperature: f64 },
ArcFace { scale: f64, margin: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub embedding_quality: f64,
pub training_loss: f64,
pub validation_loss: f64,
pub inference_latency_ms: f64,
pub model_size_params: usize,
pub memory_usage_mb: f64,
pub flops: u64,
pub training_time_minutes: f64,
pub energy_consumption: f64,
pub task_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct ArchitectureSearchSpace {
pub layer_types: Vec<LayerType>,
pub depth_range: (usize, usize),
pub width_range: (usize, usize),
pub activations: Vec<ActivationType>,
pub normalizations: Vec<NormalizationType>,
pub attention_types: Vec<AttentionType>,
pub skip_patterns: Vec<SkipPattern>,
pub embedding_dims: Vec<usize>,
}
impl Architecture {
pub fn new(layers: Vec<LayerConfig>, global_config: GlobalArchConfig) -> Self {
Self {
id: Uuid::new_v4(),
layers,
global_config,
performance: None,
generation: 0,
parents: Vec::new(),
}
}
pub fn estimate_complexity(&self) -> usize {
self.layers.iter().map(|layer| {
match &layer.layer_type {
LayerType::Linear { input_dim, output_dim } => input_dim * output_dim,
LayerType::Conv1D { filters, kernel_size, .. } => filters * kernel_size,
LayerType::LSTM { hidden_size, num_layers } => hidden_size * num_layers * 4,
LayerType::GRU { hidden_size, num_layers } => hidden_size * num_layers * 3,
LayerType::Transformer { d_model, num_heads, num_layers } => {
d_model * d_model * num_heads * num_layers
},
_ => 1000, }
}).sum()
}
}
impl Default for ArchitectureSearchSpace {
fn default() -> Self {
Self {
layer_types: vec![
LayerType::Linear { input_dim: 512, output_dim: 256 },
LayerType::LSTM { hidden_size: 256, num_layers: 2 },
LayerType::Transformer { d_model: 256, num_heads: 8, num_layers: 4 },
],
depth_range: (2, 12),
width_range: (64, 1024),
activations: vec![ActivationType::ReLU, ActivationType::GELU, ActivationType::Swish],
normalizations: vec![NormalizationType::LayerNorm, NormalizationType::BatchNorm],
attention_types: vec![AttentionType::MultiHead { num_heads: 8 }],
skip_patterns: vec![SkipPattern::None, SkipPattern::Residual],
embedding_dims: vec![128, 256, 512, 768, 1024],
}
}
}