Skip to main content

baracuda_kernels/moe/
mod.rs

1//! Mixture-of-Experts (MoE) inference forward — Phase 8 Milestone 8.5
2//! (Category V).
3//!
4//! Three fused per-token-dispatch + expert-matmul + accumulate kernel
5//! variants:
6//!
7//! - [`MoeVariant::ScalarGguf`] — scalar (no tensor cores) MoE GEMM over
8//!   GGUF-quantized expert weights. f32 activations, f32 output, q8_1
9//!   activation staging internally.
10//! - [`MoeVariant::Wmma`] — sm_70+ WMMA tensor cores over dense FP
11//!   (`f16` / `bf16`) expert weights. The FP MoE hot path.
12//! - [`MoeVariant::WmmaGguf`] — combined WMMA + GGUF path. f16/bf16
13//!   activations, GGUF-packed weights, f32 output. The production hot
14//!   path for quantized LLM inference.
15//!
16//! ## Lineage
17//!
18//! Vendored from [attention.rs](https://github.com/guoqingbao/attention.rs)
19//! via `fuel-cuda-kernels`. See
20//! `crates/baracuda-kernels-sys/LICENSE-thirdparty.md` for the full
21//! attribution chain and `kernels/include/baracuda_moe.cuh` for kernel-
22//! level lineage notes.
23//!
24//! ## Phase 20.2 — Fuel-replacement FFI surface (2026-05-25)
25//!
26//! The `baracuda_kernels_moe_*_run` C symbols are the canonical MoE
27//! surface; `fuel-cuda-kernels/src/moe/` retires in favour of direct
28//! calls to those symbols. Callers can bypass [`MoePlan`] entirely
29//! and call the FFI directly — see
30//! `crates/baracuda-kernels/tests/moe_ffi_direct_smoke.rs` for the
31//! reference call pattern. The plan layer (this module) and the FFI
32//! layer both reach the same kernel bodies in `baracuda_moe.cuh`.
33//!
34//! ## Block-format coverage
35//!
36//! The GGUF variants support `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`, `Q5_K`,
37//! and `Q6_K`. This matches Fuel's `moe_gemm_gguf` / `moe_gemm_gguf_prefill`
38//! switch exactly; the `Q4_0` / `Q4_1` / `Q5_0` / `Q5_1` block formats
39//! are NOT shipped by upstream for the MoE path (they'd require adding
40//! 4 new `vec_dot_q*_q8_1` wirings Fuel itself doesn't carry).
41//! [`MoePlan::select`] returns [`Error::Unsupported`] for any unsupported
42//! block format / variant combination.
43//!
44//! ## Inference-only
45//!
46//! All three variants are inference-only by convention; backward
47//! passes are not shipped. MoE training composes per-expert FFN ops
48//! manually at the autograd surface above.
49
50use core::ffi::c_void;
51
52use baracuda_cutlass::{Error, Result};
53use baracuda_driver::Stream;
54use baracuda_kernels_types::{
55    ArchSku, BackendKind, ElementKind, GgufBlockFormat, KernelSku, MathPrecision, MoeKind,
56    OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace, U8,
57};
58
59use crate::quantize::map_status;
60
61/// Selector for the MoE variant.
62///
63/// `#[non_exhaustive]` — additional MoE backend variants (FP8 expert
64/// weights, BF16+WMMA on Hopper, multi-block routing) may land in
65/// future phases. Match arms must include a `_ =>` catch-all.
66#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
67#[non_exhaustive]
68pub enum MoeVariant {
69    /// Scalar dispatch over GGUF-packed expert weights, f32 activations.
70    ScalarGguf,
71    /// WMMA tensor cores over FP expert weights, f16/bf16 activations.
72    Wmma,
73    /// WMMA tensor cores + GGUF-packed expert weights, f16/bf16
74    /// activations. The combined hot path for quantized LLM inference.
75    WmmaGguf,
76}
77
78/// Descriptor for an MoE forward op.
79#[derive(Copy, Clone, Debug)]
80pub struct MoeDescriptor {
81    /// Total number of tokens to process.
82    pub num_tokens: i32,
83    /// Number of experts in the MoE block.
84    pub num_experts: i32,
85    /// Number of experts each token is routed to (`top_k` in routing).
86    pub top_k: i32,
87    /// Hidden dim of the activation / output (`size_k` in Fuel-speak).
88    pub d_model: i32,
89    /// Per-expert output feature dim (`size_n` in Fuel-speak).
90    pub d_expert: i32,
91    /// Which kernel variant to dispatch.
92    pub variant: MoeVariant,
93    /// GGUF block format — must be `Some(...)` for `ScalarGguf` /
94    /// `WmmaGguf` variants and `None` for the `Wmma` variant.
95    pub block_format: Option<GgufBlockFormat>,
96    /// Activation element type. `F32` for `ScalarGguf`; `F16` or `Bf16`
97    /// for `Wmma` / `WmmaGguf`.
98    pub element: ElementKind,
99    /// `is_prefill` flag for the `Wmma` variant (selects between
100    /// prefill M=16 / N=16 / WARPS_N=2 and decode M=8 / N=32 / WARPS_N=1
101    /// tile geometries). Ignored by the other variants.
102    pub is_prefill: bool,
103}
104
105/// Args bundle for an MoE forward launch.
106///
107/// The expert weight matrix is carried as a raw byte buffer (`&[u8]`)
108/// so the same struct shape works for FP weights (`Wmma` variant) and
109/// GGUF-packed weights (`ScalarGguf` / `WmmaGguf`). Plan-side
110/// validation checks the byte length against the descriptor.
111pub struct MoeArgs<'a, T>
112where
113    T: baracuda_types::DeviceRepr + Copy + 'static,
114{
115    /// Activations `[num_tokens, d_model]`.
116    pub activations: TensorRef<'a, T, 2>,
117    /// Top-k expert indices `[num_tokens, top_k]`.
118    pub expert_indices: TensorRef<'a, i32, 2>,
119    /// Top-k expert mixing weights `[num_tokens, top_k]`.
120    pub expert_weights: TensorRef<'a, T, 2>,
121    /// Per-token sorted-by-expert flat index list `[num_tokens * top_k]`.
122    /// Pre-computed upstream (top-k routing already done).
123    pub sorted_token_ids: TensorRef<'a, i32, 1>,
124    /// Per-token expert id list aligned with `sorted_token_ids`
125    /// `[num_tokens * top_k]`. Already sorted by expert.
126    pub flat_expert_ids: TensorRef<'a, i32, 1>,
127    /// Optional per-token mixing weight `[num_tokens * top_k]`. When
128    /// `None`, the launcher passes `nullptr` and the kernel reads from
129    /// `expert_weights` via the routing path.
130    pub topk_weight_flat: Option<TensorRef<'a, f32, 1>>,
131    /// Packed expert weight bytes. For `Wmma`, must equal
132    /// `num_experts * d_expert * d_model * sizeof(T)`; for GGUF, must
133    /// equal `num_experts * d_expert * (d_model / block_size) * type_size`.
134    pub expert_matrices: TensorRef<'a, U8, 1>,
135    /// Output `[num_tokens, d_expert]`. For `WmmaGguf`, output is `f32`
136    /// regardless of activation dtype (kernel writes float directly).
137    pub output: TensorMut<'a, T, 2>,
138    /// Scratch buffer for the WMMA variants — must be at least
139    /// `num_experts * sizeof(i32)` bytes. Pass `None` for `ScalarGguf`.
140    pub expert_counts_scratch: Option<TensorMut<'a, i32, 1>>,
141    /// Scratch buffer for the WMMA variants — must be at least
142    /// `(num_experts + 1) * sizeof(i32)` bytes. Pass `None` for
143    /// `ScalarGguf`.
144    pub expert_offsets_scratch: Option<TensorMut<'a, i32, 1>>,
145}
146
147/// MoE forward plan.
148///
149/// Fused per-token dispatch + expert GEMM + accumulate over up to
150/// `top_k` experts. Inference-only.
151///
152/// **When to use**: forward MoE FFN pass. No BW plan — MoE training
153/// composes per-expert FFN ops manually at the autograd surface.
154/// Variant is selected at descriptor build time:
155///
156/// | variant       | acts        | weights         | output |
157/// |---------------|-------------|-----------------|--------|
158/// | `ScalarGguf`  | `f32`       | GGUF-packed     | `f32`  |
159/// | `Wmma`        | `f16`/`bf16`| dense FP        | `T`    |
160/// | `WmmaGguf`    | `f16`/`bf16`| GGUF-packed     | `f32`  |
161///
162/// **Shape limits**: `num_experts ≤ 1024` (WMMA scan kernel);
163/// `top_k ≥ 1`. For GGUF variants `d_model` must be a multiple of
164/// the block size.
165///
166/// **GGUF coverage**: `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`, `Q5_K`, `Q6_K`.
167/// `Q4_0`/`Q4_1`/`Q5_0`/`Q5_1`/`Q8K` are NOT shipped (Fuel upstream
168/// doesn't carry the `vec_dot_q*_q8_1` wirings for those).
169///
170/// **Workspace**: zero in [`Workspace`]. WMMA variants require
171/// caller-supplied `expert_counts_scratch` (`num_experts * i32`) and
172/// `expert_offsets_scratch` (`(num_experts + 1) * i32`) in
173/// [`MoeArgs`] instead.
174///
175/// **Precision guarantee**: deterministic, bit-stable on identical
176/// hardware (no atomics — top-k writes are to distinct token rows;
177/// per-token weight scaling is applied in-kernel).
178///
179/// # Variant / `topk_weight` semantics — **PENDING**
180///
181/// The reference CPU math for each variant is a known TODO: the
182/// `kernels/moe.cu` integration tests currently retain the kernel
183/// outputs via `let _ = ...` placeholders rather than asserting
184/// against a verified CPU reference. The exact composition rules
185/// — when the kernel reads `topk_weight_flat` vs `expert_weights`,
186/// the post-mix scaling order, the prefill-vs-decode tile-geometry
187/// numerical drift — are NOT yet pinned down by a reference
188/// implementation. Callers should treat any specific numerical
189/// output as kernel-defined until the reference lands. See
190/// `crates/baracuda-kernels/src/moe/mod.rs` and the integration
191/// tests under `crates/baracuda-kernels/tests/moe*.rs`.
192pub struct MoePlan {
193    desc: MoeDescriptor,
194    sku: KernelSku,
195}
196
197impl MoePlan {
198    /// Pick a kernel for `desc`. Errors on unsupported variant/dtype
199    /// combos, missing block format for GGUF variants, or non-positive
200    /// dims.
201    pub fn select(_stream: &Stream, desc: &MoeDescriptor, _pref: PlanPreference) -> Result<Self> {
202        if desc.num_tokens < 0
203            || desc.num_experts <= 0
204            || desc.top_k <= 0
205            || desc.d_model <= 0
206            || desc.d_expert <= 0
207        {
208            return Err(Error::InvalidProblem(
209                "MoePlan: tokens/experts/top_k/d_model/d_expert must be > 0 (tokens >= 0)",
210            ));
211        }
212        if desc.num_experts > 1024 {
213            return Err(Error::Unsupported(
214                "MoePlan: WMMA scan kernel only supports num_experts <= 1024",
215            ));
216        }
217        match desc.variant {
218            MoeVariant::ScalarGguf => {
219                if desc.element != ElementKind::F32 {
220                    return Err(Error::Unsupported(
221                        "MoePlan: ScalarGguf variant requires f32 activations",
222                    ));
223                }
224                let bf = desc.block_format.ok_or(Error::InvalidProblem(
225                    "MoePlan: ScalarGguf variant requires block_format = Some(...)",
226                ))?;
227                fuel_moe_gguf_dtype(bf)?;
228            }
229            MoeVariant::Wmma => {
230                if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
231                    return Err(Error::Unsupported(
232                        "MoePlan: Wmma variant requires f16 or bf16 activations",
233                    ));
234                }
235                if desc.block_format.is_some() {
236                    return Err(Error::InvalidProblem(
237                        "MoePlan: Wmma variant must not set block_format",
238                    ));
239                }
240            }
241            MoeVariant::WmmaGguf => {
242                if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
243                    return Err(Error::Unsupported(
244                        "MoePlan: WmmaGguf variant requires f16 or bf16 activations",
245                    ));
246                }
247                let bf = desc.block_format.ok_or(Error::InvalidProblem(
248                    "MoePlan: WmmaGguf variant requires block_format = Some(...)",
249                ))?;
250                fuel_moe_gguf_dtype(bf)?;
251                let bs = bf.block_size() as i32;
252                if desc.d_model % bs != 0 {
253                    return Err(Error::InvalidProblem(
254                        "MoePlan: d_model must be a multiple of the GGUF block size",
255                    ));
256                }
257            }
258        }
259        Ok(Self {
260            desc: *desc,
261            sku: build_sku(desc),
262        })
263    }
264
265    /// Workspace bytes — none. The WMMA variants need
266    /// `expert_counts_scratch` and `expert_offsets_scratch` but those
267    /// are carried in `MoeArgs`, not the workspace.
268    #[inline]
269    pub fn workspace_size(&self) -> usize {
270        0
271    }
272
273    /// Identity of the selected kernel.
274    #[inline]
275    pub fn sku(&self) -> KernelSku {
276        self.sku
277    }
278
279    /// Numerical guarantees.
280    #[inline]
281    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
282        self.sku.precision_guarantee
283    }
284
285    /// Launch the MoE forward kernel.
286    ///
287    /// `T` must match `desc.element` (compile-time bound enforced by
288    /// the [`TensorRef`] / [`TensorMut`] views in `args`).
289    pub fn run<T>(
290        &self,
291        stream: &Stream,
292        _workspace: Workspace<'_>,
293        args: MoeArgs<'_, T>,
294    ) -> Result<()>
295    where
296        T: baracuda_types::DeviceRepr + Copy + 'static,
297    {
298        let stream_ptr = stream.as_raw() as *mut c_void;
299        let acts_ptr = args.activations.data.as_raw().0 as *const c_void;
300        let weights_ptr = args.expert_matrices.data.as_raw().0 as *const c_void;
301        let sorted_token_ids_ptr = args.sorted_token_ids.data.as_raw().0 as *const i32;
302        let flat_expert_ids_ptr = args.flat_expert_ids.data.as_raw().0 as *const i32;
303        let topk_weights_ptr = args
304            .topk_weight_flat
305            .as_ref()
306            .map(|tw| tw.data.as_raw().0 as *const f32)
307            .unwrap_or(core::ptr::null());
308        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
309
310        let num_tokens_flat = args.sorted_token_ids.shape[0];
311
312        let status = match self.desc.variant {
313            MoeVariant::ScalarGguf => {
314                let bf = self.desc.block_format.expect("checked in select()");
315                let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
316                unsafe {
317                    baracuda_kernels_sys::baracuda_kernels_moe_scalar_gguf_run(
318                        acts_ptr,
319                        weights_ptr,
320                        sorted_token_ids_ptr,
321                        flat_expert_ids_ptr,
322                        topk_weights_ptr,
323                        out_ptr,
324                        self.desc.num_experts,
325                        self.desc.top_k,
326                        num_tokens_flat,
327                        self.desc.d_expert,
328                        self.desc.d_model,
329                        gguf_dtype,
330                        core::ptr::null_mut(),
331                        0,
332                        stream_ptr,
333                    )
334                }
335            }
336            MoeVariant::Wmma => {
337                let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
338                    "MoePlan::run: Wmma variant requires expert_counts_scratch",
339                ))?;
340                let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
341                    "MoePlan::run: Wmma variant requires expert_offsets_scratch",
342                ))?;
343                let ec_ptr = ec.data.as_raw().0 as *mut i32;
344                let eo_ptr = eo.data.as_raw().0 as *mut i32;
345                let is_prefill = if self.desc.is_prefill { 1 } else { 0 };
346                match self.desc.element {
347                    ElementKind::F16 => unsafe {
348                        baracuda_kernels_sys::baracuda_kernels_moe_wmma_f16_run(
349                            acts_ptr,
350                            weights_ptr,
351                            sorted_token_ids_ptr,
352                            flat_expert_ids_ptr,
353                            topk_weights_ptr,
354                            out_ptr,
355                            ec_ptr,
356                            eo_ptr,
357                            self.desc.num_experts,
358                            self.desc.top_k,
359                            num_tokens_flat,
360                            self.desc.d_expert,
361                            self.desc.d_model,
362                            is_prefill,
363                            core::ptr::null_mut(),
364                            0,
365                            stream_ptr,
366                        )
367                    },
368                    ElementKind::Bf16 => unsafe {
369                        baracuda_kernels_sys::baracuda_kernels_moe_wmma_bf16_run(
370                            acts_ptr,
371                            weights_ptr,
372                            sorted_token_ids_ptr,
373                            flat_expert_ids_ptr,
374                            topk_weights_ptr,
375                            out_ptr,
376                            ec_ptr,
377                            eo_ptr,
378                            self.desc.num_experts,
379                            self.desc.top_k,
380                            num_tokens_flat,
381                            self.desc.d_expert,
382                            self.desc.d_model,
383                            is_prefill,
384                            core::ptr::null_mut(),
385                            0,
386                            stream_ptr,
387                        )
388                    },
389                    _ => return Err(Error::Unsupported("MoePlan::run: Wmma element unsupported")),
390                }
391            }
392            MoeVariant::WmmaGguf => {
393                let bf = self.desc.block_format.expect("checked in select()");
394                let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
395                let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
396                    "MoePlan::run: WmmaGguf variant requires expert_counts_scratch",
397                ))?;
398                let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
399                    "MoePlan::run: WmmaGguf variant requires expert_offsets_scratch",
400                ))?;
401                let ec_ptr = ec.data.as_raw().0 as *mut i32;
402                let eo_ptr = eo.data.as_raw().0 as *mut i32;
403                match self.desc.element {
404                    ElementKind::F16 => unsafe {
405                        baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_f16_run(
406                            acts_ptr,
407                            weights_ptr,
408                            sorted_token_ids_ptr,
409                            flat_expert_ids_ptr,
410                            topk_weights_ptr,
411                            out_ptr,
412                            ec_ptr,
413                            eo_ptr,
414                            self.desc.num_experts,
415                            self.desc.top_k,
416                            num_tokens_flat,
417                            self.desc.d_expert,
418                            self.desc.d_model,
419                            gguf_dtype,
420                            core::ptr::null_mut(),
421                            0,
422                            stream_ptr,
423                        )
424                    },
425                    ElementKind::Bf16 => unsafe {
426                        baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_bf16_run(
427                            acts_ptr,
428                            weights_ptr,
429                            sorted_token_ids_ptr,
430                            flat_expert_ids_ptr,
431                            topk_weights_ptr,
432                            out_ptr,
433                            ec_ptr,
434                            eo_ptr,
435                            self.desc.num_experts,
436                            self.desc.top_k,
437                            num_tokens_flat,
438                            self.desc.d_expert,
439                            self.desc.d_model,
440                            gguf_dtype,
441                            core::ptr::null_mut(),
442                            0,
443                            stream_ptr,
444                        )
445                    },
446                    _ => return Err(Error::Unsupported("MoePlan::run: WmmaGguf element unsupported")),
447                }
448            }
449        };
450        map_status(status)
451    }
452}
453
454/// Translate baracuda's `GgufBlockFormat` into the Fuel-convention
455/// `gguf_dtype` discriminant expected by the `moe_*_gguf_run` FFI
456/// (matches the switch in Fuel's `moe_gemm_gguf`):
457///   `0 = Q8_0`, `1 = Q4_K`, `2 = Q2_K`, `3 = Q3_K`, `4 = Q5_K`, `5 = Q6_K`.
458fn fuel_moe_gguf_dtype(bf: GgufBlockFormat) -> Result<i32> {
459    match bf {
460        GgufBlockFormat::Q8_0 => Ok(0),
461        GgufBlockFormat::Q4K => Ok(1),
462        GgufBlockFormat::Q2K => Ok(2),
463        GgufBlockFormat::Q3K => Ok(3),
464        GgufBlockFormat::Q5K => Ok(4),
465        GgufBlockFormat::Q6K => Ok(5),
466        GgufBlockFormat::Q4_0
467        | GgufBlockFormat::Q4_1
468        | GgufBlockFormat::Q5_0
469        | GgufBlockFormat::Q5_1
470        | GgufBlockFormat::Q8K => Err(Error::Unsupported(
471            "MoePlan: GGUF MoE variants only support Q8_0 + k-quants (Q2_K..Q6_K)",
472        )),
473        // Defensive arm — `GgufBlockFormat` is `#[non_exhaustive]`.
474        _ => Err(Error::Unsupported(
475            "MoePlan: unsupported GGUF block format",
476        )),
477    }
478}
479
480fn build_sku(desc: &MoeDescriptor) -> KernelSku {
481    let op = match desc.variant {
482        MoeVariant::ScalarGguf => MoeKind::ScalarGguf as u16,
483        MoeVariant::Wmma => MoeKind::Wmma as u16,
484        MoeVariant::WmmaGguf => MoeKind::WmmaGguf as u16,
485    };
486    KernelSku {
487        category: OpCategory::Moe,
488        op,
489        element: desc.element,
490        aux_element: Some(ElementKind::U8),
491        layout: None,
492        epilogue: None,
493        arch: ArchSku::Sm80, // sm_70+; sm_80 is the baseline arch baracuda exposes.
494        backend: BackendKind::Bespoke,
495        precision_guarantee: PrecisionGuarantee {
496            math_precision: MathPrecision::F32,
497            accumulator: ElementKind::F32,
498            // Atomic-free (top-k > 1 writes are to distinct token rows
499            // when `topk_weights == None`; otherwise the per-token-weight
500            // scaling is applied in the kernel and the output is written
501            // directly). Deterministic on identical hardware.
502            bit_stable_on_same_hardware: true,
503            deterministic: true,
504        },
505    }
506}