use super::errors::MoeError;
pub fn route_top1(
expert_outputs: &[f32],
num_experts: usize,
hidden_size: usize,
chosen_expert: usize,
out: &mut [f32],
) -> Result<(), MoeError> {
if num_experts == 0 || hidden_size == 0 || chosen_expert >= num_experts {
return Err(MoeError::ShapeMismatch);
}
let total = num_experts.checked_mul(hidden_size).ok_or(MoeError::ShapeMismatch)?;
if expert_outputs.len() < total || out.len() < hidden_size {
return Err(MoeError::ShapeMismatch);
}
let off = chosen_expert * hidden_size;
out[..hidden_size].copy_from_slice(&expert_outputs[off..off + hidden_size]);
Ok(())
}