use candle_core::cuda_backend::CudaStorage;
use candle_core::{op::BackpropOp, DType, Storage, Tensor};
use cudarc::driver::PushKernelArg;
use crate::ptx;
const MODULE_NAME: &str = "decode_attention";
pub fn decode_attention(
q: &Tensor,
k_cache: &Tensor,
v_cache: &Tensor,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
max_kv_len: usize,
valid_kv_len: usize,
scale: f32,
) -> candle_core::Result<Tensor> {
let dtype = q.dtype();
if dtype != DType::F16 {
candle_core::bail!("decode_attention: only F16 supported, got {dtype:?}");
}
let cuda_dev = q.device().as_cuda_device()?;
let func = cuda_dev.get_or_load_custom_func(
"decode_attention_f16",
MODULE_NAME,
ptx::DECODE_ATTENTION,
)?;
let block_size = 256u32;
let grid_size = num_q_heads as u32;
let shared_mem = (max_kv_len as u32) * 4;
let num_q_heads_i32 = num_q_heads as i32;
let num_kv_heads_i32 = num_kv_heads as i32;
let head_dim_i32 = head_dim as i32;
let max_kv_len_i32 = max_kv_len as i32;
let valid_kv_len_i32 = valid_kv_len as i32;
let (q_s, q_l) = q.storage_and_layout();
let (k_s, k_l) = k_cache.storage_and_layout();
let (v_s, v_l) = v_cache.storage_and_layout();
let q_cuda = match &*q_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("q must be on CUDA"),
};
let k_cuda = match &*k_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("k_cache must be on CUDA"),
};
let v_cuda = match &*v_s {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("v_cache must be on CUDA"),
};
let q_in = q_cuda.as_cuda_slice::<half::f16>()?;
let k_in = k_cuda.as_cuda_slice::<half::f16>()?;
let v_in = v_cuda.as_cuda_slice::<half::f16>()?;
let out = unsafe { cuda_dev.alloc::<half::f16>(num_q_heads * head_dim)? };
let q_in = q_in.slice(q_l.start_offset()..);
let k_in = k_in.slice(k_l.start_offset()..);
let v_in = v_in.slice(v_l.start_offset()..);
let mut builder = func.builder();
builder.arg(&q_in);
builder.arg(&k_in);
builder.arg(&v_in);
builder.arg(&out);
builder.arg(&num_q_heads_i32);
builder.arg(&num_kv_heads_i32);
builder.arg(&head_dim_i32);
builder.arg(&max_kv_len_i32);
builder.arg(&valid_kv_len_i32);
builder.arg(&scale);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: shared_mem,
};
unsafe { builder.launch(cfg) }
.map_err(|e| candle_core::Error::Msg(format!("decode_attention kernel launch: {e}")))?;
drop(q_s);
drop(k_s);
drop(v_s);
let output_storage = CudaStorage::wrap_cuda_slice(out, cuda_dev.clone());
Ok(Tensor::from_storage(
Storage::Cuda(output_storage),
(num_q_heads, head_dim),
BackpropOp::none(),
false,
))
}