use ndarray::{Array1, Array2, ArrayView1};
use crate::gpu::gpu_error::GpuError;
pub struct SigmaPointGpuInput {
pub s_transformed: Array2<f64>,
pub qs: Array2<f64>,
pub linear_shift: Array1<f64>,
pub constant_shift: f64,
}
#[cfg(target_os = "linux")]
const STREAM_POOL_MAX: usize = 8;
#[cfg(target_os = "linux")]
const SIGMA_PIRLS_INITIAL_LM_LAMBDA: f64 = 1e-6;
#[cfg(target_os = "linux")]
#[inline]
fn pool_size(m: usize) -> usize {
m.min(STREAM_POOL_MAX).max(1)
}
pub fn try_gpu_sigma_stream_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
admission: crate::gpu::policy::PirlsLoopAdmission,
gamma_shape: f64,
convergence_tol: f64,
max_iter: usize,
) -> Result<Option<Vec<Option<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>, GpuError> {
if per_sigma.is_empty() {
return Ok(Some(Vec::new()));
}
validate_sigma_point_inputs(x_original.ncols(), per_sigma)?;
#[cfg(target_os = "linux")]
{
if crate::gpu::device_runtime::GpuRuntime::global().is_none() {
return Ok(None);
}
let Some(family_kind) = admission.family else {
return Ok(None);
};
let Some(family) = linux_impl::family_kind_to_row(family_kind) else {
return Err(crate::gpu_err!(
"sigma stream pool: family not in JIT-cached set"
));
};
let curvature = linux_impl::curvature_kind_to_row(admission.curvature);
return linux_impl::stream_pool_eval(
x_original,
y,
prior_w,
offset,
per_sigma,
family,
curvature,
gamma_shape,
convergence_tol,
max_iter,
);
}
#[cfg(not(target_os = "linux"))]
{
log::trace!(
"[sigma stream pool] non-Linux target: skipping dispatch \
(x_original={}x{}, y_len={}, prior_w_len={}, offset_len={}, \
n_sigma={}, family={:?}, curvature={:?}, gamma_shape={}, \
tol={}, max_iter={})",
x_original.nrows(),
x_original.ncols(),
y.len(),
prior_w.len(),
offset.len(),
per_sigma.len(),
admission.family,
admission.curvature,
gamma_shape,
convergence_tol,
max_iter,
);
Ok(None)
}
}
fn validate_sigma_point_inputs(p: usize, per_sigma: &[SigmaPointGpuInput]) -> Result<(), GpuError> {
for (idx, pt) in per_sigma.iter().enumerate() {
if pt.s_transformed.shape() != [p, p] {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] S shape {:?} != [{p}, {p}]",
pt.s_transformed.shape()
));
}
if pt.qs.shape() != [p, p] {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] Qs shape {:?} != [{p}, {p}]",
pt.qs.shape()
));
}
if pt.linear_shift.len() != p {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] linear shift len {} != {p}",
pt.linear_shift.len()
));
}
if !pt.constant_shift.is_finite() {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] non-finite constant shift {}",
pt.constant_shift
));
}
}
Ok(())
}
#[cfg(target_os = "linux")]
mod linux_impl {
use crate::gpu::kernels::pirls_row::{CurvatureMode, PirlsRowFamily};
use crate::gpu::kernels::sigma_cubature::SigmaPointGpuInput;
use crate::gpu::policy::{PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
use crate::linalg::utils::matrix_inversewith_regularization;
use ndarray::{Array1, Array2, ArrayView1};
type SigmaPointResult = Option<(Array2<f64>, Array1<f64>)>;
pub(super) fn family_kind_to_row(f: PirlsLoopFamilyKind) -> Option<PirlsRowFamily> {
match f {
PirlsLoopFamilyKind::BernoulliLogit => Some(PirlsRowFamily::BernoulliLogit),
PirlsLoopFamilyKind::BernoulliProbit => Some(PirlsRowFamily::BernoulliProbit),
PirlsLoopFamilyKind::BernoulliCLogLog => Some(PirlsRowFamily::BernoulliCLogLog),
PirlsLoopFamilyKind::PoissonLog => Some(PirlsRowFamily::PoissonLog),
PirlsLoopFamilyKind::GaussianIdentity => Some(PirlsRowFamily::GaussianIdentity),
PirlsLoopFamilyKind::GammaLog => Some(PirlsRowFamily::GammaLog),
}
}
pub(super) fn curvature_kind_to_row(c: PirlsLoopCurvatureKind) -> CurvatureMode {
match c {
PirlsLoopCurvatureKind::Fisher => CurvatureMode::Fisher,
PirlsLoopCurvatureKind::Observed => CurvatureMode::Observed,
}
}
fn hessian_to_original(
h_transformed: &ndarray::Array2<f64>,
qs: &ndarray::Array2<f64>,
) -> ndarray::Array2<f64> {
let tmp = qs.dot(h_transformed);
let mut h_orig = tmp.dot(&qs.t());
crate::families::custom_family::symmetrize_dense_in_place(&mut h_orig);
h_orig
}
pub(super) fn stream_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
family: PirlsRowFamily,
curvature: CurvatureMode,
gamma_shape: f64,
convergence_tol: f64,
max_iter: usize,
) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
use crate::gpu::kernels::sigma_cubature::pool_size;
use crate::solver::gpu::pirls_gpu;
let m = per_sigma.len();
let p = x_original.ncols();
for (idx, pt) in per_sigma.iter().enumerate() {
if pt.s_transformed.shape() != [p, p] || pt.qs.shape() != [p, p] {
return Err(crate::gpu_err!(
"sigma stream pool: point[{idx}] shape mismatch against point[0]"
));
}
}
if family == PirlsRowFamily::GaussianIdentity {
return gaussian_sigma_pool_eval(x_original, y, prior_w, offset, per_sigma, p);
}
let bootstrap_shared =
pirls_gpu::upload_shared_pirls_gpu(x_original, y, prior_w, offset)
.map_err(|e| crate::gpu_err!("sigma stream pool bootstrap upload: {e}"))?;
let n_streams = pool_size(m);
let mut workspace_pairs: Vec<(
crate::solver::gpu::pirls_gpu::SigmaPirlsGpuWorkspace,
crate::solver::gpu::pirls_gpu::cuda::PirlsLoopWorkspace,
)> = Vec::with_capacity(n_streams);
for _ in 0..n_streams {
let ws = pirls_gpu::allocate_sigma_pirls_workspace(&bootstrap_shared)
.map_err(|e| crate::gpu_err!("sigma stream pool alloc workspace: {e}"))?;
let loop_ws = pirls_gpu::allocate_pirls_loop_workspace(&bootstrap_shared, &ws)
.map_err(|e| crate::gpu_err!("sigma stream pool alloc loop_ws: {e}"))?;
workspace_pairs.push((ws, loop_ws));
}
let beta0: Array1<f64> = Array1::zeros(p);
let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(m);
for (idx, pt) in per_sigma.iter().enumerate() {
let stream_idx = idx % n_streams;
let (ws, loop_ws) = &mut workspace_pairs[stream_idx];
pirls_gpu::upload_qs_pirls(ws, pt.qs.view())
.map_err(|e| crate::gpu_err!("sigma stream pool upload Qs pt[{idx}]: {e}"))?;
let shared = &bootstrap_shared;
let outcome = pirls_gpu::pirls_loop_on_stream(
shared,
ws,
loop_ws,
family,
curvature,
gamma_shape,
beta0.view(),
pt.s_transformed.view(),
pt.linear_shift.view(),
pt.constant_shift,
super::SIGMA_PIRLS_INITIAL_LM_LAMBDA,
0.0,
max_iter,
convergence_tol,
None,
);
let sigma_result = match outcome {
Ok(loop_out) => {
let h_orig = hessian_to_original(&loop_out.penalized_hessian, &pt.qs);
let cov = matrix_inversewith_regularization(&h_orig, "gpu sigma point")
.ok_or_else(|| {
crate::gpu_err!(
"gpu sigma point: penalised Hessian inverse not well-defined"
)
})?;
let beta_orig = pt.qs.dot(&loop_out.beta);
Some((cov, beta_orig))
}
Err(e) => {
log::warn!(
"[sigma-cubature gpu] point[{idx}] pirls_loop_on_stream failed: {e}"
);
None
}
};
outcomes.push(sigma_result);
}
Ok(Some(outcomes))
}
fn gaussian_sigma_pool_eval(
x_original: ndarray::ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
per_sigma: &[SigmaPointGpuInput],
p: usize,
) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
use ndarray::Array1;
let xtwx = crate::solver::gpu::pirls_gpu::weighted_crossprod_gpu(x_original, prior_w)
.map_err(|e| crate::gpu_err!("gaussian sigma: XᵀWX gpu failed: {e}"))?;
let mut yw = y.to_owned();
yw -= &offset;
yw *= &prior_w;
let xtwy: Array1<f64> = x_original.t().dot(&yw);
let prior_mean_zero: Array1<f64> = Array1::zeros(p);
let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(per_sigma.len());
for (idx, pt) in per_sigma.iter().enumerate() {
let pls = crate::solver::gpu::pirls_gpu::solve_gaussian_pls_gpu(
xtwx.view(),
xtwy.view(),
pt.s_transformed.view(),
pt.linear_shift.view(),
prior_mean_zero.view(),
0.0,
Some(pt.qs.view()),
)
.map_err(|e| crate::gpu_err!("gaussian sigma pool: point[{idx}] pls failed: {e}"))?;
let h_orig = hessian_to_original(&pls.penalized_hessian, &pt.qs);
let cov = matrix_inversewith_regularization(&h_orig, "gaussian sigma point")
.ok_or_else(|| {
crate::gpu_err!(
"gaussian sigma point: penalised Hessian inverse not well-defined"
)
})?;
let beta_orig = pt.qs.dot(&pls.beta);
outcomes.push(Some((cov, beta_orig)));
}
Ok(Some(outcomes))
}
}