use std::ops::{AddAssign, SubAssign};
use matrix_kit::dynamic::matrix::Matrix;
use rand_distr::Distribution;
use crate::{math::activation::AFI, models::neuralnet::NeuralNet};
use crate::math::loss::LFI;
use super::dataset::{DataItem, DataSet};
pub struct SGDTrainer<DI: DataItem> {
pub training_data_set: DataSet<DI>,
pub testing_data_set: DataSet<DI>,
pub loss_function: LFI,
}
#[derive(Clone)]
pub struct NNGradient {
pub derivatives: NeuralNet
}
impl SubAssign<NNGradient> for NeuralNet {
fn sub_assign(&mut self, rhs: NNGradient) {
for layer in 0..self.weights.len() {
self.weights[layer] -= rhs.derivatives.weights[layer].clone();
self.biases[layer] -= rhs.derivatives.biases[layer].clone();
}
}
}
impl AddAssign for NNGradient {
fn add_assign(&mut self, rhs: Self) {
for layer in 0..self.derivatives.weights.len() {
self.derivatives.weights[layer] += rhs.derivatives.weights[layer].clone();
self.derivatives.biases[layer] += rhs.derivatives.biases[layer].clone();
}
}
}
impl NNGradient {
pub fn from_nn_shape(neuralnet: NeuralNet) -> NNGradient {
let mut grad = NNGradient { derivatives: neuralnet };
for layer in 0..grad.derivatives.weights.len() {
grad.derivatives.weights[layer] = Matrix::from_index_def(grad.derivatives.weights[layer].row_count(), grad.derivatives.weights[layer].col_count(), &mut |_, _| 0.0);
grad.derivatives.biases[layer] = Matrix::from_index_def(grad.derivatives.biases[layer].row_count(), 1, &mut |_, _| 0.0);
}
grad
}
pub fn norm(&self) -> f64 {
let mut norm_squared = 0.0;
for layer in 0..self.derivatives.weights.len() {
norm_squared += self.derivatives.weights[layer].l2_norm_squared();
}
norm_squared.sqrt()
}
pub fn set_length(&mut self, length: f64) {
let norm = self.norm();
for layer in 0..self.derivatives.weights.len() {
self.derivatives.weights[layer] /= norm;
self.derivatives.weights[layer] *= length;
self.derivatives.biases[layer] /= norm;
self.derivatives.biases[layer] *= length;
}
}
}
impl<DI: DataItem> SGDTrainer<DI> {
pub fn new(training_data_set: DataSet<DI>, testing_data_set: DataSet<DI>, loss_function: LFI) -> SGDTrainer<DI> {
SGDTrainer { training_data_set, testing_data_set, loss_function }
}
pub fn compute_gradient(&self, training_item: DI, neuralnet: &NeuralNet) -> NNGradient {
let mut gradient = NNGradient { derivatives: neuralnet.clone() };
let layers = neuralnet.layer_count() - 1;
let (z, a) = neuralnet.compute_raw_and_full_layers(training_item.input());
let dot_sigma_z: Vec<Matrix<f64>> = (1..=layers)
.map(
|l| z[l].applying_to_all(
&|x| neuralnet.activation_functions[l - 1].derivative(x)
)
).collect();
let mut gradient_wrt_activations = a.clone();
gradient_wrt_activations[layers] = self.loss_function.derivative(&a[layers], &training_item.correct_output());
gradient.derivatives.biases[layers - 1] = dot_sigma_z[layers - 1].hadamard(gradient_wrt_activations[layers].clone());
gradient.derivatives.weights[layers - 1] = gradient.derivatives.biases[layers - 1].clone() * a[layers - 1].transpose();
for layer in (0..layers).rev() {
gradient_wrt_activations[layer] = neuralnet.weights[layer].transpose().clone() * dot_sigma_z[layer].hadamard(gradient_wrt_activations[layer + 1].clone());
gradient.derivatives.biases[layer] = dot_sigma_z[layer].hadamard(gradient_wrt_activations[layer + 1].clone());
gradient.derivatives.weights[layer] = gradient.derivatives.biases[layer].clone() * a[layer].transpose().clone();
}
gradient
}
pub fn sgd_batch_step(&self, batch: Vec<DI>, neuralnet: &mut NeuralNet, learning_rate: f64) -> f64 {
let mut gradient = NNGradient::from_nn_shape(neuralnet.clone());
for item in batch {
gradient += self.compute_gradient(item, neuralnet);
}
let original_length = gradient.norm();
gradient.set_length(learning_rate);
*neuralnet -= gradient;
original_length
}
pub fn train_sgd(&self, neuralnet: &mut NeuralNet, learning_rate: f64, epochs: usize, batch_size: usize) {
for epoch in 1..=epochs {
println!("Training Epoch {}...", epoch);
for batch in self.training_data_set.all_minibatches(batch_size) {
self.sgd_batch_step(batch, neuralnet, learning_rate);
}
if epoch % 20 == 0 {
println!("Current cost: {}", self.cost(&neuralnet));
}
}
println!("Completed all epochs of training.");
}
pub fn random_network(&self, shape: Vec<usize>, activation_functions: Vec<AFI>) -> NeuralNet {
let mut rand_gen = rand::rng();
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let weights = (1..shape.len()).map(
|layer| {
Matrix::from_index_def(shape[layer], shape[layer - 1], &mut |_, _| normal.sample(&mut rand_gen))
}
).collect();
let biases = (1..shape.len()).map(
|layer| {
Matrix::from_index_def(shape[layer], 1, &mut |_, _| normal.sample(&mut rand_gen))
}
).collect();
NeuralNet::new(weights, biases, activation_functions)
}
pub fn cost(&self, network: &NeuralNet) -> f64 {
let mut average_cost = 0.0;
let ds = &self.testing_data_set;
for item in ds.data_items.clone() {
let (x, y) = (item.input(), item.correct_output());
let a = network.compute_final_layer(x);
average_cost += self.loss_function.loss(&a, &y);
}
average_cost / (ds.data_items.len() as f64)
}
pub fn accuracy(&self, network: &NeuralNet) -> f64 {
let mut num_correct = 0;
for item in self.testing_data_set.data_items.clone() {
let (guess, _) = network.classify(item.input());
if guess == item.label() {
num_correct += 1;
}
}
(num_correct as f64) / (self.testing_data_set.data_items.len() as f64)
}
pub fn display_behavior(&self, network: &NeuralNet, num_items: usize) {
println!("Displaying network performance on {} testing items", num_items);
for item in self.testing_data_set.random_sample(num_items) {
println!("---Training Label: {} ---", item.name());
println!("{:?}", item);
println!("Network output: {:?}", network.classify(item.input()));
}
println!("--------------------");
println!("Final cost: {}", self.cost(network));
println!("Classification accuracy: {}", self.accuracy(network));
}
}
#[cfg(test)]
mod sgd_tests {
use std::fs::File;
use crate::{math::{activation::AFI, LFI}, utility::mnist::mnist_utility::load_mnist};
use super::SGDTrainer;
#[test]
fn test_basic_digits_sgd() {
println!("Loading data");
let testing_ds = load_mnist("digits", "t10k");
println!("Loaded testing data");
let dataset = load_mnist("digits", "train");
println!("Loaded training data");
let trainer = SGDTrainer::new(dataset, testing_ds, LFI::Squared);
println!("Created trainer");
let mut neuralnet = trainer.random_network(vec![784, 16, 16, 10], vec![AFI::Sigmoid, AFI::Sigmoid, AFI::Sigmoid]);
let learning_rate = 0.05;
let epochs = 100;
let original_cost = trainer.cost(&neuralnet);
println!("Original cost: {}", original_cost);
trainer.train_sgd(&mut neuralnet, learning_rate, epochs, 32);
let final_cost = trainer.cost(&neuralnet);
}
#[test]
fn test_basic_fashion_sgd() {
let dataset = load_mnist("fashion", "train");
let testing_ds = load_mnist("fashion", "t10k");
let trainer = SGDTrainer::new(dataset, testing_ds, LFI::Squared);
let mut neuralnet = trainer.random_network(vec![784, 16, 16, 10], vec![AFI::Sigmoid, AFI::Sigmoid, AFI::Sigmoid]);
let learning_rate = 0.05;
let epochs = 100;
let original_cost = trainer.cost(&neuralnet);
println!("Original cost: {}", original_cost);
trainer.train_sgd(&mut neuralnet, learning_rate, epochs, 32);
let final_cost = trainer.cost(&neuralnet);
println!("Final cost: {}", final_cost);
trainer.display_behavior(&neuralnet, 10);
println!("Writing final network to testing folder.");
match File::create("testing/files/fashion_nn.mlk_nn") {
Ok(mut f) => neuralnet.write_to_file(&mut f),
Err(e) => println!("Error writing to file: {:?}", e),
}
}
}