1use metal::MTLSize;
19
20use crate::buffer::MlxBuffer;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
26
27pub struct MoeDispatchParams {
29 pub input_dim: usize,
31 pub intermediate_dim: usize,
33 pub n_selected: usize,
35}
36
37pub struct ExpertWeights<'a> {
44 pub gate_proj: &'a MlxBuffer,
45 pub up_proj: &'a MlxBuffer,
46 pub down_proj: &'a MlxBuffer,
47}
48
49#[repr(C)]
51#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
52struct GpuFusedGeluMulParams {
53 n_elements: u32,
54}
55
56#[repr(C)]
58#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct GpuMoeAccumParams {
60 n_elements: u32,
61 routing_weight: f32,
62}
63
64#[repr(C)]
66#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
67struct GpuZeroParams {
68 n_elements: u32,
69}
70
71#[repr(C)]
74#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
75struct GpuMatmulParams {
76 m: u32, k: u32, n: u32, }
80
81#[allow(clippy::too_many_arguments)]
112pub fn moe_dispatch(
113 encoder: &mut CommandEncoder,
114 registry: &mut KernelRegistry,
115 device: &metal::DeviceRef,
116 input: &MlxBuffer,
117 expert_weights: &[ExpertWeights<'_>],
118 routing_weights: &[f32],
119 output: &MlxBuffer,
120 scratch_gate: &MlxBuffer,
121 scratch_up: &MlxBuffer,
122 scratch_hidden: &MlxBuffer,
123 scratch_expert: &MlxBuffer,
124 params: &MoeDispatchParams,
125) -> Result<()> {
126 if params.input_dim == 0 {
128 return Err(MlxError::InvalidArgument(
129 "moe_dispatch: input_dim must be > 0".into(),
130 ));
131 }
132 if params.intermediate_dim == 0 {
133 return Err(MlxError::InvalidArgument(
134 "moe_dispatch: intermediate_dim must be > 0".into(),
135 ));
136 }
137 if params.n_selected == 0 {
138 return Err(MlxError::InvalidArgument(
139 "moe_dispatch: n_selected must be > 0".into(),
140 ));
141 }
142 if expert_weights.len() != params.n_selected {
143 return Err(MlxError::InvalidArgument(format!(
144 "moe_dispatch: expert_weights length ({}) must match n_selected ({})",
145 expert_weights.len(),
146 params.n_selected
147 )));
148 }
149 if routing_weights.len() != params.n_selected {
150 return Err(MlxError::InvalidArgument(format!(
151 "moe_dispatch: routing_weights length ({}) must match n_selected ({})",
152 routing_weights.len(),
153 params.n_selected
154 )));
155 }
156
157 let input_bytes = params.input_dim * std::mem::size_of::<f32>();
159 if input.byte_len() < input_bytes {
160 return Err(MlxError::InvalidArgument(format!(
161 "moe_dispatch: input buffer too small: need {} bytes, have {}",
162 input_bytes,
163 input.byte_len()
164 )));
165 }
166 if output.byte_len() < input_bytes {
167 return Err(MlxError::InvalidArgument(format!(
168 "moe_dispatch: output buffer too small: need {} bytes, have {}",
169 input_bytes,
170 output.byte_len()
171 )));
172 }
173
174 let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
175 if scratch_gate.byte_len() < intermediate_bytes {
176 return Err(MlxError::InvalidArgument(
177 "moe_dispatch: scratch_gate buffer too small".into(),
178 ));
179 }
180 if scratch_up.byte_len() < intermediate_bytes {
181 return Err(MlxError::InvalidArgument(
182 "moe_dispatch: scratch_up buffer too small".into(),
183 ));
184 }
185 if scratch_hidden.byte_len() < intermediate_bytes {
186 return Err(MlxError::InvalidArgument(
187 "moe_dispatch: scratch_hidden buffer too small".into(),
188 ));
189 }
190 if scratch_expert.byte_len() < input_bytes {
191 return Err(MlxError::InvalidArgument(
192 "moe_dispatch: scratch_expert buffer too small".into(),
193 ));
194 }
195
196 let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
198 let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
199
200 for (i, ew) in expert_weights.iter().enumerate() {
201 if ew.gate_proj.byte_len() < gate_up_bytes {
202 return Err(MlxError::InvalidArgument(format!(
203 "moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
204 i, gate_up_bytes, ew.gate_proj.byte_len()
205 )));
206 }
207 if ew.up_proj.byte_len() < gate_up_bytes {
208 return Err(MlxError::InvalidArgument(format!(
209 "moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
210 i, gate_up_bytes, ew.up_proj.byte_len()
211 )));
212 }
213 if ew.down_proj.byte_len() < down_bytes {
214 return Err(MlxError::InvalidArgument(format!(
215 "moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
216 i, down_bytes, ew.down_proj.byte_len()
217 )));
218 }
219 }
220
221 {
227 registry.get_pipeline("naive_matvec_f32", device)?;
228 registry.get_pipeline("fused_gelu_mul", device)?;
229 registry.get_pipeline("moe_accumulate", device)?;
230 registry.get_pipeline("zero_buffer", device)?;
231 }
232 let matvec_pipeline: *const metal::ComputePipelineState = {
237 let p = registry.get_pipeline("naive_matvec_f32", device)?;
238 p as *const _
239 };
240 let gelu_mul_pipeline: *const metal::ComputePipelineState = {
241 let p = registry.get_pipeline("fused_gelu_mul", device)?;
242 p as *const _
243 };
244 let accum_pipeline: *const metal::ComputePipelineState = {
245 let p = registry.get_pipeline("moe_accumulate", device)?;
246 p as *const _
247 };
248 let zero_pipeline: *const metal::ComputePipelineState = {
249 let p = registry.get_pipeline("zero_buffer", device)?;
250 p as *const _
251 };
252 let matvec_pipeline = unsafe { &*matvec_pipeline };
257 let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
258 let accum_pipeline = unsafe { &*accum_pipeline };
259 let zero_pipeline = unsafe { &*zero_pipeline };
260
261 let zero_params = GpuZeroParams {
263 n_elements: params.input_dim as u32,
264 };
265 encode_with_args(
266 encoder,
267 zero_pipeline,
268 &[
269 (0, KernelArg::Buffer(output)),
270 (1, KernelArg::Bytes(as_bytes(&zero_params))),
271 ],
272 MTLSize::new(params.input_dim as u64, 1, 1),
273 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
274 );
275
276 for (i, ew) in expert_weights.iter().enumerate() {
278 let w = routing_weights[i];
279
280 if w.abs() < 1e-10 {
282 continue;
283 }
284
285 encoder.memory_barrier();
288
289 let gate_params = GpuMatmulParams {
291 m: 1,
292 k: params.input_dim as u32,
293 n: params.intermediate_dim as u32,
294 };
295 encode_with_args(
296 encoder,
297 matvec_pipeline,
298 &[
299 (0, KernelArg::Buffer(ew.gate_proj)),
300 (1, KernelArg::Buffer(input)),
301 (2, KernelArg::Buffer(scratch_gate)),
302 (3, KernelArg::Bytes(as_bytes(&gate_params))),
303 ],
304 MTLSize::new(params.intermediate_dim as u64, 1, 1),
305 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
306 );
307
308 let up_params = GpuMatmulParams {
310 m: 1,
311 k: params.input_dim as u32,
312 n: params.intermediate_dim as u32,
313 };
314 encode_with_args(
315 encoder,
316 matvec_pipeline,
317 &[
318 (0, KernelArg::Buffer(ew.up_proj)),
319 (1, KernelArg::Buffer(input)),
320 (2, KernelArg::Buffer(scratch_up)),
321 (3, KernelArg::Bytes(as_bytes(&up_params))),
322 ],
323 MTLSize::new(params.intermediate_dim as u64, 1, 1),
324 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
325 );
326
327 encoder.memory_barrier();
329
330 let gelu_params = GpuFusedGeluMulParams {
332 n_elements: params.intermediate_dim as u32,
333 };
334 encode_with_args(
335 encoder,
336 gelu_mul_pipeline,
337 &[
338 (0, KernelArg::Buffer(scratch_gate)),
339 (1, KernelArg::Buffer(scratch_up)),
340 (2, KernelArg::Buffer(scratch_hidden)),
341 (3, KernelArg::Bytes(as_bytes(&gelu_params))),
342 ],
343 MTLSize::new(params.intermediate_dim as u64, 1, 1),
344 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
345 );
346
347 encoder.memory_barrier();
349
350 let down_params = GpuMatmulParams {
352 m: 1,
353 k: params.intermediate_dim as u32,
354 n: params.input_dim as u32,
355 };
356 encode_with_args(
357 encoder,
358 matvec_pipeline,
359 &[
360 (0, KernelArg::Buffer(ew.down_proj)),
361 (1, KernelArg::Buffer(scratch_hidden)),
362 (2, KernelArg::Buffer(scratch_expert)),
363 (3, KernelArg::Bytes(as_bytes(&down_params))),
364 ],
365 MTLSize::new(params.input_dim as u64, 1, 1),
366 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
367 );
368
369 encoder.memory_barrier();
371
372 let accum_params = GpuMoeAccumParams {
374 n_elements: params.input_dim as u32,
375 routing_weight: w,
376 };
377 encode_with_args(
378 encoder,
379 accum_pipeline,
380 &[
381 (0, KernelArg::Buffer(output)),
382 (1, KernelArg::Buffer(scratch_expert)),
383 (2, KernelArg::Bytes(as_bytes(&accum_params))),
384 ],
385 MTLSize::new(params.input_dim as u64, 1, 1),
386 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
387 );
388 }
389
390 Ok(())
391}
392
393pub fn moe_swiglu_fused_encode(
407 encoder: &mut CommandEncoder,
408 registry: &mut KernelRegistry,
409 device: &metal::DeviceRef,
410 gate_up: &MlxBuffer,
411 output: &MlxBuffer,
412 n_elements: usize,
413) -> Result<()> {
414 if n_elements == 0 {
415 return Err(MlxError::InvalidArgument(
416 "moe_swiglu_fused_encode: n_elements must be > 0".into(),
417 ));
418 }
419 let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
420 if gate_up.byte_len() < gu_required {
421 return Err(MlxError::InvalidArgument(format!(
422 "moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
423 gu_required, gate_up.byte_len()
424 )));
425 }
426 let out_required = n_elements * std::mem::size_of::<f32>();
427 if output.byte_len() < out_required {
428 return Err(MlxError::InvalidArgument(format!(
429 "moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
430 out_required, output.byte_len()
431 )));
432 }
433
434 let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
435 let params = GpuFusedGeluMulParams {
436 n_elements: n_elements as u32,
437 };
438 encode_with_args(
439 encoder,
440 pipeline,
441 &[
442 (0, KernelArg::Buffer(gate_up)),
443 (1, KernelArg::Buffer(output)),
444 (2, KernelArg::Bytes(as_bytes(¶ms))),
445 ],
446 MTLSize::new(n_elements as u64, 1, 1),
447 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
448 );
449 Ok(())
450}
451
452pub fn moe_zero_buffer_encode(
464 encoder: &mut CommandEncoder,
465 registry: &mut KernelRegistry,
466 device: &metal::DeviceRef,
467 output: &MlxBuffer,
468 n_elements: usize,
469) -> Result<()> {
470 if n_elements == 0 {
471 return Err(MlxError::InvalidArgument(
472 "moe_zero_buffer_encode: n_elements must be > 0".into(),
473 ));
474 }
475 let required = n_elements * std::mem::size_of::<f32>();
476 if output.byte_len() < required {
477 return Err(MlxError::InvalidArgument(format!(
478 "moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
479 required, output.byte_len()
480 )));
481 }
482
483 let pipeline = registry.get_pipeline("zero_buffer", device)?;
484 let params = GpuZeroParams { n_elements: n_elements as u32 };
485 encode_with_args(
486 encoder,
487 pipeline,
488 &[
489 (0, KernelArg::Buffer(output)),
490 (1, KernelArg::Bytes(as_bytes(¶ms))),
491 ],
492 MTLSize::new(n_elements as u64, 1, 1),
493 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
494 );
495 Ok(())
496}
497
498pub fn moe_accumulate_encode(
511 encoder: &mut CommandEncoder,
512 registry: &mut KernelRegistry,
513 device: &metal::DeviceRef,
514 accumulator: &MlxBuffer,
515 expert_output: &MlxBuffer,
516 routing_weight: f32,
517 n_elements: usize,
518) -> Result<()> {
519 if n_elements == 0 {
520 return Err(MlxError::InvalidArgument(
521 "moe_accumulate_encode: n_elements must be > 0".into(),
522 ));
523 }
524 let required = n_elements * std::mem::size_of::<f32>();
525 if accumulator.byte_len() < required {
526 return Err(MlxError::InvalidArgument(format!(
527 "moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
528 required, accumulator.byte_len()
529 )));
530 }
531 if expert_output.byte_len() < required {
532 return Err(MlxError::InvalidArgument(format!(
533 "moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
534 required, expert_output.byte_len()
535 )));
536 }
537
538 let pipeline = registry.get_pipeline("moe_accumulate", device)?;
539 let params = GpuMoeAccumParams {
540 n_elements: n_elements as u32,
541 routing_weight,
542 };
543 encode_with_args(
544 encoder,
545 pipeline,
546 &[
547 (0, KernelArg::Buffer(accumulator)),
548 (1, KernelArg::Buffer(expert_output)),
549 (2, KernelArg::Bytes(as_bytes(¶ms))),
550 ],
551 MTLSize::new(n_elements as u64, 1, 1),
552 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
553 );
554 Ok(())
555}
556
557pub fn moe_swiglu_batch_encode(
573 encoder: &mut CommandEncoder,
574 registry: &mut KernelRegistry,
575 device: &metal::DeviceRef,
576 gate_up: &MlxBuffer,
577 output: &MlxBuffer,
578 intermediate: usize,
579 top_k: usize,
580) -> Result<()> {
581 if intermediate == 0 || top_k == 0 {
582 return Err(MlxError::InvalidArgument(
583 "moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
584 ));
585 }
586 let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
587 if gate_up.byte_len() < gu_required {
588 return Err(MlxError::InvalidArgument(format!(
589 "moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
590 gu_required, gate_up.byte_len()
591 )));
592 }
593 let out_required = top_k * intermediate * std::mem::size_of::<f32>();
594 if output.byte_len() < out_required {
595 return Err(MlxError::InvalidArgument(format!(
596 "moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
597 out_required, output.byte_len()
598 )));
599 }
600
601 let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
602 let intermediate_bytes = (intermediate as u32).to_ne_bytes();
603 let top_k_bytes = (top_k as u32).to_ne_bytes();
604
605 encode_with_args(
606 encoder,
607 pipeline,
608 &[
609 (0, KernelArg::Buffer(gate_up)),
610 (1, KernelArg::Buffer(output)),
611 (2, KernelArg::Bytes(&intermediate_bytes)),
612 (3, KernelArg::Bytes(&top_k_bytes)),
613 ],
614 MTLSize::new(intermediate as u64, top_k as u64, 1),
615 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
616 );
617 Ok(())
618}
619
620#[repr(C)]
622#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
623struct GpuMoeWeightedSumParams {
624 hidden_size: u32,
625 top_k: u32,
626}
627
628pub fn moe_weighted_sum_encode(
644 encoder: &mut CommandEncoder,
645 registry: &mut KernelRegistry,
646 device: &metal::DeviceRef,
647 expert_outputs: &MlxBuffer,
648 weights: &MlxBuffer,
649 output: &MlxBuffer,
650 hidden_size: usize,
651 top_k: usize,
652) -> Result<()> {
653 if hidden_size == 0 || top_k == 0 {
654 return Err(MlxError::InvalidArgument(
655 "moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
656 ));
657 }
658 let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
659 if expert_outputs.byte_len() < expert_required {
660 return Err(MlxError::InvalidArgument(format!(
661 "moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
662 expert_required, expert_outputs.byte_len()
663 )));
664 }
665 let weights_required = top_k * std::mem::size_of::<f32>();
666 if weights.byte_len() < weights_required {
667 return Err(MlxError::InvalidArgument(format!(
668 "moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
669 weights_required, weights.byte_len()
670 )));
671 }
672 let out_required = hidden_size * std::mem::size_of::<f32>();
673 if output.byte_len() < out_required {
674 return Err(MlxError::InvalidArgument(format!(
675 "moe_weighted_sum_encode: output too small: need {} bytes, have {}",
676 out_required, output.byte_len()
677 )));
678 }
679
680 let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
681 let params = GpuMoeWeightedSumParams {
682 hidden_size: hidden_size as u32,
683 top_k: top_k as u32,
684 };
685 encode_with_args(
686 encoder,
687 pipeline,
688 &[
689 (0, KernelArg::Buffer(expert_outputs)),
690 (1, KernelArg::Buffer(weights)),
691 (2, KernelArg::Buffer(output)),
692 (3, KernelArg::Bytes(as_bytes(¶ms))),
693 ],
694 MTLSize::new(hidden_size as u64, 1, 1),
695 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
696 );
697 Ok(())
698}
699
700#[repr(C)]
702#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
703struct GpuMoeGatherTopkParams {
704 n_experts: u32,
705 top_k: u32,
706}
707
708pub fn moe_gather_topk_weights_encode(
727 encoder: &mut CommandEncoder,
728 registry: &mut KernelRegistry,
729 device: &metal::DeviceRef,
730 softmax_probs: &MlxBuffer,
731 sorted_indices: &MlxBuffer,
732 per_expert_scale: &MlxBuffer,
733 out_expert_ids: &MlxBuffer,
734 out_weights: &MlxBuffer,
735 n_experts: usize,
736 top_k: usize,
737) -> Result<()> {
738 if n_experts == 0 || top_k == 0 {
739 return Err(MlxError::InvalidArgument(
740 "moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
741 ));
742 }
743 if top_k > n_experts {
744 return Err(MlxError::InvalidArgument(format!(
745 "moe_gather_topk_weights: top_k ({}) > n_experts ({})",
746 top_k, n_experts,
747 )));
748 }
749 if top_k > 8 {
750 return Err(MlxError::InvalidArgument(format!(
751 "moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
752 top_k,
753 )));
754 }
755
756 let f32_size = std::mem::size_of::<f32>();
757 let u32_size = std::mem::size_of::<u32>();
758 if softmax_probs.byte_len() < n_experts * f32_size {
759 return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
760 }
761 if sorted_indices.byte_len() < n_experts * u32_size {
762 return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
763 }
764 if per_expert_scale.byte_len() < n_experts * f32_size {
765 return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
766 }
767 if out_expert_ids.byte_len() < top_k * u32_size {
768 return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
769 }
770 if out_weights.byte_len() < top_k * f32_size {
771 return Err(MlxError::InvalidArgument("out_weights too small".into()));
772 }
773
774 let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
775 let params = GpuMoeGatherTopkParams {
776 n_experts: n_experts as u32,
777 top_k: top_k as u32,
778 };
779 encode_with_args(
780 encoder,
781 pipeline,
782 &[
783 (0, KernelArg::Buffer(softmax_probs)),
784 (1, KernelArg::Buffer(sorted_indices)),
785 (2, KernelArg::Buffer(per_expert_scale)),
786 (3, KernelArg::Buffer(out_expert_ids)),
787 (4, KernelArg::Buffer(out_weights)),
788 (5, KernelArg::Bytes(as_bytes(¶ms))),
789 ],
790 MTLSize::new(1, 1, 1), MTLSize::new(1, 1, 1),
792 );
793 Ok(())
794}
795
796#[allow(clippy::too_many_arguments)]
802pub fn moe_swiglu_fused_encode_offset(
803 encoder: &mut CommandEncoder,
804 registry: &mut KernelRegistry,
805 device: &metal::DeviceRef,
806 gate_up: &MlxBuffer,
807 gu_byte_offset: usize,
808 output: &MlxBuffer,
809 out_byte_offset: usize,
810 n_elements: usize,
811) -> Result<()> {
812 if n_elements == 0 {
813 return Err(MlxError::InvalidArgument(
814 "moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
815 ));
816 }
817 let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
818 if gate_up.byte_len() < gu_required {
819 return Err(MlxError::InvalidArgument(format!(
820 "moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
821 gu_required, gu_byte_offset, gate_up.byte_len()
822 )));
823 }
824 let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
825 if output.byte_len() < out_required {
826 return Err(MlxError::InvalidArgument(format!(
827 "moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
828 out_required, out_byte_offset, output.byte_len()
829 )));
830 }
831
832 let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
833 let params = GpuFusedGeluMulParams {
834 n_elements: n_elements as u32,
835 };
836 encode_with_args(
837 encoder,
838 pipeline,
839 &[
840 (0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
841 (1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
842 (2, KernelArg::Bytes(as_bytes(¶ms))),
843 ],
844 MTLSize::new(n_elements as u64, 1, 1),
845 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
846 );
847 Ok(())
848}
849
850#[allow(clippy::too_many_arguments)]
855pub fn moe_accumulate_encode_offset(
856 encoder: &mut CommandEncoder,
857 registry: &mut KernelRegistry,
858 device: &metal::DeviceRef,
859 accumulator: &MlxBuffer,
860 expert_output: &MlxBuffer,
861 src_byte_offset: usize,
862 routing_weight: f32,
863 n_elements: usize,
864) -> Result<()> {
865 if n_elements == 0 {
866 return Err(MlxError::InvalidArgument(
867 "moe_accumulate_encode_offset: n_elements must be > 0".into(),
868 ));
869 }
870 let required = n_elements * std::mem::size_of::<f32>();
871 if accumulator.byte_len() < required {
872 return Err(MlxError::InvalidArgument(format!(
873 "moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
874 required, accumulator.byte_len()
875 )));
876 }
877 let src_required = src_byte_offset + required;
878 if expert_output.byte_len() < src_required {
879 return Err(MlxError::InvalidArgument(format!(
880 "moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
881 src_required, src_byte_offset, expert_output.byte_len()
882 )));
883 }
884
885 let pipeline = registry.get_pipeline("moe_accumulate", device)?;
886 let params = GpuMoeAccumParams {
887 n_elements: n_elements as u32,
888 routing_weight,
889 };
890 encode_with_args(
891 encoder,
892 pipeline,
893 &[
894 (0, KernelArg::Buffer(accumulator)),
895 (1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
896 (2, KernelArg::Bytes(as_bytes(¶ms))),
897 ],
898 MTLSize::new(n_elements as u64, 1, 1),
899 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
900 );
901 Ok(())
902}
903
904pub fn fused_gelu_mul_bf16_encode(
918 encoder: &mut CommandEncoder,
919 registry: &mut KernelRegistry,
920 device: &metal::DeviceRef,
921 gate_out: &MlxBuffer,
922 up_out: &MlxBuffer,
923 output: &MlxBuffer,
924 n_elements: usize,
925) -> Result<()> {
926 if n_elements == 0 {
927 return Err(MlxError::InvalidArgument(
928 "fused_gelu_mul_bf16_encode: n_elements must be > 0".into(),
929 ));
930 }
931 let required = n_elements * 2;
933 if gate_out.byte_len() < required {
934 return Err(MlxError::InvalidArgument(format!(
935 "fused_gelu_mul_bf16_encode: gate_out too small: need {} bytes, have {}",
936 required, gate_out.byte_len()
937 )));
938 }
939 if up_out.byte_len() < required {
940 return Err(MlxError::InvalidArgument(format!(
941 "fused_gelu_mul_bf16_encode: up_out too small: need {} bytes, have {}",
942 required, up_out.byte_len()
943 )));
944 }
945 if output.byte_len() < required {
946 return Err(MlxError::InvalidArgument(format!(
947 "fused_gelu_mul_bf16_encode: output too small: need {} bytes, have {}",
948 required, output.byte_len()
949 )));
950 }
951
952 let pipeline = registry.get_pipeline("fused_gelu_mul_bf16", device)?;
953 let params = GpuFusedGeluMulParams {
954 n_elements: n_elements as u32,
955 };
956 encode_with_args(
957 encoder,
958 pipeline,
959 &[
960 (0, KernelArg::Buffer(gate_out)),
961 (1, KernelArg::Buffer(up_out)),
962 (2, KernelArg::Buffer(output)),
963 (3, KernelArg::Bytes(as_bytes(¶ms))),
964 ],
965 MTLSize::new(n_elements as u64, 1, 1),
966 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
967 );
968 Ok(())
969}
970
971#[repr(C)]
973#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
974struct GpuMoeSwigluSeqParams {
975 intermediate: u32,
976 top_k: u32,
977 n_tokens: u32,
978}
979
980#[allow(clippy::too_many_arguments)]
985pub fn moe_swiglu_seq_encode(
986 encoder: &mut CommandEncoder,
987 registry: &mut KernelRegistry,
988 device: &metal::DeviceRef,
989 gate_up: &MlxBuffer,
990 output: &MlxBuffer,
991 intermediate: usize,
992 top_k: usize,
993 n_tokens: usize,
994) -> Result<()> {
995 if intermediate == 0 || top_k == 0 || n_tokens == 0 {
996 return Err(MlxError::InvalidArgument(
997 "moe_swiglu_seq_encode: all dims must be > 0".into(),
998 ));
999 }
1000 let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
1001 if gate_up.byte_len() < gu_required {
1002 return Err(MlxError::InvalidArgument(format!(
1003 "moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
1004 gu_required, gate_up.byte_len()
1005 )));
1006 }
1007 let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
1008 if output.byte_len() < out_required {
1009 return Err(MlxError::InvalidArgument(format!(
1010 "moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
1011 out_required, output.byte_len()
1012 )));
1013 }
1014
1015 let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
1016 let gpu_params = GpuMoeSwigluSeqParams {
1017 intermediate: intermediate as u32,
1018 top_k: top_k as u32,
1019 n_tokens: n_tokens as u32,
1020 };
1021
1022 encode_with_args(
1023 encoder,
1024 pipeline,
1025 &[
1026 (0, KernelArg::Buffer(gate_up)),
1027 (1, KernelArg::Buffer(output)),
1028 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
1029 ],
1030 MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1031 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1032 );
1033 Ok(())
1034}
1035
1036#[repr(C)]
1038#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
1039struct GpuMoeWeightedSumSeqParams {
1040 hidden_size: u32,
1041 top_k: u32,
1042 n_tokens: u32,
1043}
1044
1045#[allow(clippy::too_many_arguments)]
1051pub fn moe_weighted_sum_seq_encode(
1052 encoder: &mut CommandEncoder,
1053 registry: &mut KernelRegistry,
1054 device: &metal::DeviceRef,
1055 expert_outputs: &MlxBuffer,
1056 weights: &MlxBuffer,
1057 output: &MlxBuffer,
1058 hidden_size: usize,
1059 top_k: usize,
1060 n_tokens: usize,
1061) -> Result<()> {
1062 if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1063 return Err(MlxError::InvalidArgument(
1064 "moe_weighted_sum_seq_encode: all dims must be > 0".into(),
1065 ));
1066 }
1067 let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
1068 if expert_outputs.byte_len() < expert_required {
1069 return Err(MlxError::InvalidArgument(format!(
1070 "moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
1071 expert_required, expert_outputs.byte_len()
1072 )));
1073 }
1074 let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1075 if weights.byte_len() < weights_required {
1076 return Err(MlxError::InvalidArgument(format!(
1077 "moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
1078 weights_required, weights.byte_len()
1079 )));
1080 }
1081 let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1082 if output.byte_len() < out_required {
1083 return Err(MlxError::InvalidArgument(format!(
1084 "moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
1085 out_required, output.byte_len()
1086 )));
1087 }
1088
1089 let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
1090 let gpu_params = GpuMoeWeightedSumSeqParams {
1091 hidden_size: hidden_size as u32,
1092 top_k: top_k as u32,
1093 n_tokens: n_tokens as u32,
1094 };
1095
1096 encode_with_args(
1097 encoder,
1098 pipeline,
1099 &[
1100 (0, KernelArg::Buffer(expert_outputs)),
1101 (1, KernelArg::Buffer(weights)),
1102 (2, KernelArg::Buffer(output)),
1103 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1104 ],
1105 MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1106 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1107 );
1108 Ok(())
1109}
1110
1111#[allow(clippy::too_many_arguments)]
1117pub fn moe_swiglu_seq_bf16_encode(
1118 encoder: &mut CommandEncoder,
1119 registry: &mut KernelRegistry,
1120 device: &metal::DeviceRef,
1121 gate_up: &MlxBuffer,
1122 output: &MlxBuffer,
1123 intermediate: usize,
1124 top_k: usize,
1125 n_tokens: usize,
1126) -> Result<()> {
1127 if intermediate == 0 || top_k == 0 || n_tokens == 0 {
1128 return Err(MlxError::InvalidArgument(
1129 "moe_swiglu_seq_bf16_encode: all dims must be > 0".into(),
1130 ));
1131 }
1132 let gu_required = n_tokens * top_k * 2 * intermediate * 2;
1134 if gate_up.byte_len() < gu_required {
1135 return Err(MlxError::InvalidArgument(format!(
1136 "moe_swiglu_seq_bf16_encode: gate_up too small: need {} bytes, have {}",
1137 gu_required, gate_up.byte_len()
1138 )));
1139 }
1140 let out_required = n_tokens * top_k * intermediate * 2;
1141 if output.byte_len() < out_required {
1142 return Err(MlxError::InvalidArgument(format!(
1143 "moe_swiglu_seq_bf16_encode: output too small: need {} bytes, have {}",
1144 out_required, output.byte_len()
1145 )));
1146 }
1147
1148 let pipeline = registry.get_pipeline("moe_swiglu_seq_bf16", device)?;
1149 let gpu_params = GpuMoeSwigluSeqParams {
1150 intermediate: intermediate as u32,
1151 top_k: top_k as u32,
1152 n_tokens: n_tokens as u32,
1153 };
1154
1155 encode_with_args(
1156 encoder,
1157 pipeline,
1158 &[
1159 (0, KernelArg::Buffer(gate_up)),
1160 (1, KernelArg::Buffer(output)),
1161 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
1162 ],
1163 MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1164 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1165 );
1166 Ok(())
1167}
1168
1169#[allow(clippy::too_many_arguments)]
1180pub fn moe_weighted_sum_seq_bf16_input_encode(
1181 encoder: &mut CommandEncoder,
1182 registry: &mut KernelRegistry,
1183 device: &metal::DeviceRef,
1184 expert_outputs: &MlxBuffer,
1185 weights: &MlxBuffer,
1186 output: &MlxBuffer,
1187 hidden_size: usize,
1188 top_k: usize,
1189 n_tokens: usize,
1190) -> Result<()> {
1191 if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1192 return Err(MlxError::InvalidArgument(
1193 "moe_weighted_sum_seq_bf16_input_encode: all dims must be > 0".into(),
1194 ));
1195 }
1196 let expert_required = n_tokens * top_k * hidden_size * 2;
1198 if expert_outputs.byte_len() < expert_required {
1199 return Err(MlxError::InvalidArgument(format!(
1200 "moe_weighted_sum_seq_bf16_input_encode: expert_outputs too small: need {} bytes, have {}",
1201 expert_required, expert_outputs.byte_len()
1202 )));
1203 }
1204 let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1205 if weights.byte_len() < weights_required {
1206 return Err(MlxError::InvalidArgument(format!(
1207 "moe_weighted_sum_seq_bf16_input_encode: weights too small: need {} bytes, have {}",
1208 weights_required, weights.byte_len()
1209 )));
1210 }
1211 let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1212 if output.byte_len() < out_required {
1213 return Err(MlxError::InvalidArgument(format!(
1214 "moe_weighted_sum_seq_bf16_input_encode: output too small: need {} bytes, have {}",
1215 out_required, output.byte_len()
1216 )));
1217 }
1218
1219 let pipeline = registry.get_pipeline("moe_weighted_sum_seq_bf16_input", device)?;
1220 let gpu_params = GpuMoeWeightedSumSeqParams {
1221 hidden_size: hidden_size as u32,
1222 top_k: top_k as u32,
1223 n_tokens: n_tokens as u32,
1224 };
1225
1226 encode_with_args(
1227 encoder,
1228 pipeline,
1229 &[
1230 (0, KernelArg::Buffer(expert_outputs)),
1231 (1, KernelArg::Buffer(weights)),
1232 (2, KernelArg::Buffer(output)),
1233 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1234 ],
1235 MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1236 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1237 );
1238 Ok(())
1239}
1240
1241#[allow(clippy::too_many_arguments)]
1255pub fn moe_weighted_sum_seq_backward_outputs_encode(
1256 encoder: &mut CommandEncoder,
1257 registry: &mut KernelRegistry,
1258 device: &metal::DeviceRef,
1259 weights: &MlxBuffer,
1260 d_output: &MlxBuffer,
1261 d_expert_outputs: &MlxBuffer,
1262 hidden_size: usize,
1263 top_k: usize,
1264 n_tokens: usize,
1265) -> Result<()> {
1266 if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1267 return Err(MlxError::InvalidArgument(
1268 "moe_weighted_sum_seq_backward_outputs_encode: all dims must be > 0".into(),
1269 ));
1270 }
1271 let f32_size = std::mem::size_of::<f32>();
1272 let weights_required = n_tokens * top_k * f32_size;
1273 if weights.byte_len() < weights_required {
1274 return Err(MlxError::InvalidArgument(format!(
1275 "moe_weighted_sum_seq_backward_outputs_encode: weights too small: need {} bytes, have {}",
1276 weights_required,
1277 weights.byte_len()
1278 )));
1279 }
1280 let dout_required = n_tokens * hidden_size * f32_size;
1281 if d_output.byte_len() < dout_required {
1282 return Err(MlxError::InvalidArgument(format!(
1283 "moe_weighted_sum_seq_backward_outputs_encode: d_output too small: need {} bytes, have {}",
1284 dout_required,
1285 d_output.byte_len()
1286 )));
1287 }
1288 let dexp_required = n_tokens * top_k * hidden_size * f32_size;
1289 if d_expert_outputs.byte_len() < dexp_required {
1290 return Err(MlxError::InvalidArgument(format!(
1291 "moe_weighted_sum_seq_backward_outputs_encode: d_expert_outputs too small: need {} bytes, have {}",
1292 dexp_required,
1293 d_expert_outputs.byte_len()
1294 )));
1295 }
1296
1297 let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_outputs_f32", device)?;
1298 let gpu_params = GpuMoeWeightedSumSeqParams {
1299 hidden_size: hidden_size as u32,
1300 top_k: top_k as u32,
1301 n_tokens: n_tokens as u32,
1302 };
1303
1304 encode_with_args(
1305 encoder,
1306 pipeline,
1307 &[
1308 (0, KernelArg::Buffer(weights)),
1309 (1, KernelArg::Buffer(d_output)),
1310 (2, KernelArg::Buffer(d_expert_outputs)),
1311 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1312 ],
1313 MTLSize::new(hidden_size as u64, top_k as u64, n_tokens as u64),
1314 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1315 );
1316 Ok(())
1317}
1318
1319#[allow(clippy::too_many_arguments)]
1324pub fn moe_weighted_sum_seq_backward_weights_encode(
1325 encoder: &mut CommandEncoder,
1326 registry: &mut KernelRegistry,
1327 device: &metal::DeviceRef,
1328 expert_outputs: &MlxBuffer,
1329 d_output: &MlxBuffer,
1330 d_weights: &MlxBuffer,
1331 hidden_size: usize,
1332 top_k: usize,
1333 n_tokens: usize,
1334) -> Result<()> {
1335 if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
1336 return Err(MlxError::InvalidArgument(
1337 "moe_weighted_sum_seq_backward_weights_encode: all dims must be > 0".into(),
1338 ));
1339 }
1340 let f32_size = std::mem::size_of::<f32>();
1341 let exp_required = n_tokens * top_k * hidden_size * f32_size;
1342 if expert_outputs.byte_len() < exp_required {
1343 return Err(MlxError::InvalidArgument(format!(
1344 "moe_weighted_sum_seq_backward_weights_encode: expert_outputs too small: need {} bytes, have {}",
1345 exp_required,
1346 expert_outputs.byte_len()
1347 )));
1348 }
1349 let dout_required = n_tokens * hidden_size * f32_size;
1350 if d_output.byte_len() < dout_required {
1351 return Err(MlxError::InvalidArgument(format!(
1352 "moe_weighted_sum_seq_backward_weights_encode: d_output too small: need {} bytes, have {}",
1353 dout_required,
1354 d_output.byte_len()
1355 )));
1356 }
1357 let dw_required = n_tokens * top_k * f32_size;
1358 if d_weights.byte_len() < dw_required {
1359 return Err(MlxError::InvalidArgument(format!(
1360 "moe_weighted_sum_seq_backward_weights_encode: d_weights too small: need {} bytes, have {}",
1361 dw_required,
1362 d_weights.byte_len()
1363 )));
1364 }
1365
1366 let pipeline = registry.get_pipeline("moe_weighted_sum_seq_backward_weights_f32", device)?;
1367 let gpu_params = GpuMoeWeightedSumSeqParams {
1368 hidden_size: hidden_size as u32,
1369 top_k: top_k as u32,
1370 n_tokens: n_tokens as u32,
1371 };
1372
1373 let total = (n_tokens * top_k) as u64;
1374 encode_with_args(
1375 encoder,
1376 pipeline,
1377 &[
1378 (0, KernelArg::Buffer(expert_outputs)),
1379 (1, KernelArg::Buffer(d_output)),
1380 (2, KernelArg::Buffer(d_weights)),
1381 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1382 ],
1383 MTLSize::new(total, 1, 1),
1384 MTLSize::new(std::cmp::min(256, total), 1, 1),
1385 );
1386 Ok(())
1387}
1388
1389#[cfg(test)]
1390mod backward_weighted_sum_seq_tests {
1391 use super::*;
1392 use crate::device::MlxDevice;
1393 use crate::dtypes::DType;
1394
1395 fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
1396 let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
1397 b.as_mut_slice::<f32>().unwrap().fill(0.0);
1398 b
1399 }
1400 fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
1401 buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
1402 }
1403
1404 fn cpu_forward(
1405 expert_outs: &[f32],
1406 weights: &[f32],
1407 n_tokens: usize,
1408 top_k: usize,
1409 hidden: usize,
1410 ) -> Vec<f32> {
1411 let mut out = vec![0.0f32; n_tokens * hidden];
1412 for t in 0..n_tokens {
1413 for d in 0..hidden {
1414 let mut sum = 0.0f64;
1415 for k in 0..top_k {
1416 let exp_ix = (t * top_k + k) * hidden + d;
1417 let w_ix = t * top_k + k;
1418 sum += (expert_outs[exp_ix] as f64) * (weights[w_ix] as f64);
1419 }
1420 out[t * hidden + d] = sum as f32;
1421 }
1422 }
1423 out
1424 }
1425
1426 #[test]
1430 fn backward_outputs_finite_difference_falsifier() {
1431 let device = MlxDevice::new().unwrap();
1432 let mut registry = KernelRegistry::new();
1433 let n_tokens = 3usize;
1434 let top_k = 4usize;
1435 let hidden = 7usize;
1436
1437 let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1438 .map(|i| 0.1 + (i as f32) * 0.013 - (i as f32 * 0.004).sin())
1439 .collect();
1440 let weights: Vec<f32> = (0..n_tokens * top_k)
1441 .map(|i| 0.2 + (i as f32) * 0.07)
1442 .collect();
1443 let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1444 .map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
1445 .collect();
1446
1447 let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1448 fill_f32(&mut exp_buf, &expert_outs);
1449 let mut w_buf = alloc_f32(&device, n_tokens * top_k);
1450 fill_f32(&mut w_buf, &weights);
1451 let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1452 fill_f32(&mut dout_buf, &d_output_seed);
1453 let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1454
1455 let mut encoder = device.command_encoder().unwrap();
1456 moe_weighted_sum_seq_backward_outputs_encode(
1457 &mut encoder, &mut registry, device.metal_device(),
1458 &w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
1459 ).unwrap();
1460 encoder.commit_and_wait().unwrap();
1461 let analytic = dexp_buf.as_slice::<f32>().unwrap().to_vec();
1462
1463 let h: f32 = 1e-3;
1464 for idx in 0..(n_tokens * top_k * hidden) {
1465 let mut ep = expert_outs.clone(); ep[idx] += h;
1466 let mut em = expert_outs.clone(); em[idx] -= h;
1467 let yp = cpu_forward(&ep, &weights, n_tokens, top_k, hidden);
1468 let ym = cpu_forward(&em, &weights, n_tokens, top_k, hidden);
1469 let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1470 let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1471 let fd = (lp - lm) / (2.0 * h as f64);
1472 let tol = 5e-2 * fd.abs().max(1.0);
1473 assert!(
1474 (analytic[idx] as f64 - fd).abs() < tol,
1475 "d_expert_outputs[{}]: analytic={} fd={}",
1476 idx, analytic[idx], fd
1477 );
1478 }
1479 }
1480
1481 #[test]
1484 fn backward_weights_finite_difference_falsifier() {
1485 let device = MlxDevice::new().unwrap();
1486 let mut registry = KernelRegistry::new();
1487 let n_tokens = 4usize;
1488 let top_k = 3usize;
1489 let hidden = 11usize;
1490
1491 let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1492 .map(|i| 0.05 + (i as f32) * 0.017 + (i as f32 * 0.009).sin())
1493 .collect();
1494 let weights: Vec<f32> = (0..n_tokens * top_k)
1495 .map(|i| 0.1 + (i as f32) * 0.05)
1496 .collect();
1497 let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1498 .map(|i| 0.2 + (i as f32) * 0.03 - (i as f32 * 0.013).cos())
1499 .collect();
1500
1501 let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1502 fill_f32(&mut exp_buf, &expert_outs);
1503 let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1504 fill_f32(&mut dout_buf, &d_output_seed);
1505 let dw_buf = alloc_f32(&device, n_tokens * top_k);
1506
1507 let mut encoder = device.command_encoder().unwrap();
1508 moe_weighted_sum_seq_backward_weights_encode(
1509 &mut encoder, &mut registry, device.metal_device(),
1510 &exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
1511 ).unwrap();
1512 encoder.commit_and_wait().unwrap();
1513 let analytic = dw_buf.as_slice::<f32>().unwrap().to_vec();
1514
1515 let h: f32 = 1e-3;
1516 for idx in 0..(n_tokens * top_k) {
1517 let mut wp = weights.clone(); wp[idx] += h;
1518 let mut wm = weights.clone(); wm[idx] -= h;
1519 let yp = cpu_forward(&expert_outs, &wp, n_tokens, top_k, hidden);
1520 let ym = cpu_forward(&expert_outs, &wm, n_tokens, top_k, hidden);
1521 let lp: f64 = yp.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1522 let lm: f64 = ym.iter().zip(&d_output_seed).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1523 let fd = (lp - lm) / (2.0 * h as f64);
1524 let tol = 5e-2 * fd.abs().max(1.0);
1525 assert!(
1526 (analytic[idx] as f64 - fd).abs() < tol,
1527 "d_weights[{}]: analytic={} fd={}",
1528 idx, analytic[idx], fd
1529 );
1530 }
1531 }
1532
1533 #[test]
1536 fn forward_then_backward_round_trip_matches_cpu_oracle() {
1537 let device = MlxDevice::new().unwrap();
1538 let mut registry = KernelRegistry::new();
1539 let n_tokens = 2usize;
1540 let top_k = 2usize;
1541 let hidden = 5usize;
1542
1543 let expert_outs: Vec<f32> = (0..n_tokens * top_k * hidden)
1544 .map(|i| (i as f32) * 0.1 - 0.2)
1545 .collect();
1546 let weights: Vec<f32> = (0..n_tokens * top_k)
1547 .map(|i| 0.3 + (i as f32) * 0.1)
1548 .collect();
1549 let d_output_seed: Vec<f32> = (0..n_tokens * hidden)
1550 .map(|i| 1.0 - (i as f32) * 0.05)
1551 .collect();
1552
1553 let mut cpu_dexp = vec![0.0f32; n_tokens * top_k * hidden];
1555 let mut cpu_dw = vec![0.0f32; n_tokens * top_k];
1556 for t in 0..n_tokens {
1557 for k in 0..top_k {
1558 for d in 0..hidden {
1559 let exp_ix = (t * top_k + k) * hidden + d;
1560 let w_ix = t * top_k + k;
1561 let dout_ix = t * hidden + d;
1562 cpu_dexp[exp_ix] = weights[w_ix] * d_output_seed[dout_ix];
1563 cpu_dw[w_ix] += expert_outs[exp_ix] * d_output_seed[dout_ix];
1564 }
1565 }
1566 }
1567
1568 let mut exp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1569 fill_f32(&mut exp_buf, &expert_outs);
1570 let mut w_buf = alloc_f32(&device, n_tokens * top_k);
1571 fill_f32(&mut w_buf, &weights);
1572 let mut dout_buf = alloc_f32(&device, n_tokens * hidden);
1573 fill_f32(&mut dout_buf, &d_output_seed);
1574 let dexp_buf = alloc_f32(&device, n_tokens * top_k * hidden);
1575 let dw_buf = alloc_f32(&device, n_tokens * top_k);
1576
1577 let mut encoder = device.command_encoder().unwrap();
1578 moe_weighted_sum_seq_backward_outputs_encode(
1579 &mut encoder, &mut registry, device.metal_device(),
1580 &w_buf, &dout_buf, &dexp_buf, hidden, top_k, n_tokens,
1581 ).unwrap();
1582 moe_weighted_sum_seq_backward_weights_encode(
1583 &mut encoder, &mut registry, device.metal_device(),
1584 &exp_buf, &dout_buf, &dw_buf, hidden, top_k, n_tokens,
1585 ).unwrap();
1586 encoder.commit_and_wait().unwrap();
1587
1588 let gpu_dexp = dexp_buf.as_slice::<f32>().unwrap().to_vec();
1589 let gpu_dw = dw_buf.as_slice::<f32>().unwrap().to_vec();
1590
1591 for i in 0..gpu_dexp.len() {
1592 assert!(
1593 (gpu_dexp[i] - cpu_dexp[i]).abs() < 1e-5,
1594 "d_expert_outputs[{}]: gpu={} cpu={}",
1595 i, gpu_dexp[i], cpu_dexp[i]
1596 );
1597 }
1598 for i in 0..gpu_dw.len() {
1599 assert!(
1600 (gpu_dw[i] - cpu_dw[i]).abs() < 1e-5,
1601 "d_weights[{}]: gpu={} cpu={}",
1602 i, gpu_dw[i], cpu_dw[i]
1603 );
1604 }
1605 }
1606
1607 #[test]
1608 fn rejects_size_mismatch() {
1609 let device = MlxDevice::new().unwrap();
1610 let mut registry = KernelRegistry::new();
1611 let too_small = alloc_f32(&device, 4);
1612 let any = alloc_f32(&device, 1024);
1613 let mut encoder = device.command_encoder().unwrap();
1614 let res = moe_weighted_sum_seq_backward_outputs_encode(
1615 &mut encoder, &mut registry, device.metal_device(),
1616 &too_small, &any, &any, 7, 3, 4,
1617 );
1618 assert!(res.is_err());
1619 let res2 = moe_weighted_sum_seq_backward_weights_encode(
1620 &mut encoder, &mut registry, device.metal_device(),
1621 &too_small, &any, &any, 11, 3, 4,
1622 );
1623 assert!(res2.is_err());
1624 }
1625}
1626
1627#[allow(clippy::too_many_arguments)]
1646pub fn moe_swiglu_seq_backward_encode(
1647 encoder: &mut CommandEncoder,
1648 registry: &mut KernelRegistry,
1649 device: &metal::DeviceRef,
1650 gate_up: &MlxBuffer,
1651 d_output: &MlxBuffer,
1652 d_gate_up: &MlxBuffer,
1653 intermediate: usize,
1654 top_k: usize,
1655 n_tokens: usize,
1656) -> Result<()> {
1657 if intermediate == 0 || top_k == 0 || n_tokens == 0 {
1658 return Err(MlxError::InvalidArgument(
1659 "moe_swiglu_seq_backward_encode: all dims must be > 0".into(),
1660 ));
1661 }
1662 let f32_size = std::mem::size_of::<f32>();
1663 let gu_required = n_tokens * top_k * 2 * intermediate * f32_size;
1664 if gate_up.byte_len() < gu_required {
1665 return Err(MlxError::InvalidArgument(format!(
1666 "moe_swiglu_seq_backward_encode: gate_up too small: need {} bytes, have {}",
1667 gu_required, gate_up.byte_len()
1668 )));
1669 }
1670 if d_gate_up.byte_len() < gu_required {
1671 return Err(MlxError::InvalidArgument(format!(
1672 "moe_swiglu_seq_backward_encode: d_gate_up too small: need {} bytes, have {}",
1673 gu_required, d_gate_up.byte_len()
1674 )));
1675 }
1676 let dout_required = n_tokens * top_k * intermediate * f32_size;
1677 if d_output.byte_len() < dout_required {
1678 return Err(MlxError::InvalidArgument(format!(
1679 "moe_swiglu_seq_backward_encode: d_output too small: need {} bytes, have {}",
1680 dout_required, d_output.byte_len()
1681 )));
1682 }
1683
1684 let pipeline = registry.get_pipeline("moe_swiglu_seq_backward_f32", device)?;
1685 let gpu_params = GpuMoeSwigluSeqParams {
1686 intermediate: intermediate as u32,
1687 top_k: top_k as u32,
1688 n_tokens: n_tokens as u32,
1689 };
1690
1691 encode_with_args(
1692 encoder,
1693 pipeline,
1694 &[
1695 (0, KernelArg::Buffer(gate_up)),
1696 (1, KernelArg::Buffer(d_output)),
1697 (2, KernelArg::Buffer(d_gate_up)),
1698 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1699 ],
1700 MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
1701 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
1702 );
1703 Ok(())
1704}
1705
1706#[cfg(test)]
1707mod backward_swiglu_seq_tests {
1708 use super::*;
1709 use crate::device::MlxDevice;
1710 use crate::dtypes::DType;
1711
1712 fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
1713 let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
1714 b.as_mut_slice::<f32>().unwrap().fill(0.0);
1715 b
1716 }
1717 fn fill_f32(buf: &mut MlxBuffer, vals: &[f32]) {
1718 buf.as_mut_slice::<f32>().unwrap()[..vals.len()].copy_from_slice(vals);
1719 }
1720
1721 fn cpu_gelu(g: f64) -> f64 {
1723 let s = 0.7978845608 * (g + 0.044715 * g * g * g);
1724 let t = s.tanh();
1725 0.5 * g * (1.0 + t)
1726 }
1727
1728 fn cpu_forward(
1730 gate_up: &[f32],
1731 n_tokens: usize,
1732 top_k: usize,
1733 intermediate: usize,
1734 ) -> Vec<f32> {
1735 let mut out = vec![0.0f32; n_tokens * top_k * intermediate];
1736 for t in 0..n_tokens {
1737 for k in 0..top_k {
1738 let slot_base = (t * top_k + k) * 2 * intermediate;
1739 for i in 0..intermediate {
1740 let g = gate_up[slot_base + i] as f64;
1741 let u = gate_up[slot_base + intermediate + i] as f64;
1742 let y = cpu_gelu(g) * u;
1743 out[(t * top_k + k) * intermediate + i] = y as f32;
1744 }
1745 }
1746 }
1747 out
1748 }
1749
1750 #[test]
1753 fn backward_finite_difference_falsifier_both_gradients() {
1754 let device = MlxDevice::new().unwrap();
1755 let mut registry = KernelRegistry::new();
1756 let n_tokens = 2usize;
1757 let top_k = 3usize;
1758 let intermediate = 5usize;
1759 let gu_n = n_tokens * top_k * 2 * intermediate;
1760 let dout_n = n_tokens * top_k * intermediate;
1761
1762 let gate_up: Vec<f32> = (0..gu_n)
1765 .map(|i| 0.5 + (i as f32) * 0.07 - (i as f32 * 0.013).sin())
1766 .collect();
1767 let d_output_seed: Vec<f32> = (0..dout_n)
1768 .map(|i| 0.3 + (i as f32) * 0.05 - (i as f32 * 0.011).cos())
1769 .collect();
1770
1771 let mut gu_buf = alloc_f32(&device, gu_n);
1772 fill_f32(&mut gu_buf, &gate_up);
1773 let mut dout_buf = alloc_f32(&device, dout_n);
1774 fill_f32(&mut dout_buf, &d_output_seed);
1775 let dgu_buf = alloc_f32(&device, gu_n);
1776
1777 let mut encoder = device.command_encoder().unwrap();
1778 moe_swiglu_seq_backward_encode(
1779 &mut encoder, &mut registry, device.metal_device(),
1780 &gu_buf, &dout_buf, &dgu_buf,
1781 intermediate, top_k, n_tokens,
1782 ).unwrap();
1783 encoder.commit_and_wait().unwrap();
1784 let analytic = dgu_buf.as_slice::<f32>().unwrap().to_vec();
1785
1786 let h: f32 = 1e-3;
1788 for idx in 0..gu_n {
1789 let mut gp = gate_up.clone();
1790 gp[idx] += h;
1791 let mut gm = gate_up.clone();
1792 gm[idx] -= h;
1793 let yp = cpu_forward(&gp, n_tokens, top_k, intermediate);
1794 let ym = cpu_forward(&gm, n_tokens, top_k, intermediate);
1795 let lp: f64 = yp.iter().zip(&d_output_seed)
1796 .map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1797 let lm: f64 = ym.iter().zip(&d_output_seed)
1798 .map(|(a, b)| (*a as f64) * (*b as f64)).sum();
1799 let fd = (lp - lm) / (2.0 * h as f64);
1800 let tol = 5e-2 * fd.abs().max(1.0);
1801 assert!(
1802 (analytic[idx] as f64 - fd).abs() < tol,
1803 "d_gate_up[{}]: analytic={} fd={} (gate_up_value={})",
1804 idx, analytic[idx], fd, gate_up[idx]
1805 );
1806 }
1807 }
1808
1809 #[test]
1815 fn backward_canonical_asymptotics_match_expected() {
1816 let device = MlxDevice::new().unwrap();
1817 let mut registry = KernelRegistry::new();
1818 let n_tokens = 1usize;
1819 let top_k = 1usize;
1820 let intermediate = 3usize;
1821 let gu_n = n_tokens * top_k * 2 * intermediate;
1822 let dout_n = n_tokens * top_k * intermediate;
1823
1824 let gate_up = vec![0.0f32, 10.0, -10.0, 2.0, 3.0, 4.0];
1829 let d_output_seed = vec![0.5f32, 1.0, 1.0];
1830
1831 let mut gu_buf = alloc_f32(&device, gu_n);
1832 fill_f32(&mut gu_buf, &gate_up);
1833 let mut dout_buf = alloc_f32(&device, dout_n);
1834 fill_f32(&mut dout_buf, &d_output_seed);
1835 let dgu_buf = alloc_f32(&device, gu_n);
1836
1837 let mut encoder = device.command_encoder().unwrap();
1838 moe_swiglu_seq_backward_encode(
1839 &mut encoder, &mut registry, device.metal_device(),
1840 &gu_buf, &dout_buf, &dgu_buf,
1841 intermediate, top_k, n_tokens,
1842 ).unwrap();
1843 encoder.commit_and_wait().unwrap();
1844 let g = dgu_buf.as_slice::<f32>().unwrap();
1845
1846 assert!((g[0] - 0.5).abs() < 1e-5, "∂gate0={}", g[0]);
1848 assert!((g[1] - 3.0).abs() < 0.05, "∂gate1={}", g[1]);
1850 assert!(g[2].abs() < 0.05, "∂gate2={}", g[2]);
1852 assert!(g[3].abs() < 1e-5, "∂up0={}", g[3]);
1854 assert!((g[4] - 10.0).abs() < 0.05, "∂up1={}", g[4]);
1856 assert!(g[5].abs() < 0.05, "∂up2={}", g[5]);
1858 }
1859
1860 #[test]
1861 fn rejects_size_mismatch() {
1862 let device = MlxDevice::new().unwrap();
1863 let mut registry = KernelRegistry::new();
1864 let too_small = alloc_f32(&device, 4);
1865 let any = alloc_f32(&device, 1024);
1866 let mut encoder = device.command_encoder().unwrap();
1867 let res = moe_swiglu_seq_backward_encode(
1868 &mut encoder, &mut registry, device.metal_device(),
1869 &too_small, &any, &any, 5, 3, 2,
1870 );
1871 assert!(res.is_err());
1872 }
1873}