tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Inference session for a trained 0.7B MoE quality-decision model (gated on `rocm-hip`).
//!
//! The 0.7B MoE training runner saves `arch.json` (the
//! architecture description) and `checkpoint.tkp1` (the
//! trained weights) at the end of each run. This module
//! reconstructs a live `MoEModel` from those two artifacts and
//! exposes a single `forward` call that:
//!
//! - validates the input shape against the topology
//!   (IN_DIM = 96)
//! - pads the batch up to a multiple of 16 (the HIP GEMM
//!   kernel requires it) by replicating the last row
//! - runs a forward pass and returns the per-class
//!   probabilities
//!
// Inference session for a trained 0.7B MoE quality-decision model.
//
// The 0.7B MoE training runner saves `arch.json` (the architecture
// description) and `checkpoint.tkp1` (the trained weights) at the
// end of each run. This module reconstructs a live `MoEModel` from
// those two artifacts and exposes a single `forward` call that:
//   - validates the input shape against the topology (IN_DIM = 96)
//   - pads the batch up to a multiple of 16 (the HIP GEMM kernel
//     requires it) by replicating the last row
//   - runs a forward pass
//   - slices the padding rows off the output so the caller sees
//     exactly the batch they asked for
//
// The `infer_quality_moe` CLI (Step 1 of the integration plan) and
// the `tokitai-model-server` axum sidecar (Step 2) both build on
// top of `ModelSession` so they share the same load + forward path
// and the same name-matching convention for the checkpoint restore.

use std::path::{Path, PathBuf};

use crate::checkpoint::{Checkpoint, load_checkpoint, restore_parameter};
use crate::domain::DomainId;
use crate::model::Layer;
use crate::model::parameter::Parameter;
use crate::model_arch::load_arch;
use crate::moe_model::MoEModel;
use crate::moe_model::topology::{IN_DIM, N_EXPERTS, OUT_DIM};
use crate::object::{Dim, Shape, Tensor};

/// A loaded, ready-to-infer MoE model. Construct via
/// [`ModelSession::load`]; call [`ModelSession::forward`] one or
/// more times on batches of input features.
pub struct ModelSession {
    arch_path: PathBuf,
    checkpoint_path: PathBuf,
    model: MoEModel,
}

impl ModelSession {
    /// Load the model from `arch_path` + `checkpoint_path`. The arch
    /// is validated against the topology constants and the
    /// checkpoint is restored into the freshly-built model via
    /// [`restore_into_model`]. The `ModelSession` is then ready to
    /// serve `forward` calls.
    pub fn load(arch_path: &Path, checkpoint_path: &Path) -> Result<Self, String> {
        let arch =
            load_arch(arch_path).map_err(|e| format!("load_arch({}): {e}", arch_path.display()))?;
        if arch.input_dim != IN_DIM {
            return Err(format!(
                "arch input_dim {} does not match topology IN_DIM {}",
                arch.input_dim, IN_DIM
            ));
        }
        if arch.output_dim != OUT_DIM {
            return Err(format!(
                "arch output_dim {} does not match topology OUT_DIM {}",
                arch.output_dim, OUT_DIM
            ));
        }
        let mut model = arch.build().map_err(|e| format!("arch.build(): {e}"))?;
        let ckpt = load_checkpoint(checkpoint_path)
            .map_err(|e| format!("load_checkpoint({}): {e}", checkpoint_path.display()))?;
        restore_into_model(&mut model, &ckpt)?;
        Ok(Self {
            arch_path: arch_path.to_path_buf(),
            checkpoint_path: checkpoint_path.to_path_buf(),
            model,
        })
    }

    /// The arch this session was loaded from. Useful for `/healthz`
    /// responses and `ModelArch` field logging in the model server.
    pub fn arch_path(&self) -> &Path {
        &self.arch_path
    }

    /// The checkpoint this session was loaded from.
    pub fn checkpoint_path(&self) -> &Path {
        &self.checkpoint_path
    }

