Skip to main content

gam_gpu/
device_cache.rs

1//! Shared host-side scaffolding for every cudarc-backed module under
2//! `src/gpu/*` and `src/solver/gpu/*`.
3//!
4//! Before this module existed, each device backend (`bms_flex`,
5//! `survival_flex`, `polya_gamma`, `reml_trace`, ...) carried its own
6//! near-identical copy of two patterns:
7//!
8//!   1. A power-of-two bucketed free list of reusable f64 device slices
9//!      (the per-backend `DeviceArena`).
10//!   2. A `OnceLock<Result<{module: Arc<CudaModule>}, GpuError>>` that
11//!      NVRTC-compiled one source string the first time the backend
12//!      dispatched and cached the resulting module for the process lifetime.
13//!
14//! Both are now provided here so every cudarc backend points at the same
15//! implementation. The migration is atomic: no per-backend `DeviceArena`
16//! type, no per-backend ad-hoc OnceLock, no transitional shim.
17
18#[cfg(target_os = "linux")]
19pub use linux::{DeviceArena, PtxModuleCache, compile_ptx_arch};
20
21#[cfg(target_os = "linux")]
22mod linux {
23    use super::super::gpu_error::GpuError;
24    use crate::gpu_error::GpuResultExt;
25    use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
26    use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
27    use std::collections::HashMap;
28    use std::path::Path;
29    use std::sync::Arc;
30
31    /// Power-of-two bucketed free list of f64 device slices.
32    ///
33    /// Allocations round the requested element count up to the next
34    /// `usize::next_power_of_two`. On drop the slab is handed back to the
35    /// arena under the same bucket via [`DeviceArena::release`]. Held under
36    /// a `Mutex` by every backend that uses it because large-scale fits
37    /// dispatch from multiple rayon workers; the mutex is only held during
38    /// `alloc` / `release`, never across kernel launches.
39    #[derive(Default)]
40    pub struct DeviceArena {
41        free: HashMap<usize, Vec<CudaSlice<f64>>>,
42    }
43
44    impl DeviceArena {
45        #[inline]
46        pub fn bucket_of(elements: usize) -> usize {
47            elements.max(1).next_power_of_two()
48        }
49
50        /// Allocate a device slice of at least `elements` f64s. Returns the
51        /// bucket size actually allocated so the caller can release into the
52        /// same bucket on drop. `label` is woven into the error message if
53        /// the underlying `alloc_zeros` fails so failures stay attributable
54        /// to the originating backend (matching the pre-extraction wording).
55        pub fn alloc(
56            &mut self,
57            stream: &Arc<CudaStream>,
58            elements: usize,
59            label: &'static str,
60        ) -> Result<(usize, CudaSlice<f64>), GpuError> {
61            let bucket = Self::bucket_of(elements);
62            if let Some(bucket_vec) = self.free.get_mut(&bucket)
63                && let Some(slot) = bucket_vec.pop()
64            {
65                return Ok((bucket, slot));
66            }
67            let fresh = stream
68                .alloc_zeros::<f64>(bucket)
69                .gpu_ctx_with(|err| format!("{label} arena alloc_zeros<{bucket}>: {err}"))?;
70            Ok((bucket, fresh))
71        }
72
73        pub fn release(&mut self, bucket: usize, slab: CudaSlice<f64>) {
74            self.free.entry(bucket).or_default().push(slab);
75        }
76    }
77
78    /// Process-wide NVRTC module cache for a single PTX source string.
79    ///
80    /// The first call to [`PtxModuleCache::get_or_compile`] compiles the
81    /// source via `cudarc::nvrtc::compile_ptx`, loads the module on the
82    /// supplied context, and stores the resulting `Arc<CudaModule>`.
83    /// Subsequent calls return the cached module without recompiling.
84    ///
85    /// The `label` is woven into the error message so the originating
86    /// backend stays identifiable in logs; the wording matches each
87    /// caller's previous bespoke `format!` so existing log assertions
88    /// continue to hold.
89    #[derive(Default)]
90    pub struct PtxModuleCache {
91        module: std::sync::OnceLock<Arc<CudaModule>>,
92    }
93
94    impl PtxModuleCache {
95        pub const fn new() -> Self {
96            Self {
97                module: std::sync::OnceLock::new(),
98            }
99        }
100
101        pub fn get(&self) -> Option<&Arc<CudaModule>> {
102            self.module.get()
103        }
104
105        /// Compile `source` and load it on `ctx` the first time; return
106        /// the cached `Arc<CudaModule>` on every subsequent call.
107        pub fn get_or_compile(
108            &self,
109            ctx: &Arc<CudaContext>,
110            label: &'static str,
111            source: &str,
112        ) -> Result<&Arc<CudaModule>, GpuError> {
113            if let Some(existing) = self.module.get() {
114                return Ok(existing);
115            }
116            let ptx = compile_ptx_with_opts(source, nvrtc_compile_options())
117                .gpu_ctx_with(|err| format!("{label} NVRTC compile failed: {err}"))?;
118            let module = ctx
119                .load_module(ptx)
120                .gpu_ctx_with(|err| format!("{label} module load failed: {err}"))?;
121            self.module.set(module).ok();
122            Ok(self
123                .module
124                .get()
125                .expect("module slot populated immediately after set"))
126        }
127    }
128
129    /// Compile a kernel source string to PTX with the SAME device-keyed NVRTC
130    /// options [`PtxModuleCache::get_or_compile`] uses — crucially the
131    /// `--gpu-architecture` pin (#1551), without which NVRTC defaults below
132    /// `sm_60` and rejects `atomicAdd(double*, double)`. Call sites that compile
133    /// via the bare `cudarc::nvrtc::compile_ptx` (no options) MUST route through
134    /// this instead when their kernel uses double atomics, or the device path
135    /// silently falls back to the CPU.
136    pub fn compile_ptx_arch<S: AsRef<str>>(source: S) -> Result<cudarc::nvrtc::Ptx, GpuError> {
137        compile_ptx_with_opts(source.as_ref(), nvrtc_compile_options())
138            .gpu_ctx_with(|err| std::format!("NVRTC compile failed: {err}"))
139    }
140
141    fn nvrtc_compile_options() -> CompileOptions {
142        let mut opts = CompileOptions::default();
143        opts.include_paths = nvrtc_include_paths();
144        // #1551: pin the NVRTC virtual arch to the selected device's compute
145        // capability. Without it NVRTC defaults below sm_60, where the
146        // `atomicAdd(double*, double)` overload is absent — so kernels using
147        // double atomics (the SAE arrow/Schur PCG kernels) fail to compile and
148        // the device path silently falls back to the CPU (SAE ran at 0% GPU).
149        // `arch` is `Option<&'static str>`; `nvrtc_arch()` returns a static
150        // `compute_NN` for the device's real capability.
151        if let Some(runtime) = crate::device_runtime::GpuRuntime::global() {
152            opts.arch = Some(runtime.selected_device().capability.nvrtc_arch());
153        }
154        opts
155    }
156
157    fn nvrtc_include_paths() -> Vec<String> {
158        let mut paths = Vec::new();
159        push_existing_include_path(&mut paths, Path::new("/usr/local/cuda/include"));
160        push_existing_include_path(&mut paths, Path::new("/usr/include"));
161        push_existing_include_path(&mut paths, Path::new("/usr/include/x86_64-linux-gnu"));
162        push_gcc_include_paths(&mut paths, Path::new("/usr/lib/gcc/x86_64-linux-gnu"));
163        paths
164    }
165
166    fn push_gcc_include_paths(paths: &mut Vec<String>, root: &Path) {
167        let Ok(entries) = std::fs::read_dir(root) else {
168            return;
169        };
170        for entry in entries.flatten() {
171            push_existing_include_path(paths, &entry.path().join("include"));
172        }
173    }
174
175    fn push_existing_include_path(paths: &mut Vec<String>, path: &Path) {
176        if !path.is_dir() {
177            return;
178        }
179        let display = path.to_string_lossy().into_owned();
180        if !paths.iter().any(|existing| existing == &display) {
181            paths.push(display);
182        }
183    }
184}