1use crate::error::{OptimizeError, OptimizeResult};
14
15#[non_exhaustive]
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum Operation {
21 Identity,
23 Zero,
25 Conv3x3,
27 Conv5x5,
29 MaxPool,
31 AvgPool,
33 SkipConnect,
35}
36
37impl Operation {
38 pub fn cost_flops(&self, channels: usize) -> f64 {
43 let c = channels as f64;
44 match self {
45 Operation::Identity => 0.0,
46 Operation::Zero => 0.0,
47 Operation::Conv3x3 => 2.0 * 9.0 * c * c,
48 Operation::Conv5x5 => 2.0 * 25.0 * c * c,
49 Operation::MaxPool => c, Operation::AvgPool => c, Operation::SkipConnect => 0.0,
52 }
53 }
54
55 pub fn name(&self) -> &'static str {
57 match self {
58 Operation::Identity => "identity",
59 Operation::Zero => "zero",
60 Operation::Conv3x3 => "conv3x3",
61 Operation::Conv5x5 => "conv5x5",
62 Operation::MaxPool => "max_pool",
63 Operation::AvgPool => "avg_pool",
64 Operation::SkipConnect => "skip_connect",
65 }
66 }
67
68 pub fn all() -> &'static [Operation] {
70 &[
71 Operation::Identity,
72 Operation::Zero,
73 Operation::Conv3x3,
74 Operation::Conv5x5,
75 Operation::MaxPool,
76 Operation::AvgPool,
77 ]
78 }
79}
80
81#[derive(Debug, Clone)]
85pub struct DartsConfig {
86 pub n_cells: usize,
88 pub n_operations: usize,
90 pub channels: usize,
92 pub n_nodes: usize,
94 pub arch_lr: f64,
96 pub weight_lr: f64,
98 pub temperature: f64,
100}
101
102impl Default for DartsConfig {
103 fn default() -> Self {
104 Self {
105 n_cells: 4,
106 n_operations: 6,
107 channels: 16,
108 n_nodes: 4,
109 arch_lr: 3e-4,
110 weight_lr: 3e-4,
111 temperature: 1.0,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
122pub struct MixedOperation {
123 pub arch_params: Vec<f64>,
125 pub operation_outputs: Option<Vec<Vec<f64>>>,
127}
128
129impl MixedOperation {
130 pub fn new(n_ops: usize) -> Self {
133 Self {
134 arch_params: vec![0.0_f64; n_ops],
135 operation_outputs: None,
136 }
137 }
138
139 pub fn weights(&self, temperature: f64) -> Vec<f64> {
143 let t = temperature.max(1e-8); let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
145 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
147 let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
148 let sum: f64 = exps.iter().sum();
149 if sum == 0.0 {
150 vec![1.0 / self.arch_params.len() as f64; self.arch_params.len()]
151 } else {
152 exps.iter().map(|e| e / sum).collect()
153 }
154 }
155
156 pub fn forward(
160 &mut self,
161 x: &[f64],
162 op_fn: impl Fn(usize, &[f64]) -> Vec<f64>,
163 temperature: f64,
164 ) -> Vec<f64> {
165 let w = self.weights(temperature);
166 let n_ops = self.arch_params.len();
167 let op_outputs: Vec<Vec<f64>> = (0..n_ops).map(|k| op_fn(k, x)).collect();
169 let out_len = op_outputs.first().map(|v| v.len()).unwrap_or(x.len());
171 let mut result = vec![0.0_f64; out_len];
172 for (k, out) in op_outputs.iter().enumerate() {
173 for (r, o) in result.iter_mut().zip(out.iter()) {
174 *r += w[k] * o;
175 }
176 }
177 self.operation_outputs = Some(op_outputs);
178 result
179 }
180
181 pub fn argmax_op(&self) -> usize {
183 self.arch_params
184 .iter()
185 .enumerate()
186 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
187 .map(|(i, _)| i)
188 .unwrap_or(0)
189 }
190}
191
192#[derive(Debug, Clone)]
198pub struct DartsCell {
199 pub n_nodes: usize,
201 pub n_input_nodes: usize,
203 pub edges: Vec<Vec<MixedOperation>>,
207}
208
209impl DartsCell {
210 pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
217 let edges: Vec<Vec<MixedOperation>> = (0..n_intermediate_nodes)
220 .map(|i| {
221 let n_predecessors = n_input_nodes + i;
222 (0..n_predecessors)
223 .map(|_| MixedOperation::new(n_ops))
224 .collect()
225 })
226 .collect();
227
228 Self {
229 n_nodes: n_intermediate_nodes,
230 n_input_nodes,
231 edges,
232 }
233 }
234
235 pub fn forward(&mut self, inputs: &[Vec<f64>], temperature: f64) -> Vec<f64> {
244 if inputs.is_empty() {
245 return Vec::new();
246 }
247 let feature_len = inputs[0].len();
248 let mut node_outputs: Vec<Vec<f64>> = inputs.to_vec();
250
251 for i in 0..self.n_nodes {
252 let n_prev = self.n_input_nodes + i;
253 let mut node_out = vec![0.0_f64; feature_len];
254 for j in 0..n_prev {
255 let src = node_outputs[j].clone();
256 let edge_out = self.edges[i][j].forward(&src, default_op_fn, temperature);
257 for (no, eo) in node_out.iter_mut().zip(edge_out.iter()) {
258 *no += eo;
259 }
260 }
261 node_outputs.push(node_out);
262 }
263
264 let mut result = Vec::with_capacity(self.n_nodes * feature_len);
266 for node_out in node_outputs.iter().skip(self.n_input_nodes) {
267 result.extend_from_slice(node_out);
268 }
269 result
270 }
271
272 pub fn arch_parameters(&self) -> Vec<f64> {
274 self.edges
275 .iter()
276 .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
277 .collect()
278 }
279
280 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
284 let n_params: usize = self
285 .edges
286 .iter()
287 .flat_map(|row| row.iter())
288 .map(|mo| mo.arch_params.len())
289 .sum();
290 if grads.len() != n_params {
291 return Err(OptimizeError::InvalidInput(format!(
292 "Expected {} gradient values, got {}",
293 n_params,
294 grads.len()
295 )));
296 }
297 let mut idx = 0;
298 for row in self.edges.iter_mut() {
299 for mo in row.iter_mut() {
300 for p in mo.arch_params.iter_mut() {
301 *p -= lr * grads[idx];
302 idx += 1;
303 }
304 }
305 }
306 Ok(())
307 }
308
309 pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
314 self.edges
315 .iter()
316 .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
317 .collect()
318 }
319}
320
321fn default_op_fn(_k: usize, x: &[f64]) -> Vec<f64> {
323 x.to_vec()
324}
325
326#[derive(Debug, Clone)]
332pub struct DartsSearch {
333 pub cells: Vec<DartsCell>,
335 pub config: DartsConfig,
337 weights: Vec<f64>,
339}
340
341impl DartsSearch {
342 pub fn new(config: DartsConfig) -> Self {
344 let cells: Vec<DartsCell> = (0..config.n_cells)
345 .map(|_| DartsCell::new(2, config.n_nodes, config.n_operations))
346 .collect();
347 let weights = vec![0.01_f64; config.n_cells];
349 Self {
350 cells,
351 config,
352 weights,
353 }
354 }
355
356 pub fn arch_parameters(&self) -> Vec<f64> {
358 self.cells
359 .iter()
360 .flat_map(|c| c.arch_parameters())
361 .collect()
362 }
363
364 pub fn n_arch_params(&self) -> usize {
366 self.cells.iter().map(|c| c.arch_parameters().len()).sum()
367 }
368
369 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
373 let total = self.n_arch_params();
374 if grads.len() != total {
375 return Err(OptimizeError::InvalidInput(format!(
376 "Expected {} arch-param grads, got {}",
377 total,
378 grads.len()
379 )));
380 }
381 let mut offset = 0;
382 for cell in self.cells.iter_mut() {
383 let n = cell.arch_parameters().len();
384 cell.update_arch_params(&grads[offset..offset + n], lr)?;
385 offset += n;
386 }
387 Ok(())
388 }
389
390 pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
394 self.cells.iter().map(|c| c.derive_discrete()).collect()
395 }
396
397 pub fn derive_discrete_arch(&self) -> Vec<Vec<Operation>> {
403 let ops = Operation::all();
404 self.derive_discrete_arch_indices()
405 .iter()
406 .map(|cell_disc| {
407 cell_disc
408 .iter()
409 .flat_map(|node_edges| {
410 node_edges.iter().map(|&idx| {
411 if idx < ops.len() {
412 ops[idx]
413 } else {
414 Operation::Identity
415 }
416 })
417 })
418 .collect()
419 })
420 .collect()
421 }
422
423 fn compute_loss(&self, x: &[Vec<f64>], y: &[f64]) -> f64 {
428 if x.is_empty() || y.is_empty() {
429 return 0.0;
430 }
431 let w_sum: f64 = self.weights.iter().sum();
432 let mut loss = 0.0_f64;
433 let n = x.len().min(y.len());
434 for i in 0..n {
435 let x_mean = if x[i].is_empty() {
436 0.0
437 } else {
438 x[i].iter().sum::<f64>() / x[i].len() as f64
439 };
440 let pred = w_sum * x_mean;
441 let diff = pred - y[i];
442 loss += diff * diff;
443 }
444 loss / n as f64
445 }
446
447 fn weight_grads(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
449 let n = x.len().min(y.len());
450 if n == 0 {
451 return vec![0.0_f64; self.weights.len()];
452 }
453 let w_sum: f64 = self.weights.iter().sum();
454 let mut grad_sum = 0.0_f64;
455 for i in 0..n {
456 let x_mean = if x[i].is_empty() {
457 0.0
458 } else {
459 x[i].iter().sum::<f64>() / x[i].len() as f64
460 };
461 let pred = w_sum * x_mean;
462 let diff = pred - y[i];
463 grad_sum += 2.0 * diff * x_mean / n as f64;
465 }
466 vec![grad_sum; self.weights.len()]
468 }
469
470 fn arch_grads_fd(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
473 let n = self.n_arch_params();
474 if n == 0 {
475 return Vec::new();
476 }
477 let mut grads = vec![0.0_f64; n];
478 let h = 1e-4;
479 let mut offset = 0;
480 for cell_idx in 0..self.cells.len() {
481 let cell_n = self.cells[cell_idx].arch_parameters().len();
482 for local_j in 0..cell_n {
483 let global_j = offset + local_j;
484 let mut search_plus = self.clone();
486 let params_plus = search_plus.cells[cell_idx].arch_parameters();
487 let mut p_plus = params_plus.clone();
488 p_plus[local_j] += h;
489 let _ = search_plus.cells[cell_idx].set_arch_params(&p_plus);
491 let loss_plus = search_plus.compute_loss(x, y);
492
493 let mut search_minus = self.clone();
495 let params_minus = search_minus.cells[cell_idx].arch_parameters();
496 let mut p_minus = params_minus.clone();
497 p_minus[local_j] -= h;
498 let _ = search_minus.cells[cell_idx].set_arch_params(&p_minus);
499 let loss_minus = search_minus.compute_loss(x, y);
500
501 grads[global_j] = (loss_plus - loss_minus) / (2.0 * h);
502 }
503 offset += cell_n;
504 }
505 grads
506 }
507
508 pub fn bilevel_step(
515 &mut self,
516 train_x: &[Vec<f64>],
517 train_y: &[f64],
518 val_x: &[Vec<f64>],
519 val_y: &[f64],
520 ) -> (f64, f64) {
521 let train_loss = self.compute_loss(train_x, train_y);
522 let val_loss = self.compute_loss(val_x, val_y);
523
524 let w_grads = self.weight_grads(train_x, train_y);
526 let lr_w = self.config.weight_lr;
527 for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
528 *w -= lr_w * g;
529 }
530
531 let a_grads = self.arch_grads_fd(val_x, val_y);
533 let lr_a = self.config.arch_lr;
534 if !a_grads.is_empty() {
535 let _ = self.update_arch_params(&a_grads, lr_a);
536 }
537
538 (train_loss, val_loss)
539 }
540}
541
542impl DartsCell {
545 pub fn set_arch_params(&mut self, params: &[f64]) -> OptimizeResult<()> {
547 let total: usize = self
548 .edges
549 .iter()
550 .flat_map(|r| r.iter())
551 .map(|m| m.arch_params.len())
552 .sum();
553 if params.len() != total {
554 return Err(OptimizeError::InvalidInput(format!(
555 "set_arch_params: expected {total} values, got {}",
556 params.len()
557 )));
558 }
559 let mut idx = 0;
560 for row in self.edges.iter_mut() {
561 for mo in row.iter_mut() {
562 for p in mo.arch_params.iter_mut() {
563 *p = params[idx];
564 idx += 1;
565 }
566 }
567 }
568 Ok(())
569 }
570}
571
572#[cfg(test)]
575mod tests {
576 use super::*;
577
578 #[test]
579 fn mixed_operation_weights_sum_to_one() {
580 let mo = MixedOperation::new(6);
581 let w = mo.weights(1.0);
582 assert_eq!(w.len(), 6);
583 let sum: f64 = w.iter().sum();
584 assert!((sum - 1.0).abs() < 1e-10, "weights sum = {sum}");
585 }
586
587 #[test]
588 fn mixed_operation_weights_temperature_effect() {
589 let mut mo = MixedOperation::new(4);
591 mo.arch_params = vec![1.0, 0.5, 0.3, 0.2];
592 let w_hot = mo.weights(10.0);
593 let w_cold = mo.weights(0.1);
594 assert!(w_cold[0] > w_hot[0], "cold should be sharper");
596 }
597
598 #[test]
599 fn mixed_operation_forward_correct_shape() {
600 let mut mo = MixedOperation::new(3);
601 let x = vec![1.0_f64; 8];
602 let out = mo.forward(&x, |_k, v| v.to_vec(), 1.0);
603 assert_eq!(out.len(), 8);
604 }
605
606 #[test]
607 fn darts_cell_forward_output_shape() {
608 let mut cell = DartsCell::new(2, 3, 4);
609 let inputs = vec![vec![1.0_f64; 8], vec![0.5_f64; 8]];
610 let out = cell.forward(&inputs, 1.0);
611 assert_eq!(out.len(), 24);
613 }
614
615 #[test]
616 fn derive_discrete_arch_returns_ops() {
617 let config = DartsConfig {
618 n_cells: 2,
619 n_operations: 6,
620 n_nodes: 3,
621 ..Default::default()
622 };
623 let search = DartsSearch::new(config);
624 let arch = search.derive_discrete_arch();
625 assert_eq!(arch.len(), 2, "one vec per cell");
626 for cell_ops in &arch {
629 assert!(!cell_ops.is_empty());
630 }
631 }
632
633 #[test]
634 fn bilevel_step_runs_without_error() {
635 let config = DartsConfig::default();
636 let mut search = DartsSearch::new(config);
637 let train_x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
638 let train_y = vec![1.5, 3.5];
639 let val_x = vec![vec![0.5, 1.5]];
640 let val_y = vec![1.0];
641 let (tl, vl) = search.bilevel_step(&train_x, &train_y, &val_x, &val_y);
642 assert!(tl.is_finite());
643 assert!(vl.is_finite());
644 }
645
646 #[test]
647 fn arch_parameters_length_consistent() {
648 let config = DartsConfig {
649 n_cells: 3,
650 n_operations: 5,
651 n_nodes: 2,
652 ..Default::default()
653 };
654 let search = DartsSearch::new(config);
655 let params = search.arch_parameters();
656 assert_eq!(params.len(), search.n_arch_params());
657 }
658
659 #[test]
660 fn update_arch_params_wrong_length_errors() {
661 let mut search = DartsSearch::new(DartsConfig::default());
662 let result = search.update_arch_params(&[1.0, 2.0], 0.01);
663 assert!(result.is_err());
664 }
665}