use cudarc::cublas::CudaBlas;
use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg};
use candle_core::cuda_backend::CudaDevice;
use crate::ptx;
use crate::weight_store::{GpuQuantWeight, LinearWeight};
pub fn dequant_int4(
device: &CudaDevice,
qw: &GpuQuantWeight,
output: &mut CudaSlice<half::f16>,
) -> candle_core::Result<()> {
let k = qw.k as i32;
let n = qw.n as i32;
let gs = qw.group_size as i32;
if qw.symmetric {
let func = device.get_or_load_custom_func(
"dequant_int4_sym_to_fp16",
"dequant_int4",
ptx::DEQUANT_INT4,
)?;
let qw_v = qw.qweight.slice(..);
let sc_v = qw.scales.slice(..);
let mut b = func.builder();
b.arg(&qw_v);
b.arg(&sc_v);
b.arg(output);
b.arg(&k);
b.arg(&n);
b.arg(&gs);
unsafe {
b.launch(LaunchConfig {
grid_dim: (((qw.n + 255) / 256) as u32, (qw.k / 8) as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| candle_core::Error::Msg(format!("dequant_int4_sym: {e}")))?;
} else {
let func = device.get_or_load_custom_func(
"dequant_int4_to_fp16",
"dequant_int4",
ptx::DEQUANT_INT4,
)?;
let qw_v = qw.qweight.slice(..);
let sc_v = qw.scales.slice(..);
let qz_v = qw
.qzeros
.as_ref()
.ok_or_else(|| candle_core::Error::Msg("non-symmetric quant requires qzeros".into()))?
.slice(..);
let mut b = func.builder();
b.arg(&qw_v);
b.arg(&sc_v);
b.arg(&qz_v);
b.arg(output);
b.arg(&k);
b.arg(&n);
b.arg(&gs);
unsafe {
b.launch(LaunchConfig {
grid_dim: (((qw.n + 255) / 256) as u32, (qw.k / 8) as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| candle_core::Error::Msg(format!("dequant_int4: {e}")))?;
}
Ok(())
}
pub fn linear_dispatch(
blas: &CudaBlas,
device: &CudaDevice,
input: &CudaSlice<half::f16>,
weight: &LinearWeight,
output: &mut CudaSlice<half::f16>,
temp_fp16: &mut CudaSlice<half::f16>,
m: i32,
n: i32,
k: i32,
) -> candle_core::Result<()> {
match weight {
LinearWeight::Fp16(w) => crate::cublas::linear_f16(blas, input, &w.slice, output, m, n, k),
LinearWeight::Int4(qw) => {
dequant_int4(device, qw, temp_fp16)?;
crate::cublas::linear_f16(blas, input, temp_fp16, output, m, n, k)
}
LinearWeight::Marlin(_) => {
Err(candle_core::Error::Msg(
"Marlin should be dispatched via CudaDecodeRunner::linear()".into(),
))
}
}
}