ghostflow_ml/
nas.rs

1//! Neural Architecture Search (NAS)
2//!
3//! Implements automated neural network architecture discovery:
4//! - DARTS (Differentiable Architecture Search)
5//! - ENAS (Efficient Neural Architecture Search)
6//! - NASNet search space
7//! - Progressive Neural Architecture Search
8//! - Hardware-aware NAS
9
10use ghostflow_core::Tensor;
11use std::collections::HashMap;
12use rand::Rng;
13
14/// Neural network operation types
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum Operation {
17    /// 3x3 separable convolution
18    SepConv3x3,
19    /// 5x5 separable convolution
20    SepConv5x5,
21    /// 3x3 dilated convolution
22    DilConv3x3,
23    /// 5x5 dilated convolution
24    DilConv5x5,
25    /// 3x3 max pooling
26    MaxPool3x3,
27    /// 3x3 average pooling
28    AvgPool3x3,
29    /// Skip connection
30    Skip,
31    /// Zero operation (no connection)
32    Zero,
33}
34
35impl Operation {
36    /// Get all available operations
37    pub fn all() -> Vec<Operation> {
38        vec![
39            Operation::SepConv3x3,
40            Operation::SepConv5x5,
41            Operation::DilConv3x3,
42            Operation::DilConv5x5,
43            Operation::MaxPool3x3,
44            Operation::AvgPool3x3,
45            Operation::Skip,
46            Operation::Zero,
47        ]
48    }
49    
50    /// Get operation cost (FLOPs estimate)
51    pub fn cost(&self) -> f32 {
52        match self {
53            Operation::SepConv3x3 => 9.0,
54            Operation::SepConv5x5 => 25.0,
55            Operation::DilConv3x3 => 9.0,
56            Operation::DilConv5x5 => 25.0,
57            Operation::MaxPool3x3 => 1.0,
58            Operation::AvgPool3x3 => 1.0,
59            Operation::Skip => 0.0,
60            Operation::Zero => 0.0,
61        }
62    }
63}
64
65/// Architecture cell (building block)
66#[derive(Debug, Clone)]
67pub struct Cell {
68    /// Number of nodes in the cell
69    pub num_nodes: usize,
70    /// Operations between nodes: (from_node, to_node, operation)
71    pub edges: Vec<(usize, usize, Operation)>,
72    /// Architecture parameters (for DARTS)
73    pub alpha: HashMap<(usize, usize), Vec<f32>>,
74}
75
76impl Cell {
77    /// Create a new cell with random architecture
78    pub fn random(num_nodes: usize) -> Self {
79        let mut rng = rand::thread_rng();
80        let mut edges = Vec::new();
81        let mut alpha = HashMap::new();
82        
83        // Connect each node to previous nodes
84        for to_node in 2..num_nodes {
85            for from_node in 0..to_node {
86                // Random operation
87                let ops = Operation::all();
88                let op = ops[rng.gen_range(0..ops.len())];
89                edges.push((from_node, to_node, op));
90                
91                // Initialize architecture parameters
92                let num_ops = ops.len();
93                let weights: Vec<f32> = (0..num_ops)
94                    .map(|_| rng.gen_range(-0.1..0.1))
95                    .collect();
96                alpha.insert((from_node, to_node), weights);
97            }
98        }
99        
100        Cell {
101            num_nodes,
102            edges,
103            alpha,
104        }
105    }
106    
107    /// Get the dominant operation for each edge (for discretization)
108    pub fn get_genotype(&self) -> Vec<(usize, usize, Operation)> {
109        let mut genotype = Vec::new();
110        let ops = Operation::all();
111        
112        for ((from, to), weights) in &self.alpha {
113            // Find operation with highest weight
114            let (max_idx, _) = weights.iter()
115                .enumerate()
116                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
117                .unwrap();
118            
119            genotype.push((*from, *to, ops[max_idx]));
120        }
121        
122        genotype
123    }
124    
125    /// Compute cell cost (total FLOPs)
126    pub fn compute_cost(&self) -> f32 {
127        self.edges.iter().map(|(_, _, op)| op.cost()).sum()
128    }
129}
130
131/// DARTS (Differentiable Architecture Search)
132pub struct DARTS {
133    /// Normal cell (for feature extraction)
134    pub normal_cell: Cell,
135    /// Reduction cell (for downsampling)
136    pub reduction_cell: Cell,
137    /// Number of cells in the network
138    pub num_cells: usize,
139    /// Learning rate for architecture parameters
140    pub arch_lr: f32,
141    /// Learning rate for network weights
142    pub weight_lr: f32,
143}
144
145impl DARTS {
146    /// Create a new DARTS search
147    pub fn new(num_nodes: usize, num_cells: usize) -> Self {
148        DARTS {
149            normal_cell: Cell::random(num_nodes),
150            reduction_cell: Cell::random(num_nodes),
151            num_cells,
152            arch_lr: 3e-4,
153            weight_lr: 0.025,
154        }
155    }
156    
157    /// Perform one step of architecture search
158    pub fn search_step(&mut self, train_loss: f32, val_loss: f32) {
159        // Update architecture parameters based on validation loss
160        // Gradient: ∇α L_val
161        
162        for ((from, to), weights) in self.normal_cell.alpha.iter_mut() {
163            // Compute softmax of architecture weights
164            let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
165            let exp_sum: f32 = weights.iter().map(|w| (w - max_w).exp()).sum();
166            
167            // Update each operation weight
168            for (i, w) in weights.iter_mut().enumerate() {
169                let prob = (*w - max_w).exp() / exp_sum;
170                // Gradient approximation
171                let grad = val_loss * (prob - if i == 0 { 1.0 } else { 0.0 });
172                *w -= self.arch_lr * grad;
173            }
174        }
175        
176        // Same for reduction cell
177        for ((from, to), weights) in self.reduction_cell.alpha.iter_mut() {
178            let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
179            let exp_sum: f32 = weights.iter().map(|w| (w - max_w).exp()).sum();
180            
181            for (i, w) in weights.iter_mut().enumerate() {
182                let prob = (*w - max_w).exp() / exp_sum;
183                let grad = val_loss * (prob - if i == 0 { 1.0 } else { 0.0 });
184                *w -= self.arch_lr * grad;
185            }
186        }
187    }
188    
189    /// Discretize the continuous architecture
190    pub fn derive_architecture(&self) -> (Vec<(usize, usize, Operation)>, Vec<(usize, usize, Operation)>) {
191        (self.normal_cell.get_genotype(), self.reduction_cell.get_genotype())
192    }
193    
194    /// Compute total network cost
195    pub fn total_cost(&self) -> f32 {
196        let normal_cost = self.normal_cell.compute_cost();
197        let reduction_cost = self.reduction_cell.compute_cost();
198        
199        // Approximate: most cells are normal, few are reduction
200        let num_reduction = (self.num_cells as f32 / 3.0).ceil() as usize;
201        let num_normal = self.num_cells - num_reduction;
202        
203        normal_cost * num_normal as f32 + reduction_cost * num_reduction as f32
204    }
205}
206
207/// ENAS (Efficient Neural Architecture Search)
208pub struct ENAS {
209    /// Shared weights for all operations
210    pub shared_weights: HashMap<Operation, Tensor>,
211    /// Controller RNN state
212    pub controller_state: Vec<f32>,
213    /// Sampled architectures and their rewards
214    pub architecture_pool: Vec<(Cell, f32)>,
215    /// Number of architectures to sample per iteration
216    pub num_samples: usize,
217}
218
219impl ENAS {
220    /// Create a new ENAS search
221    pub fn new(num_samples: usize) -> Self {
222        let mut shared_weights = HashMap::new();
223        
224        // Initialize shared weights for each operation type
225        for op in Operation::all() {
226            let weight = Tensor::randn(&[64, 64]); // Example dimensions
227            shared_weights.insert(op, weight);
228        }
229        
230        ENAS {
231            shared_weights,
232            controller_state: vec![0.0; 128], // LSTM hidden state
233            architecture_pool: Vec::new(),
234            num_samples,
235        }
236    }
237    
238    /// Sample an architecture using the controller
239    pub fn sample_architecture(&mut self, num_nodes: usize) -> Cell {
240        let mut rng = rand::thread_rng();
241        let mut cell = Cell::random(num_nodes);
242        
243        // Use controller to bias sampling (simplified)
244        // In full implementation, this would use an RNN controller
245        for ((from, to), weights) in cell.alpha.iter_mut() {
246            // Softmax with temperature
247            let temperature = 1.0;
248            let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
249            let exp_sum: f32 = weights.iter()
250                .map(|w| ((w - max_w) / temperature).exp())
251                .sum();
252            
253            // Sample operation based on probabilities
254            let sample: f32 = rng.gen();
255            let mut cumsum = 0.0;
256            for (i, w) in weights.iter().enumerate() {
257                let prob = ((w - max_w) / temperature).exp() / exp_sum;
258                cumsum += prob;
259                if sample < cumsum {
260                    // Set this operation to have highest weight
261                    weights[i] = 1.0;
262                    break;
263                }
264            }
265        }
266        
267        cell
268    }
269    
270    /// Train sampled architectures and update controller
271    pub fn train_step(&mut self, num_nodes: usize) -> f32 {
272        let mut total_reward = 0.0;
273        
274        // Sample architectures
275        for _ in 0..self.num_samples {
276            let arch = self.sample_architecture(num_nodes);
277            
278            // Evaluate architecture (simplified - would train child network)
279            let reward = self.evaluate_architecture(&arch);
280            total_reward += reward;
281            
282            // Store in pool
283            self.architecture_pool.push((arch, reward));
284        }
285        
286        // Keep only top architectures
287        self.architecture_pool.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
288        self.architecture_pool.truncate(100);
289        
290        // Update controller based on rewards (REINFORCE)
291        let avg_reward = total_reward / self.num_samples as f32;
292        
293        // Update controller state (simplified)
294        for state in self.controller_state.iter_mut() {
295            *state += 0.01 * avg_reward;
296        }
297        
298        avg_reward
299    }
300    
301    /// Evaluate an architecture
302    fn evaluate_architecture(&self, arch: &Cell) -> f32 {
303        // Reward = accuracy - λ * cost
304        let cost = arch.compute_cost();
305        let lambda = 0.001; // Cost penalty
306        
307        // Simplified: reward based on architecture properties
308        let num_skip = arch.edges.iter()
309            .filter(|(_, _, op)| *op == Operation::Skip)
310            .count();
311        let num_zero = arch.edges.iter()
312            .filter(|(_, _, op)| *op == Operation::Zero)
313            .count();
314        
315        // Prefer architectures with some skip connections but not too many zeros
316        let base_reward = 0.8 + 0.1 * (num_skip as f32 / arch.edges.len() as f32);
317        let zero_penalty = 0.1 * (num_zero as f32 / arch.edges.len() as f32);
318        
319        base_reward - zero_penalty - lambda * cost
320    }
321    
322    /// Get best architecture found so far
323    pub fn best_architecture(&self) -> Option<&Cell> {
324        self.architecture_pool.first().map(|(arch, _)| arch)
325    }
326}
327
328/// Progressive Neural Architecture Search
329pub struct ProgressiveNAS {
330    /// Current search stage
331    pub stage: usize,
332    /// Architectures at each stage
333    pub stage_architectures: Vec<Vec<Cell>>,
334    /// Complexity budget
335    pub complexity_budget: f32,
336}
337
338impl ProgressiveNAS {
339    /// Create a new progressive NAS
340    pub fn new(complexity_budget: f32) -> Self {
341        ProgressiveNAS {
342            stage: 0,
343            stage_architectures: vec![Vec::new()],
344            complexity_budget,
345        }
346    }
347    
348    /// Progress to next stage
349    pub fn next_stage(&mut self, num_nodes: usize, num_candidates: usize) {
350        self.stage += 1;
351        let mut new_stage = Vec::new();
352        
353        if self.stage == 1 {
354            // First stage: generate random architectures
355            for _ in 0..num_candidates {
356                let cell = Cell::random(num_nodes);
357                if cell.compute_cost() <= self.complexity_budget {
358                    new_stage.push(cell);
359                }
360            }
361        } else {
362            // Later stages: mutate best from previous stage
363            let prev_stage = &self.stage_architectures[self.stage - 1];
364            
365            for parent in prev_stage.iter().take(num_candidates / 2) {
366                // Create mutations
367                for _ in 0..2 {
368                    let mut child = parent.clone();
369                    self.mutate_cell(&mut child);
370                    
371                    if child.compute_cost() <= self.complexity_budget {
372                        new_stage.push(child);
373                    }
374                }
375            }
376        }
377        
378        self.stage_architectures.push(new_stage);
379    }
380    
381    /// Mutate a cell
382    fn mutate_cell(&self, cell: &mut Cell) {
383        let mut rng = rand::thread_rng();
384        let ops = Operation::all();
385        
386        // Randomly change one operation
387        if !cell.edges.is_empty() {
388            let idx = rng.gen_range(0..cell.edges.len());
389            let new_op = ops[rng.gen_range(0..ops.len())];
390            cell.edges[idx].2 = new_op;
391        }
392    }
393    
394    /// Get current stage architectures
395    pub fn current_architectures(&self) -> &[Cell] {
396        &self.stage_architectures[self.stage]
397    }
398}
399
400/// Hardware-aware NAS
401pub struct HardwareAwareNAS {
402    /// Target hardware latency (ms)
403    pub target_latency: f32,
404    /// Target hardware (e.g., "mobile", "gpu", "tpu")
405    pub target_hardware: String,
406    /// Latency lookup table for operations
407    pub latency_table: HashMap<Operation, f32>,
408}
409
410impl HardwareAwareNAS {
411    /// Create a new hardware-aware NAS
412    pub fn new(target_hardware: &str, target_latency: f32) -> Self {
413        let mut latency_table = HashMap::new();
414        
415        // Latency estimates for different hardware (ms per operation)
416        match target_hardware {
417            "mobile" => {
418                latency_table.insert(Operation::SepConv3x3, 2.0);
419                latency_table.insert(Operation::SepConv5x5, 5.0);
420                latency_table.insert(Operation::DilConv3x3, 3.0);
421                latency_table.insert(Operation::DilConv5x5, 7.0);
422                latency_table.insert(Operation::MaxPool3x3, 0.5);
423                latency_table.insert(Operation::AvgPool3x3, 0.5);
424                latency_table.insert(Operation::Skip, 0.1);
425                latency_table.insert(Operation::Zero, 0.0);
426            }
427            "gpu" => {
428                latency_table.insert(Operation::SepConv3x3, 0.5);
429                latency_table.insert(Operation::SepConv5x5, 1.2);
430                latency_table.insert(Operation::DilConv3x3, 0.7);
431                latency_table.insert(Operation::DilConv5x5, 1.5);
432                latency_table.insert(Operation::MaxPool3x3, 0.1);
433                latency_table.insert(Operation::AvgPool3x3, 0.1);
434                latency_table.insert(Operation::Skip, 0.05);
435                latency_table.insert(Operation::Zero, 0.0);
436            }
437            "tpu" => {
438                latency_table.insert(Operation::SepConv3x3, 0.2);
439                latency_table.insert(Operation::SepConv5x5, 0.5);
440                latency_table.insert(Operation::DilConv3x3, 0.3);
441                latency_table.insert(Operation::DilConv5x5, 0.6);
442                latency_table.insert(Operation::MaxPool3x3, 0.05);
443                latency_table.insert(Operation::AvgPool3x3, 0.05);
444                latency_table.insert(Operation::Skip, 0.02);
445                latency_table.insert(Operation::Zero, 0.0);
446            }
447            _ => {
448                // Default to mobile
449                latency_table.insert(Operation::SepConv3x3, 2.0);
450                latency_table.insert(Operation::SepConv5x5, 5.0);
451                latency_table.insert(Operation::DilConv3x3, 3.0);
452                latency_table.insert(Operation::DilConv5x5, 7.0);
453                latency_table.insert(Operation::MaxPool3x3, 0.5);
454                latency_table.insert(Operation::AvgPool3x3, 0.5);
455                latency_table.insert(Operation::Skip, 0.1);
456                latency_table.insert(Operation::Zero, 0.0);
457            }
458        }
459        
460        HardwareAwareNAS {
461            target_latency,
462            target_hardware: target_hardware.to_string(),
463            latency_table,
464        }
465    }
466    
467    /// Estimate latency for a cell
468    pub fn estimate_latency(&self, cell: &Cell) -> f32 {
469        cell.edges.iter()
470            .map(|(_, _, op)| self.latency_table.get(op).unwrap_or(&0.0))
471            .sum()
472    }
473    
474    /// Check if architecture meets latency constraint
475    pub fn meets_constraint(&self, cell: &Cell) -> bool {
476        self.estimate_latency(cell) <= self.target_latency
477    }
478    
479    /// Search for architecture meeting latency constraint
480    pub fn search(&self, num_nodes: usize, num_iterations: usize) -> Option<Cell> {
481        let mut best_cell: Option<Cell> = None;
482        let mut best_score = f32::NEG_INFINITY;
483        
484        for _ in 0..num_iterations {
485            let cell = Cell::random(num_nodes);
486            
487            if self.meets_constraint(&cell) {
488                // Score = -latency (prefer faster architectures)
489                let score = -self.estimate_latency(&cell);
490                
491                if score > best_score {
492                    best_score = score;
493                    best_cell = Some(cell);
494                }
495            }
496        }
497        
498        best_cell
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    
506    #[test]
507    fn test_cell_creation() {
508        let cell = Cell::random(5);
509        assert_eq!(cell.num_nodes, 5);
510        assert!(!cell.edges.is_empty());
511    }
512    
513    #[test]
514    fn test_darts() {
515        let mut darts = DARTS::new(4, 8);
516        let initial_cost = darts.total_cost();
517        
518        // Perform search step
519        darts.search_step(0.5, 0.6);
520        
521        // Architecture should be updated
522        let (normal, reduction) = darts.derive_architecture();
523        assert!(!normal.is_empty());
524        assert!(!reduction.is_empty());
525    }
526    
527    #[test]
528    fn test_enas() {
529        let mut enas = ENAS::new(5);
530        let reward = enas.train_step(4);
531        
532        // Should have sampled architectures
533        assert!(!enas.architecture_pool.is_empty());
534        assert!(reward.is_finite());
535    }
536    
537    #[test]
538    fn test_progressive_nas() {
539        let mut pnas = ProgressiveNAS::new(100.0);
540        pnas.next_stage(4, 10);
541        
542        assert_eq!(pnas.stage, 1);
543        assert!(!pnas.current_architectures().is_empty());
544    }
545    
546    #[test]
547    fn test_hardware_aware_nas() {
548        let hwnas = HardwareAwareNAS::new("mobile", 50.0);
549        let cell = Cell::random(4);
550        
551        let latency = hwnas.estimate_latency(&cell);
552        assert!(latency >= 0.0);
553        
554        // Search for architecture
555        if let Some(arch) = hwnas.search(4, 100) {
556            assert!(hwnas.meets_constraint(&arch));
557        }
558    }
559}