#[cfg(feature = "cuda")]
use std::collections::HashMap;
#[cfg(feature = "cuda")]
use std::hash::{Hash, Hasher};
#[cfg(feature = "cuda")]
use std::sync::{Arc, LazyLock, Mutex};
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaContext, CudaFunction, DriverError};
#[cfg(feature = "cuda")]
use cudarc::nvrtc::Ptx;
#[cfg(feature = "cuda")]
static MODULE_CACHE: LazyLock<Mutex<HashMap<(&'static str, u32), CudaFunction>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
#[cfg(feature = "cuda")]
static OWNED_MODULE_CACHE: LazyLock<Mutex<HashMap<(u64, u32), CudaFunction>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
#[cfg(feature = "cuda")]
pub fn get_or_compile(
ctx: &Arc<CudaContext>,
ptx_src: &'static str,
kernel_name: &'static str,
device_ordinal: u32,
) -> Result<CudaFunction, DriverError> {
let key = (kernel_name, device_ordinal);
let mut cache = MODULE_CACHE.lock().unwrap();
if let Some(func) = cache.get(&key) {
return Ok(func.clone());
}
let module = ctx.load_module(Ptx::from_src(ptx_src))?;
let func = module.load_function(kernel_name)?;
cache.insert(key, func.clone());
Ok(func)
}
#[cfg(feature = "cuda")]
pub fn get_or_compile_owned(
ctx: &Arc<CudaContext>,
ptx_src: String,
kernel_name: String,
device_ordinal: u32,
) -> Result<CudaFunction, DriverError> {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
ptx_src.hash(&mut hasher);
let ptx_hash = hasher.finish();
let key = (ptx_hash, device_ordinal);
let mut cache = OWNED_MODULE_CACHE.lock().unwrap();
if let Some(func) = cache.get(&key) {
return Ok(func.clone());
}
let leaked_ptx: &'static str = Box::leak(ptx_src.into_boxed_str());
let leaked_name: &'static str = Box::leak(kernel_name.into_boxed_str());
let module = ctx.load_module(Ptx::from_src(leaked_ptx))?;
let func = module.load_function(leaked_name)?;
cache.insert(key, func.clone());
Ok(func)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use crate::device::GpuDevice;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
#[test]
fn cache_returns_function_on_repeated_calls() {
let dev = crate::device::GpuDevice::new(0).expect("CUDA device 0");
let a = crate::transfer::cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).expect("a");
let b = crate::transfer::cpu_to_gpu(&[4.0f32, 5.0, 6.0], &dev).expect("b");
let r1 = crate::kernels::gpu_add(&a, &b, &dev).expect("first add (compiles)");
let r2 = crate::kernels::gpu_add(&a, &b, &dev).expect("second add (cached)");
let h1 = crate::transfer::gpu_to_cpu(&r1, &dev).expect("r1");
let h2 = crate::transfer::gpu_to_cpu(&r2, &dev).expect("r2");
assert_eq!(h1, h2, "cached kernel should produce identical results");
assert_eq!(h1, vec![5.0, 7.0, 9.0]);
}
#[test]
fn cached_kernel_produces_correct_results() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let a = cpu_to_gpu(&a_data, &dev).expect("a to gpu");
let b = cpu_to_gpu(&b_data, &dev).expect("b to gpu");
let out1 = crate::kernels::gpu_add(&a, &b, &dev).expect("gpu_add 1st");
let host1 = gpu_to_cpu(&out1, &dev).expect("gpu_to_cpu 1st");
let out2 = crate::kernels::gpu_add(&a, &b, &dev).expect("gpu_add 2nd");
let host2 = gpu_to_cpu(&out2, &dev).expect("gpu_to_cpu 2nd");
for (i, ((&g1, &g2), &e)) in host1
.iter()
.zip(host2.iter())
.zip(expected.iter())
.enumerate()
{
assert!(
(g1 - e).abs() < 1e-6,
"1st call: element {i}: got {g1}, expected {e}",
);
assert!(
(g2 - e).abs() < 1e-6,
"2nd call: element {i}: got {g2}, expected {e}",
);
}
}
#[test]
fn cached_kernel_second_call_is_fast() {
use std::time::Instant;
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a_data = vec![1.0f32; 1024];
let b_data = vec![2.0f32; 1024];
let a = cpu_to_gpu(&a_data, &dev).expect("a to gpu");
let b = cpu_to_gpu(&b_data, &dev).expect("b to gpu");
let _ = crate::kernels::gpu_neg(&a, &dev);
let t1 = Instant::now();
let _ = crate::kernels::gpu_mul(&a, &b, &dev).expect("gpu_mul 1st");
let d1 = t1.elapsed();
let t2 = Instant::now();
let _ = crate::kernels::gpu_mul(&a, &b, &dev).expect("gpu_mul 2nd");
let d2 = t2.elapsed();
eprintln!(
"module_cache timing: 1st call = {:?}, 2nd call = {:?}",
d1, d2,
);
}
#[test]
fn get_or_compile_owned_returns_same_function_on_repeated_calls() {
let ptx = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry id_kernel_owned_cache_test(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va;
.reg .pred %p;
ld.param.u64 %a, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
st.global.f32 [%out], %va;
DONE:
ret;
}
"
.to_string();
let dev = crate::device::GpuDevice::new(0).expect("CUDA device 0");
let ctx = dev.context();
let f1 = super::get_or_compile_owned(
ctx,
ptx.clone(),
"id_kernel_owned_cache_test".to_string(),
dev.ordinal() as u32,
)
.expect("first compile");
let f2 = super::get_or_compile_owned(
ctx,
ptx.clone(),
"id_kernel_owned_cache_test".to_string(),
dev.ordinal() as u32,
)
.expect("second (cached) compile");
assert_eq!(format!("{f1:?}"), format!("{f2:?}"));
}
#[test]
fn get_or_compile_owned_different_ptx_returns_different_function() {
let ptx_a = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry id_kernel_owned_diff_a(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va;
.reg .pred %p;
ld.param.u64 %a, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
st.global.f32 [%out], %va;
DONE:
ret;
}
"
.to_string();
let ptx_b = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry id_kernel_owned_diff_b(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va;
.reg .pred %p;
ld.param.u64 %a, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
neg.f32 %va, %va;
st.global.f32 [%out], %va;
DONE:
ret;
}
"
.to_string();
let dev = crate::device::GpuDevice::new(0).expect("CUDA device 0");
let ctx = dev.context();
let f_a = super::get_or_compile_owned(
ctx,
ptx_a,
"id_kernel_owned_diff_a".to_string(),
dev.ordinal() as u32,
)
.expect("compile a");
let f_b = super::get_or_compile_owned(
ctx,
ptx_b,
"id_kernel_owned_diff_b".to_string(),
dev.ordinal() as u32,
)
.expect("compile b");
assert_ne!(format!("{f_a:?}"), format!("{f_b:?}"));
}
#[test]
fn broadcast_div_kernel_ptx_loads() {
let ctx = cudarc::driver::CudaContext::new(0).expect("CUDA device 0");
let _module = ctx
.load_module(cudarc::nvrtc::Ptx::from_src(
crate::kernels::BROADCAST_DIV_PTX,
))
.expect("BROADCAST_DIV_PTX must compile");
}
}