tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! MoE backward pass.
//!
//! Backward splits the gradient flow into two paths: the expert
//! path (one per active expert) and the router path (one per
//! token). The expert gradients are accumulated into the
//! corresponding `Parameter` gradients; the router gradients
//! flow back through the router linear layer.
//!
// Phase 2.4 MoE backward pass.
//
// Backward splits the gradient flow into two paths:
//
//   1. Expert path. For each expert e:
//        grad_out_e[b, d] = grad_y[b, d] * w[b, e]
//      where `w` is the (post-mask) router output. Calling
//      `expert_e.backward(grad_out_e)` returns:
//        (grad_x_e[b, k], param_grads_e)
//      The total input gradient from the experts is:
//        grad_x_experts[b, k] = sum_e w[b, e] * grad_x_e[b, k]
//
//   2. Router path. The full derivative w.r.t. the router output is:
//        grad_w[b, e] = sum_d grad_y[b, d] * out_e[b, d]
//      Note that `w` at non-top-K positions was zeroed in the forward
//      pass. The router's own `backward` does not re-apply the top-K
//      mask (the existing `Router` impl assumes the caller has masked
//      `grad_output`); we therefore apply the mask here using
//      `top_k_indices` from the saved router cache, then call
//      `router.backward(grad_w_masked)` which yields:
//        (grad_x_router[b, k], router_param_grads)
//
// Total gradients:
//   grad_x[b, k] = grad_x_experts[b, k] + grad_x_router[b, k]
//   param_grads  = concat(router_param_grads, expert_param_grads_0, ..., expert_param_grads_3)
//                  (in the same order as `MoEModel::parameters()` would return)

use crate::model::layer::{Layer, RouterCache};
use crate::model::sequential::Model;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};

use super::MoEModel;
use super::topology::{IN_DIM, N_EXPERTS, OUT_DIM, TOP_K};

impl MoEModel {
    /// Backward pass. Returns `(grad_input, param_grads)` in the
    /// order produced by `parameters()` traversal: router first
    /// (weight, bias), then experts 0..N_EXPERTS-1 in declaration
    /// order, each contributing its own Linear and LayerNorm
    /// parameter gradients in declaration order.
    pub fn backward(&self, grad_output: &Tensor<f32>) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
        let cache = self
            .last_cache
            .borrow_mut()
            .take()
            .ok_or_else(|| Error::backend("MoEModel::backward called before forward"))?;

        // -- Shape checks -------------------------------------------------
        let b = match &grad_output.meta.shape.dims[0] {
            crate::object::Dim::Static(v) => *v,
            _ => {
                return Err(Error::shape(
                    "MoEModel::backward grad_output batch dim must be static",
                ));
            }
        };
        if grad_output.data.len() != b * OUT_DIM {
            return Err(Error::shape(format!(
                "MoEModel::backward grad_output data {} != batch*OUT_DIM={}*{}",
                grad_output.data.len(),
                b,
                OUT_DIM
            )));
        }

        // -- 1. Expert path ----------------------------------------------
        let mut grad_x_experts = vec![0.0f32; b * IN_DIM];
        let mut all_expert_param_grads: Vec<Tensor<f32>> = Vec::new();
        for ei in 0..N_EXPERTS {
            // grad_out_e[b, d] = grad_y[b, d] * w[b, e]
            let mut grad_out_e = vec![0.0f32; b * OUT_DIM];
            for bi in 0..b {
                let w = cache.router_weights.data[bi * N_EXPERTS + ei];
                for d in 0..OUT_DIM {
                    grad_out_e[bi * OUT_DIM + d] = grad_output.data[bi * OUT_DIM + d] * w;
                }
            }
            let grad_out_e_t = Tensor::dense_cpu(
                grad_output.meta.domain.clone(),
                Shape::from(vec![b, OUT_DIM]),
                grad_out_e,
            );

            let (grad_x_e_t, expert_param_grads) =
                self.experts[ei].backward(&[grad_out_e_t]).map_err(|err| {
                    Error::backend(format!(
                        "MoEModel::backward expert {} backward: {}",
                        ei, err
                    ))
                })?;
            if expert_param_grads.len() != self.experts[ei].parameters().len() {
                return Err(Error::backend(format!(
                    "MoEModel::backward expert {} returned {} param grads, expected {}",
                    ei,
                    expert_param_grads.len(),
                    self.experts[ei].parameters().len()
                )));
            }

            // Accumulate w * grad_x_e into grad_x_experts.
            for bi in 0..b {
                let w = cache.router_weights.data[bi * N_EXPERTS + ei];
                for k in 0..IN_DIM {
                    grad_x_experts[bi * IN_DIM + k] += w * grad_x_e_t[0].data[bi * IN_DIM + k];
                }
            }
            all_expert_param_grads.extend(expert_param_grads);
        }

        // -- 2. Router path ----------------------------------------------
        // grad_w[b, e] = sum_d grad_y[b, d] * out_e[b, d]
        let mut grad_w = vec![0.0f32; b * N_EXPERTS];
        for bi in 0..b {
            for ei in 0..N_EXPERTS {
                let mut s = 0.0f32;
                for d in 0..OUT_DIM {
                    s += grad_output.data[bi * OUT_DIM + d]
                        * cache.expert_outputs[ei].data[bi * OUT_DIM + d];
                }
                grad_w[bi * N_EXPERTS + ei] = s;
            }
        }
        // Apply top-K mask (zeros out non-top-K entries per row).
        let router_cache = cache
            .router_cache
            .downcast_ref::<RouterCache>()
            .ok_or_else(|| Error::backend("MoEModel::backward router cache downcast failed"))?;
        for bi in 0..b {
            let row_top_k = &router_cache.top_k_indices[bi * TOP_K..(bi + 1) * TOP_K];
            for ei in 0..N_EXPERTS {
                if !row_top_k.contains(&ei) {
                    grad_w[bi * N_EXPERTS + ei] = 0.0;
                }
            }
        }
        let grad_w_t = Tensor::dense_cpu(
            grad_output.meta.domain.clone(),
            Shape::from(vec![b, N_EXPERTS]),
            grad_w,
        );

        let (grad_x_router, router_param_grads) = self
            .router
            .backward(&grad_w_t, router_cache as &dyn std::any::Any)
            .map_err(|err| {
                Error::backend(format!("MoEModel::backward router backward: {}", err))
            })?;

        // -- 3. Combine input gradients -----------------------------------
        let mut grad_x_data = vec![0.0f32; b * IN_DIM];
        for i in 0..(b * IN_DIM) {
            grad_x_data[i] = grad_x_experts[i] + grad_x_router.data[i];
        }
        let grad_input = Tensor::dense_cpu(
            grad_output.meta.domain.clone(),
            Shape::from(vec![b, IN_DIM]),
            grad_x_data,
        );

        // -- 4. Concat param gradients (router first, then experts) ------
        let mut all_param_grads = router_param_grads;
        all_param_grads.extend(all_expert_param_grads);
        Ok((grad_input, all_param_grads))
    }
}