impl CudaExecutor {
#[allow(clippy::too_many_arguments)]
pub fn flash_decoding_attention_into(
&mut self,
layer_idx: usize,
q_batched: &GpuBuffer<f32>,
k_batched: &GpuBuffer<f32>,
v_batched: &GpuBuffer<f32>,
out_batched: &GpuBuffer<f32>,
m: usize,
positions: &[u32],
) -> Result<(), GpuError> {
use trueno_gpu::kernels::{
FlashDecodingChunkKernel, FlashDecodingReduceKernel, Kernel, FLASH_DECODE_CHUNK_SIZE,
};
if !self.flash_decode_enabled {
return Err(GpuError::InvalidLaunchConfig(
"PAR-118: Flash Decoding not initialized (call init_flash_decoding first)"
.to_string(),
));
}
let num_heads = self.kv_num_heads;
let num_kv_heads = self.kv_num_kv_heads;
let head_dim = self.kv_head_dim;
let max_len = self.kv_cache_max_len;
let stride = self.batched_kv_stride;
let kv_dim = num_kv_heads * head_dim;
let scatter_config = LaunchConfig {
grid: (num_kv_heads as u32, 1, 1),
block: (head_dim as u32, 1, 1),
shared_mem: 0,
};
let scatter_type = KernelType::KvCacheScatter {
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
max_len: max_len as u32,
};
let scatter_name = self.kernels.kernel_name(&scatter_type);
let scatter_key = format!("kv_scatter_{}_{}", num_kv_heads, head_dim);
if !self.modules.contains_key(&scatter_key) {
let scatter_ptx = self.kernels.generate_ptx(&scatter_type);
let module = self.compile_ptx(&scatter_ptx)?;
self.modules.insert(scatter_key.clone(), module);
}
let k_cache = self.batched_kv_k_caches.get(&layer_idx).ok_or_else(|| {
GpuError::InvalidLaunchConfig(format!(
"PAR-118: Batched K cache not found for layer {}",
layer_idx
))
})?;
let v_cache = self.batched_kv_v_caches.get(&layer_idx).ok_or_else(|| {
GpuError::InvalidLaunchConfig(format!(
"PAR-118: Batched V cache not found for layer {}",
layer_idx
))
})?;
for seq_idx in 0..m {
let pos = positions[seq_idx] as usize;
let k_src_offset = seq_idx * kv_dim;
let k_dst_offset = seq_idx * stride;
let k_src_ptr = k_batched.as_ptr() + (k_src_offset * std::mem::size_of::<f32>()) as u64;
let k_dst_ptr = k_cache.as_ptr() + (k_dst_offset * std::mem::size_of::<f32>()) as u64;
let mut k_src = k_src_ptr;
let mut k_dst = k_dst_ptr;
let mut pos_val = pos as u32;
let mut head_dim_val = head_dim as u32;
let mut max_len_val = max_len as u32;
let scatter_module = self.modules.get_mut(&scatter_key).expect("module exists");
unsafe {
self.stream.launch_kernel(
scatter_module,
scatter_name,
&scatter_config,
&mut [
std::ptr::from_mut(&mut k_src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut pos_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut head_dim_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_len_val) as *mut std::ffi::c_void,
],
)?;
}
let v_src_offset = seq_idx * kv_dim;
let v_dst_offset = seq_idx * stride;
let v_src_ptr = v_batched.as_ptr() + (v_src_offset * std::mem::size_of::<f32>()) as u64;
let v_dst_ptr = v_cache.as_ptr() + (v_dst_offset * std::mem::size_of::<f32>()) as u64;
let mut v_src = v_src_ptr;
let mut v_dst = v_dst_ptr;
let scatter_module = self.modules.get_mut(&scatter_key).expect("module exists");
unsafe {
self.stream.launch_kernel(
scatter_module,
scatter_name,
&scatter_config,
&mut [
std::ptr::from_mut(&mut v_src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut v_dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut pos_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut head_dim_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_len_val) as *mut std::ffi::c_void,
],
)?;
}
}
for seq_idx in 0..m {
let pos = positions[seq_idx] as usize;
if seq_idx < self.batched_kv_lengths.len() {
self.batched_kv_lengths[seq_idx] = pos + 1;
}
}
let k_cache_base = k_cache.as_ptr();
let v_cache_base = v_cache.as_ptr();
let stride_bytes = (stride * std::mem::size_of::<f32>()) as u64;
let buf_len = self.batched_kv_lengths.len();
let mut k_ptrs: Vec<u64> = (0..m)
.map(|seq_idx| k_cache_base + seq_idx as u64 * stride_bytes)
.collect();
k_ptrs.resize(buf_len, k_cache_base);
let mut v_ptrs: Vec<u64> = (0..m)
.map(|seq_idx| v_cache_base + seq_idx as u64 * stride_bytes)
.collect();
v_ptrs.resize(buf_len, v_cache_base);
let mut seq_lens: Vec<u32> = (0..m)
.map(|seq_idx| {
if seq_idx < self.batched_done_mask.len() && self.batched_done_mask[seq_idx] {
0
} else {
self.batched_kv_lengths.get(seq_idx).copied().unwrap_or(1) as u32
}
})
.collect();
seq_lens.resize(buf_len, 0);
let max_seq_len_actual = seq_lens.iter().copied().max().unwrap_or(1) as usize;
let max_chunks = (max_seq_len_actual + FLASH_DECODE_CHUNK_SIZE as usize - 1)
/ FLASH_DECODE_CHUNK_SIZE as usize;
let chunk_kernel = FlashDecodingChunkKernel::new(
max_len as u32,
head_dim as u32,
num_heads as u32,
num_kv_heads as u32,
m as u32,
);
let chunk_kernel_name = chunk_kernel.name();
let chunk_module_key = format!(
"flash_decode_chunk_{}_{}_{}_{}",
max_len, head_dim, num_heads, num_kv_heads
);
if !self.modules.contains_key(&chunk_module_key) {
let chunk_ptx = chunk_kernel.emit_ptx_for_target(&self.kernels.sm_target);
let module = self.compile_ptx(&chunk_ptx)?;
self.modules.insert(chunk_module_key.clone(), module);
}
let k_ptrs_buf = self.batched_k_ptrs.as_mut().ok_or_else(|| {
GpuError::InvalidLaunchConfig("PAR-118: batched_k_ptrs not allocated".to_string())
})?;
let v_ptrs_buf = self.batched_v_ptrs.as_mut().ok_or_else(|| {
GpuError::InvalidLaunchConfig("PAR-118: batched_v_ptrs not allocated".to_string())
})?;
let seq_lens_buf = self.batched_seq_lens_gpu.as_mut().ok_or_else(|| {
GpuError::InvalidLaunchConfig("PAR-118: batched_seq_lens_gpu not allocated".to_string())
})?;
let m = seq_lens.len();
unsafe {
let mut k_view = GpuBuffer::<u64>::from_raw_parts(k_ptrs_buf.as_ptr(), m);
let mut v_view = GpuBuffer::<u64>::from_raw_parts(v_ptrs_buf.as_ptr(), m);
let mut s_view = GpuBuffer::<u32>::from_raw_parts(seq_lens_buf.as_ptr(), m);
k_view.copy_from_host_async(&k_ptrs, &self.stream)?;
v_view.copy_from_host_async(&v_ptrs, &self.stream)?;
s_view.copy_from_host_async(&seq_lens, &self.stream)?;
std::mem::forget(k_view);
std::mem::forget(v_view);
std::mem::forget(s_view);
}
let partials_buf = self.flash_decode_partials.as_ref().ok_or_else(|| {
GpuError::InvalidLaunchConfig(
"PAR-118: flash_decode_partials not allocated".to_string(),
)
})?;
let chunk_config = LaunchConfig {
grid: (num_heads as u32, m as u32, max_chunks as u32),
block: (32, 1, 1),
shared_mem: 0,
};
let mut q_ptr = q_batched.as_ptr();
let mut k_ptrs_ptr = k_ptrs_buf.as_ptr();
let mut v_ptrs_ptr = v_ptrs_buf.as_ptr();
let mut partials_ptr = partials_buf.as_ptr();
let mut seq_lens_ptr = seq_lens_buf.as_ptr();
let mut max_chunks_val = max_chunks as u32;
let chunk_module = self
.modules
.get_mut(&chunk_module_key)
.expect("module just inserted");
unsafe {
self.stream.launch_kernel(
chunk_module,
chunk_kernel_name,
&chunk_config,
&mut [
std::ptr::from_mut(&mut q_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_ptrs_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut v_ptrs_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut partials_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut seq_lens_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_chunks_val) as *mut std::ffi::c_void,
],
)?;
}
let reduce_kernel =
FlashDecodingReduceKernel::new(head_dim as u32, num_heads as u32, m as u32);
let reduce_kernel_name = reduce_kernel.name();
let reduce_module_key = format!("flash_decode_reduce_{}_{}", head_dim, num_heads);
if !self.modules.contains_key(&reduce_module_key) {
let reduce_ptx = reduce_kernel.emit_ptx_for_target(&self.kernels.sm_target);
let module = self.compile_ptx(&reduce_ptx)?;
self.modules.insert(reduce_module_key.clone(), module);
}
let reduce_config = LaunchConfig {
grid: (num_heads as u32, m as u32, 1),
block: (32, 1, 1),
shared_mem: 0,
};
let mut out_ptr = out_batched.as_ptr();
let reduce_module = self
.modules
.get_mut(&reduce_module_key)
.expect("module just inserted");
unsafe {
self.stream.launch_kernel(
reduce_module,
reduce_kernel_name,
&reduce_config,
&mut [
std::ptr::from_mut(&mut partials_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut out_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut seq_lens_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut max_chunks_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn tensor_core_attention(
&mut self,
q: &[f32],
k: &[f32],
v: &[f32],
output: &mut [f32],
seq_len: u32,
head_dim: u32,
n_heads: u32,
causal: bool,
) -> Result<(), GpuError> {
if !seq_len.is_multiple_of(16) || !head_dim.is_multiple_of(16) {
return Err(GpuError::InvalidLaunchConfig(format!(
"Tensor Core attention requires dimensions multiple of 16: seq_len={}, head_dim={}",
seq_len, head_dim
)));
}
let head_size = (seq_len * head_dim) as usize;
let total_size = head_size * n_heads as usize;
if q.len() != total_size
|| k.len() != total_size
|| v.len() != total_size
|| output.len() != total_size
{
return Err(GpuError::InvalidLaunchConfig(format!(
"Tensor Core attention size mismatch: expected {} ({}×{}×{}), got Q[{}] K[{}] V[{}] O[{}]",
total_size, n_heads, seq_len, head_dim,
q.len(), k.len(), v.len(), output.len()
)));
}
self.memory_pool.record_allocation(total_size * 4 * 4);
let kernel_type = KernelType::AttentionTensorCore {
seq_len,
head_dim,
n_heads,
causal,
};
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!(
"tensor_core_attn_{}_{}_{}_{}",
seq_len, head_dim, n_heads, causal
);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
#[cfg(test)]
eprintln!("Generated Tensor Core attention PTX:\n{}", ptx);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let buf_q = GpuBuffer::from_host(&self.context, q)?;
let buf_k = GpuBuffer::from_host(&self.context, k)?;
let buf_v = GpuBuffer::from_host(&self.context, v)?;
let buf_output = GpuBuffer::<f32>::new(&self.context, total_size)?;
let num_tiles = (seq_len + 15) / 16;
let config = LaunchConfig::grid_2d(num_tiles, n_heads, 256, 1);
let mut ptr_q = buf_q.as_ptr();
let mut ptr_k = buf_k.as_ptr();
let mut ptr_v = buf_v.as_ptr();
let mut ptr_output = buf_output.as_ptr();
let mut seq_len_val = seq_len;
let mut head_dim_val = head_dim;
let mut n_heads_val = n_heads;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_q) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_k) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_v) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut seq_len_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut head_dim_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_heads_val) as *mut std::ffi::c_void,
],
)?;
}
self.stream.synchronize()?;
buf_output.copy_to_host(output)?;
self.memory_pool.record_deallocation(total_size * 4 * 4);
Ok(())
}
}