1use elara_log::prelude::*;
2use elara_math::prelude::*;
3use std::time::Instant;
4
5const EPOCHS: usize = 10000;
6const LR: f64 = 1e-5;
7
8fn forward_pass(data: &Tensor, weights: &Tensor, biases: &Tensor) -> Tensor {
9 (&data.matmul(&weights) + biases).relu()
10}
11
12fn main() {
13 Logger::new().init().unwrap();
15
16 #[rustfmt::skip]
17 let train_data = tensor![
18 [0.0, 0.0, 1.0],
19 [1.0, 1.0, 1.0],
20 [1.0, 0.0, 1.0],
21 [0.0, 1.0, 1.0]];
22 #[rustfmt::skip]
23 let train_labels = tensor![
24 [0.0],
25 [1.0],
26 [1.0],
27 [0.0]
28 ].reshape([4, 1]);
29 let weights = Tensor::rand([3, 1]);
30 let biases = Tensor::rand([4, 1]);
31 println!("Weights before training:\n{:?}", weights);
32 let now = Instant::now();
33 for epoch in 0..(EPOCHS + 1) {
34 let output = forward_pass(&train_data, &weights, &biases);
35 let loss = elara_math::mse(&output, &train_labels);
36 println!("Epoch {}, loss {:?}", epoch, loss);
37 loss.backward();
38 weights.update(LR);
39 weights.zero_grad();
40 biases.update(LR);
41 biases.zero_grad();
42 }
43 println!("{:?}", now.elapsed());
44 let pred_data = tensor![[1.0, 0.0, 0.0]];
45 let pred = forward_pass(&pred_data, &weights, &biases);
46 println!("Weights after training:\n{:?}", weights);
47 println!("Prediction [1, 0, 0] -> {:?}", pred);
48}