pub use rlx_gpu_kernels::*;
use std::sync::Arc;
use std::sync::OnceLock;
use cudarc::driver::{CudaContext, CudaFunction, CudaModule};
pub struct CudaKernel {
pub module: Arc<CudaModule>,
pub function: CudaFunction,
}
fn ptx_cache_dir() -> Option<std::path::PathBuf> {
use std::path::PathBuf;
if let Some(p) = rlx_ir::env::var("RLX_CUDA_PTX_CACHE") {
return Some(PathBuf::from(p));
}
let base = std::env::var("XDG_CACHE_HOME")
.map(PathBuf::from)
.ok()
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| PathBuf::from(h).join(".cache"))
})?;
Some(base.join("rlx-cuda").join("ptx-cuda-12060"))
}
fn fnv1a64(s: &str) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for b in s.as_bytes() {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
fn compile(ctx: &Arc<CudaContext>, src: &str, entry: &str) -> CudaKernel {
let cache_path =
ptx_cache_dir().map(|d| d.join(format!("{}-{:016x}.ptx", entry, fnv1a64(src))));
let ptx = if let Some(ref p) = cache_path {
if let Ok(cached) = std::fs::read_to_string(p) {
cudarc::nvrtc::Ptx::from_src(cached)
} else {
let fresh = cudarc::nvrtc::compile_ptx(src)
.unwrap_or_else(|e| panic!("rlx-cuda: NVRTC compile failed for {entry}: {e}"));
if let Some(dir) = p.parent() {
let _ = std::fs::create_dir_all(dir);
}
let tmp = p.with_extension("ptx.tmp");
if std::fs::write(&tmp, fresh.to_src()).is_ok() {
let _ = std::fs::rename(&tmp, p);
}
fresh
}
} else {
cudarc::nvrtc::compile_ptx(src)
.unwrap_or_else(|e| panic!("rlx-cuda: NVRTC compile failed for {entry}: {e}"))
};
let module = ctx
.load_module(ptx)
.unwrap_or_else(|e| panic!("rlx-cuda: load_module failed for {entry}: {e}"));
let function = module
.load_function(entry)
.unwrap_or_else(|e| panic!("rlx-cuda: load_function {entry}: {e}"));
CudaKernel { module, function }
}
macro_rules! kernel_cache {
($static_name:ident, $fn_name:ident, $src:expr, $entry:expr) => {
static $static_name: OnceLock<CudaKernel> = OnceLock::new();
pub fn $fn_name(ctx: &Arc<CudaContext>) -> &'static CudaKernel {
$static_name.get_or_init(|| compile(ctx, $src, $entry))
}
};
}
kernel_cache!(BINARY, binary_kernel, BINARY_CU, "binary");
kernel_cache!(
FUSED_BINARY_UNARY,
fused_binary_unary_kernel,
FUSED_BINARY_UNARY_CU,
"fused_binary_unary"
);
kernel_cache!(
CAST_F32_TO_HALF,
cast_f32_to_half_kernel,
CAST_F32_TO_HALF_CU,
"cast_f32_to_half"
);
kernel_cache!(UNARY, unary_kernel, UNARY_CU, "unary");
kernel_cache!(COPY, copy_kernel, COPY_CU, "copy");
kernel_cache!(MATMUL, matmul_kernel, MATMUL_CU, "matmul");
kernel_cache!(
MATMUL_EPILOGUE,
matmul_epilogue_kernel,
MATMUL_EPILOGUE_CU,
"matmul_epilogue"
);
kernel_cache!(
MATMUL_WMMA,
matmul_wmma_kernel,
MATMUL_WMMA_CU,
"matmul_wmma"
);
kernel_cache!(COMPARE, compare_kernel, COMPARE_CU, "compare");
kernel_cache!(WHEREK, where_kernel, WHERE_CU, "where_select");
kernel_cache!(REDUCE, reduce_kernel, REDUCE_CU, "reduce");
kernel_cache!(SOFTMAX, softmax_kernel, SOFTMAX_CU, "softmax");
kernel_cache!(LAYERNORM, layernorm_kernel, LAYERNORM_CU, "rlx_norm");
kernel_cache!(
RMS_NORM_BWD,
rms_norm_backward_kernel,
RMS_NORM_BWD_CU,
"rlx_rms_norm_bwd"
);
kernel_cache!(
RMS_NORM_BWD_ZERO,
rms_norm_bwd_zero_kernel,
RMS_NORM_BWD_CU,
"rlx_zero_f32"
);
kernel_cache!(
CUMSUM_BWD,
cumsum_backward_kernel,
CUMSUM_BWD_CU,
"rlx_cumsum_bwd"
);
kernel_cache!(ROPE_BWD, rope_backward_kernel, ROPE_BWD_CU, "rlx_rope_bwd");
kernel_cache!(
GATHER_BWD,
gather_backward_kernel,
GATHER_BWD_CU,
"rlx_gather_axis_bwd"
);
kernel_cache!(
FUSED_RESIDUAL_LN,
fused_residual_ln_kernel,
FUSED_RESIDUAL_LN_CU,
"fused_residual_ln"
);
kernel_cache!(GATHER, gather_kernel, GATHER_CU, "gather");
kernel_cache!(
GATHER_AXIS,
gather_axis_kernel,
GATHER_AXIS_CU,
"gather_axis"
);
kernel_cache!(NARROW, narrow_kernel, NARROW_CU, "narrow");
kernel_cache!(CONCAT, concat_kernel, CONCAT_CU, "concat");
kernel_cache!(TRANSPOSE, transpose_kernel, TRANSPOSE_CU, "transpose");
kernel_cache!(EXPAND, expand_kernel, EXPAND_CU, "expand");
kernel_cache!(ATTENTION, attention_kernel, ATTENTION_CU, "attention");
kernel_cache!(
ATTENTION_BWD,
attention_bwd_kernel,
ATTENTION_BWD_CU,
"attention_bwd"
);
kernel_cache!(ARGMAX, argmax_kernel, ARGMAX_CU, "argmax");
kernel_cache!(ROPE, rope_kernel, ROPE_CU, "rope");
kernel_cache!(CUMSUM, cumsum_kernel, CUMSUM_CU, "cumsum");
kernel_cache!(TOPK, topk_kernel, TOPK_CU, "topk");
kernel_cache!(
GROUPED_MATMUL,
grouped_matmul_kernel,
GROUPED_MATMUL_CU,
"grouped_matmul"
);
kernel_cache!(
SCATTER_ADD_ZERO,
scatter_add_zero_kernel,
SCATTER_ADD_CU,
"scatter_add_zero"
);
kernel_cache!(
SCATTER_ADD_ACC,
scatter_add_acc_kernel,
SCATTER_ADD_CU,
"scatter_add_acc"
);
kernel_cache!(
DEQUANT_MATMUL,
dequant_matmul_kernel,
DEQUANT_MATMUL_CU,
"dequant_matmul"
);
kernel_cache!(
DEQUANT_GGUF,
dequant_gguf_kernel,
DEQUANT_GGUF_CU,
"dequant_gguf"
);
kernel_cache!(SAMPLE, sample_kernel, SAMPLE_CU, "sample");
kernel_cache!(
SELECTIVE_SCAN,
selective_scan_kernel,
SELECTIVE_SCAN_CU,
"selective_scan"
);
kernel_cache!(POOL1D, pool1d_kernel, POOL1D_CU, "pool1d");
kernel_cache!(POOL2D, pool2d_kernel, POOL2D_CU, "pool2d");
kernel_cache!(POOL3D, pool3d_kernel, POOL3D_CU, "pool3d");
kernel_cache!(CONV1D, conv1d_kernel, CONV1D_CU, "conv1d");
kernel_cache!(CONV2D, conv2d_kernel, CONV2D_CU, "conv2d");
kernel_cache!(CONV3D, conv3d_kernel, CONV3D_CU, "conv3d");
kernel_cache!(
LAYER_NORM2D,
layer_norm2d_kernel,
LAYER_NORM2D_CU,
"layer_norm2d"
);
kernel_cache!(
CONV_TRANSPOSE2D,
conv_transpose2d_kernel,
CONV_TRANSPOSE2D_CU,
"conv_transpose2d"
);
kernel_cache!(GROUP_NORM, group_norm_kernel, GROUP_NORM_CU, "group_norm");
kernel_cache!(
RESIZE_NEAREST_2X,
resize_nearest_2x_kernel,
RESIZE_NEAREST_2X_CU,
"resize_nearest_2x"
);
kernel_cache!(
ELEMENTWISE_REGION,
elementwise_region_kernel,
ELEMENTWISE_REGION_CU,
"elementwise_region"
);
kernel_cache!(
GAUSSIAN_SPLAT_RASTERIZE,
gaussian_splat_rasterize_kernel,
GAUSSIAN_SPLAT_RASTERIZE_CU,
"gaussian_splat_rasterize"
);
kernel_cache!(
FFT_RADIX2_FULL,
fft_radix2_full_kernel,
FFT_CU,
"fft_radix2_full"
);
kernel_cache!(
FFT_BIT_REVERSE,
fft_bit_reverse_kernel,
FFT_CU,
"fft_bit_reverse"
);
kernel_cache!(FFT_INNER, fft_inner_kernel, FFT_CU, "fft_inner");
kernel_cache!(FFT_OUTER_R4, fft_outer_r4_kernel, FFT_CU, "fft_outer_r4");
kernel_cache!(FFT_OUTER_R2, fft_outer_r2_kernel, FFT_CU, "fft_outer_r2");
pub fn dispatch_grid_1d(n: u32, block_x: u32) -> (u32, u32) {
(n.div_ceil(block_x), block_x)
}
pub fn dispatch_grid_2d(
width: u32,
height: u32,
block_x: u32,
block_y: u32,
) -> ((u32, u32, u32), (u32, u32, u32)) {
(
(width.div_ceil(block_x), height.div_ceil(block_y), 1),
(block_x, block_y, 1),
)
}