tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Training step driver.
//!
//! Composes a user-supplied forward + loss + backward with the
//! `Parameter::adamw_step` (or HIP AdamW) into a single
//! `train_step` call. Returns the loss and the per-parameter
//! gradient norm for the metrics logger.
//!
// Training step driver: composes forward, loss, backward, and
// optimizer step into a single `train_step` call.
//
// The forward and backward are user-supplied closures (or function
// pointers) so the API stays generic over the model. The user is
// responsible for parameter grads: the backward closure receives
// (predictions, target) and returns a `BackwardOutput` whose
// `param_grads` vector has one entry per trainable parameter.
//
// Loss is computed in fp32 on the CPU for this pilot phase. The fp32
// param grads are then cast to fp16 inside `AdamW::step` before
// being shipped to the ROCm/HIP AdamW kernel.

use crate::object::Tensor;

use super::adamw::AdamW;
use super::parameter::Parameter;

/// Output of a backward pass: scalar loss, gradient of the loss w.r.t.
/// the model output, and a flat fp32 gradient vector per trainable
/// parameter.
#[derive(Debug, Clone)]
pub struct BackwardOutput {
    /// Scalar loss for the batch.
    pub loss: f32,
    /// Gradient of the loss w.r.t. the model predictions (same shape
    /// as the predictions).
    pub grad_output: Tensor<f32>,
    /// Per-parameter fp32 gradient. `param_grads[i]` is the gradient
    /// for `params[i].weight`, flattened in the same order as the
    /// weight's data buffer.
    pub param_grads: Vec<Vec<f32>>,
}

/// Generic training step. Runs forward -> backward -> AdamW update.
///
/// * `inputs` - batch of input tensors (forwarded as-is to the user's
///   closure).
/// * `target` - target tensor.
/// * `loss_fn_name` - name of the loss family used by the user. The
///   driver itself does not interpret the name; the user-provided
///   backward closure is responsible for computing the loss. The
///   name is recorded for telemetry and parameter-grad conventions.
/// * `params` - mutable slice of trainable parameters. Each parameter's
///   weight, m, v, and step are updated in place.
/// * `forward` - closure that maps a slice of input tensors to a
///   single prediction tensor.
/// * `backward` - closure that takes (predictions, target) and
///   returns a `BackwardOutput`. The host-side loss is fp32; the
///   `param_grads` are fp32 and will be rounded to fp16 inside
///   `AdamW::step`.
/// * `optimizer` - mutable AdamW optimizer that will be stepped once
///   per parameter.
pub fn train_step<F, B>(
    inputs: &[Tensor<f32>],
    target: &Tensor<f32>,
    _loss_fn_name: &str,
    params: &mut [Parameter],
    forward: F,
    backward: B,
    optimizer: &mut AdamW,
) -> crate::Result<f32>
where
    F: Fn(&[Tensor<f32>]) -> Tensor<f32>,
    B: Fn(&Tensor<f32>, &Tensor<f32>) -> BackwardOutput,
{
    // 1. Forward pass.
    let predictions = forward(inputs);

    // 2. Backward pass (computes loss and per-parameter grads in fp32).
    let backward_out = backward(&predictions, target);

    // 3. Sanity check: param_grads must align with `params`.
    if backward_out.param_grads.len() != params.len() {
        return Err(crate::Error::backend(format!(
            "backward returned {} param_grads, expected {}",
            backward_out.param_grads.len(),
            params.len()
        )));
    }

    // 4. Apply the optimizer step to every parameter.
    for (param, grad_f32) in params.iter_mut().zip(backward_out.param_grads.iter()) {
        optimizer.step(param, grad_f32)?;
    }

    Ok(backward_out.loss)
}

/// Reference MSE loss + backward that operates on flat fp32 buffers.
/// Returns the per-element mean-squared-error loss and a flat fp32
/// gradient buffer of the same length. Provided as a convenience so
/// smoke tests and small linear models do not have to hand-roll the
/// MSE math.
pub fn mse_loss_backward(predictions: &[f32], target: &[f32]) -> (f32, Vec<f32>) {
    assert_eq!(
        predictions.len(),
        target.len(),
        "mse_loss_backward length mismatch"
    );
    let n = predictions.len().max(1);
    let mut loss = 0.0f32;
    let mut grad = Vec::with_capacity(predictions.len());
    for (p, t) in predictions.iter().zip(target.iter()) {
        let diff = p - t;
        loss += diff * diff;
        grad.push(2.0 * diff / n as f32);
    }
    loss /= n as f32;
    (loss, grad)
}