tiny_nn/
tiny_nn.rs

1use elara_log::prelude::*;
2use elara_math::prelude::*;
3use std::time::Instant;
4
5const EPOCHS: usize = 10000;
6const LR: f64 = 1e-5;
7
8fn forward_pass(data: &Tensor, weights: &Tensor, biases: &Tensor) -> Tensor {
9    (&data.matmul(&weights) + biases).relu()
10}
11
12fn main() {
13    // Initialize logging library
14    Logger::new().init().unwrap();
15
16    #[rustfmt::skip]
17    let train_data = tensor![
18        [0.0, 0.0, 1.0],
19        [1.0, 1.0, 1.0],
20        [1.0, 0.0, 1.0],
21        [0.0, 1.0, 1.0]];
22    #[rustfmt::skip]
23    let train_labels = tensor![
24        [0.0],
25        [1.0],
26        [1.0],
27        [0.0]
28    ].reshape([4, 1]);
29    let weights = Tensor::rand([3, 1]);
30    let biases = Tensor::rand([4, 1]);
31    println!("Weights before training:\n{:?}", weights);
32    let now = Instant::now();
33    for epoch in 0..(EPOCHS + 1) {
34        let output = forward_pass(&train_data, &weights, &biases);
35        let loss = elara_math::mse(&output, &train_labels);
36        println!("Epoch {}, loss {:?}", epoch, loss);
37        loss.backward();
38        weights.update(LR);
39        weights.zero_grad();
40        biases.update(LR);
41        biases.zero_grad();
42    }
43    println!("{:?}", now.elapsed());
44    let pred_data = tensor![[1.0, 0.0, 0.0]];
45    let pred = forward_pass(&pred_data, &weights, &biases);
46    println!("Weights after training:\n{:?}", weights);
47    println!("Prediction [1, 0, 0] -> {:?}", pred);
48}