use std::collections::VecDeque;
use crate::error::{TrainError, TrainResult};
use super::sampler::ArchSampler;
use super::space::{ArchSearchSpace, Architecture};
#[derive(Debug, Clone)]
pub struct NasResult {
pub best: Architecture,
pub best_score: f64,
pub history: Vec<(Architecture, f64)>,
}
pub struct RegularizedEvolution {
population: VecDeque<(Architecture, f64)>,
pub population_size: usize,
pub tournament_size: usize,
sampler: ArchSampler,
history: Vec<(Architecture, f64)>,
filled: bool,
}
impl RegularizedEvolution {
pub fn new(
space: ArchSearchSpace,
population_size: usize,
tournament_size: usize,
seed: u64,
) -> TrainResult<Self> {
if population_size < 2 {
return Err(TrainError::InvalidParameter(format!(
"population_size ({population_size}) must be ≥ 2"
)));
}
if tournament_size == 0 {
return Err(TrainError::InvalidParameter(
"tournament_size must be ≥ 1".to_string(),
));
}
if tournament_size > population_size {
return Err(TrainError::InvalidParameter(format!(
"tournament_size ({tournament_size}) must be ≤ population_size ({population_size})"
)));
}
Ok(Self {
population: VecDeque::new(),
population_size,
tournament_size,
sampler: ArchSampler::new(space, seed),
history: Vec::new(),
filled: false,
})
}
pub fn ask(&mut self) -> TrainResult<Architecture> {
if !self.filled {
self.sampler.random_architecture()
} else {
let winner = self.tournament_select()?;
self.sampler.mutate(&winner)
}
}
pub fn tell(&mut self, arch: Architecture, score: f64) {
self.history.push((arch.clone(), score));
self.population.push_back((arch, score));
if self.population.len() >= self.population_size {
self.filled = true;
}
if self.population.len() > self.population_size {
self.population.pop_front();
}
}
pub fn best(&self) -> Option<&(Architecture, f64)> {
self.population
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
}
pub fn result(&self) -> Option<NasResult> {
let (best_arch, best_score) = self
.history
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
Some(NasResult {
best: best_arch.clone(),
best_score: *best_score,
history: self.history.clone(),
})
}
fn tournament_select(&mut self) -> TrainResult<Architecture> {
let pop_len = self.population.len();
if pop_len == 0 {
return Err(TrainError::InvalidParameter(
"tournament_select called on empty population".to_string(),
));
}
let sample_size = self.tournament_size.min(pop_len);
let mut indices: Vec<usize> = (0..pop_len).collect();
for i in 0..sample_size {
let j = self.sampler.gen_range_usize(i, pop_len);
indices.swap(i, j);
}
let best_idx = indices[..sample_size]
.iter()
.max_by(|&&a, &&b| {
self.population[a]
.1
.partial_cmp(&self.population[b].1)
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
.ok_or_else(|| {
TrainError::InvalidParameter("tournament sample was empty".to_string())
})?;
Ok(self.population[best_idx].0.clone())
}
}