use crate::model::layer::Layer;
use crate::model::sequential::Model;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};
use super::topology::{N_EXPERTS, OUT_DIM};
use super::{MoEForwardCache, MoEModel, MoEOutput};
impl MoEModel {
pub fn forward(&self, input: &Tensor<f32>) -> Result<MoEOutput> {
let (router_weights, router_cache) = self.router.forward(input)?;
let shape = router_weights.meta.shape.clone();
let (b, e) = match (&shape.dims[0], &shape.dims[1]) {
(crate::object::Dim::Static(b), crate::object::Dim::Static(e)) => (*b, *e),
_ => {
return Err(Error::shape(format!(
"MoEModel::forward router weights must be 2D static, got {:?}",
shape
)));
}
};
if e != N_EXPERTS {
return Err(Error::shape(format!(
"MoEModel::forward expected router weights dim 1 = N_EXPERTS={}, got {}",
N_EXPERTS, e
)));
}
let mut expert_outputs: Vec<Tensor<f32>> = Vec::with_capacity(N_EXPERTS);
for (ei, expert) in self.experts.iter().enumerate() {
let outs = expert.forward(&[input.clone()]).map_err(|err| {
Error::backend(format!("MoEModel::forward expert {} forward: {}", ei, err))
})?;
if outs.len() != 1 {
return Err(Error::backend(format!(
"MoEModel::forward expert {} returned {} tensors, expected 1",
ei,
outs.len()
)));
}
expert_outputs.push(outs.into_iter().next().expect("len=1"));
}
let mut out_data = vec![0.0f32; b * OUT_DIM];
for ei in 0..N_EXPERTS {
for bi in 0..b {
let w = router_weights.data[bi * N_EXPERTS + ei];
if w == 0.0 {
continue;
}
let expert_row = &expert_outputs[ei].data[bi * OUT_DIM..(bi + 1) * OUT_DIM];
let out_row = &mut out_data[bi * OUT_DIM..(bi + 1) * OUT_DIM];
for d in 0..OUT_DIM {
out_row[d] += w * expert_row[d];
}
}
}
let logits = Tensor::dense_cpu(
input.meta.domain.clone(),
Shape::from(vec![b, OUT_DIM]),
out_data,
);
let cache = MoEForwardCache {
input: input.clone(),
router_cache,
router_weights: router_weights.clone(),
expert_outputs,
};
*self.last_cache.borrow_mut() = Some(cache);
Ok(MoEOutput {
logits,
router_weights,
})
}
}