Skip to main content

gam_models/bms/gpu/
flex.rs

1//! Bernoulli marginal-slope FLEX GPU policy and backend probe.
2
3use std::sync::OnceLock;
4
5use gam_gpu::gpu_error::GpuError;
6#[cfg(target_os = "linux")]
7use gam_gpu::gpu_error::GpuResultExt;
8use gam_gpu::{GpuDecision, GpuKernel, decide};
9
10#[cfg(target_os = "linux")]
11use std::sync::Arc;
12
13#[cfg(target_os = "linux")]
14use cudarc::driver::CudaModule;
15
16/// Decide whether the GPU row-primary Hessian path is eligible for this
17/// fit's `(n, r)`. Always-`use_gpu=false` for `r == 0` (no flex jets to
18/// process) and below the runtime row-kernel threshold.
19#[must_use]
20pub fn row_primary_hessian_decision(n: usize, r: usize) -> GpuDecision {
21    let large_enough = gam_gpu::device_runtime::GpuRuntime::global()
22        .map(|runtime| n >= runtime.policy().row_kernel_min_n && r > 0)
23        .unwrap_or(false);
24    decide(
25        GpuKernel::MarginalSlopeRows,
26        gam_gpu::GpuEligibility::from_flags(BmsFlexGpuBackend::compiled(), large_enough),
27    )
28}
29
30/// Same as [`row_primary_hessian_decision`] but turns
31/// `gpu=force`-without-support into an `Err` string at the call site.
32pub fn require_row_primary_hessian_supported(n: usize, r: usize) -> Result<GpuDecision, String> {
33    let decision = row_primary_hessian_decision(n, r);
34    decision.clone().log();
35    decision.require_supported()?;
36    Ok(decision)
37}
38
39/// The PTX source compiled and loaded at first use of the BMS flex GPU
40/// backend. The probe kernel exercises the full NVRTC → cuModuleLoadData
41/// → cuModuleGetFunction → cuLaunchKernel path so the scaffolding catches
42/// host-side issues (PTX cache, arena alloc, stream sync) before the real
43/// row kernel is dispatched by the row-primary cache builder.
44#[cfg(target_os = "linux")]
45pub(crate) const PROBE_KERNEL_SOURCE: &str = r#"
46extern "C" __global__ void bms_flex_probe() {
47    // Intentionally empty. This kernel exists only so the scaffolding can
48    // verify NVRTC compile + module load + launch + synchronize on the
49    // selected device. The real row math lives in the bms_flex_row module.
50}
51"#;
52
53/// Process-wide BMS-flex GPU backend. Lazy-initialised on first call to
54/// [`BmsFlexGpuBackend::probe`].
55#[must_use]
56pub struct BmsFlexGpuBackend {
57    #[cfg(target_os = "linux")]
58    pub(crate) inner: gam_gpu::backend_probe::CudaBackendContext,
59}
60
61impl BmsFlexGpuBackend {
62    /// Returns `true` if the BMS flex GPU backend is compiled into this
63    /// build (Linux + cudarc). On non-Linux builds returns `false` so the
64    /// policy gate reports `cpu-gpu-backend-not-compiled` like the rest
65    /// of the GPU layer.
66    pub const fn compiled() -> bool {
67        cfg!(target_os = "linux")
68    }
69
70    /// Lazily initialise the process-wide BMS flex backend. On the first
71    /// successful call this creates a CUDA context on the runtime's
72    /// selected device, opens a stream, and NVRTC-compiles the probe
73    /// kernel. Subsequent calls return the cached handle.
74    pub fn probe() -> Result<&'static Self, GpuError> {
75        static BACKEND: OnceLock<Result<BmsFlexGpuBackend, GpuError>> = OnceLock::new();
76        BACKEND
77            .get_or_init(|| {
78                #[cfg(target_os = "linux")]
79                {
80                    Self::probe_linux()
81                }
82                #[cfg(not(target_os = "linux"))]
83                {
84                    Err(GpuError::DriverLibraryUnavailable {
85                        reason: "bms_flex GPU backend is Linux-only".to_string(),
86                    })
87                }
88            })
89            .as_ref()
90            .map_err(GpuError::clone)
91    }
92
93    #[cfg(target_os = "linux")]
94    pub(crate) fn probe_linux() -> Result<Self, GpuError> {
95        let parts = gam_gpu::backend_probe::probe_cuda_backend("bms_flex")?;
96        let backend = BmsFlexGpuBackend {
97            inner: gam_gpu::backend_probe::CudaBackendContext::from_parts(parts),
98        };
99        // Eagerly compile the probe kernel so any NVRTC failure surfaces
100        // here, not at first dispatch.
101        backend.compile_probe_module()?;
102        Ok(backend)
103    }
104
105    /// NVRTC-compile (or fetch from cache) the probe module.
106    #[cfg(target_os = "linux")]
107    pub(crate) fn compile_probe_module(&self) -> Result<&Arc<CudaModule>, GpuError> {
108        self.inner
109            .module
110            .get_or_compile(&self.inner.ctx, "bms_flex", PROBE_KERNEL_SOURCE)
111    }
112
113    /// Launch the probe kernel and synchronize. Used by tests and by the
114    /// dispatcher's policy gate to verify the full host-orchestration
115    /// path before the real row kernel is dispatched.
116    #[cfg(target_os = "linux")]
117    pub fn launch_probe(&self) -> Result<(), GpuError> {
118        use cudarc::driver::LaunchConfig;
119        let module = self.compile_probe_module()?;
120        let func = module
121            .load_function("bms_flex_probe")
122            .gpu_ctx("bms_flex probe load_function")?;
123        let cfg = LaunchConfig {
124            grid_dim: (1, 1, 1),
125            block_dim: (1, 1, 1),
126            shared_mem_bytes: 0,
127        };
128        let mut builder = self.inner.stream.launch_builder(&func);
129        // SAFETY: probe kernel takes no arguments and does no memory
130        // access, so launch parameters and lack of args are trivially
131        // valid for any device.
132        unsafe { builder.launch(cfg) }.gpu_ctx("bms_flex probe launch")?;
133        self.inner
134            .stream
135            .synchronize()
136            .gpu_ctx("bms_flex probe synchronize")?;
137        Ok(())
138    }
139
140    #[cfg(not(target_os = "linux"))]
141    pub fn launch_probe(&self) -> Result<(), GpuError> {
142        Err(GpuError::DriverLibraryUnavailable {
143            reason: "bms_flex GPU backend is Linux-only".to_string(),
144        })
145    }
146
147    /// Round-trip the arena: allocate a slab, immediately release it.
148    /// Used by the device-side smoke test to verify the arena code path
149    /// is exercised; production milestones will hold slabs across the
150    /// whole row sweep instead.
151    #[cfg(target_os = "linux")]
152    pub fn arena_round_trip(&self, elements: usize) -> Result<usize, GpuError> {
153        let mut guard = self
154            .inner
155            .arena
156            .lock()
157            .gpu_ctx("bms_flex arena mutex poisoned")?;
158        let (bucket, slab) = guard.alloc(&self.inner.stream, elements, "bms_flex")?;
159        guard.release(bucket, slab);
160        Ok(bucket)
161    }
162
163    /// Return a short string describing the backend state, for logs.
164    pub fn describe(&self) -> String {
165        #[cfg(target_os = "linux")]
166        {
167            return format!(
168                "bms_flex backend: device={:?} module_loaded={}",
169                self.inner.ctx.name().ok(),
170                self.inner.module.get().is_some()
171            );
172        }
173        #[cfg(not(target_os = "linux"))]
174        {
175            "bms_flex backend: unavailable (not Linux)".to_string()
176        }
177    }
178}
179
180// ────────────────────────────────────────────────────────────────────────
181// Tests. Run via `cargo test -p gam bms_flex_gpu -- --nocapture`.
182// ────────────────────────────────────────────────────────────────────────
183
184#[cfg(test)]
185mod bms_flex_gpu_tests {
186    use super::*;
187
188    #[test]
189    pub(crate) fn bms_flex_gpu_policy_decision_is_explicit() {
190        let decision = row_primary_hessian_decision(50_000, 4);
191        assert_eq!(decision.kernel, GpuKernel::MarginalSlopeRows);
192    }
193
194    /// V100-only: probe the backend end-to-end (CUDA context create, NVRTC
195    /// compile, module load, launch, sync). Skipped on hosts without a
196    /// usable device so the test still passes on the CI/mac builders.
197    #[test]
198    pub(crate) fn bms_flex_gpu_context_initialises_when_device_present() {
199        let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
200            eprintln!("[bms_flex_gpu test] no CUDA runtime — skipping device-side init smoketest");
201            return;
202        };
203        eprintln!(
204            "[bms_flex_gpu test] runtime selected device ordinal={}",
205            runtime.selected_device().ordinal
206        );
207        let backend = BmsFlexGpuBackend::probe().unwrap_or_else(|err| {
208            panic!("BmsFlexGpuBackend::probe failed on a host that reports a CUDA runtime: {err}")
209        });
210        eprintln!("[bms_flex_gpu test] {}", backend.describe());
211        backend
212            .launch_probe()
213            .expect("probe kernel must launch+sync on a host with a usable device");
214        #[cfg(target_os = "linux")]
215        {
216            let bucket = backend
217                .arena_round_trip(1024)
218                .expect("arena round-trip must succeed on a host with a usable device");
219            assert!(bucket >= 1024, "bucket must be >= requested elements");
220            // Second round-trip at the same size should hit the cache.
221            let bucket2 = backend
222                .arena_round_trip(1024)
223                .expect("arena round-trip must succeed on a host with a usable device");
224            assert_eq!(bucket, bucket2, "bucket size must be stable for same input");
225        }
226    }
227}