tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! MoE topology builder.
//!
//! `MoEBuilder` is the public entry point for constructing an
//! `MoEModel`. The builder takes the input dim, hidden dim,
//! output dim, number of experts, and the top-k per token; it
//! allocates the per-expert MLPs and the router.
//!
// Phase 2.4 MoE model builder.
//
// `MoEModel::new(size, seed)` constructs the full model:
//   - One `Router` (Linear + Softmax + top-K mask) over N_EXPERTS experts.
//   - N_EXPERTS `Sequential` experts, each a stack of
//     [Linear, LayerNorm, GELU, (Linear, LayerNorm, GELU) x n_hidden, Linear].
//
// For `Nano` (hidden=128, depth=2) each expert is
//   Sequential[
//     Linear(96, 128), LayerNorm(128), GELU,
//     Linear(128, 128), LayerNorm(128), GELU,
//     Linear(128, 128), LayerNorm(128), GELU,
//     Linear(128, 20),
//   ]
// — same shape, same router, ~1000x fewer parameters than `Tiny`.
//
// All `Parameter`s are seeded deterministically from the supplied
// `seed: u64` (high 32 bits => base seed, low 32 bits => router seed).
// Re-running with the same seed reproduces the same initial weights —
// required for the `expert_substitution_loss_decreases` smoke test.
//
// We deliberately compose the existing `Linear`, `LayerNorm`, `GELU`,
// `Router`, and `Sequential` types from `crate::model` rather than
// re-implementing them, so the new module only owns the *gluing* logic.

use crate::model::layer::{GELU, Layer, LayerNorm, Linear, Router};
use crate::model::sequential::Sequential;

use super::topology::{IN_DIM, N_EXPERTS, OUT_DIM};
use super::{ExpertConfig, MoEModel, MoESize, RouterConfig};

impl MoEModel {
    /// Build a fresh `MoEModel` of the given `size`, with all
    /// parameters seeded from `seed`. Returns a fully constructed
    /// model whose forward and backward are ready to use.
    pub fn new(size: MoESize, seed: u64) -> Self {
        let (hidden, n_hidden) = size.dims();
        let router_seed = (seed & 0xFFFF_FFFF) as u32;
        let base_seed = ((seed >> 32) as u32).wrapping_add(0xA5A5_0001);

        // Router: Linear(IN_DIM, N_EXPERTS) -> Softmax -> Top-K.
        let router = Router::new(IN_DIM, N_EXPERTS, super::topology::TOP_K, router_seed);

        // Each expert is a Sequential of (Linear, LayerNorm, GELU) repeated.
        let mut experts: Vec<Sequential> = Vec::with_capacity(N_EXPERTS);
        for ei in 0..N_EXPERTS {
            let expert_seed = base_seed.wrapping_add((ei as u32).wrapping_mul(0x1000));
            let mut layers: Vec<Box<dyn Layer>> = Vec::new();
            // Input projection.
            layers.push(Box::new(Linear::new(
                IN_DIM,
                hidden,
                expert_seed.wrapping_add(0x10),
            )));
            layers.push(Box::new(LayerNorm::new(hidden, 1e-5)));
            layers.push(Box::new(GELU));
            // Hidden blocks.
            for hi in 0..n_hidden {
                layers.push(Box::new(Linear::new(
                    hidden,
                    hidden,
                    expert_seed.wrapping_add(0x20 + hi as u32),
                )));
                layers.push(Box::new(LayerNorm::new(hidden, 1e-5)));
                layers.push(Box::new(GELU));
            }
            // Output projection (no LN/GELU after).
            layers.push(Box::new(Linear::new(
                hidden,
                OUT_DIM,
                expert_seed.wrapping_add(0x30 + n_hidden as u32),
            )));
            experts.push(Sequential::new(layers));
        }

        Self {
            size,
            router,
            experts,
            last_cache: std::cell::RefCell::new(None),
        }
    }

    /// Router config (used by the smoke test to inspect architecture).
    pub fn router_config(&self) -> RouterConfig {
        RouterConfig {
            in_features: IN_DIM,
            n_experts: N_EXPERTS,
            top_k: super::topology::TOP_K,
        }
    }

    /// Config of expert `idx` (used by the smoke test for assertions).
    pub fn expert_config(&self, _idx: usize) -> ExpertConfig {
        let (hidden, n_hidden) = self.size.dims();
        ExpertConfig {
            in_features: IN_DIM,
            hidden,
            n_hidden_layers: n_hidden,
            out_features: OUT_DIM,
        }
    }
}