use core::ffi::c_void;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, GgufBlockFormat, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace, U8,
};
use crate::quantize::map_status;
#[derive(Copy, Clone, Debug)]
pub struct GgufDequantizeDescriptor {
pub numel: i64,
pub block_format: GgufBlockFormat,
}
pub struct GgufDequantizeArgs<'a> {
pub input: TensorRef<'a, U8, 1>,
pub output: TensorMut<'a, f32, 1>,
}
pub struct GgufDequantizePlan {
desc: GgufDequantizeDescriptor,
sku: KernelSku,
}
impl GgufDequantizePlan {
pub fn select(
_stream: &Stream,
desc: &GgufDequantizeDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.numel < 0 {
return Err(Error::InvalidProblem(
"GgufDequantizePlan: numel must be non-negative",
));
}
let bs = desc.block_format.block_size() as i64;
if desc.numel % bs != 0 {
return Err(Error::InvalidProblem(
"GgufDequantizePlan: numel must be a multiple of the block size",
));
}
Ok(Self {
desc: *desc,
sku: build_sku(desc.block_format, QuantizeKind::GgufDequantize),
})
}
pub fn can_implement(&self, args: &GgufDequantizeArgs<'_>) -> Result<()> {
if args.output.shape != [self.desc.numel as i32] {
return Err(Error::InvalidProblem(
"GgufDequantizePlan: output shape != [numel]",
));
}
let blocks = self.desc.numel / self.desc.block_format.block_size() as i64;
let expected_bytes = blocks * self.desc.block_format.type_size() as i64;
if args.input.shape != [expected_bytes as i32] {
return Err(Error::InvalidProblem(
"GgufDequantizePlan: input byte length != blocks * type_size",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: GgufDequantizeArgs<'_>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.numel == 0 {
return Ok(());
}
let x_ptr = args.input.data.as_raw().0 as *const c_void;
let y_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let numel = self.desc.numel;
let status = unsafe {
match self.desc.block_format {
GgufBlockFormat::Q4_0 => baracuda_kernels_sys::baracuda_kernels_dequantize_q4_0_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q4_1 => baracuda_kernels_sys::baracuda_kernels_dequantize_q4_1_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q5_0 => baracuda_kernels_sys::baracuda_kernels_dequantize_q5_0_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q5_1 => baracuda_kernels_sys::baracuda_kernels_dequantize_q5_1_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q8_0 => baracuda_kernels_sys::baracuda_kernels_dequantize_q8_0_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q2K => baracuda_kernels_sys::baracuda_kernels_dequantize_q2_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q3K => baracuda_kernels_sys::baracuda_kernels_dequantize_q3_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q4K => baracuda_kernels_sys::baracuda_kernels_dequantize_q4_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q5K => baracuda_kernels_sys::baracuda_kernels_dequantize_q5_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q6K => baracuda_kernels_sys::baracuda_kernels_dequantize_q6_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
GgufBlockFormat::Q8K => baracuda_kernels_sys::baracuda_kernels_dequantize_q8_K_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
),
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::GgufDequantizePlan: unsupported block format",
));
}
}
};
map_status(status)
}
}
pub(crate) fn build_sku(_block_format: GgufBlockFormat, op: QuantizeKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: op as u16,
element: ElementKind::F32,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}