#![allow(clippy::module_name_repetitions)]
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::solver::arrow_schur::ArrowSchurSystem;
pub struct ArrowSchurGpuSolution {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub log_det_hessian: f64,
}
#[derive(Debug, Clone)]
pub enum ArrowSchurGpuFailure {
Unavailable,
RidgeBumpRequired { row: usize, bump: f64 },
SchurFactorFailed { reason: String },
GpuRequiresDenseSystem {
had_hbb_matvec: bool,
had_htbeta_matvec: bool,
},
}
pub fn solve_arrow_newton_step(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let had_hbb_matvec = sys.hbb_matvec.is_some();
let had_htbeta_matvec = sys.htbeta_matvec.is_some();
if had_hbb_matvec || had_htbeta_matvec {
return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec,
had_htbeta_matvec,
});
}
if sys.hbb.dim() != (k, k) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
});
}
if n == 0 || d == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
if sys
.rows
.iter()
.any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
{
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "row block dimension mismatch".to_string(),
});
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
if crate::gpu::arrow_schur_nvrtc::system_admits_fused_path(sys) {
match cuda::solve_fused(sys, ridge_t, ridge_beta) {
Ok(sol) => return Ok(sol),
Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
}
Err(_) => {}
}
}
cuda::solve(sys, ridge_t, ridge_beta)
}
}
#[cfg(target_os = "linux")]
fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let mut d_buf = Vec::with_capacity(n * d * d);
let mut b_buf = Vec::with_capacity(n * d * k);
let mut g_buf = Vec::with_capacity(n * d);
for row in &sys.rows {
pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
}
(d_buf, b_buf, g_buf)
}
#[cfg(target_os = "linux")]
#[inline]
fn pack_block(
row: &crate::solver::arrow_schur::ArrowRowBlock,
ridge_t: f64,
d: usize,
k: usize,
d_buf: &mut Vec<f64>,
b_buf: &mut Vec<f64>,
g_buf: &mut Vec<f64>,
) {
for col in 0..d {
for r in 0..d {
let mut value = row.htt[[r, col]];
if r == col {
value += ridge_t;
}
d_buf.push(value);
}
}
for col in 0..k {
for r in 0..d {
b_buf.push(row.htbeta[[r, col]]);
}
}
for r in 0..d {
g_buf.push(row.gt[r]);
}
}
#[doc(hidden)]
pub fn solve_arrow_newton_step_fused_force(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
if crate::gpu::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k).is_none() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
#[cfg(not(target_os = "linux"))]
{
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
cuda::solve_fused(sys, ridge_t, ridge_beta)
}
}
pub fn gpu_schur_matvec_backend(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<crate::solver::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
if sys.htbeta_matvec.is_some() {
return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec: sys.hbb_matvec.is_some(),
had_htbeta_matvec: true,
});
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
}
}
#[doc(hidden)]
pub fn solve_arrow_newton_step_dense_reference(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, String> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
let mut h = Array2::<f64>::zeros((total, total));
let mut rhs = Array1::<f64>::zeros(total);
for (i, row) in sys.rows.iter().enumerate() {
let base = i * d;
for c in 0..d {
for r in 0..d {
h[[base + r, base + c]] = row.htt[[r, c]];
}
h[[base + c, base + c]] += ridge_t;
}
for c in 0..k {
for r in 0..d {
let value = row.htbeta[[r, c]];
h[[base + r, n * d + c]] = value;
h[[n * d + c, base + r]] = value;
}
}
for r in 0..d {
rhs[base + r] = -row.gt[r];
}
}
for c in 0..k {
for r in 0..k {
h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
}
h[[n * d + c, n * d + c]] += ridge_beta;
rhs[n * d + c] = -sys.gb[c];
}
let factor = cholesky_lower_host(h.view())
.ok_or_else(|| "dense reference Cholesky failed".to_string())?;
let mut log_det = 0.0_f64;
for i in 0..total {
log_det += factor[[i, i]].ln();
}
log_det *= 2.0;
let solved = solve_cholesky_lower_host(factor.view(), rhs.view());
let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
#[inline]
fn cholesky_lower_host(a: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
let n = a.nrows();
if n != a.ncols() {
return None;
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for kk in 0..j {
sum -= l[[i, kk]] * l[[j, kk]];
}
if i == j {
if sum <= 0.0 {
return None;
}
l[[i, i]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Some(l)
}
#[inline]
fn solve_cholesky_lower_host(l: ArrayView2<'_, f64>, rhs: ArrayView1<'_, f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = rhs[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
#[cfg(target_os = "linux")]
mod cuda {
use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_host};
use crate::gpu::driver::to_i32;
use crate::gpu::linalg::{DispatchOp, route_through_gpu};
use crate::solver::arrow_schur::ArrowSchurSystem;
use cudarc::cublas::sys::{
cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
};
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
use ndarray::Array1;
use std::sync::Arc;
pub(super) fn solve(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = crate::gpu::runtime::cuda_context_for(runtime.device.ordinal)
.and_then(|ctx| ctx.new_stream().ok())
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
let mut d_dev = stream
.clone_htod(&d_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut b_dev = stream
.clone_htod(&b_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut g_dev = stream
.clone_htod(&g_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
if let Some(idx) = info_host.iter().position(|info| *info != 0) {
let pivot = info_host[idx];
let scale = sys.rows[idx]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
row: idx,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
let schur_init: Vec<f64> = {
let mut tmp = Vec::with_capacity(k * k);
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
tmp.push(v);
}
}
tmp
};
let mut schur_dev = stream
.clone_htod(&schur_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
let mut rhs_dev = stream
.clone_htod(&rhs_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("Schur Cholesky failed at pivot {info}"),
});
}
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
let delta_beta_host = stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host.clone());
accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
let x_host = stream
.clone_dtoh(&g_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut delta_t = Array1::<f64>::zeros(n * d);
for (i, v) in x_host.iter().enumerate() {
delta_t[i] = -*v;
}
let l_local_host = stream
.clone_dtoh(&d_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut log_det = 0.0_f64;
for i in 0..n {
let base = i * d * d;
for j in 0..d {
log_det += l_local_host[base + j * d + j].ln();
}
}
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
fn potrf_batched(
solver: &DnHandle,
stream: &Arc<CudaStream>,
p: usize,
batch: usize,
matrices: &mut CudaSlice<f64>,
) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let matrix_len = p * p;
let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
let (base_ptr, _record) = matrices.device_ptr_mut(stream);
let mut ptrs = Vec::with_capacity(batch);
for idx in 0..batch {
ptrs.push(base_ptr + (idx as u64) * bytes_per);
}
let mut ptrs_dev = stream
.clone_htod(&ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut info_dev = stream
.alloc_zeros::<i32>(batch)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let status = {
let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
unsafe {
cusolver_sys::cusolverDnDpotrfBatched(
solver.cu(),
cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
p_i,
ptrs_ptr as *mut *mut f64,
p_i,
info_ptr as *mut i32,
batch_i,
)
}
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
stream
.clone_dtoh(&info_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
fn potrf_single(
solver: &DnHandle,
stream: &Arc<CudaStream>,
p: usize,
matrix: &mut CudaSlice<f64>,
) -> Result<i32, ArrowSchurGpuFailure> {
let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
let mut lwork = 0_i32;
{
let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
let status = unsafe {
cusolver_sys::cusolverDnDpotrf_bufferSize(
solver.cu(),
uplo,
p_i,
mat_ptr as *mut f64,
p_i,
&mut lwork,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
}
let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut workspace = stream
.alloc_zeros::<f64>(lwork_usize.max(1))
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
let status = unsafe {
cusolver_sys::cusolverDnDpotrf(
solver.cu(),
uplo,
p_i,
mat_ptr as *mut f64,
p_i,
work_ptr as *mut f64,
lwork,
info_ptr as *mut i32,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
}
let info_host = stream
.clone_dtoh(&info_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
Ok(info_host[0])
}
fn trsm_batched_lower_inplace(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
}
fn trsm_batched_lower_inplace_transposed(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
}
fn trsm_batched_inplace_inner(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
transposed: bool,
) -> Result<(), ArrowSchurGpuFailure> {
let alpha = 1.0_f64;
let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
let (l_base, _l_record) = l_stack.device_ptr(stream);
let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
let mut l_ptrs = Vec::with_capacity(n);
let mut rhs_ptrs = Vec::with_capacity(n);
for i in 0..n {
l_ptrs.push(l_base + (i as u64) * l_bytes_per);
rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
}
let mut l_ptrs_dev = stream
.clone_htod(&l_ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rhs_ptrs_dev = stream
.clone_htod(&rhs_ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
let op = if transposed {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let handle = *blas.handle();
let status = unsafe {
cudarc::cublas::sys::cublasDtrsmBatched(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
op,
cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
d_i,
nrhs_i,
&alpha,
l_ptrs_ptr as *const *const f64,
d_i,
rhs_ptrs_ptr as *const *mut f64,
d_i,
batch_i,
)
};
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Ok(())
}
fn trsm_single(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
l: &CudaSlice<f64>,
rhs: &mut CudaSlice<f64>,
upper: bool,
transposed: bool,
) -> Result<(), ArrowSchurGpuFailure> {
let alpha = 1.0_f64;
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let handle = *blas.handle();
let (l_ptr, _l_rec) = l.device_ptr(stream);
let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
let status = unsafe {
cudarc::cublas::sys::cublasDtrsm_v2(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
if upper {
cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
} else {
cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
},
if transposed {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
},
cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
n_i,
1,
&alpha,
l_ptr as *const f64,
n_i,
rhs_ptr as *mut f64,
n_i,
)
};
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Ok(())
}
fn accumulate_schur(
blas: &CudaBlas,
d: usize,
k: usize,
n: usize,
y_stack: &CudaSlice<f64>,
u_stack: &CudaSlice<f64>,
schur: &mut CudaSlice<f64>,
rhs: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let y_block_elems = d * k;
let u_block_elems = d;
for i in 0..n {
let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
let gemm_cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: -1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
beta: 1.0,
ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
};
unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 1.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn accumulate_back_sub_rhs(
blas: &CudaBlas,
d: usize,
k: usize,
n: usize,
y_stack: &CudaSlice<f64>,
delta_beta: &CudaSlice<f64>,
u_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let y_block_elems = d * k;
let u_block_elems = d;
for i in 0..n {
let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_N,
m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 1.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
use cudarc::driver::{CudaContext, CudaModule, LaunchConfig, PushKernelArg};
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
struct FusedModuleCache {
modules:
Mutex<HashMap<crate::gpu::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>>,
}
fn fused_module_cache() -> &'static FusedModuleCache {
static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
CACHE.get_or_init(|| FusedModuleCache {
modules: Mutex::new(HashMap::new()),
})
}
fn fused_module_for(
ctx: &Arc<CudaContext>,
key: crate::gpu::arrow_schur_nvrtc::FusedModuleCacheKey,
) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
let cache = fused_module_cache();
if let Ok(guard) = cache.modules.lock() {
if let Some(existing) = guard.get(&key) {
return Ok(existing.clone());
}
}
let src = crate::gpu::arrow_schur_nvrtc::forward_kernel_source(
key.p_max as usize,
key.r_template as usize,
);
let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
key.p_max, key.r_template
),
}
})?;
let module = ctx
.load_module(ptx)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
if let Ok(mut guard) = cache.modules.lock() {
guard.entry(key).or_insert_with(|| module.clone());
}
Ok(module)
}
fn pack_fused_host(
sys: &ArrowSchurSystem,
ridge_t: f64,
p_max: usize,
r_template: usize,
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let mut d_buf = vec![0.0_f64; n * p_max * p_max];
let mut b_buf = vec![0.0_f64; n * p_max * r_template];
let mut g_buf = vec![0.0_f64; n * p_max];
for (i, row) in sys.rows.iter().enumerate() {
for col in 0..d {
let base = (i * p_max + col) * p_max;
for r in 0..d {
let mut value = row.htt[[r, col]];
if r == col {
value += ridge_t;
}
d_buf[base + r] = value;
}
}
for col in 0..k {
let base = (i * p_max + col) * p_max;
for r in 0..d {
b_buf[base + r] = row.htbeta[[r, col]];
}
}
let g_base = i * p_max;
for r in 0..d {
g_buf[g_base + r] = row.gt[r];
}
}
(d_buf, b_buf, g_buf)
}
pub(super) fn solve_fused(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let plan = crate::gpu::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let p_max = plan.p_max;
let r_template = plan.r_template;
let runtime = crate::gpu::linalg::route_through_gpu(
crate::gpu::linalg::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ctx = crate::gpu::runtime::cuda_context_for(runtime.device.ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cap = &runtime.device.capability;
let key = crate::gpu::arrow_schur_nvrtc::FusedModuleCacheKey {
cc_major: cap.compute_major,
cc_minor: cap.compute_minor,
p_max: p_max as u32,
r_template: r_template as u32,
};
let module = fused_module_for(&ctx, key)?;
let forward = module
.load_function("arrow_schur_forward_pgroup")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let back_sub = module
.load_function("arrow_schur_back_sub_pgroup")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
let d_dev = stream
.clone_htod(&d_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let b_dev = stream
.clone_htod(&b_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let g_dev = stream
.clone_htod(&g_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut l_out = stream
.alloc_zeros::<f64>(n * p_max * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut u_out = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut y_out = stream
.alloc_zeros::<f64>(n * p_max * r_template)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut partial_s = stream
.alloc_zeros::<f64>(plan.partial_s_doubles)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut partial_r = stream
.alloc_zeros::<f64>(plan.partial_r_doubles)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut status_dev = stream
.alloc_zeros::<i32>(n)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ridge_arg = ridge_t;
{
let mut builder = stream.launch_builder(&forward);
builder
.arg(&d_dev)
.arg(&b_dev)
.arg(&g_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&ridge_arg)
.arg(&mut l_out)
.arg(&mut u_out)
.arg(&mut y_out)
.arg(&mut partial_s)
.arg(&mut partial_r)
.arg(&mut status_dev);
unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let status_host = stream
.clone_dtoh(&status_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
if let Some(row) = status_host.iter().position(|s| *s != 0) {
let pivot = status_host[row];
let scale = sys.rows[row]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
row,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
let partial_s_host = stream
.clone_dtoh(&partial_s)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let partial_r_host = stream
.clone_dtoh(&partial_r)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut schur_host = vec![0.0_f64; k * k];
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
schur_host[col * k + row] = v;
}
}
let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
for i in 0..n {
let s_base = i * r_template * r_template;
for col in 0..k {
let col_base = s_base + col * r_template;
let dst_col_base = col * k;
for row in 0..k {
schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
}
}
let r_base = i * r_template;
for a in 0..k {
rhs_host[a] += partial_r_host[r_base + a];
}
}
let mut schur_dev = stream
.clone_htod(&schur_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rhs_dev = stream
.clone_htod(&rhs_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("fused Schur Cholesky failed at pivot {info}"),
});
}
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
let delta_beta_host = stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host.clone());
let mut delta_t_dev = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let back_cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
{
let mut builder = stream.launch_builder(&back_sub);
builder
.arg(&l_out)
.arg(&u_out)
.arg(&y_out)
.arg(&rhs_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&mut delta_t_dev);
unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_t_host = stream
.clone_dtoh(&delta_t_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut delta_t = Array1::<f64>::zeros(n * d);
for i in 0..n {
let src_base = i * p_max;
let dst_base = i * d;
for r in 0..d {
delta_t[dst_base + r] = delta_t_host[src_base + r];
}
}
let l_local_host = stream
.clone_dtoh(&l_out)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut log_det = 0.0_f64;
for i in 0..n {
let base = i * p_max * p_max;
for j in 0..d {
log_det += l_local_host[base + j * p_max + j].ln();
}
}
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
pub(super) fn build_schur_matvec_backend(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<crate::solver::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let plan = crate::gpu::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let p_max = plan.p_max;
let r_template = plan.r_template;
let runtime = crate::gpu::linalg::route_through_gpu(
crate::gpu::linalg::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let ctx = crate::gpu::runtime::cuda_context_for(runtime.device.ordinal)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let cap = &runtime.device.capability;
let key = crate::gpu::arrow_schur_nvrtc::FusedModuleCacheKey {
cc_major: cap.compute_major,
cc_minor: cap.compute_minor,
p_max: p_max as u32,
r_template: r_template as u32,
};
let module = fused_module_for(&ctx, key)?;
let forward = module
.load_function("arrow_schur_forward_pgroup")
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
let d_dev = stream
.clone_htod(&d_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let b_dev = stream
.clone_htod(&b_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let g_dev = stream
.clone_htod(&g_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut l_out = stream
.alloc_zeros::<f64>(n * p_max * p_max)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut u_out = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut y_out = stream
.alloc_zeros::<f64>(n * p_max * r_template)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut partial_s = stream
.alloc_zeros::<f64>(plan.partial_s_doubles)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut partial_r = stream
.alloc_zeros::<f64>(plan.partial_r_doubles)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut status_dev = stream
.alloc_zeros::<i32>(n)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let ridge_arg = ridge_t;
{
let mut builder = stream.launch_builder(&forward);
builder
.arg(&d_dev)
.arg(&b_dev)
.arg(&g_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&ridge_arg)
.arg(&mut l_out)
.arg(&mut u_out)
.arg(&mut y_out)
.arg(&mut partial_s)
.arg(&mut partial_r)
.arg(&mut status_dev);
unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let status_host = stream
.clone_dtoh(&status_dev)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
if let Some(row) = status_host.iter().position(|s| *s != 0) {
let pivot = status_host[row];
let scale = sys.rows[row]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
row,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
let y_host = stream
.clone_dtoh(&y_out)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
let hbb_is_kk = sys.hbb.dim() == (k, k);
let hbb_matvec_opt = sys.hbb_matvec.clone();
let closure: crate::solver::arrow_schur::GpuSchurMatvec =
Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
if let Some(ref mv) = hbb_matvec_opt {
mv(x.view(), out);
for a in 0..k {
out[a] += ridge_beta * x[a];
}
} else if hbb_is_kk {
for a in 0..k {
let mut acc = ridge_beta * x[a];
for b in 0..k {
acc += hbb_host[a * k + b] * x[b];
}
out[a] = acc;
}
} else {
for a in 0..k {
out[a] = ridge_beta * x[a];
}
}
let mut z = vec![0.0_f64; d];
for i in 0..n {
let y_base = i * p_max * r_template;
for r in 0..d {
let mut acc = 0.0;
for c in 0..k {
acc += y_host[y_base + c * p_max + r] * x[c];
}
z[r] = acc;
}
for c in 0..k {
let mut acc = 0.0;
for r in 0..d {
acc += y_host[y_base + c * p_max + r] * z[r];
}
out[c] -= acc;
}
}
});
Ok(closure)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::arrow_schur::ArrowSchurSystem;
use ndarray::Array2;
fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(n, d, k);
let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
let mut sample = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
};
for row in &mut sys.rows {
let mut a = Array2::<f64>::zeros((d, d));
for r in 0..d {
for c in 0..d {
a[[r, c]] = sample();
}
}
let mut htt = a.t().dot(&a);
for r in 0..d {
htt[[r, r]] += d as f64 + 1.0;
}
row.htt = htt;
for r in 0..d {
for c in 0..k {
row.htbeta[[r, c]] = 0.1 * sample();
}
row.gt[r] = sample();
}
}
let mut hbb_a = Array2::<f64>::zeros((k, k));
for r in 0..k {
for c in 0..k {
hbb_a[[r, c]] = sample();
}
}
let mut hbb = hbb_a.t().dot(&hbb_a);
for r in 0..k {
hbb[[r, r]] += k as f64 + 1.0;
}
sys.hbb = hbb;
for r in 0..k {
sys.gb[r] = sample();
}
sys
}
#[test]
fn dense_reference_matches_independent_solve() {
let sys = build_fixture(4, 5, 3, 7);
let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let total = n * d + k;
let mut h = Array2::<f64>::zeros((total, total));
let mut g = ndarray::Array1::<f64>::zeros(total);
for (i, row) in sys.rows.iter().enumerate() {
let base = i * d;
for c in 0..d {
for r in 0..d {
h[[base + r, base + c]] = row.htt[[r, c]];
}
}
for c in 0..k {
for r in 0..d {
h[[base + r, n * d + c]] = row.htbeta[[r, c]];
h[[n * d + c, base + r]] = row.htbeta[[r, c]];
}
}
for r in 0..d {
g[base + r] = row.gt[r];
}
}
for c in 0..k {
for r in 0..k {
h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
}
g[n * d + c] = sys.gb[c];
}
let l = cholesky_lower_host(h.view()).unwrap();
let rhs = g.mapv(|v| -v);
let expected = solve_cholesky_lower_host(l.view(), rhs.view());
for i in 0..n * d {
assert!(
(solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
"delta_t[{i}] mismatch: got {} expected {}",
solution.delta_t[i],
expected[i]
);
}
for a in 0..k {
assert!(
(solution.delta_beta[a] - expected[n * d + a]).abs()
< 1e-10 * (1.0 + expected[n * d + a].abs()),
"delta_beta[{a}] mismatch"
);
}
}
}