gam_gpu/backend_probe.rs
1//! Shared CUDA backend-probe contract for every cudarc-backed module under
2//! `src/gpu/*`.
3//!
4//! Before this module existed, every GPU backend (`bms_flex`,
5//! `survival_flex`, `cubic_bspline_moments`, `cubic_cell`, `pirls_row`,
6//! `sphere`, ...) carried its own near-identical `probe_linux` prologue:
7//!
8//! 1. Fetch the process-wide [`GpuRuntime`] or fail with a
9//! `DriverLibraryUnavailable { reason: "<module> backend: no CUDA
10//! runtime available" }`.
11//! 2. Read the runtime's selected device ordinal.
12//! 3. Create (or reuse) the per-ordinal [`CudaContext`] or fail with a
13//! `DriverCallFailed { reason: "<module> backend: failed to create
14//! CUDA context for device N" }`.
15//! 4. Open the context's default [`CudaStream`].
16//! 5. Carry the device's compute capability alongside the handles.
17//!
18//! Those five steps are identical apart from the per-module label that gets
19//! woven into the two error messages. Drift between copies meant error
20//! wording, capability handling, context reuse, and stream choice could
21//! diverge module to module. This module hosts the single contract: each
22//! backend now calls [`probe_cuda_backend`] with its label and keeps only
23//! its module caches and optional eager-compilation step.
24//!
25//! The migration is atomic: no backend re-implements the prologue, and
26//! there is no transitional shim.
27
28// `CudaBackendParts` is re-exported alongside the probe entry points: sibling
29// crates (`gam-terms`, `gam-models`, ...) call `probe_cuda_backend` and receive
30// a `CudaBackendParts` value (without ever naming the type), so `probe_cuda_backend`'s
31// public return type must itself be reachable or `-D warnings` rejects the leak
32// (`private_interfaces`). It carries no fields a caller can misuse out of context.
33#[cfg(target_os = "linux")]
34pub use linux::{
35 CudaBackendContext, CudaBackendParts, probe_backend_with_compile, probe_cuda_backend,
36};
37
38#[cfg(target_os = "linux")]
39mod linux {
40 use crate::device::GpuCapability;
41 use crate::device_cache::{DeviceArena, PtxModuleCache};
42 use crate::device_runtime::{GpuRuntime, cuda_context_for};
43 use crate::gpu_error::GpuError;
44 use cudarc::driver::{CudaContext, CudaStream};
45 use std::sync::{Arc, Mutex};
46
47 /// The handles every cudarc backend shares once the probe succeeds:
48 /// a context on the runtime's selected device, that context's default
49 /// stream, and the device's compute capability. Module-specific
50 /// backends layer their own caches and optional eager compilation on
51 /// top of these.
52 #[derive(Debug)]
53 pub struct CudaBackendParts {
54 pub ctx: Arc<CudaContext>,
55 pub stream: Arc<CudaStream>,
56 pub capability: GpuCapability,
57 }
58
59 /// Probe the process-wide CUDA backend for the calling module.
60 ///
61 /// Resolves the global [`GpuRuntime`], creates (or reuses) the
62 /// [`CudaContext`] for its selected device, opens that context's
63 /// default stream, and returns the trio bundled in [`CudaBackendParts`].
64 /// `label` names the calling module (e.g. `"bms_flex"`) and is woven
65 /// into both failure messages so the uniform contract still attributes
66 /// errors to their originating backend.
67 pub fn probe_cuda_backend(label: &'static str) -> Result<CudaBackendParts, GpuError> {
68 let runtime = GpuRuntime::global().ok_or_else(|| GpuError::DriverLibraryUnavailable {
69 reason: format!("{label} backend: no CUDA runtime available"),
70 })?;
71 let ordinal = runtime.selected_device().ordinal;
72 let ctx = cuda_context_for(ordinal).ok_or_else(|| {
73 gpu_err!("{label} backend: failed to create CUDA context for device {ordinal}")
74 })?;
75 let stream = ctx.default_stream();
76 let capability = runtime.selected_device().capability.clone();
77 Ok(CudaBackendParts {
78 ctx,
79 stream,
80 capability,
81 })
82 }
83
84 /// Probe the CUDA backend for `label` and run a backend-specific build
85 /// step on the resolved handles.
86 ///
87 /// This is [`probe_cuda_backend`] plus the one piece that genuinely
88 /// differs between backends: the NVRTC compile (and any per-backend cache
89 /// construction). The runtime resolution, context creation, and stream
90 /// selection — together with their uniform, label-attributed error
91 /// messages — live in the shared probe; `build` receives the resolved
92 /// [`CudaBackendParts`] (so it can clone the `Arc<CudaContext>` /
93 /// `Arc<CudaStream>` it needs) and returns the backend's own state `T`.
94 pub fn probe_backend_with_compile<F, T>(label: &'static str, build: F) -> Result<T, GpuError>
95 where
96 F: FnOnce(&CudaBackendParts) -> Result<T, GpuError>,
97 {
98 let parts = probe_cuda_backend(label)?;
99 build(&parts)
100 }
101
102 /// The process-wide device handles every cudarc backend stores after a
103 /// successful probe: the [`CudaContext`], its default [`CudaStream`], the
104 /// lazily NVRTC-compiled [`PtxModuleCache`], and the bucketed
105 /// [`DeviceArena`] of reusable f64 device buffers (held under a `Mutex`
106 /// because large-scale fits dispatch from multiple rayon worker threads; the
107 /// mutex is only held during `alloc` / `release`, not across kernel
108 /// launches). Module-specific backends (`bms_flex`, `survival_flex`, …)
109 /// wrap one of these as their `inner` context so the host-side
110 /// scaffolding (arena pooling, module cache, mutex around alloc) is
111 /// uniform instead of duplicated per backend.
112 pub struct CudaBackendContext {
113 pub ctx: Arc<CudaContext>,
114 pub stream: Arc<CudaStream>,
115 pub module: PtxModuleCache,
116 pub arena: Mutex<DeviceArena>,
117 }
118
119 impl CudaBackendContext {
120 /// Build the stored context from a fresh [`CudaBackendParts`] probe
121 /// result: adopt its context and stream, start an empty module cache
122 /// (the backend's eager-compile step fills it), and an empty device
123 /// arena. The probe's compute `capability` is consumed by the probe
124 /// path itself and is not retained here.
125 pub fn from_parts(parts: CudaBackendParts) -> Self {
126 CudaBackendContext {
127 ctx: parts.ctx,
128 stream: parts.stream,
129 module: PtxModuleCache::new(),
130 arena: Mutex::new(DeviceArena::default()),
131 }
132 }
133 }
134}
135
136#[cfg(all(test, target_os = "linux"))]
137mod tests {
138 use super::probe_cuda_backend;
139 use crate::device_runtime::GpuRuntime;
140 use crate::gpu_error::GpuError;
141
142 /// Parity: every backend's probe must agree with the shared contract on
143 /// the same device. On a host with no CUDA runtime, the shared probe
144 /// must return the uniform `DriverLibraryUnavailable` carrying the
145 /// caller's label; on a host with a runtime, the probe must resolve the
146 /// *same* selected-device ordinal and compute capability the runtime
147 /// advertises, with a context bound to that ordinal and a usable
148 /// default stream. This is the regression guard that keeps the six
149 /// migrated backends (`bms_flex`, `survival_flex`,
150 /// `cubic_bspline_moments`, `cubic_cell`, `pirls_row`, `sphere`) routed
151 /// through one prologue instead of drifting copies.
152 #[test]
153 fn shared_probe_matches_runtime_device_and_labels_errors() {
154 match GpuRuntime::global() {
155 None => {
156 // No runtime: the shared probe must fail uniformly and
157 // attribute the failure to the supplied label.
158 match probe_cuda_backend("bms_flex") {
159 Err(GpuError::DriverLibraryUnavailable { reason }) => {
160 assert_eq!(
161 reason, "bms_flex backend: no CUDA runtime available",
162 "shared probe must emit the uniform no-runtime message"
163 );
164 }
165 other => panic!(
166 "expected DriverLibraryUnavailable on a host without a CUDA runtime, \
167 got {other:?}"
168 ),
169 }
170 }
171 Some(runtime) => {
172 // Runtime present: every label resolves the same selected
173 // device and the same compute capability the runtime
174 // advertises, and the context binds to that ordinal.
175 let expected_ordinal = runtime.selected_device().ordinal;
176 let expected_capability = &runtime.selected_device().capability;
177 for label in [
178 "bms_flex",
179 "survival_flex",
180 "cubic_bspline_moments",
181 "cubic_cell",
182 "pirls_row",
183 "sphere",
184 ] {
185 let parts = probe_cuda_backend(label)
186 .unwrap_or_else(|err| panic!("probe for {label} must succeed: {err:?}"));
187 assert_eq!(
188 parts.ctx.ordinal(),
189 expected_ordinal,
190 "{label}: context must bind the runtime's selected device ordinal"
191 );
192 assert_eq!(
193 &parts.capability, expected_capability,
194 "{label}: probe capability must match the runtime's selected device"
195 );
196 parts
197 .stream
198 .synchronize()
199 .unwrap_or_else(|err| panic!("{label}: default stream must sync: {err:?}"));
200 }
201 }
202 }
203 }
204}