tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Training parameter: fp16 weight tensor + fp32 AdamW state.
//!
//! The weight is stored as `Tensor<u16>` where each `u16` is the
//! IEEE 754 binary16 representation. The AdamW moments (m, v)
//! and step counter are stored as `f32`. `Parameter::adamw_step`
//! is the pure-CPU AdamW step; `hip_adamw_bridge::run_adamw_step`
//! is the HIP version (gated on `rocm-hip`).
//!
// Training parameter: an fp16 weight tensor plus fp32 AdamW state.
//
// The weight is stored as `Tensor<u16>` where each `u16` is the IEEE
// binary16 bit pattern. This matches the storage convention used by the
// ROCm/HIP kernels in `src/backend/hip_*.rs`, so the same bit pattern can
// be shipped to the AdamW kernel without any further conversion. AdamW
// moments (`m`, `v`) stay in fp32 to preserve the precision the bias
// correction depends on (see `hip_adamw.rs` for the matching rationale).

use crate::domain::DomainId;
use crate::object::{Shape, Tensor};

/// A trainable parameter: fp16 weight plus fp32 AdamW state.
///
/// The `weight` tensor stores fp16 values as `u16` bit patterns. `m` and
/// `v` are the AdamW first and second moments, kept in fp32 per the
/// kernel contract. `step` is the per-parameter 1-based step index used
/// for bias correction; in a typical workflow the global optimizer
/// counter is the source of truth and the parameter-level counter is
/// kept in sync for inspection.
#[derive(Debug, Clone)]
pub struct Parameter {
    /// fp16 weight tensor (bit patterns).
    pub weight: Tensor<u16>,
    /// AdamW first moment, fp32, length == `weight.data.len()`.
    pub m: Vec<f32>,
    /// AdamW second moment, fp32, length == `weight.data.len()`.
    pub v: Vec<f32>,
    /// Per-parameter AdamW step index (1-based after the first update).
    pub step: u32,
}

impl Parameter {
    /// Build a parameter from a weight tensor, m, v. The lengths of
    /// `weight.data`, `m`, and `v` must match.
    pub fn new(weight: Tensor<u16>, m: Vec<f32>, v: Vec<f32>) -> Self {
        let n = weight.data.len();
        assert_eq!(m.len(), n, "Parameter m length mismatch");
        assert_eq!(v.len(), n, "Parameter v length mismatch");
        Self {
            weight,
            m,
            v,
            step: 0,
        }
    }

    /// Build a parameter from a flat vector of fp16 bit patterns. The
    /// `m` and `v` vectors are zero-initialised. The shape is recorded
    /// on the tensor meta for inspection but is not interpreted by the
    /// AdamW kernel (which treats its inputs as a flat slice).
    pub fn from_fp16_bits(weight_bits: Vec<u16>, shape: Shape) -> Self {
        let n = weight_bits.len();
        let tensor = Tensor::dense_cpu(DomainId::new("fp16"), shape, weight_bits);
        Self::new(tensor, vec![0.0f32; n], vec![0.0f32; n])
    }

    /// Build a parameter from a flat vector of fp32 values, rounded to
    /// fp16 via round-to-nearest-even (matching the kernel's
    /// `__float2half_rn` convention).
    pub fn from_f32(values: &[f32], shape: Shape) -> Self {
        let bits: Vec<u16> = values.iter().map(|&v| f32_to_fp16_bits(v)).collect();
        Self::from_fp16_bits(bits, shape)
    }

    /// Number of scalar weights in this parameter.
    pub fn len(&self) -> usize {
        self.weight.data.len()
    }

    /// True if the parameter is empty.
    pub fn is_empty(&self) -> bool {
        self.weight.data.is_empty()
    }

    /// Read the weights back as fp32 (decoded from the stored fp16
    /// bit patterns).
    pub fn weights_f32(&self) -> Vec<f32> {
        self.weight
            .data
            .iter()
            .map(|&b| fp16_bits_to_f32(b))
            .collect()
    }
}

/// Round an fp32 value to the nearest fp16 (IEEE 754 binary16) bit
/// pattern. This matches the rounding rule used by HIP's
/// `__float2half_rn` so the bit pattern we feed to the AdamW kernel
/// is the same one the kernel would observe on-device.
///
/// The actual conversion lives in `crate::backend::f16_convert`. We
/// re-export under the historical training-side name so existing
/// callers (and the smoke tests) keep working. Duplicate hand-rolled
/// copies in this file had the subnormal shift bug (Task #71) and
/// the subnormal-boundary round-up bug (Task #75), both of which
/// silently rounded small AdamW gradients to either zero or ±2.0
/// and likely contributed to 0.7B MoE training divergence.
pub fn f32_to_fp16_bits(value: f32) -> u16 {
    crate::backend::f16_convert::f32_to_f16(value)
}

/// Decode an IEEE 754 binary16 bit pattern to fp32.
pub fn fp16_bits_to_f32(bits: u16) -> f32 {
    crate::backend::f16_convert::f16_to_f32(bits)
}