use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct PirlsGpuInput<'a> {
pub x: ArrayView2<'a, f64>,
pub weights: ArrayView1<'a, f64>,
pub penalty_hessian: ArrayView2<'a, f64>,
pub gradient: ArrayView1<'a, f64>,
pub step_lm_lambda: f64,
pub objective_ridge: f64,
}
#[derive(Clone, Debug)]
pub struct PirlsGpuStep {
pub penalized_hessian: Array2<f64>,
pub direction: Array1<f64>,
pub logdet: f64,
}
#[derive(Clone, Debug)]
pub struct PirlsStepStreamInput<'a> {
pub weights: ArrayView1<'a, f64>,
pub penalty_hessian: ArrayView2<'a, f64>,
pub gradient: ArrayView1<'a, f64>,
pub step_lm_lambda: f64,
pub objective_ridge: f64,
}
#[cfg(target_os = "linux")]
pub struct PirlsStepStreamDeviceInput<'a, 'b> {
pub w_solver_dev: &'a cudarc::driver::CudaSlice<f64>,
pub grad_eta_dev: &'b cudarc::driver::CudaSlice<f64>,
pub penalty_hessian: ArrayView2<'b, f64>,
pub step_lm_lambda: f64,
pub objective_ridge: f64,
pub beta_dev: &'b cudarc::driver::CudaSlice<f64>,
pub linear_shift: ArrayView1<'b, f64>,
}
#[cfg(target_os = "linux")]
pub struct PirlsGpuSharedData {
pub(crate) ctx: std::sync::Arc<cudarc::driver::CudaContext>,
pub(crate) n: usize,
pub(crate) p: usize,
pub(crate) x_original_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) y_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) prior_w_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) offset_dev: cudarc::driver::CudaSlice<f64>,
}
#[cfg(target_os = "linux")]
pub struct SigmaPirlsGpuWorkspace {
pub(crate) stream: std::sync::Arc<cudarc::driver::CudaStream>,
pub(crate) blas: cudarc::cublas::CudaBlas,
pub(crate) solver: cudarc::cusolver::DnHandle,
pub(crate) wx_dev: Option<cudarc::driver::CudaSlice<f64>>,
pub(crate) w_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) xtwx_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) h_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) rhs_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) penalty_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) qs_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) qs_tmp_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) beta_orig_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) dir_orig_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) potrf_work_dev: cudarc::driver::CudaSlice<f64>,
pub(crate) potrf_lwork: i32,
pub(crate) potrf_info_dev: cudarc::driver::CudaSlice<i32>,
pub(crate) potrs_info_dev: cudarc::driver::CudaSlice<i32>,
pub(crate) n: usize,
pub(crate) p: usize,
}
#[cfg(target_os = "linux")]
pub(crate) mod cuda {
use super::{
PirlsGpuInput, PirlsGpuSharedData, PirlsGpuStep, PirlsStepStreamDeviceInput,
PirlsStepStreamInput, SigmaPirlsGpuWorkspace,
};
use crate::gpu::common::PtxModuleCache;
use crate::gpu::driver::{from_col_major, to_col_major};
use crate::gpu::solver::{
check_deferred_potrf_info, check_deferred_potrs_info, context_and_stream, pinned_htod,
potrf_in_place_reuse, potrf_query_lwork, potrs_in_place_reuse,
};
use cudarc::cublas::sys::{
cublasDdgmm, cublasDgeam, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
};
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
use cudarc::cusolver::DnHandle;
use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut, LaunchConfig, PushKernelArg};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
const CHOL_LOGDET_PTX_SOURCE: &str = r#"
extern "C" __global__ void chol_logdet_col_major(
const double* __restrict__ factor,
int p,
double* __restrict__ out
) {
if (threadIdx.x != 0 || blockIdx.x != 0) return;
double acc = 0.0;
long long pp = (long long)p;
for (long long i = 0; i < pp; ++i) {
acc += log(factor[i * pp + i]);
}
out[0] = 2.0 * acc;
}
"#;
static CHOL_LOGDET_CACHE: PtxModuleCache = PtxModuleCache::new();
const FUSED_XTWX_P_THRESHOLD: usize = 256;
const FUSED_XTWX_PTX_SOURCE: &str = concat!(
"extern \"C\" __global__ void xtwx_lower(",
"const double* __restrict__ X,",
"const double* __restrict__ w,",
"double* __restrict__ A,",
"int n, int p) {",
"int t=blockIdx.x*blockDim.x+threadIdx.x;",
"int np=p*(p+1)/2; if(t>=np)return;",
"int jv=(int)((__dsqrt_rn((double)(8*t+1))-1.0)*0.5);",
"while((long long)(jv+1)*(jv+2)/2<=t)jv++;",
"while(jv>0&&(long long)jv*(jv+1)/2>t)jv--;",
"int kv=t-(int)((long long)jv*(jv+1)/2);",
"double acc=0.0;",
"const double*Xj=X+(long long)jv*n;",
"const double*Xk=X+(long long)kv*n;",
"for(int i=0;i<n;++i)acc+=w[i]*Xj[i]*Xk[i];",
"A[jv+(long long)kv*p]=acc;}",
"extern \"C\" __global__ void xtscore(",
"const double* __restrict__ X,",
"const double* __restrict__ score,",
"double* __restrict__ s,",
"int n, int p) {",
"int j=blockIdx.x*blockDim.x+threadIdx.x;",
"if(j>=p)return;",
"double acc=0.0;",
"const double*Xj=X+(long long)j*n;",
"for(int i=0;i<n;++i)acc+=score[i]*Xj[i];",
"s[j]=acc;}",
"extern \"C\" __global__ void symmetrize_lower(",
"double* __restrict__ A, int p) {",
"int ns=p*(p-1)/2;",
"int t=blockIdx.x*blockDim.x+threadIdx.x;",
"if(t>=ns)return;",
"int jv=(int)((__dsqrt_rn((double)(8*t+1))+1.0)*0.5);",
"while((long long)jv*(jv-1)/2>t)jv--;",
"while((long long)(jv+1)*jv/2<=t)jv++;",
"int kv=t-(int)((long long)jv*(jv-1)/2);",
"A[kv+(long long)jv*p]=A[jv+(long long)kv*p];}",
);
static FUSED_XTWX_CACHE: PtxModuleCache = PtxModuleCache::new();
impl PirlsGpuSharedData {
pub(crate) fn upload_impl(
x: ArrayView2<'_, f64>,
y: ArrayView1<'_, f64>,
prior_w: ArrayView1<'_, f64>,
offset: ArrayView1<'_, f64>,
) -> Result<Self, String> {
let (n, p) = x.dim();
if n == 0 || p == 0 {
return Err("empty design cannot be uploaded".to_string());
}
if y.len() != n || prior_w.len() != n || offset.len() != n {
return Err(format!(
"y/prior_w/offset length mismatch (y={}, w={}, offset={}, n={n})",
y.len(),
prior_w.len(),
offset.len()
));
}
let (ctx, stream) = context_and_stream()?;
let x_col = to_col_major(&x);
let x_original_dev = pinned_htod(&stream, &x_col)?;
let y_dev = pinned_htod(&stream, y.as_slice().ok_or("y not contiguous")?)?;
let prior_w_dev =
pinned_htod(&stream, prior_w.as_slice().ok_or("prior_w not contiguous")?)?;
let offset_dev =
pinned_htod(&stream, offset.as_slice().ok_or("offset not contiguous")?)?;
stream
.synchronize()
.map_err(|e| format!("cuda sync after model upload: {e}"))?;
Ok(Self {
ctx,
n,
p,
x_original_dev,
y_dev,
prior_w_dev,
offset_dev,
})
}
}
impl SigmaPirlsGpuWorkspace {
pub(crate) fn allocate_impl(shared: &PirlsGpuSharedData) -> Result<Self, String> {
let n = shared.n;
let p = shared.p;
let stream = shared
.ctx
.new_stream()
.map_err(|e| format!("cuda stream alloc: {e}"))?;
let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
let solver =
DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
let np = n.checked_mul(p).ok_or("X size overflow")?;
let pp = p.checked_mul(p).ok_or("H size overflow")?;
let wx_dev = if p >= FUSED_XTWX_P_THRESHOLD {
Some(
stream
.alloc_zeros::<f64>(np)
.map_err(|e| format!("cuda alloc WX: {e}"))?,
)
} else {
None
};
let w_dev = stream
.alloc_zeros::<f64>(n)
.map_err(|e| format!("cuda alloc W: {e}"))?;
let xtwx_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("cuda alloc XtWX: {e}"))?;
let h_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("cuda alloc H: {e}"))?;
let rhs_dev = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("cuda alloc RHS: {e}"))?;
let penalty_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("cuda alloc penalty: {e}"))?;
let mut qs_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("cuda alloc Qs: {e}"))?;
{
let mut qs_host = vec![0.0_f64; pp];
for i in 0..p {
qs_host[i * p + i] = 1.0;
}
stream
.memcpy_htod(&qs_host, &mut qs_dev)
.map_err(|e| format!("init Qs identity: {e}"))?;
}
let qs_tmp_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("cuda alloc Qs tmp: {e}"))?;
let beta_orig_dev = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("cuda alloc beta_orig: {e}"))?;
let dir_orig_dev = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("cuda alloc dir_orig: {e}"))?;
let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
let potrf_lwork = i32::try_from(potrf_lwork_usize)
.map_err(|_| format!("potrf lwork {potrf_lwork_usize} exceeds i32"))?;
let alloc_len = potrf_lwork_usize.max(1);
let potrf_work_dev = stream
.alloc_zeros::<f64>(alloc_len)
.map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
let potrf_info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|e| format!("cuda alloc potrf info: {e}"))?;
let potrs_info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|e| format!("cuda alloc potrs info: {e}"))?;
Ok(Self {
stream,
blas,
solver,
wx_dev,
w_dev,
xtwx_dev,
h_dev,
rhs_dev,
penalty_dev,
qs_dev,
qs_tmp_dev,
beta_orig_dev,
dir_orig_dev,
potrf_work_dev,
potrf_lwork,
potrf_info_dev,
potrs_info_dev,
n,
p,
})
}
}
pub(super) fn upload_qs(
ws: &mut SigmaPirlsGpuWorkspace,
qs: ArrayView2<'_, f64>,
) -> Result<(), String> {
let p = ws.p;
if qs.dim() != (p, p) {
return Err(format!("upload_qs: Qs shape {:?} != ({p},{p})", qs.dim()));
}
let qs_col = to_col_major(&qs);
ws.stream
.memcpy_htod(qs_col.as_ref(), &mut ws.qs_dev)
.map_err(|e| format!("upload Qs: {e}"))
}
pub(super) fn upload_qs_identity(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
let p = ws.p;
let pp = p * p;
let mut qs_host = vec![0.0_f64; pp];
for i in 0..p {
qs_host[i * p + i] = 1.0;
}
ws.stream
.memcpy_htod(&qs_host, &mut ws.qs_dev)
.map_err(|e| format!("upload Qs identity: {e}"))
}
#[allow(clippy::too_many_arguments)]
fn newton_step_refine_once(
solver: &cudarc::cusolver::DnHandle,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
p: usize,
chol_factor_dev: &CudaSlice<f64>,
rhs_dev: &mut CudaSlice<f64>,
potrs_info_dev: &mut CudaSlice<i32>,
mut direction_raw: Vec<f64>,
g: &[f64],
penalized_hessian: &ndarray::Array2<f64>,
step_lm_delta: f64,
) -> Result<Vec<f64>, String> {
use crate::gpu::policy::GpuDispatchPolicy;
if p < GpuDispatchPolicy::REFINEMENT_MIN_P {
return Ok(direction_raw);
}
let norm_g = g.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm_g == 0.0 {
return Ok(direction_raw);
}
let hx: Vec<f64> = (0..p)
.map(|i| {
penalized_hessian
.row(i)
.iter()
.zip(direction_raw.iter())
.map(|(hij, xj)| hij * xj)
.sum::<f64>()
+ step_lm_delta * direction_raw[i]
})
.collect();
let residual: Vec<f64> = g.iter().zip(hx.iter()).map(|(gi, hxi)| gi - hxi).collect();
let rel_res = residual.iter().map(|v| v * v).sum::<f64>().sqrt() / norm_g;
if rel_res <= GpuDispatchPolicy::REFINEMENT_TOL {
return Ok(direction_raw);
}
stream
.memcpy_htod(&residual, rhs_dev)
.map_err(|e| format!("upload residual: {e}"))?;
potrs_in_place_reuse(
solver,
stream,
p,
1,
chol_factor_dev,
rhs_dev,
potrs_info_dev,
)?;
let correction = stream
.clone_dtoh(rhs_dev)
.map_err(|e| format!("download correction: {e}"))?;
check_deferred_potrs_info(stream, potrs_info_dev)?;
for (xi, ei) in direction_raw.iter_mut().zip(correction.iter()) {
*xi += ei;
}
Ok(direction_raw)
}
pub(super) fn solve_step_on_stream(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
input: PirlsStepStreamInput<'_>,
) -> Result<PirlsGpuStep, String> {
let n = shared.n;
let p = shared.p;
if ws.n != n || ws.p != p {
return Err(format!(
"workspace shape ({}, {}) does not match shared design ({n}, {p})",
ws.n, ws.p
));
}
if input.weights.len() != n {
return Err(format!(
"weights length {} does not match rows {n}",
input.weights.len()
));
}
if input.penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty Hessian shape {:?} does not match p={p}",
input.penalty_hessian.dim()
));
}
if input.gradient.len() != p {
return Err(format!(
"gradient length {} does not match p={p}",
input.gradient.len()
));
}
let w_slice = input
.weights
.as_slice()
.ok_or("weights must be contiguous")?;
ws.stream
.memcpy_htod(w_slice, &mut ws.w_dev)
.map_err(|e| format!("upload W: {e}"))?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
if let Some(ref mut wx_dev) = ws.wx_dev {
left_scale_rows(
&ws.blas,
&ws.stream,
n,
p,
&shared.x_original_dev,
&mut ws.w_dev,
wx_dev,
)?;
let cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg, &shared.x_original_dev, wx_dev, &mut ws.xtwx_dev)
}
.map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
} else {
launch_xtwx_lower(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
&ws.w_dev,
&mut ws.xtwx_dev,
)?;
launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
}
let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
let penalty_step_view = penalty_step.view();
let penalty_step_col = to_col_major(&penalty_step_view);
ws.stream
.memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
.map_err(|e| format!("upload penalty: {e}"))?;
{
let cfg_aq = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
}
.map_err(|e| format!("dgemm A·Qs (host-input step): {e}"))?;
}
{
let cfg_qt = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
}
.map_err(|e| format!("dgemm Qsᵀ·A·Qs (host-input step): {e}"))?;
}
geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
let g_slice = input
.gradient
.as_slice()
.ok_or("gradient must be contiguous")?;
ws.stream
.memcpy_htod(g_slice, &mut ws.rhs_dev)
.map_err(|e| format!("upload gradient: {e}"))?;
let xtwx_col = ws
.stream
.clone_dtoh(&ws.xtwx_dev)
.map_err(|e| format!("download XᵀWX (host-input step): {e}"))?;
let xtwx_host = from_col_major(&xtwx_col, p, p).ok_or("XᵀWX layout conversion failed")?;
let qs_col = ws
.stream
.clone_dtoh(&ws.qs_dev)
.map_err(|e| format!("download Qs (host-input step): {e}"))?;
let qs_host =
from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (host-input step)")?;
let tmp_aq = xtwx_host.dot(&qs_host);
let h_rotated = qs_host.t().dot(&tmp_aq);
let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
let penalized_hessian = h_rotated + &penalty_export;
potrf_in_place_reuse(
&ws.solver,
&ws.stream,
p,
ws.potrf_lwork,
&mut ws.h_dev,
&mut ws.potrf_work_dev,
&mut ws.potrf_info_dev,
)?;
potrs_in_place_reuse(
&ws.solver,
&ws.stream,
p,
1,
&ws.h_dev,
&mut ws.rhs_dev,
&mut ws.potrs_info_dev,
)?;
let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
let direction_raw = ws
.stream
.clone_dtoh(&ws.rhs_dev)
.map_err(|e| format!("download direction: {e}"))?;
check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
let lm_ridge_delta = input.step_lm_lambda - input.objective_ridge;
let direction_raw = newton_step_refine_once(
&ws.solver,
&ws.stream,
p,
&ws.h_dev,
&mut ws.rhs_dev,
&mut ws.potrs_info_dev,
direction_raw,
g_slice,
&penalized_hessian,
lm_ridge_delta,
)?;
let direction = Array1::from_vec(direction_raw);
Ok(PirlsGpuStep {
penalized_hessian,
direction,
logdet,
})
}
pub(super) fn solve_step_on_stream_device(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
input: PirlsStepStreamDeviceInput<'_, '_>,
) -> Result<PirlsGpuStep, String> {
let n = shared.n;
let p = shared.p;
if ws.n != n || ws.p != p {
return Err(format!(
"workspace shape ({}, {}) does not match shared design ({n}, {p})",
ws.n, ws.p
));
}
if input.w_solver_dev.len() != n {
return Err(format!(
"w_solver_dev length {} does not match n={n}",
input.w_solver_dev.len()
));
}
if input.grad_eta_dev.len() != n {
return Err(format!(
"grad_eta_dev length {} does not match n={n}",
input.grad_eta_dev.len()
));
}
if input.penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty Hessian shape {:?} does not match p={p}",
input.penalty_hessian.dim()
));
}
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
if let Some(ref mut wx_dev_fb) = ws.wx_dev {
left_scale_rows_borrowed(
&ws.blas,
&ws.stream,
n,
p,
&shared.x_original_dev,
input.w_solver_dev,
wx_dev_fb,
)?;
let gemm_cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas.gemm(
gemm_cfg,
&shared.x_original_dev,
wx_dev_fb,
&mut ws.xtwx_dev,
)
}
.map_err(|e| format!("cublas dgemm XtWX (device-input): {e}"))?;
let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
let penalty_step_col = to_col_major(&penalty_step);
ws.stream
.memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
.map_err(|e| format!("upload penalty (device-input): {e}"))?;
{
let cfg_aq = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
}
.map_err(|e| format!("dgemm A·Qs (device-input large-p): {e}"))?;
}
{
let cfg_qt = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
}
.map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input large-p): {e}"))?;
}
geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: n_i,
n: p_i,
alpha: 1.0,
lda: n_i,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe {
ws.blas.gemv(
gemv_cfg,
&shared.x_original_dev,
input.grad_eta_dev,
&mut ws.rhs_dev,
)
}
.map_err(|e| format!("cublas dgemv Xtg (device-input): {e}"))?;
} else {
launch_xtwx_lower(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
input.w_solver_dev,
&mut ws.xtwx_dev,
)?;
launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
launch_xtscore(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
input.grad_eta_dev,
&mut ws.rhs_dev,
)?;
{
let cfg_aq = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
}
.map_err(|e| format!("dgemm A·Qs (device-input fused): {e}"))?;
}
{
let cfg_qt = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
}
.map_err(|e| format!("dgemm Qsᵀ·A·Qs (device-input fused): {e}"))?;
}
let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
let penalty_step_col = to_col_major(&penalty_step);
ws.stream
.memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
.map_err(|e| format!("upload penalty (fused device-input): {e}"))?;
geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
}
{
let cfg_qts = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: p_i,
n: p_i,
alpha: 1.0,
lda: p_i,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe {
ws.blas
.gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
}
.map_err(|e| format!("dgemv Qsᵀ·score (device-input): {e}"))?;
ws.stream
.memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
.map_err(|e| format!("d2d Qsᵀ·score→rhs (device-input): {e}"))?;
let rhs_raw = ws
.stream
.clone_dtoh(&ws.rhs_dev)
.map_err(|e| format!("download Qsᵀscore (device-input): {e}"))?;
let beta_raw = ws
.stream
.clone_dtoh(input.beta_dev)
.map_err(|e| format!("download beta (device-input): {e}"))?;
let mut rhs_host = Array1::from_vec(rhs_raw);
let beta_host = Array1::from_vec(beta_raw);
let s_beta = input.penalty_hessian.dot(&beta_host);
rhs_host -= &s_beta;
rhs_host += &input.linear_shift;
ws.stream
.memcpy_htod(
rhs_host
.as_slice()
.ok_or("rhs_host not contiguous (device-input correction)")?,
&mut ws.rhs_dev,
)
.map_err(|e| format!("re-upload corrected rhs (device-input): {e}"))?;
}
let xtwx_col = ws
.stream
.clone_dtoh(&ws.xtwx_dev)
.map_err(|e| format!("download XᵀWX (device-input): {e}"))?;
let xtwx_host = from_col_major(&xtwx_col, p, p)
.ok_or("XᵀWX layout conversion failed (device-input)")?;
let qs_col = ws
.stream
.clone_dtoh(&ws.qs_dev)
.map_err(|e| format!("download Qs (device-input): {e}"))?;
let qs_host =
from_col_major(&qs_col, p, p).ok_or("Qs layout conversion failed (device-input)")?;
let tmp_aq = xtwx_host.dot(&qs_host);
let h_rotated = qs_host.t().dot(&tmp_aq);
let penalty_export = penalty_with_ridge(input.penalty_hessian, input.objective_ridge);
let penalized_hessian = h_rotated + &penalty_export;
potrf_in_place_reuse(
&ws.solver,
&ws.stream,
p,
ws.potrf_lwork,
&mut ws.h_dev,
&mut ws.potrf_work_dev,
&mut ws.potrf_info_dev,
)?;
potrs_in_place_reuse(
&ws.solver,
&ws.stream,
p,
1,
&ws.h_dev,
&mut ws.rhs_dev,
&mut ws.potrs_info_dev,
)?;
let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
let direction_raw = ws
.stream
.clone_dtoh(&ws.rhs_dev)
.map_err(|e| format!("download direction (device-input): {e}"))?;
check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
let direction = Array1::from_vec(direction_raw);
Ok(PirlsGpuStep {
penalized_hessian,
direction,
logdet,
})
}
pub(super) fn solve_step_on_stream_device_inplace(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
input: PirlsStepStreamDeviceInput<'_, '_>,
) -> Result<f64, String> {
let n = shared.n;
let p = shared.p;
if ws.n != n || ws.p != p {
return Err(format!(
"workspace shape ({}, {}) does not match shared design ({n}, {p})",
ws.n, ws.p
));
}
if input.w_solver_dev.len() != n {
return Err(format!(
"w_solver_dev length {} does not match n={n}",
input.w_solver_dev.len()
));
}
if input.grad_eta_dev.len() != n {
return Err(format!(
"grad_eta_dev length {} does not match n={n}",
input.grad_eta_dev.len()
));
}
if input.penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty Hessian shape {:?} does not match p={p}",
input.penalty_hessian.dim()
));
}
if input.linear_shift.len() != p {
return Err(format!(
"linear_shift length {} does not match p={p}",
input.linear_shift.len()
));
}
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
if let Some(ref mut wx_dev_ib) = ws.wx_dev {
left_scale_rows_borrowed(
&ws.blas,
&ws.stream,
n,
p,
&shared.x_original_dev,
input.w_solver_dev,
wx_dev_ib,
)?;
let cfg_xtx = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_xtx, &shared.x_original_dev, wx_dev_ib, &mut ws.xtwx_dev)
}
.map_err(|e| format!("dgemm XtWX inplace (large-p): {e}"))?;
let cfg_xts = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: n_i,
n: p_i,
alpha: 1.0,
lda: n_i,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe {
ws.blas.gemv(
cfg_xts,
&shared.x_original_dev,
input.grad_eta_dev,
&mut ws.rhs_dev,
)
}
.map_err(|e| format!("dgemv Xᵀ·score inplace (large-p): {e}"))?;
} else {
launch_xtwx_lower(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
input.w_solver_dev,
&mut ws.xtwx_dev,
)?;
launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
launch_xtscore(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
input.grad_eta_dev,
&mut ws.rhs_dev,
)?;
}
{
let cfg_aq = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
}
.map_err(|e| format!("dgemm A·Qs inplace: {e}"))?;
}
{
let cfg_qt = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
}
.map_err(|e| format!("dgemm Qsᵀ·A·Qs inplace: {e}"))?;
}
let penalty_step = penalty_with_ridge(input.penalty_hessian, input.step_lm_lambda);
let penalty_step_col = to_col_major(&penalty_step);
ws.stream
.memcpy_htod(penalty_step_col.as_ref(), &mut ws.penalty_dev)
.map_err(|e| format!("upload penalty inplace: {e}"))?;
geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
{
let cfg_qts = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: p_i,
n: p_i,
alpha: 1.0,
lda: p_i,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe {
ws.blas
.gemv(cfg_qts, &ws.qs_dev, &ws.rhs_dev, &mut ws.beta_orig_dev)
}
.map_err(|e| format!("dgemv Qsᵀ·score inplace: {e}"))?;
ws.stream
.memcpy_dtod(&ws.beta_orig_dev, &mut ws.rhs_dev)
.map_err(|e| format!("d2d Qsᵀ·score→rhs inplace: {e}"))?;
}
let rhs_raw = ws
.stream
.clone_dtoh(&ws.rhs_dev)
.map_err(|e| format!("download Qsᵀ·score inplace: {e}"))?;
let beta_raw = ws
.stream
.clone_dtoh(input.beta_dev)
.map_err(|e| format!("download beta inplace: {e}"))?;
let mut rhs_host = Array1::from_vec(rhs_raw);
let beta_host = Array1::from_vec(beta_raw);
let s_beta = input.penalty_hessian.dot(&beta_host);
rhs_host -= &s_beta;
rhs_host += &input.linear_shift;
ws.stream
.memcpy_htod(
rhs_host.as_slice().ok_or("rhs_host not contiguous")?,
&mut ws.rhs_dev,
)
.map_err(|e| format!("re-upload corrected rhs inplace: {e}"))?;
potrf_in_place_reuse(
&ws.solver,
&ws.stream,
p,
ws.potrf_lwork,
&mut ws.h_dev,
&mut ws.potrf_work_dev,
&mut ws.potrf_info_dev,
)?;
potrs_in_place_reuse(
&ws.solver,
&ws.stream,
p,
1,
&ws.h_dev,
&mut ws.rhs_dev,
&mut ws.potrs_info_dev,
)?;
let logdet = cholesky_logdet_device(&ws.stream, &shared.ctx, p, &ws.h_dev)?;
check_deferred_potrf_info(&ws.stream, &ws.potrf_info_dev)?;
check_deferred_potrs_info(&ws.stream, &ws.potrs_info_dev)?;
Ok(logdet)
}
pub(super) fn rebuild_h_final(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
w_hessian_dev: &CudaSlice<f64>,
penalty_hessian: ArrayView2<'_, f64>,
objective_ridge: f64,
) -> Result<Array2<f64>, String> {
let n = shared.n;
let p = shared.p;
if let Some(ref mut wx_dev_rh) = ws.wx_dev {
left_scale_rows_borrowed(
&ws.blas,
&ws.stream,
n,
p,
&shared.x_original_dev,
w_hessian_dev,
wx_dev_rh,
)?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let gemm_cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas.gemm(
gemm_cfg,
&shared.x_original_dev,
wx_dev_rh,
&mut ws.xtwx_dev,
)
}
.map_err(|e| format!("cublas dgemm XtWX (final H rebuild): {e}"))?;
} else {
launch_xtwx_lower(
&ws.stream,
&shared.ctx,
n,
p,
&shared.x_original_dev,
w_hessian_dev,
&mut ws.xtwx_dev,
)?;
launch_symmetrize_lower(&ws.stream, &shared.ctx, p, &mut ws.xtwx_dev)?;
}
let p_i = to_i32(p)?;
{
let cfg_aq = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_aq, &ws.xtwx_dev, &ws.qs_dev, &mut ws.qs_tmp_dev)
}
.map_err(|e| format!("dgemm A·Qs (final H rebuild): {e}"))?;
}
{
let cfg_qt = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: p_i,
alpha: 1.0,
lda: p_i,
ldb: p_i,
beta: 0.0,
ldc: p_i,
};
unsafe {
ws.blas
.gemm(cfg_qt, &ws.qs_dev, &ws.qs_tmp_dev, &mut ws.h_dev)
}
.map_err(|e| format!("dgemm Qsᵀ·A·Qs (final H rebuild): {e}"))?;
}
let penalty = penalty_with_ridge(penalty_hessian, objective_ridge);
let penalty_col = to_col_major(&penalty);
ws.stream
.memcpy_htod(penalty_col.as_ref(), &mut ws.penalty_dev)
.map_err(|e| format!("upload penalty (final H rebuild): {e}"))?;
geam_add_inplace(&ws.blas, &ws.stream, p, &mut ws.h_dev, &ws.penalty_dev)?;
let h_col = ws
.stream
.clone_dtoh(&ws.h_dev)
.map_err(|e| format!("download H_final: {e}"))?;
from_col_major(&h_col, p, p).ok_or_else(|| "H_final layout conversion failed".to_string())
}
pub(super) fn weighted_crossprod(
x: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let (_, stream) = context_and_stream()?;
let (n, p) = validate_design(x, weights)?;
let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
let x_col = to_col_major(&x);
let x_dev = pinned_htod(&stream, &x_col)?;
let mut w_dev = pinned_htod(
&stream,
weights.as_slice().ok_or("weights must be contiguous")?,
)?;
let mut wx_dev = stream
.alloc_zeros::<f64>(n.checked_mul(p).ok_or("X size overflow")?)
.map_err(|e| format!("cuda alloc WX: {e}"))?;
left_scale_rows(&blas, &stream, n, p, &x_dev, &mut w_dev, &mut wx_dev)?;
let mut h_dev = stream
.alloc_zeros::<f64>(p.checked_mul(p).ok_or("H size overflow")?)
.map_err(|e| format!("cuda alloc H: {e}"))?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
unsafe { blas.gemm(cfg, &x_dev, &wx_dev, &mut h_dev) }
.map_err(|e| format!("cublas dgemm XtWX: {e}"))?;
let h_col = stream
.clone_dtoh(&h_dev)
.map_err(|e| format!("download H: {e}"))?;
from_col_major(&h_col, p, p).ok_or_else(|| "H layout conversion failed".to_string())
}
pub(super) fn solve_step(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
let (_, p) = validate_design(input.x, input.weights)?;
if input.penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty Hessian shape {:?} does not match p={p}",
input.penalty_hessian.dim()
));
}
if input.gradient.len() != p {
return Err(format!(
"gradient length {} does not match p={p}",
input.gradient.len()
));
}
let n_rows = input.x.nrows();
let zero_n = ndarray::Array1::<f64>::zeros(n_rows);
let shared =
PirlsGpuSharedData::upload_impl(input.x, zero_n.view(), zero_n.view(), zero_n.view())?;
let mut ws = SigmaPirlsGpuWorkspace::allocate_impl(&shared)?;
solve_step_on_stream(
&shared,
&mut ws,
PirlsStepStreamInput {
weights: input.weights,
penalty_hessian: input.penalty_hessian,
gradient: input.gradient,
step_lm_lambda: input.step_lm_lambda,
objective_ridge: input.objective_ridge,
},
)
}
fn validate_design(
x: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<(usize, usize), String> {
let (n, p) = x.dim();
if weights.len() != n {
return Err(format!(
"weights length {} does not match rows {n}",
weights.len()
));
}
if n == 0 || p == 0 {
return Err("empty design cannot be solved on CUDA".to_string());
}
Ok((n, p))
}
fn left_scale_rows(
blas: &CudaBlas,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
n: usize,
p: usize,
x_dev: &CudaSlice<f64>,
w_dev: &mut CudaSlice<f64>,
wx_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let handle = *blas.handle();
let (x_ptr, _x_record) = x_dev.device_ptr(stream);
let (w_ptr, _w_record) = w_dev.device_ptr(stream);
let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
let status = unsafe {
cublasDdgmm(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
n_i,
p_i,
x_ptr as *const f64,
n_i,
w_ptr as *const f64,
1,
wx_ptr as *mut f64,
n_i,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(format!("cublasDdgmm failed with {status:?}"))
}
}
fn left_scale_rows_borrowed(
blas: &CudaBlas,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
n: usize,
p: usize,
x_dev: &CudaSlice<f64>,
w_dev: &CudaSlice<f64>,
wx_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let handle = *blas.handle();
let (x_ptr, _x_record) = x_dev.device_ptr(stream);
let (w_ptr, _w_record) = w_dev.device_ptr(stream);
let (wx_ptr, _wx_record) = wx_dev.device_ptr_mut(stream);
let status = unsafe {
cublasDdgmm(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
n_i,
p_i,
x_ptr as *const f64,
n_i,
w_ptr as *const f64,
1,
wx_ptr as *mut f64,
n_i,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(format!("cublasDdgmm (borrowed) failed with {status:?}"))
}
}
fn geam_add_inplace(
blas: &CudaBlas,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
p: usize,
a: &mut CudaSlice<f64>,
b: &CudaSlice<f64>,
) -> Result<(), String> {
let p_i = to_i32(p)?;
let alpha = 1.0_f64;
let beta = 1.0_f64;
let handle = *blas.handle();
let (b_ptr, _b_record) = b.device_ptr(stream);
let (a_ptr, _a_record) = a.device_ptr_mut(stream);
let out_ptr = a_ptr;
let status = unsafe {
cublasDgeam(
handle,
cublasOperation_t::CUBLAS_OP_N,
cublasOperation_t::CUBLAS_OP_N,
p_i,
p_i,
&alpha,
a_ptr as *const f64,
p_i,
&beta,
b_ptr as *const f64,
p_i,
out_ptr as *mut f64,
p_i,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(format!("cublasDgeam failed with {status:?}"))
}
}
fn launch_xtwx_lower(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
n: usize,
p: usize,
x_dev: &CudaSlice<f64>,
w_dev: &CudaSlice<f64>,
a_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
let module = FUSED_XTWX_CACHE
.get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
.map_err(|e| format!("fused_xtwx module: {e}"))?;
let func = module
.load_function("xtwx_lower")
.map_err(|e| format!("load xtwx_lower: {e}"))?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let num_pairs = p * (p + 1) / 2;
let num_pairs_u32 = u32::try_from(num_pairs)
.map_err(|_| format!("xtwx_lower: num_pairs {num_pairs} > u32"))?;
const BLOCK: u32 = 256;
let grid = num_pairs_u32.div_ceil(BLOCK).max(1);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(x_dev);
builder.arg(w_dev);
builder.arg(a_dev);
builder.arg(&n_i);
builder.arg(&p_i);
unsafe { builder.launch(cfg) }
.map_err(|e| format!("xtwx_lower launch: {e}"))
.map(|_| ())
}
fn launch_xtscore(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
n: usize,
p: usize,
x_dev: &CudaSlice<f64>,
score_dev: &CudaSlice<f64>,
s_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
let module = FUSED_XTWX_CACHE
.get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
.map_err(|e| format!("fused_xtwx module (xtscore): {e}"))?;
let func = module
.load_function("xtscore")
.map_err(|e| format!("load xtscore: {e}"))?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let p_u32 = u32::try_from(p).map_err(|_| format!("xtscore: p {p} > u32"))?;
const BLOCK: u32 = 256;
let grid = p_u32.div_ceil(BLOCK).max(1);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(x_dev);
builder.arg(score_dev);
builder.arg(s_dev);
builder.arg(&n_i);
builder.arg(&p_i);
unsafe { builder.launch(cfg) }
.map_err(|e| format!("xtscore launch: {e}"))
.map(|_| ())
}
fn launch_symmetrize_lower(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
p: usize,
a_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
if p <= 1 {
return Ok(());
}
let module = FUSED_XTWX_CACHE
.get_or_compile(ctx, "fused_xtwx", FUSED_XTWX_PTX_SOURCE)
.map_err(|e| format!("fused_xtwx module (sym): {e}"))?;
let func = module
.load_function("symmetrize_lower")
.map_err(|e| format!("load symmetrize_lower: {e}"))?;
let p_i = to_i32(p)?;
let num_strict = p * (p - 1) / 2;
let num_strict_u32 = u32::try_from(num_strict)
.map_err(|_| format!("symmetrize_lower: num_strict {num_strict} > u32"))?;
const BLOCK: u32 = 256;
let grid = num_strict_u32.div_ceil(BLOCK).max(1);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(a_dev);
builder.arg(&p_i);
unsafe { builder.launch(cfg) }
.map_err(|e| format!("symmetrize_lower launch: {e}"))
.map(|_| ())
}
fn cholesky_logdet_device(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
ctx: &std::sync::Arc<cudarc::driver::CudaContext>,
p: usize,
factor_dev: &CudaSlice<f64>,
) -> Result<f64, String> {
let module = CHOL_LOGDET_CACHE
.get_or_compile(ctx, "pirls_gpu_chol_logdet", CHOL_LOGDET_PTX_SOURCE)
.map_err(|err| format!("chol_logdet module: {err}"))?;
let func = module
.load_function("chol_logdet_col_major")
.map_err(|err| format!("chol_logdet load_function: {err}"))?;
let mut out_dev = stream
.alloc_zeros::<f64>(1)
.map_err(|err| format!("alloc chol_logdet out: {err}"))?;
let p_i = to_i32(p)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder.arg(factor_dev);
builder.arg(&p_i);
builder.arg(&mut out_dev);
unsafe { builder.launch(cfg) }.map_err(|err| format!("chol_logdet launch: {err}"))?;
let out_host = stream
.clone_dtoh(&out_dev)
.map_err(|err| format!("download chol_logdet: {err}"))?;
Ok(out_host[0])
}
fn penalty_with_ridge(penalty: ArrayView2<'_, f64>, ridge: f64) -> Array2<f64> {
let mut out = penalty.to_owned();
if ridge != 0.0 {
for i in 0..out.nrows().min(out.ncols()) {
out[[i, i]] += ridge;
}
}
out
}
fn to_i32(value: usize) -> Result<i32, String> {
i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
}
const PIRLS_LOOP_PTX_SOURCE: &str = r#"
extern "C" {
double fabs(double);
}
extern "C" __global__ void axpy_n(
double alpha,
const double* __restrict__ x,
double* __restrict__ y,
int n
) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
y[i] += alpha * x[i];
}
extern "C" __global__ void deviance_sum(
const double* __restrict__ d,
int n,
double* __restrict__ out
) {
__shared__ double sm[1024];
int tid = threadIdx.x;
int bdim = blockDim.x;
double acc = 0.0;
for (int i = tid; i < n; i += bdim) {
acc += d[i];
}
sm[tid] = acc;
__syncthreads();
for (int stride = bdim / 2; stride > 0; stride >>= 1) {
if (tid < stride) sm[tid] += sm[tid + stride];
__syncthreads();
}
if (tid == 0) out[0] = sm[0];
}
extern "C" __global__ void linf_norm(
const double* __restrict__ v,
int p,
double* __restrict__ out
) {
__shared__ double sm[1024];
int tid = threadIdx.x;
int bdim = blockDim.x;
double acc = 0.0;
for (int i = tid; i < p; i += bdim) {
double a = fabs(v[i]);
if (a > acc) acc = a;
}
sm[tid] = acc;
__syncthreads();
for (int stride = bdim / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
double r = sm[tid + stride];
if (r > sm[tid]) sm[tid] = r;
}
__syncthreads();
}
if (tid == 0) out[0] = sm[0];
}
extern "C" __global__ void negate_n(
double* __restrict__ v,
int n
) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
v[i] = -v[i];
}
// OR-reduction over a u32 status array (length n). Single-block;
// same launch config as deviance_sum (1 block of 1024 threads).
// out[0] receives the bitwise-OR of all status[i] for i in [0, n).
extern "C" __global__ void status_or(
const unsigned int* __restrict__ status,
int n,
unsigned int* __restrict__ out
) {
__shared__ unsigned int sm[1024];
int tid = threadIdx.x;
int bdim = blockDim.x;
unsigned int acc = 0u;
for (int i = tid; i < n; i += bdim) {
acc |= status[i];
}
sm[tid] = acc;
__syncthreads();
for (int stride = bdim / 2; stride > 0; stride >>= 1) {
if (tid < stride) sm[tid] |= sm[tid + stride];
__syncthreads();
}
if (tid == 0) out[0] = sm[0];
}
"#;
static PIRLS_LOOP_CACHE: PtxModuleCache = PtxModuleCache::new();
pub struct PirlsLoopWorkspace {
pub beta_dev: CudaSlice<f64>,
pub eta_dev: CudaSlice<f64>,
pub row_solve: crate::gpu::pirls_row::SolveRowBuffers,
pub alpha_ladder: crate::gpu::pirls_row::AlphaLadderDevBuffers,
pub row_final: crate::gpu::pirls_row::RowOutputDevBuffers,
pub direction_dev: CudaSlice<f64>,
pub xd_dev: CudaSlice<f64>,
pub scalar_dev: CudaSlice<f64>,
pub status_u32_dev: CudaSlice<u32>,
pub n: usize,
pub p: usize,
}
impl PirlsLoopWorkspace {
pub fn allocate(
shared: &PirlsGpuSharedData,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
) -> Result<Self, String> {
let n = shared.n;
let p = shared.p;
let alloc_f64 = |label: &'static str, len: usize| {
stream
.alloc_zeros::<f64>(len)
.map_err(|e| format!("pirls loop alloc {label}: {e}"))
};
Ok(Self {
beta_dev: alloc_f64("beta", p)?,
eta_dev: alloc_f64("eta", n)?,
row_solve: crate::gpu::pirls_row::SolveRowBuffers::allocate(stream, n)
.map_err(|e| format!("pirls loop alloc row_solve: {e}"))?,
alpha_ladder: crate::gpu::pirls_row::AlphaLadderDevBuffers::allocate(stream)
.map_err(|e| format!("pirls loop alloc alpha_ladder: {e}"))?,
row_final: crate::gpu::pirls_row::RowOutputDevBuffers::allocate(stream, n)
.map_err(|e| format!("pirls loop alloc row_final: {e}"))?,
direction_dev: alloc_f64("direction", p)?,
xd_dev: alloc_f64("xd", n)?,
scalar_dev: alloc_f64("scalar", 1)?,
status_u32_dev: stream
.alloc_zeros::<u32>(1)
.map_err(|e| format!("pirls loop alloc status_u32: {e}"))?,
n,
p,
})
}
}
pub struct PirlsLoopExtra<'a> {
pub likelihood: &'a crate::types::GlmLikelihoodSpec,
pub inverse_link: &'a crate::types::InverseLink,
pub y: ndarray::ArrayView1<'a, f64>,
pub priorweights: ndarray::ArrayView1<'a, f64>,
pub offset: ndarray::ArrayView1<'a, f64>,
pub linear_constraints: Option<&'a crate::solver::active_set::LinearInequalityConstraints>,
pub exported_curvature: crate::solver::pirls::HessianCurvatureKind,
pub ridge_passport: Option<crate::types::RidgePassport>,
pub firth: Option<crate::solver::pirls::FirthDiagnostics>,
pub qs: Option<ndarray::ArrayView2<'a, f64>>,
pub edf: Option<f64>,
}
#[derive(Clone, Debug)]
pub struct PirlsLoopOutcome {
pub beta: Array1<f64>,
pub penalized_hessian: Array2<f64>,
pub logdet: f64,
pub deviance: f64,
pub iterations: usize,
pub converged: bool,
pub final_eta: Array1<f64>,
pub final_mu: Array1<f64>,
pub final_grad_eta: Array1<f64>,
pub final_w_hessian: Array1<f64>,
pub final_w_solver: Array1<f64>,
pub final_offset: Array1<f64>,
pub beta_transformed: Array1<f64>,
pub finalweights: Array1<f64>,
pub solveweights: Array1<f64>,
pub solve_dmu_deta: Array1<f64>,
pub solve_d2mu_deta2: Array1<f64>,
pub solve_d3mu_deta3: Array1<f64>,
pub solve_c_array: Array1<f64>,
pub solve_d_array: Array1<f64>,
pub derivatives_unsupported: bool,
pub status: crate::solver::pirls::PirlsStatus,
pub ridge_passport: crate::types::RidgePassport,
pub firth: crate::solver::pirls::FirthDiagnostics,
pub constraint_kkt: Option<crate::solver::active_set::ConstraintKktDiagnostics>,
pub edf: f64,
pub last_deviance_change: f64,
pub last_step_halving: usize,
pub last_step_size: f64,
pub final_lm_lambda: f64,
pub min_deviance: f64,
pub max_abs_eta: f64,
pub per_row_status_or: u32,
}
pub(super) fn pirls_loop(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
loop_ws: &mut PirlsLoopWorkspace,
family: crate::gpu::pirls_row::PirlsRowFamily,
curvature: crate::gpu::pirls_row::CurvatureMode,
gamma_shape: f64,
beta0_host: ArrayView1<'_, f64>,
penalty_hessian: ArrayView2<'_, f64>,
linear_shift: ArrayView1<'_, f64>,
constant_shift: f64,
lm_ridge: f64,
objective_ridge: f64,
max_iter: usize,
tol: f64,
extra: Option<&PirlsLoopExtra<'_>>,
) -> Result<PirlsLoopOutcome, String> {
let n = shared.n;
let p = shared.p;
if loop_ws.n != n || loop_ws.p != p {
return Err(format!(
"loop workspace ({}, {}) ≠ shared ({n}, {p})",
loop_ws.n, loop_ws.p
));
}
if beta0_host.len() != p {
return Err(format!("beta0 length {} ≠ p={p}", beta0_host.len()));
}
if linear_shift.len() != p {
return Err(format!(
"linear_shift length {} ≠ p={p}",
linear_shift.len()
));
}
if penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty_hessian shape {:?} ≠ (p={p}, p={p})",
penalty_hessian.dim()
));
}
ws.stream
.memcpy_htod(
beta0_host.as_slice().ok_or("beta0 not contiguous")?,
&mut loop_ws.beta_dev,
)
.map_err(|e| format!("upload beta0: {e}"))?;
let backend = crate::gpu::pirls_row::PirlsRowBackend::probe()
.map_err(|e| format!("pirls_row backend: {e}"))?;
let loop_module = PIRLS_LOOP_CACHE
.get_or_compile(&shared.ctx, "pirls_loop", PIRLS_LOOP_PTX_SOURCE)
.map_err(|e| format!("pirls loop module: {e}"))?;
let axpy_func = loop_module
.load_function("axpy_n")
.map_err(|e| format!("load axpy_n: {e}"))?;
let sum_func = loop_module
.load_function("deviance_sum")
.map_err(|e| format!("load deviance_sum: {e}"))?;
let linf_func = loop_module
.load_function("linf_norm")
.map_err(|e| format!("load linf_norm: {e}"))?;
let status_or_func = loop_module
.load_function("status_or")
.map_err(|e| format!("load status_or: {e}"))?;
gemv_no_trans(
&ws.blas,
p,
p,
&ws.qs_dev,
&loop_ws.beta_dev,
&mut ws.beta_orig_dev,
)?;
gemv_no_trans(
&ws.blas,
n,
p,
&shared.x_original_dev,
&ws.beta_orig_dev,
&mut loop_ws.eta_dev,
)?;
axpy(
&ws.stream,
&axpy_func,
1.0,
&shared.offset_dev,
&mut loop_ws.eta_dev,
n,
)?;
crate::gpu::pirls_row::launch_solve_row_on_stream(
backend,
family,
curvature,
gamma_shape,
&ws.stream,
n,
&loop_ws.eta_dev,
&shared.y_dev,
&shared.prior_w_dev,
&mut loop_ws.row_solve,
)
.map_err(|e| format!("solve-row init: {e}"))?;
let mut prev_deviance = reduce_scalar(
&ws.stream,
&sum_func,
&loop_ws.row_solve.deviance,
n,
&mut loop_ws.scalar_dev,
"deviance_init",
)?;
let mut last_logdet = 0.0_f64;
let mut converged = false;
let mut beta_host: Array1<f64> = beta0_host.to_owned();
let s_beta0 = penalty_hessian.dot(&beta_host);
let penalty_init =
beta_host.dot(&s_beta0) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
let mut prev_objective = prev_deviance + penalty_init;
let mut last_dev_delta = 0.0_f64;
let mut last_halving: usize = 0;
let mut last_step_size = 0.0_f64;
let mut min_dev = prev_deviance;
let mut step_search_exhausted = false;
for it in 0..max_iter {
last_logdet = solve_step_on_stream_device_inplace(
shared,
ws,
PirlsStepStreamDeviceInput {
w_solver_dev: &loop_ws.row_solve.w_solver,
grad_eta_dev: &loop_ws.row_solve.grad_eta,
penalty_hessian,
step_lm_lambda: lm_ridge,
objective_ridge,
beta_dev: &loop_ws.beta_dev,
linear_shift,
},
)
.map_err(|e| format!("inner step it={it}: {e}"))?;
ws.stream
.memcpy_dtod(&ws.rhs_dev, &mut loop_ws.direction_dev)
.map_err(|e| format!("direction d2d copy it={it}: {e}"))?;
let dir_linf = reduce_scalar(
&ws.stream,
&linf_func,
&loop_ws.direction_dev,
p,
&mut loop_ws.scalar_dev,
"dir_linf",
)?;
gemv_no_trans(
&ws.blas,
p,
p,
&ws.qs_dev,
&loop_ws.direction_dev,
&mut ws.dir_orig_dev,
)?;
gemv_no_trans(
&ws.blas,
n,
p,
&shared.x_original_dev,
&ws.dir_orig_dev,
&mut loop_ws.xd_dev,
)?;
loop_ws
.alpha_ladder
.zero(&ws.stream)
.map_err(|e| format!("ladder zero it={it}: {e}"))?;
crate::gpu::pirls_row::launch_alpha_ladder_on_stream(
backend,
family,
curvature,
gamma_shape,
&ws.stream,
n,
&loop_ws.eta_dev,
&loop_ws.xd_dev,
&shared.y_dev,
&shared.prior_w_dev,
&mut loop_ws.alpha_ladder,
)
.map_err(|e| format!("alpha-ladder it={it}: {e}"))?;
let obj_host: Vec<f64> = ws
.stream
.clone_dtoh(&loop_ws.alpha_ladder.objective_dev)
.map_err(|e| format!("ladder dtoh obj it={it}: {e}"))?;
let stat_host: Vec<u32> = ws
.stream
.clone_dtoh(&loop_ws.alpha_ladder.status_dev)
.map_err(|e| format!("ladder dtoh stat it={it}: {e}"))?;
let direction_host: Vec<f64> = ws
.stream
.clone_dtoh(&loop_ws.direction_dev)
.map_err(|e| format!("dtoh direction it={it}: {e}"))?;
let dir_view = ndarray::aview1(&direction_host);
let sd = penalty_hessian.dot(&dir_view);
let s_beta = penalty_hessian.dot(&beta_host);
let dtsd = dir_view.dot(&sd);
let linear_coeff = 2.0 * dir_view.dot(&(&s_beta - &linear_shift));
let penalty_beta =
beta_host.dot(&s_beta) - 2.0 * beta_host.dot(&linear_shift) + constant_shift;
const FORBIDDEN_LINESEARCH: u32 = crate::gpu::pirls_row::status_flags::INVALID_RESPONSE
| crate::gpu::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
let mut alpha = 0.0_f64;
let mut accepted_dev = prev_deviance;
let mut accepted_objective = prev_objective;
let mut halving_count: usize = 0;
for (k, (&dev_k, &st)) in obj_host.iter().zip(stat_host.iter()).enumerate() {
let a = crate::gpu::pirls_row::ALPHA_LADDER[k];
let pen_k = penalty_beta + a * linear_coeff + a * a * dtsd;
let obj_k = dev_k + pen_k;
if obj_k.is_finite() && obj_k <= prev_objective && (st & FORBIDDEN_LINESEARCH) == 0
{
alpha = a;
accepted_dev = dev_k;
accepted_objective = obj_k;
halving_count = k;
break;
}
}
if alpha == 0.0 {
step_search_exhausted = true;
last_halving = 0;
last_step_size = 0.0;
last_dev_delta = 0.0;
break;
}
step_search_exhausted = false;
axpy(
&ws.stream,
&axpy_func,
alpha,
&loop_ws.direction_dev,
&mut loop_ws.beta_dev,
p,
)?;
axpy(
&ws.stream,
&axpy_func,
alpha,
&loop_ws.xd_dev,
&mut loop_ws.eta_dev,
n,
)?;
for (b, &d) in beta_host.iter_mut().zip(direction_host.iter()) {
*b += alpha * d;
}
crate::gpu::pirls_row::launch_solve_row_on_stream(
backend,
family,
curvature,
gamma_shape,
&ws.stream,
n,
&loop_ws.eta_dev,
&shared.y_dev,
&shared.prior_w_dev,
&mut loop_ws.row_solve,
)
.map_err(|e| format!("solve-row accepted it={it}: {e}"))?;
let step_norm = alpha.abs() * dir_linf;
let dev_delta = (prev_objective - accepted_objective).abs();
last_dev_delta = dev_delta;
last_halving = halving_count;
last_step_size = alpha;
if accepted_dev < min_dev {
min_dev = accepted_dev;
}
prev_deviance = accepted_dev;
prev_objective = accepted_objective;
if dir_linf <= tol
&& step_norm <= tol
&& dev_delta <= tol * (1.0 + prev_objective.abs())
{
converged = true;
crate::gpu::pirls_row::launch_row_reweight_on_stream(
backend,
family,
curvature,
gamma_shape,
&ws.stream,
n,
&loop_ws.eta_dev,
&shared.y_dev,
&shared.prior_w_dev,
&mut loop_ws.row_final,
)
.map_err(|e| format!("final-row converged: {e}"))?;
let h_final = rebuild_h_final(
shared,
ws,
&loop_ws.row_final.w_hessian,
penalty_hessian,
objective_ridge,
)
.map_err(|e| format!("rebuild H_final (converged): {e}"))?;
return build_loop_outcome(
ws,
loop_ws,
h_final,
last_logdet,
prev_deviance,
it + 1,
converged,
lm_ridge,
objective_ridge,
extra,
LoopDiagnostics {
last_deviance_change: last_dev_delta,
last_step_halving: last_halving,
last_step_size,
min_deviance: min_dev,
step_search_exhausted,
},
&status_or_func,
);
}
}
crate::gpu::pirls_row::launch_row_reweight_on_stream(
backend,
family,
curvature,
gamma_shape,
&ws.stream,
n,
&loop_ws.eta_dev,
&shared.y_dev,
&shared.prior_w_dev,
&mut loop_ws.row_final,
)
.map_err(|e| format!("final-row max_iter: {e}"))?;
let h_final = rebuild_h_final(
shared,
ws,
&loop_ws.row_final.w_hessian,
penalty_hessian,
objective_ridge,
)
.map_err(|e| format!("rebuild H_final (max_iter): {e}"))?;
build_loop_outcome(
ws,
loop_ws,
h_final,
last_logdet,
prev_deviance,
max_iter,
converged,
lm_ridge,
objective_ridge,
extra,
LoopDiagnostics {
last_deviance_change: last_dev_delta,
last_step_halving: last_halving,
last_step_size,
min_deviance: min_dev,
step_search_exhausted,
},
&status_or_func,
)
}
struct LoopDiagnostics {
last_deviance_change: f64,
last_step_halving: usize,
last_step_size: f64,
min_deviance: f64,
step_search_exhausted: bool,
}
fn build_loop_outcome(
ws: &mut SigmaPirlsGpuWorkspace,
loop_ws: &mut PirlsLoopWorkspace,
penalized_hessian: Array2<f64>,
logdet: f64,
deviance: f64,
iterations: usize,
converged: bool,
step_lm_lambda: f64,
objective_ridge: f64,
extra: Option<&PirlsLoopExtra<'_>>,
diagnostics: LoopDiagnostics,
status_or_func: &cudarc::driver::CudaFunction,
) -> Result<PirlsLoopOutcome, String> {
let beta = download_vec(&ws.stream, &loop_ws.beta_dev)?;
let final_eta = download_vec(&ws.stream, &loop_ws.eta_dev)?;
let final_mu = download_vec(&ws.stream, &loop_ws.row_final.mu)?;
let final_grad_eta = download_vec(&ws.stream, &loop_ws.row_final.grad_eta)?;
let final_w_hessian = download_vec(&ws.stream, &loop_ws.row_final.w_hessian)?;
let final_w_solver = download_vec(&ws.stream, &loop_ws.row_final.w_solver)?;
let n_rows = loop_ws.n;
let final_row_status = reduce_status_or(
&ws.stream,
status_or_func,
&loop_ws.row_final.status,
n_rows,
&mut loop_ws.status_u32_dev,
"final_row_status",
)?;
const FORBIDDEN_FINAL: u32 = crate::gpu::pirls_row::status_flags::INVALID_RESPONSE
| crate::gpu::pirls_row::status_flags::ZERO_PRIOR_WEIGHT;
let eta_finite = final_eta.iter().all(|v| v.is_finite());
let mu_finite = final_mu.iter().all(|v| v.is_finite());
let beta_finite = beta.iter().all(|v| v.is_finite());
let stability_ok =
eta_finite && mu_finite && beta_finite && (final_row_status & FORBIDDEN_FINAL) == 0;
let status = if !stability_ok {
crate::solver::pirls::PirlsStatus::Unstable
} else if converged {
crate::solver::pirls::PirlsStatus::Converged
} else if diagnostics.step_search_exhausted {
crate::solver::pirls::PirlsStatus::LmStepSearchExhausted
} else {
crate::solver::pirls::PirlsStatus::MaxIterationsReached
};
let default_ridge = crate::types::RidgePassport::scaled_identity(
objective_ridge,
crate::types::RidgePolicy::explicit_stabilization_full(),
);
let max_abs_eta = final_eta.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
match extra {
Some(ext) => {
let (score_c, score_d, solve_dmu_deta, solve_d2mu_deta2, solve_d3mu_deta3) =
crate::solver::pirls::computeworkingweight_derivatives_from_eta(
ext.likelihood,
ext.inverse_link,
&final_eta,
ext.priorweights,
)
.map_err(|e| format!("pirls postpass dmu/deta: {e:?}"))?;
let (finalweights, solve_c_array, solve_d_array) = match ext.exported_curvature {
crate::solver::pirls::HessianCurvatureKind::Observed => {
crate::solver::pirls::compute_observed_hessian_curvature_arrays(
ext.likelihood,
ext.inverse_link,
&final_eta,
ext.y,
&final_w_solver,
ext.priorweights,
)
.map_err(|e| format!("pirls postpass observed curvature: {e:?}"))?
}
crate::solver::pirls::HessianCurvatureKind::Fisher => {
(final_w_solver.clone(), score_c.clone(), score_d.clone())
}
};
let beta_transformed = beta.clone();
let constraint_kkt = ext.linear_constraints.and_then(|lin| {
if lin.a.nrows() == 0 {
return None;
}
let grad = penalized_hessian.dot(&beta);
Some(
crate::solver::active_set::compute_constraint_kkt_diagnostics(
&beta, &grad, lin,
),
)
});
let ridge_passport = ext.ridge_passport.unwrap_or(default_ridge);
let firth = ext
.firth
.clone()
.unwrap_or(crate::solver::pirls::FirthDiagnostics::Inactive);
let edf = ext.edf.unwrap_or(f64::NAN);
let derivatives_unsupported = false;
Ok(PirlsLoopOutcome {
beta,
penalized_hessian,
logdet,
deviance,
iterations,
converged,
final_eta,
final_mu,
final_grad_eta,
final_w_hessian,
final_w_solver: final_w_solver.clone(),
final_offset: ext.offset.to_owned(),
beta_transformed,
finalweights,
solveweights: final_w_solver,
solve_dmu_deta,
solve_d2mu_deta2,
solve_d3mu_deta3,
solve_c_array,
solve_d_array,
derivatives_unsupported,
status,
ridge_passport,
firth,
constraint_kkt,
edf,
last_deviance_change: diagnostics.last_deviance_change,
last_step_halving: diagnostics.last_step_halving,
last_step_size: diagnostics.last_step_size,
final_lm_lambda: step_lm_lambda,
min_deviance: diagnostics.min_deviance,
max_abs_eta,
per_row_status_or: final_row_status,
})
}
None => {
Ok(PirlsLoopOutcome {
beta: beta.clone(),
penalized_hessian,
logdet,
deviance,
iterations,
converged,
final_eta,
final_mu,
final_grad_eta,
final_w_hessian,
final_w_solver: final_w_solver.clone(),
final_offset: Array1::<f64>::zeros(0),
beta_transformed: beta,
finalweights: Array1::<f64>::zeros(0),
solveweights: final_w_solver,
solve_dmu_deta: Array1::<f64>::zeros(0),
solve_d2mu_deta2: Array1::<f64>::zeros(0),
solve_d3mu_deta3: Array1::<f64>::zeros(0),
solve_c_array: Array1::<f64>::zeros(0),
solve_d_array: Array1::<f64>::zeros(0),
derivatives_unsupported: true,
status,
ridge_passport: default_ridge,
firth: crate::solver::pirls::FirthDiagnostics::Inactive,
constraint_kkt: None,
edf: f64::NAN,
last_deviance_change: diagnostics.last_deviance_change,
last_step_halving: diagnostics.last_step_halving,
last_step_size: diagnostics.last_step_size,
final_lm_lambda: step_lm_lambda,
min_deviance: diagnostics.min_deviance,
max_abs_eta,
per_row_status_or: final_row_status,
})
}
}
}
fn gemv_no_trans(
blas: &CudaBlas,
n: usize,
p: usize,
a_dev: &CudaSlice<f64>,
x_dev: &CudaSlice<f64>,
y_dev: &mut CudaSlice<f64>,
) -> Result<(), String> {
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_N,
m: n_i,
n: p_i,
alpha: 1.0,
lda: n_i,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe { blas.gemv(cfg, a_dev, x_dev, y_dev) }.map_err(|e| format!("dgemv no-trans: {e}"))
}
fn axpy(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
func: &cudarc::driver::CudaFunction,
alpha: f64,
x_dev: &CudaSlice<f64>,
y_dev: &mut CudaSlice<f64>,
n: usize,
) -> Result<(), String> {
const THREADS: u32 = 256;
let n_i = to_i32(n)?;
let n_u = u32::try_from(n).map_err(|_| format!("axpy n={n} > u32"))?;
let grid = n_u.div_ceil(THREADS).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(func);
builder.arg(&alpha);
builder.arg(x_dev);
builder.arg(y_dev);
builder.arg(&n_i);
unsafe { builder.launch(cfg) }
.map(|_event_pair| ())
.map_err(|e| format!("axpy launch: {e}"))
}
fn reduce_scalar(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
func: &cudarc::driver::CudaFunction,
src: &CudaSlice<f64>,
len: usize,
scalar_dev: &mut CudaSlice<f64>,
label: &'static str,
) -> Result<f64, String> {
const THREADS: u32 = 1024;
let len_i = to_i32(len)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(func);
builder.arg(src);
builder.arg(&len_i);
builder.arg(&mut *scalar_dev);
unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} reduce launch: {e}"))?;
let host = stream
.clone_dtoh(scalar_dev)
.map_err(|e| format!("download {label}: {e}"))?;
Ok(host[0])
}
fn reduce_status_or(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
func: &cudarc::driver::CudaFunction,
src: &CudaSlice<u32>,
len: usize,
status_dev: &mut CudaSlice<u32>,
label: &'static str,
) -> Result<u32, String> {
const THREADS: u32 = 1024;
let len_i = to_i32(len)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (THREADS, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(func);
builder.arg(src);
builder.arg(&len_i);
builder.arg(&mut *status_dev);
unsafe { builder.launch(cfg) }.map_err(|e| format!("{label} or reduce launch: {e}"))?;
let host = stream
.clone_dtoh(status_dev)
.map_err(|e| format!("download {label}: {e}"))?;
Ok(host[0])
}
fn download_vec(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
dev: &CudaSlice<f64>,
) -> Result<Array1<f64>, String> {
let host = stream
.clone_dtoh(dev)
.map_err(|e| format!("download vec: {e}"))?;
Ok(Array1::from_vec(host))
}
pub struct GaussianPlsResult {
pub beta: Array1<f64>,
pub penalized_hessian: Array2<f64>,
pub logdet: f64,
}
pub fn solve_gaussian_pls_on_stream(
a_orig: ArrayView2<'_, f64>,
b_orig: ArrayView1<'_, f64>,
s_transformed: ArrayView2<'_, f64>,
linear_shift: ArrayView1<'_, f64>,
prior_mean_target: ArrayView1<'_, f64>,
ridge: f64,
qs: Option<ArrayView2<'_, f64>>,
) -> Result<GaussianPlsResult, String> {
let p = b_orig.len();
if a_orig.dim() != (p, p) {
return Err(format!("A shape {:?} != ({p},{p})", a_orig.dim()));
}
if s_transformed.dim() != (p, p) {
return Err(format!("S shape {:?} != ({p},{p})", s_transformed.dim()));
}
if linear_shift.len() != p {
return Err(format!("linear_shift len {} != p={p}", linear_shift.len()));
}
if prior_mean_target.len() != p {
return Err(format!(
"prior_mean_target len {} != p={p}",
prior_mean_target.len()
));
}
if let Some(qs_v) = qs {
if qs_v.dim() != (p, p) {
return Err(format!("qs shape {:?} != ({p},{p})", qs_v.dim()));
}
}
let (h_rotated, rhs_base) = if let Some(qs_v) = qs {
let qs_owned = qs_v.to_owned();
let tmp = a_orig.dot(&qs_owned);
let h = qs_owned.t().dot(&tmp);
let rb = qs_owned.t().dot(&b_orig);
(h, rb)
} else {
(a_orig.to_owned(), b_orig.to_owned())
};
let penalized_hessian: Array2<f64> = &h_rotated + &s_transformed;
let mut regularized = penalized_hessian.clone();
if ridge > 0.0 {
for i in 0..p {
regularized[[i, i]] += ridge;
}
}
let mut rhs_host = rhs_base;
rhs_host += &linear_shift;
if ridge > 0.0 {
rhs_host.scaled_add(ridge, &prior_mean_target);
}
let (ctx, stream) = context_and_stream()?;
let solver = DnHandle::new(stream.clone())
.map_err(|e| format!("cusolver init (gaussian pls): {e}"))?;
let pp = p.checked_mul(p).ok_or("p*p overflow (gaussian pls)")?;
let mut h_dev = stream
.alloc_zeros::<f64>(pp)
.map_err(|e| format!("alloc H (gaussian pls): {e}"))?;
let mut rhs_dev = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("alloc rhs (gaussian pls): {e}"))?;
let potrf_lwork_usize = potrf_query_lwork(&solver, &stream, p)?;
let potrf_lwork = i32::try_from(potrf_lwork_usize)
.map_err(|_| "potrf lwork overflow (gaussian pls)".to_string())?;
let mut potrf_work_dev = stream
.alloc_zeros::<f64>(potrf_lwork_usize.max(1))
.map_err(|e| format!("alloc potrf workspace (gaussian pls): {e}"))?;
let mut potrf_info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|e| format!("alloc potrf info (gaussian pls): {e}"))?;
let mut potrs_info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|e| format!("alloc potrs info (gaussian pls): {e}"))?;
let reg_col = to_col_major(®ularized);
stream
.memcpy_htod(reg_col.as_ref(), &mut h_dev)
.map_err(|e| format!("upload H (gaussian pls): {e}"))?;
let rhs_slice = rhs_host
.as_slice()
.ok_or("rhs_host not contiguous (gaussian pls)")?;
stream
.memcpy_htod(rhs_slice, &mut rhs_dev)
.map_err(|e| format!("upload rhs (gaussian pls): {e}"))?;
potrf_in_place_reuse(
&solver,
&stream,
p,
potrf_lwork,
&mut h_dev,
&mut potrf_work_dev,
&mut potrf_info_dev,
)?;
potrs_in_place_reuse(
&solver,
&stream,
p,
1,
&h_dev,
&mut rhs_dev,
&mut potrs_info_dev,
)?;
let logdet = cholesky_logdet_device(&stream, &ctx, p, &h_dev)?;
let beta_raw = stream
.clone_dtoh(&rhs_dev)
.map_err(|e| format!("download beta (gaussian pls): {e}"))?;
check_deferred_potrf_info(&stream, &potrf_info_dev)?;
check_deferred_potrs_info(&stream, &potrs_info_dev)?;
Ok(GaussianPlsResult {
beta: Array1::from_vec(beta_raw),
penalized_hessian,
logdet,
})
}
}
pub fn weighted_crossprod_gpu(
x: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
#[cfg(not(target_os = "linux"))]
{
return cpu_fallback::weighted_crossprod_cpu(x, weights);
}
#[cfg(target_os = "linux")]
{
if crate::gpu::runtime::GpuRuntime::global().is_none() {
return cpu_fallback::weighted_crossprod_cpu(x, weights);
}
cuda::weighted_crossprod(x, weights)
}
}
pub fn solve_pirls_step_gpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
#[cfg(not(target_os = "linux"))]
{
return cpu_fallback::solve_step_cpu(input);
}
#[cfg(target_os = "linux")]
{
if crate::gpu::runtime::GpuRuntime::global().is_none() {
return cpu_fallback::solve_step_cpu(input);
}
cuda::solve_step(input)
}
}
#[cfg(target_os = "linux")]
pub fn upload_shared_pirls_gpu(
x: ndarray::ArrayView2<'_, f64>,
y: ndarray::ArrayView1<'_, f64>,
prior_w: ndarray::ArrayView1<'_, f64>,
offset: ndarray::ArrayView1<'_, f64>,
) -> Result<PirlsGpuSharedData, String> {
if crate::gpu::runtime::GpuRuntime::global().is_none() {
return Err("cuda runtime unavailable; cannot upload shared GPU PIRLS data".to_string());
}
PirlsGpuSharedData::upload_impl(x, y, prior_w, offset)
}
#[cfg(target_os = "linux")]
pub fn allocate_sigma_pirls_workspace(
shared: &PirlsGpuSharedData,
) -> Result<SigmaPirlsGpuWorkspace, String> {
SigmaPirlsGpuWorkspace::allocate_impl(shared)
}
#[cfg(target_os = "linux")]
pub fn upload_qs_pirls(
ws: &mut SigmaPirlsGpuWorkspace,
qs: ndarray::ArrayView2<'_, f64>,
) -> Result<(), String> {
cuda::upload_qs(ws, qs)
}
#[cfg(target_os = "linux")]
pub fn upload_qs_identity_pirls(ws: &mut SigmaPirlsGpuWorkspace) -> Result<(), String> {
cuda::upload_qs_identity(ws)
}
#[cfg(target_os = "linux")]
pub fn solve_pirls_step_on_stream(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
input: PirlsStepStreamInput<'_>,
) -> Result<PirlsGpuStep, String> {
cuda::solve_step_on_stream(shared, ws, input)
}
#[cfg(target_os = "linux")]
pub fn solve_pirls_step_on_stream_device(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
input: PirlsStepStreamDeviceInput<'_, '_>,
) -> Result<PirlsGpuStep, String> {
cuda::solve_step_on_stream_device(shared, ws, input)
}
#[cfg(target_os = "linux")]
pub fn pirls_loop_on_stream(
shared: &PirlsGpuSharedData,
ws: &mut SigmaPirlsGpuWorkspace,
loop_ws: &mut cuda::PirlsLoopWorkspace,
family: crate::gpu::pirls_row::PirlsRowFamily,
curvature: crate::gpu::pirls_row::CurvatureMode,
gamma_shape: f64,
beta0: ndarray::ArrayView1<'_, f64>,
penalty_hessian: ndarray::ArrayView2<'_, f64>,
linear_shift: ndarray::ArrayView1<'_, f64>,
constant_shift: f64,
step_lm_lambda: f64,
objective_ridge: f64,
max_iter: usize,
tol: f64,
extra: Option<&cuda::PirlsLoopExtra<'_>>,
) -> Result<cuda::PirlsLoopOutcome, String> {
cuda::pirls_loop(
shared,
ws,
loop_ws,
family,
curvature,
gamma_shape,
beta0,
penalty_hessian,
linear_shift,
constant_shift,
step_lm_lambda,
objective_ridge,
max_iter,
tol,
extra,
)
}
#[cfg(target_os = "linux")]
pub fn allocate_pirls_loop_workspace(
shared: &PirlsGpuSharedData,
ws: &SigmaPirlsGpuWorkspace,
) -> Result<cuda::PirlsLoopWorkspace, String> {
cuda::PirlsLoopWorkspace::allocate(shared, &ws.stream)
}
#[cfg(target_os = "linux")]
pub fn solve_gaussian_pls_gpu(
a_orig: ndarray::ArrayView2<'_, f64>,
b_orig: ndarray::ArrayView1<'_, f64>,
s_transformed: ndarray::ArrayView2<'_, f64>,
linear_shift: ndarray::ArrayView1<'_, f64>,
prior_mean_target: ndarray::ArrayView1<'_, f64>,
ridge: f64,
qs: Option<ndarray::ArrayView2<'_, f64>>,
) -> Result<cuda::GaussianPlsResult, String> {
cuda::solve_gaussian_pls_on_stream(
a_orig,
b_orig,
s_transformed,
linear_shift,
prior_mean_target,
ridge,
qs,
)
}
#[cfg(target_os = "linux")]
pub struct DeviceResidentPcgInput<'a> {
pub storage: &'a crate::gpu::bms_flex_row::DeviceResidentRowHess,
pub b: &'a [f64],
pub rel_tol: f64,
pub max_iters: usize,
pub precond_diag_floor: f64,
}
#[cfg(target_os = "linux")]
pub struct DeviceResidentPcgOutput {
pub x: Vec<f64>,
pub iterations: usize,
pub final_rel_residual: f64,
}
#[cfg(target_os = "linux")]
const PCG_KERNEL_SOURCE: &str = r#"
// y[i] += a * x[i]
extern "C" __global__ void pcg_axpy(int n, double a,
const double * __restrict__ x,
double * __restrict__ y)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] += a * x[i];
}
// y[i] = a * x[i] + b * y[i]
extern "C" __global__ void pcg_axpby(int n, double a,
const double * __restrict__ x,
double b,
double * __restrict__ y)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = a * x[i] + b * y[i];
}
// z[i] = r[i] / clamp(diag[i], floor) (sign-preserving floor on |diag|).
extern "C" __global__ void pcg_apply_diag_precond(int n, double floor_val,
const double * __restrict__ diag,
const double * __restrict__ r,
double * __restrict__ z)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
double d = diag[i];
double ad = d < 0 ? -d : d;
double clamped = ad > floor_val ? d : (d >= 0.0 ? floor_val : -floor_val);
z[i] = r[i] / clamped;
}
}
// Single-block dot product; writes the scalar to out[0]. n must be <= 1024.
extern "C" __global__ void pcg_dot_single_block(int n,
const double * __restrict__ a,
const double * __restrict__ b,
double * __restrict__ out)
{
__shared__ double s[1024];
int tid = threadIdx.x;
double acc = 0.0;
for (int i = tid; i < n; i += blockDim.x) acc += a[i] * b[i];
s[tid] = acc;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) s[tid] += s[tid + stride];
__syncthreads();
}
if (tid == 0) out[0] = s[0];
}
// Set out[i] = 0 for i in [0, n).
extern "C" __global__ void pcg_init_zero(int n, double * __restrict__ out) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = 0.0;
}
// Copy y[i] = x[i].
extern "C" __global__ void pcg_copy(int n,
const double * __restrict__ x,
double * __restrict__ y)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = x[i];
}
"#;
#[cfg(target_os = "linux")]
mod pcg_device {
use super::DeviceResidentPcgInput;
use super::DeviceResidentPcgOutput;
use super::PCG_KERNEL_SOURCE;
use crate::gpu::bms_flex_row::launch_bms_flex_row_diagonal;
use crate::gpu::bms_flex_row::launch_bms_flex_row_hvp_into_device;
use cudarc::driver::{CudaModule, CudaStream, LaunchConfig, PushKernelArg};
use std::sync::{Arc, OnceLock};
struct PcgBackend {
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
}
impl PcgBackend {
fn probe() -> Result<&'static Self, String> {
static BACKEND: OnceLock<Result<PcgBackend, String>> = OnceLock::new();
BACKEND
.get_or_init(|| {
let runtime = crate::gpu::runtime::GpuRuntime::global()
.ok_or_else(|| "pcg backend: no CUDA runtime available".to_string())?;
let ctx =
crate::gpu::runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or_else(|| {
format!(
"pcg backend: failed to create CUDA context for device {}",
runtime.selected_device().ordinal
)
})?;
let stream = ctx.default_stream();
let ptx = cudarc::nvrtc::compile_ptx(PCG_KERNEL_SOURCE)
.map_err(|err| format!("pcg NVRTC compile failed: {err}"))?;
let module = ctx
.load_module(ptx)
.map_err(|err| format!("pcg module load failed: {err}"))?;
Ok(PcgBackend { stream, module })
})
.as_ref()
.map_err(String::clone)
}
}
fn launch_blocks(p: usize, threads: u32) -> u32 {
((p as u32) + threads - 1) / threads
}
pub(super) fn run(
input: DeviceResidentPcgInput<'_>,
) -> Result<DeviceResidentPcgOutput, String> {
let p = input.storage.block.p_total;
if input.b.len() != p {
return Err(format!(
"device-resident pcg: b.len()={} != p_total={p}",
input.b.len()
));
}
if !input.rel_tol.is_finite() || input.rel_tol <= 0.0 {
return Err(format!(
"device-resident pcg: rel_tol must be positive and finite (got {})",
input.rel_tol
));
}
if input.max_iters == 0 {
return Err("device-resident pcg: max_iters must be >= 1".to_string());
}
if !input.precond_diag_floor.is_finite() || input.precond_diag_floor <= 0.0 {
return Err(format!(
"device-resident pcg: precond_diag_floor must be positive and finite (got {})",
input.precond_diag_floor
));
}
let backend = PcgBackend::probe()?;
let stream = backend.stream.clone();
let module = backend.module.clone();
let f_axpy = module
.load_function("pcg_axpy")
.map_err(|e| format!("pcg load pcg_axpy: {e}"))?;
let f_axpby = module
.load_function("pcg_axpby")
.map_err(|e| format!("pcg load pcg_axpby: {e}"))?;
let f_precond = module
.load_function("pcg_apply_diag_precond")
.map_err(|e| format!("pcg load pcg_apply_diag_precond: {e}"))?;
let f_dot = module
.load_function("pcg_dot_single_block")
.map_err(|e| format!("pcg load pcg_dot_single_block: {e}"))?;
let f_copy = module
.load_function("pcg_copy")
.map_err(|e| format!("pcg load pcg_copy: {e}"))?;
let mut d_x = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("pcg alloc x: {e}"))?;
let mut d_r = stream
.clone_htod(input.b)
.map_err(|e| format!("pcg upload b -> r: {e}"))?;
let mut d_z = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("pcg alloc z: {e}"))?;
let mut d_p = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("pcg alloc p: {e}"))?;
let mut d_q = stream
.alloc_zeros::<f64>(p)
.map_err(|e| format!("pcg alloc q: {e}"))?;
let mut d_scalar = stream
.alloc_zeros::<f64>(1)
.map_err(|e| format!("pcg alloc scalar: {e}"))?;
let diag_host = launch_bms_flex_row_diagonal(input.storage)
.map_err(|e| format!("pcg diag fetch: {e}"))?;
if diag_host.len() != p {
return Err(format!(
"pcg: diag length {} != p_total {p}",
diag_host.len()
));
}
let d_diag = stream
.clone_htod(&diag_host)
.map_err(|e| format!("pcg upload diag: {e}"))?;
let n_i32 = i32::try_from(p).map_err(|_| format!("pcg: p_total={p} exceeds i32 range"))?;
let vec_threads: u32 = 64;
let vec_blocks = launch_blocks(p, vec_threads);
let dot_threads: u32 = match p {
0..=64 => 64,
65..=128 => 128,
129..=256 => 256,
257..=512 => 512,
_ => 1024,
};
if p > 1024 {
return Err(format!(
"device-resident pcg: p_total={p} exceeds single-block dot capacity (1024); \
widen pcg_dot_single_block to multi-block reduce before raising the cap"
));
}
unsafe {
stream
.launch_builder(&f_dot)
.arg(&n_i32)
.arg(&d_r)
.arg(&d_r)
.arg(&mut d_scalar)
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (dot_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg b·b launch: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("pcg b·b sync: {e}"))?;
let host_scalar = stream
.clone_dtoh(&d_scalar)
.map_err(|e| format!("pcg b·b download: {e}"))?;
let bb = host_scalar[0];
if !bb.is_finite() {
return Err(format!("pcg: b·b not finite ({bb})"));
}
let b_norm = bb.sqrt();
if b_norm == 0.0 {
return Ok(DeviceResidentPcgOutput {
x: vec![0.0; p],
iterations: 0,
final_rel_residual: 0.0,
});
}
unsafe {
stream
.launch_builder(&f_precond)
.arg(&n_i32)
.arg(&input.precond_diag_floor)
.arg(&d_diag)
.arg(&d_r)
.arg(&mut d_z)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg precond z₀: {e}"))?;
unsafe {
stream
.launch_builder(&f_copy)
.arg(&n_i32)
.arg(&d_z)
.arg(&mut d_p)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg copy p₀: {e}"))?;
unsafe {
stream
.launch_builder(&f_dot)
.arg(&n_i32)
.arg(&d_r)
.arg(&d_z)
.arg(&mut d_scalar)
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (dot_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg ρ₀ launch: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("pcg ρ₀ sync: {e}"))?;
let s = stream
.clone_dtoh(&d_scalar)
.map_err(|e| format!("pcg ρ₀ download: {e}"))?;
let mut rho = s[0];
if !rho.is_finite() {
return Err(format!("pcg: ρ₀ not finite ({rho})"));
}
let mut iters_taken: usize = 0;
let mut final_rel_residual: f64 = (bb.sqrt() / b_norm).max(0.0);
for iter in 0..input.max_iters {
iters_taken = iter + 1;
launch_bms_flex_row_hvp_into_device(input.storage, &d_p, &mut d_q)
.map_err(|e| format!("pcg Hv iter {iter}: {e}"))?;
unsafe {
stream
.launch_builder(&f_dot)
.arg(&n_i32)
.arg(&d_p)
.arg(&d_q)
.arg(&mut d_scalar)
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (dot_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg p·q launch iter {iter}: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("pcg p·q sync iter {iter}: {e}"))?;
let s = stream
.clone_dtoh(&d_scalar)
.map_err(|e| format!("pcg p·q download iter {iter}: {e}"))?;
let pq = s[0];
if !pq.is_finite() || pq == 0.0 {
return Err(format!(
"pcg iter {iter}: p·q={pq} (non-finite or zero); operator is not positive-definite"
));
}
let alpha = rho / pq;
unsafe {
stream
.launch_builder(&f_axpy)
.arg(&n_i32)
.arg(&alpha)
.arg(&d_p)
.arg(&mut d_x)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg x+=αp iter {iter}: {e}"))?;
let neg_alpha = -alpha;
unsafe {
stream
.launch_builder(&f_axpy)
.arg(&n_i32)
.arg(&neg_alpha)
.arg(&d_q)
.arg(&mut d_r)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg r-=αq iter {iter}: {e}"))?;
unsafe {
stream
.launch_builder(&f_dot)
.arg(&n_i32)
.arg(&d_r)
.arg(&d_r)
.arg(&mut d_scalar)
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (dot_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg ‖r‖₂² launch iter {iter}: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("pcg ‖r‖₂² sync iter {iter}: {e}"))?;
let s = stream
.clone_dtoh(&d_scalar)
.map_err(|e| format!("pcg ‖r‖₂² download iter {iter}: {e}"))?;
let rr = s[0];
if !rr.is_finite() {
return Err(format!("pcg iter {iter}: ‖r‖₂²={rr} non-finite"));
}
let rel = rr.sqrt() / b_norm;
final_rel_residual = rel;
if rel <= input.rel_tol {
break;
}
unsafe {
stream
.launch_builder(&f_precond)
.arg(&n_i32)
.arg(&input.precond_diag_floor)
.arg(&d_diag)
.arg(&d_r)
.arg(&mut d_z)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg z=M⁻¹r iter {iter}: {e}"))?;
unsafe {
stream
.launch_builder(&f_dot)
.arg(&n_i32)
.arg(&d_r)
.arg(&d_z)
.arg(&mut d_scalar)
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (dot_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg ρ_new launch iter {iter}: {e}"))?;
stream
.synchronize()
.map_err(|e| format!("pcg ρ_new sync iter {iter}: {e}"))?;
let s = stream
.clone_dtoh(&d_scalar)
.map_err(|e| format!("pcg ρ_new download iter {iter}: {e}"))?;
let rho_new = s[0];
if !rho_new.is_finite() {
return Err(format!("pcg iter {iter}: ρ_new={rho_new} non-finite"));
}
let beta_pcg = rho_new / rho;
unsafe {
stream
.launch_builder(&f_axpby)
.arg(&n_i32)
.arg(&1.0_f64)
.arg(&d_z)
.arg(&beta_pcg)
.arg(&mut d_p)
.launch(LaunchConfig {
grid_dim: (vec_blocks, 1, 1),
block_dim: (vec_threads, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| format!("pcg p=z+βp iter {iter}: {e}"))?;
rho = rho_new;
}
let x_host = stream
.clone_dtoh(&d_x)
.map_err(|e| format!("pcg final x DtoH: {e}"))?;
drop(d_r);
drop(d_z);
drop(d_p);
drop(d_q);
drop(d_scalar);
drop(d_diag);
Ok(DeviceResidentPcgOutput {
x: x_host,
iterations: iters_taken,
final_rel_residual,
})
}
}
#[cfg(target_os = "linux")]
pub fn run_pcg_against_row_hessian_device(
input: DeviceResidentPcgInput<'_>,
) -> Result<DeviceResidentPcgOutput, String> {
pcg_device::run(input)
}
mod cpu_fallback {
use super::{PirlsGpuInput, PirlsGpuStep};
use crate::linalg::faer_ndarray::FaerCholesky;
use crate::solver::estimate::reml::assembly::xt_diag_x_dense_into;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
pub(super) fn weighted_crossprod_cpu(
x: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
validate(x, weights)?;
let x_owned = x.to_owned();
let w_owned = weights.to_owned();
let mut scratch = Array2::<f64>::zeros(x_owned.dim());
Ok(xt_diag_x_dense_into(&x_owned, &w_owned, &mut scratch))
}
pub(super) fn solve_step_cpu(input: PirlsGpuInput<'_>) -> Result<PirlsGpuStep, String> {
validate(input.x, input.weights)?;
let (_n, p) = input.x.dim();
if input.penalty_hessian.dim() != (p, p) {
return Err(format!(
"penalty Hessian shape {:?} does not match p={p}",
input.penalty_hessian.dim()
));
}
if input.gradient.len() != p {
return Err(format!(
"gradient length {} does not match p={p}",
input.gradient.len()
));
}
let xtwx = weighted_crossprod_cpu(input.x, input.weights)?;
let mut penalized_hessian = xtwx.clone();
penalized_hessian += &input.penalty_hessian;
if input.objective_ridge != 0.0 {
for i in 0..p {
penalized_hessian[[i, i]] += input.objective_ridge;
}
}
let mut h_step = xtwx;
h_step += &input.penalty_hessian;
if input.step_lm_lambda != 0.0 {
for i in 0..p {
h_step[[i, i]] += input.step_lm_lambda;
}
}
let factor = h_step
.cholesky(Side::Lower)
.map_err(|e| format!("CPU Cholesky failed in PIRLS fallback: {e:?}"))?;
let g = Array1::from_iter(input.gradient.iter().copied());
let direction = factor.solvevec(&g);
let logdet = 2.0 * factor.diag().iter().map(|v| v.ln()).sum::<f64>();
Ok(PirlsGpuStep {
penalized_hessian,
direction,
logdet,
})
}
fn validate(x: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>) -> Result<(), String> {
let (n, p) = x.dim();
if weights.len() != n {
return Err(format!(
"weights length {} does not match rows {n}",
weights.len()
));
}
if n == 0 || p == 0 {
return Err("empty design cannot be solved".to_string());
}
Ok(())
}
}
pub fn cholesky_solve_gpu(
hessian: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, f64), String> {
crate::gpu::solver::cholesky_solve_gpu(hessian, rhs)
}
pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
crate::gpu::solver::cholesky_lower_gpu(hessian)
}
#[cfg(all(test, target_os = "linux"))]
mod stream_device_parity_tests {
use super::*;
use ndarray::arr2;
#[test]
fn device_input_step_matches_host_input_step_on_v100() {
if crate::gpu::runtime::GpuRuntime::global().is_none() {
eprintln!("[stream_device_parity] no CUDA runtime — skipping");
return;
}
let x = arr2(&[
[1.0, 0.5, 0.1],
[0.2, -0.3, 1.4],
[0.7, 1.1, -0.2],
[-0.4, 0.9, 0.6],
[0.3, -0.8, 0.5],
]);
let weights = ndarray::arr1(&[1.0, 0.8, 1.2, 0.9, 1.05]);
let g_eta = ndarray::arr1(&[0.10_f64, -0.20, 0.05, 0.30, -0.15]);
let gradient: ndarray::Array1<f64> = x.t().dot(&g_eta);
let penalty = arr2(&[[0.4, 0.0, 0.0], [0.0, 0.9, 0.0], [0.0, 0.0, 1.2]]);
let lm_ridge = 0.1;
let n = x.nrows();
let y_dummy = ndarray::Array1::<f64>::zeros(n);
let prior_w_dummy = ndarray::Array1::<f64>::ones(n);
let offset_dummy = ndarray::Array1::<f64>::zeros(n);
let shared = upload_shared_pirls_gpu(
x.view(),
y_dummy.view(),
prior_w_dummy.view(),
offset_dummy.view(),
)
.expect("upload shared design");
let mut ws_host = allocate_sigma_pirls_workspace(&shared).expect("alloc host-input ws");
let mut ws_dev = allocate_sigma_pirls_workspace(&shared).expect("alloc device-input ws");
let host_step = solve_pirls_step_on_stream(
&shared,
&mut ws_host,
PirlsStepStreamInput {
weights: weights.view(),
penalty_hessian: penalty.view(),
gradient: gradient.view(),
step_lm_lambda: lm_ridge,
objective_ridge: 0.0,
},
)
.expect("host-input step");
let mut w_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc w_dev");
let mut g_dev = ws_dev.stream.alloc_zeros::<f64>(n).expect("alloc g_dev");
ws_dev
.stream
.memcpy_htod(weights.as_slice().unwrap(), &mut w_dev)
.expect("upload w_dev");
ws_dev
.stream
.memcpy_htod(g_eta.as_slice().unwrap(), &mut g_dev)
.expect("upload g_dev");
let beta_dev_test = ws_dev
.stream
.alloc_zeros::<f64>(x.ncols())
.expect("alloc beta_dev_test");
let linear_shift_test = ndarray::Array1::<f64>::zeros(x.ncols());
let dev_step = solve_pirls_step_on_stream_device(
&shared,
&mut ws_dev,
PirlsStepStreamDeviceInput {
w_solver_dev: &w_dev,
grad_eta_dev: &g_dev,
penalty_hessian: penalty.view(),
step_lm_lambda: lm_ridge,
objective_ridge: 0.0,
beta_dev: &beta_dev_test,
linear_shift: linear_shift_test.view(),
},
)
.expect("device-input step");
for i in 0..3 {
for j in 0..3 {
let diff = (host_step.penalized_hessian[[i, j]]
- dev_step.penalized_hessian[[i, j]])
.abs();
assert!(diff <= 1e-10, "H[{i},{j}] mismatch: {diff}");
}
}
assert!(
(host_step.logdet - dev_step.logdet).abs() <= 1e-9,
"logdet mismatch: host={} dev={}",
host_step.logdet,
dev_step.logdet
);
for i in 0..3 {
let diff = (host_step.direction[i] - dev_step.direction[i]).abs();
assert!(diff <= 1e-9, "direction[{i}] mismatch: {diff}");
}
}
#[test]
fn hill_climb_loop_beats_cpu_10x_on_biobank_logit() {
use crate::gpu::pirls_row::{CurvatureMode, PirlsRowFamily, RowInput, row_reweight_cpu};
use std::time::Instant;
if crate::gpu::runtime::GpuRuntime::global().is_none() {
eprintln!("[hill_climb] no CUDA runtime — skipping");
return;
}
let n = 80_000_usize;
let p = 44_usize;
let beta_true: ndarray::Array1<f64> = ndarray::Array1::from_iter(
(0..p).map(|j| 0.05 * ((j as f64) - 0.5 * p as f64) / p as f64),
);
let mut x = ndarray::Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x[[i, j]] = ((i as f64 + j as f64 * 17.0) * 0.001).sin();
}
}
let eta: ndarray::Array1<f64> = x.dot(&beta_true);
let y: ndarray::Array1<f64> = eta
.iter()
.enumerate()
.map(|(i, &e)| {
let mu = 0.5 * (1.0 + (0.5 * e).tanh());
if (i as f64 * 1.31).fract() < mu {
1.0
} else {
0.0
}
})
.collect();
let prior_w = ndarray::Array1::<f64>::ones(n);
let penalty = ndarray::Array2::<f64>::eye(p) * 1e-3;
let beta0 = ndarray::Array1::<f64>::zeros(p);
let offset_bench = ndarray::Array1::<f64>::zeros(n);
let shared =
upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_bench.view())
.expect("upload shared design");
let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
let t0 = Instant::now();
let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
drop(
pirls_loop_on_stream(
&shared,
&mut ws,
&mut loop_ws,
PirlsRowFamily::BernoulliLogit,
CurvatureMode::Fisher,
1.0,
beta0.view(),
penalty.view(),
linear_shift_zero.view(),
0.0,
0.0,
0.0,
30,
1e-6,
None,
)
.expect("pirls loop"),
);
let gpu_secs = t0.elapsed().as_secs_f64();
let t1 = Instant::now();
let mut beta = ndarray::Array1::<f64>::zeros(p);
for _ in 0..30 {
let eta: ndarray::Array1<f64> = x.dot(&beta);
let mut w = ndarray::Array1::<f64>::zeros(n);
let mut g = ndarray::Array1::<f64>::zeros(n);
for i in 0..n {
let out = row_reweight_cpu(
PirlsRowFamily::BernoulliLogit,
CurvatureMode::Fisher,
RowInput {
eta: eta[i],
y: y[i],
prior_weight: prior_w[i],
},
1.0,
);
w[i] = out.w_solver;
g[i] = out.grad_eta;
}
let mut wx_full = x.clone();
for j in 0..p {
for i in 0..n {
wx_full[[i, j]] *= w[i];
}
}
let h = x.t().dot(&wx_full) + &penalty;
let rhs = x.t().dot(&g);
use crate::linalg::faer_ndarray::FaerCholesky;
let chol = h
.cholesky(faer::Side::Lower)
.expect("CPU PIRLS reference Cholesky");
let d = chol.solvevec(&rhs);
for i in 0..p {
beta[i] -= d[i];
}
}
let cpu_secs = t1.elapsed().as_secs_f64();
let speedup = cpu_secs / gpu_secs;
eprintln!(
"[hill_climb] n={n} p={p} BernoulliLogit/Fisher: gpu={:.3}s cpu={:.3}s speedup={:.2}×",
gpu_secs, cpu_secs, speedup
);
assert!(
speedup >= 10.0,
"GPU PIRLS loop must be ≥10× CPU at biobank shape; got speedup={speedup:.2}× (gpu={gpu_secs:.3}s cpu={cpu_secs:.3}s)"
);
}
#[test]
fn pirls_loop_converges_to_ols_solution_on_gaussian_identity() {
if crate::gpu::runtime::GpuRuntime::global().is_none() {
eprintln!("[stage_3_3] no CUDA runtime — skipping");
return;
}
let x = arr2(&[
[1.0, 0.5, 0.1],
[0.2, -0.3, 1.4],
[0.7, 1.1, -0.2],
[-0.4, 0.9, 0.6],
[0.3, -0.8, 0.5],
[1.1, 0.2, -0.4],
[-0.6, 0.4, 0.3],
[0.8, -1.0, 0.7],
]);
let n = x.nrows();
let p = x.ncols();
let beta_true = ndarray::arr1(&[0.5_f64, -1.2, 0.3]);
let y: ndarray::Array1<f64> = x.dot(&beta_true);
let prior_w = ndarray::Array1::<f64>::ones(n);
let penalty = ndarray::Array2::<f64>::eye(p) * 1e-4; let beta0 = ndarray::Array1::<f64>::zeros(p);
let offset_ols = ndarray::Array1::<f64>::zeros(n);
let shared = upload_shared_pirls_gpu(x.view(), y.view(), prior_w.view(), offset_ols.view())
.expect("upload shared design");
let mut ws = allocate_sigma_pirls_workspace(&shared).expect("alloc ws");
let mut loop_ws = allocate_pirls_loop_workspace(&shared, &ws).expect("alloc loop_ws");
let linear_shift_zero = ndarray::Array1::<f64>::zeros(p);
let outcome = pirls_loop_on_stream(
&shared,
&mut ws,
&mut loop_ws,
crate::gpu::pirls_row::PirlsRowFamily::GaussianIdentity,
crate::gpu::pirls_row::CurvatureMode::Fisher,
1.0,
beta0.view(),
penalty.view(),
linear_shift_zero.view(),
0.0,
0.0,
0.0,
20,
1e-9,
None,
)
.expect("pirls loop");
let xtx = x.t().dot(&x);
let xty = x.t().dot(&y);
let h_ref = xtx + &penalty;
use crate::linalg::faer_ndarray::FaerCholesky;
let chol = h_ref
.cholesky(faer::Side::Lower)
.expect("OLS reference Cholesky");
let beta_ref: ndarray::Array1<f64> = chol.solvevec(&xty);
assert!(
outcome.converged || outcome.iterations <= 5,
"PIRLS loop did not converge in 20 iters on Gaussian-identity (iters={})",
outcome.iterations
);
for i in 0..p {
let diff = (outcome.beta[i] - beta_ref[i]).abs();
assert!(
diff <= 1e-6,
"β[{i}] mismatch: gpu={} ref={} diff={}",
outcome.beta[i],
beta_ref[i],
diff
);
}
for i in 0..p {
for j in 0..p {
let diff = (outcome.penalized_hessian[[i, j]] - h_ref[[i, j]]).abs();
assert!(diff <= 1e-8, "H[{i},{j}] mismatch: {diff}");
}
}
}
}
#[cfg(all(test, target_os = "linux"))]
mod pcg_device_parity_tests {
use super::*;
use crate::gpu::bms_flex_row::{
BmsFlexBlockLayout, BmsFlexPrimaryLayout, DeviceResidentRowHess,
};
use ndarray::Array2;
fn cpu_dense_joint_hessian(
row_hessians: &[f64],
marginal: &[f64],
logslope: &[f64],
block: &BmsFlexBlockLayout,
primary: &BmsFlexPrimaryLayout,
n: usize,
) -> Array2<f64> {
let p_total = block.p_total;
let r = primary.r;
let p_m = block.p_m;
let p_g = block.p_g;
let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
let mut h_dense = Array2::<f64>::zeros((p_total, p_total));
let mut phi = vec![vec![0.0_f64; p_total]; r];
for row in 0..n {
for col in phi.iter_mut() {
col.iter_mut().for_each(|v| *v = 0.0);
}
let mrow = &marginal[row * p_m..(row + 1) * p_m];
let grow = &logslope[row * p_g..(row + 1) * p_g];
for k in 0..p_m {
phi[0][k] = mrow[k];
}
for k in 0..p_g {
phi[1][p_m + k] = grow[k];
}
for k in 0..h_block_len {
phi[h_primary_start + k][h_block_start + k] = 1.0;
}
for k in 0..w_block_len {
phi[w_primary_start + k][w_block_start + k] = 1.0;
}
let h_row = &row_hessians[row * r * r..(row + 1) * r * r];
for u in 0..r {
for v in 0..r {
let huv = h_row[u * r + v];
if huv == 0.0 {
continue;
}
for m in 0..p_total {
let phim = phi[u][m];
if phim == 0.0 {
continue;
}
let scaled = huv * phim;
for nn in 0..p_total {
h_dense[[m, nn]] += scaled * phi[v][nn];
}
}
}
}
}
h_dense
}
fn cpu_pcg_oracle(h: &Array2<f64>, b: &[f64], rel_tol: f64) -> Vec<f64> {
let p = b.len();
let diag: ndarray::Array1<f64> =
ndarray::Array1::from_vec((0..p).map(|i| h[[i, i]]).collect());
let rhs = ndarray::Array1::from_vec(b.to_vec());
let h_owned = h.clone();
let apply = move |v: &ndarray::Array1<f64>| h_owned.dot(v);
let (x, info) =
crate::linalg::utils::solve_spd_pcg_with_info(apply, &rhs, &diag, rel_tol, 4 * p)
.expect("host PCG oracle must converge on SPD fixture");
assert!(
info.converged,
"host PCG oracle failed to converge: iters={} rel_res={}",
info.iterations, info.relative_residual_norm
);
x.to_vec()
}
#[test]
fn pcg_device_matches_dense_oracle_at_n64_r20_p44() {
let Some(_runtime) = crate::gpu::runtime::GpuRuntime::global() else {
eprintln!("[pcg_device parity] no CUDA runtime — skipping");
return;
};
let n = 64_usize;
let p_m = 14_usize;
let p_g = 12_usize;
let p_h_dim = 10_usize;
let p_w_dim = 8_usize;
let r = 2 + p_h_dim + p_w_dim;
let p_total = p_m + p_g + p_h_dim + p_w_dim;
let block = BmsFlexBlockLayout {
p_m,
p_g,
h: Some(p_m + p_g..p_m + p_g + p_h_dim),
w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
p_total,
};
let primary = BmsFlexPrimaryLayout {
h: Some(2..2 + p_h_dim),
w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
r,
};
let mut row_hessians = vec![0.0_f64; n * r * r];
for row in 0..n {
let base = row * r * r;
for u in 0..r {
for v in 0..r {
let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
row_hessians[base + u * r + v] = a;
}
}
for u in 0..r {
for v in (u + 1)..r {
let upper = row_hessians[base + u * r + v];
let lower = row_hessians[base + v * r + u];
let sym = 0.5 * (upper + lower);
row_hessians[base + u * r + v] = sym;
row_hessians[base + v * r + u] = sym;
}
row_hessians[base + u * r + u] += 4.0 * (r as f64);
}
}
let mut marginal = vec![0.0_f64; n * p_m];
for row in 0..n {
for j in 0..p_m {
let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
}
}
let mut logslope = vec![0.0_f64; n * p_g];
for row in 0..n {
for j in 0..p_g {
let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
}
}
let b: Vec<f64> = (0..p_total)
.map(|i| {
let seed = (i as f64) * 0.157 + 0.6;
seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
})
.collect();
let h_dense =
cpu_dense_joint_hessian(&row_hessians, &marginal, &logslope, &block, &primary, n);
let x_oracle = cpu_pcg_oracle(&h_dense, &b, 1e-12);
let runtime = crate::gpu::runtime::GpuRuntime::global()
.expect("runtime must exist when probe succeeded above");
let ctx = match crate::gpu::runtime::cuda_context_for(runtime.selected_device().ordinal) {
Some(c) => c,
None => {
eprintln!("[pcg_device parity] cuda_context_for failed; skipping");
return;
}
};
let stream = ctx.default_stream();
let d_h = match stream.clone_htod(&row_hessians) {
Ok(s) => s,
Err(err) => {
eprintln!("[pcg_device parity] upload h failed: {err}");
return;
}
};
let d_m = match stream.clone_htod(&marginal) {
Ok(s) => s,
Err(err) => {
eprintln!("[pcg_device parity] upload marginal failed: {err}");
return;
}
};
let d_g = match stream.clone_htod(&logslope) {
Ok(s) => s,
Err(err) => {
eprintln!("[pcg_device parity] upload logslope failed: {err}");
return;
}
};
let storage = DeviceResidentRowHess {
hess: d_h,
marginal_design: d_m,
logslope_design: d_g,
n,
r,
block,
primary,
bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
};
let out = run_pcg_against_row_hessian_device(DeviceResidentPcgInput {
storage: &storage,
b: &b,
rel_tol: 1e-10,
max_iters: 4 * p_total,
precond_diag_floor: 1e-12,
})
.expect("device-resident PCG must succeed on SPD fixture");
assert_eq!(out.x.len(), p_total);
let mut max_abs = 0.0_f64;
for i in 0..p_total {
let diff = (out.x[i] - x_oracle[i]).abs();
if diff > max_abs {
max_abs = diff;
}
}
assert!(
max_abs <= 1e-7,
"pcg_device parity ‖Δx‖∞={max_abs:.3e} > 1e-7 after {} iters \
(final rel residual={:.3e})",
out.iterations,
out.final_rel_residual
);
eprintln!(
"[pcg_device parity] n={n} p={p_total} r={r}: iters={} rel_res={:.3e} ‖Δx‖∞={:.3e}",
out.iterations, out.final_rel_residual, max_abs
);
}
}