use crate::error::{Error, Result};
use crate::ops::traits::cache::kv_cache_quant::{Int4GroupSize, KvCacheQuantOps};
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::LaunchConfig;
use numr::autograd::Var;
use numr::dtype::DType;
use numr::runtime::Device;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
use crate::ops::cuda::kernels::{
self, KV_CACHE_FP8_MODULE, KV_CACHE_INT4_MODULE, KV_CACHE_QUANT_MODULE,
};
impl KvCacheQuantOps<CudaRuntime> for CudaClient {
fn quantize_kv_fp8_per_token(
&self,
input: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<CudaRuntime>, Tensor<CudaRuntime>)> {
let input_to_use = if input.dtype() == DType::F32 {
let vars = Var::new(input.clone(), false);
let cast_var =
numr::autograd::var_cast(&vars, DType::F16, self).map_err(Error::Numr)?;
cast_var.tensor().clone()
} else {
input.clone()
};
let dtype = input_to_use.dtype();
let kernel_name = match dtype {
DType::F16 => "quantize_kv_fp8_per_token_fp16",
DType::BF16 => "quantize_kv_fp8_per_token_bf16",
_ => {
return Err(Error::KernelError {
reason: format!("FP8 quant: unsupported input dtype {dtype:?}, need F16/BF16"),
});
}
};
let device = input.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_FP8_MODULE)?;
let func = kernels::get_kernel_function(&module, kernel_name)?;
let quantized = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim], DType::U8, device);
let scales = Tensor::<CudaRuntime>::empty(&[num_tokens], DType::F32, device);
let cfg = LaunchConfig {
grid_dim: (num_tokens as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
let q_ptr = quantized.ptr();
let i_ptr = input_to_use.ptr();
let s_ptr = scales.ptr();
let batch_i32 = 1i32; let nkh_i32 = 1i32;
let sl_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&q_ptr);
builder.arg(&i_ptr);
builder.arg(&s_ptr);
builder.arg(&batch_i32);
builder.arg(&nkh_i32);
builder.arg(&sl_i32);
builder.arg(&hd_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("FP8 per-token quant failed: {e:?}"),
})?;
}
Ok((quantized, scales))
}
fn dequantize_kv_fp8_per_token(
&self,
quantized: &Tensor<CudaRuntime>,
scales: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
output_dtype: DType,
) -> Result<Tensor<CudaRuntime>> {
let target_dtype = match output_dtype {
DType::F32 => DType::F16, other => other,
};
let kernel_name = match target_dtype {
DType::F16 => "dequantize_kv_fp8_per_token_fp16",
DType::BF16 => "dequantize_kv_fp8_per_token_bf16",
_ => {
return Err(Error::KernelError {
reason: format!("FP8 dequant: unsupported output dtype {target_dtype:?}"),
});
}
};
let device = quantized.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_FP8_MODULE)?;
let func = kernels::get_kernel_function(&module, kernel_name)?;
let output = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim], target_dtype, device);
let cfg = LaunchConfig {
grid_dim: (num_tokens as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let o_ptr = output.ptr();
let q_ptr = quantized.ptr();
let s_ptr = scales.ptr();
let batch_i32 = 1i32;
let nkh_i32 = 1i32;
let sl_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&o_ptr);
builder.arg(&q_ptr);
builder.arg(&s_ptr);
builder.arg(&batch_i32);
builder.arg(&nkh_i32);
builder.arg(&sl_i32);
builder.arg(&hd_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("FP8 per-token dequant failed: {e:?}"),
})?;
}
if output_dtype == DType::F32 && target_dtype == DType::F16 {
let output_var = Var::new(output, false);
let cast_var =
numr::autograd::var_cast(&output_var, DType::F32, self).map_err(Error::Numr)?;
Ok(cast_var.tensor().clone())
} else {
Ok(output)
}
}
fn quantize_kv_int4(
&self,
input: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<(
Tensor<CudaRuntime>,
Tensor<CudaRuntime>,
Tensor<CudaRuntime>,
)> {
let dtype = input.dtype();
let dtype_suffix = match dtype {
DType::F32 => "fp32",
DType::F16 => "fp16",
DType::BF16 => "bf16",
_ => {
return Err(Error::KernelError {
reason: format!("INT4 quant: unsupported dtype {dtype:?}"),
});
}
};
let kernel_name = format!("quantize_kv_int4_per_group_{dtype_suffix}");
let device = input.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_INT4_MODULE)?;
let func = kernels::get_kernel_function(&module, &kernel_name)?;
let gs = group_size as usize;
let total = num_tokens * head_dim;
let num_groups = total.div_ceil(gs);
let packed = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim / 2], DType::U8, device);
let scales_t = Tensor::<CudaRuntime>::empty(&[num_groups], DType::F32, device);
let zeros_t = Tensor::<CudaRuntime>::empty(&[num_groups], DType::F32, device);
let cfg = LaunchConfig {
grid_dim: (num_groups as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
let i_ptr = input.ptr();
let p_ptr = packed.ptr();
let s_ptr = scales_t.ptr();
let z_ptr = zeros_t.ptr();
let nt_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
let gs_i32 = gs as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&i_ptr);
builder.arg(&p_ptr);
builder.arg(&s_ptr);
builder.arg(&z_ptr);
builder.arg(&nt_i32);
builder.arg(&hd_i32);
builder.arg(&gs_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("INT4 quant failed: {e:?}"),
})?;
}
Ok((packed, scales_t, zeros_t))
}
fn dequantize_kv_int4(
&self,
packed: &Tensor<CudaRuntime>,
scales: &Tensor<CudaRuntime>,
zeros: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<Tensor<CudaRuntime>> {
let device = packed.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_INT4_MODULE)?;
let func = kernels::get_kernel_function(&module, "dequantize_kv_int4_per_group_fp32")?;
let gs = group_size as usize;
let total = num_tokens * head_dim;
let num_groups = total.div_ceil(gs);
let output = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim], DType::F32, device);
let cfg = LaunchConfig {
grid_dim: (num_groups as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let p_ptr = packed.ptr();
let o_ptr = output.ptr();
let s_ptr = scales.ptr();
let z_ptr = zeros.ptr();
let nt_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
let gs_i32 = gs as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&p_ptr);
builder.arg(&s_ptr);
builder.arg(&z_ptr);
builder.arg(&o_ptr);
builder.arg(&nt_i32);
builder.arg(&hd_i32);
builder.arg(&gs_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("INT4 dequant failed: {e:?}"),
})?;
}
Ok(output)
}
fn quantize_kv_int8(
&self,
input: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<CudaRuntime>, Tensor<CudaRuntime>)> {
let dtype = input.dtype();
let dtype_suffix = match dtype {
DType::F32 => "fp32",
DType::F16 => "fp16",
DType::BF16 => "bf16",
_ => {
return Err(Error::KernelError {
reason: format!("INT8 quant: unsupported dtype {dtype:?}"),
});
}
};
let kernel_name = format!("quantize_kv_int8_per_token_{dtype_suffix}");
let device = input.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_QUANT_MODULE)?;
let func = kernels::get_kernel_function(&module, &kernel_name)?;
let quantized = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim], DType::I8, device);
let scales = Tensor::<CudaRuntime>::empty(&[num_tokens], DType::F32, device);
let cfg = LaunchConfig {
grid_dim: (num_tokens as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
let i_ptr = input.ptr();
let q_ptr = quantized.ptr();
let s_ptr = scales.ptr();
let nt_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&i_ptr);
builder.arg(&q_ptr);
builder.arg(&s_ptr);
builder.arg(&nt_i32);
builder.arg(&hd_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("INT8 quant failed: {e:?}"),
})?;
}
Ok((quantized, scales))
}
fn dequantize_kv_int8(
&self,
quantized: &Tensor<CudaRuntime>,
scales: &Tensor<CudaRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<Tensor<CudaRuntime>> {
let device = quantized.device();
let device_index = device.id();
let module =
kernels::get_or_load_module(self.context(), device_index, KV_CACHE_QUANT_MODULE)?;
let func = kernels::get_kernel_function(&module, "dequantize_kv_int8_per_token_fp32")?;
let output = Tensor::<CudaRuntime>::empty(&[num_tokens, head_dim], DType::F32, device);
let cfg = LaunchConfig {
grid_dim: (num_tokens as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let q_ptr = quantized.ptr();
let o_ptr = output.ptr();
let s_ptr = scales.ptr();
let nt_i32 = num_tokens as i32;
let hd_i32 = head_dim as i32;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&q_ptr);
builder.arg(&o_ptr);
builder.arg(&s_ptr);
builder.arg(&nt_i32);
builder.arg(&hd_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("INT8 dequant failed: {e:?}"),
})?;
}
Ok(output)
}
}