mnist_training/
mnist_training.rs1use axonml::prelude::*;
28use std::time::Instant;
29
30fn main() {
35 println!("=== AxonML - MNIST Training (LeNet) ===\n");
36
37 #[cfg(feature = "cuda")]
43 let device = {
44 let cuda = Device::Cuda(0);
45 if cuda.is_available() {
46 println!("GPU detected: using CUDA device 0");
47 cuda
48 } else {
49 println!("CUDA feature enabled but no GPU available, using CPU");
50 Device::Cpu
51 }
52 };
53 #[cfg(not(feature = "cuda"))]
54 let device = {
55 println!("Using CPU (compile with --features cuda for GPU)");
56 Device::Cpu
57 };
58
59 let num_train = 2000;
65 let num_test = 400;
66 println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
67 let train_dataset = SyntheticMNIST::new(num_train);
68 let test_dataset = SyntheticMNIST::new(num_test);
69
70 let batch_size = 64;
72 println!("2. Creating DataLoader (batch_size={batch_size})...");
73 let train_loader = DataLoader::new(train_dataset, batch_size);
74 let test_loader = DataLoader::new(test_dataset, batch_size);
75 println!(" Training batches: {}", train_loader.len());
76
77 println!("3. Creating LeNet model...");
83 let model = LeNet::new();
84 model.to_device(device);
85 let params = model.parameters();
86 let total_params: usize = params
87 .iter()
88 .map(|p| p.variable().data().to_vec().len())
89 .sum();
90 println!(
91 " Parameters: {} ({} total weights)",
92 params.len(),
93 total_params
94 );
95 println!(" Device: {:?}", device);
96
97 println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
99 let mut optimizer = Adam::new(params, 0.001);
100 let criterion = CrossEntropyLoss::new();
101
102 let epochs = 10;
108 println!("5. Training for {epochs} epochs...\n");
109
110 let train_start = Instant::now();
111
112 for epoch in 0..epochs {
113 let epoch_start = Instant::now();
114 let mut total_loss = 0.0;
115 let mut correct = 0usize;
116 let mut total = 0usize;
117 let mut batch_count = 0;
118
119 for batch in train_loader.iter() {
120 let bs = batch.data.shape()[0];
121
122 let input_data = batch.data.to_vec();
124 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
125 let input = Variable::new(
126 if device.is_gpu() {
127 input_tensor.to_device(device).unwrap()
128 } else {
129 input_tensor
130 },
131 true,
132 );
133
134 let target_onehot = batch.targets.to_vec();
136 let mut target_indices = vec![0.0f32; bs];
137 for i in 0..bs {
138 let offset = i * 10;
139 let mut max_idx = 0;
140 let mut max_val = f32::NEG_INFINITY;
141 for c in 0..10 {
142 if target_onehot[offset + c] > max_val {
143 max_val = target_onehot[offset + c];
144 max_idx = c;
145 }
146 }
147 target_indices[i] = max_idx as f32;
148 }
149 let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
150 let target = Variable::new(
151 if device.is_gpu() {
152 target_tensor.to_device(device).unwrap()
153 } else {
154 target_tensor
155 },
156 false,
157 );
158
159 let output = model.forward(&input);
161
162 let loss = criterion.compute(&output, &target);
164
165 let loss_val = loss.data().to_vec()[0];
166 total_loss += loss_val;
167 batch_count += 1;
168
169 let out_data = output.data().to_vec();
171 for i in 0..bs {
172 let offset = i * 10;
173 let mut pred = 0;
174 let mut pred_val = f32::NEG_INFINITY;
175 for c in 0..10 {
176 if out_data[offset + c] > pred_val {
177 pred_val = out_data[offset + c];
178 pred = c;
179 }
180 }
181 if pred == target_indices[i] as usize {
182 correct += 1;
183 }
184 total += 1;
185 }
186
187 loss.backward();
189
190 optimizer.step();
192 optimizer.zero_grad();
193 }
194
195 let epoch_time = epoch_start.elapsed();
196 let avg_loss = total_loss / batch_count as f32;
197 let accuracy = 100.0 * correct as f32 / total as f32;
198 let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
199
200 println!(
201 " Epoch {:2}/{}: Loss={:.4} Acc={:.1}% ({:.0} samples/s, {:.2}s)",
202 epoch + 1,
203 epochs,
204 avg_loss,
205 accuracy,
206 samples_per_sec,
207 epoch_time.as_secs_f64(),
208 );
209 }
210
211 let train_time = train_start.elapsed();
212 println!("\n Total training time: {:.2}s", train_time.as_secs_f64());
213
214 println!("\n6. Evaluating on test set...");
220
221 let (correct, total) = no_grad(|| {
223 let mut correct = 0usize;
224 let mut total = 0usize;
225
226 for batch in test_loader.iter() {
227 let bs = batch.data.shape()[0];
228
229 let input_data = batch.data.to_vec();
230 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
231 let input = Variable::new(
232 if device.is_gpu() {
233 input_tensor.to_device(device).unwrap()
234 } else {
235 input_tensor
236 },
237 false,
238 );
239
240 let target_onehot = batch.targets.to_vec();
241 let output = model.forward(&input);
242 let out_data = output.data().to_vec();
243
244 for i in 0..bs {
245 let offset = i * 10;
247 let mut pred = 0;
248 let mut pred_val = f32::NEG_INFINITY;
249 for c in 0..10 {
250 if out_data[offset + c] > pred_val {
251 pred_val = out_data[offset + c];
252 pred = c;
253 }
254 }
255
256 let mut true_label = 0;
258 let mut true_val = f32::NEG_INFINITY;
259 for c in 0..10 {
260 if target_onehot[i * 10 + c] > true_val {
261 true_val = target_onehot[i * 10 + c];
262 true_label = c;
263 }
264 }
265
266 if pred == true_label {
267 correct += 1;
268 }
269 total += 1;
270 }
271 }
272
273 (correct, total)
274 });
275
276 let test_accuracy = 100.0 * correct as f32 / total as f32;
277 println!(
278 " Test Accuracy: {}/{} ({:.2}%)",
279 correct, total, test_accuracy
280 );
281
282 println!("\n=== Training Complete! ===");
283 println!(" Device: {:?}", device);
284 println!(" Final test accuracy: {:.2}%", test_accuracy);
285}