1use axonml_autograd::Variable;
19use axonml_nn::{
20 AdaptiveAvgPool2d, BatchNorm2d, Conv2d, Linear, MaxPool2d, Module, Parameter, ReLU,
21};
22use axonml_tensor::Tensor;
23
24fn flatten(input: &Variable) -> Variable {
30 let data = input.data();
31 let shape = data.shape();
32
33 if shape.len() <= 2 {
34 return input.clone();
35 }
36
37 let batch_size = shape[0];
38 let features: usize = shape[1..].iter().product();
39
40 Variable::new(
41 Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
42 input.requires_grad(),
43 )
44}
45
46pub struct BasicBlock {
54 conv1: Conv2d,
55 bn1: BatchNorm2d,
56 conv2: Conv2d,
57 bn2: BatchNorm2d,
58 downsample: Option<(Conv2d, BatchNorm2d)>,
59 relu: ReLU,
60}
61
62impl BasicBlock {
63 pub const EXPANSION: usize = 1;
65
66 pub fn new(
68 in_channels: usize,
69 out_channels: usize,
70 stride: usize,
71 downsample: Option<(Conv2d, BatchNorm2d)>,
72 ) -> Self {
73 Self {
74 conv1: Conv2d::with_options(
75 in_channels,
76 out_channels,
77 (3, 3),
78 (stride, stride),
79 (1, 1),
80 true,
81 ),
82 bn1: BatchNorm2d::new(out_channels),
83 conv2: Conv2d::with_options(out_channels, out_channels, (3, 3), (1, 1), (1, 1), true),
84 bn2: BatchNorm2d::new(out_channels),
85 downsample,
86 relu: ReLU,
87 }
88 }
89}
90
91impl Module for BasicBlock {
92 fn forward(&self, x: &Variable) -> Variable {
93 let identity = x.clone();
94
95 let out = self.conv1.forward(x);
96 let out = self.bn1.forward(&out);
97 let out = self.relu.forward(&out);
98
99 let out = self.conv2.forward(&out);
100 let out = self.bn2.forward(&out);
101
102 let identity = match &self.downsample {
103 Some((conv, bn)) => {
104 let ds = conv.forward(&identity);
105 bn.forward(&ds)
106 }
107 None => identity,
108 };
109
110 let out = out.add_var(&identity);
112 self.relu.forward(&out)
113 }
114
115 fn parameters(&self) -> Vec<Parameter> {
116 let mut params = Vec::new();
117 params.extend(self.conv1.parameters());
118 params.extend(self.bn1.parameters());
119 params.extend(self.conv2.parameters());
120 params.extend(self.bn2.parameters());
121 if let Some((conv, bn)) = &self.downsample {
122 params.extend(conv.parameters());
123 params.extend(bn.parameters());
124 }
125 params
126 }
127
128 fn train(&mut self) {
129 self.bn1.train();
130 self.bn2.train();
131 if let Some((_, bn)) = &mut self.downsample {
132 bn.train();
133 }
134 }
135
136 fn eval(&mut self) {
137 self.bn1.eval();
138 self.bn2.eval();
139 if let Some((_, bn)) = &mut self.downsample {
140 bn.eval();
141 }
142 }
143
144 fn is_training(&self) -> bool {
145 self.bn1.is_training()
146 }
147}
148
149pub struct Bottleneck {
157 conv1: Conv2d,
158 bn1: BatchNorm2d,
159 conv2: Conv2d,
160 bn2: BatchNorm2d,
161 conv3: Conv2d,
162 bn3: BatchNorm2d,
163 downsample: Option<(Conv2d, BatchNorm2d)>,
164 relu: ReLU,
165}
166
167impl Bottleneck {
168 pub const EXPANSION: usize = 4;
170
171 pub fn new(
173 in_channels: usize,
174 out_channels: usize,
175 stride: usize,
176 downsample: Option<(Conv2d, BatchNorm2d)>,
177 ) -> Self {
178 let width = out_channels;
179
180 Self {
181 conv1: Conv2d::with_options(in_channels, width, (1, 1), (1, 1), (0, 0), true),
183 bn1: BatchNorm2d::new(width),
184 conv2: Conv2d::with_options(width, width, (3, 3), (stride, stride), (1, 1), true),
186 bn2: BatchNorm2d::new(width),
187 conv3: Conv2d::with_options(
189 width,
190 out_channels * Self::EXPANSION,
191 (1, 1),
192 (1, 1),
193 (0, 0),
194 true,
195 ),
196 bn3: BatchNorm2d::new(out_channels * Self::EXPANSION),
197 downsample,
198 relu: ReLU,
199 }
200 }
201}
202
203impl Module for Bottleneck {
204 fn forward(&self, x: &Variable) -> Variable {
205 let identity = x.clone();
206
207 let out = self.conv1.forward(x);
208 let out = self.bn1.forward(&out);
209 let out = self.relu.forward(&out);
210
211 let out = self.conv2.forward(&out);
212 let out = self.bn2.forward(&out);
213 let out = self.relu.forward(&out);
214
215 let out = self.conv3.forward(&out);
216 let out = self.bn3.forward(&out);
217
218 let identity = match &self.downsample {
219 Some((conv, bn)) => {
220 let ds = conv.forward(&identity);
221 bn.forward(&ds)
222 }
223 None => identity,
224 };
225
226 let out = out.add_var(&identity);
227 self.relu.forward(&out)
228 }
229
230 fn parameters(&self) -> Vec<Parameter> {
231 let mut params = Vec::new();
232 params.extend(self.conv1.parameters());
233 params.extend(self.bn1.parameters());
234 params.extend(self.conv2.parameters());
235 params.extend(self.bn2.parameters());
236 params.extend(self.conv3.parameters());
237 params.extend(self.bn3.parameters());
238 if let Some((conv, bn)) = &self.downsample {
239 params.extend(conv.parameters());
240 params.extend(bn.parameters());
241 }
242 params
243 }
244
245 fn train(&mut self) {
246 self.bn1.train();
247 self.bn2.train();
248 self.bn3.train();
249 if let Some((_, bn)) = &mut self.downsample {
250 bn.train();
251 }
252 }
253
254 fn eval(&mut self) {
255 self.bn1.eval();
256 self.bn2.eval();
257 self.bn3.eval();
258 if let Some((_, bn)) = &mut self.downsample {
259 bn.eval();
260 }
261 }
262
263 fn is_training(&self) -> bool {
264 self.bn1.is_training()
265 }
266}
267
268pub struct ResNet {
274 conv1: Conv2d,
275 bn1: BatchNorm2d,
276 relu: ReLU,
277 maxpool: MaxPool2d,
278 layer1: Vec<BasicBlock>,
279 layer2: Vec<BasicBlock>,
280 layer3: Vec<BasicBlock>,
281 layer4: Vec<BasicBlock>,
282 avgpool: AdaptiveAvgPool2d,
283 fc: Linear,
284}
285
286impl ResNet {
287 #[must_use] pub fn resnet18(num_classes: usize) -> Self {
289 Self::new_basic(&[2, 2, 2, 2], num_classes)
290 }
291
292 #[must_use] pub fn resnet34(num_classes: usize) -> Self {
294 Self::new_basic(&[3, 4, 6, 3], num_classes)
295 }
296
297 fn new_basic(layers: &[usize; 4], num_classes: usize) -> Self {
299 Self {
300 conv1: Conv2d::with_options(3, 64, (7, 7), (2, 2), (3, 3), true),
301 bn1: BatchNorm2d::new(64),
302 relu: ReLU,
303 maxpool: MaxPool2d::with_options((3, 3), (2, 2), (1, 1)),
304 layer1: Self::make_basic_layer(64, 64, layers[0], 1),
305 layer2: Self::make_basic_layer(64, 128, layers[1], 2),
306 layer3: Self::make_basic_layer(128, 256, layers[2], 2),
307 layer4: Self::make_basic_layer(256, 512, layers[3], 2),
308 avgpool: AdaptiveAvgPool2d::new((1, 1)),
309 fc: Linear::new(512 * BasicBlock::EXPANSION, num_classes),
310 }
311 }
312
313 fn make_basic_layer(
314 in_channels: usize,
315 out_channels: usize,
316 blocks: usize,
317 stride: usize,
318 ) -> Vec<BasicBlock> {
319 let mut layers = Vec::new();
320
321 let downsample = if stride != 1 || in_channels != out_channels {
323 Some((
324 Conv2d::with_options(
325 in_channels,
326 out_channels,
327 (1, 1),
328 (stride, stride),
329 (0, 0),
330 false,
331 ),
332 BatchNorm2d::new(out_channels),
333 ))
334 } else {
335 None
336 };
337
338 layers.push(BasicBlock::new(
339 in_channels,
340 out_channels,
341 stride,
342 downsample,
343 ));
344
345 for _ in 1..blocks {
347 layers.push(BasicBlock::new(out_channels, out_channels, 1, None));
348 }
349
350 layers
351 }
352}
353
354impl Module for ResNet {
355 fn forward(&self, x: &Variable) -> Variable {
356 let mut out = self.conv1.forward(x);
358 out = self.bn1.forward(&out);
359 out = self.relu.forward(&out);
360 out = self.maxpool.forward(&out);
361
362 for block in &self.layer1 {
364 out = block.forward(&out);
365 }
366 for block in &self.layer2 {
367 out = block.forward(&out);
368 }
369 for block in &self.layer3 {
370 out = block.forward(&out);
371 }
372 for block in &self.layer4 {
373 out = block.forward(&out);
374 }
375
376 out = self.avgpool.forward(&out);
378 out = flatten(&out);
380
381 self.fc.forward(&out)
382 }
383
384 fn parameters(&self) -> Vec<Parameter> {
385 let mut params = Vec::new();
386 params.extend(self.conv1.parameters());
387 params.extend(self.bn1.parameters());
388 for block in &self.layer1 {
389 params.extend(block.parameters());
390 }
391 for block in &self.layer2 {
392 params.extend(block.parameters());
393 }
394 for block in &self.layer3 {
395 params.extend(block.parameters());
396 }
397 for block in &self.layer4 {
398 params.extend(block.parameters());
399 }
400 params.extend(self.fc.parameters());
401 params
402 }
403
404 fn train(&mut self) {
405 self.bn1.train();
406 for block in &mut self.layer1 {
407 block.train();
408 }
409 for block in &mut self.layer2 {
410 block.train();
411 }
412 for block in &mut self.layer3 {
413 block.train();
414 }
415 for block in &mut self.layer4 {
416 block.train();
417 }
418 }
419
420 fn eval(&mut self) {
421 self.bn1.eval();
422 for block in &mut self.layer1 {
423 block.eval();
424 }
425 for block in &mut self.layer2 {
426 block.eval();
427 }
428 for block in &mut self.layer3 {
429 block.eval();
430 }
431 for block in &mut self.layer4 {
432 block.eval();
433 }
434 }
435
436 fn is_training(&self) -> bool {
437 self.bn1.is_training()
438 }
439}
440
441#[must_use] pub fn resnet18() -> ResNet {
447 ResNet::resnet18(1000)
448}
449
450#[must_use] pub fn resnet34() -> ResNet {
452 ResNet::resnet34(1000)
453}
454
455#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_basic_block() {
465 let block = BasicBlock::new(64, 64, 1, None);
466
467 let input = Variable::new(
468 Tensor::from_vec(vec![0.0; 64 * 8 * 8], &[1, 64, 8, 8]).unwrap(),
469 false,
470 );
471
472 let output = block.forward(&input);
473 assert_eq!(output.data().shape(), &[1, 64, 8, 8]);
474 }
475
476 #[test]
477 fn test_basic_block_with_downsample() {
478 let downsample = (
479 Conv2d::with_options(64, 128, (1, 1), (2, 2), (0, 0), false),
480 BatchNorm2d::new(128),
481 );
482
483 let block = BasicBlock::new(64, 128, 2, Some(downsample));
484
485 let input = Variable::new(
486 Tensor::from_vec(vec![0.0; 64 * 8 * 8], &[1, 64, 8, 8]).unwrap(),
487 false,
488 );
489
490 let output = block.forward(&input);
491 assert_eq!(output.data().shape(), &[1, 128, 4, 4]);
492 }
493
494 #[test]
495 fn test_resnet18_creation() {
496 let model = ResNet::resnet18(10);
497 let params = model.parameters();
498 assert!(!params.is_empty());
499 }
500
501 #[test]
502 fn test_resnet18_forward_small() {
503 let model = ResNet::resnet18(10);
504
505 let input = Variable::new(
507 Tensor::from_vec(vec![0.0; 3 * 32 * 32], &[1, 3, 32, 32]).unwrap(),
508 false,
509 );
510
511 let output = model.forward(&input);
512 assert_eq!(output.data().shape()[0], 1);
513 assert_eq!(output.data().shape()[1], 10);
514 }
515}