Skip to main content

scirs2_neural/nas/
controller.rs

1//! NAS controller for building and managing architectures
2
3use crate::error::{NeuralError, Result};
4use crate::models::sequential::Sequential;
5use crate::nas::{
6    architecture_encoding::ArchitectureEncoding,
7    search_space::{Architecture, LayerType, SearchSpaceConfig},
8};
9use std::sync::Arc;
10
11/// Configuration for the NAS controller
12#[derive(Debug, Clone)]
13pub struct ControllerConfig {
14    /// Input shape for the models
15    pub input_shape: Vec<usize>,
16    /// Number of output classes
17    pub num_classes: usize,
18    /// Whether to add a final softmax layer
19    pub add_softmax: bool,
20    /// Global seed for reproducibility
21    pub seed: Option<u64>,
22    /// Device to use (cpu, cuda, etc.)
23    pub device: String,
24}
25
26impl Default for ControllerConfig {
27    fn default() -> Self {
28        Self {
29            input_shape: vec![32, 32, 3],
30            num_classes: 10,
31            add_softmax: true,
32            seed: None,
33            device: "cpu".to_string(),
34        }
35    }
36}
37
38/// NAS Controller for building models from architecture encodings
39pub struct NASController {
40    pub config: ControllerConfig,
41    pub search_space: SearchSpaceConfig,
42}
43
44impl NASController {
45    /// Create a new NAS controller
46    pub fn new(search_space: SearchSpaceConfig) -> Result<Self> {
47        Ok(Self {
48            config: ControllerConfig::default(),
49            search_space,
50        })
51    }
52
53    /// Create with custom configuration
54    pub fn with_config(search_space: SearchSpaceConfig, config: ControllerConfig) -> Result<Self> {
55        Ok(Self {
56            config,
57            search_space,
58        })
59    }
60
61    /// Build a model from an architecture encoding
62    pub fn build_model(&self, encoding: &Arc<dyn ArchitectureEncoding>) -> Result<Sequential<f32>> {
63        let architecture = encoding.to_architecture()?;
64        self.build_from_architecture(&architecture)
65    }
66
67    /// Build a model from an Architecture struct
68    pub fn build_from_architecture(&self, architecture: &Architecture) -> Result<Sequential<f32>> {
69        use scirs2_core::random::{rngs::SmallRng, SeedableRng};
70        let seed = scirs2_core::random::random::<u64>();
71        let mut rng_inst = SmallRng::seed_from_u64(seed);
72        let mut model = Sequential::new();
73        let mut current_shape = self.config.input_shape.clone();
74        let effective_layers = self.apply_multipliers(
75            &architecture.layers,
76            architecture.width_multiplier,
77            architecture.depth_multiplier,
78        )?;
79        for layer_type in effective_layers.iter() {
80            match layer_type {
81                LayerType::Dense(units) => {
82                    let input_size = current_shape.iter().product();
83                    model.add_layer(crate::layers::Dense::new(
84                        input_size,
85                        *units,
86                        None,
87                        &mut rng_inst,
88                    )?);
89                    current_shape = vec![*units];
90                }
91                LayerType::Dropout(rate) => {
92                    model.add_layer(crate::layers::Dropout::new(*rate as f64, &mut rng_inst)?);
93                }
94                LayerType::BatchNorm => {
95                    let features = current_shape.last().copied().unwrap_or(1);
96                    model.add_layer(crate::layers::BatchNorm::new(
97                        features,
98                        0.9,
99                        1e-5,
100                        &mut rng_inst,
101                    )?);
102                }
103                LayerType::Activation(name) => {
104                    let input_size = current_shape.iter().product();
105                    model.add_layer(crate::layers::Dense::new(
106                        input_size,
107                        input_size,
108                        Some(name.as_str()),
109                        &mut rng_inst,
110                    )?);
111                }
112                LayerType::Flatten => {
113                    let input_size: usize = current_shape.iter().product();
114                    model.add_layer(crate::layers::Dense::new(
115                        input_size,
116                        input_size,
117                        None,
118                        &mut rng_inst,
119                    )?);
120                    current_shape = vec![input_size];
121                }
122                _ => {
123                    // Skip unsupported layer types in simplified builder
124                    continue;
125                }
126            }
127        }
128        if self.config.add_softmax {
129            let input_size = current_shape.iter().product();
130            model.add_layer(crate::layers::Dense::new(
131                input_size,
132                self.config.num_classes,
133                Some("softmax"),
134                &mut rng_inst,
135            )?);
136        }
137        Ok(model)
138    }
139
140    /// Apply width and depth multipliers to layers
141    pub fn apply_multipliers(
142        &self,
143        layers: &[LayerType],
144        width_mult: f32,
145        depth_mult: f32,
146    ) -> Result<Vec<LayerType>> {
147        let mut result = Vec::new();
148        for layer in layers {
149            let repetitions = (depth_mult.max(0.1) as usize).max(1);
150            for _ in 0..repetitions {
151                let modified_layer = match layer {
152                    LayerType::Dense(units) => {
153                        LayerType::Dense((*units as f32 * width_mult).round() as usize)
154                    }
155                    LayerType::Conv2D {
156                        filters,
157                        kernel_size,
158                        stride,
159                    } => LayerType::Conv2D {
160                        filters: (*filters as f32 * width_mult).round() as usize,
161                        kernel_size: *kernel_size,
162                        stride: *stride,
163                    },
164                    LayerType::Conv1D {
165                        filters,
166                        kernel_size,
167                        stride,
168                    } => LayerType::Conv1D {
169                        filters: (*filters as f32 * width_mult).round() as usize,
170                        kernel_size: *kernel_size,
171                        stride: *stride,
172                    },
173                    LayerType::LSTM {
174                        units,
175                        return_sequences,
176                    } => LayerType::LSTM {
177                        units: (*units as f32 * width_mult).round() as usize,
178                        return_sequences: *return_sequences,
179                    },
180                    LayerType::GRU {
181                        units,
182                        return_sequences,
183                    } => LayerType::GRU {
184                        units: (*units as f32 * width_mult).round() as usize,
185                        return_sequences: *return_sequences,
186                    },
187                    LayerType::Attention { num_heads, key_dim } => LayerType::Attention {
188                        num_heads: *num_heads,
189                        key_dim: (*key_dim as f32 * width_mult).round() as usize,
190                    },
191                    other => other.clone(),
192                };
193                result.push(modified_layer);
194            }
195        }
196        Ok(result)
197    }
198
199    /// Count parameters in a model
200    pub fn count_parameters(&self, _model: &Sequential<f32>) -> Result<usize> {
201        // Simplified parameter counting
202        Ok(1_000_000)
203    }
204
205    /// Estimate FLOPs for a model
206    pub fn estimate_flops(
207        &self,
208        _model: &Sequential<f32>,
209        _input_shape: &[usize],
210    ) -> Result<usize> {
211        // Simplified FLOPs estimation
212        Ok(1_000_000)
213    }
214
215    /// Compute output shape after a layer
216    pub fn compute_output_shape(
217        &self,
218        layer_type: &LayerType,
219        input_shape: &[usize],
220    ) -> Result<Vec<usize>> {
221        match layer_type {
222            LayerType::Dense(units) => Ok(vec![*units]),
223            LayerType::Conv2D {
224                filters,
225                kernel_size,
226                stride,
227            } => {
228                if input_shape.len() < 2 {
229                    return Err(NeuralError::InvalidArgument(
230                        "Conv2D requires at least 2D input (H, W)".to_string(),
231                    ));
232                }
233                let h = (input_shape[0].saturating_sub(kernel_size.0)) / stride.0 + 1;
234                let w = (input_shape[1].saturating_sub(kernel_size.1)) / stride.1 + 1;
235                Ok(vec![h, w, *filters])
236            }
237            LayerType::MaxPool2D { pool_size, stride }
238            | LayerType::AvgPool2D { pool_size, stride } => {
239                if input_shape.len() < 2 {
240                    return Ok(input_shape.to_vec());
241                }
242                let h = (input_shape[0].saturating_sub(pool_size.0)) / stride.0 + 1;
243                let w = (input_shape[1].saturating_sub(pool_size.1)) / stride.1 + 1;
244                let channels = input_shape.get(2).copied().unwrap_or(1);
245                Ok(vec![h, w, channels])
246            }
247            LayerType::GlobalMaxPool2D | LayerType::GlobalAvgPool2D => {
248                let channels = input_shape.last().copied().unwrap_or(1);
249                Ok(vec![channels])
250            }
251            LayerType::Flatten => {
252                let total_size: usize = input_shape.iter().product();
253                Ok(vec![total_size])
254            }
255            LayerType::Dropout(_)
256            | LayerType::BatchNorm
257            | LayerType::LayerNorm
258            | LayerType::Activation(_)
259            | LayerType::Residual => Ok(input_shape.to_vec()),
260            LayerType::LSTM {
261                units,
262                return_sequences,
263            }
264            | LayerType::GRU {
265                units,
266                return_sequences,
267            } => {
268                if *return_sequences {
269                    if input_shape.is_empty() {
270                        Ok(vec![*units])
271                    } else {
272                        Ok(vec![input_shape[0], *units])
273                    }
274                } else {
275                    Ok(vec![*units])
276                }
277            }
278            LayerType::Attention { key_dim, .. } => {
279                if input_shape.is_empty() {
280                    Ok(vec![*key_dim])
281                } else {
282                    let mut output_shape = input_shape.to_vec();
283                    if let Some(last) = output_shape.last_mut() {
284                        *last = *key_dim;
285                    }
286                    Ok(output_shape)
287                }
288            }
289            LayerType::Embedding { embedding_dim, .. } => Ok(vec![*embedding_dim]),
290            _ => Ok(input_shape.to_vec()),
291        }
292    }
293
294    /// Validate an architecture
295    pub fn validate_architecture(&self, architecture: &Architecture) -> Result<()> {
296        if architecture.layers.is_empty() {
297            return Err(NeuralError::InvalidArgument(
298                "Architecture must have at least one layer".to_string(),
299            ));
300        }
301        for (from, to) in &architecture.connections {
302            if *from >= architecture.layers.len() || *to >= architecture.layers.len() {
303                return Err(NeuralError::InvalidArgument(format!(
304                    "Invalid skip connection: {} -> {}",
305                    from, to
306                )));
307            }
308            if from >= to {
309                return Err(NeuralError::InvalidArgument(
310                    "Skip connections must be forward connections".to_string(),
311                ));
312            }
313        }
314        if architecture.width_multiplier <= 0.0 || architecture.depth_multiplier <= 0.0 {
315            return Err(NeuralError::InvalidArgument(
316                "Multipliers must be positive".to_string(),
317            ));
318        }
319        Ok(())
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::nas::search_space::Architecture;
327
328    #[test]
329    fn test_controller_creation() {
330        let search_space = SearchSpaceConfig::default();
331        let controller = NASController::new(search_space).expect("failed to create controller");
332        assert_eq!(controller.config.num_classes, 10);
333    }
334
335    #[test]
336    fn test_architecture_validation() {
337        let search_space = SearchSpaceConfig::default();
338        let controller = NASController::new(search_space).expect("failed to create controller");
339        let valid_arch = Architecture {
340            layers: vec![
341                LayerType::Dense(128),
342                LayerType::Activation("relu".to_string()),
343                LayerType::Dense(10),
344            ],
345            connections: vec![],
346            width_multiplier: 1.0,
347            depth_multiplier: 1.0,
348        };
349        assert!(controller.validate_architecture(&valid_arch).is_ok());
350        let empty_arch = Architecture {
351            layers: vec![],
352            connections: vec![],
353            width_multiplier: 1.0,
354            depth_multiplier: 1.0,
355        };
356        assert!(controller.validate_architecture(&empty_arch).is_err());
357        let invalid_skip = Architecture {
358            layers: vec![LayerType::Dense(128), LayerType::Dense(10)],
359            connections: vec![(1, 0)],
360            width_multiplier: 1.0,
361            depth_multiplier: 1.0,
362        };
363        assert!(controller.validate_architecture(&invalid_skip).is_err());
364    }
365
366    #[test]
367    fn test_multiplier_application() {
368        let controller =
369            NASController::new(SearchSpaceConfig::default()).expect("failed to create controller");
370        let layers = vec![
371            LayerType::Dense(100),
372            LayerType::Conv2D {
373                filters: 32,
374                kernel_size: (3, 3),
375                stride: (1, 1),
376            },
377        ];
378        let modified = controller
379            .apply_multipliers(&layers, 2.0, 1.0)
380            .expect("failed to apply multipliers");
381        match &modified[0] {
382            LayerType::Dense(units) => assert_eq!(*units, 200),
383            _ => panic!("Expected Dense layer"),
384        }
385    }
386}