simple_linear_train_script/
simple-linear-train-script.rs

1use hodu::prelude::*;
2use std::time::Instant;
3
4fn main() -> Result<(), Box<dyn std::error::Error>> {
5    let input_data: Vec<Vec<f32>> = (0..10000)
6        .map(|i| {
7            vec![
8                (i % 100) as f32 / 100.0,
9                ((i % 100) + 1) as f32 / 100.0,
10                ((i % 100) + 2) as f32 / 100.0,
11            ]
12        })
13        .collect();
14    let target_data: Vec<Vec<f32>> = (0..10000).map(|i| vec![((i % 100) * 10) as f32 / 1000.0]).collect();
15
16    let input_tensor = Tensor::new(input_data)?;
17    let target_tensor = Tensor::new(target_data)?;
18
19    // Build script
20    let builder = Builder::new("linear_training".to_string());
21    builder.start()?;
22
23    let mut linear = Linear::new(3, 1, true, DType::F32)?;
24    let mse_loss = MSE::new();
25    let mut optimizer = SGD::new(0.01);
26
27    let input = Tensor::input("input", &[10000, 3])?;
28    input.requires_grad()?;
29    let target = Tensor::input("target", &[10000, 1])?;
30
31    let epochs = 1000;
32    let mut final_loss = Tensor::full(&[], 0.0)?;
33
34    for _ in 0..epochs {
35        let pred = linear.forward(&input)?;
36        let loss = mse_loss.forward((&pred, &target))?;
37
38        loss.backward()?;
39
40        optimizer.step(&mut linear.parameters())?;
41        optimizer.zero_grad(&mut linear.parameters())?;
42
43        final_loss = loss;
44    }
45
46    let params = linear.parameters();
47    builder.add_output("loss", final_loss)?;
48    builder.add_output("weight", *params[0])?;
49    builder.add_output("bias", *params[1])?;
50
51    builder.end()?;
52
53    let mut script = builder.build()?;
54    #[cfg(feature = "xla")]
55    script.set_backend(Backend::XLA);
56
57    script.add_input("input", input_tensor);
58    script.add_input("target", target_tensor);
59
60    println!("Compiling script...");
61    let compile_start = Instant::now();
62    script.compile()?;
63    let compile_elapsed = compile_start.elapsed();
64    println!("Compilation time: {:?}", compile_elapsed);
65
66    println!("Running script...");
67    let run_start = Instant::now();
68    let output = script.run()?;
69    let run_elapsed = run_start.elapsed();
70
71    println!("Loss: {}", output["loss"]);
72    println!("Weight: {}", output["weight"]);
73    println!("Bias: {}", output["bias"]);
74    println!("Execution time: {:?}", run_elapsed);
75    println!("Total time: {:?}", compile_elapsed + run_elapsed);
76
77    Ok(())
78}