use cudarc::driver::{CudaFunction, CudaSlice};
use std::sync::{Arc, Mutex, OnceLock};
use crate::gpu_backend::cuda_full_layer::CudaKvCache;
use crate::gpu_backend::cuda_graph::{compile_or_load_ptx, CudaGraph, CudaGraphError};
use crate::gpu_backend::cuda_k_quant_prefill_kernels::CUDA_K_QUANT_PREFILL_KERNELS_SRC;
use crate::gpu_backend::cuda_prefill::CudaPrefillBuffers;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KQuantFormat {
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
}
pub struct CudaKQuantPrefillModules {
pub gemm_q2k: CudaFunction,
pub gemm_q2k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q2k: CudaFunction,
pub gemm_q3k: CudaFunction,
pub gemm_q3k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q3k: CudaFunction,
pub gemm_q4k: CudaFunction,
pub gemm_q4k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q4k: CudaFunction,
pub gemm_q5k: CudaFunction,
pub gemm_q5k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q5k: CudaFunction,
pub gemm_q6k: CudaFunction,
pub gemm_q6k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q6k: CudaFunction,
pub gemm_q8k: CudaFunction,
pub gemm_q8k_residual: CudaFunction,
pub fused_gate_up_swiglu_gemm_q8k: CudaFunction,
}
unsafe impl Send for CudaKQuantPrefillModules {}
unsafe impl Sync for CudaKQuantPrefillModules {}
struct CudaKQuantPrefillState {
kquant_modules: Mutex<Option<Arc<CudaKQuantPrefillModules>>>,
prefill_buffers: Mutex<Option<CudaPrefillBuffers>>,
kv_cache: Mutex<Option<CudaKvCache>>,
logits_buf: Mutex<Option<(CudaSlice<f32>, usize)>>,
}
unsafe impl Send for CudaKQuantPrefillState {}
unsafe impl Sync for CudaKQuantPrefillState {}
static K_QUANT_PREFILL_STATE: OnceLock<CudaKQuantPrefillState> = OnceLock::new();
fn k_quant_prefill_state() -> &'static CudaKQuantPrefillState {
K_QUANT_PREFILL_STATE.get_or_init(|| CudaKQuantPrefillState {
kquant_modules: Mutex::new(None),
prefill_buffers: Mutex::new(None),
kv_cache: Mutex::new(None),
logits_buf: Mutex::new(None),
})
}
pub fn init_k_quant_prefill_modules(
graph: &CudaGraph,
) -> Result<Arc<CudaKQuantPrefillModules>, CudaGraphError> {
let state = k_quant_prefill_state();
let mut guard = state
.kquant_modules
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(ref m) = *guard {
return Ok(Arc::clone(m));
}
let ptx = compile_or_load_ptx(CUDA_K_QUANT_PREFILL_KERNELS_SRC, "k_quant_prefill_kernels")?;
let module = graph
.context_arc()
.load_module(ptx)
.map_err(|e| CudaGraphError::DriverError(format!("load_module k_quant_prefill: {e}")))?;
let load = |name: &str| -> Result<CudaFunction, CudaGraphError> {
module
.load_function(name)
.map_err(|e| CudaGraphError::DriverError(format!("load_function({name}): {e}")))
};
let mods = Arc::new(CudaKQuantPrefillModules {
gemm_q2k: load("gemm_q2k")?,
gemm_q2k_residual: load("gemm_q2k_residual")?,
fused_gate_up_swiglu_gemm_q2k: load("fused_gate_up_swiglu_gemm_q2k")?,
gemm_q3k: load("gemm_q3k")?,
gemm_q3k_residual: load("gemm_q3k_residual")?,
fused_gate_up_swiglu_gemm_q3k: load("fused_gate_up_swiglu_gemm_q3k")?,
gemm_q4k: load("gemm_q4k")?,
gemm_q4k_residual: load("gemm_q4k_residual")?,
fused_gate_up_swiglu_gemm_q4k: load("fused_gate_up_swiglu_gemm_q4k")?,
gemm_q5k: load("gemm_q5k")?,
gemm_q5k_residual: load("gemm_q5k_residual")?,
fused_gate_up_swiglu_gemm_q5k: load("fused_gate_up_swiglu_gemm_q5k")?,
gemm_q6k: load("gemm_q6k")?,
gemm_q6k_residual: load("gemm_q6k_residual")?,
fused_gate_up_swiglu_gemm_q6k: load("fused_gate_up_swiglu_gemm_q6k")?,
gemm_q8k: load("gemm_q8k")?,
gemm_q8k_residual: load("gemm_q8k_residual")?,
fused_gate_up_swiglu_gemm_q8k: load("fused_gate_up_swiglu_gemm_q8k")?,
});
*guard = Some(Arc::clone(&mods));
Ok(mods)
}
pub struct CudaKQuantPrefillLayerParams<'a> {
pub format: KQuantFormat,
pub attn_norm_handle: u64,
pub attn_norm_bytes: &'a [f32],
pub fused_qkv_handle: u64,
pub fused_qkv_bytes: &'a [u8],
pub q_norm_handle: u64,
pub q_norm_bytes: &'a [f32],
pub k_norm_handle: u64,
pub k_norm_bytes: &'a [f32],
pub attn_proj_handle: u64,
pub attn_proj_bytes: &'a [u8],
pub ffn_norm_handle: u64,
pub ffn_norm_bytes: &'a [f32],
pub gate_up_handle: u64,
pub gate_bytes: &'a [u8],
pub up_bytes: &'a [u8],
pub down_handle: u64,
pub down_bytes: &'a [u8],
}
fn next_pow2_cap(n: usize) -> usize {
if n == 0 {
return 1;
}
let mut cap = 1usize;
while cap < n {
cap <<= 1;
}
cap
}
#[allow(clippy::too_many_arguments)]
pub(super) fn acquire_k_quant_prefill_buffers(
graph: &CudaGraph,
batch_size: usize,
hidden_size: usize,
intermediate_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
max_seq: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaPrefillBuffers>>, CudaGraphError> {
let state = k_quant_prefill_state();
let mut guard = state
.prefill_buffers
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(b) => !b.matches(
batch_size,
hidden_size,
intermediate_size,
nq,
nkv,
head_dim,
max_seq,
),
None => true,
};
if needs_alloc {
let capacity = next_pow2_cap(batch_size);
let alloc = |n: usize| -> Result<CudaSlice<f32>, CudaGraphError> {
graph
.stream_arc()
.alloc_zeros::<f32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc_zeros kqpb({n}): {e}")))
};
let qkv_total = (nq + 2 * nkv) * head_dim;
*guard = Some(CudaPrefillBuffers {
d_input: alloc(capacity * hidden_size)?,
d_normed: alloc(capacity * hidden_size)?,
d_qkv: alloc(capacity * qkv_total)?,
d_attn_out: alloc(capacity * nq * head_dim)?,
d_gate_up: alloc(2 * capacity * intermediate_size)?,
d_swiglu: alloc(capacity * intermediate_size)?,
capacity,
actual_batch_size: batch_size,
hidden_size,
intermediate_size,
nq,
nkv,
head_dim,
max_seq,
});
} else {
guard
.as_mut()
.expect("guard is Some when needs_alloc is false")
.actual_batch_size = batch_size;
}
Ok(guard)
}
pub(super) fn acquire_k_quant_kv_cache(
graph: &CudaGraph,
n_layers: usize,
n_kv: usize,
max_seq: usize,
head_dim: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaKvCache>>, CudaGraphError> {
let state = k_quant_prefill_state();
let mut guard = state
.kv_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(c) => !c.matches(n_layers, n_kv, max_seq, head_dim),
None => true,
};
if needs_alloc {
let total = n_layers * n_kv * max_seq * head_dim;
let k_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv k_cache kquant: {e}")))?;
let v_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv v_cache kquant: {e}")))?;
*guard = Some(CudaKvCache {
k_cache,
v_cache,
n_layers,
n_kv,
max_seq,
head_dim,
});
}
Ok(guard)
}
pub(super) type KQuantLogitsGuard = std::sync::MutexGuard<'static, Option<(CudaSlice<f32>, usize)>>;
pub(super) fn acquire_k_quant_logits(
graph: &CudaGraph,
n: usize,
) -> Result<KQuantLogitsGuard, CudaGraphError> {
let state = k_quant_prefill_state();
let mut guard = state
.logits_buf
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some((_, sz)) => *sz != n,
None => true,
};
if needs_alloc {
let buf = graph
.stream_arc()
.alloc_zeros::<f32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc logits kquant({n}): {e}")))?;
*guard = Some((buf, n));
}
Ok(guard)
}