Skip to main content

axonml_vision/models/
lenet.rs

1//! `LeNet` - Classic CNN Architecture
2//!
3//! Implementation of LeNet-5, one of the earliest successful CNNs.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_autograd::Variable;
9use axonml_nn::{Conv2d, Linear, Module, Parameter};
10use axonml_tensor::Tensor;
11
12// =============================================================================
13// LeNet-5
14// =============================================================================
15
16/// LeNet-5 architecture for MNIST digit classification.
17///
18/// Architecture:
19/// - Conv2d(1, 6, 5) -> `ReLU` -> MaxPool2d(2)
20/// - Conv2d(6, 16, 5) -> `ReLU` -> MaxPool2d(2)
21/// - Flatten
22/// - Linear(256, 120) -> `ReLU`
23/// - Linear(120, 84) -> `ReLU`
24/// - Linear(84, 10)
25pub struct LeNet {
26    conv1: Conv2d,
27    conv2: Conv2d,
28    fc1: Linear,
29    fc2: Linear,
30    fc3: Linear,
31}
32
33impl LeNet {
34    /// Creates a new LeNet-5 for MNIST (28x28 input, 10 classes).
35    #[must_use] pub fn new() -> Self {
36        Self {
37            conv1: Conv2d::new(1, 6, 5),       // 28x28 -> 24x24
38            conv2: Conv2d::new(6, 16, 5),      // 12x12 -> 8x8 (after pool)
39            fc1: Linear::new(16 * 4 * 4, 120), // After 2 pools: 8x8 -> 4x4
40            fc2: Linear::new(120, 84),
41            fc3: Linear::new(84, 10),
42        }
43    }
44
45    /// Creates a `LeNet` for CIFAR-10 (32x32 input, 10 classes).
46    #[must_use] pub fn for_cifar10() -> Self {
47        Self {
48            conv1: Conv2d::new(3, 6, 5),       // 32x32 -> 28x28
49            conv2: Conv2d::new(6, 16, 5),      // 14x14 -> 10x10 (after pool)
50            fc1: Linear::new(16 * 5 * 5, 120), // After 2 pools: 10x10 -> 5x5
51            fc2: Linear::new(120, 84),
52            fc3: Linear::new(84, 10),
53        }
54    }
55
56    /// Max pooling 2x2 operation.
57    fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
58        let data = input.data();
59        let shape = data.shape();
60
61        if shape.len() == 4 {
62            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
63            let out_h = h / kernel_size;
64            let out_w = w / kernel_size;
65
66            let data_vec = data.to_vec();
67            let mut result = vec![0.0f32; n * c * out_h * out_w];
68
69            for batch in 0..n {
70                for ch in 0..c {
71                    for oh in 0..out_h {
72                        for ow in 0..out_w {
73                            let mut max_val = f32::NEG_INFINITY;
74                            for kh in 0..kernel_size {
75                                for kw in 0..kernel_size {
76                                    let ih = oh * kernel_size + kh;
77                                    let iw = ow * kernel_size + kw;
78                                    let idx = batch * c * h * w + ch * h * w + ih * w + iw;
79                                    max_val = max_val.max(data_vec[idx]);
80                                }
81                            }
82                            let out_idx =
83                                batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
84                            result[out_idx] = max_val;
85                        }
86                    }
87                }
88            }
89
90            Variable::new(
91                Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
92                input.requires_grad(),
93            )
94        } else if shape.len() == 3 {
95            // Single image without batch
96            let (c, h, w) = (shape[0], shape[1], shape[2]);
97            let out_h = h / kernel_size;
98            let out_w = w / kernel_size;
99
100            let data_vec = data.to_vec();
101            let mut result = vec![0.0f32; c * out_h * out_w];
102
103            for ch in 0..c {
104                for oh in 0..out_h {
105                    for ow in 0..out_w {
106                        let mut max_val = f32::NEG_INFINITY;
107                        for kh in 0..kernel_size {
108                            for kw in 0..kernel_size {
109                                let ih = oh * kernel_size + kh;
110                                let iw = ow * kernel_size + kw;
111                                let idx = ch * h * w + ih * w + iw;
112                                max_val = max_val.max(data_vec[idx]);
113                            }
114                        }
115                        let out_idx = ch * out_h * out_w + oh * out_w + ow;
116                        result[out_idx] = max_val;
117                    }
118                }
119            }
120
121            Variable::new(
122                Tensor::from_vec(result, &[c, out_h, out_w]).unwrap(),
123                input.requires_grad(),
124            )
125        } else {
126            input.clone()
127        }
128    }
129
130    /// Flattens a tensor to 2D (batch, features).
131    fn flatten(&self, input: &Variable) -> Variable {
132        let data = input.data();
133        let shape = data.shape();
134
135        if shape.len() <= 2 {
136            return input.clone();
137        }
138
139        let batch_size = shape[0];
140        let features: usize = shape[1..].iter().product();
141
142        Variable::new(
143            Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
144            input.requires_grad(),
145        )
146    }
147}
148
149impl Default for LeNet {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl Module for LeNet {
156    fn forward(&self, input: &Variable) -> Variable {
157        // Conv1 -> ReLU -> Pool
158        let x = self.conv1.forward(input);
159        let x = x.relu();
160        let x = self.max_pool2d(&x, 2);
161
162        // Conv2 -> ReLU -> Pool
163        let x = self.conv2.forward(&x);
164        let x = x.relu();
165        let x = self.max_pool2d(&x, 2);
166
167        // Flatten
168        let x = self.flatten(&x);
169
170        // FC layers
171        let x = self.fc1.forward(&x);
172        let x = x.relu();
173        let x = self.fc2.forward(&x);
174        let x = x.relu();
175        self.fc3.forward(&x)
176    }
177
178    fn parameters(&self) -> Vec<Parameter> {
179        let mut params = Vec::new();
180        params.extend(self.conv1.parameters());
181        params.extend(self.conv2.parameters());
182        params.extend(self.fc1.parameters());
183        params.extend(self.fc2.parameters());
184        params.extend(self.fc3.parameters());
185        params
186    }
187
188    fn train(&mut self) {
189        // LeNet has no training-mode-specific behavior
190    }
191
192    fn eval(&mut self) {
193        // LeNet has no eval-mode-specific behavior
194    }
195}
196
197// =============================================================================
198// SimpleCNN
199// =============================================================================
200
201/// A simple CNN for quick experiments.
202pub struct SimpleCNN {
203    conv1: Conv2d,
204    fc1: Linear,
205    fc2: Linear,
206    input_channels: usize,
207    num_classes: usize,
208}
209
210impl SimpleCNN {
211    /// Creates a new `SimpleCNN`.
212    /// Note: Conv2d with kernel 3 and no padding: 28-3+1=26, after pool: 13
213    #[must_use] pub fn new(input_channels: usize, num_classes: usize) -> Self {
214        Self {
215            conv1: Conv2d::new(input_channels, 32, 3),
216            fc1: Linear::new(32 * 13 * 13, 128), // 28x28 -> 26x26 (conv) -> 13x13 (pool)
217            fc2: Linear::new(128, num_classes),
218            input_channels,
219            num_classes,
220        }
221    }
222
223    /// Creates a `SimpleCNN` for MNIST.
224    #[must_use] pub fn for_mnist() -> Self {
225        Self::new(1, 10)
226    }
227
228    /// Creates a `SimpleCNN` for CIFAR-10.
229    #[must_use] pub fn for_cifar10() -> Self {
230        // 32x32 -> 30x30 (conv with k=3) -> 15x15 (pool)
231        Self {
232            conv1: Conv2d::new(3, 32, 3),
233            fc1: Linear::new(32 * 15 * 15, 128),
234            fc2: Linear::new(128, 10),
235            input_channels: 3,
236            num_classes: 10,
237        }
238    }
239
240    /// Returns the number of input channels.
241    #[must_use] pub fn input_channels(&self) -> usize {
242        self.input_channels
243    }
244
245    /// Returns the number of classes.
246    #[must_use] pub fn num_classes(&self) -> usize {
247        self.num_classes
248    }
249
250    fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
251        let data = input.data();
252        let shape = data.shape();
253
254        if shape.len() != 4 {
255            return input.clone();
256        }
257
258        let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
259        let out_h = h / kernel_size;
260        let out_w = w / kernel_size;
261
262        let data_vec = data.to_vec();
263        let mut result = vec![0.0f32; n * c * out_h * out_w];
264
265        for batch in 0..n {
266            for ch in 0..c {
267                for oh in 0..out_h {
268                    for ow in 0..out_w {
269                        let mut max_val = f32::NEG_INFINITY;
270                        for kh in 0..kernel_size {
271                            for kw in 0..kernel_size {
272                                let ih = oh * kernel_size + kh;
273                                let iw = ow * kernel_size + kw;
274                                let idx = batch * c * h * w + ch * h * w + ih * w + iw;
275                                max_val = max_val.max(data_vec[idx]);
276                            }
277                        }
278                        let out_idx =
279                            batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
280                        result[out_idx] = max_val;
281                    }
282                }
283            }
284        }
285
286        Variable::new(
287            Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
288            input.requires_grad(),
289        )
290    }
291
292    fn flatten(&self, input: &Variable) -> Variable {
293        let data = input.data();
294        let shape = data.shape();
295
296        if shape.len() <= 2 {
297            return input.clone();
298        }
299
300        let batch_size = shape[0];
301        let features: usize = shape[1..].iter().product();
302
303        Variable::new(
304            Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
305            input.requires_grad(),
306        )
307    }
308}
309
310impl Module for SimpleCNN {
311    fn forward(&self, input: &Variable) -> Variable {
312        let x = self.conv1.forward(input);
313        let x = x.relu();
314        let x = self.max_pool2d(&x, 2);
315        let x = self.flatten(&x);
316        let x = self.fc1.forward(&x);
317        let x = x.relu();
318        self.fc2.forward(&x)
319    }
320
321    fn parameters(&self) -> Vec<Parameter> {
322        let mut params = Vec::new();
323        params.extend(self.conv1.parameters());
324        params.extend(self.fc1.parameters());
325        params.extend(self.fc2.parameters());
326        params
327    }
328
329    fn train(&mut self) {}
330    fn eval(&mut self) {}
331}
332
333// =============================================================================
334// MLP for classification
335// =============================================================================
336
337/// A simple MLP for classification (flattened input).
338pub struct MLP {
339    fc1: Linear,
340    fc2: Linear,
341    fc3: Linear,
342}
343
344impl MLP {
345    /// Creates a new MLP.
346    #[must_use] pub fn new(input_size: usize, hidden_size: usize, num_classes: usize) -> Self {
347        Self {
348            fc1: Linear::new(input_size, hidden_size),
349            fc2: Linear::new(hidden_size, hidden_size / 2),
350            fc3: Linear::new(hidden_size / 2, num_classes),
351        }
352    }
353
354    /// Creates an MLP for MNIST (784 -> 256 -> 128 -> 10).
355    #[must_use] pub fn for_mnist() -> Self {
356        Self::new(784, 256, 10)
357    }
358
359    /// Creates an MLP for CIFAR-10 (3072 -> 512 -> 256 -> 10).
360    #[must_use] pub fn for_cifar10() -> Self {
361        Self::new(3072, 512, 10)
362    }
363}
364
365impl Module for MLP {
366    fn forward(&self, input: &Variable) -> Variable {
367        // Flatten if needed
368        let data = input.data();
369        let shape = data.shape();
370        let x = if shape.len() > 2 {
371            let batch = shape[0];
372            let features: usize = shape[1..].iter().product();
373            Variable::new(
374                Tensor::from_vec(data.to_vec(), &[batch, features]).unwrap(),
375                input.requires_grad(),
376            )
377        } else if shape.len() == 1 {
378            // Add batch dimension
379            Variable::new(
380                Tensor::from_vec(data.to_vec(), &[1, shape[0]]).unwrap(),
381                input.requires_grad(),
382            )
383        } else {
384            input.clone()
385        };
386
387        let x = self.fc1.forward(&x);
388        let x = x.relu();
389        let x = self.fc2.forward(&x);
390        let x = x.relu();
391        self.fc3.forward(&x)
392    }
393
394    fn parameters(&self) -> Vec<Parameter> {
395        let mut params = Vec::new();
396        params.extend(self.fc1.parameters());
397        params.extend(self.fc2.parameters());
398        params.extend(self.fc3.parameters());
399        params
400    }
401
402    fn train(&mut self) {}
403    fn eval(&mut self) {}
404}
405
406// =============================================================================
407// Tests
408// =============================================================================
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_lenet_creation() {
416        let model = LeNet::new();
417        let params = model.parameters();
418
419        // Should have parameters from 2 conv + 3 fc layers
420        assert!(!params.is_empty());
421    }
422
423    #[test]
424    fn test_lenet_forward() {
425        let model = LeNet::new();
426
427        // Create a batch of 2 MNIST images
428        let input = Variable::new(
429            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
430            false,
431        );
432
433        let output = model.forward(&input);
434        assert_eq!(output.data().shape(), &[2, 10]);
435    }
436
437    #[test]
438    fn test_simple_cnn_mnist() {
439        let model = SimpleCNN::for_mnist();
440
441        let input = Variable::new(
442            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
443            false,
444        );
445
446        let output = model.forward(&input);
447        assert_eq!(output.data().shape(), &[2, 10]);
448    }
449
450    #[test]
451    fn test_mlp_mnist() {
452        let model = MLP::for_mnist();
453
454        // Flattened MNIST input
455        let input = Variable::new(
456            Tensor::from_vec(vec![0.5; 2 * 784], &[2, 784]).unwrap(),
457            false,
458        );
459
460        let output = model.forward(&input);
461        assert_eq!(output.data().shape(), &[2, 10]);
462    }
463
464    #[test]
465    fn test_mlp_auto_flatten() {
466        let model = MLP::for_mnist();
467
468        // 4D input (like image)
469        let input = Variable::new(
470            Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
471            false,
472        );
473
474        let output = model.forward(&input);
475        assert_eq!(output.data().shape(), &[2, 10]);
476    }
477
478    #[test]
479    fn test_lenet_parameter_count() {
480        let model = LeNet::new();
481        let params = model.parameters();
482
483        // Count total parameters
484        let total: usize = params
485            .iter()
486            .map(|p| p.variable().data().to_vec().len())
487            .sum();
488
489        // LeNet-5 should have around 44k parameters for MNIST
490        assert!(total > 40000 && total < 100000);
491    }
492}