use super::{AnnealingStrategy, Lcg, TemperatureSchedule};
use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone)]
pub struct GdasConfig {
pub n_cells: usize,
pub n_operations: usize,
pub channels: usize,
pub n_nodes: usize,
pub arch_lr: f64,
pub weight_lr: f64,
pub temperature_schedule: TemperatureSchedule,
pub seed: u64,
}
impl Default for GdasConfig {
fn default() -> Self {
Self {
n_cells: 3,
n_operations: 6,
channels: 32,
n_nodes: 4,
arch_lr: 3e-4,
weight_lr: 1e-3,
temperature_schedule: TemperatureSchedule::new(
1.0,
0.1,
AnnealingStrategy::Exponential,
100,
),
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct GdasMixedOperation {
pub arch_params: Vec<f64>,
pub last_selected: usize,
pub last_soft_weights: Vec<f64>,
}
impl GdasMixedOperation {
pub fn new(n_ops: usize) -> Self {
Self {
arch_params: vec![0.0_f64; n_ops],
last_selected: 0,
last_soft_weights: vec![1.0 / n_ops as f64; n_ops],
}
}
pub fn gumbel_softmax_sample(&self, temperature: f64, rng: &mut Lcg) -> (usize, Vec<f64>) {
let eps = 1e-20_f64;
let temp = temperature.max(1e-8);
let n = self.arch_params.len();
let mut logits = vec![0.0_f64; n];
for k in 0..n {
let u = rng.next_f64().max(eps);
let gumbel_noise = -(-u.ln()).ln();
logits[k] = self.arch_params[k] + gumbel_noise;
}
let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut exp_vals: Vec<f64> = logits.iter().map(|&l| ((l - max_l) / temp).exp()).collect();
let sum = exp_vals.iter().sum::<f64>().max(eps);
for v in &mut exp_vals {
*v /= sum;
}
let selected = exp_vals
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
(selected, exp_vals)
}
pub fn weights(&self, temperature: f64) -> Vec<f64> {
let t = temperature.max(1e-8);
let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
let n = self.arch_params.len();
vec![1.0 / n as f64; n]
} else {
exps.iter().map(|e| e / sum).collect()
}
}
pub fn argmax_op(&self) -> usize {
self.arch_params
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) {
for (p, g) in self.arch_params.iter_mut().zip(grads.iter()) {
*p -= lr * g;
}
}
}
#[derive(Debug, Clone)]
pub struct GdasCell {
pub n_nodes: usize,
pub n_input_nodes: usize,
pub edges: Vec<Vec<GdasMixedOperation>>,
}
impl GdasCell {
pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
let edges: Vec<Vec<GdasMixedOperation>> = (0..n_intermediate_nodes)
.map(|i| {
let n_predecessors = n_input_nodes + i;
(0..n_predecessors)
.map(|_| GdasMixedOperation::new(n_ops))
.collect()
})
.collect();
Self {
n_nodes: n_intermediate_nodes,
n_input_nodes,
edges,
}
}
pub fn arch_parameters(&self) -> Vec<f64> {
self.edges
.iter()
.flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
.collect()
}
pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
let n_params: usize = self
.edges
.iter()
.flat_map(|row| row.iter())
.map(|mo| mo.arch_params.len())
.sum();
if grads.len() != n_params {
return Err(OptimizeError::InvalidInput(format!(
"GdasCell::update_arch_params: expected {n_params} grads, got {}",
grads.len()
)));
}
let mut idx = 0;
for row in self.edges.iter_mut() {
for mo in row.iter_mut() {
let n = mo.arch_params.len();
mo.update_arch_params(&grads[idx..idx + n], lr);
idx += n;
}
}
Ok(())
}
pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
self.edges
.iter()
.map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
.collect()
}
}
pub struct GdasSearch {
pub cells: Vec<GdasCell>,
pub config: GdasConfig,
weights: Vec<f64>,
rng: Lcg,
current_step: usize,
}
impl GdasSearch {
pub fn new(config: GdasConfig) -> Self {
let cells: Vec<GdasCell> = (0..config.n_cells)
.map(|_| GdasCell::new(2, config.n_nodes, config.n_operations))
.collect();
let weights = vec![0.01_f64; config.n_cells];
let rng = Lcg::new(config.seed);
Self {
cells,
config,
weights,
rng,
current_step: 0,
}
}
pub fn current_temperature(&self) -> f64 {
self.config
.temperature_schedule
.temperature_at(self.current_step)
}
pub fn arch_parameters(&self) -> Vec<f64> {
self.cells
.iter()
.flat_map(|c| c.arch_parameters())
.collect()
}
pub fn n_arch_params(&self) -> usize {
self.cells.iter().map(|c| c.arch_parameters().len()).sum()
}
pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
let total = self.n_arch_params();
if grads.len() != total {
return Err(OptimizeError::InvalidInput(format!(
"GdasSearch::update_arch_params: expected {total} grads, got {}",
grads.len()
)));
}
let mut offset = 0;
for cell in self.cells.iter_mut() {
let n = cell.arch_parameters().len();
cell.update_arch_params(&grads[offset..offset + n], lr)?;
offset += n;
}
Ok(())
}
pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
self.cells.iter().map(|c| c.derive_discrete()).collect()
}
pub fn arch_grads_fd(&self, val_fn: impl Fn(&[f64]) -> f64, step: f64) -> Vec<f64> {
let params = self.arch_parameters();
let n = params.len();
let mut grads = vec![0.0_f64; n];
for i in 0..n {
let mut p_plus = params.clone();
p_plus[i] += step;
let mut p_minus = params.clone();
p_minus[i] -= step;
grads[i] = (val_fn(&p_plus) - val_fn(&p_minus)) / (2.0 * step);
}
grads
}
pub fn bilevel_step(
&mut self,
weight_grad_fn: impl Fn(&[f64]) -> Vec<f64>,
val_fn: impl Fn(&[f64]) -> f64,
) -> OptimizeResult<()> {
self.current_step += 1;
let w_grads = weight_grad_fn(&self.weights);
if w_grads.len() != self.weights.len() {
return Err(OptimizeError::InvalidInput(format!(
"weight_grad_fn returned {} grads, expected {}",
w_grads.len(),
self.weights.len()
)));
}
let lr_w = self.config.weight_lr;
for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
*w -= lr_w * g;
}
let a_grads = self.arch_grads_fd(&val_fn, 1e-4);
if !a_grads.is_empty() {
self.update_arch_params(&a_grads, self.config.arch_lr)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_lcg() -> Lcg {
Lcg::new(12345)
}
#[test]
fn test_gumbel_softmax_valid_distribution() {
let mo = GdasMixedOperation::new(6);
let mut rng = make_lcg();
let (selected, soft) = mo.gumbel_softmax_sample(1.0, &mut rng);
assert!(selected < 6, "selected={selected} out of range");
assert_eq!(soft.len(), 6);
let sum: f64 = soft.iter().sum();
assert!((sum - 1.0).abs() < 1e-9, "soft sum={sum}");
for &w in &soft {
assert!(w >= 0.0, "negative weight {w}");
}
}
#[test]
fn test_temperature_annealing_sharpens() {
let mut mo = GdasMixedOperation::new(6);
mo.arch_params = vec![2.0, 0.5, 0.3, 0.1, 0.1, 0.0];
let w_hot = mo.weights(10.0);
let w_cold = mo.weights(0.01);
assert!(
w_cold[0] > w_hot[0],
"cold w[0]={} should > hot w[0]={}",
w_cold[0],
w_hot[0]
);
let entropy_hot: f64 = w_hot
.iter()
.map(|&p| if p > 0.0 { -p * p.ln() } else { 0.0 })
.sum();
let entropy_cold: f64 = w_cold
.iter()
.map(|&p| if p > 0.0 { -p * p.ln() } else { 0.0 })
.sum();
assert!(
entropy_hot > entropy_cold,
"hot entropy={entropy_hot} should > cold entropy={entropy_cold}"
);
}
#[test]
fn test_gumbel_softmax_selected_in_range() {
let mo = GdasMixedOperation::new(7);
let mut rng = make_lcg();
for _ in 0..50 {
let (sel, _) = mo.gumbel_softmax_sample(0.5, &mut rng);
assert!(sel < 7, "sel={sel}");
}
}
#[test]
fn test_gdas_cell_arch_params_shape() {
let cell = GdasCell::new(2, 4, 6);
let params = cell.arch_parameters();
assert_eq!(params.len(), 84, "params.len()={}", params.len());
}
#[test]
fn test_gdas_cell_update_wrong_len_errors() {
let mut cell = GdasCell::new(2, 3, 6);
let result = cell.update_arch_params(&[1.0, 2.0], 0.01);
assert!(result.is_err());
}
#[test]
fn test_gdas_bilevel_step_runs() {
let config = GdasConfig::default();
let mut search = GdasSearch::new(config);
let weight_grad_fn = |weights: &[f64]| vec![0.0_f64; weights.len()];
let val_fn = |params: &[f64]| params.iter().map(|p| p * p).sum::<f64>();
search
.bilevel_step(weight_grad_fn, val_fn)
.expect("bilevel_step should not error");
}
#[test]
fn test_gdas_bilevel_step_advances_temperature() {
let config = GdasConfig::default();
let mut search = GdasSearch::new(config);
let t0 = search.current_temperature();
let _ = search.bilevel_step(|w| vec![0.0; w.len()], |p| p.iter().sum::<f64>());
let t1 = search.current_temperature();
assert!(t1 <= t0 + 1e-12, "t1={t1} should be ≤ t0={t0}");
}
#[test]
fn test_derive_discrete_arch_valid() {
let config = GdasConfig {
n_cells: 2,
n_operations: 6,
n_nodes: 3,
..Default::default()
};
let search = GdasSearch::new(config);
let arch = search.derive_discrete_arch_indices();
assert_eq!(arch.len(), 2);
for cell_disc in &arch {
for node_edges in cell_disc {
for &op_idx in node_edges {
assert!(op_idx < 6, "op_idx={op_idx} >= n_operations=6");
}
}
}
}
#[test]
fn test_gdas_arch_parameters_length_consistent() {
let config = GdasConfig::default();
let search = GdasSearch::new(config);
assert_eq!(search.arch_parameters().len(), search.n_arch_params());
}
#[test]
fn test_gdas_update_arch_params_wrong_length_errors() {
let mut search = GdasSearch::new(GdasConfig::default());
let result = search.update_arch_params(&[1.0, 2.0], 0.01);
assert!(result.is_err());
}
#[test]
fn test_gdas_weight_gradient_wrong_length_errors() {
let mut search = GdasSearch::new(GdasConfig::default());
let bad_grad_fn = |_: &[f64]| vec![0.0_f64; 9999];
let result = search.bilevel_step(bad_grad_fn, |p| p.iter().sum::<f64>());
assert!(result.is_err());
}
}