1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! Training step driver.
//!
//! Composes a user-supplied forward + loss + backward with the
//! `Parameter::adamw_step` (or HIP AdamW) into a single
//! `train_step` call. Returns the loss and the per-parameter
//! gradient norm for the metrics logger.
//!
// Training step driver: composes forward, loss, backward, and
// optimizer step into a single `train_step` call.
//
// The forward and backward are user-supplied closures (or function
// pointers) so the API stays generic over the model. The user is
// responsible for parameter grads: the backward closure receives
// (predictions, target) and returns a `BackwardOutput` whose
// `param_grads` vector has one entry per trainable parameter.
//
// Loss is computed in fp32 on the CPU for this pilot phase. The fp32
// param grads are then cast to fp16 inside `AdamW::step` before
// being shipped to the ROCm/HIP AdamW kernel.
use crateTensor;
use AdamW;
use Parameter;
/// Output of a backward pass: scalar loss, gradient of the loss w.r.t.
/// the model output, and a flat fp32 gradient vector per trainable
/// parameter.
/// Generic training step. Runs forward -> backward -> AdamW update.
///
/// * `inputs` - batch of input tensors (forwarded as-is to the user's
/// closure).
/// * `target` - target tensor.
/// * `loss_fn_name` - name of the loss family used by the user. The
/// driver itself does not interpret the name; the user-provided
/// backward closure is responsible for computing the loss. The
/// name is recorded for telemetry and parameter-grad conventions.
/// * `params` - mutable slice of trainable parameters. Each parameter's
/// weight, m, v, and step are updated in place.
/// * `forward` - closure that maps a slice of input tensors to a
/// single prediction tensor.
/// * `backward` - closure that takes (predictions, target) and
/// returns a `BackwardOutput`. The host-side loss is fp32; the
/// `param_grads` are fp32 and will be rounded to fp16 inside
/// `AdamW::step`.
/// * `optimizer` - mutable AdamW optimizer that will be stepped once
/// per parameter.
/// Reference MSE loss + backward that operates on flat fp32 buffers.
/// Returns the per-element mean-squared-error loss and a flat fp32
/// gradient buffer of the same length. Provided as a convenience so
/// smoke tests and small linear models do not have to hand-roll the
/// MSE math.