1use axonml_autograd::Variable;
9use axonml_nn::{Conv2d, Linear, Module, Parameter};
10use axonml_tensor::Tensor;
11
12pub struct LeNet {
26 conv1: Conv2d,
27 conv2: Conv2d,
28 fc1: Linear,
29 fc2: Linear,
30 fc3: Linear,
31}
32
33impl LeNet {
34 #[must_use] pub fn new() -> Self {
36 Self {
37 conv1: Conv2d::new(1, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 4 * 4, 120), fc2: Linear::new(120, 84),
41 fc3: Linear::new(84, 10),
42 }
43 }
44
45 #[must_use] pub fn for_cifar10() -> Self {
47 Self {
48 conv1: Conv2d::new(3, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 5 * 5, 120), fc2: Linear::new(120, 84),
52 fc3: Linear::new(84, 10),
53 }
54 }
55
56 fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
58 let data = input.data();
59 let shape = data.shape();
60
61 if shape.len() == 4 {
62 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
63 let out_h = h / kernel_size;
64 let out_w = w / kernel_size;
65
66 let data_vec = data.to_vec();
67 let mut result = vec![0.0f32; n * c * out_h * out_w];
68
69 for batch in 0..n {
70 for ch in 0..c {
71 for oh in 0..out_h {
72 for ow in 0..out_w {
73 let mut max_val = f32::NEG_INFINITY;
74 for kh in 0..kernel_size {
75 for kw in 0..kernel_size {
76 let ih = oh * kernel_size + kh;
77 let iw = ow * kernel_size + kw;
78 let idx = batch * c * h * w + ch * h * w + ih * w + iw;
79 max_val = max_val.max(data_vec[idx]);
80 }
81 }
82 let out_idx =
83 batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
84 result[out_idx] = max_val;
85 }
86 }
87 }
88 }
89
90 Variable::new(
91 Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
92 input.requires_grad(),
93 )
94 } else if shape.len() == 3 {
95 let (c, h, w) = (shape[0], shape[1], shape[2]);
97 let out_h = h / kernel_size;
98 let out_w = w / kernel_size;
99
100 let data_vec = data.to_vec();
101 let mut result = vec![0.0f32; c * out_h * out_w];
102
103 for ch in 0..c {
104 for oh in 0..out_h {
105 for ow in 0..out_w {
106 let mut max_val = f32::NEG_INFINITY;
107 for kh in 0..kernel_size {
108 for kw in 0..kernel_size {
109 let ih = oh * kernel_size + kh;
110 let iw = ow * kernel_size + kw;
111 let idx = ch * h * w + ih * w + iw;
112 max_val = max_val.max(data_vec[idx]);
113 }
114 }
115 let out_idx = ch * out_h * out_w + oh * out_w + ow;
116 result[out_idx] = max_val;
117 }
118 }
119 }
120
121 Variable::new(
122 Tensor::from_vec(result, &[c, out_h, out_w]).unwrap(),
123 input.requires_grad(),
124 )
125 } else {
126 input.clone()
127 }
128 }
129
130 fn flatten(&self, input: &Variable) -> Variable {
132 let data = input.data();
133 let shape = data.shape();
134
135 if shape.len() <= 2 {
136 return input.clone();
137 }
138
139 let batch_size = shape[0];
140 let features: usize = shape[1..].iter().product();
141
142 Variable::new(
143 Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
144 input.requires_grad(),
145 )
146 }
147}
148
149impl Default for LeNet {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl Module for LeNet {
156 fn forward(&self, input: &Variable) -> Variable {
157 let x = self.conv1.forward(input);
159 let x = x.relu();
160 let x = self.max_pool2d(&x, 2);
161
162 let x = self.conv2.forward(&x);
164 let x = x.relu();
165 let x = self.max_pool2d(&x, 2);
166
167 let x = self.flatten(&x);
169
170 let x = self.fc1.forward(&x);
172 let x = x.relu();
173 let x = self.fc2.forward(&x);
174 let x = x.relu();
175 self.fc3.forward(&x)
176 }
177
178 fn parameters(&self) -> Vec<Parameter> {
179 let mut params = Vec::new();
180 params.extend(self.conv1.parameters());
181 params.extend(self.conv2.parameters());
182 params.extend(self.fc1.parameters());
183 params.extend(self.fc2.parameters());
184 params.extend(self.fc3.parameters());
185 params
186 }
187
188 fn train(&mut self) {
189 }
191
192 fn eval(&mut self) {
193 }
195}
196
197pub struct SimpleCNN {
203 conv1: Conv2d,
204 fc1: Linear,
205 fc2: Linear,
206 input_channels: usize,
207 num_classes: usize,
208}
209
210impl SimpleCNN {
211 #[must_use] pub fn new(input_channels: usize, num_classes: usize) -> Self {
214 Self {
215 conv1: Conv2d::new(input_channels, 32, 3),
216 fc1: Linear::new(32 * 13 * 13, 128), fc2: Linear::new(128, num_classes),
218 input_channels,
219 num_classes,
220 }
221 }
222
223 #[must_use] pub fn for_mnist() -> Self {
225 Self::new(1, 10)
226 }
227
228 #[must_use] pub fn for_cifar10() -> Self {
230 Self {
232 conv1: Conv2d::new(3, 32, 3),
233 fc1: Linear::new(32 * 15 * 15, 128),
234 fc2: Linear::new(128, 10),
235 input_channels: 3,
236 num_classes: 10,
237 }
238 }
239
240 #[must_use] pub fn input_channels(&self) -> usize {
242 self.input_channels
243 }
244
245 #[must_use] pub fn num_classes(&self) -> usize {
247 self.num_classes
248 }
249
250 fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
251 let data = input.data();
252 let shape = data.shape();
253
254 if shape.len() != 4 {
255 return input.clone();
256 }
257
258 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
259 let out_h = h / kernel_size;
260 let out_w = w / kernel_size;
261
262 let data_vec = data.to_vec();
263 let mut result = vec![0.0f32; n * c * out_h * out_w];
264
265 for batch in 0..n {
266 for ch in 0..c {
267 for oh in 0..out_h {
268 for ow in 0..out_w {
269 let mut max_val = f32::NEG_INFINITY;
270 for kh in 0..kernel_size {
271 for kw in 0..kernel_size {
272 let ih = oh * kernel_size + kh;
273 let iw = ow * kernel_size + kw;
274 let idx = batch * c * h * w + ch * h * w + ih * w + iw;
275 max_val = max_val.max(data_vec[idx]);
276 }
277 }
278 let out_idx =
279 batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
280 result[out_idx] = max_val;
281 }
282 }
283 }
284 }
285
286 Variable::new(
287 Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
288 input.requires_grad(),
289 )
290 }
291
292 fn flatten(&self, input: &Variable) -> Variable {
293 let data = input.data();
294 let shape = data.shape();
295
296 if shape.len() <= 2 {
297 return input.clone();
298 }
299
300 let batch_size = shape[0];
301 let features: usize = shape[1..].iter().product();
302
303 Variable::new(
304 Tensor::from_vec(data.to_vec(), &[batch_size, features]).unwrap(),
305 input.requires_grad(),
306 )
307 }
308}
309
310impl Module for SimpleCNN {
311 fn forward(&self, input: &Variable) -> Variable {
312 let x = self.conv1.forward(input);
313 let x = x.relu();
314 let x = self.max_pool2d(&x, 2);
315 let x = self.flatten(&x);
316 let x = self.fc1.forward(&x);
317 let x = x.relu();
318 self.fc2.forward(&x)
319 }
320
321 fn parameters(&self) -> Vec<Parameter> {
322 let mut params = Vec::new();
323 params.extend(self.conv1.parameters());
324 params.extend(self.fc1.parameters());
325 params.extend(self.fc2.parameters());
326 params
327 }
328
329 fn train(&mut self) {}
330 fn eval(&mut self) {}
331}
332
333pub struct MLP {
339 fc1: Linear,
340 fc2: Linear,
341 fc3: Linear,
342}
343
344impl MLP {
345 #[must_use] pub fn new(input_size: usize, hidden_size: usize, num_classes: usize) -> Self {
347 Self {
348 fc1: Linear::new(input_size, hidden_size),
349 fc2: Linear::new(hidden_size, hidden_size / 2),
350 fc3: Linear::new(hidden_size / 2, num_classes),
351 }
352 }
353
354 #[must_use] pub fn for_mnist() -> Self {
356 Self::new(784, 256, 10)
357 }
358
359 #[must_use] pub fn for_cifar10() -> Self {
361 Self::new(3072, 512, 10)
362 }
363}
364
365impl Module for MLP {
366 fn forward(&self, input: &Variable) -> Variable {
367 let data = input.data();
369 let shape = data.shape();
370 let x = if shape.len() > 2 {
371 let batch = shape[0];
372 let features: usize = shape[1..].iter().product();
373 Variable::new(
374 Tensor::from_vec(data.to_vec(), &[batch, features]).unwrap(),
375 input.requires_grad(),
376 )
377 } else if shape.len() == 1 {
378 Variable::new(
380 Tensor::from_vec(data.to_vec(), &[1, shape[0]]).unwrap(),
381 input.requires_grad(),
382 )
383 } else {
384 input.clone()
385 };
386
387 let x = self.fc1.forward(&x);
388 let x = x.relu();
389 let x = self.fc2.forward(&x);
390 let x = x.relu();
391 self.fc3.forward(&x)
392 }
393
394 fn parameters(&self) -> Vec<Parameter> {
395 let mut params = Vec::new();
396 params.extend(self.fc1.parameters());
397 params.extend(self.fc2.parameters());
398 params.extend(self.fc3.parameters());
399 params
400 }
401
402 fn train(&mut self) {}
403 fn eval(&mut self) {}
404}
405
406#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_lenet_creation() {
416 let model = LeNet::new();
417 let params = model.parameters();
418
419 assert!(!params.is_empty());
421 }
422
423 #[test]
424 fn test_lenet_forward() {
425 let model = LeNet::new();
426
427 let input = Variable::new(
429 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
430 false,
431 );
432
433 let output = model.forward(&input);
434 assert_eq!(output.data().shape(), &[2, 10]);
435 }
436
437 #[test]
438 fn test_simple_cnn_mnist() {
439 let model = SimpleCNN::for_mnist();
440
441 let input = Variable::new(
442 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
443 false,
444 );
445
446 let output = model.forward(&input);
447 assert_eq!(output.data().shape(), &[2, 10]);
448 }
449
450 #[test]
451 fn test_mlp_mnist() {
452 let model = MLP::for_mnist();
453
454 let input = Variable::new(
456 Tensor::from_vec(vec![0.5; 2 * 784], &[2, 784]).unwrap(),
457 false,
458 );
459
460 let output = model.forward(&input);
461 assert_eq!(output.data().shape(), &[2, 10]);
462 }
463
464 #[test]
465 fn test_mlp_auto_flatten() {
466 let model = MLP::for_mnist();
467
468 let input = Variable::new(
470 Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
471 false,
472 );
473
474 let output = model.forward(&input);
475 assert_eq!(output.data().shape(), &[2, 10]);
476 }
477
478 #[test]
479 fn test_lenet_parameter_count() {
480 let model = LeNet::new();
481 let params = model.parameters();
482
483 let total: usize = params
485 .iter()
486 .map(|p| p.variable().data().to_vec().len())
487 .sum();
488
489 assert!(total > 40000 && total < 100000);
491 }
492}