baracuda_kernels/moe/mod.rs
1//! Mixture-of-Experts (MoE) inference forward — Phase 8 Milestone 8.5
2//! (Category V).
3//!
4//! Three fused per-token-dispatch + expert-matmul + accumulate kernel
5//! variants:
6//!
7//! - [`MoeVariant::ScalarGguf`] — scalar (no tensor cores) MoE GEMM over
8//! GGUF-quantized expert weights. f32 activations, f32 output, q8_1
9//! activation staging internally.
10//! - [`MoeVariant::Wmma`] — sm_70+ WMMA tensor cores over dense FP
11//! (`f16` / `bf16`) expert weights. The FP MoE hot path.
12//! - [`MoeVariant::WmmaGguf`] — combined WMMA + GGUF path. f16/bf16
13//! activations, GGUF-packed weights, f32 output. The production hot
14//! path for quantized LLM inference.
15//!
16//! ## Lineage
17//!
18//! Vendored from [attention.rs](https://github.com/guoqingbao/attention.rs)
19//! via `fuel-cuda-kernels`. See
20//! `crates/baracuda-kernels-sys/LICENSE-thirdparty.md` for the full
21//! attribution chain and `kernels/include/baracuda_moe.cuh` for kernel-
22//! level lineage notes.
23//!
24//! ## Phase 20.2 — Fuel-replacement FFI surface (2026-05-25)
25//!
26//! The `baracuda_kernels_moe_*_run` C symbols are the canonical MoE
27//! surface; `fuel-cuda-kernels/src/moe/` retires in favour of direct
28//! calls to those symbols. Callers can bypass [`MoePlan`] entirely
29//! and call the FFI directly — see
30//! `crates/baracuda-kernels/tests/moe_ffi_direct_smoke.rs` for the
31//! reference call pattern. The plan layer (this module) and the FFI
32//! layer both reach the same kernel bodies in `baracuda_moe.cuh`.
33//!
34//! ## Block-format coverage
35//!
36//! The GGUF variants support `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`, `Q5_K`,
37//! and `Q6_K`. This matches Fuel's `moe_gemm_gguf` / `moe_gemm_gguf_prefill`
38//! switch exactly; the `Q4_0` / `Q4_1` / `Q5_0` / `Q5_1` block formats
39//! are NOT shipped by upstream for the MoE path (they'd require adding
40//! 4 new `vec_dot_q*_q8_1` wirings Fuel itself doesn't carry).
41//! [`MoePlan::select`] returns [`Error::Unsupported`] for any unsupported
42//! block format / variant combination.
43//!
44//! ## Inference-only
45//!
46//! All three variants are inference-only by convention; backward
47//! passes are not shipped. MoE training composes per-expert FFN ops
48//! manually at the autograd surface above.
49
50use core::ffi::c_void;
51
52use baracuda_cutlass::{Error, Result};
53use baracuda_driver::Stream;
54use baracuda_kernels_types::{
55 ArchSku, BackendKind, ElementKind, GgufBlockFormat, KernelSku, MathPrecision, MoeKind,
56 OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace, U8,
57};
58
59use crate::quantize::map_status;
60
61/// Selector for the MoE variant.
62///
63/// `#[non_exhaustive]` — additional MoE backend variants (FP8 expert
64/// weights, BF16+WMMA on Hopper, multi-block routing) may land in
65/// future phases. Match arms must include a `_ =>` catch-all.
66#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
67#[non_exhaustive]
68pub enum MoeVariant {
69 /// Scalar dispatch over GGUF-packed expert weights, f32 activations.
70 ScalarGguf,
71 /// WMMA tensor cores over FP expert weights, f16/bf16 activations.
72 Wmma,
73 /// WMMA tensor cores + GGUF-packed expert weights, f16/bf16
74 /// activations. The combined hot path for quantized LLM inference.
75 WmmaGguf,
76}
77
78/// Descriptor for an MoE forward op.
79#[derive(Copy, Clone, Debug)]
80pub struct MoeDescriptor {
81 /// Total number of tokens to process.
82 pub num_tokens: i32,
83 /// Number of experts in the MoE block.
84 pub num_experts: i32,
85 /// Number of experts each token is routed to (`top_k` in routing).
86 pub top_k: i32,
87 /// Hidden dim of the activation / output (`size_k` in Fuel-speak).
88 pub d_model: i32,
89 /// Per-expert output feature dim (`size_n` in Fuel-speak).
90 pub d_expert: i32,
91 /// Which kernel variant to dispatch.
92 pub variant: MoeVariant,
93 /// GGUF block format — must be `Some(...)` for `ScalarGguf` /
94 /// `WmmaGguf` variants and `None` for the `Wmma` variant.
95 pub block_format: Option<GgufBlockFormat>,
96 /// Activation element type. `F32` for `ScalarGguf`; `F16` or `Bf16`
97 /// for `Wmma` / `WmmaGguf`.
98 pub element: ElementKind,
99 /// `is_prefill` flag for the `Wmma` variant (selects between
100 /// prefill M=16 / N=16 / WARPS_N=2 and decode M=8 / N=32 / WARPS_N=1
101 /// tile geometries). Ignored by the other variants.
102 pub is_prefill: bool,
103}
104
105/// Args bundle for an MoE forward launch.
106///
107/// The expert weight matrix is carried as a raw byte buffer (`&[u8]`)
108/// so the same struct shape works for FP weights (`Wmma` variant) and
109/// GGUF-packed weights (`ScalarGguf` / `WmmaGguf`). Plan-side
110/// validation checks the byte length against the descriptor.
111pub struct MoeArgs<'a, T>
112where
113 T: baracuda_types::DeviceRepr + Copy + 'static,
114{
115 /// Activations `[num_tokens, d_model]`.
116 pub activations: TensorRef<'a, T, 2>,
117 /// Top-k expert indices `[num_tokens, top_k]`.
118 pub expert_indices: TensorRef<'a, i32, 2>,
119 /// Top-k expert mixing weights `[num_tokens, top_k]`.
120 pub expert_weights: TensorRef<'a, T, 2>,
121 /// Per-token sorted-by-expert flat index list `[num_tokens * top_k]`.
122 /// Pre-computed upstream (top-k routing already done).
123 pub sorted_token_ids: TensorRef<'a, i32, 1>,
124 /// Per-token expert id list aligned with `sorted_token_ids`
125 /// `[num_tokens * top_k]`. Already sorted by expert.
126 pub flat_expert_ids: TensorRef<'a, i32, 1>,
127 /// Optional per-token mixing weight `[num_tokens * top_k]`. When
128 /// `None`, the launcher passes `nullptr` and the kernel reads from
129 /// `expert_weights` via the routing path.
130 pub topk_weight_flat: Option<TensorRef<'a, f32, 1>>,
131 /// Packed expert weight bytes. For `Wmma`, must equal
132 /// `num_experts * d_expert * d_model * sizeof(T)`; for GGUF, must
133 /// equal `num_experts * d_expert * (d_model / block_size) * type_size`.
134 pub expert_matrices: TensorRef<'a, U8, 1>,
135 /// Output `[num_tokens, d_expert]`. For `WmmaGguf`, output is `f32`
136 /// regardless of activation dtype (kernel writes float directly).
137 pub output: TensorMut<'a, T, 2>,
138 /// Scratch buffer for the WMMA variants — must be at least
139 /// `num_experts * sizeof(i32)` bytes. Pass `None` for `ScalarGguf`.
140 pub expert_counts_scratch: Option<TensorMut<'a, i32, 1>>,
141 /// Scratch buffer for the WMMA variants — must be at least
142 /// `(num_experts + 1) * sizeof(i32)` bytes. Pass `None` for
143 /// `ScalarGguf`.
144 pub expert_offsets_scratch: Option<TensorMut<'a, i32, 1>>,
145}
146
147/// MoE forward plan.
148///
149/// Fused per-token dispatch + expert GEMM + accumulate over up to
150/// `top_k` experts. Inference-only.
151///
152/// **When to use**: forward MoE FFN pass. No BW plan — MoE training
153/// composes per-expert FFN ops manually at the autograd surface.
154/// Variant is selected at descriptor build time:
155///
156/// | variant | acts | weights | output |
157/// |---------------|-------------|-----------------|--------|
158/// | `ScalarGguf` | `f32` | GGUF-packed | `f32` |
159/// | `Wmma` | `f16`/`bf16`| dense FP | `T` |
160/// | `WmmaGguf` | `f16`/`bf16`| GGUF-packed | `f32` |
161///
162/// **Shape limits**: `num_experts ≤ 1024` (WMMA scan kernel);
163/// `top_k ≥ 1`. For GGUF variants `d_model` must be a multiple of
164/// the block size.
165///
166/// **GGUF coverage**: `Q8_0`, `Q2_K`, `Q3_K`, `Q4_K`, `Q5_K`, `Q6_K`.
167/// `Q4_0`/`Q4_1`/`Q5_0`/`Q5_1`/`Q8K` are NOT shipped (Fuel upstream
168/// doesn't carry the `vec_dot_q*_q8_1` wirings for those).
169///
170/// **Workspace**: zero in [`Workspace`]. WMMA variants require
171/// caller-supplied `expert_counts_scratch` (`num_experts * i32`) and
172/// `expert_offsets_scratch` (`(num_experts + 1) * i32`) in
173/// [`MoeArgs`] instead.
174///
175/// **Precision guarantee**: deterministic, bit-stable on identical
176/// hardware (no atomics — top-k writes are to distinct token rows;
177/// per-token weight scaling is applied in-kernel).
178///
179/// # Variant / `topk_weight` semantics — **PENDING**
180///
181/// The reference CPU math for each variant is a known TODO: the
182/// `kernels/moe.cu` integration tests currently retain the kernel
183/// outputs via `let _ = ...` placeholders rather than asserting
184/// against a verified CPU reference. The exact composition rules
185/// — when the kernel reads `topk_weight_flat` vs `expert_weights`,
186/// the post-mix scaling order, the prefill-vs-decode tile-geometry
187/// numerical drift — are NOT yet pinned down by a reference
188/// implementation. Callers should treat any specific numerical
189/// output as kernel-defined until the reference lands. See
190/// `crates/baracuda-kernels/src/moe/mod.rs` and the integration
191/// tests under `crates/baracuda-kernels/tests/moe*.rs`.
192pub struct MoePlan {
193 desc: MoeDescriptor,
194 sku: KernelSku,
195}
196
197impl MoePlan {
198 /// Pick a kernel for `desc`. Errors on unsupported variant/dtype
199 /// combos, missing block format for GGUF variants, or non-positive
200 /// dims.
201 pub fn select(_stream: &Stream, desc: &MoeDescriptor, _pref: PlanPreference) -> Result<Self> {
202 if desc.num_tokens < 0
203 || desc.num_experts <= 0
204 || desc.top_k <= 0
205 || desc.d_model <= 0
206 || desc.d_expert <= 0
207 {
208 return Err(Error::InvalidProblem(
209 "MoePlan: tokens/experts/top_k/d_model/d_expert must be > 0 (tokens >= 0)",
210 ));
211 }
212 if desc.num_experts > 1024 {
213 return Err(Error::Unsupported(
214 "MoePlan: WMMA scan kernel only supports num_experts <= 1024",
215 ));
216 }
217 match desc.variant {
218 MoeVariant::ScalarGguf => {
219 if desc.element != ElementKind::F32 {
220 return Err(Error::Unsupported(
221 "MoePlan: ScalarGguf variant requires f32 activations",
222 ));
223 }
224 let bf = desc.block_format.ok_or(Error::InvalidProblem(
225 "MoePlan: ScalarGguf variant requires block_format = Some(...)",
226 ))?;
227 fuel_moe_gguf_dtype(bf)?;
228 }
229 MoeVariant::Wmma => {
230 if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
231 return Err(Error::Unsupported(
232 "MoePlan: Wmma variant requires f16 or bf16 activations",
233 ));
234 }
235 if desc.block_format.is_some() {
236 return Err(Error::InvalidProblem(
237 "MoePlan: Wmma variant must not set block_format",
238 ));
239 }
240 }
241 MoeVariant::WmmaGguf => {
242 if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
243 return Err(Error::Unsupported(
244 "MoePlan: WmmaGguf variant requires f16 or bf16 activations",
245 ));
246 }
247 let bf = desc.block_format.ok_or(Error::InvalidProblem(
248 "MoePlan: WmmaGguf variant requires block_format = Some(...)",
249 ))?;
250 fuel_moe_gguf_dtype(bf)?;
251 let bs = bf.block_size() as i32;
252 if desc.d_model % bs != 0 {
253 return Err(Error::InvalidProblem(
254 "MoePlan: d_model must be a multiple of the GGUF block size",
255 ));
256 }
257 }
258 }
259 Ok(Self {
260 desc: *desc,
261 sku: build_sku(desc),
262 })
263 }
264
265 /// Workspace bytes — none. The WMMA variants need
266 /// `expert_counts_scratch` and `expert_offsets_scratch` but those
267 /// are carried in `MoeArgs`, not the workspace.
268 #[inline]
269 pub fn workspace_size(&self) -> usize {
270 0
271 }
272
273 /// Identity of the selected kernel.
274 #[inline]
275 pub fn sku(&self) -> KernelSku {
276 self.sku
277 }
278
279 /// Numerical guarantees.
280 #[inline]
281 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
282 self.sku.precision_guarantee
283 }
284
285 /// Launch the MoE forward kernel.
286 ///
287 /// `T` must match `desc.element` (compile-time bound enforced by
288 /// the [`TensorRef`] / [`TensorMut`] views in `args`).
289 pub fn run<T>(
290 &self,
291 stream: &Stream,
292 _workspace: Workspace<'_>,
293 args: MoeArgs<'_, T>,
294 ) -> Result<()>
295 where
296 T: baracuda_types::DeviceRepr + Copy + 'static,
297 {
298 let stream_ptr = stream.as_raw() as *mut c_void;
299 let acts_ptr = args.activations.data.as_raw().0 as *const c_void;
300 let weights_ptr = args.expert_matrices.data.as_raw().0 as *const c_void;
301 let sorted_token_ids_ptr = args.sorted_token_ids.data.as_raw().0 as *const i32;
302 let flat_expert_ids_ptr = args.flat_expert_ids.data.as_raw().0 as *const i32;
303 let topk_weights_ptr = args
304 .topk_weight_flat
305 .as_ref()
306 .map(|tw| tw.data.as_raw().0 as *const f32)
307 .unwrap_or(core::ptr::null());
308 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
309
310 let num_tokens_flat = args.sorted_token_ids.shape[0];
311
312 let status = match self.desc.variant {
313 MoeVariant::ScalarGguf => {
314 let bf = self.desc.block_format.expect("checked in select()");
315 let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
316 unsafe {
317 baracuda_kernels_sys::baracuda_kernels_moe_scalar_gguf_run(
318 acts_ptr,
319 weights_ptr,
320 sorted_token_ids_ptr,
321 flat_expert_ids_ptr,
322 topk_weights_ptr,
323 out_ptr,
324 self.desc.num_experts,
325 self.desc.top_k,
326 num_tokens_flat,
327 self.desc.d_expert,
328 self.desc.d_model,
329 gguf_dtype,
330 core::ptr::null_mut(),
331 0,
332 stream_ptr,
333 )
334 }
335 }
336 MoeVariant::Wmma => {
337 let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
338 "MoePlan::run: Wmma variant requires expert_counts_scratch",
339 ))?;
340 let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
341 "MoePlan::run: Wmma variant requires expert_offsets_scratch",
342 ))?;
343 let ec_ptr = ec.data.as_raw().0 as *mut i32;
344 let eo_ptr = eo.data.as_raw().0 as *mut i32;
345 let is_prefill = if self.desc.is_prefill { 1 } else { 0 };
346 match self.desc.element {
347 ElementKind::F16 => unsafe {
348 baracuda_kernels_sys::baracuda_kernels_moe_wmma_f16_run(
349 acts_ptr,
350 weights_ptr,
351 sorted_token_ids_ptr,
352 flat_expert_ids_ptr,
353 topk_weights_ptr,
354 out_ptr,
355 ec_ptr,
356 eo_ptr,
357 self.desc.num_experts,
358 self.desc.top_k,
359 num_tokens_flat,
360 self.desc.d_expert,
361 self.desc.d_model,
362 is_prefill,
363 core::ptr::null_mut(),
364 0,
365 stream_ptr,
366 )
367 },
368 ElementKind::Bf16 => unsafe {
369 baracuda_kernels_sys::baracuda_kernels_moe_wmma_bf16_run(
370 acts_ptr,
371 weights_ptr,
372 sorted_token_ids_ptr,
373 flat_expert_ids_ptr,
374 topk_weights_ptr,
375 out_ptr,
376 ec_ptr,
377 eo_ptr,
378 self.desc.num_experts,
379 self.desc.top_k,
380 num_tokens_flat,
381 self.desc.d_expert,
382 self.desc.d_model,
383 is_prefill,
384 core::ptr::null_mut(),
385 0,
386 stream_ptr,
387 )
388 },
389 _ => return Err(Error::Unsupported("MoePlan::run: Wmma element unsupported")),
390 }
391 }
392 MoeVariant::WmmaGguf => {
393 let bf = self.desc.block_format.expect("checked in select()");
394 let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
395 let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
396 "MoePlan::run: WmmaGguf variant requires expert_counts_scratch",
397 ))?;
398 let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
399 "MoePlan::run: WmmaGguf variant requires expert_offsets_scratch",
400 ))?;
401 let ec_ptr = ec.data.as_raw().0 as *mut i32;
402 let eo_ptr = eo.data.as_raw().0 as *mut i32;
403 match self.desc.element {
404 ElementKind::F16 => unsafe {
405 baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_f16_run(
406 acts_ptr,
407 weights_ptr,
408 sorted_token_ids_ptr,
409 flat_expert_ids_ptr,
410 topk_weights_ptr,
411 out_ptr,
412 ec_ptr,
413 eo_ptr,
414 self.desc.num_experts,
415 self.desc.top_k,
416 num_tokens_flat,
417 self.desc.d_expert,
418 self.desc.d_model,
419 gguf_dtype,
420 core::ptr::null_mut(),
421 0,
422 stream_ptr,
423 )
424 },
425 ElementKind::Bf16 => unsafe {
426 baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_bf16_run(
427 acts_ptr,
428 weights_ptr,
429 sorted_token_ids_ptr,
430 flat_expert_ids_ptr,
431 topk_weights_ptr,
432 out_ptr,
433 ec_ptr,
434 eo_ptr,
435 self.desc.num_experts,
436 self.desc.top_k,
437 num_tokens_flat,
438 self.desc.d_expert,
439 self.desc.d_model,
440 gguf_dtype,
441 core::ptr::null_mut(),
442 0,
443 stream_ptr,
444 )
445 },
446 _ => return Err(Error::Unsupported("MoePlan::run: WmmaGguf element unsupported")),
447 }
448 }
449 };
450 map_status(status)
451 }
452}
453
454/// Translate baracuda's `GgufBlockFormat` into the Fuel-convention
455/// `gguf_dtype` discriminant expected by the `moe_*_gguf_run` FFI
456/// (matches the switch in Fuel's `moe_gemm_gguf`):
457/// `0 = Q8_0`, `1 = Q4_K`, `2 = Q2_K`, `3 = Q3_K`, `4 = Q5_K`, `5 = Q6_K`.
458fn fuel_moe_gguf_dtype(bf: GgufBlockFormat) -> Result<i32> {
459 match bf {
460 GgufBlockFormat::Q8_0 => Ok(0),
461 GgufBlockFormat::Q4K => Ok(1),
462 GgufBlockFormat::Q2K => Ok(2),
463 GgufBlockFormat::Q3K => Ok(3),
464 GgufBlockFormat::Q5K => Ok(4),
465 GgufBlockFormat::Q6K => Ok(5),
466 GgufBlockFormat::Q4_0
467 | GgufBlockFormat::Q4_1
468 | GgufBlockFormat::Q5_0
469 | GgufBlockFormat::Q5_1
470 | GgufBlockFormat::Q8K => Err(Error::Unsupported(
471 "MoePlan: GGUF MoE variants only support Q8_0 + k-quants (Q2_K..Q6_K)",
472 )),
473 // Defensive arm — `GgufBlockFormat` is `#[non_exhaustive]`.
474 _ => Err(Error::Unsupported(
475 "MoePlan: unsupported GGUF block format",
476 )),
477 }
478}
479
480fn build_sku(desc: &MoeDescriptor) -> KernelSku {
481 let op = match desc.variant {
482 MoeVariant::ScalarGguf => MoeKind::ScalarGguf as u16,
483 MoeVariant::Wmma => MoeKind::Wmma as u16,
484 MoeVariant::WmmaGguf => MoeKind::WmmaGguf as u16,
485 };
486 KernelSku {
487 category: OpCategory::Moe,
488 op,
489 element: desc.element,
490 aux_element: Some(ElementKind::U8),
491 layout: None,
492 epilogue: None,
493 arch: ArchSku::Sm80, // sm_70+; sm_80 is the baseline arch baracuda exposes.
494 backend: BackendKind::Bespoke,
495 precision_guarantee: PrecisionGuarantee {
496 math_precision: MathPrecision::F32,
497 accumulator: ElementKind::F32,
498 // Atomic-free (top-k > 1 writes are to distinct token rows
499 // when `topk_weights == None`; otherwise the per-token-weight
500 // scaling is applied in the kernel and the output is written
501 // directly). Deterministic on identical hardware.
502 bit_stable_on_same_hardware: true,
503 deterministic: true,
504 },
505 }
506}