Skip to main content

simple_training/
simple_training.rs

1//! Simple Training Example
2//!
3//! # File
4//! `crates/axonml/examples/simple_training.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml::prelude::*;
18
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}