use rumus::autograd;
use rumus::nn::{self, Conv2d, Flatten, Linear, MaxPool2d, Module};
use rumus::optim::{Adam, Optimizer};
use rumus::tensor::{self, Tensor};
#[derive(Module)]
struct MiniCNN {
conv: Conv2d,
pool: MaxPool2d,
flat: Flatten,
fc1: Linear,
fc2: Linear,
}
impl MiniCNN {
fn new() -> Self {
Self {
conv: Conv2d::new(1, 4, 3, 1, 0, true), pool: MaxPool2d::new(2, 2),
flat: Flatten::new(),
fc1: Linear::new(36, 16, true),
fc2: Linear::new(16, 1, true),
}
}
fn forward(&self, input: &Tensor) -> Tensor {
let batch = input.shape()[0];
let x = self.conv.forward(input);
let mut pooled_outputs: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
let x_b = x.slice_batch(b); let x_b = nn::relu(&x_b);
let x_b = self.pool.forward(&x_b); pooled_outputs.push(x_b);
}
let pooled = tensor::stack(&pooled_outputs);
let flat = self.flat.forward(&pooled); let x = nn::relu(&self.fc1.forward(&flat));
self.fc2.forward(&x) }
}
fn make_image(center: bool) -> Vec<f32> {
let mut img = vec![0.0f32; 64];
for r in 0..8 {
for c in 0..8 {
let is_center = r >= 2 && r < 6 && c >= 2 && c < 6;
if center && is_center {
img[r * 8 + c] = 1.0;
} else if !center && !is_center {
img[r * 8 + c] = 1.0;
}
}
}
img
}
#[test]
fn test_cnn_spatial_classification() {
let mut input_data = Vec::with_capacity(4 * 1 * 8 * 8);
input_data.extend(make_image(false)); input_data.extend(make_image(true)); input_data.extend(make_image(false)); input_data.extend(make_image(true));
let inputs = Tensor::new(input_data, vec![4, 1, 8, 8]);
let targets = Tensor::new(vec![0.0, 1.0, 0.0, 1.0], vec![4, 1]);
let model = MiniCNN::new();
let mut optimizer = Adam::new(model.parameters(), 0.01);
let mut final_loss = f32::MAX;
for _epoch in 0..100 {
let pred = model.forward(&inputs);
let loss = nn::mse_loss(&pred, &targets);
{
let g = loss.data();
final_loss = g[0];
}
let mut grads = autograd::backward(&loss).expect("backward failed");
optimizer.step(&mut grads).expect("optimizer step failed");
}
assert!(
final_loss < 0.05,
"CNN training did not converge: final loss = {:.6}",
final_loss,
);
}