tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! MoE forward pass.
//!
//! The forward pass: project the input through the router, take
//! the top-k expert weights per token, run each token through
//! its selected experts, and combine the outputs. The forward
//! cache lives on the model itself (`RefCell<Option<...>>`) so
//! the backward pass can read it.
//!
// Phase 2.4 MoE forward pass.
//
// Forward per sample (B is the batch dim, E = N_EXPERTS = 4, K = IN_DIM,
// D = OUT_DIM = 20):
//
//     logits = Linear_router(x)              # [B, E]
//     probs  = softmax(logits)               # [B, E]
//     w      = top_k_mask(probs)             # [B, E], non-top-K entries = 0
//     out_e  = Expert_e(x) for e in 0..E     # each [B, D]
//     y      = sum_e w[:, e] * out_e         # [B, D]
//
// All four experts run on the full batch (dense MoE pattern). The
// router weights at non-top-K positions are exactly 0, so those
// experts do not contribute to the output, but their forward
// computation is still done (a deliberate first-cut simplification;
// sparse per-sample dispatch is a Phase 2.5 optimization).
//
// The `top_k_indices` recorded inside the router's forward cache
// tell us which experts were activated for each sample; the smoke
// test reads them via the saved `MoEForwardCache`.

use crate::model::layer::Layer;
use crate::model::sequential::Model;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};

use super::topology::{N_EXPERTS, OUT_DIM};
use super::{MoEForwardCache, MoEModel, MoEOutput};

impl MoEModel {
    /// Forward pass. Returns the logits and the (post-mask) router
    /// weights so the caller can inspect routing decisions.
    pub fn forward(&self, input: &Tensor<f32>) -> Result<MoEOutput> {
        // 1. Router: x[B, K] -> w[B, E] (top-K mask applied).
        let (router_weights, router_cache) = self.router.forward(input)?;
        let shape = router_weights.meta.shape.clone();
        let (b, e) = match (&shape.dims[0], &shape.dims[1]) {
            (crate::object::Dim::Static(b), crate::object::Dim::Static(e)) => (*b, *e),
            _ => {
                return Err(Error::shape(format!(
                    "MoEModel::forward router weights must be 2D static, got {:?}",
                    shape
                )));
            }
        };
        if e != N_EXPERTS {
            return Err(Error::shape(format!(
                "MoEModel::forward expected router weights dim 1 = N_EXPERTS={}, got {}",
                N_EXPERTS, e
            )));
        }

        // 2. Each expert processes the full batch.
        let mut expert_outputs: Vec<Tensor<f32>> = Vec::with_capacity(N_EXPERTS);
        for (ei, expert) in self.experts.iter().enumerate() {
            let outs = expert.forward(&[input.clone()]).map_err(|err| {
                Error::backend(format!("MoEModel::forward expert {} forward: {}", ei, err))
            })?;
            if outs.len() != 1 {
                return Err(Error::backend(format!(
                    "MoEModel::forward expert {} returned {} tensors, expected 1",
                    ei,
                    outs.len()
                )));
            }
            expert_outputs.push(outs.into_iter().next().expect("len=1"));
        }

        // 3. Combine: y[b, d] = sum_e w[b, e] * expert_out_e[b, d].
        let mut out_data = vec![0.0f32; b * OUT_DIM];
        for ei in 0..N_EXPERTS {
            for bi in 0..b {
                let w = router_weights.data[bi * N_EXPERTS + ei];
                if w == 0.0 {
                    // Top-K masked: skip the inner loop.
                    continue;
                }
                let expert_row = &expert_outputs[ei].data[bi * OUT_DIM..(bi + 1) * OUT_DIM];
                let out_row = &mut out_data[bi * OUT_DIM..(bi + 1) * OUT_DIM];
                for d in 0..OUT_DIM {
                    out_row[d] += w * expert_row[d];
                }
            }
        }
        let logits = Tensor::dense_cpu(
            input.meta.domain.clone(),
            Shape::from(vec![b, OUT_DIM]),
            out_data,
        );

        // 4. Stash the forward cache for backward.
        let cache = MoEForwardCache {
            input: input.clone(),
            router_cache,
            router_weights: router_weights.clone(),
            expert_outputs,
        };
        *self.last_cache.borrow_mut() = Some(cache);

        Ok(MoEOutput {
            logits,
            router_weights,
        })
    }
}