Skip to main content

mnist_training/
mnist_training.rs

1//! MNIST Training Example - LeNet on SyntheticMNIST with GPU support
2//!
3//! # File
4//! `crates/axonml/examples/mnist_training.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 19, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml::prelude::*;
18use std::time::Instant;
19
20fn main() {
21    println!("=== AxonML - MNIST Training (LeNet) ===\n");
22
23    // Detect device
24    #[cfg(feature = "cuda")]
25    let device = {
26        let cuda = Device::Cuda(0);
27        if cuda.is_available() {
28            println!("GPU detected: using CUDA device 0");
29            cuda
30        } else {
31            println!("CUDA feature enabled but no GPU available, using CPU");
32            Device::Cpu
33        }
34    };
35    #[cfg(not(feature = "cuda"))]
36    let device = {
37        println!("Using CPU (compile with --features cuda for GPU)");
38        Device::Cpu
39    };
40
41    // 1. Create dataset
42    let num_train = 2000;
43    let num_test = 400;
44    println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
45    let train_dataset = SyntheticMNIST::new(num_train);
46    let test_dataset = SyntheticMNIST::new(num_test);
47
48    // 2. Create DataLoader
49    let batch_size = 64;
50    println!("2. Creating DataLoader (batch_size={batch_size})...");
51    let train_loader = DataLoader::new(train_dataset, batch_size);
52    let test_loader = DataLoader::new(test_dataset, batch_size);
53    println!("   Training batches: {}", train_loader.len());
54
55    // 3. Create LeNet model and move to device
56    println!("3. Creating LeNet model...");
57    let model = LeNet::new();
58    model.to_device(device);
59    let params = model.parameters();
60    let total_params: usize = params
61        .iter()
62        .map(|p| p.variable().data().to_vec().len())
63        .sum();
64    println!(
65        "   Parameters: {} ({} total weights)",
66        params.len(),
67        total_params
68    );
69    println!("   Device: {:?}", device);
70
71    // 4. Create optimizer and loss
72    println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
73    let mut optimizer = Adam::new(params, 0.001);
74    let criterion = CrossEntropyLoss::new();
75
76    // 5. Training loop
77    let epochs = 10;
78    println!("5. Training for {epochs} epochs...\n");
79
80    let train_start = Instant::now();
81
82    for epoch in 0..epochs {
83        let epoch_start = Instant::now();
84        let mut total_loss = 0.0;
85        let mut correct = 0usize;
86        let mut total = 0usize;
87        let mut batch_count = 0;
88
89        for batch in train_loader.iter() {
90            let bs = batch.data.shape()[0];
91
92            // Reshape to [N, 1, 28, 28] and create Variable
93            let input_data = batch.data.to_vec();
94            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
95            let input = Variable::new(
96                if device.is_gpu() {
97                    input_tensor.to_device(device).unwrap()
98                } else {
99                    input_tensor
100                },
101                true,
102            );
103
104            // Target: convert one-hot [N, 10] to class indices [N]
105            let target_onehot = batch.targets.to_vec();
106            let mut target_indices = vec![0.0f32; bs];
107            for i in 0..bs {
108                let offset = i * 10;
109                let mut max_idx = 0;
110                let mut max_val = f32::NEG_INFINITY;
111                for c in 0..10 {
112                    if target_onehot[offset + c] > max_val {
113                        max_val = target_onehot[offset + c];
114                        max_idx = c;
115                    }
116                }
117                target_indices[i] = max_idx as f32;
118            }
119            let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
120            let target = Variable::new(
121                if device.is_gpu() {
122                    target_tensor.to_device(device).unwrap()
123                } else {
124                    target_tensor
125                },
126                false,
127            );
128
129            // Forward pass
130            let output = model.forward(&input);
131
132            // Cross-entropy loss
133            let loss = criterion.compute(&output, &target);
134
135            let loss_val = loss.data().to_vec()[0];
136            total_loss += loss_val;
137            batch_count += 1;
138
139            // Compute training accuracy
140            let out_data = output.data().to_vec();
141            for i in 0..bs {
142                let offset = i * 10;
143                let mut pred = 0;
144                let mut pred_val = f32::NEG_INFINITY;
145                for c in 0..10 {
146                    if out_data[offset + c] > pred_val {
147                        pred_val = out_data[offset + c];
148                        pred = c;
149                    }
150                }
151                if pred == target_indices[i] as usize {
152                    correct += 1;
153                }
154                total += 1;
155            }
156
157            // Backward pass
158            loss.backward();
159
160            // Update weights
161            optimizer.step();
162            optimizer.zero_grad();
163        }
164
165        let epoch_time = epoch_start.elapsed();
166        let avg_loss = total_loss / batch_count as f32;
167        let accuracy = 100.0 * correct as f32 / total as f32;
168        let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
169
170        println!(
171            "   Epoch {:2}/{}: Loss={:.4}  Acc={:.1}%  ({:.0} samples/s, {:.2}s)",
172            epoch + 1,
173            epochs,
174            avg_loss,
175            accuracy,
176            samples_per_sec,
177            epoch_time.as_secs_f64(),
178        );
179    }
180
181    let train_time = train_start.elapsed();
182    println!("\n   Total training time: {:.2}s", train_time.as_secs_f64());
183
184    // 6. Test evaluation
185    println!("\n6. Evaluating on test set...");
186
187    // Disable gradient computation for evaluation
188    let (correct, total) = no_grad(|| {
189        let mut correct = 0usize;
190        let mut total = 0usize;
191
192        for batch in test_loader.iter() {
193            let bs = batch.data.shape()[0];
194
195            let input_data = batch.data.to_vec();
196            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
197            let input = Variable::new(
198                if device.is_gpu() {
199                    input_tensor.to_device(device).unwrap()
200                } else {
201                    input_tensor
202                },
203                false,
204            );
205
206            let target_onehot = batch.targets.to_vec();
207            let output = model.forward(&input);
208            let out_data = output.data().to_vec();
209
210            for i in 0..bs {
211                // Prediction: argmax of output
212                let offset = i * 10;
213                let mut pred = 0;
214                let mut pred_val = f32::NEG_INFINITY;
215                for c in 0..10 {
216                    if out_data[offset + c] > pred_val {
217                        pred_val = out_data[offset + c];
218                        pred = c;
219                    }
220                }
221
222                // True label: argmax of one-hot target
223                let mut true_label = 0;
224                let mut true_val = f32::NEG_INFINITY;
225                for c in 0..10 {
226                    if target_onehot[i * 10 + c] > true_val {
227                        true_val = target_onehot[i * 10 + c];
228                        true_label = c;
229                    }
230                }
231
232                if pred == true_label {
233                    correct += 1;
234                }
235                total += 1;
236            }
237        }
238
239        (correct, total)
240    });
241
242    let test_accuracy = 100.0 * correct as f32 / total as f32;
243    println!(
244        "   Test Accuracy: {}/{} ({:.2}%)",
245        correct, total, test_accuracy
246    );
247
248    println!("\n=== Training Complete! ===");
249    println!("   Device: {:?}", device);
250    println!("   Final test accuracy: {:.2}%", test_accuracy);
251}