#[cfg(feature = "cuda")]
use super::super::cache::compile_lock_launch;
#[cfg(feature = "cuda")]
use super::super::GpuResidentTensor;
#[cfg(feature = "cuda")]
use crate::driver::{CudaContext, CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use crate::error::Result;
#[cfg(feature = "cuda")]
struct IncrementalAttentionParams {
q_expected: usize,
}
#[cfg(feature = "cuda")]
fn validate_incremental_attention(
q: &GpuResidentTensor<f32>,
k_cache: &GpuResidentTensor<f32>,
v_cache: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
max_seq_len: u32,
) -> Result<IncrementalAttentionParams> {
let q_expected = (n_heads * head_dim) as usize;
if q.len() != q_expected {
return Err(crate::GpuError::InvalidParameter(format!(
"Q has {} elements, expected {} (n_heads={}, head_dim={})",
q.len(),
q_expected,
n_heads,
head_dim
)));
}
let cache_expected = (n_heads * max_seq_len * head_dim) as usize;
if k_cache.len() != cache_expected {
return Err(crate::GpuError::InvalidParameter(format!(
"K cache has {} elements, expected {} (n_heads={}, max_seq_len={}, head_dim={})",
k_cache.len(),
cache_expected,
n_heads,
max_seq_len,
head_dim
)));
}
if v_cache.len() != cache_expected {
return Err(crate::GpuError::InvalidParameter(format!(
"V cache has {} elements, expected {}",
v_cache.len(),
cache_expected
)));
}
if seq_len > max_seq_len {
return Err(crate::GpuError::InvalidParameter(format!(
"seq_len ({}) exceeds max_seq_len ({})",
seq_len, max_seq_len
)));
}
Ok(IncrementalAttentionParams { q_expected })
}
#[cfg(feature = "cuda")]
fn launch_incremental_attention_kernel(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k_cache: &GpuResidentTensor<f32>,
v_cache: &GpuResidentTensor<f32>,
output: &GpuBuffer<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
max_seq_len: u32,
stream: &CudaStream,
) -> Result<()> {
use crate::kernels::{IncrementalAttentionKernel, Kernel};
let kernel = IncrementalAttentionKernel::new(max_seq_len, head_dim, n_heads);
let ptx = kernel.emit_ptx();
let cache_key = format!("incremental_attention:{}:{}:{}", max_seq_len, head_dim, n_heads);
let config = LaunchConfig {
grid: (n_heads, 1, 1),
block: (32, 1, 1), shared_mem: 0,
};
let q_ptr = q.as_ptr();
let k_ptr = k_cache.as_ptr();
let v_ptr = v_cache.as_ptr();
let out_ptr = output.as_ptr();
let seq_len_val = seq_len;
let mut args: [*mut std::ffi::c_void; 5] = [
std::ptr::addr_of!(q_ptr) as *mut _,
std::ptr::addr_of!(k_ptr) as *mut _,
std::ptr::addr_of!(v_ptr) as *mut _,
std::ptr::addr_of!(out_ptr) as *mut _,
std::ptr::addr_of!(seq_len_val) as *mut _,
];
compile_lock_launch(ctx, stream, &cache_key, &ptx, kernel.name(), &config, &mut args)?;
Ok(())
}
#[cfg(feature = "cuda")]
pub fn incremental_attention_gpu(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k_cache: &GpuResidentTensor<f32>,
v_cache: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
max_seq_len: u32,
) -> Result<GpuResidentTensor<f32>> {
let params = validate_incremental_attention(
q,
k_cache,
v_cache,
n_heads,
head_dim,
seq_len,
max_seq_len,
)?;
if seq_len == 0 {
let zeros = vec![0.0f32; params.q_expected];
return GpuResidentTensor::from_host(ctx, &zeros);
}
let output = GpuBuffer::new(ctx, params.q_expected)?;
let stream = CudaStream::new(ctx)?;
launch_incremental_attention_kernel(
ctx,
q,
k_cache,
v_cache,
&output,
n_heads,
head_dim,
seq_len,
max_seq_len,
&stream,
)?;
stream.synchronize()?;
Ok(GpuResidentTensor::from_buffer_internal(output, 1))
}
#[cfg(feature = "cuda")]
pub fn incremental_attention_gpu_with_stream(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k_cache: &GpuResidentTensor<f32>,
v_cache: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
max_seq_len: u32,
stream: &CudaStream,
) -> Result<GpuResidentTensor<f32>> {
let params = validate_incremental_attention(
q,
k_cache,
v_cache,
n_heads,
head_dim,
seq_len,
max_seq_len,
)?;
if seq_len == 0 {
let zeros = vec![0.0f32; params.q_expected];
return GpuResidentTensor::from_host(ctx, &zeros);
}
let output = GpuBuffer::new(ctx, params.q_expected)?;
launch_incremental_attention_kernel(
ctx,
q,
k_cache,
v_cache,
&output,
n_heads,
head_dim,
seq_len,
max_seq_len,
stream,
)?;
Ok(GpuResidentTensor::from_buffer_internal(output, 1))
}
#[cfg(feature = "cuda")]
pub fn incremental_attention_gpu_async(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k_cache: &GpuResidentTensor<f32>,
v_cache: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
max_seq_len: u32,
) -> Result<(GpuResidentTensor<f32>, CudaStream)> {
let params = validate_incremental_attention(
q,
k_cache,
v_cache,
n_heads,
head_dim,
seq_len,
max_seq_len,
)?;
if seq_len == 0 {
let zeros = vec![0.0f32; params.q_expected];
let output = GpuResidentTensor::from_host(ctx, &zeros)?;
let stream = CudaStream::new(ctx)?;
return Ok((output, stream));
}
let output = GpuBuffer::new(ctx, params.q_expected)?;
let stream = CudaStream::new(ctx)?;
launch_incremental_attention_kernel(
ctx,
q,
k_cache,
v_cache,
&output,
n_heads,
head_dim,
seq_len,
max_seq_len,
&stream,
)?;
Ok((GpuResidentTensor::from_buffer_internal(output, 1), stream))
}
#[cfg(feature = "cuda")]
pub fn kv_cache_scatter_gpu(
ctx: &CudaContext,
src: &GpuResidentTensor<f32>,
cache: &mut GpuResidentTensor<f32>,
pos: u32,
n_heads: u32,
head_dim: u32,
max_seq_len: u32,
stream: &CudaStream,
) -> Result<()> {
use crate::kernels::{Kernel, KvCacheScatterKernel};
let src_expected = (n_heads * head_dim) as usize;
if src.len() != src_expected {
return Err(crate::GpuError::InvalidParameter(format!(
"Source has {} elements, expected {} (n_heads={}, head_dim={})",
src.len(),
src_expected,
n_heads,
head_dim
)));
}
let cache_expected = (n_heads * max_seq_len * head_dim) as usize;
if cache.len() != cache_expected {
return Err(crate::GpuError::InvalidParameter(format!(
"Cache has {} elements, expected {} (n_heads={}, max_seq_len={}, head_dim={})",
cache.len(),
cache_expected,
n_heads,
max_seq_len,
head_dim
)));
}
if pos >= max_seq_len {
return Err(crate::GpuError::InvalidParameter(format!(
"Position {} >= max_seq_len {}",
pos, max_seq_len
)));
}
let kernel = KvCacheScatterKernel::new(n_heads, head_dim, max_seq_len);
let ptx = kernel.emit_ptx();
let cache_key = format!("kv_scatter:{}:{}:{}", n_heads, head_dim, max_seq_len);
let config = LaunchConfig {
grid: (n_heads, 1, 1),
block: (head_dim.min(256), 1, 1), shared_mem: 0,
};
let src_ptr = src.as_ptr();
let cache_ptr = cache.as_ptr();
let mut args: [*mut std::ffi::c_void; 5] = [
std::ptr::addr_of!(src_ptr) as *mut _,
std::ptr::addr_of!(cache_ptr) as *mut _,
std::ptr::addr_of!(pos) as *mut _,
std::ptr::addr_of!(head_dim) as *mut _,
std::ptr::addr_of!(max_seq_len) as *mut _,
];
compile_lock_launch(ctx, stream, &cache_key, &ptx, kernel.name(), &config, &mut args)?;
Ok(())
}