pub struct LeNet { /* private fields */ }Expand description
LeNet-5 architecture for MNIST digit classification.
Architecture:
- Conv2d(1, 6, 5) ->
ReLU-> MaxPool2d(2) - Conv2d(6, 16, 5) ->
ReLU-> MaxPool2d(2) - Flatten
- Linear(256, 120) ->
ReLU - Linear(120, 84) ->
ReLU - Linear(84, 10)
Implementations§
Source§impl LeNet
impl LeNet
Sourcepub fn new() -> LeNet
pub fn new() -> LeNet
Creates a new LeNet-5 for MNIST (28x28 input, 10 classes).
Examples found in repository?
examples/mnist_training.rs (line 83)
34fn main() {
35 println!("=== AxonML - MNIST Training (LeNet) ===\n");
36
37 // -------------------------------------------------------------------------
38 // Device Detection
39 // -------------------------------------------------------------------------
40
41 // Detect device
42 #[cfg(feature = "cuda")]
43 let device = {
44 let cuda = Device::Cuda(0);
45 if cuda.is_available() {
46 println!("GPU detected: using CUDA device 0");
47 cuda
48 } else {
49 println!("CUDA feature enabled but no GPU available, using CPU");
50 Device::Cpu
51 }
52 };
53 #[cfg(not(feature = "cuda"))]
54 let device = {
55 println!("Using CPU (compile with --features cuda for GPU)");
56 Device::Cpu
57 };
58
59 // -------------------------------------------------------------------------
60 // Dataset and DataLoader Setup
61 // -------------------------------------------------------------------------
62
63 // 1. Create dataset
64 let num_train = 2000;
65 let num_test = 400;
66 println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
67 let train_dataset = SyntheticMNIST::new(num_train);
68 let test_dataset = SyntheticMNIST::new(num_test);
69
70 // 2. Create DataLoader
71 let batch_size = 64;
72 println!("2. Creating DataLoader (batch_size={batch_size})...");
73 let train_loader = DataLoader::new(train_dataset, batch_size);
74 let test_loader = DataLoader::new(test_dataset, batch_size);
75 println!(" Training batches: {}", train_loader.len());
76
77 // -------------------------------------------------------------------------
78 // Model, Optimizer, and Loss
79 // -------------------------------------------------------------------------
80
81 // 3. Create LeNet model and move to device
82 println!("3. Creating LeNet model...");
83 let model = LeNet::new();
84 model.to_device(device);
85 let params = model.parameters();
86 let total_params: usize = params
87 .iter()
88 .map(|p| p.variable().data().to_vec().len())
89 .sum();
90 println!(
91 " Parameters: {} ({} total weights)",
92 params.len(),
93 total_params
94 );
95 println!(" Device: {:?}", device);
96
97 // 4. Create optimizer and loss
98 println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
99 let mut optimizer = Adam::new(params, 0.001);
100 let criterion = CrossEntropyLoss::new();
101
102 // -------------------------------------------------------------------------
103 // Training Loop
104 // -------------------------------------------------------------------------
105
106 // 5. Training loop
107 let epochs = 10;
108 println!("5. Training for {epochs} epochs...\n");
109
110 let train_start = Instant::now();
111
112 for epoch in 0..epochs {
113 let epoch_start = Instant::now();
114 let mut total_loss = 0.0;
115 let mut correct = 0usize;
116 let mut total = 0usize;
117 let mut batch_count = 0;
118
119 for batch in train_loader.iter() {
120 let bs = batch.data.shape()[0];
121
122 // Reshape to [N, 1, 28, 28] and create Variable
123 let input_data = batch.data.to_vec();
124 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
125 let input = Variable::new(
126 if device.is_gpu() {
127 input_tensor.to_device(device).unwrap()
128 } else {
129 input_tensor
130 },
131 true,
132 );
133
134 // Target: convert one-hot [N, 10] to class indices [N]
135 let target_onehot = batch.targets.to_vec();
136 let mut target_indices = vec![0.0f32; bs];
137 for i in 0..bs {
138 let offset = i * 10;
139 let mut max_idx = 0;
140 let mut max_val = f32::NEG_INFINITY;
141 for c in 0..10 {
142 if target_onehot[offset + c] > max_val {
143 max_val = target_onehot[offset + c];
144 max_idx = c;
145 }
146 }
147 target_indices[i] = max_idx as f32;
148 }
149 let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
150 let target = Variable::new(
151 if device.is_gpu() {
152 target_tensor.to_device(device).unwrap()
153 } else {
154 target_tensor
155 },
156 false,
157 );
158
159 // Forward pass
160 let output = model.forward(&input);
161
162 // Cross-entropy loss
163 let loss = criterion.compute(&output, &target);
164
165 let loss_val = loss.data().to_vec()[0];
166 total_loss += loss_val;
167 batch_count += 1;
168
169 // Compute training accuracy
170 let out_data = output.data().to_vec();
171 for i in 0..bs {
172 let offset = i * 10;
173 let mut pred = 0;
174 let mut pred_val = f32::NEG_INFINITY;
175 for c in 0..10 {
176 if out_data[offset + c] > pred_val {
177 pred_val = out_data[offset + c];
178 pred = c;
179 }
180 }
181 if pred == target_indices[i] as usize {
182 correct += 1;
183 }
184 total += 1;
185 }
186
187 // Backward pass
188 loss.backward();
189
190 // Update weights
191 optimizer.step();
192 optimizer.zero_grad();
193 }
194
195 let epoch_time = epoch_start.elapsed();
196 let avg_loss = total_loss / batch_count as f32;
197 let accuracy = 100.0 * correct as f32 / total as f32;
198 let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
199
200 println!(
201 " Epoch {:2}/{}: Loss={:.4} Acc={:.1}% ({:.0} samples/s, {:.2}s)",
202 epoch + 1,
203 epochs,
204 avg_loss,
205 accuracy,
206 samples_per_sec,
207 epoch_time.as_secs_f64(),
208 );
209 }
210
211 let train_time = train_start.elapsed();
212 println!("\n Total training time: {:.2}s", train_time.as_secs_f64());
213
214 // -------------------------------------------------------------------------
215 // Test Evaluation
216 // -------------------------------------------------------------------------
217
218 // 6. Test evaluation
219 println!("\n6. Evaluating on test set...");
220
221 // Disable gradient computation for evaluation
222 let (correct, total) = no_grad(|| {
223 let mut correct = 0usize;
224 let mut total = 0usize;
225
226 for batch in test_loader.iter() {
227 let bs = batch.data.shape()[0];
228
229 let input_data = batch.data.to_vec();
230 let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
231 let input = Variable::new(
232 if device.is_gpu() {
233 input_tensor.to_device(device).unwrap()
234 } else {
235 input_tensor
236 },
237 false,
238 );
239
240 let target_onehot = batch.targets.to_vec();
241 let output = model.forward(&input);
242 let out_data = output.data().to_vec();
243
244 for i in 0..bs {
245 // Prediction: argmax of output
246 let offset = i * 10;
247 let mut pred = 0;
248 let mut pred_val = f32::NEG_INFINITY;
249 for c in 0..10 {
250 if out_data[offset + c] > pred_val {
251 pred_val = out_data[offset + c];
252 pred = c;
253 }
254 }
255
256 // True label: argmax of one-hot target
257 let mut true_label = 0;
258 let mut true_val = f32::NEG_INFINITY;
259 for c in 0..10 {
260 if target_onehot[i * 10 + c] > true_val {
261 true_val = target_onehot[i * 10 + c];
262 true_label = c;
263 }
264 }
265
266 if pred == true_label {
267 correct += 1;
268 }
269 total += 1;
270 }
271 }
272
273 (correct, total)
274 });
275
276 let test_accuracy = 100.0 * correct as f32 / total as f32;
277 println!(
278 " Test Accuracy: {}/{} ({:.2}%)",
279 correct, total, test_accuracy
280 );
281
282 println!("\n=== Training Complete! ===");
283 println!(" Device: {:?}", device);
284 println!(" Final test accuracy: {:.2}%", test_accuracy);
285}Sourcepub fn for_cifar10() -> LeNet
pub fn for_cifar10() -> LeNet
Creates a LeNet for CIFAR-10 (32x32 input, 10 classes).
Trait Implementations§
Source§impl Module for LeNet
impl Module for LeNet
Source§fn named_parameters(&self) -> HashMap<String, Parameter>
fn named_parameters(&self) -> HashMap<String, Parameter>
Returns named parameters of this module.
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Returns the number of trainable parameters.
Source§fn set_training(&mut self, _training: bool)
fn set_training(&mut self, _training: bool)
Sets the training mode.
Sets the training mode. Read more
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Returns whether the module is in training mode. Read more
Auto Trait Implementations§
impl Freeze for LeNet
impl !RefUnwindSafe for LeNet
impl Send for LeNet
impl Sync for LeNet
impl Unpin for LeNet
impl UnsafeUnpin for LeNet
impl !UnwindSafe for LeNet
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<R, P> ReadPrimitive<R> for P
impl<R, P> ReadPrimitive<R> for P
Source§fn read_from_little_endian(read: &mut R) -> Result<Self, Error>
fn read_from_little_endian(read: &mut R) -> Result<Self, Error>
Read this value from the supplied reader. Same as
ReadEndian::read_from_little_endian().