use crate::error::Result;
use crate::training::TrainingMetrics;
use ndarray::{Array1, Array2};
pub struct SimpleMLP {
pub w1: Array2<f64>,
pub b1: Array1<f64>,
pub w2: Array2<f64>,
pub b2: Array1<f64>,
}
impl SimpleMLP {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
let scale1 = (2.0 / (input_size + hidden_size) as f64).sqrt();
let scale2 = (2.0 / (hidden_size + output_size) as f64).sqrt();
Self {
w1: Array2::from_shape_fn((input_size, hidden_size), |_| {
rng.gen_range(-scale1..scale1)
}),
b1: Array1::zeros(hidden_size),
w2: Array2::from_shape_fn((hidden_size, output_size), |_| {
rng.gen_range(-scale2..scale2)
}),
b2: Array1::zeros(output_size),
}
}
pub fn forward(&self, x: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
let z1 = x.dot(&self.w1) + &self.b1;
let h1 = z1.mapv(|v| v.max(0.0));
let z2 = h1.dot(&self.w2) + &self.b2;
(h1, z2)
}
pub fn predict(&self, x: &Array2<f64>) -> Array2<f64> {
let (_, output) = self.forward(x);
output
}
pub fn train_step(
&mut self,
x: &Array2<f64>,
y: &Array2<f64>,
learning_rate: f64,
) -> f64 {
let batch_size = x.nrows() as f64;
let (h1, output) = self.forward(x);
let loss = (&output - y).mapv(|v| v * v).mean().unwrap();
let dz2 = (&output - y) * (2.0 / batch_size);
let dw2 = h1.t().dot(&dz2);
let db2 = dz2.sum_axis(ndarray::Axis(0));
let dh1 = dz2.dot(&self.w2.t());
let dz1 = dh1 * &h1.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 });
let dw1 = x.t().dot(&dz1);
let db1 = dz1.sum_axis(ndarray::Axis(0));
let clip_value = 5.0;
let dw1_clipped = dw1.mapv(|v| v.max(-clip_value).min(clip_value));
let dw2_clipped = dw2.mapv(|v| v.max(-clip_value).min(clip_value));
let db1_clipped = db1.mapv(|v| v.max(-clip_value).min(clip_value));
let db2_clipped = db2.mapv(|v| v.max(-clip_value).min(clip_value));
self.w1 = &self.w1 - &(&dw1_clipped * learning_rate);
self.b1 = &self.b1 - &(&db1_clipped * learning_rate);
self.w2 = &self.w2 - &(&dw2_clipped * learning_rate);
self.b2 = &self.b2 - &(&db2_clipped * learning_rate);
loss
}
}
#[derive(Debug, Clone)]
pub struct SimpleCPUTrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub early_stopping_patience: usize,
pub print_every: usize,
}
impl Default for SimpleCPUTrainingConfig {
fn default() -> Self {
Self {
epochs: 30,
batch_size: 32,
learning_rate: 0.01,
early_stopping_patience: 10,
print_every: 3,
}
}
}
pub struct SimpleCPUTrainer {
config: SimpleCPUTrainingConfig,
}
impl SimpleCPUTrainer {
pub fn new(config: SimpleCPUTrainingConfig) -> Self {
Self { config }
}
pub fn train(
&self,
model: &mut SimpleMLP,
train_x: &Array2<f64>,
train_y: &Array2<f64>,
val_x: Option<&Array2<f64>>,
val_y: Option<&Array2<f64>>,
) -> Result<TrainingMetrics> {
let mut best_val_loss = f64::INFINITY;
let mut patience_counter = 0;
let mut train_losses = Vec::new();
let mut val_losses = Vec::new();
println!("Starting training...");
println!(
"Epochs: {}, Batch size: {}, Learning rate: {}",
self.config.epochs, self.config.batch_size, self.config.learning_rate
);
for epoch in 0..self.config.epochs {
let mut epoch_loss = 0.0;
let num_batches =
(train_x.nrows() + self.config.batch_size - 1) / self.config.batch_size;
for batch_idx in 0..num_batches {
let start = batch_idx * self.config.batch_size;
let end = (start + self.config.batch_size).min(train_x.nrows());
let batch_x = train_x.slice(ndarray::s![start..end, ..]).to_owned();
let batch_y = train_y.slice(ndarray::s![start..end, ..]).to_owned();
let loss = model.train_step(&batch_x, &batch_y, self.config.learning_rate);
epoch_loss += loss;
}
epoch_loss /= num_batches as f64;
train_losses.push(epoch_loss);
let val_loss = if let (Some(vx), Some(vy)) = (val_x, val_y) {
let pred = model.predict(vx);
let loss = (&pred - vy).mapv(|v| v * v).mean().unwrap();
val_losses.push(loss);
loss
} else {
epoch_loss
};
if (epoch + 1) % self.config.print_every == 0 {
println!(
"Epoch {}/{}: train_loss={:.6}, val_loss={:.6}",
epoch + 1,
self.config.epochs,
epoch_loss,
val_loss
);
}
if val_loss < best_val_loss {
best_val_loss = val_loss;
patience_counter = 0;
} else {
patience_counter += 1;
if patience_counter >= self.config.early_stopping_patience {
println!("Early stopping at epoch {}", epoch + 1);
break;
}
}
}
Ok(TrainingMetrics {
train_loss: *train_losses.last().unwrap_or(&0.0),
val_loss: val_losses.last().copied(),
epoch: train_losses.len(),
learning_rate: self.config.learning_rate,
epoch_time_seconds: 0.0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_mlp_creation() {
let model = SimpleMLP::new(10, 20, 5);
assert_eq!(model.w1.shape(), &[10, 20]);
assert_eq!(model.w2.shape(), &[20, 5]);
}
#[test]
fn test_simple_mlp_forward() {
let model = SimpleMLP::new(10, 20, 5);
let x = Array2::from_shape_fn((4, 10), |_| rand::random::<f64>());
let output = model.predict(&x);
assert_eq!(output.shape(), &[4, 5]);
}
#[test]
fn test_train_step() {
let mut model = SimpleMLP::new(10, 20, 5);
let x = Array2::from_shape_fn((4, 10), |_| rand::random::<f64>());
let y = Array2::from_shape_fn((4, 5), |_| rand::random::<f64>());
let initial_loss = model.train_step(&x, &y, 0.01);
let second_loss = model.train_step(&x, &y, 0.01);
assert!(initial_loss.is_finite());
assert!(second_loss.is_finite());
}
}