backpropagation/
backpropagation.rs

1/// This example shows how to implement a simple adaptive linear neuron using the
2/// backpropagation algorithm.
3///
4/// We will compute the expression `x1*w1 + x2*w2 + b` and apply a non-linear activation
5/// function to it: `tanh(x1*w1 + x2*w2 + b)`.
6///
7/// We will only allow the values of `w1`, `w2` and `b` to be learned to reach a fixed
8/// target value.
9/// 
10/// For more information, see [`Expr`].
11extern crate alpha_micrograd_rust;
12
13use alpha_micrograd_rust::value::Expr;
14
15fn main() {
16    // these are the initial values for the nodes of the graph
17    let mut x1 = Expr::new_leaf(2.0, "x1");
18    x1.is_learnable = false;
19
20    let mut x2 = Expr::new_leaf(1.0, "x2");
21    x2.is_learnable = false;
22
23    let w1 = Expr::new_leaf(-3.0, "w1");
24    let w2 = Expr::new_leaf(1.0, "w2");
25    let b = Expr::new_leaf(6.5, "b");
26
27    // here we compute the expression x1*w1 + x2*w2 + b
28    let x1w1 = x1 * w1;
29    let x2w2 = x2 * w2;
30    let x1w1_x2w2 = x1w1 + x2w2;
31    let n = x1w1_x2w2 + b;
32
33    // we add a non-linear activation function: tanh(x1*w1 + x2*w2 + b)
34    let o = n.tanh("o");
35
36    println!("Initial output: {:.2}", o.result);
37
38    // we set the target value
39    let target_value = 0.2;
40    let mut target = Expr::new_leaf(target_value, "target");
41    target.is_learnable = false;
42
43    // we compute the loss function
44    let mut squared_exponent = Expr::new_leaf(2.0, "squared_exponent");
45    squared_exponent.is_learnable = false;
46
47    let mut loss = (o - target).pow(squared_exponent, "loss");
48    loss.is_learnable = false;
49
50    // we print the initial loss
51    println!("Initial loss: {:.4}", loss.result);
52
53    println!("\nTraining:");
54    let learning_rate = 0.01;
55    for i in 1..=50 {
56        loss.learn(learning_rate);
57        loss.recalculate();
58
59        let target = loss.find("o").expect("Node not found");
60
61        println!(
62            "Iteration {:2}, loss: {:.4} / result: {:.2}",
63            i, loss.result, target.result
64        );
65    }
66
67    let w1 = loss.find("w1").expect("Node not found");
68    let w2 = loss.find("w2").expect("Node not found");
69    let b = loss.find("b").expect("Node not found");
70
71    println!(
72        "\nFinal values: w1: {:.2}, w2: {:.2}, b: {:.2}",
73        w1.result, w2.result, b.result
74    );
75
76    let x1 = loss.find("x1").expect("Node not found");
77    let x2 = loss.find("x2").expect("Node not found");
78
79    let n = loss
80        .find("(((x1 * w1) + (x2 * w2)) + b)") // auto-generated node name
81        .expect("Node not found");
82    let o = loss.find("o").expect("Node not found");
83
84    println!(
85        "Final formula: tanh({:.2}*{:.2} + {:.2}*{:.2} + {:.2}) = tanh({:.2}) = {:.2}",
86        x1.result, w1.result, x2.result, w2.result, b.result, n.result, o.result
87    )
88}