use crate::error::{Error, Result};
use crate::ops::traits::{Int4GroupSize, KvCacheQuantOps};
use numr::dtype::DType;
use numr::runtime::wgpu::{WgpuClient, WgpuRuntime, get_buffer};
use numr::tensor::Tensor;
use wgpu::BufferUsages;
const QUANT_FP8_SRC: &str = include_str!("../shaders/cache/kv_cache_quant_fp8.wgsl");
const DEQUANT_FP8_SRC: &str = include_str!("../shaders/cache/kv_cache_dequant_fp8.wgsl");
const QUANT_INT4_SRC: &str = include_str!("../shaders/cache/kv_cache_quant_int4.wgsl");
const DEQUANT_INT4_SRC: &str = include_str!("../shaders/cache/kv_cache_dequant_int4.wgsl");
const QUANT_INT8_SRC: &str = include_str!("../shaders/cache/kv_cache_quant_int8.wgsl");
const DEQUANT_INT8_SRC: &str = include_str!("../shaders/cache/kv_cache_dequant_int8.wgsl");
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct QuantParams {
num_tokens: u32,
head_dim: u32,
group_size: u32,
mode: u32,
}
fn validate_f32(t: &Tensor<WgpuRuntime>, op: &str) -> Result<()> {
if t.dtype() != DType::F32 {
return Err(Error::InvalidArgument {
arg: "dtype",
reason: format!("{}: WebGPU requires F32, got {:?}", op, t.dtype()),
});
}
Ok(())
}
fn create_params_buf(client: &WgpuClient, params: &QuantParams) -> wgpu::Buffer {
let buf = client.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("quant_params"),
size: std::mem::size_of::<QuantParams>() as u64,
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
client
.wgpu_queue()
.write_buffer(&buf, 0, bytemuck::bytes_of(params));
buf
}
fn dispatch(
client: &WgpuClient,
shader_src: &'static str,
entry: &'static str,
bufs: &[&wgpu::Buffer],
num_storage: u32,
num_readonly: u32,
workgroups: u32,
) -> Result<()> {
let cache = client.pipeline_cache();
let module = cache.get_or_create_module(entry, shader_src);
let layout = cache.get_or_create_layout(numr::runtime::wgpu::shaders::LayoutKey {
num_storage_buffers: num_storage,
num_uniform_buffers: 1,
num_readonly_storage: num_readonly,
});
let pipeline = cache.get_or_create_pipeline(entry, entry, &module, &layout);
let bind_group = cache.create_bind_group(&layout, bufs);
let mut encoder = client
.wgpu_device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(entry) });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(entry),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
client
.wgpu_queue()
.submit(std::iter::once(encoder.finish()));
Ok(())
}
impl KvCacheQuantOps<WgpuRuntime> for WgpuClient {
fn quantize_kv_fp8_per_token(
&self,
input: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<WgpuRuntime>, Tensor<WgpuRuntime>)> {
validate_f32(input, "quantize_kv_fp8_per_token")?;
let quantized =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim], DType::F32, input.device());
let scales = Tensor::<WgpuRuntime>::zeros(&[num_tokens], DType::F32, input.device());
let input_buf = get_buffer(input.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "input buffer not found".into(),
})?;
let quant_buf =
get_buffer(quantized.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "quantized buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: 0,
mode: 1,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
QUANT_FP8_SRC,
"quantize_kv_fp8_per_token_f32",
&[&input_buf, &quant_buf, &scales_buf, ¶ms_buf],
3, 1, (num_tokens as u32).div_ceil(256),
)?;
Ok((quantized, scales))
}
fn dequantize_kv_fp8_per_token(
&self,
quantized: &Tensor<WgpuRuntime>,
scales: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
_output_dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
validate_f32(quantized, "dequantize_kv_fp8_per_token")?;
validate_f32(scales, "dequantize_kv_fp8_per_token")?;
let output =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim], DType::F32, quantized.device());
let quant_buf =
get_buffer(quantized.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "quantized buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let out_buf = get_buffer(output.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "output buffer not found".into(),
})?;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: 0,
mode: 1,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
DEQUANT_FP8_SRC,
"dequantize_kv_fp8_per_token_f32",
&[&quant_buf, &scales_buf, &out_buf, ¶ms_buf],
3, 2, (num_tokens as u32).div_ceil(256),
)?;
Ok(output)
}
fn quantize_kv_int4(
&self,
input: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<(
Tensor<WgpuRuntime>,
Tensor<WgpuRuntime>,
Tensor<WgpuRuntime>,
)> {
validate_f32(input, "quantize_kv_int4")?;
let group_sz = group_size as usize;
let num_groups = (num_tokens * head_dim) / group_sz;
let packed =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim / 2], DType::F32, input.device());
let scales = Tensor::<WgpuRuntime>::zeros(&[num_groups], DType::F32, input.device());
let zeros = Tensor::<WgpuRuntime>::zeros(&[num_groups], DType::F32, input.device());
let input_buf = get_buffer(input.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "input buffer not found".into(),
})?;
let packed_buf = get_buffer(packed.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "packed buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let zeros_buf = get_buffer(zeros.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "zeros buffer not found".into(),
})?;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: group_sz as u32,
mode: 0,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
QUANT_INT4_SRC,
"quantize_kv_int4_f32",
&[
&input_buf,
&packed_buf,
&scales_buf,
&zeros_buf,
¶ms_buf,
],
4,
1,
(num_groups as u32).div_ceil(256),
)?;
Ok((packed, scales, zeros))
}
fn dequantize_kv_int4(
&self,
packed: &Tensor<WgpuRuntime>,
scales: &Tensor<WgpuRuntime>,
zeros: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<Tensor<WgpuRuntime>> {
validate_f32(packed, "dequantize_kv_int4")?;
validate_f32(scales, "dequantize_kv_int4")?;
validate_f32(zeros, "dequantize_kv_int4")?;
let output =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim], DType::F32, packed.device());
let packed_buf = get_buffer(packed.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "packed buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let zeros_buf = get_buffer(zeros.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "zeros buffer not found".into(),
})?;
let out_buf = get_buffer(output.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "output buffer not found".into(),
})?;
let group_sz = group_size as usize;
let num_groups = (num_tokens * head_dim) / group_sz;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: group_sz as u32,
mode: 0,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
DEQUANT_INT4_SRC,
"dequantize_kv_int4_f32",
&[&packed_buf, &scales_buf, &zeros_buf, &out_buf, ¶ms_buf],
4,
3,
(num_groups as u32).div_ceil(256),
)?;
Ok(output)
}
fn quantize_kv_int8(
&self,
input: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<WgpuRuntime>, Tensor<WgpuRuntime>)> {
validate_f32(input, "quantize_kv_int8")?;
let quantized =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim], DType::F32, input.device());
let scales = Tensor::<WgpuRuntime>::zeros(&[num_tokens], DType::F32, input.device());
let input_buf = get_buffer(input.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "input buffer not found".into(),
})?;
let quant_buf =
get_buffer(quantized.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "quantized buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: 0,
mode: 1,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
QUANT_INT8_SRC,
"quantize_kv_int8_f32",
&[&input_buf, &quant_buf, &scales_buf, ¶ms_buf],
3,
1,
(num_tokens as u32).div_ceil(256),
)?;
Ok((quantized, scales))
}
fn dequantize_kv_int8(
&self,
quantized: &Tensor<WgpuRuntime>,
scales: &Tensor<WgpuRuntime>,
num_tokens: usize,
head_dim: usize,
) -> Result<Tensor<WgpuRuntime>> {
validate_f32(quantized, "dequantize_kv_int8")?;
validate_f32(scales, "dequantize_kv_int8")?;
let output =
Tensor::<WgpuRuntime>::zeros(&[num_tokens, head_dim], DType::F32, quantized.device());
let quant_buf =
get_buffer(quantized.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "quantized buffer not found".into(),
})?;
let scales_buf = get_buffer(scales.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "scales buffer not found".into(),
})?;
let out_buf = get_buffer(output.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "output buffer not found".into(),
})?;
let params = QuantParams {
num_tokens: num_tokens as u32,
head_dim: head_dim as u32,
group_size: 0,
mode: 1,
};
let params_buf = create_params_buf(self, ¶ms);
dispatch(
self,
DEQUANT_INT8_SRC,
"dequantize_kv_int8_f32",
&[&quant_buf, &scales_buf, &out_buf, ¶ms_buf],
3,
2,
(num_tokens as u32).div_ceil(256),
)?;
Ok(output)
}
}