use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, GgufBlockFormat, KernelSku,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut,
TensorRef, Workspace, U8,
};
use baracuda_types::DeviceRepr;
use half::{bf16, f16};
use crate::quantize::map_status;
pub trait GgufMmvqBatchedActivation: Element + sealed::Sealed {}
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for half::f16 {}
impl Sealed for half::bf16 {}
}
impl GgufMmvqBatchedActivation for f32 {}
impl GgufMmvqBatchedActivation for f16 {}
impl GgufMmvqBatchedActivation for bf16 {}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum GgufMmvqBatchedFormat {
Quantized(GgufBlockFormat),
Fp,
}
#[derive(Copy, Clone, Debug)]
pub struct GgufMmvqBatchedDescriptor {
pub n_experts: i32,
pub n_rows_per_expert: i32,
pub n_cols: i32,
pub m_total: i32,
pub top_k: i32,
pub format: GgufMmvqBatchedFormat,
}
impl Default for GgufMmvqBatchedDescriptor {
fn default() -> Self {
Self {
n_experts: 0,
n_rows_per_expert: 0,
n_cols: 0,
m_total: 0,
top_k: 1,
format: GgufMmvqBatchedFormat::Quantized(GgufBlockFormat::Q8_0),
}
}
}
pub struct GgufMmvqBatchedArgs<'a, T: DeviceRepr + Copy + 'static = f32> {
pub weights: TensorRef<'a, U8, 1>,
pub activations: TensorRef<'a, T, 2>,
pub sorted_token_ids: TensorRef<'a, i32, 1>,
pub expert_offsets: TensorRef<'a, i32, 1>,
pub topk_weights: Option<TensorRef<'a, f32, 1>>,
pub output: TensorMut<'a, T, 2>,
}
pub struct GgufMmvqBatchedPlan<T: DeviceRepr + Copy + 'static = f32> {
desc: GgufMmvqBatchedDescriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: GgufMmvqBatchedActivation> GgufMmvqBatchedPlan<T> {
pub fn select(
_stream: &Stream,
desc: &GgufMmvqBatchedDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.n_experts <= 0 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: n_experts must be positive",
));
}
if desc.n_rows_per_expert < 0 || desc.n_cols < 0 || desc.m_total < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: n_rows_per_expert / n_cols / m_total must be non-negative",
));
}
if desc.top_k <= 0 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: top_k must be positive",
));
}
if let GgufMmvqBatchedFormat::Quantized(fmt) = desc.format {
if !fmt.has_mmvq() {
return Err(Error::Unsupported(
"GgufMmvqBatchedPlan: block format reports no MMVQ kernel",
));
}
let bs = fmt.block_size() as i32;
if desc.n_cols % bs != 0 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: n_cols must be a multiple of the block size",
));
}
#[cfg(debug_assertions)]
{
use baracuda_kernels_types::GgufBlockFormat;
let is_type_0_1 = matches!(
fmt,
GgufBlockFormat::Q4_0
| GgufBlockFormat::Q4_1
| GgufBlockFormat::Q5_0
| GgufBlockFormat::Q5_1
| GgufBlockFormat::Q8_0,
);
if is_type_0_1 && desc.n_cols < 64 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: type-0/1 block formats require \
n_cols >= 64 (2 * GGML_CUDA_DMMV_X); smaller n_cols \
produces silent-wrong results because contiguous-batched \
activations make threads' OOB reads hit adjacent tokens' \
rows",
));
}
}
}
Ok(Self {
desc: *desc,
sku: build_sku(&desc.format, T::KIND),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &GgufMmvqBatchedArgs<'_, T>) -> Result<()> {
if args.activations.shape[1] != self.desc.n_cols {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: activations.shape[1] != n_cols",
));
}
if args.output.shape[1] != self.desc.n_rows_per_expert {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: output.shape[1] != n_rows_per_expert",
));
}
if args.output.shape[0] != args.activations.shape[0] {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: output.shape[0] != activations.shape[0] (M_tokens mismatch)",
));
}
if args.sorted_token_ids.shape[0] != self.desc.m_total {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: sorted_token_ids.shape[0] != m_total",
));
}
if args.expert_offsets.shape[0] != self.desc.n_experts + 1 {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: expert_offsets.shape[0] != n_experts + 1",
));
}
if let Some(ref tw) = args.topk_weights {
if tw.shape[0] != self.desc.m_total {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: topk_weights.shape[0] != m_total",
));
}
}
let expected_weight_bytes: i64 = match self.desc.format {
GgufMmvqBatchedFormat::Quantized(fmt) => {
let blocks_per_row = self.desc.n_cols / fmt.block_size() as i32;
(self.desc.n_experts as i64)
* (self.desc.n_rows_per_expert as i64)
* (blocks_per_row as i64)
* (fmt.type_size() as i64)
}
GgufMmvqBatchedFormat::Fp => {
(self.desc.n_experts as i64)
* (self.desc.n_rows_per_expert as i64)
* (self.desc.n_cols as i64)
* (core::mem::size_of::<T>() as i64)
}
};
if (args.weights.shape[0] as i64) < expected_weight_bytes {
return Err(Error::InvalidProblem(
"GgufMmvqBatchedPlan: weights byte length < n_experts * n_rows_per_expert * (n_cols/bs) * type_size",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
(self.desc.m_total as usize) * core::mem::size_of::<i32>()
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: GgufMmvqBatchedArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.m_total == 0
|| self.desc.n_experts == 0
|| self.desc.n_rows_per_expert == 0
|| self.desc.n_cols == 0
{
return Ok(());
}
let need = self.workspace_size();
let (ws_ptr, ws_bytes) = match workspace {
Workspace::None => {
return Err(Error::WorkspaceTooSmall { needed: need, got: 0 });
}
Workspace::Borrowed(slice) => {
let got = slice.len();
if got < need {
return Err(Error::WorkspaceTooSmall { needed: need, got });
}
(slice.as_raw().0 as *mut c_void, got)
}
};
let w_ptr = args.weights.data.as_raw().0 as *const c_void;
let y_ptr = args.activations.data.as_raw().0 as *const c_void;
let dst_ptr = args.output.data.as_raw().0 as *mut c_void;
let tids_ptr = args.sorted_token_ids.data.as_raw().0 as *const i32;
let off_ptr = args.expert_offsets.data.as_raw().0 as *const i32;
let tw_ptr = args
.topk_weights
.as_ref()
.map(|t| t.data.as_raw().0 as *const f32)
.unwrap_or(core::ptr::null());
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_ffi::<T>(
&self.desc.format,
self.desc.n_experts,
self.desc.n_rows_per_expert,
self.desc.n_cols,
w_ptr,
y_ptr,
tids_ptr,
off_ptr,
tw_ptr,
dst_ptr,
self.desc.top_k,
ws_ptr,
ws_bytes,
stream_ptr,
)
};
map_status(status)
}
}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_ffi<T: GgufMmvqBatchedActivation>(
format: &GgufMmvqBatchedFormat,
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *mut c_void,
top_k: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
match (format, T::KIND) {
(GgufMmvqBatchedFormat::Fp, ElementKind::F32) => {
unsafe { baracuda_kernels_sys::baracuda_kernels_mmvq_batched_f32_run(
n_experts, n_rows_per_expert, n_cols, weights, activations,
sorted_token_ids, expert_offsets, topk_weights, output, top_k,
workspace, workspace_bytes, stream) }
}
(GgufMmvqBatchedFormat::Fp, ElementKind::F16) => {
unsafe { baracuda_kernels_sys::baracuda_kernels_mmvq_batched_f16_run(
n_experts, n_rows_per_expert, n_cols, weights, activations,
sorted_token_ids, expert_offsets, topk_weights, output, top_k,
workspace, workspace_bytes, stream) }
}
(GgufMmvqBatchedFormat::Fp, ElementKind::Bf16) => {
unsafe { baracuda_kernels_sys::baracuda_kernels_mmvq_batched_bf16_run(
n_experts, n_rows_per_expert, n_cols, weights, activations,
sorted_token_ids, expert_offsets, topk_weights, output, top_k,
workspace, workspace_bytes, stream) }
}
(GgufMmvqBatchedFormat::Quantized(fmt), kind) => {
unsafe { dispatch_quant_ffi(
*fmt, kind, n_experts, n_rows_per_expert, n_cols, weights, activations,
sorted_token_ids, expert_offsets, topk_weights, output, top_k,
workspace, workspace_bytes, stream,
) }
}
_ => -1,
}
}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_quant_ffi(
fmt: GgufBlockFormat,
kind: ElementKind,
n_experts: i32,
n_rows_per_expert: i32,
n_cols: i32,
weights: *const c_void,
activations: *const c_void,
sorted_token_ids: *const i32,
expert_offsets: *const i32,
topk_weights: *const f32,
output: *mut c_void,
top_k: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
macro_rules! call {
($sym:ident) => {
unsafe { baracuda_kernels_sys::$sym(
n_experts, n_rows_per_expert, n_cols, weights, activations,
sorted_token_ids, expert_offsets, topk_weights, output, top_k,
workspace, workspace_bytes, stream) }
};
}
match (fmt, kind) {
(GgufBlockFormat::Q4_0, ElementKind::F32) => call!(baracuda_kernels_mmvq_q4_0_batched_run),
(GgufBlockFormat::Q4_0, ElementKind::F16) => call!(baracuda_kernels_mmvq_q4_0_batched_f16_run),
(GgufBlockFormat::Q4_0, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q4_0_batched_bf16_run),
(GgufBlockFormat::Q4_1, ElementKind::F32) => call!(baracuda_kernels_mmvq_q4_1_batched_run),
(GgufBlockFormat::Q4_1, ElementKind::F16) => call!(baracuda_kernels_mmvq_q4_1_batched_f16_run),
(GgufBlockFormat::Q4_1, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q4_1_batched_bf16_run),
(GgufBlockFormat::Q5_0, ElementKind::F32) => call!(baracuda_kernels_mmvq_q5_0_batched_run),
(GgufBlockFormat::Q5_0, ElementKind::F16) => call!(baracuda_kernels_mmvq_q5_0_batched_f16_run),
(GgufBlockFormat::Q5_0, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q5_0_batched_bf16_run),
(GgufBlockFormat::Q5_1, ElementKind::F32) => call!(baracuda_kernels_mmvq_q5_1_batched_run),
(GgufBlockFormat::Q5_1, ElementKind::F16) => call!(baracuda_kernels_mmvq_q5_1_batched_f16_run),
(GgufBlockFormat::Q5_1, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q5_1_batched_bf16_run),
(GgufBlockFormat::Q8_0, ElementKind::F32) => call!(baracuda_kernels_mmvq_q8_0_batched_run),
(GgufBlockFormat::Q8_0, ElementKind::F16) => call!(baracuda_kernels_mmvq_q8_0_batched_f16_run),
(GgufBlockFormat::Q8_0, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q8_0_batched_bf16_run),
(GgufBlockFormat::Q2K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q2_K_batched_run),
(GgufBlockFormat::Q2K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q2_K_batched_f16_run),
(GgufBlockFormat::Q2K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q2_K_batched_bf16_run),
(GgufBlockFormat::Q3K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q3_K_batched_run),
(GgufBlockFormat::Q3K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q3_K_batched_f16_run),
(GgufBlockFormat::Q3K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q3_K_batched_bf16_run),
(GgufBlockFormat::Q4K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q4_K_batched_run),
(GgufBlockFormat::Q4K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q4_K_batched_f16_run),
(GgufBlockFormat::Q4K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q4_K_batched_bf16_run),
(GgufBlockFormat::Q5K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q5_K_batched_run),
(GgufBlockFormat::Q5K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q5_K_batched_f16_run),
(GgufBlockFormat::Q5K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q5_K_batched_bf16_run),
(GgufBlockFormat::Q6K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q6_K_batched_run),
(GgufBlockFormat::Q6K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q6_K_batched_f16_run),
(GgufBlockFormat::Q6K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q6_K_batched_bf16_run),
(GgufBlockFormat::Q8K, ElementKind::F32) => call!(baracuda_kernels_mmvq_q8_K_batched_run),
(GgufBlockFormat::Q8K, ElementKind::F16) => call!(baracuda_kernels_mmvq_q8_K_batched_f16_run),
(GgufBlockFormat::Q8K, ElementKind::Bf16) => call!(baracuda_kernels_mmvq_q8_K_batched_bf16_run),
_ => -1,
}
}
fn build_sku(_format: &GgufMmvqBatchedFormat, act_kind: ElementKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::GgufMmvq as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}