Skip to main content

axonml_vision/models/
vgg.rs

1//! VGG - Very Deep Convolutional Networks
2//!
3//! Implementation of VGG architectures for image classification.
4//!
5//! # Supported Variants
6//!
7//! - VGG11: 11 layers (~133M parameters)
8//! - VGG13: 13 layers (~133M parameters)
9//! - VGG16: 16 layers (~138M parameters)
10//! - VGG19: 19 layers (~144M parameters)
11//!
12//! All variants available with or without batch normalization.
13//!
14//! # Reference
15//!
16//! "Very Deep Convolutional Networks for Large-Scale Image Recognition"
17//! (Simonyan & Zisserman, 2014)
18//! <https://arxiv.org/abs/1409.1556>
19
20use axonml_autograd::Variable;
21use axonml_nn::{BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, Parameter, ReLU};
22use axonml_tensor::Tensor;
23
24// =============================================================================
25// Helper Functions
26// =============================================================================
27
28/// Flatten a tensor from [N, C, H, W] to [N, C*H*W].
29fn flatten(input: &Variable) -> Variable {
30    let data = input.data();
31    let shape = data.shape();
32
33    if shape.len() <= 2 {
34        return input.clone();
35    }
36
37    let batch_size = shape[0];
38    let features: usize = shape[1..].iter().product();
39
40    Variable::new(
41        Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
42        input.requires_grad(),
43    )
44}
45
46// =============================================================================
47// VGG Configuration
48// =============================================================================
49
50/// VGG layer configuration.
51#[derive(Debug, Clone, Copy)]
52pub enum VggLayer {
53    /// Convolutional layer with output channels.
54    Conv(usize),
55    /// Max pooling layer.
56    MaxPool,
57}
58
59/// Get VGG11 configuration.
60#[must_use] pub fn vgg11_config() -> Vec<VggLayer> {
61    use VggLayer::{Conv, MaxPool};
62    vec![
63        Conv(64),
64        MaxPool,
65        Conv(128),
66        MaxPool,
67        Conv(256),
68        Conv(256),
69        MaxPool,
70        Conv(512),
71        Conv(512),
72        MaxPool,
73        Conv(512),
74        Conv(512),
75        MaxPool,
76    ]
77}
78
79/// Get VGG13 configuration.
80#[must_use] pub fn vgg13_config() -> Vec<VggLayer> {
81    use VggLayer::{Conv, MaxPool};
82    vec![
83        Conv(64),
84        Conv(64),
85        MaxPool,
86        Conv(128),
87        Conv(128),
88        MaxPool,
89        Conv(256),
90        Conv(256),
91        MaxPool,
92        Conv(512),
93        Conv(512),
94        MaxPool,
95        Conv(512),
96        Conv(512),
97        MaxPool,
98    ]
99}
100
101/// Get VGG16 configuration.
102#[must_use] pub fn vgg16_config() -> Vec<VggLayer> {
103    use VggLayer::{Conv, MaxPool};
104    vec![
105        Conv(64),
106        Conv(64),
107        MaxPool,
108        Conv(128),
109        Conv(128),
110        MaxPool,
111        Conv(256),
112        Conv(256),
113        Conv(256),
114        MaxPool,
115        Conv(512),
116        Conv(512),
117        Conv(512),
118        MaxPool,
119        Conv(512),
120        Conv(512),
121        Conv(512),
122        MaxPool,
123    ]
124}
125
126/// Get VGG19 configuration.
127#[must_use] pub fn vgg19_config() -> Vec<VggLayer> {
128    use VggLayer::{Conv, MaxPool};
129    vec![
130        Conv(64),
131        Conv(64),
132        MaxPool,
133        Conv(128),
134        Conv(128),
135        MaxPool,
136        Conv(256),
137        Conv(256),
138        Conv(256),
139        Conv(256),
140        MaxPool,
141        Conv(512),
142        Conv(512),
143        Conv(512),
144        Conv(512),
145        MaxPool,
146        Conv(512),
147        Conv(512),
148        Conv(512),
149        Conv(512),
150        MaxPool,
151    ]
152}
153
154// =============================================================================
155// VGG Feature Extractor
156// =============================================================================
157
158/// VGG feature extraction layers.
159pub struct VggFeatures {
160    layers: Vec<VggFeatureLayer>,
161}
162
163enum VggFeatureLayer {
164    Conv(Conv2d),
165    BatchNorm(BatchNorm2d),
166    ReLU(ReLU),
167    MaxPool(MaxPool2d),
168}
169
170impl VggFeatures {
171    /// Create VGG feature layers from configuration.
172    #[must_use] pub fn new(config: &[VggLayer], batch_norm: bool) -> Self {
173        let mut layers = Vec::new();
174        let mut in_channels = 3;
175
176        for &layer in config {
177            match layer {
178                VggLayer::Conv(out_channels) => {
179                    layers.push(VggFeatureLayer::Conv(Conv2d::with_options(
180                        in_channels,
181                        out_channels,
182                        (3, 3),
183                        (1, 1),
184                        (1, 1),
185                        true,
186                    )));
187                    if batch_norm {
188                        layers.push(VggFeatureLayer::BatchNorm(BatchNorm2d::new(out_channels)));
189                    }
190                    layers.push(VggFeatureLayer::ReLU(ReLU));
191                    in_channels = out_channels;
192                }
193                VggLayer::MaxPool => {
194                    layers.push(VggFeatureLayer::MaxPool(MaxPool2d::with_options(
195                        (2, 2),
196                        (2, 2),
197                        (0, 0),
198                    )));
199                }
200            }
201        }
202
203        Self { layers }
204    }
205}
206
207impl Module for VggFeatures {
208    fn forward(&self, x: &Variable) -> Variable {
209        let mut out = x.clone();
210        for layer in &self.layers {
211            out = match layer {
212                VggFeatureLayer::Conv(conv) => conv.forward(&out),
213                VggFeatureLayer::BatchNorm(bn) => bn.forward(&out),
214                VggFeatureLayer::ReLU(relu) => relu.forward(&out),
215                VggFeatureLayer::MaxPool(pool) => pool.forward(&out),
216            };
217        }
218        out
219    }
220
221    fn parameters(&self) -> Vec<Parameter> {
222        let mut params = Vec::new();
223        for layer in &self.layers {
224            match layer {
225                VggFeatureLayer::Conv(conv) => params.extend(conv.parameters()),
226                VggFeatureLayer::BatchNorm(bn) => params.extend(bn.parameters()),
227                _ => {}
228            }
229        }
230        params
231    }
232
233    fn train(&mut self) {
234        for layer in &mut self.layers {
235            if let VggFeatureLayer::BatchNorm(bn) = layer {
236                bn.train();
237            }
238        }
239    }
240
241    fn eval(&mut self) {
242        for layer in &mut self.layers {
243            if let VggFeatureLayer::BatchNorm(bn) = layer {
244                bn.eval();
245            }
246        }
247    }
248
249    fn is_training(&self) -> bool {
250        for layer in &self.layers {
251            if let VggFeatureLayer::BatchNorm(bn) = layer {
252                return bn.is_training();
253            }
254        }
255        true
256    }
257}
258
259// =============================================================================
260// VGG Classifier
261// =============================================================================
262
263/// VGG classifier head.
264pub struct VggClassifier {
265    fc1: Linear,
266    fc2: Linear,
267    fc3: Linear,
268    relu: ReLU,
269    dropout: Dropout,
270}
271
272impl VggClassifier {
273    /// Create classifier for VGG (assuming 7x7 feature maps).
274    #[must_use] pub fn new(num_classes: usize) -> Self {
275        Self {
276            fc1: Linear::new(512 * 7 * 7, 4096),
277            fc2: Linear::new(4096, 4096),
278            fc3: Linear::new(4096, num_classes),
279            relu: ReLU,
280            dropout: Dropout::new(0.5),
281        }
282    }
283
284    /// Create classifier with custom input size.
285    #[must_use] pub fn with_input_size(input_features: usize, num_classes: usize) -> Self {
286        Self {
287            fc1: Linear::new(input_features, 4096),
288            fc2: Linear::new(4096, 4096),
289            fc3: Linear::new(4096, num_classes),
290            relu: ReLU,
291            dropout: Dropout::new(0.5),
292        }
293    }
294}
295
296impl Module for VggClassifier {
297    fn forward(&self, x: &Variable) -> Variable {
298        let out = self.fc1.forward(x);
299        let out = self.relu.forward(&out);
300        let out = self.dropout.forward(&out);
301
302        let out = self.fc2.forward(&out);
303        let out = self.relu.forward(&out);
304        let out = self.dropout.forward(&out);
305
306        self.fc3.forward(&out)
307    }
308
309    fn parameters(&self) -> Vec<Parameter> {
310        let mut params = Vec::new();
311        params.extend(self.fc1.parameters());
312        params.extend(self.fc2.parameters());
313        params.extend(self.fc3.parameters());
314        params
315    }
316
317    fn train(&mut self) {
318        self.dropout.train();
319    }
320
321    fn eval(&mut self) {
322        self.dropout.eval();
323    }
324
325    fn is_training(&self) -> bool {
326        self.dropout.is_training()
327    }
328}
329
330// =============================================================================
331// VGG Model
332// =============================================================================
333
334/// VGG model for image classification.
335pub struct VGG {
336    features: VggFeatures,
337    classifier: VggClassifier,
338}
339
340impl VGG {
341    /// Create VGG with custom configuration.
342    #[must_use] pub fn new(config: &[VggLayer], num_classes: usize, batch_norm: bool) -> Self {
343        Self {
344            features: VggFeatures::new(config, batch_norm),
345            classifier: VggClassifier::new(num_classes),
346        }
347    }
348
349    /// Create VGG11.
350    #[must_use] pub fn vgg11(num_classes: usize) -> Self {
351        Self::new(&vgg11_config(), num_classes, false)
352    }
353
354    /// Create VGG11 with batch normalization.
355    #[must_use] pub fn vgg11_bn(num_classes: usize) -> Self {
356        Self::new(&vgg11_config(), num_classes, true)
357    }
358
359    /// Create VGG13.
360    #[must_use] pub fn vgg13(num_classes: usize) -> Self {
361        Self::new(&vgg13_config(), num_classes, false)
362    }
363
364    /// Create VGG13 with batch normalization.
365    #[must_use] pub fn vgg13_bn(num_classes: usize) -> Self {
366        Self::new(&vgg13_config(), num_classes, true)
367    }
368
369    /// Create VGG16.
370    #[must_use] pub fn vgg16(num_classes: usize) -> Self {
371        Self::new(&vgg16_config(), num_classes, false)
372    }
373
374    /// Create VGG16 with batch normalization.
375    #[must_use] pub fn vgg16_bn(num_classes: usize) -> Self {
376        Self::new(&vgg16_config(), num_classes, true)
377    }
378
379    /// Create VGG19.
380    #[must_use] pub fn vgg19(num_classes: usize) -> Self {
381        Self::new(&vgg19_config(), num_classes, false)
382    }
383
384    /// Create VGG19 with batch normalization.
385    #[must_use] pub fn vgg19_bn(num_classes: usize) -> Self {
386        Self::new(&vgg19_config(), num_classes, true)
387    }
388}
389
390impl Module for VGG {
391    fn forward(&self, x: &Variable) -> Variable {
392        let out = self.features.forward(x);
393
394        // Flatten: [batch, 512, 7, 7] -> [batch, 512*7*7]
395        let out = flatten(&out);
396
397        self.classifier.forward(&out)
398    }
399
400    fn parameters(&self) -> Vec<Parameter> {
401        let mut params = Vec::new();
402        params.extend(self.features.parameters());
403        params.extend(self.classifier.parameters());
404        params
405    }
406
407    fn train(&mut self) {
408        self.features.train();
409        self.classifier.train();
410    }
411
412    fn eval(&mut self) {
413        self.features.eval();
414        self.classifier.eval();
415    }
416
417    fn is_training(&self) -> bool {
418        self.features.is_training()
419    }
420}
421
422// =============================================================================
423// Convenience Functions
424// =============================================================================
425
426/// Create VGG11 for `ImageNet` (1000 classes).
427#[must_use] pub fn vgg11() -> VGG {
428    VGG::vgg11(1000)
429}
430
431/// Create VGG13 for `ImageNet` (1000 classes).
432#[must_use] pub fn vgg13() -> VGG {
433    VGG::vgg13(1000)
434}
435
436/// Create VGG16 for `ImageNet` (1000 classes).
437#[must_use] pub fn vgg16() -> VGG {
438    VGG::vgg16(1000)
439}
440
441/// Create VGG19 for `ImageNet` (1000 classes).
442#[must_use] pub fn vgg19() -> VGG {
443    VGG::vgg19(1000)
444}
445
446// =============================================================================
447// Tests
448// =============================================================================
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_vgg_features() {
456        let config = vec![VggLayer::Conv(64), VggLayer::MaxPool];
457        let features = VggFeatures::new(&config, false);
458
459        let input = Variable::new(
460            Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
461            false,
462        );
463
464        let output = features.forward(&input);
465        // After one conv and one maxpool
466        assert_eq!(output.data().shape()[1], 64);
467        assert_eq!(output.data().shape()[2], 16); // 32 / 2
468    }
469
470    #[test]
471    fn test_vgg11_creation() {
472        let model = VGG::vgg11(10);
473        let params = model.parameters();
474        assert!(!params.is_empty());
475    }
476
477    #[test]
478    fn test_vgg11_bn_creation() {
479        let model = VGG::vgg11_bn(10);
480        let params = model.parameters();
481        assert!(!params.is_empty());
482    }
483
484    #[test]
485    fn test_vgg16_creation() {
486        let model = VGG::vgg16(1000);
487        let params = model.parameters();
488        assert!(!params.is_empty());
489    }
490
491    #[test]
492    fn test_vgg_forward_small() {
493        // Use small input for quick test
494        let config = vec![VggLayer::Conv(64), VggLayer::MaxPool];
495        let features = VggFeatures::new(&config, false);
496
497        // Custom small classifier
498        let classifier = VggClassifier::with_input_size(64 * 16 * 16, 10);
499
500        let input = Variable::new(
501            Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
502            false,
503        );
504
505        let out = features.forward(&input);
506        let out = flatten(&out);
507        let out = classifier.forward(&out);
508
509        assert_eq!(out.data().shape(), &[1, 10]);
510    }
511
512    #[test]
513    fn test_vgg_train_eval_mode() {
514        let mut model = VGG::vgg11_bn(10);
515
516        model.train();
517        assert!(model.is_training());
518
519        model.eval();
520        // Note: eval mode may not change is_training for all layers
521    }
522}