simple_linear_train_script/
simple-linear-train-script.rs1use hodu::prelude::*;
2use std::time::Instant;
3
4fn main() -> Result<(), Box<dyn std::error::Error>> {
5 let input_data: Vec<Vec<f32>> = (0..10000)
6 .map(|i| {
7 vec![
8 (i % 100) as f32 / 100.0,
9 ((i % 100) + 1) as f32 / 100.0,
10 ((i % 100) + 2) as f32 / 100.0,
11 ]
12 })
13 .collect();
14 let target_data: Vec<Vec<f32>> = (0..10000).map(|i| vec![((i % 100) * 10) as f32 / 1000.0]).collect();
15
16 let input_tensor = Tensor::new(input_data)?;
17 let target_tensor = Tensor::new(target_data)?;
18
19 let builder = Builder::new("linear_training".to_string());
21 builder.start()?;
22
23 let mut linear = Linear::new(3, 1, true, DType::F32)?;
24 let mse_loss = MSE::new();
25 let mut optimizer = SGD::new(0.01);
26
27 let input = Tensor::input("input", &[10000, 3])?;
28 input.requires_grad()?;
29 let target = Tensor::input("target", &[10000, 1])?;
30
31 let epochs = 1000;
32 let mut final_loss = Tensor::full(&[], 0.0)?;
33
34 for _ in 0..epochs {
35 let pred = linear.forward(&input)?;
36 let loss = mse_loss.forward((&pred, &target))?;
37
38 loss.backward()?;
39
40 optimizer.step(&mut linear.parameters())?;
41 optimizer.zero_grad(&mut linear.parameters())?;
42
43 final_loss = loss;
44 }
45
46 let params = linear.parameters();
47 builder.add_output("loss", final_loss)?;
48 builder.add_output("weight", *params[0])?;
49 builder.add_output("bias", *params[1])?;
50
51 builder.end()?;
52
53 let mut script = builder.build()?;
54 #[cfg(feature = "xla")]
55 script.set_backend(Backend::XLA);
56
57 script.add_input("input", input_tensor);
58 script.add_input("target", target_tensor);
59
60 println!("Compiling script...");
61 let compile_start = Instant::now();
62 script.compile()?;
63 let compile_elapsed = compile_start.elapsed();
64 println!("Compilation time: {:?}", compile_elapsed);
65
66 println!("Running script...");
67 let run_start = Instant::now();
68 let output = script.run()?;
69 let run_elapsed = run_start.elapsed();
70
71 println!("Loss: {}", output["loss"]);
72 println!("Weight: {}", output["weight"]);
73 println!("Bias: {}", output["bias"]);
74 println!("Execution time: {:?}", run_elapsed);
75 println!("Total time: {:?}", compile_elapsed + run_elapsed);
76
77 Ok(())
78}