use crate::error::TrainResult;
use super::evolution::NasResult;
use super::sampler::ArchSampler;
use super::space::{ArchSearchSpace, Architecture};
pub struct RandomArchSearch {
sampler: ArchSampler,
best: Option<(Architecture, f64)>,
history: Vec<(Architecture, f64)>,
}
impl RandomArchSearch {
pub fn new(space: ArchSearchSpace, seed: u64) -> Self {
Self {
sampler: ArchSampler::new(space, seed),
best: None,
history: Vec::new(),
}
}
pub fn ask(&mut self) -> TrainResult<Architecture> {
self.sampler.random_architecture()
}
pub fn tell(&mut self, arch: Architecture, score: f64) {
let is_better = self
.best
.as_ref()
.is_none_or(|(_, best_score)| score > *best_score);
if is_better {
self.best = Some((arch.clone(), score));
}
self.history.push((arch, score));
}
pub fn best(&self) -> Option<&(Architecture, f64)> {
self.best.as_ref()
}
pub fn result(&self) -> Option<NasResult> {
let (best_arch, best_score) = self.best.as_ref()?;
Some(NasResult {
best: best_arch.clone(),
best_score: *best_score,
history: self.history.clone(),
})
}
}