Skip to main content

ferrotorch_gpu/
module_cache.rs

1//! Global cache for compiled CUDA modules and kernel functions.
2//!
3//! Without caching, every call to a GPU kernel (e.g. [`gpu_add`], [`gpu_conv2d_f32`],
4//! [`gpu_flash_attention_f32`]) recompiles PTX source into a CUBIN via
5//! `CudaContext::load_module(Ptx::from_src(...))`.  This compilation takes
6//! ~1700 us per call -- far longer than the actual kernel execution.
7//!
8//! This module provides [`get_or_compile`], which compiles the PTX only on
9//! first use and returns a cached [`CudaFunction`] on subsequent calls.  The
10//! cache is keyed by the static kernel name string, which is unique per
11//! kernel entry point in this crate.
12//!
13//! # Thread safety
14//!
15//! The cache uses a global [`Mutex`]-protected [`HashMap`].  The critical
16//! section is short (a hash lookup + optional insert), so contention is
17//! negligible in practice.
18//!
19//! [`gpu_add`]: crate::kernels::gpu_add
20//! [`gpu_conv2d_f32`]: crate::conv::gpu_conv2d_f32
21//! [`gpu_flash_attention_f32`]: crate::flash_attention::gpu_flash_attention_f32
22
23#[cfg(feature = "cuda")]
24use std::collections::HashMap;
25#[cfg(feature = "cuda")]
26use std::sync::{Arc, LazyLock, Mutex};
27
28#[cfg(feature = "cuda")]
29use cudarc::driver::{CudaContext, CudaFunction, DriverError};
30#[cfg(feature = "cuda")]
31use cudarc::nvrtc::Ptx;
32
33/// Global cache mapping (kernel name, device ordinal) to their compiled
34/// [`CudaFunction`]s.
35///
36/// Keyed by `(&'static str, u32)` -- the kernel name (e.g. `"add_kernel"`)
37/// and the CUDA device ordinal.  A kernel compiled for device 0 cannot be
38/// used on device 1, so the ordinal is part of the key.
39#[cfg(feature = "cuda")]
40static MODULE_CACHE: LazyLock<Mutex<HashMap<(&'static str, u32), CudaFunction>>> =
41    LazyLock::new(|| Mutex::new(HashMap::new()));
42
43/// Get a compiled kernel function, compiling the PTX only on first use.
44///
45/// On the first call for a given `(kernel_name, device_ordinal)` pair, this
46/// function compiles `ptx_src` into a CUDA module and extracts the named
47/// function.  The resulting [`CudaFunction`] is cached globally and returned
48/// by clone on subsequent calls, eliminating the ~1700 us PTX compilation
49/// overhead.
50///
51/// # Arguments
52///
53/// - `ctx`            -- CUDA context (from `device.context()`).
54/// - `ptx_src`        -- PTX source string (a `&'static str` constant).
55/// - `kernel_name`    -- entry-point name inside the PTX module.
56/// - `device_ordinal` -- CUDA device ordinal (so kernels compiled for
57///   device 0 are not reused on device 1).
58///
59/// # Errors
60///
61/// Returns [`DriverError`] if PTX compilation or function lookup fails.
62#[cfg(feature = "cuda")]
63pub fn get_or_compile(
64    ctx: &Arc<CudaContext>,
65    ptx_src: &'static str,
66    kernel_name: &'static str,
67    device_ordinal: u32,
68) -> Result<CudaFunction, DriverError> {
69    let key = (kernel_name, device_ordinal);
70    let mut cache = MODULE_CACHE.lock().unwrap();
71    if let Some(func) = cache.get(&key) {
72        return Ok(func.clone());
73    }
74    let module = ctx.load_module(Ptx::from_src(ptx_src))?;
75    let func = module.load_function(kernel_name)?;
76    cache.insert(key, func.clone());
77    Ok(func)
78}
79
80
81#[cfg(test)]
82#[cfg(feature = "cuda")]
83mod tests {
84    use crate::device::GpuDevice;
85    use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
86
87    #[test]
88    fn cache_returns_function_on_repeated_calls() {
89        // Verify the cache works by calling gpu_add twice. The first call
90        // compiles the PTX; the second hits the cache. Both must succeed.
91        let dev = crate::device::GpuDevice::new(0).expect("CUDA device 0");
92        let a = crate::transfer::cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).expect("a");
93        let b = crate::transfer::cpu_to_gpu(&[4.0f32, 5.0, 6.0], &dev).expect("b");
94
95        let r1 = crate::kernels::gpu_add(&a, &b, &dev).expect("first add (compiles)");
96        let r2 = crate::kernels::gpu_add(&a, &b, &dev).expect("second add (cached)");
97
98        let h1 = crate::transfer::gpu_to_cpu(&r1, &dev).expect("r1");
99        let h2 = crate::transfer::gpu_to_cpu(&r2, &dev).expect("r2");
100        assert_eq!(h1, h2, "cached kernel should produce identical results");
101        assert_eq!(h1, vec![5.0, 7.0, 9.0]);
102    }
103
104    #[test]
105    fn cached_kernel_produces_correct_results() {
106        // Run gpu_add twice and verify both produce correct results,
107        // confirming the cached kernel is functional.
108        let dev = GpuDevice::new(0).expect("CUDA device 0");
109
110        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
111        let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
112        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
113
114        let a = cpu_to_gpu(&a_data, &dev).expect("a to gpu");
115        let b = cpu_to_gpu(&b_data, &dev).expect("b to gpu");
116
117        // First call (compiles PTX).
118        let out1 = crate::kernels::gpu_add(&a, &b, &dev).expect("gpu_add 1st");
119        let host1 = gpu_to_cpu(&out1, &dev).expect("gpu_to_cpu 1st");
120
121        // Second call (uses cache).
122        let out2 = crate::kernels::gpu_add(&a, &b, &dev).expect("gpu_add 2nd");
123        let host2 = gpu_to_cpu(&out2, &dev).expect("gpu_to_cpu 2nd");
124
125        for (i, ((&g1, &g2), &e)) in host1
126            .iter()
127            .zip(host2.iter())
128            .zip(expected.iter())
129            .enumerate()
130        {
131            assert!(
132                (g1 - e).abs() < 1e-6,
133                "1st call: element {i}: got {g1}, expected {e}",
134            );
135            assert!(
136                (g2 - e).abs() < 1e-6,
137                "2nd call: element {i}: got {g2}, expected {e}",
138            );
139        }
140    }
141
142    #[test]
143    fn cached_kernel_second_call_is_fast() {
144        // The second call should be significantly faster than the first
145        // because it skips PTX compilation.
146        use std::time::Instant;
147
148        let dev = GpuDevice::new(0).expect("CUDA device 0");
149
150        let a_data = vec![1.0f32; 1024];
151        let b_data = vec![2.0f32; 1024];
152
153        let a = cpu_to_gpu(&a_data, &dev).expect("a to gpu");
154        let b = cpu_to_gpu(&b_data, &dev).expect("b to gpu");
155
156        // Warm up with a different kernel to avoid measuring CUDA init.
157        let _ = crate::kernels::gpu_neg(&a, &dev);
158
159        // We cannot rely on add_kernel being uncached here (other tests
160        // may have run first), so we use the mul_kernel via gpu_mul,
161        // which is less likely to have been called yet.  Even if it has
162        // been cached, both calls should be fast, and that is fine -- the
163        // structural test above already verifies identity.
164        let t1 = Instant::now();
165        let _ = crate::kernels::gpu_mul(&a, &b, &dev).expect("gpu_mul 1st");
166        let d1 = t1.elapsed();
167
168        let t2 = Instant::now();
169        let _ = crate::kernels::gpu_mul(&a, &b, &dev).expect("gpu_mul 2nd");
170        let d2 = t2.elapsed();
171
172        // The second call should be faster (no compilation).
173        // We do not assert a strict ratio because CI environments vary,
174        // but we log for manual inspection.
175        eprintln!(
176            "module_cache timing: 1st call = {:?}, 2nd call = {:?}",
177            d1, d2,
178        );
179    }
180}