Skip to main content

train_step

Function train_step 

Source
pub fn train_step(
    dev: &GpuDevice,
    opt: &mut AdamW,
    step_num: u32,
    forward_fn: impl FnOnce(&mut Tape<'_>) -> Result<(TensorId, Vec<TensorId>)>,
) -> Result<StepResult>
Expand description

Train an MLP (or any differentiable graph) for one step. forward_fn builds the computation graph on the tape and returns (loss_id, param_ids). The training loop runs backward, extracts gradients, and updates params.