use trueno_gpu::kernels::{AttentionKernel, Kernel, QuantizeKernel};
#[cfg(feature = "cuda")]
#[derive(Debug, Clone)]
pub struct CudaBackend {
pub m: u32,
pub n: u32,
pub k: u32,
pub head_dim: u32,
pub num_heads: u32,
pub max_seq_len: u32,
q4k_gemm_ptx_cache: std::cell::RefCell<Option<String>>,
flash_attention_ptx_cache: std::cell::RefCell<Option<String>>,
}
#[cfg(feature = "cuda")]
impl CudaBackend {
#[must_use]
pub fn new(m: u32, n: u32, k: u32, head_dim: u32) -> Self {
Self {
m,
n,
k,
head_dim,
num_heads: 32, max_seq_len: 2048, q4k_gemm_ptx_cache: std::cell::RefCell::new(None),
flash_attention_ptx_cache: std::cell::RefCell::new(None),
}
}
#[must_use]
pub const fn with_num_heads(mut self, num_heads: u32) -> Self {
self.num_heads = num_heads;
self
}
#[must_use]
pub const fn with_max_seq_len(mut self, max_seq_len: u32) -> Self {
self.max_seq_len = max_seq_len;
self
}
#[must_use]
pub fn q4k_gemm_ptx(&self) -> String {
if let Some(cached) = self.q4k_gemm_ptx_cache.borrow().as_ref() {
return cached.clone();
}
let kernel = QuantizeKernel::new(self.m, self.n, self.k);
let ptx = kernel.emit_ptx();
*self.q4k_gemm_ptx_cache.borrow_mut() = Some(ptx.clone());
ptx
}
#[must_use]
pub fn q4k_gemm_kernel_name(&self) -> &'static str {
"q4k_gemm_fused"
}
#[must_use]
pub const fn q4k_blocks_per_row(&self) -> u32 {
self.k / 32
}
#[must_use]
pub const fn q4k_weight_bytes(&self) -> usize {
let blocks_per_row = self.k / 32;
let bytes_per_row = blocks_per_row * 18;
(self.n as usize) * (bytes_per_row as usize)
}
#[must_use]
pub fn flash_attention_ptx(&self, seq_len: u32, head_dim: u32, causal: bool) -> String {
let kernel = if causal {
AttentionKernel::new(seq_len, head_dim).with_causal()
} else {
AttentionKernel::new(seq_len, head_dim)
};
kernel.emit_ptx()
}
#[must_use]
pub fn flash_attention_causal_ptx(&self) -> String {
if let Some(cached) = self.flash_attention_ptx_cache.borrow().as_ref() {
return cached.clone();
}
let ptx = self.flash_attention_ptx(self.max_seq_len, self.head_dim, true);
*self.flash_attention_ptx_cache.borrow_mut() = Some(ptx.clone());
ptx
}
#[must_use]
pub const fn flash_attention_kernel_name(&self, causal: bool) -> &'static str {
if causal {
"flash_attention_causal"
} else {
"flash_attention"
}
}
#[must_use]
pub const fn flash_attention_smem_bytes(&self) -> usize {
let tile_q = 64_u32;
let tile_kv = 64_u32;
let d = self.head_dim;
((tile_q * d + tile_kv * d * 2) * 4) as usize
}
#[must_use]
pub const fn kv_cache_bytes_per_layer(&self) -> usize {
let k_size = self.num_heads * self.max_seq_len * self.head_dim * 4;
let v_size = self.num_heads * self.max_seq_len * self.head_dim * 4;
(k_size + v_size) as usize
}
#[must_use]
pub const fn kv_cache_total_bytes(&self, num_layers: u32) -> usize {
self.kv_cache_bytes_per_layer() * (num_layers as usize)
}
#[must_use]
pub const fn kv_cache_page_tokens(&self) -> u32 {
64
}
#[must_use]
pub const fn kv_cache_pages_needed(&self, seq_len: u32) -> u32 {
let page_size = self.kv_cache_page_tokens();
seq_len.div_ceil(page_size)
}
#[must_use]
pub const fn q4k_gemm_launch_config(&self) -> ((u32, u32, u32), (u32, u32, u32)) {
let tile_size = 32_u32;
let grid_x = self.n.div_ceil(tile_size);
let grid_y = self.m.div_ceil(tile_size);
let grid = (grid_x, grid_y, 1);
let block = (tile_size * tile_size, 1, 1);
(grid, block)
}
#[must_use]
pub const fn flash_attention_launch_config(
&self,
seq_len: u32,
) -> ((u32, u32, u32), (u32, u32, u32)) {
let tile_q = 64_u32;
let num_q_blocks = seq_len.div_ceil(tile_q);
let grid = (num_q_blocks, self.num_heads, 1);
let block = (tile_q * self.head_dim, 1, 1);
(grid, block)
}
#[must_use]
pub const fn validate_dimensions(&self) -> bool {
let k_valid = self.k.is_multiple_of(32);
let head_dim_valid = self.head_dim.is_power_of_two();
let non_zero = self.m > 0 && self.n > 0 && self.k > 0 && self.head_dim > 0;
k_valid && head_dim_valid && non_zero
}
#[must_use]
pub const fn ptx_target(&self) -> &'static str {
"sm_89"
}
#[must_use]
pub const fn ptx_version(&self) -> (u32, u32) {
(8, 0)
}
}
include!("backend_cuda.rs");