use core::ffi::c_void;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, GgufBlockFormat, KernelSku, MathPrecision, MoeKind,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace, U8,
};
use crate::quantize::map_status;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum MoeVariant {
ScalarGguf,
Wmma,
WmmaGguf,
}
#[derive(Copy, Clone, Debug)]
pub struct MoeDescriptor {
pub num_tokens: i32,
pub num_experts: i32,
pub top_k: i32,
pub d_model: i32,
pub d_expert: i32,
pub variant: MoeVariant,
pub block_format: Option<GgufBlockFormat>,
pub element: ElementKind,
pub is_prefill: bool,
}
pub struct MoeArgs<'a, T>
where
T: baracuda_types::DeviceRepr + Copy + 'static,
{
pub activations: TensorRef<'a, T, 2>,
pub expert_indices: TensorRef<'a, i32, 2>,
pub expert_weights: TensorRef<'a, T, 2>,
pub sorted_token_ids: TensorRef<'a, i32, 1>,
pub flat_expert_ids: TensorRef<'a, i32, 1>,
pub topk_weight_flat: Option<TensorRef<'a, f32, 1>>,
pub expert_matrices: TensorRef<'a, U8, 1>,
pub output: TensorMut<'a, T, 2>,
pub expert_counts_scratch: Option<TensorMut<'a, i32, 1>>,
pub expert_offsets_scratch: Option<TensorMut<'a, i32, 1>>,
}
pub struct MoePlan {
desc: MoeDescriptor,
sku: KernelSku,
}
impl MoePlan {
pub fn select(_stream: &Stream, desc: &MoeDescriptor, _pref: PlanPreference) -> Result<Self> {
if desc.num_tokens < 0
|| desc.num_experts <= 0
|| desc.top_k <= 0
|| desc.d_model <= 0
|| desc.d_expert <= 0
{
return Err(Error::InvalidProblem(
"MoePlan: tokens/experts/top_k/d_model/d_expert must be > 0 (tokens >= 0)",
));
}
if desc.num_experts > 1024 {
return Err(Error::Unsupported(
"MoePlan: WMMA scan kernel only supports num_experts <= 1024",
));
}
match desc.variant {
MoeVariant::ScalarGguf => {
if desc.element != ElementKind::F32 {
return Err(Error::Unsupported(
"MoePlan: ScalarGguf variant requires f32 activations",
));
}
let bf = desc.block_format.ok_or(Error::InvalidProblem(
"MoePlan: ScalarGguf variant requires block_format = Some(...)",
))?;
fuel_moe_gguf_dtype(bf)?;
}
MoeVariant::Wmma => {
if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
return Err(Error::Unsupported(
"MoePlan: Wmma variant requires f16 or bf16 activations",
));
}
if desc.block_format.is_some() {
return Err(Error::InvalidProblem(
"MoePlan: Wmma variant must not set block_format",
));
}
}
MoeVariant::WmmaGguf => {
if desc.element != ElementKind::F16 && desc.element != ElementKind::Bf16 {
return Err(Error::Unsupported(
"MoePlan: WmmaGguf variant requires f16 or bf16 activations",
));
}
let bf = desc.block_format.ok_or(Error::InvalidProblem(
"MoePlan: WmmaGguf variant requires block_format = Some(...)",
))?;
fuel_moe_gguf_dtype(bf)?;
let bs = bf.block_size() as i32;
if desc.d_model % bs != 0 {
return Err(Error::InvalidProblem(
"MoePlan: d_model must be a multiple of the GGUF block size",
));
}
}
}
Ok(Self {
desc: *desc,
sku: build_sku(desc),
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run<T>(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: MoeArgs<'_, T>,
) -> Result<()>
where
T: baracuda_types::DeviceRepr + Copy + 'static,
{
let stream_ptr = stream.as_raw() as *mut c_void;
let acts_ptr = args.activations.data.as_raw().0 as *const c_void;
let weights_ptr = args.expert_matrices.data.as_raw().0 as *const c_void;
let sorted_token_ids_ptr = args.sorted_token_ids.data.as_raw().0 as *const i32;
let flat_expert_ids_ptr = args.flat_expert_ids.data.as_raw().0 as *const i32;
let topk_weights_ptr = args
.topk_weight_flat
.as_ref()
.map(|tw| tw.data.as_raw().0 as *const f32)
.unwrap_or(core::ptr::null());
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let num_tokens_flat = args.sorted_token_ids.shape[0];
let status = match self.desc.variant {
MoeVariant::ScalarGguf => {
let bf = self.desc.block_format.expect("checked in select()");
let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
unsafe {
baracuda_kernels_sys::baracuda_kernels_moe_scalar_gguf_run(
acts_ptr,
weights_ptr,
sorted_token_ids_ptr,
flat_expert_ids_ptr,
topk_weights_ptr,
out_ptr,
self.desc.num_experts,
self.desc.top_k,
num_tokens_flat,
self.desc.d_expert,
self.desc.d_model,
gguf_dtype,
core::ptr::null_mut(),
0,
stream_ptr,
)
}
}
MoeVariant::Wmma => {
let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
"MoePlan::run: Wmma variant requires expert_counts_scratch",
))?;
let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
"MoePlan::run: Wmma variant requires expert_offsets_scratch",
))?;
let ec_ptr = ec.data.as_raw().0 as *mut i32;
let eo_ptr = eo.data.as_raw().0 as *mut i32;
let is_prefill = if self.desc.is_prefill { 1 } else { 0 };
match self.desc.element {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_moe_wmma_f16_run(
acts_ptr,
weights_ptr,
sorted_token_ids_ptr,
flat_expert_ids_ptr,
topk_weights_ptr,
out_ptr,
ec_ptr,
eo_ptr,
self.desc.num_experts,
self.desc.top_k,
num_tokens_flat,
self.desc.d_expert,
self.desc.d_model,
is_prefill,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_moe_wmma_bf16_run(
acts_ptr,
weights_ptr,
sorted_token_ids_ptr,
flat_expert_ids_ptr,
topk_weights_ptr,
out_ptr,
ec_ptr,
eo_ptr,
self.desc.num_experts,
self.desc.top_k,
num_tokens_flat,
self.desc.d_expert,
self.desc.d_model,
is_prefill,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
_ => return Err(Error::Unsupported("MoePlan::run: Wmma element unsupported")),
}
}
MoeVariant::WmmaGguf => {
let bf = self.desc.block_format.expect("checked in select()");
let gguf_dtype = fuel_moe_gguf_dtype(bf).expect("checked in select()");
let ec = args.expert_counts_scratch.as_ref().ok_or(Error::InvalidProblem(
"MoePlan::run: WmmaGguf variant requires expert_counts_scratch",
))?;
let eo = args.expert_offsets_scratch.as_ref().ok_or(Error::InvalidProblem(
"MoePlan::run: WmmaGguf variant requires expert_offsets_scratch",
))?;
let ec_ptr = ec.data.as_raw().0 as *mut i32;
let eo_ptr = eo.data.as_raw().0 as *mut i32;
match self.desc.element {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_f16_run(
acts_ptr,
weights_ptr,
sorted_token_ids_ptr,
flat_expert_ids_ptr,
topk_weights_ptr,
out_ptr,
ec_ptr,
eo_ptr,
self.desc.num_experts,
self.desc.top_k,
num_tokens_flat,
self.desc.d_expert,
self.desc.d_model,
gguf_dtype,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_moe_wmma_gguf_bf16_run(
acts_ptr,
weights_ptr,
sorted_token_ids_ptr,
flat_expert_ids_ptr,
topk_weights_ptr,
out_ptr,
ec_ptr,
eo_ptr,
self.desc.num_experts,
self.desc.top_k,
num_tokens_flat,
self.desc.d_expert,
self.desc.d_model,
gguf_dtype,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
_ => return Err(Error::Unsupported("MoePlan::run: WmmaGguf element unsupported")),
}
}
};
map_status(status)
}
}
fn fuel_moe_gguf_dtype(bf: GgufBlockFormat) -> Result<i32> {
match bf {
GgufBlockFormat::Q8_0 => Ok(0),
GgufBlockFormat::Q4K => Ok(1),
GgufBlockFormat::Q2K => Ok(2),
GgufBlockFormat::Q3K => Ok(3),
GgufBlockFormat::Q5K => Ok(4),
GgufBlockFormat::Q6K => Ok(5),
GgufBlockFormat::Q4_0
| GgufBlockFormat::Q4_1
| GgufBlockFormat::Q5_0
| GgufBlockFormat::Q5_1
| GgufBlockFormat::Q8K => Err(Error::Unsupported(
"MoePlan: GGUF MoE variants only support Q8_0 + k-quants (Q2_K..Q6_K)",
)),
_ => Err(Error::Unsupported(
"MoePlan: unsupported GGUF block format",
)),
}
}
fn build_sku(desc: &MoeDescriptor) -> KernelSku {
let op = match desc.variant {
MoeVariant::ScalarGguf => MoeKind::ScalarGguf as u16,
MoeVariant::Wmma => MoeKind::Wmma as u16,
MoeVariant::WmmaGguf => MoeKind::WmmaGguf as u16,
};
KernelSku {
category: OpCategory::Moe,
op,
element: desc.element,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80, backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}