gam_gpu/backend_probe.rs
1//! Shared CUDA backend-probe contract for every cudarc-backed module under
2//! `src/gpu/*`.
3//!
4//! Before this module existed, every GPU backend (`bms_flex`,
5//! `survival_flex`, `cubic_bspline_moments`, `cubic_cell`, `pirls_row`,
6//! `sphere`, ...) carried its own near-identical `probe_linux` prologue:
7//!
8//! 1. Fetch the process-wide [`GpuRuntime`] or fail with a
9//! `DriverLibraryUnavailable { reason: "<module> backend: no CUDA
10//! runtime available" }`.
11//! 2. Read the runtime's selected device ordinal.
12//! 3. Create (or reuse) the per-ordinal [`CudaContext`] or fail with a
13//! `DriverCallFailed { reason: "<module> backend: failed to create
14//! CUDA context for device N" }`.
15//! 4. Open the context's default [`CudaStream`].
16//! 5. Carry the device's compute capability alongside the handles.
17//!
18//! Those five steps are identical apart from the per-module label that gets
19//! woven into the two error messages. Drift between copies meant error
20//! wording, capability handling, context reuse, and stream choice could
21//! diverge module to module. This module hosts the single contract: each
22//! backend now calls [`probe_cuda_backend`] with its label and keeps only
23//! its module caches and optional eager-compilation step.
24//!
25//! The migration is atomic: no backend re-implements the prologue, and
26//! there is no transitional shim.
27
28// `CudaBackendParts` is re-exported alongside the probe entry points: sibling
29// crates (`gam-terms`, `gam-models`, ...) call `probe_cuda_backend` and receive
30// a `CudaBackendParts` value (without ever naming the type), so `probe_cuda_backend`'s
31// public return type must itself be reachable or `-D warnings` rejects the leak
32// (`private_interfaces`). It carries no fields a caller can misuse out of context.
33#[cfg(target_os = "linux")]
34pub use linux::{
35 CudaBackendContext, CudaBackendParts, probe_backend_with_compile, probe_cuda_backend,
36};
37
38#[cfg(target_os = "linux")]
39mod linux {
40 use crate::device::GpuCapability;
41 use crate::device_cache::{DeviceArena, PtxModuleCache};
42 use crate::device_runtime::{GpuRuntime, cuda_context_for};
43 use crate::gpu_error::GpuError;
44 use cudarc::driver::{CudaContext, CudaStream};
45 use std::sync::{Arc, Mutex};
46
47 /// The handles every cudarc backend shares once the probe succeeds:
48 /// a context on the runtime's selected device, that context's default
49 /// stream, and the device's compute capability. Module-specific
50 /// backends layer their own caches and optional eager compilation on
51 /// top of these.
52 #[derive(Debug)]
53 pub struct CudaBackendParts {
54 pub ctx: Arc<CudaContext>,
55 pub stream: Arc<CudaStream>,
56 pub capability: GpuCapability,
57 }
58
59 /// Probe the process-wide CUDA backend for the calling module.
60 ///
61 /// Resolves the global [`GpuRuntime`], creates (or reuses) the
62 /// [`CudaContext`] for its selected device, opens that context's
63 /// default stream, and returns the trio bundled in [`CudaBackendParts`].
64 /// `label` names the calling module (e.g. `"bms_flex"`) and is woven
65 /// into both failure messages so the uniform contract still attributes
66 /// errors to their originating backend.
67 pub fn probe_cuda_backend(label: &'static str) -> Result<CudaBackendParts, GpuError> {
68 let runtime = GpuRuntime::global().ok_or_else(|| GpuError::DriverLibraryUnavailable {
69 reason: format!("{label} backend: no CUDA runtime available"),
70 })?;
71 let ordinal = runtime.selected_device().ordinal;
72 let ctx = cuda_context_for(ordinal).ok_or_else(|| {
73 gpu_err!("{label} backend: failed to create CUDA context for device {ordinal}")
74 })?;
75 let stream = ctx.default_stream();
76 let capability = runtime.selected_device().capability.clone();
77 Ok(CudaBackendParts {
78 ctx,
79 stream,
80 capability,
81 })
82 }
83
84 /// Probe the CUDA backend for `label` and run a backend-specific build
85 /// step on the resolved handles.
86 ///
87 /// This is [`probe_cuda_backend`] plus the one piece that genuinely
88 /// differs between backends: the NVRTC compile (and any per-backend cache
89 /// construction). The runtime resolution, context creation, and stream
90 /// selection — together with their uniform, label-attributed error
91 /// messages — live in the shared probe; `build` receives the resolved
92 /// [`CudaBackendParts`] (so it can clone the `Arc<CudaContext>` /
93 /// `Arc<CudaStream>` it needs) and returns the backend's own state `T`.
94 pub fn probe_backend_with_compile<F, T>(
95 label: &'static str,
96 build: F,
97 ) -> Result<T, GpuError>
98 where
99 F: FnOnce(&CudaBackendParts) -> Result<T, GpuError>,
100 {
101 let parts = probe_cuda_backend(label)?;
102 build(&parts)
103 }
104
105 /// The process-wide device handles every cudarc backend stores after a
106 /// successful probe: the [`CudaContext`], its default [`CudaStream`], the
107 /// lazily NVRTC-compiled [`PtxModuleCache`], and the bucketed
108 /// [`DeviceArena`] of reusable f64 device buffers (held under a `Mutex`
109 /// because large-scale fits dispatch from multiple rayon worker threads; the
110 /// mutex is only held during `alloc` / `release`, not across kernel
111 /// launches). Module-specific backends (`bms_flex`, `survival_flex`, …)
112 /// wrap one of these as their `inner` context so the host-side
113 /// scaffolding (arena pooling, module cache, mutex around alloc) is
114 /// uniform instead of duplicated per backend.
115 pub struct CudaBackendContext {
116 pub ctx: Arc<CudaContext>,
117 pub stream: Arc<CudaStream>,
118 pub module: PtxModuleCache,
119 pub arena: Mutex<DeviceArena>,
120 }
121
122 impl CudaBackendContext {
123 /// Build the stored context from a fresh [`CudaBackendParts`] probe
124 /// result: adopt its context and stream, start an empty module cache
125 /// (the backend's eager-compile step fills it), and an empty device
126 /// arena. The probe's compute `capability` is consumed by the probe
127 /// path itself and is not retained here.
128 pub fn from_parts(parts: CudaBackendParts) -> Self {
129 CudaBackendContext {
130 ctx: parts.ctx,
131 stream: parts.stream,
132 module: PtxModuleCache::new(),
133 arena: Mutex::new(DeviceArena::default()),
134 }
135 }
136 }
137}
138
139#[cfg(all(test, target_os = "linux"))]
140mod tests {
141 use super::probe_cuda_backend;
142 use crate::device_runtime::GpuRuntime;
143 use crate::gpu_error::GpuError;
144
145 /// Parity: every backend's probe must agree with the shared contract on
146 /// the same device. On a host with no CUDA runtime, the shared probe
147 /// must return the uniform `DriverLibraryUnavailable` carrying the
148 /// caller's label; on a host with a runtime, the probe must resolve the
149 /// *same* selected-device ordinal and compute capability the runtime
150 /// advertises, with a context bound to that ordinal and a usable
151 /// default stream. This is the regression guard that keeps the six
152 /// migrated backends (`bms_flex`, `survival_flex`,
153 /// `cubic_bspline_moments`, `cubic_cell`, `pirls_row`, `sphere`) routed
154 /// through one prologue instead of drifting copies.
155 #[test]
156 fn shared_probe_matches_runtime_device_and_labels_errors() {
157 match GpuRuntime::global() {
158 None => {
159 // No runtime: the shared probe must fail uniformly and
160 // attribute the failure to the supplied label.
161 match probe_cuda_backend("bms_flex") {
162 Err(GpuError::DriverLibraryUnavailable { reason }) => {
163 assert_eq!(
164 reason, "bms_flex backend: no CUDA runtime available",
165 "shared probe must emit the uniform no-runtime message"
166 );
167 }
168 other => panic!(
169 "expected DriverLibraryUnavailable on a host without a CUDA runtime, \
170 got {other:?}"
171 ),
172 }
173 }
174 Some(runtime) => {
175 // Runtime present: every label resolves the same selected
176 // device and the same compute capability the runtime
177 // advertises, and the context binds to that ordinal.
178 let expected_ordinal = runtime.selected_device().ordinal;
179 let expected_capability = &runtime.selected_device().capability;
180 for label in [
181 "bms_flex",
182 "survival_flex",
183 "cubic_bspline_moments",
184 "cubic_cell",
185 "pirls_row",
186 "sphere",
187 ] {
188 let parts = probe_cuda_backend(label)
189 .unwrap_or_else(|err| panic!("probe for {label} must succeed: {err:?}"));
190 assert_eq!(
191 parts.ctx.ordinal(),
192 expected_ordinal,
193 "{label}: context must bind the runtime's selected device ordinal"
194 );
195 assert_eq!(
196 &parts.capability, expected_capability,
197 "{label}: probe capability must match the runtime's selected device"
198 );
199 parts
200 .stream
201 .synchronize()
202 .unwrap_or_else(|err| panic!("{label}: default stream must sync: {err:?}"));
203 }
204 }
205 }
206 }
207}