1use std::any::Any;
27
28use axonml_autograd::no_grad::is_grad_enabled;
29use axonml_autograd::{GradFn, GradientFunction, Variable};
30use axonml_nn::{Conv2d, Linear, Module, Parameter};
31use axonml_tensor::Tensor;
32
33pub struct LeNet {
47 conv1: Conv2d,
48 conv2: Conv2d,
49 fc1: Linear,
50 fc2: Linear,
51 fc3: Linear,
52}
53
54impl LeNet {
55 #[must_use]
57 pub fn new() -> Self {
58 Self {
59 conv1: Conv2d::new(1, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 4 * 4, 120), fc2: Linear::new(120, 84),
63 fc3: Linear::new(84, 10),
64 }
65 }
66
67 #[must_use]
69 pub fn for_cifar10() -> Self {
70 Self {
71 conv1: Conv2d::new(3, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 5 * 5, 120), fc2: Linear::new(120, 84),
75 fc3: Linear::new(84, 10),
76 }
77 }
78
79 fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
81 let data = input.data();
82 let shape = data.shape();
83
84 if shape.len() == 4 {
85 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
86 let out_h = h / kernel_size;
87 let out_w = w / kernel_size;
88
89 let data_vec = data.to_vec();
90 let out_size = n * c * out_h * out_w;
91 let mut result = vec![0.0f32; out_size];
92 let mut max_indices = vec![0usize; out_size];
93
94 for batch in 0..n {
95 for ch in 0..c {
96 for oh in 0..out_h {
97 for ow in 0..out_w {
98 let mut max_val = f32::NEG_INFINITY;
99 let mut max_idx = 0usize;
100 for kh in 0..kernel_size {
101 for kw in 0..kernel_size {
102 let ih = oh * kernel_size + kh;
103 let iw = ow * kernel_size + kw;
104 let idx = batch * c * h * w + ch * h * w + ih * w + iw;
105 if data_vec[idx] > max_val {
106 max_val = data_vec[idx];
107 max_idx = idx;
108 }
109 }
110 }
111 let out_idx =
112 batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
113 result[out_idx] = max_val;
114 max_indices[out_idx] = max_idx;
115 }
116 }
117 }
118 }
119
120 let output_tensor = Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap();
121 if input.requires_grad() && is_grad_enabled() {
122 let grad_fn = GradFn::new(MaxPool2dBackward {
123 next_fns: vec![input.grad_fn().cloned()],
124 max_indices,
125 input_shape: shape.to_vec(),
126 });
127 Variable::from_operation(output_tensor, grad_fn, true)
128 } else {
129 Variable::new(output_tensor, false)
130 }
131 } else if shape.len() == 3 {
132 let (c, h, w) = (shape[0], shape[1], shape[2]);
134 let out_h = h / kernel_size;
135 let out_w = w / kernel_size;
136
137 let data_vec = data.to_vec();
138 let out_size = c * out_h * out_w;
139 let mut result = vec![0.0f32; out_size];
140 let mut max_indices = vec![0usize; out_size];
141
142 for ch in 0..c {
143 for oh in 0..out_h {
144 for ow in 0..out_w {
145 let mut max_val = f32::NEG_INFINITY;
146 let mut max_idx = 0usize;
147 for kh in 0..kernel_size {
148 for kw in 0..kernel_size {
149 let ih = oh * kernel_size + kh;
150 let iw = ow * kernel_size + kw;
151 let idx = ch * h * w + ih * w + iw;
152 if data_vec[idx] > max_val {
153 max_val = data_vec[idx];
154 max_idx = idx;
155 }
156 }
157 }
158 let out_idx = ch * out_h * out_w + oh * out_w + ow;
159 result[out_idx] = max_val;
160 max_indices[out_idx] = max_idx;
161 }
162 }
163 }
164
165 let output_tensor = Tensor::from_vec(result, &[c, out_h, out_w]).unwrap();
166 if input.requires_grad() && is_grad_enabled() {
167 let grad_fn = GradFn::new(MaxPool2dBackward {
168 next_fns: vec![input.grad_fn().cloned()],
169 max_indices,
170 input_shape: shape.to_vec(),
171 });
172 Variable::from_operation(output_tensor, grad_fn, true)
173 } else {
174 Variable::new(output_tensor, false)
175 }
176 } else {
177 input.clone()
178 }
179 }
180
181 fn flatten(&self, input: &Variable) -> Variable {
184 let shape = input.shape();
185
186 if shape.len() <= 2 {
187 return input.clone();
188 }
189
190 let batch_size = shape[0];
191 let features: usize = shape[1..].iter().product();
192
193 input.reshape(&[batch_size, features])
194 }
195}
196
197impl Default for LeNet {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl Module for LeNet {
204 fn forward(&self, input: &Variable) -> Variable {
205 let x = self.conv1.forward(input);
207 let x = x.relu();
208 let x = self.max_pool2d(&x, 2);
209
210 let x = self.conv2.forward(&x);
212 let x = x.relu();
213 let x = self.max_pool2d(&x, 2);
214
215 let x = self.flatten(&x);
217
218 let x = self.fc1.forward(&x);
220 let x = x.relu();
221 let x = self.fc2.forward(&x);
222 let x = x.relu();
223 self.fc3.forward(&x)
224 }
225
226 fn parameters(&self) -> Vec<Parameter> {
227 let mut params = Vec::new();
228 params.extend(self.conv1.parameters());
229 params.extend(self.conv2.parameters());
230 params.extend(self.fc1.parameters());
231 params.extend(self.fc2.parameters());
232 params.extend(self.fc3.parameters());
233 params
234 }
235
236 fn train(&mut self) {
237 }
239
240 fn eval(&mut self) {
241 }
243}
244
245pub struct SimpleCNN {
251 conv1: Conv2d,
252 fc1: Linear,
253 fc2: Linear,
254 input_channels: usize,
255 num_classes: usize,
256}
257
258impl SimpleCNN {
259 #[must_use]
262 pub fn new(input_channels: usize, num_classes: usize) -> Self {
263 Self {
264 conv1: Conv2d::new(input_channels, 32, 3),
265 fc1: Linear::new(32 * 13 * 13, 128), fc2: Linear::new(128, num_classes),
267 input_channels,
268 num_classes,
269 }
270 }
271
272 #[must_use]
274 pub fn for_mnist() -> Self {
275 Self::new(1, 10)
276 }
277
278 #[must_use]
280 pub fn for_cifar10() -> Self {
281 Self {
283 conv1: Conv2d::new(3, 32, 3),
284 fc1: Linear::new(32 * 15 * 15, 128),
285 fc2: Linear::new(128, 10),
286 input_channels: 3,
287 num_classes: 10,
288 }
289 }
290
291 #[must_use]
293 pub fn input_channels(&self) -> usize {
294 self.input_channels
295 }
296
297 #[must_use]
299 pub fn num_classes(&self) -> usize {
300 self.num_classes
301 }
302
303 fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
304 let data = input.data();
305 let shape = data.shape();
306
307 if shape.len() != 4 {
308 return input.clone();
309 }
310
311 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
312 let out_h = h / kernel_size;
313 let out_w = w / kernel_size;
314
315 let data_vec = data.to_vec();
316 let mut result = vec![0.0f32; n * c * out_h * out_w];
317
318 for batch in 0..n {
319 for ch in 0..c {
320 for oh in 0..out_h {
321 for ow in 0..out_w {
322 let mut max_val = f32::NEG_INFINITY;
323 for kh in 0..kernel_size {
324 for kw in 0..kernel_size {
325 let ih = oh * kernel_size + kh;
326 let iw = ow * kernel_size + kw;
327 let idx = batch * c * h * w + ch * h * w + ih * w + iw;
328 max_val = max_val.max(data_vec[idx]);
329 }
330 }
331 let out_idx =
332 batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
333 result[out_idx] = max_val;
334 }
335 }
336 }
337 }
338
339 Variable::new(
340 Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
341 input.requires_grad(),
342 )
343 }
344
345 fn flatten(&self, input: &Variable) -> Variable {
346 let shape = input.shape();
347
348 if shape.len() <= 2 {
349 return input.clone();
350 }
351
352 let batch_size = shape[0];
353 let features: usize = shape[1..].iter().product();
354
355 input.reshape(&[batch_size, features])
356 }
357}
358
359impl Module for SimpleCNN {
360 fn forward(&self, input: &Variable) -> Variable {
361 let x = self.conv1.forward(input);
362 let x = x.relu();
363 let x = self.max_pool2d(&x, 2);
364 let x = self.flatten(&x);
365 let x = self.fc1.forward(&x);
366 let x = x.relu();
367 self.fc2.forward(&x)
368 }
369
370 fn parameters(&self) -> Vec<Parameter> {
371 let mut params = Vec::new();
372 params.extend(self.conv1.parameters());
373 params.extend(self.fc1.parameters());
374 params.extend(self.fc2.parameters());
375 params
376 }
377
378 fn train(&mut self) {}
379 fn eval(&mut self) {}
380}
381
382pub struct MLP {
388 fc1: Linear,
389 fc2: Linear,
390 fc3: Linear,
391}
392
393impl MLP {
394 #[must_use]
396 pub fn new(input_size: usize, hidden_size: usize, num_classes: usize) -> Self {
397 Self {
398 fc1: Linear::new(input_size, hidden_size),
399 fc2: Linear::new(hidden_size, hidden_size / 2),
400 fc3: Linear::new(hidden_size / 2, num_classes),
401 }
402 }
403
404 #[must_use]
406 pub fn for_mnist() -> Self {
407 Self::new(784, 256, 10)
408 }
409
410 #[must_use]
412 pub fn for_cifar10() -> Self {
413 Self::new(3072, 512, 10)
414 }
415}
416
417impl Module for MLP {
418 fn forward(&self, input: &Variable) -> Variable {
419 let data = input.data();
421 let shape = data.shape();
422 let x = if shape.len() > 2 {
423 let batch = shape[0];
424 let features: usize = shape[1..].iter().product();
425 Variable::new(
426 Tensor::from_vec(data.to_vec(), &[batch, features]).unwrap(),
427 input.requires_grad(),
428 )
429 } else if shape.len() == 1 {
430 Variable::new(
432 Tensor::from_vec(data.to_vec(), &[1, shape[0]]).unwrap(),
433 input.requires_grad(),
434 )
435 } else {
436 input.clone()
437 };
438
439 let x = self.fc1.forward(&x);
440 let x = x.relu();
441 let x = self.fc2.forward(&x);
442 let x = x.relu();
443 self.fc3.forward(&x)
444 }
445
446 fn parameters(&self) -> Vec<Parameter> {
447 let mut params = Vec::new();
448 params.extend(self.fc1.parameters());
449 params.extend(self.fc2.parameters());
450 params.extend(self.fc3.parameters());
451 params
452 }
453
454 fn train(&mut self) {}
455 fn eval(&mut self) {}
456}
457
458#[derive(Debug)]
466struct MaxPool2dBackward {
467 next_fns: Vec<Option<GradFn>>,
468 max_indices: Vec<usize>,
469 input_shape: Vec<usize>,
470}
471
472impl GradientFunction for MaxPool2dBackward {
473 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
474 let g_vec = grad_output.to_vec();
475 let input_size: usize = self.input_shape.iter().product();
476 let mut grad_input = vec![0.0f32; input_size];
477
478 for (i, &idx) in self.max_indices.iter().enumerate() {
479 if i < g_vec.len() {
480 grad_input[idx] += g_vec[i];
481 }
482 }
483
484 let gi = Tensor::from_vec(grad_input, &self.input_shape).unwrap();
485 vec![Some(gi)]
486 }
487
488 fn name(&self) -> &'static str {
489 "MaxPool2dBackward"
490 }
491
492 fn next_functions(&self) -> &[Option<GradFn>] {
493 &self.next_fns
494 }
495
496 fn as_any(&self) -> &dyn Any {
497 self
498 }
499}
500
501#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_lenet_creation() {
511 let model = LeNet::new();
512 let params = model.parameters();
513
514 assert!(!params.is_empty());
516 }
517
518 #[test]
519 fn test_lenet_forward() {
520 let model = LeNet::new();
521
522 let input = Variable::new(
524 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
525 false,
526 );
527
528 let output = model.forward(&input);
529 assert_eq!(output.data().shape(), &[2, 10]);
530 }
531
532 #[test]
533 fn test_simple_cnn_mnist() {
534 let model = SimpleCNN::for_mnist();
535
536 let input = Variable::new(
537 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
538 false,
539 );
540
541 let output = model.forward(&input);
542 assert_eq!(output.data().shape(), &[2, 10]);
543 }
544
545 #[test]
546 fn test_mlp_mnist() {
547 let model = MLP::for_mnist();
548
549 let input = Variable::new(
551 Tensor::from_vec(vec![0.5; 2 * 784], &[2, 784]).unwrap(),
552 false,
553 );
554
555 let output = model.forward(&input);
556 assert_eq!(output.data().shape(), &[2, 10]);
557 }
558
559 #[test]
560 fn test_mlp_auto_flatten() {
561 let model = MLP::for_mnist();
562
563 let input = Variable::new(
565 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
566 false,
567 );
568
569 let output = model.forward(&input);
570 assert_eq!(output.data().shape(), &[2, 10]);
571 }
572
573 #[test]
574 fn test_lenet_parameter_count() {
575 let model = LeNet::new();
576 let params = model.parameters();
577
578 let total: usize = params
580 .iter()
581 .map(|p| p.variable().data().to_vec().len())
582 .sum();
583
584 assert!(total > 40000 && total < 100000);
586 }
587}