gam_gpu/device_cache.rs
1//! Shared host-side scaffolding for every cudarc-backed module under
2//! `src/gpu/*` and `src/solver/gpu/*`.
3//!
4//! Before this module existed, each device backend (`bms_flex`,
5//! `survival_flex`, `polya_gamma`, `reml_trace`, ...) carried its own
6//! near-identical copy of two patterns:
7//!
8//! 1. A power-of-two bucketed free list of reusable f64 device slices
9//! (the per-backend `DeviceArena`).
10//! 2. A `OnceLock<Result<{module: Arc<CudaModule>}, GpuError>>` that
11//! NVRTC-compiled one source string the first time the backend
12//! dispatched and cached the resulting module for the process lifetime.
13//!
14//! Both are now provided here so every cudarc backend points at the same
15//! implementation. The migration is atomic: no per-backend `DeviceArena`
16//! type, no per-backend ad-hoc OnceLock, no transitional shim.
17
18#[cfg(target_os = "linux")]
19pub use linux::{DeviceArena, PtxModuleCache, compile_ptx_arch};
20
21#[cfg(target_os = "linux")]
22mod linux {
23 use super::super::gpu_error::GpuError;
24 use crate::gpu_error::GpuResultExt;
25 use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
26 use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
27 use std::collections::HashMap;
28 use std::path::Path;
29 use std::sync::Arc;
30
31 /// Power-of-two bucketed free list of f64 device slices.
32 ///
33 /// Allocations round the requested element count up to the next
34 /// `usize::next_power_of_two`. On drop the slab is handed back to the
35 /// arena under the same bucket via [`DeviceArena::release`]. Held under
36 /// a `Mutex` by every backend that uses it because large-scale fits
37 /// dispatch from multiple rayon workers; the mutex is only held during
38 /// `alloc` / `release`, never across kernel launches.
39 #[derive(Default)]
40 pub struct DeviceArena {
41 free: HashMap<usize, Vec<CudaSlice<f64>>>,
42 }
43
44 impl DeviceArena {
45 #[inline]
46 pub fn bucket_of(elements: usize) -> usize {
47 elements.max(1).next_power_of_two()
48 }
49
50 /// Allocate a device slice of at least `elements` f64s. Returns the
51 /// bucket size actually allocated so the caller can release into the
52 /// same bucket on drop. `label` is woven into the error message if
53 /// the underlying `alloc_zeros` fails so failures stay attributable
54 /// to the originating backend (matching the pre-extraction wording).
55 pub fn alloc(
56 &mut self,
57 stream: &Arc<CudaStream>,
58 elements: usize,
59 label: &'static str,
60 ) -> Result<(usize, CudaSlice<f64>), GpuError> {
61 let bucket = Self::bucket_of(elements);
62 if let Some(bucket_vec) = self.free.get_mut(&bucket)
63 && let Some(slot) = bucket_vec.pop()
64 {
65 return Ok((bucket, slot));
66 }
67 let fresh = stream
68 .alloc_zeros::<f64>(bucket)
69 .gpu_ctx_with(|err| format!("{label} arena alloc_zeros<{bucket}>: {err}"))?;
70 Ok((bucket, fresh))
71 }
72
73 pub fn release(&mut self, bucket: usize, slab: CudaSlice<f64>) {
74 self.free.entry(bucket).or_default().push(slab);
75 }
76 }
77
78 /// Process-wide NVRTC module cache for a single PTX source string.
79 ///
80 /// The first call to [`PtxModuleCache::get_or_compile`] compiles the
81 /// source via `cudarc::nvrtc::compile_ptx`, loads the module on the
82 /// supplied context, and stores the resulting `Arc<CudaModule>`.
83 /// Subsequent calls return the cached module without recompiling.
84 ///
85 /// The `label` is woven into the error message so the originating
86 /// backend stays identifiable in logs; the wording matches each
87 /// caller's previous bespoke `format!` so existing log assertions
88 /// continue to hold.
89 #[derive(Default)]
90 pub struct PtxModuleCache {
91 module: std::sync::OnceLock<Arc<CudaModule>>,
92 }
93
94 impl PtxModuleCache {
95 pub const fn new() -> Self {
96 Self {
97 module: std::sync::OnceLock::new(),
98 }
99 }
100
101 pub fn get(&self) -> Option<&Arc<CudaModule>> {
102 self.module.get()
103 }
104
105 /// Compile `source` and load it on `ctx` the first time; return
106 /// the cached `Arc<CudaModule>` on every subsequent call.
107 pub fn get_or_compile(
108 &self,
109 ctx: &Arc<CudaContext>,
110 label: &'static str,
111 source: &str,
112 ) -> Result<&Arc<CudaModule>, GpuError> {
113 if let Some(existing) = self.module.get() {
114 return Ok(existing);
115 }
116 let ptx = compile_ptx_with_opts(source, nvrtc_compile_options())
117 .gpu_ctx_with(|err| format!("{label} NVRTC compile failed: {err}"))?;
118 let module = ctx
119 .load_module(ptx)
120 .gpu_ctx_with(|err| format!("{label} module load failed: {err}"))?;
121 self.module.set(module).ok();
122 Ok(self
123 .module
124 .get()
125 .expect("module slot populated immediately after set"))
126 }
127 }
128
129 /// Compile a kernel source string to PTX with the SAME device-keyed NVRTC
130 /// options [`PtxModuleCache::get_or_compile`] uses — crucially the
131 /// `--gpu-architecture` pin (#1551), without which NVRTC defaults below
132 /// `sm_60` and rejects `atomicAdd(double*, double)`. Call sites that compile
133 /// via the bare `cudarc::nvrtc::compile_ptx` (no options) MUST route through
134 /// this instead when their kernel uses double atomics, or the device path
135 /// silently falls back to the CPU.
136 pub fn compile_ptx_arch<S: AsRef<str>>(source: S) -> Result<cudarc::nvrtc::Ptx, GpuError> {
137 compile_ptx_with_opts(source.as_ref(), nvrtc_compile_options())
138 .gpu_ctx_with(|err| std::format!("NVRTC compile failed: {err}"))
139 }
140
141 fn nvrtc_compile_options() -> CompileOptions {
142 let mut opts = CompileOptions::default();
143 opts.include_paths = nvrtc_include_paths();
144 // #1551: pin the NVRTC virtual arch to the selected device's compute
145 // capability. Without it NVRTC defaults below sm_60, where the
146 // `atomicAdd(double*, double)` overload is absent — so kernels using
147 // double atomics (the SAE arrow/Schur PCG kernels) fail to compile and
148 // the device path silently falls back to the CPU (SAE ran at 0% GPU).
149 // `arch` is `Option<&'static str>`; `nvrtc_arch()` returns a static
150 // `compute_NN` for the device's real capability.
151 if let Some(runtime) = crate::device_runtime::GpuRuntime::global() {
152 opts.arch = Some(runtime.selected_device().capability.nvrtc_arch());
153 }
154 opts
155 }
156
157 fn nvrtc_include_paths() -> Vec<String> {
158 let mut paths = Vec::new();
159 push_existing_include_path(&mut paths, Path::new("/usr/local/cuda/include"));
160 push_existing_include_path(&mut paths, Path::new("/usr/include"));
161 push_existing_include_path(&mut paths, Path::new("/usr/include/x86_64-linux-gnu"));
162 push_gcc_include_paths(&mut paths, Path::new("/usr/lib/gcc/x86_64-linux-gnu"));
163 paths
164 }
165
166 fn push_gcc_include_paths(paths: &mut Vec<String>, root: &Path) {
167 let Ok(entries) = std::fs::read_dir(root) else {
168 return;
169 };
170 for entry in entries.flatten() {
171 push_existing_include_path(paths, &entry.path().join("include"));
172 }
173 }
174
175 fn push_existing_include_path(paths: &mut Vec<String>, path: &Path) {
176 if !path.is_dir() {
177 return;
178 }
179 let display = path.to_string_lossy().into_owned();
180 if !paths.iter().any(|existing| existing == &display) {
181 paths.push(display);
182 }
183 }
184}