1use ndarray::{Array1, Array2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24pub struct ArrowSchurGpuSolution {
26 pub delta_t: Array1<f64>,
27 pub delta_beta: Array1<f64>,
28 pub log_det_hessian: f64,
31}
32
33#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38 Unavailable,
40 RidgeBumpRequired { row: usize, bump: f64 },
43 SchurFactorFailed { reason: String },
46 GpuRequiresDenseSystem {
52 had_hbb_matvec: bool,
53 had_htbeta_matvec: bool,
54 },
55}
56
57const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
70
71pub fn solve_arrow_newton_step(
75 sys: &ArrowSchurSystem,
76 ridge_t: f64,
77 ridge_beta: f64,
78) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
79 let n = sys.rows.len();
80 let d = sys.d;
81 let k = sys.k;
82
83 let had_hbb_matvec = sys.hbb_matvec.is_some();
88 let had_htbeta_matvec = sys.htbeta_matvec.is_some();
89 if had_hbb_matvec || had_htbeta_matvec {
90 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
91 had_hbb_matvec,
92 had_htbeta_matvec,
93 });
94 }
95
96 if sys.hbb.dim() != (k, k) {
97 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
98 reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
99 });
100 }
101 if n == 0 || d == 0 {
102 return Err(ArrowSchurGpuFailure::Unavailable);
103 }
104 if sys
105 .rows
106 .iter()
107 .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
108 {
109 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
110 reason: "row block dimension mismatch".to_string(),
111 });
112 }
113
114 #[cfg(not(target_os = "linux"))]
115 {
116 if ridge_t.is_nan() || ridge_beta.is_nan() {
117 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
118 reason: "ridge is NaN".to_string(),
119 });
120 }
121 Err(ArrowSchurGpuFailure::Unavailable)
122 }
123
124 #[cfg(target_os = "linux")]
125 {
126 if gam_gpu::device_runtime::GpuRuntime::global()
135 .map(gam_gpu::device_runtime::GpuRuntime::device_count)
136 .unwrap_or(0)
137 > 1
138 {
139 match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
140 Ok(sol) => return Ok(sol),
141 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
142 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
143 }
144 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
145 return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
146 }
147 Err(_) => {}
150 }
151 }
152 if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
158 match cuda::solve_fused(sys, ridge_t, ridge_beta) {
159 Ok(sol) => return Ok(sol),
160 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
164 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
165 }
166 Err(_) => {}
170 }
171 }
172 cuda::solve(sys, ridge_t, ridge_beta)
173 }
174}
175
176#[cfg(target_os = "linux")]
182fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
183 let n = sys.rows.len();
184 let d = sys.d;
185 let k = sys.k;
186 let mut d_buf = Vec::with_capacity(n * d * d);
187 let mut b_buf = Vec::with_capacity(n * d * k);
188 let mut g_buf = Vec::with_capacity(n * d);
189 for row in &sys.rows {
190 pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
191 }
192 (d_buf, b_buf, g_buf)
193}
194
195#[cfg(target_os = "linux")]
196#[inline]
197fn pack_block(
198 row: &crate::arrow_schur::ArrowRowBlock,
199 ridge_t: f64,
200 d: usize,
201 k: usize,
202 d_buf: &mut Vec<f64>,
203 b_buf: &mut Vec<f64>,
204 g_buf: &mut Vec<f64>,
205) {
206 for col in 0..d {
207 for r in 0..d {
208 let mut value = row.htt[[r, col]];
209 if r == col {
210 value += ridge_t;
211 }
212 d_buf.push(value);
213 }
214 }
215 for col in 0..k {
216 for r in 0..d {
217 b_buf.push(row.htbeta[[r, col]]);
218 }
219 }
220 for r in 0..d {
221 g_buf.push(row.gt[r]);
222 }
223}
224
225#[doc(hidden)]
230#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
232 sys: &ArrowSchurSystem,
233 ridge_t: f64,
234 ridge_beta: f64,
235) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
236 if ridge_t.is_nan() || ridge_beta.is_nan() {
237 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
238 reason: "ridge is NaN".to_string(),
239 });
240 }
241 #[cfg(not(target_os = "linux"))]
242 {
243 Err(ArrowSchurGpuFailure::Unavailable)
248 }
249 #[cfg(target_os = "linux")]
250 {
251 if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
252 .is_none()
253 {
254 return Err(ArrowSchurGpuFailure::Unavailable);
255 }
256 cuda::solve_fused(sys, ridge_t, ridge_beta)
257 }
258}
259
260pub struct ResidentArrowFrameHandle {
270 #[cfg(target_os = "linux")]
271 inner: cuda::ResidentArrowFrame,
272 #[cfg(not(target_os = "linux"))]
273 _never: std::convert::Infallible,
274}
275
276impl ResidentArrowFrameHandle {
277 pub fn new(
279 sys: &ArrowSchurSystem,
280 ridge_t: f64,
281 ridge_beta: f64,
282 ) -> Result<Self, ArrowSchurGpuFailure> {
283 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
286 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
287 had_hbb_matvec: sys.hbb_matvec.is_some(),
288 had_htbeta_matvec: sys.htbeta_matvec.is_some(),
289 });
290 }
291 #[cfg(not(target_os = "linux"))]
292 {
293 if ridge_t.is_nan() || ridge_beta.is_nan() {
294 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
295 reason: "ridge is NaN".to_string(),
296 });
297 }
298 Err(ArrowSchurGpuFailure::Unavailable)
299 }
300 #[cfg(target_os = "linux")]
301 {
302 Ok(Self {
303 inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
304 })
305 }
306 }
307
308 pub fn solve_gradient(
310 &self,
311 g_t: &[f64],
312 g_beta: &[f64],
313 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
314 #[cfg(not(target_os = "linux"))]
315 {
316 if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
317 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
318 reason: "non-finite gradient entry".to_string(),
319 });
320 }
321 Err(ArrowSchurGpuFailure::Unavailable)
322 }
323 #[cfg(target_os = "linux")]
324 {
325 self.inner.solve_gradient(g_t, g_beta)
326 }
327 }
328
329 #[must_use]
331 pub fn log_det_hessian(&self) -> f64 {
332 #[cfg(not(target_os = "linux"))]
333 {
334 panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
340 }
341 #[cfg(target_os = "linux")]
342 {
343 self.inner.log_det_hessian()
344 }
345 }
346}
347
348pub fn gpu_schur_matvec_backend(
391 sys: &ArrowSchurSystem,
392 ridge_t: f64,
393 ridge_beta: f64,
394) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
395 if sys.htbeta_matvec.is_some() {
398 return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
399 }
400
401 #[cfg(not(target_os = "linux"))]
402 {
403 if ridge_t.is_nan() || ridge_beta.is_nan() {
406 return Err(ArrowSchurGpuFailure::Unavailable);
407 }
408 Err(ArrowSchurGpuFailure::Unavailable)
409 }
410
411 #[cfg(target_os = "linux")]
412 {
413 cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
414 }
415}
416
417fn build_row_procedural_matvec(
434 sys: &ArrowSchurSystem,
435 ridge_t: f64,
436 ridge_beta: f64,
437) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
438 use std::sync::Arc;
439 let n = sys.rows.len();
440 let k = sys.k;
441 let forward = sys
442 .htbeta_matvec
443 .clone()
444 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
445 let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
446 ArrowSchurGpuFailure::SchurFactorFailed {
451 reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
452 forward operator installed without its sparse adjoint"
453 .to_string(),
454 }
455 })?;
456
457 let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
462 for (i, row) in sys.rows.iter().enumerate() {
463 let di = row.htt.nrows();
464 if row.htt.ncols() != di || row.gt.len() != di {
465 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
466 reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
467 });
468 }
469 let mut block = row.htt.clone();
470 for r in 0..di {
471 block[[r, r]] += ridge_t;
472 }
473 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
474 .ok_or_else(|| {
475 let scale = row
476 .htt
477 .diag()
478 .iter()
479 .map(|v| v.abs())
480 .fold(0.0_f64, f64::max)
481 .max(1.0);
482 ArrowSchurGpuFailure::RidgeBumpRequired {
483 row: i,
484 bump: scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN,
485 }
486 })?;
487 factors.push(factor);
488 }
489
490 let penalty_op = sys.effective_penalty_op();
497 let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
498
499 let closure: crate::arrow_schur::GpuSchurMatvec =
500 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
501 assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
502 assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
503
504 {
507 let x_slice = x.as_slice().expect("x must be contiguous");
508 let out_slice = out.as_slice_mut().expect("out must be contiguous");
509 for a in 0..k {
510 out_slice[a] = ridge_beta * x_slice[a];
511 }
512 penalty_op.matvec(x_slice, out_slice);
513 }
514
515 let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
542 && rayon::current_thread_index().is_none();
543 if parallel {
544 use rayon::prelude::*;
545 const CHUNK: usize = 64;
546 let partials: Vec<Array1<f64>> = (0..n)
547 .into_par_iter()
548 .chunks(CHUNK)
549 .map(|idxs| {
550 let mut neg = Array1::<f64>::zeros(k);
554 for i in idxs {
555 let di = row_dims[i];
556 let mut v_i = Array1::<f64>::zeros(di);
558 forward(i, x.view(), &mut v_i);
559 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
561 transpose(i, w_i.view(), &mut neg);
563 }
564 neg
565 })
566 .collect();
567 let mut neg = Array1::<f64>::zeros(k);
576 for part in &partials {
577 for a in 0..k {
578 neg[a] += part[a];
579 }
580 }
581 for a in 0..k {
582 out[a] -= neg[a];
583 }
584 } else {
585 let mut neg = Array1::<f64>::zeros(k);
587 for i in 0..n {
588 let di = row_dims[i];
589 let mut v_i = Array1::<f64>::zeros(di);
591 forward(i, x.view(), &mut v_i);
592 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
594 transpose(i, w_i.view(), &mut neg);
596 }
597 for a in 0..k {
598 out[a] -= neg[a];
599 }
600 }
601 });
602
603 Ok(closure)
604}
605
606pub fn solve_reduced_beta_pcg(
628 s_acc: &Array2<f64>,
629 rhs_beta: &Array1<f64>,
630 max_iterations: usize,
631 relative_tolerance: f64,
632) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
633 solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
634 .map(|(x, _)| x)
635}
636
637#[doc(hidden)]
638pub fn solve_reduced_beta_pcg_with_diagnostics(
639 s_acc: &Array2<f64>,
640 rhs_beta: &Array1<f64>,
641 max_iterations: usize,
642 relative_tolerance: f64,
643) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
644 let k = rhs_beta.len();
645 if s_acc.dim() != (k, k) {
646 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
647 reason: format!(
648 "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
649 s_acc.dim()
650 ),
651 });
652 }
653 if k == 0 {
654 return Err(ArrowSchurGpuFailure::Unavailable);
655 }
656
657 #[cfg(not(target_os = "linux"))]
658 {
659 if relative_tolerance.is_nan() || max_iterations == 0 {
660 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
661 reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
662 });
663 }
664 Err(ArrowSchurGpuFailure::Unavailable)
665 }
666
667 #[cfg(target_os = "linux")]
668 {
669 cuda::solve_reduced_beta_pcg_with_diagnostics(
670 s_acc,
671 rhs_beta,
672 max_iterations,
673 relative_tolerance,
674 )
675 }
676}
677
678pub fn solve_sae_matrix_free_pcg(
679 sys: &ArrowSchurSystem,
680 data: &DeviceSaePcgData,
681 ridge_t: f64,
682 ridge_beta: f64,
683 rhs_beta: &Array1<f64>,
684 max_iterations: usize,
685 relative_tolerance: f64,
686) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
687 if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
688 return Err(ArrowSchurGpuFailure::Unavailable);
689 }
690 #[cfg(not(target_os = "linux"))]
691 {
692 if ridge_t.is_nan()
693 || ridge_beta.is_nan()
694 || relative_tolerance.is_nan()
695 || max_iterations == 0
696 {
697 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
698 reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
699 });
700 }
701 Err(ArrowSchurGpuFailure::Unavailable)
702 }
703 #[cfg(target_os = "linux")]
704 {
705 if data.frame.is_some() {
712 cuda::solve_sae_matrix_free_pcg_framed(
713 sys,
714 data,
715 ridge_t,
716 ridge_beta,
717 rhs_beta,
718 max_iterations,
719 relative_tolerance,
720 )
721 } else {
722 cuda::solve_sae_matrix_free_pcg(
723 sys,
724 data,
725 ridge_t,
726 ridge_beta,
727 rhs_beta,
728 max_iterations,
729 relative_tolerance,
730 )
731 }
732 }
733}
734
735#[doc(hidden)]
739pub fn solve_arrow_newton_step_dense_reference(
740 sys: &ArrowSchurSystem,
741 ridge_t: f64,
742 ridge_beta: f64,
743) -> Result<ArrowSchurGpuSolution, String> {
744 let n = sys.rows.len();
745 let d = sys.d;
746 let k = sys.k;
747 let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
748 let mut h = Array2::<f64>::zeros((total, total));
749 let mut rhs = Array1::<f64>::zeros(total);
750 for (i, row) in sys.rows.iter().enumerate() {
751 let base = i * d;
752 for c in 0..d {
753 for r in 0..d {
754 h[[base + r, base + c]] = row.htt[[r, c]];
755 }
756 h[[base + c, base + c]] += ridge_t;
757 }
758 for c in 0..k {
759 for r in 0..d {
760 let value = row.htbeta[[r, c]];
761 h[[base + r, n * d + c]] = value;
762 h[[n * d + c, base + r]] = value;
763 }
764 }
765 for r in 0..d {
766 rhs[base + r] = -row.gt[r];
767 }
768 }
769 for c in 0..k {
770 for r in 0..k {
771 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
772 }
773 h[[n * d + c, n * d + c]] += ridge_beta;
774 rhs[n * d + c] = -sys.gb[c];
775 }
776 let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
777 .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
778 let mut log_det = 0.0_f64;
779 for i in 0..total {
780 log_det += factor[[i, i]].ln();
781 }
782 log_det *= 2.0;
783 let solved = cholesky_solve_vector(factor.view(), rhs.view());
784 let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
785 let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
786 Ok(ArrowSchurGpuSolution {
787 delta_t,
788 delta_beta,
789 log_det_hessian: log_det,
790 })
791}
792
793#[doc(hidden)]
804pub fn sae_framed_penalty_matvec_cpu(
805 data: &DeviceSaePcgData,
806 ridge_beta: f64,
807 x: &[f64],
808 out: &mut [f64],
809) {
810 let frame = data
811 .frame
812 .as_ref()
813 .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
814 let k = data.beta_dim;
815 for a in 0..k {
816 out[a] = ridge_beta * x[a];
817 }
818 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
820 let off = blk.global_offset;
821 let m = blk.factor_a.nrows();
822 for i_a in 0..m {
823 for i_b in 0..r {
824 let mut acc = 0.0_f64;
825 for j_a in 0..m {
826 let s = blk.factor_a[[i_a, j_a]];
827 if s == 0.0 {
828 continue;
829 }
830 acc += s * x[off + j_a * r + i_b];
831 }
832 out[off + i_a * r + i_b] += acc;
833 }
834 }
835 }
836 for blk in &frame.frame_blocks {
838 let r_i = frame.ranks[blk.atom_i];
839 let r_j = frame.ranks[blk.atom_j];
840 let off_i = frame.border_offsets[blk.atom_i];
841 let off_j = frame.border_offsets[blk.atom_j];
842 let (m_i, m_j) = blk.g.dim();
843 for li in 0..m_i {
844 let yi_base = off_i + li * r_i;
845 for lj in 0..m_j {
846 let g = blk.g[[li, lj]];
847 if g == 0.0 {
848 continue;
849 }
850 let xj_base = off_j + lj * r_j;
851 for a in 0..r_i {
852 let mut acc = 0.0_f64;
853 for b in 0..r_j {
854 acc += blk.w[[a, b]] * x[xj_base + b];
855 }
856 out[yi_base + a] += g * acc;
857 }
858 }
859 }
860 }
861}
862
863#[doc(hidden)]
872pub fn sae_framed_schur_matvec_cpu(
873 sys: &ArrowSchurSystem,
874 data: &DeviceSaePcgData,
875 ridge_t: f64,
876 ridge_beta: f64,
877 x: &[f64],
878 out: &mut [f64],
879) -> Result<(), String> {
880 let frame = data
881 .frame
882 .as_ref()
883 .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
884 let k = data.beta_dim;
885 sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
886 if frame.row_htbeta.len() != sys.rows.len() {
887 return Err(format!(
888 "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
889 frame.row_htbeta.len(),
890 sys.rows.len()
891 ));
892 }
893 for (i, row) in sys.rows.iter().enumerate() {
894 let slab = &frame.row_htbeta[i];
895 if slab.is_empty() {
896 continue;
897 }
898 let qi = sys.row_dims[i];
899 if qi == 0 || slab.len() != qi * k {
900 continue;
901 }
902 let mut h = vec![0.0_f64; qi];
904 for c in 0..qi {
905 let base = c * k;
906 let mut acc = 0.0_f64;
907 for a in 0..k {
908 acc += slab[base + a] * x[a];
909 }
910 h[c] = acc;
911 }
912 let mut block = row.htt.clone();
914 for d in 0..qi {
915 block[[d, d]] += ridge_t;
916 }
917 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
918 .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
919 let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
920 for c in 0..qi {
922 let sc = s[c];
923 if sc == 0.0 {
924 continue;
925 }
926 let base = c * k;
927 for a in 0..k {
928 out[a] -= slab[base + a] * sc;
929 }
930 }
931 }
932 Ok(())
933}
934
935#[cfg(target_os = "linux")]
936mod cuda {
937 use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
938 use gam_gpu::driver::to_i32;
939 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
940 use crate::arrow_schur::{
941 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
942 };
943 use cudarc::cublas::sys::{
944 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
945 };
946 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
947 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
948 use cudarc::driver::{
949 CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
950 PushKernelArg,
951 };
952 use ndarray::Array1;
953 use std::sync::{Arc, OnceLock};
954
955 struct RowSlot {
960 d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, diag_scale: f64, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
970 bump: Option<f64>,
973 tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
979
980 pub(super) fn solve_multi_gpu(
1001 sys: &ArrowSchurSystem,
1002 ridge_t: f64,
1003 ridge_beta: f64,
1004 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1005 let n = sys.rows.len();
1006 let d = sys.d;
1007 let k = sys.k;
1008 if n == 0 || d == 0 || k == 0 {
1009 return Err(ArrowSchurGpuFailure::Unavailable);
1010 }
1011 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1015 return Err(ArrowSchurGpuFailure::Unavailable);
1016 }
1017
1018 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1019 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1020 if runtime.device_count() < 2 {
1021 return Err(ArrowSchurGpuFailure::Unavailable);
1022 }
1023
1024 let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1026 for row in &sys.rows {
1027 if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1028 return Err(ArrowSchurGpuFailure::Unavailable);
1029 }
1030 let mut d_block = Vec::with_capacity(d * d);
1031 let mut b_block = Vec::with_capacity(d * k);
1032 let mut g_vec = Vec::with_capacity(d);
1033 pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1034 let diag_scale = row
1035 .htt
1036 .diag()
1037 .iter()
1038 .map(|v| v.abs())
1039 .fold(0.0_f64, f64::max)
1040 .max(1.0);
1041 slots.push(RowSlot {
1042 d_block,
1043 b_block,
1044 g_vec,
1045 diag_scale,
1046 l_block: Vec::new(),
1047 u_vec: Vec::new(),
1048 y_block: Vec::new(),
1049 log_det_local: 0.0,
1050 bump: None,
1051 tile_partial_schur: None,
1052 tile_partial_rhs: None,
1053 delta_t_block: vec![0.0; d],
1054 });
1055 }
1056
1057 let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1059 forward_tile(ordinal, d, k, tile)
1060 });
1061 if forward_ok.is_none() {
1062 return Err(ArrowSchurGpuFailure::Unavailable);
1063 }
1064
1065 let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1067 if let Some((row, bump)) = slots
1068 .iter()
1069 .enumerate()
1070 .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1071 {
1072 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1073 }
1074
1075 let mut schur_host = vec![0.0_f64; k * k];
1080 for col in 0..k {
1081 for row in 0..k {
1082 let mut v = sys.hbb[[row, col]];
1083 if row == col {
1084 v += ridge_beta;
1085 }
1086 schur_host[col * k + row] = v;
1087 }
1088 }
1089 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1090 let mut log_det = 0.0_f64;
1091 for start in tile_starts(&row_base_of_tile) {
1092 let slot = &slots[start];
1093 let partial_schur = slot
1094 .tile_partial_schur
1095 .as_ref()
1096 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1097 let partial_rhs = slot
1098 .tile_partial_rhs
1099 .as_ref()
1100 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1101 for idx in 0..k * k {
1106 schur_host[idx] += partial_schur[idx];
1107 }
1108 for a in 0..k {
1109 rhs_host[a] += partial_rhs[a];
1110 }
1111 }
1112 for slot in &slots {
1113 log_det += slot.log_det_local;
1114 }
1115
1116 let primary = runtime.selected_device().ordinal;
1120 let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1121 .and_then(|ctx| ctx.new_stream().ok())
1122 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1123 let solver =
1124 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1125 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1126 let mut schur_dev = stream
1127 .clone_htod(&schur_host)
1128 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1129 let mut rhs_dev = stream
1130 .clone_htod(&rhs_host)
1131 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1132 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1133 if info != 0 {
1134 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1135 reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1136 });
1137 }
1138 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1139 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1140 let delta_beta_host = stream
1141 .clone_dtoh(&rhs_dev)
1142 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1143 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1144 let l_schur_host = stream
1145 .clone_dtoh(&schur_dev)
1146 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1147 for j in 0..k {
1148 log_det += l_schur_host[j * k + j].ln();
1149 }
1150 log_det *= 2.0;
1151
1152 let delta_beta_ref = &delta_beta_host;
1154 let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1155 back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1156 });
1157 if back_ok.is_none() {
1158 return Err(ArrowSchurGpuFailure::Unavailable);
1159 }
1160
1161 let mut delta_t = Array1::<f64>::zeros(n * d);
1163 for (i, slot) in slots.iter().enumerate() {
1164 let base = i * d;
1165 for r in 0..d {
1166 delta_t[base + r] = slot.delta_t_block[r];
1167 }
1168 }
1169
1170 Ok(ArrowSchurGpuSolution {
1171 delta_t,
1172 delta_beta,
1173 log_det_hessian: log_det,
1174 })
1175 }
1176
1177 fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1180 tiles.iter().map(|(_, range)| range.start)
1181 }
1182
1183 fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1191 if tile.is_empty() {
1192 return Some(());
1193 }
1194 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1197 .and_then(|ctx| ctx.new_stream().ok())?;
1198 let solver = DnHandle::new(stream.clone()).ok()?;
1199 let blas = CudaBlas::new(stream.clone()).ok()?;
1200 let m = tile.len();
1201
1202 let mut d_host = Vec::with_capacity(m * d * d);
1205 let mut b_host = Vec::with_capacity(m * d * k);
1206 let mut g_host = Vec::with_capacity(m * d);
1207 for slot in tile.iter() {
1208 d_host.extend_from_slice(&slot.d_block);
1209 b_host.extend_from_slice(&slot.b_block);
1210 g_host.extend_from_slice(&slot.g_vec);
1211 }
1212 let mut d_dev = stream.clone_htod(&d_host).ok()?;
1213 let mut b_dev = stream.clone_htod(&b_host).ok()?;
1214 let mut g_dev = stream.clone_htod(&g_host).ok()?;
1215
1216 let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1218 if let Some(local) = info_host.iter().position(|info| *info != 0) {
1219 let pivot = info_host[local];
1220 tile[local].bump = Some(
1221 tile[local].diag_scale
1222 * (f64::from(pivot).abs()).max(1.0)
1223 * f64::EPSILON.sqrt()
1224 * super::RIDGE_BUMP_EPS_MARGIN,
1225 );
1226 return Some(());
1227 }
1228
1229 trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1231 trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1232
1233 let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1235 let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1236 accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1237
1238 let l_host = stream.clone_dtoh(&d_dev).ok()?;
1240 let u_host = stream.clone_dtoh(&g_dev).ok()?;
1241 let y_host = stream.clone_dtoh(&b_dev).ok()?;
1242 let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1243 let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1244
1245 for (local, slot) in tile.iter_mut().enumerate() {
1246 let l_base = local * d * d;
1247 let u_base = local * d;
1248 let y_base = local * d * k;
1249 slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1250 slot.u_vec = u_host[u_base..u_base + d].to_vec();
1251 slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1252 let mut log_det_local = 0.0_f64;
1253 for j in 0..d {
1254 log_det_local += l_host[l_base + j * d + j].ln();
1255 }
1256 slot.log_det_local = log_det_local;
1257 }
1258 tile[0].tile_partial_schur = Some(partial_schur);
1259 tile[0].tile_partial_rhs = Some(partial_rhs);
1260 Some(())
1261 }
1262
1263 fn back_sub_tile(
1267 ordinal: usize,
1268 d: usize,
1269 k: usize,
1270 delta_beta: &[f64],
1271 tile: &mut [RowSlot],
1272 ) -> Option<()> {
1273 if tile.is_empty() {
1274 return Some(());
1275 }
1276 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1279 .and_then(|ctx| ctx.new_stream().ok())?;
1280 let blas = CudaBlas::new(stream.clone()).ok()?;
1281 let m = tile.len();
1282
1283 let mut l_host = Vec::with_capacity(m * d * d);
1284 let mut u_host = Vec::with_capacity(m * d);
1285 let mut y_host = Vec::with_capacity(m * d * k);
1286 for slot in tile.iter() {
1287 l_host.extend_from_slice(&slot.l_block);
1288 u_host.extend_from_slice(&slot.u_vec);
1289 y_host.extend_from_slice(&slot.y_block);
1290 }
1291 let d_dev = stream.clone_htod(&l_host).ok()?;
1292 let mut g_dev = stream.clone_htod(&u_host).ok()?;
1293 let b_dev = stream.clone_htod(&y_host).ok()?;
1294 let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1295
1296 accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1298 trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1299 let x_host = stream.clone_dtoh(&g_dev).ok()?;
1300 for (local, slot) in tile.iter_mut().enumerate() {
1301 let base = local * d;
1302 for r in 0..d {
1303 slot.delta_t_block[r] = -x_host[base + r];
1304 }
1305 }
1306 Some(())
1307 }
1308
1309 pub(super) fn solve(
1310 sys: &ArrowSchurSystem,
1311 ridge_t: f64,
1312 ridge_beta: f64,
1313 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1314 let n = sys.rows.len();
1315 let d = sys.d;
1316 let k = sys.k;
1317 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1318 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1319
1320 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1321 .and_then(|ctx| ctx.new_stream().ok())
1322 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1323 let solver =
1324 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1325 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1326
1327 let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1329 let mut d_dev = stream
1330 .clone_htod(&d_host)
1331 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1332 let mut b_dev = stream
1333 .clone_htod(&b_host)
1334 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1335 let mut g_dev = stream
1336 .clone_htod(&g_host)
1337 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1338
1339 let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1347 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1348 let pivot = info_host[idx];
1349 let scale = sys.rows[idx]
1350 .htt
1351 .diag()
1352 .iter()
1353 .map(|v| v.abs())
1354 .fold(0.0_f64, f64::max)
1355 .max(1.0);
1356 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1357 row: idx,
1358 bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
1359 });
1360 }
1361
1362 trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1365 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1368
1369 let schur_init: Vec<f64> = {
1388 let mut tmp = Vec::with_capacity(k * k);
1389 for col in 0..k {
1390 for row in 0..k {
1391 let mut v = sys.hbb[[row, col]];
1392 if row == col {
1393 v += ridge_beta;
1394 }
1395 tmp.push(v);
1396 }
1397 }
1398 tmp
1399 };
1400 let mut schur_dev = stream
1401 .clone_htod(&schur_init)
1402 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1403 let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1404 let mut rhs_dev = stream
1405 .clone_htod(&rhs_init)
1406 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1407
1408 accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1409
1410 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1412 if info != 0 {
1413 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1414 reason: format!("Schur Cholesky failed at pivot {info}"),
1415 });
1416 }
1417 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1419 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1420 let delta_beta_host = stream
1421 .clone_dtoh(&rhs_dev)
1422 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1423 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1424
1425 accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1433 trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1434
1435 let x_host = stream
1436 .clone_dtoh(&g_dev)
1437 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1438 let mut delta_t = Array1::<f64>::zeros(n * d);
1439 for (i, v) in x_host.iter().enumerate() {
1440 delta_t[i] = -*v;
1441 }
1442
1443 let l_local_host = stream
1445 .clone_dtoh(&d_dev)
1446 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1447 let l_schur_host = stream
1448 .clone_dtoh(&schur_dev)
1449 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1450 let mut log_det = 0.0_f64;
1451 for i in 0..n {
1452 let base = i * d * d;
1453 for j in 0..d {
1454 log_det += l_local_host[base + j * d + j].ln();
1455 }
1456 }
1457 for j in 0..k {
1458 log_det += l_schur_host[j * k + j].ln();
1459 }
1460 log_det *= 2.0;
1461
1462 Ok(ArrowSchurGpuSolution {
1463 delta_t,
1464 delta_beta,
1465 log_det_hessian: log_det,
1466 })
1467 }
1468
1469 fn potrf_batched(
1470 solver: &DnHandle,
1471 stream: &Arc<CudaStream>,
1472 p: usize,
1473 batch: usize,
1474 matrices: &mut CudaSlice<f64>,
1475 ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1476 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1477 let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1478 let matrix_len = p * p;
1479 let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1480 let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1481 let mut ptrs = Vec::with_capacity(batch);
1482 for idx in 0..batch {
1483 ptrs.push(base_ptr + (idx as u64) * bytes_per);
1484 }
1485 let mut ptrs_dev = stream
1486 .clone_htod(&ptrs)
1487 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1488 let mut info_dev = stream
1489 .alloc_zeros::<i32>(batch)
1490 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1491 let status = {
1492 let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1493 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1494 unsafe {
1497 cusolver_sys::cusolverDnDpotrfBatched(
1498 solver.cu(),
1499 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1500 p_i,
1501 ptrs_ptr as *mut *mut f64,
1502 p_i,
1503 info_ptr as *mut i32,
1504 batch_i,
1505 )
1506 }
1507 };
1508 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1509 return Err(ArrowSchurGpuFailure::Unavailable);
1510 }
1511 stream
1512 .clone_dtoh(&info_dev)
1513 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1514 }
1515
1516 fn potrf_single(
1517 solver: &DnHandle,
1518 stream: &Arc<CudaStream>,
1519 p: usize,
1520 matrix: &mut CudaSlice<f64>,
1521 ) -> Result<i32, ArrowSchurGpuFailure> {
1522 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1523 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1524 let mut lwork = 0_i32;
1525 {
1526 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1527 let status = unsafe {
1529 cusolver_sys::cusolverDnDpotrf_bufferSize(
1530 solver.cu(),
1531 uplo,
1532 p_i,
1533 mat_ptr as *mut f64,
1534 p_i,
1535 &mut lwork,
1536 )
1537 };
1538 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1539 return Err(ArrowSchurGpuFailure::Unavailable);
1540 }
1541 }
1542 let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1543 let mut workspace = stream
1544 .alloc_zeros::<f64>(lwork_usize.max(1))
1545 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1546 let mut info_dev = stream
1547 .alloc_zeros::<i32>(1)
1548 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1549 {
1550 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1551 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1552 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1553 let status = unsafe {
1555 cusolver_sys::cusolverDnDpotrf(
1556 solver.cu(),
1557 uplo,
1558 p_i,
1559 mat_ptr as *mut f64,
1560 p_i,
1561 work_ptr as *mut f64,
1562 lwork,
1563 info_ptr as *mut i32,
1564 )
1565 };
1566 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1567 return Err(ArrowSchurGpuFailure::Unavailable);
1568 }
1569 }
1570 let info_host = stream
1571 .clone_dtoh(&info_dev)
1572 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1573 Ok(info_host[0])
1574 }
1575
1576 fn trsm_batched_lower_inplace(
1580 blas: &CudaBlas,
1581 stream: &Arc<CudaStream>,
1582 d: usize,
1583 n: usize,
1584 nrhs: usize,
1585 l_stack: &CudaSlice<f64>,
1586 rhs_stack: &mut CudaSlice<f64>,
1587 ) -> Result<(), ArrowSchurGpuFailure> {
1588 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1589 }
1590
1591 fn trsm_batched_lower_inplace_transposed(
1593 blas: &CudaBlas,
1594 stream: &Arc<CudaStream>,
1595 d: usize,
1596 n: usize,
1597 nrhs: usize,
1598 l_stack: &CudaSlice<f64>,
1599 rhs_stack: &mut CudaSlice<f64>,
1600 ) -> Result<(), ArrowSchurGpuFailure> {
1601 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1602 }
1603
1604 fn trsm_batched_inplace_inner(
1605 blas: &CudaBlas,
1606 stream: &Arc<CudaStream>,
1607 d: usize,
1608 n: usize,
1609 nrhs: usize,
1610 l_stack: &CudaSlice<f64>,
1611 rhs_stack: &mut CudaSlice<f64>,
1612 transposed: bool,
1613 ) -> Result<(), ArrowSchurGpuFailure> {
1614 let alpha = 1.0_f64;
1615 let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1616 let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1617 let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1618 let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1619 let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1620 let (l_base, _l_record) = l_stack.device_ptr(stream);
1621 let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1622 let mut l_ptrs = Vec::with_capacity(n);
1623 let mut rhs_ptrs = Vec::with_capacity(n);
1624 for i in 0..n {
1625 l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1626 rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1627 }
1628 let mut l_ptrs_dev = stream
1629 .clone_htod(&l_ptrs)
1630 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1631 let mut rhs_ptrs_dev = stream
1632 .clone_htod(&rhs_ptrs)
1633 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1634 let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1635 let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1636 let op = if transposed {
1637 cublasOperation_t::CUBLAS_OP_T
1638 } else {
1639 cublasOperation_t::CUBLAS_OP_N
1640 };
1641 let handle = *blas.handle();
1642 let status = unsafe {
1645 cudarc::cublas::sys::cublasDtrsmBatched(
1646 handle,
1647 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1648 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1649 op,
1650 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1651 d_i,
1652 nrhs_i,
1653 &alpha,
1654 l_ptrs_ptr as *const *const f64,
1655 d_i,
1656 rhs_ptrs_ptr as *const *mut f64,
1657 d_i,
1658 batch_i,
1659 )
1660 };
1661 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1662 return Err(ArrowSchurGpuFailure::Unavailable);
1663 }
1664 Ok(())
1665 }
1666
1667 fn trsm_single(
1670 blas: &CudaBlas,
1671 stream: &Arc<CudaStream>,
1672 n: usize,
1673 l: &CudaSlice<f64>,
1674 rhs: &mut CudaSlice<f64>,
1675 upper: bool,
1676 transposed: bool,
1677 ) -> Result<(), ArrowSchurGpuFailure> {
1678 let alpha = 1.0_f64;
1679 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1680 let handle = *blas.handle();
1681 let (l_ptr, _l_rec) = l.device_ptr(stream);
1682 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1683 let status = unsafe {
1685 cudarc::cublas::sys::cublasDtrsm_v2(
1686 handle,
1687 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1688 if upper {
1689 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1690 } else {
1691 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1692 },
1693 if transposed {
1694 cublasOperation_t::CUBLAS_OP_T
1695 } else {
1696 cublasOperation_t::CUBLAS_OP_N
1697 },
1698 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1699 n_i,
1700 1,
1701 &alpha,
1702 l_ptr as *const f64,
1703 n_i,
1704 rhs_ptr as *mut f64,
1705 n_i,
1706 )
1707 };
1708 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1709 return Err(ArrowSchurGpuFailure::Unavailable);
1710 }
1711 Ok(())
1712 }
1713
1714 fn accumulate_schur(
1718 blas: &CudaBlas,
1719 d: usize,
1720 k: usize,
1721 n: usize,
1722 y_stack: &CudaSlice<f64>,
1723 u_stack: &CudaSlice<f64>,
1724 schur: &mut CudaSlice<f64>,
1725 rhs: &mut CudaSlice<f64>,
1726 ) -> Result<(), ArrowSchurGpuFailure> {
1727 let y_block_elems = d * k;
1728 let u_block_elems = d;
1729 for i in 0..n {
1730 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1731 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1732 let gemm_cfg = GemmConfig::<f64> {
1734 transa: cublasOperation_t::CUBLAS_OP_T,
1735 transb: cublasOperation_t::CUBLAS_OP_N,
1736 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1737 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1738 k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1739 alpha: -1.0,
1740 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1741 ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1742 beta: 1.0,
1743 ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1744 };
1745 unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1747 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1748 let gemv_cfg = GemvConfig::<f64> {
1750 trans: cublasOperation_t::CUBLAS_OP_T,
1751 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1752 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1753 alpha: 1.0,
1754 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1755 incx: 1,
1756 beta: 1.0,
1757 incy: 1,
1758 };
1759 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1762 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1763 }
1764 Ok(())
1765 }
1766
1767 fn accumulate_schur_rhs_only(
1775 blas: &CudaBlas,
1776 d: usize,
1777 k: usize,
1778 n: usize,
1779 y_stack: &CudaSlice<f64>,
1780 u_stack: &CudaSlice<f64>,
1781 rhs: &mut CudaSlice<f64>,
1782 ) -> Result<(), ArrowSchurGpuFailure> {
1783 let y_block_elems = d * k;
1784 let u_block_elems = d;
1785 for i in 0..n {
1786 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1787 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1788 let gemv_cfg = GemvConfig::<f64> {
1789 trans: cublasOperation_t::CUBLAS_OP_T,
1790 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1791 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1792 alpha: 1.0,
1793 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1794 incx: 1,
1795 beta: 1.0,
1796 incy: 1,
1797 };
1798 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1801 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1802 }
1803 Ok(())
1804 }
1805
1806 fn accumulate_back_sub_rhs(
1809 blas: &CudaBlas,
1810 d: usize,
1811 k: usize,
1812 n: usize,
1813 y_stack: &CudaSlice<f64>,
1814 delta_beta: &CudaSlice<f64>,
1815 u_stack: &mut CudaSlice<f64>,
1816 ) -> Result<(), ArrowSchurGpuFailure> {
1817 let y_block_elems = d * k;
1818 let u_block_elems = d;
1819 for i in 0..n {
1820 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1821 let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1822 let gemv_cfg = GemvConfig::<f64> {
1823 trans: cublasOperation_t::CUBLAS_OP_N,
1824 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1825 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1826 alpha: 1.0,
1827 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1828 incx: 1,
1829 beta: 1.0,
1830 incy: 1,
1831 };
1832 unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1835 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1836 }
1837 Ok(())
1838 }
1839
1840 use std::collections::HashMap;
1856 use std::sync::Mutex;
1857
1858 struct FusedModuleCache {
1863 modules: Mutex<
1864 HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1865 >,
1866 }
1867
1868 fn fused_module_cache() -> &'static FusedModuleCache {
1869 static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1870 CACHE.get_or_init(|| FusedModuleCache {
1871 modules: Mutex::new(HashMap::new()),
1872 })
1873 }
1874
1875 fn fused_module_for(
1876 ctx: &Arc<CudaContext>,
1877 key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
1878 ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
1879 let cache = fused_module_cache();
1880 if let Ok(guard) = cache.modules.lock() {
1881 if let Some(existing) = guard.get(&key) {
1882 return Ok(existing.clone());
1883 }
1884 }
1885 let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
1886 key.p_max as usize,
1887 key.r_template as usize,
1888 );
1889 let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
1890 ArrowSchurGpuFailure::SchurFactorFailed {
1891 reason: format!(
1892 "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
1893 key.p_max, key.r_template
1894 ),
1895 }
1896 })?;
1897 let module = ctx
1898 .load_module(ptx)
1899 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1900 if let Ok(mut guard) = cache.modules.lock() {
1901 guard.entry(key).or_insert_with(|| module.clone());
1902 }
1903 Ok(module)
1904 }
1905
1906 const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
1907extern "C" __global__ void arrow_pcg_jacobi_mul(
1908 const double* __restrict__ inv_diag,
1909 const double* __restrict__ r,
1910 double* __restrict__ z,
1911 int n
1912) {
1913 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1914 if (idx < n) {
1915 z[idx] = inv_diag[idx] * r[idx];
1916 }
1917}
1918
1919extern "C" __global__ void arrow_pcg_update_p(
1920 const double* __restrict__ z,
1921 double beta,
1922 double* __restrict__ p,
1923 int n
1924) {
1925 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1926 if (idx < n) {
1927 p[idx] = z[idx] + beta * p[idx];
1928 }
1929}
1930
1931extern "C" __global__ void arrow_sae_init(
1932 double* __restrict__ out,
1933 const double* __restrict__ x,
1934 double ridge,
1935 int n
1936) {
1937 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1938 if (idx < n) {
1939 out[idx] = ridge * x[idx];
1940 }
1941}
1942
1943extern "C" __global__ void arrow_sae_smooth_matvec(
1944 const double* __restrict__ x,
1945 double* __restrict__ out,
1946 const int* __restrict__ block_offsets,
1947 const int* __restrict__ block_m,
1948 const int* __restrict__ factor_ptr,
1949 const double* __restrict__ factors,
1950 int p,
1951 int n_blocks
1952) {
1953 int block_id = blockIdx.y;
1954 int linear = blockIdx.x * blockDim.x + threadIdx.x;
1955 if (block_id >= n_blocks) {
1956 return;
1957 }
1958 int m = block_m[block_id];
1959 int total = m * p;
1960 if (linear >= total) {
1961 return;
1962 }
1963 int li = linear / p;
1964 int oc = linear - li * p;
1965 int off = block_offsets[block_id];
1966 int fbase = factor_ptr[block_id];
1967 double acc = 0.0;
1968 for (int lj = 0; lj < m; ++lj) {
1969 double a = factors[fbase + li * m + lj];
1970 acc += a * x[off + lj * p + oc];
1971 }
1972 out[off + li * p + oc] += acc;
1973}
1974
1975extern "C" __global__ void arrow_sae_sparse_g_matvec(
1976 const double* __restrict__ x,
1977 double* __restrict__ out,
1978 const int* __restrict__ row_off,
1979 const int* __restrict__ col_off,
1980 const int* __restrict__ rows,
1981 const int* __restrict__ cols,
1982 const int* __restrict__ data_ptr,
1983 const double* __restrict__ data,
1984 int p,
1985 int n_blocks
1986) {
1987 int block_id = blockIdx.y;
1988 int linear = blockIdx.x * blockDim.x + threadIdx.x;
1989 if (block_id >= n_blocks) {
1990 return;
1991 }
1992 int m_i = rows[block_id];
1993 int m_j = cols[block_id];
1994 int total = m_i * p;
1995 if (linear >= total) {
1996 return;
1997 }
1998 int li = linear / p;
1999 int oc = linear - li * p;
2000 int rbase = row_off[block_id];
2001 int cbase = col_off[block_id];
2002 int dbase = data_ptr[block_id];
2003 double acc = 0.0;
2004 for (int lj = 0; lj < m_j; ++lj) {
2005 acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2006 }
2007 // #1017 — a row atom co-occurs with multiple column atoms, so several
2008 // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2009 // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2010 // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2011 // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2012 atomicAdd(&out[(rbase + li) * p + oc], acc);
2013}
2014
2015extern "C" __global__ void arrow_sae_gather_u(
2016 const double* __restrict__ x,
2017 const int* __restrict__ row_ptr,
2018 const int* __restrict__ beta_base,
2019 const double* __restrict__ phi,
2020 double* __restrict__ u,
2021 int p,
2022 int n_rows
2023) {
2024 int row = blockIdx.y;
2025 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2026 if (row >= n_rows || oc >= p) {
2027 return;
2028 }
2029 double acc = 0.0;
2030 int start = row_ptr[row];
2031 int end = row_ptr[row + 1];
2032 for (int e = start; e < end; ++e) {
2033 acc += phi[e] * x[beta_base[e] + oc];
2034 }
2035 u[row * p + oc] = acc;
2036}
2037
2038extern "C" __global__ void arrow_sae_apply_l(
2039 const double* __restrict__ u,
2040 const int* __restrict__ jac_ptr,
2041 const double* __restrict__ jac,
2042 double* __restrict__ w,
2043 int p,
2044 int max_q,
2045 int n_rows
2046) {
2047 int row = blockIdx.y;
2048 int c = blockIdx.x * blockDim.x + threadIdx.x;
2049 if (row >= n_rows) {
2050 return;
2051 }
2052 int jstart = jac_ptr[row];
2053 int q = (jac_ptr[row + 1] - jstart) / p;
2054 if (c >= q) {
2055 return;
2056 }
2057 double acc = 0.0;
2058 for (int oc = 0; oc < p; ++oc) {
2059 acc += jac[jstart + c * p + oc] * u[row * p + oc];
2060 }
2061 w[row * max_q + c] = acc;
2062}
2063
2064extern "C" __global__ void arrow_sae_apply_ainv(
2065 const double* __restrict__ ainv,
2066 const double* __restrict__ w,
2067 double* __restrict__ v,
2068 int max_q,
2069 int n_rows
2070) {
2071 int row = blockIdx.y;
2072 int c = blockIdx.x * blockDim.x + threadIdx.x;
2073 if (row >= n_rows || c >= max_q) {
2074 return;
2075 }
2076 double acc = 0.0;
2077 int base = row * max_q * max_q;
2078 for (int j = 0; j < max_q; ++j) {
2079 acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2080 }
2081 v[row * max_q + c] = acc;
2082}
2083
2084extern "C" __global__ void arrow_sae_scatter_sub(
2085 const double* __restrict__ v,
2086 const int* __restrict__ jac_ptr,
2087 const double* __restrict__ jac,
2088 const int* __restrict__ row_ptr,
2089 const int* __restrict__ beta_base,
2090 const double* __restrict__ phi,
2091 double* __restrict__ out,
2092 int p,
2093 int max_q,
2094 int n_rows
2095) {
2096 int row = blockIdx.y;
2097 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2098 if (row >= n_rows || oc >= p) {
2099 return;
2100 }
2101 int jstart = jac_ptr[row];
2102 int q = (jac_ptr[row + 1] - jstart) / p;
2103 double lt_v = 0.0;
2104 for (int c = 0; c < q; ++c) {
2105 lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2106 }
2107 int start = row_ptr[row];
2108 int end = row_ptr[row + 1];
2109 for (int e = start; e < end; ++e) {
2110 atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2111 }
2112}
2113
2114extern "C" __global__ void arrow_sae_diag_sub(
2115 double* __restrict__ diag,
2116 const double* __restrict__ ainv,
2117 const int* __restrict__ jac_ptr,
2118 const double* __restrict__ jac,
2119 const int* __restrict__ row_ptr,
2120 const int* __restrict__ beta_base,
2121 const double* __restrict__ phi,
2122 int p,
2123 int max_q,
2124 int n_rows
2125) {
2126 int row = blockIdx.y;
2127 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2128 if (row >= n_rows || oc >= p) {
2129 return;
2130 }
2131 int jstart = jac_ptr[row];
2132 int q = (jac_ptr[row + 1] - jstart) / p;
2133 int abase = row * max_q * max_q;
2134 double quad = 0.0;
2135 for (int c = 0; c < q; ++c) {
2136 double lc = jac[jstart + c * p + oc];
2137 for (int d = 0; d < q; ++d) {
2138 quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2139 }
2140 }
2141 int start = row_ptr[row];
2142 int end = row_ptr[row + 1];
2143 for (int e = start; e < end; ++e) {
2144 double pe = phi[e];
2145 atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2146 }
2147}
2148
2149/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2150 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2151 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2152 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2153 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2154
2155extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2156 const double* __restrict__ x,
2157 double* __restrict__ out,
2158 const int* __restrict__ block_offsets,
2159 const int* __restrict__ block_m,
2160 const int* __restrict__ block_r,
2161 const int* __restrict__ factor_ptr,
2162 const double* __restrict__ factors,
2163 int n_blocks
2164) {
2165 int block_id = blockIdx.y;
2166 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2167 if (block_id >= n_blocks) {
2168 return;
2169 }
2170 int m = block_m[block_id];
2171 int r = block_r[block_id];
2172 int total = m * r;
2173 if (linear >= total) {
2174 return;
2175 }
2176 int li = linear / r;
2177 int ib = linear - li * r;
2178 int off = block_offsets[block_id];
2179 int fbase = factor_ptr[block_id];
2180 double acc = 0.0;
2181 for (int lj = 0; lj < m; ++lj) {
2182 double a = factors[fbase + li * m + lj];
2183 acc += a * x[off + lj * r + ib];
2184 }
2185 out[off + li * r + ib] += acc;
2186}
2187
2188extern "C" __global__ void arrow_sae_frame_g_matvec(
2189 const double* __restrict__ x,
2190 double* __restrict__ out,
2191 const int* __restrict__ off_i,
2192 const int* __restrict__ off_j,
2193 const int* __restrict__ r_i,
2194 const int* __restrict__ r_j,
2195 const int* __restrict__ m_i,
2196 const int* __restrict__ m_j,
2197 const int* __restrict__ g_ptr,
2198 const double* __restrict__ g_data,
2199 const int* __restrict__ w_ptr,
2200 const double* __restrict__ w_data,
2201 int n_blocks
2202) {
2203 int block_id = blockIdx.y;
2204 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2205 if (block_id >= n_blocks) {
2206 return;
2207 }
2208 int ri = r_i[block_id];
2209 int rj = r_j[block_id];
2210 int mi = m_i[block_id];
2211 int mj = m_j[block_id];
2212 int total = mi * ri;
2213 if (linear >= total) {
2214 return;
2215 }
2216 int li = linear / ri; // basis row in atom i
2217 int a = linear - li * ri; // frame coord in atom i
2218 int oi = off_i[block_id];
2219 int oj = off_j[block_id];
2220 int gbase = g_ptr[block_id];
2221 int wbase = w_ptr[block_id];
2222 double acc = 0.0;
2223 for (int lj = 0; lj < mj; ++lj) {
2224 double g = g_data[gbase + li * mj + lj];
2225 if (g == 0.0) { continue; }
2226 int xj_base = oj + lj * rj;
2227 double inner = 0.0;
2228 for (int b = 0; b < rj; ++b) {
2229 inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2230 }
2231 acc += g * inner;
2232 }
2233 // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2234 // multiple co-occurring (i,j) frame blocks running concurrently on
2235 // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2236 // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2237 atomicAdd(&out[oi + li * ri + a], acc);
2238}
2239
2240/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2241 * h_i = H_tβ^(i) · x (length q_i)
2242 * s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
2243 * out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
2244 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2245 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2246extern "C" __global__ void arrow_sae_frame_apply_h(
2247 const double* __restrict__ x,
2248 const int* __restrict__ htb_ptr,
2249 const double* __restrict__ htb,
2250 const int* __restrict__ q_of,
2251 double* __restrict__ hvec,
2252 int k,
2253 int max_q,
2254 int n_rows
2255) {
2256 int row = blockIdx.y;
2257 int c = blockIdx.x * blockDim.x + threadIdx.x;
2258 if (row >= n_rows) { return; }
2259 int q = q_of[row];
2260 if (c >= q) { return; }
2261 int base = htb_ptr[row] + c * k;
2262 double acc = 0.0;
2263 for (int a = 0; a < k; ++a) {
2264 acc += htb[base + a] * x[a];
2265 }
2266 hvec[row * max_q + c] = acc;
2267}
2268
2269extern "C" __global__ void arrow_sae_frame_apply_ainv(
2270 const double* __restrict__ ainv,
2271 const double* __restrict__ hvec,
2272 const int* __restrict__ q_of,
2273 double* __restrict__ svec,
2274 int max_q,
2275 int n_rows
2276) {
2277 int row = blockIdx.y;
2278 int c = blockIdx.x * blockDim.x + threadIdx.x;
2279 if (row >= n_rows || c >= max_q) { return; }
2280 int q = q_of[row];
2281 double acc = 0.0;
2282 int abase = row * max_q * max_q;
2283 for (int j = 0; j < q; ++j) {
2284 acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2285 }
2286 svec[row * max_q + c] = acc;
2287}
2288
2289extern "C" __global__ void arrow_sae_frame_scatter_h(
2290 const double* __restrict__ svec,
2291 const int* __restrict__ htb_ptr,
2292 const double* __restrict__ htb,
2293 const int* __restrict__ q_of,
2294 double* __restrict__ out,
2295 int k,
2296 int max_q,
2297 int n_rows
2298) {
2299 int row = blockIdx.y;
2300 int a = blockIdx.x * blockDim.x + threadIdx.x;
2301 if (row >= n_rows || a >= k) { return; }
2302 int q = q_of[row];
2303 int hbase = htb_ptr[row];
2304 double acc = 0.0;
2305 for (int c = 0; c < q; ++c) {
2306 acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2307 }
2308 atomicAdd(&out[a], -acc);
2309}
2310
2311/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2312extern "C" __global__ void arrow_sae_frame_diag_sub(
2313 double* __restrict__ diag,
2314 const double* __restrict__ ainv,
2315 const int* __restrict__ htb_ptr,
2316 const double* __restrict__ htb,
2317 const int* __restrict__ q_of,
2318 int k,
2319 int max_q,
2320 int n_rows
2321) {
2322 int row = blockIdx.y;
2323 int a = blockIdx.x * blockDim.x + threadIdx.x;
2324 if (row >= n_rows || a >= k) { return; }
2325 int q = q_of[row];
2326 int hbase = htb_ptr[row];
2327 int abase = row * max_q * max_q;
2328 double quad = 0.0;
2329 for (int c = 0; c < q; ++c) {
2330 double hc = htb[hbase + c * k + a];
2331 for (int d = 0; d < q; ++d) {
2332 quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2333 }
2334 }
2335 atomicAdd(&diag[a], -quad);
2336}
2337"#;
2338
2339 fn pcg_vector_module(
2340 ctx: &Arc<CudaContext>,
2341 ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2342 static CACHE: gam_gpu::device_cache::PtxModuleCache =
2343 gam_gpu::device_cache::PtxModuleCache::new();
2344 CACHE
2345 .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2346 .map_err(|err| {
2347 log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2353 ArrowSchurGpuFailure::Unavailable
2354 })
2355 }
2356
2357 fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2358 let threads = 256u32;
2359 let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2360 Ok(LaunchConfig {
2361 grid_dim: (blocks, 1, 1),
2362 block_dim: (threads, 1, 1),
2363 shared_mem_bytes: 0,
2364 })
2365 }
2366
2367 fn launch_jacobi_mul(
2368 stream: &Arc<CudaStream>,
2369 module: &Arc<CudaModule>,
2370 inv_diag: &CudaSlice<f64>,
2371 r: &CudaSlice<f64>,
2372 z: &mut CudaSlice<f64>,
2373 n: usize,
2374 ) -> Result<(), ArrowSchurGpuFailure> {
2375 let kernel = module
2376 .load_function("arrow_pcg_jacobi_mul")
2377 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2378 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2379 let mut builder = stream.launch_builder(&kernel);
2380 builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2381 unsafe { builder.launch(pcg_launch_config(n)?) }
2384 .map(drop)
2385 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2386 }
2387
2388 fn launch_update_p(
2389 stream: &Arc<CudaStream>,
2390 module: &Arc<CudaModule>,
2391 z: &CudaSlice<f64>,
2392 beta: f64,
2393 p: &mut CudaSlice<f64>,
2394 n: usize,
2395 ) -> Result<(), ArrowSchurGpuFailure> {
2396 let kernel = module
2397 .load_function("arrow_pcg_update_p")
2398 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2399 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2400 let mut builder = stream.launch_builder(&kernel);
2401 builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2402 unsafe { builder.launch(pcg_launch_config(n)?) }
2405 .map(drop)
2406 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2407 }
2408
2409 struct DeviceSaePcgBuffers {
2410 row_ptr: CudaSlice<i32>,
2411 beta_base: CudaSlice<i32>,
2412 phi: CudaSlice<f64>,
2413 jac_ptr: CudaSlice<i32>,
2414 jac: CudaSlice<f64>,
2415 smooth_offsets: CudaSlice<i32>,
2416 smooth_m: CudaSlice<i32>,
2417 smooth_ptr: CudaSlice<i32>,
2418 smooth_data: CudaSlice<f64>,
2419 g_row_off: CudaSlice<i32>,
2420 g_col_off: CudaSlice<i32>,
2421 g_rows: CudaSlice<i32>,
2422 g_cols: CudaSlice<i32>,
2423 g_ptr: CudaSlice<i32>,
2424 g_data: CudaSlice<f64>,
2425 ainv: CudaSlice<f64>,
2426 u: CudaSlice<f64>,
2427 w: CudaSlice<f64>,
2428 v: CudaSlice<f64>,
2429 n_rows: usize,
2430 p: usize,
2431 k: usize,
2432 max_q: usize,
2433 smooth_blocks: usize,
2434 g_blocks: usize,
2435 }
2436
2437 fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2438 to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2439 }
2440
2441 fn sae_penalty_diag_host(
2442 data: &DeviceSaePcgData,
2443 ridge_beta: f64,
2444 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2445 let mut diag = vec![ridge_beta; data.beta_dim];
2446 for block in &data.smooth_blocks {
2447 let (rows, cols) = block.factor_a.dim();
2448 if rows != cols {
2449 return Err(ArrowSchurGpuFailure::Unavailable);
2450 }
2451 for row in 0..rows {
2452 let coeff = block.factor_a[[row, row]];
2453 let base = block
2454 .global_offset
2455 .checked_add(
2456 row.checked_mul(data.p)
2457 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2458 )
2459 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2460 let end = base
2461 .checked_add(data.p)
2462 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2463 if end > diag.len() {
2464 return Err(ArrowSchurGpuFailure::Unavailable);
2465 }
2466 for channel in 0..data.p {
2467 diag[base + channel] += coeff;
2468 }
2469 }
2470 }
2471 for block in &data.sparse_g_blocks {
2472 if block.row_off != block.col_off {
2473 continue;
2474 }
2475 let (rows, cols) = block.data.dim();
2476 for row in 0..rows.min(cols) {
2477 let coeff = block.data[[row, row]];
2478 let beta_row = block
2479 .row_off
2480 .checked_add(row)
2481 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2482 let base = beta_row
2483 .checked_mul(data.p)
2484 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2485 let end = base
2486 .checked_add(data.p)
2487 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2488 if end > diag.len() {
2489 return Err(ArrowSchurGpuFailure::Unavailable);
2490 }
2491 for channel in 0..data.p {
2492 diag[base + channel] += coeff;
2493 }
2494 }
2495 }
2496 Ok(diag)
2497 }
2498
2499 fn flatten_device_sae_data(
2500 sys: &ArrowSchurSystem,
2501 data: &DeviceSaePcgData,
2502 ridge_t: f64,
2503 stream: &Arc<CudaStream>,
2504 ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2505 let n_rows = sys.rows.len();
2506 let p = data.p;
2507 let k = data.beta_dim;
2508 if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2509 return Err(ArrowSchurGpuFailure::Unavailable);
2510 }
2511
2512 let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2513 let mut beta_base_host = Vec::<i32>::new();
2514 let mut phi_host = Vec::<f64>::new();
2515 row_ptr_host.push(0_i32);
2516 for row in data.a_phi.iter() {
2517 for &(base, phi) in row {
2518 beta_base_host.push(checked_i32(base)?);
2519 phi_host.push(phi);
2520 }
2521 row_ptr_host.push(checked_i32(beta_base_host.len())?);
2522 }
2523
2524 let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2525 let mut jac_host = Vec::<f64>::new();
2526 let mut max_q = 0usize;
2527 jac_ptr_host.push(0_i32);
2528 for row_jac in data.local_jac.iter() {
2529 if row_jac.len() % p != 0 {
2530 return Err(ArrowSchurGpuFailure::Unavailable);
2531 }
2532 max_q = max_q.max(row_jac.len() / p);
2533 jac_host.extend_from_slice(row_jac);
2534 jac_ptr_host.push(checked_i32(jac_host.len())?);
2535 }
2536 if max_q == 0 {
2537 return Err(ArrowSchurGpuFailure::Unavailable);
2538 }
2539
2540 let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2541 let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2542 let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2543 let mut smooth_data_host = Vec::<f64>::new();
2544 smooth_ptr_host.push(0_i32);
2545 for block in &data.smooth_blocks {
2546 let (rows, cols) = block.factor_a.dim();
2547 if rows != cols {
2548 return Err(ArrowSchurGpuFailure::Unavailable);
2549 }
2550 smooth_offsets_host.push(checked_i32(block.global_offset)?);
2551 smooth_m_host.push(checked_i32(rows)?);
2552 for r in 0..rows {
2553 for c in 0..cols {
2554 smooth_data_host.push(block.factor_a[[r, c]]);
2555 }
2556 }
2557 smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2558 }
2559
2560 let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2561 let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2562 let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2563 let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2564 let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2565 let mut g_data_host = Vec::<f64>::new();
2566 g_ptr_host.push(0_i32);
2567 for block in &data.sparse_g_blocks {
2568 let (rows, cols) = block.data.dim();
2569 g_row_off_host.push(checked_i32(block.row_off)?);
2570 g_col_off_host.push(checked_i32(block.col_off)?);
2571 g_rows_host.push(checked_i32(rows)?);
2572 g_cols_host.push(checked_i32(cols)?);
2573 for r in 0..rows {
2574 for c in 0..cols {
2575 g_data_host.push(block.data[[r, c]]);
2576 }
2577 }
2578 g_ptr_host.push(checked_i32(g_data_host.len())?);
2579 }
2580
2581 let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2582 for (row_idx, row) in sys.rows.iter().enumerate() {
2583 let q = data.local_jac[row_idx].len() / p;
2584 if row.htt.dim() != (q, q) {
2585 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2586 reason: format!(
2587 "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2588 row.htt.dim()
2589 ),
2590 });
2591 }
2592 let mut block = row.htt.clone();
2593 for d in 0..q {
2594 block[[d, d]] += ridge_t;
2595 }
2596 let factor = gam_linalg::triangular::cholesky_factor_in_place(
2597 block.view(),
2598 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2599 )
2600 .ok_or_else(|| {
2601 let scale = row
2602 .htt
2603 .diag()
2604 .iter()
2605 .map(|v| v.abs())
2606 .fold(0.0_f64, f64::max)
2607 .max(1.0);
2608 ArrowSchurGpuFailure::RidgeBumpRequired {
2609 row: row_idx,
2610 bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
2611 }
2612 })?;
2613 for col in 0..q {
2614 let mut e = Array1::<f64>::zeros(q);
2615 e[col] = 1.0;
2616 let solved =
2617 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2618 for r in 0..q {
2619 ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2620 }
2621 }
2622 }
2623
2624 Ok(DeviceSaePcgBuffers {
2625 row_ptr: stream
2626 .clone_htod(&row_ptr_host)
2627 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2628 beta_base: stream
2629 .clone_htod(&beta_base_host)
2630 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2631 phi: stream
2632 .clone_htod(&phi_host)
2633 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2634 jac_ptr: stream
2635 .clone_htod(&jac_ptr_host)
2636 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2637 jac: stream
2638 .clone_htod(&jac_host)
2639 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2640 smooth_offsets: stream
2641 .clone_htod(&smooth_offsets_host)
2642 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2643 smooth_m: stream
2644 .clone_htod(&smooth_m_host)
2645 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2646 smooth_ptr: stream
2647 .clone_htod(&smooth_ptr_host)
2648 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2649 smooth_data: stream
2650 .clone_htod(&smooth_data_host)
2651 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2652 g_row_off: stream
2653 .clone_htod(&g_row_off_host)
2654 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2655 g_col_off: stream
2656 .clone_htod(&g_col_off_host)
2657 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2658 g_rows: stream
2659 .clone_htod(&g_rows_host)
2660 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2661 g_cols: stream
2662 .clone_htod(&g_cols_host)
2663 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2664 g_ptr: stream
2665 .clone_htod(&g_ptr_host)
2666 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2667 g_data: stream
2668 .clone_htod(&g_data_host)
2669 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2670 ainv: stream
2671 .clone_htod(&ainv_host)
2672 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2673 u: stream
2674 .alloc_zeros::<f64>(n_rows * p)
2675 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2676 w: stream
2677 .alloc_zeros::<f64>(n_rows * max_q)
2678 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2679 v: stream
2680 .alloc_zeros::<f64>(n_rows * max_q)
2681 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2682 n_rows,
2683 p,
2684 k,
2685 max_q,
2686 smooth_blocks: data.smooth_blocks.len(),
2687 g_blocks: data.sparse_g_blocks.len(),
2688 })
2689 }
2690
2691 fn launch_sae_init(
2692 stream: &Arc<CudaStream>,
2693 module: &Arc<CudaModule>,
2694 out: &mut CudaSlice<f64>,
2695 x: &CudaSlice<f64>,
2696 ridge: f64,
2697 n: usize,
2698 ) -> Result<(), ArrowSchurGpuFailure> {
2699 let kernel = module
2700 .load_function("arrow_sae_init")
2701 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2702 let n_i32 = checked_i32(n)?;
2703 let mut builder = stream.launch_builder(&kernel);
2704 builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2705 unsafe { builder.launch(pcg_launch_config(n)?) }
2709 .map(drop)
2710 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2711 }
2712
2713 fn launch_sae_penalty_matvec(
2714 stream: &Arc<CudaStream>,
2715 module: &Arc<CudaModule>,
2716 buffers: &mut DeviceSaePcgBuffers,
2717 x: &CudaSlice<f64>,
2718 out: &mut CudaSlice<f64>,
2719 ridge_beta: f64,
2720 ) -> Result<(), ArrowSchurGpuFailure> {
2721 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2722 if buffers.smooth_blocks > 0 {
2723 let kernel = module
2724 .load_function("arrow_sae_smooth_matvec")
2725 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2726 let max_m = buffers.k;
2727 let p_i32 = checked_i32(buffers.p)?;
2728 let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2729 let cfg = LaunchConfig {
2730 grid_dim: (
2731 ((max_m as u32).saturating_add(255) / 256).max(1),
2732 checked_i32(buffers.smooth_blocks)? as u32,
2733 1,
2734 ),
2735 block_dim: (256, 1, 1),
2736 shared_mem_bytes: 0,
2737 };
2738 let mut builder = stream.launch_builder(&kernel);
2739 builder
2740 .arg(x)
2741 .arg(&mut *out)
2742 .arg(&buffers.smooth_offsets)
2743 .arg(&buffers.smooth_m)
2744 .arg(&buffers.smooth_ptr)
2745 .arg(&buffers.smooth_data)
2746 .arg(&p_i32)
2747 .arg(&blocks_i32);
2748 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2753 }
2754 if buffers.g_blocks > 0 {
2755 let kernel = module
2756 .load_function("arrow_sae_sparse_g_matvec")
2757 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2758 let max_work = buffers
2759 .k
2760 .checked_div(buffers.p)
2761 .unwrap_or(0)
2762 .saturating_mul(buffers.p);
2763 let p_i32 = checked_i32(buffers.p)?;
2764 let blocks_i32 = checked_i32(buffers.g_blocks)?;
2765 let cfg = LaunchConfig {
2766 grid_dim: (
2767 ((max_work as u32).saturating_add(255) / 256).max(1),
2768 checked_i32(buffers.g_blocks)? as u32,
2769 1,
2770 ),
2771 block_dim: (256, 1, 1),
2772 shared_mem_bytes: 0,
2773 };
2774 let mut builder = stream.launch_builder(&kernel);
2775 builder
2776 .arg(x)
2777 .arg(&mut *out)
2778 .arg(&buffers.g_row_off)
2779 .arg(&buffers.g_col_off)
2780 .arg(&buffers.g_rows)
2781 .arg(&buffers.g_cols)
2782 .arg(&buffers.g_ptr)
2783 .arg(&buffers.g_data)
2784 .arg(&p_i32)
2785 .arg(&blocks_i32);
2786 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2791 }
2792 Ok(())
2793 }
2794
2795 fn launch_sae_row_schur_sub(
2796 stream: &Arc<CudaStream>,
2797 module: &Arc<CudaModule>,
2798 buffers: &mut DeviceSaePcgBuffers,
2799 x: &CudaSlice<f64>,
2800 out: &mut CudaSlice<f64>,
2801 ) -> Result<(), ArrowSchurGpuFailure> {
2802 let p_i32 = checked_i32(buffers.p)?;
2803 let max_q_i32 = checked_i32(buffers.max_q)?;
2804 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2805 let cfg_p_rows = LaunchConfig {
2806 grid_dim: (
2807 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2808 checked_i32(buffers.n_rows)? as u32,
2809 1,
2810 ),
2811 block_dim: (256, 1, 1),
2812 shared_mem_bytes: 0,
2813 };
2814 let gather = module
2815 .load_function("arrow_sae_gather_u")
2816 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2817 {
2818 let mut builder = stream.launch_builder(&gather);
2819 builder
2820 .arg(x)
2821 .arg(&buffers.row_ptr)
2822 .arg(&buffers.beta_base)
2823 .arg(&buffers.phi)
2824 .arg(&mut buffers.u)
2825 .arg(&p_i32)
2826 .arg(&n_rows_i32);
2827 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2831 }
2832
2833 let cfg_q_rows = LaunchConfig {
2834 grid_dim: (
2835 ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2836 checked_i32(buffers.n_rows)? as u32,
2837 1,
2838 ),
2839 block_dim: (256, 1, 1),
2840 shared_mem_bytes: 0,
2841 };
2842 let apply_l = module
2843 .load_function("arrow_sae_apply_l")
2844 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2845 {
2846 let mut builder = stream.launch_builder(&apply_l);
2847 builder
2848 .arg(&buffers.u)
2849 .arg(&buffers.jac_ptr)
2850 .arg(&buffers.jac)
2851 .arg(&mut buffers.w)
2852 .arg(&p_i32)
2853 .arg(&max_q_i32)
2854 .arg(&n_rows_i32);
2855 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2859 }
2860
2861 let apply_ainv = module
2862 .load_function("arrow_sae_apply_ainv")
2863 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2864 {
2865 let mut builder = stream.launch_builder(&apply_ainv);
2866 builder
2867 .arg(&buffers.ainv)
2868 .arg(&buffers.w)
2869 .arg(&mut buffers.v)
2870 .arg(&max_q_i32)
2871 .arg(&n_rows_i32);
2872 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2876 }
2877
2878 let scatter = module
2879 .load_function("arrow_sae_scatter_sub")
2880 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2881 {
2882 let mut builder = stream.launch_builder(&scatter);
2883 builder
2884 .arg(&buffers.v)
2885 .arg(&buffers.jac_ptr)
2886 .arg(&buffers.jac)
2887 .arg(&buffers.row_ptr)
2888 .arg(&buffers.beta_base)
2889 .arg(&buffers.phi)
2890 .arg(out)
2891 .arg(&p_i32)
2892 .arg(&max_q_i32)
2893 .arg(&n_rows_i32);
2894 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2898 }
2899 Ok(())
2900 }
2901
2902 fn launch_sae_diag_sub(
2903 stream: &Arc<CudaStream>,
2904 module: &Arc<CudaModule>,
2905 buffers: &DeviceSaePcgBuffers,
2906 diag: &mut CudaSlice<f64>,
2907 ) -> Result<(), ArrowSchurGpuFailure> {
2908 let kernel = module
2909 .load_function("arrow_sae_diag_sub")
2910 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2911 let p_i32 = checked_i32(buffers.p)?;
2912 let max_q_i32 = checked_i32(buffers.max_q)?;
2913 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2914 let cfg = LaunchConfig {
2915 grid_dim: (
2916 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2917 checked_i32(buffers.n_rows)? as u32,
2918 1,
2919 ),
2920 block_dim: (256, 1, 1),
2921 shared_mem_bytes: 0,
2922 };
2923 let mut builder = stream.launch_builder(&kernel);
2924 builder
2925 .arg(diag)
2926 .arg(&buffers.ainv)
2927 .arg(&buffers.jac_ptr)
2928 .arg(&buffers.jac)
2929 .arg(&buffers.row_ptr)
2930 .arg(&buffers.beta_base)
2931 .arg(&buffers.phi)
2932 .arg(&p_i32)
2933 .arg(&max_q_i32)
2934 .arg(&n_rows_i32);
2935 unsafe { builder.launch(cfg) }
2939 .map(drop)
2940 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2941 }
2942
2943 fn launch_sae_matvec(
2944 stream: &Arc<CudaStream>,
2945 module: &Arc<CudaModule>,
2946 buffers: &mut DeviceSaePcgBuffers,
2947 x: &CudaSlice<f64>,
2948 out: &mut CudaSlice<f64>,
2949 ridge_beta: f64,
2950 ) -> Result<(), ArrowSchurGpuFailure> {
2951 launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
2952 launch_sae_row_schur_sub(stream, module, buffers, x, out)
2953 }
2954
2955 fn pack_fused_host(
2960 sys: &ArrowSchurSystem,
2961 ridge_t: f64,
2962 p_max: usize,
2963 r_template: usize,
2964 ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
2965 let n = sys.rows.len();
2966 let d = sys.d;
2967 let k = sys.k;
2968 let mut d_buf = vec![0.0_f64; n * p_max * p_max];
2969 let mut b_buf = vec![0.0_f64; n * p_max * r_template];
2970 let mut g_buf = vec![0.0_f64; n * p_max];
2971 for (i, row) in sys.rows.iter().enumerate() {
2972 for col in 0..d {
2974 let base = (i * p_max + col) * p_max;
2975 for r in 0..d {
2976 let mut value = row.htt[[r, col]];
2977 if r == col {
2978 value += ridge_t;
2979 }
2980 d_buf[base + r] = value;
2981 }
2982 }
2983 for col in 0..k {
2991 let base = (i * r_template + col) * p_max;
2992 for r in 0..d {
2993 b_buf[base + r] = row.htbeta[[r, col]];
2994 }
2995 }
2996 let g_base = i * p_max;
2998 for r in 0..d {
2999 g_buf[g_base + r] = row.gt[r];
3000 }
3001 }
3002 (d_buf, b_buf, g_buf)
3003 }
3004
3005 pub(super) struct ResidentArrowFrame {
3032 n: usize,
3033 d: usize,
3034 k: usize,
3035 stream: Arc<CudaStream>,
3036 blas: CudaBlas,
3037 l_dev: CudaSlice<f64>,
3040 y_dev: CudaSlice<f64>,
3043 schur_dev: CudaSlice<f64>,
3046 log_det_hessian: f64,
3049 }
3050
3051 impl ResidentArrowFrame {
3052 pub(super) fn new(
3056 sys: &ArrowSchurSystem,
3057 ridge_t: f64,
3058 ridge_beta: f64,
3059 ) -> Result<Self, ArrowSchurGpuFailure> {
3060 if ridge_t.is_nan() || ridge_beta.is_nan() {
3061 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3062 reason: "ridge is NaN".to_string(),
3063 });
3064 }
3065 let n = sys.rows.len();
3066 let d = sys.d;
3067 let k = sys.k;
3068 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3069 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3070 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3071 .and_then(|ctx| ctx.new_stream().ok())
3072 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3073 let solver =
3074 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3075 let blas =
3076 CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3077
3078 let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3080 let mut l_dev = stream
3081 .clone_htod(&d_host)
3082 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3083 let mut y_dev = stream
3084 .clone_htod(&b_host)
3085 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3086
3087 let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3089 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3090 let pivot = info_host[idx];
3091 let scale = sys.rows[idx]
3092 .htt
3093 .diag()
3094 .iter()
3095 .map(|v| v.abs())
3096 .fold(0.0_f64, f64::max)
3097 .max(1.0);
3098 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3099 row: idx,
3100 bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3101 });
3102 }
3103
3104 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3106
3107 let schur_init: Vec<f64> = {
3112 let mut tmp = Vec::with_capacity(k * k);
3113 for col in 0..k {
3114 for row in 0..k {
3115 let mut v = sys.hbb[[row, col]];
3116 if row == col {
3117 v += ridge_beta;
3118 }
3119 tmp.push(v);
3120 }
3121 }
3122 tmp
3123 };
3124 let mut schur_dev = stream
3125 .clone_htod(&schur_init)
3126 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3127 let zero_u = stream
3130 .clone_htod(&vec![0.0_f64; n * d])
3131 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3132 let mut throwaway_rhs = stream
3133 .clone_htod(&vec![0.0_f64; k])
3134 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3135 accumulate_schur(
3136 &blas,
3137 d,
3138 k,
3139 n,
3140 &y_dev,
3141 &zero_u,
3142 &mut schur_dev,
3143 &mut throwaway_rhs,
3144 )?;
3145 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3146 if info != 0 {
3147 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3148 reason: format!("Schur Cholesky failed at pivot {info}"),
3149 });
3150 }
3151
3152 let l_local_host = stream
3154 .clone_dtoh(&l_dev)
3155 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3156 let l_schur_host = stream
3157 .clone_dtoh(&schur_dev)
3158 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3159 let mut log_det = 0.0_f64;
3160 for i in 0..n {
3161 let base = i * d * d;
3162 for j in 0..d {
3163 log_det += l_local_host[base + j * d + j].ln();
3164 }
3165 }
3166 for j in 0..k {
3167 log_det += l_schur_host[j * k + j].ln();
3168 }
3169 log_det *= 2.0;
3170
3171 Ok(Self {
3172 n,
3173 d,
3174 k,
3175 stream,
3176 blas,
3177 l_dev,
3178 y_dev,
3179 schur_dev,
3180 log_det_hessian: log_det,
3181 })
3182 }
3183
3184 #[inline]
3185 pub(super) fn log_det_hessian(&self) -> f64 {
3186 self.log_det_hessian
3187 }
3188
3189 pub(super) fn solve_gradient(
3193 &self,
3194 g_t: &[f64],
3195 g_beta: &[f64],
3196 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3197 let n = self.n;
3198 let d = self.d;
3199 let k = self.k;
3200 if g_t.len() != n * d || g_beta.len() != k {
3201 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3202 reason: format!(
3203 "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3204 g_t.len(),
3205 n * d,
3206 g_beta.len(),
3207 k
3208 ),
3209 });
3210 }
3211 let mut u_dev = self
3213 .stream
3214 .clone_htod(&g_t.to_vec())
3215 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3216 trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3217
3218 let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3221 let mut rhs_dev = self
3222 .stream
3223 .clone_htod(&rhs_init)
3224 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3225 accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3226
3227 trsm_single(
3229 &self.blas,
3230 &self.stream,
3231 k,
3232 &self.schur_dev,
3233 &mut rhs_dev,
3234 false,
3235 false,
3236 )?;
3237 trsm_single(
3238 &self.blas,
3239 &self.stream,
3240 k,
3241 &self.schur_dev,
3242 &mut rhs_dev,
3243 false,
3244 true,
3245 )?;
3246 let delta_beta_host = self
3247 .stream
3248 .clone_dtoh(&rhs_dev)
3249 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3250 let delta_beta = Array1::from_vec(delta_beta_host);
3251
3252 accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3254 trsm_batched_lower_inplace_transposed(
3255 &self.blas,
3256 &self.stream,
3257 d,
3258 n,
3259 1,
3260 &self.l_dev,
3261 &mut u_dev,
3262 )?;
3263 let x_host = self
3264 .stream
3265 .clone_dtoh(&u_dev)
3266 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3267 let mut delta_t = Array1::<f64>::zeros(n * d);
3268 for (i, v) in x_host.iter().enumerate() {
3269 delta_t[i] = -*v;
3270 }
3271
3272 Ok(ArrowSchurGpuSolution {
3273 delta_t,
3274 delta_beta,
3275 log_det_hessian: self.log_det_hessian,
3276 })
3277 }
3278 }
3279
3280 pub(super) fn solve_fused(
3281 sys: &ArrowSchurSystem,
3282 ridge_t: f64,
3283 ridge_beta: f64,
3284 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3285 let n = sys.rows.len();
3286 let d = sys.d;
3287 let k = sys.k;
3288 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3289 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3290 let p_max = plan.p_max;
3291 let r_template = plan.r_template;
3292
3293 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3294 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3295 )
3296 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3297 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3298 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3299 let stream = ctx
3300 .new_stream()
3301 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3302 let cap = &runtime.device.capability;
3303 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3304 cc_major: cap.compute_major,
3305 cc_minor: cap.compute_minor,
3306 p_max: p_max as u32,
3307 r_template: r_template as u32,
3308 };
3309 let module = fused_module_for(&ctx, key)?;
3310 let forward = module
3311 .load_function("arrow_schur_forward_pgroup")
3312 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3313 let back_sub = module
3314 .load_function("arrow_schur_back_sub_pgroup")
3315 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3316
3317 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3319 let d_dev = stream
3320 .clone_htod(&d_host)
3321 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3322 let b_dev = stream
3323 .clone_htod(&b_host)
3324 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3325 let g_dev = stream
3326 .clone_htod(&g_host)
3327 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3328 let mut l_out = stream
3329 .alloc_zeros::<f64>(n * p_max * p_max)
3330 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3331 let mut u_out = stream
3332 .alloc_zeros::<f64>(n * p_max)
3333 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3334 let mut y_out = stream
3335 .alloc_zeros::<f64>(n * p_max * r_template)
3336 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3337 let mut partial_s = stream
3338 .alloc_zeros::<f64>(plan.partial_s_doubles)
3339 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3340 let mut partial_r = stream
3341 .alloc_zeros::<f64>(plan.partial_r_doubles)
3342 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3343 let mut status_dev = stream
3344 .alloc_zeros::<i32>(n)
3345 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3346
3347 let cfg = LaunchConfig {
3349 grid_dim: (plan.blocks, 1, 1),
3350 block_dim: (plan.threads_per_block, 1, 1),
3351 shared_mem_bytes: 0,
3352 };
3353 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3354 let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3355 let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3356 let ridge_arg = ridge_t;
3357 {
3358 let mut builder = stream.launch_builder(&forward);
3359 builder
3360 .arg(&d_dev)
3361 .arg(&b_dev)
3362 .arg(&g_dev)
3363 .arg(&n_i32)
3364 .arg(&p_i32)
3365 .arg(&r_i32)
3366 .arg(&ridge_arg)
3367 .arg(&mut l_out)
3368 .arg(&mut u_out)
3369 .arg(&mut y_out)
3370 .arg(&mut partial_s)
3371 .arg(&mut partial_r)
3372 .arg(&mut status_dev);
3373 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3377 }
3378 stream
3379 .synchronize()
3380 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3381
3382 let status_host = stream
3384 .clone_dtoh(&status_dev)
3385 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3386 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3387 let pivot = status_host[row];
3388 let scale = sys.rows[row]
3389 .htt
3390 .diag()
3391 .iter()
3392 .map(|v| v.abs())
3393 .fold(0.0_f64, f64::max)
3394 .max(1.0);
3395 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3396 row,
3397 bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3398 });
3399 }
3400
3401 let partial_s_host = stream
3403 .clone_dtoh(&partial_s)
3404 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3405 let partial_r_host = stream
3406 .clone_dtoh(&partial_r)
3407 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3408 let mut schur_host = vec![0.0_f64; k * k];
3409 for col in 0..k {
3410 for row in 0..k {
3411 let mut v = sys.hbb[[row, col]];
3412 if row == col {
3413 v += ridge_beta;
3414 }
3415 schur_host[col * k + row] = v;
3416 }
3417 }
3418 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3419 for i in 0..n {
3420 let s_base = i * r_template * r_template;
3423 for col in 0..k {
3424 let col_base = s_base + col * r_template;
3425 let dst_col_base = col * k;
3426 for row in 0..k {
3427 schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3428 }
3429 }
3430 let r_base = i * r_template;
3431 for a in 0..k {
3432 rhs_host[a] += partial_r_host[r_base + a];
3433 }
3434 }
3435
3436 let mut schur_dev = stream
3438 .clone_htod(&schur_host)
3439 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3440 let mut rhs_dev = stream
3441 .clone_htod(&rhs_host)
3442 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3443 let solver =
3444 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3446 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3447 if info != 0 {
3448 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3449 reason: format!("fused Schur Cholesky failed at pivot {info}"),
3450 });
3451 }
3452 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3453 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3454 let delta_beta_host = stream
3455 .clone_dtoh(&rhs_dev)
3456 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457 let delta_beta = Array1::from_vec(delta_beta_host.clone());
3458
3459 let mut delta_t_dev = stream
3461 .alloc_zeros::<f64>(n * p_max)
3462 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463 let back_cfg = LaunchConfig {
3464 grid_dim: (plan.blocks, 1, 1),
3465 block_dim: (plan.threads_per_block, 1, 1),
3466 shared_mem_bytes: 0,
3467 };
3468 {
3469 let mut builder = stream.launch_builder(&back_sub);
3470 builder
3471 .arg(&l_out)
3472 .arg(&u_out)
3473 .arg(&y_out)
3474 .arg(&rhs_dev)
3475 .arg(&n_i32)
3476 .arg(&p_i32)
3477 .arg(&r_i32)
3478 .arg(&mut delta_t_dev);
3479 unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3483 }
3484 stream
3485 .synchronize()
3486 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487
3488 let delta_t_host = stream
3489 .clone_dtoh(&delta_t_dev)
3490 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3491 let mut delta_t = Array1::<f64>::zeros(n * d);
3492 for i in 0..n {
3493 let src_base = i * p_max;
3494 let dst_base = i * d;
3495 for r in 0..d {
3496 delta_t[dst_base + r] = delta_t_host[src_base + r];
3497 }
3498 }
3499
3500 let l_local_host = stream
3502 .clone_dtoh(&l_out)
3503 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3504 let l_schur_host = stream
3505 .clone_dtoh(&schur_dev)
3506 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3507 let mut log_det = 0.0_f64;
3508 for i in 0..n {
3509 let base = i * p_max * p_max;
3510 for j in 0..d {
3511 log_det += l_local_host[base + j * p_max + j].ln();
3512 }
3513 }
3514 for j in 0..k {
3515 log_det += l_schur_host[j * k + j].ln();
3516 }
3517 log_det *= 2.0;
3518
3519 Ok(ArrowSchurGpuSolution {
3520 delta_t,
3521 delta_beta,
3522 log_det_hessian: log_det,
3523 })
3524 }
3525
3526 pub(super) fn build_schur_matvec_backend(
3536 sys: &ArrowSchurSystem,
3537 ridge_t: f64,
3538 ridge_beta: f64,
3539 ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3540 let n = sys.rows.len();
3541 let d = sys.d;
3542 let k = sys.k;
3543 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3544 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3545 let p_max = plan.p_max;
3546 let r_template = plan.r_template;
3547
3548 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3549 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3550 )
3551 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3552 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3553 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3554 let stream = ctx
3555 .new_stream()
3556 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3557 let cap = &runtime.device.capability;
3558 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3559 cc_major: cap.compute_major,
3560 cc_minor: cap.compute_minor,
3561 p_max: p_max as u32,
3562 r_template: r_template as u32,
3563 };
3564 let module = fused_module_for(&ctx, key)?;
3565 let forward = module
3566 .load_function("arrow_schur_forward_pgroup")
3567 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3568
3569 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3570 let d_dev = stream
3571 .clone_htod(&d_host)
3572 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3573 let b_dev = stream
3574 .clone_htod(&b_host)
3575 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3576 let g_dev = stream
3577 .clone_htod(&g_host)
3578 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3579 let mut l_out = stream
3580 .alloc_zeros::<f64>(n * p_max * p_max)
3581 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3582 let mut u_out = stream
3583 .alloc_zeros::<f64>(n * p_max)
3584 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3585 let mut y_out = stream
3586 .alloc_zeros::<f64>(n * p_max * r_template)
3587 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3588 let mut partial_s = stream
3589 .alloc_zeros::<f64>(plan.partial_s_doubles)
3590 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3591 let mut partial_r = stream
3592 .alloc_zeros::<f64>(plan.partial_r_doubles)
3593 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3594 let mut status_dev = stream
3595 .alloc_zeros::<i32>(n)
3596 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3597
3598 let cfg = LaunchConfig {
3599 grid_dim: (plan.blocks, 1, 1),
3600 block_dim: (plan.threads_per_block, 1, 1),
3601 shared_mem_bytes: 0,
3602 };
3603 let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3604 let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3605 let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3606 let ridge_arg = ridge_t;
3607 {
3608 let mut builder = stream.launch_builder(&forward);
3609 builder
3610 .arg(&d_dev)
3611 .arg(&b_dev)
3612 .arg(&g_dev)
3613 .arg(&n_i32)
3614 .arg(&p_i32)
3615 .arg(&r_i32)
3616 .arg(&ridge_arg)
3617 .arg(&mut l_out)
3618 .arg(&mut u_out)
3619 .arg(&mut y_out)
3620 .arg(&mut partial_s)
3621 .arg(&mut partial_r)
3622 .arg(&mut status_dev);
3623 unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3626 }
3627 stream
3628 .synchronize()
3629 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3630
3631 let status_host = stream
3632 .clone_dtoh(&status_dev)
3633 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3634 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3635 let pivot = status_host[row];
3636 let scale = sys.rows[row]
3637 .htt
3638 .diag()
3639 .iter()
3640 .map(|v| v.abs())
3641 .fold(0.0_f64, f64::max)
3642 .max(1.0);
3643 return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3644 row,
3645 bump: scale * (pivot.abs() as f64).max(1.0) * f64::EPSILON.sqrt() * 1024.0,
3646 });
3647 }
3648
3649 let y_host = stream
3651 .clone_dtoh(&y_out)
3652 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3653
3654 let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3657 let hbb_is_kk = sys.hbb.dim() == (k, k);
3658 let hbb_matvec_opt = sys.hbb_matvec.clone();
3659
3660 let closure: crate::arrow_schur::GpuSchurMatvec =
3661 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3662 assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3663 assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3664
3665 if let Some(ref mv) = hbb_matvec_opt {
3667 mv(x.view(), out);
3668 for a in 0..k {
3669 out[a] += ridge_beta * x[a];
3670 }
3671 } else if hbb_is_kk {
3672 for a in 0..k {
3674 let mut acc = ridge_beta * x[a];
3675 for b in 0..k {
3676 acc += hbb_host[a * k + b] * x[b];
3677 }
3678 out[a] = acc;
3679 }
3680 } else {
3681 for a in 0..k {
3682 out[a] = ridge_beta * x[a];
3683 }
3684 }
3685
3686 let mut z = vec![0.0_f64; d];
3689 for i in 0..n {
3690 let y_base = i * p_max * r_template;
3691 for r in 0..d {
3692 let mut acc = 0.0;
3693 for c in 0..k {
3694 acc += y_host[y_base + c * p_max + r] * x[c];
3695 }
3696 z[r] = acc;
3697 }
3698 for c in 0..k {
3699 let mut acc = 0.0;
3700 for r in 0..d {
3701 acc += y_host[y_base + c * p_max + r] * z[r];
3702 }
3703 out[c] -= acc;
3704 }
3705 }
3706 });
3707
3708 Ok(closure)
3709 }
3710
3711 struct DeviceSaeFrameBuffers {
3714 s_off: CudaSlice<i32>,
3716 s_m: CudaSlice<i32>,
3717 s_r: CudaSlice<i32>,
3718 s_ptr: CudaSlice<i32>,
3719 s_data: CudaSlice<f64>,
3720 s_blocks: usize,
3721 g_off_i: CudaSlice<i32>,
3723 g_off_j: CudaSlice<i32>,
3724 g_ri: CudaSlice<i32>,
3725 g_rj: CudaSlice<i32>,
3726 g_mi: CudaSlice<i32>,
3727 g_mj: CudaSlice<i32>,
3728 g_ptr: CudaSlice<i32>,
3729 g_data: CudaSlice<f64>,
3730 w_ptr: CudaSlice<i32>,
3731 w_data: CudaSlice<f64>,
3732 g_blocks: usize,
3733 g_max_work: usize,
3734 htb_ptr: CudaSlice<i32>,
3736 htb: CudaSlice<f64>,
3737 q_of: CudaSlice<i32>,
3738 ainv: CudaSlice<f64>,
3739 hvec: CudaSlice<f64>,
3740 svec: CudaSlice<f64>,
3741 n_rows: usize,
3742 k: usize,
3743 max_q: usize,
3744 }
3745
3746 fn flatten_device_sae_frame_data(
3747 sys: &ArrowSchurSystem,
3748 data: &DeviceSaePcgData,
3749 frame: &DeviceSaeFrameData,
3750 ridge_t: f64,
3751 stream: &Arc<CudaStream>,
3752 ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3753 let n_rows = sys.rows.len();
3754 let k = data.beta_dim;
3755 if frame.row_htbeta.len() != n_rows
3756 || frame.ranks.len() != frame.basis_sizes.len()
3757 || frame.border_offsets.len() != frame.ranks.len()
3758 || data.smooth_blocks.len() != frame.smooth_ranks.len()
3759 {
3760 return Err(ArrowSchurGpuFailure::Unavailable);
3761 }
3762
3763 let mut s_off = Vec::new();
3765 let mut s_m = Vec::new();
3766 let mut s_r = Vec::new();
3767 let mut s_ptr = vec![0_i32];
3768 let mut s_data = Vec::<f64>::new();
3769 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3770 let (m, mc) = blk.factor_a.dim();
3771 if m != mc {
3772 return Err(ArrowSchurGpuFailure::Unavailable);
3773 }
3774 s_off.push(checked_i32(blk.global_offset)?);
3775 s_m.push(checked_i32(m)?);
3776 s_r.push(checked_i32(r)?);
3777 for ri in 0..m {
3778 for ci in 0..m {
3779 s_data.push(blk.factor_a[[ri, ci]]);
3780 }
3781 }
3782 s_ptr.push(checked_i32(s_data.len())?);
3783 }
3784
3785 let mut g_off_i = Vec::new();
3787 let mut g_off_j = Vec::new();
3788 let mut g_ri = Vec::new();
3789 let mut g_rj = Vec::new();
3790 let mut g_mi = Vec::new();
3791 let mut g_mj = Vec::new();
3792 let mut g_ptr = vec![0_i32];
3793 let mut g_data = Vec::<f64>::new();
3794 let mut w_ptr = vec![0_i32];
3795 let mut w_data = Vec::<f64>::new();
3796 let mut g_max_work = 0usize;
3797 for blk in &frame.frame_blocks {
3798 let ri = frame.ranks[blk.atom_i];
3799 let rj = frame.ranks[blk.atom_j];
3800 let (mi, mj) = blk.g.dim();
3801 if blk.w.dim() != (ri, rj) {
3802 return Err(ArrowSchurGpuFailure::Unavailable);
3803 }
3804 g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3805 g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3806 g_ri.push(checked_i32(ri)?);
3807 g_rj.push(checked_i32(rj)?);
3808 g_mi.push(checked_i32(mi)?);
3809 g_mj.push(checked_i32(mj)?);
3810 for r in 0..mi {
3811 for c in 0..mj {
3812 g_data.push(blk.g[[r, c]]);
3813 }
3814 }
3815 g_ptr.push(checked_i32(g_data.len())?);
3816 for a in 0..ri {
3817 for b in 0..rj {
3818 w_data.push(blk.w[[a, b]]);
3819 }
3820 }
3821 w_ptr.push(checked_i32(w_data.len())?);
3822 g_max_work = g_max_work.max(mi * ri);
3823 }
3824
3825 let mut htb_ptr = vec![0_i32];
3827 let mut htb = Vec::<f64>::new();
3828 let mut q_of = Vec::<i32>::with_capacity(n_rows);
3829 let mut max_q = 0usize;
3830 for (i, slab) in frame.row_htbeta.iter().enumerate() {
3831 let qi = sys.row_dims[i];
3832 let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3835 qi
3836 } else {
3837 0
3838 };
3839 q_of.push(checked_i32(q_eff)?);
3840 max_q = max_q.max(q_eff);
3841 if q_eff > 0 {
3842 htb.extend_from_slice(slab);
3843 }
3844 htb_ptr.push(checked_i32(htb.len())?);
3845 }
3846 if max_q == 0 {
3847 max_q = 1;
3850 }
3851
3852 let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3853 for (i, row) in sys.rows.iter().enumerate() {
3854 let q = q_of[i] as usize;
3855 if q == 0 {
3856 continue;
3857 }
3858 if row.htt.dim() != (q, q) {
3859 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3860 reason: format!(
3861 "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3862 row.htt.dim()
3863 ),
3864 });
3865 }
3866 let mut block = row.htt.clone();
3867 for d in 0..q {
3868 block[[d, d]] += ridge_t;
3869 }
3870 let factor = gam_linalg::triangular::cholesky_factor_in_place(
3871 block.view(),
3872 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3873 )
3874 .ok_or_else(|| {
3875 let scale = row
3876 .htt
3877 .diag()
3878 .iter()
3879 .map(|v| v.abs())
3880 .fold(0.0_f64, f64::max)
3881 .max(1.0);
3882 ArrowSchurGpuFailure::RidgeBumpRequired {
3883 row: i,
3884 bump: scale * f64::EPSILON.sqrt() * super::RIDGE_BUMP_EPS_MARGIN,
3885 }
3886 })?;
3887 for col in 0..q {
3888 let mut e = Array1::<f64>::zeros(q);
3889 e[col] = 1.0;
3890 let solved =
3891 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3892 for r in 0..q {
3893 ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3894 }
3895 }
3896 }
3897
3898 let htod_i = |v: &[i32]| {
3899 stream
3900 .clone_htod(v)
3901 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3902 };
3903 let htod_f = |v: &[f64]| {
3904 stream
3905 .clone_htod(v)
3906 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3907 };
3908 Ok(DeviceSaeFrameBuffers {
3909 s_off: htod_i(&s_off)?,
3910 s_m: htod_i(&s_m)?,
3911 s_r: htod_i(&s_r)?,
3912 s_ptr: htod_i(&s_ptr)?,
3913 s_data: htod_f(&s_data)?,
3914 s_blocks: data.smooth_blocks.len(),
3915 g_off_i: htod_i(&g_off_i)?,
3916 g_off_j: htod_i(&g_off_j)?,
3917 g_ri: htod_i(&g_ri)?,
3918 g_rj: htod_i(&g_rj)?,
3919 g_mi: htod_i(&g_mi)?,
3920 g_mj: htod_i(&g_mj)?,
3921 g_ptr: htod_i(&g_ptr)?,
3922 g_data: htod_f(&g_data)?,
3923 w_ptr: htod_i(&w_ptr)?,
3924 w_data: htod_f(&w_data)?,
3925 g_blocks: frame.frame_blocks.len(),
3926 g_max_work,
3927 htb_ptr: htod_i(&htb_ptr)?,
3928 htb: htod_f(&htb)?,
3929 q_of: htod_i(&q_of)?,
3930 ainv: htod_f(&ainv)?,
3931 hvec: stream
3932 .alloc_zeros::<f64>(n_rows * max_q)
3933 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
3934 svec: stream
3935 .alloc_zeros::<f64>(n_rows * max_q)
3936 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
3937 n_rows,
3938 k,
3939 max_q,
3940 })
3941 }
3942
3943 fn sae_frame_penalty_diag_host(
3944 data: &DeviceSaePcgData,
3945 frame: &DeviceSaeFrameData,
3946 ridge_beta: f64,
3947 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
3948 let mut diag = vec![ridge_beta; data.beta_dim];
3949 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3951 let m = blk.factor_a.nrows();
3952 for ia in 0..m {
3953 let coeff = blk.factor_a[[ia, ia]];
3954 let base = blk.global_offset + ia * r;
3955 for ib in 0..r {
3956 if base + ib >= diag.len() {
3957 return Err(ArrowSchurGpuFailure::Unavailable);
3958 }
3959 diag[base + ib] += coeff;
3960 }
3961 }
3962 }
3963 for blk in &frame.frame_blocks {
3965 if blk.atom_i != blk.atom_j {
3966 continue;
3967 }
3968 let r = frame.ranks[blk.atom_i];
3969 let off = frame.border_offsets[blk.atom_i];
3970 let (mi, mj) = blk.g.dim();
3971 for li in 0..mi.min(mj) {
3972 let gii = blk.g[[li, li]];
3973 let base = off + li * r;
3974 for a in 0..r {
3975 if base + a >= diag.len() {
3976 return Err(ArrowSchurGpuFailure::Unavailable);
3977 }
3978 diag[base + a] += gii * blk.w[[a, a]];
3979 }
3980 }
3981 }
3982 Ok(diag)
3983 }
3984
3985 fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
3986 Ok(LaunchConfig {
3987 grid_dim: (
3988 ((work as u32).saturating_add(255) / 256).max(1),
3989 checked_i32(n_rows)? as u32,
3990 1,
3991 ),
3992 block_dim: (256, 1, 1),
3993 shared_mem_bytes: 0,
3994 })
3995 }
3996
3997 fn launch_sae_frame_matvec(
3998 stream: &Arc<CudaStream>,
3999 module: &Arc<CudaModule>,
4000 buffers: &mut DeviceSaeFrameBuffers,
4001 x: &CudaSlice<f64>,
4002 out: &mut CudaSlice<f64>,
4003 ridge_beta: f64,
4004 ) -> Result<(), ArrowSchurGpuFailure> {
4005 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4006 if buffers.s_blocks > 0 {
4008 let kernel = module
4009 .load_function("arrow_sae_frame_smooth_matvec")
4010 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4011 let blocks_i32 = checked_i32(buffers.s_blocks)?;
4012 let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4013 let mut b = stream.launch_builder(&kernel);
4014 b.arg(x)
4015 .arg(&mut *out)
4016 .arg(&buffers.s_off)
4017 .arg(&buffers.s_m)
4018 .arg(&buffers.s_r)
4019 .arg(&buffers.s_ptr)
4020 .arg(&buffers.s_data)
4021 .arg(&blocks_i32);
4022 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4025 }
4026 if buffers.g_blocks > 0 {
4028 let kernel = module
4029 .load_function("arrow_sae_frame_g_matvec")
4030 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4031 let blocks_i32 = checked_i32(buffers.g_blocks)?;
4032 let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4033 let mut b = stream.launch_builder(&kernel);
4034 b.arg(x)
4035 .arg(&mut *out)
4036 .arg(&buffers.g_off_i)
4037 .arg(&buffers.g_off_j)
4038 .arg(&buffers.g_ri)
4039 .arg(&buffers.g_rj)
4040 .arg(&buffers.g_mi)
4041 .arg(&buffers.g_mj)
4042 .arg(&buffers.g_ptr)
4043 .arg(&buffers.g_data)
4044 .arg(&buffers.w_ptr)
4045 .arg(&buffers.w_data)
4046 .arg(&blocks_i32);
4047 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4050 }
4051 let k_i32 = checked_i32(buffers.k)?;
4053 let max_q_i32 = checked_i32(buffers.max_q)?;
4054 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4055 {
4056 let kernel = module
4057 .load_function("arrow_sae_frame_apply_h")
4058 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4059 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4060 let mut b = stream.launch_builder(&kernel);
4061 b.arg(x)
4062 .arg(&buffers.htb_ptr)
4063 .arg(&buffers.htb)
4064 .arg(&buffers.q_of)
4065 .arg(&mut buffers.hvec)
4066 .arg(&k_i32)
4067 .arg(&max_q_i32)
4068 .arg(&n_rows_i32);
4069 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4072 }
4073 {
4074 let kernel = module
4075 .load_function("arrow_sae_frame_apply_ainv")
4076 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4077 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4078 let mut b = stream.launch_builder(&kernel);
4079 b.arg(&buffers.ainv)
4080 .arg(&buffers.hvec)
4081 .arg(&buffers.q_of)
4082 .arg(&mut buffers.svec)
4083 .arg(&max_q_i32)
4084 .arg(&n_rows_i32);
4085 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4088 }
4089 {
4090 let kernel = module
4091 .load_function("arrow_sae_frame_scatter_h")
4092 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4093 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4094 let mut b = stream.launch_builder(&kernel);
4095 b.arg(&buffers.svec)
4096 .arg(&buffers.htb_ptr)
4097 .arg(&buffers.htb)
4098 .arg(&buffers.q_of)
4099 .arg(out)
4100 .arg(&k_i32)
4101 .arg(&max_q_i32)
4102 .arg(&n_rows_i32);
4103 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4106 }
4107 Ok(())
4108 }
4109
4110 fn launch_sae_frame_diag_sub(
4111 stream: &Arc<CudaStream>,
4112 module: &Arc<CudaModule>,
4113 buffers: &DeviceSaeFrameBuffers,
4114 diag: &mut CudaSlice<f64>,
4115 ) -> Result<(), ArrowSchurGpuFailure> {
4116 let kernel = module
4117 .load_function("arrow_sae_frame_diag_sub")
4118 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4119 let k_i32 = checked_i32(buffers.k)?;
4120 let max_q_i32 = checked_i32(buffers.max_q)?;
4121 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4122 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4123 let mut b = stream.launch_builder(&kernel);
4124 b.arg(diag)
4125 .arg(&buffers.ainv)
4126 .arg(&buffers.htb_ptr)
4127 .arg(&buffers.htb)
4128 .arg(&buffers.q_of)
4129 .arg(&k_i32)
4130 .arg(&max_q_i32)
4131 .arg(&n_rows_i32);
4132 unsafe { b.launch(cfg) }
4134 .map(drop)
4135 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4136 }
4137
4138 pub(super) fn solve_sae_matrix_free_pcg_framed(
4139 sys: &ArrowSchurSystem,
4140 data: &DeviceSaePcgData,
4141 ridge_t: f64,
4142 ridge_beta: f64,
4143 rhs_beta: &Array1<f64>,
4144 max_iterations: usize,
4145 relative_tolerance: f64,
4146 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4147 let k = rhs_beta.len();
4148 if k == 0 || data.beta_dim != k || sys.k != k {
4149 return Err(ArrowSchurGpuFailure::Unavailable);
4150 }
4151 let frame = data
4152 .frame
4153 .as_ref()
4154 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4155 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4156 .filter(|rt| {
4157 rt.policy().reduced_schur_matvec_should_offload(
4158 sys.rows.len(),
4159 sys.k,
4160 sys.d,
4161 max_iterations,
4162 )
4163 })
4164 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4165 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4166 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4167 let stream = ctx
4168 .new_stream()
4169 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4170 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4171 let vector_module = pcg_vector_module(&ctx)?;
4172 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4173
4174 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4175 if rhs_norm == 0.0 {
4176 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4177 }
4178 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4179 let rhs_dev = stream
4180 .clone_htod(
4181 rhs_beta
4182 .as_slice()
4183 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4184 )
4185 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4186 let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4187 let mut diag_dev = stream
4188 .clone_htod(&diag_host)
4189 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4190 launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4191 let diag_host = stream
4192 .clone_dtoh(&diag_dev)
4193 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4194 let mut inv_diag = Vec::with_capacity(k);
4195 for (idx, &d) in diag_host.iter().enumerate() {
4196 if !d.is_finite() || d <= 1.0e-18 {
4197 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4198 reason: format!(
4199 "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4200 ),
4201 });
4202 }
4203 inv_diag.push(1.0 / d);
4204 }
4205 let inv_diag_dev = stream
4206 .clone_htod(&inv_diag)
4207 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4208
4209 let mut x_dev = stream
4210 .alloc_zeros::<f64>(k)
4211 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4212 let mut r_dev = stream
4213 .alloc_zeros::<f64>(k)
4214 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4215 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4216 let mut z_dev = stream
4217 .alloc_zeros::<f64>(k)
4218 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4219 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4220 let mut p_dev = stream
4221 .alloc_zeros::<f64>(k)
4222 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4223 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4224 let mut ap_dev = stream
4225 .alloc_zeros::<f64>(k)
4226 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4227
4228 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4229 if rz <= 0.0 || !rz.is_finite() {
4230 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4231 reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4232 });
4233 }
4234 let mut diag = PcgDiagnostics {
4235 precond_apply_calls: 1,
4236 stopping_reason: PcgStopReason::MaxIter,
4237 ..PcgDiagnostics::default()
4238 };
4239 for _ in 0..max_iterations.max(1) {
4240 launch_sae_frame_matvec(
4241 &stream,
4242 vector_module,
4243 &mut buffers,
4244 &p_dev,
4245 &mut ap_dev,
4246 ridge_beta,
4247 )?;
4248 diag.matvec_calls += 1;
4249 diag.iterations += 1;
4250 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4251 if pap <= 0.0 || !pap.is_finite() {
4252 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4253 reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4254 });
4255 }
4256 let alpha = rz / pap;
4257 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4258 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4259 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4260 if r_norm <= tol {
4261 diag.final_relative_residual = r_norm / rhs_norm;
4262 diag.stopping_reason = PcgStopReason::Converged;
4263 break;
4264 }
4265 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4266 diag.precond_apply_calls += 1;
4267 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4268 if rz_new <= 0.0 || !rz_new.is_finite() {
4269 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4270 reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4271 });
4272 }
4273 let beta = rz_new / rz;
4274 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4275 rz = rz_new;
4276 }
4277 if diag.stopping_reason != PcgStopReason::Converged {
4278 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4279 diag.final_relative_residual = r_norm / rhs_norm;
4280 diag.stopping_reason = PcgStopReason::MaxIter;
4281 }
4282 let x = stream
4283 .clone_dtoh(&x_dev)
4284 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4285 Ok((Array1::from_vec(x), diag))
4286 }
4287
4288 pub(super) fn solve_sae_matrix_free_pcg(
4295 sys: &ArrowSchurSystem,
4296 data: &DeviceSaePcgData,
4297 ridge_t: f64,
4298 ridge_beta: f64,
4299 rhs_beta: &Array1<f64>,
4300 max_iterations: usize,
4301 relative_tolerance: f64,
4302 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4303 let k = rhs_beta.len();
4304 if k == 0 || data.beta_dim != k || sys.k != k {
4305 return Err(ArrowSchurGpuFailure::Unavailable);
4306 }
4307 if data.frame.is_some() {
4311 return Err(ArrowSchurGpuFailure::Unavailable);
4312 }
4313 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4327 .filter(|rt| {
4328 rt.policy().reduced_schur_matvec_should_offload(
4329 sys.rows.len(),
4330 sys.k,
4331 sys.d,
4332 max_iterations,
4333 )
4334 })
4335 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4336 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4337 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4338 let stream = ctx
4339 .new_stream()
4340 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4341 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4342 let vector_module = pcg_vector_module(&ctx)?;
4343 let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4344
4345 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4346 if rhs_norm == 0.0 {
4347 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4348 }
4349 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4350 let rhs_dev = stream
4351 .clone_htod(
4352 rhs_beta
4353 .as_slice()
4354 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4355 )
4356 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4357 let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4358 let mut diag_dev = stream
4359 .clone_htod(&diag_host)
4360 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4361 launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4362 let diag_host = stream
4363 .clone_dtoh(&diag_dev)
4364 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4365 let mut inv_diag = Vec::with_capacity(k);
4366 for (idx, &d) in diag_host.iter().enumerate() {
4367 if !d.is_finite() || d <= 1.0e-18 {
4368 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4369 reason: format!(
4370 "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4371 ),
4372 });
4373 }
4374 inv_diag.push(1.0 / d);
4375 }
4376 let inv_diag_dev = stream
4377 .clone_htod(&inv_diag)
4378 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4379
4380 let mut x_dev = stream
4381 .alloc_zeros::<f64>(k)
4382 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4383 let mut r_dev = stream
4384 .alloc_zeros::<f64>(k)
4385 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4386 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4387 let mut z_dev = stream
4388 .alloc_zeros::<f64>(k)
4389 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4390 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4391 let mut p_dev = stream
4392 .alloc_zeros::<f64>(k)
4393 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4394 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4395 let mut ap_dev = stream
4396 .alloc_zeros::<f64>(k)
4397 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4398
4399 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4400 if rz <= 0.0 || !rz.is_finite() {
4401 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4402 reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4403 });
4404 }
4405 let mut diag = PcgDiagnostics {
4406 precond_apply_calls: 1,
4407 stopping_reason: PcgStopReason::MaxIter,
4408 ..PcgDiagnostics::default()
4409 };
4410
4411 for _ in 0..max_iterations.max(1) {
4412 launch_sae_matvec(
4413 &stream,
4414 vector_module,
4415 &mut buffers,
4416 &p_dev,
4417 &mut ap_dev,
4418 ridge_beta,
4419 )?;
4420 diag.matvec_calls += 1;
4421 diag.iterations += 1;
4422 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4423 if pap <= 0.0 || !pap.is_finite() {
4424 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4425 reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4426 });
4427 }
4428 let alpha = rz / pap;
4429 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4430 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4431 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4432 if r_norm <= tol {
4433 diag.final_relative_residual = r_norm / rhs_norm;
4434 diag.stopping_reason = PcgStopReason::Converged;
4435 break;
4436 }
4437 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4438 diag.precond_apply_calls += 1;
4439 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4440 if rz_new <= 0.0 || !rz_new.is_finite() {
4441 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4442 reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4443 });
4444 }
4445 let beta = rz_new / rz;
4446 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4447 rz = rz_new;
4448 }
4449 if diag.stopping_reason != PcgStopReason::Converged {
4450 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4451 diag.final_relative_residual = r_norm / rhs_norm;
4452 diag.stopping_reason = PcgStopReason::MaxIter;
4453 }
4454 let x = stream
4455 .clone_dtoh(&x_dev)
4456 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4457 Ok((Array1::from_vec(x), diag))
4458 }
4459
4460 pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4461 s_acc: &ndarray::Array2<f64>,
4462 rhs_beta: &Array1<f64>,
4463 max_iterations: usize,
4464 relative_tolerance: f64,
4465 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4466 let k = rhs_beta.len();
4467 let cg_iters = max_iterations.max(1);
4479 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4480 gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4481 m: k,
4482 n: k,
4483 k: cg_iters,
4484 },
4485 )
4486 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4487 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4488 .and_then(|ctx| ctx.new_stream().ok())
4489 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4490 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4491 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4492 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4493 let vector_module = pcg_vector_module(&ctx)?;
4494
4495 let mut inv_diag = vec![0.0_f64; k];
4497 for j in 0..k {
4498 let djj = s_acc[[j, j]];
4499 if !(djj > 0.0) {
4500 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4501 reason: format!(
4502 "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4503 ),
4504 });
4505 }
4506 inv_diag[j] = 1.0 / djj;
4507 }
4508
4509 let mut s_host = vec![0.0_f64; k * k];
4511 for col in 0..k {
4512 for row in 0..k {
4513 s_host[col * k + row] = s_acc[[row, col]];
4514 }
4515 }
4516 let s_dev = stream
4517 .clone_htod(&s_host)
4518 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4519
4520 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4524 if rhs_norm == 0.0 {
4525 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4526 }
4527 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4528
4529 let mut x_dev = stream
4532 .alloc_zeros::<f64>(k)
4533 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4534 let mut r_dev = stream
4535 .clone_htod(
4536 rhs_beta
4537 .as_slice()
4538 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4539 )
4540 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541 let inv_diag_dev = stream
4542 .clone_htod(&inv_diag)
4543 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544 let mut z_dev = stream
4545 .alloc_zeros::<f64>(k)
4546 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4547 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4548 let mut p_dev = stream
4549 .alloc_zeros::<f64>(k)
4550 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4551 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4552 let mut sp_dev = stream
4553 .alloc_zeros::<f64>(k)
4554 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4555 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4556 let mut diag = PcgDiagnostics {
4557 precond_apply_calls: 1,
4558 stopping_reason: PcgStopReason::MaxIter,
4559 ..PcgDiagnostics::default()
4560 };
4561 if rz <= 0.0 || !rz.is_finite() {
4562 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4563 reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4564 });
4565 }
4566
4567 let max_iters = max_iterations.max(1);
4568 for _ in 0..max_iters {
4569 let gemv_cfg = GemvConfig::<f64> {
4571 trans: cublasOperation_t::CUBLAS_OP_N,
4572 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4573 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4574 alpha: 1.0,
4575 lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4576 incx: 1,
4577 beta: 0.0,
4578 incy: 1,
4579 };
4580 unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4582 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4583 diag.matvec_calls += 1;
4584 diag.iterations += 1;
4585
4586 let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4587 if !(p_sp > 0.0) {
4588 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4591 reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4592 });
4593 }
4594 let alpha = rz / p_sp;
4595 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4596 device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4597 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4598 if r_norm <= tol {
4599 diag.final_relative_residual = r_norm / rhs_norm;
4600 diag.stopping_reason = PcgStopReason::Converged;
4601 break;
4602 }
4603 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4604 diag.precond_apply_calls += 1;
4605 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4606 if rz_new <= 0.0 || !rz_new.is_finite() {
4607 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4608 reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4609 });
4610 }
4611 let beta = rz_new / rz;
4612 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4613 rz = rz_new;
4614 }
4615 if diag.stopping_reason != PcgStopReason::Converged {
4616 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4617 diag.final_relative_residual = r_norm / rhs_norm;
4618 diag.stopping_reason = PcgStopReason::MaxIter;
4619 }
4620
4621 let x = stream
4622 .clone_dtoh(&x_dev)
4623 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4624 Ok((Array1::from_vec(x), diag))
4625 }
4626
4627 fn device_copy(
4628 blas: &CudaBlas,
4629 stream: &Arc<CudaStream>,
4630 n: usize,
4631 src: &CudaSlice<f64>,
4632 dst: &mut CudaSlice<f64>,
4633 ) -> Result<(), ArrowSchurGpuFailure> {
4634 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4635 let (src_ptr, _src_rec) = src.device_ptr(stream);
4636 let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4637 let status = unsafe {
4640 cudarc::cublas::sys::cublasDcopy_v2(
4641 *blas.handle(),
4642 n_i,
4643 src_ptr as *const f64,
4644 1,
4645 dst_ptr as *mut f64,
4646 1,
4647 )
4648 };
4649 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4650 Ok(())
4651 } else {
4652 Err(ArrowSchurGpuFailure::Unavailable)
4653 }
4654 }
4655
4656 fn device_axpy(
4657 blas: &CudaBlas,
4658 stream: &Arc<CudaStream>,
4659 n: usize,
4660 alpha: f64,
4661 x: &CudaSlice<f64>,
4662 y: &mut CudaSlice<f64>,
4663 ) -> Result<(), ArrowSchurGpuFailure> {
4664 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4665 let (x_ptr, _x_rec) = x.device_ptr(stream);
4666 let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4667 let status = unsafe {
4670 cudarc::cublas::sys::cublasDaxpy_v2(
4671 *blas.handle(),
4672 n_i,
4673 &alpha,
4674 x_ptr as *const f64,
4675 1,
4676 y_ptr as *mut f64,
4677 1,
4678 )
4679 };
4680 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4681 Ok(())
4682 } else {
4683 Err(ArrowSchurGpuFailure::Unavailable)
4684 }
4685 }
4686
4687 fn device_dot(
4688 blas: &CudaBlas,
4689 stream: &Arc<CudaStream>,
4690 n: usize,
4691 x: &CudaSlice<f64>,
4692 y: &CudaSlice<f64>,
4693 ) -> Result<f64, ArrowSchurGpuFailure> {
4694 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4695 let (x_ptr, _x_rec) = x.device_ptr(stream);
4696 let (y_ptr, _y_rec) = y.device_ptr(stream);
4697 let mut result = 0.0_f64;
4698 let status = unsafe {
4702 cudarc::cublas::sys::cublasDdot_v2(
4703 *blas.handle(),
4704 n_i,
4705 x_ptr as *const f64,
4706 1,
4707 y_ptr as *const f64,
4708 1,
4709 &mut result,
4710 )
4711 };
4712 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4713 Ok(result)
4714 } else {
4715 Err(ArrowSchurGpuFailure::Unavailable)
4716 }
4717 }
4718
4719 fn device_nrm2(
4720 blas: &CudaBlas,
4721 stream: &Arc<CudaStream>,
4722 n: usize,
4723 x: &CudaSlice<f64>,
4724 ) -> Result<f64, ArrowSchurGpuFailure> {
4725 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4726 let (x_ptr, _x_rec) = x.device_ptr(stream);
4727 let mut result = 0.0_f64;
4728 let status = unsafe {
4732 cudarc::cublas::sys::cublasDnrm2_v2(
4733 *blas.handle(),
4734 n_i,
4735 x_ptr as *const f64,
4736 1,
4737 &mut result,
4738 )
4739 };
4740 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4741 Ok(result)
4742 } else {
4743 Err(ArrowSchurGpuFailure::Unavailable)
4744 }
4745 }
4746
4747 #[cfg(test)]
4748 mod tests {
4749 use super::*;
4754 use crate::arrow_schur::{
4755 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4756 FactoredFrameGBlock,
4757 };
4758 use ndarray::Array2;
4759
4760 fn device_matvec_once(
4763 sys: &ArrowSchurSystem,
4764 data: &DeviceSaePcgData,
4765 ridge_t: f64,
4766 ridge_beta: f64,
4767 x_host: &[f64],
4768 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4769 let k = x_host.len();
4770 let frame = data
4771 .frame
4772 .as_ref()
4773 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4774 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4775 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4776 let ctx =
4777 gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4778 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4779 let stream = ctx
4780 .new_stream()
4781 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4782 let vector_module = pcg_vector_module(&ctx)?;
4783 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4784 let x_dev = stream
4785 .clone_htod(x_host)
4786 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4787 let mut out_dev = stream
4788 .alloc_zeros::<f64>(k)
4789 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4790 launch_sae_frame_matvec(
4791 &stream,
4792 vector_module,
4793 &mut buffers,
4794 &x_dev,
4795 &mut out_dev,
4796 ridge_beta,
4797 )?;
4798 stream
4799 .clone_dtoh(&out_dev)
4800 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4801 }
4802
4803 #[test]
4809 fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4810 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4811 return;
4812 }
4813 let p = 3usize;
4814 let ranks = vec![2usize, 3usize];
4815 let basis_sizes = vec![2usize, 2usize];
4816 let mut border_offsets = Vec::new();
4817 let mut acc = 0usize;
4818 for k in 0..2 {
4819 border_offsets.push(acc);
4820 acc += basis_sizes[k] * ranks[k];
4821 }
4822 let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
4824 Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4825 0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4826 })
4827 };
4828 let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4829 let w_of = |i: usize, j: usize| -> Array2<f64> {
4830 let (ui, uj) = (&frames[i], &frames[j]);
4831 Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4832 (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4833 })
4834 };
4835 let mut frame_blocks = Vec::new();
4836 for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4837 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4838 let mut g =
4839 Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4840 if i == j {
4841 for r in 0..mi.min(mj) {
4842 g[[r, r]] += mi as f64 + 2.0;
4843 }
4844 }
4845 frame_blocks.push(FactoredFrameGBlock {
4846 atom_i: i,
4847 atom_j: j,
4848 g,
4849 w: w_of(i, j),
4850 });
4851 }
4852 let mut smooth_blocks = Vec::new();
4853 for k in 0..2 {
4854 let m = basis_sizes[k];
4855 let mut s =
4856 Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
4857 for r in 0..m {
4858 s[[r, r]] += 1.0;
4859 }
4860 smooth_blocks.push(DeviceSaeSmoothBlock {
4861 global_offset: border_offsets[k],
4862 factor_a: s,
4863 });
4864 }
4865 let smooth_ranks = ranks.clone();
4866 let n = 2usize;
4867 let q = 2usize;
4868 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
4869 let mut row_htbeta = Vec::new();
4870 for i in 0..n {
4871 let mut htt =
4872 Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
4873 for r in 0..q {
4874 htt[[r, r]] += q as f64 + 2.0;
4875 }
4876 sys.rows[i].htt = htt;
4877 let mut slab = vec![0.0_f64; q * border_dim];
4878 for c in 0..q {
4879 for col in 0..border_dim {
4880 let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
4881 slab[c * border_dim + col] = v;
4882 sys.rows[i].htbeta[[c, col]] = v;
4883 }
4884 }
4885 row_htbeta.push(slab);
4886 }
4887 let data = DeviceSaePcgData {
4888 p,
4889 beta_dim: border_dim,
4890 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
4891 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
4892 smooth_blocks,
4893 sparse_g_blocks: Vec::new(),
4894 frame: Some(DeviceSaeFrameData {
4895 ranks,
4896 basis_sizes,
4897 border_offsets,
4898 frame_blocks,
4899 smooth_ranks,
4900 row_htbeta,
4901 }),
4902 };
4903 let ridge_t = 1e-7;
4904 let ridge_beta = 1e-6;
4905 let mut first_bad: Option<usize> = None;
4906 let mut worst = 0.0_f64;
4907 let mut worst_at = 0usize;
4908 let mut worst_dev = 0.0_f64;
4909 let mut worst_cpu = 0.0_f64;
4910 for col in 0..border_dim {
4911 let mut x = vec![0.0_f64; border_dim];
4912 x[col] = 1.0;
4913 let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
4914 Ok(v) => v,
4915 Err(_) => return,
4916 };
4917 let mut cpu = vec![0.0_f64; border_dim];
4918 super::super::sae_framed_schur_matvec_cpu(
4919 &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
4920 )
4921 .expect("cpu matvec");
4922 for r in 0..border_dim {
4923 let d = (dev[r] - cpu[r]).abs();
4924 if d > 1e-9 && first_bad.is_none() {
4925 first_bad = Some(r * border_dim + col);
4926 }
4927 if d > worst {
4928 worst = d;
4929 worst_at = r * border_dim + col;
4930 worst_dev = dev[r];
4931 worst_cpu = cpu[r];
4932 }
4933 }
4934 }
4935 assert!(
4936 worst <= 1e-9,
4937 "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
4938 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
4939 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
4940 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
4941 G⊗W=cross, reduced-Schur=dense per-row)",
4942 );
4943 }
4944 }
4945}
4946
4947#[cfg(test)]
4948mod tests {
4949 use super::*;
4950 use crate::arrow_schur::ArrowSchurSystem;
4951 use ndarray::{Array2, ArrayView1};
4952
4953 fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
4954 let mut sys = ArrowSchurSystem::new(n, d, k);
4955 let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
4956 let mut sample = || -> f64 {
4957 state = state
4958 .wrapping_mul(6364136223846793005)
4959 .wrapping_add(1442695040888963407);
4960 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
4961 };
4962 for row in &mut sys.rows {
4963 let mut a = Array2::<f64>::zeros((d, d));
4964 for r in 0..d {
4965 for c in 0..d {
4966 a[[r, c]] = sample();
4967 }
4968 }
4969 let mut htt = a.t().dot(&a);
4970 for r in 0..d {
4971 htt[[r, r]] += d as f64 + 1.0;
4972 }
4973 row.htt = htt;
4974 for r in 0..d {
4975 for c in 0..k {
4976 row.htbeta[[r, c]] = 0.1 * sample();
4977 }
4978 row.gt[r] = sample();
4979 }
4980 }
4981 let mut hbb_a = Array2::<f64>::zeros((k, k));
4982 for r in 0..k {
4983 for c in 0..k {
4984 hbb_a[[r, c]] = sample();
4985 }
4986 }
4987 let mut hbb = hbb_a.t().dot(&hbb_a);
4988 for r in 0..k {
4989 hbb[[r, r]] += k as f64 + 1.0;
4990 }
4991 sys.hbb = hbb;
4992 for r in 0..k {
4993 sys.gb[r] = sample();
4994 }
4995 sys
4996 }
4997
4998 fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
4999 let mut s = Array2::<f64>::zeros((k, k));
5000 for row in 0..k {
5001 s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5002 if row + 1 < k {
5003 s[[row, row + 1]] = -0.05;
5004 s[[row + 1, row]] = -0.05;
5005 }
5006 if row + 7 < k {
5007 s[[row, row + 7]] = 0.01;
5008 s[[row + 7, row]] = 0.01;
5009 }
5010 }
5011 let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5012 (s, rhs)
5013 }
5014
5015 fn dense_pcg_cpu_reference(
5016 s: &Array2<f64>,
5017 rhs: &Array1<f64>,
5018 max_iterations: usize,
5019 relative_tolerance: f64,
5020 ) -> Array1<f64> {
5021 let k = rhs.len();
5022 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5023 if rhs_norm == 0.0 {
5024 return Array1::<f64>::zeros(k);
5025 }
5026 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5027 let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5028 let mut x = Array1::<f64>::zeros(k);
5029 let mut r = rhs.clone();
5030 let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5031 let mut p = z.clone();
5032 let mut sp = Array1::<f64>::zeros(k);
5033 let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5034 for _ in 0..max_iterations.max(1) {
5035 for row in 0..k {
5036 let mut acc = 0.0;
5037 for col in 0..k {
5038 acc += s[[row, col]] * p[col];
5039 }
5040 sp[row] = acc;
5041 }
5042 let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5043 let alpha = rz / p_sp;
5044 for idx in 0..k {
5045 x[idx] += alpha * p[idx];
5046 r[idx] -= alpha * sp[idx];
5047 }
5048 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5049 if r_norm <= tol {
5050 break;
5051 }
5052 for idx in 0..k {
5053 z[idx] = inv_diag[idx] * r[idx];
5054 }
5055 let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5056 let beta = rz_next / rz;
5057 for idx in 0..k {
5058 p[idx] = z[idx] + beta * p[idx];
5059 }
5060 rz = rz_next;
5061 }
5062 x
5063 }
5064
5065 #[test]
5066 fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5067 let (s, rhs) = device_pcg_fixture(512);
5068 let max_iterations = 200usize;
5069 let relative_tolerance = 1.0e-12;
5070 let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5071 let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5072 &s,
5073 &rhs,
5074 max_iterations,
5075 relative_tolerance,
5076 ) {
5077 Ok(result) => result,
5078 Err(failure) => {
5085 assert!(
5086 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5087 "#1017: CUDA device present but the device reduced-beta PCG \
5088 declined/faulted instead of returning a result (tag: {failure:?}) — \
5089 the kernel does not run correctly on GPU"
5090 );
5091 return;
5092 }
5093 };
5094 let max_err = cpu
5095 .iter()
5096 .zip(device.iter())
5097 .map(|(a, b)| (a - b).abs())
5098 .fold(0.0_f64, f64::max);
5099 assert!(
5100 max_err <= 1.0e-10,
5101 "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5102 );
5103 assert!(diag.matvec_calls > 0);
5104 assert_eq!(diag.matvec_calls, diag.iterations);
5105 }
5106
5107 #[test]
5108 fn dense_reference_matches_independent_solve() {
5109 let sys = build_fixture(4, 5, 3, 7);
5110 let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5111 let n = sys.rows.len();
5115 let d = sys.d;
5116 let k = sys.k;
5117 let total = n * d + k;
5118 let mut h = Array2::<f64>::zeros((total, total));
5119 let mut g = ndarray::Array1::<f64>::zeros(total);
5120 for (i, row) in sys.rows.iter().enumerate() {
5121 let base = i * d;
5122 for c in 0..d {
5123 for r in 0..d {
5124 h[[base + r, base + c]] = row.htt[[r, c]];
5125 }
5126 }
5127 for c in 0..k {
5128 for r in 0..d {
5129 h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5130 h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5131 }
5132 }
5133 for r in 0..d {
5134 g[base + r] = row.gt[r];
5135 }
5136 }
5137 for c in 0..k {
5138 for r in 0..k {
5139 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5140 }
5141 g[n * d + c] = sys.gb[c];
5142 }
5143 let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5144 let rhs = g.mapv(|v| -v);
5145 let expected = cholesky_solve_vector(l.view(), rhs.view());
5146 for i in 0..n * d {
5147 assert!(
5148 (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5149 "delta_t[{i}] mismatch: got {} expected {}",
5150 solution.delta_t[i],
5151 expected[i]
5152 );
5153 }
5154 for a in 0..k {
5155 assert!(
5156 (solution.delta_beta[a] - expected[n * d + a]).abs()
5157 < 1e-10 * (1.0 + expected[n * d + a].abs()),
5158 "delta_beta[{a}] mismatch"
5159 );
5160 }
5161 }
5162
5163 #[test]
5177 fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5178 use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5179 let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
5181 let k = 24usize;
5182 let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5183 let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5188 let forward_slabs = slabs.clone();
5189 let transpose_slabs = slabs;
5190 sys.set_row_htbeta_operator(
5191 move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5192 let h = &forward_slabs[row];
5193 for r in 0..h.nrows() {
5194 let mut acc = 0.0_f64;
5195 for c in 0..h.ncols() {
5196 acc += h[[r, c]] * x[c];
5197 }
5198 out[r] = acc;
5199 }
5200 },
5201 move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5202 let h = &transpose_slabs[row];
5203 for r in 0..h.nrows() {
5204 for c in 0..h.ncols() {
5205 out[c] += h[[r, c]] * v[r];
5206 }
5207 }
5208 },
5209 );
5210
5211 let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5212 .expect("row-procedural matvec backend builds for matrix-free system");
5213 let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5214
5215 let mut out_parallel_a = Array1::<f64>::zeros(k);
5219 matvec(&x, &mut out_parallel_a);
5220 let mut out_parallel_b = Array1::<f64>::zeros(k);
5221 matvec(&x, &mut out_parallel_b);
5222 for a in 0..k {
5223 assert_eq!(
5224 out_parallel_a[a].to_bits(),
5225 out_parallel_b[a].to_bits(),
5226 "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5227 );
5228 }
5229
5230 let mut out_serial = Array1::<f64>::zeros(k);
5235 rayon::ThreadPoolBuilder::new()
5236 .num_threads(2)
5237 .build()
5238 .expect("build rayon pool")
5239 .install(|| matvec(&x, &mut out_serial));
5240
5241 let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5242 for a in 0..k {
5243 let diff = (out_parallel_a[a] - out_serial[a]).abs();
5244 assert!(
5245 diff <= 1e-12 * (1.0 + max_abs),
5246 "row-procedural matvec parallel vs serial diverged beyond reassociation \
5247 at index {a}: {} vs {} (diff={diff:e})",
5248 out_parallel_a[a],
5249 out_serial[a]
5250 );
5251 }
5252 }
5253
5254 #[test]
5261 fn framed_sae_schur_matvec_matches_dense_reference() {
5262 use crate::arrow_schur::{
5263 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5264 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5265 };
5266
5267 let p = 4usize;
5268 let ranks = vec![2usize, 4usize, 3usize];
5270 let basis_sizes = vec![2usize, 1usize, 2usize];
5271 let n_atoms = ranks.len();
5272 let mut border_offsets = Vec::with_capacity(n_atoms);
5273 let mut acc = 0usize;
5274 for k in 0..n_atoms {
5275 border_offsets.push(acc);
5276 acc += basis_sizes[k] * ranks[k];
5277 }
5278 let border_dim = acc; let mut state = 0x1234_5678_9abc_def0u64;
5281 let mut sample = || -> f64 {
5282 state = state
5283 .wrapping_mul(6364136223846793005)
5284 .wrapping_add(1442695040888963407);
5285 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5286 };
5287
5288 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5291 for k in 0..n_atoms {
5292 let r = ranks[k];
5293 let mut u = Array2::<f64>::zeros((p, r));
5294 for i in 0..p {
5295 for j in 0..r {
5296 u[[i, j]] = if r == p && i == j {
5297 1.0
5298 } else if r == p {
5299 0.0
5300 } else {
5301 sample()
5302 };
5303 }
5304 }
5305 frames.push(u);
5306 }
5307 let w_of = |i: usize, j: usize| -> Array2<f64> {
5308 let (ui, uj) = (&frames[i], &frames[j]);
5309 let (ri, rj) = (ranks[i], ranks[j]);
5310 let mut w = Array2::<f64>::zeros((ri, rj));
5311 for a in 0..ri {
5312 for b in 0..rj {
5313 let mut s = 0.0;
5314 for c in 0..p {
5315 s += ui[[c, a]] * uj[[c, b]];
5316 }
5317 w[[a, b]] = s;
5318 }
5319 }
5320 w
5321 };
5322
5323 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5325 let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5326 pairs.sort();
5327 for &(i, j) in &pairs {
5328 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5329 let mut g = Array2::<f64>::zeros((mi, mj));
5330 for r in 0..mi {
5331 for c in 0..mj {
5332 g[[r, c]] = 0.3 * sample();
5333 }
5334 }
5335 if i == j {
5337 for r in 0..mi.min(mj) {
5338 g[[r, r]] += mi as f64 + 2.0;
5339 }
5340 }
5341 frame_blocks.push(FactoredFrameGBlock {
5342 atom_i: i,
5343 atom_j: j,
5344 g,
5345 w: w_of(i, j),
5346 });
5347 }
5348
5349 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5351 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5352 for k in 0..n_atoms {
5353 let m = basis_sizes[k];
5354 let mut a = Array2::<f64>::zeros((m, m));
5355 for r in 0..m {
5356 for c in 0..m {
5357 a[[r, c]] = 0.2 * sample();
5358 }
5359 }
5360 let mut s = a.t().dot(&a);
5361 for r in 0..m {
5362 s[[r, r]] += 1.0;
5363 }
5364 smooth_blocks.push(DeviceSaeSmoothBlock {
5365 global_offset: border_offsets[k],
5366 factor_a: s,
5367 });
5368 smooth_ranks.push(ranks[k]);
5369 }
5370
5371 let n = 6usize;
5373 let q = 3usize;
5374 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5375 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5376 for i in 0..n {
5377 let mut a = Array2::<f64>::zeros((q, q));
5379 for r in 0..q {
5380 for c in 0..q {
5381 a[[r, c]] = sample();
5382 }
5383 }
5384 let mut htt = a.t().dot(&a);
5385 for r in 0..q {
5386 htt[[r, r]] += q as f64 + 1.0;
5387 }
5388 sys.rows[i].htt = htt;
5389 let mut slab = vec![0.0_f64; q * border_dim];
5390 for c in 0..q {
5391 for col in 0..border_dim {
5392 let v = 0.15 * sample();
5393 slab[c * border_dim + col] = v;
5394 sys.rows[i].htbeta[[c, col]] = v;
5395 }
5396 }
5397 row_htbeta.push(slab);
5398 }
5399
5400 let data_op =
5403 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5404 .expect("frame op");
5405 let mut hbb = data_op.to_dense();
5406 for k in 0..n_atoms {
5407 let op = IdentityRightKroneckerPenaltyOp {
5408 factor_a: smooth_blocks[k].factor_a.clone(),
5409 p: ranks[k],
5410 global_offset: border_offsets[k],
5411 k: border_dim,
5412 };
5413 let d = op.to_dense();
5414 for r in 0..border_dim {
5415 for c in 0..border_dim {
5416 hbb[[r, c]] += d[[r, c]];
5417 }
5418 }
5419 }
5420 sys.hbb = hbb;
5421
5422 let data = DeviceSaePcgData {
5423 p,
5424 beta_dim: border_dim,
5425 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5426 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5427 smooth_blocks,
5428 sparse_g_blocks: Vec::new(),
5429 frame: Some(DeviceSaeFrameData {
5430 ranks: ranks.clone(),
5431 basis_sizes: basis_sizes.clone(),
5432 border_offsets: border_offsets.clone(),
5433 frame_blocks,
5434 smooth_ranks,
5435 row_htbeta,
5436 }),
5437 };
5438
5439 let ridge_t = 1e-7;
5440 let ridge_beta = 1e-6;
5441
5442 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5446 for r in 0..border_dim {
5447 for c in 0..border_dim {
5448 s_dense[[r, c]] = sys.hbb[[r, c]];
5449 }
5450 s_dense[[r, r]] += ridge_beta;
5451 }
5452 for row in &sys.rows {
5453 let mut htt = row.htt.clone();
5454 for d in 0..q {
5455 htt[[d, d]] += ridge_t;
5456 }
5457 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5458 .expect("htt PD");
5459 let mut y = Array2::<f64>::zeros((q, border_dim));
5461 for col in 0..border_dim {
5462 let mut e = Array1::<f64>::zeros(q);
5463 for r in 0..q {
5464 e[r] = row.htbeta[[r, col]];
5465 }
5466 let solved = cholesky_solve_vector(factor.view(), e.view());
5467 for r in 0..q {
5468 y[[r, col]] = solved[r];
5469 }
5470 }
5471 for r in 0..border_dim {
5472 for c in 0..border_dim {
5473 let mut acc = 0.0;
5474 for d in 0..q {
5475 acc += row.htbeta[[d, r]] * y[[d, c]];
5476 }
5477 s_dense[[r, c]] -= acc;
5478 }
5479 }
5480 }
5481
5482 let mut max_rel = 0.0_f64;
5484 for trial in 0..4 {
5485 let x: Vec<f64> = (0..border_dim)
5486 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5487 .collect();
5488 let mut got = vec![0.0_f64; border_dim];
5489 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5490 .expect("framed matvec");
5491 let mut want = vec![0.0_f64; border_dim];
5492 for r in 0..border_dim {
5493 let mut acc = 0.0;
5494 for c in 0..border_dim {
5495 acc += s_dense[[r, c]] * x[c];
5496 }
5497 want[r] = acc;
5498 }
5499 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5500 for a in 0..border_dim {
5501 let rel = (got[a] - want[a]).abs() / scale;
5502 max_rel = max_rel.max(rel);
5503 }
5504 }
5505 assert!(
5506 max_rel <= 1e-10,
5507 "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5508 );
5509 }
5510
5511 #[test]
5517 fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
5518 use crate::arrow_schur::{
5519 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5520 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5521 };
5522
5523 let p = 6usize;
5527 let n_atoms = 8usize;
5528 let ranks: Vec<usize> = (0..n_atoms)
5529 .map(|k| if k % 2 == 0 { 3usize } else { p })
5530 .collect();
5531 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
5532 let mut border_offsets = Vec::with_capacity(n_atoms);
5533 let mut acc = 0usize;
5534 for k in 0..n_atoms {
5535 border_offsets.push(acc);
5536 acc += basis_sizes[k] * ranks[k];
5537 }
5538 let border_dim = acc; let mut state = 0xfeed_face_dead_beefu64;
5541 let mut sample = || -> f64 {
5542 state = state
5543 .wrapping_mul(6364136223846793005)
5544 .wrapping_add(1442695040888963407);
5545 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5546 };
5547 let mut frames: Vec<Array2<f64>> = Vec::new();
5548 for k in 0..n_atoms {
5549 let r = ranks[k];
5550 let mut u = Array2::<f64>::zeros((p, r));
5551 for i in 0..p {
5552 for j in 0..r {
5553 u[[i, j]] = if r == p && i == j {
5554 1.0
5555 } else if r == p {
5556 0.0
5557 } else {
5558 sample()
5559 };
5560 }
5561 }
5562 frames.push(u);
5563 }
5564 let w_of = |i: usize, j: usize| {
5565 let (ui, uj) = (&frames[i], &frames[j]);
5566 let (ri, rj) = (ranks[i], ranks[j]);
5567 let mut w = Array2::<f64>::zeros((ri, rj));
5568 for a in 0..ri {
5569 for b in 0..rj {
5570 let mut s = 0.0;
5571 for c in 0..p {
5572 s += ui[[c, a]] * uj[[c, b]];
5573 }
5574 w[[a, b]] = s;
5575 }
5576 }
5577 w
5578 };
5579 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
5580 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
5582 pairs.push((i, j));
5583 pairs.push((j, i));
5584 }
5585 let mut frame_blocks = Vec::new();
5586 for &(i, j) in &pairs {
5587 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5588 let mut g = Array2::<f64>::zeros((mi, mj));
5589 for r in 0..mi {
5590 for c in 0..mj {
5591 g[[r, c]] = 0.25 * sample();
5592 }
5593 }
5594 if i == j {
5595 for r in 0..mi.min(mj) {
5596 g[[r, r]] += mi as f64 + 2.0;
5597 }
5598 }
5599 frame_blocks.push(FactoredFrameGBlock {
5600 atom_i: i,
5601 atom_j: j,
5602 g,
5603 w: w_of(i, j),
5604 });
5605 }
5606 let mut smooth_blocks = Vec::new();
5607 let mut smooth_ranks = Vec::new();
5608 for k in 0..n_atoms {
5609 let m = basis_sizes[k];
5610 let mut a = Array2::<f64>::zeros((m, m));
5611 for r in 0..m {
5612 for c in 0..m {
5613 a[[r, c]] = 0.2 * sample();
5614 }
5615 }
5616 let mut s = a.t().dot(&a);
5617 for r in 0..m {
5618 s[[r, r]] += 1.0;
5619 }
5620 smooth_blocks.push(DeviceSaeSmoothBlock {
5621 global_offset: border_offsets[k],
5622 factor_a: s,
5623 });
5624 smooth_ranks.push(ranks[k]);
5625 }
5626 let n = 400usize;
5627 let q = 4usize;
5628 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5629 let mut row_htbeta = Vec::new();
5630 for i in 0..n {
5631 let mut a = Array2::<f64>::zeros((q, q));
5632 for r in 0..q {
5633 for c in 0..q {
5634 a[[r, c]] = sample();
5635 }
5636 }
5637 let mut htt = a.t().dot(&a);
5638 for r in 0..q {
5639 htt[[r, r]] += q as f64 + 1.0;
5640 }
5641 sys.rows[i].htt = htt;
5642 let mut slab = vec![0.0_f64; q * border_dim];
5643 for c in 0..q {
5644 for col in 0..border_dim {
5645 let v = 0.02 * sample();
5648 slab[c * border_dim + col] = v;
5649 sys.rows[i].htbeta[[c, col]] = v;
5650 }
5651 }
5652 row_htbeta.push(slab);
5653 }
5654 let data_op =
5655 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5656 .expect("frame op");
5657 let mut hbb = data_op.to_dense();
5658 for k in 0..n_atoms {
5659 let op = IdentityRightKroneckerPenaltyOp {
5660 factor_a: smooth_blocks[k].factor_a.clone(),
5661 p: ranks[k],
5662 global_offset: border_offsets[k],
5663 k: border_dim,
5664 };
5665 let d = op.to_dense();
5666 for r in 0..border_dim {
5667 for c in 0..border_dim {
5668 hbb[[r, c]] += d[[r, c]];
5669 }
5670 }
5671 }
5672 sys.hbb = hbb;
5673 let data = DeviceSaePcgData {
5674 p,
5675 beta_dim: border_dim,
5676 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5677 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5678 smooth_blocks,
5679 sparse_g_blocks: Vec::new(),
5680 frame: Some(DeviceSaeFrameData {
5681 ranks: ranks.clone(),
5682 basis_sizes: basis_sizes.clone(),
5683 border_offsets: border_offsets.clone(),
5684 frame_blocks,
5685 smooth_ranks,
5686 row_htbeta,
5687 }),
5688 };
5689 let ridge_t = 1e-7;
5690 let ridge_beta = 1e-6;
5691 let rhs: Array1<f64> =
5692 Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
5693
5694 let (device, diag) =
5695 match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
5696 Ok(result) => result,
5697 Err(failure) => {
5703 assert!(
5704 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5705 "#1017: CUDA device present but the framed device SAE PCG \
5706 declined/faulted instead of returning a result (tag: {failure:?}) — \
5707 the kernel does not run correctly on GPU"
5708 );
5709 return;
5710 }
5711 };
5712
5713 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5716 for col in 0..border_dim {
5717 let mut e = vec![0.0_f64; border_dim];
5718 e[col] = 1.0;
5719 let mut sc = vec![0.0_f64; border_dim];
5720 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &e, &mut sc)
5721 .expect("cpu matvec");
5722 for r in 0..border_dim {
5723 s_dense[[r, col]] = sc[r];
5724 }
5725 }
5726 let factor = cholesky_factor_in_place(s_dense.view(), CholeskyGuard::NonnegativePivot)
5727 .expect("S PD");
5728 let cpu = cholesky_solve_vector(factor.view(), rhs.view());
5729
5730 let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5731 let mut max_rel = 0.0_f64;
5732 for a in 0..border_dim {
5733 max_rel = max_rel.max((device[a] - cpu[a]).abs() / scale);
5734 }
5735 let mut s_dev_resid = 0.0_f64;
5744 {
5745 let sx = s_dense.dot(&device);
5746 for a in 0..border_dim {
5747 s_dev_resid = s_dev_resid.max((sx[a] - rhs[a]).abs());
5748 }
5749 }
5750 let s_cpu_resid = {
5751 let sc = s_dense.dot(&cpu);
5752 let mut m = 0.0_f64;
5753 for a in 0..border_dim {
5754 m = m.max((sc[a] - rhs[a]).abs());
5755 }
5756 m
5757 };
5758 assert!(
5759 max_rel <= 1e-7,
5760 "[#1551 framed-triage] max_rel={max_rel:e} | device-vs-CPU-operator residual \
5761 ‖S_cpu·device−rhs‖={s_dev_resid:e} (CPU's own ={s_cpu_resid:e}) | device PCG \
5762 stop={:?} iters={} final_rel_resid={:e} — large operator-residual ⇒ device matvec \
5763 is a different operator (kernel bug); small ⇒ PCG/precond or singular-S issue",
5764 diag.stopping_reason,
5765 diag.iterations,
5766 diag.final_relative_residual,
5767 );
5768 }
5769}