Skip to main content

scirs2_optimize/nas/
differentiable.rs

1//! DARTS: Differentiable Architecture Search (Liu et al., ICLR 2019).
2//!
3//! Relaxes the discrete architecture choice over a set of candidate
4//! operations to a continuous mixing via softmax weights α.  During
5//! search, both the network weights w and architecture weights α are
6//! optimised (bi-level).  After search the discrete architecture is
7//! recovered by taking argmax(α) per edge.
8//!
9//! This module provides the core data structure and the derive-discrete
10//! step.  Actual gradient-based updates require coupling with a neural
11//! network trainer; `update_alpha` provides a hook for applying
12//! externally-computed gradients.
13
14use crate::error::OptimizeError;
15use crate::nas::search_space::{ArchEdge, ArchNode, Architecture, OpType, SearchSpace};
16
17/// DARTS continuous architecture parameterisation.
18///
19/// `alpha[e][k]` is the (un-normalised) log-weight for operation `k` on
20/// edge `e`.  Normalised weights are obtained via softmax.
21#[derive(Debug, Clone)]
22pub struct DARTSSearch {
23    /// Number of intermediate nodes in the cell
24    pub n_nodes: usize,
25    /// Number of candidate operations
26    pub n_ops: usize,
27    /// Architecture weights: shape `[n_edges, n_ops]`
28    pub alpha: Vec<Vec<f64>>,
29    /// Learning rate for architecture weight updates
30    pub learning_rate: f64,
31    /// Number of input nodes (from previous cells)
32    pub n_input_nodes: usize,
33}
34
35impl DARTSSearch {
36    /// Initialise DARTS with uniform architecture weights.
37    ///
38    /// # Arguments
39    /// - `n_nodes`: Number of intermediate nodes per cell.
40    /// - `operations`: Slice of candidate operations.
41    /// - `n_input_nodes`: Number of fixed input nodes (e.g., 2 for DARTS).
42    pub fn new(n_nodes: usize, operations: &[OpType], n_input_nodes: usize) -> Self {
43        let n_ops = operations.len();
44        // In DARTS each intermediate node i receives edges from all
45        // previous nodes (i nodes including the n_input_nodes inputs).
46        // Total edges = sum_{i=0}^{n_nodes-1} (n_input_nodes + i)
47        let n_edges: usize = (0..n_nodes).map(|i| n_input_nodes + i).sum();
48        let init_weight = if n_ops > 0 { 1.0 / n_ops as f64 } else { 0.0 };
49        let alpha = vec![vec![init_weight; n_ops]; n_edges.max(1)];
50
51        Self {
52            n_nodes,
53            n_ops,
54            alpha,
55            learning_rate: 3e-4,
56            n_input_nodes,
57        }
58    }
59
60    /// Number of edges in the DARTS cell.
61    pub fn n_edges(&self) -> usize {
62        self.alpha.len()
63    }
64
65    /// Softmax-normalised operation weights for a given edge.
66    ///
67    /// Returns a zero vector if `edge_idx` is out of range.
68    pub fn get_op_weights(&self, edge_idx: usize) -> Vec<f64> {
69        if edge_idx >= self.alpha.len() {
70            return vec![0.0; self.n_ops];
71        }
72        let raw = &self.alpha[edge_idx];
73        let max = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
74        let exp: Vec<f64> = raw.iter().map(|x| (x - max).exp()).collect();
75        let sum: f64 = exp.iter().sum();
76        if sum == 0.0 {
77            return vec![1.0 / self.n_ops as f64; self.n_ops];
78        }
79        exp.iter().map(|e| e / sum).collect()
80    }
81
82    /// Derive a discrete architecture by taking argmax(α) per edge.
83    ///
84    /// Returns an `Architecture` whose edges carry the selected operations.
85    pub fn derive_architecture(
86        &self,
87        space: &SearchSpace,
88        n_cells: usize,
89        channels: usize,
90        n_classes: usize,
91    ) -> Architecture {
92        let mut arch = Architecture::new(n_cells, channels, n_classes);
93
94        // Add input nodes (two previous cell outputs)
95        for i in 0..self.n_input_nodes {
96            arch.nodes.push(ArchNode {
97                id: i,
98                name: format!("input{}", i),
99                output_channels: channels,
100            });
101        }
102
103        // Add intermediate nodes
104        let mut edge_idx = 0usize;
105        for i in 0..self.n_nodes {
106            let node_id = self.n_input_nodes + i;
107            arch.nodes.push(ArchNode {
108                id: node_id,
109                name: format!("node{}", i),
110                output_channels: channels,
111            });
112
113            // Edges from all previous nodes to this one
114            let n_prev = self.n_input_nodes + i;
115            for from_id in 0..n_prev {
116                let weights = self.get_op_weights(edge_idx);
117                let best_op_idx = weights
118                    .iter()
119                    .enumerate()
120                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
121                    .map(|(idx, _)| idx)
122                    .unwrap_or(0);
123
124                let op = space
125                    .operations
126                    .get(best_op_idx)
127                    .cloned()
128                    .unwrap_or(OpType::Skip);
129
130                arch.edges.push(ArchEdge {
131                    from: from_id,
132                    to: node_id,
133                    op,
134                });
135                edge_idx += 1;
136            }
137        }
138
139        arch
140    }
141
142    /// Apply an external gradient update to a single (edge, op) weight.
143    ///
144    /// Typical usage: call after computing `∂L/∂α[edge_idx][op_idx]`
145    /// from a validation loss.
146    pub fn update_alpha(
147        &mut self,
148        edge_idx: usize,
149        op_idx: usize,
150        grad: f64,
151    ) -> Result<(), OptimizeError> {
152        if edge_idx >= self.alpha.len() {
153            return Err(OptimizeError::InvalidParameter(format!(
154                "edge_idx {} out of range (n_edges = {})",
155                edge_idx,
156                self.alpha.len()
157            )));
158        }
159        if op_idx >= self.n_ops {
160            return Err(OptimizeError::InvalidParameter(format!(
161                "op_idx {} out of range (n_ops = {})",
162                op_idx, self.n_ops
163            )));
164        }
165        self.alpha[edge_idx][op_idx] += self.learning_rate * grad;
166        Ok(())
167    }
168
169    /// Batch update: apply gradient matrix `grads[edge][op]` to all weights.
170    ///
171    /// `grads` must have shape `[n_edges, n_ops]`.
172    pub fn update_alpha_batch(&mut self, grads: &[Vec<f64>]) -> Result<(), OptimizeError> {
173        if grads.len() != self.alpha.len() {
174            return Err(OptimizeError::InvalidParameter(format!(
175                "grads has {} rows but alpha has {}",
176                grads.len(),
177                self.alpha.len()
178            )));
179        }
180        for (e, row) in grads.iter().enumerate() {
181            if row.len() != self.n_ops {
182                return Err(OptimizeError::InvalidParameter(format!(
183                    "grads[{}] has {} columns but n_ops = {}",
184                    e,
185                    row.len(),
186                    self.n_ops
187                )));
188            }
189            for (k, &g) in row.iter().enumerate() {
190                self.alpha[e][k] += self.learning_rate * g;
191            }
192        }
193        Ok(())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::nas::search_space::SearchSpace;
201
202    fn make_darts() -> DARTSSearch {
203        let space = SearchSpace::darts_like(4);
204        DARTSSearch::new(4, &space.operations, 2)
205    }
206
207    #[test]
208    fn test_get_op_weights_sum_to_one() {
209        let darts = make_darts();
210        for e in 0..darts.n_edges() {
211            let w = darts.get_op_weights(e);
212            let sum: f64 = w.iter().sum();
213            assert!(
214                (sum - 1.0).abs() < 1e-10,
215                "weights do not sum to 1: {}",
216                sum
217            );
218        }
219    }
220
221    #[test]
222    fn test_derive_architecture_correct_structure() {
223        let space = SearchSpace::darts_like(4);
224        let darts = DARTSSearch::new(4, &space.operations, 2);
225        let arch = darts.derive_architecture(&space, 2, 64, 10);
226
227        // Should have n_input_nodes + n_nodes nodes
228        assert_eq!(arch.nodes.len(), 2 + 4);
229        // All edges should have valid from/to indices
230        for e in &arch.edges {
231            assert!(e.from < arch.nodes.len());
232            assert!(e.to < arch.nodes.len());
233        }
234    }
235
236    #[test]
237    fn test_update_alpha_changes_weights() {
238        let mut darts = make_darts();
239        let before = darts.alpha[0][0];
240        darts.update_alpha(0, 0, 1.0).expect("update failed");
241        assert!(
242            (darts.alpha[0][0] - before).abs() > 1e-12,
243            "alpha did not change"
244        );
245    }
246
247    #[test]
248    fn test_update_alpha_out_of_range_errors() {
249        let mut darts = make_darts();
250        assert!(darts.update_alpha(9999, 0, 1.0).is_err());
251        assert!(darts.update_alpha(0, 9999, 1.0).is_err());
252    }
253
254    #[test]
255    fn test_update_alpha_batch_correct_shape() {
256        let mut darts = make_darts();
257        let n_e = darts.n_edges();
258        let n_o = darts.n_ops;
259        let grads = vec![vec![0.1; n_o]; n_e];
260        darts
261            .update_alpha_batch(&grads)
262            .expect("batch update failed");
263    }
264
265    #[test]
266    fn test_update_alpha_batch_wrong_shape_errors() {
267        let mut darts = make_darts();
268        let grads = vec![vec![0.1; darts.n_ops]; darts.n_edges() + 1];
269        assert!(darts.update_alpha_batch(&grads).is_err());
270    }
271
272    #[test]
273    fn test_argmax_selects_highest_weight() {
274        let space = SearchSpace::darts_like(2);
275        let mut darts = DARTSSearch::new(2, &space.operations, 2);
276        // Manually set edge 0 to strongly prefer op 3
277        let n_ops = darts.n_ops;
278        for k in 0..n_ops {
279            darts.alpha[0][k] = 0.0;
280        }
281        darts.alpha[0][3] = 10.0;
282
283        let arch = darts.derive_architecture(&space, 1, 32, 10);
284        // The op on the first edge should be space.operations[3]
285        if let Some(e) = arch.edges.first() {
286            assert_eq!(e.op, space.operations[3]);
287        }
288    }
289}