Skip to main content

gam_gpu/
backend_probe.rs

1//! Shared CUDA backend-probe contract for every cudarc-backed module under
2//! `src/gpu/*`.
3//!
4//! Before this module existed, every GPU backend (`bms_flex`,
5//! `survival_flex`, `cubic_bspline_moments`, `cubic_cell`, `pirls_row`,
6//! `sphere`, ...) carried its own near-identical `probe_linux` prologue:
7//!
8//!   1. Fetch the process-wide [`GpuRuntime`] or fail with a
9//!      `DriverLibraryUnavailable { reason: "<module> backend: no CUDA
10//!      runtime available" }`.
11//!   2. Read the runtime's selected device ordinal.
12//!   3. Create (or reuse) the per-ordinal [`CudaContext`] or fail with a
13//!      `DriverCallFailed { reason: "<module> backend: failed to create
14//!      CUDA context for device N" }`.
15//!   4. Open the context's default [`CudaStream`].
16//!   5. Carry the device's compute capability alongside the handles.
17//!
18//! Those five steps are identical apart from the per-module label that gets
19//! woven into the two error messages. Drift between copies meant error
20//! wording, capability handling, context reuse, and stream choice could
21//! diverge module to module. This module hosts the single contract: each
22//! backend now calls [`probe_cuda_backend`] with its label and keeps only
23//! its module caches and optional eager-compilation step.
24//!
25//! The migration is atomic: no backend re-implements the prologue, and
26//! there is no transitional shim.
27
28// `CudaBackendParts` is re-exported alongside the probe entry points: sibling
29// crates (`gam-terms`, `gam-models`, ...) call `probe_cuda_backend` and receive
30// a `CudaBackendParts` value (without ever naming the type), so `probe_cuda_backend`'s
31// public return type must itself be reachable or `-D warnings` rejects the leak
32// (`private_interfaces`). It carries no fields a caller can misuse out of context.
33#[cfg(target_os = "linux")]
34pub use linux::{
35    CudaBackendContext, CudaBackendParts, probe_backend_with_compile, probe_cuda_backend,
36};
37
38#[cfg(target_os = "linux")]
39mod linux {
40    use crate::device::GpuCapability;
41    use crate::device_cache::{DeviceArena, PtxModuleCache};
42    use crate::device_runtime::{GpuRuntime, cuda_context_for};
43    use crate::gpu_error::GpuError;
44    use cudarc::driver::{CudaContext, CudaStream};
45    use std::sync::{Arc, Mutex};
46
47    /// The handles every cudarc backend shares once the probe succeeds:
48    /// a context on the runtime's selected device, that context's default
49    /// stream, and the device's compute capability. Module-specific
50    /// backends layer their own caches and optional eager compilation on
51    /// top of these.
52    #[derive(Debug)]
53    pub struct CudaBackendParts {
54        pub ctx: Arc<CudaContext>,
55        pub stream: Arc<CudaStream>,
56        pub capability: GpuCapability,
57    }
58
59    /// Probe the process-wide CUDA backend for the calling module.
60    ///
61    /// Resolves the global [`GpuRuntime`], creates (or reuses) the
62    /// [`CudaContext`] for its selected device, opens that context's
63    /// default stream, and returns the trio bundled in [`CudaBackendParts`].
64    /// `label` names the calling module (e.g. `"bms_flex"`) and is woven
65    /// into both failure messages so the uniform contract still attributes
66    /// errors to their originating backend.
67    pub fn probe_cuda_backend(label: &'static str) -> Result<CudaBackendParts, GpuError> {
68        let runtime = GpuRuntime::global().ok_or_else(|| GpuError::DriverLibraryUnavailable {
69            reason: format!("{label} backend: no CUDA runtime available"),
70        })?;
71        let ordinal = runtime.selected_device().ordinal;
72        let ctx = cuda_context_for(ordinal).ok_or_else(|| {
73            gpu_err!("{label} backend: failed to create CUDA context for device {ordinal}")
74        })?;
75        let stream = ctx.default_stream();
76        let capability = runtime.selected_device().capability.clone();
77        Ok(CudaBackendParts {
78            ctx,
79            stream,
80            capability,
81        })
82    }
83
84    /// Probe the CUDA backend for `label` and run a backend-specific build
85    /// step on the resolved handles.
86    ///
87    /// This is [`probe_cuda_backend`] plus the one piece that genuinely
88    /// differs between backends: the NVRTC compile (and any per-backend cache
89    /// construction). The runtime resolution, context creation, and stream
90    /// selection — together with their uniform, label-attributed error
91    /// messages — live in the shared probe; `build` receives the resolved
92    /// [`CudaBackendParts`] (so it can clone the `Arc<CudaContext>` /
93    /// `Arc<CudaStream>` it needs) and returns the backend's own state `T`.
94    pub fn probe_backend_with_compile<F, T>(
95        label: &'static str,
96        build: F,
97    ) -> Result<T, GpuError>
98    where
99        F: FnOnce(&CudaBackendParts) -> Result<T, GpuError>,
100    {
101        let parts = probe_cuda_backend(label)?;
102        build(&parts)
103    }
104
105    /// The process-wide device handles every cudarc backend stores after a
106    /// successful probe: the [`CudaContext`], its default [`CudaStream`], the
107    /// lazily NVRTC-compiled [`PtxModuleCache`], and the bucketed
108    /// [`DeviceArena`] of reusable f64 device buffers (held under a `Mutex`
109    /// because large-scale fits dispatch from multiple rayon worker threads; the
110    /// mutex is only held during `alloc` / `release`, not across kernel
111    /// launches). Module-specific backends (`bms_flex`, `survival_flex`, …)
112    /// wrap one of these as their `inner` context so the host-side
113    /// scaffolding (arena pooling, module cache, mutex around alloc) is
114    /// uniform instead of duplicated per backend.
115    pub struct CudaBackendContext {
116        pub ctx: Arc<CudaContext>,
117        pub stream: Arc<CudaStream>,
118        pub module: PtxModuleCache,
119        pub arena: Mutex<DeviceArena>,
120    }
121
122    impl CudaBackendContext {
123        /// Build the stored context from a fresh [`CudaBackendParts`] probe
124        /// result: adopt its context and stream, start an empty module cache
125        /// (the backend's eager-compile step fills it), and an empty device
126        /// arena. The probe's compute `capability` is consumed by the probe
127        /// path itself and is not retained here.
128        pub fn from_parts(parts: CudaBackendParts) -> Self {
129            CudaBackendContext {
130                ctx: parts.ctx,
131                stream: parts.stream,
132                module: PtxModuleCache::new(),
133                arena: Mutex::new(DeviceArena::default()),
134            }
135        }
136    }
137}
138
139#[cfg(all(test, target_os = "linux"))]
140mod tests {
141    use super::probe_cuda_backend;
142    use crate::device_runtime::GpuRuntime;
143    use crate::gpu_error::GpuError;
144
145    /// Parity: every backend's probe must agree with the shared contract on
146    /// the same device. On a host with no CUDA runtime, the shared probe
147    /// must return the uniform `DriverLibraryUnavailable` carrying the
148    /// caller's label; on a host with a runtime, the probe must resolve the
149    /// *same* selected-device ordinal and compute capability the runtime
150    /// advertises, with a context bound to that ordinal and a usable
151    /// default stream. This is the regression guard that keeps the six
152    /// migrated backends (`bms_flex`, `survival_flex`,
153    /// `cubic_bspline_moments`, `cubic_cell`, `pirls_row`, `sphere`) routed
154    /// through one prologue instead of drifting copies.
155    #[test]
156    fn shared_probe_matches_runtime_device_and_labels_errors() {
157        match GpuRuntime::global() {
158            None => {
159                // No runtime: the shared probe must fail uniformly and
160                // attribute the failure to the supplied label.
161                match probe_cuda_backend("bms_flex") {
162                    Err(GpuError::DriverLibraryUnavailable { reason }) => {
163                        assert_eq!(
164                            reason, "bms_flex backend: no CUDA runtime available",
165                            "shared probe must emit the uniform no-runtime message"
166                        );
167                    }
168                    other => panic!(
169                        "expected DriverLibraryUnavailable on a host without a CUDA runtime, \
170                         got {other:?}"
171                    ),
172                }
173            }
174            Some(runtime) => {
175                // Runtime present: every label resolves the same selected
176                // device and the same compute capability the runtime
177                // advertises, and the context binds to that ordinal.
178                let expected_ordinal = runtime.selected_device().ordinal;
179                let expected_capability = &runtime.selected_device().capability;
180                for label in [
181                    "bms_flex",
182                    "survival_flex",
183                    "cubic_bspline_moments",
184                    "cubic_cell",
185                    "pirls_row",
186                    "sphere",
187                ] {
188                    let parts = probe_cuda_backend(label)
189                        .unwrap_or_else(|err| panic!("probe for {label} must succeed: {err:?}"));
190                    assert_eq!(
191                        parts.ctx.ordinal(),
192                        expected_ordinal,
193                        "{label}: context must bind the runtime's selected device ordinal"
194                    );
195                    assert_eq!(
196                        &parts.capability, expected_capability,
197                        "{label}: probe capability must match the runtime's selected device"
198                    );
199                    parts
200                        .stream
201                        .synchronize()
202                        .unwrap_or_else(|err| panic!("{label}: default stream must sync: {err:?}"));
203                }
204            }
205        }
206    }
207}