1use metal::MTLSize;
19
20use crate::buffer::MlxBuffer;
21use crate::encoder::CommandEncoder;
22use crate::error::{MlxError, Result};
23use crate::kernel_registry::KernelRegistry;
24
25use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
26
27pub struct MoeDispatchParams {
29 pub input_dim: usize,
31 pub intermediate_dim: usize,
33 pub n_selected: usize,
35}
36
37pub struct ExpertWeights<'a> {
44 pub gate_proj: &'a MlxBuffer,
45 pub up_proj: &'a MlxBuffer,
46 pub down_proj: &'a MlxBuffer,
47}
48
49#[repr(C)]
51#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
52struct GpuFusedGeluMulParams {
53 n_elements: u32,
54}
55
56#[repr(C)]
58#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
59struct GpuMoeAccumParams {
60 n_elements: u32,
61 routing_weight: f32,
62}
63
64#[repr(C)]
66#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
67struct GpuZeroParams {
68 n_elements: u32,
69}
70
71#[repr(C)]
74#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
75struct GpuMatmulParams {
76 m: u32, k: u32, n: u32, }
80
81#[allow(clippy::too_many_arguments)]
112pub fn moe_dispatch(
113 encoder: &mut CommandEncoder,
114 registry: &mut KernelRegistry,
115 device: &metal::DeviceRef,
116 input: &MlxBuffer,
117 expert_weights: &[ExpertWeights<'_>],
118 routing_weights: &[f32],
119 output: &MlxBuffer,
120 scratch_gate: &MlxBuffer,
121 scratch_up: &MlxBuffer,
122 scratch_hidden: &MlxBuffer,
123 scratch_expert: &MlxBuffer,
124 params: &MoeDispatchParams,
125) -> Result<()> {
126 if params.input_dim == 0 {
128 return Err(MlxError::InvalidArgument(
129 "moe_dispatch: input_dim must be > 0".into(),
130 ));
131 }
132 if params.intermediate_dim == 0 {
133 return Err(MlxError::InvalidArgument(
134 "moe_dispatch: intermediate_dim must be > 0".into(),
135 ));
136 }
137 if params.n_selected == 0 {
138 return Err(MlxError::InvalidArgument(
139 "moe_dispatch: n_selected must be > 0".into(),
140 ));
141 }
142 if expert_weights.len() != params.n_selected {
143 return Err(MlxError::InvalidArgument(format!(
144 "moe_dispatch: expert_weights length ({}) must match n_selected ({})",
145 expert_weights.len(),
146 params.n_selected
147 )));
148 }
149 if routing_weights.len() != params.n_selected {
150 return Err(MlxError::InvalidArgument(format!(
151 "moe_dispatch: routing_weights length ({}) must match n_selected ({})",
152 routing_weights.len(),
153 params.n_selected
154 )));
155 }
156
157 let input_bytes = params.input_dim * std::mem::size_of::<f32>();
159 if input.byte_len() < input_bytes {
160 return Err(MlxError::InvalidArgument(format!(
161 "moe_dispatch: input buffer too small: need {} bytes, have {}",
162 input_bytes,
163 input.byte_len()
164 )));
165 }
166 if output.byte_len() < input_bytes {
167 return Err(MlxError::InvalidArgument(format!(
168 "moe_dispatch: output buffer too small: need {} bytes, have {}",
169 input_bytes,
170 output.byte_len()
171 )));
172 }
173
174 let intermediate_bytes = params.intermediate_dim * std::mem::size_of::<f32>();
175 if scratch_gate.byte_len() < intermediate_bytes {
176 return Err(MlxError::InvalidArgument(
177 "moe_dispatch: scratch_gate buffer too small".into(),
178 ));
179 }
180 if scratch_up.byte_len() < intermediate_bytes {
181 return Err(MlxError::InvalidArgument(
182 "moe_dispatch: scratch_up buffer too small".into(),
183 ));
184 }
185 if scratch_hidden.byte_len() < intermediate_bytes {
186 return Err(MlxError::InvalidArgument(
187 "moe_dispatch: scratch_hidden buffer too small".into(),
188 ));
189 }
190 if scratch_expert.byte_len() < input_bytes {
191 return Err(MlxError::InvalidArgument(
192 "moe_dispatch: scratch_expert buffer too small".into(),
193 ));
194 }
195
196 let gate_up_bytes = params.input_dim * params.intermediate_dim * std::mem::size_of::<f32>();
198 let down_bytes = params.intermediate_dim * params.input_dim * std::mem::size_of::<f32>();
199
200 for (i, ew) in expert_weights.iter().enumerate() {
201 if ew.gate_proj.byte_len() < gate_up_bytes {
202 return Err(MlxError::InvalidArgument(format!(
203 "moe_dispatch: expert {} gate_proj too small: need {} bytes, have {}",
204 i, gate_up_bytes, ew.gate_proj.byte_len()
205 )));
206 }
207 if ew.up_proj.byte_len() < gate_up_bytes {
208 return Err(MlxError::InvalidArgument(format!(
209 "moe_dispatch: expert {} up_proj too small: need {} bytes, have {}",
210 i, gate_up_bytes, ew.up_proj.byte_len()
211 )));
212 }
213 if ew.down_proj.byte_len() < down_bytes {
214 return Err(MlxError::InvalidArgument(format!(
215 "moe_dispatch: expert {} down_proj too small: need {} bytes, have {}",
216 i, down_bytes, ew.down_proj.byte_len()
217 )));
218 }
219 }
220
221 {
227 registry.get_pipeline("naive_matvec_f32", device)?;
228 registry.get_pipeline("fused_gelu_mul", device)?;
229 registry.get_pipeline("moe_accumulate", device)?;
230 registry.get_pipeline("zero_buffer", device)?;
231 }
232 let matvec_pipeline: *const metal::ComputePipelineState = {
237 let p = registry.get_pipeline("naive_matvec_f32", device)?;
238 p as *const _
239 };
240 let gelu_mul_pipeline: *const metal::ComputePipelineState = {
241 let p = registry.get_pipeline("fused_gelu_mul", device)?;
242 p as *const _
243 };
244 let accum_pipeline: *const metal::ComputePipelineState = {
245 let p = registry.get_pipeline("moe_accumulate", device)?;
246 p as *const _
247 };
248 let zero_pipeline: *const metal::ComputePipelineState = {
249 let p = registry.get_pipeline("zero_buffer", device)?;
250 p as *const _
251 };
252 let matvec_pipeline = unsafe { &*matvec_pipeline };
257 let gelu_mul_pipeline = unsafe { &*gelu_mul_pipeline };
258 let accum_pipeline = unsafe { &*accum_pipeline };
259 let zero_pipeline = unsafe { &*zero_pipeline };
260
261 let zero_params = GpuZeroParams {
263 n_elements: params.input_dim as u32,
264 };
265 encode_with_args(
266 encoder,
267 zero_pipeline,
268 &[
269 (0, KernelArg::Buffer(output)),
270 (1, KernelArg::Bytes(as_bytes(&zero_params))),
271 ],
272 MTLSize::new(params.input_dim as u64, 1, 1),
273 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
274 );
275
276 for (i, ew) in expert_weights.iter().enumerate() {
278 let w = routing_weights[i];
279
280 if w.abs() < 1e-10 {
282 continue;
283 }
284
285 encoder.memory_barrier();
288
289 let gate_params = GpuMatmulParams {
291 m: 1,
292 k: params.input_dim as u32,
293 n: params.intermediate_dim as u32,
294 };
295 encode_with_args(
296 encoder,
297 matvec_pipeline,
298 &[
299 (0, KernelArg::Buffer(ew.gate_proj)),
300 (1, KernelArg::Buffer(input)),
301 (2, KernelArg::Buffer(scratch_gate)),
302 (3, KernelArg::Bytes(as_bytes(&gate_params))),
303 ],
304 MTLSize::new(params.intermediate_dim as u64, 1, 1),
305 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
306 );
307
308 let up_params = GpuMatmulParams {
310 m: 1,
311 k: params.input_dim as u32,
312 n: params.intermediate_dim as u32,
313 };
314 encode_with_args(
315 encoder,
316 matvec_pipeline,
317 &[
318 (0, KernelArg::Buffer(ew.up_proj)),
319 (1, KernelArg::Buffer(input)),
320 (2, KernelArg::Buffer(scratch_up)),
321 (3, KernelArg::Bytes(as_bytes(&up_params))),
322 ],
323 MTLSize::new(params.intermediate_dim as u64, 1, 1),
324 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
325 );
326
327 encoder.memory_barrier();
329
330 let gelu_params = GpuFusedGeluMulParams {
332 n_elements: params.intermediate_dim as u32,
333 };
334 encode_with_args(
335 encoder,
336 gelu_mul_pipeline,
337 &[
338 (0, KernelArg::Buffer(scratch_gate)),
339 (1, KernelArg::Buffer(scratch_up)),
340 (2, KernelArg::Buffer(scratch_hidden)),
341 (3, KernelArg::Bytes(as_bytes(&gelu_params))),
342 ],
343 MTLSize::new(params.intermediate_dim as u64, 1, 1),
344 MTLSize::new(std::cmp::min(256, params.intermediate_dim as u64), 1, 1),
345 );
346
347 encoder.memory_barrier();
349
350 let down_params = GpuMatmulParams {
352 m: 1,
353 k: params.intermediate_dim as u32,
354 n: params.input_dim as u32,
355 };
356 encode_with_args(
357 encoder,
358 matvec_pipeline,
359 &[
360 (0, KernelArg::Buffer(ew.down_proj)),
361 (1, KernelArg::Buffer(scratch_hidden)),
362 (2, KernelArg::Buffer(scratch_expert)),
363 (3, KernelArg::Bytes(as_bytes(&down_params))),
364 ],
365 MTLSize::new(params.input_dim as u64, 1, 1),
366 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
367 );
368
369 encoder.memory_barrier();
371
372 let accum_params = GpuMoeAccumParams {
374 n_elements: params.input_dim as u32,
375 routing_weight: w,
376 };
377 encode_with_args(
378 encoder,
379 accum_pipeline,
380 &[
381 (0, KernelArg::Buffer(output)),
382 (1, KernelArg::Buffer(scratch_expert)),
383 (2, KernelArg::Bytes(as_bytes(&accum_params))),
384 ],
385 MTLSize::new(params.input_dim as u64, 1, 1),
386 MTLSize::new(std::cmp::min(256, params.input_dim as u64), 1, 1),
387 );
388 }
389
390 Ok(())
391}
392
393pub fn moe_swiglu_fused_encode(
407 encoder: &mut CommandEncoder,
408 registry: &mut KernelRegistry,
409 device: &metal::DeviceRef,
410 gate_up: &MlxBuffer,
411 output: &MlxBuffer,
412 n_elements: usize,
413) -> Result<()> {
414 if n_elements == 0 {
415 return Err(MlxError::InvalidArgument(
416 "moe_swiglu_fused_encode: n_elements must be > 0".into(),
417 ));
418 }
419 let gu_required = 2 * n_elements * std::mem::size_of::<f32>();
420 if gate_up.byte_len() < gu_required {
421 return Err(MlxError::InvalidArgument(format!(
422 "moe_swiglu_fused_encode: gate_up buffer too small: need {} bytes, have {}",
423 gu_required, gate_up.byte_len()
424 )));
425 }
426 let out_required = n_elements * std::mem::size_of::<f32>();
427 if output.byte_len() < out_required {
428 return Err(MlxError::InvalidArgument(format!(
429 "moe_swiglu_fused_encode: output buffer too small: need {} bytes, have {}",
430 out_required, output.byte_len()
431 )));
432 }
433
434 let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
435 let params = GpuFusedGeluMulParams {
436 n_elements: n_elements as u32,
437 };
438 encode_with_args(
439 encoder,
440 pipeline,
441 &[
442 (0, KernelArg::Buffer(gate_up)),
443 (1, KernelArg::Buffer(output)),
444 (2, KernelArg::Bytes(as_bytes(¶ms))),
445 ],
446 MTLSize::new(n_elements as u64, 1, 1),
447 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
448 );
449 Ok(())
450}
451
452pub fn moe_zero_buffer_encode(
464 encoder: &mut CommandEncoder,
465 registry: &mut KernelRegistry,
466 device: &metal::DeviceRef,
467 output: &MlxBuffer,
468 n_elements: usize,
469) -> Result<()> {
470 if n_elements == 0 {
471 return Err(MlxError::InvalidArgument(
472 "moe_zero_buffer_encode: n_elements must be > 0".into(),
473 ));
474 }
475 let required = n_elements * std::mem::size_of::<f32>();
476 if output.byte_len() < required {
477 return Err(MlxError::InvalidArgument(format!(
478 "moe_zero_buffer_encode: buffer too small: need {} bytes, have {}",
479 required, output.byte_len()
480 )));
481 }
482
483 let pipeline = registry.get_pipeline("zero_buffer", device)?;
484 let params = GpuZeroParams { n_elements: n_elements as u32 };
485 encode_with_args(
486 encoder,
487 pipeline,
488 &[
489 (0, KernelArg::Buffer(output)),
490 (1, KernelArg::Bytes(as_bytes(¶ms))),
491 ],
492 MTLSize::new(n_elements as u64, 1, 1),
493 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
494 );
495 Ok(())
496}
497
498pub fn moe_accumulate_encode(
511 encoder: &mut CommandEncoder,
512 registry: &mut KernelRegistry,
513 device: &metal::DeviceRef,
514 accumulator: &MlxBuffer,
515 expert_output: &MlxBuffer,
516 routing_weight: f32,
517 n_elements: usize,
518) -> Result<()> {
519 if n_elements == 0 {
520 return Err(MlxError::InvalidArgument(
521 "moe_accumulate_encode: n_elements must be > 0".into(),
522 ));
523 }
524 let required = n_elements * std::mem::size_of::<f32>();
525 if accumulator.byte_len() < required {
526 return Err(MlxError::InvalidArgument(format!(
527 "moe_accumulate_encode: accumulator too small: need {} bytes, have {}",
528 required, accumulator.byte_len()
529 )));
530 }
531 if expert_output.byte_len() < required {
532 return Err(MlxError::InvalidArgument(format!(
533 "moe_accumulate_encode: expert_output too small: need {} bytes, have {}",
534 required, expert_output.byte_len()
535 )));
536 }
537
538 let pipeline = registry.get_pipeline("moe_accumulate", device)?;
539 let params = GpuMoeAccumParams {
540 n_elements: n_elements as u32,
541 routing_weight,
542 };
543 encode_with_args(
544 encoder,
545 pipeline,
546 &[
547 (0, KernelArg::Buffer(accumulator)),
548 (1, KernelArg::Buffer(expert_output)),
549 (2, KernelArg::Bytes(as_bytes(¶ms))),
550 ],
551 MTLSize::new(n_elements as u64, 1, 1),
552 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
553 );
554 Ok(())
555}
556
557pub fn moe_swiglu_batch_encode(
573 encoder: &mut CommandEncoder,
574 registry: &mut KernelRegistry,
575 device: &metal::DeviceRef,
576 gate_up: &MlxBuffer,
577 output: &MlxBuffer,
578 intermediate: usize,
579 top_k: usize,
580) -> Result<()> {
581 if intermediate == 0 || top_k == 0 {
582 return Err(MlxError::InvalidArgument(
583 "moe_swiglu_batch_encode: intermediate and top_k must be > 0".into(),
584 ));
585 }
586 let gu_required = top_k * 2 * intermediate * std::mem::size_of::<f32>();
587 if gate_up.byte_len() < gu_required {
588 return Err(MlxError::InvalidArgument(format!(
589 "moe_swiglu_batch_encode: gate_up too small: need {} bytes, have {}",
590 gu_required, gate_up.byte_len()
591 )));
592 }
593 let out_required = top_k * intermediate * std::mem::size_of::<f32>();
594 if output.byte_len() < out_required {
595 return Err(MlxError::InvalidArgument(format!(
596 "moe_swiglu_batch_encode: output too small: need {} bytes, have {}",
597 out_required, output.byte_len()
598 )));
599 }
600
601 let pipeline = registry.get_pipeline("moe_swiglu_batch", device)?;
602 let intermediate_bytes = (intermediate as u32).to_ne_bytes();
603 let top_k_bytes = (top_k as u32).to_ne_bytes();
604
605 encode_with_args(
606 encoder,
607 pipeline,
608 &[
609 (0, KernelArg::Buffer(gate_up)),
610 (1, KernelArg::Buffer(output)),
611 (2, KernelArg::Bytes(&intermediate_bytes)),
612 (3, KernelArg::Bytes(&top_k_bytes)),
613 ],
614 MTLSize::new(intermediate as u64, top_k as u64, 1),
615 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
616 );
617 Ok(())
618}
619
620#[repr(C)]
622#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
623struct GpuMoeWeightedSumParams {
624 hidden_size: u32,
625 top_k: u32,
626}
627
628pub fn moe_weighted_sum_encode(
644 encoder: &mut CommandEncoder,
645 registry: &mut KernelRegistry,
646 device: &metal::DeviceRef,
647 expert_outputs: &MlxBuffer,
648 weights: &MlxBuffer,
649 output: &MlxBuffer,
650 hidden_size: usize,
651 top_k: usize,
652) -> Result<()> {
653 if hidden_size == 0 || top_k == 0 {
654 return Err(MlxError::InvalidArgument(
655 "moe_weighted_sum_encode: hidden_size and top_k must be > 0".into(),
656 ));
657 }
658 let expert_required = top_k * hidden_size * std::mem::size_of::<f32>();
659 if expert_outputs.byte_len() < expert_required {
660 return Err(MlxError::InvalidArgument(format!(
661 "moe_weighted_sum_encode: expert_outputs too small: need {} bytes, have {}",
662 expert_required, expert_outputs.byte_len()
663 )));
664 }
665 let weights_required = top_k * std::mem::size_of::<f32>();
666 if weights.byte_len() < weights_required {
667 return Err(MlxError::InvalidArgument(format!(
668 "moe_weighted_sum_encode: weights too small: need {} bytes, have {}",
669 weights_required, weights.byte_len()
670 )));
671 }
672 let out_required = hidden_size * std::mem::size_of::<f32>();
673 if output.byte_len() < out_required {
674 return Err(MlxError::InvalidArgument(format!(
675 "moe_weighted_sum_encode: output too small: need {} bytes, have {}",
676 out_required, output.byte_len()
677 )));
678 }
679
680 let pipeline = registry.get_pipeline("moe_weighted_sum", device)?;
681 let params = GpuMoeWeightedSumParams {
682 hidden_size: hidden_size as u32,
683 top_k: top_k as u32,
684 };
685 encode_with_args(
686 encoder,
687 pipeline,
688 &[
689 (0, KernelArg::Buffer(expert_outputs)),
690 (1, KernelArg::Buffer(weights)),
691 (2, KernelArg::Buffer(output)),
692 (3, KernelArg::Bytes(as_bytes(¶ms))),
693 ],
694 MTLSize::new(hidden_size as u64, 1, 1),
695 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
696 );
697 Ok(())
698}
699
700#[repr(C)]
702#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
703struct GpuMoeGatherTopkParams {
704 n_experts: u32,
705 top_k: u32,
706}
707
708pub fn moe_gather_topk_weights_encode(
727 encoder: &mut CommandEncoder,
728 registry: &mut KernelRegistry,
729 device: &metal::DeviceRef,
730 softmax_probs: &MlxBuffer,
731 sorted_indices: &MlxBuffer,
732 per_expert_scale: &MlxBuffer,
733 out_expert_ids: &MlxBuffer,
734 out_weights: &MlxBuffer,
735 n_experts: usize,
736 top_k: usize,
737) -> Result<()> {
738 if n_experts == 0 || top_k == 0 {
739 return Err(MlxError::InvalidArgument(
740 "moe_gather_topk_weights: n_experts and top_k must be > 0".into(),
741 ));
742 }
743 if top_k > n_experts {
744 return Err(MlxError::InvalidArgument(format!(
745 "moe_gather_topk_weights: top_k ({}) > n_experts ({})",
746 top_k, n_experts,
747 )));
748 }
749 if top_k > 8 {
750 return Err(MlxError::InvalidArgument(format!(
751 "moe_gather_topk_weights: top_k ({}) > 8 (shader fixed-size array limit)",
752 top_k,
753 )));
754 }
755
756 let f32_size = std::mem::size_of::<f32>();
757 let u32_size = std::mem::size_of::<u32>();
758 if softmax_probs.byte_len() < n_experts * f32_size {
759 return Err(MlxError::InvalidArgument("softmax_probs too small".into()));
760 }
761 if sorted_indices.byte_len() < n_experts * u32_size {
762 return Err(MlxError::InvalidArgument("sorted_indices too small".into()));
763 }
764 if per_expert_scale.byte_len() < n_experts * f32_size {
765 return Err(MlxError::InvalidArgument("per_expert_scale too small".into()));
766 }
767 if out_expert_ids.byte_len() < top_k * u32_size {
768 return Err(MlxError::InvalidArgument("out_expert_ids too small".into()));
769 }
770 if out_weights.byte_len() < top_k * f32_size {
771 return Err(MlxError::InvalidArgument("out_weights too small".into()));
772 }
773
774 let pipeline = registry.get_pipeline("moe_gather_topk_weights", device)?;
775 let params = GpuMoeGatherTopkParams {
776 n_experts: n_experts as u32,
777 top_k: top_k as u32,
778 };
779 encode_with_args(
780 encoder,
781 pipeline,
782 &[
783 (0, KernelArg::Buffer(softmax_probs)),
784 (1, KernelArg::Buffer(sorted_indices)),
785 (2, KernelArg::Buffer(per_expert_scale)),
786 (3, KernelArg::Buffer(out_expert_ids)),
787 (4, KernelArg::Buffer(out_weights)),
788 (5, KernelArg::Bytes(as_bytes(¶ms))),
789 ],
790 MTLSize::new(1, 1, 1), MTLSize::new(1, 1, 1),
792 );
793 Ok(())
794}
795
796#[allow(clippy::too_many_arguments)]
802pub fn moe_swiglu_fused_encode_offset(
803 encoder: &mut CommandEncoder,
804 registry: &mut KernelRegistry,
805 device: &metal::DeviceRef,
806 gate_up: &MlxBuffer,
807 gu_byte_offset: usize,
808 output: &MlxBuffer,
809 out_byte_offset: usize,
810 n_elements: usize,
811) -> Result<()> {
812 if n_elements == 0 {
813 return Err(MlxError::InvalidArgument(
814 "moe_swiglu_fused_encode_offset: n_elements must be > 0".into(),
815 ));
816 }
817 let gu_required = gu_byte_offset + 2 * n_elements * std::mem::size_of::<f32>();
818 if gate_up.byte_len() < gu_required {
819 return Err(MlxError::InvalidArgument(format!(
820 "moe_swiglu_fused_encode_offset: gate_up buffer too small: need {} bytes (offset {}), have {}",
821 gu_required, gu_byte_offset, gate_up.byte_len()
822 )));
823 }
824 let out_required = out_byte_offset + n_elements * std::mem::size_of::<f32>();
825 if output.byte_len() < out_required {
826 return Err(MlxError::InvalidArgument(format!(
827 "moe_swiglu_fused_encode_offset: output buffer too small: need {} bytes (offset {}), have {}",
828 out_required, out_byte_offset, output.byte_len()
829 )));
830 }
831
832 let pipeline = registry.get_pipeline("moe_swiglu_fused", device)?;
833 let params = GpuFusedGeluMulParams {
834 n_elements: n_elements as u32,
835 };
836 encode_with_args(
837 encoder,
838 pipeline,
839 &[
840 (0, KernelArg::BufferWithOffset(gate_up, gu_byte_offset as u64)),
841 (1, KernelArg::BufferWithOffset(output, out_byte_offset as u64)),
842 (2, KernelArg::Bytes(as_bytes(¶ms))),
843 ],
844 MTLSize::new(n_elements as u64, 1, 1),
845 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
846 );
847 Ok(())
848}
849
850#[allow(clippy::too_many_arguments)]
855pub fn moe_accumulate_encode_offset(
856 encoder: &mut CommandEncoder,
857 registry: &mut KernelRegistry,
858 device: &metal::DeviceRef,
859 accumulator: &MlxBuffer,
860 expert_output: &MlxBuffer,
861 src_byte_offset: usize,
862 routing_weight: f32,
863 n_elements: usize,
864) -> Result<()> {
865 if n_elements == 0 {
866 return Err(MlxError::InvalidArgument(
867 "moe_accumulate_encode_offset: n_elements must be > 0".into(),
868 ));
869 }
870 let required = n_elements * std::mem::size_of::<f32>();
871 if accumulator.byte_len() < required {
872 return Err(MlxError::InvalidArgument(format!(
873 "moe_accumulate_encode_offset: accumulator too small: need {} bytes, have {}",
874 required, accumulator.byte_len()
875 )));
876 }
877 let src_required = src_byte_offset + required;
878 if expert_output.byte_len() < src_required {
879 return Err(MlxError::InvalidArgument(format!(
880 "moe_accumulate_encode_offset: expert_output too small: need {} bytes (offset {}), have {}",
881 src_required, src_byte_offset, expert_output.byte_len()
882 )));
883 }
884
885 let pipeline = registry.get_pipeline("moe_accumulate", device)?;
886 let params = GpuMoeAccumParams {
887 n_elements: n_elements as u32,
888 routing_weight,
889 };
890 encode_with_args(
891 encoder,
892 pipeline,
893 &[
894 (0, KernelArg::Buffer(accumulator)),
895 (1, KernelArg::BufferWithOffset(expert_output, src_byte_offset as u64)),
896 (2, KernelArg::Bytes(as_bytes(¶ms))),
897 ],
898 MTLSize::new(n_elements as u64, 1, 1),
899 MTLSize::new(std::cmp::min(256, n_elements as u64), 1, 1),
900 );
901 Ok(())
902}
903
904#[repr(C)]
906#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
907struct GpuMoeSwigluSeqParams {
908 intermediate: u32,
909 top_k: u32,
910 n_tokens: u32,
911}
912
913#[allow(clippy::too_many_arguments)]
918pub fn moe_swiglu_seq_encode(
919 encoder: &mut CommandEncoder,
920 registry: &mut KernelRegistry,
921 device: &metal::DeviceRef,
922 gate_up: &MlxBuffer,
923 output: &MlxBuffer,
924 intermediate: usize,
925 top_k: usize,
926 n_tokens: usize,
927) -> Result<()> {
928 if intermediate == 0 || top_k == 0 || n_tokens == 0 {
929 return Err(MlxError::InvalidArgument(
930 "moe_swiglu_seq_encode: all dims must be > 0".into(),
931 ));
932 }
933 let gu_required = n_tokens * top_k * 2 * intermediate * std::mem::size_of::<f32>();
934 if gate_up.byte_len() < gu_required {
935 return Err(MlxError::InvalidArgument(format!(
936 "moe_swiglu_seq_encode: gate_up too small: need {} bytes, have {}",
937 gu_required, gate_up.byte_len()
938 )));
939 }
940 let out_required = n_tokens * top_k * intermediate * std::mem::size_of::<f32>();
941 if output.byte_len() < out_required {
942 return Err(MlxError::InvalidArgument(format!(
943 "moe_swiglu_seq_encode: output too small: need {} bytes, have {}",
944 out_required, output.byte_len()
945 )));
946 }
947
948 let pipeline = registry.get_pipeline("moe_swiglu_seq", device)?;
949 let gpu_params = GpuMoeSwigluSeqParams {
950 intermediate: intermediate as u32,
951 top_k: top_k as u32,
952 n_tokens: n_tokens as u32,
953 };
954
955 encode_with_args(
956 encoder,
957 pipeline,
958 &[
959 (0, KernelArg::Buffer(gate_up)),
960 (1, KernelArg::Buffer(output)),
961 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
962 ],
963 MTLSize::new(intermediate as u64, top_k as u64, n_tokens as u64),
964 MTLSize::new(std::cmp::min(256, intermediate as u64), 1, 1),
965 );
966 Ok(())
967}
968
969#[repr(C)]
971#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
972struct GpuMoeWeightedSumSeqParams {
973 hidden_size: u32,
974 top_k: u32,
975 n_tokens: u32,
976}
977
978#[allow(clippy::too_many_arguments)]
984pub fn moe_weighted_sum_seq_encode(
985 encoder: &mut CommandEncoder,
986 registry: &mut KernelRegistry,
987 device: &metal::DeviceRef,
988 expert_outputs: &MlxBuffer,
989 weights: &MlxBuffer,
990 output: &MlxBuffer,
991 hidden_size: usize,
992 top_k: usize,
993 n_tokens: usize,
994) -> Result<()> {
995 if hidden_size == 0 || top_k == 0 || n_tokens == 0 {
996 return Err(MlxError::InvalidArgument(
997 "moe_weighted_sum_seq_encode: all dims must be > 0".into(),
998 ));
999 }
1000 let expert_required = n_tokens * top_k * hidden_size * std::mem::size_of::<f32>();
1001 if expert_outputs.byte_len() < expert_required {
1002 return Err(MlxError::InvalidArgument(format!(
1003 "moe_weighted_sum_seq_encode: expert_outputs too small: need {} bytes, have {}",
1004 expert_required, expert_outputs.byte_len()
1005 )));
1006 }
1007 let weights_required = n_tokens * top_k * std::mem::size_of::<f32>();
1008 if weights.byte_len() < weights_required {
1009 return Err(MlxError::InvalidArgument(format!(
1010 "moe_weighted_sum_seq_encode: weights too small: need {} bytes, have {}",
1011 weights_required, weights.byte_len()
1012 )));
1013 }
1014 let out_required = n_tokens * hidden_size * std::mem::size_of::<f32>();
1015 if output.byte_len() < out_required {
1016 return Err(MlxError::InvalidArgument(format!(
1017 "moe_weighted_sum_seq_encode: output too small: need {} bytes, have {}",
1018 out_required, output.byte_len()
1019 )));
1020 }
1021
1022 let pipeline = registry.get_pipeline("moe_weighted_sum_seq", device)?;
1023 let gpu_params = GpuMoeWeightedSumSeqParams {
1024 hidden_size: hidden_size as u32,
1025 top_k: top_k as u32,
1026 n_tokens: n_tokens as u32,
1027 };
1028
1029 encode_with_args(
1030 encoder,
1031 pipeline,
1032 &[
1033 (0, KernelArg::Buffer(expert_outputs)),
1034 (1, KernelArg::Buffer(weights)),
1035 (2, KernelArg::Buffer(output)),
1036 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
1037 ],
1038 MTLSize::new(hidden_size as u64, n_tokens as u64, 1),
1039 MTLSize::new(std::cmp::min(256, hidden_size as u64), 1, 1),
1040 );
1041 Ok(())
1042}