gam_gpu/device_cache.rs
1//! Shared host-side scaffolding for every cudarc-backed module under
2//! `src/gpu/*` and `src/solver/gpu/*`.
3//!
4//! Before this module existed, each device backend (`bms_flex`,
5//! `survival_flex`, `polya_gamma`, `reml_trace`, ...) carried its own
6//! near-identical copy of two patterns:
7//!
8//! 1. A power-of-two bucketed free list of reusable f64 device slices
9//! (the per-backend `DeviceArena`).
10//! 2. A `OnceLock<Result<{module: Arc<CudaModule>}, GpuError>>` that
11//! NVRTC-compiled one source string the first time the backend
12//! dispatched and cached the resulting module for the process lifetime.
13//!
14//! Both are now provided here so every cudarc backend points at the same
15//! implementation. The migration is atomic: no per-backend `DeviceArena`
16//! type, no per-backend ad-hoc OnceLock, no transitional shim.
17
18#[cfg(target_os = "linux")]
19pub use linux::{DeviceArena, PtxModuleCache, compile_ptx_arch};
20
21#[cfg(target_os = "linux")]
22mod linux {
23 use super::super::gpu_error::GpuError;
24 use crate::gpu_error::GpuResultExt;
25 use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
26 use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
27 use std::collections::HashMap;
28 use std::path::Path;
29 use std::sync::Arc;
30
31 /// Power-of-two bucketed free list of f64 device slices.
32 ///
33 /// Allocations round the requested element count up to the next
34 /// `usize::next_power_of_two`. On drop the slab is handed back to the
35 /// arena under the same bucket via [`DeviceArena::release`]. Held under
36 /// a `Mutex` by every backend that uses it because large-scale fits
37 /// dispatch from multiple rayon workers; the mutex is only held during
38 /// `alloc` / `release`, never across kernel launches.
39 #[derive(Default)]
40 pub struct DeviceArena {
41 free: HashMap<usize, Vec<CudaSlice<f64>>>,
42 }
43
44 impl DeviceArena {
45 #[inline]
46 pub fn bucket_of(elements: usize) -> usize {
47 elements.max(1).next_power_of_two()
48 }
49
50 /// Allocate a device slice of at least `elements` f64s. Returns the
51 /// bucket size actually allocated so the caller can release into the
52 /// same bucket on drop. `label` is woven into the error message if
53 /// the underlying `alloc_zeros` fails so failures stay attributable
54 /// to the originating backend (matching the pre-extraction wording).
55 pub fn alloc(
56 &mut self,
57 stream: &Arc<CudaStream>,
58 elements: usize,
59 label: &'static str,
60 ) -> Result<(usize, CudaSlice<f64>), GpuError> {
61 let bucket = Self::bucket_of(elements);
62 if let Some(bucket_vec) = self.free.get_mut(&bucket)
63 && let Some(slot) = bucket_vec.pop()
64 {
65 return Ok((bucket, slot));
66 }
67 let fresh = stream
68 .alloc_zeros::<f64>(bucket)
69 .gpu_ctx_with(|err| format!("{label} arena alloc_zeros<{bucket}>: {err}"))?;
70 Ok((bucket, fresh))
71 }
72
73 pub fn release(&mut self, bucket: usize, slab: CudaSlice<f64>) {
74 self.free.entry(bucket).or_default().push(slab);
75 }
76 }
77
78 /// Process-wide NVRTC module cache for a single PTX source string.
79 ///
80 /// The first call to [`PtxModuleCache::get_or_compile`] compiles the
81 /// source via `cudarc::nvrtc::compile_ptx`, loads the module on the
82 /// supplied context, and stores the resulting `Arc<CudaModule>`.
83 /// Subsequent calls return the cached module without recompiling.
84 ///
85 /// The `label` is woven into the error message so the originating
86 /// backend stays identifiable in logs; the wording matches each
87 /// caller's previous bespoke `format!` so existing log assertions
88 /// continue to hold.
89 #[derive(Default)]
90 pub struct PtxModuleCache {
91 module: std::sync::OnceLock<Arc<CudaModule>>,
92 }
93
94 impl PtxModuleCache {
95 pub const fn new() -> Self {
96 Self {
97 module: std::sync::OnceLock::new(),
98 }
99 }
100
101 pub fn get(&self) -> Option<&Arc<CudaModule>> {
102 self.module.get()
103 }
104
105 /// Compile `source` and load it on `ctx` the first time; return
106 /// the cached `Arc<CudaModule>` on every subsequent call.
107 pub fn get_or_compile(
108 &self,
109 ctx: &Arc<CudaContext>,
110 label: &'static str,
111 source: &str,
112 ) -> Result<&Arc<CudaModule>, GpuError> {
113 if let Some(existing) = self.module.get() {
114 return Ok(existing);
115 }
116 let ptx = compile_ptx_with_opts(source, nvrtc_compile_options())
117 .gpu_ctx_with(|err| format!("{label} NVRTC compile failed: {err}"))?;
118 let module = ctx
119 .load_module(ptx)
120 .gpu_ctx_with(|err| format!("{label} module load failed: {err}"))?;
121 self.module.set(module).ok();
122 Ok(self
123 .module
124 .get()
125 .expect("module slot populated immediately after set"))
126 }
127 }
128
129 /// Compile a kernel source string to PTX with the SAME device-keyed NVRTC
130 /// options [`PtxModuleCache::get_or_compile`] uses — crucially the
131 /// `--gpu-architecture` pin (#1551), without which NVRTC defaults below
132 /// `sm_60` and rejects `atomicAdd(double*, double)`. Call sites that compile
133 /// via the bare `cudarc::nvrtc::compile_ptx` (no options) MUST route through
134 /// this instead when their kernel uses double atomics, or the device path
135 /// silently falls back to the CPU.
136 pub fn compile_ptx_arch<S: AsRef<str>>(source: S) -> Result<cudarc::nvrtc::Ptx, GpuError> {
137 compile_ptx_with_opts(source.as_ref(), nvrtc_compile_options())
138 .gpu_ctx_with(|err| std::format!("NVRTC compile failed: {err}"))
139 }
140
141 fn nvrtc_compile_options() -> CompileOptions {
142 let mut opts = CompileOptions::default();
143 opts.include_paths = nvrtc_include_paths();
144 // GPU↔CPU PARITY: disable FMA contraction. NVRTC's default is
145 // `--fmad=true`, which fuses `a*b + c` into a single fused multiply-add
146 // (ONE rounding). The CPU oracle computes `a*b` then `+ c` as two
147 // SEPARATELY-rounded f64 ops. For shallow kernels the gap is ~1 ULP;
148 // for deep derivative towers (the survival/SAE seeded jets, whose
149 // Hessian + contracted third/fourth channels chain dozens of mul/add
150 // steps) the per-op FMA divergence accumulates to ~5e-8 — enough to
151 // blow a 1e-9 parity gate on a real device (measured on a V100,
152 // compute 7.0: survival_rowjet device-vs-CPU max abs diff 5.09e-8).
153 // `--use_fast_math` was already off, but that does NOT imply fmad off
154 // (use_fast_math only ADDS fmad=true; the converse default is still
155 // on). Pinning fmad=false makes every shared-options kernel
156 // bit-comparable to the separately-rounded CPU path. `Option::None`
157 // would defer to NVRTC's `true` default, so we set it explicitly.
158 opts.fmad = Some(false);
159 // #1551: pin the NVRTC virtual arch to the selected device's compute
160 // capability. Without it NVRTC defaults below sm_60, where the
161 // `atomicAdd(double*, double)` overload is absent — so kernels using
162 // double atomics (the SAE arrow/Schur PCG kernels) fail to compile and
163 // the device path silently falls back to the CPU (SAE ran at 0% GPU).
164 // `arch` is `Option<&'static str>`; `nvrtc_arch()` returns a static
165 // `compute_NN` for the device's real capability.
166 if let Some(runtime) = crate::device_runtime::GpuRuntime::global() {
167 opts.arch = Some(runtime.selected_device().capability.nvrtc_arch());
168 }
169 opts
170 }
171
172 fn nvrtc_include_paths() -> Vec<String> {
173 let mut paths = Vec::new();
174 push_existing_include_path(&mut paths, Path::new("/usr/local/cuda/include"));
175 push_existing_include_path(&mut paths, Path::new("/usr/include"));
176 push_existing_include_path(&mut paths, Path::new("/usr/include/x86_64-linux-gnu"));
177 push_gcc_include_paths(&mut paths, Path::new("/usr/lib/gcc/x86_64-linux-gnu"));
178 paths
179 }
180
181 fn push_gcc_include_paths(paths: &mut Vec<String>, root: &Path) {
182 let Ok(entries) = std::fs::read_dir(root) else {
183 return;
184 };
185 for entry in entries.flatten() {
186 push_existing_include_path(paths, &entry.path().join("include"));
187 }
188 }
189
190 fn push_existing_include_path(paths: &mut Vec<String>, path: &Path) {
191 if !path.is_dir() {
192 return;
193 }
194 let display = path.to_string_lossy().into_owned();
195 if !paths.iter().any(|existing| existing == &display) {
196 paths.push(display);
197 }
198 }
199}