1use super::{AnnealingStrategy, Lcg, Operation, TemperatureSchedule};
23use crate::error::{OptimizeError, OptimizeResult};
24
25#[derive(Debug, Clone)]
29pub struct SnasConfig {
30 pub n_cells: usize,
32 pub n_operations: usize,
34 pub channels: usize,
36 pub n_nodes: usize,
38 pub arch_lr: f64,
40 pub weight_lr: f64,
42 pub temperature_schedule: TemperatureSchedule,
44 pub resource_weight: f64,
46 pub seed: u64,
48}
49
50impl Default for SnasConfig {
51 fn default() -> Self {
52 Self {
53 n_cells: 3,
54 n_operations: 6,
55 channels: 32,
56 n_nodes: 4,
57 arch_lr: 3e-4,
58 weight_lr: 1e-3,
59 temperature_schedule: TemperatureSchedule::new(
60 1.0,
61 0.1,
62 AnnealingStrategy::Exponential,
63 100,
64 ),
65 resource_weight: 0.001,
66 seed: 42,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
78pub struct SnasMixedOperation {
79 pub arch_params: Vec<f64>,
81 pub last_concrete_weights: Vec<f64>,
83}
84
85impl SnasMixedOperation {
86 pub fn new(n_ops: usize) -> Self {
88 Self {
89 arch_params: vec![0.0_f64; n_ops],
90 last_concrete_weights: vec![1.0 / n_ops as f64; n_ops],
91 }
92 }
93
94 pub fn concrete_sample(&self, temperature: f64, rng: &mut Lcg) -> Vec<f64> {
100 let eps = 1e-20_f64;
101 let temp = temperature.max(1e-8);
102 let n = self.arch_params.len();
103
104 let mut logits = vec![0.0_f64; n];
105 for k in 0..n {
106 let u = rng.next_f64().max(eps);
107 let gumbel_noise = -(-u.ln()).ln();
108 logits[k] = self.arch_params[k] + gumbel_noise;
109 }
110
111 let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
113 let mut exp_vals: Vec<f64> = logits.iter().map(|&l| ((l - max_l) / temp).exp()).collect();
114 let sum = exp_vals.iter().sum::<f64>().max(eps);
115 for v in &mut exp_vals {
116 *v /= sum;
117 }
118 exp_vals
119 }
120
121 pub fn weights(&self, temperature: f64) -> Vec<f64> {
123 let t = temperature.max(1e-8);
124 let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
125 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
126 let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
127 let sum: f64 = exps.iter().sum();
128 if sum == 0.0 {
129 let n = self.arch_params.len();
130 vec![1.0 / n as f64; n]
131 } else {
132 exps.iter().map(|e| e / sum).collect()
133 }
134 }
135
136 pub fn expected_cost(&self, temperature: f64, channels: usize) -> f64 {
141 let ops = Operation::all();
142 let w = self.weights(temperature);
143 w.iter()
144 .zip(ops.iter())
145 .take(self.arch_params.len())
146 .map(|(&wk, op)| wk * op.cost_flops(channels))
147 .sum()
148 }
149
150 pub fn argmax_op(&self) -> usize {
152 self.arch_params
153 .iter()
154 .enumerate()
155 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
156 .map(|(i, _)| i)
157 .unwrap_or(0)
158 }
159
160 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) {
162 for (p, g) in self.arch_params.iter_mut().zip(grads.iter()) {
163 *p -= lr * g;
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
172pub struct SnasCell {
173 pub n_nodes: usize,
175 pub n_input_nodes: usize,
177 pub edges: Vec<Vec<SnasMixedOperation>>,
180}
181
182impl SnasCell {
183 pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
185 let edges: Vec<Vec<SnasMixedOperation>> = (0..n_intermediate_nodes)
186 .map(|i| {
187 let n_predecessors = n_input_nodes + i;
188 (0..n_predecessors)
189 .map(|_| SnasMixedOperation::new(n_ops))
190 .collect()
191 })
192 .collect();
193 Self {
194 n_nodes: n_intermediate_nodes,
195 n_input_nodes,
196 edges,
197 }
198 }
199
200 pub fn arch_parameters(&self) -> Vec<f64> {
202 self.edges
203 .iter()
204 .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
205 .collect()
206 }
207
208 pub fn total_expected_cost(&self, temperature: f64, channels: usize) -> f64 {
210 self.edges
211 .iter()
212 .flat_map(|row| row.iter())
213 .map(|mo| mo.expected_cost(temperature, channels))
214 .sum()
215 }
216
217 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
219 let n_params: usize = self
220 .edges
221 .iter()
222 .flat_map(|row| row.iter())
223 .map(|mo| mo.arch_params.len())
224 .sum();
225 if grads.len() != n_params {
226 return Err(OptimizeError::InvalidInput(format!(
227 "SnasCell::update_arch_params: expected {n_params} grads, got {}",
228 grads.len()
229 )));
230 }
231 let mut idx = 0;
232 for row in self.edges.iter_mut() {
233 for mo in row.iter_mut() {
234 let n = mo.arch_params.len();
235 mo.update_arch_params(&grads[idx..idx + n], lr);
236 idx += n;
237 }
238 }
239 Ok(())
240 }
241
242 pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
244 self.edges
245 .iter()
246 .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
247 .collect()
248 }
249}
250
251pub struct SnasSearch {
257 pub cells: Vec<SnasCell>,
259 pub config: SnasConfig,
261 weights: Vec<f64>,
263 rng: Lcg,
265 current_step: usize,
267}
268
269impl SnasSearch {
270 pub fn new(config: SnasConfig) -> Self {
272 let cells: Vec<SnasCell> = (0..config.n_cells)
273 .map(|_| SnasCell::new(2, config.n_nodes, config.n_operations))
274 .collect();
275 let weights = vec![0.01_f64; config.n_cells];
276 let rng = Lcg::new(config.seed);
277 Self {
278 cells,
279 config,
280 weights,
281 rng,
282 current_step: 0,
283 }
284 }
285
286 pub fn current_temperature(&self) -> f64 {
288 self.config
289 .temperature_schedule
290 .temperature_at(self.current_step)
291 }
292
293 pub fn arch_parameters(&self) -> Vec<f64> {
295 self.cells
296 .iter()
297 .flat_map(|c| c.arch_parameters())
298 .collect()
299 }
300
301 pub fn n_arch_params(&self) -> usize {
303 self.cells.iter().map(|c| c.arch_parameters().len()).sum()
304 }
305
306 pub fn total_expected_cost(&self) -> f64 {
308 let temp = self.current_temperature();
309 let channels = self.config.channels;
310 self.cells
311 .iter()
312 .map(|c| c.total_expected_cost(temp, channels))
313 .sum()
314 }
315
316 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
318 let total = self.n_arch_params();
319 if grads.len() != total {
320 return Err(OptimizeError::InvalidInput(format!(
321 "SnasSearch::update_arch_params: expected {total} grads, got {}",
322 grads.len()
323 )));
324 }
325 let mut offset = 0;
326 for cell in self.cells.iter_mut() {
327 let n = cell.arch_parameters().len();
328 cell.update_arch_params(&grads[offset..offset + n], lr)?;
329 offset += n;
330 }
331 Ok(())
332 }
333
334 pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
336 self.cells.iter().map(|c| c.derive_discrete()).collect()
337 }
338
339 pub fn arch_grads_fd(&self, val_fn: impl Fn(&[f64]) -> f64, step: f64) -> Vec<f64> {
345 let params = self.arch_parameters();
346 let n = params.len();
347 let lambda = self.config.resource_weight;
348 let temp = self.current_temperature();
349 let channels = self.config.channels;
350
351 let mut grads = vec![0.0_f64; n];
352 for i in 0..n {
353 let mut p_plus = params.clone();
354 p_plus[i] += step;
355 let mut p_minus = params.clone();
356 p_minus[i] -= step;
357
358 let cost_plus = resource_cost_at(&p_plus, &self.cells, temp, channels, lambda);
360 let cost_minus = resource_cost_at(&p_minus, &self.cells, temp, channels, lambda);
361
362 let task_grad = (val_fn(&p_plus) - val_fn(&p_minus)) / (2.0 * step);
363 let cost_grad = (cost_plus - cost_minus) / (2.0 * step);
364 grads[i] = task_grad + cost_grad;
365 }
366 grads
367 }
368
369 pub fn bilevel_step(
373 &mut self,
374 weight_grad_fn: impl Fn(&[f64]) -> Vec<f64>,
375 val_fn: impl Fn(&[f64]) -> f64,
376 ) -> OptimizeResult<()> {
377 self.current_step += 1;
378
379 let w_grads = weight_grad_fn(&self.weights);
381 if w_grads.len() != self.weights.len() {
382 return Err(OptimizeError::InvalidInput(format!(
383 "weight_grad_fn returned {} grads, expected {}",
384 w_grads.len(),
385 self.weights.len()
386 )));
387 }
388 let lr_w = self.config.weight_lr;
389 for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
390 *w -= lr_w * g;
391 }
392
393 let a_grads = self.arch_grads_fd(&val_fn, 1e-4);
395 if !a_grads.is_empty() {
396 self.update_arch_params(&a_grads, self.config.arch_lr)?;
397 }
398
399 Ok(())
400 }
401}
402
403fn resource_cost_at(
408 params: &[f64],
409 cells: &[SnasCell],
410 temperature: f64,
411 channels: usize,
412 lambda: f64,
413) -> f64 {
414 let ops = Operation::all();
415 let n_ops_canonical = ops.len();
416 let eps = 1e-8_f64;
417 let temp = temperature.max(eps);
418
419 let mut total_cost = 0.0_f64;
420 let mut offset = 0_usize;
421
422 for cell in cells.iter() {
423 for node_edges in cell.edges.iter() {
424 for mo in node_edges.iter() {
425 let n = mo.arch_params.len().min(n_ops_canonical);
426 let slice = ¶ms[offset..offset + mo.arch_params.len()];
427
428 let scaled: Vec<f64> = slice.iter().map(|a| a / temp).collect();
430 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
431 let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
432 let sum: f64 = exps.iter().sum::<f64>().max(eps);
433
434 for k in 0..n {
435 let wk = exps[k] / sum;
436 total_cost += wk * ops[k].cost_flops(channels);
437 }
438
439 offset += mo.arch_params.len();
440 }
441 }
442 }
443
444 lambda * total_cost
445}
446
447#[cfg(test)]
450mod tests {
451 use super::*;
452
453 fn make_lcg() -> Lcg {
454 Lcg::new(99999)
455 }
456
457 #[test]
460 fn test_concrete_sample_valid() {
461 let mo = SnasMixedOperation::new(6);
462 let mut rng = make_lcg();
463 let weights = mo.concrete_sample(1.0, &mut rng);
464
465 assert_eq!(weights.len(), 6);
466 let sum: f64 = weights.iter().sum();
467 assert!((sum - 1.0).abs() < 1e-9, "concrete sample sum={sum}");
468 for &w in &weights {
469 assert!(w >= 0.0, "negative concrete weight {w}");
470 }
471 }
472
473 #[test]
474 fn test_concrete_sample_multiple_calls_valid() {
475 let mo = SnasMixedOperation::new(6);
476 let mut rng = make_lcg();
477 for _ in 0..20 {
478 let w = mo.concrete_sample(0.5, &mut rng);
479 let sum: f64 = w.iter().sum();
480 assert!((sum - 1.0).abs() < 1e-9);
481 for &v in &w {
482 assert!(v >= 0.0);
483 }
484 }
485 }
486
487 #[test]
488 fn test_concrete_sample_low_temp_peaks() {
489 let mut mo = SnasMixedOperation::new(6);
491 mo.arch_params = vec![5.0, 0.1, 0.1, 0.1, 0.1, 0.1];
492 let mut rng = make_lcg();
493 let mut dominant_count = 0;
495 for _ in 0..20 {
496 let w = mo.concrete_sample(0.01, &mut rng);
497 if w[0] > 0.5 {
498 dominant_count += 1;
499 }
500 }
501 assert!(
503 dominant_count >= 10,
504 "dominant_count={dominant_count} too low"
505 );
506 }
507
508 #[test]
511 fn test_expected_cost_nonneg() {
512 let mo = SnasMixedOperation::new(6);
513 let cost = mo.expected_cost(1.0, 16);
514 assert!(cost >= 0.0, "cost={cost}");
515 }
516
517 #[test]
518 fn test_total_expected_cost_nonneg() {
519 let config = SnasConfig::default();
520 let search = SnasSearch::new(config);
521 let cost = search.total_expected_cost();
522 assert!(cost >= 0.0, "total cost={cost}");
523 }
524
525 #[test]
526 fn test_expected_cost_zero_for_no_flop_ops() {
527 let mut mo = SnasMixedOperation::new(3); mo.arch_params = vec![10.0, 10.0, 10.0]; let cost = mo.expected_cost(1.0, 16);
532 assert!(cost >= 0.0);
536 }
537
538 #[test]
541 fn test_snas_cell_arch_params_shape() {
542 let cell = SnasCell::new(2, 4, 6);
543 assert_eq!(cell.arch_parameters().len(), 84);
545 }
546
547 #[test]
548 fn test_snas_cell_update_wrong_len_errors() {
549 let mut cell = SnasCell::new(2, 3, 6);
550 let result = cell.update_arch_params(&[0.0; 3], 0.01);
551 assert!(result.is_err());
552 }
553
554 #[test]
557 fn test_snas_bilevel_step_runs() {
558 let config = SnasConfig::default();
559 let mut search = SnasSearch::new(config);
560
561 let weight_grad_fn = |weights: &[f64]| vec![0.0_f64; weights.len()];
562 let val_fn = |params: &[f64]| params.iter().map(|p| p * p).sum::<f64>();
563
564 search
565 .bilevel_step(weight_grad_fn, val_fn)
566 .expect("snas bilevel_step should not error");
567 }
568
569 #[test]
570 fn test_snas_bilevel_step_advances_temperature() {
571 let config = SnasConfig::default();
572 let mut search = SnasSearch::new(config);
573 let t0 = search.current_temperature();
574 let _ = search.bilevel_step(|w| vec![0.0; w.len()], |p| p.iter().sum::<f64>());
575 let t1 = search.current_temperature();
576 assert!(t1 <= t0 + 1e-12, "t1={t1} should be ≤ t0={t0}");
577 }
578
579 #[test]
580 fn test_derive_discrete_arch_valid() {
581 let config = SnasConfig {
582 n_cells: 2,
583 n_operations: 6,
584 n_nodes: 3,
585 ..Default::default()
586 };
587 let search = SnasSearch::new(config);
588 let arch = search.derive_discrete_arch_indices();
589 assert_eq!(arch.len(), 2);
590 for cell_disc in &arch {
591 for node_edges in cell_disc {
592 for &op_idx in node_edges {
593 assert!(op_idx < 6, "op_idx={op_idx}");
594 }
595 }
596 }
597 }
598
599 #[test]
600 fn test_snas_arch_params_consistent() {
601 let search = SnasSearch::new(SnasConfig::default());
602 assert_eq!(search.arch_parameters().len(), search.n_arch_params());
603 }
604}