use crate::errors::NeuroxResult;
use crate::optimizer::{Adam, SGD};
use crate::{
layers::{Activation, Dense},
loss,
tensor::Tensor,
};
pub struct Model {
pub layers: Vec<Dense>,
}
impl Model {
pub fn new(layer_sizes: &[usize], activation: Activation) -> Self {
let mut layers = Vec::new();
for win in layer_sizes.windows(2) {
layers.push(Dense::new(win[0], win[1], activation));
}
Self { layers }
}
pub fn forward(&mut self, input: &Tensor) -> NeuroxResult<Tensor> {
let mut x = input.clone();
for layer in self.layers.iter_mut() {
x = layer.forward(&x)?;
}
Ok(x)
}
pub fn train_sgd(
&mut self,
x: &Tensor,
y: &Tensor,
epochs: usize,
batch_size: usize,
lr: f32,
) -> NeuroxResult<()> {
let opt = SGD::new(lr);
for _epoch in 0..epochs {
for start in (0..x.rows).step_by(batch_size) {
let end = (start + batch_size).min(x.rows);
let bx = slice_rows(x, start, end)?;
let by = slice_rows(y, start, end)?;
let preds = self.forward(&bx)?;
let probs = crate::activations::softmax(&preds);
let (_loss, grad) = loss::cross_entropy_loss(&probs, &by);
let mut upstream_grad = grad;
for layer in self.layers.iter_mut().rev() {
upstream_grad = layer.backward(&upstream_grad)?;
}
opt.step(&mut self.layers);
}
}
Ok(())
}
pub fn train_adam(
&mut self,
x: &Tensor,
y: &Tensor,
epochs: usize,
batch_size: usize,
lr: f32,
) -> NeuroxResult<()> {
let mut adam = Adam::new(lr, &self.layers);
for _epoch in 0..epochs {
for start in (0..x.rows).step_by(batch_size) {
let end = (start + batch_size).min(x.rows);
let bx = slice_rows(x, start, end)?;
let by = slice_rows(y, start, end)?;
let preds = self.forward(&bx)?;
let probs = crate::activations::softmax(&preds);
let (_loss, grad) = loss::cross_entropy_loss(&probs, &by);
let mut upstream_grad = grad;
for layer in self.layers.iter_mut().rev() {
upstream_grad = layer.backward(&upstream_grad)?;
}
adam.step(&mut self.layers);
}
}
Ok(())
}
pub fn summary(&self) {
println!("Model Summary:");
let mut total = 0usize;
for (i, l) in self.layers.iter().enumerate() {
println!(
" Layer {}: Dense {} -> {} (params {})",
i,
l.w.rows,
l.w.cols,
l.num_params()
);
total += l.num_params();
}
println!("Total params: {}", total);
}
}
fn slice_rows(t: &Tensor, start: usize, end: usize) -> NeuroxResult<Tensor> {
assert!(start < end && end <= t.rows);
let cols = t.cols;
let mut out = Tensor::zeros(end - start, cols);
for i in 0..(end - start) {
for j in 0..cols {
out.set(i, j, t.get(start + i, j));
}
}
Ok(out)
}