1use crate::error::{NeuralError, Result};
4use crate::models::sequential::Sequential;
5use crate::nas::{
6 architecture_encoding::ArchitectureEncoding,
7 search_space::{Architecture, LayerType, SearchSpaceConfig},
8};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct ControllerConfig {
14 pub input_shape: Vec<usize>,
16 pub num_classes: usize,
18 pub add_softmax: bool,
20 pub seed: Option<u64>,
22 pub device: String,
24}
25
26impl Default for ControllerConfig {
27 fn default() -> Self {
28 Self {
29 input_shape: vec![32, 32, 3],
30 num_classes: 10,
31 add_softmax: true,
32 seed: None,
33 device: "cpu".to_string(),
34 }
35 }
36}
37
38pub struct NASController {
40 pub config: ControllerConfig,
41 pub search_space: SearchSpaceConfig,
42}
43
44impl NASController {
45 pub fn new(search_space: SearchSpaceConfig) -> Result<Self> {
47 Ok(Self {
48 config: ControllerConfig::default(),
49 search_space,
50 })
51 }
52
53 pub fn with_config(search_space: SearchSpaceConfig, config: ControllerConfig) -> Result<Self> {
55 Ok(Self {
56 config,
57 search_space,
58 })
59 }
60
61 pub fn build_model(&self, encoding: &Arc<dyn ArchitectureEncoding>) -> Result<Sequential<f32>> {
63 let architecture = encoding.to_architecture()?;
64 self.build_from_architecture(&architecture)
65 }
66
67 pub fn build_from_architecture(&self, architecture: &Architecture) -> Result<Sequential<f32>> {
69 use scirs2_core::random::{rngs::SmallRng, SeedableRng};
70 let seed = scirs2_core::random::random::<u64>();
71 let mut rng_inst = SmallRng::seed_from_u64(seed);
72 let mut model = Sequential::new();
73 let mut current_shape = self.config.input_shape.clone();
74 let effective_layers = self.apply_multipliers(
75 &architecture.layers,
76 architecture.width_multiplier,
77 architecture.depth_multiplier,
78 )?;
79 for layer_type in effective_layers.iter() {
80 match layer_type {
81 LayerType::Dense(units) => {
82 let input_size = current_shape.iter().product();
83 model.add_layer(crate::layers::Dense::new(
84 input_size,
85 *units,
86 None,
87 &mut rng_inst,
88 )?);
89 current_shape = vec![*units];
90 }
91 LayerType::Dropout(rate) => {
92 model.add_layer(crate::layers::Dropout::new(*rate as f64, &mut rng_inst)?);
93 }
94 LayerType::BatchNorm => {
95 let features = current_shape.last().copied().unwrap_or(1);
96 model.add_layer(crate::layers::BatchNorm::new(
97 features,
98 0.9,
99 1e-5,
100 &mut rng_inst,
101 )?);
102 }
103 LayerType::Activation(name) => {
104 let input_size = current_shape.iter().product();
105 model.add_layer(crate::layers::Dense::new(
106 input_size,
107 input_size,
108 Some(name.as_str()),
109 &mut rng_inst,
110 )?);
111 }
112 LayerType::Flatten => {
113 let input_size: usize = current_shape.iter().product();
114 model.add_layer(crate::layers::Dense::new(
115 input_size,
116 input_size,
117 None,
118 &mut rng_inst,
119 )?);
120 current_shape = vec![input_size];
121 }
122 _ => {
123 continue;
125 }
126 }
127 }
128 if self.config.add_softmax {
129 let input_size = current_shape.iter().product();
130 model.add_layer(crate::layers::Dense::new(
131 input_size,
132 self.config.num_classes,
133 Some("softmax"),
134 &mut rng_inst,
135 )?);
136 }
137 Ok(model)
138 }
139
140 pub fn apply_multipliers(
142 &self,
143 layers: &[LayerType],
144 width_mult: f32,
145 depth_mult: f32,
146 ) -> Result<Vec<LayerType>> {
147 let mut result = Vec::new();
148 for layer in layers {
149 let repetitions = (depth_mult.max(0.1) as usize).max(1);
150 for _ in 0..repetitions {
151 let modified_layer = match layer {
152 LayerType::Dense(units) => {
153 LayerType::Dense((*units as f32 * width_mult).round() as usize)
154 }
155 LayerType::Conv2D {
156 filters,
157 kernel_size,
158 stride,
159 } => LayerType::Conv2D {
160 filters: (*filters as f32 * width_mult).round() as usize,
161 kernel_size: *kernel_size,
162 stride: *stride,
163 },
164 LayerType::Conv1D {
165 filters,
166 kernel_size,
167 stride,
168 } => LayerType::Conv1D {
169 filters: (*filters as f32 * width_mult).round() as usize,
170 kernel_size: *kernel_size,
171 stride: *stride,
172 },
173 LayerType::LSTM {
174 units,
175 return_sequences,
176 } => LayerType::LSTM {
177 units: (*units as f32 * width_mult).round() as usize,
178 return_sequences: *return_sequences,
179 },
180 LayerType::GRU {
181 units,
182 return_sequences,
183 } => LayerType::GRU {
184 units: (*units as f32 * width_mult).round() as usize,
185 return_sequences: *return_sequences,
186 },
187 LayerType::Attention { num_heads, key_dim } => LayerType::Attention {
188 num_heads: *num_heads,
189 key_dim: (*key_dim as f32 * width_mult).round() as usize,
190 },
191 other => other.clone(),
192 };
193 result.push(modified_layer);
194 }
195 }
196 Ok(result)
197 }
198
199 pub fn count_parameters(&self, _model: &Sequential<f32>) -> Result<usize> {
201 Ok(1_000_000)
203 }
204
205 pub fn estimate_flops(
207 &self,
208 _model: &Sequential<f32>,
209 _input_shape: &[usize],
210 ) -> Result<usize> {
211 Ok(1_000_000)
213 }
214
215 pub fn compute_output_shape(
217 &self,
218 layer_type: &LayerType,
219 input_shape: &[usize],
220 ) -> Result<Vec<usize>> {
221 match layer_type {
222 LayerType::Dense(units) => Ok(vec![*units]),
223 LayerType::Conv2D {
224 filters,
225 kernel_size,
226 stride,
227 } => {
228 if input_shape.len() < 2 {
229 return Err(NeuralError::InvalidArgument(
230 "Conv2D requires at least 2D input (H, W)".to_string(),
231 ));
232 }
233 let h = (input_shape[0].saturating_sub(kernel_size.0)) / stride.0 + 1;
234 let w = (input_shape[1].saturating_sub(kernel_size.1)) / stride.1 + 1;
235 Ok(vec![h, w, *filters])
236 }
237 LayerType::MaxPool2D { pool_size, stride }
238 | LayerType::AvgPool2D { pool_size, stride } => {
239 if input_shape.len() < 2 {
240 return Ok(input_shape.to_vec());
241 }
242 let h = (input_shape[0].saturating_sub(pool_size.0)) / stride.0 + 1;
243 let w = (input_shape[1].saturating_sub(pool_size.1)) / stride.1 + 1;
244 let channels = input_shape.get(2).copied().unwrap_or(1);
245 Ok(vec![h, w, channels])
246 }
247 LayerType::GlobalMaxPool2D | LayerType::GlobalAvgPool2D => {
248 let channels = input_shape.last().copied().unwrap_or(1);
249 Ok(vec![channels])
250 }
251 LayerType::Flatten => {
252 let total_size: usize = input_shape.iter().product();
253 Ok(vec![total_size])
254 }
255 LayerType::Dropout(_)
256 | LayerType::BatchNorm
257 | LayerType::LayerNorm
258 | LayerType::Activation(_)
259 | LayerType::Residual => Ok(input_shape.to_vec()),
260 LayerType::LSTM {
261 units,
262 return_sequences,
263 }
264 | LayerType::GRU {
265 units,
266 return_sequences,
267 } => {
268 if *return_sequences {
269 if input_shape.is_empty() {
270 Ok(vec![*units])
271 } else {
272 Ok(vec![input_shape[0], *units])
273 }
274 } else {
275 Ok(vec![*units])
276 }
277 }
278 LayerType::Attention { key_dim, .. } => {
279 if input_shape.is_empty() {
280 Ok(vec![*key_dim])
281 } else {
282 let mut output_shape = input_shape.to_vec();
283 if let Some(last) = output_shape.last_mut() {
284 *last = *key_dim;
285 }
286 Ok(output_shape)
287 }
288 }
289 LayerType::Embedding { embedding_dim, .. } => Ok(vec![*embedding_dim]),
290 _ => Ok(input_shape.to_vec()),
291 }
292 }
293
294 pub fn validate_architecture(&self, architecture: &Architecture) -> Result<()> {
296 if architecture.layers.is_empty() {
297 return Err(NeuralError::InvalidArgument(
298 "Architecture must have at least one layer".to_string(),
299 ));
300 }
301 for (from, to) in &architecture.connections {
302 if *from >= architecture.layers.len() || *to >= architecture.layers.len() {
303 return Err(NeuralError::InvalidArgument(format!(
304 "Invalid skip connection: {} -> {}",
305 from, to
306 )));
307 }
308 if from >= to {
309 return Err(NeuralError::InvalidArgument(
310 "Skip connections must be forward connections".to_string(),
311 ));
312 }
313 }
314 if architecture.width_multiplier <= 0.0 || architecture.depth_multiplier <= 0.0 {
315 return Err(NeuralError::InvalidArgument(
316 "Multipliers must be positive".to_string(),
317 ));
318 }
319 Ok(())
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::nas::search_space::Architecture;
327
328 #[test]
329 fn test_controller_creation() {
330 let search_space = SearchSpaceConfig::default();
331 let controller = NASController::new(search_space).expect("failed to create controller");
332 assert_eq!(controller.config.num_classes, 10);
333 }
334
335 #[test]
336 fn test_architecture_validation() {
337 let search_space = SearchSpaceConfig::default();
338 let controller = NASController::new(search_space).expect("failed to create controller");
339 let valid_arch = Architecture {
340 layers: vec![
341 LayerType::Dense(128),
342 LayerType::Activation("relu".to_string()),
343 LayerType::Dense(10),
344 ],
345 connections: vec![],
346 width_multiplier: 1.0,
347 depth_multiplier: 1.0,
348 };
349 assert!(controller.validate_architecture(&valid_arch).is_ok());
350 let empty_arch = Architecture {
351 layers: vec![],
352 connections: vec![],
353 width_multiplier: 1.0,
354 depth_multiplier: 1.0,
355 };
356 assert!(controller.validate_architecture(&empty_arch).is_err());
357 let invalid_skip = Architecture {
358 layers: vec![LayerType::Dense(128), LayerType::Dense(10)],
359 connections: vec![(1, 0)],
360 width_multiplier: 1.0,
361 depth_multiplier: 1.0,
362 };
363 assert!(controller.validate_architecture(&invalid_skip).is_err());
364 }
365
366 #[test]
367 fn test_multiplier_application() {
368 let controller =
369 NASController::new(SearchSpaceConfig::default()).expect("failed to create controller");
370 let layers = vec![
371 LayerType::Dense(100),
372 LayerType::Conv2D {
373 filters: 32,
374 kernel_size: (3, 3),
375 stride: (1, 1),
376 },
377 ];
378 let modified = controller
379 .apply_multipliers(&layers, 2.0, 1.0)
380 .expect("failed to apply multipliers");
381 match &modified[0] {
382 LayerType::Dense(units) => assert_eq!(*units, 200),
383 _ => panic!("Expected Dense layer"),
384 }
385 }
386}