use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MoEActivation {
SiLU,
GeLU,
None,
}
pub trait MoEOps<R: Runtime> {
fn moe_top_k_routing(&self, logits: &Tensor<R>, k: usize) -> Result<(Tensor<R>, Tensor<R>)>;
fn moe_permute_tokens(
&self,
tokens: &Tensor<R>,
indices: &Tensor<R>,
num_experts: usize,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn moe_unpermute_tokens(
&self,
expert_output: &Tensor<R>,
sort_indices: &Tensor<R>,
weights: &Tensor<R>,
num_tokens: usize,
) -> Result<Tensor<R>>;
fn moe_grouped_gemm(
&self,
permuted_tokens: &Tensor<R>,
expert_weights: &Tensor<R>,
expert_offsets: &Tensor<R>,
) -> Result<Tensor<R>>;
fn moe_grouped_gemm_fused(
&self,
permuted_tokens: &Tensor<R>,
expert_weights: &Tensor<R>,
expert_offsets: &Tensor<R>,
activation: MoEActivation,
) -> Result<Tensor<R>>;
}