pub fn mse(predicted: &Tensor, target: &Tensor) -> Tensor ⓘExpand description
Mean squared error function
Examples found in repository?
examples/tiny_nn.rs (line 35)
12fn main() {
13 // Initialize logging library
14 Logger::new().init().unwrap();
15
16 #[rustfmt::skip]
17 let train_data = tensor![
18 [0.0, 0.0, 1.0],
19 [1.0, 1.0, 1.0],
20 [1.0, 0.0, 1.0],
21 [0.0, 1.0, 1.0]];
22 #[rustfmt::skip]
23 let train_labels = tensor![
24 [0.0],
25 [1.0],
26 [1.0],
27 [0.0]
28 ].reshape([4, 1]);
29 let weights = Tensor::rand([3, 1]);
30 let biases = Tensor::rand([4, 1]);
31 println!("Weights before training:\n{:?}", weights);
32 let now = Instant::now();
33 for epoch in 0..(EPOCHS + 1) {
34 let output = forward_pass(&train_data, &weights, &biases);
35 let loss = elara_math::mse(&output, &train_labels);
36 println!("Epoch {}, loss {:?}", epoch, loss);
37 loss.backward();
38 weights.update(LR);
39 weights.zero_grad();
40 biases.update(LR);
41 biases.zero_grad();
42 }
43 println!("{:?}", now.elapsed());
44 let pred_data = tensor![[1.0, 0.0, 0.0]];
45 let pred = forward_pass(&pred_data, &weights, &biases);
46 println!("Weights after training:\n{:?}", weights);
47 println!("Prediction [1, 0, 0] -> {:?}", pred);
48}