use cudarc::driver::{CudaContext, CudaSlice, CudaView, LaunchConfig, PushKernelArg};
use std::sync::Arc;
use crate::ptx;
#[allow(clippy::too_many_arguments)]
pub fn launch_int8_paged_decode_attention(
ctx: &Arc<CudaContext>,
q: &CudaSlice<half::f16>,
k_pool: &CudaSlice<i8>,
v_pool: &CudaSlice<i8>,
k_scales_pool: &CudaSlice<half::f16>,
v_scales_pool: &CudaSlice<half::f16>,
block_table: &CudaView<'_, i32>,
output: &mut CudaSlice<half::f16>,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
valid_kv_len: usize,
block_size: usize,
scale: f32,
) -> std::result::Result<(), String> {
let stream = ctx.default_stream();
let func = stream
.context()
.load_module(ptx::INT8_PAGED_DECODE_ATTENTION.into())
.map_err(|e| format!("load int8_paged_decode_attention module: {e}"))?
.load_function("paged_decode_attention_int8")
.map_err(|e| format!("load paged_decode_attention_int8 func: {e}"))?;
let nq = num_q_heads as i32;
let nkv = num_kv_heads as i32;
let hd = head_dim as i32;
let kvl = valid_kv_len as i32;
let bs = block_size as i32;
let mut b = stream.launch_builder(&func);
b.arg(q);
b.arg(k_pool);
b.arg(v_pool);
b.arg(k_scales_pool);
b.arg(v_scales_pool);
b.arg(block_table);
b.arg(output);
b.arg(&nq);
b.arg(&nkv);
b.arg(&hd);
b.arg(&kvl);
b.arg(&bs);
b.arg(&scale);
let shared_bytes = (valid_kv_len as u32) * 12;
let cfg = LaunchConfig {
grid_dim: (num_q_heads as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
};
unsafe { b.launch(cfg) }
.map(|_| ())
.map_err(|e| format!("int8_paged_decode_attention launch: {e}"))
}
#[allow(clippy::too_many_arguments)]
pub fn launch_int8_kv_cache_append(
ctx: &Arc<CudaContext>,
k_in: &CudaSlice<half::f16>,
v_in: &CudaSlice<half::f16>,
k_out_pool: &mut CudaSlice<i8>,
v_out_pool: &mut CudaSlice<i8>,
k_scales_pool: &mut CudaSlice<half::f16>,
v_scales_pool: &mut CudaSlice<half::f16>,
slot_mapping: &CudaSlice<i32>,
num_tokens: usize,
num_kv_heads: usize,
head_dim: usize,
) -> std::result::Result<(), String> {
if head_dim > 256 {
return Err(format!(
"int8_kv_cache_append: head_dim {head_dim} > 256 (kernel uses one thread per element)"
));
}
let stream = ctx.default_stream();
let func = stream
.context()
.load_module(ptx::INT8_PAGED_DECODE_ATTENTION.into())
.map_err(|e| format!("load int8_paged_decode_attention module: {e}"))?
.load_function("int8_kv_cache_append")
.map_err(|e| format!("load int8_kv_cache_append func: {e}"))?;
let nkv = num_kv_heads as i32;
let hd = head_dim as i32;
let nt = num_tokens as i32;
let mut b = stream.launch_builder(&func);
b.arg(k_in);
b.arg(v_in);
b.arg(&mut *k_out_pool);
b.arg(&mut *v_out_pool);
b.arg(&mut *k_scales_pool);
b.arg(&mut *v_scales_pool);
b.arg(slot_mapping);
b.arg(&nkv);
b.arg(&hd);
b.arg(&nt);
let cfg = LaunchConfig {
grid_dim: (num_tokens as u32, num_kv_heads as u32, 1),
block_dim: (head_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { b.launch(cfg) }
.map(|_| ())
.map_err(|e| format!("int8_kv_cache_append launch: {e}"))
}