pub mod gdas;
pub mod predictor_nas;
pub mod snas;
use crate::error::{OptimizeError, OptimizeResult};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operation {
Identity,
Zero,
Conv3x3,
Conv5x5,
MaxPool,
AvgPool,
SkipConnect,
}
impl Operation {
pub fn cost_flops(&self, channels: usize) -> f64 {
let c = channels as f64;
match self {
Operation::Identity => 0.0,
Operation::Zero => 0.0,
Operation::Conv3x3 => 2.0 * 9.0 * c * c,
Operation::Conv5x5 => 2.0 * 25.0 * c * c,
Operation::MaxPool => c, Operation::AvgPool => c, Operation::SkipConnect => 0.0,
}
}
pub fn name(&self) -> &'static str {
match self {
Operation::Identity => "identity",
Operation::Zero => "zero",
Operation::Conv3x3 => "conv3x3",
Operation::Conv5x5 => "conv5x5",
Operation::MaxPool => "max_pool",
Operation::AvgPool => "avg_pool",
Operation::SkipConnect => "skip_connect",
}
}
pub fn all() -> &'static [Operation] {
&[
Operation::Identity,
Operation::Zero,
Operation::Conv3x3,
Operation::Conv5x5,
Operation::MaxPool,
Operation::AvgPool,
]
}
}
#[derive(Debug, Clone)]
pub struct DartsConfig {
pub n_cells: usize,
pub n_operations: usize,
pub channels: usize,
pub n_nodes: usize,
pub arch_lr: f64,
pub weight_lr: f64,
pub temperature: f64,
}
impl Default for DartsConfig {
fn default() -> Self {
Self {
n_cells: 4,
n_operations: 6,
channels: 16,
n_nodes: 4,
arch_lr: 3e-4,
weight_lr: 3e-4,
temperature: 1.0,
}
}
}
pub(crate) struct Lcg {
state: u64,
}
impl Lcg {
pub(crate) fn new(seed: u64) -> Self {
Self { state: seed }
}
pub(crate) fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((self.state >> 11) as f64) * (1.0 / (1u64 << 53) as f64)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AnnealingStrategy {
Linear,
Exponential,
Cosine,
}
#[derive(Debug, Clone)]
pub struct TemperatureSchedule {
pub initial: f64,
pub final_temp: f64,
pub strategy: AnnealingStrategy,
pub total_steps: usize,
}
impl TemperatureSchedule {
pub fn new(
initial: f64,
final_temp: f64,
strategy: AnnealingStrategy,
total_steps: usize,
) -> Self {
Self {
initial,
final_temp,
strategy,
total_steps,
}
}
pub fn temperature_at(&self, step: usize) -> f64 {
let t = step.min(self.total_steps);
let frac = if self.total_steps == 0 {
1.0
} else {
t as f64 / self.total_steps as f64
};
match self.strategy {
AnnealingStrategy::Linear => self.initial + (self.final_temp - self.initial) * frac,
AnnealingStrategy::Exponential => {
if self.initial <= 0.0 || self.final_temp <= 0.0 {
self.final_temp
} else {
self.initial * (self.final_temp / self.initial).powf(frac)
}
}
AnnealingStrategy::Cosine => {
self.final_temp
+ 0.5
* (self.initial - self.final_temp)
* (1.0 + (std::f64::consts::PI * frac).cos())
}
}
}
}
#[derive(Debug, Clone)]
pub struct MixedOperation {
pub arch_params: Vec<f64>,
pub operation_outputs: Option<Vec<Vec<f64>>>,
}
impl MixedOperation {
pub fn new(n_ops: usize) -> Self {
Self {
arch_params: vec![0.0_f64; n_ops],
operation_outputs: None,
}
}
pub fn weights(&self, temperature: f64) -> Vec<f64> {
let t = temperature.max(1e-8); let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
vec![1.0 / self.arch_params.len() as f64; self.arch_params.len()]
} else {
exps.iter().map(|e| e / sum).collect()
}
}
pub fn forward(
&mut self,
x: &[f64],
op_fn: impl Fn(usize, &[f64]) -> Vec<f64>,
temperature: f64,
) -> Vec<f64> {
let w = self.weights(temperature);
let n_ops = self.arch_params.len();
let op_outputs: Vec<Vec<f64>> = (0..n_ops).map(|k| op_fn(k, x)).collect();
let out_len = op_outputs.first().map(|v| v.len()).unwrap_or(x.len());
let mut result = vec![0.0_f64; out_len];
for (k, out) in op_outputs.iter().enumerate() {
for (r, o) in result.iter_mut().zip(out.iter()) {
*r += w[k] * o;
}
}
self.operation_outputs = Some(op_outputs);
result
}
pub fn argmax_op(&self) -> usize {
self.arch_params
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct DartsCell {
pub n_nodes: usize,
pub n_input_nodes: usize,
pub edges: Vec<Vec<MixedOperation>>,
}
impl DartsCell {
pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
let edges: Vec<Vec<MixedOperation>> = (0..n_intermediate_nodes)
.map(|i| {
let n_predecessors = n_input_nodes + i;
(0..n_predecessors)
.map(|_| MixedOperation::new(n_ops))
.collect()
})
.collect();
Self {
n_nodes: n_intermediate_nodes,
n_input_nodes,
edges,
}
}
pub fn forward(&mut self, inputs: &[Vec<f64>], temperature: f64) -> Vec<f64> {
if inputs.is_empty() {
return Vec::new();
}
let feature_len = inputs[0].len();
let mut node_outputs: Vec<Vec<f64>> = inputs.to_vec();
for i in 0..self.n_nodes {
let n_prev = self.n_input_nodes + i;
let mut node_out = vec![0.0_f64; feature_len];
for j in 0..n_prev {
let src = node_outputs[j].clone();
let edge_out = self.edges[i][j].forward(&src, default_op_fn, temperature);
for (no, eo) in node_out.iter_mut().zip(edge_out.iter()) {
*no += eo;
}
}
node_outputs.push(node_out);
}
let mut result = Vec::with_capacity(self.n_nodes * feature_len);
for node_out in node_outputs.iter().skip(self.n_input_nodes) {
result.extend_from_slice(node_out);
}
result
}
pub fn arch_parameters(&self) -> Vec<f64> {
self.edges
.iter()
.flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
.collect()
}
pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
let n_params: usize = self
.edges
.iter()
.flat_map(|row| row.iter())
.map(|mo| mo.arch_params.len())
.sum();
if grads.len() != n_params {
return Err(OptimizeError::InvalidInput(format!(
"Expected {} gradient values, got {}",
n_params,
grads.len()
)));
}
let mut idx = 0;
for row in self.edges.iter_mut() {
for mo in row.iter_mut() {
for p in mo.arch_params.iter_mut() {
*p -= lr * grads[idx];
idx += 1;
}
}
}
Ok(())
}
pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
self.edges
.iter()
.map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
.collect()
}
}
fn default_op_fn(_k: usize, x: &[f64]) -> Vec<f64> {
x.to_vec()
}
#[derive(Debug, Clone)]
pub struct DartsSearch {
pub cells: Vec<DartsCell>,
pub config: DartsConfig,
weights: Vec<f64>,
}
impl DartsSearch {
pub fn new(config: DartsConfig) -> Self {
let cells: Vec<DartsCell> = (0..config.n_cells)
.map(|_| DartsCell::new(2, config.n_nodes, config.n_operations))
.collect();
let weights = vec![0.01_f64; config.n_cells];
Self {
cells,
config,
weights,
}
}
pub fn arch_parameters(&self) -> Vec<f64> {
self.cells
.iter()
.flat_map(|c| c.arch_parameters())
.collect()
}
pub fn n_arch_params(&self) -> usize {
self.cells.iter().map(|c| c.arch_parameters().len()).sum()
}
pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
let total = self.n_arch_params();
if grads.len() != total {
return Err(OptimizeError::InvalidInput(format!(
"Expected {} arch-param grads, got {}",
total,
grads.len()
)));
}
let mut offset = 0;
for cell in self.cells.iter_mut() {
let n = cell.arch_parameters().len();
cell.update_arch_params(&grads[offset..offset + n], lr)?;
offset += n;
}
Ok(())
}
pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
self.cells.iter().map(|c| c.derive_discrete()).collect()
}
pub fn derive_discrete_arch(&self) -> Vec<Vec<Operation>> {
let ops = Operation::all();
self.derive_discrete_arch_indices()
.iter()
.map(|cell_disc| {
cell_disc
.iter()
.flat_map(|node_edges| {
node_edges.iter().map(|&idx| {
if idx < ops.len() {
ops[idx]
} else {
Operation::Identity
}
})
})
.collect()
})
.collect()
}
fn compute_loss(&self, x: &[Vec<f64>], y: &[f64]) -> f64 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
let w_sum: f64 = self.weights.iter().sum();
let mut loss = 0.0_f64;
let n = x.len().min(y.len());
for i in 0..n {
let x_mean = if x[i].is_empty() {
0.0
} else {
x[i].iter().sum::<f64>() / x[i].len() as f64
};
let pred = w_sum * x_mean;
let diff = pred - y[i];
loss += diff * diff;
}
loss / n as f64
}
fn weight_grads(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
let n = x.len().min(y.len());
if n == 0 {
return vec![0.0_f64; self.weights.len()];
}
let w_sum: f64 = self.weights.iter().sum();
let mut grad_sum = 0.0_f64;
for i in 0..n {
let x_mean = if x[i].is_empty() {
0.0
} else {
x[i].iter().sum::<f64>() / x[i].len() as f64
};
let pred = w_sum * x_mean;
let diff = pred - y[i];
grad_sum += 2.0 * diff * x_mean / n as f64;
}
vec![grad_sum; self.weights.len()]
}
fn arch_grads_fd(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
let n = self.n_arch_params();
if n == 0 {
return Vec::new();
}
let mut grads = vec![0.0_f64; n];
let h = 1e-4;
let mut offset = 0;
for cell_idx in 0..self.cells.len() {
let cell_n = self.cells[cell_idx].arch_parameters().len();
for local_j in 0..cell_n {
let global_j = offset + local_j;
let mut search_plus = self.clone();
let params_plus = search_plus.cells[cell_idx].arch_parameters();
let mut p_plus = params_plus.clone();
p_plus[local_j] += h;
let _ = search_plus.cells[cell_idx].set_arch_params(&p_plus);
let loss_plus = search_plus.compute_loss(x, y);
let mut search_minus = self.clone();
let params_minus = search_minus.cells[cell_idx].arch_parameters();
let mut p_minus = params_minus.clone();
p_minus[local_j] -= h;
let _ = search_minus.cells[cell_idx].set_arch_params(&p_minus);
let loss_minus = search_minus.compute_loss(x, y);
grads[global_j] = (loss_plus - loss_minus) / (2.0 * h);
}
offset += cell_n;
}
grads
}
pub fn bilevel_step(
&mut self,
train_x: &[Vec<f64>],
train_y: &[f64],
val_x: &[Vec<f64>],
val_y: &[f64],
) -> (f64, f64) {
let train_loss = self.compute_loss(train_x, train_y);
let val_loss = self.compute_loss(val_x, val_y);
let w_grads = self.weight_grads(train_x, train_y);
let lr_w = self.config.weight_lr;
for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
*w -= lr_w * g;
}
let a_grads = self.arch_grads_fd(val_x, val_y);
let lr_a = self.config.arch_lr;
if !a_grads.is_empty() {
let _ = self.update_arch_params(&a_grads, lr_a);
}
(train_loss, val_loss)
}
}
impl DartsCell {
pub fn set_arch_params(&mut self, params: &[f64]) -> OptimizeResult<()> {
let total: usize = self
.edges
.iter()
.flat_map(|r| r.iter())
.map(|m| m.arch_params.len())
.sum();
if params.len() != total {
return Err(OptimizeError::InvalidInput(format!(
"set_arch_params: expected {total} values, got {}",
params.len()
)));
}
let mut idx = 0;
for row in self.edges.iter_mut() {
for mo in row.iter_mut() {
for p in mo.arch_params.iter_mut() {
*p = params[idx];
idx += 1;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mixed_operation_weights_sum_to_one() {
let mo = MixedOperation::new(6);
let w = mo.weights(1.0);
assert_eq!(w.len(), 6);
let sum: f64 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "weights sum = {sum}");
}
#[test]
fn mixed_operation_weights_temperature_effect() {
let mut mo = MixedOperation::new(4);
mo.arch_params = vec![1.0, 0.5, 0.3, 0.2];
let w_hot = mo.weights(10.0);
let w_cold = mo.weights(0.1);
assert!(w_cold[0] > w_hot[0], "cold should be sharper");
}
#[test]
fn mixed_operation_forward_correct_shape() {
let mut mo = MixedOperation::new(3);
let x = vec![1.0_f64; 8];
let out = mo.forward(&x, |_k, v| v.to_vec(), 1.0);
assert_eq!(out.len(), 8);
}
#[test]
fn darts_cell_forward_output_shape() {
let mut cell = DartsCell::new(2, 3, 4);
let inputs = vec![vec![1.0_f64; 8], vec![0.5_f64; 8]];
let out = cell.forward(&inputs, 1.0);
assert_eq!(out.len(), 24);
}
#[test]
fn derive_discrete_arch_returns_ops() {
let config = DartsConfig {
n_cells: 2,
n_operations: 6,
n_nodes: 3,
..Default::default()
};
let search = DartsSearch::new(config);
let arch = search.derive_discrete_arch();
assert_eq!(arch.len(), 2, "one vec per cell");
for cell_ops in &arch {
assert!(!cell_ops.is_empty());
}
}
#[test]
fn bilevel_step_runs_without_error() {
let config = DartsConfig::default();
let mut search = DartsSearch::new(config);
let train_x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let train_y = vec![1.5, 3.5];
let val_x = vec![vec![0.5, 1.5]];
let val_y = vec![1.0];
let (tl, vl) = search.bilevel_step(&train_x, &train_y, &val_x, &val_y);
assert!(tl.is_finite());
assert!(vl.is_finite());
}
#[test]
fn arch_parameters_length_consistent() {
let config = DartsConfig {
n_cells: 3,
n_operations: 5,
n_nodes: 2,
..Default::default()
};
let search = DartsSearch::new(config);
let params = search.arch_parameters();
assert_eq!(params.len(), search.n_arch_params());
}
#[test]
fn update_arch_params_wrong_length_errors() {
let mut search = DartsSearch::new(DartsConfig::default());
let result = search.update_arch_params(&[1.0, 2.0], 0.01);
assert!(result.is_err());
}
#[test]
fn temperature_schedule_linear_bounds() {
let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Linear, 100);
let t0 = sched.temperature_at(0);
let t_half = sched.temperature_at(50);
let t_end = sched.temperature_at(100);
assert!((t0 - 10.0).abs() < 1e-10, "t0={t0}");
assert!((t_half - 5.5).abs() < 1e-10, "t_half={t_half}");
assert!((t_end - 1.0).abs() < 1e-10, "t_end={t_end}");
}
#[test]
fn temperature_schedule_exponential_bounds() {
let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Exponential, 100);
let t0 = sched.temperature_at(0);
let t_end = sched.temperature_at(100);
assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
let t_mid = sched.temperature_at(50);
assert!(t_mid > 1.0 && t_mid < 10.0, "t_mid={t_mid}");
}
#[test]
fn temperature_schedule_cosine_bounds() {
let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Cosine, 100);
let t0 = sched.temperature_at(0);
let t_end = sched.temperature_at(100);
assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
}
#[test]
fn temperature_schedule_clamped_beyond_total() {
let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 10);
let t_over = sched.temperature_at(999);
let t_end = sched.temperature_at(10);
assert!((t_over - t_end).abs() < 1e-10);
}
#[test]
fn temperature_schedule_zero_steps() {
let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 0);
let t = sched.temperature_at(0);
assert!((t - 1.0).abs() < 1e-10, "t={t}");
}
}