1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
pub mod activations;
pub mod layers;
pub trait NeuralNetwork<const I: usize, const O: usize> {
fn forward(&self, input: [f32; I]) -> [f32; O];
}
#[cfg(test)]
mod tests {
use super::NeuralNetwork;
use crate::activations::ReLu;
use crate::layers::Dense;
use mushin_derive::NeuralNetwork;
use rand::{distributions::Uniform, SeedableRng};
use rand_chacha::ChaCha8Rng;
#[derive(NeuralNetwork)]
struct TestNetwork {
input: Dense<ReLu, 2, 3>,
hidden: Dense<ReLu, 3, 3>,
output: Dense<ReLu, 3, 1>,
}
#[test]
fn network_forward() {
let mut rng = ChaCha8Rng::from_seed(Default::default());
let dist = Uniform::from(-1.0..=1.0);
let nn = TestNetwork {
input: Dense::random(&mut rng, &dist),
hidden: Dense::random(&mut rng, &dist),
output: Dense::random(&mut rng, &dist),
};
let output = nn.forward([1.0, 1.0]);
approx::assert_relative_eq!(output[..], [0.0]);
}
}