tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! The `Parameter` type: trainable tensor + AdamW state.
//!
//! A `Parameter` is a `Tensor<f32>` plus the AdamW first/second
//! moment state and step counter. The Model layer treats every
//! trainable weight tensor as a Parameter; non-trainable buffers
//! (e.g. running statistics) are not Parameters.
//!
// A `Parameter` is a `Tensor<f32>` plus the AdamW first/second moment
// state and step counter. The Model layer treats every trainable
// weight tensor as a Parameter; non-trainable buffers (e.g.
// LayerNorm gamma/beta) are still wrapped as `Parameter` with
// `requires_grad` semantics left to the layer.

use crate::Error;
use crate::domain::DomainId;
use crate::object::{Shape, Tensor};

#[derive(Debug, Clone)]
pub struct Parameter {
    pub data: Tensor<f32>,
    pub m: Tensor<f32>,
    pub v: Tensor<f32>,
    pub step: u32,
}

impl Parameter {
    /// New parameter with all weights set to `0.0` and the AdamW
    /// state initialized to `0.0`.
    pub fn zeros(shape: Shape, domain: DomainId) -> Self {
        let n = numel(&shape);
        Self {
            data: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            step: 0,
        }
    }

    /// New parameter filled with samples from `U(lo, hi)` (xorshift32 PRNG,
    /// deterministic given the seed).
    pub fn uniform(shape: Shape, lo: f32, hi: f32, seed: u32, domain: DomainId) -> Self {
        let n = numel(&shape);
        let mut data = Vec::with_capacity(n);
        let mut state: u32 = seed.wrapping_add(1);
        for _ in 0..n {
            state ^= state << 13;
            state ^= state >> 17;
            state ^= state << 5;
            let frac = state as f32 / u32::MAX as f32;
            data.push(lo + (hi - lo) * frac);
        }
        Self {
            data: Tensor::dense_cpu(domain.clone(), shape.clone(), data),
            m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            step: 0,
        }
    }

    /// Wrap an existing tensor as a Parameter (AdamW state reset to 0).
    pub fn from_tensor(t: Tensor<f32>) -> Self {
        let n = t.data.len();
        let shape = t.meta.shape.clone();
        let domain = t.meta.domain.clone();
        Self {
            data: t,
            m: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            v: Tensor::dense_cpu(domain.clone(), shape.clone(), vec![0.0f32; n]),
            step: 0,
        }
    }

    pub fn numel(&self) -> usize {
        self.data.data.len()
    }

    /// In-place AdamW update using the host-side pure-Rust reference
    /// kernel. We deliberately avoid the HIP AdamW pilot for now
    /// because the kernel takes fp16 weights/gradients and the model
    /// layer stores everything in fp32; the round-trip would only add
    /// noise. Phase 2.2 (the training driver) can swap this for the
    /// HIP kernel once the rest of the pipeline is in place.
    pub fn adamw_step(
        &mut self,
        grad: &Tensor<f32>,
        lr: f32,
        beta1: f32,
        beta2: f32,
        eps: f32,
        weight_decay: f32,
    ) -> Result<(), Error> {
        if grad.data.len() != self.data.data.len() {
            return Err(Error::shape(format!(
                "adamw_step grad length {} != parameter length {}",
                grad.data.len(),
                self.data.data.len()
            )));
        }
        if self.step == u32::MAX {
            return Err(Error::backend("adamw_step parameter step counter overflow"));
        }
        self.step += 1;
        let t = self.step as f32;
        let bc1 = 1.0 - beta1.powf(t);
        let bc2 = 1.0 - beta2.powf(t);
        for i in 0..self.data.data.len() {
            let g = grad.data[i] + weight_decay * self.data.data[i];
            self.m.data[i] = beta1 * self.m.data[i] + (1.0 - beta1) * g;
            self.v.data[i] = beta2 * self.v.data[i] + (1.0 - beta2) * g * g;
            let m_hat = self.m.data[i] / bc1;
            let v_hat = self.v.data[i] / bc2;
            self.data.data[i] -= lr * m_hat / (v_hat.sqrt() + eps);
        }
        Ok(())
    }

    /// In-place stochastic gradient descent (no momentum). The smoke
    /// test uses this for the "loss decreases" assertion because
    /// AdamW's bias correction makes the comparison noisier.
    pub fn sgd_step(&mut self, grad: &Tensor<f32>, lr: f32) -> Result<(), Error> {
        if grad.data.len() != self.data.data.len() {
            return Err(Error::shape(format!(
                "sgd_step grad length {} != parameter length {}",
                grad.data.len(),
                self.data.data.len()
            )));
        }
        for (theta, g) in self.data.data.iter_mut().zip(grad.data.iter()) {
            *theta -= lr * g;
        }
        Ok(())
    }

    /// In-place SGD with momentum. Reuses the `m` buffer as the
    /// momentum accumulator. The 10K MLP gate test
    /// (`tests/mlp_10k_e2e.rs`) uses this because AdamW's bias
    /// correction amplifies fp16 GEMM noise past the gate. The
    /// 0.7B MoE runner (`bin/train_quality_moe.rs`) opts into this
    /// via `--optimizer sgd` for the same reason on the fp16
    /// HIP path. Update form (Sutskever et al., identical to
    /// PyTorch `nesterov=False`):
    ///   v  <- momentum * v + grad
    ///   w  <- w - lr * v
    pub fn sgd_momentum_step(
        &mut self,
        grad: &Tensor<f32>,
        lr: f32,
        momentum: f32,
    ) -> Result<(), Error> {
        if grad.data.len() != self.data.data.len() {
            return Err(Error::shape(format!(
                "sgd_momentum_step grad length {} != parameter length {}",
                grad.data.len(),
                self.data.data.len()
            )));
        }
        // `m` is already initialised to zeros by every Parameter
        // constructor, so no lazy init is needed. Reusing it as the
        // momentum buffer means an SGD->AdamW swap is not state-safe
        // mid-run; the runner is responsible for choosing one
        // optimizer per run.
        for ((theta, g), momentum_buf) in self
            .data
            .data
            .iter_mut()
            .zip(grad.data.iter())
            .zip(self.m.data.iter_mut())
        {
            *momentum_buf = momentum * (*momentum_buf) + *g;
            *theta -= lr * *momentum_buf;
        }
        Ok(())
    }
}

fn numel(shape: &Shape) -> usize {
    let mut n = 1usize;
    for d in &shape.dims {
        match d {
            crate::object::Dim::Static(v) => n *= v,
            _ => {
                // Symbolic/dynamic dimensions are not supported in the
                // model layer; bail out with a 1-element allocation to
                // surface the error elsewhere.
                return 1;
            }
        }
    }
    n
}