use core::panic;
use rand::distr::{Distribution, Uniform};
use rand::rng;
#[derive(Clone)]
struct Layer {
inputs: Vec<f32>,
outputs: Vec<f32>,
biases: Vec<f32>,
weights: Vec<Vec<f32>>,
error_derivative: Vec<f32>,
}
#[derive(Clone)]
pub struct NeuralNetwork {
layers: Vec<Layer>,
learning_rate: f32,
activation: Box<dyn Activation>,
}
impl NeuralNetwork {
pub fn new(
layer_sizes: Vec<usize>,
learning_rate: f32,
activation: Box<dyn Activation>,
) -> Self {
let mut rng = rng();
let between = Uniform::try_from(-1.0..1.0).unwrap();
let mut layers = Vec::new();
for layer_index in 0..layer_sizes.len() {
let inputs = vec![
0.0;
layer_sizes[match layer_index {
0 => layer_index,
_ => layer_index - 1,
}]
];
dbg!(&inputs);
let weights = (0..layer_sizes[layer_index])
.into_iter()
.map(|_| {
(0..layer_sizes[match layer_index {
0 => layer_index,
_ => layer_index - 1,
}])
.into_iter()
.map(|_| between.sample(&mut rng))
.collect()
})
.collect();
layers.push(Layer {
inputs,
outputs: vec![0.0; layer_sizes[layer_index]],
biases: (0..layer_sizes[layer_index])
.into_iter()
.map(|_| between.sample(&mut rng))
.collect(),
weights,
error_derivative: vec![0.0; layer_sizes[layer_index]],
});
}
return Self {
layers,
learning_rate,
activation,
};
}
pub fn forward(&mut self, inputs: &Vec<f32>) -> Vec<f32> {
if inputs.len() != self.layers[0].inputs.len() {
dbg!(inputs.len(), self.layers[0].inputs.len());
panic!(
"The given arguement: 'inputs' in the 'forward' method must have the same length as the first 'Layer' inputs defined previously"
);
}
for layer_index in 0..self.layers.len() {
self.layers[layer_index].outputs.fill(0.0);
self.layers[layer_index].inputs = match layer_index {
0 => inputs.clone(),
_ => self.layers[layer_index - 1].outputs.clone(),
};
for j in 0..self.layers[layer_index].outputs.len() {
for k in 0..self.layers[layer_index].inputs.len() {
self.layers[layer_index].outputs[j] +=
self.layers[layer_index].inputs[k] * self.layers[layer_index].weights[j][k];
}
self.layers[layer_index].outputs[j] += self.layers[layer_index].biases[j];
self.layers[layer_index].outputs[j] = self
.activation
.function(self.layers[layer_index].outputs[j]);
}
}
self.layers.last().unwrap().outputs.clone()
}
pub fn errors(&self, expected: &Vec<f32>) -> f32 {
if expected.len() != self.layers.last().unwrap().outputs.len() {
panic!(
"The given arguement: 'expected' in the 'errors' method must have the same length as the last 'Layer
outputs defined previously"
);
}
let mut error = 0.0;
for (actual, expected) in self.layers.last().unwrap().outputs.iter().zip(expected) {
error += (actual - expected).powi(2);
}
error
}
pub fn backpropagate(&mut self, expected: &Vec<f32>) {
for layer_index in (0..self.layers.len()).rev() {
for k in 0..self.layers[layer_index].outputs.len() {
let delta = if layer_index == self.layers.len() - 1 {
let error = self.layers[layer_index].outputs[k] - expected[k];
error
* self
.activation
.derivative(self.layers[layer_index].outputs[k])
} else {
let mut error = 0.0;
for j in 0..self.layers[layer_index + 1].outputs.len() {
error += self.layers[layer_index + 1].weights[j][k]
* self.layers[layer_index + 1].error_derivative[j];
}
error
* self
.activation
.derivative(self.layers[layer_index].outputs[k])
};
self.layers[layer_index].error_derivative[k] = delta;
}
for j in 0..self.layers[layer_index].outputs.len() {
for k in 0..self.layers[layer_index].inputs.len() {
self.layers[layer_index].weights[j][k] -= self.learning_rate
* self.layers[layer_index].error_derivative[j]
* self.layers[layer_index].inputs[k];
}
self.layers[layer_index].biases[j] -=
self.learning_rate * self.layers[layer_index].error_derivative[j];
}
}
}
}
pub trait Activation: ActivationClone {
fn function(&self, x: f32) -> f32;
fn derivative(&self, x: f32) -> f32;
}
pub trait ActivationClone {
fn clone_box(&self) -> Box<dyn Activation>;
}
impl<T> ActivationClone for T
where
T: 'static + Activation + Clone,
{
fn clone_box(&self) -> Box<dyn Activation> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Activation> {
fn clone(&self) -> Box<dyn Activation> {
self.clone_box()
}
}
#[derive(Clone)]
pub struct Sigmoid;
impl Activation for Sigmoid {
fn function(&self, x: f32) -> f32 {
x.tanh()
}
fn derivative(&self, x: f32) -> f32 {
1.0 - x.powi(2)
}
}