Skip to main content

ferrum_models/moe/
dispatch.rs

1//! Expert dispatch — load per-layer expert weights from a GGUF file and run
2//! the per-token MoE forward (top-K experts per token, weighted combine).
3//!
4//! Phase 2 ships a CPU-only implementation (`moe_forward_cpu`). The
5//! algorithm is:
6//!
7//! ```text
8//! for each token b in batch:
9//!     route token b → (expert_ids[K], weights[K])
10//!     out[b] = 0
11//!     for each (expert_id, weight) pair:
12//!         gate_up = experts.gate_up[expert_id].forward(x[b])     # [2*ffn]
13//!         silu_mul = silu(gate_up[..ffn]) * gate_up[ffn..]       # [ffn]
14//!         contribution = experts.down[expert_id].forward(silu_mul) # [hidden]
15//!         out[b] += weight * contribution
16//! ```
17//!
18//! The fused `gate || up` per-expert layout means we can call
19//! `Backend::fused_silu_mul_split` directly on the projection's output
20//! — same kernel ferrum already uses for dense Llama-family models.
21
22use std::path::Path;
23use std::sync::atomic::{AtomicU64, Ordering};
24
25use candle_core::quantized::GgmlDType;
26use candle_core::{Device, Result as CandleResult};
27use ferrum_kernels::backend::cpu::CpuBackend;
28use ferrum_kernels::backend::{Backend, GgufQuantType};
29use ferrum_kernels::Linear;
30use ferrum_quantization::gguf::GgufFile;
31use ferrum_quantization::{DenseLinear, QuantLinear};
32use ferrum_types::{FerrumError, Result};
33
34use crate::moe::router::RouterOutput;
35
36/// MoE per-op timers. Public so the model wrapper can drain + print at
37/// end of decode. Times are in microseconds, atomically accumulated.
38/// Toggle via env `FERRUM_MOE_PROFILE=1`.
39pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
40pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
41pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
42pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
43pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
44pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
45pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
46pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
47pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
48pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
49pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
50pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
51pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
52pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
53
54fn moe_profile_enabled() -> bool {
55    std::env::var("FERRUM_MOE_PROFILE").is_ok()
56}
57
58/// Per-layer expert weights, materialised as `[num_experts]`-long vectors
59/// of `Box<dyn Linear<B>>`. Each entry runs the corresponding expert's
60/// fused `[gate; up]` projection or its `down` projection.
61///
62/// `B::Buffer` is hidden behind `Linear<B>` so this struct is generic
63/// over backend, but Phase 2's only consumer (`moe_forward_cpu`) is CPU-
64/// only — generic `moe_forward<B>` is deferred until the trait gains
65/// scaled-accumulate + cheap buffer slicing.
66pub struct ExpertStack<B: Backend> {
67    /// Fused `[gate; up]` projection per expert. Output shape per token:
68    /// `[2 * expert_intermediate]` — the lower half is gate, upper is up.
69    pub gate_up: Vec<Box<dyn Linear<B>>>,
70    /// `down` projection per expert. Output shape per token: `[hidden_size]`.
71    pub down: Vec<Box<dyn Linear<B>>>,
72    /// Stacked-experts representation for backends that have a batched
73    /// MoE indirect-dispatch kernel (Metal `gemv_q4kw_moe_id_f32` /
74    /// `gemv_q6kw_moe_id_f32`). Holds **all experts** for one matmul
75    /// role in a single `B::QuantStore` with byte stride between expert
76    /// slabs, so a single dispatch can cover all selected (token, expert)
77    /// pairs at decode m=1.
78    ///
79    /// `None` on backends without the kernel (CPU, CUDA-without-MoE-kernel)
80    /// and on quant flavours that don't have a stacked path yet — callers
81    /// fall back to the per-expert `gate_up` / `down` Linears in those
82    /// cases.
83    pub gate_stacked: Option<B::QuantStore>,
84    pub up_stacked: Option<B::QuantStore>,
85    pub down_stacked: Option<B::QuantStore>,
86}
87
88impl<B: Backend> ExpertStack<B> {
89    /// Build from raw fp32 stacked tensors (test helper). Caller has
90    /// already dequantised and laid out the data:
91    ///   `gate_stack`: `[num_experts * expert_inter * hidden]`
92    ///   `up_stack`:   `[num_experts * expert_inter * hidden]`
93    ///   `down_stack`: `[num_experts * hidden * expert_inter]`
94    /// Each per-expert slice is row-major in the natural Linear shape.
95    pub fn from_dense_stacks(
96        gate_stack: &[f32],
97        up_stack: &[f32],
98        down_stack: &[f32],
99        num_experts: usize,
100        hidden_size: usize,
101        expert_intermediate: usize,
102    ) -> Result<Self> {
103        let gate_up_per_expert = expert_intermediate * hidden_size;
104        let down_per_expert = hidden_size * expert_intermediate;
105
106        check_size(
107            gate_stack.len(),
108            num_experts * gate_up_per_expert,
109            "gate_stack",
110        )?;
111        check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
112        check_size(
113            down_stack.len(),
114            num_experts * down_per_expert,
115            "down_stack",
116        )?;
117
118        let mut gate_up = Vec::with_capacity(num_experts);
119        let mut down = Vec::with_capacity(num_experts);
120        for e in 0..num_experts {
121            let g_off = e * gate_up_per_expert;
122            let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
123            let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
124
125            // Fused [gate; up] is [2 * expert_inter, hidden] row-major.
126            // We concatenate row-blocks so the first expert_inter rows are
127            // gate, the next expert_inter rows are up — the layout
128            // fused_silu_mul_split expects.
129            let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
130            fused.extend_from_slice(g_slice);
131            fused.extend_from_slice(u_slice);
132            gate_up.push(Box::new(DenseLinear::<B>::from_rows(
133                &fused,
134                2 * expert_intermediate,
135                hidden_size,
136            )) as Box<dyn Linear<B>>);
137
138            let d_off = e * down_per_expert;
139            let d_slice = &down_stack[d_off..d_off + down_per_expert];
140            down.push(Box::new(DenseLinear::<B>::from_rows(
141                d_slice,
142                hidden_size,
143                expert_intermediate,
144            )) as Box<dyn Linear<B>>);
145        }
146        Ok(Self {
147            gate_up,
148            down,
149            gate_stacked: None,
150            up_stacked: None,
151            down_stacked: None,
152        })
153    }
154
155    /// Load all experts for one MoE layer from a GGUF file. Names follow
156    /// the GGUF convention: `blk.{layer_idx}.ffn_{gate,up,down}_exps.weight`.
157    ///
158    /// The loader picks between two strategies based on the on-disk dtype
159    /// of the expert tensors:
160    ///
161    ///   - **Quantised path** (Q4_K / Q6_K only): each expert's
162    ///     `gate || up` becomes a single `QuantLinear<B>` (Fused
163    ///     QuantStore — gate + up share `n_cols = hidden`), and `down` is
164    ///     a plain `QuantLinear<B>`. Block bytes stay compressed in
165    ///     backend memory; per-call dequant happens inside `gemm_quant`.
166    ///   - **Dense fallback** (everything else, e.g. F32 / F16 / Q5_K
167    ///     until a kernel ships): eager-dequant to fp32 and wrap
168    ///     `DenseLinear<B>`. Memory inflates ~7× vs Q4_K_M but the
169    ///     algorithm is correctness-equivalent and this is the path the
170    ///     synthetic-MoE test fixtures need.
171    ///
172    /// The runtime dispatcher (`moe_forward<B>`) doesn't see which path
173    /// was taken — it just calls `Linear::forward` per (token, expert).
174    pub fn load_from_gguf(
175        gguf: &GgufFile,
176        layer_idx: usize,
177        num_experts: usize,
178        hidden_size: usize,
179        expert_intermediate: usize,
180    ) -> Result<Self> {
181        if let Some(quant) = Self::try_load_quantised(
182            gguf,
183            layer_idx,
184            num_experts,
185            hidden_size,
186            expert_intermediate,
187        )? {
188            if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
189                eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
190            }
191            return Ok(quant);
192        }
193
194        if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
195            eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
196        }
197
198        let device = Device::Cpu;
199        let gate = read_dequant_flat(
200            gguf,
201            &format!("blk.{layer_idx}.ffn_gate_exps.weight"),
202            &device,
203        )?;
204        let up = read_dequant_flat(
205            gguf,
206            &format!("blk.{layer_idx}.ffn_up_exps.weight"),
207            &device,
208        )?;
209        let down = read_dequant_flat(
210            gguf,
211            &format!("blk.{layer_idx}.ffn_down_exps.weight"),
212            &device,
213        )?;
214        // Eager-dense path leaves stacked variants as None — no MoE
215        // fast path for synthesised / non-quantised expert tensors.
216        Self::from_dense_stacks(
217            &gate,
218            &up,
219            &down,
220            num_experts,
221            hidden_size,
222            expert_intermediate,
223        )
224    }
225
226    /// Attempt the quantised path. Returns `Ok(None)` if any of the three
227    /// tensors isn't a supported k-quant flavour (Q4_K / Q6_K) or if the
228    /// shape doesn't match the expected per-expert tile size — caller
229    /// then takes the eager-dequant fallback. Returns `Err` only on a
230    /// genuine load failure (missing tensor, byte-count mismatch).
231    fn try_load_quantised(
232        gguf: &GgufFile,
233        layer_idx: usize,
234        num_experts: usize,
235        hidden_size: usize,
236        expert_intermediate: usize,
237    ) -> Result<Option<Self>> {
238        let device = Device::Cpu;
239
240        let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
241        let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
242        let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
243
244        // Inspect tensor info up front — if any tensor isn't a k-quant
245        // flavour the backend can dispatch on, bail to the dense path
246        // before paying the byte-read cost.
247        let gate_kind = match quant_kind(gguf, &gate_name)? {
248            Some(k) => k,
249            None => return Ok(None),
250        };
251        let up_kind = match quant_kind(gguf, &up_name)? {
252            Some(k) => k,
253            None => return Ok(None),
254        };
255        let down_kind = match quant_kind(gguf, &down_name)? {
256            Some(k) => k,
257            None => return Ok(None),
258        };
259
260        // Slice the three 3-D quantised expert stacks directly from
261        // the mmap. These are the dominant memory cost on Qwen3-MoE
262        // (~14 GB for Qwen3-30B-A3B); going through candle's
263        // `read_tensor` would copy them into a heap `Vec<u8>` first,
264        // then `load_quant_experts` would copy again into the Metal
265        // buffer — together doubling the working set and pushing a
266        // 32 GB Mac into swap. With this slice + the Metal mmap
267        // registry, we avoid both copies (steady state: just the
268        // file mmap).
269        let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
270            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
271        })?;
272        let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
273            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
274        })?;
275        let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
276            FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
277        })?;
278        let _ = device; // candle device no longer needed for the byte read
279
280        // Per-expert byte stride for each tensor. The 3-D layout is
281        // contiguous, [num_experts, rows, cols] row-major, so each
282        // expert's slab is exactly `total_bytes / num_experts`.
283        let gate_per = block_bytes_for(
284            gate_kind,
285            expert_intermediate * hidden_size,
286            "ffn_gate_exps",
287        )?;
288        let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
289        let down_per = block_bytes_for(
290            down_kind,
291            hidden_size * expert_intermediate,
292            "ffn_down_exps",
293        )?;
294
295        check_size(
296            gate_bytes.len(),
297            num_experts * gate_per,
298            "ffn_gate_exps bytes",
299        )?;
300        check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
301        check_size(
302            down_bytes.len(),
303            num_experts * down_per,
304            "ffn_down_exps bytes",
305        )?;
306
307        // Try the stacked-experts fast path FIRST. If the backend has a
308        // batched MoE kernel (Metal `gemv_q*kw_moe_id_f32`), we want to
309        // hold the experts only as one big stacked buffer per role —
310        // not as 128 per-expert MetalQuantStores PLUS the stacked one
311        // (that would double-allocate ~17 GB on a 32 GB Mac, which on
312        // Qwen3-30B-A3B Q4_K_M sends the model into swap and tanks
313        // both load and forward time).
314        let gate_stacked = B::load_quant_experts(
315            gate_kind,
316            gate_bytes,
317            num_experts,
318            expert_intermediate,
319            hidden_size,
320        )
321        .ok();
322        let up_stacked = B::load_quant_experts(
323            up_kind,
324            up_bytes,
325            num_experts,
326            expert_intermediate,
327            hidden_size,
328        )
329        .ok();
330        let down_stacked = B::load_quant_experts(
331            down_kind,
332            down_bytes,
333            num_experts,
334            hidden_size,
335            expert_intermediate,
336        )
337        .ok();
338
339        // Decide the storage shape:
340        //   * Stacked-only (Metal MoE fast path): all three stacked
341        //     loaders succeeded — skip per-expert and use stacked
342        //     for both decode and prefill. Cuts memory in half.
343        //   * Per-expert: stacked path is incomplete or unsupported —
344        //     load 128-per-layer QuantLinears and let `moe_forward`
345        //     drive the per-(token, expert) loop on top of them.
346        let stacked_complete =
347            gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
348
349        let (gate_up, down) = if stacked_complete {
350            // No per-expert needed — `moe_forward_stacked_decode_impl`
351            // and the per-token prefill loop both use the stacked buffers.
352            (Vec::new(), Vec::new())
353        } else {
354            let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
355            let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
356            for e in 0..num_experts {
357                let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
358                let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
359                let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
360
361                let parts: [(GgufQuantType, &[u8], usize); 2] = [
362                    (gate_kind, g_slice, expert_intermediate),
363                    (up_kind, u_slice, expert_intermediate),
364                ];
365                let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
366                    Ok(q) => q,
367                    Err(_) => return Ok(None),
368                };
369                gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
370
371                let down_e = match QuantLinear::<B>::from_gguf_bytes(
372                    down_kind,
373                    d_slice,
374                    hidden_size,
375                    expert_intermediate,
376                ) {
377                    Ok(q) => q,
378                    Err(_) => return Ok(None),
379                };
380                down.push(Box::new(down_e) as Box<dyn Linear<B>>);
381            }
382            (gate_up, down)
383        };
384
385        Ok(Some(Self {
386            gate_up,
387            down,
388            gate_stacked,
389            up_stacked,
390            down_stacked,
391        }))
392    }
393
394    /// Convenience: open a GGUF and load layer `layer_idx`. The GGUF
395    /// stays open inside this call only — for multi-layer loads use
396    /// [`Self::load_from_gguf`] with a shared [`GgufFile`].
397    pub fn open_and_load(
398        path: impl AsRef<Path>,
399        layer_idx: usize,
400        num_experts: usize,
401        hidden_size: usize,
402        expert_intermediate: usize,
403    ) -> Result<Self> {
404        let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
405        Self::load_from_gguf(
406            &gguf,
407            layer_idx,
408            num_experts,
409            hidden_size,
410            expert_intermediate,
411        )
412    }
413
414    /// `num_experts` for the layer (consistency check helper).
415    ///
416    /// Returns the per-expert Vec length, OR — when the stacked-only
417    /// path is in effect (Metal MoE fast path with empty per-expert
418    /// Vecs) — falls back to a stored count via the stacked variants.
419    /// In the stacked-only case there's no Vec to count, so this method
420    /// is mostly used by tests on the per-expert path.
421    pub fn num_experts(&self) -> usize {
422        debug_assert_eq!(
423            self.gate_up.len(),
424            self.down.len(),
425            "ExpertStack: gate_up and down disagree on expert count"
426        );
427        self.gate_up.len()
428    }
429}
430
431/// Backend-generic MoE forward.
432///
433/// Equivalent of [`moe_forward_cpu`] but parameterised on `B: Backend`
434/// so Metal / CUDA paths can dispatch the same per-(token, expert) loop
435/// using their own kernels for the gemv + silu + scaled-add primitives.
436///
437/// The caller pre-supplies all scratch buffers — this function does no
438/// allocation, which matters because it's invoked from inside the
439/// transformer's `forward_layer` where allocation during graph capture
440/// (CUDA) would corrupt the captured graph.
441///
442/// Buffer contract (lengths, sized at scratch alloc time):
443///   - `x`            : `[batch * hidden]` post-RMSNorm activations
444///   - `router_logits`: `[batch * num_experts]` raw router output
445///   - `out`          : `[batch * hidden]` — caller is responsible for
446///                      zeroing this before the call (we accumulate,
447///                      not assign)
448///   - `x_single`     : `[hidden]` per-token input slice
449///   - `acc_buf`      : `[hidden]` per-token output accumulator (kept
450///                      separate from `x_single` so the gate_up gemv
451///                      can consume `x_single` repeatedly across the
452///                      top_k loop without an inter-pair restore)
453///   - `gate_up_buf`  : `[2 * expert_inter]` per-(token, expert) gemv out
454///   - `silu_buf`     : `[expert_inter]`
455///   - `down_buf`     : `[hidden]` per-(token, expert) accumulate src
456///
457/// Routing (softmax + top-K + optional renorm) runs host-side using
458/// `B::to_vec(router_logits, …)` — the routing computation is small
459/// (`batch * num_experts` floats) and the top-K is a sort, both of
460/// which dwarf in cost any plausible host↔device transfer.
461///
462/// Per-pair dispatch budget (m=1, Metal):
463///   gate_up Fused gemv (2 parts) + silu + down gemv + scaled_add
464///   = 5 dispatches/pair. Plus 2 copy_slice/token (load x_single,
465///   write acc_buf back to out[b]). With top_k=8 and 48 layers, that's
466///   8×5 + 2 = 42 dispatches/layer × 48 ≈ 2k/token (vs. ~3.5k in the
467///   pre-PR scheme that round-tripped through `out` per pair).
468#[allow(clippy::too_many_arguments)]
469pub fn moe_forward<B: Backend>(
470    ctx: &mut B::Context,
471    x: &B::Buffer,
472    router_logits: &B::Buffer,
473    out: &mut B::Buffer,
474    batch: usize,
475    hidden_size: usize,
476    expert_intermediate: usize,
477    num_experts: usize,
478    top_k: usize,
479    norm_topk_prob: bool,
480    experts: &ExpertStack<B>,
481    x_single: &mut B::Buffer,
482    acc_buf: &mut B::Buffer,
483    gate_up_buf: &mut B::Buffer,
484    silu_buf: &mut B::Buffer,
485    down_buf: &mut B::Buffer,
486    zero_hidden: &B::Buffer,
487) -> Result<()> {
488    let n_experts = experts.num_experts();
489    if n_experts != num_experts {
490        return Err(FerrumError::model(format!(
491            "moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
492        )));
493    }
494
495    let prof = moe_profile_enabled();
496
497    // Routing on host. Sized batch*num_experts (e.g. 512*128 = 64k floats
498    // per layer for Qwen3-30B-A3B prefill); cheap relative to the per-
499    // expert gemvs that follow.
500    let t0 = if prof {
501        Some(std::time::Instant::now())
502    } else {
503        None
504    };
505    B::sync(ctx);
506    if let Some(t) = t0 {
507        MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
508        MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
509    }
510
511    let t0 = if prof {
512        Some(std::time::Instant::now())
513    } else {
514        None
515    };
516    let logits_host = B::to_vec(router_logits, batch * num_experts);
517    let route_out =
518        crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
519    if let Some(t) = t0 {
520        MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
521        MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
522    }
523
524    for b in 0..batch {
525        // Load x[b] into x_single + reset accumulator.
526        let t0 = if prof {
527            Some(std::time::Instant::now())
528        } else {
529            None
530        };
531        B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
532        B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
533        if let Some(t) = t0 {
534            MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
535            MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
536        }
537
538        for k in 0..top_k {
539            let pair = b * top_k + k;
540            let expert_id = route_out.expert_ids[pair] as usize;
541            let weight = route_out.expert_weights[pair];
542            if expert_id >= num_experts {
543                return Err(FerrumError::model(format!(
544                    "moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
545                )));
546            }
547
548            // Fused gate||up gemv → [2 * expert_inter]
549            let t0 = if prof {
550                B::sync(ctx);
551                Some(std::time::Instant::now())
552            } else {
553                None
554            };
555            experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
556            if let Some(t) = t0 {
557                B::sync(ctx);
558                MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
559                MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
560            }
561
562            // SiLU(gate) * up → [expert_inter]
563            let t0 = if prof {
564                Some(std::time::Instant::now())
565            } else {
566                None
567            };
568            B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
569            if let Some(t) = t0 {
570                B::sync(ctx);
571                MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
572                MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
573            }
574
575            // down gemv → [hidden]
576            let t0 = if prof {
577                Some(std::time::Instant::now())
578            } else {
579                None
580            };
581            experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
582            if let Some(t) = t0 {
583                B::sync(ctx);
584                MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
585                MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
586            }
587
588            // acc_buf += weight * down_buf
589            let t0 = if prof {
590                Some(std::time::Instant::now())
591            } else {
592                None
593            };
594            B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
595            if let Some(t) = t0 {
596                B::sync(ctx);
597                MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
598                MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
599            }
600        }
601
602        // Final write: out[b] = acc_buf
603        let t0 = if prof {
604            Some(std::time::Instant::now())
605        } else {
606            None
607        };
608        B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
609        if let Some(t) = t0 {
610            MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
611            MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
612        }
613    }
614
615    Ok(())
616}
617
618/// Run MoE forward on CPU.
619///
620/// Inputs:
621///   - `x`: `[batch, hidden_size]` row-major hidden states (post-attention,
622///          post-residual — i.e. what the dense MLP would normally see).
623///   - `router`: top-K assignments + weights from [`super::router::route`].
624///   - `experts`: per-layer expert weights from [`ExpertStack::load_from_gguf`].
625///
626/// Output:
627///   - `out`: `[batch, hidden_size]`. Resized + zero-initialised.
628///
629/// The function recomputes its scratch buffers each call. For tight
630/// inner loops, callers will eventually want a pre-allocated workspace
631/// (Phase 2F refactor). For now, this is the readable reference.
632pub fn moe_forward_cpu(
633    x: &[f32],
634    batch: usize,
635    hidden_size: usize,
636    expert_intermediate: usize,
637    top_k: usize,
638    router: &RouterOutput,
639    experts: &ExpertStack<CpuBackend>,
640    out: &mut Vec<f32>,
641) -> Result<()> {
642    let n_experts = experts.num_experts();
643
644    if x.len() != batch * hidden_size {
645        return Err(FerrumError::model(format!(
646            "moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
647            x.len(),
648            batch,
649            hidden_size,
650            batch * hidden_size
651        )));
652    }
653    if router.expert_ids.len() != batch * top_k {
654        return Err(FerrumError::model(format!(
655            "moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
656            router.expert_ids.len(),
657            batch,
658            top_k,
659            batch * top_k
660        )));
661    }
662
663    out.clear();
664    out.resize(batch * hidden_size, 0.0);
665
666    let mut ctx = <CpuBackend as Backend>::new_context();
667    let mut x_b: Vec<f32> = vec![0.0; hidden_size];
668    let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
669    let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
670    let mut down_out: Vec<f32> = vec![0.0; hidden_size];
671
672    for b in 0..batch {
673        x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
674
675        for k in 0..top_k {
676            let pair_idx = b * top_k + k;
677            let expert_id = router.expert_ids[pair_idx] as usize;
678            let weight = router.expert_weights[pair_idx];
679
680            if expert_id >= n_experts {
681                return Err(FerrumError::model(format!(
682                    "moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
683                )));
684            }
685
686            // Gate||Up projection (fused) → [1, 2*expert_inter]
687            experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
688
689            // SiLU(gate) * up → [1, expert_inter]
690            <CpuBackend as Backend>::fused_silu_mul_split(
691                &mut ctx,
692                &gate_up_buf,
693                &mut silu_mul_buf,
694                1,
695                expert_intermediate,
696            );
697
698            // Down projection → [1, hidden]
699            experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
700
701            // Weighted accumulate into out[b, :]. Done host-side because
702            // CpuBackend::Buffer = Vec<f32> and the trait doesn't yet
703            // expose scaled-add.
704            let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
705            for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
706                *o += weight * *d;
707            }
708        }
709    }
710
711    Ok(())
712}
713
714fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
715    if actual != expected {
716        return Err(FerrumError::model(format!(
717            "ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
718        )));
719    }
720    Ok(())
721}
722
723/// Map candle's `GgmlDType` to the kernel-side `GgufQuantType` for the
724/// dtypes a backend can dispatch on. Returns `None` for any other dtype
725/// (callers fall back to eager dequant).
726fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
727    let info = gguf.tensor_info(name).ok_or_else(|| {
728        FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
729    })?;
730    Ok(match info.ggml_dtype {
731        GgmlDType::Q4K => Some(GgufQuantType::Q4K),
732        GgmlDType::Q6K => Some(GgufQuantType::Q6K),
733        _ => None,
734    })
735}
736
737/// Per-expert block-byte count for a given k-quant flavour and element
738/// count. Q4_K = 144 B / 256 elems, Q6_K = 210 B / 256 elems. Errors if
739/// `n_elems` is not a multiple of the super-block size (256) — a Q-quant
740/// invariant.
741fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
742    const QK_K: usize = 256;
743    if n_elems % QK_K != 0 {
744        return Err(FerrumError::model(format!(
745            "ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
746        )));
747    }
748    let block_bytes = match kind {
749        GgufQuantType::Q4K => 144,
750        GgufQuantType::Q6K => 210,
751        // Other k-quants are filtered out earlier via `quant_kind`; reaching here
752        // with one would be a programming error.
753        other => {
754            return Err(FerrumError::model(format!(
755                "ExpertStack {label}: unsupported k-quant flavour {other:?}"
756            )))
757        }
758    };
759    Ok((n_elems / QK_K) * block_bytes)
760}
761
762fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
763    let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
764    let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
765    let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
766    flat.to_vec1::<f32>().map_err(candle_to_ferrum)
767}
768
769fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
770    FerrumError::model(format!("candle: {e}"))
771}
772
773// Suppress unused-import warning when this module compiles standalone in
774// the lib (the candle Result alias is only used via map_err in Phase 2).
775#[allow(dead_code)]
776type _CandleResult<T> = CandleResult<T>;