use crate::error::OptimizeError;
use crate::nas::search_space::{Architecture, SearchSpace};
use scirs2_core::random::{rngs::StdRng, SeedableRng};
#[derive(Debug, Clone)]
pub struct NASResult {
pub best_arch: Architecture,
pub best_score: f64,
pub all_scores: Vec<f64>,
pub n_evaluated: usize,
}
pub trait ArchFitness: Send + Sync {
fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError>;
}
pub struct ParamCountFitness {
pub target_params: usize,
}
impl ParamCountFitness {
pub fn new(target_params: usize) -> Self {
Self { target_params }
}
}
impl ArchFitness for ParamCountFitness {
fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
let params = arch.total_params() as f64;
let target = self.target_params as f64;
if target == 0.0 {
return Ok(if params == 0.0 { 0.0 } else { -1.0 });
}
Ok(-(params - target).abs() / target)
}
}
pub struct FlopsFitness {
pub flops_budget: usize,
pub spatial: usize,
}
impl FlopsFitness {
pub fn new(flops_budget: usize, spatial: usize) -> Self {
Self {
flops_budget,
spatial,
}
}
}
impl ArchFitness for FlopsFitness {
fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
let flops = arch.total_flops(self.spatial) as f64;
let budget = self.flops_budget as f64;
if budget == 0.0 {
return Ok(0.0);
}
if flops <= budget {
Ok(flops / budget)
} else {
Ok(-(flops - budget) / budget)
}
}
}
pub struct RandomNAS {
pub n_trials: usize,
}
impl RandomNAS {
pub fn new(n_trials: usize) -> Self {
Self { n_trials }
}
pub fn search<F: ArchFitness>(
&self,
space: &SearchSpace,
fitness: &F,
seed: u64,
) -> Result<NASResult, OptimizeError> {
use scirs2_core::random::{Rng, RngExt};
if self.n_trials == 0 {
return Err(OptimizeError::InvalidParameter(
"n_trials must be at least 1".to_string(),
));
}
let mut rng = StdRng::seed_from_u64(seed);
let mut best_score = f64::NEG_INFINITY;
let mut best_arch = space.sample_random(&mut rng);
let mut all_scores = Vec::with_capacity(self.n_trials);
for _ in 0..self.n_trials {
let arch = space.sample_random(&mut rng);
let score = fitness.evaluate(&arch)?;
all_scores.push(score);
if score > best_score {
best_score = score;
best_arch = arch;
}
}
Ok(NASResult {
best_arch,
best_score,
all_scores,
n_evaluated: self.n_trials,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nas::search_space::SearchSpace;
#[test]
fn test_random_nas_returns_result() {
let space = SearchSpace::darts_like(3);
let fitness = ParamCountFitness::new(10_000);
let nas = RandomNAS::new(20);
let result = nas.search(&space, &fitness, 0).expect("search failed");
assert_eq!(result.n_evaluated, 20);
assert_eq!(result.all_scores.len(), 20);
assert!(result.best_score.is_finite());
}
#[test]
fn test_random_nas_zero_trials_errors() {
let space = SearchSpace::darts_like(3);
let fitness = ParamCountFitness::new(10_000);
let nas = RandomNAS::new(0);
assert!(nas.search(&space, &fitness, 0).is_err());
}
#[test]
fn test_param_count_fitness_exact_match() {
let mut arch = Architecture::new(1, 32, 10);
let fitness = ParamCountFitness::new(0);
let score = fitness.evaluate(&arch).expect("eval failed");
assert_eq!(score, 0.0);
use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
arch.nodes.push(ArchNode {
id: 0,
name: "n0".into(),
output_channels: 32,
});
arch.nodes.push(ArchNode {
id: 1,
name: "n1".into(),
output_channels: 32,
});
arch.edges.push(ArchEdge {
from: 0,
to: 1,
op: OpType::Conv3x3,
});
let fitness2 = ParamCountFitness::new(0);
let score2 = fitness2.evaluate(&arch).expect("eval failed");
assert_eq!(score2, -1.0);
}
#[test]
fn test_flops_fitness_under_budget() {
use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
let mut arch = Architecture::new(1, 8, 10);
arch.nodes.push(ArchNode {
id: 0,
name: "n0".into(),
output_channels: 8,
});
arch.nodes.push(ArchNode {
id: 1,
name: "n1".into(),
output_channels: 8,
});
arch.edges.push(ArchEdge {
from: 0,
to: 1,
op: OpType::Skip,
});
let fitness = FlopsFitness::new(1_000_000, 8);
let score = fitness.evaluate(&arch).expect("eval failed");
assert!(score >= 0.0 && score <= 1.0);
}
}