tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! 0.7B-style sparse Mixture-of-Experts (MoE) model (gated on `rocm-hip`).
//!
//! Phase 2.4 of the 0.7B MoE quality-decision training project.
//! Implements the sparse MoE architecture: a router that selects
//! the top-k experts per token, and a stack of expert MLPs.
//!
//! - `topology` — the MoE topology (router + experts + head).
//! - `forward` — the forward pass through the topology.
//! - `backward` — the backward pass with the gradient flow split
//!   into the expert path and the router path.
//! - `builder` — the topology builder.
//! - `diagnose_ref` — the reference diagnosis (used by
//!   `tests/moe_router_diagnose.rs`).
//!
// Phase 2.4 MoE quality-decision model.
//
// Implements the 0.7B-style sparse Mixture-of-Experts architecture
// described in `docs/MOE_DESIGN.md` §3-§4. The module glues together
// the existing `Layer` trait, `Linear`, `LayerNorm`, `GELU`, `Router`,
// and `Sequential` types from `crate::model` and adds:
//
//   - `MoESize` — model size variant (Tiny / Medium / Full).
//   - `MoEModel` — the full model (router + N_EXPERTS experts).
//   - `MoEForwardCache` — per-call cache that travels from `forward`
//     to `backward`.
//   - `MoEOutput` — forward return value (logits + router weights).
//   - `RouterConfig`, `ExpertConfig` — read-only architecture
//     descriptors.
//
// Memory accounting (per spec, for the patent / disclosure):
//   Nano  (hidden=128,  depth=2):  ~195K params ->  ~390KB fp16 +
//                                                     ~1.5MB AdamW
//                                                     state   (~2MB)
//   Tiny  (hidden=2048, depth=5):  ~85M params  ->  170MB fp16 +
//                                                      680MB AdamW
//                                                      state   (~850MB)
//   Medium(hidden=4096, depth=5):  ~338M params ->  676MB fp16 +
//                                                     2.7GB AdamW
//                                                     state   (~3.4GB)
//   Full  (hidden=4096, depth=12): ~808M params -> 1.6GB fp16 +
//                                                     6.5GB AdamW
//                                                     state   (~8GB)
//
// All three fit on a 16GB GDDR6 device; Tiny and Medium are the
// "production" sizes for the smoke test, Full is config-only. Nano
// is the hyperparameter-iteration size — small enough that a full
// optimizer step runs in milliseconds.

use std::any::Any;
use std::cell::RefCell;

use crate::model::layer::{Layer, Router};
use crate::model::sequential::{Model, Sequential};
use crate::object::Tensor;

pub mod backward;
pub mod builder;
pub mod diagnose_ref;
pub mod forward;
pub mod topology;

pub use topology::{IN_DIM, N_EXPERTS, OUT_DIM, TOP_K, param_count, params_for_size};

/// MoE model size variants.
///
/// `Nano` is the fast-iteration smoke size (hidden=128, depth=2,
/// ~195K params) — small enough that a full forward+backward
/// optimizer step runs in milliseconds, so it is the right choice
/// for hyperparameter sweeps (LR, optimizer, router scaling) where
/// the larger sizes would each take minutes per step.
///
/// `Tiny` and `Medium` are fully implemented; `Full` is a config-only
/// stub kept for the spec's 0.7B target — its `MoEModel::new` would
/// allocate ~8GB of AdamW state, so it is not constructed in tests.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum MoESize {
    /// hidden=128, depth=2. ~195K params. ~1000x faster than Tiny.
    /// Keeps the 4-expert sheaf router and the same 96-in/20-out
    /// schema, so it works with the existing training data and
    /// integration code — the point of Nano is to exercise the
    /// *topology* (router, top-K, gating) at sub-second cost.
    Nano,
    /// hidden=2048, depth=5. ~85M params. Fits on a 7800 XT with
    /// massive headroom; used for the smoke test.
    Tiny,
    /// hidden=4096, depth=5. ~338M params. The "production" size
    /// in the patent disclosure's revised scope.
    Medium,
    /// hidden=4096, depth=12. ~808M params. The 0.7B target. Not
    /// constructed in tests (would allocate ~8GB AdamW state).
    Full,
}

