use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static MOE_SOFTMAX_TOPK_SHADER_SOURCE: &str =
include_str!("../shaders/moe_softmax_topk.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("moe_softmax_topk_f32", MOE_SOFTMAX_TOPK_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MoeSoftmaxTopkGpuParams {
n_tokens: u32,
n_experts: u32,
top_k: u32,
_pad: f32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_moe_softmax_topk(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
logits: &MlxBuffer,
out_ids: &MlxBuffer,
out_weights: &MlxBuffer,
n_tokens: u32,
n_experts: u32,
top_k: u32,
) -> Result<()> {
if n_tokens == 0 || n_experts == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_softmax_topk: n_tokens, n_experts, and top_k must be > 0".into(),
));
}
if top_k > 64 {
return Err(MlxError::InvalidArgument(format!(
"moe_softmax_topk: top_k ({top_k}) > 64 (kernel supports up to 64)"
)));
}
if top_k > n_experts {
return Err(MlxError::InvalidArgument(format!(
"moe_softmax_topk: top_k ({top_k}) > n_experts ({n_experts})"
)));
}
let expected_logits = (n_tokens as usize) * (n_experts as usize) * DType::F32.size_of();
let expected_ids = (n_tokens as usize) * (top_k as usize) * DType::U32.size_of();
let expected_weights = (n_tokens as usize) * (top_k as usize) * DType::F32.size_of();
if logits.byte_len() < expected_logits {
return Err(MlxError::InvalidArgument(format!(
"moe_softmax_topk: logits too small (expected {expected_logits}, got {})",
logits.byte_len()
)));
}
if out_ids.byte_len() < expected_ids {
return Err(MlxError::InvalidArgument(format!(
"moe_softmax_topk: out_ids too small (expected {expected_ids}, got {})",
out_ids.byte_len()
)));
}
if out_weights.byte_len() < expected_weights {
return Err(MlxError::InvalidArgument(format!(
"moe_softmax_topk: out_weights too small (expected {expected_weights}, got {})",
out_weights.byte_len()
)));
}
let gpu_params = MoeSoftmaxTopkGpuParams {
n_tokens,
n_experts,
top_k,
_pad: 0.0,
};
let pipeline = registry.get_pipeline("moe_softmax_topk_f32", device.metal_device())?;
let tg_size = (n_experts as u64).min(128);
let shmem_bytes = (2 * tg_size + n_experts as u64) * 4;
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(logits)),
(2, KernelArg::Buffer(out_ids)),
(3, KernelArg::Buffer(out_weights)),
],
&[(0, shmem_bytes)],
MTLSize::new(n_tokens as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}