use std::sync::OnceLock;
use ndarray::Array2;
use super::error::GpuError;
#[cfg(target_os = "linux")]
use super::error::GpuResultExt;
use super::{GpuDecision, GpuKernel, decide};
#[cfg(target_os = "linux")]
use crate::gpu_err;
use super::bms_flex_row::{
BmsFlexRowKernelInputs, BmsFlexRowKernelOutputs, launch_bms_flex_row_kernel,
};
#[cfg(target_os = "linux")]
use std::sync::{Arc, Mutex};
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaContext, CudaModule, CudaStream};
#[cfg(target_os = "linux")]
use super::common::{DeviceArena, PtxModuleCache};
#[must_use]
pub fn row_primary_hessian_decision(n: usize, r: usize) -> GpuDecision {
let large_enough = super::runtime::GpuRuntime::global()
.map(|runtime| n >= runtime.policy().row_kernel_min_n && r > 0)
.unwrap_or(false);
decide(
GpuKernel::MarginalSlopeRows,
super::GpuEligibility::from_flags(BmsFlexGpuBackend::compiled(), large_enough),
)
}
pub fn require_row_primary_hessian_supported(n: usize, r: usize) -> Result<GpuDecision, String> {
let decision = row_primary_hessian_decision(n, r);
decision.clone().log();
decision.require_supported()?;
Ok(decision)
}
#[derive(Clone, Copy, Debug)]
pub struct BmsFlexGpuRowInputs<'a> {
pub n: usize,
pub r: usize,
pub p: usize,
pub q_dim: usize,
pub g_dim: usize,
pub p_h: usize,
pub p_w: usize,
pub beta: &'a [f64],
pub y: &'a [f64],
pub weights: &'a [f64],
pub q: &'a [f64],
pub b: &'a [f64],
pub mu_1: &'a [f64],
pub mu_2: &'a [f64],
pub z_obs: &'a [f64],
pub s_f: f64,
pub cell_offsets: &'a [u32],
pub cell_c0: &'a [f64],
pub cell_c1: &'a [f64],
pub cell_c2: &'a [f64],
pub cell_c3: &'a [f64],
pub cell_a: &'a [f64],
pub cell_aa: &'a [f64],
pub cell_r: &'a [f64],
pub cell_ar: &'a [f64],
pub cell_sbb: &'a [f64],
pub cell_sbh: &'a [f64],
pub cell_sbw: &'a [f64],
pub cell_moments: &'a [f64],
pub chi_obs: &'a [f64],
pub xi_obs: &'a [f64],
pub rho_u: &'a [f64],
pub tau_u: &'a [f64],
pub r_uv: &'a [f64],
pub x_design: &'a [f64],
pub g_design: &'a [f64],
}
impl<'a> BmsFlexGpuRowInputs<'a> {
fn validate(&self) -> Result<(), GpuError> {
let want_p = self.q_dim + self.g_dim + self.p_h + self.p_w;
if self.p != want_p {
crate::gpu_bail!(
"bms_flex inputs: p={} != q_dim({}) + g_dim({}) + p_h({}) + p_w({}) = {}",
self.p,
self.q_dim,
self.g_dim,
self.p_h,
self.p_w,
want_p
);
}
if self.r != 2 + self.p_h + self.p_w {
crate::gpu_bail!(
"bms_flex inputs: r={} != 2 + p_h({}) + p_w({}) = {}",
self.r,
self.p_h,
self.p_w,
2 + self.p_h + self.p_w
);
}
if self.beta.len() != self.p {
crate::gpu_bail!(
"bms_flex inputs: beta.len()={} != p={}",
self.beta.len(),
self.p
);
}
if self.y.len() != self.n {
crate::gpu_bail!("bms_flex inputs: y.len()={} != n={}", self.y.len(), self.n);
}
if self.weights.len() != self.n {
crate::gpu_bail!(
"bms_flex inputs: weights.len()={} != n={}",
self.weights.len(),
self.n
);
}
if self.x_design.len() != self.n * self.q_dim {
crate::gpu_bail!(
"bms_flex inputs: x_design.len()={} != n({})*q_dim({}) = {}",
self.x_design.len(),
self.n,
self.q_dim,
self.n * self.q_dim
);
}
if self.g_design.len() != self.n * self.g_dim {
crate::gpu_bail!(
"bms_flex inputs: g_design.len()={} != n({})*g_dim({}) = {}",
self.g_design.len(),
self.n,
self.g_dim,
self.n * self.g_dim
);
}
Ok(())
}
fn as_row_kernel_inputs(&self) -> BmsFlexRowKernelInputs<'a> {
BmsFlexRowKernelInputs {
n_rows: self.n,
r: self.r,
p_h: self.p_h,
p_w: self.p_w,
q: self.q,
b: self.b,
mu_1: self.mu_1,
mu_2: self.mu_2,
z_obs: self.z_obs,
y: self.y,
w: self.weights,
s_f: self.s_f,
cell_offsets: self.cell_offsets,
cell_c0: self.cell_c0,
cell_c1: self.cell_c1,
cell_c2: self.cell_c2,
cell_c3: self.cell_c3,
cell_a: self.cell_a,
cell_aa: self.cell_aa,
cell_r: self.cell_r,
cell_ar: self.cell_ar,
cell_sbb: self.cell_sbb,
cell_sbh: self.cell_sbh,
cell_sbw: self.cell_sbw,
cell_moments: crate::gpu::bms_flex_row::CellMomentsSource::Host(self.cell_moments),
chi_obs: self.chi_obs,
xi_obs: self.xi_obs,
rho_u: self.rho_u,
tau_u: self.tau_u,
r_uv: self.r_uv,
}
}
}
#[cfg(target_os = "linux")]
const PROBE_KERNEL_SOURCE: &str = r#"
extern "C" __global__ void bms_flex_probe() {
// Intentionally empty. This kernel exists only so the scaffolding can
// verify NVRTC compile + module load + launch + synchronize on the
// selected device. The real row math lives in the bms_flex_row module.
}
"#;
#[must_use]
pub struct BmsFlexGpuBackend {
#[cfg(target_os = "linux")]
inner: BmsFlexGpuContextLinux,
}
#[cfg(target_os = "linux")]
struct BmsFlexGpuContextLinux {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
module: PtxModuleCache,
arena: Mutex<DeviceArena>,
}
impl BmsFlexGpuBackend {
pub const fn compiled() -> bool {
cfg!(target_os = "linux")
}
pub fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<BmsFlexGpuBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
#[cfg(target_os = "linux")]
{
Self::probe_linux()
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "bms_flex GPU backend is Linux-only".to_string(),
})
}
})
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
fn probe_linux() -> Result<Self, GpuError> {
let runtime = super::runtime::GpuRuntime::global().ok_or_else(|| {
GpuError::DriverLibraryUnavailable {
reason: "bms_flex backend: no CUDA runtime available".to_string(),
}
})?;
let ctx = super::runtime::cuda_context_for(runtime.selected_device().ordinal).ok_or_else(
|| {
gpu_err!(
"bms_flex backend: failed to create CUDA context for device {}",
runtime.selected_device().ordinal
)
},
)?;
let stream = ctx.default_stream();
let backend = BmsFlexGpuBackend {
inner: BmsFlexGpuContextLinux {
ctx,
stream,
module: PtxModuleCache::new(),
arena: Mutex::new(DeviceArena::default()),
},
};
backend.compile_probe_module()?;
Ok(backend)
}
#[cfg(target_os = "linux")]
fn compile_probe_module(&self) -> Result<&Arc<CudaModule>, GpuError> {
self.inner
.module
.get_or_compile(&self.inner.ctx, "bms_flex", PROBE_KERNEL_SOURCE)
}
#[cfg(target_os = "linux")]
pub fn launch_probe(&self) -> Result<(), GpuError> {
use cudarc::driver::LaunchConfig;
let module = self.compile_probe_module()?;
let func = module
.load_function("bms_flex_probe")
.gpu_ctx("bms_flex probe load_function")?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = self.inner.stream.launch_builder(&func);
unsafe { builder.launch(cfg) }.gpu_ctx("bms_flex probe launch")?;
self.inner
.stream
.synchronize()
.gpu_ctx("bms_flex probe synchronize")?;
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn launch_probe(&self) -> Result<(), GpuError> {
Err(GpuError::DriverLibraryUnavailable {
reason: "bms_flex GPU backend is Linux-only".to_string(),
})
}
#[cfg(target_os = "linux")]
pub fn arena_round_trip(&self, elements: usize) -> Result<usize, GpuError> {
let mut guard = self
.inner
.arena
.lock()
.gpu_ctx("bms_flex arena mutex poisoned")?;
let (bucket, slab) = guard.alloc(&self.inner.stream, elements, "bms_flex")?;
guard.release(bucket, slab);
Ok(bucket)
}
pub fn describe(&self) -> String {
#[cfg(target_os = "linux")]
{
return format!(
"bms_flex backend: device={:?} module_loaded={}",
self.inner.ctx.name().ok(),
self.inner.module.get().is_some()
);
}
#[cfg(not(target_os = "linux"))]
{
"bms_flex backend: unavailable (not Linux)".to_string()
}
}
}
#[inline]
fn project_v_through_row(
v: &[f64],
x_row: &[f64],
g_row: &[f64],
q_dim: usize,
g_dim: usize,
p_h: usize,
p_w: usize,
r: usize,
) -> Vec<f64> {
let mut out = vec![0.0f64; r];
let mut acc_q = 0.0f64;
for k in 0..q_dim {
acc_q += x_row[k] * v[k];
}
out[0] = acc_q;
let mut acc_g = 0.0f64;
for k in 0..g_dim {
acc_g += g_row[k] * v[q_dim + k];
}
out[1] = acc_g;
for j in 0..p_h {
out[2 + j] = v[q_dim + g_dim + j];
}
for j in 0..p_w {
out[2 + p_h + j] = v[q_dim + g_dim + p_h + j];
}
out
}
#[inline]
fn accumulate_row_vector_pullback(
t: &[f64],
x_row: &[f64],
g_row: &[f64],
q_dim: usize,
g_dim: usize,
p_h: usize,
p_w: usize,
out: &mut [f64],
) {
let t0 = t[0];
for k in 0..q_dim {
out[k] += t0 * x_row[k];
}
let t1 = t[1];
for k in 0..g_dim {
out[q_dim + k] += t1 * g_row[k];
}
for j in 0..p_h {
out[q_dim + g_dim + j] += t[2 + j];
}
for j in 0..p_w {
out[q_dim + g_dim + p_h + j] += t[2 + p_h + j];
}
}
#[inline]
fn accumulate_row_hessian_pullback(
hess_row: &[f64], r: usize,
x_row: &[f64],
g_row: &[f64],
q_dim: usize,
g_dim: usize,
p_h: usize,
p_w: usize,
p: usize,
out: &mut Array2<f64>,
) {
let mut phi: Vec<Vec<f64>> = Vec::with_capacity(r);
let mut phi0 = vec![0.0f64; p];
for k in 0..q_dim {
phi0[k] = x_row[k];
}
phi.push(phi0);
let mut phi1 = vec![0.0f64; p];
for k in 0..g_dim {
phi1[q_dim + k] = g_row[k];
}
phi.push(phi1);
for j in 0..p_h {
let mut row = vec![0.0f64; p];
row[q_dim + g_dim + j] = 1.0;
phi.push(row);
}
for j in 0..p_w {
let mut row = vec![0.0f64; p];
row[q_dim + g_dim + p_h + j] = 1.0;
phi.push(row);
}
assert_eq!(phi.len(), r);
for u in 0..r {
for v in 0..r {
let h_uv = hess_row[u * r + v];
if h_uv == 0.0 {
continue;
}
let phi_u = &phi[u];
let phi_v = &phi[v];
for m in 0..p {
let pm = phi_u[m];
if pm == 0.0 {
continue;
}
let scaled = h_uv * pm;
for n in 0..p {
out[[m, n]] += scaled * phi_v[n];
}
}
}
}
}
pub fn gpu_gradient(inputs: BmsFlexGpuRowInputs<'_>) -> Result<(f64, Vec<f64>), GpuError> {
inputs.validate()?;
BmsFlexGpuBackend::probe()?;
let outputs = launch_bms_flex_row_kernel(inputs.as_row_kernel_inputs())?;
let BmsFlexRowKernelOutputs { neglog, grad, .. } = outputs;
let n = inputs.n;
let r = inputs.r;
let p = inputs.p;
let q_dim = inputs.q_dim;
let g_dim = inputs.g_dim;
let p_h = inputs.p_h;
let p_w = inputs.p_w;
let mut neglog_sum = 0.0f64;
for v in &neglog {
neglog_sum += *v;
}
let mut joint_grad = vec![0.0f64; p];
for i in 0..n {
let row_grad = &grad[i * r..(i + 1) * r];
let x_row = &inputs.x_design[i * q_dim..(i + 1) * q_dim];
let g_row = &inputs.g_design[i * g_dim..(i + 1) * g_dim];
accumulate_row_vector_pullback(
row_grad,
x_row,
g_row,
q_dim,
g_dim,
p_h,
p_w,
&mut joint_grad,
);
}
Ok((neglog_sum, joint_grad))
}
pub fn gpu_hessian_matvec(
inputs: BmsFlexGpuRowInputs<'_>,
v: &[f64],
) -> Result<Vec<f64>, GpuError> {
inputs.validate()?;
if v.len() != inputs.p {
crate::gpu_bail!(
"bms_flex gpu_hessian_matvec: v.len()={} != p={}",
v.len(),
inputs.p
);
}
BmsFlexGpuBackend::probe()?;
let outputs = launch_bms_flex_row_kernel(inputs.as_row_kernel_inputs())?;
let BmsFlexRowKernelOutputs { hess, .. } = outputs;
let n = inputs.n;
let r = inputs.r;
let p = inputs.p;
let q_dim = inputs.q_dim;
let g_dim = inputs.g_dim;
let p_h = inputs.p_h;
let p_w = inputs.p_w;
let mut out = vec![0.0f64; p];
for i in 0..n {
let x_row = &inputs.x_design[i * q_dim..(i + 1) * q_dim];
let g_row = &inputs.g_design[i * g_dim..(i + 1) * g_dim];
let pv = project_v_through_row(v, x_row, g_row, q_dim, g_dim, p_h, p_w, r);
let hess_row = &hess[i * r * r..(i + 1) * r * r];
let mut t = vec![0.0f64; r];
for u in 0..r {
let row = &hess_row[u * r..(u + 1) * r];
let mut acc = 0.0f64;
for w in 0..r {
acc += row[w] * pv[w];
}
t[u] = acc;
}
accumulate_row_vector_pullback(&t, x_row, g_row, q_dim, g_dim, p_h, p_w, &mut out);
}
Ok(out)
}
pub fn gpu_hessian_dense(inputs: BmsFlexGpuRowInputs<'_>) -> Result<Array2<f64>, GpuError> {
inputs.validate()?;
BmsFlexGpuBackend::probe()?;
let outputs = launch_bms_flex_row_kernel(inputs.as_row_kernel_inputs())?;
let BmsFlexRowKernelOutputs { hess, .. } = outputs;
let n = inputs.n;
let r = inputs.r;
let p = inputs.p;
let q_dim = inputs.q_dim;
let g_dim = inputs.g_dim;
let p_h = inputs.p_h;
let p_w = inputs.p_w;
let mut out = Array2::<f64>::zeros((p, p));
for i in 0..n {
let x_row = &inputs.x_design[i * q_dim..(i + 1) * q_dim];
let g_row = &inputs.g_design[i * g_dim..(i + 1) * g_dim];
let hess_row = &hess[i * r * r..(i + 1) * r * r];
accumulate_row_hessian_pullback(
hess_row, r, x_row, g_row, q_dim, g_dim, p_h, p_w, p, &mut out,
);
}
Ok(out)
}
#[cfg(test)]
mod bms_flex_gpu_tests {
use super::*;
struct ScratchBuffers {
beta: Vec<f64>,
y: Vec<f64>,
w: Vec<f64>,
q: Vec<f64>,
b: Vec<f64>,
mu_1: Vec<f64>,
mu_2: Vec<f64>,
z_obs: Vec<f64>,
cell_offsets: Vec<u32>,
cell_c0: Vec<f64>,
cell_c1: Vec<f64>,
cell_c2: Vec<f64>,
cell_c3: Vec<f64>,
cell_a: Vec<f64>,
cell_aa: Vec<f64>,
cell_r: Vec<f64>,
cell_ar: Vec<f64>,
cell_sbb: Vec<f64>,
cell_sbh: Vec<f64>,
cell_sbw: Vec<f64>,
cell_moments: Vec<f64>,
chi_obs: Vec<f64>,
xi_obs: Vec<f64>,
rho_u: Vec<f64>,
tau_u: Vec<f64>,
r_uv: Vec<f64>,
x_design: Vec<f64>,
g_design: Vec<f64>,
}
fn zero_buffers(
n: usize,
q_dim: usize,
g_dim: usize,
p_h: usize,
p_w: usize,
) -> ScratchBuffers {
let r = 2 + p_h + p_w;
let p = q_dim + g_dim + p_h + p_w;
let cells_per_row = 1usize;
let total_cells = n * cells_per_row;
let mut cell_offsets = Vec::with_capacity(n + 1);
for i in 0..=n {
cell_offsets.push((i * cells_per_row) as u32);
}
let r_minus_1 = r.saturating_sub(1);
ScratchBuffers {
beta: vec![0.0; p],
y: vec![0.0; n],
w: vec![1.0; n],
q: vec![0.0; n],
b: vec![1.0; n],
mu_1: vec![1.0; n],
mu_2: vec![0.0; n],
z_obs: vec![0.0; n],
cell_offsets,
cell_c0: vec![0.0; total_cells],
cell_c1: vec![0.0; total_cells],
cell_c2: vec![0.0; total_cells],
cell_c3: vec![0.0; total_cells],
cell_a: vec![0.0; total_cells * 4],
cell_aa: vec![0.0; total_cells * 4],
cell_r: vec![0.0; total_cells * r_minus_1 * 4],
cell_ar: vec![0.0; total_cells * r_minus_1 * 4],
cell_sbb: vec![0.0; total_cells * 4],
cell_sbh: vec![0.0; total_cells * p_h * 4],
cell_sbw: vec![0.0; total_cells * p_w * 4],
cell_moments: vec![0.0; total_cells * 10],
chi_obs: vec![0.0; n],
xi_obs: vec![0.0; n],
rho_u: vec![0.0; n * r],
tau_u: vec![0.0; n * r],
r_uv: vec![0.0; n * r * r],
x_design: vec![0.0; n * q_dim],
g_design: vec![0.0; n * g_dim],
}
}
fn inputs_from<'a>(
bufs: &'a ScratchBuffers,
n: usize,
q_dim: usize,
g_dim: usize,
p_h: usize,
p_w: usize,
) -> BmsFlexGpuRowInputs<'a> {
let r = 2 + p_h + p_w;
let p = q_dim + g_dim + p_h + p_w;
BmsFlexGpuRowInputs {
n,
r,
p,
q_dim,
g_dim,
p_h,
p_w,
beta: &bufs.beta,
y: &bufs.y,
weights: &bufs.w,
q: &bufs.q,
b: &bufs.b,
mu_1: &bufs.mu_1,
mu_2: &bufs.mu_2,
z_obs: &bufs.z_obs,
s_f: 1.0,
cell_offsets: &bufs.cell_offsets,
cell_c0: &bufs.cell_c0,
cell_c1: &bufs.cell_c1,
cell_c2: &bufs.cell_c2,
cell_c3: &bufs.cell_c3,
cell_a: &bufs.cell_a,
cell_aa: &bufs.cell_aa,
cell_r: &bufs.cell_r,
cell_ar: &bufs.cell_ar,
cell_sbb: &bufs.cell_sbb,
cell_sbh: &bufs.cell_sbh,
cell_sbw: &bufs.cell_sbw,
cell_moments: &bufs.cell_moments,
chi_obs: &bufs.chi_obs,
xi_obs: &bufs.xi_obs,
rho_u: &bufs.rho_u,
tau_u: &bufs.tau_u,
r_uv: &bufs.r_uv,
x_design: &bufs.x_design,
g_design: &bufs.g_design,
}
}
#[test]
fn bms_flex_gpu_policy_decision_is_explicit() {
let decision = row_primary_hessian_decision(50_000, 4);
assert_eq!(decision.kernel, GpuKernel::MarginalSlopeRows);
}
#[test]
fn bms_flex_gpu_gradient_routes_through_row_kernel_or_clean_error() {
let n = 4;
let bufs = zero_buffers(n, 1, 1, 0, 0);
let inputs = inputs_from(&bufs, n, 1, 1, 0, 0);
match gpu_gradient(inputs) {
Err(GpuError::DriverLibraryUnavailable { .. })
| Err(GpuError::DriverCallFailed { .. })
| Err(GpuError::DriverSymbolMissing { .. })
| Err(GpuError::NotYetImplemented { .. }) => {}
Err(other) => panic!("unexpected GpuError variant: {other:?}"),
Ok((neglog, grad)) => {
assert!(neglog.is_finite() || neglog.is_nan(), "neglog: {neglog}");
assert_eq!(grad.len(), 2);
}
}
}
#[test]
fn bms_flex_gpu_hessian_matvec_rejects_wrong_v_length() {
let n = 4;
let bufs = zero_buffers(n, 1, 1, 0, 0);
let inputs = inputs_from(&bufs, n, 1, 1, 0, 0);
let v_wrong = vec![0.0; inputs.p + 1];
match gpu_hessian_matvec(inputs, &v_wrong) {
Err(GpuError::DriverCallFailed { reason }) => {
assert!(
reason.contains("v.len()"),
"expected v.len() mismatch message, got: {reason}"
);
}
other => panic!("expected v.len() mismatch, got {other:?}"),
}
}
#[test]
fn bms_flex_gpu_hessian_dense_routes_through_row_kernel_or_clean_error() {
let n = 4;
let bufs = zero_buffers(n, 1, 1, 0, 0);
let inputs = inputs_from(&bufs, n, 1, 1, 0, 0);
match gpu_hessian_dense(inputs) {
Err(GpuError::DriverLibraryUnavailable { .. })
| Err(GpuError::DriverCallFailed { .. })
| Err(GpuError::DriverSymbolMissing { .. })
| Err(GpuError::NotYetImplemented { .. }) => {}
Err(other) => panic!("unexpected GpuError variant: {other:?}"),
Ok(h) => {
let p = 2usize;
assert_eq!(h.shape(), &[p, p]);
}
}
}
#[test]
fn bms_flex_gpu_inputs_validate_catches_shape_mismatches() {
let n = 4;
let mut bufs = zero_buffers(n, 1, 1, 0, 0);
bufs.beta.push(0.0); let bad = BmsFlexGpuRowInputs {
beta: &bufs.beta,
..inputs_from(&bufs, n, 1, 1, 0, 0)
};
let err = bad.validate().expect_err("beta length mismatch must fail");
assert!(
matches!(err, GpuError::DriverCallFailed { .. }),
"expected DriverCallFailed, got {err:?}"
);
}
#[test]
fn bms_flex_gpu_hessian_dense_pullback_matches_cpu_reference() {
let n = 8;
let q_dim = 2;
let g_dim = 2;
let p_h = 0;
let p_w = 0;
let r = 2 + p_h + p_w;
let p = q_dim + g_dim + p_h + p_w;
let mut x_design = Vec::with_capacity(n * q_dim);
let mut g_design = Vec::with_capacity(n * g_dim);
let mut hess_flat = Vec::with_capacity(n * r * r);
for i in 0..n {
let f = (i as f64) + 1.0;
x_design.push(1.0);
x_design.push(f);
g_design.push(1.0);
g_design.push(f.cos());
let h00 = 2.0 + f;
let h01 = 0.1 * f;
let h11 = 3.0 + 0.5 * f;
hess_flat.push(h00);
hess_flat.push(h01);
hess_flat.push(h01);
hess_flat.push(h11);
}
let mut h_cpu = Array2::<f64>::zeros((p, p));
for i in 0..n {
let x_row = &x_design[i * q_dim..(i + 1) * q_dim];
let g_row = &g_design[i * g_dim..(i + 1) * g_dim];
let h00 = hess_flat[i * r * r];
let h01 = hess_flat[i * r * r + 1];
let h10 = hess_flat[i * r * r + 2];
let h11 = hess_flat[i * r * r + 3];
for a in 0..q_dim {
for b in 0..q_dim {
h_cpu[[a, b]] += h00 * x_row[a] * x_row[b];
}
}
for a in 0..g_dim {
for b in 0..g_dim {
h_cpu[[q_dim + a, q_dim + b]] += h11 * g_row[a] * g_row[b];
}
}
for a in 0..q_dim {
for b in 0..g_dim {
h_cpu[[a, q_dim + b]] += h01 * x_row[a] * g_row[b];
}
}
for a in 0..g_dim {
for b in 0..q_dim {
h_cpu[[q_dim + a, b]] += h10 * g_row[a] * x_row[b];
}
}
}
let mut h_via_helper = Array2::<f64>::zeros((p, p));
for i in 0..n {
let x_row = &x_design[i * q_dim..(i + 1) * q_dim];
let g_row = &g_design[i * g_dim..(i + 1) * g_dim];
let hess_row = &hess_flat[i * r * r..(i + 1) * r * r];
accumulate_row_hessian_pullback(
hess_row,
r,
x_row,
g_row,
q_dim,
g_dim,
p_h,
p_w,
p,
&mut h_via_helper,
);
}
for m in 0..p {
for nn in 0..p {
let a = h_cpu[[m, nn]];
let b = h_via_helper[[m, nn]];
let diff = (a - b).abs();
assert!(
diff <= 1e-12 * a.abs().max(b.abs()).max(1.0),
"pullback parity mismatch at ({m},{nn}): cpu={a} helper={b} diff={diff}"
);
}
}
let bufs = {
let mut b = zero_buffers(n, q_dim, g_dim, p_h, p_w);
b.x_design = x_design.clone();
b.g_design = g_design.clone();
b
};
let inputs = inputs_from(&bufs, n, q_dim, g_dim, p_h, p_w);
match gpu_hessian_dense(inputs) {
Err(GpuError::DriverLibraryUnavailable { .. })
| Err(GpuError::DriverCallFailed { .. })
| Err(GpuError::DriverSymbolMissing { .. })
| Err(GpuError::NotYetImplemented { .. }) => {}
Err(other) => panic!("unexpected GpuError variant: {other:?}"),
Ok(h) => {
assert_eq!(h.shape(), &[p, p]);
}
}
}
#[test]
fn bms_flex_gpu_context_initialises_when_device_present() {
let Some(runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[bms_flex_gpu test] no CUDA runtime — skipping device-side init smoketest");
return;
};
eprintln!(
"[bms_flex_gpu test] runtime selected device ordinal={}",
runtime.selected_device().ordinal
);
let backend = BmsFlexGpuBackend::probe().unwrap_or_else(|err| {
panic!("BmsFlexGpuBackend::probe failed on a host that reports a CUDA runtime: {err}")
});
eprintln!("[bms_flex_gpu test] {}", backend.describe());
backend
.launch_probe()
.expect("probe kernel must launch+sync on a host with a usable device");
#[cfg(target_os = "linux")]
{
let bucket = backend
.arena_round_trip(1024)
.expect("arena round-trip must succeed on a host with a usable device");
assert!(bucket >= 1024, "bucket must be >= requested elements");
let bucket2 = backend
.arena_round_trip(1024)
.expect("arena round-trip must succeed on a host with a usable device");
assert_eq!(bucket, bucket2, "bucket size must be stable for same input");
}
}
}