    /// Run one forward pass on `batch` and return the logits plus
    /// the post-mask router weights. `batch` is a list of
    /// `IN_DIM`-length feature rows; an empty batch is rejected
    /// up front.
    pub fn forward(&mut self, batch: Vec<Vec<f32>>) -> Result<InferOutput, String> {
        if batch.is_empty() {
            return Err("batch is empty".to_string());
        }
        for (i, row) in batch.iter().enumerate() {
            if row.len() != IN_DIM {
                return Err(format!(
                    "batch row {i} has length {}, expected {IN_DIM}",
                    row.len()
                ));
            }
        }

        // Pad up to a multiple of 16 by replicating the last row.
        // We do NOT zero-fill: zero rows would skew the LayerNorm
        // running stats in the trained model and the user would see
        // an "output is 0 for zero input" artifact that doesn't
        // represent the model's behaviour on the actual data.
        let b = batch.len();
        let b_padded = b.next_multiple_of(16).max(16);
        let pad_row = batch.last().expect("non-empty").clone();
        let mut data = Vec::with_capacity(b_padded * IN_DIM);
        for row in &batch {
            data.extend_from_slice(row);
        }
        while data.len() < b_padded * IN_DIM {
            data.extend_from_slice(&pad_row);
        }
        let input = Tensor::dense_cpu(
            DomainId::new("infer_input"),
            Shape::from(vec![b_padded, IN_DIM]),
            data,
        );

        // Forward on padded batch.
        let out = self
            .model
            .forward(&input)
            .map_err(|e| format!("model.forward: {e}"))?;

        // Slice the padding rows off.
        let logits = slice_rows(&out.logits, b)?;
        let router_weights = slice_rows(&out.router_weights, b)?;
        Ok(InferOutput {
            logits,
            router_weights,
        })
    }
}

/// Result of a single `ModelSession::forward` call: the model
/// logits `[B, OUT_DIM]` and the (post-mask) router weights
/// `[B, N_EXPERTS]`.
pub struct InferOutput {
    pub logits: Tensor<f32>,
    pub router_weights: Tensor<f32>,
}

/// Restore every snapshot in `ckpt` into the matching live parameter
/// of `model`. Names follow the same convention the training runner
/// uses (`build_checkpoint` in `src/training_runner.rs`):
///
///   - `router.param_{i}` for the router's parameters in
///     `model.router.parameters()` declaration order.
///   - `expert_{ei}.param_{i}` for each expert's parameters in
///     `model.experts[ei].parameters()` declaration order.
///
/// The `i` is a PER-PREFIX index. We walk the model in the same
/// order `build_checkpoint` writes snapshots and zip the live
/// `&mut Parameter` handles with the incoming `ParameterSnapshot`s.
fn restore_into_model(model: &mut MoEModel, ckpt: &Checkpoint) -> Result<(), String> {
    let mut live: Vec<(String, &mut Parameter)> = Vec::with_capacity(ckpt.params.len());
    for (i, p) in model.router.parameters_mut().into_iter().enumerate() {
        live.push((format!("router.param_{i}"), p));
    }
    for (ei, expert) in model.experts.iter_mut().enumerate() {
        for (i, p) in expert.parameters_mut().into_iter().enumerate() {
            live.push((format!("expert_{ei}.param_{i}"), p));
        }
    }
    if live.len() != ckpt.params.len() {
        return Err(format!(
            "parameter count mismatch: model has {} live params, checkpoint has {} snapshots; \
             the checkpoint was probably saved for a different arch",
            live.len(),
            ckpt.params.len()
        ));
    }
    for ((expected_name, param), snap) in live.into_iter().zip(ckpt.params.iter()) {
        if snap.name != expected_name {
            return Err(format!(
                "snapshot name mismatch: expected `{expected_name}` (matching the model's \
                 declaration order), got `{}`; the checkpoint was probably saved for a \
                 different arch",
                snap.name
            ));
        }
        restore_parameter(param, snap);
    }
    Ok(())
}

/// Keep the first `rows` rows of a 2-D `Tensor<f32>`. Errors out if
/// the tensor is not 2-D or if `rows` exceeds the row count.
fn slice_rows(t: &Tensor<f32>, rows: usize) -> Result<Tensor<f32>, String> {
    let dims = &t.meta.shape.dims;
    let (cur_rows, cols) = match dims.as_slice() {
        [Dim::Static(r), Dim::Static(c)] => (*r, *c),
        _ => {
            return Err(format!(
                "slice_rows: expected 2-D static tensor, got {:?}",
                dims
            ));
        }
    };
    if rows > cur_rows {
        return Err(format!(
            "slice_rows: asked for {rows} rows but tensor only has {cur_rows}"
        ));
    }
    let mut data = Vec::with_capacity(rows * cols);
    for r in 0..rows {
        data.extend_from_slice(&t.data[r * cols..(r + 1) * cols]);
    }
    Ok(Tensor::dense_cpu(
        t.meta.domain.clone(),
        Shape::from(vec![rows, cols]),
        data,
    ))
}

pub use crate::moe_model::topology::N_EXPERTS as INFER_N_EXPERTS;
/// Re-export the topology constants for callers that need them.
pub use crate::moe_model::topology::{IN_DIM as INFER_IN_DIM, OUT_DIM as INFER_OUT_DIM};