use std::sync::OnceLock;
use crate::gpu::gpu_error::GpuError;
#[cfg(target_os = "linux")]
use crate::gpu::gpu_error::GpuResultExt;
use crate::gpu::{GpuDecision, GpuKernel, decide};
#[cfg(target_os = "linux")]
use std::sync::Arc;
#[cfg(target_os = "linux")]
use cudarc::driver::CudaModule;
#[must_use]
pub fn row_primary_hessian_decision(n: usize, r: usize) -> GpuDecision {
let large_enough = crate::gpu::device_runtime::GpuRuntime::global()
.map(|runtime| n >= runtime.policy().row_kernel_min_n && r > 0)
.unwrap_or(false);
decide(
GpuKernel::MarginalSlopeRows,
crate::gpu::GpuEligibility::from_flags(BmsFlexGpuBackend::compiled(), large_enough),
)
}
pub fn require_row_primary_hessian_supported(n: usize, r: usize) -> Result<GpuDecision, String> {
let decision = row_primary_hessian_decision(n, r);
decision.clone().log();
decision.require_supported()?;
Ok(decision)
}
#[cfg(target_os = "linux")]
pub(crate) const PROBE_KERNEL_SOURCE: &str = r#"
extern "C" __global__ void bms_flex_probe() {
// Intentionally empty. This kernel exists only so the scaffolding can
// verify NVRTC compile + module load + launch + synchronize on the
// selected device. The real row math lives in the bms_flex_row module.
}
"#;
#[must_use]
pub struct BmsFlexGpuBackend {
#[cfg(target_os = "linux")]
pub(crate) inner: crate::gpu::backend_probe::CudaBackendContext,
}
impl BmsFlexGpuBackend {
pub const fn compiled() -> bool {
cfg!(target_os = "linux")
}
pub fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<BmsFlexGpuBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
#[cfg(target_os = "linux")]
{
Self::probe_linux()
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "bms_flex GPU backend is Linux-only".to_string(),
})
}
})
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
pub(crate) fn probe_linux() -> Result<Self, GpuError> {
let parts = crate::gpu::backend_probe::probe_cuda_backend("bms_flex")?;
let backend = BmsFlexGpuBackend {
inner: crate::gpu::backend_probe::CudaBackendContext::from_parts(parts),
};
backend.compile_probe_module()?;
Ok(backend)
}
#[cfg(target_os = "linux")]
pub(crate) fn compile_probe_module(&self) -> Result<&Arc<CudaModule>, GpuError> {
self.inner
.module
.get_or_compile(&self.inner.ctx, "bms_flex", PROBE_KERNEL_SOURCE)
}
#[cfg(target_os = "linux")]
pub fn launch_probe(&self) -> Result<(), GpuError> {
use cudarc::driver::LaunchConfig;
let module = self.compile_probe_module()?;
let func = module
.load_function("bms_flex_probe")
.gpu_ctx("bms_flex probe load_function")?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = self.inner.stream.launch_builder(&func);
unsafe { builder.launch(cfg) }.gpu_ctx("bms_flex probe launch")?;
self.inner
.stream
.synchronize()
.gpu_ctx("bms_flex probe synchronize")?;
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn launch_probe(&self) -> Result<(), GpuError> {
Err(GpuError::DriverLibraryUnavailable {
reason: "bms_flex GPU backend is Linux-only".to_string(),
})
}
#[cfg(target_os = "linux")]
pub fn arena_round_trip(&self, elements: usize) -> Result<usize, GpuError> {
let mut guard = self
.inner
.arena
.lock()
.gpu_ctx("bms_flex arena mutex poisoned")?;
let (bucket, slab) = guard.alloc(&self.inner.stream, elements, "bms_flex")?;
guard.release(bucket, slab);
Ok(bucket)
}
pub fn describe(&self) -> String {
#[cfg(target_os = "linux")]
{
return format!(
"bms_flex backend: device={:?} module_loaded={}",
self.inner.ctx.name().ok(),
self.inner.module.get().is_some()
);
}
#[cfg(not(target_os = "linux"))]
{
"bms_flex backend: unavailable (not Linux)".to_string()
}
}
}
#[cfg(test)]
mod bms_flex_gpu_tests {
use super::*;
#[test]
pub(crate) fn bms_flex_gpu_policy_decision_is_explicit() {
let decision = row_primary_hessian_decision(50_000, 4);
assert_eq!(decision.kernel, GpuKernel::MarginalSlopeRows);
}
#[test]
pub(crate) fn bms_flex_gpu_context_initialises_when_device_present() {
let Some(runtime) = crate::gpu::device_runtime::GpuRuntime::global() else {
eprintln!("[bms_flex_gpu test] no CUDA runtime — skipping device-side init smoketest");
return;
};
eprintln!(
"[bms_flex_gpu test] runtime selected device ordinal={}",
runtime.selected_device().ordinal
);
let backend = BmsFlexGpuBackend::probe().unwrap_or_else(|err| {
panic!("BmsFlexGpuBackend::probe failed on a host that reports a CUDA runtime: {err}")
});
eprintln!("[bms_flex_gpu test] {}", backend.describe());
backend
.launch_probe()
.expect("probe kernel must launch+sync on a host with a usable device");
#[cfg(target_os = "linux")]
{
let bucket = backend
.arena_round_trip(1024)
.expect("arena round-trip must succeed on a host with a usable device");
assert!(bucket >= 1024, "bucket must be >= requested elements");
let bucket2 = backend
.arena_round_trip(1024)
.expect("arena round-trip must succeed on a host with a usable device");
assert_eq!(bucket, bucket2, "bucket size must be stable for same input");
}
}
}