Skip to main content

axonml_vision/models/
lenet.rs

1//! LeNet-5, SimpleCNN, MLP — Classic Classification Architectures
2//!
3//! Three lightweight classification models. `LeNet` implements LeNet-5: two
4//! Conv2d layers with ReLU + MaxPool2d, then three Linear layers (256->120->84->10).
5//! Includes differentiable `MaxPool2dBackward` gradient function. `SimpleCNN` is
6//! a minimal single-conv model (Conv2d->ReLU->MaxPool->Linear->ReLU->Linear).
7//! `MLP` is a three-layer fully-connected network with auto-flattening for image
8//! inputs. All implement `Module` with factory methods for MNIST (1ch, 28x28) and
9//! CIFAR-10 (3ch, 32x32) configurations.
10//!
11//! # File
12//! `crates/axonml-vision/src/models/lenet.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 16, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26use std::any::Any;
27
28use axonml_autograd::no_grad::is_grad_enabled;
29use axonml_autograd::{GradFn, GradientFunction, Variable};
30use axonml_nn::{Conv2d, Linear, Module, Parameter};
31use axonml_tensor::Tensor;
32
33// =============================================================================
34// LeNet-5
35// =============================================================================
36
37/// LeNet-5 architecture for MNIST digit classification.
38///
39/// Architecture:
40/// - Conv2d(1, 6, 5) -> `ReLU` -> MaxPool2d(2)
41/// - Conv2d(6, 16, 5) -> `ReLU` -> MaxPool2d(2)
42/// - Flatten
43/// - Linear(256, 120) -> `ReLU`
44/// - Linear(120, 84) -> `ReLU`
45/// - Linear(84, 10)
46pub struct LeNet {
47    conv1: Conv2d,
48    conv2: Conv2d,
49    fc1: Linear,
50    fc2: Linear,
51    fc3: Linear,
52}
53
54impl LeNet {
55    /// Creates a new LeNet-5 for MNIST (28x28 input, 10 classes).
56    #[must_use]
57    pub fn new() -> Self {
58        Self {
59            conv1: Conv2d::new(1, 6, 5),       // 28x28 -> 24x24
60            conv2: Conv2d::new(6, 16, 5),      // 12x12 -> 8x8 (after pool)
61            fc1: Linear::new(16 * 4 * 4, 120), // After 2 pools: 8x8 -> 4x4
62            fc2: Linear::new(120, 84),
63            fc3: Linear::new(84, 10),
64        }
65    }
66
67    /// Creates a `LeNet` for CIFAR-10 (32x32 input, 10 classes).
68    #[must_use]
69    pub fn for_cifar10() -> Self {
70        Self {
71            conv1: Conv2d::new(3, 6, 5),       // 32x32 -> 28x28
72            conv2: Conv2d::new(6, 16, 5),      // 14x14 -> 10x10 (after pool)
73            fc1: Linear::new(16 * 5 * 5, 120), // After 2 pools: 10x10 -> 5x5
74            fc2: Linear::new(120, 84),
75            fc3: Linear::new(84, 10),
76        }
77    }
78
79    /// Max pooling 2x2 operation.
80    fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
81        let data = input.data();
82        let shape = data.shape();
83
84        if shape.len() == 4 {
85            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
86            let out_h = h / kernel_size;
87            let out_w = w / kernel_size;
88
89            let data_vec = data.to_vec();
90            let out_size = n * c * out_h * out_w;
91            let mut result = vec![0.0f32; out_size];
92            let mut max_indices = vec![0usize; out_size];
93
94            for batch in 0..n {
95                for ch in 0..c {
96                    for oh in 0..out_h {
97                        for ow in 0..out_w {
98                            let mut max_val = f32::NEG_INFINITY;
99                            let mut max_idx = 0usize;
100                            for kh in 0..kernel_size {
101                                for kw in 0..kernel_size {
102                                    let ih = oh * kernel_size + kh;
103                                    let iw = ow * kernel_size + kw;
104                                    let idx = batch * c * h * w + ch * h * w + ih * w + iw;
105                                    if data_vec[idx] > max_val {
106                                        max_val = data_vec[idx];
107                                        max_idx = idx;
108                                    }
109                                }
110                            }
111                            let out_idx =
112                                batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
113                            result[out_idx] = max_val;
114                            max_indices[out_idx] = max_idx;
115                        }
116                    }
117                }
118            }
119
120            let output_tensor = Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap();
121            if input.requires_grad() && is_grad_enabled() {
122                let grad_fn = GradFn::new(MaxPool2dBackward {
123                    next_fns: vec![input.grad_fn().cloned()],
124                    max_indices,
125                    input_shape: shape.to_vec(),
126                });
127                Variable::from_operation(output_tensor, grad_fn, true)
128            } else {
129                Variable::new(output_tensor, false)
130            }
131        } else if shape.len() == 3 {
132            // Single image without batch
133            let (c, h, w) = (shape[0], shape[1], shape[2]);
134            let out_h = h / kernel_size;
135            let out_w = w / kernel_size;
136
137            let data_vec = data.to_vec();
138            let out_size = c * out_h * out_w;
139            let mut result = vec![0.0f32; out_size];
140            let mut max_indices = vec![0usize; out_size];
141
142            for ch in 0..c {
143                for oh in 0..out_h {
144                    for ow in 0..out_w {
145                        let mut max_val = f32::NEG_INFINITY;
146                        let mut max_idx = 0usize;
147                        for kh in 0..kernel_size {
148                            for kw in 0..kernel_size {
149                                let ih = oh * kernel_size + kh;
150                                let iw = ow * kernel_size + kw;
151                                let idx = ch * h * w + ih * w + iw;
152                                if data_vec[idx] > max_val {
153                                    max_val = data_vec[idx];
154                                    max_idx = idx;
155                                }
156                            }
157                        }
158                        let out_idx = ch * out_h * out_w + oh * out_w + ow;
159                        result[out_idx] = max_val;
160                        max_indices[out_idx] = max_idx;
161                    }
162                }
163            }
164
165            let output_tensor = Tensor::from_vec(result, &[c, out_h, out_w]).unwrap();
166            if input.requires_grad() && is_grad_enabled() {
167                let grad_fn = GradFn::new(MaxPool2dBackward {
168                    next_fns: vec![input.grad_fn().cloned()],
169                    max_indices,
170                    input_shape: shape.to_vec(),
171                });
172                Variable::from_operation(output_tensor, grad_fn, true)
173            } else {
174                Variable::new(output_tensor, false)
175            }
176        } else {
177            input.clone()
178        }
179    }
180
181    /// Flattens a tensor to 2D (batch, features).
182    /// Uses Variable::reshape() to preserve the autograd graph.
183    fn flatten(&self, input: &Variable) -> Variable {
184        let shape = input.shape();
185
186        if shape.len() <= 2 {
187            return input.clone();
188        }
189
190        let batch_size = shape[0];
191        let features: usize = shape[1..].iter().product();
192
193        input.reshape(&[batch_size, features])
194    }
195}
196
197impl Default for LeNet {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl Module for LeNet {
204    fn forward(&self, input: &Variable) -> Variable {
205        // Conv1 -> ReLU -> Pool
206        let x = self.conv1.forward(input);
207        let x = x.relu();
208        let x = self.max_pool2d(&x, 2);
209
210        // Conv2 -> ReLU -> Pool
211        let x = self.conv2.forward(&x);
212        let x = x.relu();
213        let x = self.max_pool2d(&x, 2);
214
215        // Flatten
216        let x = self.flatten(&x);
217
218        // FC layers
219        let x = self.fc1.forward(&x);
220        let x = x.relu();
221        let x = self.fc2.forward(&x);
222        let x = x.relu();
223        self.fc3.forward(&x)
224    }
225
226    fn parameters(&self) -> Vec<Parameter> {
227        let mut params = Vec::new();
228        params.extend(self.conv1.parameters());
229        params.extend(self.conv2.parameters());
230        params.extend(self.fc1.parameters());
231        params.extend(self.fc2.parameters());
232        params.extend(self.fc3.parameters());
233        params
234    }
235
236    fn train(&mut self) {
237        // LeNet has no training-mode-specific behavior
238    }
239
240    fn eval(&mut self) {
241        // LeNet has no eval-mode-specific behavior
242    }
243}
244
245// =============================================================================
246// SimpleCNN
247// =============================================================================
248
249/// A simple CNN for quick experiments.
250pub struct SimpleCNN {
251    conv1: Conv2d,
252    fc1: Linear,
253    fc2: Linear,
254    input_channels: usize,
255    num_classes: usize,
256}
257
258impl SimpleCNN {
259    /// Creates a new `SimpleCNN`.
260    /// Note: Conv2d with kernel 3 and no padding: 28-3+1=26, after pool: 13
261    #[must_use]
262    pub fn new(input_channels: usize, num_classes: usize) -> Self {
263        Self {
264            conv1: Conv2d::new(input_channels, 32, 3),
265            fc1: Linear::new(32 * 13 * 13, 128), // 28x28 -> 26x26 (conv) -> 13x13 (pool)
266            fc2: Linear::new(128, num_classes),
267            input_channels,
268            num_classes,
269        }
270    }
271
272    /// Creates a `SimpleCNN` for MNIST.
273    #[must_use]
274    pub fn for_mnist() -> Self {
275        Self::new(1, 10)
276    }
277
278    /// Creates a `SimpleCNN` for CIFAR-10.
279    #[must_use]
280    pub fn for_cifar10() -> Self {
281        // 32x32 -> 30x30 (conv with k=3) -> 15x15 (pool)
282        Self {
283            conv1: Conv2d::new(3, 32, 3),
284            fc1: Linear::new(32 * 15 * 15, 128),
285            fc2: Linear::new(128, 10),
286            input_channels: 3,
287            num_classes: 10,
288        }
289    }
290
291    /// Returns the number of input channels.
292    #[must_use]
293    pub fn input_channels(&self) -> usize {
294        self.input_channels
295    }
296
297    /// Returns the number of classes.
298    #[must_use]
299    pub fn num_classes(&self) -> usize {
300        self.num_classes
301    }
302
303    fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
304        let data = input.data();
305        let shape = data.shape();
306
307        if shape.len() != 4 {
308            return input.clone();
309        }
310
311        let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
312        let out_h = h / kernel_size;
313        let out_w = w / kernel_size;
314
315        let data_vec = data.to_vec();
316        let mut result = vec![0.0f32; n * c * out_h * out_w];
317
318        for batch in 0..n {
319            for ch in 0..c {
320                for oh in 0..out_h {
321                    for ow in 0..out_w {
322                        let mut max_val = f32::NEG_INFINITY;
323                        for kh in 0..kernel_size {
324                            for kw in 0..kernel_size {
325                                let ih = oh * kernel_size + kh;
326                                let iw = ow * kernel_size + kw;
327                                let idx = batch * c * h * w + ch * h * w + ih * w + iw;
328                                max_val = max_val.max(data_vec[idx]);
329                            }
330                        }
331                        let out_idx =
332                            batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
333                        result[out_idx] = max_val;
334                    }
335                }
336            }
337        }
338
339        Variable::new(
340            Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
341            input.requires_grad(),
342        )
343    }
344
345    fn flatten(&self, input: &Variable) -> Variable {
346        let shape = input.shape();
347
348        if shape.len() <= 2 {
349            return input.clone();
350        }
351
352        let batch_size = shape[0];
353        let features: usize = shape[1..].iter().product();
354
355        input.reshape(&[batch_size, features])
356    }
357}
358
359impl Module for SimpleCNN {
360    fn forward(&self, input: &Variable) -> Variable {
361        let x = self.conv1.forward(input);
362        let x = x.relu();
363        let x = self.max_pool2d(&x, 2);
364        let x = self.flatten(&x);
365        let x = self.fc1.forward(&x);
366        let x = x.relu();
367        self.fc2.forward(&x)
368    }
369
370    fn parameters(&self) -> Vec<Parameter> {
371        let mut params = Vec::new();
372        params.extend(self.conv1.parameters());
373        params.extend(self.fc1.parameters());
374        params.extend(self.fc2.parameters());
375        params
376    }
377
378    fn train(&mut self) {}
379    fn eval(&mut self) {}
380}
381
382// =============================================================================
383// MLP for classification
384// =============================================================================
385
386/// A simple MLP for classification (flattened input).
387pub struct MLP {
388    fc1: Linear,
389    fc2: Linear,
390    fc3: Linear,
391}
392
393impl MLP {
394    /// Creates a new MLP.
395    #[must_use]
396    pub fn new(input_size: usize, hidden_size: usize, num_classes: usize) -> Self {
397        Self {
398            fc1: Linear::new(input_size, hidden_size),
399            fc2: Linear::new(hidden_size, hidden_size / 2),
400            fc3: Linear::new(hidden_size / 2, num_classes),
401        }
402    }
403
404    /// Creates an MLP for MNIST (784 -> 256 -> 128 -> 10).
405    #[must_use]
406    pub fn for_mnist() -> Self {
407        Self::new(784, 256, 10)
408    }
409
410    /// Creates an MLP for CIFAR-10 (3072 -> 512 -> 256 -> 10).
411    #[must_use]
412    pub fn for_cifar10() -> Self {
413        Self::new(3072, 512, 10)
414    }
415}
416
417impl Module for MLP {
418    fn forward(&self, input: &Variable) -> Variable {
419        // Flatten if needed
420        let data = input.data();
421        let shape = data.shape();
422        let x = if shape.len() > 2 {
423            let batch = shape[0];
424            let features: usize = shape[1..].iter().product();
425            Variable::new(
426                Tensor::from_vec(data.to_vec(), &[batch, features]).unwrap(),
427                input.requires_grad(),
428            )
429        } else if shape.len() == 1 {
430            // Add batch dimension
431            Variable::new(
432                Tensor::from_vec(data.to_vec(), &[1, shape[0]]).unwrap(),
433                input.requires_grad(),
434            )
435        } else {
436            input.clone()
437        };
438
439        let x = self.fc1.forward(&x);
440        let x = x.relu();
441        let x = self.fc2.forward(&x);
442        let x = x.relu();
443        self.fc3.forward(&x)
444    }
445
446    fn parameters(&self) -> Vec<Parameter> {
447        let mut params = Vec::new();
448        params.extend(self.fc1.parameters());
449        params.extend(self.fc2.parameters());
450        params.extend(self.fc3.parameters());
451        params
452    }
453
454    fn train(&mut self) {}
455    fn eval(&mut self) {}
456}
457
458// =============================================================================
459// MaxPool2dBackward
460// =============================================================================
461
462/// Gradient function for MaxPool2d.
463///
464/// Backward pass routes gradient only to the max element in each pooling window.
465#[derive(Debug)]
466struct MaxPool2dBackward {
467    next_fns: Vec<Option<GradFn>>,
468    max_indices: Vec<usize>,
469    input_shape: Vec<usize>,
470}
471
472impl GradientFunction for MaxPool2dBackward {
473    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
474        let g_vec = grad_output.to_vec();
475        let input_size: usize = self.input_shape.iter().product();
476        let mut grad_input = vec![0.0f32; input_size];
477
478        for (i, &idx) in self.max_indices.iter().enumerate() {
479            if i < g_vec.len() {
480                grad_input[idx] += g_vec[i];
481            }
482        }
483
484        let gi = Tensor::from_vec(grad_input, &self.input_shape).unwrap();
485        vec![Some(gi)]
486    }
487
488    fn name(&self) -> &'static str {
489        "MaxPool2dBackward"
490    }
491
492    fn next_functions(&self) -> &[Option<GradFn>] {
493        &self.next_fns
494    }
495
496    fn as_any(&self) -> &dyn Any {
497        self
498    }
499}
500
501// =============================================================================
502// Tests
503// =============================================================================
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_lenet_creation() {
511        let model = LeNet::new();
512        let params = model.parameters();
513
514        // Should have parameters from 2 conv + 3 fc layers
515        assert!(!params.is_empty());
516    }
517
518    #[test]
519    fn test_lenet_forward() {
520        let model = LeNet::new();
521
522        // Create a batch of 2 MNIST images
523        let input = Variable::new(
524            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
525            false,
526        );
527
528        let output = model.forward(&input);
529        assert_eq!(output.data().shape(), &[2, 10]);
530    }
531
532    #[test]
533    fn test_simple_cnn_mnist() {
534        let model = SimpleCNN::for_mnist();
535
536        let input = Variable::new(
537            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
538            false,
539        );
540
541        let output = model.forward(&input);
542        assert_eq!(output.data().shape(), &[2, 10]);
543    }
544
545    #[test]
546    fn test_mlp_mnist() {
547        let model = MLP::for_mnist();
548
549        // Flattened MNIST input
550        let input = Variable::new(
551            Tensor::from_vec(vec![0.5; 2 * 784], &[2, 784]).unwrap(),
552            false,
553        );
554
555        let output = model.forward(&input);
556        assert_eq!(output.data().shape(), &[2, 10]);
557    }
558
559    #[test]
560    fn test_mlp_auto_flatten() {
561        let model = MLP::for_mnist();
562
563        // 4D input (like image)
564        let input = Variable::new(
565            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
566            false,
567        );
568
569        let output = model.forward(&input);
570        assert_eq!(output.data().shape(), &[2, 10]);
571    }
572
573    #[test]
574    fn test_lenet_parameter_count() {
575        let model = LeNet::new();
576        let params = model.parameters();
577
578        // Count total parameters
579        let total: usize = params
580            .iter()
581            .map(|p| p.variable().data().to_vec().len())
582            .sum();
583
584        // LeNet-5 should have around 44k parameters for MNIST
585        assert!(total > 40000 && total < 100000);
586    }
587}