mse

Function mse 

Source
pub fn mse(predicted: &Tensor, target: &Tensor) -> Tensor 
Expand description

Mean squared error function

Examples found in repository?
examples/tiny_nn.rs (line 35)
12fn main() {
13    // Initialize logging library
14    Logger::new().init().unwrap();
15
16    #[rustfmt::skip]
17    let train_data = tensor![
18        [0.0, 0.0, 1.0],
19        [1.0, 1.0, 1.0],
20        [1.0, 0.0, 1.0],
21        [0.0, 1.0, 1.0]];
22    #[rustfmt::skip]
23    let train_labels = tensor![
24        [0.0],
25        [1.0],
26        [1.0],
27        [0.0]
28    ].reshape([4, 1]);
29    let weights = Tensor::rand([3, 1]);
30    let biases = Tensor::rand([4, 1]);
31    println!("Weights before training:\n{:?}", weights);
32    let now = Instant::now();
33    for epoch in 0..(EPOCHS + 1) {
34        let output = forward_pass(&train_data, &weights, &biases);
35        let loss = elara_math::mse(&output, &train_labels);
36        println!("Epoch {}, loss {:?}", epoch, loss);
37        loss.backward();
38        weights.update(LR);
39        weights.zero_grad();
40        biases.update(LR);
41        biases.zero_grad();
42    }
43    println!("{:?}", now.elapsed());
44    let pred_data = tensor![[1.0, 0.0, 0.0]];
45    let pred = forward_pass(&pred_data, &weights, &biases);
46    println!("Weights after training:\n{:?}", weights);
47    println!("Prediction [1, 0, 0] -> {:?}", pred);
48}