#![no_std]
use core::marker::PhantomData;
use activator::Activator;
use na::{SMatrix, SVector};
pub use nalgebra as na;
pub mod activator;
pub mod prelude {
pub use crate::activator::*;
pub use crate::FeedForward;
pub use nalgebra::SVector;
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FeedForward<
A: Activator,
const INPUTS: usize,
const HIDDEN: usize,
const OUTPUT: usize = 1,
> {
hidden_weights: SMatrix<f64, HIDDEN, INPUTS>,
output_weights: SMatrix<f64, OUTPUT, HIDDEN>,
hidden_bias: SVector<f64, HIDDEN>,
output_bias: SVector<f64, OUTPUT>,
_phantom: PhantomData<A>,
}
impl<A: Activator, const INPUTS: usize, const HIDDEN: usize, const OUTPUT: usize> Default
for FeedForward<A, INPUTS, HIDDEN, OUTPUT>
{
fn default() -> Self {
Self::new()
}
}
impl<A: Activator, const INPUTS: usize, const HIDDEN: usize, const OUTPUT: usize>
FeedForward<A, INPUTS, HIDDEN, OUTPUT>
{
pub fn new() -> Self {
let mut hidden_weights = SMatrix::<f64, HIDDEN, INPUTS>::zeros();
let mut output_weights = SMatrix::<f64, OUTPUT, HIDDEN>::zeros();
let mut hidden_bias = SVector::<f64, HIDDEN>::zeros();
let mut output_bias = SVector::<f64, OUTPUT>::zeros();
let mut i = 0;
while i < HIDDEN {
let mut j = 0;
while j < INPUTS {
hidden_weights[(i, j)] = simple_hash(i, j);
j += 1;
}
hidden_bias[i] = simple_hash(i, 0);
i += 1;
}
let mut i = 0;
while i < OUTPUT {
let mut j = 0;
while j < HIDDEN {
output_weights[(i, j)] = simple_hash(i + HIDDEN, j);
j += 1;
}
output_bias[i] = simple_hash(i + HIDDEN, 0);
i += 1;
}
Self {
hidden_weights,
output_weights,
hidden_bias,
output_bias,
_phantom: PhantomData,
}
}
pub fn forward(&self, input: &SVector<f64, INPUTS>) -> SVector<f64, OUTPUT> {
let hidden = self.hidden_weights * input + self.hidden_bias;
let hidden_activated = hidden.map(A::activate);
let output = self.output_weights * hidden_activated + self.output_bias;
output.map(A::activate)
}
pub fn train(
&mut self,
input: &SVector<f64, INPUTS>,
target: &SVector<f64, OUTPUT>,
learning_rate: f64,
) {
let hidden = self.hidden_weights * input + self.hidden_bias;
let hidden_activated = hidden.map(A::activate);
let output = self.output_weights * hidden_activated + self.output_bias;
let output_activated = output.map(A::activate);
let output_error = target - output_activated;
let output_delta = output_error.component_mul(&output.map(A::derivative));
let hidden_error = self.output_weights.transpose() * output_delta;
let hidden_delta = hidden_error.component_mul(&hidden.map(A::derivative));
self.output_weights += learning_rate * (output_delta * hidden_activated.transpose());
self.output_bias += learning_rate * output_delta;
self.hidden_weights += learning_rate * (hidden_delta * input.transpose());
self.hidden_bias += learning_rate * hidden_delta;
}
}
fn simple_hash(x: usize, y: usize) -> f64 {
let h = (x.wrapping_mul(31).wrapping_add(y)) as f64;
(h % 100.0) / 100.0 - 0.5
}
#[cfg(test)]
mod tests {
use super::prelude::*;
#[test]
fn test_binary_classification_xor() {
let mut nn = FeedForward::<Sigmoid, 2, 4, 1>::new();
let training_data = [
([0.0, 0.0], [0.0]),
([0.0, 1.0], [1.0]),
([1.0, 0.0], [1.0]),
([1.0, 1.0], [0.0]),
];
for _ in 0..10_000 {
for (input, target) in &training_data {
let input = SVector::<f64, 2>::from_column_slice(input);
let target = SVector::<f64, 1>::from_column_slice(target);
nn.train(&input, &target, 0.1);
}
}
for (input, expected) in &training_data {
let input = SVector::<f64, 2>::from_column_slice(input);
let output = nn.forward(&input);
assert!((output[0] - expected[0]).abs() < 0.2);
}
}
#[test]
fn test_regression_sine_wave() {
let mut nn = FeedForward::<Sigmoid, 1, 8, 1>::new();
let training_data: [(f64, f64); 8] = [
(0.0, 0.0),
(0.25, 0.707),
(0.5, 1.0),
(0.75, 0.707),
(1.0, 0.0),
(1.25, -0.707),
(1.5, -1.0),
(1.75, -0.707),
];
for _ in 0..10_000 {
for &(x, y) in &training_data {
let input = SVector::<f64, 1>::from_column_slice(&[x]);
let target = SVector::<f64, 1>::from_column_slice(&[y]);
nn.train(&input, &target, 0.05);
}
}
let test_x = 0.5; let input = SVector::<f64, 1>::from_column_slice(&[test_x]);
let output = nn.forward(&input);
assert!((output[0] - 1.0).abs() < 0.2);
}
#[test]
fn test_pattern_recognition() {
let mut nn = FeedForward::<Sigmoid, 9, 5, 1>::new();
let x_pattern = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
let o_pattern = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0];
for _ in 0..1000 {
let input = SVector::<f64, 9>::from_column_slice(&x_pattern);
let target = SVector::<f64, 1>::from_column_slice(&[1.0]);
nn.train(&input, &target, 0.1);
let input = SVector::<f64, 9>::from_column_slice(&o_pattern);
let target = SVector::<f64, 1>::from_column_slice(&[0.0]);
nn.train(&input, &target, 0.1);
}
let input = SVector::<f64, 9>::from_column_slice(&x_pattern);
let output = nn.forward(&input);
assert!(output[0] > 0.8);
let input = SVector::<f64, 9>::from_column_slice(&o_pattern);
let output = nn.forward(&input);
assert!(output[0] < 0.2); }
#[test]
fn test_network_stability() {
let nn = FeedForward::<Sigmoid, 3, 4, 2>::new();
let input = SVector::<f64, 3>::from_column_slice(&[0.5, 0.5, 0.5]);
let first_output = nn.forward(&input);
let second_output = nn.forward(&input);
assert_eq!(first_output, second_output);
let perturbed_input = SVector::<f64, 3>::from_column_slice(&[0.51, 0.5, 0.5]);
let perturbed_output = nn.forward(&perturbed_input);
assert!((perturbed_output[0] - first_output[0]).abs() < 0.1);
}
#[test]
fn test_learning_convergence() {
let mut nn = FeedForward::<Sigmoid, 1, 3, 1>::new();
let input = SVector::<f64, 1>::from_column_slice(&[0.5]);
let target = SVector::<f64, 1>::from_column_slice(&[1.0]);
let initial_error = (nn.forward(&input)[0] - target[0]).abs();
for _ in 0..1000 {
nn.train(&input, &target, 0.1);
}
let final_error = (nn.forward(&input)[0] - target[0]).abs();
assert!(final_error < initial_error);
}
}