use ndarray::{Array1, Array2};
use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
pub struct ArrowSchurGpuSolution {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub log_det_hessian: f64,
}
#[derive(Debug, Clone)]
pub enum ArrowSchurGpuFailure {
Unavailable,
RidgeBumpRequired { row: usize, bump: f64 },
SchurFactorFailed { reason: String },
GpuRequiresDenseSystem {
had_hbb_matvec: bool,
had_htbeta_matvec: bool,
},
}
const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
pub fn solve_arrow_newton_step(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let had_hbb_matvec = sys.hbb_matvec.is_some();
let had_htbeta_matvec = sys.htbeta_matvec.is_some();
if had_hbb_matvec || had_htbeta_matvec {
return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec,
had_htbeta_matvec,
});
}
if sys.hbb.dim() != (k, k) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
});
}
if n == 0 || d == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
if sys
.rows
.iter()
.any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
{
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "row block dimension mismatch".to_string(),
});
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
if gam_gpu::device_runtime::GpuRuntime::global()
.map(gam_gpu::device_runtime::GpuRuntime::device_count)
.unwrap_or(0)
> 1
{
match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
Ok(sol) => return Ok(sol),
Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
}
Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
}
Err(_) => {}
}
}
if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
match cuda::solve_fused(sys, ridge_t, ridge_beta) {
Ok(sol) => return Ok(sol),
Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
}
Err(_) => {}
}
}
cuda::solve(sys, ridge_t, ridge_beta)
}
}
#[cfg(target_os = "linux")]
fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let mut d_buf = Vec::with_capacity(n * d * d);
let mut b_buf = Vec::with_capacity(n * d * k);
let mut g_buf = Vec::with_capacity(n * d);
for row in &sys.rows {
pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
}
(d_buf, b_buf, g_buf)
}
#[cfg(target_os = "linux")]
#[inline]
fn pack_block(
row: &crate::arrow_schur::ArrowRowBlock,
ridge_t: f64,
d: usize,
k: usize,
d_buf: &mut Vec<f64>,
b_buf: &mut Vec<f64>,
g_buf: &mut Vec<f64>,
) {
for col in 0..d {
for r in 0..d {
let mut value = row.htt[[r, col]];
if r == col {
value += ridge_t;
}
d_buf.push(value);
}
}
for col in 0..k {
for r in 0..d {
b_buf.push(row.htbeta[[r, col]]);
}
}
for r in 0..d {
g_buf.push(row.gt[r]);
}
}
#[doc(hidden)]
#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
#[cfg(not(target_os = "linux"))]
{
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
.is_none()
{
return Err(ArrowSchurGpuFailure::Unavailable);
}
cuda::solve_fused(sys, ridge_t, ridge_beta)
}
}
pub struct ResidentArrowFrameHandle {
#[cfg(target_os = "linux")]
inner: cuda::ResidentArrowFrame,
#[cfg(not(target_os = "linux"))]
_never: std::convert::Infallible,
}
impl ResidentArrowFrameHandle {
pub fn new(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<Self, ArrowSchurGpuFailure> {
if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
had_hbb_matvec: sys.hbb_matvec.is_some(),
had_htbeta_matvec: sys.htbeta_matvec.is_some(),
});
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
Ok(Self {
inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
})
}
}
pub fn solve_gradient(
&self,
g_t: &[f64],
g_beta: &[f64],
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
#[cfg(not(target_os = "linux"))]
{
if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "non-finite gradient entry".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
self.inner.solve_gradient(g_t, g_beta)
}
}
#[must_use]
pub fn log_det_hessian(&self) -> f64 {
#[cfg(not(target_os = "linux"))]
{
panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
}
#[cfg(target_os = "linux")]
{
self.inner.log_det_hessian()
}
}
}
pub fn gpu_schur_matvec_backend(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
if sys.htbeta_matvec.is_some() {
return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
}
}
fn build_row_procedural_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
use std::sync::Arc;
let n = sys.rows.len();
let k = sys.k;
let forward = sys
.htbeta_matvec
.clone()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
ArrowSchurGpuFailure::SchurFactorFailed {
reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
forward operator installed without its sparse adjoint"
.to_string(),
}
})?;
let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
for (i, row) in sys.rows.iter().enumerate() {
let di = row.htt.nrows();
if row.htt.ncols() != di || row.gt.len() != di {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
});
}
let mut block = row.htt.clone();
for r in 0..di {
block[[r, r]] += ridge_t;
}
let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
.ok_or_else(|| {
let scale = row
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
ArrowSchurGpuFailure::RidgeBumpRequired {
row: i,
bump: scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN,
}
})?;
factors.push(factor);
}
let penalty_op = sys.effective_penalty_op();
let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
let closure: crate::arrow_schur::GpuSchurMatvec =
Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
{
let x_slice = x.as_slice().expect("x must be contiguous");
let out_slice = out.as_slice_mut().expect("out must be contiguous");
for a in 0..k {
out_slice[a] = ridge_beta * x_slice[a];
}
penalty_op.matvec(x_slice, out_slice);
}
let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
&& rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let partials: Vec<Array1<f64>> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let mut neg = Array1::<f64>::zeros(k);
for i in idxs {
let di = row_dims[i];
let mut v_i = Array1::<f64>::zeros(di);
forward(i, x.view(), &mut v_i);
let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
transpose(i, w_i.view(), &mut neg);
}
neg
})
.collect();
let mut neg = Array1::<f64>::zeros(k);
for part in &partials {
for a in 0..k {
neg[a] += part[a];
}
}
for a in 0..k {
out[a] -= neg[a];
}
} else {
let mut neg = Array1::<f64>::zeros(k);
for i in 0..n {
let di = row_dims[i];
let mut v_i = Array1::<f64>::zeros(di);
forward(i, x.view(), &mut v_i);
let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
transpose(i, w_i.view(), &mut neg);
}
for a in 0..k {
out[a] -= neg[a];
}
}
});
Ok(closure)
}
pub fn solve_reduced_beta_pcg(
s_acc: &Array2<f64>,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
.map(|(x, _)| x)
}
#[doc(hidden)]
pub fn solve_reduced_beta_pcg_with_diagnostics(
s_acc: &Array2<f64>,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
let k = rhs_beta.len();
if s_acc.dim() != (k, k) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
s_acc.dim()
),
});
}
if k == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
#[cfg(not(target_os = "linux"))]
{
if relative_tolerance.is_nan() || max_iterations == 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
cuda::solve_reduced_beta_pcg_with_diagnostics(
s_acc,
rhs_beta,
max_iterations,
relative_tolerance,
)
}
}
pub fn solve_sae_matrix_free_pcg(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
ridge_beta: f64,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
#[cfg(not(target_os = "linux"))]
{
if ridge_t.is_nan()
|| ridge_beta.is_nan()
|| relative_tolerance.is_nan()
|| max_iterations == 0
{
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
});
}
Err(ArrowSchurGpuFailure::Unavailable)
}
#[cfg(target_os = "linux")]
{
if data.frame.is_some() {
cuda::solve_sae_matrix_free_pcg_framed(
sys,
data,
ridge_t,
ridge_beta,
rhs_beta,
max_iterations,
relative_tolerance,
)
} else {
cuda::solve_sae_matrix_free_pcg(
sys,
data,
ridge_t,
ridge_beta,
rhs_beta,
max_iterations,
relative_tolerance,
)
}
}
}
#[doc(hidden)]
pub fn solve_arrow_newton_step_dense_reference(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, String> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
let mut h = Array2::<f64>::zeros((total, total));
let mut rhs = Array1::<f64>::zeros(total);
for (i, row) in sys.rows.iter().enumerate() {
let base = i * d;
for c in 0..d {
for r in 0..d {
h[[base + r, base + c]] = row.htt[[r, c]];
}
h[[base + c, base + c]] += ridge_t;
}
for c in 0..k {
for r in 0..d {
let value = row.htbeta[[r, c]];
h[[base + r, n * d + c]] = value;
h[[n * d + c, base + r]] = value;
}
}
for r in 0..d {
rhs[base + r] = -row.gt[r];
}
}
for c in 0..k {
for r in 0..k {
h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
}
h[[n * d + c, n * d + c]] += ridge_beta;
rhs[n * d + c] = -sys.gb[c];
}
let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
.ok_or_else(|| "dense reference Cholesky failed".to_string())?;
let mut log_det = 0.0_f64;
for i in 0..total {
log_det += factor[[i, i]].ln();
}
log_det *= 2.0;
let solved = cholesky_solve_vector(factor.view(), rhs.view());
let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
#[doc(hidden)]
pub fn sae_framed_penalty_matvec_cpu(
data: &DeviceSaePcgData,
ridge_beta: f64,
x: &[f64],
out: &mut [f64],
) {
let frame = data
.frame
.as_ref()
.expect("sae_framed_penalty_matvec_cpu requires frame metadata");
let k = data.beta_dim;
for a in 0..k {
out[a] = ridge_beta * x[a];
}
for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
let off = blk.global_offset;
let m = blk.factor_a.nrows();
for i_a in 0..m {
for i_b in 0..r {
let mut acc = 0.0_f64;
for j_a in 0..m {
let s = blk.factor_a[[i_a, j_a]];
if s == 0.0 {
continue;
}
acc += s * x[off + j_a * r + i_b];
}
out[off + i_a * r + i_b] += acc;
}
}
}
for blk in &frame.frame_blocks {
let r_i = frame.ranks[blk.atom_i];
let r_j = frame.ranks[blk.atom_j];
let off_i = frame.border_offsets[blk.atom_i];
let off_j = frame.border_offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
let yi_base = off_i + li * r_i;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
let xj_base = off_j + lj * r_j;
for a in 0..r_i {
let mut acc = 0.0_f64;
for b in 0..r_j {
acc += blk.w[[a, b]] * x[xj_base + b];
}
out[yi_base + a] += g * acc;
}
}
}
}
}
#[doc(hidden)]
pub fn sae_framed_schur_matvec_cpu(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
ridge_beta: f64,
x: &[f64],
out: &mut [f64],
) -> Result<(), String> {
let frame = data
.frame
.as_ref()
.ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
let k = data.beta_dim;
sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
if frame.row_htbeta.len() != sys.rows.len() {
return Err(format!(
"sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
frame.row_htbeta.len(),
sys.rows.len()
));
}
for (i, row) in sys.rows.iter().enumerate() {
let slab = &frame.row_htbeta[i];
if slab.is_empty() {
continue;
}
let qi = sys.row_dims[i];
if qi == 0 || slab.len() != qi * k {
continue;
}
let mut h = vec![0.0_f64; qi];
for c in 0..qi {
let base = c * k;
let mut acc = 0.0_f64;
for a in 0..k {
acc += slab[base + a] * x[a];
}
h[c] = acc;
}
let mut block = row.htt.clone();
for d in 0..qi {
block[[d, d]] += ridge_t;
}
let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
.ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
for c in 0..qi {
let sc = s[c];
if sc == 0.0 {
continue;
}
let base = c * k;
for a in 0..k {
out[a] -= slab[base + a] * sc;
}
}
}
Ok(())
}
#[cfg(target_os = "linux")]
mod cuda {
use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
use gam_gpu::driver::to_i32;
use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
use crate::arrow_schur::{
ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
};
use cudarc::cublas::sys::{
cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
};
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
use cudarc::driver::{
CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
PushKernelArg,
};
use ndarray::Array1;
use std::sync::{Arc, OnceLock};
struct RowSlot {
d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, diag_scale: f64, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
bump: Option<f64>,
tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
pub(super) fn solve_multi_gpu(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
if n == 0 || d == 0 || k == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let runtime = gam_gpu::device_runtime::GpuRuntime::global()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
if runtime.device_count() < 2 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
for row in &sys.rows {
if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut d_block = Vec::with_capacity(d * d);
let mut b_block = Vec::with_capacity(d * k);
let mut g_vec = Vec::with_capacity(d);
pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
let diag_scale = row
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
slots.push(RowSlot {
d_block,
b_block,
g_vec,
diag_scale,
l_block: Vec::new(),
u_vec: Vec::new(),
y_block: Vec::new(),
log_det_local: 0.0,
bump: None,
tile_partial_schur: None,
tile_partial_rhs: None,
delta_t_block: vec![0.0; d],
});
}
let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
forward_tile(ordinal, d, k, tile)
});
if forward_ok.is_none() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
if let Some((row, bump)) = slots
.iter()
.enumerate()
.find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
{
return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
}
let mut schur_host = vec![0.0_f64; k * k];
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
schur_host[col * k + row] = v;
}
}
let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
let mut log_det = 0.0_f64;
for start in tile_starts(&row_base_of_tile) {
let slot = &slots[start];
let partial_schur = slot
.tile_partial_schur
.as_ref()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let partial_rhs = slot
.tile_partial_rhs
.as_ref()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
for idx in 0..k * k {
schur_host[idx] += partial_schur[idx];
}
for a in 0..k {
rhs_host[a] += partial_rhs[a];
}
}
for slot in &slots {
log_det += slot.log_det_local;
}
let primary = runtime.selected_device().ordinal;
let stream = gam_gpu::device_runtime::cuda_context_for(primary)
.and_then(|ctx| ctx.new_stream().ok())
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut schur_dev = stream
.clone_htod(&schur_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rhs_dev = stream
.clone_htod(&rhs_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
});
}
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
let delta_beta_host = stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host.clone());
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
let delta_beta_ref = &delta_beta_host;
let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
});
if back_ok.is_none() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut delta_t = Array1::<f64>::zeros(n * d);
for (i, slot) in slots.iter().enumerate() {
let base = i * d;
for r in 0..d {
delta_t[base + r] = slot.delta_t_block[r];
}
}
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
tiles.iter().map(|(_, range)| range.start)
}
fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
if tile.is_empty() {
return Some(());
}
let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
.and_then(|ctx| ctx.new_stream().ok())?;
let solver = DnHandle::new(stream.clone()).ok()?;
let blas = CudaBlas::new(stream.clone()).ok()?;
let m = tile.len();
let mut d_host = Vec::with_capacity(m * d * d);
let mut b_host = Vec::with_capacity(m * d * k);
let mut g_host = Vec::with_capacity(m * d);
for slot in tile.iter() {
d_host.extend_from_slice(&slot.d_block);
b_host.extend_from_slice(&slot.b_block);
g_host.extend_from_slice(&slot.g_vec);
}
let mut d_dev = stream.clone_htod(&d_host).ok()?;
let mut b_dev = stream.clone_htod(&b_host).ok()?;
let mut g_dev = stream.clone_htod(&g_host).ok()?;
let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
if let Some(local) = info_host.iter().position(|info| *info != 0) {
let pivot = info_host[local];
tile[local].bump = Some(
tile[local].diag_scale
* (f64::from(pivot).abs()).max(1.0)
* f64::EPSILON.sqrt()
* super::RIDGE_BUMP_EPS_MARGIN,
);
return Some(());
}
trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
let l_host = stream.clone_dtoh(&d_dev).ok()?;
let u_host = stream.clone_dtoh(&g_dev).ok()?;
let y_host = stream.clone_dtoh(&b_dev).ok()?;
let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
for (local, slot) in tile.iter_mut().enumerate() {
let l_base = local * d * d;
let u_base = local * d;
let y_base = local * d * k;
slot.l_block = l_host[l_base..l_base + d * d].to_vec();
slot.u_vec = u_host[u_base..u_base + d].to_vec();
slot.y_block = y_host[y_base..y_base + d * k].to_vec();
let mut log_det_local = 0.0_f64;
for j in 0..d {
log_det_local += l_host[l_base + j * d + j].ln();
}
slot.log_det_local = log_det_local;
}
tile[0].tile_partial_schur = Some(partial_schur);
tile[0].tile_partial_rhs = Some(partial_rhs);
Some(())
}
fn back_sub_tile(
ordinal: usize,
d: usize,
k: usize,
delta_beta: &[f64],
tile: &mut [RowSlot],
) -> Option<()> {
if tile.is_empty() {
return Some(());
}
let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
.and_then(|ctx| ctx.new_stream().ok())?;
let blas = CudaBlas::new(stream.clone()).ok()?;
let m = tile.len();
let mut l_host = Vec::with_capacity(m * d * d);
let mut u_host = Vec::with_capacity(m * d);
let mut y_host = Vec::with_capacity(m * d * k);
for slot in tile.iter() {
l_host.extend_from_slice(&slot.l_block);
u_host.extend_from_slice(&slot.u_vec);
y_host.extend_from_slice(&slot.y_block);
}
let d_dev = stream.clone_htod(&l_host).ok()?;
let mut g_dev = stream.clone_htod(&u_host).ok()?;
let b_dev = stream.clone_htod(&y_host).ok()?;
let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
let x_host = stream.clone_dtoh(&g_dev).ok()?;
for (local, slot) in tile.iter_mut().enumerate() {
let base = local * d;
for r in 0..d {
slot.delta_t_block[r] = -x_host[base + r];
}
}
Some(())
}
pub(super) fn solve(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.and_then(|ctx| ctx.new_stream().ok())
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
let mut d_dev = stream
.clone_htod(&d_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut b_dev = stream
.clone_htod(&b_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut g_dev = stream
.clone_htod(&g_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
if let Some(idx) = info_host.iter().position(|info| *info != 0) {
let pivot = info_host[idx];
let scale = sys.rows[idx]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
row: idx,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
let schur_init: Vec<f64> = {
let mut tmp = Vec::with_capacity(k * k);
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
tmp.push(v);
}
}
tmp
};
let mut schur_dev = stream
.clone_htod(&schur_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
let mut rhs_dev = stream
.clone_htod(&rhs_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("Schur Cholesky failed at pivot {info}"),
});
}
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
let delta_beta_host = stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host.clone());
accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
let x_host = stream
.clone_dtoh(&g_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut delta_t = Array1::<f64>::zeros(n * d);
for (i, v) in x_host.iter().enumerate() {
delta_t[i] = -*v;
}
let l_local_host = stream
.clone_dtoh(&d_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut log_det = 0.0_f64;
for i in 0..n {
let base = i * d * d;
for j in 0..d {
log_det += l_local_host[base + j * d + j].ln();
}
}
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
fn potrf_batched(
solver: &DnHandle,
stream: &Arc<CudaStream>,
p: usize,
batch: usize,
matrices: &mut CudaSlice<f64>,
) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let matrix_len = p * p;
let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
let (base_ptr, _record) = matrices.device_ptr_mut(stream);
let mut ptrs = Vec::with_capacity(batch);
for idx in 0..batch {
ptrs.push(base_ptr + (idx as u64) * bytes_per);
}
let mut ptrs_dev = stream
.clone_htod(&ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut info_dev = stream
.alloc_zeros::<i32>(batch)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let status = {
let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
unsafe {
cusolver_sys::cusolverDnDpotrfBatched(
solver.cu(),
cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
p_i,
ptrs_ptr as *mut *mut f64,
p_i,
info_ptr as *mut i32,
batch_i,
)
}
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
stream
.clone_dtoh(&info_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
fn potrf_single(
solver: &DnHandle,
stream: &Arc<CudaStream>,
p: usize,
matrix: &mut CudaSlice<f64>,
) -> Result<i32, ArrowSchurGpuFailure> {
let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
let mut lwork = 0_i32;
{
let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
let status = unsafe {
cusolver_sys::cusolverDnDpotrf_bufferSize(
solver.cu(),
uplo,
p_i,
mat_ptr as *mut f64,
p_i,
&mut lwork,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
}
let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut workspace = stream
.alloc_zeros::<f64>(lwork_usize.max(1))
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut info_dev = stream
.alloc_zeros::<i32>(1)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
let status = unsafe {
cusolver_sys::cusolverDnDpotrf(
solver.cu(),
uplo,
p_i,
mat_ptr as *mut f64,
p_i,
work_ptr as *mut f64,
lwork,
info_ptr as *mut i32,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
}
let info_host = stream
.clone_dtoh(&info_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
Ok(info_host[0])
}
fn trsm_batched_lower_inplace(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
}
fn trsm_batched_lower_inplace_transposed(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
}
fn trsm_batched_inplace_inner(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
d: usize,
n: usize,
nrhs: usize,
l_stack: &CudaSlice<f64>,
rhs_stack: &mut CudaSlice<f64>,
transposed: bool,
) -> Result<(), ArrowSchurGpuFailure> {
let alpha = 1.0_f64;
let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
let (l_base, _l_record) = l_stack.device_ptr(stream);
let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
let mut l_ptrs = Vec::with_capacity(n);
let mut rhs_ptrs = Vec::with_capacity(n);
for i in 0..n {
l_ptrs.push(l_base + (i as u64) * l_bytes_per);
rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
}
let mut l_ptrs_dev = stream
.clone_htod(&l_ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rhs_ptrs_dev = stream
.clone_htod(&rhs_ptrs)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
let op = if transposed {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
};
let handle = *blas.handle();
let status = unsafe {
cudarc::cublas::sys::cublasDtrsmBatched(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
op,
cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
d_i,
nrhs_i,
&alpha,
l_ptrs_ptr as *const *const f64,
d_i,
rhs_ptrs_ptr as *const *mut f64,
d_i,
batch_i,
)
};
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Ok(())
}
fn trsm_single(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
l: &CudaSlice<f64>,
rhs: &mut CudaSlice<f64>,
upper: bool,
transposed: bool,
) -> Result<(), ArrowSchurGpuFailure> {
let alpha = 1.0_f64;
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let handle = *blas.handle();
let (l_ptr, _l_rec) = l.device_ptr(stream);
let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
let status = unsafe {
cudarc::cublas::sys::cublasDtrsm_v2(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
if upper {
cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
} else {
cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
},
if transposed {
cublasOperation_t::CUBLAS_OP_T
} else {
cublasOperation_t::CUBLAS_OP_N
},
cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
n_i,
1,
&alpha,
l_ptr as *const f64,
n_i,
rhs_ptr as *mut f64,
n_i,
)
};
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return Err(ArrowSchurGpuFailure::Unavailable);
}
Ok(())
}
fn accumulate_schur(
blas: &CudaBlas,
d: usize,
k: usize,
n: usize,
y_stack: &CudaSlice<f64>,
u_stack: &CudaSlice<f64>,
schur: &mut CudaSlice<f64>,
rhs: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let y_block_elems = d * k;
let u_block_elems = d;
for i in 0..n {
let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
let gemm_cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: -1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
beta: 1.0,
ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
};
unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 1.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn accumulate_schur_rhs_only(
blas: &CudaBlas,
d: usize,
k: usize,
n: usize,
y_stack: &CudaSlice<f64>,
u_stack: &CudaSlice<f64>,
rhs: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let y_block_elems = d * k;
let u_block_elems = d;
for i in 0..n {
let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_T,
m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 1.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn accumulate_back_sub_rhs(
blas: &CudaBlas,
d: usize,
k: usize,
n: usize,
y_stack: &CudaSlice<f64>,
delta_beta: &CudaSlice<f64>,
u_stack: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let y_block_elems = d * k;
let u_block_elems = d;
for i in 0..n {
let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_N,
m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 1.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
use std::collections::HashMap;
use std::sync::Mutex;
struct FusedModuleCache {
modules: Mutex<
HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
>,
}
fn fused_module_cache() -> &'static FusedModuleCache {
static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
CACHE.get_or_init(|| FusedModuleCache {
modules: Mutex::new(HashMap::new()),
})
}
fn fused_module_for(
ctx: &Arc<CudaContext>,
key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
let cache = fused_module_cache();
if let Ok(guard) = cache.modules.lock() {
if let Some(existing) = guard.get(&key) {
return Ok(existing.clone());
}
}
let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
key.p_max as usize,
key.r_template as usize,
);
let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
key.p_max, key.r_template
),
}
})?;
let module = ctx
.load_module(ptx)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
if let Ok(mut guard) = cache.modules.lock() {
guard.entry(key).or_insert_with(|| module.clone());
}
Ok(module)
}
const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
extern "C" __global__ void arrow_pcg_jacobi_mul(
const double* __restrict__ inv_diag,
const double* __restrict__ r,
double* __restrict__ z,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
z[idx] = inv_diag[idx] * r[idx];
}
}
extern "C" __global__ void arrow_pcg_update_p(
const double* __restrict__ z,
double beta,
double* __restrict__ p,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
p[idx] = z[idx] + beta * p[idx];
}
}
extern "C" __global__ void arrow_sae_init(
double* __restrict__ out,
const double* __restrict__ x,
double ridge,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = ridge * x[idx];
}
}
extern "C" __global__ void arrow_sae_smooth_matvec(
const double* __restrict__ x,
double* __restrict__ out,
const int* __restrict__ block_offsets,
const int* __restrict__ block_m,
const int* __restrict__ factor_ptr,
const double* __restrict__ factors,
int p,
int n_blocks
) {
int block_id = blockIdx.y;
int linear = blockIdx.x * blockDim.x + threadIdx.x;
if (block_id >= n_blocks) {
return;
}
int m = block_m[block_id];
int total = m * p;
if (linear >= total) {
return;
}
int li = linear / p;
int oc = linear - li * p;
int off = block_offsets[block_id];
int fbase = factor_ptr[block_id];
double acc = 0.0;
for (int lj = 0; lj < m; ++lj) {
double a = factors[fbase + li * m + lj];
acc += a * x[off + lj * p + oc];
}
out[off + li * p + oc] += acc;
}
extern "C" __global__ void arrow_sae_sparse_g_matvec(
const double* __restrict__ x,
double* __restrict__ out,
const int* __restrict__ row_off,
const int* __restrict__ col_off,
const int* __restrict__ rows,
const int* __restrict__ cols,
const int* __restrict__ data_ptr,
const double* __restrict__ data,
int p,
int n_blocks
) {
int block_id = blockIdx.y;
int linear = blockIdx.x * blockDim.x + threadIdx.x;
if (block_id >= n_blocks) {
return;
}
int m_i = rows[block_id];
int m_j = cols[block_id];
int total = m_i * p;
if (linear >= total) {
return;
}
int li = linear / p;
int oc = linear - li * p;
int rbase = row_off[block_id];
int cbase = col_off[block_id];
int dbase = data_ptr[block_id];
double acc = 0.0;
for (int lj = 0; lj < m_j; ++lj) {
acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
}
out[(rbase + li) * p + oc] += acc;
}
extern "C" __global__ void arrow_sae_gather_u(
const double* __restrict__ x,
const int* __restrict__ row_ptr,
const int* __restrict__ beta_base,
const double* __restrict__ phi,
double* __restrict__ u,
int p,
int n_rows
) {
int row = blockIdx.y;
int oc = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || oc >= p) {
return;
}
double acc = 0.0;
int start = row_ptr[row];
int end = row_ptr[row + 1];
for (int e = start; e < end; ++e) {
acc += phi[e] * x[beta_base[e] + oc];
}
u[row * p + oc] = acc;
}
extern "C" __global__ void arrow_sae_apply_l(
const double* __restrict__ u,
const int* __restrict__ jac_ptr,
const double* __restrict__ jac,
double* __restrict__ w,
int p,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows) {
return;
}
int jstart = jac_ptr[row];
int q = (jac_ptr[row + 1] - jstart) / p;
if (c >= q) {
return;
}
double acc = 0.0;
for (int oc = 0; oc < p; ++oc) {
acc += jac[jstart + c * p + oc] * u[row * p + oc];
}
w[row * max_q + c] = acc;
}
extern "C" __global__ void arrow_sae_apply_ainv(
const double* __restrict__ ainv,
const double* __restrict__ w,
double* __restrict__ v,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || c >= max_q) {
return;
}
double acc = 0.0;
int base = row * max_q * max_q;
for (int j = 0; j < max_q; ++j) {
acc += ainv[base + c * max_q + j] * w[row * max_q + j];
}
v[row * max_q + c] = acc;
}
extern "C" __global__ void arrow_sae_scatter_sub(
const double* __restrict__ v,
const int* __restrict__ jac_ptr,
const double* __restrict__ jac,
const int* __restrict__ row_ptr,
const int* __restrict__ beta_base,
const double* __restrict__ phi,
double* __restrict__ out,
int p,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int oc = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || oc >= p) {
return;
}
int jstart = jac_ptr[row];
int q = (jac_ptr[row + 1] - jstart) / p;
double lt_v = 0.0;
for (int c = 0; c < q; ++c) {
lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
}
int start = row_ptr[row];
int end = row_ptr[row + 1];
for (int e = start; e < end; ++e) {
atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
}
}
extern "C" __global__ void arrow_sae_diag_sub(
double* __restrict__ diag,
const double* __restrict__ ainv,
const int* __restrict__ jac_ptr,
const double* __restrict__ jac,
const int* __restrict__ row_ptr,
const int* __restrict__ beta_base,
const double* __restrict__ phi,
int p,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int oc = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || oc >= p) {
return;
}
int jstart = jac_ptr[row];
int q = (jac_ptr[row + 1] - jstart) / p;
int abase = row * max_q * max_q;
double quad = 0.0;
for (int c = 0; c < q; ++c) {
double lc = jac[jstart + c * p + oc];
for (int d = 0; d < q; ++d) {
quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
}
}
int start = row_ptr[row];
int end = row_ptr[row + 1];
for (int e = start; e < end; ++e) {
double pe = phi[e];
atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
}
}
/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
* The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
* smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
* `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
* the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
extern "C" __global__ void arrow_sae_frame_smooth_matvec(
const double* __restrict__ x,
double* __restrict__ out,
const int* __restrict__ block_offsets,
const int* __restrict__ block_m,
const int* __restrict__ block_r,
const int* __restrict__ factor_ptr,
const double* __restrict__ factors,
int n_blocks
) {
int block_id = blockIdx.y;
int linear = blockIdx.x * blockDim.x + threadIdx.x;
if (block_id >= n_blocks) {
return;
}
int m = block_m[block_id];
int r = block_r[block_id];
int total = m * r;
if (linear >= total) {
return;
}
int li = linear / r;
int ib = linear - li * r;
int off = block_offsets[block_id];
int fbase = factor_ptr[block_id];
double acc = 0.0;
for (int lj = 0; lj < m; ++lj) {
double a = factors[fbase + li * m + lj];
acc += a * x[off + lj * r + ib];
}
out[off + li * r + ib] += acc;
}
extern "C" __global__ void arrow_sae_frame_g_matvec(
const double* __restrict__ x,
double* __restrict__ out,
const int* __restrict__ off_i,
const int* __restrict__ off_j,
const int* __restrict__ r_i,
const int* __restrict__ r_j,
const int* __restrict__ m_i,
const int* __restrict__ m_j,
const int* __restrict__ g_ptr,
const double* __restrict__ g_data,
const int* __restrict__ w_ptr,
const double* __restrict__ w_data,
int n_blocks
) {
int block_id = blockIdx.y;
int linear = blockIdx.x * blockDim.x + threadIdx.x;
if (block_id >= n_blocks) {
return;
}
int ri = r_i[block_id];
int rj = r_j[block_id];
int mi = m_i[block_id];
int mj = m_j[block_id];
int total = mi * ri;
if (linear >= total) {
return;
}
int li = linear / ri; // basis row in atom i
int a = linear - li * ri; // frame coord in atom i
int oi = off_i[block_id];
int oj = off_j[block_id];
int gbase = g_ptr[block_id];
int wbase = w_ptr[block_id];
double acc = 0.0;
for (int lj = 0; lj < mj; ++lj) {
double g = g_data[gbase + li * mj + lj];
if (g == 0.0) { continue; }
int xj_base = oj + lj * rj;
double inner = 0.0;
for (int b = 0; b < rj; ++b) {
inner += w_data[wbase + a * rj + b] * x[xj_base + b];
}
acc += g * inner;
}
out[oi + li * ri + a] += acc;
}
/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
* h_i = H_tβ^(i) · x (length q_i)
* s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
* out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
* `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
* (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
extern "C" __global__ void arrow_sae_frame_apply_h(
const double* __restrict__ x,
const int* __restrict__ htb_ptr,
const double* __restrict__ htb,
const int* __restrict__ q_of,
double* __restrict__ hvec,
int k,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows) { return; }
int q = q_of[row];
if (c >= q) { return; }
int base = htb_ptr[row] + c * k;
double acc = 0.0;
for (int a = 0; a < k; ++a) {
acc += htb[base + a] * x[a];
}
hvec[row * max_q + c] = acc;
}
extern "C" __global__ void arrow_sae_frame_apply_ainv(
const double* __restrict__ ainv,
const double* __restrict__ hvec,
const int* __restrict__ q_of,
double* __restrict__ svec,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || c >= max_q) { return; }
int q = q_of[row];
double acc = 0.0;
int abase = row * max_q * max_q;
for (int j = 0; j < q; ++j) {
acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
}
svec[row * max_q + c] = acc;
}
extern "C" __global__ void arrow_sae_frame_scatter_h(
const double* __restrict__ svec,
const int* __restrict__ htb_ptr,
const double* __restrict__ htb,
const int* __restrict__ q_of,
double* __restrict__ out,
int k,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int a = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || a >= k) { return; }
int q = q_of[row];
int hbase = htb_ptr[row];
double acc = 0.0;
for (int c = 0; c < q; ++c) {
acc += htb[hbase + c * k + a] * svec[row * max_q + c];
}
atomicAdd(&out[a], -acc);
}
/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
extern "C" __global__ void arrow_sae_frame_diag_sub(
double* __restrict__ diag,
const double* __restrict__ ainv,
const int* __restrict__ htb_ptr,
const double* __restrict__ htb,
const int* __restrict__ q_of,
int k,
int max_q,
int n_rows
) {
int row = blockIdx.y;
int a = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= n_rows || a >= k) { return; }
int q = q_of[row];
int hbase = htb_ptr[row];
int abase = row * max_q * max_q;
double quad = 0.0;
for (int c = 0; c < q; ++c) {
double hc = htb[hbase + c * k + a];
for (int d = 0; d < q; ++d) {
quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
}
}
atomicAdd(&diag[a], -quad);
}
"#;
fn pcg_vector_module(
ctx: &Arc<CudaContext>,
) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
static CACHE: gam_gpu::device_cache::PtxModuleCache =
gam_gpu::device_cache::PtxModuleCache::new();
CACHE
.get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
.map_err(|err| {
log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
ArrowSchurGpuFailure::Unavailable
})
}
fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
let threads = 256u32;
let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
Ok(LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: 0,
})
}
fn launch_jacobi_mul(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
inv_diag: &CudaSlice<f64>,
r: &CudaSlice<f64>,
z: &mut CudaSlice<f64>,
n: usize,
) -> Result<(), ArrowSchurGpuFailure> {
let kernel = module
.load_function("arrow_pcg_jacobi_mul")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let mut builder = stream.launch_builder(&kernel);
builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
unsafe { builder.launch(pcg_launch_config(n)?) }
.map(drop)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
fn launch_update_p(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
z: &CudaSlice<f64>,
beta: f64,
p: &mut CudaSlice<f64>,
n: usize,
) -> Result<(), ArrowSchurGpuFailure> {
let kernel = module
.load_function("arrow_pcg_update_p")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let mut builder = stream.launch_builder(&kernel);
builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
unsafe { builder.launch(pcg_launch_config(n)?) }
.map(drop)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
struct DeviceSaePcgBuffers {
row_ptr: CudaSlice<i32>,
beta_base: CudaSlice<i32>,
phi: CudaSlice<f64>,
jac_ptr: CudaSlice<i32>,
jac: CudaSlice<f64>,
smooth_offsets: CudaSlice<i32>,
smooth_m: CudaSlice<i32>,
smooth_ptr: CudaSlice<i32>,
smooth_data: CudaSlice<f64>,
g_row_off: CudaSlice<i32>,
g_col_off: CudaSlice<i32>,
g_rows: CudaSlice<i32>,
g_cols: CudaSlice<i32>,
g_ptr: CudaSlice<i32>,
g_data: CudaSlice<f64>,
ainv: CudaSlice<f64>,
u: CudaSlice<f64>,
w: CudaSlice<f64>,
v: CudaSlice<f64>,
n_rows: usize,
p: usize,
k: usize,
max_q: usize,
smooth_blocks: usize,
g_blocks: usize,
}
fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
}
fn sae_penalty_diag_host(
data: &DeviceSaePcgData,
ridge_beta: f64,
) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
let mut diag = vec![ridge_beta; data.beta_dim];
for block in &data.smooth_blocks {
let (rows, cols) = block.factor_a.dim();
if rows != cols {
return Err(ArrowSchurGpuFailure::Unavailable);
}
for row in 0..rows {
let coeff = block.factor_a[[row, row]];
let base = block
.global_offset
.checked_add(
row.checked_mul(data.p)
.ok_or(ArrowSchurGpuFailure::Unavailable)?,
)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let end = base
.checked_add(data.p)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
if end > diag.len() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
for channel in 0..data.p {
diag[base + channel] += coeff;
}
}
}
for block in &data.sparse_g_blocks {
if block.row_off != block.col_off {
continue;
}
let (rows, cols) = block.data.dim();
for row in 0..rows.min(cols) {
let coeff = block.data[[row, row]];
let beta_row = block
.row_off
.checked_add(row)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let base = beta_row
.checked_mul(data.p)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let end = base
.checked_add(data.p)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
if end > diag.len() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
for channel in 0..data.p {
diag[base + channel] += coeff;
}
}
}
Ok(diag)
}
fn flatten_device_sae_data(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
stream: &Arc<CudaStream>,
) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
let n_rows = sys.rows.len();
let p = data.p;
let k = data.beta_dim;
if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
let mut beta_base_host = Vec::<i32>::new();
let mut phi_host = Vec::<f64>::new();
row_ptr_host.push(0_i32);
for row in data.a_phi.iter() {
for &(base, phi) in row {
beta_base_host.push(checked_i32(base)?);
phi_host.push(phi);
}
row_ptr_host.push(checked_i32(beta_base_host.len())?);
}
let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
let mut jac_host = Vec::<f64>::new();
let mut max_q = 0usize;
jac_ptr_host.push(0_i32);
for row_jac in data.local_jac.iter() {
if row_jac.len() % p != 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
max_q = max_q.max(row_jac.len() / p);
jac_host.extend_from_slice(row_jac);
jac_ptr_host.push(checked_i32(jac_host.len())?);
}
if max_q == 0 {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
let mut smooth_data_host = Vec::<f64>::new();
smooth_ptr_host.push(0_i32);
for block in &data.smooth_blocks {
let (rows, cols) = block.factor_a.dim();
if rows != cols {
return Err(ArrowSchurGpuFailure::Unavailable);
}
smooth_offsets_host.push(checked_i32(block.global_offset)?);
smooth_m_host.push(checked_i32(rows)?);
for r in 0..rows {
for c in 0..cols {
smooth_data_host.push(block.factor_a[[r, c]]);
}
}
smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
}
let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
let mut g_data_host = Vec::<f64>::new();
g_ptr_host.push(0_i32);
for block in &data.sparse_g_blocks {
let (rows, cols) = block.data.dim();
g_row_off_host.push(checked_i32(block.row_off)?);
g_col_off_host.push(checked_i32(block.col_off)?);
g_rows_host.push(checked_i32(rows)?);
g_cols_host.push(checked_i32(cols)?);
for r in 0..rows {
for c in 0..cols {
g_data_host.push(block.data[[r, c]]);
}
}
g_ptr_host.push(checked_i32(g_data_host.len())?);
}
let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
for (row_idx, row) in sys.rows.iter().enumerate() {
let q = data.local_jac[row_idx].len() / p;
if row.htt.dim() != (q, q) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
row.htt.dim()
),
});
}
let mut block = row.htt.clone();
for d in 0..q {
block[[d, d]] += ridge_t;
}
let factor = gam_linalg::triangular::cholesky_factor_in_place(
block.view(),
gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
)
.ok_or_else(|| {
let scale = row
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
ArrowSchurGpuFailure::RidgeBumpRequired {
row: row_idx,
bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
}
})?;
for col in 0..q {
let mut e = Array1::<f64>::zeros(q);
e[col] = 1.0;
let solved =
gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
for r in 0..q {
ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
}
}
}
Ok(DeviceSaePcgBuffers {
row_ptr: stream
.clone_htod(&row_ptr_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
beta_base: stream
.clone_htod(&beta_base_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
phi: stream
.clone_htod(&phi_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
jac_ptr: stream
.clone_htod(&jac_ptr_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
jac: stream
.clone_htod(&jac_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
smooth_offsets: stream
.clone_htod(&smooth_offsets_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
smooth_m: stream
.clone_htod(&smooth_m_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
smooth_ptr: stream
.clone_htod(&smooth_ptr_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
smooth_data: stream
.clone_htod(&smooth_data_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_row_off: stream
.clone_htod(&g_row_off_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_col_off: stream
.clone_htod(&g_col_off_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_rows: stream
.clone_htod(&g_rows_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_cols: stream
.clone_htod(&g_cols_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_ptr: stream
.clone_htod(&g_ptr_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
g_data: stream
.clone_htod(&g_data_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
ainv: stream
.clone_htod(&ainv_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
u: stream
.alloc_zeros::<f64>(n_rows * p)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
w: stream
.alloc_zeros::<f64>(n_rows * max_q)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
v: stream
.alloc_zeros::<f64>(n_rows * max_q)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
n_rows,
p,
k,
max_q,
smooth_blocks: data.smooth_blocks.len(),
g_blocks: data.sparse_g_blocks.len(),
})
}
fn launch_sae_init(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
out: &mut CudaSlice<f64>,
x: &CudaSlice<f64>,
ridge: f64,
n: usize,
) -> Result<(), ArrowSchurGpuFailure> {
let kernel = module
.load_function("arrow_sae_init")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let n_i32 = checked_i32(n)?;
let mut builder = stream.launch_builder(&kernel);
builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
unsafe { builder.launch(pcg_launch_config(n)?) }
.map(drop)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
fn launch_sae_penalty_matvec(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &mut DeviceSaePcgBuffers,
x: &CudaSlice<f64>,
out: &mut CudaSlice<f64>,
ridge_beta: f64,
) -> Result<(), ArrowSchurGpuFailure> {
launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
if buffers.smooth_blocks > 0 {
let kernel = module
.load_function("arrow_sae_smooth_matvec")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let max_m = buffers.k;
let p_i32 = checked_i32(buffers.p)?;
let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
let cfg = LaunchConfig {
grid_dim: (
((max_m as u32).saturating_add(255) / 256).max(1),
checked_i32(buffers.smooth_blocks)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernel);
builder
.arg(x)
.arg(&mut *out)
.arg(&buffers.smooth_offsets)
.arg(&buffers.smooth_m)
.arg(&buffers.smooth_ptr)
.arg(&buffers.smooth_data)
.arg(&p_i32)
.arg(&blocks_i32);
unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
if buffers.g_blocks > 0 {
let kernel = module
.load_function("arrow_sae_sparse_g_matvec")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let max_work = buffers
.k
.checked_div(buffers.p)
.unwrap_or(0)
.saturating_mul(buffers.p);
let p_i32 = checked_i32(buffers.p)?;
let blocks_i32 = checked_i32(buffers.g_blocks)?;
let cfg = LaunchConfig {
grid_dim: (
((max_work as u32).saturating_add(255) / 256).max(1),
checked_i32(buffers.g_blocks)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernel);
builder
.arg(x)
.arg(&mut *out)
.arg(&buffers.g_row_off)
.arg(&buffers.g_col_off)
.arg(&buffers.g_rows)
.arg(&buffers.g_cols)
.arg(&buffers.g_ptr)
.arg(&buffers.g_data)
.arg(&p_i32)
.arg(&blocks_i32);
unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn launch_sae_row_schur_sub(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &mut DeviceSaePcgBuffers,
x: &CudaSlice<f64>,
out: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let p_i32 = checked_i32(buffers.p)?;
let max_q_i32 = checked_i32(buffers.max_q)?;
let n_rows_i32 = checked_i32(buffers.n_rows)?;
let cfg_p_rows = LaunchConfig {
grid_dim: (
((buffers.p as u32).saturating_add(255) / 256).max(1),
checked_i32(buffers.n_rows)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let gather = module
.load_function("arrow_sae_gather_u")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let mut builder = stream.launch_builder(&gather);
builder
.arg(x)
.arg(&buffers.row_ptr)
.arg(&buffers.beta_base)
.arg(&buffers.phi)
.arg(&mut buffers.u)
.arg(&p_i32)
.arg(&n_rows_i32);
unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
let cfg_q_rows = LaunchConfig {
grid_dim: (
((buffers.max_q as u32).saturating_add(255) / 256).max(1),
checked_i32(buffers.n_rows)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let apply_l = module
.load_function("arrow_sae_apply_l")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let mut builder = stream.launch_builder(&apply_l);
builder
.arg(&buffers.u)
.arg(&buffers.jac_ptr)
.arg(&buffers.jac)
.arg(&mut buffers.w)
.arg(&p_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
let apply_ainv = module
.load_function("arrow_sae_apply_ainv")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let mut builder = stream.launch_builder(&apply_ainv);
builder
.arg(&buffers.ainv)
.arg(&buffers.w)
.arg(&mut buffers.v)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
let scatter = module
.load_function("arrow_sae_scatter_sub")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
{
let mut builder = stream.launch_builder(&scatter);
builder
.arg(&buffers.v)
.arg(&buffers.jac_ptr)
.arg(&buffers.jac)
.arg(&buffers.row_ptr)
.arg(&buffers.beta_base)
.arg(&buffers.phi)
.arg(out)
.arg(&p_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn launch_sae_diag_sub(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &DeviceSaePcgBuffers,
diag: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let kernel = module
.load_function("arrow_sae_diag_sub")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let p_i32 = checked_i32(buffers.p)?;
let max_q_i32 = checked_i32(buffers.max_q)?;
let n_rows_i32 = checked_i32(buffers.n_rows)?;
let cfg = LaunchConfig {
grid_dim: (
((buffers.p as u32).saturating_add(255) / 256).max(1),
checked_i32(buffers.n_rows)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernel);
builder
.arg(diag)
.arg(&buffers.ainv)
.arg(&buffers.jac_ptr)
.arg(&buffers.jac)
.arg(&buffers.row_ptr)
.arg(&buffers.beta_base)
.arg(&buffers.phi)
.arg(&p_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { builder.launch(cfg) }
.map(drop)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
fn launch_sae_matvec(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &mut DeviceSaePcgBuffers,
x: &CudaSlice<f64>,
out: &mut CudaSlice<f64>,
ridge_beta: f64,
) -> Result<(), ArrowSchurGpuFailure> {
launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
launch_sae_row_schur_sub(stream, module, buffers, x, out)
}
fn pack_fused_host(
sys: &ArrowSchurSystem,
ridge_t: f64,
p_max: usize,
r_template: usize,
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let mut d_buf = vec![0.0_f64; n * p_max * p_max];
let mut b_buf = vec![0.0_f64; n * p_max * r_template];
let mut g_buf = vec![0.0_f64; n * p_max];
for (i, row) in sys.rows.iter().enumerate() {
for col in 0..d {
let base = (i * p_max + col) * p_max;
for r in 0..d {
let mut value = row.htt[[r, col]];
if r == col {
value += ridge_t;
}
d_buf[base + r] = value;
}
}
for col in 0..k {
let base = (i * p_max + col) * p_max;
for r in 0..d {
b_buf[base + r] = row.htbeta[[r, col]];
}
}
let g_base = i * p_max;
for r in 0..d {
g_buf[g_base + r] = row.gt[r];
}
}
(d_buf, b_buf, g_buf)
}
pub(super) struct ResidentArrowFrame {
n: usize,
d: usize,
k: usize,
stream: Arc<CudaStream>,
blas: CudaBlas,
l_dev: CudaSlice<f64>,
y_dev: CudaSlice<f64>,
schur_dev: CudaSlice<f64>,
log_det_hessian: f64,
}
impl ResidentArrowFrame {
pub(super) fn new(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<Self, ArrowSchurGpuFailure> {
if ridge_t.is_nan() || ridge_beta.is_nan() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: "ridge is NaN".to_string(),
});
}
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.and_then(|ctx| ctx.new_stream().ok())
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas =
CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
let mut l_dev = stream
.clone_htod(&d_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut y_dev = stream
.clone_htod(&b_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
if let Some(idx) = info_host.iter().position(|info| *info != 0) {
let pivot = info_host[idx];
let scale = sys.rows[idx]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
row: idx,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
let schur_init: Vec<f64> = {
let mut tmp = Vec::with_capacity(k * k);
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
tmp.push(v);
}
}
tmp
};
let mut schur_dev = stream
.clone_htod(&schur_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let zero_u = stream
.clone_htod(&vec![0.0_f64; n * d])
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut throwaway_rhs = stream
.clone_htod(&vec![0.0_f64; k])
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
accumulate_schur(
&blas,
d,
k,
n,
&y_dev,
&zero_u,
&mut schur_dev,
&mut throwaway_rhs,
)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("Schur Cholesky failed at pivot {info}"),
});
}
let l_local_host = stream
.clone_dtoh(&l_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut log_det = 0.0_f64;
for i in 0..n {
let base = i * d * d;
for j in 0..d {
log_det += l_local_host[base + j * d + j].ln();
}
}
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
Ok(Self {
n,
d,
k,
stream,
blas,
l_dev,
y_dev,
schur_dev,
log_det_hessian: log_det,
})
}
#[inline]
pub(super) fn log_det_hessian(&self) -> f64 {
self.log_det_hessian
}
pub(super) fn solve_gradient(
&self,
g_t: &[f64],
g_beta: &[f64],
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = self.n;
let d = self.d;
let k = self.k;
if g_t.len() != n * d || g_beta.len() != k {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
g_t.len(),
n * d,
g_beta.len(),
k
),
});
}
let mut u_dev = self
.stream
.clone_htod(&g_t.to_vec())
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
let mut rhs_dev = self
.stream
.clone_htod(&rhs_init)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
trsm_single(
&self.blas,
&self.stream,
k,
&self.schur_dev,
&mut rhs_dev,
false,
false,
)?;
trsm_single(
&self.blas,
&self.stream,
k,
&self.schur_dev,
&mut rhs_dev,
false,
true,
)?;
let delta_beta_host = self
.stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host);
accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
trsm_batched_lower_inplace_transposed(
&self.blas,
&self.stream,
d,
n,
1,
&self.l_dev,
&mut u_dev,
)?;
let x_host = self
.stream
.clone_dtoh(&u_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut delta_t = Array1::<f64>::zeros(n * d);
for (i, v) in x_host.iter().enumerate() {
delta_t[i] = -*v;
}
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: self.log_det_hessian,
})
}
}
pub(super) fn solve_fused(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let p_max = plan.p_max;
let r_template = plan.r_template;
let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cap = &runtime.device.capability;
let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
cc_major: cap.compute_major,
cc_minor: cap.compute_minor,
p_max: p_max as u32,
r_template: r_template as u32,
};
let module = fused_module_for(&ctx, key)?;
let forward = module
.load_function("arrow_schur_forward_pgroup")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let back_sub = module
.load_function("arrow_schur_back_sub_pgroup")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
let d_dev = stream
.clone_htod(&d_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let b_dev = stream
.clone_htod(&b_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let g_dev = stream
.clone_htod(&g_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut l_out = stream
.alloc_zeros::<f64>(n * p_max * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut u_out = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut y_out = stream
.alloc_zeros::<f64>(n * p_max * r_template)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut partial_s = stream
.alloc_zeros::<f64>(plan.partial_s_doubles)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut partial_r = stream
.alloc_zeros::<f64>(plan.partial_r_doubles)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut status_dev = stream
.alloc_zeros::<i32>(n)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ridge_arg = ridge_t;
{
let mut builder = stream.launch_builder(&forward);
builder
.arg(&d_dev)
.arg(&b_dev)
.arg(&g_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&ridge_arg)
.arg(&mut l_out)
.arg(&mut u_out)
.arg(&mut y_out)
.arg(&mut partial_s)
.arg(&mut partial_r)
.arg(&mut status_dev);
unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let status_host = stream
.clone_dtoh(&status_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
if let Some(row) = status_host.iter().position(|s| *s != 0) {
let pivot = status_host[row];
let scale = sys.rows[row]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
row,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
let partial_s_host = stream
.clone_dtoh(&partial_s)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let partial_r_host = stream
.clone_dtoh(&partial_r)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut schur_host = vec![0.0_f64; k * k];
for col in 0..k {
for row in 0..k {
let mut v = sys.hbb[[row, col]];
if row == col {
v += ridge_beta;
}
schur_host[col * k + row] = v;
}
}
let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
for i in 0..n {
let s_base = i * r_template * r_template;
for col in 0..k {
let col_base = s_base + col * r_template;
let dst_col_base = col * k;
for row in 0..k {
schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
}
}
let r_base = i * r_template;
for a in 0..k {
rhs_host[a] += partial_r_host[r_base + a];
}
}
let mut schur_dev = stream
.clone_htod(&schur_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rhs_dev = stream
.clone_htod(&rhs_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let solver =
DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
if info != 0 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("fused Schur Cholesky failed at pivot {info}"),
});
}
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
let delta_beta_host = stream
.clone_dtoh(&rhs_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_beta = Array1::from_vec(delta_beta_host.clone());
let mut delta_t_dev = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let back_cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
{
let mut builder = stream.launch_builder(&back_sub);
builder
.arg(&l_out)
.arg(&u_out)
.arg(&y_out)
.arg(&rhs_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&mut delta_t_dev);
unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let delta_t_host = stream
.clone_dtoh(&delta_t_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut delta_t = Array1::<f64>::zeros(n * d);
for i in 0..n {
let src_base = i * p_max;
let dst_base = i * d;
for r in 0..d {
delta_t[dst_base + r] = delta_t_host[src_base + r];
}
}
let l_local_host = stream
.clone_dtoh(&l_out)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let l_schur_host = stream
.clone_dtoh(&schur_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut log_det = 0.0_f64;
for i in 0..n {
let base = i * p_max * p_max;
for j in 0..d {
log_det += l_local_host[base + j * p_max + j].ln();
}
}
for j in 0..k {
log_det += l_schur_host[j * k + j].ln();
}
log_det *= 2.0;
Ok(ArrowSchurGpuSolution {
delta_t,
delta_beta,
log_det_hessian: log_det,
})
}
pub(super) fn build_schur_matvec_backend(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let p_max = plan.p_max;
let r_template = plan.r_template;
let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let cap = &runtime.device.capability;
let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
cc_major: cap.compute_major,
cc_minor: cap.compute_minor,
p_max: p_max as u32,
r_template: r_template as u32,
};
let module = fused_module_for(&ctx, key)?;
let forward = module
.load_function("arrow_schur_forward_pgroup")
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
let d_dev = stream
.clone_htod(&d_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let b_dev = stream
.clone_htod(&b_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let g_dev = stream
.clone_htod(&g_host)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut l_out = stream
.alloc_zeros::<f64>(n * p_max * p_max)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut u_out = stream
.alloc_zeros::<f64>(n * p_max)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut y_out = stream
.alloc_zeros::<f64>(n * p_max * r_template)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut partial_s = stream
.alloc_zeros::<f64>(plan.partial_s_doubles)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut partial_r = stream
.alloc_zeros::<f64>(plan.partial_r_doubles)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let mut status_dev = stream
.alloc_zeros::<i32>(n)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let cfg = LaunchConfig {
grid_dim: (plan.blocks, 1, 1),
block_dim: (plan.threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
let ridge_arg = ridge_t;
{
let mut builder = stream.launch_builder(&forward);
builder
.arg(&d_dev)
.arg(&b_dev)
.arg(&g_dev)
.arg(&n_i32)
.arg(&p_i32)
.arg(&r_i32)
.arg(&ridge_arg)
.arg(&mut l_out)
.arg(&mut u_out)
.arg(&mut y_out)
.arg(&mut partial_s)
.arg(&mut partial_r)
.arg(&mut status_dev);
unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
}
stream
.synchronize()
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let status_host = stream
.clone_dtoh(&status_dev)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
if let Some(row) = status_host.iter().position(|s| *s != 0) {
let pivot = status_host[row];
let scale = sys.rows[row]
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
row,
bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
});
}
let y_host = stream
.clone_dtoh(&y_out)
.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
let hbb_is_kk = sys.hbb.dim() == (k, k);
let hbb_matvec_opt = sys.hbb_matvec.clone();
let closure: crate::arrow_schur::GpuSchurMatvec =
Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
if let Some(ref mv) = hbb_matvec_opt {
mv(x.view(), out);
for a in 0..k {
out[a] += ridge_beta * x[a];
}
} else if hbb_is_kk {
for a in 0..k {
let mut acc = ridge_beta * x[a];
for b in 0..k {
acc += hbb_host[a * k + b] * x[b];
}
out[a] = acc;
}
} else {
for a in 0..k {
out[a] = ridge_beta * x[a];
}
}
let mut z = vec![0.0_f64; d];
for i in 0..n {
let y_base = i * p_max * r_template;
for r in 0..d {
let mut acc = 0.0;
for c in 0..k {
acc += y_host[y_base + c * p_max + r] * x[c];
}
z[r] = acc;
}
for c in 0..k {
let mut acc = 0.0;
for r in 0..d {
acc += y_host[y_base + c * p_max + r] * z[r];
}
out[c] -= acc;
}
}
});
Ok(closure)
}
struct DeviceSaeFrameBuffers {
s_off: CudaSlice<i32>,
s_m: CudaSlice<i32>,
s_r: CudaSlice<i32>,
s_ptr: CudaSlice<i32>,
s_data: CudaSlice<f64>,
s_blocks: usize,
g_off_i: CudaSlice<i32>,
g_off_j: CudaSlice<i32>,
g_ri: CudaSlice<i32>,
g_rj: CudaSlice<i32>,
g_mi: CudaSlice<i32>,
g_mj: CudaSlice<i32>,
g_ptr: CudaSlice<i32>,
g_data: CudaSlice<f64>,
w_ptr: CudaSlice<i32>,
w_data: CudaSlice<f64>,
g_blocks: usize,
g_max_work: usize,
htb_ptr: CudaSlice<i32>,
htb: CudaSlice<f64>,
q_of: CudaSlice<i32>,
ainv: CudaSlice<f64>,
hvec: CudaSlice<f64>,
svec: CudaSlice<f64>,
n_rows: usize,
k: usize,
max_q: usize,
}
fn flatten_device_sae_frame_data(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
frame: &DeviceSaeFrameData,
ridge_t: f64,
stream: &Arc<CudaStream>,
) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
let n_rows = sys.rows.len();
let k = data.beta_dim;
if frame.row_htbeta.len() != n_rows
|| frame.ranks.len() != frame.basis_sizes.len()
|| frame.border_offsets.len() != frame.ranks.len()
|| data.smooth_blocks.len() != frame.smooth_ranks.len()
{
return Err(ArrowSchurGpuFailure::Unavailable);
}
let mut s_off = Vec::new();
let mut s_m = Vec::new();
let mut s_r = Vec::new();
let mut s_ptr = vec![0_i32];
let mut s_data = Vec::<f64>::new();
for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
let (m, mc) = blk.factor_a.dim();
if m != mc {
return Err(ArrowSchurGpuFailure::Unavailable);
}
s_off.push(checked_i32(blk.global_offset)?);
s_m.push(checked_i32(m)?);
s_r.push(checked_i32(r)?);
for ri in 0..m {
for ci in 0..m {
s_data.push(blk.factor_a[[ri, ci]]);
}
}
s_ptr.push(checked_i32(s_data.len())?);
}
let mut g_off_i = Vec::new();
let mut g_off_j = Vec::new();
let mut g_ri = Vec::new();
let mut g_rj = Vec::new();
let mut g_mi = Vec::new();
let mut g_mj = Vec::new();
let mut g_ptr = vec![0_i32];
let mut g_data = Vec::<f64>::new();
let mut w_ptr = vec![0_i32];
let mut w_data = Vec::<f64>::new();
let mut g_max_work = 0usize;
for blk in &frame.frame_blocks {
let ri = frame.ranks[blk.atom_i];
let rj = frame.ranks[blk.atom_j];
let (mi, mj) = blk.g.dim();
if blk.w.dim() != (ri, rj) {
return Err(ArrowSchurGpuFailure::Unavailable);
}
g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
g_ri.push(checked_i32(ri)?);
g_rj.push(checked_i32(rj)?);
g_mi.push(checked_i32(mi)?);
g_mj.push(checked_i32(mj)?);
for r in 0..mi {
for c in 0..mj {
g_data.push(blk.g[[r, c]]);
}
}
g_ptr.push(checked_i32(g_data.len())?);
for a in 0..ri {
for b in 0..rj {
w_data.push(blk.w[[a, b]]);
}
}
w_ptr.push(checked_i32(w_data.len())?);
g_max_work = g_max_work.max(mi * ri);
}
let mut htb_ptr = vec![0_i32];
let mut htb = Vec::<f64>::new();
let mut q_of = Vec::<i32>::with_capacity(n_rows);
let mut max_q = 0usize;
for (i, slab) in frame.row_htbeta.iter().enumerate() {
let qi = sys.row_dims[i];
let q_eff = if !slab.is_empty() && slab.len() == qi * k {
qi
} else {
0
};
q_of.push(checked_i32(q_eff)?);
max_q = max_q.max(q_eff);
if q_eff > 0 {
htb.extend_from_slice(slab);
}
htb_ptr.push(checked_i32(htb.len())?);
}
if max_q == 0 {
max_q = 1;
}
let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
for (i, row) in sys.rows.iter().enumerate() {
let q = q_of[i] as usize;
if q == 0 {
continue;
}
if row.htt.dim() != (q, q) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
row.htt.dim()
),
});
}
let mut block = row.htt.clone();
for d in 0..q {
block[[d, d]] += ridge_t;
}
let factor = gam_linalg::triangular::cholesky_factor_in_place(
block.view(),
gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
)
.ok_or_else(|| {
let scale = row
.htt
.diag()
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1.0);
ArrowSchurGpuFailure::RidgeBumpRequired {
row: i,
bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
}
})?;
for col in 0..q {
let mut e = Array1::<f64>::zeros(q);
e[col] = 1.0;
let solved =
gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
for r in 0..q {
ainv[i * max_q * max_q + r * max_q + col] = solved[r];
}
}
}
let htod_i = |v: &[i32]| {
stream
.clone_htod(v)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
};
let htod_f = |v: &[f64]| {
stream
.clone_htod(v)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
};
Ok(DeviceSaeFrameBuffers {
s_off: htod_i(&s_off)?,
s_m: htod_i(&s_m)?,
s_r: htod_i(&s_r)?,
s_ptr: htod_i(&s_ptr)?,
s_data: htod_f(&s_data)?,
s_blocks: data.smooth_blocks.len(),
g_off_i: htod_i(&g_off_i)?,
g_off_j: htod_i(&g_off_j)?,
g_ri: htod_i(&g_ri)?,
g_rj: htod_i(&g_rj)?,
g_mi: htod_i(&g_mi)?,
g_mj: htod_i(&g_mj)?,
g_ptr: htod_i(&g_ptr)?,
g_data: htod_f(&g_data)?,
w_ptr: htod_i(&w_ptr)?,
w_data: htod_f(&w_data)?,
g_blocks: frame.frame_blocks.len(),
g_max_work,
htb_ptr: htod_i(&htb_ptr)?,
htb: htod_f(&htb)?,
q_of: htod_i(&q_of)?,
ainv: htod_f(&ainv)?,
hvec: stream
.alloc_zeros::<f64>(n_rows * max_q)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
svec: stream
.alloc_zeros::<f64>(n_rows * max_q)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
n_rows,
k,
max_q,
})
}
fn sae_frame_penalty_diag_host(
data: &DeviceSaePcgData,
frame: &DeviceSaeFrameData,
ridge_beta: f64,
) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
let mut diag = vec![ridge_beta; data.beta_dim];
for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
let m = blk.factor_a.nrows();
for ia in 0..m {
let coeff = blk.factor_a[[ia, ia]];
let base = blk.global_offset + ia * r;
for ib in 0..r {
if base + ib >= diag.len() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
diag[base + ib] += coeff;
}
}
}
for blk in &frame.frame_blocks {
if blk.atom_i != blk.atom_j {
continue;
}
let r = frame.ranks[blk.atom_i];
let off = frame.border_offsets[blk.atom_i];
let (mi, mj) = blk.g.dim();
for li in 0..mi.min(mj) {
let gii = blk.g[[li, li]];
let base = off + li * r;
for a in 0..r {
if base + a >= diag.len() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
diag[base + a] += gii * blk.w[[a, a]];
}
}
}
Ok(diag)
}
fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
Ok(LaunchConfig {
grid_dim: (
((work as u32).saturating_add(255) / 256).max(1),
checked_i32(n_rows)? as u32,
1,
),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
fn launch_sae_frame_matvec(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &mut DeviceSaeFrameBuffers,
x: &CudaSlice<f64>,
out: &mut CudaSlice<f64>,
ridge_beta: f64,
) -> Result<(), ArrowSchurGpuFailure> {
launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
if buffers.s_blocks > 0 {
let kernel = module
.load_function("arrow_sae_frame_smooth_matvec")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blocks_i32 = checked_i32(buffers.s_blocks)?;
let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
let mut b = stream.launch_builder(&kernel);
b.arg(x)
.arg(&mut *out)
.arg(&buffers.s_off)
.arg(&buffers.s_m)
.arg(&buffers.s_r)
.arg(&buffers.s_ptr)
.arg(&buffers.s_data)
.arg(&blocks_i32);
unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
if buffers.g_blocks > 0 {
let kernel = module
.load_function("arrow_sae_frame_g_matvec")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blocks_i32 = checked_i32(buffers.g_blocks)?;
let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
let mut b = stream.launch_builder(&kernel);
b.arg(x)
.arg(&mut *out)
.arg(&buffers.g_off_i)
.arg(&buffers.g_off_j)
.arg(&buffers.g_ri)
.arg(&buffers.g_rj)
.arg(&buffers.g_mi)
.arg(&buffers.g_mj)
.arg(&buffers.g_ptr)
.arg(&buffers.g_data)
.arg(&buffers.w_ptr)
.arg(&buffers.w_data)
.arg(&blocks_i32);
unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
let k_i32 = checked_i32(buffers.k)?;
let max_q_i32 = checked_i32(buffers.max_q)?;
let n_rows_i32 = checked_i32(buffers.n_rows)?;
{
let kernel = module
.load_function("arrow_sae_frame_apply_h")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
let mut b = stream.launch_builder(&kernel);
b.arg(x)
.arg(&buffers.htb_ptr)
.arg(&buffers.htb)
.arg(&buffers.q_of)
.arg(&mut buffers.hvec)
.arg(&k_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
{
let kernel = module
.load_function("arrow_sae_frame_apply_ainv")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
let mut b = stream.launch_builder(&kernel);
b.arg(&buffers.ainv)
.arg(&buffers.hvec)
.arg(&buffers.q_of)
.arg(&mut buffers.svec)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
{
let kernel = module
.load_function("arrow_sae_frame_scatter_h")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let cfg = frame_grid(buffers.k, buffers.n_rows)?;
let mut b = stream.launch_builder(&kernel);
b.arg(&buffers.svec)
.arg(&buffers.htb_ptr)
.arg(&buffers.htb)
.arg(&buffers.q_of)
.arg(out)
.arg(&k_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
}
Ok(())
}
fn launch_sae_frame_diag_sub(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
buffers: &DeviceSaeFrameBuffers,
diag: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let kernel = module
.load_function("arrow_sae_frame_diag_sub")
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let k_i32 = checked_i32(buffers.k)?;
let max_q_i32 = checked_i32(buffers.max_q)?;
let n_rows_i32 = checked_i32(buffers.n_rows)?;
let cfg = frame_grid(buffers.k, buffers.n_rows)?;
let mut b = stream.launch_builder(&kernel);
b.arg(diag)
.arg(&buffers.ainv)
.arg(&buffers.htb_ptr)
.arg(&buffers.htb)
.arg(&buffers.q_of)
.arg(&k_i32)
.arg(&max_q_i32)
.arg(&n_rows_i32);
unsafe { b.launch(cfg) }
.map(drop)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
pub(super) fn solve_sae_matrix_free_pcg_framed(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
ridge_beta: f64,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
let k = rhs_beta.len();
if k == 0 || data.beta_dim != k || sys.k != k {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let frame = data
.frame
.as_ref()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let runtime = gam_gpu::device_runtime::GpuRuntime::global()
.filter(|rt| {
rt.policy().reduced_schur_matvec_should_offload(
sys.rows.len(),
sys.k,
sys.d,
max_iterations,
)
})
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let vector_module = pcg_vector_module(&ctx)?;
let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
if rhs_norm == 0.0 {
return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
}
let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
let rhs_dev = stream
.clone_htod(
rhs_beta
.as_slice()
.ok_or(ArrowSchurGpuFailure::Unavailable)?,
)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
let mut diag_dev = stream
.clone_htod(&diag_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
let diag_host = stream
.clone_dtoh(&diag_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut inv_diag = Vec::with_capacity(k);
for (idx, &d) in diag_host.iter().enumerate() {
if !d.is_finite() || d <= 1.0e-18 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
),
});
}
inv_diag.push(1.0 / d);
}
let inv_diag_dev = stream
.clone_htod(&inv_diag)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut x_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut r_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
let mut z_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
let mut p_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
let mut ap_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
if rz <= 0.0 || !rz.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
});
}
let mut diag = PcgDiagnostics {
precond_apply_calls: 1,
stopping_reason: PcgStopReason::MaxIter,
..PcgDiagnostics::default()
};
for _ in 0..max_iterations.max(1) {
launch_sae_frame_matvec(
&stream,
vector_module,
&mut buffers,
&p_dev,
&mut ap_dev,
ridge_beta,
)?;
diag.matvec_calls += 1;
diag.iterations += 1;
let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
if pap <= 0.0 || !pap.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
});
}
let alpha = rz / pap;
device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
if r_norm <= tol {
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::Converged;
break;
}
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
diag.precond_apply_calls += 1;
let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
if rz_new <= 0.0 || !rz_new.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
});
}
let beta = rz_new / rz;
launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
rz = rz_new;
}
if diag.stopping_reason != PcgStopReason::Converged {
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::MaxIter;
}
let x = stream
.clone_dtoh(&x_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
Ok((Array1::from_vec(x), diag))
}
pub(super) fn solve_sae_matrix_free_pcg(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
ridge_beta: f64,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
let k = rhs_beta.len();
if k == 0 || data.beta_dim != k || sys.k != k {
return Err(ArrowSchurGpuFailure::Unavailable);
}
if data.frame.is_some() {
return Err(ArrowSchurGpuFailure::Unavailable);
}
let runtime = gam_gpu::device_runtime::GpuRuntime::global()
.filter(|rt| {
rt.policy().reduced_schur_matvec_should_offload(
sys.rows.len(),
sys.k,
sys.d,
max_iterations,
)
})
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let vector_module = pcg_vector_module(&ctx)?;
let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
if rhs_norm == 0.0 {
return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
}
let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
let rhs_dev = stream
.clone_htod(
rhs_beta
.as_slice()
.ok_or(ArrowSchurGpuFailure::Unavailable)?,
)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
let mut diag_dev = stream
.clone_htod(&diag_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
let diag_host = stream
.clone_dtoh(&diag_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut inv_diag = Vec::with_capacity(k);
for (idx, &d) in diag_host.iter().enumerate() {
if !d.is_finite() || d <= 1.0e-18 {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
),
});
}
inv_diag.push(1.0 / d);
}
let inv_diag_dev = stream
.clone_htod(&inv_diag)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut x_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut r_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
let mut z_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
let mut p_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
let mut ap_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
if rz <= 0.0 || !rz.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
});
}
let mut diag = PcgDiagnostics {
precond_apply_calls: 1,
stopping_reason: PcgStopReason::MaxIter,
..PcgDiagnostics::default()
};
for _ in 0..max_iterations.max(1) {
launch_sae_matvec(
&stream,
vector_module,
&mut buffers,
&p_dev,
&mut ap_dev,
ridge_beta,
)?;
diag.matvec_calls += 1;
diag.iterations += 1;
let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
if pap <= 0.0 || !pap.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
});
}
let alpha = rz / pap;
device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
if r_norm <= tol {
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::Converged;
break;
}
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
diag.precond_apply_calls += 1;
let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
if rz_new <= 0.0 || !rz_new.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
});
}
let beta = rz_new / rz;
launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
rz = rz_new;
}
if diag.stopping_reason != PcgStopReason::Converged {
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::MaxIter;
}
let x = stream
.clone_dtoh(&x_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
Ok((Array1::from_vec(x), diag))
}
pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
s_acc: &ndarray::Array2<f64>,
rhs_beta: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
let k = rhs_beta.len();
let cg_iters = max_iterations.max(1);
let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
gam_gpu::linalg_dispatch::DispatchOp::Gemm {
m: k,
n: k,
k: cg_iters,
},
)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.and_then(|ctx| ctx.new_stream().ok())
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let vector_module = pcg_vector_module(&ctx)?;
let mut inv_diag = vec![0.0_f64; k];
for j in 0..k {
let djj = s_acc[[j, j]];
if !(djj > 0.0) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!(
"reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
),
});
}
inv_diag[j] = 1.0 / djj;
}
let mut s_host = vec![0.0_f64; k * k];
for col in 0..k {
for row in 0..k {
s_host[col * k + row] = s_acc[[row, col]];
}
}
let s_dev = stream
.clone_htod(&s_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
if rhs_norm == 0.0 {
return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
}
let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
let mut x_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut r_dev = stream
.clone_htod(
rhs_beta
.as_slice()
.ok_or(ArrowSchurGpuFailure::Unavailable)?,
)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let inv_diag_dev = stream
.clone_htod(&inv_diag)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut z_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
let mut p_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
let mut sp_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
let mut diag = PcgDiagnostics {
precond_apply_calls: 1,
stopping_reason: PcgStopReason::MaxIter,
..PcgDiagnostics::default()
};
if rz <= 0.0 || !rz.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
});
}
let max_iters = max_iterations.max(1);
for _ in 0..max_iters {
let gemv_cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_N,
m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
alpha: 1.0,
lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
diag.matvec_calls += 1;
diag.iterations += 1;
let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
if !(p_sp > 0.0) {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
});
}
let alpha = rz / p_sp;
device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
if r_norm <= tol {
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::Converged;
break;
}
launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
diag.precond_apply_calls += 1;
let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
if rz_new <= 0.0 || !rz_new.is_finite() {
return Err(ArrowSchurGpuFailure::SchurFactorFailed {
reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
});
}
let beta = rz_new / rz;
launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
rz = rz_new;
}
if diag.stopping_reason != PcgStopReason::Converged {
let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
diag.final_relative_residual = r_norm / rhs_norm;
diag.stopping_reason = PcgStopReason::MaxIter;
}
let x = stream
.clone_dtoh(&x_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
Ok((Array1::from_vec(x), diag))
}
fn device_copy(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
src: &CudaSlice<f64>,
dst: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let (src_ptr, _src_rec) = src.device_ptr(stream);
let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
let status = unsafe {
cudarc::cublas::sys::cublasDcopy_v2(
*blas.handle(),
n_i,
src_ptr as *const f64,
1,
dst_ptr as *mut f64,
1,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(ArrowSchurGpuFailure::Unavailable)
}
}
fn device_axpy(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
alpha: f64,
x: &CudaSlice<f64>,
y: &mut CudaSlice<f64>,
) -> Result<(), ArrowSchurGpuFailure> {
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let (x_ptr, _x_rec) = x.device_ptr(stream);
let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
let status = unsafe {
cudarc::cublas::sys::cublasDaxpy_v2(
*blas.handle(),
n_i,
&alpha,
x_ptr as *const f64,
1,
y_ptr as *mut f64,
1,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(())
} else {
Err(ArrowSchurGpuFailure::Unavailable)
}
}
fn device_dot(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
x: &CudaSlice<f64>,
y: &CudaSlice<f64>,
) -> Result<f64, ArrowSchurGpuFailure> {
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let (x_ptr, _x_rec) = x.device_ptr(stream);
let (y_ptr, _y_rec) = y.device_ptr(stream);
let mut result = 0.0_f64;
let status = unsafe {
cudarc::cublas::sys::cublasDdot_v2(
*blas.handle(),
n_i,
x_ptr as *const f64,
1,
y_ptr as *const f64,
1,
&mut result,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(result)
} else {
Err(ArrowSchurGpuFailure::Unavailable)
}
}
fn device_nrm2(
blas: &CudaBlas,
stream: &Arc<CudaStream>,
n: usize,
x: &CudaSlice<f64>,
) -> Result<f64, ArrowSchurGpuFailure> {
let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
let (x_ptr, _x_rec) = x.device_ptr(stream);
let mut result = 0.0_f64;
let status = unsafe {
cudarc::cublas::sys::cublasDnrm2_v2(
*blas.handle(),
n_i,
x_ptr as *const f64,
1,
&mut result,
)
};
if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
Ok(result)
} else {
Err(ArrowSchurGpuFailure::Unavailable)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow_schur::{
ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
FactoredFrameGBlock,
};
use ndarray::Array2;
fn device_matvec_once(
sys: &ArrowSchurSystem,
data: &DeviceSaePcgData,
ridge_t: f64,
ridge_beta: f64,
x_host: &[f64],
) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
let k = x_host.len();
let frame = data
.frame
.as_ref()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let runtime = gam_gpu::device_runtime::GpuRuntime::global()
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let ctx =
gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
.ok_or(ArrowSchurGpuFailure::Unavailable)?;
let stream = ctx
.new_stream()
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let vector_module = pcg_vector_module(&ctx)?;
let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
let x_dev = stream
.clone_htod(x_host)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
let mut out_dev = stream
.alloc_zeros::<f64>(k)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
launch_sae_frame_matvec(
&stream,
vector_module,
&mut buffers,
&x_dev,
&mut out_dev,
ridge_beta,
)?;
stream
.clone_dtoh(&out_dev)
.map_err(|_| ArrowSchurGpuFailure::Unavailable)
}
#[test]
fn framed_sae_device_matvec_stage_diff_tiny_1551() {
if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
return;
}
let p = 3usize;
let ranks = vec![2usize, 3usize];
let basis_sizes = vec![2usize, 2usize];
let mut border_offsets = Vec::new();
let mut acc = 0usize;
for k in 0..2 {
border_offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
})
};
let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
let w_of = |i: usize, j: usize| -> Array2<f64> {
let (ui, uj) = (&frames[i], &frames[j]);
Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
(0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
})
};
let mut frame_blocks = Vec::new();
for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
let mut g =
Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
if i == j {
for r in 0..mi.min(mj) {
g[[r, r]] += mi as f64 + 2.0;
}
}
frame_blocks.push(FactoredFrameGBlock {
atom_i: i,
atom_j: j,
g,
w: w_of(i, j),
});
}
let mut smooth_blocks = Vec::new();
for k in 0..2 {
let m = basis_sizes[k];
let mut s =
Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
for r in 0..m {
s[[r, r]] += 1.0;
}
smooth_blocks.push(DeviceSaeSmoothBlock {
global_offset: border_offsets[k],
factor_a: s,
});
}
let smooth_ranks = ranks.clone();
let n = 2usize;
let q = 2usize;
let mut sys = ArrowSchurSystem::new(n, q, border_dim);
let mut row_htbeta = Vec::new();
for i in 0..n {
let mut htt =
Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
for r in 0..q {
htt[[r, r]] += q as f64 + 2.0;
}
sys.rows[i].htt = htt;
let mut slab = vec![0.0_f64; q * border_dim];
for c in 0..q {
for col in 0..border_dim {
let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
slab[c * border_dim + col] = v;
sys.rows[i].htbeta[[c, col]] = v;
}
}
row_htbeta.push(slab);
}
let data = DeviceSaePcgData {
p,
beta_dim: border_dim,
a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
smooth_blocks,
sparse_g_blocks: Vec::new(),
frame: Some(DeviceSaeFrameData {
ranks,
basis_sizes,
border_offsets,
frame_blocks,
smooth_ranks,
row_htbeta,
}),
};
let ridge_t = 1e-7;
let ridge_beta = 1e-6;
let mut first_bad: Option<usize> = None;
let mut worst = 0.0_f64;
let mut worst_at = 0usize;
let mut worst_dev = 0.0_f64;
let mut worst_cpu = 0.0_f64;
for col in 0..border_dim {
let mut x = vec![0.0_f64; border_dim];
x[col] = 1.0;
let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
Ok(v) => v,
Err(_) => return,
};
let mut cpu = vec![0.0_f64; border_dim];
super::super::sae_framed_schur_matvec_cpu(
&sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
)
.expect("cpu matvec");
for r in 0..border_dim {
let d = (dev[r] - cpu[r]).abs();
if d > 1e-9 && first_bad.is_none() {
first_bad = Some(r * border_dim + col);
}
if d > worst {
worst = d;
worst_at = r * border_dim + col;
worst_dev = dev[r];
worst_cpu = cpu[r];
}
}
}
assert!(
worst <= 1e-9,
"[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
(row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
G⊗W=cross, reduced-Schur=dense per-row)",
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow_schur::ArrowSchurSystem;
use ndarray::{Array2, ArrayView1};
fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(n, d, k);
let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
let mut sample = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
};
for row in &mut sys.rows {
let mut a = Array2::<f64>::zeros((d, d));
for r in 0..d {
for c in 0..d {
a[[r, c]] = sample();
}
}
let mut htt = a.t().dot(&a);
for r in 0..d {
htt[[r, r]] += d as f64 + 1.0;
}
row.htt = htt;
for r in 0..d {
for c in 0..k {
row.htbeta[[r, c]] = 0.1 * sample();
}
row.gt[r] = sample();
}
}
let mut hbb_a = Array2::<f64>::zeros((k, k));
for r in 0..k {
for c in 0..k {
hbb_a[[r, c]] = sample();
}
}
let mut hbb = hbb_a.t().dot(&hbb_a);
for r in 0..k {
hbb[[r, r]] += k as f64 + 1.0;
}
sys.hbb = hbb;
for r in 0..k {
sys.gb[r] = sample();
}
sys
}
fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
let mut s = Array2::<f64>::zeros((k, k));
for row in 0..k {
s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
if row + 1 < k {
s[[row, row + 1]] = -0.05;
s[[row + 1, row]] = -0.05;
}
if row + 7 < k {
s[[row, row + 7]] = 0.01;
s[[row + 7, row]] = 0.01;
}
}
let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
(s, rhs)
}
fn dense_pcg_cpu_reference(
s: &Array2<f64>,
rhs: &Array1<f64>,
max_iterations: usize,
relative_tolerance: f64,
) -> Array1<f64> {
let k = rhs.len();
let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
if rhs_norm == 0.0 {
return Array1::<f64>::zeros(k);
}
let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
let mut x = Array1::<f64>::zeros(k);
let mut r = rhs.clone();
let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
let mut p = z.clone();
let mut sp = Array1::<f64>::zeros(k);
let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
for _ in 0..max_iterations.max(1) {
for row in 0..k {
let mut acc = 0.0;
for col in 0..k {
acc += s[[row, col]] * p[col];
}
sp[row] = acc;
}
let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
let alpha = rz / p_sp;
for idx in 0..k {
x[idx] += alpha * p[idx];
r[idx] -= alpha * sp[idx];
}
let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
if r_norm <= tol {
break;
}
for idx in 0..k {
z[idx] = inv_diag[idx] * r[idx];
}
let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
let beta = rz_next / rz;
for idx in 0..k {
p[idx] = z[idx] + beta * p[idx];
}
rz = rz_next;
}
x
}
#[test]
fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
let (s, rhs) = device_pcg_fixture(512);
let max_iterations = 200usize;
let relative_tolerance = 1.0e-12;
let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
&s,
&rhs,
max_iterations,
relative_tolerance,
) {
Ok(result) => result,
Err(failure) => {
assert!(
gam_gpu::device_runtime::GpuRuntime::global().is_none(),
"#1017: CUDA device present but the device reduced-beta PCG \
declined/faulted instead of returning a result (tag: {failure:?}) — \
the kernel does not run correctly on GPU"
);
return;
}
};
let max_err = cpu
.iter()
.zip(device.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
max_err <= 1.0e-10,
"device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
);
assert!(diag.matvec_calls > 0);
assert_eq!(diag.matvec_calls, diag.iterations);
}
#[test]
fn dense_reference_matches_independent_solve() {
let sys = build_fixture(4, 5, 3, 7);
let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let total = n * d + k;
let mut h = Array2::<f64>::zeros((total, total));
let mut g = ndarray::Array1::<f64>::zeros(total);
for (i, row) in sys.rows.iter().enumerate() {
let base = i * d;
for c in 0..d {
for r in 0..d {
h[[base + r, base + c]] = row.htt[[r, c]];
}
}
for c in 0..k {
for r in 0..d {
h[[base + r, n * d + c]] = row.htbeta[[r, c]];
h[[n * d + c, base + r]] = row.htbeta[[r, c]];
}
}
for r in 0..d {
g[base + r] = row.gt[r];
}
}
for c in 0..k {
for r in 0..k {
h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
}
g[n * d + c] = sys.gb[c];
}
let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
let rhs = g.mapv(|v| -v);
let expected = cholesky_solve_vector(l.view(), rhs.view());
for i in 0..n * d {
assert!(
(solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
"delta_t[{i}] mismatch: got {} expected {}",
solution.delta_t[i],
expected[i]
);
}
for a in 0..k {
assert!(
(solution.delta_beta[a] - expected[n * d + a]).abs()
< 1e-10 * (1.0 + expected[n * d + a].abs()),
"delta_beta[{a}] mismatch"
);
}
}
#[test]
fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
let k = 24usize;
let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
let forward_slabs = slabs.clone();
let transpose_slabs = slabs;
sys.set_row_htbeta_operator(
move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
let h = &forward_slabs[row];
for r in 0..h.nrows() {
let mut acc = 0.0_f64;
for c in 0..h.ncols() {
acc += h[[r, c]] * x[c];
}
out[r] = acc;
}
},
move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
let h = &transpose_slabs[row];
for r in 0..h.nrows() {
for c in 0..h.ncols() {
out[c] += h[[r, c]] * v[r];
}
}
},
);
let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
.expect("row-procedural matvec backend builds for matrix-free system");
let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
let mut out_parallel_a = Array1::<f64>::zeros(k);
matvec(&x, &mut out_parallel_a);
let mut out_parallel_b = Array1::<f64>::zeros(k);
matvec(&x, &mut out_parallel_b);
for a in 0..k {
assert_eq!(
out_parallel_a[a].to_bits(),
out_parallel_b[a].to_bits(),
"row-procedural matvec parallel reduction is non-deterministic at index {a}"
);
}
let mut out_serial = Array1::<f64>::zeros(k);
rayon::ThreadPoolBuilder::new()
.num_threads(2)
.build()
.expect("build rayon pool")
.install(|| matvec(&x, &mut out_serial));
let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
for a in 0..k {
let diff = (out_parallel_a[a] - out_serial[a]).abs();
assert!(
diff <= 1e-12 * (1.0 + max_abs),
"row-procedural matvec parallel vs serial diverged beyond reassociation \
at index {a}: {} vs {} (diff={diff:e})",
out_parallel_a[a],
out_serial[a]
);
}
}
#[test]
fn framed_sae_schur_matvec_matches_dense_reference() {
use crate::arrow_schur::{
BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
};
let p = 4usize;
let ranks = vec![2usize, 4usize, 3usize];
let basis_sizes = vec![2usize, 1usize, 2usize];
let n_atoms = ranks.len();
let mut border_offsets = Vec::with_capacity(n_atoms);
let mut acc = 0usize;
for k in 0..n_atoms {
border_offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
let border_dim = acc;
let mut state = 0x1234_5678_9abc_def0u64;
let mut sample = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
};
let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
for k in 0..n_atoms {
let r = ranks[k];
let mut u = Array2::<f64>::zeros((p, r));
for i in 0..p {
for j in 0..r {
u[[i, j]] = if r == p && i == j {
1.0
} else if r == p {
0.0
} else {
sample()
};
}
}
frames.push(u);
}
let w_of = |i: usize, j: usize| -> Array2<f64> {
let (ui, uj) = (&frames[i], &frames[j]);
let (ri, rj) = (ranks[i], ranks[j]);
let mut w = Array2::<f64>::zeros((ri, rj));
for a in 0..ri {
for b in 0..rj {
let mut s = 0.0;
for c in 0..p {
s += ui[[c, a]] * uj[[c, b]];
}
w[[a, b]] = s;
}
}
w
};
let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
pairs.sort();
for &(i, j) in &pairs {
let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
let mut g = Array2::<f64>::zeros((mi, mj));
for r in 0..mi {
for c in 0..mj {
g[[r, c]] = 0.3 * sample();
}
}
if i == j {
for r in 0..mi.min(mj) {
g[[r, r]] += mi as f64 + 2.0;
}
}
frame_blocks.push(FactoredFrameGBlock {
atom_i: i,
atom_j: j,
g,
w: w_of(i, j),
});
}
let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
for k in 0..n_atoms {
let m = basis_sizes[k];
let mut a = Array2::<f64>::zeros((m, m));
for r in 0..m {
for c in 0..m {
a[[r, c]] = 0.2 * sample();
}
}
let mut s = a.t().dot(&a);
for r in 0..m {
s[[r, r]] += 1.0;
}
smooth_blocks.push(DeviceSaeSmoothBlock {
global_offset: border_offsets[k],
factor_a: s,
});
smooth_ranks.push(ranks[k]);
}
let n = 6usize;
let q = 3usize;
let mut sys = ArrowSchurSystem::new(n, q, border_dim);
let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..n {
let mut a = Array2::<f64>::zeros((q, q));
for r in 0..q {
for c in 0..q {
a[[r, c]] = sample();
}
}
let mut htt = a.t().dot(&a);
for r in 0..q {
htt[[r, r]] += q as f64 + 1.0;
}
sys.rows[i].htt = htt;
let mut slab = vec![0.0_f64; q * border_dim];
for c in 0..q {
for col in 0..border_dim {
let v = 0.15 * sample();
slab[c * border_dim + col] = v;
sys.rows[i].htbeta[[c, col]] = v;
}
}
row_htbeta.push(slab);
}
let data_op =
FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
.expect("frame op");
let mut hbb = data_op.to_dense();
for k in 0..n_atoms {
let op = IdentityRightKroneckerPenaltyOp {
factor_a: smooth_blocks[k].factor_a.clone(),
p: ranks[k],
global_offset: border_offsets[k],
k: border_dim,
};
let d = op.to_dense();
for r in 0..border_dim {
for c in 0..border_dim {
hbb[[r, c]] += d[[r, c]];
}
}
}
sys.hbb = hbb;
let data = DeviceSaePcgData {
p,
beta_dim: border_dim,
a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
smooth_blocks,
sparse_g_blocks: Vec::new(),
frame: Some(DeviceSaeFrameData {
ranks: ranks.clone(),
basis_sizes: basis_sizes.clone(),
border_offsets: border_offsets.clone(),
frame_blocks,
smooth_ranks,
row_htbeta,
}),
};
let ridge_t = 1e-7;
let ridge_beta = 1e-6;
let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
for r in 0..border_dim {
for c in 0..border_dim {
s_dense[[r, c]] = sys.hbb[[r, c]];
}
s_dense[[r, r]] += ridge_beta;
}
for row in &sys.rows {
let mut htt = row.htt.clone();
for d in 0..q {
htt[[d, d]] += ridge_t;
}
let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
.expect("htt PD");
let mut y = Array2::<f64>::zeros((q, border_dim));
for col in 0..border_dim {
let mut e = Array1::<f64>::zeros(q);
for r in 0..q {
e[r] = row.htbeta[[r, col]];
}
let solved = cholesky_solve_vector(factor.view(), e.view());
for r in 0..q {
y[[r, col]] = solved[r];
}
}
for r in 0..border_dim {
for c in 0..border_dim {
let mut acc = 0.0;
for d in 0..q {
acc += row.htbeta[[d, r]] * y[[d, c]];
}
s_dense[[r, c]] -= acc;
}
}
}
let mut max_rel = 0.0_f64;
for trial in 0..4 {
let x: Vec<f64> = (0..border_dim)
.map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
.collect();
let mut got = vec![0.0_f64; border_dim];
sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
.expect("framed matvec");
let mut want = vec![0.0_f64; border_dim];
for r in 0..border_dim {
let mut acc = 0.0;
for c in 0..border_dim {
acc += s_dense[[r, c]] * x[c];
}
want[r] = acc;
}
let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
for a in 0..border_dim {
let rel = (got[a] - want[a]).abs() / scale;
max_rel = max_rel.max(rel);
}
}
assert!(
max_rel <= 1e-10,
"framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
);
}
#[test]
fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
use crate::arrow_schur::{
BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
};
let p = 6usize;
let n_atoms = 8usize;
let ranks: Vec<usize> = (0..n_atoms)
.map(|k| if k % 2 == 0 { 3usize } else { p })
.collect();
let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
let mut border_offsets = Vec::with_capacity(n_atoms);
let mut acc = 0usize;
for k in 0..n_atoms {
border_offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
let border_dim = acc;
let mut state = 0xfeed_face_dead_beefu64;
let mut sample = || -> f64 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
};
let mut frames: Vec<Array2<f64>> = Vec::new();
for k in 0..n_atoms {
let r = ranks[k];
let mut u = Array2::<f64>::zeros((p, r));
for i in 0..p {
for j in 0..r {
u[[i, j]] = if r == p && i == j {
1.0
} else if r == p {
0.0
} else {
sample()
};
}
}
frames.push(u);
}
let w_of = |i: usize, j: usize| {
let (ui, uj) = (&frames[i], &frames[j]);
let (ri, rj) = (ranks[i], ranks[j]);
let mut w = Array2::<f64>::zeros((ri, rj));
for a in 0..ri {
for b in 0..rj {
let mut s = 0.0;
for c in 0..p {
s += ui[[c, a]] * uj[[c, b]];
}
w[[a, b]] = s;
}
}
w
};
let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
pairs.push((i, j));
pairs.push((j, i));
}
let mut frame_blocks = Vec::new();
for &(i, j) in &pairs {
let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
let mut g = Array2::<f64>::zeros((mi, mj));
for r in 0..mi {
for c in 0..mj {
g[[r, c]] = 0.25 * sample();
}
}
if i == j {
for r in 0..mi.min(mj) {
g[[r, r]] += mi as f64 + 2.0;
}
}
frame_blocks.push(FactoredFrameGBlock {
atom_i: i,
atom_j: j,
g,
w: w_of(i, j),
});
}
let mut smooth_blocks = Vec::new();
let mut smooth_ranks = Vec::new();
for k in 0..n_atoms {
let m = basis_sizes[k];
let mut a = Array2::<f64>::zeros((m, m));
for r in 0..m {
for c in 0..m {
a[[r, c]] = 0.2 * sample();
}
}
let mut s = a.t().dot(&a);
for r in 0..m {
s[[r, r]] += 1.0;
}
smooth_blocks.push(DeviceSaeSmoothBlock {
global_offset: border_offsets[k],
factor_a: s,
});
smooth_ranks.push(ranks[k]);
}
let n = 400usize;
let q = 4usize;
let mut sys = ArrowSchurSystem::new(n, q, border_dim);
let mut row_htbeta = Vec::new();
for i in 0..n {
let mut a = Array2::<f64>::zeros((q, q));
for r in 0..q {
for c in 0..q {
a[[r, c]] = sample();
}
}
let mut htt = a.t().dot(&a);
for r in 0..q {
htt[[r, r]] += q as f64 + 1.0;
}
sys.rows[i].htt = htt;
let mut slab = vec![0.0_f64; q * border_dim];
for c in 0..q {
for col in 0..border_dim {
let v = 0.02 * sample();
slab[c * border_dim + col] = v;
sys.rows[i].htbeta[[c, col]] = v;
}
}
row_htbeta.push(slab);
}
let data_op =
FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
.expect("frame op");
let mut hbb = data_op.to_dense();
for k in 0..n_atoms {
let op = IdentityRightKroneckerPenaltyOp {
factor_a: smooth_blocks[k].factor_a.clone(),
p: ranks[k],
global_offset: border_offsets[k],
k: border_dim,
};
let d = op.to_dense();
for r in 0..border_dim {
for c in 0..border_dim {
hbb[[r, c]] += d[[r, c]];
}
}
}
sys.hbb = hbb;
let data = DeviceSaePcgData {
p,
beta_dim: border_dim,
a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
smooth_blocks,
sparse_g_blocks: Vec::new(),
frame: Some(DeviceSaeFrameData {
ranks: ranks.clone(),
basis_sizes: basis_sizes.clone(),
border_offsets: border_offsets.clone(),
frame_blocks,
smooth_ranks,
row_htbeta,
}),
};
let ridge_t = 1e-7;
let ridge_beta = 1e-6;
let rhs: Array1<f64> =
Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
let (device, diag) =
match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
Ok(result) => result,
Err(failure) => {
assert!(
gam_gpu::device_runtime::GpuRuntime::global().is_none(),
"#1017: CUDA device present but the framed device SAE PCG \
declined/faulted instead of returning a result (tag: {failure:?}) — \
the kernel does not run correctly on GPU"
);
return;
}
};
let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
for col in 0..border_dim {
let mut e = vec![0.0_f64; border_dim];
e[col] = 1.0;
let mut sc = vec![0.0_f64; border_dim];
sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &e, &mut sc)
.expect("cpu matvec");
for r in 0..border_dim {
s_dense[[r, col]] = sc[r];
}
}
let factor = cholesky_factor_in_place(s_dense.view(), CholeskyGuard::NonnegativePivot)
.expect("S PD");
let cpu = cholesky_solve_vector(factor.view(), rhs.view());
let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
let mut max_rel = 0.0_f64;
for a in 0..border_dim {
max_rel = max_rel.max((device[a] - cpu[a]).abs() / scale);
}
let mut s_dev_resid = 0.0_f64;
{
let sx = s_dense.dot(&device);
for a in 0..border_dim {
s_dev_resid = s_dev_resid.max((sx[a] - rhs[a]).abs());
}
}
let s_cpu_resid = {
let sc = s_dense.dot(&cpu);
let mut m = 0.0_f64;
for a in 0..border_dim {
m = m.max((sc[a] - rhs[a]).abs());
}
m
};
assert!(
max_rel <= 1e-7,
"[#1551 framed-triage] max_rel={max_rel:e} | device-vs-CPU-operator residual \
‖S_cpu·device−rhs‖={s_dev_resid:e} (CPU's own ={s_cpu_resid:e}) | device PCG \
stop={:?} iters={} final_rel_resid={:e} — large operator-residual ⇒ device matvec \
is a different operator (kernel bug); small ⇒ PCG/precond or singular-S issue",
diag.stopping_reason,
diag.iterations,
diag.final_relative_residual,
);
}
}