Skip to main content

axonml_vision/models/
resnet.rs

1//! `ResNet` - Deep Residual Networks
2//!
3//! Implementation of `ResNet` architectures for image classification.
4//!
5//! # Supported Variants
6//!
7//! - `ResNet18`: 18 layers, ~11M parameters
8//! - `ResNet34`: 34 layers, ~21M parameters
9//! - `ResNet50`: 50 layers, ~23M parameters
10//! - `ResNet101`: 101 layers, ~42M parameters
11//! - `ResNet152`: 152 layers, ~58M parameters
12//!
13//! # Reference
14//!
15//! "Deep Residual Learning for Image Recognition" (He et al., 2015)
16//! <https://arxiv.org/abs/1512.03385>
17
18use axonml_autograd::Variable;
19use axonml_nn::{
20    AdaptiveAvgPool2d, BatchNorm2d, Conv2d, Linear, MaxPool2d, Module, Parameter, ReLU,
21};
22use axonml_tensor::Tensor;
23
24// =============================================================================
25// Helper Functions
26// =============================================================================
27
28/// Flatten a tensor from [N, C, H, W] to [N, C*H*W].
29fn flatten(input: &Variable) -> Variable {
30    let data = input.data();
31    let shape = data.shape();
32
33    if shape.len() <= 2 {
34        return input.clone();
35    }
36
37    let batch_size = shape[0];
38    let features: usize = shape[1..].iter().product();
39
40    Variable::new(
41        Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
42        input.requires_grad(),
43    )
44}
45
46// =============================================================================
47// Basic Block (for ResNet18, ResNet34)
48// =============================================================================
49
50/// Basic residual block for ResNet18/34.
51///
52/// Structure: conv3x3 -> BN -> `ReLU` -> conv3x3 -> BN + skip -> `ReLU`
53pub struct BasicBlock {
54    conv1: Conv2d,
55    bn1: BatchNorm2d,
56    conv2: Conv2d,
57    bn2: BatchNorm2d,
58    downsample: Option<(Conv2d, BatchNorm2d)>,
59    relu: ReLU,
60}
61
62impl BasicBlock {
63    /// Expansion factor for this block type.
64    pub const EXPANSION: usize = 1;
65
66    /// Create a new `BasicBlock`.
67    pub fn new(
68        in_channels: usize,
69        out_channels: usize,
70        stride: usize,
71        downsample: Option<(Conv2d, BatchNorm2d)>,
72    ) -> Self {
73        Self {
74            conv1: Conv2d::with_options(
75                in_channels,
76                out_channels,
77                (3, 3),
78                (stride, stride),
79                (1, 1),
80                true,
81            ),
82            bn1: BatchNorm2d::new(out_channels),
83            conv2: Conv2d::with_options(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true),
84            bn2: BatchNorm2d::new(out_channels),
85            downsample,
86            relu: ReLU,
87        }
88    }
89}
90
91impl Module for BasicBlock {
92    fn forward(&self, x: &Variable) -> Variable {
93        let identity = x.clone();
94
95        let out = self.conv1.forward(x);
96        let out = self.bn1.forward(&out);
97        let out = self.relu.forward(&out);
98
99        let out = self.conv2.forward(&out);
100        let out = self.bn2.forward(&out);
101
102        let identity = match &self.downsample {
103            Some((conv, bn)) => {
104                let ds = conv.forward(&identity);
105                bn.forward(&ds)
106            }
107            None => identity,
108        };
109
110        // Residual connection: out = out + identity
111        let out = out.add_var(&identity);
112        self.relu.forward(&out)
113    }
114
115    fn parameters(&self) -> Vec<Parameter> {
116        let mut params = Vec::new();
117        params.extend(self.conv1.parameters());
118        params.extend(self.bn1.parameters());
119        params.extend(self.conv2.parameters());
120        params.extend(self.bn2.parameters());
121        if let Some((conv, bn)) = &self.downsample {
122            params.extend(conv.parameters());
123            params.extend(bn.parameters());
124        }
125        params
126    }
127
128    fn train(&mut self) {
129        self.bn1.train();
130        self.bn2.train();
131        if let Some((_, bn)) = &mut self.downsample {
132            bn.train();
133        }
134    }
135
136    fn eval(&mut self) {
137        self.bn1.eval();
138        self.bn2.eval();
139        if let Some((_, bn)) = &mut self.downsample {
140            bn.eval();
141        }
142    }
143
144    fn is_training(&self) -> bool {
145        self.bn1.is_training()
146    }
147}
148
149// =============================================================================
150// Bottleneck Block (for ResNet50, ResNet101, ResNet152)
151// =============================================================================
152
153/// Bottleneck residual block for ResNet50/101/152.
154///
155/// Structure: conv1x1 -> BN -> `ReLU` -> conv3x3 -> BN -> `ReLU` -> conv1x1 -> BN + skip -> `ReLU`
156pub struct Bottleneck {
157    conv1: Conv2d,
158    bn1: BatchNorm2d,
159    conv2: Conv2d,
160    bn2: BatchNorm2d,
161    conv3: Conv2d,
162    bn3: BatchNorm2d,
163    downsample: Option<(Conv2d, BatchNorm2d)>,
164    relu: ReLU,
165}
166
167impl Bottleneck {
168    /// Expansion factor for this block type.
169    pub const EXPANSION: usize = 4;
170
171    /// Create a new Bottleneck block.
172    pub fn new(
173        in_channels: usize,
174        out_channels: usize,
175        stride: usize,
176        downsample: Option<(Conv2d, BatchNorm2d)>,
177    ) -> Self {
178        let width = out_channels;
179
180        Self {
181            // 1x1 conv to reduce channels
182            conv1: Conv2d::with_options(in_channels, width, (1, 1), (1, 1), (0, 0), true),
183            bn1: BatchNorm2d::new(width),
184            // 3x3 conv
185            conv2: Conv2d::with_options(width, width, (3, 3), (stride, stride), (1, 1), true),
186            bn2: BatchNorm2d::new(width),
187            // 1x1 conv to expand channels
188            conv3: Conv2d::with_options(
189                width,
190                out_channels * Self::EXPANSION,
191                (1, 1),
192                (1, 1),
193                (0, 0),
194                true,
195            ),
196            bn3: BatchNorm2d::new(out_channels * Self::EXPANSION),
197            downsample,
198            relu: ReLU,
199        }
200    }
201}
202
203impl Module for Bottleneck {
204    fn forward(&self, x: &Variable) -> Variable {
205        let identity = x.clone();
206
207        let out = self.conv1.forward(x);
208        let out = self.bn1.forward(&out);
209        let out = self.relu.forward(&out);
210
211        let out = self.conv2.forward(&out);
212        let out = self.bn2.forward(&out);
213        let out = self.relu.forward(&out);
214
215        let out = self.conv3.forward(&out);
216        let out = self.bn3.forward(&out);
217
218        let identity = match &self.downsample {
219            Some((conv, bn)) => {
220                let ds = conv.forward(&identity);
221                bn.forward(&ds)
222            }
223            None => identity,
224        };
225
226        let out = out.add_var(&identity);
227        self.relu.forward(&out)
228    }
229
230    fn parameters(&self) -> Vec<Parameter> {
231        let mut params = Vec::new();
232        params.extend(self.conv1.parameters());
233        params.extend(self.bn1.parameters());
234        params.extend(self.conv2.parameters());
235        params.extend(self.bn2.parameters());
236        params.extend(self.conv3.parameters());
237        params.extend(self.bn3.parameters());
238        if let Some((conv, bn)) = &self.downsample {
239            params.extend(conv.parameters());
240            params.extend(bn.parameters());
241        }
242        params
243    }
244
245    fn train(&mut self) {
246        self.bn1.train();
247        self.bn2.train();
248        self.bn3.train();
249        if let Some((_, bn)) = &mut self.downsample {
250            bn.train();
251        }
252    }
253
254    fn eval(&mut self) {
255        self.bn1.eval();
256        self.bn2.eval();
257        self.bn3.eval();
258        if let Some((_, bn)) = &mut self.downsample {
259            bn.eval();
260        }
261    }
262
263    fn is_training(&self) -> bool {
264        self.bn1.is_training()
265    }
266}
267
268// =============================================================================
269// ResNet
270// =============================================================================
271
272/// `ResNet` model for image classification.
273pub struct ResNet {
274    conv1: Conv2d,
275    bn1: BatchNorm2d,
276    relu: ReLU,
277    maxpool: MaxPool2d,
278    layer1: Vec<BasicBlock>,
279    layer2: Vec<BasicBlock>,
280    layer3: Vec<BasicBlock>,
281    layer4: Vec<BasicBlock>,
282    avgpool: AdaptiveAvgPool2d,
283    fc: Linear,
284}
285
286impl ResNet {
287    /// Create `ResNet18`.
288    #[must_use] pub fn resnet18(num_classes: usize) -> Self {
289        Self::new_basic(&[2, 2, 2, 2], num_classes)
290    }
291
292    /// Create `ResNet34`.
293    #[must_use] pub fn resnet34(num_classes: usize) -> Self {
294        Self::new_basic(&[3, 4, 6, 3], num_classes)
295    }
296
297    /// Create `ResNet` with `BasicBlock`.
298    fn new_basic(layers: &[usize; 4], num_classes: usize) -> Self {
299        Self {
300            conv1: Conv2d::with_options(3, 64, (7, 7), (2, 2), (3, 3), true),
301            bn1: BatchNorm2d::new(64),
302            relu: ReLU,
303            maxpool: MaxPool2d::with_options((3, 3), (2, 2), (1, 1)),
304            layer1: Self::make_basic_layer(64, 64, layers[0], 1),
305            layer2: Self::make_basic_layer(64, 128, layers[1], 2),
306            layer3: Self::make_basic_layer(128, 256, layers[2], 2),
307            layer4: Self::make_basic_layer(256, 512, layers[3], 2),
308            avgpool: AdaptiveAvgPool2d::new((1, 1)),
309            fc: Linear::new(512 * BasicBlock::EXPANSION, num_classes),
310        }
311    }
312
313    fn make_basic_layer(
314        in_channels: usize,
315        out_channels: usize,
316        blocks: usize,
317        stride: usize,
318    ) -> Vec<BasicBlock> {
319        let mut layers = Vec::new();
320
321        // First block may have stride and downsample
322        let downsample = if stride != 1 || in_channels != out_channels {
323            Some((
324                Conv2d::with_options(
325                    in_channels,
326                    out_channels,
327                    (1, 1),
328                    (stride, stride),
329                    (0, 0),
330                    false,
331                ),
332                BatchNorm2d::new(out_channels),
333            ))
334        } else {
335            None
336        };
337
338        layers.push(BasicBlock::new(
339            in_channels,
340            out_channels,
341            stride,
342            downsample,
343        ));
344
345        // Remaining blocks
346        for _ in 1..blocks {
347            layers.push(BasicBlock::new(out_channels, out_channels, 1, None));
348        }
349
350        layers
351    }
352}
353
354impl Module for ResNet {
355    fn forward(&self, x: &Variable) -> Variable {
356        // Initial conv layer
357        let mut out = self.conv1.forward(x);
358        out = self.bn1.forward(&out);
359        out = self.relu.forward(&out);
360        out = self.maxpool.forward(&out);
361
362        // Residual layers
363        for block in &self.layer1 {
364            out = block.forward(&out);
365        }
366        for block in &self.layer2 {
367            out = block.forward(&out);
368        }
369        for block in &self.layer3 {
370            out = block.forward(&out);
371        }
372        for block in &self.layer4 {
373            out = block.forward(&out);
374        }
375
376        // Classification head
377        out = self.avgpool.forward(&out);
378        // Flatten: [batch, channels, 1, 1] -> [batch, channels]
379        out = flatten(&out);
380
381        self.fc.forward(&out)
382    }
383
384    fn parameters(&self) -> Vec<Parameter> {
385        let mut params = Vec::new();
386        params.extend(self.conv1.parameters());
387        params.extend(self.bn1.parameters());
388        for block in &self.layer1 {
389            params.extend(block.parameters());
390        }
391        for block in &self.layer2 {
392            params.extend(block.parameters());
393        }
394        for block in &self.layer3 {
395            params.extend(block.parameters());
396        }
397        for block in &self.layer4 {
398            params.extend(block.parameters());
399        }
400        params.extend(self.fc.parameters());
401        params
402    }
403
404    fn train(&mut self) {
405        self.bn1.train();
406        for block in &mut self.layer1 {
407            block.train();
408        }
409        for block in &mut self.layer2 {
410            block.train();
411        }
412        for block in &mut self.layer3 {
413            block.train();
414        }
415        for block in &mut self.layer4 {
416            block.train();
417        }
418    }
419
420    fn eval(&mut self) {
421        self.bn1.eval();
422        for block in &mut self.layer1 {
423            block.eval();
424        }
425        for block in &mut self.layer2 {
426            block.eval();
427        }
428        for block in &mut self.layer3 {
429            block.eval();
430        }
431        for block in &mut self.layer4 {
432            block.eval();
433        }
434    }
435
436    fn is_training(&self) -> bool {
437        self.bn1.is_training()
438    }
439}
440
441// =============================================================================
442// Convenience Functions
443// =============================================================================
444
445/// Create `ResNet18` for `ImageNet` (1000 classes).
446#[must_use] pub fn resnet18() -> ResNet {
447    ResNet::resnet18(1000)
448}
449
450/// Create `ResNet34` for `ImageNet` (1000 classes).
451#[must_use] pub fn resnet34() -> ResNet {
452    ResNet::resnet34(1000)
453}
454
455// =============================================================================
456// Tests
457// =============================================================================
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_basic_block() {
465        let block = BasicBlock::new(64, 64, 1, None);
466
467        let input = Variable::new(
468            Tensor::from_vec(vec![0.0; 64 * 8 * 8], &[1, 64, 8, 8]).unwrap(),
469            false,
470        );
471
472        let output = block.forward(&input);
473        assert_eq!(output.data().shape(), &[1, 64, 8, 8]);
474    }
475
476    #[test]
477    fn test_basic_block_with_downsample() {
478        let downsample = (
479            Conv2d::with_options(64, 128, (1, 1), (2, 2), (0, 0), false),
480            BatchNorm2d::new(128),
481        );
482
483        let block = BasicBlock::new(64, 128, 2, Some(downsample));
484
485        let input = Variable::new(
486            Tensor::from_vec(vec![0.0; 64 * 8 * 8], &[1, 64, 8, 8]).unwrap(),
487            false,
488        );
489
490        let output = block.forward(&input);
491        assert_eq!(output.data().shape(), &[1, 128, 4, 4]);
492    }
493
494    #[test]
495    fn test_resnet18_creation() {
496        let model = ResNet::resnet18(10);
497        let params = model.parameters();
498        assert!(!params.is_empty());
499    }
500
501    #[test]
502    fn test_resnet18_forward_small() {
503        let model = ResNet::resnet18(10);
504
505        // Small input for quick test
506        let input = Variable::new(
507            Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
508            false,
509        );
510
511        let output = model.forward(&input);
512        assert_eq!(output.data().shape()[0], 1);
513        assert_eq!(output.data().shape()[1], 10);
514    }
515}