use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, GgufBlockFormat, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace, U8,
};
use crate::quantize::gguf::mmvq::GgufMmvqActivation;
use crate::quantize::map_status;
#[derive(Copy, Clone, Debug)]
pub struct GgufMmvqMultiMDescriptor {
pub nrows: i32,
pub ncols: i32,
pub m: i32,
pub block_format: GgufBlockFormat,
pub w_start_byte_offset: i64,
}
impl Default for GgufMmvqMultiMDescriptor {
fn default() -> Self {
Self {
nrows: 0,
ncols: 0,
m: 1,
block_format: GgufBlockFormat::Q8_0,
w_start_byte_offset: 0,
}
}
}
pub struct GgufMmvqMultiMArgs<'a, T: GgufMmvqActivation = f32> {
pub weight: TensorRef<'a, U8, 1>,
pub activations: TensorRef<'a, T, 2>,
pub output: TensorMut<'a, f32, 2>,
}
pub struct GgufMmvqMultiMPlan<T: GgufMmvqActivation = f32> {
desc: GgufMmvqMultiMDescriptor,
sku: KernelSku,
workspace_bytes: usize,
_phantom: PhantomData<T>,
}
impl<T: GgufMmvqActivation> GgufMmvqMultiMPlan<T> {
pub fn select(
_stream: &Stream,
desc: &GgufMmvqMultiMDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.nrows < 0 || desc.ncols < 0 || desc.m < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: nrows / ncols / m must be non-negative",
));
}
if desc.block_format == GgufBlockFormat::Q8K {
return Err(Error::Unsupported(
"GgufMmvqMultiMPlan: Q8_K not supported — upstream \
llama.cpp / Fuel reserve Q8_K as a CPU-side intermediate; \
use GgufMmvqPlan (bespoke baracuda kernel) for Q8_K MMVQ.",
));
}
let bs = desc.block_format.block_size() as i32;
if desc.ncols % bs != 0 {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: ncols must be a multiple of the \
block size (32 for Q4_0/Q4_1/Q5_0/Q5_1/Q8_0; 256 for k-quants)",
));
}
if desc.w_start_byte_offset < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: w_start_byte_offset must be non-negative",
));
}
let blocks_per_row = ((desc.ncols + 31) / 32) as usize;
let workspace_bytes = (desc.m as usize) * blocks_per_row * 36;
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND),
workspace_bytes,
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &GgufMmvqMultiMArgs<'_, T>) -> Result<()> {
if args.activations.shape != [self.desc.m, self.desc.ncols] {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: activations shape != [M, ncols]",
));
}
if args.output.shape != [self.desc.m, self.desc.nrows] {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: output shape != [M, nrows]",
));
}
if args.activations.stride[1] != 1 {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: activations must be contig along K",
));
}
if args.output.stride[1] != 1 {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: output must be contig along nrows",
));
}
let bs = self.desc.block_format.block_size() as i32;
let blocks_per_row = self.desc.ncols / bs;
let expected_bytes =
(self.desc.nrows as i64) * (blocks_per_row as i64) * (self.desc.block_format.type_size() as i64);
let weight_len_bytes = args.weight.shape[0] as i64;
let need_bytes = self.desc.w_start_byte_offset + expected_bytes;
if weight_len_bytes < need_bytes {
return Err(Error::InvalidProblem(
"GgufMmvqMultiMPlan: weight byte length < offset + nrows * blocks_per_row * type_size",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
self.workspace_bytes
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: GgufMmvqMultiMArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.nrows == 0 || self.desc.ncols == 0 || self.desc.m == 0 {
return Ok(());
}
let need = self.workspace_bytes;
let ws_ptr = match workspace {
Workspace::None => {
if need > 0 {
return Err(Error::WorkspaceTooSmall { needed: need, got: 0 });
}
core::ptr::null_mut()
}
Workspace::Borrowed(slice) => {
let got = slice.len();
if got < need {
return Err(Error::WorkspaceTooSmall { needed: need, got });
}
slice.as_raw().0 as *mut c_void
}
};
let w_ptr = args.weight.data.as_raw().0 as *const c_void;
let y_ptr = args.activations.data.as_raw().0 as *const c_void;
let dst_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let ncols = self.desc.ncols;
let nrows = self.desc.nrows;
let w_off = self.desc.w_start_byte_offset;
let stage_status = unsafe {
stage_q8_1::<T>(
ncols as i64,
self.desc.m as i64,
y_ptr,
ws_ptr,
stream_ptr,
)
};
map_status(stage_status)?;
let mut m_done = 0i32;
while m_done < self.desc.m {
let m_remaining = self.desc.m - m_done;
let m_chunk = pick_chunk_size(m_remaining);
let blocks_per_row = ((ncols + 31) / 32) as i64;
let ws_row_bytes = blocks_per_row * 36;
let chunk_ws_ptr = unsafe {
(ws_ptr as *mut u8).offset((m_done as i64 * ws_row_bytes) as isize)
} as *const c_void;
let chunk_dst_ptr = unsafe {
(dst_ptr as *mut f32).offset((m_done as isize) * (nrows as isize))
} as *mut c_void;
let status = unsafe {
dispatch_multim(
self.desc.block_format,
m_chunk, ncols, nrows, w_ptr, w_off, chunk_ws_ptr, chunk_dst_ptr, stream_ptr,
)
};
map_status(status)?;
m_done += m_chunk;
}
Ok(())
}
}
fn pick_chunk_size(remaining: i32) -> i32 {
if remaining >= 8 {
8
} else if remaining >= 4 {
4
} else if remaining >= 2 {
2
} else {
1
}
}
unsafe fn stage_q8_1<T: GgufMmvqActivation>(
kx: i64,
ny: i64,
src: *const c_void,
dst: *mut c_void,
stream: *mut c_void,
) -> i32 {
match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantize_q8_1_f32_run(
kx, ny, src, dst, stream,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantize_q8_1_f16_run(
kx, ny, src, dst, stream,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantize_q8_1_bf16_run(
kx, ny, src, dst, stream,
)
},
_ => 3, }
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines, non_snake_case)]
unsafe fn dispatch_multim(
fmt: GgufBlockFormat,
m: i32,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
w_off: i64,
activations_q8_1: *const c_void,
dst: *mut c_void,
stream: *mut c_void,
) -> i32 {
use baracuda_kernels_sys as sys;
let ws = core::ptr::null_mut();
match (fmt, m) {
(GgufBlockFormat::Q8_0, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q8_0_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q8_0, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q8_0_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q8_0, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q8_0_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q8_0, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q8_0_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_0, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_0_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_0, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_0_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_0, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_0_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_0, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_0_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_1, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_1_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_1, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_1_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_1, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_1_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4_1, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_1_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_0, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_0_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_0, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_0_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_0, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_0_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_0, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_0_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_1, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_1_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_1, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_1_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_1, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_1_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5_1, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_1_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q2K, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q2_K_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q2K, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q2_K_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q2K, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q2_K_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q2K, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q2_K_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q3K, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q3_K_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q3K, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q3_K_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q3K, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q3_K_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q3K, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q3_K_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4K, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_K_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4K, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_K_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4K, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_K_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q4K, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q4_K_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5K, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_K_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5K, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_K_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5K, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_K_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q5K, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q5_K_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q6K, 1) => unsafe {
sys::baracuda_kernels_mmvq_multim_q6_K_m1_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q6K, 2) => unsafe {
sys::baracuda_kernels_mmvq_multim_q6_K_m2_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q6K, 4) => unsafe {
sys::baracuda_kernels_mmvq_multim_q6_K_m4_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
(GgufBlockFormat::Q6K, 8) => unsafe {
sys::baracuda_kernels_mmvq_multim_q6_K_m8_run(
ncols, nrows, w_ptr, w_off, activations_q8_1, dst, ws, 0, stream)
},
_ => 2,
}
}
fn build_sku(act_kind: ElementKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::GgufMmvq as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}