1use ghostflow_core::Tensor;
11use std::collections::HashMap;
12use rand::Rng;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum Operation {
17 SepConv3x3,
19 SepConv5x5,
21 DilConv3x3,
23 DilConv5x5,
25 MaxPool3x3,
27 AvgPool3x3,
29 Skip,
31 Zero,
33}
34
35impl Operation {
36 pub fn all() -> Vec<Operation> {
38 vec![
39 Operation::SepConv3x3,
40 Operation::SepConv5x5,
41 Operation::DilConv3x3,
42 Operation::DilConv5x5,
43 Operation::MaxPool3x3,
44 Operation::AvgPool3x3,
45 Operation::Skip,
46 Operation::Zero,
47 ]
48 }
49
50 pub fn cost(&self) -> f32 {
52 match self {
53 Operation::SepConv3x3 => 9.0,
54 Operation::SepConv5x5 => 25.0,
55 Operation::DilConv3x3 => 9.0,
56 Operation::DilConv5x5 => 25.0,
57 Operation::MaxPool3x3 => 1.0,
58 Operation::AvgPool3x3 => 1.0,
59 Operation::Skip => 0.0,
60 Operation::Zero => 0.0,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct Cell {
68 pub num_nodes: usize,
70 pub edges: Vec<(usize, usize, Operation)>,
72 pub alpha: HashMap<(usize, usize), Vec<f32>>,
74}
75
76impl Cell {
77 pub fn random(num_nodes: usize) -> Self {
79 let mut rng = rand::thread_rng();
80 let mut edges = Vec::new();
81 let mut alpha = HashMap::new();
82
83 for to_node in 2..num_nodes {
85 for from_node in 0..to_node {
86 let ops = Operation::all();
88 let op = ops[rng.gen_range(0..ops.len())];
89 edges.push((from_node, to_node, op));
90
91 let num_ops = ops.len();
93 let weights: Vec<f32> = (0..num_ops)
94 .map(|_| rng.gen_range(-0.1..0.1))
95 .collect();
96 alpha.insert((from_node, to_node), weights);
97 }
98 }
99
100 Cell {
101 num_nodes,
102 edges,
103 alpha,
104 }
105 }
106
107 pub fn get_genotype(&self) -> Vec<(usize, usize, Operation)> {
109 let mut genotype = Vec::new();
110 let ops = Operation::all();
111
112 for ((from, to), weights) in &self.alpha {
113 let (max_idx, _) = weights.iter()
115 .enumerate()
116 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
117 .unwrap();
118
119 genotype.push((*from, *to, ops[max_idx]));
120 }
121
122 genotype
123 }
124
125 pub fn compute_cost(&self) -> f32 {
127 self.edges.iter().map(|(_, _, op)| op.cost()).sum()
128 }
129}
130
131pub struct DARTS {
133 pub normal_cell: Cell,
135 pub reduction_cell: Cell,
137 pub num_cells: usize,
139 pub arch_lr: f32,
141 pub weight_lr: f32,
143}
144
145impl DARTS {
146 pub fn new(num_nodes: usize, num_cells: usize) -> Self {
148 DARTS {
149 normal_cell: Cell::random(num_nodes),
150 reduction_cell: Cell::random(num_nodes),
151 num_cells,
152 arch_lr: 3e-4,
153 weight_lr: 0.025,
154 }
155 }
156
157 pub fn search_step(&mut self, train_loss: f32, val_loss: f32) {
159 for ((from, to), weights) in self.normal_cell.alpha.iter_mut() {
163 let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
165 let exp_sum: f32 = weights.iter().map(|w| (w - max_w).exp()).sum();
166
167 for (i, w) in weights.iter_mut().enumerate() {
169 let prob = (*w - max_w).exp() / exp_sum;
170 let grad = val_loss * (prob - if i == 0 { 1.0 } else { 0.0 });
172 *w -= self.arch_lr * grad;
173 }
174 }
175
176 for ((from, to), weights) in self.reduction_cell.alpha.iter_mut() {
178 let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
179 let exp_sum: f32 = weights.iter().map(|w| (w - max_w).exp()).sum();
180
181 for (i, w) in weights.iter_mut().enumerate() {
182 let prob = (*w - max_w).exp() / exp_sum;
183 let grad = val_loss * (prob - if i == 0 { 1.0 } else { 0.0 });
184 *w -= self.arch_lr * grad;
185 }
186 }
187 }
188
189 pub fn derive_architecture(&self) -> (Vec<(usize, usize, Operation)>, Vec<(usize, usize, Operation)>) {
191 (self.normal_cell.get_genotype(), self.reduction_cell.get_genotype())
192 }
193
194 pub fn total_cost(&self) -> f32 {
196 let normal_cost = self.normal_cell.compute_cost();
197 let reduction_cost = self.reduction_cell.compute_cost();
198
199 let num_reduction = (self.num_cells as f32 / 3.0).ceil() as usize;
201 let num_normal = self.num_cells - num_reduction;
202
203 normal_cost * num_normal as f32 + reduction_cost * num_reduction as f32
204 }
205}
206
207pub struct ENAS {
209 pub shared_weights: HashMap<Operation, Tensor>,
211 pub controller_state: Vec<f32>,
213 pub architecture_pool: Vec<(Cell, f32)>,
215 pub num_samples: usize,
217}
218
219impl ENAS {
220 pub fn new(num_samples: usize) -> Self {
222 let mut shared_weights = HashMap::new();
223
224 for op in Operation::all() {
226 let weight = Tensor::randn(&[64, 64]); shared_weights.insert(op, weight);
228 }
229
230 ENAS {
231 shared_weights,
232 controller_state: vec![0.0; 128], architecture_pool: Vec::new(),
234 num_samples,
235 }
236 }
237
238 pub fn sample_architecture(&mut self, num_nodes: usize) -> Cell {
240 let mut rng = rand::thread_rng();
241 let mut cell = Cell::random(num_nodes);
242
243 for ((from, to), weights) in cell.alpha.iter_mut() {
246 let temperature = 1.0;
248 let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
249 let exp_sum: f32 = weights.iter()
250 .map(|w| ((w - max_w) / temperature).exp())
251 .sum();
252
253 let sample: f32 = rng.gen();
255 let mut cumsum = 0.0;
256 for (i, w) in weights.iter().enumerate() {
257 let prob = ((w - max_w) / temperature).exp() / exp_sum;
258 cumsum += prob;
259 if sample < cumsum {
260 weights[i] = 1.0;
262 break;
263 }
264 }
265 }
266
267 cell
268 }
269
270 pub fn train_step(&mut self, num_nodes: usize) -> f32 {
272 let mut total_reward = 0.0;
273
274 for _ in 0..self.num_samples {
276 let arch = self.sample_architecture(num_nodes);
277
278 let reward = self.evaluate_architecture(&arch);
280 total_reward += reward;
281
282 self.architecture_pool.push((arch, reward));
284 }
285
286 self.architecture_pool.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
288 self.architecture_pool.truncate(100);
289
290 let avg_reward = total_reward / self.num_samples as f32;
292
293 for state in self.controller_state.iter_mut() {
295 *state += 0.01 * avg_reward;
296 }
297
298 avg_reward
299 }
300
301 fn evaluate_architecture(&self, arch: &Cell) -> f32 {
303 let cost = arch.compute_cost();
305 let lambda = 0.001; let num_skip = arch.edges.iter()
309 .filter(|(_, _, op)| *op == Operation::Skip)
310 .count();
311 let num_zero = arch.edges.iter()
312 .filter(|(_, _, op)| *op == Operation::Zero)
313 .count();
314
315 let base_reward = 0.8 + 0.1 * (num_skip as f32 / arch.edges.len() as f32);
317 let zero_penalty = 0.1 * (num_zero as f32 / arch.edges.len() as f32);
318
319 base_reward - zero_penalty - lambda * cost
320 }
321
322 pub fn best_architecture(&self) -> Option<&Cell> {
324 self.architecture_pool.first().map(|(arch, _)| arch)
325 }
326}
327
328pub struct ProgressiveNAS {
330 pub stage: usize,
332 pub stage_architectures: Vec<Vec<Cell>>,
334 pub complexity_budget: f32,
336}
337
338impl ProgressiveNAS {
339 pub fn new(complexity_budget: f32) -> Self {
341 ProgressiveNAS {
342 stage: 0,
343 stage_architectures: vec![Vec::new()],
344 complexity_budget,
345 }
346 }
347
348 pub fn next_stage(&mut self, num_nodes: usize, num_candidates: usize) {
350 self.stage += 1;
351 let mut new_stage = Vec::new();
352
353 if self.stage == 1 {
354 for _ in 0..num_candidates {
356 let cell = Cell::random(num_nodes);
357 if cell.compute_cost() <= self.complexity_budget {
358 new_stage.push(cell);
359 }
360 }
361 } else {
362 let prev_stage = &self.stage_architectures[self.stage - 1];
364
365 for parent in prev_stage.iter().take(num_candidates / 2) {
366 for _ in 0..2 {
368 let mut child = parent.clone();
369 self.mutate_cell(&mut child);
370
371 if child.compute_cost() <= self.complexity_budget {
372 new_stage.push(child);
373 }
374 }
375 }
376 }
377
378 self.stage_architectures.push(new_stage);
379 }
380
381 fn mutate_cell(&self, cell: &mut Cell) {
383 let mut rng = rand::thread_rng();
384 let ops = Operation::all();
385
386 if !cell.edges.is_empty() {
388 let idx = rng.gen_range(0..cell.edges.len());
389 let new_op = ops[rng.gen_range(0..ops.len())];
390 cell.edges[idx].2 = new_op;
391 }
392 }
393
394 pub fn current_architectures(&self) -> &[Cell] {
396 &self.stage_architectures[self.stage]
397 }
398}
399
400pub struct HardwareAwareNAS {
402 pub target_latency: f32,
404 pub target_hardware: String,
406 pub latency_table: HashMap<Operation, f32>,
408}
409
410impl HardwareAwareNAS {
411 pub fn new(target_hardware: &str, target_latency: f32) -> Self {
413 let mut latency_table = HashMap::new();
414
415 match target_hardware {
417 "mobile" => {
418 latency_table.insert(Operation::SepConv3x3, 2.0);
419 latency_table.insert(Operation::SepConv5x5, 5.0);
420 latency_table.insert(Operation::DilConv3x3, 3.0);
421 latency_table.insert(Operation::DilConv5x5, 7.0);
422 latency_table.insert(Operation::MaxPool3x3, 0.5);
423 latency_table.insert(Operation::AvgPool3x3, 0.5);
424 latency_table.insert(Operation::Skip, 0.1);
425 latency_table.insert(Operation::Zero, 0.0);
426 }
427 "gpu" => {
428 latency_table.insert(Operation::SepConv3x3, 0.5);
429 latency_table.insert(Operation::SepConv5x5, 1.2);
430 latency_table.insert(Operation::DilConv3x3, 0.7);
431 latency_table.insert(Operation::DilConv5x5, 1.5);
432 latency_table.insert(Operation::MaxPool3x3, 0.1);
433 latency_table.insert(Operation::AvgPool3x3, 0.1);
434 latency_table.insert(Operation::Skip, 0.05);
435 latency_table.insert(Operation::Zero, 0.0);
436 }
437 "tpu" => {
438 latency_table.insert(Operation::SepConv3x3, 0.2);
439 latency_table.insert(Operation::SepConv5x5, 0.5);
440 latency_table.insert(Operation::DilConv3x3, 0.3);
441 latency_table.insert(Operation::DilConv5x5, 0.6);
442 latency_table.insert(Operation::MaxPool3x3, 0.05);
443 latency_table.insert(Operation::AvgPool3x3, 0.05);
444 latency_table.insert(Operation::Skip, 0.02);
445 latency_table.insert(Operation::Zero, 0.0);
446 }
447 _ => {
448 latency_table.insert(Operation::SepConv3x3, 2.0);
450 latency_table.insert(Operation::SepConv5x5, 5.0);
451 latency_table.insert(Operation::DilConv3x3, 3.0);
452 latency_table.insert(Operation::DilConv5x5, 7.0);
453 latency_table.insert(Operation::MaxPool3x3, 0.5);
454 latency_table.insert(Operation::AvgPool3x3, 0.5);
455 latency_table.insert(Operation::Skip, 0.1);
456 latency_table.insert(Operation::Zero, 0.0);
457 }
458 }
459
460 HardwareAwareNAS {
461 target_latency,
462 target_hardware: target_hardware.to_string(),
463 latency_table,
464 }
465 }
466
467 pub fn estimate_latency(&self, cell: &Cell) -> f32 {
469 cell.edges.iter()
470 .map(|(_, _, op)| self.latency_table.get(op).unwrap_or(&0.0))
471 .sum()
472 }
473
474 pub fn meets_constraint(&self, cell: &Cell) -> bool {
476 self.estimate_latency(cell) <= self.target_latency
477 }
478
479 pub fn search(&self, num_nodes: usize, num_iterations: usize) -> Option<Cell> {
481 let mut best_cell: Option<Cell> = None;
482 let mut best_score = f32::NEG_INFINITY;
483
484 for _ in 0..num_iterations {
485 let cell = Cell::random(num_nodes);
486
487 if self.meets_constraint(&cell) {
488 let score = -self.estimate_latency(&cell);
490
491 if score > best_score {
492 best_score = score;
493 best_cell = Some(cell);
494 }
495 }
496 }
497
498 best_cell
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_cell_creation() {
508 let cell = Cell::random(5);
509 assert_eq!(cell.num_nodes, 5);
510 assert!(!cell.edges.is_empty());
511 }
512
513 #[test]
514 fn test_darts() {
515 let mut darts = DARTS::new(4, 8);
516 let initial_cost = darts.total_cost();
517
518 darts.search_step(0.5, 0.6);
520
521 let (normal, reduction) = darts.derive_architecture();
523 assert!(!normal.is_empty());
524 assert!(!reduction.is_empty());
525 }
526
527 #[test]
528 fn test_enas() {
529 let mut enas = ENAS::new(5);
530 let reward = enas.train_step(4);
531
532 assert!(!enas.architecture_pool.is_empty());
534 assert!(reward.is_finite());
535 }
536
537 #[test]
538 fn test_progressive_nas() {
539 let mut pnas = ProgressiveNAS::new(100.0);
540 pnas.next_stage(4, 10);
541
542 assert_eq!(pnas.stage, 1);
543 assert!(!pnas.current_architectures().is_empty());
544 }
545
546 #[test]
547 fn test_hardware_aware_nas() {
548 let hwnas = HardwareAwareNAS::new("mobile", 50.0);
549 let cell = Cell::random(4);
550
551 let latency = hwnas.estimate_latency(&cell);
552 assert!(latency >= 0.0);
553
554 if let Some(arch) = hwnas.search(4, 100) {
556 assert!(hwnas.meets_constraint(&arch));
557 }
558 }
559}