use crate::model::layer::{Layer, RouterCache};
use crate::model::sequential::Model;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};
use super::MoEModel;
use super::topology::{IN_DIM, N_EXPERTS, OUT_DIM, TOP_K};
impl MoEModel {
pub fn backward(&self, grad_output: &Tensor<f32>) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let cache = self
.last_cache
.borrow_mut()
.take()
.ok_or_else(|| Error::backend("MoEModel::backward called before forward"))?;
let b = match &grad_output.meta.shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => {
return Err(Error::shape(
"MoEModel::backward grad_output batch dim must be static",
));
}
};
if grad_output.data.len() != b * OUT_DIM {
return Err(Error::shape(format!(
"MoEModel::backward grad_output data {} != batch*OUT_DIM={}*{}",
grad_output.data.len(),
b,
OUT_DIM
)));
}
let mut grad_x_experts = vec![0.0f32; b * IN_DIM];
let mut all_expert_param_grads: Vec<Tensor<f32>> = Vec::new();
for ei in 0..N_EXPERTS {
let mut grad_out_e = vec![0.0f32; b * OUT_DIM];
for bi in 0..b {
let w = cache.router_weights.data[bi * N_EXPERTS + ei];
for d in 0..OUT_DIM {
grad_out_e[bi * OUT_DIM + d] = grad_output.data[bi * OUT_DIM + d] * w;
}
}
let grad_out_e_t = Tensor::dense_cpu(
grad_output.meta.domain.clone(),
Shape::from(vec![b, OUT_DIM]),
grad_out_e,
);
let (grad_x_e_t, expert_param_grads) =
self.experts[ei].backward(&[grad_out_e_t]).map_err(|err| {
Error::backend(format!(
"MoEModel::backward expert {} backward: {}",
ei, err
))
})?;
if expert_param_grads.len() != self.experts[ei].parameters().len() {
return Err(Error::backend(format!(
"MoEModel::backward expert {} returned {} param grads, expected {}",
ei,
expert_param_grads.len(),
self.experts[ei].parameters().len()
)));
}
for bi in 0..b {
let w = cache.router_weights.data[bi * N_EXPERTS + ei];
for k in 0..IN_DIM {
grad_x_experts[bi * IN_DIM + k] += w * grad_x_e_t[0].data[bi * IN_DIM + k];
}
}
all_expert_param_grads.extend(expert_param_grads);
}
let mut grad_w = vec![0.0f32; b * N_EXPERTS];
for bi in 0..b {
for ei in 0..N_EXPERTS {
let mut s = 0.0f32;
for d in 0..OUT_DIM {
s += grad_output.data[bi * OUT_DIM + d]
* cache.expert_outputs[ei].data[bi * OUT_DIM + d];
}
grad_w[bi * N_EXPERTS + ei] = s;
}
}
let router_cache = cache
.router_cache
.downcast_ref::<RouterCache>()
.ok_or_else(|| Error::backend("MoEModel::backward router cache downcast failed"))?;
for bi in 0..b {
let row_top_k = &router_cache.top_k_indices[bi * TOP_K..(bi + 1) * TOP_K];
for ei in 0..N_EXPERTS {
if !row_top_k.contains(&ei) {
grad_w[bi * N_EXPERTS + ei] = 0.0;
}
}
}
let grad_w_t = Tensor::dense_cpu(
grad_output.meta.domain.clone(),
Shape::from(vec![b, N_EXPERTS]),
grad_w,
);
let (grad_x_router, router_param_grads) = self
.router
.backward(&grad_w_t, router_cache as &dyn std::any::Any)
.map_err(|err| {
Error::backend(format!("MoEModel::backward router backward: {}", err))
})?;
let mut grad_x_data = vec![0.0f32; b * IN_DIM];
for i in 0..(b * IN_DIM) {
grad_x_data[i] = grad_x_experts[i] + grad_x_router.data[i];
}
let grad_input = Tensor::dense_cpu(
grad_output.meta.domain.clone(),
Shape::from(vec![b, IN_DIM]),
grad_x_data,
);
let mut all_param_grads = router_param_grads;
all_param_grads.extend(all_expert_param_grads);
Ok((grad_input, all_param_grads))
}
}