use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::error::GpuError;
use super::pirls_row::{CurvatureMode, PirlsRowFamily};
pub struct SigmaPointGpuInput {
pub s_transformed: Array2<f64>,
pub qs: Array2<f64>,
pub linear_shift: Array1<f64>,
pub constant_shift: f64,
}
#[cfg(target_os = "linux")]
const STREAM_POOL_MAX: usize = 8;
#[cfg(target_os = "linux")]
#[inline]
fn pool_size(m: usize) -> usize {
m.min(STREAM_POOL_MAX).max(1)
}
pub fn try_gpu_sigma_stream_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
admission: crate::gpu::policy::PirlsLoopAdmission,
gamma_shape: f64,
convergence_tol: f64,
max_iter: usize,
) -> Result<Option<Vec<Option<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>, GpuError> {
if per_sigma.is_empty() {
return Ok(Some(Vec::new()));
}
#[cfg(target_os = "linux")]
{
if crate::gpu::runtime::GpuRuntime::global().is_none() {
return Ok(None);
}
let Some(family_kind) = admission.family else {
return Ok(None);
};
let Some(family) = linux_impl::family_kind_to_row(family_kind) else {
return Err(crate::gpu_err!(
"sigma stream pool: family not in JIT-cached set"
));
};
let curvature = linux_impl::curvature_kind_to_row(admission.curvature);
return linux_impl::stream_pool_eval(
x_original,
y,
prior_w,
offset,
per_sigma,
family,
curvature,
gamma_shape,
convergence_tol,
max_iter,
);
}
#[cfg(not(target_os = "linux"))]
{
log::trace!(
"[sigma stream pool] non-Linux target: skipping dispatch \
(x_original={}x{}, y_len={}, prior_w_len={}, offset_len={}, \
n_sigma={}, family={:?}, curvature={:?}, gamma_shape={}, \
tol={}, max_iter={})",
x_original.nrows(),
x_original.ncols(),
y.len(),
prior_w.len(),
offset.len(),
per_sigma.len(),
admission.family,
admission.curvature,
gamma_shape,
convergence_tol,
max_iter,
);
Ok(None)
}
}
#[cfg(target_os = "linux")]
mod linux_impl {
use crate::gpu::pirls_row::{CurvatureMode, PirlsRowFamily};
use crate::gpu::policy::{PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
use crate::gpu::sigma_cubature::SigmaPointGpuInput;
use crate::linalg::utils::matrix_inversewith_regularization;
use ndarray::{Array1, Array2, ArrayView1};
type SigmaPointResult = Option<(Array2<f64>, Array1<f64>)>;
pub(super) fn family_kind_to_row(f: PirlsLoopFamilyKind) -> Option<PirlsRowFamily> {
match f {
PirlsLoopFamilyKind::BernoulliLogit => Some(PirlsRowFamily::BernoulliLogit),
PirlsLoopFamilyKind::BernoulliProbit => Some(PirlsRowFamily::BernoulliProbit),
PirlsLoopFamilyKind::BernoulliCLogLog => Some(PirlsRowFamily::BernoulliCLogLog),
PirlsLoopFamilyKind::PoissonLog => Some(PirlsRowFamily::PoissonLog),
PirlsLoopFamilyKind::GaussianIdentity => Some(PirlsRowFamily::GaussianIdentity),
PirlsLoopFamilyKind::GammaLog => Some(PirlsRowFamily::GammaLog),
}
}
pub(super) fn curvature_kind_to_row(c: PirlsLoopCurvatureKind) -> CurvatureMode {
match c {
PirlsLoopCurvatureKind::Fisher => CurvatureMode::Fisher,
PirlsLoopCurvatureKind::Observed => CurvatureMode::Observed,
}
}
fn hessian_to_original(
h_transformed: &ndarray::Array2<f64>,
qs: &ndarray::Array2<f64>,
) -> ndarray::Array2<f64> {
let tmp = qs.dot(h_transformed);
let mut h_orig = tmp.dot(&qs.t());
crate::families::custom_family::symmetrize_dense_in_place(&mut h_orig);
h_orig
}
pub(super) fn stream_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
family: PirlsRowFamily,
curvature: CurvatureMode,
gamma_shape: f64,
convergence_tol: f64,
max_iter: usize,
) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
use crate::gpu::sigma_cubature::pool_size;
use crate::solver::gpu::pirls_gpu;
let m = per_sigma.len();
let p = x_original.ncols();
for (idx, pt) in per_sigma.iter().enumerate() {
if pt.s_transformed.shape() != [p, p] || pt.qs.shape() != [p, p] {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] shape mismatch against point[0]"
));
}
}
if family == PirlsRowFamily::GaussianIdentity {
return gaussian_sigma_pool_eval(x_original, y, prior_w, offset, per_sigma, p);
}
let bootstrap_shared =
pirls_gpu::upload_shared_pirls_gpu(x_original, y, prior_w, offset)
.map_err(|e| crate::gpu_err!("sigma stream pool bootstrap upload: {e}"))?;
let n_streams = pool_size(m);
let mut workspace_pairs: Vec<(
crate::solver::gpu::pirls_gpu::SigmaPirlsGpuWorkspace,
crate::solver::gpu::pirls_gpu::cuda::PirlsLoopWorkspace,
)> = Vec::with_capacity(n_streams);
for _ in 0..n_streams {
let ws = pirls_gpu::allocate_sigma_pirls_workspace(&bootstrap_shared)
.map_err(|e| crate::gpu_err!("sigma stream pool alloc workspace: {e}"))?;
let loop_ws = pirls_gpu::allocate_pirls_loop_workspace(&bootstrap_shared, &ws)
.map_err(|e| crate::gpu_err!("sigma stream pool alloc loop_ws: {e}"))?;
workspace_pairs.push((ws, loop_ws));
}
let beta0: Array1<f64> = Array1::zeros(p);
let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(m);
for (idx, pt) in per_sigma.iter().enumerate() {
let stream_idx = idx % n_streams;
let (ws, loop_ws) = &mut workspace_pairs[stream_idx];
pirls_gpu::upload_qs_pirls(ws, pt.qs.view())
.map_err(|e| crate::gpu_err!("sigma stream pool upload Qs pt[{idx}]: {e}"))?;
let shared = &bootstrap_shared;
let outcome = pirls_gpu::pirls_loop_on_stream(
shared,
ws,
loop_ws,
family,
curvature,
gamma_shape,
beta0.view(),
pt.s_transformed.view(),
pt.linear_shift.view(),
pt.constant_shift,
1e-6,
0.0,
max_iter,
convergence_tol,
None,
);
let sigma_result = match outcome {
Ok(loop_out) => {
let h_orig = hessian_to_original(&loop_out.penalized_hessian, &pt.qs);
let cov = matrix_inversewith_regularization(&h_orig, "gpu sigma point")
.ok_or_else(|| {
crate::gpu_err!(
"gpu sigma point: penalised Hessian inverse not well-defined"
)
})?;
let beta_orig = pt.qs.dot(&loop_out.beta);
Some((cov, beta_orig))
}
Err(e) => {
log::warn!(
"[sigma-cubature gpu] point[{idx}] pirls_loop_on_stream failed: {e}"
);
None
}
};
outcomes.push(sigma_result);
}
Ok(Some(outcomes))
}
fn gaussian_sigma_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
p: usize,
) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
use ndarray::Array1;
let xtwx = crate::solver::gpu::pirls_gpu::weighted_crossprod_gpu(x_original, prior_w)
.map_err(|e| crate::gpu_err!("gaussian sigma: XᵀWX gpu failed: {e}"))?;
let mut yw = y.to_owned();
yw -= &offset;
yw *= &prior_w;
let xtwy: Array1<f64> = x_original.t().dot(&yw);
let prior_mean_zero: Array1<f64> = Array1::zeros(p);
let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(per_sigma.len());
for (idx, pt) in per_sigma.iter().enumerate() {
let pls = crate::solver::gpu::pirls_gpu::solve_gaussian_pls_gpu(
xtwx.view(),
xtwy.view(),
pt.s_transformed.view(),
pt.linear_shift.view(),
prior_mean_zero.view(),
0.0,
Some(pt.qs.view()),
)
.map_err(|e| crate::gpu_err!("gaussian sigma pool: point[{idx}] pls failed: {e}"))?;
let h_orig = hessian_to_original(&pls.penalized_hessian, &pt.qs);
let cov = matrix_inversewith_regularization(&h_orig, "gaussian sigma point")
.ok_or_else(|| {
crate::gpu_err!(
"gaussian sigma point: penalised Hessian inverse not well-defined"
)
})?;
let beta_orig = pt.qs.dot(&pls.beta);
outcomes.push(Some((cov, beta_orig)));
}
Ok(Some(outcomes))
}
}
pub struct SigmaPointInput<'a> {
pub eta: ArrayView1<'a, f64>,
pub y: ArrayView1<'a, f64>,
pub prior_w: ArrayView1<'a, f64>,
pub beta: ArrayView1<'a, f64>,
pub hessian_inv: ArrayView2<'a, f64>,
}
pub struct SigmaCubatureBatch<'a> {
pub family: PirlsRowFamily,
pub curvature: CurvatureMode,
pub points: &'a [SigmaPointInput<'a>],
}
impl<'a> SigmaCubatureBatch<'a> {
#[inline]
pub fn m(&self) -> usize {
self.points.len()
}
pub fn check_shape(&self) -> Result<(usize, usize), GpuError> {
let first = self
.points
.first()
.ok_or_else(|| crate::gpu_err!("sigma_cubature batch is empty"))?;
let n = first.y.len();
let p = first.beta.len();
if first.eta.len() != n || first.prior_w.len() != n || first.hessian_inv.shape() != [p, p] {
return Err(crate::gpu_err!(
"sigma_cubature batch[0] shape mismatch: n={}, p={}, eta={}, prior_w={}, hessian_inv={:?}",
n,
p,
first.eta.len(),
first.prior_w.len(),
first.hessian_inv.shape()
));
}
for (idx, point) in self.points.iter().enumerate().skip(1) {
if point.eta.len() != n
|| point.y.len() != n
|| point.prior_w.len() != n
|| point.beta.len() != p
|| point.hessian_inv.shape() != [p, p]
{
return Err(crate::gpu_err!(
"sigma_cubature batch[{idx}] shape mismatch against batch[0] n={n}, p={p}"
));
}
}
Ok((n, p))
}
}
pub type DeviceSigmaPoint = (Array2<f64>, Array1<f64>);
pub fn try_device_sigma_eval(
batch: &SigmaCubatureBatch<'_>,
) -> Result<Option<Vec<DeviceSigmaPoint>>, GpuError> {
batch.check_shape()?;
if batch.m() == 0 {
return Err(crate::gpu_err!(
"try_device_sigma_eval: empty sigma batch (caller must filter)"
));
}
#[cfg(target_os = "linux")]
{
if !super::runtime::GpuRuntime::is_available() {
return Ok(None);
}
let points: Vec<DeviceSigmaPoint> = batch
.points
.iter()
.map(|pt| (pt.hessian_inv.to_owned(), pt.beta.to_owned()))
.collect();
Ok(Some(points))
}
#[cfg(not(target_os = "linux"))]
{
Ok(None)
}
}
pub fn try_device_moment_reduce(
points: &[DeviceSigmaPoint],
p: usize,
) -> Result<Option<Array2<f64>>, GpuError> {
if points.is_empty() {
return Err(crate::gpu_err!(
"try_device_moment_reduce: empty points (caller must guard)"
));
}
for (idx, (a, b)) in points.iter().enumerate() {
if a.shape() != [p, p] {
return Err(crate::gpu_err!(
"try_device_moment_reduce: A[{idx}] shape {:?} != [{p}, {p}]",
a.shape()
));
}
if b.len() != p {
return Err(crate::gpu_err!(
"try_device_moment_reduce: b[{idx}] len {} != {p}",
b.len()
));
}
}
#[cfg(target_os = "linux")]
{
if !super::runtime::GpuRuntime::is_available() {
return Ok(None);
}
Some(linux::moment_reduce_linux(points, p)).transpose()
}
#[cfg(not(target_os = "linux"))]
{
Ok(None)
}
}
pub fn try_device_sigma_eval_batched(
batch: &SigmaCubatureBatch<'_>,
) -> Result<Option<Vec<DeviceSigmaPoint>>, GpuError> {
let (_n_rows, p) = batch.check_shape()?;
if batch.m() < BATCHED_DISPATCH_MIN_M {
return Ok(None);
}
if p > BATCHED_DISPATCH_MAX_P {
return Ok(None);
}
#[cfg(target_os = "linux")]
{
if !super::runtime::GpuRuntime::is_available() {
return Ok(None);
}
let points: Vec<DeviceSigmaPoint> = batch
.points
.iter()
.map(|pt| (pt.hessian_inv.to_owned(), pt.beta.to_owned()))
.collect();
Ok(Some(points))
}
#[cfg(not(target_os = "linux"))]
{
Ok(None)
}
}
pub const BATCHED_DISPATCH_MIN_M: usize = 6;
pub const BATCHED_DISPATCH_MAX_P: usize = 96;
#[cfg(target_os = "linux")]
mod linux {
use super::DeviceSigmaPoint;
use crate::gpu::common::PtxModuleCache;
use crate::gpu::error::{GpuError, GpuResultExt};
use cudarc::driver::{LaunchConfig, PushKernelArg};
use ndarray::Array2;
static MOMENT_REDUCE_PTX: PtxModuleCache = PtxModuleCache::new();
const MOMENT_REDUCE_SRC: &str = r#"
extern "C" __global__ void sigma_mean_hinv(int M, int p, const double* __restrict__ A_in, double* __restrict__ out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = p * p;
if (idx >= total) return;
double acc = 0.0;
long long stride = (long long)p * (long long)p;
for (int m = 0; m < M; ++m) {
acc += A_in[(long long)m * stride + (long long)idx];
}
out[idx] = acc / (double)M;
}
extern "C" __global__ void sigma_mean_beta(int M, int p, const double* __restrict__ b_in, double* __restrict__ out) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= p) return;
double acc = 0.0;
for (int m = 0; m < M; ++m) {
acc += b_in[(long long)m * (long long)p + (long long)i];
}
out[i] = acc / (double)M;
}
extern "C" __global__ void sigma_second_beta(int M, int p, const double* __restrict__ b_in, double* __restrict__ out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = p * p;
if (idx >= total) return;
int i = idx / p;
int j = idx - i * p;
double acc = 0.0;
for (int m = 0; m < M; ++m) {
double bi = b_in[(long long)m * (long long)p + (long long)i];
double bj = b_in[(long long)m * (long long)p + (long long)j];
acc += bi * bj;
}
out[idx] = acc / (double)M;
}
"#;
pub(super) fn moment_reduce_linux(
points: &[DeviceSigmaPoint],
p: usize,
) -> Result<Array2<f64>, GpuError> {
let runtime = crate::gpu::runtime::GpuRuntime::global().ok_or_else(|| {
crate::gpu_err!("try_device_moment_reduce: GpuRuntime unavailable after probe accepted")
})?;
let ctx = crate::gpu::runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or_else(|| {
crate::gpu_err!(
"try_device_moment_reduce: CUDA context for ordinal {} unavailable",
runtime.selected_device().ordinal
)
})?;
ctx.bind_to_thread()
.gpu_ctx("try_device_moment_reduce bind_to_thread")?;
let stream = ctx.default_stream();
let m = points.len();
let module = MOMENT_REDUCE_PTX.get_or_compile(
&ctx,
"sigma_cubature_moment_reduce",
MOMENT_REDUCE_SRC,
)?;
let p2 = p * p;
let mut a_flat: Vec<f64> = Vec::with_capacity(m * p2);
let mut b_flat: Vec<f64> = Vec::with_capacity(m * p);
for (a, b) in points {
let a_slice = a
.as_slice()
.ok_or_else(|| crate::gpu_err!("A_m not contiguous in moment_reduce_linux"))?;
a_flat.extend_from_slice(a_slice);
let b_slice = b
.as_slice()
.ok_or_else(|| crate::gpu_err!("b_m not contiguous in moment_reduce_linux"))?;
b_flat.extend_from_slice(b_slice);
}
let a_dev = stream
.clone_htod(&a_flat)
.gpu_ctx("sigma_cubature htod A")?;
let b_dev = stream
.clone_htod(&b_flat)
.gpu_ctx("sigma_cubature htod b")?;
let mut mean_hinv_dev = stream
.alloc_zeros::<f64>(p2)
.gpu_ctx("sigma_cubature alloc mean_hinv")?;
let mut mean_beta_dev = stream
.alloc_zeros::<f64>(p)
.gpu_ctx("sigma_cubature alloc mean_beta")?;
let mut second_beta_dev = stream
.alloc_zeros::<f64>(p2)
.gpu_ctx("sigma_cubature alloc second_beta")?;
const THREADS: u32 = 128;
let m_i32 =
i32::try_from(m).map_err(|_| crate::gpu_err!("sigma_cubature M={m} overflows i32"))?;
let p_i32 =
i32::try_from(p).map_err(|_| crate::gpu_err!("sigma_cubature p={p} overflows i32"))?;
{
let func = module
.load_function("sigma_mean_hinv")
.gpu_ctx("sigma_cubature load sigma_mean_hinv")?;
let total = u32::try_from(p2)
.map_err(|_| crate::gpu_err!("sigma_cubature p*p={p2} overflows u32"))?;
let cfg = LaunchConfig {
grid_dim: (total.div_ceil(THREADS).max(1), 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&m_i32);
builder.arg(&p_i32);
builder.arg(&a_dev);
builder.arg(&mut mean_hinv_dev);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx("sigma_cubature launch sigma_mean_hinv")?;
}
{
let func = module
.load_function("sigma_mean_beta")
.gpu_ctx("sigma_cubature load sigma_mean_beta")?;
let total = u32::try_from(p)
.map_err(|_| crate::gpu_err!("sigma_cubature p={p} overflows u32"))?;
let cfg = LaunchConfig {
grid_dim: (total.div_ceil(THREADS).max(1), 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&m_i32);
builder.arg(&p_i32);
builder.arg(&b_dev);
builder.arg(&mut mean_beta_dev);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx("sigma_cubature launch sigma_mean_beta")?;
}
{
let func = module
.load_function("sigma_second_beta")
.gpu_ctx("sigma_cubature load sigma_second_beta")?;
let total = u32::try_from(p2)
.map_err(|_| crate::gpu_err!("sigma_cubature p*p={p2} overflows u32"))?;
let cfg = LaunchConfig {
grid_dim: (total.div_ceil(THREADS).max(1), 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(&m_i32);
builder.arg(&p_i32);
builder.arg(&b_dev);
builder.arg(&mut second_beta_dev);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.gpu_ctx("sigma_cubature launch sigma_second_beta")?;
}
let mean_hinv_host = stream
.clone_dtoh(&mean_hinv_dev)
.gpu_ctx("sigma_cubature dtoh mean_hinv")?;
let mean_beta_host = stream
.clone_dtoh(&mean_beta_dev)
.gpu_ctx("sigma_cubature dtoh mean_beta")?;
let second_beta_host = stream
.clone_dtoh(&second_beta_dev)
.gpu_ctx("sigma_cubature dtoh second_beta")?;
stream
.synchronize()
.gpu_ctx("sigma_cubature synchronize after dtoh")?;
let mean_hinv = Array2::from_shape_vec((p, p), mean_hinv_host).map_err(|err| {
crate::gpu_err!("sigma_cubature mean_hinv reshape failed (p={p}): {err}")
})?;
let second_beta = Array2::from_shape_vec((p, p), second_beta_host).map_err(|err| {
crate::gpu_err!("sigma_cubature second_beta reshape failed (p={p}): {err}")
})?;
let mut mean_outer = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
mean_outer[[i, j]] = mean_beta_host[i] * mean_beta_host[j];
}
}
Ok(mean_hinv + (second_beta - mean_outer))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn small_pair(diag: f64, beta: &[f64]) -> DeviceSigmaPoint {
let p = beta.len();
let mut a = Array2::<f64>::zeros((p, p));
for i in 0..p {
a[[i, i]] = diag;
}
(a, Array1::from(beta.to_vec()))
}
#[test]
fn moment_reduce_rejects_empty_input() {
let err = try_device_moment_reduce(&[], 3).unwrap_err();
match err {
GpuError::DriverCallFailed { reason } => {
assert!(reason.contains("empty points"));
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn moment_reduce_rejects_shape_mismatch() {
let pts = vec![small_pair(1.0, &[0.1, 0.2])];
let err = try_device_moment_reduce(&pts, 3).unwrap_err();
match err {
GpuError::DriverCallFailed { reason } => {
assert!(reason.contains("A[0] shape"));
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn moment_reduce_declines_when_runtime_unavailable() {
let pts = vec![small_pair(1.0, &[0.1, 0.2])];
let outcome = try_device_moment_reduce(&pts, 2);
if !crate::gpu::runtime::GpuRuntime::is_available() {
assert!(matches!(outcome, Ok(None)));
}
}
#[test]
fn batched_dispatch_below_breakeven_declines() {
let eta = array![0.0, 0.1];
let y = array![0.0, 1.0];
let prior_w = array![1.0, 1.0];
let beta = array![0.0, 0.0];
let hessian_inv = Array2::<f64>::eye(2);
let pts: Vec<SigmaPointInput<'_>> = (0..(BATCHED_DISPATCH_MIN_M - 1))
.map(|_| SigmaPointInput {
eta: eta.view(),
y: y.view(),
prior_w: prior_w.view(),
beta: beta.view(),
hessian_inv: hessian_inv.view(),
})
.collect();
let batch = SigmaCubatureBatch {
family: PirlsRowFamily::BernoulliLogit,
curvature: CurvatureMode::Fisher,
points: &pts,
};
let outcome = try_device_sigma_eval_batched(&batch).expect("preflight succeeds");
assert!(
outcome.is_none(),
"below-breakeven batch must decline cleanly"
);
}
#[test]
fn sigma_eval_declines_when_runtime_unavailable() {
let eta = array![0.0];
let y = array![1.0];
let prior_w = array![1.0];
let beta = array![0.0];
let hessian_inv = Array2::<f64>::eye(1);
let pts = vec![SigmaPointInput {
eta: eta.view(),
y: y.view(),
prior_w: prior_w.view(),
beta: beta.view(),
hessian_inv: hessian_inv.view(),
}];
let batch = SigmaCubatureBatch {
family: PirlsRowFamily::BernoulliLogit,
curvature: CurvatureMode::Fisher,
points: &pts,
};
let outcome = try_device_sigma_eval(&batch).expect("shape preflight succeeds");
if !crate::gpu::runtime::GpuRuntime::is_available() {
assert!(outcome.is_none());
}
}
}