use scirs2_core::random::{SeedableRng, StdRng};
use crate::error::{TrainError, TrainResult};
use super::space::{ArchSearchSpace, Architecture, LayerSpec};
pub struct ArchSampler {
space: ArchSearchSpace,
rng: StdRng,
}
impl ArchSampler {
pub fn new(space: ArchSearchSpace, seed: u64) -> Self {
Self {
space,
rng: StdRng::seed_from_u64(seed),
}
}
pub fn random_architecture(&mut self) -> TrainResult<Architecture> {
let depth_range = self.space.max_depth - self.space.min_depth + 1;
let depth = self.space.min_depth + self.rng.gen_range(0..depth_range);
let mut layers = Vec::with_capacity(depth);
for _ in 0..depth {
layers.push(self.sample_layer()?);
}
Ok(Architecture { layers })
}
pub fn mutate(&mut self, arch: &Architecture) -> TrainResult<Architecture> {
let mut new_arch = arch.clone();
let mutation_type = self.rng.gen_range(0..4_usize);
match mutation_type {
0 => {
let layer_idx = self.pick_layer_index(&new_arch)?;
let new_op = self.pick_option(&self.space.op_options.clone())?;
new_arch.layers[layer_idx].op = new_op;
}
1 => {
let layer_idx = self.pick_layer_index(&new_arch)?;
let new_width = self.pick_width()?;
new_arch.layers[layer_idx].width = new_width;
}
2 => {
let layer_idx = self.pick_layer_index(&new_arch)?;
let new_act = self.pick_option(&self.space.activation_options.clone())?;
new_arch.layers[layer_idx].activation = new_act;
}
3 => {
let can_add = new_arch.depth() < self.space.max_depth;
let can_remove = new_arch.depth() > self.space.min_depth;
if can_add && can_remove {
if self.rng.gen_range(0..2_usize) == 0 {
self.add_random_layer(&mut new_arch)?;
} else {
self.remove_random_layer(&mut new_arch)?;
}
} else if can_add {
self.add_random_layer(&mut new_arch)?;
} else if can_remove {
self.remove_random_layer(&mut new_arch)?;
} else {
let layer_idx = self.pick_layer_index(&new_arch)?;
let new_op = self.pick_option(&self.space.op_options.clone())?;
new_arch.layers[layer_idx].op = new_op;
}
}
_ => unreachable!("gen_range(0..4) is always in 0..3"),
}
Ok(new_arch)
}
fn sample_layer(&mut self) -> TrainResult<LayerSpec> {
let op = self.pick_option(&self.space.op_options.clone())?;
let width = self.pick_width()?;
let activation = self.pick_option(&self.space.activation_options.clone())?;
Ok(LayerSpec {
op,
width,
activation,
})
}
fn pick_option(&mut self, options: &[String]) -> TrainResult<String> {
if options.is_empty() {
return Err(TrainError::InvalidParameter(
"option list must be non-empty".to_string(),
));
}
let idx = self.rng.gen_range(0..options.len());
Ok(options[idx].clone())
}
fn pick_width(&mut self) -> TrainResult<usize> {
if self.space.width_options.is_empty() {
return Err(TrainError::InvalidParameter(
"width_options must be non-empty".to_string(),
));
}
let idx = self.rng.gen_range(0..self.space.width_options.len());
Ok(self.space.width_options[idx])
}
fn pick_layer_index(&mut self, arch: &Architecture) -> TrainResult<usize> {
if arch.layers.is_empty() {
return Err(TrainError::InvalidParameter(
"architecture has no layers to mutate".to_string(),
));
}
Ok(self.rng.gen_range(0..arch.layers.len()))
}
fn add_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
let new_layer = self.sample_layer()?;
let pos = self.rng.gen_range(0..=arch.layers.len());
arch.layers.insert(pos, new_layer);
Ok(())
}
fn remove_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
let idx = self.pick_layer_index(arch)?;
arch.layers.remove(idx);
Ok(())
}
pub fn gen_range_usize(&mut self, lower: usize, upper: usize) -> usize {
if upper <= lower {
return lower;
}
lower + self.rng.gen_range(0..(upper - lower))
}
}