use crate::error::{NeuralError, Result};
use scirs2_core::random::{ChaCha20Rng, Rng, RngExt, SeedableRng};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AcquisitionFunction {
ExpectedImprovement,
UpperConfidenceBound,
ThompsonSampling,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PredictorType {
GaussianProcess,
MlpRegressor,
}
#[derive(Debug, Clone)]
pub struct PredictorNasConfig {
pub n_initial_architectures: usize,
pub n_iterations: usize,
pub n_candidates_per_iter: usize,
pub acquisition: AcquisitionFunction,
pub predictor_type: PredictorType,
pub ucb_kappa: f64,
pub seed: u64,
}
impl Default for PredictorNasConfig {
fn default() -> Self {
Self {
n_initial_architectures: 10,
n_iterations: 20,
n_candidates_per_iter: 100,
acquisition: AcquisitionFunction::ExpectedImprovement,
predictor_type: PredictorType::GaussianProcess,
ucb_kappa: 2.0,
seed: 42,
}
}
}
pub struct ArchFeatureExtractor {
pub n_edges: usize,
pub n_ops: usize,
}
impl ArchFeatureExtractor {
pub fn new(n_edges: usize, n_ops: usize) -> Self {
Self { n_edges, n_ops }
}
pub fn encode(&self, arch: &[usize]) -> Vec<f64> {
let mut features = vec![0.0_f64; self.feature_dim()];
for (edge_idx, &op_idx) in arch.iter().enumerate().take(self.n_edges) {
let clamped_op = op_idx.min(self.n_ops.saturating_sub(1));
features[edge_idx * self.n_ops + clamped_op] = 1.0;
}
features
}
pub fn feature_dim(&self) -> usize {
self.n_edges * self.n_ops
}
}
pub struct GpPredictor {
x_train: Vec<Vec<f64>>,
y_train: Vec<f64>,
length_scale: f64,
noise_var: f64,
dual_alpha: Vec<f64>,
chol_l: Vec<Vec<f64>>,
}
impl GpPredictor {
pub fn new(length_scale: f64, noise_var: f64) -> Self {
Self {
x_train: Vec::new(),
y_train: Vec::new(),
length_scale,
noise_var,
dual_alpha: Vec::new(),
chol_l: Vec::new(),
}
}
fn rbf_kernel(&self, a: &[f64], b: &[f64]) -> f64 {
let sq_dist: f64 = a
.iter()
.zip(b.iter())
.map(|(&ai, &bi)| (ai - bi) * (ai - bi))
.sum();
let two_l2 = 2.0 * self.length_scale * self.length_scale;
(-sq_dist / two_l2).exp()
}
pub fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) {
let n = x.len();
if n == 0 || n != y.len() {
return;
}
self.x_train = x.to_vec();
self.y_train = y.to_vec();
let mut k_mat = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in 0..n {
k_mat[i][j] = self.rbf_kernel(&x[i], &x[j]);
}
k_mat[i][i] += self.noise_var;
}
self.chol_l = cholesky_decompose(&k_mat);
self.dual_alpha = cholesky_solve(&self.chol_l, y);
}
pub fn predict(&self, x: &[f64]) -> (f64, f64) {
if self.x_train.is_empty() {
return (0.0, 1.0);
}
let k_star: Vec<f64> = self
.x_train
.iter()
.map(|xi| self.rbf_kernel(x, xi))
.collect();
let mean: f64 = k_star
.iter()
.zip(self.dual_alpha.iter())
.map(|(&ki, &ai)| ki * ai)
.sum();
let k_xx = self.rbf_kernel(x, x);
let v = forward_solve(&self.chol_l, &k_star);
let v_sq_norm: f64 = v.iter().map(|&vi| vi * vi).sum();
let variance = (k_xx - v_sq_norm).max(1e-10);
let std_dev = variance.sqrt();
(mean, std_dev)
}
pub fn n_train(&self) -> usize {
self.x_train.len()
}
}
pub struct MlpPredictor {
w1: Vec<Vec<f64>>,
b1: Vec<f64>,
w2: Vec<Vec<f64>>,
b2: Vec<f64>,
w3: Vec<f64>,
b3: f64,
hidden_size: usize,
}
impl MlpPredictor {
pub fn new(input_dim: usize, hidden_size: usize, rng: &mut impl Rng) -> Self {
let w1 = random_matrix(hidden_size, input_dim, rng, (2.0 / input_dim as f64).sqrt());
let b1 = vec![0.0_f64; hidden_size];
let w2 = random_matrix(
hidden_size,
hidden_size,
rng,
(2.0 / hidden_size as f64).sqrt(),
);
let b2 = vec![0.0_f64; hidden_size];
let w3 = random_vector(hidden_size, rng, (2.0 / hidden_size as f64).sqrt());
let b3 = 0.0;
Self {
w1,
b1,
w2,
b2,
w3,
b3,
hidden_size,
}
}
pub fn forward(&self, x: &[f64]) -> f64 {
let h1: Vec<f64> = (0..self.hidden_size)
.map(|j| {
let pre_act: f64 = self.w1[j]
.iter()
.zip(x.iter())
.map(|(&w, &xi)| w * xi)
.sum::<f64>()
+ self.b1[j];
relu(pre_act)
})
.collect();
let h2: Vec<f64> = (0..self.hidden_size)
.map(|j| {
let pre_act: f64 = self.w2[j]
.iter()
.zip(h1.iter())
.map(|(&w, &hi)| w * hi)
.sum::<f64>()
+ self.b2[j];
relu(pre_act)
})
.collect();
self.w3
.iter()
.zip(h2.iter())
.map(|(&w, &h)| w * h)
.sum::<f64>()
+ self.b3
}
pub fn fit(&mut self, x: &[Vec<f64>], y: &[f64], n_epochs: usize, lr: f64) {
if x.is_empty() || x.len() != y.len() {
return;
}
let eps = 1e-5;
let n = x.len();
for _ in 0..n_epochs {
for sample_idx in 0..n {
let xi = &x[sample_idx];
let yi = y[sample_idx];
let pred = self.forward(xi);
let loss_scale = 2.0 * (pred - yi);
for k in 0..self.w3.len() {
let orig = self.w3[k];
self.w3[k] = orig + eps;
let p_plus = self.forward(xi);
self.w3[k] = orig - eps;
let p_minus = self.forward(xi);
self.w3[k] = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.w3[k] -= lr * grad;
}
{
let orig = self.b3;
self.b3 = orig + eps;
let p_plus = self.forward(xi);
self.b3 = orig - eps;
let p_minus = self.forward(xi);
self.b3 = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.b3 -= lr * grad;
}
for j in 0..self.hidden_size {
for k in 0..self.hidden_size {
let orig = self.w2[j][k];
self.w2[j][k] = orig + eps;
let p_plus = self.forward(xi);
self.w2[j][k] = orig - eps;
let p_minus = self.forward(xi);
self.w2[j][k] = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.w2[j][k] -= lr * grad;
}
let orig = self.b2[j];
self.b2[j] = orig + eps;
let p_plus = self.forward(xi);
self.b2[j] = orig - eps;
let p_minus = self.forward(xi);
self.b2[j] = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.b2[j] -= lr * grad;
}
for j in 0..self.hidden_size {
for k in 0..xi.len() {
let orig = self.w1[j][k];
self.w1[j][k] = orig + eps;
let p_plus = self.forward(xi);
self.w1[j][k] = orig - eps;
let p_minus = self.forward(xi);
self.w1[j][k] = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.w1[j][k] -= lr * grad;
}
let orig = self.b1[j];
self.b1[j] = orig + eps;
let p_plus = self.forward(xi);
self.b1[j] = orig - eps;
let p_minus = self.forward(xi);
self.b1[j] = orig;
let grad = loss_scale * (p_plus - p_minus) / (2.0 * eps);
self.b1[j] -= lr * grad;
}
}
}
}
pub fn predict(&self, x: &[f64]) -> (f64, f64) {
(self.forward(x), 0.1)
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
pub struct PredictorNasSearch {
config: PredictorNasConfig,
extractor: ArchFeatureExtractor,
evaluated: Vec<(Vec<usize>, f64)>,
rng: ChaCha20Rng,
}
impl PredictorNasSearch {
pub fn new(config: PredictorNasConfig, n_edges: usize, n_ops: usize) -> Self {
let seed = config.seed;
Self {
extractor: ArchFeatureExtractor::new(n_edges, n_ops),
evaluated: Vec::new(),
rng: ChaCha20Rng::seed_from_u64(seed),
config,
}
}
pub fn record_evaluation(&mut self, arch: Vec<usize>, score: f64) {
self.evaluated.push((arch, score));
}
pub fn n_evaluated(&self) -> usize {
self.evaluated.len()
}
pub fn best_architecture(&self) -> Option<(&Vec<usize>, f64)> {
self.evaluated
.iter()
.max_by(|(_, s1), (_, s2)| s1.partial_cmp(s2).unwrap_or(std::cmp::Ordering::Equal))
.map(|(arch, score)| (arch, *score))
}
pub fn propose_next_architectures(&mut self, n_proposals: usize) -> Result<Vec<Vec<usize>>> {
let n_edges = self.extractor.n_edges;
let n_ops = self.extractor.n_ops;
if self.evaluated.len() < 2 {
return Ok(self.random_architectures(n_proposals));
}
let x_train: Vec<Vec<f64>> = self
.evaluated
.iter()
.map(|(arch, _)| self.extractor.encode(arch))
.collect();
let y_train: Vec<f64> = self.evaluated.iter().map(|(_, s)| *s).collect();
let best_score = y_train.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let n_candidates = self.config.n_candidates_per_iter.max(n_proposals);
let candidates = self.random_architectures(n_candidates);
let feature_dim = self.extractor.feature_dim();
let scored: Vec<(f64, Vec<usize>)> = match self.config.predictor_type {
PredictorType::GaussianProcess => {
let mut gp = GpPredictor::new(1.0, 0.01);
gp.fit(&x_train, &y_train);
candidates
.into_iter()
.map(|arch| {
let feat = self.extractor.encode(&arch);
let (mu, sigma) = gp.predict(&feat);
let acq = self.acquisition_score(
mu,
sigma,
best_score,
&mut ChaCha20Rng::seed_from_u64(42),
);
(acq, arch)
})
.collect()
}
PredictorType::MlpRegressor => {
let hidden = 32_usize.min(feature_dim.max(4));
let mut mlp = MlpPredictor::new(feature_dim, hidden, &mut self.rng.clone());
mlp.fit(&x_train, &y_train, 20, 1e-3);
candidates
.into_iter()
.map(|arch| {
let feat = self.extractor.encode(&arch);
let (mu, sigma) = mlp.predict(&feat);
let acq = self.acquisition_score(
mu,
sigma,
best_score,
&mut ChaCha20Rng::seed_from_u64(42),
);
(acq, arch)
})
.collect()
}
};
let mut scored = scored;
scored.sort_by(|(s1, _), (s2, _)| s2.partial_cmp(s1).unwrap_or(std::cmp::Ordering::Equal));
let proposals: Vec<Vec<usize>> = scored
.into_iter()
.take(n_proposals)
.map(|(_, arch)| arch)
.collect();
if proposals.is_empty() {
return Err(NeuralError::ComputationError(
"No proposals generated — check n_candidates_per_iter".to_string(),
));
}
Ok(proposals)
}
fn random_architectures(&mut self, n: usize) -> Vec<Vec<usize>> {
let n_edges = self.extractor.n_edges;
let n_ops = self.extractor.n_ops;
(0..n)
.map(|_| {
(0..n_edges)
.map(|_| self.rng.random_range(0..n_ops))
.collect()
})
.collect()
}
fn acquisition_score(&self, mu: f64, sigma: f64, best: f64, rng: &mut impl Rng) -> f64 {
match self.config.acquisition {
AcquisitionFunction::ExpectedImprovement => Self::expected_improvement(mu, sigma, best),
AcquisitionFunction::UpperConfidenceBound => mu + self.config.ucb_kappa * sigma,
AcquisitionFunction::ThompsonSampling => {
let u: f64 = rng.random();
let u2: f64 = rng.random();
let u_clamped = u.max(1e-40);
let normal_sample =
(-2.0 * u_clamped.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
mu + sigma * normal_sample
}
}
}
fn expected_improvement(mu: f64, sigma: f64, best: f64) -> f64 {
if sigma < 1e-6 {
return (mu - best).max(0.0);
}
let z = (mu - best) / sigma;
let phi_z = standard_normal_pdf(z);
let big_phi_z = standard_normal_cdf(z);
(mu - best) * big_phi_z + sigma * phi_z
}
}
fn standard_normal_pdf(z: f64) -> f64 {
let inv_sqrt_2pi = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
inv_sqrt_2pi * (-0.5 * z * z).exp()
}
fn standard_normal_cdf(z: f64) -> f64 {
let t = 1.0 / (1.0 + 0.2316419 * z.abs());
let poly = t
* (0.319_381_53
+ t * (-0.356_563_782
+ t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
let phi = 1.0 - standard_normal_pdf(z) * poly;
if z >= 0.0 {
phi
} else {
1.0 - phi
}
}
fn cholesky_decompose(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = a.len();
let mut l = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in 0..=i {
let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
if i == j {
let diag = a[i][i] - sum;
l[i][j] = diag.max(1e-12).sqrt();
} else {
l[i][j] = if l[j][j].abs() < 1e-12 {
0.0
} else {
(a[i][j] - sum) / l[j][j]
};
}
}
}
l
}
fn forward_solve(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let n = l.len();
let mut v = vec![0.0_f64; n];
for i in 0..n {
let sum: f64 = (0..i).map(|k| l[i][k] * v[k]).sum();
v[i] = if l[i][i].abs() < 1e-12 {
0.0
} else {
(b[i] - sum) / l[i][i]
};
}
v
}
fn backward_solve(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let n = l.len();
let mut v = vec![0.0_f64; n];
for i in (0..n).rev() {
let sum: f64 = (i + 1..n).map(|k| l[k][i] * v[k]).sum();
v[i] = if l[i][i].abs() < 1e-12 {
0.0
} else {
(b[i] - sum) / l[i][i]
};
}
v
}
fn cholesky_solve(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let v = forward_solve(l, b);
backward_solve(l, &v)
}
fn random_matrix(rows: usize, cols: usize, rng: &mut impl Rng, scale: f64) -> Vec<Vec<f64>> {
(0..rows)
.map(|_| {
(0..cols)
.map(|_| {
let u: f64 = rng.random();
let u2: f64 = rng.random();
let u_clamped = u.max(1e-40);
let normal =
(-2.0 * u_clamped.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
scale * normal
})
.collect()
})
.collect()
}
fn random_vector(len: usize, rng: &mut impl Rng, scale: f64) -> Vec<f64> {
(0..len)
.map(|_| {
let u: f64 = rng.random();
let u2: f64 = rng.random();
let u_clamped = u.max(1e-40);
let normal = (-2.0 * u_clamped.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
scale * normal
})
.collect()
}
fn relu(x: f64) -> f64 {
x.max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
#[test]
fn test_gdas_gumbel_top1_returns_valid_index() {
use crate::nas::gdas::{GdasConfig, GdasSearch};
let config = GdasConfig::default();
let search = GdasSearch::new(config.clone());
let mut rng = make_rng(0);
let logits = vec![1.0, 2.0, 0.5, 3.0, 1.5, 0.0, 2.5, 1.0];
let (idx, weights) = search.gumbel_top1_sample(&logits, &mut rng).unwrap();
assert!(idx < config.n_ops);
let sum: f64 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
assert_eq!(weights[idx], 1.0);
}
#[test]
fn test_gdas_anneal_temperature_decreases() {
use crate::nas::gdas::{GdasConfig, GdasSearch, TemperatureSchedule};
let config = GdasConfig {
schedule: TemperatureSchedule::Exponential,
..Default::default()
};
let mut search = GdasSearch::new(config);
let initial = search.temperature;
search.anneal_temperature(50);
assert!(search.temperature < initial);
}
#[test]
fn test_gdas_derive_architecture_argmax() {
use crate::nas::gdas::{GdasConfig, GdasSearch};
let config = GdasConfig {
n_nodes: 3,
n_ops: 4,
..GdasConfig::default()
};
let mut search = GdasSearch::new(config);
search.alpha[0] = vec![0.1, 0.2, 5.0, 0.0];
search.alpha[1] = vec![3.0, 0.1, 0.2, 0.0];
search.alpha[2] = vec![0.0, 0.1, 0.2, 7.0];
let arch = search.derive_architecture();
assert_eq!(arch, vec![2, 0, 3]);
}
#[test]
fn test_gdas_n_edges_correct() {
use crate::nas::gdas::{GdasConfig, GdasSearch};
let config = GdasConfig {
n_nodes: 4,
..GdasConfig::default()
};
let search = GdasSearch::new(config);
assert_eq!(search.n_edges(), 6); }
#[test]
fn test_snas_gumbel_softmax_sums_to_one() {
use crate::nas::snas::{SnasConfig, SnasSearch};
let config = SnasConfig::default();
let search = SnasSearch::new(config);
let mut rng = make_rng(42);
let logits = vec![1.0, 0.5, -1.0, 2.0, 0.0, -0.5, 1.5, 0.3];
let weights = search
.gumbel_softmax_sample(&logits, 1.0, &mut rng)
.unwrap();
let sum: f64 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
#[test]
fn test_snas_kl_divergence_uniform_alpha_near_zero() {
use crate::nas::snas::{SnasConfig, SnasSearch};
let config = SnasConfig::default();
let search = SnasSearch::new(config);
let kl = search.kl_divergence_from_uniform(0).unwrap();
assert!(kl < 1e-10, "KL from uniform alpha should be ~0, got {kl}");
}
#[test]
fn test_snas_kl_divergence_peaked_positive() {
use crate::nas::snas::{SnasConfig, SnasSearch};
let config = SnasConfig {
n_nodes: 3,
n_ops: 4,
..SnasConfig::default()
};
let mut search = SnasSearch::new(config);
search.alpha[0] = vec![20.0, -10.0, -10.0, -10.0];
let kl = search.kl_divergence_from_uniform(0).unwrap();
assert!(
kl > 0.1,
"Peaked distribution should have positive KL, got {kl}"
);
}
#[test]
fn test_snas_architecture_entropy_uniform_max() {
use crate::nas::snas::{SnasConfig, SnasSearch};
let config = SnasConfig {
n_nodes: 3,
n_ops: 8,
..SnasConfig::default()
};
let search = SnasSearch::new(config);
let entropy = search.architecture_entropy();
let expected = (8.0_f64).ln();
assert!((entropy - expected).abs() < 1e-6);
}
#[test]
fn test_arch_feature_extractor_length() {
let extractor = ArchFeatureExtractor::new(6, 8);
assert_eq!(extractor.feature_dim(), 48);
let arch = vec![0, 1, 2, 3, 4, 5];
let features = extractor.encode(&arch);
assert_eq!(features.len(), 48);
}
#[test]
fn test_arch_feature_extractor_deterministic() {
let extractor = ArchFeatureExtractor::new(3, 4);
let arch = vec![1, 0, 3];
let f1 = extractor.encode(&arch);
let f2 = extractor.encode(&arch);
assert_eq!(f1, f2);
}
#[test]
fn test_arch_feature_extractor_distinct_architectures() {
let extractor = ArchFeatureExtractor::new(3, 4);
let arch1 = vec![0, 1, 2];
let arch2 = vec![1, 0, 2];
let f1 = extractor.encode(&arch1);
let f2 = extractor.encode(&arch2);
assert_ne!(
f1, f2,
"Different architectures should produce different features"
);
}
#[test]
fn test_arch_feature_extractor_one_hot_property() {
let extractor = ArchFeatureExtractor::new(4, 5);
let arch = vec![2, 0, 4, 1];
let features = extractor.encode(&arch);
for edge in 0..4 {
let block = &features[edge * 5..(edge + 1) * 5];
let sum: f64 = block.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"One-hot property failed at edge {edge}"
);
}
}
#[test]
fn test_gp_predictor_fit_and_predict() {
let mut gp = GpPredictor::new(1.0, 0.01);
let x = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let y = vec![0.0, 1.0, 1.0, 2.0];
gp.fit(&x, &y);
let (mean, std) = gp.predict(&[0.5, 0.5]);
assert!(std > 0.0, "Standard deviation should be positive");
assert!(
mean > -0.5 && mean < 2.5,
"Mean {mean} should be in plausible range"
);
}
#[test]
fn test_gp_predictor_variance_at_training_points_small() {
let mut gp = GpPredictor::new(1.0, 1e-6);
let x = vec![vec![0.0], vec![1.0], vec![2.0]];
let y = vec![1.0, 2.0, 3.0];
gp.fit(&x, &y);
let (_, std_at_train) = gp.predict(&[0.0]);
assert!(
std_at_train < 0.5,
"Variance at training point should be small, got {std_at_train}"
);
}
#[test]
fn test_mlp_predictor_overfit() {
let mut rng = make_rng(0);
let input_dim = 4;
let mut mlp = MlpPredictor::new(input_dim, 16, &mut rng);
let x = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let y = vec![1.0, 2.0, 3.0];
mlp.fit(&x, &y, 200, 0.01);
for (xi, &yi) in x.iter().zip(y.iter()) {
let (pred, _) = mlp.predict(xi);
assert!(
(pred - yi).abs() < 1.0,
"MLP should approximately fit training data: pred={pred:.3}, target={yi}"
);
}
}
#[test]
fn test_expected_improvement_at_best_zero() {
let ei = PredictorNasSearch::expected_improvement(1.0, 1e-9, 1.0);
assert!(
ei.abs() < 1e-6,
"EI at current best with tiny sigma should be ~0, got {ei}"
);
}
#[test]
fn test_expected_improvement_positive_with_uncertainty() {
let ei = PredictorNasSearch::expected_improvement(2.0, 0.5, 1.0);
assert!(ei > 0.0, "EI should be positive when mu > best, got {ei}");
}
#[test]
fn test_expected_improvement_zero_sigma_clamp() {
let ei_above = PredictorNasSearch::expected_improvement(1.5, 1e-8, 1.0);
assert!(
(ei_above - 0.5).abs() < 1e-6,
"Expected 0.5, got {ei_above}"
);
let ei_below = PredictorNasSearch::expected_improvement(0.5, 1e-8, 1.0);
assert!(
ei_below == 0.0,
"EI should be 0 when mu < best with tiny sigma"
);
}
#[test]
fn test_predictor_nas_record_evaluation() {
let config = PredictorNasConfig::default();
let mut search = PredictorNasSearch::new(config, 6, 8);
assert_eq!(search.n_evaluated(), 0);
search.record_evaluation(vec![0, 1, 2, 3, 4, 5], 0.75);
assert_eq!(search.n_evaluated(), 1);
}
#[test]
fn test_predictor_nas_best_architecture() {
let config = PredictorNasConfig::default();
let mut search = PredictorNasSearch::new(config, 3, 4);
search.record_evaluation(vec![0, 1, 2], 0.80);
search.record_evaluation(vec![1, 0, 3], 0.95);
search.record_evaluation(vec![2, 2, 1], 0.70);
let (best_arch, best_score) = search.best_architecture().unwrap();
assert!((best_score - 0.95).abs() < 1e-10);
assert_eq!(best_arch, &vec![1, 0, 3]);
}
#[test]
fn test_predictor_nas_propose_returns_valid_arches() {
let config = PredictorNasConfig {
predictor_type: PredictorType::GaussianProcess,
n_candidates_per_iter: 20,
..PredictorNasConfig::default()
};
let mut search = PredictorNasSearch::new(config, 4, 5);
for i in 0..5 {
let arch: Vec<usize> = (0..4).map(|j| (i + j) % 5).collect();
search.record_evaluation(arch, 0.5 + i as f64 * 0.05);
}
let proposals = search.propose_next_architectures(3).unwrap();
assert_eq!(proposals.len(), 3);
for arch in &proposals {
assert_eq!(arch.len(), 4);
for &op in arch {
assert!(op < 5, "Operation index {op} out of range for n_ops=5");
}
}
}
#[test]
fn test_predictor_nas_n_evaluated_increases() {
let config = PredictorNasConfig::default();
let mut search = PredictorNasSearch::new(config, 3, 4);
for i in 0..10 {
search.record_evaluation(vec![i % 4, (i + 1) % 4, (i + 2) % 4], i as f64 * 0.1);
}
assert_eq!(search.n_evaluated(), 10);
}
#[test]
fn test_predictor_nas_no_data_returns_random() {
let config = PredictorNasConfig {
n_candidates_per_iter: 10,
..PredictorNasConfig::default()
};
let mut search = PredictorNasSearch::new(config, 4, 5);
let proposals = search.propose_next_architectures(5).unwrap();
assert_eq!(proposals.len(), 5);
}
}