use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub struct MoeDispatchParams {
pub input_dim: usize,
pub intermediate_dim: usize,
pub n_selected: usize,
}
pub struct ExpertWeights<'a> {
pub gate_proj: &'a MlxBuffer,
pub up_proj: &'a MlxBuffer,
pub down_proj: &'a MlxBuffer,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedGeluMulParams {
n_elements: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeAccumParams {
n_elements: u32,
routing_weight: f32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuZeroParams {
n_elements: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMatmulParams {
m: u32, k: u32, n: u32, }
#[allow(clippy::too_many_arguments)]
pub fn moe_dispatch(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
expert_weights: &[ExpertWeights<'_>],
routing_weights: &[f32],
output: &MlxBuffer,
scratch_gate: &MlxBuffer,
scratch_up: &MlxBuffer,
scratch_hidden: &MlxBuffer,
scratch_expert: &MlxBuffer,
params: &MoeDispatchParams,
) -> Result<()> {
if params.input_dim == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: input_dim must be > 0".into(),
));
}
if params.intermediate_dim == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: intermediate_dim must be > 0".into(),
));
}
if params.n_selected == 0 {
return Err(MlxError::InvalidArgument(
"moe_dispatch: n_selected must be > 0".into(),
));
}
if expert_weights.len() != params.n_selected {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert_weights length ({}) must match n_selected ({})",
expert_weights.len(),
params.n_selected
)));
}
if routing_weights.len() != params.n_selected {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: routing_weights length ({}) must match n_selected ({})",
routing_weights.len(),
params.n_selected
)));
}
let input_bytes = params.input_dim * std::mem::size_of::<f32>();
if input.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: input buffer too small: need {} bytes, have {}",
input_bytes,
input.byte_len()
)));
}
if output.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: output buffer too small: need {} bytes, have {}",
input_bytes,
output.byte_len()
)));
}
let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
if scratch_gate.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_gate buffer too small".into(),
));
}
if scratch_up.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_up buffer too small".into(),
));
}
if scratch_hidden.byte_len() < intermediate_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_hidden buffer too small".into(),
));
}
if scratch_expert.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(
"moe_dispatch: scratch_expert buffer too small".into(),
));
}
let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
for (i, ew) in expert_weights.iter().enumerate() {
if ew.gate_proj.byte_len() < gate_up_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
i, gate_up_bytes, ew.gate_proj.byte_len()
)));
}
if ew.up_proj.byte_len() < gate_up_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
i, gate_up_bytes, ew.up_proj.byte_len()
)));
}
if ew.down_proj.byte_len() < down_bytes {
return Err(MlxError::InvalidArgument(format!(
"moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
i, down_bytes, ew.down_proj.byte_len()
)));
}
}
{
registry.get_pipeline("naive_matvec_f32", device)?;
registry.get_pipeline("fused_gelu_mul", device)?;
registry.get_pipeline("moe_accumulate", device)?;
registry.get_pipeline("zero_buffer", device)?;
}
let matvec_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("naive_matvec_f32", device)?;
p as *const _
};
let gelu_mul_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("fused_gelu_mul", device)?;
p as *const _
};
let accum_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("moe_accumulate", device)?;
p as *const _
};
let zero_pipeline: *const metal::ComputePipelineState = {
let p = registry.get_pipeline("zero_buffer", device)?;
p as *const _
};
let matvec_pipeline = unsafe { &*matvec_pipeline };
let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
let accum_pipeline = unsafe { &*accum_pipeline };
let zero_pipeline = unsafe { &*zero_pipeline };
let zero_params = GpuZeroParams {
n_elements: params.input_dim as u32,
};
encode_with_args(
encoder,
zero_pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Bytes(as_bytes(&zero_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
for (i, ew) in expert_weights.iter().enumerate() {
let w = routing_weights[i];
if w.abs() < 1e-10 {
continue;
}
encoder.memory_barrier();
let gate_params = GpuMatmulParams {
m: 1,
k: params.input_dim as u32,
n: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.gate_proj)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(scratch_gate)),
(3, KernelArg::Bytes(as_bytes(&gate_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
let up_params = GpuMatmulParams {
m: 1,
k: params.input_dim as u32,
n: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.up_proj)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(scratch_up)),
(3, KernelArg::Bytes(as_bytes(&up_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
encoder.memory_barrier();
let gelu_params = GpuFusedGeluMulParams {
n_elements: params.intermediate_dim as u32,
};
encode_with_args(
encoder,
gelu_mul_pipeline,
&[
(0, KernelArg::Buffer(scratch_gate)),
(1, KernelArg::Buffer(scratch_up)),
(2, KernelArg::Buffer(scratch_hidden)),
(3, KernelArg::Bytes(as_bytes(&gelu_params))),
],
MTLSize::new(params.intermediate_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
);
encoder.memory_barrier();
let down_params = GpuMatmulParams {
m: 1,
k: params.intermediate_dim as u32,
n: params.input_dim as u32,
};
encode_with_args(
encoder,
matvec_pipeline,
&[
(0, KernelArg::Buffer(ew.down_proj)),
(1, KernelArg::Buffer(scratch_hidden)),
(2, KernelArg::Buffer(scratch_expert)),
(3, KernelArg::Bytes(as_bytes(&down_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
encoder.memory_barrier();
let accum_params = GpuMoeAccumParams {
n_elements: params.input_dim as u32,
routing_weight: w,
};
encode_with_args(
encoder,
accum_pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Buffer(scratch_expert)),
(2, KernelArg::Bytes(as_bytes(&accum_params))),
],
MTLSize::new(params.input_dim as u64, 1, 1),
MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
);
}
Ok(())
}
pub fn moe_swiglu_fused_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_fused_encode: n_elements must be > 0".into(),
));
}
let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = n_elements * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
let params = GpuFusedGeluMulParams {
n_elements: n_elements as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_zero_buffer_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
output: &MlxBuffer,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_zero_buffer_encode: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if output.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("zero_buffer", device)?;
let params = GpuZeroParams { n_elements: n_elements as u32 };
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(output)),
(1, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_accumulate_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
accumulator: &MlxBuffer,
expert_output: &MlxBuffer,
routing_weight: f32,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_accumulate_encode: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if accumulator.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
required, accumulator.byte_len()
)));
}
if expert_output.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
required, expert_output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_accumulate", device)?;
let params = GpuMoeAccumParams {
n_elements: n_elements as u32,
routing_weight,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(accumulator)),
(1, KernelArg::Buffer(expert_output)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn moe_swiglu_batch_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
intermediate: usize,
top_k: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
));
}
let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = top_k * intermediate * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
let intermediate_bytes = (intermediate as u32).to_ne_bytes();
let top_k_bytes = (top_k as u32).to_ne_bytes();
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(&intermediate_bytes)),
(3, KernelArg::Bytes(&top_k_bytes)),
],
MTLSize::new(intermediate as u64, top_k as u64, 1),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeWeightedSumParams {
hidden_size: u32,
top_k: u32,
}
pub fn moe_weighted_sum_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
));
}
let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
if expert_outputs.byte_len() < expert_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
expert_required, expert_outputs.byte_len()
)));
}
let weights_required = top_k * std::mem::size_of::<f32>();
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
weights_required, weights.byte_len()
)));
}
let out_required = hidden_size * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
let params = GpuMoeWeightedSumParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(weights)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(hidden_size as u64, 1, 1),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeGatherTopkParams {
n_experts: u32,
top_k: u32,
}
pub fn moe_gather_topk_weights_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
softmax_probs: &MlxBuffer,
sorted_indices: &MlxBuffer,
per_expert_scale: &MlxBuffer,
out_expert_ids: &MlxBuffer,
out_weights: &MlxBuffer,
n_experts: usize,
top_k: usize,
) -> Result<()> {
if n_experts == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
));
}
if top_k > n_experts {
return Err(MlxError::InvalidArgument(format!(
"moe_gather_topk_weights: top_k ({}) > n_experts ({})",
top_k, n_experts,
)));
}
if top_k > 8 {
return Err(MlxError::InvalidArgument(format!(
"moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
top_k,
)));
}
let f32_size = std::mem::size_of::<f32>();
let u32_size = std::mem::size_of::<u32>();
if softmax_probs.byte_len() < n_experts * f32_size {
return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
}
if sorted_indices.byte_len() < n_experts * u32_size {
return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
}
if per_expert_scale.byte_len() < n_experts * f32_size {
return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
}
if out_expert_ids.byte_len() < top_k * u32_size {
return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
}
if out_weights.byte_len() < top_k * f32_size {
return Err(MlxError::InvalidArgument("out_weights too small".into()));
}
let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
let params = GpuMoeGatherTopkParams {
n_experts: n_experts as u32,
top_k: top_k as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(softmax_probs)),
(1, KernelArg::Buffer(sorted_indices)),
(2, KernelArg::Buffer(per_expert_scale)),
(3, KernelArg::Buffer(out_expert_ids)),
(4, KernelArg::Buffer(out_weights)),
(5, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(1, 1, 1), MTLSize::new(1, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_fused_encode_offset(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
gu_byte_offset: usize,
output: &MlxBuffer,
out_byte_offset: usize,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
));
}
let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
gu_required, gu_byte_offset, gate_up.byte_len()
)));
}
let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
out_required, out_byte_offset, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
let params = GpuFusedGeluMulParams {
n_elements: n_elements as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
(1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_accumulate_encode_offset(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
accumulator: &MlxBuffer,
expert_output: &MlxBuffer,
src_byte_offset: usize,
routing_weight: f32,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"moe_accumulate_encode_offset: n_elements must be > 0".into(),
));
}
let required = n_elements * std::mem::size_of::<f32>();
if accumulator.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
required, accumulator.byte_len()
)));
}
let src_required = src_byte_offset + required;
if expert_output.byte_len() < src_required {
return Err(MlxError::InvalidArgument(format!(
"moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
src_required, src_byte_offset, expert_output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_accumulate", device)?;
let params = GpuMoeAccumParams {
n_elements: n_elements as u32,
routing_weight,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(accumulator)),
(1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
(2, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
pub fn fused_gelu_mul_bf16_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_out: &MlxBuffer,
up_out: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"fused_gelu_mul_bf16_encode: n_elements must be > 0".into(),
));
}
let required = n_elements * 2;
if gate_out.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"fused_gelu_mul_bf16_encode: gate_out too small: need {} bytes, have {}",
required, gate_out.byte_len()
)));
}
if up_out.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"fused_gelu_mul_bf16_encode: up_out too small: need {} bytes, have {}",
required, up_out.byte_len()
)));
}
if output.byte_len() < required {
return Err(MlxError::InvalidArgument(format!(
"fused_gelu_mul_bf16_encode: output too small: need {} bytes, have {}",
required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("fused_gelu_mul_bf16", device)?;
let params = GpuFusedGeluMulParams {
n_elements: n_elements as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_out)),
(1, KernelArg::Buffer(up_out)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(¶ms))),
],
MTLSize::new(n_elements as u64, 1, 1),
MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeSwigluSeqParams {
intermediate: u32,
top_k: u32,
n_tokens: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_seq_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
intermediate: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_seq_encode: all dims must be > 0".into(),
));
}
let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
let gpu_params = GpuMoeSwigluSeqParams {
intermediate: intermediate as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuMoeWeightedSumSeqParams {
hidden_size: u32,
top_k: u32,
n_tokens: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_weighted_sum_seq_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_seq_encode: all dims must be > 0".into(),
));
}
let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
if expert_outputs.byte_len() < expert_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
expert_required, expert_outputs.byte_len()
)));
}
let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
weights_required, weights.byte_len()
)));
}
let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
let gpu_params = GpuMoeWeightedSumSeqParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(weights)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_seq_bf16_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
output: &MlxBuffer,
intermediate: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_seq_bf16_encode: all dims must be > 0".into(),
));
}
let gu_required = n_tokens * top_k * 2 * intermediate * 2;
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_bf16_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
let out_required = n_tokens * top_k * intermediate * 2;
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_bf16_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_seq_bf16", device)?;
let gpu_params = GpuMoeSwigluSeqParams {
intermediate: intermediate as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_weighted_sum_seq_bf16_input_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
weights: &MlxBuffer,
output: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_seq_bf16_input_encode: all dims must be > 0".into(),
));
}
let expert_required = n_tokens * top_k * hidden_size * 2;
if expert_outputs.byte_len() < expert_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_bf16_input_encode: expert_outputs too small: need {} bytes, have {}",
expert_required, expert_outputs.byte_len()
)));
}
let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_bf16_input_encode: weights too small: need {} bytes, have {}",
weights_required, weights.byte_len()
)));
}
let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
if output.byte_len() < out_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_bf16_input_encode: output too small: need {} bytes, have {}",
out_required, output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum_seq_bf16_input", device)?;
let gpu_params = GpuMoeWeightedSumSeqParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(weights)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_weighted_sum_seq_backward_outputs_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
weights: &MlxBuffer,
d_output: &MlxBuffer,
d_expert_outputs: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_seq_backward_outputs_encode: all dims must be > 0".into(),
));
}
let f32_size = std::mem::size_of::<f32>();
let weights_required = n_tokens * top_k * f32_size;
if weights.byte_len() < weights_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_outputs_encode: weights too small: need {} bytes, have {}",
weights_required,
weights.byte_len()
)));
}
let dout_required = n_tokens * hidden_size * f32_size;
if d_output.byte_len() < dout_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_outputs_encode: d_output too small: need {} bytes, have {}",
dout_required,
d_output.byte_len()
)));
}
let dexp_required = n_tokens * top_k * hidden_size * f32_size;
if d_expert_outputs.byte_len() < dexp_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_outputs_encode: d_expert_outputs too small: need {} bytes, have {}",
dexp_required,
d_expert_outputs.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_outputs_f32", device)?;
let gpu_params = GpuMoeWeightedSumSeqParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(weights)),
(1, KernelArg::Buffer(d_output)),
(2, KernelArg::Buffer(d_expert_outputs)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(hidden_size as u64, top_k as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn moe_weighted_sum_seq_backward_weights_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
expert_outputs: &MlxBuffer,
d_output: &MlxBuffer,
d_weights: &MlxBuffer,
hidden_size: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_sum_seq_backward_weights_encode: all dims must be > 0".into(),
));
}
let f32_size = std::mem::size_of::<f32>();
let exp_required = n_tokens * top_k * hidden_size * f32_size;
if expert_outputs.byte_len() < exp_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_weights_encode: expert_outputs too small: need {} bytes, have {}",
exp_required,
expert_outputs.byte_len()
)));
}
let dout_required = n_tokens * hidden_size * f32_size;
if d_output.byte_len() < dout_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_weights_encode: d_output too small: need {} bytes, have {}",
dout_required,
d_output.byte_len()
)));
}
let dw_required = n_tokens * top_k * f32_size;
if d_weights.byte_len() < dw_required {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_sum_seq_backward_weights_encode: d_weights too small: need {} bytes, have {}",
dw_required,
d_weights.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_weights_f32", device)?;
let gpu_params = GpuMoeWeightedSumSeqParams {
hidden_size: hidden_size as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
let total = (n_tokens * top_k) as u64;
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(expert_outputs)),
(1, KernelArg::Buffer(d_output)),
(2, KernelArg::Buffer(d_weights)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(total, 1, 1),
MTLSize::new(std::cmp::min(256, total), 1, 1),
);
Ok(())
}
#[cfg(test)]
mod backward_weighted_sum_seq_tests {
use super::*;
use crate::device::MlxDevice;
use crate::dtypes::DType;
fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
b.as_mut_slice::<f32>().unwrap().fill(0.0);
b
}
fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
}
fn cpu_forward(
expert_outs: &[f32],
weights: &[f32],
n_tokens: usize,
top_k: usize,
hidden: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; n_tokens * hidden];
for t in 0..n_tokens {
for d in 0..hidden {
let mut sum = 0.0f64;
for k in 0..top_k {
let exp_ix = (t * top_k + k) * hidden + d;
let w_ix = t * top_k + k;
sum += (expert_outs[exp_ix] as f64) * (weights[w_ix] as f64);
}
out[t * hidden + d] = sum as f32;
}
}
out
}
#[test]
fn backward_outputs_finite_difference_falsifier() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let n_tokens = 3usize;
let top_k = 4usize;
let hidden = 7usize;
let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
.map(|i| 0.1 + (i as f32) * 0.013 - (i as f32 * 0.004).sin())
.collect();
let weights: Vec<f32> = (0..n_tokens * top_k)
.map(|i| 0.2 + (i as f32) * 0.07)
.collect();
let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
.map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
.collect();
let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
fill_f32(&mut exp_buf, &expert_outs);
let mut w_buf = alloc_f32(&device, n_tokens * top_k);
fill_f32(&mut w_buf, &weights);
let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
fill_f32(&mut dout_buf, &d_output_seed);
let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
let mut encoder = device.command_encoder().unwrap();
moe_weighted_sum_seq_backward_outputs_encode(
&mut encoder, &mut registry, device.metal_device(),
&w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
).unwrap();
encoder.commit_and_wait().unwrap();
let analytic = dexp_buf.as_slice::<f32>().unwrap().to_vec();
let h: f32 = 1e-3;
for idx in 0..(n_tokens * top_k * hidden) {
let mut ep = expert_outs.clone(); ep[idx] += h;
let mut em = expert_outs.clone(); em[idx] -= h;
let yp = cpu_forward(&ep, &weights, n_tokens, top_k, hidden);
let ym = cpu_forward(&em, &weights, n_tokens, top_k, hidden);
let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let fd = (lp - lm) / (2.0 * h as f64);
let tol = 5e-2 * fd.abs().max(1.0);
assert!(
(analytic[idx] as f64 - fd).abs() < tol,
"d_expert_outputs[{}]: analytic={} fd={}",
idx, analytic[idx], fd
);
}
}
#[test]
fn backward_weights_finite_difference_falsifier() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let n_tokens = 4usize;
let top_k = 3usize;
let hidden = 11usize;
let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
.map(|i| 0.05 + (i as f32) * 0.017 + (i as f32 * 0.009).sin())
.collect();
let weights: Vec<f32> = (0..n_tokens * top_k)
.map(|i| 0.1 + (i as f32) * 0.05)
.collect();
let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
.map(|i| 0.2 + (i as f32) * 0.03 - (i as f32 * 0.013).cos())
.collect();
let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
fill_f32(&mut exp_buf, &expert_outs);
let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
fill_f32(&mut dout_buf, &d_output_seed);
let dw_buf = alloc_f32(&device, n_tokens * top_k);
let mut encoder = device.command_encoder().unwrap();
moe_weighted_sum_seq_backward_weights_encode(
&mut encoder, &mut registry, device.metal_device(),
&exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
).unwrap();
encoder.commit_and_wait().unwrap();
let analytic = dw_buf.as_slice::<f32>().unwrap().to_vec();
let h: f32 = 1e-3;
for idx in 0..(n_tokens * top_k) {
let mut wp = weights.clone(); wp[idx] += h;
let mut wm = weights.clone(); wm[idx] -= h;
let yp = cpu_forward(&expert_outs, &wp, n_tokens, top_k, hidden);
let ym = cpu_forward(&expert_outs, &wm, n_tokens, top_k, hidden);
let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let fd = (lp - lm) / (2.0 * h as f64);
let tol = 5e-2 * fd.abs().max(1.0);
assert!(
(analytic[idx] as f64 - fd).abs() < tol,
"d_weights[{}]: analytic={} fd={}",
idx, analytic[idx], fd
);
}
}
#[test]
fn forward_then_backward_round_trip_matches_cpu_oracle() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let n_tokens = 2usize;
let top_k = 2usize;
let hidden = 5usize;
let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
.map(|i| (i as f32) * 0.1 - 0.2)
.collect();
let weights: Vec<f32> = (0..n_tokens * top_k)
.map(|i| 0.3 + (i as f32) * 0.1)
.collect();
let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
.map(|i| 1.0 - (i as f32) * 0.05)
.collect();
let mut cpu_dexp = vec![0.0f32; n_tokens * top_k * hidden];
let mut cpu_dw = vec![0.0f32; n_tokens * top_k];
for t in 0..n_tokens {
for k in 0..top_k {
for d in 0..hidden {
let exp_ix = (t * top_k + k) * hidden + d;
let w_ix = t * top_k + k;
let dout_ix = t * hidden + d;
cpu_dexp[exp_ix] = weights[w_ix] * d_output_seed[dout_ix];
cpu_dw[w_ix] += expert_outs[exp_ix] * d_output_seed[dout_ix];
}
}
}
let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
fill_f32(&mut exp_buf, &expert_outs);
let mut w_buf = alloc_f32(&device, n_tokens * top_k);
fill_f32(&mut w_buf, &weights);
let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
fill_f32(&mut dout_buf, &d_output_seed);
let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
let dw_buf = alloc_f32(&device, n_tokens * top_k);
let mut encoder = device.command_encoder().unwrap();
moe_weighted_sum_seq_backward_outputs_encode(
&mut encoder, &mut registry, device.metal_device(),
&w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
).unwrap();
moe_weighted_sum_seq_backward_weights_encode(
&mut encoder, &mut registry, device.metal_device(),
&exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu_dexp = dexp_buf.as_slice::<f32>().unwrap().to_vec();
let gpu_dw = dw_buf.as_slice::<f32>().unwrap().to_vec();
for i in 0..gpu_dexp.len() {
assert!(
(gpu_dexp[i] - cpu_dexp[i]).abs() < 1e-5,
"d_expert_outputs[{}]: gpu={} cpu={}",
i, gpu_dexp[i], cpu_dexp[i]
);
}
for i in 0..gpu_dw.len() {
assert!(
(gpu_dw[i] - cpu_dw[i]).abs() < 1e-5,
"d_weights[{}]: gpu={} cpu={}",
i, gpu_dw[i], cpu_dw[i]
);
}
}
#[test]
fn rejects_size_mismatch() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let too_small = alloc_f32(&device, 4);
let any = alloc_f32(&device, 1024);
let mut encoder = device.command_encoder().unwrap();
let res = moe_weighted_sum_seq_backward_outputs_encode(
&mut encoder, &mut registry, device.metal_device(),
&too_small, &any, &any, 7, 3, 4,
);
assert!(res.is_err());
let res2 = moe_weighted_sum_seq_backward_weights_encode(
&mut encoder, &mut registry, device.metal_device(),
&too_small, &any, &any, 11, 3, 4,
);
assert!(res2.is_err());
}
}
#[allow(clippy::too_many_arguments)]
pub fn moe_swiglu_seq_backward_encode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate_up: &MlxBuffer,
d_output: &MlxBuffer,
d_gate_up: &MlxBuffer,
intermediate: usize,
top_k: usize,
n_tokens: usize,
) -> Result<()> {
if intermediate == 0 || top_k == 0 || n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"moe_swiglu_seq_backward_encode: all dims must be > 0".into(),
));
}
let f32_size = std::mem::size_of::<f32>();
let gu_required = n_tokens * top_k * 2 * intermediate * f32_size;
if gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_backward_encode: gate_up too small: need {} bytes, have {}",
gu_required, gate_up.byte_len()
)));
}
if d_gate_up.byte_len() < gu_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_backward_encode: d_gate_up too small: need {} bytes, have {}",
gu_required, d_gate_up.byte_len()
)));
}
let dout_required = n_tokens * top_k * intermediate * f32_size;
if d_output.byte_len() < dout_required {
return Err(MlxError::InvalidArgument(format!(
"moe_swiglu_seq_backward_encode: d_output too small: need {} bytes, have {}",
dout_required, d_output.byte_len()
)));
}
let pipeline = registry.get_pipeline("moe_swiglu_seq_backward_f32", device)?;
let gpu_params = GpuMoeSwigluSeqParams {
intermediate: intermediate as u32,
top_k: top_k as u32,
n_tokens: n_tokens as u32,
};
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(gate_up)),
(1, KernelArg::Buffer(d_output)),
(2, KernelArg::Buffer(d_gate_up)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
);
Ok(())
}
#[cfg(test)]
mod backward_swiglu_seq_tests {
use super::*;
use crate::device::MlxDevice;
use crate::dtypes::DType;
fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
b.as_mut_slice::<f32>().unwrap().fill(0.0);
b
}
fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
}
fn cpu_gelu(g: f64) -> f64 {
let s = 0.7978845608 * (g + 0.044715 * g * g * g);
let t = s.tanh();
0.5 * g * (1.0 + t)
}
fn cpu_forward(
gate_up: &[f32],
n_tokens: usize,
top_k: usize,
intermediate: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; n_tokens * top_k * intermediate];
for t in 0..n_tokens {
for k in 0..top_k {
let slot_base = (t * top_k + k) * 2 * intermediate;
for i in 0..intermediate {
let g = gate_up[slot_base + i] as f64;
let u = gate_up[slot_base + intermediate + i] as f64;
let y = cpu_gelu(g) * u;
out[(t * top_k + k) * intermediate + i] = y as f32;
}
}
}
out
}
#[test]
fn backward_finite_difference_falsifier_both_gradients() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let n_tokens = 2usize;
let top_k = 3usize;
let intermediate = 5usize;
let gu_n = n_tokens * top_k * 2 * intermediate;
let dout_n = n_tokens * top_k * intermediate;
let gate_up: Vec<f32> = (0..gu_n)
.map(|i| 0.5 + (i as f32) * 0.07 - (i as f32 * 0.013).sin())
.collect();
let d_output_seed: Vec<f32> = (0..dout_n)
.map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
.collect();
let mut gu_buf = alloc_f32(&device, gu_n);
fill_f32(&mut gu_buf, &gate_up);
let mut dout_buf = alloc_f32(&device, dout_n);
fill_f32(&mut dout_buf, &d_output_seed);
let dgu_buf = alloc_f32(&device, gu_n);
let mut encoder = device.command_encoder().unwrap();
moe_swiglu_seq_backward_encode(
&mut encoder, &mut registry, device.metal_device(),
&gu_buf, &dout_buf, &dgu_buf,
intermediate, top_k, n_tokens,
).unwrap();
encoder.commit_and_wait().unwrap();
let analytic = dgu_buf.as_slice::<f32>().unwrap().to_vec();
let h: f32 = 1e-3;
for idx in 0..gu_n {
let mut gp = gate_up.clone();
gp[idx] += h;
let mut gm = gate_up.clone();
gm[idx] -= h;
let yp = cpu_forward(&gp, n_tokens, top_k, intermediate);
let ym = cpu_forward(&gm, n_tokens, top_k, intermediate);
let lp: f64 = yp.iter().zip(&d_output_seed)
.map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let lm: f64 = ym.iter().zip(&d_output_seed)
.map(|(a, b)| (*a as f64) * (*b as f64)).sum();
let fd = (lp - lm) / (2.0 * h as f64);
let tol = 5e-2 * fd.abs().max(1.0);
assert!(
(analytic[idx] as f64 - fd).abs() < tol,
"d_gate_up[{}]: analytic={} fd={} (gate_up_value={})",
idx, analytic[idx], fd, gate_up[idx]
);
}
}
#[test]
fn backward_canonical_asymptotics_match_expected() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let n_tokens = 1usize;
let top_k = 1usize;
let intermediate = 3usize;
let gu_n = n_tokens * top_k * 2 * intermediate;
let dout_n = n_tokens * top_k * intermediate;
let gate_up = vec![0.0f32, 10.0, -10.0, 2.0, 3.0, 4.0];
let d_output_seed = vec![0.5f32, 1.0, 1.0];
let mut gu_buf = alloc_f32(&device, gu_n);
fill_f32(&mut gu_buf, &gate_up);
let mut dout_buf = alloc_f32(&device, dout_n);
fill_f32(&mut dout_buf, &d_output_seed);
let dgu_buf = alloc_f32(&device, gu_n);
let mut encoder = device.command_encoder().unwrap();
moe_swiglu_seq_backward_encode(
&mut encoder, &mut registry, device.metal_device(),
&gu_buf, &dout_buf, &dgu_buf,
intermediate, top_k, n_tokens,
).unwrap();
encoder.commit_and_wait().unwrap();
let g = dgu_buf.as_slice::<f32>().unwrap();
assert!((g[0] - 0.5).abs() < 1e-5, "∂gate0={}", g[0]);
assert!((g[1] - 3.0).abs() < 0.05, "∂gate1={}", g[1]);
assert!(g[2].abs() < 0.05, "∂gate2={}", g[2]);
assert!(g[3].abs() < 1e-5, "∂up0={}", g[3]);
assert!((g[4] - 10.0).abs() < 0.05, "∂up1={}", g[4]);
assert!(g[5].abs() < 0.05, "∂up2={}", g[5]);
}
#[test]
fn rejects_size_mismatch() {
let device = MlxDevice::new().unwrap();
let mut registry = KernelRegistry::new();
let too_small = alloc_f32(&device, 4);
let any = alloc_f32(&device, 1024);
let mut encoder = device.command_encoder().unwrap();
let res = moe_swiglu_seq_backward_encode(
&mut encoder, &mut registry, device.metal_device(),
&too_small, &any, &any, 5, 3, 2,
);
assert!(res.is_err());
}
}