use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, GgufBlockFormat, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
U8,
};
use half::{bf16, f16};
use crate::quantize::map_status;
pub trait GgufMmvqActivation: Element + sealed::Sealed {}
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for half::f16 {}
impl Sealed for half::bf16 {}
}
impl GgufMmvqActivation for f32 {}
impl GgufMmvqActivation for f16 {}
impl GgufMmvqActivation for bf16 {}
#[derive(Copy, Clone, Debug)]
pub struct GgufMmvqDescriptor {
pub nrows: i32,
pub ncols: i32,
pub block_format: GgufBlockFormat,
pub w_start_byte_offset: i64,
}
impl Default for GgufMmvqDescriptor {
fn default() -> Self {
Self {
nrows: 0,
ncols: 0,
block_format: GgufBlockFormat::Q8_0,
w_start_byte_offset: 0,
}
}
}
pub struct GgufMmvqArgs<'a, T: GgufMmvqActivation = f32> {
pub weight: TensorRef<'a, U8, 1>,
pub activation: TensorRef<'a, T, 1>,
pub output: TensorMut<'a, T, 1>,
}
pub struct GgufMmvqPlan<T: GgufMmvqActivation = f32> {
desc: GgufMmvqDescriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: GgufMmvqActivation> GgufMmvqPlan<T> {
pub fn select(
_stream: &Stream,
desc: &GgufMmvqDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.nrows < 0 || desc.ncols < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: nrows / ncols must be non-negative",
));
}
if !desc.block_format.has_mmvq() {
return Err(Error::Unsupported(
"GgufMmvqPlan: block format reports no MMVQ kernel",
));
}
let bs = desc.block_format.block_size() as i32;
if desc.ncols % bs != 0 {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: ncols must be a multiple of the block size",
));
}
#[cfg(debug_assertions)]
{
if desc.w_start_byte_offset < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: w_start_byte_offset must be non-negative",
));
}
let alignment = required_alignment(desc.block_format);
if desc.w_start_byte_offset % alignment != 0 {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: w_start_byte_offset must be aligned to the block format's natural alignment (Q4_1/Q5_1/Q2K/Q4K/Q5K/Q8K = 4, others = 2)",
));
}
}
Ok(Self {
desc: *desc,
sku: build_sku(desc.block_format, T::KIND),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &GgufMmvqArgs<'_, T>) -> Result<()> {
if args.activation.shape != [self.desc.ncols] {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: activation shape != [ncols]",
));
}
if args.output.shape != [self.desc.nrows] {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: output shape != [nrows]",
));
}
let blocks_per_row = self.desc.ncols / self.desc.block_format.block_size() as i32;
let expected_bytes =
(self.desc.nrows as i64) * (blocks_per_row as i64) * (self.desc.block_format.type_size() as i64);
let weight_len_bytes = args.weight.shape[0] as i64;
if self.desc.w_start_byte_offset < 0 {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: w_start_byte_offset must be non-negative",
));
}
let need_bytes = self.desc.w_start_byte_offset + expected_bytes;
if weight_len_bytes < need_bytes {
return Err(Error::InvalidProblem(
"GgufMmvqPlan: weight byte length < w_start_byte_offset + nrows * blocks_per_row * type_size",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: GgufMmvqArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.nrows == 0 || self.desc.ncols == 0 {
return Ok(());
}
let w_ptr = args.weight.data.as_raw().0 as *const c_void;
let y_ptr = args.activation.data.as_raw().0 as *const c_void;
let dst_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let ncols = self.desc.ncols;
let nrows = self.desc.nrows;
let w_off = self.desc.w_start_byte_offset;
let stride_y = args.activation.stride[0];
let use_strided = w_off != 0 || stride_y != 1;
let status =
unsafe { dispatch_ffi::<T>(self.desc.block_format, use_strided, ncols, nrows, w_ptr,
w_off, stride_y, y_ptr, dst_ptr, stream_ptr) };
map_status(status)
}
}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_ffi<T: GgufMmvqActivation>(
fmt: GgufBlockFormat,
use_strided: bool,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
w_off: i64,
stride_y: i64,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
match T::KIND {
ElementKind::F32 => {
if !use_strided {
dispatch_f32_contig(fmt, ncols, nrows, w_ptr, y_ptr, dst_ptr, stream_ptr)
} else {
dispatch_f32_strided(fmt, ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, stream_ptr)
}
}
ElementKind::F16 => {
if !use_strided {
dispatch_f16_contig(fmt, ncols, nrows, w_ptr, y_ptr, dst_ptr, stream_ptr)
} else {
dispatch_f16_strided(fmt, ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, stream_ptr)
}
}
ElementKind::Bf16 => {
if !use_strided {
dispatch_bf16_contig(fmt, ncols, nrows, w_ptr, y_ptr, dst_ptr, stream_ptr)
} else {
dispatch_bf16_strided(fmt, ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, stream_ptr)
}
}
_ => unreachable!("GgufMmvqActivation is sealed to f32 / f16 / bf16"),
}
}}
unsafe fn dispatch_f32_contig(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_f32_strided(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
w_off: i64,
stride_y: i64,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_actstrided_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
unsafe fn dispatch_f16_contig(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_f16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_f16_strided(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
w_off: i64,
stride_y: i64,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_actstrided_f16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
unsafe fn dispatch_bf16_contig(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_bf16_run(
ncols, nrows, w_ptr, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
#[allow(clippy::too_many_arguments)]
unsafe fn dispatch_bf16_strided(
fmt: GgufBlockFormat,
ncols: i32,
nrows: i32,
w_ptr: *const c_void,
w_off: i64,
stride_y: i64,
y_ptr: *const c_void,
dst_ptr: *mut c_void,
stream_ptr: *mut c_void,
) -> i32 { unsafe {
let ws = core::ptr::null_mut();
match fmt {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_0_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_1_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_0_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_1_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_0_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_mmvq_q2_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_mmvq_q3_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_mmvq_q4_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_mmvq_q5_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_mmvq_q6_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_mmvq_q8_K_actstrided_bf16_run(
ncols, nrows, w_ptr, w_off, stride_y, y_ptr, dst_ptr, ws, 0, stream_ptr),
_ => 3,
}
}}
#[cfg(debug_assertions)]
#[inline]
fn required_alignment(format: GgufBlockFormat) -> i64 {
match format {
GgufBlockFormat::Q4_0
| GgufBlockFormat::Q5_0
| GgufBlockFormat::Q8_0
| GgufBlockFormat::Q3K
| GgufBlockFormat::Q6K => 2,
GgufBlockFormat::Q4_1
| GgufBlockFormat::Q5_1
| GgufBlockFormat::Q2K
| GgufBlockFormat::Q4K
| GgufBlockFormat::Q5K
| GgufBlockFormat::Q8K => 4,
_ => 4,
}
}
fn build_sku(_block_format: GgufBlockFormat, act_kind: ElementKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::GgufMmvq as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}