1use std::path::Path;
23use std::sync::{
24 atomic::{AtomicU64, Ordering},
25 OnceLock,
26};
27
28use candle_core::quantized::GgmlDType;
29use candle_core::{Device, Result as CandleResult};
30use ferrum_kernels::backend::cpu::CpuBackend;
31use ferrum_kernels::backend::{
32 Backend, BackendMoeFused, BackendPagedKv, BackendQuantGguf, BackendQuantMarlin, GgufQuantType,
33 LlmBackend, QuantLlmBackend,
34};
35use ferrum_kernels::{Linear, StackedExpertGgufLinear};
36use ferrum_quantization::gguf::GgufFile;
37use ferrum_quantization::{DenseLinear, QuantLinear};
38use ferrum_types::{FerrumError, Result};
39
40use crate::moe::router::RouterOutput;
41
42pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
46pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
47pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
48pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
49pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
50pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
51pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
52pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
53pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
54pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
55pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
56pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
57pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
58pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
59
60pub static MOE_BUCKET_SYNC_US: AtomicU64 = AtomicU64::new(0);
63pub static MOE_BUCKET_D2H_US: AtomicU64 = AtomicU64::new(0);
64pub static MOE_BUCKET_ROUTE_US: AtomicU64 = AtomicU64::new(0);
65pub static MOE_BUCKET_PLAN_US: AtomicU64 = AtomicU64::new(0);
66pub static MOE_BUCKET_GATHER_US: AtomicU64 = AtomicU64::new(0);
67pub static MOE_BUCKET_GEMM1_US: AtomicU64 = AtomicU64::new(0);
68pub static MOE_BUCKET_SILU_US: AtomicU64 = AtomicU64::new(0);
69pub static MOE_BUCKET_GEMM3_US: AtomicU64 = AtomicU64::new(0);
70pub static MOE_BUCKET_COMBINE_US: AtomicU64 = AtomicU64::new(0);
71pub static MOE_BUCKET_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74struct MoeDispatchRuntimeConfig {
75 moe_profile: bool,
76 decode_op_profile: bool,
77 vllm_moe_zero_ws: bool,
78 vllm_moe_pair_ids: bool,
79 moe_load_trace: bool,
80 moe_block_size: Option<usize>,
81 moe_large_m_block_size: Option<usize>,
82 moe_large_m_min_pairs: usize,
83 vllm_moe: bool,
84 moe_host_route: bool,
85}
86
87impl Default for MoeDispatchRuntimeConfig {
88 fn default() -> Self {
89 Self {
90 moe_profile: false,
91 decode_op_profile: false,
92 vllm_moe_zero_ws: false,
93 vllm_moe_pair_ids: false,
94 moe_load_trace: false,
95 moe_block_size: None,
96 moe_large_m_block_size: None,
97 moe_large_m_min_pairs: 1024,
98 vllm_moe: false,
99 moe_host_route: false,
100 }
101 }
102}
103
104impl MoeDispatchRuntimeConfig {
105 fn from_env() -> Self {
106 Self::from_env_vars(std::env::vars())
107 }
108
109 fn from_env_vars<I, K, V>(vars: I) -> Self
110 where
111 I: IntoIterator<Item = (K, V)>,
112 K: AsRef<str>,
113 V: AsRef<str>,
114 {
115 let mut config = Self::default();
116 for (name, value) in vars {
117 let value = value.as_ref();
118 match name.as_ref() {
119 "FERRUM_MOE_PROFILE" => config.moe_profile = true,
120 "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
121 "FERRUM_VLLM_MOE_ZERO_WS" => config.vllm_moe_zero_ws = value == "1",
122 "FERRUM_VLLM_MOE_PAIR_IDS" => config.vllm_moe_pair_ids = value == "1",
123 "FERRUM_MOE_LOAD_TRACE" => config.moe_load_trace = true,
124 "FERRUM_MOE_BLOCK_SIZE" => {
125 config.moe_block_size = parse_moe_block_size_value(value);
126 }
127 "FERRUM_MOE_LARGE_M_BLOCK_SIZE" => {
128 config.moe_large_m_block_size = parse_moe_block_size_value(value);
129 }
130 "FERRUM_MOE_LARGE_M_MIN_PAIRS" => {
131 config.moe_large_m_min_pairs = value.parse::<usize>().unwrap_or(1024);
132 }
133 "FERRUM_VLLM_MOE" => config.vllm_moe = value == "1",
134 "FERRUM_MOE_HOST_ROUTE" => config.moe_host_route = value == "1",
135 _ => {}
136 }
137 }
138 config
139 }
140}
141
142fn parse_moe_block_size_value(value: &str) -> Option<usize> {
143 value
144 .parse::<usize>()
145 .ok()
146 .filter(|bs| matches!(*bs, 8 | 16 | 32 | 48 | 64))
147}
148
149fn moe_dispatch_runtime_config() -> &'static MoeDispatchRuntimeConfig {
150 static CONFIG: OnceLock<MoeDispatchRuntimeConfig> = OnceLock::new();
151 CONFIG.get_or_init(MoeDispatchRuntimeConfig::from_env)
152}
153
154fn moe_profile_enabled() -> bool {
155 moe_dispatch_runtime_config().moe_profile
156}
157
158pub struct ExpertStack<B: QuantLlmBackend + BackendMoeFused> {
168 pub gate_up: Vec<Box<dyn Linear<B>>>,
171 pub down: Vec<Box<dyn Linear<B>>>,
173 pub gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
186 pub up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
187 pub down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
188
189 pub gate_up_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
198 pub down_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
199}
200
201impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
202 pub fn gate_up_stacked_store(
206 &self,
207 _expert_idx: usize,
208 ) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
209 self.gate_up_marlin_stack.as_ref()
210 }
211
212 pub fn down_stacked_store(
214 &self,
215 _expert_idx: usize,
216 ) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
217 self.down_marlin_stack.as_ref()
218 }
219
220 pub fn gemv_gate(
231 &self,
232 ctx: &mut B::Context,
233 input: &B::Buffer,
234 ids: &B::Buffer,
235 out: &mut B::Buffer,
236 top_k: usize,
237 ) -> Result<()> {
238 let weight = self.gate_stacked.as_deref().ok_or_else(|| {
239 FerrumError::unsupported("ExpertStack::gemv_gate: gate_stacked not loaded")
240 })?;
241 weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
242 }
243
244 pub fn gemv_up(
246 &self,
247 ctx: &mut B::Context,
248 input: &B::Buffer,
249 ids: &B::Buffer,
250 out: &mut B::Buffer,
251 top_k: usize,
252 ) -> Result<()> {
253 let weight = self.up_stacked.as_deref().ok_or_else(|| {
254 FerrumError::unsupported("ExpertStack::gemv_up: up_stacked not loaded")
255 })?;
256 weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
257 }
258
259 pub fn gemv_down(
262 &self,
263 ctx: &mut B::Context,
264 input: &B::Buffer,
265 ids: &B::Buffer,
266 out: &mut B::Buffer,
267 top_k: usize,
268 expert_intermediate: usize,
269 ) -> Result<()> {
270 let weight = self.down_stacked.as_deref().ok_or_else(|| {
271 FerrumError::unsupported("ExpertStack::gemv_down: down_stacked not loaded")
272 })?;
273 weight.gemv_moe_id(ctx, input, ids, out, top_k, expert_intermediate)
274 }
275
276 pub fn gemv_gate_up_silu_fused(
280 &self,
281 ctx: &mut B::Context,
282 input: &B::Buffer,
283 ids: &B::Buffer,
284 out_silu_stacked: &mut B::Buffer,
285 top_k: usize,
286 ) -> Result<()> {
287 let gate = self.gate_stacked.as_deref().ok_or_else(|| {
288 FerrumError::unsupported(
289 "ExpertStack::gemv_gate_up_silu_fused: gate_stacked not loaded",
290 )
291 })?;
292 let up = self.up_stacked.as_deref().ok_or_else(|| {
293 FerrumError::unsupported("ExpertStack::gemv_gate_up_silu_fused: up_stacked not loaded")
294 })?;
295 gate.gemv_moe_id_gate_up_silu(ctx, input, up, ids, out_silu_stacked, top_k)
296 }
297
298 #[allow(clippy::too_many_arguments)]
312 pub fn gemm_gate(
313 &self,
314 ctx: &mut B::Context,
315 src1: &B::Buffer,
316 ids: &B::Buffer,
317 tpe: &B::Buffer,
318 dst: &mut B::Buffer,
319 args_buf: Option<&B::Buffer>,
320 top_k: usize,
321 max_per_expert: usize,
322 tokens: usize,
323 ) -> Result<()> {
324 let weight = self.gate_stacked.as_deref().ok_or_else(|| {
325 FerrumError::unsupported("ExpertStack::gemm_gate: gate_stacked not loaded")
326 })?;
327 match args_buf {
328 Some(args) => weight.gemm_moe_id_indirect(
329 ctx,
330 src1,
331 ids,
332 tpe,
333 dst,
334 args,
335 1,
336 top_k,
337 max_per_expert,
338 tokens,
339 ),
340 None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
341 }
342 }
343
344 #[allow(clippy::too_many_arguments)]
346 pub fn gemm_up(
347 &self,
348 ctx: &mut B::Context,
349 src1: &B::Buffer,
350 ids: &B::Buffer,
351 tpe: &B::Buffer,
352 dst: &mut B::Buffer,
353 args_buf: Option<&B::Buffer>,
354 top_k: usize,
355 max_per_expert: usize,
356 tokens: usize,
357 ) -> Result<()> {
358 let weight = self.up_stacked.as_deref().ok_or_else(|| {
359 FerrumError::unsupported("ExpertStack::gemm_up: up_stacked not loaded")
360 })?;
361 match args_buf {
362 Some(args) => weight.gemm_moe_id_indirect(
363 ctx,
364 src1,
365 ids,
366 tpe,
367 dst,
368 args,
369 1,
370 top_k,
371 max_per_expert,
372 tokens,
373 ),
374 None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
375 }
376 }
377
378 #[allow(clippy::too_many_arguments)]
381 pub fn gemm_down(
382 &self,
383 ctx: &mut B::Context,
384 src1: &B::Buffer,
385 ids: &B::Buffer,
386 tpe: &B::Buffer,
387 dst: &mut B::Buffer,
388 args_buf: Option<&B::Buffer>,
389 top_k: usize,
390 max_per_expert: usize,
391 tokens: usize,
392 ) -> Result<()> {
393 let weight = self.down_stacked.as_deref().ok_or_else(|| {
394 FerrumError::unsupported("ExpertStack::gemm_down: down_stacked not loaded")
395 })?;
396 match args_buf {
397 Some(args) => weight.gemm_moe_id_indirect(
398 ctx,
399 src1,
400 ids,
401 tpe,
402 dst,
403 args,
404 top_k,
405 top_k,
406 max_per_expert,
407 tokens,
408 ),
409 None => weight.gemm_moe_id(
410 ctx,
411 src1,
412 ids,
413 tpe,
414 dst,
415 top_k,
416 top_k,
417 max_per_expert,
418 tokens,
419 ),
420 }
421 }
422
423 #[allow(clippy::too_many_arguments)]
431 pub fn gemv_gate_batched(
432 &self,
433 ctx: &mut B::Context,
434 input: &B::Buffer,
435 ids: &B::Buffer,
436 dst: &mut B::Buffer,
437 m: usize,
438 top_k: usize,
439 src1_outer_stride: usize,
440 src1_inner_stride: usize,
441 ) -> Result<()> {
442 let weight = self.gate_stacked.as_deref().ok_or_else(|| {
443 FerrumError::unsupported("ExpertStack::gemv_gate_batched: gate_stacked not loaded")
444 })?;
445 weight.gemv_moe_id_batched(
446 ctx,
447 input,
448 ids,
449 dst,
450 m,
451 top_k,
452 src1_outer_stride,
453 src1_inner_stride,
454 )
455 }
456
457 #[allow(clippy::too_many_arguments)]
459 pub fn gemv_up_batched(
460 &self,
461 ctx: &mut B::Context,
462 input: &B::Buffer,
463 ids: &B::Buffer,
464 dst: &mut B::Buffer,
465 m: usize,
466 top_k: usize,
467 src1_outer_stride: usize,
468 src1_inner_stride: usize,
469 ) -> Result<()> {
470 let weight = self.up_stacked.as_deref().ok_or_else(|| {
471 FerrumError::unsupported("ExpertStack::gemv_up_batched: up_stacked not loaded")
472 })?;
473 weight.gemv_moe_id_batched(
474 ctx,
475 input,
476 ids,
477 dst,
478 m,
479 top_k,
480 src1_outer_stride,
481 src1_inner_stride,
482 )
483 }
484
485 #[allow(clippy::too_many_arguments)]
487 pub fn gemv_down_batched(
488 &self,
489 ctx: &mut B::Context,
490 input: &B::Buffer,
491 ids: &B::Buffer,
492 dst: &mut B::Buffer,
493 m: usize,
494 top_k: usize,
495 src1_outer_stride: usize,
496 src1_inner_stride: usize,
497 ) -> Result<()> {
498 let weight = self.down_stacked.as_deref().ok_or_else(|| {
499 FerrumError::unsupported("ExpertStack::gemv_down_batched: down_stacked not loaded")
500 })?;
501 weight.gemv_moe_id_batched(
502 ctx,
503 input,
504 ids,
505 dst,
506 m,
507 top_k,
508 src1_outer_stride,
509 src1_inner_stride,
510 )
511 }
512
513 #[allow(clippy::too_many_arguments)]
516 pub fn gemv_gate_up_silu_batched_fused(
517 &self,
518 ctx: &mut B::Context,
519 input: &B::Buffer,
520 ids: &B::Buffer,
521 silu_out: &mut B::Buffer,
522 m: usize,
523 top_k: usize,
524 src1_outer_stride: usize,
525 src1_inner_stride: usize,
526 ) -> Result<()> {
527 let gate = self.gate_stacked.as_deref().ok_or_else(|| {
528 FerrumError::unsupported(
529 "ExpertStack::gemv_gate_up_silu_batched_fused: gate_stacked not loaded",
530 )
531 })?;
532 let up = self.up_stacked.as_deref().ok_or_else(|| {
533 FerrumError::unsupported(
534 "ExpertStack::gemv_gate_up_silu_batched_fused: up_stacked not loaded",
535 )
536 })?;
537 gate.gemv_moe_id_gate_up_silu_batched(
538 ctx,
539 input,
540 up,
541 ids,
542 silu_out,
543 m,
544 top_k,
545 src1_outer_stride,
546 src1_inner_stride,
547 )
548 }
549
550 #[allow(clippy::too_many_arguments)]
558 pub fn gemv_gate_offset(
559 &self,
560 ctx: &mut B::Context,
561 src1: &B::Buffer,
562 src1_offset: usize,
563 ids: &B::Buffer,
564 ids_offset: usize,
565 dst: &mut B::Buffer,
566 top_k: usize,
567 src1_stride: usize,
568 ) -> Result<()> {
569 let weight = self.gate_stacked.as_deref().ok_or_else(|| {
570 FerrumError::unsupported("ExpertStack::gemv_gate_offset: gate_stacked not loaded")
571 })?;
572 weight.gemv_moe_id_offset(
573 ctx,
574 src1,
575 src1_offset,
576 ids,
577 ids_offset,
578 dst,
579 top_k,
580 src1_stride,
581 )
582 }
583
584 #[allow(clippy::too_many_arguments)]
586 pub fn gemv_up_offset(
587 &self,
588 ctx: &mut B::Context,
589 src1: &B::Buffer,
590 src1_offset: usize,
591 ids: &B::Buffer,
592 ids_offset: usize,
593 dst: &mut B::Buffer,
594 top_k: usize,
595 src1_stride: usize,
596 ) -> Result<()> {
597 let weight = self.up_stacked.as_deref().ok_or_else(|| {
598 FerrumError::unsupported("ExpertStack::gemv_up_offset: up_stacked not loaded")
599 })?;
600 weight.gemv_moe_id_offset(
601 ctx,
602 src1,
603 src1_offset,
604 ids,
605 ids_offset,
606 dst,
607 top_k,
608 src1_stride,
609 )
610 }
611
612 #[allow(clippy::too_many_arguments)]
614 pub fn gemv_down_offset(
615 &self,
616 ctx: &mut B::Context,
617 src1: &B::Buffer,
618 src1_offset: usize,
619 ids: &B::Buffer,
620 ids_offset: usize,
621 dst: &mut B::Buffer,
622 top_k: usize,
623 src1_stride: usize,
624 ) -> Result<()> {
625 let weight = self.down_stacked.as_deref().ok_or_else(|| {
626 FerrumError::unsupported("ExpertStack::gemv_down_offset: down_stacked not loaded")
627 })?;
628 weight.gemv_moe_id_offset(
629 ctx,
630 src1,
631 src1_offset,
632 ids,
633 ids_offset,
634 dst,
635 top_k,
636 src1_stride,
637 )
638 }
639}
640
641impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
642 pub fn from_dense_stacks(
649 gate_stack: &[f32],
650 up_stack: &[f32],
651 down_stack: &[f32],
652 num_experts: usize,
653 hidden_size: usize,
654 expert_intermediate: usize,
655 ) -> Result<Self> {
656 let gate_up_per_expert = expert_intermediate * hidden_size;
657 let down_per_expert = hidden_size * expert_intermediate;
658
659 check_size(
660 gate_stack.len(),
661 num_experts * gate_up_per_expert,
662 "gate_stack",
663 )?;
664 check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
665 check_size(
666 down_stack.len(),
667 num_experts * down_per_expert,
668 "down_stack",
669 )?;
670
671 let mut gate_up = Vec::with_capacity(num_experts);
672 let mut down = Vec::with_capacity(num_experts);
673 for e in 0..num_experts {
674 let g_off = e * gate_up_per_expert;
675 let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
676 let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
677
678 let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
683 fused.extend_from_slice(g_slice);
684 fused.extend_from_slice(u_slice);
685 gate_up.push(Box::new(DenseLinear::<B>::from_rows(
686 &fused,
687 2 * expert_intermediate,
688 hidden_size,
689 )) as Box<dyn Linear<B>>);
690
691 let d_off = e * down_per_expert;
692 let d_slice = &down_stack[d_off..d_off + down_per_expert];
693 down.push(Box::new(DenseLinear::<B>::from_rows(
694 d_slice,
695 hidden_size,
696 expert_intermediate,
697 )) as Box<dyn Linear<B>>);
698 }
699 Ok(Self {
700 gate_up,
701 down,
702 gate_stacked: None,
703 up_stacked: None,
704 down_stacked: None,
705 gate_up_marlin_stack: None,
706 down_marlin_stack: None,
707 })
708 }
709
710 pub fn load_from_gguf(
730 gguf: &GgufFile,
731 layer_idx: usize,
732 num_experts: usize,
733 hidden_size: usize,
734 expert_intermediate: usize,
735 ) -> Result<Self> {
736 let runtime_config = moe_dispatch_runtime_config();
737 if let Some(quant) = Self::try_load_quantised(
738 gguf,
739 layer_idx,
740 num_experts,
741 hidden_size,
742 expert_intermediate,
743 )? {
744 if runtime_config.moe_load_trace {
745 eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
746 }
747 return Ok(quant);
748 }
749
750 if runtime_config.moe_load_trace {
751 eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
752 }
753
754 let device = Device::Cpu;
755 let gate = read_dequant_flat(
756 gguf,
757 &format!("blk.{layer_idx}.ffn_gate_exps.weight"),
758 &device,
759 )?;
760 let up = read_dequant_flat(
761 gguf,
762 &format!("blk.{layer_idx}.ffn_up_exps.weight"),
763 &device,
764 )?;
765 let down = read_dequant_flat(
766 gguf,
767 &format!("blk.{layer_idx}.ffn_down_exps.weight"),
768 &device,
769 )?;
770 Self::from_dense_stacks(
773 &gate,
774 &up,
775 &down,
776 num_experts,
777 hidden_size,
778 expert_intermediate,
779 )
780 }
781
782 fn try_load_quantised(
788 gguf: &GgufFile,
789 layer_idx: usize,
790 num_experts: usize,
791 hidden_size: usize,
792 expert_intermediate: usize,
793 ) -> Result<Option<Self>> {
794 let device = Device::Cpu;
795
796 let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
797 let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
798 let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
799
800 let gate_kind = match quant_kind(gguf, &gate_name)? {
804 Some(k) => k,
805 None => return Ok(None),
806 };
807 let up_kind = match quant_kind(gguf, &up_name)? {
808 Some(k) => k,
809 None => return Ok(None),
810 };
811 let down_kind = match quant_kind(gguf, &down_name)? {
812 Some(k) => k,
813 None => return Ok(None),
814 };
815
816 let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
826 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
827 })?;
828 let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
829 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
830 })?;
831 let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
832 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
833 })?;
834 let _ = device; let gate_per = block_bytes_for(
840 gate_kind,
841 expert_intermediate * hidden_size,
842 "ffn_gate_exps",
843 )?;
844 let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
845 let down_per = block_bytes_for(
846 down_kind,
847 hidden_size * expert_intermediate,
848 "ffn_down_exps",
849 )?;
850
851 check_size(
852 gate_bytes.len(),
853 num_experts * gate_per,
854 "ffn_gate_exps bytes",
855 )?;
856 check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
857 check_size(
858 down_bytes.len(),
859 num_experts * down_per,
860 "ffn_down_exps bytes",
861 )?;
862
863 let gate_stacked = B::load_quant_experts(
871 gate_kind,
872 gate_bytes,
873 num_experts,
874 expert_intermediate,
875 hidden_size,
876 )
877 .ok();
878 let up_stacked = B::load_quant_experts(
879 up_kind,
880 up_bytes,
881 num_experts,
882 expert_intermediate,
883 hidden_size,
884 )
885 .ok();
886 let down_stacked = B::load_quant_experts(
887 down_kind,
888 down_bytes,
889 num_experts,
890 hidden_size,
891 expert_intermediate,
892 )
893 .ok();
894
895 let stacked_complete =
903 gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
904
905 let (gate_up, down) = if stacked_complete {
906 (Vec::new(), Vec::new())
909 } else {
910 let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
911 let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
912 for e in 0..num_experts {
913 let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
914 let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
915 let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
916
917 let parts: [(GgufQuantType, &[u8], usize); 2] = [
918 (gate_kind, g_slice, expert_intermediate),
919 (up_kind, u_slice, expert_intermediate),
920 ];
921 let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
922 Ok(q) => q,
923 Err(_) => return Ok(None),
924 };
925 gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
926
927 let down_e = match QuantLinear::<B>::from_gguf_bytes(
928 down_kind,
929 d_slice,
930 hidden_size,
931 expert_intermediate,
932 ) {
933 Ok(q) => q,
934 Err(_) => return Ok(None),
935 };
936 down.push(Box::new(down_e) as Box<dyn Linear<B>>);
937 }
938 (gate_up, down)
939 };
940
941 Ok(Some(Self {
942 gate_up,
943 down,
944 gate_stacked,
945 up_stacked,
946 down_stacked,
947 gate_up_marlin_stack: None,
948 down_marlin_stack: None,
949 }))
950 }
951
952 pub fn open_and_load(
956 path: impl AsRef<Path>,
957 layer_idx: usize,
958 num_experts: usize,
959 hidden_size: usize,
960 expert_intermediate: usize,
961 ) -> Result<Self> {
962 let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
963 Self::load_from_gguf(
964 &gguf,
965 layer_idx,
966 num_experts,
967 hidden_size,
968 expert_intermediate,
969 )
970 }
971
972 pub fn num_experts(&self) -> usize {
980 debug_assert_eq!(
981 self.gate_up.len(),
982 self.down.len(),
983 "ExpertStack: gate_up and down disagree on expert count"
984 );
985 self.gate_up.len()
986 }
987}
988
989pub struct MoeForwardParams<'a, B: QuantLlmBackend + BackendMoeFused> {
1027 pub ctx: &'a mut B::Context,
1028 pub x: &'a B::Buffer,
1029 pub router_logits: &'a B::Buffer,
1030 pub out: &'a mut B::Buffer,
1031 pub batch: usize,
1032 pub hidden_size: usize,
1033 pub expert_intermediate: usize,
1034 pub num_experts: usize,
1035 pub top_k: usize,
1036 pub norm_topk_prob: bool,
1037 pub experts: &'a ExpertStack<B>,
1038 pub x_single: &'a mut B::Buffer,
1039 pub acc_buf: &'a mut B::Buffer,
1040 pub gate_up_buf: &'a mut B::Buffer,
1041 pub silu_buf: &'a mut B::Buffer,
1042 pub down_buf: &'a mut B::Buffer,
1043 pub zero_hidden: &'a B::Buffer,
1044}
1045
1046pub fn moe_forward<B: QuantLlmBackend + BackendMoeFused>(
1047 params: MoeForwardParams<'_, B>,
1048) -> Result<()> {
1049 let MoeForwardParams {
1050 ctx,
1051 x,
1052 router_logits,
1053 out,
1054 batch,
1055 hidden_size,
1056 expert_intermediate,
1057 num_experts,
1058 top_k,
1059 norm_topk_prob,
1060 experts,
1061 x_single,
1062 acc_buf,
1063 gate_up_buf,
1064 silu_buf,
1065 down_buf,
1066 zero_hidden,
1067 } = params;
1068 let n_experts = experts.num_experts();
1069 if n_experts != num_experts {
1070 return Err(FerrumError::model(format!(
1071 "moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
1072 )));
1073 }
1074
1075 let prof = moe_profile_enabled();
1076
1077 let t0 = if prof {
1081 Some(std::time::Instant::now())
1082 } else {
1083 None
1084 };
1085 B::sync(ctx);
1086 if let Some(t) = t0 {
1087 MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1088 MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
1089 }
1090
1091 let t0 = if prof {
1092 Some(std::time::Instant::now())
1093 } else {
1094 None
1095 };
1096 let logits_host = B::to_vec(router_logits, batch * num_experts);
1097 let route_out =
1098 crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
1099 if let Some(t) = t0 {
1100 MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1101 MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
1102 }
1103
1104 for b in 0..batch {
1105 let t0 = if prof {
1107 Some(std::time::Instant::now())
1108 } else {
1109 None
1110 };
1111 B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
1112 B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
1113 if let Some(t) = t0 {
1114 MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1115 MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
1116 }
1117
1118 for k in 0..top_k {
1119 let pair = b * top_k + k;
1120 let expert_id = route_out.expert_ids[pair] as usize;
1121 let weight = route_out.expert_weights[pair];
1122 if expert_id >= num_experts {
1123 return Err(FerrumError::model(format!(
1124 "moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
1125 )));
1126 }
1127
1128 let t0 = if prof {
1130 B::sync(ctx);
1131 Some(std::time::Instant::now())
1132 } else {
1133 None
1134 };
1135 experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
1136 if let Some(t) = t0 {
1137 B::sync(ctx);
1138 MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1139 MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
1140 }
1141
1142 let t0 = if prof {
1144 Some(std::time::Instant::now())
1145 } else {
1146 None
1147 };
1148 B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
1149 if let Some(t) = t0 {
1150 B::sync(ctx);
1151 MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1152 MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
1153 }
1154
1155 let t0 = if prof {
1157 Some(std::time::Instant::now())
1158 } else {
1159 None
1160 };
1161 experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
1162 if let Some(t) = t0 {
1163 B::sync(ctx);
1164 MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1165 MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
1166 }
1167
1168 let t0 = if prof {
1170 Some(std::time::Instant::now())
1171 } else {
1172 None
1173 };
1174 B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
1175 if let Some(t) = t0 {
1176 B::sync(ctx);
1177 MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1178 MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
1179 }
1180 }
1181
1182 let t0 = if prof {
1184 Some(std::time::Instant::now())
1185 } else {
1186 None
1187 };
1188 B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
1189 if let Some(t) = t0 {
1190 MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1191 MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
1192 }
1193 }
1194
1195 Ok(())
1196}
1197
1198pub const MOE_BLOCK_SIZE_MAX: usize = 64;
1201
1202fn pick_moe_block_size(
1217 plan: Option<&MoeBucketPlan>,
1218 num_experts: usize,
1219 use_device_route: bool,
1220 total_pairs: usize,
1221) -> usize {
1222 pick_moe_block_size_with_config(
1223 moe_dispatch_runtime_config(),
1224 plan,
1225 num_experts,
1226 use_device_route,
1227 total_pairs,
1228 )
1229}
1230
1231fn pick_moe_block_size_with_config(
1232 config: &MoeDispatchRuntimeConfig,
1233 plan: Option<&MoeBucketPlan>,
1234 num_experts: usize,
1235 use_device_route: bool,
1236 total_pairs: usize,
1237) -> usize {
1238 const CANDIDATES: &[usize] = &[64, 32, 16];
1239 const PADDING_BUDGET: f64 = 1.30; if let Some(bs) = config.moe_block_size {
1244 return bs;
1245 }
1246 if use_device_route {
1247 if let Some(bs) = config.moe_large_m_block_size {
1248 if total_pairs >= config.moe_large_m_min_pairs {
1249 return bs;
1250 }
1251 }
1252 return 16;
1261 }
1262 let Some(plan) = plan else {
1263 return 16;
1264 };
1265 let m_e: Vec<usize> = (0..num_experts)
1266 .map(|e| plan.expert_offsets[e + 1] - plan.expert_offsets[e])
1267 .collect();
1268 let total_actual: usize = m_e.iter().sum();
1269 if total_actual == 0 {
1270 return 16;
1271 }
1272 for &bs in CANDIDATES {
1273 let total_padded: usize = m_e.iter().map(|&m| m.div_ceil(bs) * bs).sum();
1274 if (total_padded as f64) <= (total_actual as f64) * PADDING_BUDGET {
1275 return bs;
1276 }
1277 }
1278 16
1279}
1280
1281pub struct MoeBucketPlan {
1286 pub expert_offsets: Vec<usize>,
1290 pub packed_token_idx: Vec<u32>,
1293 pub pairs_by_token: Vec<i32>,
1297 pub pair_weights: Vec<f32>,
1301 cursors: Vec<usize>,
1305}
1306
1307impl MoeBucketPlan {
1308 pub fn empty() -> Self {
1312 Self {
1313 expert_offsets: Vec::new(),
1314 packed_token_idx: Vec::new(),
1315 pairs_by_token: Vec::new(),
1316 pair_weights: Vec::new(),
1317 cursors: Vec::new(),
1318 }
1319 }
1320
1321 pub fn build(route: &RouterOutput, batch: usize, num_experts: usize, top_k: usize) -> Self {
1324 let mut p = Self::empty();
1325 p.rebuild_into(route, batch, num_experts, top_k);
1326 p
1327 }
1328
1329 pub fn rebuild_into(
1334 &mut self,
1335 route: &RouterOutput,
1336 batch: usize,
1337 num_experts: usize,
1338 top_k: usize,
1339 ) {
1340 debug_assert_eq!(route.expert_ids.len(), batch * top_k);
1341 debug_assert_eq!(route.expert_weights.len(), batch * top_k);
1342 let total_pairs = batch * top_k;
1343
1344 self.expert_offsets.clear();
1345 self.expert_offsets.resize(num_experts + 1, 0);
1346 self.packed_token_idx.clear();
1347 self.packed_token_idx.resize(total_pairs, 0);
1348 self.pairs_by_token.clear();
1349 self.pairs_by_token.resize(total_pairs, -1);
1350
1351 for &eid in &route.expert_ids {
1355 self.expert_offsets[eid as usize + 1] += 1;
1356 }
1357
1358 for e in 0..num_experts {
1360 self.expert_offsets[e + 1] += self.expert_offsets[e];
1361 }
1362
1363 self.cursors.clear();
1368 self.cursors
1369 .extend_from_slice(&self.expert_offsets[..num_experts]);
1370
1371 for b in 0..batch {
1372 for k in 0..top_k {
1373 let pair_flat = b * top_k + k;
1374 let eid = route.expert_ids[pair_flat] as usize;
1375 let slot = self.cursors[eid];
1376 self.cursors[eid] += 1;
1377 self.packed_token_idx[slot] = b as u32;
1378 self.pairs_by_token[pair_flat] = slot as i32;
1379 }
1380 }
1381
1382 self.pair_weights.clear();
1385 self.pair_weights.extend_from_slice(&route.expert_weights);
1386 }
1387}
1388
1389pub struct MoeRouteScratch {
1397 pub output: RouterOutput,
1398 pub probs: Vec<f32>,
1401 pub plan: MoeBucketPlan,
1402}
1403
1404impl MoeRouteScratch {
1405 pub fn new() -> Self {
1406 Self {
1407 output: RouterOutput::empty(),
1408 probs: Vec::new(),
1409 plan: MoeBucketPlan::empty(),
1410 }
1411 }
1412}
1413
1414impl Default for MoeRouteScratch {
1415 fn default() -> Self {
1416 Self::new()
1417 }
1418}
1419
1420pub struct DeviceRouteScratch<'a, B: crate::moe::dispatch::Backend> {
1429 pub selected_ids: &'a mut B::Buffer,
1430 pub pair_weights: &'a mut B::Buffer,
1431 pub pairs_by_token: &'a mut B::Buffer,
1432 pub packed_token_idx: &'a mut B::Buffer,
1433 pub expert_offsets: &'a mut B::Buffer,
1434 pub sorted_tokens: &'a mut B::Buffer,
1438 pub block_ids: &'a mut B::Buffer,
1439 pub total_post_pad: &'a mut B::Buffer,
1440}
1441
1442pub struct MoeForwardBucketedParams<'a, B: QuantLlmBackend + BackendMoeFused> {
1459 pub ctx: &'a mut B::Context,
1460 pub x: &'a B::Buffer,
1461 pub router_logits: &'a B::Buffer,
1462 pub out: &'a mut B::Buffer,
1463 pub batch: usize,
1464 pub hidden_size: usize,
1465 pub expert_intermediate: usize,
1466 pub num_experts: usize,
1467 pub top_k: usize,
1468 pub norm_topk_prob: bool,
1469 pub experts: &'a ExpertStack<B>,
1470 pub x_packed: &'a mut B::Buffer,
1471 pub gate_up_packed: &'a mut B::Buffer,
1472 pub silu_packed: &'a mut B::Buffer,
1473 pub down_packed: &'a mut B::Buffer,
1474 pub route_scratch: &'a mut MoeRouteScratch,
1475 pub device_route: Option<DeviceRouteScratch<'a, B>>,
1480}
1481
1482pub fn moe_forward_bucketed<B: QuantLlmBackend + BackendMoeFused>(
1483 params: MoeForwardBucketedParams<'_, B>,
1484) -> Result<()> {
1485 let MoeForwardBucketedParams {
1486 ctx,
1487 x,
1488 router_logits,
1489 out,
1490 batch,
1491 hidden_size,
1492 expert_intermediate,
1493 num_experts,
1494 top_k,
1495 norm_topk_prob,
1496 experts,
1497 x_packed,
1498 gate_up_packed,
1499 silu_packed,
1500 down_packed,
1501 route_scratch,
1502 device_route,
1503 } = params;
1504 if experts.num_experts() != num_experts {
1505 return Err(FerrumError::model(format!(
1506 "moe_forward_bucketed: experts {} != num_experts {num_experts}",
1507 experts.num_experts()
1508 )));
1509 }
1510
1511 let runtime_config = moe_dispatch_runtime_config();
1512 let prof = runtime_config.moe_profile || runtime_config.decode_op_profile;
1515 if prof {
1516 MOE_BUCKET_LAYER_CALLS.fetch_add(1, Ordering::Relaxed);
1517 }
1518
1519 let use_vllm_moe = runtime_config.vllm_moe;
1533 let use_device_route = device_route.is_some() && use_vllm_moe && !runtime_config.moe_host_route;
1549 let use_vllm_pair_ids = use_device_route && runtime_config.vllm_moe_pair_ids;
1550
1551 let mut dr_kept: Option<DeviceRouteScratch<'_, B>> = if use_device_route {
1558 let dr = device_route.expect("device_route is Some when use_device_route");
1559 B::route_topk_softmax(
1560 ctx,
1561 router_logits,
1562 dr.selected_ids,
1563 dr.pair_weights,
1564 batch,
1565 num_experts,
1566 top_k,
1567 norm_topk_prob,
1568 )?;
1569 if !use_vllm_pair_ids {
1570 B::moe_build_pairs_by_token(
1571 ctx,
1572 dr.selected_ids,
1573 dr.pairs_by_token,
1574 dr.packed_token_idx,
1575 dr.expert_offsets,
1576 batch * top_k,
1577 num_experts,
1578 top_k,
1579 )?;
1580 }
1581 Some(dr)
1582 } else {
1583 None
1584 };
1585
1586 let plan: Option<&crate::moe::MoeBucketPlan> = if !use_device_route {
1598 let t_route_total = if prof {
1599 Some(std::time::Instant::now())
1600 } else {
1601 None
1602 };
1603 let gpu_route = B::try_gpu_route_topk_into_host(
1604 ctx,
1605 router_logits,
1606 &mut route_scratch.output.expert_ids,
1607 &mut route_scratch.output.expert_weights,
1608 batch,
1609 num_experts,
1610 top_k,
1611 norm_topk_prob,
1612 );
1613 if gpu_route.is_err() {
1614 let t_sync = if prof {
1615 Some(std::time::Instant::now())
1616 } else {
1617 None
1618 };
1619 B::sync(ctx);
1620 if let Some(t) = t_sync {
1621 MOE_BUCKET_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1622 }
1623 let t_d2h = if prof {
1624 Some(std::time::Instant::now())
1625 } else {
1626 None
1627 };
1628 let logits_host = B::to_vec(router_logits, batch * num_experts);
1629 if let Some(t) = t_d2h {
1630 MOE_BUCKET_D2H_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1631 }
1632 let t_route = if prof {
1633 Some(std::time::Instant::now())
1634 } else {
1635 None
1636 };
1637 crate::moe::router::route_into(
1638 &logits_host,
1639 batch,
1640 num_experts,
1641 top_k,
1642 norm_topk_prob,
1643 &mut route_scratch.output,
1644 &mut route_scratch.probs,
1645 );
1646 if let Some(t) = t_route {
1647 MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1648 }
1649 } else if let Some(t) = t_route_total {
1650 MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1651 }
1652 let t_plan = if prof {
1653 Some(std::time::Instant::now())
1654 } else {
1655 None
1656 };
1657 route_scratch
1658 .plan
1659 .rebuild_into(&route_scratch.output, batch, num_experts, top_k);
1660 if let Some(t) = t_plan {
1661 MOE_BUCKET_PLAN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1662 }
1663 Some(&route_scratch.plan)
1664 } else {
1665 None
1666 };
1667
1668 if !use_vllm_pair_ids {
1672 let t_gather = if prof {
1673 Some(std::time::Instant::now())
1674 } else {
1675 None
1676 };
1677 if let Some(ref dr) = dr_kept {
1678 B::embedding_lookup_dev(
1679 ctx,
1680 x,
1681 dr.packed_token_idx,
1682 x_packed,
1683 batch * top_k,
1684 hidden_size,
1685 );
1686 } else {
1687 let plan = plan.expect("plan is Some when !use_device_route");
1688 B::embedding_lookup(ctx, x, &plan.packed_token_idx, x_packed, hidden_size);
1689 }
1690 if let Some(t) = t_gather {
1691 B::sync(ctx);
1692 MOE_BUCKET_GATHER_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1693 }
1694 }
1695
1696 let gate_up_dim_per_expert = 2 * expert_intermediate;
1704 let down_n_per_expert = hidden_size;
1705 let gu_store = experts.gate_up_stacked_store(0).ok_or_else(|| {
1711 FerrumError::model(
1712 "moe_forward_bucketed requires stacked gate_up store \
1713 (load via Qwen3MoeModel::new_safetensors)",
1714 )
1715 })?;
1716 let zero_marlin_workspace = !use_vllm_moe || runtime_config.vllm_moe_zero_ws;
1717 if zero_marlin_workspace {
1718 let _ = gu_store.zero_workspace(ctx);
1719 }
1720
1721 let total_pairs_active = batch * top_k;
1727 let max_block_size: usize = 64;
1761 let moe_block_size: usize = pick_moe_block_size_with_config(
1762 runtime_config,
1763 plan,
1764 num_experts,
1765 use_device_route,
1766 total_pairs_active,
1767 );
1768 debug_assert!(
1769 moe_block_size <= max_block_size,
1770 "moe_block_size {moe_block_size} exceeds scratch worst-case {max_block_size}"
1771 );
1772 let sorted_max_size = batch * top_k + num_experts * moe_block_size;
1780 let vllm_routing_owned: Option<ferrum_kernels::backend::MoeRouting<B>> =
1781 if use_vllm_moe && !use_device_route {
1782 let plan = plan.expect("plan is Some when host vllm builder runs");
1783 let mut padded_offsets = Vec::with_capacity(num_experts + 1);
1784 let mut acc = 0usize;
1785 for e in 0..num_experts {
1786 padded_offsets.push(acc);
1787 let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1788 let pe = m_e.div_ceil(moe_block_size) * moe_block_size;
1789 acc += pe;
1790 }
1791 padded_offsets.push(acc);
1792 let total_padded = acc;
1793 let total_blocks = total_padded / moe_block_size;
1794 let sentinel = total_pairs_active as i32;
1795
1796 let mut sorted_token_ids = vec![sentinel; total_padded];
1797 let mut expert_ids = vec![0i32; total_blocks];
1798 for e in 0..num_experts {
1799 let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1800 if m_e == 0 {
1801 continue;
1802 }
1803 let p_off = padded_offsets[e];
1804 let real_off = plan.expert_offsets[e];
1805 for i in 0..m_e {
1806 sorted_token_ids[p_off + i] = (real_off + i) as i32;
1807 }
1808 let blocks_for_e = (padded_offsets[e + 1] - p_off) / moe_block_size;
1809 let block_start = p_off / moe_block_size;
1810 for b in 0..blocks_for_e {
1811 expert_ids[block_start + b] = e as i32;
1812 }
1813 }
1814 let num_tokens_past_padded = vec![total_padded as i32];
1815 Some(B::upload_moe_routing(
1816 ctx,
1817 &sorted_token_ids,
1818 &expert_ids,
1819 &num_tokens_past_padded,
1820 )?)
1821 } else {
1822 None
1823 };
1824
1825 if use_device_route {
1829 let dr = dr_kept
1830 .as_mut()
1831 .expect("dr_kept is Some when use_device_route");
1832 if use_vllm_pair_ids {
1833 B::moe_align_block_size_pair_ids(
1834 ctx,
1835 dr.selected_ids,
1836 dr.sorted_tokens,
1837 dr.block_ids,
1838 dr.total_post_pad,
1839 batch * top_k,
1840 num_experts,
1841 moe_block_size,
1842 sorted_max_size,
1843 )?;
1844 } else {
1845 B::moe_align_block_size(
1846 ctx,
1847 dr.selected_ids,
1848 dr.sorted_tokens,
1849 dr.block_ids,
1850 dr.total_post_pad,
1851 batch * top_k,
1852 num_experts,
1853 moe_block_size,
1854 sorted_max_size,
1855 )?;
1856 }
1857 }
1858
1859 let vllm_refs: Option<(&B::Buffer, &B::Buffer, &B::Buffer)> = if use_device_route {
1864 let dr = dr_kept
1865 .as_ref()
1866 .expect("dr_kept is Some when use_device_route");
1867 Some((&*dr.sorted_tokens, &*dr.block_ids, &*dr.total_post_pad))
1868 } else if let Some(r) = vllm_routing_owned.as_ref() {
1869 Some((
1870 &r.sorted_token_ids,
1871 &r.expert_ids,
1872 &r.num_tokens_past_padded,
1873 ))
1874 } else {
1875 None
1876 };
1877
1878 let phase1_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
1882 let plan = plan.expect("plan is Some when batched GEMM path runs");
1883 let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
1884 for e in 0..num_experts {
1885 let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1886 if m_e == 0 {
1887 continue;
1888 }
1889 let pair_off = plan.expert_offsets[e];
1890 v.push((e, pair_off, pair_off, m_e));
1891 }
1892 v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
1893 v
1894 } else {
1895 Vec::new()
1896 };
1897 let t_gemm1 = if prof {
1898 Some(std::time::Instant::now())
1899 } else {
1900 None
1901 };
1902 if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
1903 if use_vllm_pair_ids {
1905 gu_store.gemm_phase_vllm(
1906 ctx,
1907 x,
1908 sorted_tokens,
1909 block_ids,
1910 total_post_pad,
1911 gate_up_packed,
1912 batch,
1913 moe_block_size,
1914 top_k,
1915 )?;
1916 } else {
1917 gu_store.gemm_phase_vllm(
1918 ctx,
1919 x_packed,
1920 sorted_tokens,
1921 block_ids,
1922 total_post_pad,
1923 gate_up_packed,
1924 total_pairs_active,
1925 moe_block_size,
1926 1, )?;
1928 }
1929 } else {
1930 gu_store.gemm_phase_batched(
1931 ctx,
1932 x_packed,
1933 &phase1_dispatches,
1934 gate_up_packed,
1935 hidden_size,
1936 )?;
1937 }
1938 if let Some(t) = t_gemm1 {
1939 B::sync(ctx);
1940 MOE_BUCKET_GEMM1_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1941 }
1942
1943 let total_pairs_active = batch * top_k;
1949 let t_silu = if prof {
1950 Some(std::time::Instant::now())
1951 } else {
1952 None
1953 };
1954 B::fused_silu_mul_split(
1955 ctx,
1956 gate_up_packed,
1957 silu_packed,
1958 total_pairs_active,
1959 expert_intermediate,
1960 );
1961 if let Some(t) = t_silu {
1962 B::sync(ctx);
1963 MOE_BUCKET_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1964 }
1965
1966 let d_store = experts.down_stacked_store(0).ok_or_else(|| {
1968 FerrumError::model(
1969 "moe_forward_bucketed requires stacked down store \
1970 (load via Qwen3MoeModel::new_safetensors)",
1971 )
1972 })?;
1973 if zero_marlin_workspace {
1974 let _ = d_store.zero_workspace(ctx);
1975 }
1976 let phase3_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
1977 let plan = plan.expect("plan is Some when batched GEMM path runs");
1978 let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
1979 for e in 0..num_experts {
1980 let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1981 if m_e == 0 {
1982 continue;
1983 }
1984 let pair_off = plan.expert_offsets[e];
1985 v.push((e, pair_off, pair_off, m_e));
1986 }
1987 v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
1988 v
1989 } else {
1990 Vec::new()
1991 };
1992 let t_gemm3 = if prof {
1993 Some(std::time::Instant::now())
1994 } else {
1995 None
1996 };
1997 if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
1998 d_store.gemm_phase_vllm(
1999 ctx,
2000 silu_packed,
2001 sorted_tokens,
2002 block_ids,
2003 total_post_pad,
2004 down_packed,
2005 total_pairs_active,
2006 moe_block_size,
2007 1,
2008 )?;
2009 } else {
2010 d_store.gemm_phase_batched(
2011 ctx,
2012 silu_packed,
2013 &phase3_dispatches,
2014 down_packed,
2015 expert_intermediate,
2016 )?;
2017 }
2018 if let Some(t) = t_gemm3 {
2019 B::sync(ctx);
2020 MOE_BUCKET_GEMM3_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
2021 }
2022
2023 let total_pairs = batch * top_k;
2040 let t_comb = if prof {
2041 Some(std::time::Instant::now())
2042 } else {
2043 None
2044 };
2045 if use_vllm_pair_ids {
2046 let dr = dr_kept
2047 .as_ref()
2048 .expect("dr_kept is Some when use_vllm_pair_ids");
2049 B::weighted_sum_batched(
2050 ctx,
2051 down_packed,
2052 dr.pair_weights,
2053 out,
2054 batch,
2055 top_k,
2056 hidden_size,
2057 )?;
2058 } else {
2059 let (pairs_ref, weights_ref);
2060 let _pairs_owned;
2061 let _weights_owned;
2062 if let Some(ref dr) = dr_kept {
2063 pairs_ref = &*dr.pairs_by_token;
2064 weights_ref = &*dr.pair_weights;
2065 } else {
2066 let plan = plan.expect("plan is Some when host moe_combine runs");
2067 _pairs_owned = B::from_slice_typed::<i32>(&plan.pairs_by_token);
2068 _weights_owned = B::from_slice_typed::<f32>(&plan.pair_weights);
2069 pairs_ref = &_pairs_owned;
2070 weights_ref = &_weights_owned;
2071 }
2072 B::moe_combine(
2073 ctx,
2074 down_packed,
2075 pairs_ref,
2076 weights_ref,
2077 out,
2078 batch,
2079 hidden_size,
2080 top_k,
2081 total_pairs,
2082 );
2083 }
2084 if let Some(t) = t_comb {
2085 B::sync(ctx);
2086 MOE_BUCKET_COMBINE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
2087 }
2088
2089 Ok(())
2090}
2091
2092pub fn moe_forward_cpu(
2107 x: &[f32],
2108 batch: usize,
2109 hidden_size: usize,
2110 expert_intermediate: usize,
2111 top_k: usize,
2112 router: &RouterOutput,
2113 experts: &ExpertStack<CpuBackend>,
2114 out: &mut Vec<f32>,
2115) -> Result<()> {
2116 let n_experts = experts.num_experts();
2117
2118 if x.len() != batch * hidden_size {
2119 return Err(FerrumError::model(format!(
2120 "moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
2121 x.len(),
2122 batch,
2123 hidden_size,
2124 batch * hidden_size
2125 )));
2126 }
2127 if router.expert_ids.len() != batch * top_k {
2128 return Err(FerrumError::model(format!(
2129 "moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
2130 router.expert_ids.len(),
2131 batch,
2132 top_k,
2133 batch * top_k
2134 )));
2135 }
2136
2137 out.clear();
2138 out.resize(batch * hidden_size, 0.0);
2139
2140 let mut ctx = <CpuBackend as Backend>::new_context();
2141 let mut x_b: Vec<f32> = vec![0.0; hidden_size];
2142 let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
2143 let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
2144 let mut down_out: Vec<f32> = vec![0.0; hidden_size];
2145
2146 for b in 0..batch {
2147 x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
2148
2149 for k in 0..top_k {
2150 let pair_idx = b * top_k + k;
2151 let expert_id = router.expert_ids[pair_idx] as usize;
2152 let weight = router.expert_weights[pair_idx];
2153
2154 if expert_id >= n_experts {
2155 return Err(FerrumError::model(format!(
2156 "moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
2157 )));
2158 }
2159
2160 experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
2162
2163 <CpuBackend as Backend>::fused_silu_mul_split(
2165 &mut ctx,
2166 &gate_up_buf,
2167 &mut silu_mul_buf,
2168 1,
2169 expert_intermediate,
2170 );
2171
2172 experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
2174
2175 let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
2179 for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
2180 *o += weight * *d;
2181 }
2182 }
2183 }
2184
2185 Ok(())
2186}
2187
2188fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
2189 if actual != expected {
2190 return Err(FerrumError::model(format!(
2191 "ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
2192 )));
2193 }
2194 Ok(())
2195}
2196
2197fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
2201 let info = gguf.tensor_info(name).ok_or_else(|| {
2202 FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
2203 })?;
2204 Ok(match info.ggml_dtype {
2205 GgmlDType::Q4K => Some(GgufQuantType::Q4K),
2206 GgmlDType::Q6K => Some(GgufQuantType::Q6K),
2207 _ => None,
2208 })
2209}
2210
2211fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
2216 const QK_K: usize = 256;
2217 if n_elems % QK_K != 0 {
2218 return Err(FerrumError::model(format!(
2219 "ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
2220 )));
2221 }
2222 let block_bytes = match kind {
2223 GgufQuantType::Q4K => 144,
2224 GgufQuantType::Q6K => 210,
2225 other => {
2228 return Err(FerrumError::model(format!(
2229 "ExpertStack {label}: unsupported k-quant flavour {other:?}"
2230 )))
2231 }
2232 };
2233 Ok((n_elems / QK_K) * block_bytes)
2234}
2235
2236fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
2237 let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
2238 let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
2239 let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
2240 flat.to_vec1::<f32>().map_err(candle_to_ferrum)
2241}
2242
2243fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
2244 FerrumError::model(format!("candle: {e}"))
2245}
2246
2247#[allow(dead_code)]
2250type _CandleResult<T> = CandleResult<T>;
2251
2252#[cfg(test)]
2253mod tests {
2254 use super::{pick_moe_block_size_with_config, MoeDispatchRuntimeConfig};
2255
2256 #[test]
2257 fn moe_dispatch_runtime_config_parses_m3_startup_knobs() {
2258 let config = MoeDispatchRuntimeConfig::from_env_vars([
2259 ("FERRUM_MOE_PROFILE", "0"),
2260 ("FERRUM_DECODE_OP_PROFILE", "true"),
2261 ("FERRUM_VLLM_MOE_ZERO_WS", "1"),
2262 ("FERRUM_VLLM_MOE_PAIR_IDS", "1"),
2263 ("FERRUM_MOE_LOAD_TRACE", ""),
2264 ("FERRUM_MOE_BLOCK_SIZE", "8"),
2265 ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2266 ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "2048"),
2267 ("FERRUM_VLLM_MOE", "1"),
2268 ("FERRUM_MOE_HOST_ROUTE", "1"),
2269 ]);
2270
2271 assert!(config.moe_profile);
2272 assert!(config.decode_op_profile);
2273 assert!(config.vllm_moe_zero_ws);
2274 assert!(config.vllm_moe_pair_ids);
2275 assert!(config.moe_load_trace);
2276 assert_eq!(config.moe_block_size, Some(8));
2277 assert_eq!(config.moe_large_m_block_size, Some(64));
2278 assert_eq!(config.moe_large_m_min_pairs, 2048);
2279 assert!(config.vllm_moe);
2280 assert!(config.moe_host_route);
2281 }
2282
2283 #[test]
2284 fn moe_dispatch_runtime_config_bounds_invalid_block_values() {
2285 let config = MoeDispatchRuntimeConfig::from_env_vars([
2286 ("FERRUM_MOE_BLOCK_SIZE", "12"),
2287 ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "128"),
2288 ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "bad"),
2289 ("FERRUM_VLLM_MOE_ZERO_WS", "true"),
2290 ("FERRUM_MOE_HOST_ROUTE", "0"),
2291 ]);
2292
2293 assert_eq!(config.moe_block_size, None);
2294 assert_eq!(config.moe_large_m_block_size, None);
2295 assert_eq!(config.moe_large_m_min_pairs, 1024);
2296 assert!(!config.vllm_moe_zero_ws);
2297 assert!(!config.moe_host_route);
2298 }
2299
2300 #[test]
2301 fn device_route_large_m_block_size_is_thresholded() {
2302 let config = MoeDispatchRuntimeConfig::from_env_vars([
2303 ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2304 ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
2305 ]);
2306
2307 assert_eq!(
2308 pick_moe_block_size_with_config(&config, None, 128, true, 256),
2309 16
2310 );
2311 assert_eq!(
2312 pick_moe_block_size_with_config(&config, None, 128, true, 1024),
2313 64
2314 );
2315 }
2316
2317 #[test]
2318 fn global_moe_block_size_override_still_wins() {
2319 let config = MoeDispatchRuntimeConfig::from_env_vars([
2320 ("FERRUM_MOE_BLOCK_SIZE", "32"),
2321 ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2322 ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
2323 ]);
2324
2325 assert_eq!(
2326 pick_moe_block_size_with_config(&config, None, 128, true, 2048),
2327 32
2328 );
2329 }
2330}