use std::any::Any;
use axonml_autograd::no_grad::is_grad_enabled;
use axonml_autograd::{GradFn, GradientFunction, Variable};
use axonml_nn::{Conv2d, Linear, Module, Parameter};
use axonml_tensor::Tensor;
pub struct LeNet {
conv1: Conv2d,
conv2: Conv2d,
fc1: Linear,
fc2: Linear,
fc3: Linear,
}
impl LeNet {
#[must_use]
pub fn new() -> Self {
Self {
conv1: Conv2d::new(1, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 4 * 4, 120), fc2: Linear::new(120, 84),
fc3: Linear::new(84, 10),
}
}
#[must_use]
pub fn for_cifar10() -> Self {
Self {
conv1: Conv2d::new(3, 6, 5), conv2: Conv2d::new(6, 16, 5), fc1: Linear::new(16 * 5 * 5, 120), fc2: Linear::new(120, 84),
fc3: Linear::new(84, 10),
}
}
fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
let data = input.data();
let shape = data.shape();
if shape.len() == 4 {
let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let out_h = h / kernel_size;
let out_w = w / kernel_size;
let data_vec = data.to_vec();
let out_size = n * c * out_h * out_w;
let mut result = vec![0.0f32; out_size];
let mut max_indices = vec![0usize; out_size];
for batch in 0..n {
for ch in 0..c {
for oh in 0..out_h {
for ow in 0..out_w {
let mut max_val = f32::NEG_INFINITY;
let mut max_idx = 0usize;
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let ih = oh * kernel_size + kh;
let iw = ow * kernel_size + kw;
let idx = batch * c * h * w + ch * h * w + ih * w + iw;
if data_vec[idx] > max_val {
max_val = data_vec[idx];
max_idx = idx;
}
}
}
let out_idx =
batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
result[out_idx] = max_val;
max_indices[out_idx] = max_idx;
}
}
}
}
let output_tensor = Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = GradFn::new(MaxPool2dBackward {
next_fns: vec![input.grad_fn().cloned()],
max_indices,
input_shape: shape.to_vec(),
});
Variable::from_operation(output_tensor, grad_fn, true)
} else {
Variable::new(output_tensor, false)
}
} else if shape.len() == 3 {
let (c, h, w) = (shape[0], shape[1], shape[2]);
let out_h = h / kernel_size;
let out_w = w / kernel_size;
let data_vec = data.to_vec();
let out_size = c * out_h * out_w;
let mut result = vec![0.0f32; out_size];
let mut max_indices = vec![0usize; out_size];
for ch in 0..c {
for oh in 0..out_h {
for ow in 0..out_w {
let mut max_val = f32::NEG_INFINITY;
let mut max_idx = 0usize;
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let ih = oh * kernel_size + kh;
let iw = ow * kernel_size + kw;
let idx = ch * h * w + ih * w + iw;
if data_vec[idx] > max_val {
max_val = data_vec[idx];
max_idx = idx;
}
}
}
let out_idx = ch * out_h * out_w + oh * out_w + ow;
result[out_idx] = max_val;
max_indices[out_idx] = max_idx;
}
}
}
let output_tensor = Tensor::from_vec(result, &[c, out_h, out_w]).unwrap();
if input.requires_grad() && is_grad_enabled() {
let grad_fn = GradFn::new(MaxPool2dBackward {
next_fns: vec![input.grad_fn().cloned()],
max_indices,
input_shape: shape.to_vec(),
});
Variable::from_operation(output_tensor, grad_fn, true)
} else {
Variable::new(output_tensor, false)
}
} else {
input.clone()
}
}
fn flatten(&self, input: &Variable) -> Variable {
let shape = input.shape();
if shape.len() <= 2 {
return input.clone();
}
let batch_size = shape[0];
let features: usize = shape[1..].iter().product();
input.reshape(&[batch_size, features])
}
}
impl Default for LeNet {
fn default() -> Self {
Self::new()
}
}
impl Module for LeNet {
fn forward(&self, input: &Variable) -> Variable {
let x = self.conv1.forward(input);
let x = x.relu();
let x = self.max_pool2d(&x, 2);
let x = self.conv2.forward(&x);
let x = x.relu();
let x = self.max_pool2d(&x, 2);
let x = self.flatten(&x);
let x = self.fc1.forward(&x);
let x = x.relu();
let x = self.fc2.forward(&x);
let x = x.relu();
self.fc3.forward(&x)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.conv1.parameters());
params.extend(self.conv2.parameters());
params.extend(self.fc1.parameters());
params.extend(self.fc2.parameters());
params.extend(self.fc3.parameters());
params
}
fn train(&mut self) {
}
fn eval(&mut self) {
}
}
pub struct SimpleCNN {
conv1: Conv2d,
fc1: Linear,
fc2: Linear,
input_channels: usize,
num_classes: usize,
}
impl SimpleCNN {
#[must_use]
pub fn new(input_channels: usize, num_classes: usize) -> Self {
Self {
conv1: Conv2d::new(input_channels, 32, 3),
fc1: Linear::new(32 * 13 * 13, 128), fc2: Linear::new(128, num_classes),
input_channels,
num_classes,
}
}
#[must_use]
pub fn for_mnist() -> Self {
Self::new(1, 10)
}
#[must_use]
pub fn for_cifar10() -> Self {
Self {
conv1: Conv2d::new(3, 32, 3),
fc1: Linear::new(32 * 15 * 15, 128),
fc2: Linear::new(128, 10),
input_channels: 3,
num_classes: 10,
}
}
#[must_use]
pub fn input_channels(&self) -> usize {
self.input_channels
}
#[must_use]
pub fn num_classes(&self) -> usize {
self.num_classes
}
fn max_pool2d(&self, input: &Variable, kernel_size: usize) -> Variable {
let data = input.data();
let shape = data.shape();
if shape.len() != 4 {
return input.clone();
}
let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let out_h = h / kernel_size;
let out_w = w / kernel_size;
let data_vec = data.to_vec();
let mut result = vec![0.0f32; n * c * out_h * out_w];
for batch in 0..n {
for ch in 0..c {
for oh in 0..out_h {
for ow in 0..out_w {
let mut max_val = f32::NEG_INFINITY;
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let ih = oh * kernel_size + kh;
let iw = ow * kernel_size + kw;
let idx = batch * c * h * w + ch * h * w + ih * w + iw;
max_val = max_val.max(data_vec[idx]);
}
}
let out_idx =
batch * c * out_h * out_w + ch * out_h * out_w + oh * out_w + ow;
result[out_idx] = max_val;
}
}
}
}
Variable::new(
Tensor::from_vec(result, &[n, c, out_h, out_w]).unwrap(),
input.requires_grad(),
)
}
fn flatten(&self, input: &Variable) -> Variable {
let shape = input.shape();
if shape.len() <= 2 {
return input.clone();
}
let batch_size = shape[0];
let features: usize = shape[1..].iter().product();
input.reshape(&[batch_size, features])
}
}
impl Module for SimpleCNN {
fn forward(&self, input: &Variable) -> Variable {
let x = self.conv1.forward(input);
let x = x.relu();
let x = self.max_pool2d(&x, 2);
let x = self.flatten(&x);
let x = self.fc1.forward(&x);
let x = x.relu();
self.fc2.forward(&x)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.conv1.parameters());
params.extend(self.fc1.parameters());
params.extend(self.fc2.parameters());
params
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
pub struct MLP {
fc1: Linear,
fc2: Linear,
fc3: Linear,
}
impl MLP {
#[must_use]
pub fn new(input_size: usize, hidden_size: usize, num_classes: usize) -> Self {
Self {
fc1: Linear::new(input_size, hidden_size),
fc2: Linear::new(hidden_size, hidden_size / 2),
fc3: Linear::new(hidden_size / 2, num_classes),
}
}
#[must_use]
pub fn for_mnist() -> Self {
Self::new(784, 256, 10)
}
#[must_use]
pub fn for_cifar10() -> Self {
Self::new(3072, 512, 10)
}
}
impl Module for MLP {
fn forward(&self, input: &Variable) -> Variable {
let data = input.data();
let shape = data.shape();
let x = if shape.len() > 2 {
let batch = shape[0];
let features: usize = shape[1..].iter().product();
Variable::new(
Tensor::from_vec(data.to_vec(), &[batch, features]).unwrap(),
input.requires_grad(),
)
} else if shape.len() == 1 {
Variable::new(
Tensor::from_vec(data.to_vec(), &[1, shape[0]]).unwrap(),
input.requires_grad(),
)
} else {
input.clone()
};
let x = self.fc1.forward(&x);
let x = x.relu();
let x = self.fc2.forward(&x);
let x = x.relu();
self.fc3.forward(&x)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.fc1.parameters());
params.extend(self.fc2.parameters());
params.extend(self.fc3.parameters());
params
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
#[derive(Debug)]
struct MaxPool2dBackward {
next_fns: Vec<Option<GradFn>>,
max_indices: Vec<usize>,
input_shape: Vec<usize>,
}
impl GradientFunction for MaxPool2dBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let g_vec = grad_output.to_vec();
let input_size: usize = self.input_shape.iter().product();
let mut grad_input = vec![0.0f32; input_size];
for (i, &idx) in self.max_indices.iter().enumerate() {
if i < g_vec.len() {
grad_input[idx] += g_vec[i];
}
}
let gi = Tensor::from_vec(grad_input, &self.input_shape).unwrap();
vec![Some(gi)]
}
fn name(&self) -> &'static str {
"MaxPool2dBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lenet_creation() {
let model = LeNet::new();
let params = model.parameters();
assert!(!params.is_empty());
}
#[test]
fn test_lenet_forward() {
let model = LeNet::new();
let input = Variable::new(
Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
false,
);
let output = model.forward(&input);
assert_eq!(output.data().shape(), &[2, 10]);
}
#[test]
fn test_simple_cnn_mnist() {
let model = SimpleCNN::for_mnist();
let input = Variable::new(
Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
false,
);
let output = model.forward(&input);
assert_eq!(output.data().shape(), &[2, 10]);
}
#[test]
fn test_mlp_mnist() {
let model = MLP::for_mnist();
let input = Variable::new(
Tensor::from_vec(vec![0.5; 2 * 784], &[2, 784]).unwrap(),
false,
);
let output = model.forward(&input);
assert_eq!(output.data().shape(), &[2, 10]);
}
#[test]
fn test_mlp_auto_flatten() {
let model = MLP::for_mnist();
let input = Variable::new(
Tensor::from_vec(vec![0.5; 2 * 28 * 28], &[2, 1, 28, 28]).unwrap(),
false,
);
let output = model.forward(&input);
assert_eq!(output.data().shape(), &[2, 10]);
}
#[test]
fn test_lenet_parameter_count() {
let model = LeNet::new();
let params = model.parameters();
let total: usize = params
.iter()
.map(|p| p.variable().data().to_vec().len())
.sum();
assert!(total > 40000 && total < 100000);
}
}