1use axonml_autograd::Variable;
21use axonml_nn::{BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, Parameter, ReLU};
22use axonml_tensor::Tensor;
23
24fn flatten(input: &Variable) -> Variable {
30 let data = input.data();
31 let shape = data.shape();
32
33 if shape.len() <= 2 {
34 return input.clone();
35 }
36
37 let batch_size = shape[0];
38 let features: usize = shape[1..].iter().product();
39
40 Variable::new(
41 Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
42 input.requires_grad(),
43 )
44}
45
46#[derive(Debug, Clone, Copy)]
52pub enum VggLayer {
53 Conv(usize),
55 MaxPool,
57}
58
59#[must_use] pub fn vgg11_config() -> Vec<VggLayer> {
61 use VggLayer::{Conv, MaxPool};
62 vec![
63 Conv(64),
64 MaxPool,
65 Conv(128),
66 MaxPool,
67 Conv(256),
68 Conv(256),
69 MaxPool,
70 Conv(512),
71 Conv(512),
72 MaxPool,
73 Conv(512),
74 Conv(512),
75 MaxPool,
76 ]
77}
78
79#[must_use] pub fn vgg13_config() -> Vec<VggLayer> {
81 use VggLayer::{Conv, MaxPool};
82 vec![
83 Conv(64),
84 Conv(64),
85 MaxPool,
86 Conv(128),
87 Conv(128),
88 MaxPool,
89 Conv(256),
90 Conv(256),
91 MaxPool,
92 Conv(512),
93 Conv(512),
94 MaxPool,
95 Conv(512),
96 Conv(512),
97 MaxPool,
98 ]
99}
100
101#[must_use] pub fn vgg16_config() -> Vec<VggLayer> {
103 use VggLayer::{Conv, MaxPool};
104 vec![
105 Conv(64),
106 Conv(64),
107 MaxPool,
108 Conv(128),
109 Conv(128),
110 MaxPool,
111 Conv(256),
112 Conv(256),
113 Conv(256),
114 MaxPool,
115 Conv(512),
116 Conv(512),
117 Conv(512),
118 MaxPool,
119 Conv(512),
120 Conv(512),
121 Conv(512),
122 MaxPool,
123 ]
124}
125
126#[must_use] pub fn vgg19_config() -> Vec<VggLayer> {
128 use VggLayer::{Conv, MaxPool};
129 vec![
130 Conv(64),
131 Conv(64),
132 MaxPool,
133 Conv(128),
134 Conv(128),
135 MaxPool,
136 Conv(256),
137 Conv(256),
138 Conv(256),
139 Conv(256),
140 MaxPool,
141 Conv(512),
142 Conv(512),
143 Conv(512),
144 Conv(512),
145 MaxPool,
146 Conv(512),
147 Conv(512),
148 Conv(512),
149 Conv(512),
150 MaxPool,
151 ]
152}
153
154pub struct VggFeatures {
160 layers: Vec<VggFeatureLayer>,
161}
162
163enum VggFeatureLayer {
164 Conv(Conv2d),
165 BatchNorm(BatchNorm2d),
166 ReLU(ReLU),
167 MaxPool(MaxPool2d),
168}
169
170impl VggFeatures {
171 #[must_use] pub fn new(config: &[VggLayer], batch_norm: bool) -> Self {
173 let mut layers = Vec::new();
174 let mut in_channels = 3;
175
176 for &layer in config {
177 match layer {
178 VggLayer::Conv(out_channels) => {
179 layers.push(VggFeatureLayer::Conv(Conv2d::with_options(
180 in_channels,
181 out_channels,
182 (3, 3),
183 (1, 1),
184 (1, 1),
185 true,
186 )));
187 if batch_norm {
188 layers.push(VggFeatureLayer::BatchNorm(BatchNorm2d::new(out_channels)));
189 }
190 layers.push(VggFeatureLayer::ReLU(ReLU));
191 in_channels = out_channels;
192 }
193 VggLayer::MaxPool => {
194 layers.push(VggFeatureLayer::MaxPool(MaxPool2d::with_options(
195 (2, 2),
196 (2, 2),
197 (0, 0),
198 )));
199 }
200 }
201 }
202
203 Self { layers }
204 }
205}
206
207impl Module for VggFeatures {
208 fn forward(&self, x: &Variable) -> Variable {
209 let mut out = x.clone();
210 for layer in &self.layers {
211 out = match layer {
212 VggFeatureLayer::Conv(conv) => conv.forward(&out),
213 VggFeatureLayer::BatchNorm(bn) => bn.forward(&out),
214 VggFeatureLayer::ReLU(relu) => relu.forward(&out),
215 VggFeatureLayer::MaxPool(pool) => pool.forward(&out),
216 };
217 }
218 out
219 }
220
221 fn parameters(&self) -> Vec<Parameter> {
222 let mut params = Vec::new();
223 for layer in &self.layers {
224 match layer {
225 VggFeatureLayer::Conv(conv) => params.extend(conv.parameters()),
226 VggFeatureLayer::BatchNorm(bn) => params.extend(bn.parameters()),
227 _ => {}
228 }
229 }
230 params
231 }
232
233 fn train(&mut self) {
234 for layer in &mut self.layers {
235 if let VggFeatureLayer::BatchNorm(bn) = layer {
236 bn.train();
237 }
238 }
239 }
240
241 fn eval(&mut self) {
242 for layer in &mut self.layers {
243 if let VggFeatureLayer::BatchNorm(bn) = layer {
244 bn.eval();
245 }
246 }
247 }
248
249 fn is_training(&self) -> bool {
250 for layer in &self.layers {
251 if let VggFeatureLayer::BatchNorm(bn) = layer {
252 return bn.is_training();
253 }
254 }
255 true
256 }
257}
258
259pub struct VggClassifier {
265 fc1: Linear,
266 fc2: Linear,
267 fc3: Linear,
268 relu: ReLU,
269 dropout: Dropout,
270}
271
272impl VggClassifier {
273 #[must_use] pub fn new(num_classes: usize) -> Self {
275 Self {
276 fc1: Linear::new(512 * 7 * 7, 4096),
277 fc2: Linear::new(4096, 4096),
278 fc3: Linear::new(4096, num_classes),
279 relu: ReLU,
280 dropout: Dropout::new(0.5),
281 }
282 }
283
284 #[must_use] pub fn with_input_size(input_features: usize, num_classes: usize) -> Self {
286 Self {
287 fc1: Linear::new(input_features, 4096),
288 fc2: Linear::new(4096, 4096),
289 fc3: Linear::new(4096, num_classes),
290 relu: ReLU,
291 dropout: Dropout::new(0.5),
292 }
293 }
294}
295
296impl Module for VggClassifier {
297 fn forward(&self, x: &Variable) -> Variable {
298 let out = self.fc1.forward(x);
299 let out = self.relu.forward(&out);
300 let out = self.dropout.forward(&out);
301
302 let out = self.fc2.forward(&out);
303 let out = self.relu.forward(&out);
304 let out = self.dropout.forward(&out);
305
306 self.fc3.forward(&out)
307 }
308
309 fn parameters(&self) -> Vec<Parameter> {
310 let mut params = Vec::new();
311 params.extend(self.fc1.parameters());
312 params.extend(self.fc2.parameters());
313 params.extend(self.fc3.parameters());
314 params
315 }
316
317 fn train(&mut self) {
318 self.dropout.train();
319 }
320
321 fn eval(&mut self) {
322 self.dropout.eval();
323 }
324
325 fn is_training(&self) -> bool {
326 self.dropout.is_training()
327 }
328}
329
330pub struct VGG {
336 features: VggFeatures,
337 classifier: VggClassifier,
338}
339
340impl VGG {
341 #[must_use] pub fn new(config: &[VggLayer], num_classes: usize, batch_norm: bool) -> Self {
343 Self {
344 features: VggFeatures::new(config, batch_norm),
345 classifier: VggClassifier::new(num_classes),
346 }
347 }
348
349 #[must_use] pub fn vgg11(num_classes: usize) -> Self {
351 Self::new(&vgg11_config(), num_classes, false)
352 }
353
354 #[must_use] pub fn vgg11_bn(num_classes: usize) -> Self {
356 Self::new(&vgg11_config(), num_classes, true)
357 }
358
359 #[must_use] pub fn vgg13(num_classes: usize) -> Self {
361 Self::new(&vgg13_config(), num_classes, false)
362 }
363
364 #[must_use] pub fn vgg13_bn(num_classes: usize) -> Self {
366 Self::new(&vgg13_config(), num_classes, true)
367 }
368
369 #[must_use] pub fn vgg16(num_classes: usize) -> Self {
371 Self::new(&vgg16_config(), num_classes, false)
372 }
373
374 #[must_use] pub fn vgg16_bn(num_classes: usize) -> Self {
376 Self::new(&vgg16_config(), num_classes, true)
377 }
378
379 #[must_use] pub fn vgg19(num_classes: usize) -> Self {
381 Self::new(&vgg19_config(), num_classes, false)
382 }
383
384 #[must_use] pub fn vgg19_bn(num_classes: usize) -> Self {
386 Self::new(&vgg19_config(), num_classes, true)
387 }
388}
389
390impl Module for VGG {
391 fn forward(&self, x: &Variable) -> Variable {
392 let out = self.features.forward(x);
393
394 let out = flatten(&out);
396
397 self.classifier.forward(&out)
398 }
399
400 fn parameters(&self) -> Vec<Parameter> {
401 let mut params = Vec::new();
402 params.extend(self.features.parameters());
403 params.extend(self.classifier.parameters());
404 params
405 }
406
407 fn train(&mut self) {
408 self.features.train();
409 self.classifier.train();
410 }
411
412 fn eval(&mut self) {
413 self.features.eval();
414 self.classifier.eval();
415 }
416
417 fn is_training(&self) -> bool {
418 self.features.is_training()
419 }
420}
421
422#[must_use] pub fn vgg11() -> VGG {
428 VGG::vgg11(1000)
429}
430
431#[must_use] pub fn vgg13() -> VGG {
433 VGG::vgg13(1000)
434}
435
436#[must_use] pub fn vgg16() -> VGG {
438 VGG::vgg16(1000)
439}
440
441#[must_use] pub fn vgg19() -> VGG {
443 VGG::vgg19(1000)
444}
445
446#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_vgg_features() {
456 let config = vec![VggLayer::Conv(64), VggLayer::MaxPool];
457 let features = VggFeatures::new(&config, false);
458
459 let input = Variable::new(
460 Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
461 false,
462 );
463
464 let output = features.forward(&input);
465 assert_eq!(output.data().shape()[1], 64);
467 assert_eq!(output.data().shape()[2], 16); }
469
470 #[test]
471 fn test_vgg11_creation() {
472 let model = VGG::vgg11(10);
473 let params = model.parameters();
474 assert!(!params.is_empty());
475 }
476
477 #[test]
478 fn test_vgg11_bn_creation() {
479 let model = VGG::vgg11_bn(10);
480 let params = model.parameters();
481 assert!(!params.is_empty());
482 }
483
484 #[test]
485 fn test_vgg16_creation() {
486 let model = VGG::vgg16(1000);
487 let params = model.parameters();
488 assert!(!params.is_empty());
489 }
490
491 #[test]
492 fn test_vgg_forward_small() {
493 let config = vec![VggLayer::Conv(64), VggLayer::MaxPool];
495 let features = VggFeatures::new(&config, false);
496
497 let classifier = VggClassifier::with_input_size(64 * 16 * 16, 10);
499
500 let input = Variable::new(
501 Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
502 false,
503 );
504
505 let out = features.forward(&input);
506 let out = flatten(&out);
507 let out = classifier.forward(&out);
508
509 assert_eq!(out.data().shape(), &[1, 10]);
510 }
511
512 #[test]
513 fn test_vgg_train_eval_mode() {
514 let mut model = VGG::vgg11_bn(10);
515
516 model.train();
517 assert!(model.is_training());
518
519 model.eval();
520 }
522}