tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! AdamW optimizer wrapper.
//!
//! Holds the hyper-parameters (lr, beta1, beta2, eps, weight_decay)
//! and a global 1-based step counter. Exposes a `step` method
//! that calls the ROCm/HIP `run_rocm_hip_adamw_step` kernel
//! against a `Parameter`. Pure-CPU AdamW (via
//! `Parameter::adamw_step`) is the default; the HIP kernel is
//! gated on `rocm-hip`.
//!
// AdamW optimizer wrapper. Holds the hyper-parameters and a global
// 1-based step counter, and exposes a `step` method that calls the
// ROCm/HIP `run_rocm_hip_adamw_step` kernel against a `Parameter`.
//
// The fp32 gradient is rounded to fp16 (round-to-nearest-even,
// matching the kernel's `__float2half_rn`) before being shipped to
// the device, mirroring the in-kernel grad load that already does the
// same cast on the GPU side. This keeps the host-side dispatch byte
// compatible with the kernel's `__half` input.

use crate::Result;

use super::hip_adamw_bridge::run_rocm_hip_adamw_step;
use super::parameter::{Parameter, f32_to_fp16_bits};

/// AdamW optimizer (decoupled weight decay) over fp16 weights with
/// fp32 moments. The hyper-parameters match the Phase 1 kernel
/// contract exactly: bias correction uses `1 - beta1^t` and
/// `1 - beta2^t` with the global step counter `t`.
#[derive(Debug, Clone)]
pub struct AdamW {
    /// Learning rate.
    pub lr: f32,
    /// First-moment decay rate.
    pub beta1: f32,
    /// Second-moment decay rate.
    pub beta2: f32,
    /// Numerical-stability epsilon inside the denominator.
    pub eps: f32,
    /// Decoupled weight decay coefficient.
    pub weight_decay: f32,
    /// Global 1-based step counter (shared across all parameters so
    /// bias correction is consistent).
    pub step: u32,
}

impl AdamW {
    /// Build an AdamW optimizer. The step counter starts at 0; the
    /// first call to `step` increments it to 1 before invoking the
    /// kernel.
    pub fn new(lr: f32, beta1: f32, beta2: f32, eps: f32, weight_decay: f32) -> Self {
        Self {
            lr,
            beta1,
            beta2,
            eps,
            weight_decay,
            step: 0,
        }
    }

    /// Apply one AdamW update to `param` using the fp32 gradient
    /// `grad`. The gradient is rounded to fp16 in place (matching the
    /// kernel's `__half2float` decode of the grad bit pattern). The
    /// parameter's weight, m, v, and step are all updated in place.
    pub fn step(&mut self, param: &mut Parameter, grad: &[f32]) -> Result<()> {
        let n = param.weight.data.len();
        if grad.len() != n {
            return Err(crate::Error::backend(format!(
                "AdamW step grad length {} does not match parameter length {}",
                grad.len(),
                n
            )));
        }
        let grad_bits: Vec<u16> = grad.iter().map(|&g| f32_to_fp16_bits(g)).collect();
        // Increment both the global optimizer step and the
        // per-parameter step so callers can inspect either.
        self.step = self.step.saturating_add(1);
        param.step = param.step.saturating_add(1);
        run_rocm_hip_adamw_step(
            &mut param.weight.data,
            &mut param.m,
            &mut param.v,
            &grad_bits,
            self.lr,
            self.beta1,
            self.beta2,
            self.eps,
            self.weight_decay,
            self.step as i32,
        )
    }
}