pub fn train_step(
dev: &GpuDevice,
opt: &mut AdamW,
step_num: u32,
forward_fn: impl FnOnce(&mut Tape<'_>) -> Result<(TensorId, Vec<TensorId>)>,
) -> Result<StepResult>Expand description
Train an MLP (or any differentiable graph) for one step.
forward_fn builds the computation graph on the tape and returns (loss_id, param_ids).
The training loop runs backward, extracts gradients, and updates params.