use crate::dataset::Dataset;
use crate::utils::*;
use nalgebra::DMatrix;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{fs, marker::PhantomData, path::Path};
#[derive(Serialize, Deserialize)]
pub struct NeuralNet<A: Activation> {
layers: Vec<DMatrix<f64>>,
weights: Vec<DMatrix<f64>>,
biases: Vec<DMatrix<f64>>,
errors: Vec<DMatrix<f64>>,
activation: PhantomData<A>,
}
impl<A: Activation + Serialize + DeserializeOwned> NeuralNet<A> {
pub fn new(node_counts: &[usize]) -> Self {
let num_layers = node_counts.len();
if num_layers < 2 {
panic!(
"not enough layers supplied (expected at least 2, found {})",
num_layers
);
}
Self {
layers: node_counts.iter().map(|c| DMatrix::zeros(*c, 1)).collect(),
weights: (1..num_layers)
.map(|i| gen_random_matrix(node_counts[i], node_counts[i - 1]))
.collect(),
biases: node_counts
.iter()
.skip(1)
.map(|c| gen_random_matrix(*c, 1))
.collect(),
errors: node_counts
.iter()
.skip(1)
.map(|c| DMatrix::zeros(*c, 1))
.collect(),
activation: PhantomData,
}
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, LoadErr> {
let file = fs::File::open(path)?;
let decoded: NeuralNet<A> = bincode::deserialize_from(file)?;
Ok(decoded)
}
pub fn train(&mut self, mut training_dataset: Dataset, iterations: u64, learning_rate: f64) {
let progress_bar = indicatif::ProgressBar::new(iterations);
progress_bar.set_style(
indicatif::ProgressStyle::default_bar()
.template("Training [{bar:30}] {percent:>3}% ETA: {eta}")
.progress_chars("=> "),
);
let percentile = iterations / 100;
for i in 1..iterations {
training_dataset.shuffle();
for (inputs, targets) in &training_dataset {
let guesses = self.guess(inputs);
self.backpropagate(&guesses, targets, learning_rate);
}
if i % percentile == 0 {
progress_bar.inc(percentile);
}
}
progress_bar.finish_and_clear();
}
pub fn test(&mut self, testing_dataset: Dataset) -> f64 {
let mut avg_cost = 0.0;
for (inputs, targets) in &testing_dataset {
let guesses = self.guess(inputs);
let cost_sum: f64 = guesses
.iter()
.zip(targets)
.map(|(i, t)| (t - i).abs())
.sum();
avg_cost += cost_sum / guesses.len() as f64;
}
avg_cost /= testing_dataset.rows() as f64;
avg_cost
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<(), SaveErr> {
let encoded = bincode::serialize(&self)?;
fs::write(path, encoded)?;
Ok(())
}
pub fn guess(&mut self, inputs: &[f64]) -> Vec<f64> {
let num_inputs = inputs.len();
let num_input_layer_rows = self.layers[0].row_iter().len();
if num_inputs != num_input_layer_rows {
panic!(
"incorrect number of inputs supplied (expected {}, found {})",
num_input_layer_rows, num_inputs
);
}
let num_layers = self.layers.len();
self.layers[0] = convert_slice_to_matrix(inputs);
for i in 0..num_layers - 1 {
let mut value = &self.weights[i] * &self.layers[i];
value += &self.biases[i];
for x in value.iter_mut() {
*x = A::activate(*x);
}
self.layers[i + 1] = value;
}
self.layers[num_layers - 1].iter().cloned().collect()
}
fn backpropagate(&mut self, guesses: &[f64], targets: &[f64], learning_rate: f64) {
let guesses = convert_slice_to_matrix(guesses);
let targets = convert_slice_to_matrix(targets);
let num_layers = self.layers.len();
self.errors[num_layers - 2] = targets - guesses;
for (i, layer) in self.layers.iter().enumerate().skip(1).rev() {
let mut gradients = layer.map(A::derivative);
gradients.component_mul_assign(&self.errors[i - 1]);
gradients *= learning_rate;
let deltas = &gradients * self.layers[i - 1].transpose();
self.weights[i - 1] += deltas;
self.biases[i - 1] += gradients;
if i != 1 {
self.errors[i - 2] = self.weights[i - 1].transpose() * &self.errors[i - 1];
}
}
}
}
pub trait Activation {
fn activate(x: f64) -> f64;
fn derivative(x: f64) -> f64;
}
#[derive(Serialize, Deserialize)]
pub struct Sigmoid;
impl Activation for Sigmoid {
fn activate(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
fn derivative(x: f64) -> f64 {
x * (1.0 - x)
}
}
#[derive(thiserror::Error, Debug)]
pub enum SaveErr {
#[error("failed to serialize network")]
Serialize(#[from] bincode::Error),
#[error("failed to write to file")]
FileWrite(#[from] std::io::Error),
}
#[derive(thiserror::Error, Debug)]
pub enum LoadErr {
#[error("failed to deserialize network")]
Deserialize(#[from] bincode::Error),
#[error("failed to read from file")]
FileRead(#[from] std::io::Error),
}