Skip to main content

ferrum_models/moe/
dispatch.rs

1//! Expert dispatch — load per-layer expert weights from a GGUF file and run
2//! the per-token MoE forward (top-K experts per token, weighted combine).
3//!
4//! Phase 2 ships a CPU-only implementation (`moe_forward_cpu`). The
5//! algorithm is:
6//!
7//! ```text
8//! for each token b in batch:
9//!     route token b → (expert_ids[K], weights[K])
10//!     out[b] = 0
11//!     for each (expert_id, weight) pair:
12//!         gate_up = experts.gate_up[expert_id].forward(x[b])     # [2*ffn]
13//!         silu_mul = silu(gate_up[..ffn]) * gate_up[ffn..]       # [ffn]
14//!         contribution = experts.down[expert_id].forward(silu_mul) # [hidden]
15//!         out[b] += weight * contribution
16//! ```
17//!
18//! The fused `gate || up` per-expert layout means we can call
19//! `Backend::fused_silu_mul_split` directly on the projection's output
20//! — same kernel ferrum already uses for dense Llama-family models.
21
22use std::path::Path;
23use std::sync::{
24    atomic::{AtomicU64, Ordering},
25    OnceLock,
26};
27
28use candle_core::quantized::GgmlDType;
29use candle_core::{Device, Result as CandleResult};
30use ferrum_kernels::backend::cpu::CpuBackend;
31use ferrum_kernels::backend::{
32    Backend, BackendMoeFused, BackendPagedKv, BackendQuantGguf, BackendQuantMarlin, GgufQuantType,
33    LlmBackend, QuantLlmBackend,
34};
35use ferrum_kernels::{Linear, StackedExpertGgufLinear};
36use ferrum_quantization::gguf::GgufFile;
37use ferrum_quantization::{DenseLinear, QuantLinear};
38use ferrum_types::{FerrumError, Result};
39
40use crate::moe::router::RouterOutput;
41
42/// MoE per-op timers. Public so the model wrapper can drain + print at
43/// end of decode. Times are in microseconds, atomically accumulated.
44/// Toggle via env `FERRUM_MOE_PROFILE=1`.
45pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
46pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
47pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
48pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
49pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
50pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
51pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
52pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
53pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
54pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
55pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
56pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
57pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
58pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
59
60// Bucketed-path per-phase timers. Drained by the model wrapper alongside
61// the per-pair counters above. Same `FERRUM_MOE_PROFILE=1` gate.
62pub static MOE_BUCKET_SYNC_US: AtomicU64 = AtomicU64::new(0);
63pub static MOE_BUCKET_D2H_US: AtomicU64 = AtomicU64::new(0);
64pub static MOE_BUCKET_ROUTE_US: AtomicU64 = AtomicU64::new(0);
65pub static MOE_BUCKET_PLAN_US: AtomicU64 = AtomicU64::new(0);
66pub static MOE_BUCKET_GATHER_US: AtomicU64 = AtomicU64::new(0);
67pub static MOE_BUCKET_GEMM1_US: AtomicU64 = AtomicU64::new(0);
68pub static MOE_BUCKET_SILU_US: AtomicU64 = AtomicU64::new(0);
69pub static MOE_BUCKET_GEMM3_US: AtomicU64 = AtomicU64::new(0);
70pub static MOE_BUCKET_COMBINE_US: AtomicU64 = AtomicU64::new(0);
71pub static MOE_BUCKET_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
72
73#[derive(Debug, Clone, PartialEq, Eq)]
74struct MoeDispatchRuntimeConfig {
75    moe_profile: bool,
76    decode_op_profile: bool,
77    vllm_moe_zero_ws: bool,
78    vllm_moe_pair_ids: bool,
79    moe_load_trace: bool,
80    moe_block_size: Option<usize>,
81    moe_large_m_block_size: Option<usize>,
82    moe_large_m_min_pairs: usize,
83    vllm_moe: bool,
84    moe_host_route: bool,
85}
86
87impl Default for MoeDispatchRuntimeConfig {
88    fn default() -> Self {
89        Self {
90            moe_profile: false,
91            decode_op_profile: false,
92            vllm_moe_zero_ws: false,
93            vllm_moe_pair_ids: false,
94            moe_load_trace: false,
95            moe_block_size: None,
96            moe_large_m_block_size: None,
97            moe_large_m_min_pairs: 1024,
98            vllm_moe: false,
99            moe_host_route: false,
100        }
101    }
102}
103
104impl MoeDispatchRuntimeConfig {
105    fn from_env() -> Self {
106        Self::from_env_vars(std::env::vars())
107    }
108
109    fn from_env_vars<I, K, V>(vars: I) -> Self
110    where
111        I: IntoIterator<Item = (K, V)>,
112        K: AsRef<str>,
113        V: AsRef<str>,
114    {
115        let mut config = Self::default();
116        for (name, value) in vars {
117            let value = value.as_ref();
118            match name.as_ref() {
119                "FERRUM_MOE_PROFILE" => config.moe_profile = true,
120                "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
121                "FERRUM_VLLM_MOE_ZERO_WS" => config.vllm_moe_zero_ws = value == "1",
122                "FERRUM_VLLM_MOE_PAIR_IDS" => config.vllm_moe_pair_ids = value == "1",
123                "FERRUM_MOE_LOAD_TRACE" => config.moe_load_trace = true,
124                "FERRUM_MOE_BLOCK_SIZE" => {
125                    config.moe_block_size = parse_moe_block_size_value(value);
126                }
127                "FERRUM_MOE_LARGE_M_BLOCK_SIZE" => {
128                    config.moe_large_m_block_size = parse_moe_block_size_value(value);
129                }
130                "FERRUM_MOE_LARGE_M_MIN_PAIRS" => {
131                    config.moe_large_m_min_pairs = value.parse::<usize>().unwrap_or(1024);
132                }
133                "FERRUM_VLLM_MOE" => config.vllm_moe = value == "1",
134                "FERRUM_MOE_HOST_ROUTE" => config.moe_host_route = value == "1",
135                _ => {}
136            }
137        }
138        config
139    }
140}
141
142fn parse_moe_block_size_value(value: &str) -> Option<usize> {
143    value
144        .parse::<usize>()
145        .ok()
146        .filter(|bs| matches!(*bs, 8 | 16 | 32 | 48 | 64))
147}
148
149fn moe_dispatch_runtime_config() -> &'static MoeDispatchRuntimeConfig {
150    static CONFIG: OnceLock<MoeDispatchRuntimeConfig> = OnceLock::new();
151    CONFIG.get_or_init(MoeDispatchRuntimeConfig::from_env)
152}
153
154fn moe_profile_enabled() -> bool {
155    moe_dispatch_runtime_config().moe_profile
156}
157
158/// Per-layer expert weights, materialised as `[num_experts]`-long vectors
159/// of `Box<dyn Linear<B>>`. Each entry runs the corresponding expert's
160/// fused `[gate; up]` projection or its `down` projection.
161///
162/// `B::Buffer` is hidden behind `Linear<B>` so this struct is generic
163/// over backend. Production (`Qwen3MoeModel::forward`) dispatches through
164/// the generic [`moe_forward<B>`] (this file, line ~960) and
165/// [`moe_forward_bucketed<B>`]; the CPU-only `moe_forward_cpu` is the
166/// reference path used by parity tests + `Qwen3MoeLayer::forward_cpu`.
167pub struct ExpertStack<B: QuantLlmBackend + BackendMoeFused> {
168    /// Fused `[gate; up]` projection per expert. Output shape per token:
169    /// `[2 * expert_intermediate]` — the lower half is gate, upper is up.
170    pub gate_up: Vec<Box<dyn Linear<B>>>,
171    /// `down` projection per expert. Output shape per token: `[hidden_size]`.
172    pub down: Vec<Box<dyn Linear<B>>>,
173    /// Stacked-experts representation for backends that have a batched
174    /// MoE indirect-dispatch kernel (Metal `gemv_q4kw_moe_id_f32` /
175    /// `gemv_q6kw_moe_id_f32`). Holds **all experts** for one matmul
176    /// role behind a `StackedExpertGgufLinear<B>` (typically backed by a
177    /// single GPU buffer with byte stride between expert slabs), so a
178    /// single dispatch can cover all selected (token, expert) pairs at
179    /// decode m=1.
180    ///
181    /// `None` on backends without the kernel (CPU, CUDA-without-MoE-kernel)
182    /// and on quant flavours that don't have a stacked path yet — callers
183    /// fall back to the per-expert `gate_up` / `down` Linears in those
184    /// cases.
185    pub gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
186    pub up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
187    pub down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
188
189    /// Stacked Marlin GPTQ expert tiles for the bucketed CUDA path.
190    /// When both are Some, [`moe_forward_bucketed`] dispatches expert
191    /// GEMMs through trait-object methods (`store.gemm_phase_*` /
192    /// `store.zero_workspace`). None on CPU / Metal / GGUF.
193    ///
194    /// Phase C step 3: replaces `Option<Arc<B::GptqStore>>` with a
195    /// `Box<dyn MarlinExpertStack<B>>` trait object — kills the
196    /// `type GptqStore` leak through the model layer.
197    pub gate_up_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
198    pub down_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
199}
200
201impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
202    /// Returns the shared stacked Marlin expert tile for `gate_up` if
203    /// loaded via the bucketed/Marlin path. Used by
204    /// [`moe_forward_bucketed`].
205    pub fn gate_up_stacked_store(
206        &self,
207        _expert_idx: usize,
208    ) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
209        self.gate_up_marlin_stack.as_ref()
210    }
211
212    /// Same for `down`.
213    pub fn down_stacked_store(
214        &self,
215        _expert_idx: usize,
216    ) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
217        self.down_marlin_stack.as_ref()
218    }
219
220    // ── MoE GEMV dispatch (hides B::QuantStore + in_stride from callers) ──
221    //
222    // These wrap `B::gemv_quant_moe_id*` so the MoE forward path goes
223    // through the ExpertStack abstraction instead of reaching into
224    // `self.gate_stacked` / `self.up_stacked` / `self.down_stacked`
225    // directly. The weight + correct in_stride are picked from `self`,
226    // so callers only pass activations + routing + scratch out.
227
228    /// Gate projection: `out_stacked[k] = gate_weight[expert_id[k]] · input`,
229    /// broadcast input across all top_k slots.
230    pub fn gemv_gate(
231        &self,
232        ctx: &mut B::Context,
233        input: &B::Buffer,
234        ids: &B::Buffer,
235        out: &mut B::Buffer,
236        top_k: usize,
237    ) -> Result<()> {
238        let weight = self.gate_stacked.as_deref().ok_or_else(|| {
239            FerrumError::unsupported("ExpertStack::gemv_gate: gate_stacked not loaded")
240        })?;
241        weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
242    }
243
244    /// Up projection: same shape as gate, broadcast input.
245    pub fn gemv_up(
246        &self,
247        ctx: &mut B::Context,
248        input: &B::Buffer,
249        ids: &B::Buffer,
250        out: &mut B::Buffer,
251        top_k: usize,
252    ) -> Result<()> {
253        let weight = self.up_stacked.as_deref().ok_or_else(|| {
254            FerrumError::unsupported("ExpertStack::gemv_up: up_stacked not loaded")
255        })?;
256        weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
257    }
258
259    /// Down projection: per-slot input via `in_stride = expert_intermediate`.
260    /// Caller's `input` is the SiLU-mul stacked output (`top_k × inter` floats).
261    pub fn gemv_down(
262        &self,
263        ctx: &mut B::Context,
264        input: &B::Buffer,
265        ids: &B::Buffer,
266        out: &mut B::Buffer,
267        top_k: usize,
268        expert_intermediate: usize,
269    ) -> Result<()> {
270        let weight = self.down_stacked.as_deref().ok_or_else(|| {
271            FerrumError::unsupported("ExpertStack::gemv_down: down_stacked not loaded")
272        })?;
273        weight.gemv_moe_id(ctx, input, ids, out, top_k, expert_intermediate)
274    }
275
276    /// Fused gate + up + SiLU·gate: replaces 3 separate dispatches with 1.
277    /// Backend must support the fused path
278    /// (`B::supports_fused_moe_gate_up_silu()`); caller checks first.
279    pub fn gemv_gate_up_silu_fused(
280        &self,
281        ctx: &mut B::Context,
282        input: &B::Buffer,
283        ids: &B::Buffer,
284        out_silu_stacked: &mut B::Buffer,
285        top_k: usize,
286    ) -> Result<()> {
287        let gate = self.gate_stacked.as_deref().ok_or_else(|| {
288            FerrumError::unsupported(
289                "ExpertStack::gemv_gate_up_silu_fused: gate_stacked not loaded",
290            )
291        })?;
292        let up = self.up_stacked.as_deref().ok_or_else(|| {
293            FerrumError::unsupported("ExpertStack::gemv_gate_up_silu_fused: up_stacked not loaded")
294        })?;
295        gate.gemv_moe_id_gate_up_silu(ctx, input, up, ids, out_silu_stacked, top_k)
296    }
297
298    // ── Prefill GEMM dispatch (Phase 3d) ──
299    //
300    // Same role as the gemv wrappers above, but for the m>1 path that
301    // emits batched mul_mm_id instead of per-pair gemv. `args_buf`
302    // toggles between direct (`gemm_quant_moe_id`) and indirect-grid
303    // (`gemm_quant_moe_id_indirect`) dispatch — the indirect form lets
304    // `compute_ids_tpe_gpu` produce a tighter grid sized to `max(tpe[e])`.
305    //
306    // ne11 is fixed by role: gate/up = 1 (broadcast across slots),
307    // down = top_k (per-slot src1 read). Callers no longer pass it.
308
309    /// Gate prefill GEMM. `dst` shape: `[batch, top_k, expert_inter]`.
310    /// `args_buf=Some` triggers indirect-grid dispatch.
311    #[allow(clippy::too_many_arguments)]
312    pub fn gemm_gate(
313        &self,
314        ctx: &mut B::Context,
315        src1: &B::Buffer,
316        ids: &B::Buffer,
317        tpe: &B::Buffer,
318        dst: &mut B::Buffer,
319        args_buf: Option<&B::Buffer>,
320        top_k: usize,
321        max_per_expert: usize,
322        tokens: usize,
323    ) -> Result<()> {
324        let weight = self.gate_stacked.as_deref().ok_or_else(|| {
325            FerrumError::unsupported("ExpertStack::gemm_gate: gate_stacked not loaded")
326        })?;
327        match args_buf {
328            Some(args) => weight.gemm_moe_id_indirect(
329                ctx,
330                src1,
331                ids,
332                tpe,
333                dst,
334                args,
335                1,
336                top_k,
337                max_per_expert,
338                tokens,
339            ),
340            None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
341        }
342    }
343
344    /// Up prefill GEMM. Same shape contract as [`Self::gemm_gate`].
345    #[allow(clippy::too_many_arguments)]
346    pub fn gemm_up(
347        &self,
348        ctx: &mut B::Context,
349        src1: &B::Buffer,
350        ids: &B::Buffer,
351        tpe: &B::Buffer,
352        dst: &mut B::Buffer,
353        args_buf: Option<&B::Buffer>,
354        top_k: usize,
355        max_per_expert: usize,
356        tokens: usize,
357    ) -> Result<()> {
358        let weight = self.up_stacked.as_deref().ok_or_else(|| {
359            FerrumError::unsupported("ExpertStack::gemm_up: up_stacked not loaded")
360        })?;
361        match args_buf {
362            Some(args) => weight.gemm_moe_id_indirect(
363                ctx,
364                src1,
365                ids,
366                tpe,
367                dst,
368                args,
369                1,
370                top_k,
371                max_per_expert,
372                tokens,
373            ),
374            None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
375        }
376    }
377
378    /// Down prefill GEMM. `dst` shape: `[batch, top_k, hidden]`.
379    /// ne11=top_k (per-slot src1 read from `silu_stacked[batch, top_k, inter]`).
380    #[allow(clippy::too_many_arguments)]
381    pub fn gemm_down(
382        &self,
383        ctx: &mut B::Context,
384        src1: &B::Buffer,
385        ids: &B::Buffer,
386        tpe: &B::Buffer,
387        dst: &mut B::Buffer,
388        args_buf: Option<&B::Buffer>,
389        top_k: usize,
390        max_per_expert: usize,
391        tokens: usize,
392    ) -> Result<()> {
393        let weight = self.down_stacked.as_deref().ok_or_else(|| {
394            FerrumError::unsupported("ExpertStack::gemm_down: down_stacked not loaded")
395        })?;
396        match args_buf {
397            Some(args) => weight.gemm_moe_id_indirect(
398                ctx,
399                src1,
400                ids,
401                tpe,
402                dst,
403                args,
404                top_k,
405                top_k,
406                max_per_expert,
407                tokens,
408            ),
409            None => weight.gemm_moe_id(
410                ctx,
411                src1,
412                ids,
413                tpe,
414                dst,
415                top_k,
416                top_k,
417                max_per_expert,
418                tokens,
419            ),
420        }
421    }
422
423    // ── Batched-decode GEMV dispatch (Phase 3d) ──
424    //
425    // For the small-m batched-decode range (c=2..32). Single Metal
426    // launch covers all m*top_k (token, expert) pairs.
427
428    /// Gate batched gemv: `dst[m * top_k]` with broadcast input
429    /// (slots within a token share the activation row).
430    #[allow(clippy::too_many_arguments)]
431    pub fn gemv_gate_batched(
432        &self,
433        ctx: &mut B::Context,
434        input: &B::Buffer,
435        ids: &B::Buffer,
436        dst: &mut B::Buffer,
437        m: usize,
438        top_k: usize,
439        src1_outer_stride: usize,
440        src1_inner_stride: usize,
441    ) -> Result<()> {
442        let weight = self.gate_stacked.as_deref().ok_or_else(|| {
443            FerrumError::unsupported("ExpertStack::gemv_gate_batched: gate_stacked not loaded")
444        })?;
445        weight.gemv_moe_id_batched(
446            ctx,
447            input,
448            ids,
449            dst,
450            m,
451            top_k,
452            src1_outer_stride,
453            src1_inner_stride,
454        )
455    }
456
457    /// Up batched gemv: same shape as [`Self::gemv_gate_batched`].
458    #[allow(clippy::too_many_arguments)]
459    pub fn gemv_up_batched(
460        &self,
461        ctx: &mut B::Context,
462        input: &B::Buffer,
463        ids: &B::Buffer,
464        dst: &mut B::Buffer,
465        m: usize,
466        top_k: usize,
467        src1_outer_stride: usize,
468        src1_inner_stride: usize,
469    ) -> Result<()> {
470        let weight = self.up_stacked.as_deref().ok_or_else(|| {
471            FerrumError::unsupported("ExpertStack::gemv_up_batched: up_stacked not loaded")
472        })?;
473        weight.gemv_moe_id_batched(
474            ctx,
475            input,
476            ids,
477            dst,
478            m,
479            top_k,
480            src1_outer_stride,
481            src1_inner_stride,
482        )
483    }
484
485    /// Down batched gemv: src1 = `silu_stacked[m, top_k, inter]` per-slot read.
486    #[allow(clippy::too_many_arguments)]
487    pub fn gemv_down_batched(
488        &self,
489        ctx: &mut B::Context,
490        input: &B::Buffer,
491        ids: &B::Buffer,
492        dst: &mut B::Buffer,
493        m: usize,
494        top_k: usize,
495        src1_outer_stride: usize,
496        src1_inner_stride: usize,
497    ) -> Result<()> {
498        let weight = self.down_stacked.as_deref().ok_or_else(|| {
499            FerrumError::unsupported("ExpertStack::gemv_down_batched: down_stacked not loaded")
500        })?;
501        weight.gemv_moe_id_batched(
502            ctx,
503            input,
504            ids,
505            dst,
506            m,
507            top_k,
508            src1_outer_stride,
509            src1_inner_stride,
510        )
511    }
512
513    /// Fused batched gate + up + SiLU·gate. Single dispatch over `m * top_k`
514    /// pairs. Caller gates on `B::supports_batched_moe_gate_up_silu()` first.
515    #[allow(clippy::too_many_arguments)]
516    pub fn gemv_gate_up_silu_batched_fused(
517        &self,
518        ctx: &mut B::Context,
519        input: &B::Buffer,
520        ids: &B::Buffer,
521        silu_out: &mut B::Buffer,
522        m: usize,
523        top_k: usize,
524        src1_outer_stride: usize,
525        src1_inner_stride: usize,
526    ) -> Result<()> {
527        let gate = self.gate_stacked.as_deref().ok_or_else(|| {
528            FerrumError::unsupported(
529                "ExpertStack::gemv_gate_up_silu_batched_fused: gate_stacked not loaded",
530            )
531        })?;
532        let up = self.up_stacked.as_deref().ok_or_else(|| {
533            FerrumError::unsupported(
534                "ExpertStack::gemv_gate_up_silu_batched_fused: up_stacked not loaded",
535            )
536        })?;
537        gate.gemv_moe_id_gate_up_silu_batched(
538            ctx,
539            input,
540            up,
541            ids,
542            silu_out,
543            m,
544            top_k,
545            src1_outer_stride,
546            src1_inner_stride,
547        )
548    }
549
550    // ── Per-item offset GEMV (Phase 3d, qwen3_moe.rs decode path) ──
551    //
552    // Used by the per-item batched-decode loop in `Qwen3MoeModel::forward`
553    // when offset variants are supported. Reads `src1` at `src1_offset`
554    // floats and `ids` at `ids_offset` ids, writes `dst` from offset 0.
555
556    /// Gate offset gemv. `src1_stride=0` → broadcast.
557    #[allow(clippy::too_many_arguments)]
558    pub fn gemv_gate_offset(
559        &self,
560        ctx: &mut B::Context,
561        src1: &B::Buffer,
562        src1_offset: usize,
563        ids: &B::Buffer,
564        ids_offset: usize,
565        dst: &mut B::Buffer,
566        top_k: usize,
567        src1_stride: usize,
568    ) -> Result<()> {
569        let weight = self.gate_stacked.as_deref().ok_or_else(|| {
570            FerrumError::unsupported("ExpertStack::gemv_gate_offset: gate_stacked not loaded")
571        })?;
572        weight.gemv_moe_id_offset(
573            ctx,
574            src1,
575            src1_offset,
576            ids,
577            ids_offset,
578            dst,
579            top_k,
580            src1_stride,
581        )
582    }
583
584    /// Up offset gemv.
585    #[allow(clippy::too_many_arguments)]
586    pub fn gemv_up_offset(
587        &self,
588        ctx: &mut B::Context,
589        src1: &B::Buffer,
590        src1_offset: usize,
591        ids: &B::Buffer,
592        ids_offset: usize,
593        dst: &mut B::Buffer,
594        top_k: usize,
595        src1_stride: usize,
596    ) -> Result<()> {
597        let weight = self.up_stacked.as_deref().ok_or_else(|| {
598            FerrumError::unsupported("ExpertStack::gemv_up_offset: up_stacked not loaded")
599        })?;
600        weight.gemv_moe_id_offset(
601            ctx,
602            src1,
603            src1_offset,
604            ids,
605            ids_offset,
606            dst,
607            top_k,
608            src1_stride,
609        )
610    }
611
612    /// Down offset gemv.
613    #[allow(clippy::too_many_arguments)]
614    pub fn gemv_down_offset(
615        &self,
616        ctx: &mut B::Context,
617        src1: &B::Buffer,
618        src1_offset: usize,
619        ids: &B::Buffer,
620        ids_offset: usize,
621        dst: &mut B::Buffer,
622        top_k: usize,
623        src1_stride: usize,
624    ) -> Result<()> {
625        let weight = self.down_stacked.as_deref().ok_or_else(|| {
626            FerrumError::unsupported("ExpertStack::gemv_down_offset: down_stacked not loaded")
627        })?;
628        weight.gemv_moe_id_offset(
629            ctx,
630            src1,
631            src1_offset,
632            ids,
633            ids_offset,
634            dst,
635            top_k,
636            src1_stride,
637        )
638    }
639}
640
641impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
642    /// Build from raw fp32 stacked tensors (test helper). Caller has
643    /// already dequantised and laid out the data:
644    ///   `gate_stack`: `[num_experts * expert_inter * hidden]`
645    ///   `up_stack`:   `[num_experts * expert_inter * hidden]`
646    ///   `down_stack`: `[num_experts * hidden * expert_inter]`
647    /// Each per-expert slice is row-major in the natural Linear shape.
648    pub fn from_dense_stacks(
649        gate_stack: &[f32],
650        up_stack: &[f32],
651        down_stack: &[f32],
652        num_experts: usize,
653        hidden_size: usize,
654        expert_intermediate: usize,
655    ) -> Result<Self> {
656        let gate_up_per_expert = expert_intermediate * hidden_size;
657        let down_per_expert = hidden_size * expert_intermediate;
658
659        check_size(
660            gate_stack.len(),
661            num_experts * gate_up_per_expert,
662            "gate_stack",
663        )?;
664        check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
665        check_size(
666            down_stack.len(),
667            num_experts * down_per_expert,
668            "down_stack",
669        )?;
670
671        let mut gate_up = Vec::with_capacity(num_experts);
672        let mut down = Vec::with_capacity(num_experts);
673        for e in 0..num_experts {
674            let g_off = e * gate_up_per_expert;
675            let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
676            let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
677
678            // Fused [gate; up] is [2 * expert_inter, hidden] row-major.
679            // We concatenate row-blocks so the first expert_inter rows are
680            // gate, the next expert_inter rows are up — the layout
681            // fused_silu_mul_split expects.
682            let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
683            fused.extend_from_slice(g_slice);
684            fused.extend_from_slice(u_slice);
685            gate_up.push(Box::new(DenseLinear::<B>::from_rows(
686                &fused,
687                2 * expert_intermediate,
688                hidden_size,
689            )) as Box<dyn Linear<B>>);
690
691            let d_off = e * down_per_expert;
692            let d_slice = &down_stack[d_off..d_off + down_per_expert];
693            down.push(Box::new(DenseLinear::<B>::from_rows(
694                d_slice,
695                hidden_size,
696                expert_intermediate,
697            )) as Box<dyn Linear<B>>);
698        }
699        Ok(Self {
700            gate_up,
701            down,
702            gate_stacked: None,
703            up_stacked: None,
704            down_stacked: None,
705            gate_up_marlin_stack: None,
706            down_marlin_stack: None,
707        })
708    }
709
710    /// Load all experts for one MoE layer from a GGUF file. Names follow
711    /// the GGUF convention: `blk.{layer_idx}.ffn_{gate,up,down}_exps.weight`.
712    ///
713    /// The loader picks between two strategies based on the on-disk dtype
714    /// of the expert tensors:
715    ///
716    ///   - **Quantised path** (Q4_K / Q6_K only): each expert's
717    ///     `gate || up` becomes a single `QuantLinear<B>` (Fused
718    ///     QuantStore — gate + up share `n_cols = hidden`), and `down` is
719    ///     a plain `QuantLinear<B>`. Block bytes stay compressed in
720    ///     backend memory; per-call dequant happens inside `gemm_quant`.
721    ///   - **Dense fallback** (everything else, e.g. F32 / F16 / Q5_K
722    ///     until a kernel ships): eager-dequant to fp32 and wrap
723    ///     `DenseLinear<B>`. Memory inflates ~7× vs Q4_K_M but the
724    ///     algorithm is correctness-equivalent and this is the path the
725    ///     synthetic-MoE test fixtures need.
726    ///
727    /// The runtime dispatcher (`moe_forward<B>`) doesn't see which path
728    /// was taken — it just calls `Linear::forward` per (token, expert).
729    pub fn load_from_gguf(
730        gguf: &GgufFile,
731        layer_idx: usize,
732        num_experts: usize,
733        hidden_size: usize,
734        expert_intermediate: usize,
735    ) -> Result<Self> {
736        let runtime_config = moe_dispatch_runtime_config();
737        if let Some(quant) = Self::try_load_quantised(
738            gguf,
739            layer_idx,
740            num_experts,
741            hidden_size,
742            expert_intermediate,
743        )? {
744            if runtime_config.moe_load_trace {
745                eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
746            }
747            return Ok(quant);
748        }
749
750        if runtime_config.moe_load_trace {
751            eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
752        }
753
754        let device = Device::Cpu;
755        let gate = read_dequant_flat(
756            gguf,
757            &format!("blk.{layer_idx}.ffn_gate_exps.weight"),
758            &device,
759        )?;
760        let up = read_dequant_flat(
761            gguf,
762            &format!("blk.{layer_idx}.ffn_up_exps.weight"),
763            &device,
764        )?;
765        let down = read_dequant_flat(
766            gguf,
767            &format!("blk.{layer_idx}.ffn_down_exps.weight"),
768            &device,
769        )?;
770        // Eager-dense path leaves stacked variants as None — no MoE
771        // fast path for synthesised / non-quantised expert tensors.
772        Self::from_dense_stacks(
773            &gate,
774            &up,
775            &down,
776            num_experts,
777            hidden_size,
778            expert_intermediate,
779        )
780    }
781
782    /// Attempt the quantised path. Returns `Ok(None)` if any of the three
783    /// tensors isn't a supported k-quant flavour (Q4_K / Q6_K) or if the
784    /// shape doesn't match the expected per-expert tile size — caller
785    /// then takes the eager-dequant fallback. Returns `Err` only on a
786    /// genuine load failure (missing tensor, byte-count mismatch).
787    fn try_load_quantised(
788        gguf: &GgufFile,
789        layer_idx: usize,
790        num_experts: usize,
791        hidden_size: usize,
792        expert_intermediate: usize,
793    ) -> Result<Option<Self>> {
794        let device = Device::Cpu;
795
796        let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
797        let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
798        let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
799
800        // Inspect tensor info up front — if any tensor isn't a k-quant
801        // flavour the backend can dispatch on, bail to the dense path
802        // before paying the byte-read cost.
803        let gate_kind = match quant_kind(gguf, &gate_name)? {
804            Some(k) => k,
805            None => return Ok(None),
806        };
807        let up_kind = match quant_kind(gguf, &up_name)? {
808            Some(k) => k,
809            None => return Ok(None),
810        };
811        let down_kind = match quant_kind(gguf, &down_name)? {
812            Some(k) => k,
813            None => return Ok(None),
814        };
815
816        // Slice the three 3-D quantised expert stacks directly from
817        // the mmap. These are the dominant memory cost on Qwen3-MoE
818        // (~14 GB for Qwen3-30B-A3B); going through candle's
819        // `read_tensor` would copy them into a heap `Vec<u8>` first,
820        // then `load_quant_experts` would copy again into the Metal
821        // buffer — together doubling the working set and pushing a
822        // 32 GB Mac into swap. With this slice + the Metal mmap
823        // registry, we avoid both copies (steady state: just the
824        // file mmap).
825        let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
826            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
827        })?;
828        let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
829            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
830        })?;
831        let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
832            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
833        })?;
834        let _ = device; // candle device no longer needed for the byte read
835
836        // Per-expert byte stride for each tensor. The 3-D layout is
837        // contiguous, [num_experts, rows, cols] row-major, so each
838        // expert's slab is exactly `total_bytes / num_experts`.
839        let gate_per = block_bytes_for(
840            gate_kind,
841            expert_intermediate * hidden_size,
842            "ffn_gate_exps",
843        )?;
844        let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
845        let down_per = block_bytes_for(
846            down_kind,
847            hidden_size * expert_intermediate,
848            "ffn_down_exps",
849        )?;
850
851        check_size(
852            gate_bytes.len(),
853            num_experts * gate_per,
854            "ffn_gate_exps bytes",
855        )?;
856        check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
857        check_size(
858            down_bytes.len(),
859            num_experts * down_per,
860            "ffn_down_exps bytes",
861        )?;
862
863        // Try the stacked-experts fast path FIRST. If the backend has a
864        // batched MoE kernel (Metal `gemv_q*kw_moe_id_f32`), we want to
865        // hold the experts only as one big stacked buffer per role —
866        // not as 128 per-expert MetalQuantStores PLUS the stacked one
867        // (that would double-allocate ~17 GB on a 32 GB Mac, which on
868        // Qwen3-30B-A3B Q4_K_M sends the model into swap and tanks
869        // both load and forward time).
870        let gate_stacked = B::load_quant_experts(
871            gate_kind,
872            gate_bytes,
873            num_experts,
874            expert_intermediate,
875            hidden_size,
876        )
877        .ok();
878        let up_stacked = B::load_quant_experts(
879            up_kind,
880            up_bytes,
881            num_experts,
882            expert_intermediate,
883            hidden_size,
884        )
885        .ok();
886        let down_stacked = B::load_quant_experts(
887            down_kind,
888            down_bytes,
889            num_experts,
890            hidden_size,
891            expert_intermediate,
892        )
893        .ok();
894
895        // Decide the storage shape:
896        //   * Stacked-only (Metal MoE fast path): all three stacked
897        //     loaders succeeded — skip per-expert and use stacked
898        //     for both decode and prefill. Cuts memory in half.
899        //   * Per-expert: stacked path is incomplete or unsupported —
900        //     load 128-per-layer QuantLinears and let `moe_forward`
901        //     drive the per-(token, expert) loop on top of them.
902        let stacked_complete =
903            gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
904
905        let (gate_up, down) = if stacked_complete {
906            // No per-expert needed — `moe_forward_stacked_decode_impl`
907            // and the per-token prefill loop both use the stacked buffers.
908            (Vec::new(), Vec::new())
909        } else {
910            let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
911            let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
912            for e in 0..num_experts {
913                let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
914                let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
915                let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
916
917                let parts: [(GgufQuantType, &[u8], usize); 2] = [
918                    (gate_kind, g_slice, expert_intermediate),
919                    (up_kind, u_slice, expert_intermediate),
920                ];
921                let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
922                    Ok(q) => q,
923                    Err(_) => return Ok(None),
924                };
925                gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
926
927                let down_e = match QuantLinear::<B>::from_gguf_bytes(
928                    down_kind,
929                    d_slice,
930                    hidden_size,
931                    expert_intermediate,
932                ) {
933                    Ok(q) => q,
934                    Err(_) => return Ok(None),
935                };
936                down.push(Box::new(down_e) as Box<dyn Linear<B>>);
937            }
938            (gate_up, down)
939        };
940
941        Ok(Some(Self {
942            gate_up,
943            down,
944            gate_stacked,
945            up_stacked,
946            down_stacked,
947            gate_up_marlin_stack: None,
948            down_marlin_stack: None,
949        }))
950    }
951
952    /// Convenience: open a GGUF and load layer `layer_idx`. The GGUF
953    /// stays open inside this call only — for multi-layer loads use
954    /// [`Self::load_from_gguf`] with a shared [`GgufFile`].
955    pub fn open_and_load(
956        path: impl AsRef<Path>,
957        layer_idx: usize,
958        num_experts: usize,
959        hidden_size: usize,
960        expert_intermediate: usize,
961    ) -> Result<Self> {
962        let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
963        Self::load_from_gguf(
964            &gguf,
965            layer_idx,
966            num_experts,
967            hidden_size,
968            expert_intermediate,
969        )
970    }
971
972    /// `num_experts` for the layer (consistency check helper).
973    ///
974    /// Returns the per-expert Vec length, OR — when the stacked-only
975    /// path is in effect (Metal MoE fast path with empty per-expert
976    /// Vecs) — falls back to a stored count via the stacked variants.
977    /// In the stacked-only case there's no Vec to count, so this method
978    /// is mostly used by tests on the per-expert path.
979    pub fn num_experts(&self) -> usize {
980        debug_assert_eq!(
981            self.gate_up.len(),
982            self.down.len(),
983            "ExpertStack: gate_up and down disagree on expert count"
984        );
985        self.gate_up.len()
986    }
987}
988
989/// Backend-generic MoE forward.
990///
991/// Equivalent of [`moe_forward_cpu`] but parameterised on `B: Backend`
992/// so Metal / CUDA paths can dispatch the same per-(token, expert) loop
993/// using their own kernels for the gemv + silu + scaled-add primitives.
994///
995/// The caller pre-supplies all scratch buffers — this function does no
996/// allocation, which matters because it's invoked from inside the
997/// transformer's `forward_layer` where allocation during graph capture
998/// (CUDA) would corrupt the captured graph.
999///
1000/// Buffer contract (lengths, sized at scratch alloc time):
1001///   - `x`            : `[batch * hidden]` post-RMSNorm activations
1002///   - `router_logits`: `[batch * num_experts]` raw router output
1003///   - `out`          : `[batch * hidden]` — caller is responsible for
1004///                      zeroing this before the call (we accumulate,
1005///                      not assign)
1006///   - `x_single`     : `[hidden]` per-token input slice
1007///   - `acc_buf`      : `[hidden]` per-token output accumulator (kept
1008///                      separate from `x_single` so the gate_up gemv
1009///                      can consume `x_single` repeatedly across the
1010///                      top_k loop without an inter-pair restore)
1011///   - `gate_up_buf`  : `[2 * expert_inter]` per-(token, expert) gemv out
1012///   - `silu_buf`     : `[expert_inter]`
1013///   - `down_buf`     : `[hidden]` per-(token, expert) accumulate src
1014///
1015/// Routing (softmax + top-K + optional renorm) runs host-side using
1016/// `B::to_vec(router_logits, …)` — the routing computation is small
1017/// (`batch * num_experts` floats) and the top-K is a sort, both of
1018/// which dwarf in cost any plausible host↔device transfer.
1019///
1020/// Per-pair dispatch budget (m=1, Metal):
1021///   gate_up Fused gemv (2 parts) + silu + down gemv + scaled_add
1022///   = 5 dispatches/pair. Plus 2 copy_slice/token (load x_single,
1023///   write acc_buf back to out[b]). With top_k=8 and 48 layers, that's
1024///   8×5 + 2 = 42 dispatches/layer × 48 ≈ 2k/token (vs. ~3.5k in the
1025///   pre-PR scheme that round-tripped through `out` per pair).
1026pub struct MoeForwardParams<'a, B: QuantLlmBackend + BackendMoeFused> {
1027    pub ctx: &'a mut B::Context,
1028    pub x: &'a B::Buffer,
1029    pub router_logits: &'a B::Buffer,
1030    pub out: &'a mut B::Buffer,
1031    pub batch: usize,
1032    pub hidden_size: usize,
1033    pub expert_intermediate: usize,
1034    pub num_experts: usize,
1035    pub top_k: usize,
1036    pub norm_topk_prob: bool,
1037    pub experts: &'a ExpertStack<B>,
1038    pub x_single: &'a mut B::Buffer,
1039    pub acc_buf: &'a mut B::Buffer,
1040    pub gate_up_buf: &'a mut B::Buffer,
1041    pub silu_buf: &'a mut B::Buffer,
1042    pub down_buf: &'a mut B::Buffer,
1043    pub zero_hidden: &'a B::Buffer,
1044}
1045
1046pub fn moe_forward<B: QuantLlmBackend + BackendMoeFused>(
1047    params: MoeForwardParams<'_, B>,
1048) -> Result<()> {
1049    let MoeForwardParams {
1050        ctx,
1051        x,
1052        router_logits,
1053        out,
1054        batch,
1055        hidden_size,
1056        expert_intermediate,
1057        num_experts,
1058        top_k,
1059        norm_topk_prob,
1060        experts,
1061        x_single,
1062        acc_buf,
1063        gate_up_buf,
1064        silu_buf,
1065        down_buf,
1066        zero_hidden,
1067    } = params;
1068    let n_experts = experts.num_experts();
1069    if n_experts != num_experts {
1070        return Err(FerrumError::model(format!(
1071            "moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
1072        )));
1073    }
1074
1075    let prof = moe_profile_enabled();
1076
1077    // Routing on host. Sized batch*num_experts (e.g. 512*128 = 64k floats
1078    // per layer for Qwen3-30B-A3B prefill); cheap relative to the per-
1079    // expert gemvs that follow.
1080    let t0 = if prof {
1081        Some(std::time::Instant::now())
1082    } else {
1083        None
1084    };
1085    B::sync(ctx);
1086    if let Some(t) = t0 {
1087        MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1088        MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
1089    }
1090
1091    let t0 = if prof {
1092        Some(std::time::Instant::now())
1093    } else {
1094        None
1095    };
1096    let logits_host = B::to_vec(router_logits, batch * num_experts);
1097    let route_out =
1098        crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
1099    if let Some(t) = t0 {
1100        MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1101        MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
1102    }
1103
1104    for b in 0..batch {
1105        // Load x[b] into x_single + reset accumulator.
1106        let t0 = if prof {
1107            Some(std::time::Instant::now())
1108        } else {
1109            None
1110        };
1111        B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
1112        B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
1113        if let Some(t) = t0 {
1114            MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1115            MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
1116        }
1117
1118        for k in 0..top_k {
1119            let pair = b * top_k + k;
1120            let expert_id = route_out.expert_ids[pair] as usize;
1121            let weight = route_out.expert_weights[pair];
1122            if expert_id >= num_experts {
1123                return Err(FerrumError::model(format!(
1124                    "moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
1125                )));
1126            }
1127
1128            // Fused gate||up gemv → [2 * expert_inter]
1129            let t0 = if prof {
1130                B::sync(ctx);
1131                Some(std::time::Instant::now())
1132            } else {
1133                None
1134            };
1135            experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
1136            if let Some(t) = t0 {
1137                B::sync(ctx);
1138                MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1139                MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
1140            }
1141
1142            // SiLU(gate) * up → [expert_inter]
1143            let t0 = if prof {
1144                Some(std::time::Instant::now())
1145            } else {
1146                None
1147            };
1148            B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
1149            if let Some(t) = t0 {
1150                B::sync(ctx);
1151                MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1152                MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
1153            }
1154
1155            // down gemv → [hidden]
1156            let t0 = if prof {
1157                Some(std::time::Instant::now())
1158            } else {
1159                None
1160            };
1161            experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
1162            if let Some(t) = t0 {
1163                B::sync(ctx);
1164                MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1165                MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
1166            }
1167
1168            // acc_buf += weight * down_buf
1169            let t0 = if prof {
1170                Some(std::time::Instant::now())
1171            } else {
1172                None
1173            };
1174            B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
1175            if let Some(t) = t0 {
1176                B::sync(ctx);
1177                MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1178                MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
1179            }
1180        }
1181
1182        // Final write: out[b] = acc_buf
1183        let t0 = if prof {
1184            Some(std::time::Instant::now())
1185        } else {
1186            None
1187        };
1188        B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
1189        if let Some(t) = t0 {
1190            MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1191            MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
1192        }
1193    }
1194
1195    Ok(())
1196}
1197
1198/// Largest moe_block_size we'd ever pick. Drives Qwen3MoeScratch
1199/// `route_sorted_tokens_dev` sizing (allocates `t*top_k + n_exp*MAX`).
1200pub const MOE_BLOCK_SIZE_MAX: usize = 64;
1201
1202/// Pick `moe_block_size` ∈ {16, 32, 64} based on routing distribution.
1203///
1204/// Marlin-MoE templates instantiate for thread_m_blocks ∈ {1, 2, 3, 4}
1205/// → block_size ∈ {16, 32, 48, 64}. Larger block_size enables the
1206/// "large_batch tile" path (thread_n=256, num_threads=256, 8× more work
1207/// per kernel launch) but pads each expert's tokens up to a multiple of
1208/// block_size — sparse routing wastes most of that.
1209///
1210/// Decision: pick the largest size whose **total padded tokens** stays
1211/// within 30% of actual. If we can't keep overhead below the threshold,
1212/// stick with 16. Skips block_size=48 for simplicity (rare sweet spot).
1213///
1214/// Device-routing path doesn't expose `plan` host-side; fall back to 16
1215/// (no regression vs pre-PR behaviour).
1216fn pick_moe_block_size(
1217    plan: Option<&MoeBucketPlan>,
1218    num_experts: usize,
1219    use_device_route: bool,
1220    total_pairs: usize,
1221) -> usize {
1222    pick_moe_block_size_with_config(
1223        moe_dispatch_runtime_config(),
1224        plan,
1225        num_experts,
1226        use_device_route,
1227        total_pairs,
1228    )
1229}
1230
1231fn pick_moe_block_size_with_config(
1232    config: &MoeDispatchRuntimeConfig,
1233    plan: Option<&MoeBucketPlan>,
1234    num_experts: usize,
1235    use_device_route: bool,
1236    total_pairs: usize,
1237) -> usize {
1238    const CANDIDATES: &[usize] = &[64, 32, 16];
1239    const PADDING_BUDGET: f64 = 1.30; // ≤ 30% overhead vs actual tokens
1240                                      // Manual override (testing / autotuning): FERRUM_MOE_BLOCK_SIZE=8/16/32/48/64.
1241                                      // vLLM 0.20.2 often selects 8 for small-M MoE; keep it override-only
1242                                      // until full-model correctness + throughput beats the 16 default.
1243    if let Some(bs) = config.moe_block_size {
1244        return bs;
1245    }
1246    if use_device_route {
1247        if let Some(bs) = config.moe_large_m_block_size {
1248            if total_pairs >= config.moe_large_m_min_pairs {
1249                return bs;
1250            }
1251        }
1252        // Empirical 2026-05-13: block_size=64 (`thread_m_blocks=4`,
1253        // matching vLLM's tile) regresses M3 c=32 by 5.7% on RTX 4090
1254        // because sparse routing (top_k=8 / num_experts=128 / m=32 ≈
1255        // 2 pairs per active expert) pads each expert's tile by ~32×,
1256        // and the wasted sentinel-row compute exceeds the tile-width
1257        // win. block_size=32 is within noise of 16. Keep 16 as default;
1258        // FERRUM_MOE_BLOCK_SIZE override stays for future autotuning
1259        // when m / routing density changes (e.g. dense Llama at m=32).
1260        return 16;
1261    }
1262    let Some(plan) = plan else {
1263        return 16;
1264    };
1265    let m_e: Vec<usize> = (0..num_experts)
1266        .map(|e| plan.expert_offsets[e + 1] - plan.expert_offsets[e])
1267        .collect();
1268    let total_actual: usize = m_e.iter().sum();
1269    if total_actual == 0 {
1270        return 16;
1271    }
1272    for &bs in CANDIDATES {
1273        let total_padded: usize = m_e.iter().map(|&m| m.div_ceil(bs) * bs).sum();
1274        if (total_padded as f64) <= (total_actual as f64) * PADDING_BUDGET {
1275            return bs;
1276        }
1277    }
1278    16
1279}
1280
1281/// Bucket plan: per-expert lists of which (token, k_slot) pairs route
1282/// through that expert. Built host-side from the router output and used
1283/// by [`moe_forward_bucketed`] to issue ONE m=tokens_per_expert Marlin
1284/// GEMM per active expert instead of `batch * top_k` m=1 GEMMs.
1285pub struct MoeBucketPlan {
1286    /// `expert_offsets[e+1] - expert_offsets[e]` = tokens routed to expert e.
1287    /// Length: `num_experts + 1`. `expert_offsets[num_experts]` = total_pairs
1288    /// (always `batch * top_k`).
1289    pub expert_offsets: Vec<usize>,
1290    /// `[total_pairs]` flat: which input token each packed-row gathers
1291    /// from. Index into `x[batch, hidden]`.
1292    pub packed_token_idx: Vec<u32>,
1293    /// `[batch, top_k]` row-major: for each (b, k_slot), which row of the
1294    /// packed buffers carries that pair's contribution. Used by
1295    /// `B::moe_combine` to scatter weighted sums back to `out[b]`.
1296    pub pairs_by_token: Vec<i32>,
1297    /// `[batch, top_k]` row-major: combine weight for the (b, k_slot)
1298    /// pair, copied verbatim from the router output. Used by
1299    /// `B::moe_combine`.
1300    pub pair_weights: Vec<f32>,
1301    /// Cached cursor scratch for [`Self::rebuild_into`] — sized to
1302    /// `num_experts` on first build and reused (one alloc total instead
1303    /// of one per call).
1304    cursors: Vec<usize>,
1305}
1306
1307impl MoeBucketPlan {
1308    /// Empty plan with no allocation. Use [`Self::rebuild_into`] before
1309    /// reuse — this is the cheap constructor for putting the plan in a
1310    /// scratch struct.
1311    pub fn empty() -> Self {
1312        Self {
1313            expert_offsets: Vec::new(),
1314            packed_token_idx: Vec::new(),
1315            pairs_by_token: Vec::new(),
1316            pair_weights: Vec::new(),
1317            cursors: Vec::new(),
1318        }
1319    }
1320
1321    /// Allocate a fresh plan. Convenience wrapper over [`Self::rebuild_into`]
1322    /// for tests and code paths that don't care about reuse.
1323    pub fn build(route: &RouterOutput, batch: usize, num_experts: usize, top_k: usize) -> Self {
1324        let mut p = Self::empty();
1325        p.rebuild_into(route, batch, num_experts, top_k);
1326        p
1327    }
1328
1329    /// Allocation-free rebuild. Reuses the existing `expert_offsets`,
1330    /// `packed_token_idx`, `pairs_by_token`, `pair_weights` buffers via
1331    /// `clear() + resize()`. Uses the trailing tail of `expert_offsets`
1332    /// as the host-side cursor scratch (saves the per-call `cursors.clone()`).
1333    pub fn rebuild_into(
1334        &mut self,
1335        route: &RouterOutput,
1336        batch: usize,
1337        num_experts: usize,
1338        top_k: usize,
1339    ) {
1340        debug_assert_eq!(route.expert_ids.len(), batch * top_k);
1341        debug_assert_eq!(route.expert_weights.len(), batch * top_k);
1342        let total_pairs = batch * top_k;
1343
1344        self.expert_offsets.clear();
1345        self.expert_offsets.resize(num_experts + 1, 0);
1346        self.packed_token_idx.clear();
1347        self.packed_token_idx.resize(total_pairs, 0);
1348        self.pairs_by_token.clear();
1349        self.pairs_by_token.resize(total_pairs, -1);
1350
1351        // Pass 1: count pairs per expert. Stored into expert_offsets[1..]
1352        // so the inclusive-prefix-sum in Pass 2 can run in place — no
1353        // separate `counts` Vec.
1354        for &eid in &route.expert_ids {
1355            self.expert_offsets[eid as usize + 1] += 1;
1356        }
1357
1358        // Pass 2: in-place inclusive prefix sum → expert_offsets[].
1359        for e in 0..num_experts {
1360            self.expert_offsets[e + 1] += self.expert_offsets[e];
1361        }
1362
1363        // Pass 3: fill packed_token_idx + pairs_by_token by walking pairs
1364        // in (b, k) order and bucketing. The `cursors` scratch tracks how
1365        // many pairs each expert has already received; on first call it
1366        // grows to `num_experts`, subsequent calls reuse the allocation.
1367        self.cursors.clear();
1368        self.cursors
1369            .extend_from_slice(&self.expert_offsets[..num_experts]);
1370
1371        for b in 0..batch {
1372            for k in 0..top_k {
1373                let pair_flat = b * top_k + k;
1374                let eid = route.expert_ids[pair_flat] as usize;
1375                let slot = self.cursors[eid];
1376                self.cursors[eid] += 1;
1377                self.packed_token_idx[slot] = b as u32;
1378                self.pairs_by_token[pair_flat] = slot as i32;
1379            }
1380        }
1381
1382        // Pair weights: replicate from RouterOutput. Reuse self's vector
1383        // via clear() + extend rather than the per-call `clone()`.
1384        self.pair_weights.clear();
1385        self.pair_weights.extend_from_slice(&route.expert_weights);
1386    }
1387}
1388
1389/// Reusable host-side scratch for [`moe_forward_bucketed`]. Holds the
1390/// router output, softmax scratch buffer, and bucket plan, all reused
1391/// across layers so the inner MoE forward path is allocation-free.
1392///
1393/// At c=32 / Qwen3-MoE / 48 layers, the previous fresh-`Vec`-per-layer
1394/// pattern accounted for ~10 ms / token of pure CPU softmax+sort+alloc
1395/// (25% of MoE wallclock — see `docs/bench/cuda-rtx4090-2026-05-08-m3-moe`).
1396pub struct MoeRouteScratch {
1397    pub output: RouterOutput,
1398    /// Softmax buffer reused across all rows of all layers — sized to
1399    /// `num_experts` on first use.
1400    pub probs: Vec<f32>,
1401    pub plan: MoeBucketPlan,
1402}
1403
1404impl MoeRouteScratch {
1405    pub fn new() -> Self {
1406        Self {
1407            output: RouterOutput::empty(),
1408            probs: Vec::new(),
1409            plan: MoeBucketPlan::empty(),
1410        }
1411    }
1412}
1413
1414impl Default for MoeRouteScratch {
1415    fn default() -> Self {
1416        Self::new()
1417    }
1418}
1419
1420/// Bundle of pre-allocated device buffers for the graph-capturable
1421/// device-routing path in [`moe_forward_bucketed`]. Pass `Some` to
1422/// take the device path (under `FERRUM_MOE_DEVICE_ROUTE=1`); pass
1423/// `None` for the legacy host-mediated path (used by tests + the
1424/// non-vLLM CUDA bucketed path).
1425///
1426/// Pre-allocated on Qwen3MoeScratch (`route_pairs_dev` etc.) so the
1427/// per-layer call doesn't alloc inside a captured stream.
1428pub struct DeviceRouteScratch<'a, B: crate::moe::dispatch::Backend> {
1429    pub selected_ids: &'a mut B::Buffer,
1430    pub pair_weights: &'a mut B::Buffer,
1431    pub pairs_by_token: &'a mut B::Buffer,
1432    pub packed_token_idx: &'a mut B::Buffer,
1433    pub expert_offsets: &'a mut B::Buffer,
1434    // Phase 2: moe_align_block_size outputs for the vLLM marlin_moe
1435    // fused GEMM path. Same shape as host `vllm_routing` builder
1436    // produces, but device-resident.
1437    pub sorted_tokens: &'a mut B::Buffer,
1438    pub block_ids: &'a mut B::Buffer,
1439    pub total_post_pad: &'a mut B::Buffer,
1440}
1441
1442/// Bucketed MoE forward: gather → per-expert m=N Marlin GEMM → silu_mul →
1443/// per-expert m=N Marlin GEMM → moe_combine.
1444///
1445/// Replaces the `batch × top_k` m=1 dispatch loop in [`moe_forward`] with
1446/// `num_active_experts × 2` m=tokens_per_expert dispatches. For prefill
1447/// (m=512+), this is a 30× reduction in GEMM launches AND each GEMM runs
1448/// at a much more efficient m than the m=1 path. For decode (m=1), the
1449/// number of dispatches is similar but we still benefit from the
1450/// gather/combine kernel pattern (one launch each instead of 2 per pair).
1451///
1452/// **Requires**: scratch buffers `x_packed [total_pairs, hidden]`,
1453/// `gate_up_packed [total_pairs, 2*expert_inter]`,
1454/// `silu_packed [total_pairs, expert_inter]`, and
1455/// `down_packed [total_pairs, hidden]` provisioned by the caller. The
1456/// caller is responsible for sizing these to `batch * top_k` rows
1457/// (worst-case all top_k pairs alive).
1458pub struct MoeForwardBucketedParams<'a, B: QuantLlmBackend + BackendMoeFused> {
1459    pub ctx: &'a mut B::Context,
1460    pub x: &'a B::Buffer,
1461    pub router_logits: &'a B::Buffer,
1462    pub out: &'a mut B::Buffer,
1463    pub batch: usize,
1464    pub hidden_size: usize,
1465    pub expert_intermediate: usize,
1466    pub num_experts: usize,
1467    pub top_k: usize,
1468    pub norm_topk_prob: bool,
1469    pub experts: &'a ExpertStack<B>,
1470    pub x_packed: &'a mut B::Buffer,
1471    pub gate_up_packed: &'a mut B::Buffer,
1472    pub silu_packed: &'a mut B::Buffer,
1473    pub down_packed: &'a mut B::Buffer,
1474    pub route_scratch: &'a mut MoeRouteScratch,
1475    // Optional device routing scratch — when Some AND
1476    // FERRUM_MOE_DEVICE_ROUTE=1 AND FERRUM_VLLM_MOE=1, runs the
1477    // graph-capturable device-routing branch. None / unset = legacy
1478    // host-mediated path (used by tests + non-vLLM path).
1479    pub device_route: Option<DeviceRouteScratch<'a, B>>,
1480}
1481
1482pub fn moe_forward_bucketed<B: QuantLlmBackend + BackendMoeFused>(
1483    params: MoeForwardBucketedParams<'_, B>,
1484) -> Result<()> {
1485    let MoeForwardBucketedParams {
1486        ctx,
1487        x,
1488        router_logits,
1489        out,
1490        batch,
1491        hidden_size,
1492        expert_intermediate,
1493        num_experts,
1494        top_k,
1495        norm_topk_prob,
1496        experts,
1497        x_packed,
1498        gate_up_packed,
1499        silu_packed,
1500        down_packed,
1501        route_scratch,
1502        device_route,
1503    } = params;
1504    if experts.num_experts() != num_experts {
1505        return Err(FerrumError::model(format!(
1506            "moe_forward_bucketed: experts {} != num_experts {num_experts}",
1507            experts.num_experts()
1508        )));
1509    }
1510
1511    let runtime_config = moe_dispatch_runtime_config();
1512    // Bucket profiling fires on either FERRUM_MOE_PROFILE=1 (legacy)
1513    // or FERRUM_DECODE_OP_PROFILE=1 (the gate the print site uses).
1514    let prof = runtime_config.moe_profile || runtime_config.decode_op_profile;
1515    if prof {
1516        MOE_BUCKET_LAYER_CALLS.fetch_add(1, Ordering::Relaxed);
1517    }
1518
1519    // ── Device-route fast path (opt-in via FERRUM_MOE_DEVICE_ROUTE=1
1520    //     + FERRUM_VLLM_MOE=1 + device_route Some) ────────────────────
1521    //
1522    // Skips ALL host round-trips in the routing + bucket-plan stages:
1523    //   1. B::route_topk_softmax → device expert_ids + weights
1524    //   2. B::moe_build_pairs_by_token → device pairs / packed_idx /
1525    //      expert_offsets
1526    //   3. Gather via B::embedding_lookup_dev (device packed_idx)
1527    //   4. (rest of function reuses these device buffers; the vLLM
1528    //      MoE GEMM consumes them directly)
1529    //
1530    // This is the prerequisite for CUDA Graph capture over the MoE
1531    // layer loop in Qwen3MoeModel::decode_batch_internal.
1532    let use_vllm_moe = runtime_config.vllm_moe;
1533    // Device-routing path: enabled whenever the caller passes
1534    // pre-allocated `DeviceRouteScratch` AND `FERRUM_VLLM_MOE=1` is on.
1535    // No separate env var — the device path is strictly faster than
1536    // the host path (+15.4% c=32 on Qwen3-30B-A3B-GPTQ-Int4, RTX 4090
1537    // bench docs/bench/moe-phase3-vast-2026-05-12); the host path's
1538    // per-layer `try_gpu_route_topk_into_host` (D2H + cuStreamSynchronize)
1539    // was a per-layer GPU stall that compounded over 48 layers.
1540    //
1541    // Requires use_vllm_moe because the non-vLLM bucketed path needs
1542    // host phase1_dispatches / phase3_dispatches lists (one entry per
1543    // active expert with expert-id-dependent shape), which can't be
1544    // built on-device.
1545    //
1546    // Callers that need to force the host path for diagnostics can set
1547    // FERRUM_MOE_HOST_ROUTE=1 (opt-out).
1548    let use_device_route = device_route.is_some() && use_vllm_moe && !runtime_config.moe_host_route;
1549    let use_vllm_pair_ids = use_device_route && runtime_config.vllm_moe_pair_ids;
1550
1551    // Run device-side routing kernels EARLY so `dr.packed_token_idx`
1552    // is available for the device-buffer gather (embedding_lookup_dev),
1553    // `dr.pairs_by_token` / `dr.pair_weights` for moe_combine, and
1554    // `dr.sorted_tokens` / `dr.block_ids` / `dr.total_post_pad` (via
1555    // moe_align_block_size below) for the vLLM marlin_moe GEMM phases.
1556    // Kept alive in `dr_kept` until end of function.
1557    let mut dr_kept: Option<DeviceRouteScratch<'_, B>> = if use_device_route {
1558        let dr = device_route.expect("device_route is Some when use_device_route");
1559        B::route_topk_softmax(
1560            ctx,
1561            router_logits,
1562            dr.selected_ids,
1563            dr.pair_weights,
1564            batch,
1565            num_experts,
1566            top_k,
1567            norm_topk_prob,
1568        )?;
1569        if !use_vllm_pair_ids {
1570            B::moe_build_pairs_by_token(
1571                ctx,
1572                dr.selected_ids,
1573                dr.pairs_by_token,
1574                dr.packed_token_idx,
1575                dr.expert_offsets,
1576                batch * top_k,
1577                num_experts,
1578                top_k,
1579            )?;
1580        }
1581        Some(dr)
1582    } else {
1583        None
1584    };
1585
1586    // ── Routing + bucket plan (host) ─────────────────────────────────
1587    //
1588    // Skipped entirely under use_device_route — the device kernels run
1589    // by `dr_kept` above produce equivalent on-device buffers. The host
1590    // path stays for the legacy non-vllm bucketed dispatch and for
1591    // tests (where device_route is None).
1592    //
1593    // GPU fast-path: `try_gpu_route_topk_into_host` runs the same
1594    // route_topk_softmax kernel and D2Hs only `[batch, top_k]` ids +
1595    // weights (~1 KB at c=32) into RouterOutput. Host fallback covers
1596    // backends without the override (CPU / Metal / future).
1597    let plan: Option<&crate::moe::MoeBucketPlan> = if !use_device_route {
1598        let t_route_total = if prof {
1599            Some(std::time::Instant::now())
1600        } else {
1601            None
1602        };
1603        let gpu_route = B::try_gpu_route_topk_into_host(
1604            ctx,
1605            router_logits,
1606            &mut route_scratch.output.expert_ids,
1607            &mut route_scratch.output.expert_weights,
1608            batch,
1609            num_experts,
1610            top_k,
1611            norm_topk_prob,
1612        );
1613        if gpu_route.is_err() {
1614            let t_sync = if prof {
1615                Some(std::time::Instant::now())
1616            } else {
1617                None
1618            };
1619            B::sync(ctx);
1620            if let Some(t) = t_sync {
1621                MOE_BUCKET_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1622            }
1623            let t_d2h = if prof {
1624                Some(std::time::Instant::now())
1625            } else {
1626                None
1627            };
1628            let logits_host = B::to_vec(router_logits, batch * num_experts);
1629            if let Some(t) = t_d2h {
1630                MOE_BUCKET_D2H_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1631            }
1632            let t_route = if prof {
1633                Some(std::time::Instant::now())
1634            } else {
1635                None
1636            };
1637            crate::moe::router::route_into(
1638                &logits_host,
1639                batch,
1640                num_experts,
1641                top_k,
1642                norm_topk_prob,
1643                &mut route_scratch.output,
1644                &mut route_scratch.probs,
1645            );
1646            if let Some(t) = t_route {
1647                MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1648            }
1649        } else if let Some(t) = t_route_total {
1650            MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1651        }
1652        let t_plan = if prof {
1653            Some(std::time::Instant::now())
1654        } else {
1655            None
1656        };
1657        route_scratch
1658            .plan
1659            .rebuild_into(&route_scratch.output, batch, num_experts, top_k);
1660        if let Some(t) = t_plan {
1661            MOE_BUCKET_PLAN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1662        }
1663        Some(&route_scratch.plan)
1664    } else {
1665        None
1666    };
1667
1668    // ── Gather: x_packed[i] = x[packed_token_idx[i]] ───────────────────
1669    // Under use_device_route, read packed_token_idx from device (no
1670    // host roundtrip → graph-capturable). Else use the host plan.
1671    if !use_vllm_pair_ids {
1672        let t_gather = if prof {
1673            Some(std::time::Instant::now())
1674        } else {
1675            None
1676        };
1677        if let Some(ref dr) = dr_kept {
1678            B::embedding_lookup_dev(
1679                ctx,
1680                x,
1681                dr.packed_token_idx,
1682                x_packed,
1683                batch * top_k,
1684                hidden_size,
1685            );
1686        } else {
1687            let plan = plan.expect("plan is Some when !use_device_route");
1688            B::embedding_lookup(ctx, x, &plan.packed_token_idx, x_packed, hidden_size);
1689        }
1690        if let Some(t) = t_gather {
1691            B::sync(ctx);
1692            MOE_BUCKET_GATHER_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1693        }
1694    }
1695
1696    // ── Per-expert dispatch: gate_up + down GEMMs at m=tokens_per_expert
1697    //
1698    // Uses the strided GPTQ + silu_mul methods so we can pump through
1699    // the BIG packed buffers (allocated once at scratch alloc) without
1700    // any per-expert copies. Each expert gets its column-slice of the
1701    // shared stacked Marlin tile via expert_offset; the row-slice of
1702    // the packed input/output buffers via in_row_offset / out_row_offset.
1703    let gate_up_dim_per_expert = 2 * expert_intermediate;
1704    let down_n_per_expert = hidden_size;
1705    // Bulk-zero the gate_up workspace ONCE before phase 1 for the
1706    // non-vLLM Marlin paths. The vLLM marlin_moe_wna16 kernel resets
1707    // its lock slots internally on the reduce path; vLLM itself only
1708    // zeros this workspace at allocation time. Keep
1709    // FERRUM_VLLM_MOE_ZERO_WS=1 as an A/B escape hatch.
1710    let gu_store = experts.gate_up_stacked_store(0).ok_or_else(|| {
1711        FerrumError::model(
1712            "moe_forward_bucketed requires stacked gate_up store \
1713             (load via Qwen3MoeModel::new_safetensors)",
1714        )
1715    })?;
1716    let zero_marlin_workspace = !use_vllm_moe || runtime_config.vllm_moe_zero_ws;
1717    if zero_marlin_workspace {
1718        let _ = gu_store.zero_workspace(ctx);
1719    }
1720
1721    // Decide path: vLLM marlin_moe_wna16 fused or per-expert bucketed
1722    // GEMMs. Under use_device_route, build routing on-device via
1723    // `moe_align_block_size`; under use_vllm_moe alone, host-build it.
1724    // Either way the GEMM dispatcher takes `&Buffer` for the 3 routing
1725    // arrays.
1726    let total_pairs_active = batch * top_k;
1727    // ── Dynamic moe_block_size policy ────────────────────────────────────
1728    //
1729    // Marlin-MoE kernel template is instantiated for thread_m_blocks ∈
1730    // {1, 2, 3, 4} via COMMON_GET_IF_M1 / COMMON_GET_IF_M234 in
1731    // vllm_marlin_moe/ops.cu. Each maps to block_size = thread_m_blocks
1732    // × 16 ∈ {16, 32, 48, 64}. Picking the right one is a classic
1733    // throughput-vs-padding-waste tradeoff:
1734    //
1735    //   block_size=16 :  16 thread_n=128 num_threads=128 (small_batch tile)
1736    //   block_size=32+:  16 thread_n=256 num_threads=256 (large_batch tile,
1737    //                    8× more work per kernel launch)
1738    //
1739    // Larger tile = more arithmetic per memory load = higher DRAM
1740    // utilization. But each expert pads its actual token count up to
1741    // a multiple of block_size — sparse routing (many experts, few
1742    // tokens each) bleeds into massive padding waste.
1743    //
1744    // Test data (commit ccba35f static block=64 vs reverted block=16):
1745    //   bench/v0.2-cuda dmon @ c=32:
1746    //     block=16  →  SM=99%  DRAM=50%  (mem-stalled, tile too small)
1747    //     block=64  →  varies wildly:
1748    //                    same-prompt c=32  : 2078 tok/s (+100% vs block=16)
1749    //                    apples c=32 diverse: 921 tok/s (-11% vs block=16)
1750    //
1751    // Decision rule: pick the largest block_size whose padding overhead
1752    // would still be ≤ ~30%. The host-routing path has `plan.expert_offsets`
1753    // which gives exact m_e per expert — pick by actual data. The
1754    // device-routing path doesn't have host visibility so falls back to
1755    // a conservative 16 (matches pre-PR behaviour, no regression).
1756    //
1757    // Worst-case scratch sizing: 64 (the largest block_size we'd pick).
1758    // `Qwen3MoeScratch.route_sorted_tokens_dev` capacity is allocated
1759    // assuming this upper bound.
1760    let max_block_size: usize = 64;
1761    let moe_block_size: usize = pick_moe_block_size_with_config(
1762        runtime_config,
1763        plan,
1764        num_experts,
1765        use_device_route,
1766        total_pairs_active,
1767    );
1768    debug_assert!(
1769        moe_block_size <= max_block_size,
1770        "moe_block_size {moe_block_size} exceeds scratch worst-case {max_block_size}"
1771    );
1772    // sorted_max bound — passed to moe_align as a runtime cap so it
1773    // never writes past `total_padded` for the chosen block_size. We use
1774    // the picked `moe_block_size`, not `max_block_size`, so moe_align
1775    // doesn't sentinel-fill the slack between the actual padded count
1776    // and the worst-case buffer capacity (saves ~6 KB of writes per
1777    // layer × 48 layers × 32 layer-loop iters when block_size lands at 16).
1778    // The buffer itself is sized for max_block_size in qwen3_moe.rs.
1779    let sorted_max_size = batch * top_k + num_experts * moe_block_size;
1780    let vllm_routing_owned: Option<ferrum_kernels::backend::MoeRouting<B>> =
1781        if use_vllm_moe && !use_device_route {
1782            let plan = plan.expect("plan is Some when host vllm builder runs");
1783            let mut padded_offsets = Vec::with_capacity(num_experts + 1);
1784            let mut acc = 0usize;
1785            for e in 0..num_experts {
1786                padded_offsets.push(acc);
1787                let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1788                let pe = m_e.div_ceil(moe_block_size) * moe_block_size;
1789                acc += pe;
1790            }
1791            padded_offsets.push(acc);
1792            let total_padded = acc;
1793            let total_blocks = total_padded / moe_block_size;
1794            let sentinel = total_pairs_active as i32;
1795
1796            let mut sorted_token_ids = vec![sentinel; total_padded];
1797            let mut expert_ids = vec![0i32; total_blocks];
1798            for e in 0..num_experts {
1799                let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1800                if m_e == 0 {
1801                    continue;
1802                }
1803                let p_off = padded_offsets[e];
1804                let real_off = plan.expert_offsets[e];
1805                for i in 0..m_e {
1806                    sorted_token_ids[p_off + i] = (real_off + i) as i32;
1807                }
1808                let blocks_for_e = (padded_offsets[e + 1] - p_off) / moe_block_size;
1809                let block_start = p_off / moe_block_size;
1810                for b in 0..blocks_for_e {
1811                    expert_ids[block_start + b] = e as i32;
1812                }
1813            }
1814            let num_tokens_past_padded = vec![total_padded as i32];
1815            Some(B::upload_moe_routing(
1816                ctx,
1817                &sorted_token_ids,
1818                &expert_ids,
1819                &num_tokens_past_padded,
1820            )?)
1821        } else {
1822            None
1823        };
1824
1825    // Device-side moe_align_block_size — under use_device_route, fill
1826    // dr.{sorted_tokens, block_ids, total_post_pad} on device from
1827    // dr.selected_ids. No host roundtrip → captures cleanly.
1828    if use_device_route {
1829        let dr = dr_kept
1830            .as_mut()
1831            .expect("dr_kept is Some when use_device_route");
1832        if use_vllm_pair_ids {
1833            B::moe_align_block_size_pair_ids(
1834                ctx,
1835                dr.selected_ids,
1836                dr.sorted_tokens,
1837                dr.block_ids,
1838                dr.total_post_pad,
1839                batch * top_k,
1840                num_experts,
1841                moe_block_size,
1842                sorted_max_size,
1843            )?;
1844        } else {
1845            B::moe_align_block_size(
1846                ctx,
1847                dr.selected_ids,
1848                dr.sorted_tokens,
1849                dr.block_ids,
1850                dr.total_post_pad,
1851                batch * top_k,
1852                num_experts,
1853                moe_block_size,
1854                sorted_max_size,
1855            )?;
1856        }
1857    }
1858
1859    // Resolve the 3 routing buffers for vLLM phase 1/3 GEMM. Either
1860    // from dr_kept (device-built by moe_align_block_size) or from
1861    // vllm_routing_owned (host-built + uploaded). None → use legacy
1862    // per-expert batched GEMM path.
1863    let vllm_refs: Option<(&B::Buffer, &B::Buffer, &B::Buffer)> = if use_device_route {
1864        let dr = dr_kept
1865            .as_ref()
1866            .expect("dr_kept is Some when use_device_route");
1867        Some((&*dr.sorted_tokens, &*dr.block_ids, &*dr.total_post_pad))
1868    } else if let Some(r) = vllm_routing_owned.as_ref() {
1869        Some((
1870            &r.sorted_token_ids,
1871            &r.expert_ids,
1872            &r.num_tokens_past_padded,
1873        ))
1874    } else {
1875        None
1876    };
1877
1878    // Phase 1/3 batched-GEMM dispatch lists. Only built (and read) for
1879    // the non-vLLM path. Under use_device_route the host plan is None
1880    // anyway, so we'd fail to build them — skip via vllm_refs.is_some.
1881    let phase1_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
1882        let plan = plan.expect("plan is Some when batched GEMM path runs");
1883        let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
1884        for e in 0..num_experts {
1885            let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1886            if m_e == 0 {
1887                continue;
1888            }
1889            let pair_off = plan.expert_offsets[e];
1890            v.push((e, pair_off, pair_off, m_e));
1891        }
1892        v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
1893        v
1894    } else {
1895        Vec::new()
1896    };
1897    let t_gemm1 = if prof {
1898        Some(std::time::Instant::now())
1899    } else {
1900        None
1901    };
1902    if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
1903        // fp32_reduce path: kernel writes C directly via global reduce.
1904        if use_vllm_pair_ids {
1905            gu_store.gemm_phase_vllm(
1906                ctx,
1907                x,
1908                sorted_tokens,
1909                block_ids,
1910                total_post_pad,
1911                gate_up_packed,
1912                batch,
1913                moe_block_size,
1914                top_k,
1915            )?;
1916        } else {
1917            gu_store.gemm_phase_vllm(
1918                ctx,
1919                x_packed,
1920                sorted_tokens,
1921                block_ids,
1922                total_post_pad,
1923                gate_up_packed,
1924                total_pairs_active,
1925                moe_block_size,
1926                1, // top_k=1: pre-gathered rows already index packed input directly
1927            )?;
1928        }
1929    } else {
1930        gu_store.gemm_phase_batched(
1931            ctx,
1932            x_packed,
1933            &phase1_dispatches,
1934            gate_up_packed,
1935            hidden_size,
1936        )?;
1937    }
1938    if let Some(t) = t_gemm1 {
1939        B::sync(ctx);
1940        MOE_BUCKET_GEMM1_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1941    }
1942
1943    // Phase 2: SiLU(gate) * up — single launch covering ALL active
1944    // expert rows in the packed buffer. The unused rows (zeros from
1945    // experts with m_e=0) just produce zeros that the combine step
1946    // ignores via pairs_by_token. Saves num_active_experts-1 launches
1947    // per layer.
1948    let total_pairs_active = batch * top_k;
1949    let t_silu = if prof {
1950        Some(std::time::Instant::now())
1951    } else {
1952        None
1953    };
1954    B::fused_silu_mul_split(
1955        ctx,
1956        gate_up_packed,
1957        silu_packed,
1958        total_pairs_active,
1959        expert_intermediate,
1960    );
1961    if let Some(t) = t_silu {
1962        B::sync(ctx);
1963        MOE_BUCKET_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
1964    }
1965
1966    // Phase 3: down GEMM per active expert. Multi-stream batched.
1967    let d_store = experts.down_stacked_store(0).ok_or_else(|| {
1968        FerrumError::model(
1969            "moe_forward_bucketed requires stacked down store \
1970             (load via Qwen3MoeModel::new_safetensors)",
1971        )
1972    })?;
1973    if zero_marlin_workspace {
1974        let _ = d_store.zero_workspace(ctx);
1975    }
1976    let phase3_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
1977        let plan = plan.expect("plan is Some when batched GEMM path runs");
1978        let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
1979        for e in 0..num_experts {
1980            let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
1981            if m_e == 0 {
1982                continue;
1983            }
1984            let pair_off = plan.expert_offsets[e];
1985            v.push((e, pair_off, pair_off, m_e));
1986        }
1987        v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
1988        v
1989    } else {
1990        Vec::new()
1991    };
1992    let t_gemm3 = if prof {
1993        Some(std::time::Instant::now())
1994    } else {
1995        None
1996    };
1997    if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
1998        d_store.gemm_phase_vllm(
1999            ctx,
2000            silu_packed,
2001            sorted_tokens,
2002            block_ids,
2003            total_post_pad,
2004            down_packed,
2005            total_pairs_active,
2006            moe_block_size,
2007            1,
2008        )?;
2009    } else {
2010        d_store.gemm_phase_batched(
2011            ctx,
2012            silu_packed,
2013            &phase3_dispatches,
2014            down_packed,
2015            expert_intermediate,
2016        )?;
2017    }
2018    if let Some(t) = t_gemm3 {
2019        B::sync(ctx);
2020        MOE_BUCKET_GEMM3_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
2021    }
2022
2023    // ── Combine: out[b, h] = Σ_k weights[b,k] * down_packed[pairs_by_token[b,k], h]
2024    //
2025    // Two paths for the pairs/weights device buffers:
2026    //
2027    //   (a) device-route mode: reuse `dr_kept` populated up top by
2028    //       B::route_topk_softmax + B::moe_build_pairs_by_token. No
2029    //       host→device upload, so this is graph-capturable when
2030    //       wrapped in begin_graph_capture.
2031    //
2032    //   (b) legacy: upload host plan (plan.pairs_by_token /
2033    //       plan.pair_weights) via from_slice_typed. Records host
2034    //       pointer; captures stale on replay.
2035    //
2036    // Both produce mathematically equivalent outputs — device path
2037    // does the same counting-sort the host plan rebuild does, just
2038    // on-device via the moe_build_pairs kernel.
2039    let total_pairs = batch * top_k;
2040    let t_comb = if prof {
2041        Some(std::time::Instant::now())
2042    } else {
2043        None
2044    };
2045    if use_vllm_pair_ids {
2046        let dr = dr_kept
2047            .as_ref()
2048            .expect("dr_kept is Some when use_vllm_pair_ids");
2049        B::weighted_sum_batched(
2050            ctx,
2051            down_packed,
2052            dr.pair_weights,
2053            out,
2054            batch,
2055            top_k,
2056            hidden_size,
2057        )?;
2058    } else {
2059        let (pairs_ref, weights_ref);
2060        let _pairs_owned;
2061        let _weights_owned;
2062        if let Some(ref dr) = dr_kept {
2063            pairs_ref = &*dr.pairs_by_token;
2064            weights_ref = &*dr.pair_weights;
2065        } else {
2066            let plan = plan.expect("plan is Some when host moe_combine runs");
2067            _pairs_owned = B::from_slice_typed::<i32>(&plan.pairs_by_token);
2068            _weights_owned = B::from_slice_typed::<f32>(&plan.pair_weights);
2069            pairs_ref = &_pairs_owned;
2070            weights_ref = &_weights_owned;
2071        }
2072        B::moe_combine(
2073            ctx,
2074            down_packed,
2075            pairs_ref,
2076            weights_ref,
2077            out,
2078            batch,
2079            hidden_size,
2080            top_k,
2081            total_pairs,
2082        );
2083    }
2084    if let Some(t) = t_comb {
2085        B::sync(ctx);
2086        MOE_BUCKET_COMBINE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
2087    }
2088
2089    Ok(())
2090}
2091
2092/// Run MoE forward on CPU.
2093///
2094/// Inputs:
2095///   - `x`: `[batch, hidden_size]` row-major hidden states (post-attention,
2096///          post-residual — i.e. what the dense MLP would normally see).
2097///   - `router`: top-K assignments + weights from [`super::router::route`].
2098///   - `experts`: per-layer expert weights from [`ExpertStack::load_from_gguf`].
2099///
2100/// Output:
2101///   - `out`: `[batch, hidden_size]`. Resized + zero-initialised.
2102///
2103/// The function recomputes its scratch buffers each call. For tight
2104/// inner loops, callers will eventually want a pre-allocated workspace
2105/// (Phase 2F refactor). For now, this is the readable reference.
2106pub fn moe_forward_cpu(
2107    x: &[f32],
2108    batch: usize,
2109    hidden_size: usize,
2110    expert_intermediate: usize,
2111    top_k: usize,
2112    router: &RouterOutput,
2113    experts: &ExpertStack<CpuBackend>,
2114    out: &mut Vec<f32>,
2115) -> Result<()> {
2116    let n_experts = experts.num_experts();
2117
2118    if x.len() != batch * hidden_size {
2119        return Err(FerrumError::model(format!(
2120            "moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
2121            x.len(),
2122            batch,
2123            hidden_size,
2124            batch * hidden_size
2125        )));
2126    }
2127    if router.expert_ids.len() != batch * top_k {
2128        return Err(FerrumError::model(format!(
2129            "moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
2130            router.expert_ids.len(),
2131            batch,
2132            top_k,
2133            batch * top_k
2134        )));
2135    }
2136
2137    out.clear();
2138    out.resize(batch * hidden_size, 0.0);
2139
2140    let mut ctx = <CpuBackend as Backend>::new_context();
2141    let mut x_b: Vec<f32> = vec![0.0; hidden_size];
2142    let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
2143    let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
2144    let mut down_out: Vec<f32> = vec![0.0; hidden_size];
2145
2146    for b in 0..batch {
2147        x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
2148
2149        for k in 0..top_k {
2150            let pair_idx = b * top_k + k;
2151            let expert_id = router.expert_ids[pair_idx] as usize;
2152            let weight = router.expert_weights[pair_idx];
2153
2154            if expert_id >= n_experts {
2155                return Err(FerrumError::model(format!(
2156                    "moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
2157                )));
2158            }
2159
2160            // Gate||Up projection (fused) → [1, 2*expert_inter]
2161            experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
2162
2163            // SiLU(gate) * up → [1, expert_inter]
2164            <CpuBackend as Backend>::fused_silu_mul_split(
2165                &mut ctx,
2166                &gate_up_buf,
2167                &mut silu_mul_buf,
2168                1,
2169                expert_intermediate,
2170            );
2171
2172            // Down projection → [1, hidden]
2173            experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
2174
2175            // Weighted accumulate into out[b, :]. Done host-side because
2176            // CpuBackend::Buffer = Vec<f32> and the trait doesn't yet
2177            // expose scaled-add.
2178            let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
2179            for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
2180                *o += weight * *d;
2181            }
2182        }
2183    }
2184
2185    Ok(())
2186}
2187
2188fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
2189    if actual != expected {
2190        return Err(FerrumError::model(format!(
2191            "ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
2192        )));
2193    }
2194    Ok(())
2195}
2196
2197/// Map candle's `GgmlDType` to the kernel-side `GgufQuantType` for the
2198/// dtypes a backend can dispatch on. Returns `None` for any other dtype
2199/// (callers fall back to eager dequant).
2200fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
2201    let info = gguf.tensor_info(name).ok_or_else(|| {
2202        FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
2203    })?;
2204    Ok(match info.ggml_dtype {
2205        GgmlDType::Q4K => Some(GgufQuantType::Q4K),
2206        GgmlDType::Q6K => Some(GgufQuantType::Q6K),
2207        _ => None,
2208    })
2209}
2210
2211/// Per-expert block-byte count for a given k-quant flavour and element
2212/// count. Q4_K = 144 B / 256 elems, Q6_K = 210 B / 256 elems. Errors if
2213/// `n_elems` is not a multiple of the super-block size (256) — a Q-quant
2214/// invariant.
2215fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
2216    const QK_K: usize = 256;
2217    if n_elems % QK_K != 0 {
2218        return Err(FerrumError::model(format!(
2219            "ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
2220        )));
2221    }
2222    let block_bytes = match kind {
2223        GgufQuantType::Q4K => 144,
2224        GgufQuantType::Q6K => 210,
2225        // Other k-quants are filtered out earlier via `quant_kind`; reaching here
2226        // with one would be a programming error.
2227        other => {
2228            return Err(FerrumError::model(format!(
2229                "ExpertStack {label}: unsupported k-quant flavour {other:?}"
2230            )))
2231        }
2232    };
2233    Ok((n_elems / QK_K) * block_bytes)
2234}
2235
2236fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
2237    let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
2238    let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
2239    let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
2240    flat.to_vec1::<f32>().map_err(candle_to_ferrum)
2241}
2242
2243fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
2244    FerrumError::model(format!("candle: {e}"))
2245}
2246
2247// Suppress unused-import warning when this module compiles standalone in
2248// the lib (the candle Result alias is only used via map_err in Phase 2).
2249#[allow(dead_code)]
2250type _CandleResult<T> = CandleResult<T>;
2251
2252#[cfg(test)]
2253mod tests {
2254    use super::{pick_moe_block_size_with_config, MoeDispatchRuntimeConfig};
2255
2256    #[test]
2257    fn moe_dispatch_runtime_config_parses_m3_startup_knobs() {
2258        let config = MoeDispatchRuntimeConfig::from_env_vars([
2259            ("FERRUM_MOE_PROFILE", "0"),
2260            ("FERRUM_DECODE_OP_PROFILE", "true"),
2261            ("FERRUM_VLLM_MOE_ZERO_WS", "1"),
2262            ("FERRUM_VLLM_MOE_PAIR_IDS", "1"),
2263            ("FERRUM_MOE_LOAD_TRACE", ""),
2264            ("FERRUM_MOE_BLOCK_SIZE", "8"),
2265            ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2266            ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "2048"),
2267            ("FERRUM_VLLM_MOE", "1"),
2268            ("FERRUM_MOE_HOST_ROUTE", "1"),
2269        ]);
2270
2271        assert!(config.moe_profile);
2272        assert!(config.decode_op_profile);
2273        assert!(config.vllm_moe_zero_ws);
2274        assert!(config.vllm_moe_pair_ids);
2275        assert!(config.moe_load_trace);
2276        assert_eq!(config.moe_block_size, Some(8));
2277        assert_eq!(config.moe_large_m_block_size, Some(64));
2278        assert_eq!(config.moe_large_m_min_pairs, 2048);
2279        assert!(config.vllm_moe);
2280        assert!(config.moe_host_route);
2281    }
2282
2283    #[test]
2284    fn moe_dispatch_runtime_config_bounds_invalid_block_values() {
2285        let config = MoeDispatchRuntimeConfig::from_env_vars([
2286            ("FERRUM_MOE_BLOCK_SIZE", "12"),
2287            ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "128"),
2288            ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "bad"),
2289            ("FERRUM_VLLM_MOE_ZERO_WS", "true"),
2290            ("FERRUM_MOE_HOST_ROUTE", "0"),
2291        ]);
2292
2293        assert_eq!(config.moe_block_size, None);
2294        assert_eq!(config.moe_large_m_block_size, None);
2295        assert_eq!(config.moe_large_m_min_pairs, 1024);
2296        assert!(!config.vllm_moe_zero_ws);
2297        assert!(!config.moe_host_route);
2298    }
2299
2300    #[test]
2301    fn device_route_large_m_block_size_is_thresholded() {
2302        let config = MoeDispatchRuntimeConfig::from_env_vars([
2303            ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2304            ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
2305        ]);
2306
2307        assert_eq!(
2308            pick_moe_block_size_with_config(&config, None, 128, true, 256),
2309            16
2310        );
2311        assert_eq!(
2312            pick_moe_block_size_with_config(&config, None, 128, true, 1024),
2313            64
2314        );
2315    }
2316
2317    #[test]
2318    fn global_moe_block_size_override_still_wins() {
2319        let config = MoeDispatchRuntimeConfig::from_env_vars([
2320            ("FERRUM_MOE_BLOCK_SIZE", "32"),
2321            ("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
2322            ("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
2323        ]);
2324
2325        assert_eq!(
2326            pick_moe_block_size_with_config(&config, None, 128, true, 2048),
2327            32
2328        );
2329    }
2330}