use std::collections::HashMap;
use crate::error::{TrainError, TrainResult};
use crate::hyperparameter::{HyperparamConfig, HyperparamValue};
#[derive(Debug, Clone, PartialEq)]
pub struct LayerSpec {
pub op: String,
pub width: usize,
pub activation: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Architecture {
pub layers: Vec<LayerSpec>,
}
impl Architecture {
pub fn param_count(&self) -> usize {
if self.layers.len() < 2 {
return 0;
}
self.layers
.windows(2)
.map(|w| w[0].width * w[1].width)
.sum()
}
pub fn depth(&self) -> usize {
self.layers.len()
}
pub fn to_config(&self) -> HyperparamConfig {
let mut m: HashMap<String, HyperparamValue> = HashMap::new();
m.insert(
"depth".to_string(),
HyperparamValue::Int(self.layers.len() as i64),
);
for (i, layer) in self.layers.iter().enumerate() {
m.insert(
format!("layer_{i}_op"),
HyperparamValue::String(layer.op.clone()),
);
m.insert(
format!("layer_{i}_width"),
HyperparamValue::Int(layer.width as i64),
);
m.insert(
format!("layer_{i}_activation"),
HyperparamValue::String(layer.activation.clone()),
);
}
m
}
pub fn from_config(cfg: &HyperparamConfig, max_depth: usize) -> TrainResult<Self> {
let depth = cfg.get("depth").and_then(|v| v.as_int()).ok_or_else(|| {
TrainError::InvalidParameter("config missing 'depth' Int key".to_string())
})?;
if depth < 1 {
return Err(TrainError::InvalidParameter(format!(
"decoded depth {depth} must be ≥ 1"
)));
}
if depth as usize > max_depth {
return Err(TrainError::InvalidParameter(format!(
"decoded depth {depth} exceeds max_depth {max_depth}"
)));
}
let mut layers = Vec::with_capacity(depth as usize);
for i in 0..depth as usize {
let op = cfg
.get(&format!("layer_{i}_op"))
.and_then(|v| v.as_string())
.ok_or_else(|| {
TrainError::InvalidParameter(format!(
"config missing 'layer_{i}_op' String key"
))
})?
.to_string();
let width = cfg
.get(&format!("layer_{i}_width"))
.and_then(|v| v.as_int())
.ok_or_else(|| {
TrainError::InvalidParameter(format!(
"config missing 'layer_{i}_width' Int key"
))
})?;
if width < 1 {
return Err(TrainError::InvalidParameter(format!(
"layer {i} width {width} must be ≥ 1"
)));
}
let activation = cfg
.get(&format!("layer_{i}_activation"))
.and_then(|v| v.as_string())
.ok_or_else(|| {
TrainError::InvalidParameter(format!(
"config missing 'layer_{i}_activation' String key"
))
})?
.to_string();
layers.push(LayerSpec {
op,
width: width as usize,
activation,
});
}
Ok(Architecture { layers })
}
}
#[derive(Debug, Clone)]
pub struct ArchSearchSpace {
pub min_depth: usize,
pub max_depth: usize,
pub width_options: Vec<usize>,
pub activation_options: Vec<String>,
pub op_options: Vec<String>,
}
impl ArchSearchSpace {
pub fn new(
min_depth: usize,
max_depth: usize,
width_options: Vec<usize>,
activation_options: Vec<String>,
op_options: Vec<String>,
) -> TrainResult<Self> {
if min_depth < 1 {
return Err(TrainError::InvalidParameter(
"min_depth must be ≥ 1".to_string(),
));
}
if max_depth < min_depth {
return Err(TrainError::InvalidParameter(format!(
"max_depth ({max_depth}) must be ≥ min_depth ({min_depth})"
)));
}
if width_options.is_empty() {
return Err(TrainError::InvalidParameter(
"width_options must be non-empty".to_string(),
));
}
if activation_options.is_empty() {
return Err(TrainError::InvalidParameter(
"activation_options must be non-empty".to_string(),
));
}
if op_options.is_empty() {
return Err(TrainError::InvalidParameter(
"op_options must be non-empty".to_string(),
));
}
Ok(Self {
min_depth,
max_depth,
width_options,
activation_options,
op_options,
})
}
}