gam_models/bms/gpu/
flex.rs1use std::sync::OnceLock;
4
5use gam_gpu::gpu_error::GpuError;
6#[cfg(target_os = "linux")]
7use gam_gpu::gpu_error::GpuResultExt;
8use gam_gpu::{GpuDecision, GpuKernel, decide};
9
10#[cfg(target_os = "linux")]
11use std::sync::Arc;
12
13#[cfg(target_os = "linux")]
14use cudarc::driver::CudaModule;
15
16#[must_use]
20pub fn row_primary_hessian_decision(n: usize, r: usize) -> GpuDecision {
21 let large_enough = gam_gpu::device_runtime::GpuRuntime::global()
22 .map(|runtime| n >= runtime.policy().row_kernel_min_n && r > 0)
23 .unwrap_or(false);
24 decide(
25 GpuKernel::MarginalSlopeRows,
26 gam_gpu::GpuEligibility::from_flags(BmsFlexGpuBackend::compiled(), large_enough),
27 )
28}
29
30pub fn require_row_primary_hessian_supported(n: usize, r: usize) -> Result<GpuDecision, String> {
33 let decision = row_primary_hessian_decision(n, r);
34 decision.clone().log();
35 decision.require_supported()?;
36 Ok(decision)
37}
38
39#[cfg(target_os = "linux")]
45pub(crate) const PROBE_KERNEL_SOURCE: &str = r#"
46extern "C" __global__ void bms_flex_probe() {
47 // Intentionally empty. This kernel exists only so the scaffolding can
48 // verify NVRTC compile + module load + launch + synchronize on the
49 // selected device. The real row math lives in the bms_flex_row module.
50}
51"#;
52
53#[must_use]
56pub struct BmsFlexGpuBackend {
57 #[cfg(target_os = "linux")]
58 pub(crate) inner: gam_gpu::backend_probe::CudaBackendContext,
59}
60
61impl BmsFlexGpuBackend {
62 pub const fn compiled() -> bool {
67 cfg!(target_os = "linux")
68 }
69
70 pub fn probe() -> Result<&'static Self, GpuError> {
75 static BACKEND: OnceLock<Result<BmsFlexGpuBackend, GpuError>> = OnceLock::new();
76 BACKEND
77 .get_or_init(|| {
78 #[cfg(target_os = "linux")]
79 {
80 Self::probe_linux()
81 }
82 #[cfg(not(target_os = "linux"))]
83 {
84 Err(GpuError::DriverLibraryUnavailable {
85 reason: "bms_flex GPU backend is Linux-only".to_string(),
86 })
87 }
88 })
89 .as_ref()
90 .map_err(GpuError::clone)
91 }
92
93 #[cfg(target_os = "linux")]
94 pub(crate) fn probe_linux() -> Result<Self, GpuError> {
95 let parts = gam_gpu::backend_probe::probe_cuda_backend("bms_flex")?;
96 let backend = BmsFlexGpuBackend {
97 inner: gam_gpu::backend_probe::CudaBackendContext::from_parts(parts),
98 };
99 backend.compile_probe_module()?;
102 Ok(backend)
103 }
104
105 #[cfg(target_os = "linux")]
107 pub(crate) fn compile_probe_module(&self) -> Result<&Arc<CudaModule>, GpuError> {
108 self.inner
109 .module
110 .get_or_compile(&self.inner.ctx, "bms_flex", PROBE_KERNEL_SOURCE)
111 }
112
113 #[cfg(target_os = "linux")]
117 pub fn launch_probe(&self) -> Result<(), GpuError> {
118 use cudarc::driver::LaunchConfig;
119 let module = self.compile_probe_module()?;
120 let func = module
121 .load_function("bms_flex_probe")
122 .gpu_ctx("bms_flex probe load_function")?;
123 let cfg = LaunchConfig {
124 grid_dim: (1, 1, 1),
125 block_dim: (1, 1, 1),
126 shared_mem_bytes: 0,
127 };
128 let mut builder = self.inner.stream.launch_builder(&func);
129 unsafe { builder.launch(cfg) }.gpu_ctx("bms_flex probe launch")?;
133 self.inner
134 .stream
135 .synchronize()
136 .gpu_ctx("bms_flex probe synchronize")?;
137 Ok(())
138 }
139
140 #[cfg(not(target_os = "linux"))]
141 pub fn launch_probe(&self) -> Result<(), GpuError> {
142 Err(GpuError::DriverLibraryUnavailable {
143 reason: "bms_flex GPU backend is Linux-only".to_string(),
144 })
145 }
146
147 #[cfg(target_os = "linux")]
152 pub fn arena_round_trip(&self, elements: usize) -> Result<usize, GpuError> {
153 let mut guard = self
154 .inner
155 .arena
156 .lock()
157 .gpu_ctx("bms_flex arena mutex poisoned")?;
158 let (bucket, slab) = guard.alloc(&self.inner.stream, elements, "bms_flex")?;
159 guard.release(bucket, slab);
160 Ok(bucket)
161 }
162
163 pub fn describe(&self) -> String {
165 #[cfg(target_os = "linux")]
166 {
167 return format!(
168 "bms_flex backend: device={:?} module_loaded={}",
169 self.inner.ctx.name().ok(),
170 self.inner.module.get().is_some()
171 );
172 }
173 #[cfg(not(target_os = "linux"))]
174 {
175 "bms_flex backend: unavailable (not Linux)".to_string()
176 }
177 }
178}
179
180#[cfg(test)]
185mod bms_flex_gpu_tests {
186 use super::*;
187
188 #[test]
189 pub(crate) fn bms_flex_gpu_policy_decision_is_explicit() {
190 let decision = row_primary_hessian_decision(50_000, 4);
191 assert_eq!(decision.kernel, GpuKernel::MarginalSlopeRows);
192 }
193
194 #[test]
198 pub(crate) fn bms_flex_gpu_context_initialises_when_device_present() {
199 let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
200 eprintln!("[bms_flex_gpu test] no CUDA runtime — skipping device-side init smoketest");
201 return;
202 };
203 eprintln!(
204 "[bms_flex_gpu test] runtime selected device ordinal={}",
205 runtime.selected_device().ordinal
206 );
207 let backend = BmsFlexGpuBackend::probe().unwrap_or_else(|err| {
208 panic!("BmsFlexGpuBackend::probe failed on a host that reports a CUDA runtime: {err}")
209 });
210 eprintln!("[bms_flex_gpu test] {}", backend.describe());
211 backend
212 .launch_probe()
213 .expect("probe kernel must launch+sync on a host with a usable device");
214 #[cfg(target_os = "linux")]
215 {
216 let bucket = backend
217 .arena_round_trip(1024)
218 .expect("arena round-trip must succeed on a host with a usable device");
219 assert!(bucket >= 1024, "bucket must be >= requested elements");
220 let bucket2 = backend
222 .arena_round_trip(1024)
223 .expect("arena round-trip must succeed on a host with a usable device");
224 assert_eq!(bucket, bucket2, "bucket size must be stable for same input");
225 }
226 }
227}