simple_training/
simple_training.rs1use axonml::prelude::*;
18
19fn main() {
20 println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22 println!("Version: {}", axonml::version());
24 println!("Features: {}\n", axonml::features());
25
26 println!("1. Creating XOR dataset...");
28 let inputs = vec![
29 vec![0.0, 0.0],
30 vec![0.0, 1.0],
31 vec![1.0, 0.0],
32 vec![1.0, 1.0],
33 ];
34 let targets = vec![0.0, 1.0, 1.0, 0.0]; println!(" Inputs: {inputs:?}");
37 println!(" Targets: {targets:?}\n");
38
39 println!("2. Creating MLP model (2 -> 4 -> 1)...");
41 let linear1 = Linear::new(2, 4);
42 let linear2 = Linear::new(4, 1);
43
44 println!(" Layer 1: Linear(2, 4)");
45 println!(" Layer 2: Linear(4, 1)\n");
46
47 println!("3. Creating Adam optimizer (lr=0.1)...");
49 let params = [linear1.parameters(), linear2.parameters()].concat();
50 let mut optimizer = Adam::new(params, 0.1);
51 println!(" Optimizer created!\n");
52
53 println!("4. Training for 1000 epochs...");
55 let epochs = 1000;
56
57 for epoch in 0..epochs {
58 let mut total_loss = 0.0;
59
60 for (input, &target) in inputs.iter().zip(targets.iter()) {
61 let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64 let h = linear1.forward(&x);
66 let h = h.sigmoid();
67 let output = linear2.forward(&h);
68 let output = output.sigmoid();
69
70 let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73 let diff = output.sub_var(&y);
75 let loss = diff.mul_var(&diff);
76
77 total_loss += loss.data().to_vec()[0];
78
79 loss.backward();
81
82 optimizer.step();
84 optimizer.zero_grad();
85 }
86
87 if epoch % 200 == 0 || epoch == epochs - 1 {
88 println!(" Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89 }
90 }
91
92 println!("\n5. Testing trained model...");
94 for (input, &expected) in inputs.iter().zip(targets.iter()) {
95 let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97 let h = linear1.forward(&x);
98 let h = h.sigmoid();
99 let output = linear2.forward(&h);
100 let output = output.sigmoid();
101
102 let pred = output.data().to_vec()[0];
103 let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105 println!(
106 " Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107 );
108 }
109
110 println!("\n=== Training Complete! ===");
111}