mnist_training/
mnist_training.rs1use axonml::prelude::*;
18use std::time::Instant;
19
20fn main() {
21 println!("=== AxonML - MNIST Training (LeNet) ===\n");
22
23 #[cfg(feature = "cuda")]
25 let device = {
26 let cuda = Device::Cuda(0);
27 if cuda.is_available() {
28 println!("GPU detected: using CUDA device 0");
29 cuda
30 } else {
31 println!("CUDA feature enabled but no GPU available, using CPU");
32 Device::Cpu
33 }
34 };
35 #[cfg(not(feature = "cuda"))]
36 let device = {
37 println!("Using CPU (compile with --features cuda for GPU)");
38 Device::Cpu
39 };
40
41 let num_train = 2000;
43 let num_test = 400;
44 println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
45 let train_dataset = SyntheticMNIST::new(num_train);
46 let test_dataset = SyntheticMNIST::new(num_test);
47
48 let batch_size = 64;
50 println!("2. Creating DataLoader (batch_size={batch_size})...");
51 let train_loader = DataLoader::new(train_dataset, batch_size);
52 let test_loader = DataLoader::new(test_dataset, batch_size);
53 println!(" Training batches: {}", train_loader.len());
54
55 println!("3. Creating LeNet model...");
57 let model = LeNet::new();
58 model.to_device(device);
59 let params = model.parameters();
60 let total_params: usize = params
61 .iter()
62 .map(|p| p.variable().data().to_vec().len())
63 .sum();
64 println!(
65 " Parameters: {} ({} total weights)",
66 params.len(),
67 total_params
68 );
69 println!(" Device: {:?}", device);
70
71 println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
73 let mut optimizer = Adam::new(params, 0.001);
74 let criterion = CrossEntropyLoss::new();
75
76 let epochs = 10;
78 println!("5. Training for {epochs} epochs...\n");
79
80 let train_start = Instant::now();
81
82 for epoch in 0..epochs {
83 let epoch_start = Instant::now();
84 let mut total_loss = 0.0;
85 let mut correct = 0usize;
86 let mut total = 0usize;
87 let mut batch_count = 0;
88
89 for batch in train_loader.iter() {
90 let bs = batch.data.shape()[0];
91
92 let input_data = batch.data.to_vec();
94 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
95 let input = Variable::new(
96 if device.is_gpu() {
97 input_tensor.to_device(device).unwrap()
98 } else {
99 input_tensor
100 },
101 true,
102 );
103
104 let target_onehot = batch.targets.to_vec();
106 let mut target_indices = vec![0.0f32; bs];
107 for i in 0..bs {
108 let offset = i * 10;
109 let mut max_idx = 0;
110 let mut max_val = f32::NEG_INFINITY;
111 for c in 0..10 {
112 if target_onehot[offset + c] > max_val {
113 max_val = target_onehot[offset + c];
114 max_idx = c;
115 }
116 }
117 target_indices[i] = max_idx as f32;
118 }
119 let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
120 let target = Variable::new(
121 if device.is_gpu() {
122 target_tensor.to_device(device).unwrap()
123 } else {
124 target_tensor
125 },
126 false,
127 );
128
129 let output = model.forward(&input);
131
132 let loss = criterion.compute(&output, &target);
134
135 let loss_val = loss.data().to_vec()[0];
136 total_loss += loss_val;
137 batch_count += 1;
138
139 let out_data = output.data().to_vec();
141 for i in 0..bs {
142 let offset = i * 10;
143 let mut pred = 0;
144 let mut pred_val = f32::NEG_INFINITY;
145 for c in 0..10 {
146 if out_data[offset + c] > pred_val {
147 pred_val = out_data[offset + c];
148 pred = c;
149 }
150 }
151 if pred == target_indices[i] as usize {
152 correct += 1;
153 }
154 total += 1;
155 }
156
157 loss.backward();
159
160 optimizer.step();
162 optimizer.zero_grad();
163 }
164
165 let epoch_time = epoch_start.elapsed();
166 let avg_loss = total_loss / batch_count as f32;
167 let accuracy = 100.0 * correct as f32 / total as f32;
168 let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
169
170 println!(
171 " Epoch {:2}/{}: Loss={:.4} Acc={:.1}% ({:.0} samples/s, {:.2}s)",
172 epoch + 1,
173 epochs,
174 avg_loss,
175 accuracy,
176 samples_per_sec,
177 epoch_time.as_secs_f64(),
178 );
179 }
180
181 let train_time = train_start.elapsed();
182 println!("\n Total training time: {:.2}s", train_time.as_secs_f64());
183
184 println!("\n6. Evaluating on test set...");
186
187 let (correct, total) = no_grad(|| {
189 let mut correct = 0usize;
190 let mut total = 0usize;
191
192 for batch in test_loader.iter() {
193 let bs = batch.data.shape()[0];
194
195 let input_data = batch.data.to_vec();
196 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
197 let input = Variable::new(
198 if device.is_gpu() {
199 input_tensor.to_device(device).unwrap()
200 } else {
201 input_tensor
202 },
203 false,
204 );
205
206 let target_onehot = batch.targets.to_vec();
207 let output = model.forward(&input);
208 let out_data = output.data().to_vec();
209
210 for i in 0..bs {
211 let offset = i * 10;
213 let mut pred = 0;
214 let mut pred_val = f32::NEG_INFINITY;
215 for c in 0..10 {
216 if out_data[offset + c] > pred_val {
217 pred_val = out_data[offset + c];
218 pred = c;
219 }
220 }
221
222 let mut true_label = 0;
224 let mut true_val = f32::NEG_INFINITY;
225 for c in 0..10 {
226 if target_onehot[i * 10 + c] > true_val {
227 true_val = target_onehot[i * 10 + c];
228 true_label = c;
229 }
230 }
231
232 if pred == true_label {
233 correct += 1;
234 }
235 total += 1;
236 }
237 }
238
239 (correct, total)
240 });
241
242 let test_accuracy = 100.0 * correct as f32 / total as f32;
243 println!(
244 " Test Accuracy: {}/{} ({:.2}%)",
245 correct, total, test_accuracy
246 );
247
248 println!("\n=== Training Complete! ===");
249 println!(" Device: {:?}", device);
250 println!(" Final test accuracy: {:.2}%", test_accuracy);
251}