Skip to main content

gam_gpu/
backend_probe.rs

1//! Shared CUDA backend-probe contract for every cudarc-backed module under
2//! `src/gpu/*`.
3//!
4//! Before this module existed, every GPU backend (`bms_flex`,
5//! `survival_flex`, `cubic_bspline_moments`, `cubic_cell`, `pirls_row`,
6//! `sphere`, ...) carried its own near-identical `probe_linux` prologue:
7//!
8//!   1. Fetch the process-wide [`GpuRuntime`] or fail with a
9//!      `DriverLibraryUnavailable { reason: "<module> backend: no CUDA
10//!      runtime available" }`.
11//!   2. Read the runtime's selected device ordinal.
12//!   3. Create (or reuse) the per-ordinal [`CudaContext`] or fail with a
13//!      `DriverCallFailed { reason: "<module> backend: failed to create
14//!      CUDA context for device N" }`.
15//!   4. Open the context's default [`CudaStream`].
16//!   5. Carry the device's compute capability alongside the handles.
17//!
18//! Those five steps are identical apart from the per-module label that gets
19//! woven into the two error messages. Drift between copies meant error
20//! wording, capability handling, context reuse, and stream choice could
21//! diverge module to module. This module hosts the single contract: each
22//! backend now calls [`probe_cuda_backend`] with its label and keeps only
23//! its module caches and optional eager-compilation step.
24//!
25//! The migration is atomic: no backend re-implements the prologue, and
26//! there is no transitional shim.
27
28// `CudaBackendParts` is re-exported alongside the probe entry points: sibling
29// crates (`gam-terms`, `gam-models`, ...) call `probe_cuda_backend` and receive
30// a `CudaBackendParts` value (without ever naming the type), so `probe_cuda_backend`'s
31// public return type must itself be reachable or `-D warnings` rejects the leak
32// (`private_interfaces`). It carries no fields a caller can misuse out of context.
33#[cfg(target_os = "linux")]
34pub use linux::{
35    CudaBackendContext, CudaBackendParts, probe_backend_with_compile, probe_cuda_backend,
36};
37
38#[cfg(target_os = "linux")]
39mod linux {
40    use crate::device::GpuCapability;
41    use crate::device_cache::{DeviceArena, PtxModuleCache};
42    use crate::device_runtime::{GpuRuntime, cuda_context_for};
43    use crate::gpu_error::GpuError;
44    use cudarc::driver::{CudaContext, CudaStream};
45    use std::sync::{Arc, Mutex};
46
47    /// The handles every cudarc backend shares once the probe succeeds:
48    /// a context on the runtime's selected device, that context's default
49    /// stream, and the device's compute capability. Module-specific
50    /// backends layer their own caches and optional eager compilation on
51    /// top of these.
52    #[derive(Debug)]
53    pub struct CudaBackendParts {
54        pub ctx: Arc<CudaContext>,
55        pub stream: Arc<CudaStream>,
56        pub capability: GpuCapability,
57    }
58
59    /// Probe the process-wide CUDA backend for the calling module.
60    ///
61    /// Resolves the global [`GpuRuntime`], creates (or reuses) the
62    /// [`CudaContext`] for its selected device, opens that context's
63    /// default stream, and returns the trio bundled in [`CudaBackendParts`].
64    /// `label` names the calling module (e.g. `"bms_flex"`) and is woven
65    /// into both failure messages so the uniform contract still attributes
66    /// errors to their originating backend.
67    pub fn probe_cuda_backend(label: &'static str) -> Result<CudaBackendParts, GpuError> {
68        let runtime = GpuRuntime::global().ok_or_else(|| GpuError::DriverLibraryUnavailable {
69            reason: format!("{label} backend: no CUDA runtime available"),
70        })?;
71        let ordinal = runtime.selected_device().ordinal;
72        let ctx = cuda_context_for(ordinal).ok_or_else(|| {
73            gpu_err!("{label} backend: failed to create CUDA context for device {ordinal}")
74        })?;
75        let stream = ctx.default_stream();
76        let capability = runtime.selected_device().capability.clone();
77        Ok(CudaBackendParts {
78            ctx,
79            stream,
80            capability,
81        })
82    }
83
84    /// Probe the CUDA backend for `label` and run a backend-specific build
85    /// step on the resolved handles.
86    ///
87    /// This is [`probe_cuda_backend`] plus the one piece that genuinely
88    /// differs between backends: the NVRTC compile (and any per-backend cache
89    /// construction). The runtime resolution, context creation, and stream
90    /// selection — together with their uniform, label-attributed error
91    /// messages — live in the shared probe; `build` receives the resolved
92    /// [`CudaBackendParts`] (so it can clone the `Arc<CudaContext>` /
93    /// `Arc<CudaStream>` it needs) and returns the backend's own state `T`.
94    pub fn probe_backend_with_compile<F, T>(label: &'static str, build: F) -> Result<T, GpuError>
95    where
96        F: FnOnce(&CudaBackendParts) -> Result<T, GpuError>,
97    {
98        let parts = probe_cuda_backend(label)?;
99        build(&parts)
100    }
101
102    /// The process-wide device handles every cudarc backend stores after a
103    /// successful probe: the [`CudaContext`], its default [`CudaStream`], the
104    /// lazily NVRTC-compiled [`PtxModuleCache`], and the bucketed
105    /// [`DeviceArena`] of reusable f64 device buffers (held under a `Mutex`
106    /// because large-scale fits dispatch from multiple rayon worker threads; the
107    /// mutex is only held during `alloc` / `release`, not across kernel
108    /// launches). Module-specific backends (`bms_flex`, `survival_flex`, …)
109    /// wrap one of these as their `inner` context so the host-side
110    /// scaffolding (arena pooling, module cache, mutex around alloc) is
111    /// uniform instead of duplicated per backend.
112    pub struct CudaBackendContext {
113        pub ctx: Arc<CudaContext>,
114        pub stream: Arc<CudaStream>,
115        pub module: PtxModuleCache,
116        pub arena: Mutex<DeviceArena>,
117    }
118
119    impl CudaBackendContext {
120        /// Build the stored context from a fresh [`CudaBackendParts`] probe
121        /// result: adopt its context and stream, start an empty module cache
122        /// (the backend's eager-compile step fills it), and an empty device
123        /// arena. The probe's compute `capability` is consumed by the probe
124        /// path itself and is not retained here.
125        pub fn from_parts(parts: CudaBackendParts) -> Self {
126            CudaBackendContext {
127                ctx: parts.ctx,
128                stream: parts.stream,
129                module: PtxModuleCache::new(),
130                arena: Mutex::new(DeviceArena::default()),
131            }
132        }
133    }
134}
135
136#[cfg(all(test, target_os = "linux"))]
137mod tests {
138    use super::probe_cuda_backend;
139    use crate::device_runtime::GpuRuntime;
140    use crate::gpu_error::GpuError;
141
142    /// Parity: every backend's probe must agree with the shared contract on
143    /// the same device. On a host with no CUDA runtime, the shared probe
144    /// must return the uniform `DriverLibraryUnavailable` carrying the
145    /// caller's label; on a host with a runtime, the probe must resolve the
146    /// *same* selected-device ordinal and compute capability the runtime
147    /// advertises, with a context bound to that ordinal and a usable
148    /// default stream. This is the regression guard that keeps the six
149    /// migrated backends (`bms_flex`, `survival_flex`,
150    /// `cubic_bspline_moments`, `cubic_cell`, `pirls_row`, `sphere`) routed
151    /// through one prologue instead of drifting copies.
152    #[test]
153    fn shared_probe_matches_runtime_device_and_labels_errors() {
154        match GpuRuntime::global() {
155            None => {
156                // No runtime: the shared probe must fail uniformly and
157                // attribute the failure to the supplied label.
158                match probe_cuda_backend("bms_flex") {
159                    Err(GpuError::DriverLibraryUnavailable { reason }) => {
160                        assert_eq!(
161                            reason, "bms_flex backend: no CUDA runtime available",
162                            "shared probe must emit the uniform no-runtime message"
163                        );
164                    }
165                    other => panic!(
166                        "expected DriverLibraryUnavailable on a host without a CUDA runtime, \
167                         got {other:?}"
168                    ),
169                }
170            }
171            Some(runtime) => {
172                // Runtime present: every label resolves the same selected
173                // device and the same compute capability the runtime
174                // advertises, and the context binds to that ordinal.
175                let expected_ordinal = runtime.selected_device().ordinal;
176                let expected_capability = &runtime.selected_device().capability;
177                for label in [
178                    "bms_flex",
179                    "survival_flex",
180                    "cubic_bspline_moments",
181                    "cubic_cell",
182                    "pirls_row",
183                    "sphere",
184                ] {
185                    let parts = probe_cuda_backend(label)
186                        .unwrap_or_else(|err| panic!("probe for {label} must succeed: {err:?}"));
187                    assert_eq!(
188                        parts.ctx.ordinal(),
189                        expected_ordinal,
190                        "{label}: context must bind the runtime's selected device ordinal"
191                    );
192                    assert_eq!(
193                        &parts.capability, expected_capability,
194                        "{label}: probe capability must match the runtime's selected device"
195                    );
196                    parts
197                        .stream
198                        .synchronize()
199                        .unwrap_or_else(|err| panic!("{label}: default stream must sync: {err:?}"));
200                }
201            }
202        }
203    }
204}