#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MoeStdError {
Empty,
ShapeMismatch,
}
impl From<native_neural_network::moe::MoeError> for MoeStdError {
fn from(e: native_neural_network::moe::MoeError) -> Self {
match e {
native_neural_network::moe::MoeError::Empty => MoeStdError::Empty,
native_neural_network::moe::MoeError::ShapeMismatch => MoeStdError::ShapeMismatch,
}
}
}
pub fn top1_gating(scores: &[f32], num_experts: usize) -> Result<usize, MoeStdError> {
native_neural_network::moe::top1_gating(scores, num_experts).map_err(|e| e.into())
}
pub fn route_top1(
expert_outputs: &[f32],
num_experts: usize,
hidden_size: usize,
chosen_expert: usize,
out: &mut [f32],
) -> Result<(), MoeStdError> {
native_neural_network::moe::route_top1(
expert_outputs,
num_experts,
hidden_size,
chosen_expert,
out,
)
.map_err(|e| e.into())
}