use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub struct MoeDispatchParams {
pub input_dim: usize,
pub intermediate_dim: usize,
pub n_selected: usize,
}
pub struct ExpertWeights<'a> {
pub gate_proj: &'a MlxBuffer,
pub up_proj: &'a MlxBuffer,
pub down_proj: &'a MlxBuffer,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedGeluMulParams {
n_elements: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeAccumParams {
n_elements: u32,
routing_weight: f32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuZeroParams {
n_elements: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMatmulParams {
m: u32, k: u32, n: u32, }
#[allow(clippy::too_many_arguments)]
pub fn moe_dispatch(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
expert_weights: &[ExpertWeights<'_>],
routing_weights: &[f32],
output: &MlxBuffer,
scratch_gate: &MlxBuffer,
scratch_up: &MlxBuffer,
scratch_hidden: &MlxBuffer,
scratch_expert: &MlxBuffer,
params: &MoeDispatchParams,
) -> Result<()> {
if params.input_dim == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: input_dim must be > 0".into(),
));
}
if params.intermediate_dim == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: intermediate_dim must be > 0".into(),
));
}
if params.n_selected == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: n_selected must be > 0".into(),
));
}
if expert_weights.len() != params.n_selected {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert_weights length ({}) must match n_selected ({})",
expert_weights.len(),
params.n_selected
)));
}
if routing_weights.len() != params.n_selected {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: routing_weights length ({}) must match n_selected ({})",
routing_weights.len(),
params.n_selected
)));
}
let input_bytes = params.input_dim * std::mem::size_of::<f32>();
if input.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: input buffer too small: need {} bytes, have {}",
input_bytes,
input.byte_len()
)));
}
if output.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: output buffer too small: need {} bytes, have {}",
input_bytes,
output.byte_len()
)));
}
let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
if scratch_gate.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_gate buffer too small".into(),
));
}
if scratch_up.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_up buffer too small".into(),
));
}
if scratch_hidden.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_hidden buffer too small".into(),
));
}
if scratch_expert.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_expert buffer too small".into(),
));
}
let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
for (i, ew) in expert_weights.iter().enumerate() {
if ew.gate_proj.byte_len() < gate_up_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
i, gate_up_bytes, ew.gate_proj.byte_len()
)));
}
if ew.up_proj.byte_len() < gate_up_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
i, gate_up_bytes, ew.up_proj.byte_len()
)));
}
if ew.down_proj.byte_len() < down_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
i, down_bytes, ew.down_proj.byte_len()
)));
}
}
{
registry.get_pipeline("naive_matvec_f32", device)?;
registry.get_pipeline("fused_gelu_mul", device)?;
registry.get_pipeline("moe_accumulate", device)?;
registry.get_pipeline("zero_buffer", device)?;
}
let matvec_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("naive_matvec_f32", device)?;
p as *const _
};
let gelu_mul_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("fused_gelu_mul", device)?;
p as *const _
};
let accum_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("moe_accumulate", device)?;
p as *const _
};
let zero_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("zero_buffer", device)?;
p as *const _
};
let matvec_pipeline = unsafe { &*matvec_pipeline };
let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
let accum_pipeline = unsafe { &*accum_pipeline };
let zero_pipeline = unsafe { &*zero_pipeline };
let zero_params = GpuZeroParams {
n_elements: params.input_dim as u32,
};
encode_with_args(
encoder,
zero_pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Bytes(as_bytes(&zero_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
for (i, ew) in expert_weights.iter().enumerate() {
let w = routing_weights[i];
if w.abs() < 1e-10 {
continue;
}
encoder.memory_barrier();
let gate_params = GpuMatmulParams {
m: 1,
k: params.input_dim as u32,
n: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.gate_proj)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(scratch_gate)),
(3, KernelArg::Bytes(as_bytes(&gate_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
let up_params = GpuMatmulParams {
m: 1,
k: params.input_dim as u32,
n: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.up_proj)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(scratch_up)),
(3, KernelArg::Bytes(as_bytes(&up_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
encoder.memory_barrier();
let gelu_params = GpuFusedGeluMulParams {
n_elements: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
gelu_mul_pipeline,
&[
(0, KernelArg::Buffer(scratch_gate)),
(1, KernelArg::Buffer(scratch_up)),
(2, KernelArg::Buffer(scratch_hidden)),
(3, KernelArg::Bytes(as_bytes(&gelu_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
encoder.memory_barrier();
let down_params = GpuMatmulParams {
m: 1,
k: params.intermediate_dim as u32,
n: params.input_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.down_proj)),
(1, KernelArg::Buffer(scratch_hidden)),
(2, KernelArg::Buffer(scratch_expert)),
(3, KernelArg::Bytes(as_bytes(&down_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
encoder.memory_barrier();
let accum_params = GpuMoeAccumParams {
n_elements: params.input_dim as u32,
routing_weight: w,
};
encode_with_args(
encoder,
accum_pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Buffer(scratch_expert)),
(2, KernelArg::Bytes(as_bytes(&accum_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
}
Ok(())
}
pub fn moe_swiglu_fused_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_fused_encode: n_elements must be > 0".into(),
));
}
let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = n_elements * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
let params = GpuFusedGeluMulParams {
n_elements: n_elements as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_zero_buffer_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
output: &MlxBuffer,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_zero_buffer_encode: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if output.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("zero_buffer", device)?;
let params = GpuZeroParams { n_elements: n_elements as u32 };
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_accumulate_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
accumulator: &MlxBuffer,
expert_output: &MlxBuffer,
routing_weight: f32,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_accumulate_encode: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if accumulator.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
required, accumulator.byte_len()
)));
}
if expert_output.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
required, expert_output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_accumulate", device)?;
let params = GpuMoeAccumParams {
n_elements: n_elements as u32,
routing_weight,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(accumulator)),
(1, KernelArg::Buffer(expert_output)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_swiglu_batch_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
intermediate: usize,
top_k: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
));
}
let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = top_k * intermediate * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
let intermediate_bytes = (intermediate as u32).to_ne_bytes();
let top_k_bytes = (top_k as u32).to_ne_bytes();
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(&intermediate_bytes)),
(3, KernelArg::Bytes(&top_k_bytes)),
],
MTLSize::new(intermediate as u64, top_k as u64, 1),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeWeightedSumParams {
hidden_size: u32,
top_k: u32,
}
pub fn moe_weighted_sum_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
));
}
let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
if expert_outputs.byte_len() < expert_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
expert_required, expert_outputs.byte_len()
)));
}
let weights_required = top_k * std::mem::size_of::<f32>();
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
weights_required, weights.byte_len()
)));
}
let out_required = hidden_size * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
let params = GpuMoeWeightedSumParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(weights)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(hidden_size as u64, 1, 1),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeGatherTopkParams {
n_experts: u32,
top_k: u32,
}
pub fn moe_gather_topk_weights_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
softmax_probs: &MlxBuffer,
sorted_indices: &MlxBuffer,
per_expert_scale: &MlxBuffer,
out_expert_ids: &MlxBuffer,
out_weights: &MlxBuffer,
n_experts: usize,
top_k: usize,
) -> Result<()> {
if n_experts == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
));
}
if top_k > n_experts {
return Err(MlxError::InvalidArgument(format!(
"moe_gather_topk_weights: top_k ({}) > n_experts ({})",
top_k, n_experts,
)));
}
if top_k > 8 {
return Err(MlxError::InvalidArgument(format!(
"moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
top_k,
)));
}
let f32_size = std::mem::size_of::<f32>();
let u32_size = std::mem::size_of::<u32>();
if softmax_probs.byte_len() < n_experts * f32_size {
return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
}
if sorted_indices.byte_len() < n_experts * u32_size {
return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
}
if per_expert_scale.byte_len() < n_experts * f32_size {
return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
}
if out_expert_ids.byte_len() < top_k * u32_size {
return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
}
if out_weights.byte_len() < top_k * f32_size {
return Err(MlxError::InvalidArgument("out_weights too small".into()));
}
let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
let params = GpuMoeGatherTopkParams {
n_experts: n_experts as u32,
top_k: top_k as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(softmax_probs)),
(1, KernelArg::Buffer(sorted_indices)),
(2, KernelArg::Buffer(per_expert_scale)),
(3, KernelArg::Buffer(out_expert_ids)),
(4, KernelArg::Buffer(out_weights)),
(5, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(1, 1, 1), MTLSize::new(1, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_fused_encode_offset(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
gu_byte_offset: usize,
output: &MlxBuffer,
out_byte_offset: usize,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
));
}
let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
gu_required, gu_byte_offset, gate_up.byte_len()
)));
}
let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
out_required, out_byte_offset, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
let params = GpuFusedGeluMulParams {
n_elements: n_elements as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
(1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_accumulate_encode_offset(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
accumulator: &MlxBuffer,
expert_output: &MlxBuffer,
src_byte_offset: usize,
routing_weight: f32,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_accumulate_encode_offset: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if accumulator.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
required, accumulator.byte_len()
)));
}
let src_required = src_byte_offset + required;
if expert_output.byte_len() < src_required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
src_required, src_byte_offset, expert_output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_accumulate", device)?;
let params = GpuMoeAccumParams {
n_elements: n_elements as u32,
routing_weight,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(accumulator)),
(1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeSwigluSeqParams {
intermediate: u32,
top_k: u32,
n_tokens: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_seq_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
intermediate: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_seq_encode: all dims must be > 0".into(),
));
}
let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
let gpu_params = GpuMoeSwigluSeqParams {
intermediate: intermediate as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeWeightedSumSeqParams {
hidden_size: u32,
top_k: u32,
n_tokens: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_weighted_sum_seq_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_seq_encode: all dims must be > 0".into(),
));
}
let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
if expert_outputs.byte_len() < expert_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
expert_required, expert_outputs.byte_len()
)));
}
let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
weights_required, weights.byte_len()
)));
}
let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
let gpu_params = GpuMoeWeightedSumSeqParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(weights)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}