1extern crate alpha_micrograd_rust;
11
12use alpha_micrograd_rust::nn::{Activation, MLP};
13use alpha_micrograd_rust::value::Expr;
14
15fn main() {
16 let mut targets = vec![
17 Expr::new_leaf(150.0, "t1"),
18 Expr::new_leaf(250.0, "t2"),
19 Expr::new_leaf(350.0, "t3"),
20 ];
21 targets.iter_mut().for_each(|target| {
22 target.is_learnable = false;
23 });
24
25 let mlp = MLP::new(
26 3,
27 Activation::Tanh,
28 vec![2, 2],
29 Activation::Tanh,
30 1,
31 Activation::None,
32 );
33 println!("Initial values: {:}", mlp);
34
35 let mut inputs = vec![
36 vec![
37 Expr::new_leaf(1.0, "x_1,1"),
38 Expr::new_leaf(2.0, "x_1,2"),
39 Expr::new_leaf(3.0, "x_1,3"),
40 ],
41 vec![
42 Expr::new_leaf(4.0, "x_2,1"),
43 Expr::new_leaf(5.0, "x_2,2"),
44 Expr::new_leaf(6.0, "x_2,3"),
45 ],
46 vec![
47 Expr::new_leaf(7.0, "x_3,1"),
48 Expr::new_leaf(8.0, "x_3,2"),
49 Expr::new_leaf(9.0, "x_3,3"),
50 ],
51 ];
52
53 inputs.iter_mut().for_each(|instance| {
54 instance.iter_mut().for_each(|value| {
55 value.is_learnable = false;
56 });
57 });
58
59 let predictions = inputs
60 .iter()
62 .map(|example| mlp.forward(example.clone()))
63 .enumerate()
65 .map(|(i, mut y)| {
66 let mut result = y.remove(0);
68 result.name = format!("y{:}", i + 1).to_string();
69 result
70 })
71 .collect::<Vec<_>>();
73
74 let differences = predictions
75 .iter()
76 .zip(targets.iter())
77 .map(|(y, t)| y.clone() - t.clone())
78 .collect::<Vec<_>>();
79 let mut loss = differences
80 .iter()
81 .map(|d| d.clone() * d.clone())
82 .sum::<Expr>();
83
84 let y1 = loss.find("y1").unwrap();
85 let y2 = loss.find("y2").unwrap();
86 let y3 = loss.find("y3").unwrap();
87 println!("Initial loss: {:.2}", loss.result);
88 println!(
89 "Initial predictions: {:5.2} {:5.2} {:5.2}",
90 y1.result, y2.result, y3.result
91 );
92
93 println!("\nTraining:");
94 let learning_rate = 0.025;
95 for i in 1..=100 {
96 loss.learn(learning_rate);
97 loss.recalculate();
98
99 let t1 = loss.find("t1").unwrap();
100 let t2 = loss.find("t2").unwrap();
101 let t3 = loss.find("t3").unwrap();
102
103 let y1 = loss.find("y1").unwrap();
104 let y2 = loss.find("y2").unwrap();
105 let y3 = loss.find("y3").unwrap();
106
107 println!(
108 "Iteration {:3}, loss: {:11.4} / predicted: {:5.2}, {:5.2}, {:5.2} (targets: {:5.2}, {:5.2}, {:5.2})",
109 i, loss.result, y1.result, y2.result, y3.result, t1.result, t2.result, t3.result
110 );
111 }
112
113 println!("Final values: {:}", mlp);
114}