Skip to main content

mixed_precision/
main.rs

1//! Mixed precision training — float16 forward, float32 gradients.
2//!
3//! Demonstrates `cast_parameters` + `GradScaler` workflow for training
4//! with reduced precision. The scaler dynamically adjusts loss scaling
5//! to prevent gradient underflow in float16.
6//!
7//! Run: `cargo run --example mixed_precision`
8
9use flodl::*;
10use flodl::monitor::Monitor;
11
12fn main() -> Result<()> {
13    let opts = TensorOptions::default();
14
15    // Random regression data.
16    let batches: Vec<(Tensor, Tensor)> = (0..4)
17        .map(|_| {
18            let x = Tensor::randn(&[50, 8], opts).unwrap();
19            let y = Tensor::randn(&[50, 4], opts).unwrap();
20            (x, y)
21        })
22        .collect();
23
24    // Build the model (starts in float32).
25    let model = FlowBuilder::from(Linear::new(8, 32)?)
26        .through(GELU)
27        .through(LayerNorm::new(32)?)
28        .also(Linear::new(32, 32)?)
29        .through(Linear::new(32, 4)?)
30        .build()?;
31
32    let params = model.parameters();
33
34    // Cast parameters to float16 for reduced memory and faster matmuls.
35    cast_parameters(&params, DType::Float16);
36    println!("Parameters cast to float16");
37
38    let mut optimizer = Adam::new(&params, 0.001);
39    let mut scaler = GradScaler::new();
40    model.train();
41
42    let num_epochs = 100usize;
43    let mut monitor = Monitor::new(num_epochs);
44
45    for epoch in 0..num_epochs {
46        let t = std::time::Instant::now();
47        let mut steps_taken = 0u32;
48
49        for (xb, yb) in &batches {
50            // Cast inputs to match parameter dtype.
51            let input = Variable::new(xb.to_dtype(DType::Float16)?, true);
52            let target = Variable::new(yb.to_dtype(DType::Float16)?, false);
53
54            optimizer.zero_grad();
55            let pred = model.forward(&input)?;
56            let loss = mse_loss(&pred, &target)?;
57
58            // Scale loss before backward to prevent gradient underflow.
59            let scaled_loss = scaler.scale(&loss)?;
60            scaled_loss.backward()?;
61
62            // Unscale, check for inf/nan, step if clean.
63            let stepped = scaler.step(&params, &mut || optimizer.step())?;
64            scaler.update();
65
66            if stepped {
67                steps_taken += 1;
68            }
69
70            model.record_scalar("loss", loss.item()?);
71        }
72
73        model.record_scalar("scale", scaler.scale_factor());
74        model.flush(&[]);
75        monitor.log(epoch, t.elapsed(), &model);
76
77        if steps_taken == 0 {
78            println!("epoch {}: all steps skipped (inf grads), scale={:.0}", epoch, scaler.scale_factor());
79        }
80    }
81
82    monitor.finish();
83
84    // Cast back to float32 for inference or checkpointing.
85    cast_parameters(&params, DType::Float32);
86    println!("Parameters cast back to float32 for export");
87
88    Ok(())
89}