1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//! Shared host-side scaffolding for every cudarc-backed module under
//! `src/gpu/*` and `src/solver/gpu/*`.
//!
//! Before this module existed, each device backend (`bms_flex`,
//! `survival_flex`, `polya_gamma`, `reml_trace`, ...) carried its own
//! near-identical copy of two patterns:
//!
//! 1. A power-of-two bucketed free list of reusable f64 device slices
//! (the per-backend `DeviceArena`).
//! 2. A `OnceLock<Result<{module: Arc<CudaModule>}, GpuError>>` that
//! NVRTC-compiled one source string the first time the backend
//! dispatched and cached the resulting module for the process lifetime.
//!
//! Both are now provided here so every cudarc backend points at the same
//! implementation. The migration is atomic: no per-backend `DeviceArena`
//! type, no per-backend ad-hoc OnceLock, no transitional shim.
#[cfg(target_os = "linux")]
pub use linux::{DeviceArena, PtxModuleCache};
#[cfg(target_os = "linux")]
mod linux {
use super::super::error::GpuError;
use crate::gpu::error::GpuResultExt;
use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
use std::collections::HashMap;
use std::sync::Arc;
/// Power-of-two bucketed free list of f64 device slices.
///
/// Allocations round the requested element count up to the next
/// `usize::next_power_of_two`. On drop the slab is handed back to the
/// arena under the same bucket via [`DeviceArena::release`]. Held under
/// a `Mutex` by every backend that uses it because biobank-scale fits
/// dispatch from multiple rayon workers; the mutex is only held during
/// `alloc` / `release`, never across kernel launches.
#[derive(Default)]
pub struct DeviceArena {
free: HashMap<usize, Vec<CudaSlice<f64>>>,
}
impl DeviceArena {
#[inline]
pub fn bucket_of(elements: usize) -> usize {
elements.max(1).next_power_of_two()
}
/// Allocate a device slice of at least `elements` f64s. Returns the
/// bucket size actually allocated so the caller can release into the
/// same bucket on drop. `label` is woven into the error message if
/// the underlying `alloc_zeros` fails so failures stay attributable
/// to the originating backend (matching the pre-extraction wording).
pub fn alloc(
&mut self,
stream: &Arc<CudaStream>,
elements: usize,
label: &'static str,
) -> Result<(usize, CudaSlice<f64>), GpuError> {
let bucket = Self::bucket_of(elements);
if let Some(bucket_vec) = self.free.get_mut(&bucket)
&& let Some(slot) = bucket_vec.pop()
{
return Ok((bucket, slot));
}
let fresh = stream
.alloc_zeros::<f64>(bucket)
.gpu_ctx_with(|err| format!("{label} arena alloc_zeros<{bucket}>: {err}"))?;
Ok((bucket, fresh))
}
pub fn release(&mut self, bucket: usize, slab: CudaSlice<f64>) {
self.free.entry(bucket).or_default().push(slab);
}
}
/// Process-wide NVRTC module cache for a single PTX source string.
///
/// The first call to [`PtxModuleCache::get_or_compile`] compiles the
/// source via `cudarc::nvrtc::compile_ptx`, loads the module on the
/// supplied context, and stores the resulting `Arc<CudaModule>`.
/// Subsequent calls return the cached module without recompiling.
///
/// The `label` is woven into the error message so the originating
/// backend stays identifiable in logs; the wording matches each
/// caller's previous bespoke `format!` so existing log assertions
/// continue to hold.
#[derive(Default)]
pub struct PtxModuleCache {
module: std::sync::OnceLock<Arc<CudaModule>>,
}
impl PtxModuleCache {
pub const fn new() -> Self {
Self {
module: std::sync::OnceLock::new(),
}
}
pub fn get(&self) -> Option<&Arc<CudaModule>> {
self.module.get()
}
/// Compile `source` and load it on `ctx` the first time; return
/// the cached `Arc<CudaModule>` on every subsequent call.
pub fn get_or_compile(
&self,
ctx: &Arc<CudaContext>,
label: &'static str,
source: &str,
) -> Result<&Arc<CudaModule>, GpuError> {
if let Some(existing) = self.module.get() {
return Ok(existing);
}
let ptx = cudarc::nvrtc::compile_ptx(source)
.gpu_ctx_with(|err| format!("{label} NVRTC compile failed: {err}"))?;
let module = ctx
.load_module(ptx)
.gpu_ctx_with(|err| format!("{label} module load failed: {err}"))?;
self.module.set(module).ok();
Ok(self
.module
.get()
.expect("module slot populated immediately after set"))
}
}
}