use crate::error::{Error, Result};
use cudarc::driver::safe::CudaFunction;
use cudarc::driver::sys;
use numr::runtime::cuda::CudaRuntime;
use numr::tensor::Tensor;
pub(super) struct AttentionParams {
pub batch_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub seq_len_q: usize,
pub seq_len_k: usize,
pub head_dim: usize,
pub block_m: usize,
pub block_n: usize,
pub use_sm_kernel: bool,
}
pub(super) fn validate_qkv(
q: &Tensor<CudaRuntime>,
k: &Tensor<CudaRuntime>,
v: &Tensor<CudaRuntime>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<AttentionParams> {
let q_shape = q.shape();
let k_shape = k.shape();
let v_shape = v.shape();
if q_shape.len() != 4 {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("expected 4D [B, H, S, D], got {}D", q_shape.len()),
});
}
if k_shape.len() != 4 {
return Err(Error::InvalidArgument {
arg: "k",
reason: format!("expected 4D, got {}D", k_shape.len()),
});
}
if v_shape.len() != 4 {
return Err(Error::InvalidArgument {
arg: "v",
reason: format!("expected 4D, got {}D", v_shape.len()),
});
}
if q_shape[1] != num_heads {
return Err(Error::InvalidArgument {
arg: "num_heads",
reason: format!("num_heads={} but q dim 1 is {}", num_heads, q_shape[1]),
});
}
if k_shape[1] != num_kv_heads {
return Err(Error::InvalidArgument {
arg: "num_kv_heads",
reason: format!(
"num_kv_heads={} but k dim 1 is {}",
num_kv_heads, k_shape[1]
),
});
}
if q_shape[3] != head_dim || k_shape[3] != head_dim || v_shape[3] != head_dim {
return Err(Error::InvalidArgument {
arg: "head_dim",
reason: format!(
"head_dim={} but q.D={}, k.D={}, v.D={}",
head_dim, q_shape[3], k_shape[3], v_shape[3]
),
});
}
if q_shape[0] != k_shape[0] || q_shape[0] != v_shape[0] {
return Err(Error::InvalidArgument {
arg: "batch_size",
reason: format!(
"batch mismatch: q.B={}, k.B={}, v.B={}",
q_shape[0], k_shape[0], v_shape[0]
),
});
}
if k_shape[2] != v_shape[2] {
return Err(Error::InvalidArgument {
arg: "v",
reason: format!("k seq_len={} != v seq_len={}", k_shape[2], v_shape[2]),
});
}
if num_heads % num_kv_heads != 0 {
return Err(Error::InvalidArgument {
arg: "num_kv_heads",
reason: format!(
"num_heads ({}) must be divisible by num_kv_heads ({})",
num_heads, num_kv_heads
),
});
}
let dtype = q.dtype();
if k.dtype() != dtype || v.dtype() != dtype {
return Err(Error::InvalidArgument {
arg: "dtype",
reason: format!(
"Q/K/V dtype mismatch: Q={:?}, K={:?}, V={:?}",
dtype,
k.dtype(),
v.dtype()
),
});
}
if !q.is_contiguous() || !k.is_contiguous() || !v.is_contiguous() {
return Err(Error::InvalidArgument {
arg: "contiguity",
reason: "Flash Attention requires contiguous Q, K, V tensors".into(),
});
}
let elem_bytes = q.dtype().size_in_bytes();
let (block_m, block_n, use_sm_kernel) = block_config(head_dim, elem_bytes)?;
Ok(AttentionParams {
batch_size: q_shape[0],
num_heads,
num_kv_heads,
seq_len_q: q_shape[2],
seq_len_k: k_shape[2],
head_dim,
block_m,
block_n,
use_sm_kernel,
})
}
pub(super) fn device_max_smem() -> usize {
unsafe {
let mut cuda_dev: i32 = 0;
sys::cuCtxGetDevice(&mut cuda_dev);
let mut max_smem: i32 = 0;
sys::cuDeviceGetAttribute(
&mut max_smem,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
cuda_dev,
);
max_smem as usize
}
}
pub(super) fn compute_smem(
block_m: usize,
block_n: usize,
head_dim: usize,
elem_bytes: usize,
) -> usize {
let head_stride = head_dim + 1; (block_m * head_stride + 2 * block_n * head_stride) * elem_bytes
}
fn block_config_large(head_dim: usize) -> Option<(usize, usize)> {
match head_dim {
32 => Some((128, 128)),
64 => Some((128, 128)),
96 => Some((64, 128)),
128 => Some((128, 64)),
192 => Some((64, 64)),
256 => Some((64, 64)),
_ => None,
}
}
fn block_config_small(head_dim: usize) -> Option<(usize, usize)> {
match head_dim {
96 => Some((32, 32)),
128 => Some((64, 32)),
192 => Some((32, 16)),
256 => Some((16, 16)),
_ => None,
}
}
pub(super) fn block_config(head_dim: usize, elem_bytes: usize) -> Result<(usize, usize, bool)> {
if let Some((bm, bn)) = block_config_large(head_dim) {
let smem = compute_smem(bm, bn, head_dim, elem_bytes);
if smem <= device_max_smem() {
return Ok((bm, bn, false));
}
}
if let Some((bm, bn)) = block_config_small(head_dim) {
let smem = compute_smem(bm, bn, head_dim, elem_bytes);
if smem <= device_max_smem() {
return Ok((bm, bn, true));
}
}
Err(Error::InvalidArgument {
arg: "head_dim",
reason: format!(
"unsupported head_dim={} for this GPU (max shared memory: {}KB). Supported: 32, 64, 96, 128, 192, 256",
head_dim,
device_max_smem() / 1024
),
})
}
pub(crate) fn set_smem_attribute(func: &CudaFunction, smem_size: usize) -> Result<()> {
if smem_size <= 48 * 1024 {
return Ok(());
}
let max_shared_mem = unsafe {
let mut cuda_dev: i32 = 0;
sys::cuCtxGetDevice(&mut cuda_dev);
let mut max_smem: i32 = 0;
sys::cuDeviceGetAttribute(
&mut max_smem,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
cuda_dev,
);
max_smem as usize
};
if smem_size > max_shared_mem {
return Err(Error::KernelError {
reason: format!(
"shared memory {}KB exceeds device limit {}KB",
smem_size / 1024,
max_shared_mem / 1024
),
});
}
let cu_function: sys::CUfunction = unsafe {
let kernel_ptr = func as *const _ as *const usize;
std::ptr::read(kernel_ptr.add(1)) as sys::CUfunction
};
unsafe {
let result = sys::cuFuncSetAttribute(
cu_function,
sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
smem_size as i32,
);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::KernelError {
reason: format!(
"failed to set dynamic shared memory to {}KB: {:?}",
smem_size / 1024,
result
),
});
}
}
Ok(())
}