Skip to main content

scirs2_optimize/darts/
snas.rs

1//! SNAS: Stochastic Neural Architecture Search (Xie et al., ICLR 2019).
2//!
3//! SNAS uses the concrete (Gumbel-Softmax) distribution as mixing weights for
4//! the forward pass — unlike GDAS's straight-through hard selection, SNAS
5//! computes a **weighted sum** over **all** operations using concrete-sample
6//! weights.  This allows the resource cost to be differentiated through the
7//! architecture parameters via the expected-cost term.
8//!
9//! The full objective is:
10//!
11//! ```text
12//! L_total = L_task + λ · E[cost(architecture)]
13//! ```
14//!
15//! where the expectation is approximated via `softmax(α / τ)` weights.
16//!
17//! ## References
18//!
19//! - Xie, S., Zheng, H., Liu, C. and Lin, L. (2019). "SNAS: Stochastic Neural
20//!   Architecture Search". ICLR 2019.
21
22use super::{AnnealingStrategy, Lcg, Operation, TemperatureSchedule};
23use crate::error::{OptimizeError, OptimizeResult};
24
25// ─────────────────────────────────────────────────────────── SnasConfig ──
26
27/// Configuration for a SNAS architecture search experiment.
28#[derive(Debug, Clone)]
29pub struct SnasConfig {
30    /// Number of cells stacked in the super-network.
31    pub n_cells: usize,
32    /// Number of candidate operations per edge.
33    pub n_operations: usize,
34    /// Number of feature channels (used for FLOP cost estimation).
35    pub channels: usize,
36    /// Number of intermediate nodes per cell.
37    pub n_nodes: usize,
38    /// Learning rate for architecture parameter updates.
39    pub arch_lr: f64,
40    /// Learning rate for network weight updates.
41    pub weight_lr: f64,
42    /// Temperature schedule for the concrete distribution.
43    pub temperature_schedule: TemperatureSchedule,
44    /// Resource penalty weight λ (multiplies expected FLOP cost in the loss).
45    pub resource_weight: f64,
46    /// Random seed for the internal LCG.
47    pub seed: u64,
48}
49
50impl Default for SnasConfig {
51    fn default() -> Self {
52        Self {
53            n_cells: 3,
54            n_operations: 6,
55            channels: 32,
56            n_nodes: 4,
57            arch_lr: 3e-4,
58            weight_lr: 1e-3,
59            temperature_schedule: TemperatureSchedule::new(
60                1.0,
61                0.1,
62                AnnealingStrategy::Exponential,
63                100,
64            ),
65            resource_weight: 0.001,
66            seed: 42,
67        }
68    }
69}
70
71// ──────────────────────────────────────────────────── SnasMixedOperation ──
72
73/// One mixed operation on a directed edge in the SNAS cell DAG.
74///
75/// Unlike GDAS's hard selection, SNAS uses the concrete sample weights to
76/// compute a differentiable weighted sum over all operations.
77#[derive(Debug, Clone)]
78pub struct SnasMixedOperation {
79    /// Un-normalised architecture parameters `α_k`, one per operation.
80    pub arch_params: Vec<f64>,
81    /// Concrete sample weights from the last sampling pass.
82    pub last_concrete_weights: Vec<f64>,
83}
84
85impl SnasMixedOperation {
86    /// Create a new `SnasMixedOperation` initialised to uniform weights.
87    pub fn new(n_ops: usize) -> Self {
88        Self {
89            arch_params: vec![0.0_f64; n_ops],
90            last_concrete_weights: vec![1.0 / n_ops as f64; n_ops],
91        }
92    }
93
94    /// Draw a concrete (Gumbel-Softmax) sample.
95    ///
96    /// Returns the soft weights `w_k = softmax((α_k + g_k) / τ)` where `g_k`
97    /// are i.i.d. Gumbel(0,1) noise samples.  The weights sum to 1 and are
98    /// non-negative.
99    pub fn concrete_sample(&self, temperature: f64, rng: &mut Lcg) -> Vec<f64> {
100        let eps = 1e-20_f64;
101        let temp = temperature.max(1e-8);
102        let n = self.arch_params.len();
103
104        let mut logits = vec![0.0_f64; n];
105        for k in 0..n {
106            let u = rng.next_f64().max(eps);
107            let gumbel_noise = -(-u.ln()).ln();
108            logits[k] = self.arch_params[k] + gumbel_noise;
109        }
110
111        // Numerically-stable softmax at temperature τ.
112        let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
113        let mut exp_vals: Vec<f64> = logits.iter().map(|&l| ((l - max_l) / temp).exp()).collect();
114        let sum = exp_vals.iter().sum::<f64>().max(eps);
115        for v in &mut exp_vals {
116            *v /= sum;
117        }
118        exp_vals
119    }
120
121    /// Compute standard softmax-normalised weights (no noise) at `temperature`.
122    pub fn weights(&self, temperature: f64) -> Vec<f64> {
123        let t = temperature.max(1e-8);
124        let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
125        let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
126        let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
127        let sum: f64 = exps.iter().sum();
128        if sum == 0.0 {
129            let n = self.arch_params.len();
130            vec![1.0 / n as f64; n]
131        } else {
132            exps.iter().map(|e| e / sum).collect()
133        }
134    }
135
136    /// Compute the expected FLOP cost for this edge.
137    ///
138    /// `E[cost] = Σ_k w_k * cost_flops(op_k)` where `w_k` are softmax weights
139    /// at the current temperature.
140    pub fn expected_cost(&self, temperature: f64, channels: usize) -> f64 {
141        let ops = Operation::all();
142        let w = self.weights(temperature);
143        w.iter()
144            .zip(ops.iter())
145            .take(self.arch_params.len())
146            .map(|(&wk, op)| wk * op.cost_flops(channels))
147            .sum()
148    }
149
150    /// Index of the operation with the highest architecture weight (argmax).
151    pub fn argmax_op(&self) -> usize {
152        self.arch_params
153            .iter()
154            .enumerate()
155            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
156            .map(|(i, _)| i)
157            .unwrap_or(0)
158    }
159
160    /// Apply a gradient-descent step to architecture parameters.
161    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) {
162        for (p, g) in self.arch_params.iter_mut().zip(grads.iter()) {
163            *p -= lr * g;
164        }
165    }
166}
167
168// ─────────────────────────────────────────────────────────────── SnasCell ──
169
170/// A SNAS cell: DAG with fixed input nodes and learnable intermediate nodes.
171#[derive(Debug, Clone)]
172pub struct SnasCell {
173    /// Number of intermediate (learnable) nodes.
174    pub n_nodes: usize,
175    /// Number of fixed input nodes (typically 2).
176    pub n_input_nodes: usize,
177    /// `edges[i][j]` is the `SnasMixedOperation` from node j to intermediate
178    /// node i.
179    pub edges: Vec<Vec<SnasMixedOperation>>,
180}
181
182impl SnasCell {
183    /// Create a new SNAS cell.
184    pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
185        let edges: Vec<Vec<SnasMixedOperation>> = (0..n_intermediate_nodes)
186            .map(|i| {
187                let n_predecessors = n_input_nodes + i;
188                (0..n_predecessors)
189                    .map(|_| SnasMixedOperation::new(n_ops))
190                    .collect()
191            })
192            .collect();
193        Self {
194            n_nodes: n_intermediate_nodes,
195            n_input_nodes,
196            edges,
197        }
198    }
199
200    /// Collect all architecture parameters from every edge, flattened.
201    pub fn arch_parameters(&self) -> Vec<f64> {
202        self.edges
203            .iter()
204            .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
205            .collect()
206    }
207
208    /// Compute total expected FLOP cost for this cell.
209    pub fn total_expected_cost(&self, temperature: f64, channels: usize) -> f64 {
210        self.edges
211            .iter()
212            .flat_map(|row| row.iter())
213            .map(|mo| mo.expected_cost(temperature, channels))
214            .sum()
215    }
216
217    /// Apply gradient updates to architecture parameters.
218    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
219        let n_params: usize = self
220            .edges
221            .iter()
222            .flat_map(|row| row.iter())
223            .map(|mo| mo.arch_params.len())
224            .sum();
225        if grads.len() != n_params {
226            return Err(OptimizeError::InvalidInput(format!(
227                "SnasCell::update_arch_params: expected {n_params} grads, got {}",
228                grads.len()
229            )));
230        }
231        let mut idx = 0;
232        for row in self.edges.iter_mut() {
233            for mo in row.iter_mut() {
234                let n = mo.arch_params.len();
235                mo.update_arch_params(&grads[idx..idx + n], lr);
236                idx += n;
237            }
238        }
239        Ok(())
240    }
241
242    /// Derive the discrete architecture: argmax operation index per edge.
243    pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
244        self.edges
245            .iter()
246            .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
247            .collect()
248    }
249}
250
251// ──────────────────────────────────────────────────────────── SnasSearch ──
252
253/// Top-level SNAS search controller.
254///
255/// Implements the bi-level optimisation loop with resource penalty.
256pub struct SnasSearch {
257    /// Stack of cells forming the super-network.
258    pub cells: Vec<SnasCell>,
259    /// Configuration.
260    pub config: SnasConfig,
261    /// Flat network weights (one scalar per cell in this toy model).
262    weights: Vec<f64>,
263    /// Internal pseudo-random number generator.
264    rng: Lcg,
265    /// Current training step (advances the temperature schedule).
266    current_step: usize,
267}
268
269impl SnasSearch {
270    /// Construct a `SnasSearch` from the given config.
271    pub fn new(config: SnasConfig) -> Self {
272        let cells: Vec<SnasCell> = (0..config.n_cells)
273            .map(|_| SnasCell::new(2, config.n_nodes, config.n_operations))
274            .collect();
275        let weights = vec![0.01_f64; config.n_cells];
276        let rng = Lcg::new(config.seed);
277        Self {
278            cells,
279            config,
280            weights,
281            rng,
282            current_step: 0,
283        }
284    }
285
286    /// Current concrete-distribution temperature.
287    pub fn current_temperature(&self) -> f64 {
288        self.config
289            .temperature_schedule
290            .temperature_at(self.current_step)
291    }
292
293    /// Return all architecture parameters across all cells, flattened.
294    pub fn arch_parameters(&self) -> Vec<f64> {
295        self.cells
296            .iter()
297            .flat_map(|c| c.arch_parameters())
298            .collect()
299    }
300
301    /// Total number of architecture parameters.
302    pub fn n_arch_params(&self) -> usize {
303        self.cells.iter().map(|c| c.arch_parameters().len()).sum()
304    }
305
306    /// Compute total expected FLOP cost over all cells (resource penalty term).
307    pub fn total_expected_cost(&self) -> f64 {
308        let temp = self.current_temperature();
309        let channels = self.config.channels;
310        self.cells
311            .iter()
312            .map(|c| c.total_expected_cost(temp, channels))
313            .sum()
314    }
315
316    /// Apply a gradient step to architecture parameters.
317    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
318        let total = self.n_arch_params();
319        if grads.len() != total {
320            return Err(OptimizeError::InvalidInput(format!(
321                "SnasSearch::update_arch_params: expected {total} grads, got {}",
322                grads.len()
323            )));
324        }
325        let mut offset = 0;
326        for cell in self.cells.iter_mut() {
327            let n = cell.arch_parameters().len();
328            cell.update_arch_params(&grads[offset..offset + n], lr)?;
329            offset += n;
330        }
331        Ok(())
332    }
333
334    /// Derive discrete architecture: argmax op index per edge, per cell.
335    pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
336        self.cells.iter().map(|c| c.derive_discrete()).collect()
337    }
338
339    /// Compute finite-difference gradients of the SNAS validation objective
340    /// (task loss + λ · resource_cost) w.r.t. architecture parameters.
341    ///
342    /// `val_fn(arch_params)` should return the **task** loss.  The resource
343    /// penalty is computed analytically via `total_expected_cost`.
344    pub fn arch_grads_fd(&self, val_fn: impl Fn(&[f64]) -> f64, step: f64) -> Vec<f64> {
345        let params = self.arch_parameters();
346        let n = params.len();
347        let lambda = self.config.resource_weight;
348        let temp = self.current_temperature();
349        let channels = self.config.channels;
350
351        let mut grads = vec![0.0_f64; n];
352        for i in 0..n {
353            let mut p_plus = params.clone();
354            p_plus[i] += step;
355            let mut p_minus = params.clone();
356            p_minus[i] -= step;
357
358            // Resource cost gradient via FD on arch params.
359            let cost_plus = resource_cost_at(&p_plus, &self.cells, temp, channels, lambda);
360            let cost_minus = resource_cost_at(&p_minus, &self.cells, temp, channels, lambda);
361
362            let task_grad = (val_fn(&p_plus) - val_fn(&p_minus)) / (2.0 * step);
363            let cost_grad = (cost_plus - cost_minus) / (2.0 * step);
364            grads[i] = task_grad + cost_grad;
365        }
366        grads
367    }
368
369    /// One bi-level optimisation step with resource penalty.
370    ///
371    /// The full validation objective is `L_task(arch_params) + λ · E[cost]`.
372    pub fn bilevel_step(
373        &mut self,
374        weight_grad_fn: impl Fn(&[f64]) -> Vec<f64>,
375        val_fn: impl Fn(&[f64]) -> f64,
376    ) -> OptimizeResult<()> {
377        self.current_step += 1;
378
379        // Inner: gradient step on network weights.
380        let w_grads = weight_grad_fn(&self.weights);
381        if w_grads.len() != self.weights.len() {
382            return Err(OptimizeError::InvalidInput(format!(
383                "weight_grad_fn returned {} grads, expected {}",
384                w_grads.len(),
385                self.weights.len()
386            )));
387        }
388        let lr_w = self.config.weight_lr;
389        for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
390            *w -= lr_w * g;
391        }
392
393        // Outer: gradient step on architecture params (task loss + resource).
394        let a_grads = self.arch_grads_fd(&val_fn, 1e-4);
395        if !a_grads.is_empty() {
396            self.update_arch_params(&a_grads, self.config.arch_lr)?;
397        }
398
399        Ok(())
400    }
401}
402
403// ─────────────────────────────────────────────── helper: resource cost at params ──
404
405/// Compute `λ * E[FLOP cost]` for a given flat arch-param vector without
406/// mutably borrowing `SnasSearch`.  Used inside FD gradient computation.
407fn resource_cost_at(
408    params: &[f64],
409    cells: &[SnasCell],
410    temperature: f64,
411    channels: usize,
412    lambda: f64,
413) -> f64 {
414    let ops = Operation::all();
415    let n_ops_canonical = ops.len();
416    let eps = 1e-8_f64;
417    let temp = temperature.max(eps);
418
419    let mut total_cost = 0.0_f64;
420    let mut offset = 0_usize;
421
422    for cell in cells.iter() {
423        for node_edges in cell.edges.iter() {
424            for mo in node_edges.iter() {
425                let n = mo.arch_params.len().min(n_ops_canonical);
426                let slice = &params[offset..offset + mo.arch_params.len()];
427
428                // Softmax of this edge's params.
429                let scaled: Vec<f64> = slice.iter().map(|a| a / temp).collect();
430                let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
431                let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
432                let sum: f64 = exps.iter().sum::<f64>().max(eps);
433
434                for k in 0..n {
435                    let wk = exps[k] / sum;
436                    total_cost += wk * ops[k].cost_flops(channels);
437                }
438
439                offset += mo.arch_params.len();
440            }
441        }
442    }
443
444    lambda * total_cost
445}
446
447// ═══════════════════════════════════════════════════════════════════ tests ═══
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    fn make_lcg() -> Lcg {
454        Lcg::new(99999)
455    }
456
457    // ── SnasMixedOperation ─────────────────────────────────────────────────────
458
459    #[test]
460    fn test_concrete_sample_valid() {
461        let mo = SnasMixedOperation::new(6);
462        let mut rng = make_lcg();
463        let weights = mo.concrete_sample(1.0, &mut rng);
464
465        assert_eq!(weights.len(), 6);
466        let sum: f64 = weights.iter().sum();
467        assert!((sum - 1.0).abs() < 1e-9, "concrete sample sum={sum}");
468        for &w in &weights {
469            assert!(w >= 0.0, "negative concrete weight {w}");
470        }
471    }
472
473    #[test]
474    fn test_concrete_sample_multiple_calls_valid() {
475        let mo = SnasMixedOperation::new(6);
476        let mut rng = make_lcg();
477        for _ in 0..20 {
478            let w = mo.concrete_sample(0.5, &mut rng);
479            let sum: f64 = w.iter().sum();
480            assert!((sum - 1.0).abs() < 1e-9);
481            for &v in &w {
482                assert!(v >= 0.0);
483            }
484        }
485    }
486
487    #[test]
488    fn test_concrete_sample_low_temp_peaks() {
489        // At very low temperature the concrete sample should be near one-hot.
490        let mut mo = SnasMixedOperation::new(6);
491        mo.arch_params = vec![5.0, 0.1, 0.1, 0.1, 0.1, 0.1];
492        let mut rng = make_lcg();
493        // Run several trials; the dominant weight should frequently be very large.
494        let mut dominant_count = 0;
495        for _ in 0..20 {
496            let w = mo.concrete_sample(0.01, &mut rng);
497            if w[0] > 0.5 {
498                dominant_count += 1;
499            }
500        }
501        // Expect most draws to be dominated by index 0 at τ=0.01.
502        assert!(
503            dominant_count >= 10,
504            "dominant_count={dominant_count} too low"
505        );
506    }
507
508    // ── Expected cost ──────────────────────────────────────────────────────────
509
510    #[test]
511    fn test_expected_cost_nonneg() {
512        let mo = SnasMixedOperation::new(6);
513        let cost = mo.expected_cost(1.0, 16);
514        assert!(cost >= 0.0, "cost={cost}");
515    }
516
517    #[test]
518    fn test_total_expected_cost_nonneg() {
519        let config = SnasConfig::default();
520        let search = SnasSearch::new(config);
521        let cost = search.total_expected_cost();
522        assert!(cost >= 0.0, "total cost={cost}");
523    }
524
525    #[test]
526    fn test_expected_cost_zero_for_no_flop_ops() {
527        // If all arch_params prefer Identity/Zero/SkipConnect (no-FLOP ops),
528        // the expected cost should be very small.
529        let mut mo = SnasMixedOperation::new(3); // identity=0, zero=1, skip=2
530        mo.arch_params = vec![10.0, 10.0, 10.0]; // all equal, ops 0..2 have 0 FLOPs
531        let cost = mo.expected_cost(1.0, 16);
532        // Identity(0), Zero(0), Conv3x3(non-zero in original ordering).
533        // op[2] = Conv3x3 in Operation::all(), so we can't claim 0.
534        // Just verify non-negative.
535        assert!(cost >= 0.0);
536    }
537
538    // ── SnasCell ───────────────────────────────────────────────────────────────
539
540    #[test]
541    fn test_snas_cell_arch_params_shape() {
542        let cell = SnasCell::new(2, 4, 6);
543        // Same edge structure as GDAS: total edges = 2+3+4+5 = 14, 6 ops each → 84.
544        assert_eq!(cell.arch_parameters().len(), 84);
545    }
546
547    #[test]
548    fn test_snas_cell_update_wrong_len_errors() {
549        let mut cell = SnasCell::new(2, 3, 6);
550        let result = cell.update_arch_params(&[0.0; 3], 0.01);
551        assert!(result.is_err());
552    }
553
554    // ── SnasSearch ─────────────────────────────────────────────────────────────
555
556    #[test]
557    fn test_snas_bilevel_step_runs() {
558        let config = SnasConfig::default();
559        let mut search = SnasSearch::new(config);
560
561        let weight_grad_fn = |weights: &[f64]| vec![0.0_f64; weights.len()];
562        let val_fn = |params: &[f64]| params.iter().map(|p| p * p).sum::<f64>();
563
564        search
565            .bilevel_step(weight_grad_fn, val_fn)
566            .expect("snas bilevel_step should not error");
567    }
568
569    #[test]
570    fn test_snas_bilevel_step_advances_temperature() {
571        let config = SnasConfig::default();
572        let mut search = SnasSearch::new(config);
573        let t0 = search.current_temperature();
574        let _ = search.bilevel_step(|w| vec![0.0; w.len()], |p| p.iter().sum::<f64>());
575        let t1 = search.current_temperature();
576        assert!(t1 <= t0 + 1e-12, "t1={t1} should be ≤ t0={t0}");
577    }
578
579    #[test]
580    fn test_derive_discrete_arch_valid() {
581        let config = SnasConfig {
582            n_cells: 2,
583            n_operations: 6,
584            n_nodes: 3,
585            ..Default::default()
586        };
587        let search = SnasSearch::new(config);
588        let arch = search.derive_discrete_arch_indices();
589        assert_eq!(arch.len(), 2);
590        for cell_disc in &arch {
591            for node_edges in cell_disc {
592                for &op_idx in node_edges {
593                    assert!(op_idx < 6, "op_idx={op_idx}");
594                }
595            }
596        }
597    }
598
599    #[test]
600    fn test_snas_arch_params_consistent() {
601        let search = SnasSearch::new(SnasConfig::default());
602        assert_eq!(search.arch_parameters().len(), search.n_arch_params());
603    }
604}