use cudarc::driver::PushKernelArg;
pub use cudarc::driver::safe::LaunchConfig;
use cudarc::driver::safe::{CudaContext, CudaFunction, CudaModule, CudaStream};
use cudarc::nvrtc::Ptx;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use crate::dtype::DType;
use crate::error::{Error, Result};
const KERNEL_DIR: &str = env!("CUDA_KERNEL_DIR");
fn load_ptx(name: &str) -> Ptx {
let path = format!("{}/{}.ptx", KERNEL_DIR, name);
Ptx::from_file(path)
}
static MODULE_CACHE: OnceLock<Mutex<HashMap<(usize, &'static str), Arc<CudaModule>>>> =
OnceLock::new();
pub fn get_or_load_module(
context: &Arc<CudaContext>,
device_index: usize,
module_name: &'static str,
) -> Result<Arc<CudaModule>> {
let cache = MODULE_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = cache
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let key = (device_index, module_name);
if let Some(module) = guard.get(&key) {
return Ok(module.clone());
}
let ptx = load_ptx(module_name);
let module = context.load_module(ptx).map_err(|e| {
Error::Internal(format!(
"Failed to load CUDA module '{}': {:?}. \
Ensure CUDA kernels were compiled correctly by build.rs.",
module_name, e
))
})?;
guard.insert(key, module.clone());
Ok(module)
}
pub fn preload_modules(
context: &Arc<CudaContext>,
device_index: usize,
module_names: &[&'static str],
) -> Result<()> {
for name in module_names {
get_or_load_module(context, device_index, name)?;
}
Ok(())
}
pub fn get_kernel_function(module: &Arc<CudaModule>, kernel_name: &str) -> Result<CudaFunction> {
module.load_function(kernel_name).map_err(|e| {
Error::Internal(format!(
"Failed to get kernel '{}': {:?}. \
Check that the kernel name matches the CUDA source.",
kernel_name, e
))
})
}
pub const BLOCK_SIZE: u32 = 256;
#[inline]
pub fn elementwise_launch_config(numel: usize) -> (u32, u32, u32) {
let grid_size = ((numel as u32) + BLOCK_SIZE - 1) / BLOCK_SIZE;
(grid_size, 1, 1)
}
#[inline]
#[allow(dead_code)] pub fn reduce_launch_config(numel: usize) -> (u32, u32) {
let block_size = BLOCK_SIZE;
let grid_size = ((numel as u32) + block_size - 1) / block_size;
let grid_size = grid_size.min(1024);
(grid_size, block_size)
}
#[inline]
pub fn reduce_dim_launch_config(outer: usize, inner: usize) -> ((u32, u32, u32), u32) {
let grid = (outer as u32, inner as u32, 1);
let block = BLOCK_SIZE;
(grid, block)
}
#[inline]
pub fn softmax_launch_config(outer: usize, dim_size: usize) -> (u32, u32, u32) {
let block_size = BLOCK_SIZE.min(dim_size as u32).next_power_of_two();
let block_size = block_size.min(BLOCK_SIZE);
let grid_size = outer as u32;
let shared_mem = 2 * block_size * 4; (grid_size, block_size, shared_mem)
}
#[inline]
#[allow(dead_code)] pub fn softmax_dim_launch_config(outer: usize, inner: usize) -> ((u32, u32, u32), (u32, u32, u32)) {
let total_elements = (outer * inner) as u32;
let grid_x = (total_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let grid = (grid_x, 1, 1);
let block = (BLOCK_SIZE, 1, 1);
(grid, block)
}
#[inline]
pub fn launch_config(
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
) -> LaunchConfig {
LaunchConfig {
grid_dim: grid,
block_dim: block,
shared_mem_bytes: shared_mem,
}
}
pub mod kernel_names {
pub const BINARY_MODULE: &str = "binary";
pub const UNARY_MODULE: &str = "unary";
pub const SCALAR_MODULE: &str = "scalar";
pub const REDUCE_MODULE: &str = "reduce";
pub const COMPARE_MODULE: &str = "compare";
pub const ACTIVATION_MODULE: &str = "activation";
pub const SOFTMAX_MODULE: &str = "softmax";
pub const NORM_MODULE: &str = "norm";
pub const FUSED_ADD_NORM_MODULE: &str = "fused_add_norm";
pub const CAST_MODULE: &str = "cast";
pub const UTILITY_MODULE: &str = "utility";
pub const TERNARY_MODULE: &str = "ternary";
#[cfg(feature = "sparse")]
pub const SCAN_MODULE: &str = "scan";
#[cfg(feature = "sparse")]
pub const SPARSE_SPMV_MODULE: &str = "sparse_spmv";
#[cfg(feature = "sparse")]
pub const SPARSE_MERGE_MODULE: &str = "sparse_merge";
#[cfg(feature = "sparse")]
pub const SPARSE_CONVERT_MODULE: &str = "sparse_convert";
#[cfg(feature = "sparse")]
pub const SPARSE_COO_MODULE: &str = "sparse_coo";
#[cfg(feature = "sparse")]
pub const DSMM_MODULE: &str = "dsmm";
pub const LINALG_BASIC_MODULE: &str = "linalg_basic";
pub const LINALG_BANDED_MODULE: &str = "linalg_banded";
pub const LINALG_SOLVERS_MODULE: &str = "linalg_solvers";
pub const LINALG_DECOMP_MODULE: &str = "linalg_decomp";
pub const LINALG_SVD_MODULE: &str = "linalg_svd";
pub const LINALG_EIGEN_MODULE: &str = "linalg_eigen";
pub const LINALG_SCHUR_MODULE: &str = "linalg_schur";
pub const LINALG_EIGEN_GENERAL_MODULE: &str = "linalg_eigen_general";
pub const LINALG_ADVANCED_MODULE: &str = "linalg_advanced";
pub const LINALG_QZ_MODULE: &str = "linalg_qz";
pub const LINALG_MATRIX_FUNCS_MODULE: &str = "linalg_matrix_funcs";
pub const MATMUL_MODULE: &str = "matmul";
pub const GEMV_MODULE: &str = "gemv";
pub const CUMULATIVE_MODULE: &str = "cumulative";
pub const DISTRIBUTIONS_MODULE: &str = "distributions";
pub const QUASIRANDOM_MODULE: &str = "quasirandom";
pub const ADVANCED_RANDOM_MODULE: &str = "advanced_random";
pub const STATISTICS_MODULE: &str = "statistics";
pub const SEMIRING_MATMUL_MODULE: &str = "semiring_matmul";
#[inline]
pub fn reduce_kernel(op: &str) -> String {
format!("reduce_{}", op)
}
#[inline]
pub fn reduce_dim_kernel(op: &str) -> String {
format!("reduce_{}_dim", op)
}
}
pub fn dtype_suffix(dtype: DType) -> &'static str {
match dtype {
DType::F32 => "f32",
DType::F64 => "f64",
DType::F16 => "f16",
DType::BF16 => "bf16",
DType::FP8E4M3 => "fp8_e4m3",
DType::FP8E5M2 => "fp8_e5m2",
DType::I64 => "i64",
DType::I32 => "i32",
DType::I16 => "i16",
DType::I8 => "i8",
DType::U64 => "u64",
DType::U32 => "u32",
DType::U16 => "u16",
DType::U8 => "u8",
DType::Bool => "bool",
DType::Complex64 => "c64",
DType::Complex128 => "c128",
}
}
#[inline]
pub fn kernel_name(base: &str, dtype: DType) -> String {
format!("{}_{}", base, dtype_suffix(dtype))
}
pub unsafe fn launch_unary_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
module_name: &'static str,
op: &str,
dtype: DType,
input_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, module_name)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA {} kernel '{}' launch failed: {:?}",
module_name, op, e
))
})?;
Ok(())
}
}
pub unsafe fn launch_binary_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
module_name: &'static str,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, module_name)?;
let func_name = kernel_name(op, dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA {} kernel '{}' launch failed: {:?}",
module_name, op, e
))
})?;
Ok(())
}
}
use crate::algorithm::TileConfig;
#[inline]
pub fn matmul_launch_config(
m: usize,
n: usize,
cfg: &TileConfig,
elem_size: usize,
) -> LaunchConfig {
let grid_x = ((n as u32) + cfg.block_n as u32 - 1) / cfg.block_n as u32;
let grid_y = ((m as u32) + cfg.block_m as u32 - 1) / cfg.block_m as u32;
let threads_x = cfg.block_n / cfg.thread_n;
let threads_y = cfg.block_m / cfg.thread_m;
let shared_mem_bytes = (cfg.block_m * cfg.block_k + cfg.block_k * cfg.block_n) * elem_size;
LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (threads_x as u32, threads_y as u32, 1),
shared_mem_bytes: shared_mem_bytes as u32,
}
}
#[inline]
pub fn matmul_batched_launch_config(
batch: usize,
m: usize,
n: usize,
cfg: &TileConfig,
elem_size: usize,
) -> LaunchConfig {
let grid_x = ((n as u32) + cfg.block_n as u32 - 1) / cfg.block_n as u32;
let grid_y = ((m as u32) + cfg.block_m as u32 - 1) / cfg.block_m as u32;
let grid_z = batch as u32;
let threads_x = cfg.block_n / cfg.thread_n;
let threads_y = cfg.block_m / cfg.thread_m;
let shared_mem_bytes = (cfg.block_m * cfg.block_k + cfg.block_k * cfg.block_n) * elem_size;
LaunchConfig {
grid_dim: (grid_x, grid_y, grid_z),
block_dim: (threads_x as u32, threads_y as u32, 1),
shared_mem_bytes: shared_mem_bytes as u32,
}
}
#[inline]
pub fn default_tile_config(dtype: DType) -> TileConfig {
match dtype {
DType::F64 => TileConfig {
block_m: 64,
block_n: 64,
block_k: 8,
thread_m: 4,
thread_n: 4,
},
_ => TileConfig::CUDA,
}
}
pub unsafe fn launch_matmul_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if m <= 16 {
unsafe {
return launch_gemv_kernel(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
c_ptr,
1,
m,
n,
k,
1,
1,
);
}
}
unsafe {
launch_matmul_kernel_with_config(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
c_ptr,
m,
n,
k,
&default_tile_config(dtype),
)
}
}
pub unsafe fn launch_gemv_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?;
let func_name = kernel_name("gemv", dtype);
let func = get_kernel_function(&module, &func_name)?;
let block_size: u32 = 256;
let grid_x = ((n as u32) + block_size - 1) / block_size;
let grid_y = m as u32;
let grid_z = batch as u32;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, grid_z),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA GEMV kernel launch failed: {:?}", e)))?;
}
Ok(())
}
pub unsafe fn launch_gemv_kernel_bt(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?;
let func_name = kernel_name("gemv_bt", dtype);
let func = get_kernel_function(&module, &func_name)?;
let warps_per_block: u32 = 8;
let grid_x = ((n as u32) + warps_per_block - 1) / warps_per_block;
let grid_y = m as u32;
let grid_z = batch as u32;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, grid_z),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA GEMV-BT kernel launch failed: {:?}", e)))?;
}
Ok(())
}
pub unsafe fn launch_gemv_kernel_bt_mr(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?;
let func_name = kernel_name("gemv_bt_mr", dtype);
let func = get_kernel_function(&module, &func_name)?;
let warps_per_block: u32 = 8;
let rows_per_warp: u32 = 2;
let cols_per_block = warps_per_block * rows_per_warp; let grid_x = ((n as u32) + cols_per_block - 1) / cols_per_block;
let grid_y = m as u32;
let grid_z = batch as u32;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, grid_z),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA GEMV-BT-MR kernel launch failed: {:?}", e))
})?;
}
Ok(())
}
pub unsafe fn launch_matmul_kernel_with_config(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
m: usize,
n: usize,
k: usize,
tile_cfg: &TileConfig,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?;
let func_name = kernel_name("matmul", dtype);
let func = get_kernel_function(&module, &func_name)?;
let elem_size = dtype.size_in_bytes();
let shared_elem_size = match dtype {
DType::F16 | DType::BF16 => 4, _ => elem_size,
};
let cfg = matmul_launch_config(m, n, tile_cfg, shared_elem_size);
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let block_m = tile_cfg.block_m as u32;
let block_n = tile_cfg.block_n as u32;
let block_k = tile_cfg.block_k as u32;
let thread_m = tile_cfg.thread_m as u32;
let thread_n = tile_cfg.thread_n as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&block_m);
builder.arg(&block_n);
builder.arg(&block_k);
builder.arg(&thread_m);
builder.arg(&thread_n);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA matmul kernel launch failed: {:?}", e)))?;
}
Ok(())
}
pub unsafe fn launch_matmul_batched_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
if m <= 16 {
unsafe {
return launch_gemv_kernel(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
c_ptr,
batch,
m,
n,
k,
a_batch,
b_batch,
);
}
}
unsafe {
launch_matmul_batched_kernel_with_config(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
c_ptr,
batch,
m,
n,
k,
&default_tile_config(dtype),
a_batch,
b_batch,
)
}
}
pub unsafe fn launch_matmul_batched_kernel_with_config(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
tile_cfg: &TileConfig,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?;
let func_name = kernel_name("matmul_batched", dtype);
let func = get_kernel_function(&module, &func_name)?;
let elem_size = dtype.size_in_bytes();
let shared_elem_size = match dtype {
DType::F16 | DType::BF16 => 4,
_ => elem_size,
};
let cfg = matmul_batched_launch_config(batch, m, n, tile_cfg, shared_elem_size);
let batch_u32 = batch as u32;
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let block_m = tile_cfg.block_m as u32;
let block_n = tile_cfg.block_n as u32;
let block_k = tile_cfg.block_k as u32;
let thread_m = tile_cfg.thread_m as u32;
let thread_n = tile_cfg.thread_n as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&batch_u32);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&block_m);
builder.arg(&block_n);
builder.arg(&block_k);
builder.arg(&thread_m);
builder.arg(&thread_n);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA batched matmul kernel launch failed: {:?}", e))
})?;
}
Ok(())
}
pub unsafe fn launch_matmul_bias_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
bias_ptr: u64,
c_ptr: u64,
m: usize,
n: usize,
k: usize,
) -> Result<()> {
unsafe {
launch_matmul_bias_kernel_with_config(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
bias_ptr,
c_ptr,
m,
n,
k,
&default_tile_config(dtype),
)
}
}
pub unsafe fn launch_matmul_bias_kernel_with_config(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
bias_ptr: u64,
c_ptr: u64,
m: usize,
n: usize,
k: usize,
tile_cfg: &TileConfig,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?;
let func_name = kernel_name("matmul_bias", dtype);
let func = get_kernel_function(&module, &func_name)?;
let elem_size = dtype.size_in_bytes();
let shared_elem_size = match dtype {
DType::F16 | DType::BF16 => 4, _ => elem_size,
};
let cfg = matmul_launch_config(m, n, tile_cfg, shared_elem_size);
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let block_m = tile_cfg.block_m as u32;
let block_n = tile_cfg.block_n as u32;
let block_k = tile_cfg.block_k as u32;
let thread_m = tile_cfg.thread_m as u32;
let thread_n = tile_cfg.thread_n as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&bias_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&block_m);
builder.arg(&block_n);
builder.arg(&block_k);
builder.arg(&thread_m);
builder.arg(&thread_n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA matmul_bias kernel launch failed: {:?}", e))
})?;
}
Ok(())
}
pub unsafe fn launch_matmul_bias_batched_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
bias_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
unsafe {
launch_matmul_bias_batched_kernel_with_config(
context,
stream,
device_index,
dtype,
a_ptr,
b_ptr,
bias_ptr,
c_ptr,
batch,
m,
n,
k,
&default_tile_config(dtype),
a_batch,
b_batch,
)
}
}
pub unsafe fn launch_matmul_bias_batched_kernel_with_config(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
bias_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
tile_cfg: &TileConfig,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?;
let func_name = kernel_name("matmul_bias_batched", dtype);
let func = get_kernel_function(&module, &func_name)?;
let elem_size = dtype.size_in_bytes();
let shared_elem_size = match dtype {
DType::F16 | DType::BF16 => 4,
_ => elem_size,
};
let cfg = matmul_batched_launch_config(batch, m, n, tile_cfg, shared_elem_size);
let batch_u32 = batch as u32;
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let block_m = tile_cfg.block_m as u32;
let block_n = tile_cfg.block_n as u32;
let block_k = tile_cfg.block_k as u32;
let thread_m = tile_cfg.thread_m as u32;
let thread_n = tile_cfg.thread_n as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&bias_ptr);
builder.arg(&c_ptr);
builder.arg(&batch_u32);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&block_m);
builder.arg(&block_n);
builder.arg(&block_k);
builder.arg(&thread_m);
builder.arg(&thread_n);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA batched matmul_bias kernel launch failed: {:?}",
e
))
})?;
}
Ok(())
}
pub unsafe fn launch_semiring_matmul_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
m: usize,
n: usize,
k: usize,
semiring_op: u32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::SEMIRING_MATMUL_MODULE)?;
let func_name = kernel_name("semiring_matmul", dtype);
let func = get_kernel_function(&module, &func_name)?;
let block_x = 16u32;
let block_y = 16u32;
let grid_x = (n as u32 + block_x - 1) / block_x;
let grid_y = (m as u32 + block_y - 1) / block_y;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (block_x, block_y, 1),
shared_mem_bytes: 0,
};
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&semiring_op);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA semiring matmul kernel launch failed: {:?}",
e
))
})?;
}
Ok(())
}
pub unsafe fn launch_semiring_matmul_batched_kernel(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
c_ptr: u64,
batch: usize,
m: usize,
n: usize,
k: usize,
semiring_op: u32,
a_batch: usize,
b_batch: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::SEMIRING_MATMUL_MODULE)?;
let func_name = kernel_name("semiring_matmul_batched", dtype);
let func = get_kernel_function(&module, &func_name)?;
let block_x = 16u32;
let block_y = 16u32;
let grid_x = (n as u32 + block_x - 1) / block_x;
let grid_y = (m as u32 + block_y - 1) / block_y;
let grid_z = batch as u32;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, grid_z),
block_dim: (block_x, block_y, 1),
shared_mem_bytes: 0,
};
let m_u32 = m as u32;
let n_u32 = n as u32;
let k_u32 = k as u32;
let batch_u32 = batch as u32;
let a_batch_u32 = a_batch as u32;
let b_batch_u32 = b_batch as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&c_ptr);
builder.arg(&m_u32);
builder.arg(&n_u32);
builder.arg(&k_u32);
builder.arg(&semiring_op);
builder.arg(&batch_u32);
builder.arg(&a_batch_u32);
builder.arg(&b_batch_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA batched semiring matmul kernel launch failed: {:?}",
e
))
})?;
}
Ok(())
}