use crate::error::{NeuralError, Result};
use scirs2_core::random::{Rng, RngExt, SeedableRng};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemperatureSchedule {
Exponential,
Linear,
Cosine,
}
#[derive(Debug, Clone)]
pub struct GdasConfig {
pub n_nodes: usize,
pub n_ops: usize,
pub tau_start: f64,
pub tau_end: f64,
pub n_epochs: usize,
pub arch_lr: f64,
pub weight_lr: f64,
pub schedule: TemperatureSchedule,
pub seed: u64,
}
impl Default for GdasConfig {
fn default() -> Self {
Self {
n_nodes: 4,
n_ops: 8,
tau_start: 10.0,
tau_end: 0.1,
n_epochs: 100,
arch_lr: 3e-4,
weight_lr: 1e-3,
schedule: TemperatureSchedule::Exponential,
seed: 42,
}
}
}
pub struct GdasSearch {
pub alpha: Vec<Vec<f64>>,
pub temperature: f64,
config: GdasConfig,
}
impl GdasSearch {
pub fn new(config: GdasConfig) -> Self {
let n_edges = config.n_nodes * (config.n_nodes - 1) / 2;
let alpha = vec![vec![0.0_f64; config.n_ops]; n_edges];
let temperature = config.tau_start;
Self {
alpha,
temperature,
config,
}
}
pub fn n_edges(&self) -> usize {
self.config.n_nodes * (self.config.n_nodes - 1) / 2
}
pub fn gumbel_top1_sample(
&self,
logits: &[f64],
rng: &mut impl Rng,
) -> Result<(usize, Vec<f64>)> {
if logits.is_empty() {
return Err(NeuralError::InvalidArgument(
"logits must be non-empty for Gumbel-top1 sampling".to_string(),
));
}
let tau = self.temperature;
if tau <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"Temperature must be positive, got {tau}"
)));
}
let mut best_val = f64::NEG_INFINITY;
let mut best_idx = 0usize;
for (i, &logit) in logits.iter().enumerate() {
let u: f64 = rng.random();
let u_clamped = u.max(1e-40);
let gumbel = -(-u_clamped.ln()).ln();
let y = (logit + gumbel) / tau;
if y > best_val {
best_val = y;
best_idx = i;
}
}
let mut soft_weights = vec![0.0_f64; logits.len()];
soft_weights[best_idx] = 1.0;
Ok((best_idx, soft_weights))
}
pub fn anneal_temperature(&mut self, epoch: usize) {
let tau_start = self.config.tau_start;
let tau_end = self.config.tau_end;
let n_epochs = self.config.n_epochs.max(1);
let t = (epoch as f64) / (n_epochs as f64);
self.temperature = match self.config.schedule {
TemperatureSchedule::Exponential => {
let ratio = tau_end / tau_start;
tau_start * ratio.powf(t)
}
TemperatureSchedule::Linear => {
tau_start - (tau_start - tau_end) * t
}
TemperatureSchedule::Cosine => {
tau_end + 0.5 * (tau_start - tau_end) * (1.0 + (std::f64::consts::PI * t).cos())
}
};
self.temperature = self
.temperature
.clamp(tau_end.min(tau_start), tau_end.max(tau_start));
}
pub fn derive_architecture(&self) -> Vec<usize> {
self.alpha
.iter()
.map(|edge_logits| {
edge_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0)
})
.collect()
}
pub fn architecture_probabilities(&self) -> Vec<Vec<f64>> {
self.alpha
.iter()
.map(|edge_logits| softmax(edge_logits))
.collect()
}
pub fn update_alpha(&mut self, edge_idx: usize, op_idx: usize, gradient: f64, lr: f64) {
if let Some(edge) = self.alpha.get_mut(edge_idx) {
if let Some(param) = edge.get_mut(op_idx) {
*param -= lr * gradient;
}
}
}
pub fn config(&self) -> &GdasConfig {
&self.config
}
}
fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return Vec::new();
}
let max_val = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum <= 0.0 {
let n = exps.len() as f64;
return vec![1.0 / n; exps.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
#[test]
fn test_gdas_gumbel_top1_returns_valid_index() {
let config = GdasConfig::default();
let search = GdasSearch::new(config.clone());
let mut rng = make_rng(0);
let logits = vec![1.0, 2.0, 0.5, 3.0, 1.5, 0.0, 2.5, 1.0];
let (idx, weights) = search.gumbel_top1_sample(&logits, &mut rng).unwrap();
assert!(idx < config.n_ops, "Selected index must be in [0, n_ops)");
assert_eq!(weights.len(), config.n_ops);
let sum: f64 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "Weights must sum to 1");
assert_eq!(weights[idx], 1.0, "Selected operation weight must be 1");
for (i, &w) in weights.iter().enumerate() {
if i != idx {
assert_eq!(w, 0.0, "Non-selected weights must be 0");
}
}
}
#[test]
fn test_gdas_n_edges_correct() {
let config = GdasConfig {
n_nodes: 4,
..Default::default()
};
let search = GdasSearch::new(config);
assert_eq!(search.n_edges(), 6);
assert_eq!(search.alpha.len(), 6);
}
#[test]
fn test_gdas_anneal_temperature_decreases_exponential() {
let config = GdasConfig {
schedule: TemperatureSchedule::Exponential,
tau_start: 10.0,
tau_end: 0.1,
n_epochs: 100,
..Default::default()
};
let mut search = GdasSearch::new(config);
let initial_temp = search.temperature;
search.anneal_temperature(50);
assert!(
search.temperature < initial_temp,
"Temperature should decrease after annealing"
);
search.anneal_temperature(100);
assert!(
(search.temperature - 0.1).abs() < 0.01,
"Temperature should approach tau_end at end of schedule"
);
}
#[test]
fn test_gdas_anneal_temperature_decreases_linear() {
let config = GdasConfig {
schedule: TemperatureSchedule::Linear,
tau_start: 10.0,
tau_end: 0.1,
n_epochs: 100,
..Default::default()
};
let mut search = GdasSearch::new(config);
search.anneal_temperature(50);
let mid_temp = search.temperature;
assert!(
(mid_temp - 5.05).abs() < 0.1,
"Linear midpoint temperature should be ~5.05"
);
}
#[test]
fn test_gdas_anneal_temperature_cosine() {
let config = GdasConfig {
schedule: TemperatureSchedule::Cosine,
tau_start: 10.0,
tau_end: 0.1,
n_epochs: 100,
..Default::default()
};
let mut search = GdasSearch::new(config);
search.anneal_temperature(0);
assert!(
(search.temperature - 10.0).abs() < 0.01,
"Cosine schedule at epoch 0 should be tau_start"
);
}
#[test]
fn test_gdas_derive_architecture_argmax() {
let config = GdasConfig {
n_nodes: 3,
n_ops: 4,
..GdasConfig::default()
};
let mut search = GdasSearch::new(config);
search.alpha[0] = vec![0.1, 0.2, 5.0, 0.0];
search.alpha[1] = vec![3.0, 0.1, 0.2, 0.0];
search.alpha[2] = vec![0.0, 0.1, 0.2, 7.0];
let arch = search.derive_architecture();
assert_eq!(arch.len(), 3);
assert_eq!(arch[0], 2);
assert_eq!(arch[1], 0);
assert_eq!(arch[2], 3);
}
#[test]
fn test_gdas_architecture_probabilities_sum_to_one() {
let config = GdasConfig::default();
let search = GdasSearch::new(config);
let probs = search.architecture_probabilities();
for edge_probs in &probs {
let sum: f64 = edge_probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "Probabilities must sum to 1");
}
}
#[test]
fn test_gdas_update_alpha_gradient_step() {
let config = GdasConfig::default();
let mut search = GdasSearch::new(config);
let initial = search.alpha[0][0];
search.update_alpha(0, 0, 1.0, 0.01);
let expected = initial - 0.01 * 1.0;
assert!(
(search.alpha[0][0] - expected).abs() < 1e-12,
"Alpha should be updated by gradient step"
);
}
#[test]
fn test_gdas_gumbel_top1_empty_logits_error() {
let config = GdasConfig::default();
let search = GdasSearch::new(config);
let mut rng = make_rng(0);
let result = search.gumbel_top1_sample(&[], &mut rng);
assert!(result.is_err(), "Empty logits should return an error");
}
}