impl MoESize {
    /// `(hidden_dim, n_hidden_layers)` for the given size.
    pub fn dims(self) -> (usize, usize) {
        match self {
            MoESize::Nano => (128, 2),
            MoESize::Tiny => (2048, 5),
            MoESize::Medium => (4096, 5),
            MoESize::Full => (4096, 12),
        }
    }

    /// Human-readable name.
    pub fn name(self) -> &'static str {
        match self {
            MoESize::Nano => "Nano",
            MoESize::Tiny => "Tiny",
            MoESize::Medium => "Medium",
            MoESize::Full => "Full",
        }
    }

    /// Trainable parameter count (matches `topology::param_count`).
    pub fn param_count(self) -> usize {
        topology::param_count(self)
    }
}

/// Router configuration (read-only descriptor).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RouterConfig {
    pub in_features: usize,
    pub n_experts: usize,
    pub top_k: usize,
}

/// Expert configuration (read-only descriptor).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExpertConfig {
    pub in_features: usize,
    pub hidden: usize,
    pub n_hidden_layers: usize,
    pub out_features: usize,
}

/// Forward pass output: model logits + (post-mask) router weights.
#[derive(Debug, Clone)]
pub struct MoEOutput {
    /// Final model output, shape `[B, OUT_DIM]`.
    pub logits: Tensor<f32>,
    /// Per-expert gating weights, shape `[B, N_EXPERTS]`. Entries
    /// at non-top-K positions are exactly 0.
    pub router_weights: Tensor<f32>,
}

/// Per-call forward cache, stored on the model and consumed by
/// `backward`. The router's own cache is `Box<dyn Any>` to stay
/// decoupled from its concrete type.
pub struct MoEForwardCache {
    /// Original input tensor (cloned for safety).
    pub input: Tensor<f32>,
    /// Router's own `RouterCache` (consumed by the router's backward).
    pub router_cache: Box<dyn Any + Send>,
    /// Post-mask router weights `[B, N_EXPERTS]` (snapshot for the
    /// expert path's grad_expert_output computation).
    pub router_weights: Tensor<f32>,
    /// Per-expert forward outputs, each `[B, OUT_DIM]`.
    pub expert_outputs: Vec<Tensor<f32>>,
}

/// The full MoE model: one router + N_EXPERTS deep-MLP experts.
pub struct MoEModel {
    /// Size variant.
    pub size: MoESize,
    /// Router (Linear + Softmax + top-K mask).
    pub router: Router,
    /// One `Sequential` per expert.
    pub experts: Vec<Sequential>,
    /// Last forward's cache (taken by `backward`).
    pub last_cache: RefCell<Option<MoEForwardCache>>,
}

impl MoEModel {
    /// Total number of trainable parameter tensors
    /// (Linear.weight + Linear.bias + LayerNorm.gamma + LayerNorm.beta
    /// summed over the router and all experts). This is also the
    /// number of `Tensor<f32>` entries in the `param_grads` vector
    /// returned by `backward`.
    pub fn parameter_count(&self) -> usize {
        let mut n = self.router.parameters().len();
        for expert in &self.experts {
            n += expert.parameters().len();
        }
        n
    }

    /// Total scalar parameter count (sum of `numel()` over every
    /// trainable tensor). Matches `size.param_count()`.
    pub fn scalar_param_count(&self) -> usize {
        let mut n = 0usize;
        for p in self.router.parameters() {
            n += p.numel();
        }
        for expert in &self.experts {
            for p in expert.parameters() {
                n += p.numel();
            }
        }
        n
    }

    /// Read-only access to the experts (for tests / smoke runs).
    pub fn expert(&self, idx: usize) -> &Sequential {
        &self.experts[idx]
    }

    /// Mutable access to the experts (for tests / smoke runs that
    /// need to apply in-place optimizer steps).
    pub fn expert_mut(&mut self, idx: usize) -> &mut Sequential {
        &mut self.experts[idx]
    }
}