1use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24pub struct ArrowSchurGpuSolution {
26 pub delta_t: Array1<f64>,
27 pub delta_beta: Array1<f64>,
28 pub log_det_hessian: f64,
31}
32
33#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38 Unavailable,
40 RidgeBumpRequired { row: usize, bump: f64 },
43 SchurFactorFailed { reason: String },
46 GpuRequiresDenseSystem {
57 had_hbb_matvec: bool,
58 had_htbeta_matvec: bool,
59 },
60}
61
62const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
73
74#[must_use]
110fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
111 let d = htt.nrows();
112 let mut scale = 1.0_f64;
115 let mut min_gershgorin_edge = f64::INFINITY;
116 for i in 0..d {
117 let diag = htt[[i, i]];
118 scale = scale.max(diag.abs());
119 let mut off_sum = 0.0_f64;
120 for j in 0..d {
121 if j != i {
122 off_sum += htt[[i, j]].abs();
123 }
124 }
125 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
126 }
127 if !min_gershgorin_edge.is_finite() {
128 return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131 }
132 let deficit = -(min_gershgorin_edge + ridge_t);
135 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
136 deficit.max(0.0) + margin
139}
140
141#[cfg(target_os = "linux")]
154#[must_use]
155fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
156 if d == 0 || block.len() < d * d {
157 return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
158 }
159 let mut scale = 1.0_f64;
162 let mut min_gershgorin_edge = f64::INFINITY;
163 for i in 0..d {
164 let diag = block[i * d + i];
165 scale = scale.max(diag.abs());
166 let mut off_sum = 0.0_f64;
167 for j in 0..d {
168 if j != i {
169 off_sum += block[j * d + i].abs();
170 }
171 }
172 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
173 }
174 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
175 if !min_gershgorin_edge.is_finite() {
176 return margin;
177 }
178 (-min_gershgorin_edge).max(0.0) + margin
179}
180
181pub fn solve_arrow_newton_step(
185 sys: &ArrowSchurSystem,
186 ridge_t: f64,
187 ridge_beta: f64,
188) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
189 let n = sys.rows.len();
190 let d = sys.d;
191 let k = sys.k;
192
193 let had_hbb_matvec = sys.hbb_matvec.is_some();
198 let had_htbeta_matvec = sys.htbeta_matvec.is_some();
199 if had_hbb_matvec || had_htbeta_matvec {
200 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
201 had_hbb_matvec,
202 had_htbeta_matvec,
203 });
204 }
205
206 if sys.penalty_op.is_some() {
223 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
224 had_hbb_matvec: false,
225 had_htbeta_matvec: false,
226 });
227 }
228
229 if sys.hbb.dim() != (k, k) {
230 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
244 had_hbb_matvec: false,
245 had_htbeta_matvec: false,
246 });
247 }
248 if n == 0 || d == 0 {
249 return Err(ArrowSchurGpuFailure::Unavailable);
250 }
251 if sys
252 .rows
253 .iter()
254 .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
255 {
256 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
257 reason: "row block dimension mismatch".to_string(),
258 });
259 }
260
261 #[cfg(not(target_os = "linux"))]
262 {
263 if ridge_t.is_nan() || ridge_beta.is_nan() {
264 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
265 reason: "ridge is NaN".to_string(),
266 });
267 }
268 Err(ArrowSchurGpuFailure::Unavailable)
269 }
270
271 #[cfg(target_os = "linux")]
272 {
273 if gam_gpu::device_runtime::GpuRuntime::global()
282 .map(gam_gpu::device_runtime::GpuRuntime::device_count)
283 .unwrap_or(0)
284 > 1
285 {
286 match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
287 Ok(sol) => return Ok(sol),
288 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
289 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
290 }
291 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
292 return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
293 }
294 Err(_) => {}
297 }
298 }
299 if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
305 match cuda::solve_fused(sys, ridge_t, ridge_beta) {
306 Ok(sol) => return Ok(sol),
307 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
311 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
312 }
313 Err(_) => {}
317 }
318 }
319 cuda::solve(sys, ridge_t, ridge_beta)
320 }
321}
322
323#[cfg(target_os = "linux")]
329fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
330 let n = sys.rows.len();
331 let d = sys.d;
332 let k = sys.k;
333 let mut d_buf = Vec::with_capacity(n * d * d);
334 let mut b_buf = Vec::with_capacity(n * d * k);
335 let mut g_buf = Vec::with_capacity(n * d);
336 for row in &sys.rows {
337 pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
338 }
339 (d_buf, b_buf, g_buf)
340}
341
342#[cfg(target_os = "linux")]
343#[inline]
344fn pack_block(
345 row: &crate::arrow_schur::ArrowRowBlock,
346 ridge_t: f64,
347 d: usize,
348 k: usize,
349 d_buf: &mut Vec<f64>,
350 b_buf: &mut Vec<f64>,
351 g_buf: &mut Vec<f64>,
352) {
353 for col in 0..d {
354 for r in 0..d {
355 let mut value = row.htt[[r, col]];
356 if r == col {
357 value += ridge_t;
358 }
359 d_buf.push(value);
360 }
361 }
362 for col in 0..k {
363 for r in 0..d {
364 b_buf.push(row.htbeta[[r, col]]);
365 }
366 }
367 for r in 0..d {
368 g_buf.push(row.gt[r]);
369 }
370}
371
372#[doc(hidden)]
377#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
379 sys: &ArrowSchurSystem,
380 ridge_t: f64,
381 ridge_beta: f64,
382) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
383 if ridge_t.is_nan() || ridge_beta.is_nan() {
384 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
385 reason: "ridge is NaN".to_string(),
386 });
387 }
388 #[cfg(not(target_os = "linux"))]
389 {
390 Err(ArrowSchurGpuFailure::Unavailable)
395 }
396 #[cfg(target_os = "linux")]
397 {
398 if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
399 .is_none()
400 {
401 return Err(ArrowSchurGpuFailure::Unavailable);
402 }
403 cuda::solve_fused(sys, ridge_t, ridge_beta)
404 }
405}
406
407pub struct ResidentArrowFrameHandle {
417 #[cfg(target_os = "linux")]
418 inner: cuda::ResidentArrowFrame,
419 #[cfg(not(target_os = "linux"))]
420 _never: std::convert::Infallible,
421}
422
423impl ResidentArrowFrameHandle {
424 pub fn new(
426 sys: &ArrowSchurSystem,
427 ridge_t: f64,
428 ridge_beta: f64,
429 ) -> Result<Self, ArrowSchurGpuFailure> {
430 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
433 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
434 had_hbb_matvec: sys.hbb_matvec.is_some(),
435 had_htbeta_matvec: sys.htbeta_matvec.is_some(),
436 });
437 }
438 #[cfg(not(target_os = "linux"))]
439 {
440 if ridge_t.is_nan() || ridge_beta.is_nan() {
441 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
442 reason: "ridge is NaN".to_string(),
443 });
444 }
445 Err(ArrowSchurGpuFailure::Unavailable)
446 }
447 #[cfg(target_os = "linux")]
448 {
449 Ok(Self {
450 inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
451 })
452 }
453 }
454
455 pub fn solve_gradient(
457 &self,
458 g_t: &[f64],
459 g_beta: &[f64],
460 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
461 #[cfg(not(target_os = "linux"))]
462 {
463 if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
464 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
465 reason: "non-finite gradient entry".to_string(),
466 });
467 }
468 Err(ArrowSchurGpuFailure::Unavailable)
469 }
470 #[cfg(target_os = "linux")]
471 {
472 self.inner.solve_gradient(g_t, g_beta)
473 }
474 }
475
476 #[must_use]
478 pub fn log_det_hessian(&self) -> f64 {
479 #[cfg(not(target_os = "linux"))]
480 {
481 panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
487 }
488 #[cfg(target_os = "linux")]
489 {
490 self.inner.log_det_hessian()
491 }
492 }
493}
494
495pub fn gpu_schur_matvec_backend(
538 sys: &ArrowSchurSystem,
539 ridge_t: f64,
540 ridge_beta: f64,
541) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
542 if sys.htbeta_matvec.is_some() {
545 return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
546 }
547
548 #[cfg(not(target_os = "linux"))]
549 {
550 if ridge_t.is_nan() || ridge_beta.is_nan() {
553 return Err(ArrowSchurGpuFailure::Unavailable);
554 }
555 Err(ArrowSchurGpuFailure::Unavailable)
556 }
557
558 #[cfg(target_os = "linux")]
559 {
560 cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
561 }
562}
563
564fn build_row_procedural_matvec(
581 sys: &ArrowSchurSystem,
582 ridge_t: f64,
583 ridge_beta: f64,
584) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
585 use std::sync::Arc;
586 let n = sys.rows.len();
587 let k = sys.k;
588 let forward = sys
589 .htbeta_matvec
590 .clone()
591 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
592 let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
593 ArrowSchurGpuFailure::SchurFactorFailed {
598 reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
599 forward operator installed without its sparse adjoint"
600 .to_string(),
601 }
602 })?;
603
604 let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
609 for (i, row) in sys.rows.iter().enumerate() {
610 let di = row.htt.nrows();
611 if row.htt.ncols() != di || row.gt.len() != di {
612 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
613 reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
614 });
615 }
616 let mut block = row.htt.clone();
617 for r in 0..di {
618 block[[r, r]] += ridge_t;
619 }
620 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
621 .ok_or_else(|| {
622 ArrowSchurGpuFailure::RidgeBumpRequired {
626 row: i,
627 bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
628 }
629 })?;
630 factors.push(factor);
631 }
632
633 let penalty_op = sys.effective_penalty_op();
640 let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
641
642 let closure: crate::arrow_schur::GpuSchurMatvec =
643 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
644 assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
645 assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
646
647 {
650 let x_slice = x.as_slice().expect("x must be contiguous");
651 let out_slice = out.as_slice_mut().expect("out must be contiguous");
652 for a in 0..k {
653 out_slice[a] = ridge_beta * x_slice[a];
654 }
655 penalty_op.matvec(x_slice, out_slice);
656 }
657
658 let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
685 && rayon::current_thread_index().is_none();
686 if parallel {
687 use rayon::prelude::*;
688 const CHUNK: usize = 64;
689 let partials: Vec<Array1<f64>> = (0..n)
690 .into_par_iter()
691 .chunks(CHUNK)
692 .map(|idxs| {
693 let mut neg = Array1::<f64>::zeros(k);
697 for i in idxs {
698 let di = row_dims[i];
699 let mut v_i = Array1::<f64>::zeros(di);
701 forward(i, x.view(), &mut v_i);
702 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
704 transpose(i, w_i.view(), &mut neg);
706 }
707 neg
708 })
709 .collect();
710 let mut neg = Array1::<f64>::zeros(k);
722 for part in &partials {
723 for a in 0..k {
724 neg[a] += part[a];
725 }
726 }
727 for a in 0..k {
728 out[a] -= neg[a];
729 }
730 } else {
731 let mut neg = Array1::<f64>::zeros(k);
733 for i in 0..n {
734 let di = row_dims[i];
735 let mut v_i = Array1::<f64>::zeros(di);
737 forward(i, x.view(), &mut v_i);
738 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
740 transpose(i, w_i.view(), &mut neg);
742 }
743 for a in 0..k {
744 out[a] -= neg[a];
745 }
746 }
747 });
748
749 Ok(closure)
750}
751
752pub fn solve_reduced_beta_pcg(
774 s_acc: &Array2<f64>,
775 rhs_beta: &Array1<f64>,
776 max_iterations: usize,
777 relative_tolerance: f64,
778) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
779 solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
780 .map(|(x, _)| x)
781}
782
783#[doc(hidden)]
784pub fn solve_reduced_beta_pcg_with_diagnostics(
785 s_acc: &Array2<f64>,
786 rhs_beta: &Array1<f64>,
787 max_iterations: usize,
788 relative_tolerance: f64,
789) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
790 let k = rhs_beta.len();
791 if s_acc.dim() != (k, k) {
792 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
793 reason: format!(
794 "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
795 s_acc.dim()
796 ),
797 });
798 }
799 if k == 0 {
800 return Err(ArrowSchurGpuFailure::Unavailable);
801 }
802
803 #[cfg(not(target_os = "linux"))]
804 {
805 if relative_tolerance.is_nan() || max_iterations == 0 {
806 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
807 reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
808 });
809 }
810 Err(ArrowSchurGpuFailure::Unavailable)
811 }
812
813 #[cfg(target_os = "linux")]
814 {
815 cuda::solve_reduced_beta_pcg_with_diagnostics(
816 s_acc,
817 rhs_beta,
818 max_iterations,
819 relative_tolerance,
820 )
821 }
822}
823
824pub fn solve_sae_matrix_free_pcg(
825 sys: &ArrowSchurSystem,
826 data: &DeviceSaePcgData,
827 ridge_t: f64,
828 ridge_beta: f64,
829 rhs_beta: &Array1<f64>,
830 max_iterations: usize,
831 relative_tolerance: f64,
832) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
833 if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
834 return Err(ArrowSchurGpuFailure::Unavailable);
835 }
836 #[cfg(not(target_os = "linux"))]
837 {
838 if ridge_t.is_nan()
839 || ridge_beta.is_nan()
840 || relative_tolerance.is_nan()
841 || max_iterations == 0
842 {
843 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
844 reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
845 });
846 }
847 Err(ArrowSchurGpuFailure::Unavailable)
848 }
849 #[cfg(target_os = "linux")]
850 {
851 if data.frame.is_some() {
858 cuda::solve_sae_matrix_free_pcg_framed(
859 sys,
860 data,
861 ridge_t,
862 ridge_beta,
863 rhs_beta,
864 max_iterations,
865 relative_tolerance,
866 )
867 } else {
868 cuda::solve_sae_matrix_free_pcg(
869 sys,
870 data,
871 ridge_t,
872 ridge_beta,
873 rhs_beta,
874 max_iterations,
875 relative_tolerance,
876 )
877 }
878 }
879}
880
881#[doc(hidden)]
890pub fn framed_schur_matvec_once_on_device(
891 sys: &ArrowSchurSystem,
892 data: &DeviceSaePcgData,
893 ridge_t: f64,
894 ridge_beta: f64,
895 x: &Array1<f64>,
896) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
897 if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
898 return Err(ArrowSchurGpuFailure::Unavailable);
899 }
900 if data.frame.is_none() {
901 return Err(ArrowSchurGpuFailure::Unavailable);
902 }
903 #[cfg(not(target_os = "linux"))]
904 {
905 if ridge_t.is_finite() && ridge_beta.is_finite() {
910 return Err(ArrowSchurGpuFailure::Unavailable);
911 }
912 Err(ArrowSchurGpuFailure::Unavailable)
913 }
914 #[cfg(target_os = "linux")]
915 {
916 cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
917 }
918}
919
920#[doc(hidden)]
924pub fn solve_arrow_newton_step_dense_reference(
925 sys: &ArrowSchurSystem,
926 ridge_t: f64,
927 ridge_beta: f64,
928) -> Result<ArrowSchurGpuSolution, String> {
929 let n = sys.rows.len();
930 let d = sys.d;
931 let k = sys.k;
932 let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
933 let mut h = Array2::<f64>::zeros((total, total));
934 let mut rhs = Array1::<f64>::zeros(total);
935 for (i, row) in sys.rows.iter().enumerate() {
936 let base = i * d;
937 for c in 0..d {
938 for r in 0..d {
939 h[[base + r, base + c]] = row.htt[[r, c]];
940 }
941 h[[base + c, base + c]] += ridge_t;
942 }
943 for c in 0..k {
944 for r in 0..d {
945 let value = row.htbeta[[r, c]];
946 h[[base + r, n * d + c]] = value;
947 h[[n * d + c, base + r]] = value;
948 }
949 }
950 for r in 0..d {
951 rhs[base + r] = -row.gt[r];
952 }
953 }
954 for c in 0..k {
955 for r in 0..k {
956 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
957 }
958 h[[n * d + c, n * d + c]] += ridge_beta;
959 rhs[n * d + c] = -sys.gb[c];
960 }
961 let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
962 .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
963 let mut log_det = 0.0_f64;
964 for i in 0..total {
965 log_det += factor[[i, i]].ln();
966 }
967 log_det *= 2.0;
968 let solved = cholesky_solve_vector(factor.view(), rhs.view());
969 let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
970 let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
971 Ok(ArrowSchurGpuSolution {
972 delta_t,
973 delta_beta,
974 log_det_hessian: log_det,
975 })
976}
977
978#[doc(hidden)]
989pub fn sae_framed_penalty_matvec_cpu(
990 data: &DeviceSaePcgData,
991 ridge_beta: f64,
992 x: &[f64],
993 out: &mut [f64],
994) {
995 let frame = data
996 .frame
997 .as_ref()
998 .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
999 let k = data.beta_dim;
1000 for a in 0..k {
1001 out[a] = ridge_beta * x[a];
1002 }
1003 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
1005 let off = blk.global_offset;
1006 let m = blk.factor_a.nrows();
1007 for i_a in 0..m {
1008 for i_b in 0..r {
1009 let mut acc = 0.0_f64;
1010 for j_a in 0..m {
1011 let s = blk.factor_a[[i_a, j_a]];
1012 if s == 0.0 {
1013 continue;
1014 }
1015 acc += s * x[off + j_a * r + i_b];
1016 }
1017 out[off + i_a * r + i_b] += acc;
1018 }
1019 }
1020 }
1021 for blk in &frame.frame_blocks {
1023 let r_i = frame.ranks[blk.atom_i];
1024 let r_j = frame.ranks[blk.atom_j];
1025 let off_i = frame.border_offsets[blk.atom_i];
1026 let off_j = frame.border_offsets[blk.atom_j];
1027 let (m_i, m_j) = blk.g.dim();
1028 for li in 0..m_i {
1029 let yi_base = off_i + li * r_i;
1030 for lj in 0..m_j {
1031 let g = blk.g[[li, lj]];
1032 if g == 0.0 {
1033 continue;
1034 }
1035 let xj_base = off_j + lj * r_j;
1036 for a in 0..r_i {
1037 let mut acc = 0.0_f64;
1038 for b in 0..r_j {
1039 acc += blk.w[[a, b]] * x[xj_base + b];
1040 }
1041 out[yi_base + a] += g * acc;
1042 }
1043 }
1044 }
1045 }
1046}
1047
1048#[doc(hidden)]
1057pub fn sae_framed_schur_matvec_cpu(
1058 sys: &ArrowSchurSystem,
1059 data: &DeviceSaePcgData,
1060 ridge_t: f64,
1061 ridge_beta: f64,
1062 x: &[f64],
1063 out: &mut [f64],
1064) -> Result<(), String> {
1065 let frame = data
1066 .frame
1067 .as_ref()
1068 .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1069 let k = data.beta_dim;
1070 sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1071 if frame.row_htbeta.len() != sys.rows.len() {
1072 return Err(format!(
1073 "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1074 frame.row_htbeta.len(),
1075 sys.rows.len()
1076 ));
1077 }
1078 for (i, row) in sys.rows.iter().enumerate() {
1079 let slab = &frame.row_htbeta[i];
1080 if slab.is_empty() {
1081 continue;
1082 }
1083 let qi = sys.row_dims[i];
1084 if qi == 0 || slab.len() != qi * k {
1085 continue;
1086 }
1087 let mut h = vec![0.0_f64; qi];
1089 for c in 0..qi {
1090 let base = c * k;
1091 let mut acc = 0.0_f64;
1092 for a in 0..k {
1093 acc += slab[base + a] * x[a];
1094 }
1095 h[c] = acc;
1096 }
1097 let mut block = row.htt.clone();
1099 for d in 0..qi {
1100 block[[d, d]] += ridge_t;
1101 }
1102 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1103 .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1104 let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1105 for c in 0..qi {
1107 let sc = s[c];
1108 if sc == 0.0 {
1109 continue;
1110 }
1111 let base = c * k;
1112 for a in 0..k {
1113 out[a] -= slab[base + a] * sc;
1114 }
1115 }
1116 }
1117 Ok(())
1118}
1119
1120#[cfg(target_os = "linux")]
1121mod cuda {
1122 use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1123 use gam_gpu::driver::to_i32;
1124 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1125 use crate::arrow_schur::{
1126 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1127 };
1128 use cudarc::cublas::sys::{
1129 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1130 };
1131 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1132 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1133 use cudarc::driver::{
1134 CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1135 PushKernelArg,
1136 };
1137 use ndarray::Array1;
1138 use std::sync::{Arc, OnceLock};
1139
1140 struct RowSlot {
1145 d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
1154 bump: Option<f64>,
1157 tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
1163
1164 pub(super) fn solve_multi_gpu(
1185 sys: &ArrowSchurSystem,
1186 ridge_t: f64,
1187 ridge_beta: f64,
1188 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1189 let n = sys.rows.len();
1190 let d = sys.d;
1191 let k = sys.k;
1192 if n == 0 || d == 0 || k == 0 {
1193 return Err(ArrowSchurGpuFailure::Unavailable);
1194 }
1195 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1199 return Err(ArrowSchurGpuFailure::Unavailable);
1200 }
1201
1202 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1203 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1204 if runtime.device_count() < 2 {
1205 return Err(ArrowSchurGpuFailure::Unavailable);
1206 }
1207
1208 let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1210 for row in &sys.rows {
1211 if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1212 return Err(ArrowSchurGpuFailure::Unavailable);
1213 }
1214 let mut d_block = Vec::with_capacity(d * d);
1215 let mut b_block = Vec::with_capacity(d * k);
1216 let mut g_vec = Vec::with_capacity(d);
1217 pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1218 slots.push(RowSlot {
1219 d_block,
1220 b_block,
1221 g_vec,
1222 l_block: Vec::new(),
1223 u_vec: Vec::new(),
1224 y_block: Vec::new(),
1225 log_det_local: 0.0,
1226 bump: None,
1227 tile_partial_schur: None,
1228 tile_partial_rhs: None,
1229 delta_t_block: vec![0.0; d],
1230 });
1231 }
1232
1233 let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1235 forward_tile(ordinal, d, k, tile)
1236 });
1237 if forward_ok.is_none() {
1238 return Err(ArrowSchurGpuFailure::Unavailable);
1239 }
1240
1241 let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1243 if let Some((row, bump)) = slots
1244 .iter()
1245 .enumerate()
1246 .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1247 {
1248 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1249 }
1250
1251 let mut schur_host = vec![0.0_f64; k * k];
1256 for col in 0..k {
1257 for row in 0..k {
1258 let mut v = sys.hbb[[row, col]];
1259 if row == col {
1260 v += ridge_beta;
1261 }
1262 schur_host[col * k + row] = v;
1263 }
1264 }
1265 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1266 let mut log_det = 0.0_f64;
1267 for start in tile_starts(&row_base_of_tile) {
1268 let slot = &slots[start];
1269 let partial_schur = slot
1270 .tile_partial_schur
1271 .as_ref()
1272 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1273 let partial_rhs = slot
1274 .tile_partial_rhs
1275 .as_ref()
1276 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1277 for idx in 0..k * k {
1282 schur_host[idx] += partial_schur[idx];
1283 }
1284 for a in 0..k {
1285 rhs_host[a] += partial_rhs[a];
1286 }
1287 }
1288 for slot in &slots {
1289 log_det += slot.log_det_local;
1290 }
1291
1292 let primary = runtime.selected_device().ordinal;
1296 let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1297 .and_then(|ctx| ctx.new_stream().ok())
1298 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1299 let solver =
1300 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1301 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1302 let mut schur_dev = stream
1303 .clone_htod(&schur_host)
1304 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1305 let mut rhs_dev = stream
1306 .clone_htod(&rhs_host)
1307 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1308 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1309 if info != 0 {
1310 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1311 reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1312 });
1313 }
1314 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1315 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1316 let delta_beta_host = stream
1317 .clone_dtoh(&rhs_dev)
1318 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1319 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1320 let l_schur_host = stream
1321 .clone_dtoh(&schur_dev)
1322 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1323 for j in 0..k {
1324 log_det += l_schur_host[j * k + j].ln();
1325 }
1326 log_det *= 2.0;
1327
1328 let delta_beta_ref = &delta_beta_host;
1330 let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1331 back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1332 });
1333 if back_ok.is_none() {
1334 return Err(ArrowSchurGpuFailure::Unavailable);
1335 }
1336
1337 let mut delta_t = Array1::<f64>::zeros(n * d);
1339 for (i, slot) in slots.iter().enumerate() {
1340 let base = i * d;
1341 for r in 0..d {
1342 delta_t[base + r] = slot.delta_t_block[r];
1343 }
1344 }
1345
1346 Ok(ArrowSchurGpuSolution {
1347 delta_t,
1348 delta_beta,
1349 log_det_hessian: log_det,
1350 })
1351 }
1352
1353 fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1356 tiles.iter().map(|(_, range)| range.start)
1357 }
1358
1359 fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1367 if tile.is_empty() {
1368 return Some(());
1369 }
1370 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1373 .and_then(|ctx| ctx.new_stream().ok())?;
1374 let solver = DnHandle::new(stream.clone()).ok()?;
1375 let blas = CudaBlas::new(stream.clone()).ok()?;
1376 let m = tile.len();
1377
1378 let mut d_host = Vec::with_capacity(m * d * d);
1381 let mut b_host = Vec::with_capacity(m * d * k);
1382 let mut g_host = Vec::with_capacity(m * d);
1383 for slot in tile.iter() {
1384 d_host.extend_from_slice(&slot.d_block);
1385 b_host.extend_from_slice(&slot.b_block);
1386 g_host.extend_from_slice(&slot.g_vec);
1387 }
1388 let mut d_dev = stream.clone_htod(&d_host).ok()?;
1389 let mut b_dev = stream.clone_htod(&b_host).ok()?;
1390 let mut g_dev = stream.clone_htod(&g_host).ok()?;
1391
1392 let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1398 if let Some(local) = info_host.iter().position(|info| *info != 0) {
1399 tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1400 &tile[local].d_block,
1401 d,
1402 ));
1403 return Some(());
1404 }
1405
1406 trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1408 trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1409
1410 let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1412 let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1413 accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1414
1415 let l_host = stream.clone_dtoh(&d_dev).ok()?;
1417 let u_host = stream.clone_dtoh(&g_dev).ok()?;
1418 let y_host = stream.clone_dtoh(&b_dev).ok()?;
1419 let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1420 let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1421
1422 for (local, slot) in tile.iter_mut().enumerate() {
1423 let l_base = local * d * d;
1424 let u_base = local * d;
1425 let y_base = local * d * k;
1426 slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1427 slot.u_vec = u_host[u_base..u_base + d].to_vec();
1428 slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1429 let mut log_det_local = 0.0_f64;
1430 for j in 0..d {
1431 log_det_local += l_host[l_base + j * d + j].ln();
1432 }
1433 slot.log_det_local = log_det_local;
1434 }
1435 tile[0].tile_partial_schur = Some(partial_schur);
1436 tile[0].tile_partial_rhs = Some(partial_rhs);
1437 Some(())
1438 }
1439
1440 fn back_sub_tile(
1444 ordinal: usize,
1445 d: usize,
1446 k: usize,
1447 delta_beta: &[f64],
1448 tile: &mut [RowSlot],
1449 ) -> Option<()> {
1450 if tile.is_empty() {
1451 return Some(());
1452 }
1453 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1456 .and_then(|ctx| ctx.new_stream().ok())?;
1457 let blas = CudaBlas::new(stream.clone()).ok()?;
1458 let m = tile.len();
1459
1460 let mut l_host = Vec::with_capacity(m * d * d);
1461 let mut u_host = Vec::with_capacity(m * d);
1462 let mut y_host = Vec::with_capacity(m * d * k);
1463 for slot in tile.iter() {
1464 l_host.extend_from_slice(&slot.l_block);
1465 u_host.extend_from_slice(&slot.u_vec);
1466 y_host.extend_from_slice(&slot.y_block);
1467 }
1468 let d_dev = stream.clone_htod(&l_host).ok()?;
1469 let mut g_dev = stream.clone_htod(&u_host).ok()?;
1470 let b_dev = stream.clone_htod(&y_host).ok()?;
1471 let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1472
1473 accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1475 trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1476 let x_host = stream.clone_dtoh(&g_dev).ok()?;
1477 for (local, slot) in tile.iter_mut().enumerate() {
1478 let base = local * d;
1479 for r in 0..d {
1480 slot.delta_t_block[r] = -x_host[base + r];
1481 }
1482 }
1483 Some(())
1484 }
1485
1486 pub(super) fn solve(
1487 sys: &ArrowSchurSystem,
1488 ridge_t: f64,
1489 ridge_beta: f64,
1490 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1491 let n = sys.rows.len();
1492 let d = sys.d;
1493 let k = sys.k;
1494 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1495 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1496
1497 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1498 .and_then(|ctx| ctx.new_stream().ok())
1499 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1500 let solver =
1501 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1502 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1503
1504 let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1506 let mut d_dev = stream
1507 .clone_htod(&d_host)
1508 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1509 let mut b_dev = stream
1510 .clone_htod(&b_host)
1511 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1512 let mut g_dev = stream
1513 .clone_htod(&g_host)
1514 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1515
1516 let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1524 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1525 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1529 row: idx,
1530 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1531 });
1532 }
1533
1534 trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1537 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1540
1541 let schur_init: Vec<f64> = {
1560 let mut tmp = Vec::with_capacity(k * k);
1561 for col in 0..k {
1562 for row in 0..k {
1563 let mut v = sys.hbb[[row, col]];
1564 if row == col {
1565 v += ridge_beta;
1566 }
1567 tmp.push(v);
1568 }
1569 }
1570 tmp
1571 };
1572 let mut schur_dev = stream
1573 .clone_htod(&schur_init)
1574 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1575 let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1576 let mut rhs_dev = stream
1577 .clone_htod(&rhs_init)
1578 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1579
1580 accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1581
1582 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1584 if info != 0 {
1585 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1586 reason: format!("Schur Cholesky failed at pivot {info}"),
1587 });
1588 }
1589 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1591 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1592 let delta_beta_host = stream
1593 .clone_dtoh(&rhs_dev)
1594 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1595 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1596
1597 accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1605 trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1606
1607 let x_host = stream
1608 .clone_dtoh(&g_dev)
1609 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1610 let mut delta_t = Array1::<f64>::zeros(n * d);
1611 for (i, v) in x_host.iter().enumerate() {
1612 delta_t[i] = -*v;
1613 }
1614
1615 let l_local_host = stream
1617 .clone_dtoh(&d_dev)
1618 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1619 let l_schur_host = stream
1620 .clone_dtoh(&schur_dev)
1621 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1622 let mut log_det = 0.0_f64;
1623 for i in 0..n {
1624 let base = i * d * d;
1625 for j in 0..d {
1626 log_det += l_local_host[base + j * d + j].ln();
1627 }
1628 }
1629 for j in 0..k {
1630 log_det += l_schur_host[j * k + j].ln();
1631 }
1632 log_det *= 2.0;
1633
1634 Ok(ArrowSchurGpuSolution {
1635 delta_t,
1636 delta_beta,
1637 log_det_hessian: log_det,
1638 })
1639 }
1640
1641 fn potrf_batched(
1642 solver: &DnHandle,
1643 stream: &Arc<CudaStream>,
1644 p: usize,
1645 batch: usize,
1646 matrices: &mut CudaSlice<f64>,
1647 ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1648 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1649 let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1650 let matrix_len = p * p;
1651 let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1652 let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1653 let mut ptrs = Vec::with_capacity(batch);
1654 for idx in 0..batch {
1655 ptrs.push(base_ptr + (idx as u64) * bytes_per);
1656 }
1657 let mut ptrs_dev = stream
1658 .clone_htod(&ptrs)
1659 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1660 let mut info_dev = stream
1661 .alloc_zeros::<i32>(batch)
1662 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1663 let status = {
1664 let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1665 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1666 unsafe {
1669 cusolver_sys::cusolverDnDpotrfBatched(
1670 solver.cu(),
1671 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1672 p_i,
1673 ptrs_ptr as *mut *mut f64,
1674 p_i,
1675 info_ptr as *mut i32,
1676 batch_i,
1677 )
1678 }
1679 };
1680 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1681 return Err(ArrowSchurGpuFailure::Unavailable);
1682 }
1683 stream
1684 .clone_dtoh(&info_dev)
1685 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1686 }
1687
1688 fn potrf_single(
1689 solver: &DnHandle,
1690 stream: &Arc<CudaStream>,
1691 p: usize,
1692 matrix: &mut CudaSlice<f64>,
1693 ) -> Result<i32, ArrowSchurGpuFailure> {
1694 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1695 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1696 let mut lwork = 0_i32;
1697 {
1698 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1699 let status = unsafe {
1701 cusolver_sys::cusolverDnDpotrf_bufferSize(
1702 solver.cu(),
1703 uplo,
1704 p_i,
1705 mat_ptr as *mut f64,
1706 p_i,
1707 &mut lwork,
1708 )
1709 };
1710 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1711 return Err(ArrowSchurGpuFailure::Unavailable);
1712 }
1713 }
1714 let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1715 let mut workspace = stream
1716 .alloc_zeros::<f64>(lwork_usize.max(1))
1717 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1718 let mut info_dev = stream
1719 .alloc_zeros::<i32>(1)
1720 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1721 {
1722 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1723 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1724 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1725 let status = unsafe {
1727 cusolver_sys::cusolverDnDpotrf(
1728 solver.cu(),
1729 uplo,
1730 p_i,
1731 mat_ptr as *mut f64,
1732 p_i,
1733 work_ptr as *mut f64,
1734 lwork,
1735 info_ptr as *mut i32,
1736 )
1737 };
1738 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1739 return Err(ArrowSchurGpuFailure::Unavailable);
1740 }
1741 }
1742 let info_host = stream
1743 .clone_dtoh(&info_dev)
1744 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1745 Ok(info_host[0])
1746 }
1747
1748 fn trsm_batched_lower_inplace(
1752 blas: &CudaBlas,
1753 stream: &Arc<CudaStream>,
1754 d: usize,
1755 n: usize,
1756 nrhs: usize,
1757 l_stack: &CudaSlice<f64>,
1758 rhs_stack: &mut CudaSlice<f64>,
1759 ) -> Result<(), ArrowSchurGpuFailure> {
1760 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1761 }
1762
1763 fn trsm_batched_lower_inplace_transposed(
1765 blas: &CudaBlas,
1766 stream: &Arc<CudaStream>,
1767 d: usize,
1768 n: usize,
1769 nrhs: usize,
1770 l_stack: &CudaSlice<f64>,
1771 rhs_stack: &mut CudaSlice<f64>,
1772 ) -> Result<(), ArrowSchurGpuFailure> {
1773 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1774 }
1775
1776 fn trsm_batched_inplace_inner(
1777 blas: &CudaBlas,
1778 stream: &Arc<CudaStream>,
1779 d: usize,
1780 n: usize,
1781 nrhs: usize,
1782 l_stack: &CudaSlice<f64>,
1783 rhs_stack: &mut CudaSlice<f64>,
1784 transposed: bool,
1785 ) -> Result<(), ArrowSchurGpuFailure> {
1786 let alpha = 1.0_f64;
1787 let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1788 let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1789 let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1790 let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1791 let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1792 let (l_base, _l_record) = l_stack.device_ptr(stream);
1793 let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1794 let mut l_ptrs = Vec::with_capacity(n);
1795 let mut rhs_ptrs = Vec::with_capacity(n);
1796 for i in 0..n {
1797 l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1798 rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1799 }
1800 let mut l_ptrs_dev = stream
1801 .clone_htod(&l_ptrs)
1802 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1803 let mut rhs_ptrs_dev = stream
1804 .clone_htod(&rhs_ptrs)
1805 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1806 let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1807 let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1808 let op = if transposed {
1809 cublasOperation_t::CUBLAS_OP_T
1810 } else {
1811 cublasOperation_t::CUBLAS_OP_N
1812 };
1813 let handle = *blas.handle();
1814 let status = unsafe {
1817 cudarc::cublas::sys::cublasDtrsmBatched(
1818 handle,
1819 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1820 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1821 op,
1822 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1823 d_i,
1824 nrhs_i,
1825 &alpha,
1826 l_ptrs_ptr as *const *const f64,
1827 d_i,
1828 rhs_ptrs_ptr as *const *mut f64,
1829 d_i,
1830 batch_i,
1831 )
1832 };
1833 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1834 return Err(ArrowSchurGpuFailure::Unavailable);
1835 }
1836 Ok(())
1837 }
1838
1839 fn trsm_single(
1842 blas: &CudaBlas,
1843 stream: &Arc<CudaStream>,
1844 n: usize,
1845 l: &CudaSlice<f64>,
1846 rhs: &mut CudaSlice<f64>,
1847 upper: bool,
1848 transposed: bool,
1849 ) -> Result<(), ArrowSchurGpuFailure> {
1850 let alpha = 1.0_f64;
1851 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1852 let handle = *blas.handle();
1853 let (l_ptr, _l_rec) = l.device_ptr(stream);
1854 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1855 let status = unsafe {
1857 cudarc::cublas::sys::cublasDtrsm_v2(
1858 handle,
1859 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1860 if upper {
1861 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1862 } else {
1863 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1864 },
1865 if transposed {
1866 cublasOperation_t::CUBLAS_OP_T
1867 } else {
1868 cublasOperation_t::CUBLAS_OP_N
1869 },
1870 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1871 n_i,
1872 1,
1873 &alpha,
1874 l_ptr as *const f64,
1875 n_i,
1876 rhs_ptr as *mut f64,
1877 n_i,
1878 )
1879 };
1880 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1881 return Err(ArrowSchurGpuFailure::Unavailable);
1882 }
1883 Ok(())
1884 }
1885
1886 fn accumulate_schur(
1890 blas: &CudaBlas,
1891 d: usize,
1892 k: usize,
1893 n: usize,
1894 y_stack: &CudaSlice<f64>,
1895 u_stack: &CudaSlice<f64>,
1896 schur: &mut CudaSlice<f64>,
1897 rhs: &mut CudaSlice<f64>,
1898 ) -> Result<(), ArrowSchurGpuFailure> {
1899 let y_block_elems = d * k;
1900 let u_block_elems = d;
1901 for i in 0..n {
1902 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1903 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1904 let gemm_cfg = GemmConfig::<f64> {
1906 transa: cublasOperation_t::CUBLAS_OP_T,
1907 transb: cublasOperation_t::CUBLAS_OP_N,
1908 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1909 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1910 k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1911 alpha: -1.0,
1912 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1913 ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1914 beta: 1.0,
1915 ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1916 };
1917 unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1919 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1920 let gemv_cfg = GemvConfig::<f64> {
1922 trans: cublasOperation_t::CUBLAS_OP_T,
1923 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1924 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1925 alpha: 1.0,
1926 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1927 incx: 1,
1928 beta: 1.0,
1929 incy: 1,
1930 };
1931 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1934 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1935 }
1936 Ok(())
1937 }
1938
1939 fn accumulate_schur_rhs_only(
1947 blas: &CudaBlas,
1948 d: usize,
1949 k: usize,
1950 n: usize,
1951 y_stack: &CudaSlice<f64>,
1952 u_stack: &CudaSlice<f64>,
1953 rhs: &mut CudaSlice<f64>,
1954 ) -> Result<(), ArrowSchurGpuFailure> {
1955 let y_block_elems = d * k;
1956 let u_block_elems = d;
1957 for i in 0..n {
1958 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1959 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1960 let gemv_cfg = GemvConfig::<f64> {
1961 trans: cublasOperation_t::CUBLAS_OP_T,
1962 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1963 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1964 alpha: 1.0,
1965 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1966 incx: 1,
1967 beta: 1.0,
1968 incy: 1,
1969 };
1970 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1973 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1974 }
1975 Ok(())
1976 }
1977
1978 fn accumulate_back_sub_rhs(
1981 blas: &CudaBlas,
1982 d: usize,
1983 k: usize,
1984 n: usize,
1985 y_stack: &CudaSlice<f64>,
1986 delta_beta: &CudaSlice<f64>,
1987 u_stack: &mut CudaSlice<f64>,
1988 ) -> Result<(), ArrowSchurGpuFailure> {
1989 let y_block_elems = d * k;
1990 let u_block_elems = d;
1991 for i in 0..n {
1992 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1993 let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1994 let gemv_cfg = GemvConfig::<f64> {
1995 trans: cublasOperation_t::CUBLAS_OP_N,
1996 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1997 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1998 alpha: 1.0,
1999 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
2000 incx: 1,
2001 beta: 1.0,
2002 incy: 1,
2003 };
2004 unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
2007 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2008 }
2009 Ok(())
2010 }
2011
2012 use std::collections::HashMap;
2028 use std::sync::Mutex;
2029
2030 struct FusedModuleCache {
2035 modules: Mutex<
2036 HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
2037 >,
2038 }
2039
2040 fn fused_module_cache() -> &'static FusedModuleCache {
2041 static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
2042 CACHE.get_or_init(|| FusedModuleCache {
2043 modules: Mutex::new(HashMap::new()),
2044 })
2045 }
2046
2047 fn fused_module_for(
2048 ctx: &Arc<CudaContext>,
2049 key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2050 ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2051 let cache = fused_module_cache();
2052 if let Ok(guard) = cache.modules.lock() {
2053 if let Some(existing) = guard.get(&key) {
2054 return Ok(existing.clone());
2055 }
2056 }
2057 let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2058 key.p_max as usize,
2059 key.r_template as usize,
2060 );
2061 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).map_err(|err| {
2062 ArrowSchurGpuFailure::SchurFactorFailed {
2063 reason: format!(
2064 "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2065 key.p_max, key.r_template
2066 ),
2067 }
2068 })?;
2069 let module = ctx
2070 .load_module(ptx)
2071 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2072 if let Ok(mut guard) = cache.modules.lock() {
2073 guard.entry(key).or_insert_with(|| module.clone());
2074 }
2075 Ok(module)
2076 }
2077
2078 const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2079extern "C" __global__ void arrow_pcg_jacobi_mul(
2080 const double* __restrict__ inv_diag,
2081 const double* __restrict__ r,
2082 double* __restrict__ z,
2083 int n
2084) {
2085 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2086 if (idx < n) {
2087 z[idx] = inv_diag[idx] * r[idx];
2088 }
2089}
2090
2091extern "C" __global__ void arrow_pcg_update_p(
2092 const double* __restrict__ z,
2093 double beta,
2094 double* __restrict__ p,
2095 int n
2096) {
2097 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2098 if (idx < n) {
2099 p[idx] = z[idx] + beta * p[idx];
2100 }
2101}
2102
2103extern "C" __global__ void arrow_sae_init(
2104 double* __restrict__ out,
2105 const double* __restrict__ x,
2106 double ridge,
2107 int n
2108) {
2109 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2110 if (idx < n) {
2111 out[idx] = ridge * x[idx];
2112 }
2113}
2114
2115extern "C" __global__ void arrow_sae_smooth_matvec(
2116 const double* __restrict__ x,
2117 double* __restrict__ out,
2118 const int* __restrict__ block_offsets,
2119 const int* __restrict__ block_m,
2120 const int* __restrict__ factor_ptr,
2121 const double* __restrict__ factors,
2122 int p,
2123 int n_blocks
2124) {
2125 int block_id = blockIdx.y;
2126 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2127 if (block_id >= n_blocks) {
2128 return;
2129 }
2130 int m = block_m[block_id];
2131 int total = m * p;
2132 if (linear >= total) {
2133 return;
2134 }
2135 int li = linear / p;
2136 int oc = linear - li * p;
2137 int off = block_offsets[block_id];
2138 int fbase = factor_ptr[block_id];
2139 double acc = 0.0;
2140 for (int lj = 0; lj < m; ++lj) {
2141 double a = factors[fbase + li * m + lj];
2142 acc += a * x[off + lj * p + oc];
2143 }
2144 out[off + li * p + oc] += acc;
2145}
2146
2147extern "C" __global__ void arrow_sae_sparse_g_matvec(
2148 const double* __restrict__ x,
2149 double* __restrict__ out,
2150 const int* __restrict__ row_off,
2151 const int* __restrict__ col_off,
2152 const int* __restrict__ rows,
2153 const int* __restrict__ cols,
2154 const int* __restrict__ data_ptr,
2155 const double* __restrict__ data,
2156 int p,
2157 int n_blocks
2158) {
2159 int block_id = blockIdx.y;
2160 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2161 if (block_id >= n_blocks) {
2162 return;
2163 }
2164 int m_i = rows[block_id];
2165 int m_j = cols[block_id];
2166 int total = m_i * p;
2167 if (linear >= total) {
2168 return;
2169 }
2170 int li = linear / p;
2171 int oc = linear - li * p;
2172 int rbase = row_off[block_id];
2173 int cbase = col_off[block_id];
2174 int dbase = data_ptr[block_id];
2175 double acc = 0.0;
2176 for (int lj = 0; lj < m_j; ++lj) {
2177 acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2178 }
2179 // #1017 — a row atom co-occurs with multiple column atoms, so several
2180 // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2181 // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2182 // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2183 // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2184 atomicAdd(&out[(rbase + li) * p + oc], acc);
2185}
2186
2187extern "C" __global__ void arrow_sae_gather_u(
2188 const double* __restrict__ x,
2189 const int* __restrict__ row_ptr,
2190 const int* __restrict__ beta_base,
2191 const double* __restrict__ phi,
2192 double* __restrict__ u,
2193 int p,
2194 int n_rows
2195) {
2196 int row = blockIdx.y;
2197 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2198 if (row >= n_rows || oc >= p) {
2199 return;
2200 }
2201 double acc = 0.0;
2202 int start = row_ptr[row];
2203 int end = row_ptr[row + 1];
2204 for (int e = start; e < end; ++e) {
2205 acc += phi[e] * x[beta_base[e] + oc];
2206 }
2207 u[row * p + oc] = acc;
2208}
2209
2210extern "C" __global__ void arrow_sae_apply_l(
2211 const double* __restrict__ u,
2212 const int* __restrict__ jac_ptr,
2213 const double* __restrict__ jac,
2214 double* __restrict__ w,
2215 int p,
2216 int max_q,
2217 int n_rows
2218) {
2219 int row = blockIdx.y;
2220 int c = blockIdx.x * blockDim.x + threadIdx.x;
2221 if (row >= n_rows) {
2222 return;
2223 }
2224 int jstart = jac_ptr[row];
2225 int q = (jac_ptr[row + 1] - jstart) / p;
2226 if (c >= q) {
2227 return;
2228 }
2229 double acc = 0.0;
2230 for (int oc = 0; oc < p; ++oc) {
2231 acc += jac[jstart + c * p + oc] * u[row * p + oc];
2232 }
2233 w[row * max_q + c] = acc;
2234}
2235
2236extern "C" __global__ void arrow_sae_apply_ainv(
2237 const double* __restrict__ ainv,
2238 const double* __restrict__ w,
2239 double* __restrict__ v,
2240 int max_q,
2241 int n_rows
2242) {
2243 int row = blockIdx.y;
2244 int c = blockIdx.x * blockDim.x + threadIdx.x;
2245 if (row >= n_rows || c >= max_q) {
2246 return;
2247 }
2248 double acc = 0.0;
2249 int base = row * max_q * max_q;
2250 for (int j = 0; j < max_q; ++j) {
2251 acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2252 }
2253 v[row * max_q + c] = acc;
2254}
2255
2256extern "C" __global__ void arrow_sae_scatter_sub(
2257 const double* __restrict__ v,
2258 const int* __restrict__ jac_ptr,
2259 const double* __restrict__ jac,
2260 const int* __restrict__ row_ptr,
2261 const int* __restrict__ beta_base,
2262 const double* __restrict__ phi,
2263 double* __restrict__ out,
2264 int p,
2265 int max_q,
2266 int n_rows
2267) {
2268 int row = blockIdx.y;
2269 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2270 if (row >= n_rows || oc >= p) {
2271 return;
2272 }
2273 int jstart = jac_ptr[row];
2274 int q = (jac_ptr[row + 1] - jstart) / p;
2275 double lt_v = 0.0;
2276 for (int c = 0; c < q; ++c) {
2277 lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2278 }
2279 int start = row_ptr[row];
2280 int end = row_ptr[row + 1];
2281 for (int e = start; e < end; ++e) {
2282 atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2283 }
2284}
2285
2286extern "C" __global__ void arrow_sae_diag_sub(
2287 double* __restrict__ diag,
2288 const double* __restrict__ ainv,
2289 const int* __restrict__ jac_ptr,
2290 const double* __restrict__ jac,
2291 const int* __restrict__ row_ptr,
2292 const int* __restrict__ beta_base,
2293 const double* __restrict__ phi,
2294 int p,
2295 int max_q,
2296 int n_rows
2297) {
2298 int row = blockIdx.y;
2299 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2300 if (row >= n_rows || oc >= p) {
2301 return;
2302 }
2303 int jstart = jac_ptr[row];
2304 int q = (jac_ptr[row + 1] - jstart) / p;
2305 int abase = row * max_q * max_q;
2306 double quad = 0.0;
2307 for (int c = 0; c < q; ++c) {
2308 double lc = jac[jstart + c * p + oc];
2309 for (int d = 0; d < q; ++d) {
2310 quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2311 }
2312 }
2313 int start = row_ptr[row];
2314 int end = row_ptr[row + 1];
2315 for (int e = start; e < end; ++e) {
2316 double pe = phi[e];
2317 atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2318 }
2319}
2320
2321/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2322 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2323 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2324 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2325 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2326
2327extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2328 const double* __restrict__ x,
2329 double* __restrict__ out,
2330 const int* __restrict__ block_offsets,
2331 const int* __restrict__ block_m,
2332 const int* __restrict__ block_r,
2333 const int* __restrict__ factor_ptr,
2334 const double* __restrict__ factors,
2335 int n_blocks
2336) {
2337 int block_id = blockIdx.y;
2338 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2339 if (block_id >= n_blocks) {
2340 return;
2341 }
2342 int m = block_m[block_id];
2343 int r = block_r[block_id];
2344 int total = m * r;
2345 if (linear >= total) {
2346 return;
2347 }
2348 int li = linear / r;
2349 int ib = linear - li * r;
2350 int off = block_offsets[block_id];
2351 int fbase = factor_ptr[block_id];
2352 double acc = 0.0;
2353 for (int lj = 0; lj < m; ++lj) {
2354 double a = factors[fbase + li * m + lj];
2355 acc += a * x[off + lj * r + ib];
2356 }
2357 out[off + li * r + ib] += acc;
2358}
2359
2360extern "C" __global__ void arrow_sae_frame_g_matvec(
2361 const double* __restrict__ x,
2362 double* __restrict__ out,
2363 const int* __restrict__ off_i,
2364 const int* __restrict__ off_j,
2365 const int* __restrict__ r_i,
2366 const int* __restrict__ r_j,
2367 const int* __restrict__ m_i,
2368 const int* __restrict__ m_j,
2369 const int* __restrict__ g_ptr,
2370 const double* __restrict__ g_data,
2371 const int* __restrict__ w_ptr,
2372 const double* __restrict__ w_data,
2373 int n_blocks
2374) {
2375 int block_id = blockIdx.y;
2376 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2377 if (block_id >= n_blocks) {
2378 return;
2379 }
2380 int ri = r_i[block_id];
2381 int rj = r_j[block_id];
2382 int mi = m_i[block_id];
2383 int mj = m_j[block_id];
2384 int total = mi * ri;
2385 if (linear >= total) {
2386 return;
2387 }
2388 int li = linear / ri; // basis row in atom i
2389 int a = linear - li * ri; // frame coord in atom i
2390 int oi = off_i[block_id];
2391 int oj = off_j[block_id];
2392 int gbase = g_ptr[block_id];
2393 int wbase = w_ptr[block_id];
2394 double acc = 0.0;
2395 for (int lj = 0; lj < mj; ++lj) {
2396 double g = g_data[gbase + li * mj + lj];
2397 if (g == 0.0) { continue; }
2398 int xj_base = oj + lj * rj;
2399 double inner = 0.0;
2400 for (int b = 0; b < rj; ++b) {
2401 inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2402 }
2403 acc += g * inner;
2404 }
2405 // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2406 // multiple co-occurring (i,j) frame blocks running concurrently on
2407 // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2408 // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2409 atomicAdd(&out[oi + li * ri + a], acc);
2410}
2411
2412/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2413 * h_i = H_tβ^(i) · x (length q_i)
2414 * s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
2415 * out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
2416 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2417 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2418extern "C" __global__ void arrow_sae_frame_apply_h(
2419 const double* __restrict__ x,
2420 const int* __restrict__ htb_ptr,
2421 const double* __restrict__ htb,
2422 const int* __restrict__ q_of,
2423 double* __restrict__ hvec,
2424 int k,
2425 int max_q,
2426 int n_rows
2427) {
2428 int row = blockIdx.y;
2429 int c = blockIdx.x * blockDim.x + threadIdx.x;
2430 if (row >= n_rows) { return; }
2431 int q = q_of[row];
2432 if (c >= q) { return; }
2433 int base = htb_ptr[row] + c * k;
2434 double acc = 0.0;
2435 for (int a = 0; a < k; ++a) {
2436 acc += htb[base + a] * x[a];
2437 }
2438 hvec[row * max_q + c] = acc;
2439}
2440
2441extern "C" __global__ void arrow_sae_frame_apply_ainv(
2442 const double* __restrict__ ainv,
2443 const double* __restrict__ hvec,
2444 const int* __restrict__ q_of,
2445 double* __restrict__ svec,
2446 int max_q,
2447 int n_rows
2448) {
2449 int row = blockIdx.y;
2450 int c = blockIdx.x * blockDim.x + threadIdx.x;
2451 if (row >= n_rows || c >= max_q) { return; }
2452 int q = q_of[row];
2453 double acc = 0.0;
2454 int abase = row * max_q * max_q;
2455 for (int j = 0; j < q; ++j) {
2456 acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2457 }
2458 svec[row * max_q + c] = acc;
2459}
2460
2461extern "C" __global__ void arrow_sae_frame_scatter_h(
2462 const double* __restrict__ svec,
2463 const int* __restrict__ htb_ptr,
2464 const double* __restrict__ htb,
2465 const int* __restrict__ q_of,
2466 double* __restrict__ out,
2467 int k,
2468 int max_q,
2469 int n_rows
2470) {
2471 int row = blockIdx.y;
2472 int a = blockIdx.x * blockDim.x + threadIdx.x;
2473 if (row >= n_rows || a >= k) { return; }
2474 int q = q_of[row];
2475 int hbase = htb_ptr[row];
2476 double acc = 0.0;
2477 for (int c = 0; c < q; ++c) {
2478 acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2479 }
2480 atomicAdd(&out[a], -acc);
2481}
2482
2483/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2484extern "C" __global__ void arrow_sae_frame_diag_sub(
2485 double* __restrict__ diag,
2486 const double* __restrict__ ainv,
2487 const int* __restrict__ htb_ptr,
2488 const double* __restrict__ htb,
2489 const int* __restrict__ q_of,
2490 int k,
2491 int max_q,
2492 int n_rows
2493) {
2494 int row = blockIdx.y;
2495 int a = blockIdx.x * blockDim.x + threadIdx.x;
2496 if (row >= n_rows || a >= k) { return; }
2497 int q = q_of[row];
2498 int hbase = htb_ptr[row];
2499 int abase = row * max_q * max_q;
2500 double quad = 0.0;
2501 for (int c = 0; c < q; ++c) {
2502 double hc = htb[hbase + c * k + a];
2503 for (int d = 0; d < q; ++d) {
2504 quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2505 }
2506 }
2507 atomicAdd(&diag[a], -quad);
2508}
2509"#;
2510
2511 fn pcg_vector_module(
2512 ctx: &Arc<CudaContext>,
2513 ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2514 static CACHE: gam_gpu::device_cache::PtxModuleCache =
2515 gam_gpu::device_cache::PtxModuleCache::new();
2516 CACHE
2517 .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2518 .map_err(|err| {
2519 log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2525 ArrowSchurGpuFailure::Unavailable
2526 })
2527 }
2528
2529 fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2530 let threads = 256u32;
2531 let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2532 Ok(LaunchConfig {
2533 grid_dim: (blocks, 1, 1),
2534 block_dim: (threads, 1, 1),
2535 shared_mem_bytes: 0,
2536 })
2537 }
2538
2539 fn launch_jacobi_mul(
2540 stream: &Arc<CudaStream>,
2541 module: &Arc<CudaModule>,
2542 inv_diag: &CudaSlice<f64>,
2543 r: &CudaSlice<f64>,
2544 z: &mut CudaSlice<f64>,
2545 n: usize,
2546 ) -> Result<(), ArrowSchurGpuFailure> {
2547 let kernel = module
2548 .load_function("arrow_pcg_jacobi_mul")
2549 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2550 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2551 let mut builder = stream.launch_builder(&kernel);
2552 builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2553 unsafe { builder.launch(pcg_launch_config(n)?) }
2556 .map(drop)
2557 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2558 }
2559
2560 fn launch_update_p(
2561 stream: &Arc<CudaStream>,
2562 module: &Arc<CudaModule>,
2563 z: &CudaSlice<f64>,
2564 beta: f64,
2565 p: &mut CudaSlice<f64>,
2566 n: usize,
2567 ) -> Result<(), ArrowSchurGpuFailure> {
2568 let kernel = module
2569 .load_function("arrow_pcg_update_p")
2570 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2571 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2572 let mut builder = stream.launch_builder(&kernel);
2573 builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2574 unsafe { builder.launch(pcg_launch_config(n)?) }
2577 .map(drop)
2578 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2579 }
2580
2581 struct DeviceSaePcgBuffers {
2582 row_ptr: CudaSlice<i32>,
2583 beta_base: CudaSlice<i32>,
2584 phi: CudaSlice<f64>,
2585 jac_ptr: CudaSlice<i32>,
2586 jac: CudaSlice<f64>,
2587 smooth_offsets: CudaSlice<i32>,
2588 smooth_m: CudaSlice<i32>,
2589 smooth_ptr: CudaSlice<i32>,
2590 smooth_data: CudaSlice<f64>,
2591 g_row_off: CudaSlice<i32>,
2592 g_col_off: CudaSlice<i32>,
2593 g_rows: CudaSlice<i32>,
2594 g_cols: CudaSlice<i32>,
2595 g_ptr: CudaSlice<i32>,
2596 g_data: CudaSlice<f64>,
2597 ainv: CudaSlice<f64>,
2598 u: CudaSlice<f64>,
2599 w: CudaSlice<f64>,
2600 v: CudaSlice<f64>,
2601 n_rows: usize,
2602 p: usize,
2603 k: usize,
2604 max_q: usize,
2605 smooth_blocks: usize,
2606 g_blocks: usize,
2607 }
2608
2609 fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2610 to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2611 }
2612
2613 fn sae_penalty_diag_host(
2614 data: &DeviceSaePcgData,
2615 ridge_beta: f64,
2616 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2617 let mut diag = vec![ridge_beta; data.beta_dim];
2618 for block in &data.smooth_blocks {
2619 let (rows, cols) = block.factor_a.dim();
2620 if rows != cols {
2621 return Err(ArrowSchurGpuFailure::Unavailable);
2622 }
2623 for row in 0..rows {
2624 let coeff = block.factor_a[[row, row]];
2625 let base = block
2626 .global_offset
2627 .checked_add(
2628 row.checked_mul(data.p)
2629 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2630 )
2631 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2632 let end = base
2633 .checked_add(data.p)
2634 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2635 if end > diag.len() {
2636 return Err(ArrowSchurGpuFailure::Unavailable);
2637 }
2638 for channel in 0..data.p {
2639 diag[base + channel] += coeff;
2640 }
2641 }
2642 }
2643 for block in &data.sparse_g_blocks {
2644 if block.row_off != block.col_off {
2645 continue;
2646 }
2647 let (rows, cols) = block.data.dim();
2648 for row in 0..rows.min(cols) {
2649 let coeff = block.data[[row, row]];
2650 let beta_row = block
2651 .row_off
2652 .checked_add(row)
2653 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2654 let base = beta_row
2655 .checked_mul(data.p)
2656 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2657 let end = base
2658 .checked_add(data.p)
2659 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2660 if end > diag.len() {
2661 return Err(ArrowSchurGpuFailure::Unavailable);
2662 }
2663 for channel in 0..data.p {
2664 diag[base + channel] += coeff;
2665 }
2666 }
2667 }
2668 Ok(diag)
2669 }
2670
2671 fn flatten_device_sae_data(
2672 sys: &ArrowSchurSystem,
2673 data: &DeviceSaePcgData,
2674 ridge_t: f64,
2675 stream: &Arc<CudaStream>,
2676 ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2677 let n_rows = sys.rows.len();
2678 let p = data.p;
2679 let k = data.beta_dim;
2680 if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2681 return Err(ArrowSchurGpuFailure::Unavailable);
2682 }
2683
2684 let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2685 let mut beta_base_host = Vec::<i32>::new();
2686 let mut phi_host = Vec::<f64>::new();
2687 row_ptr_host.push(0_i32);
2688 for row in data.a_phi.iter() {
2689 for &(base, phi) in row {
2690 beta_base_host.push(checked_i32(base)?);
2691 phi_host.push(phi);
2692 }
2693 row_ptr_host.push(checked_i32(beta_base_host.len())?);
2694 }
2695
2696 let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2697 let mut jac_host = Vec::<f64>::new();
2698 let mut max_q = 0usize;
2699 jac_ptr_host.push(0_i32);
2700 for row_jac in data.local_jac.iter() {
2701 if row_jac.len() % p != 0 {
2702 return Err(ArrowSchurGpuFailure::Unavailable);
2703 }
2704 max_q = max_q.max(row_jac.len() / p);
2705 jac_host.extend_from_slice(row_jac);
2706 jac_ptr_host.push(checked_i32(jac_host.len())?);
2707 }
2708 if max_q == 0 {
2709 return Err(ArrowSchurGpuFailure::Unavailable);
2710 }
2711
2712 let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2713 let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2714 let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2715 let mut smooth_data_host = Vec::<f64>::new();
2716 smooth_ptr_host.push(0_i32);
2717 for block in &data.smooth_blocks {
2718 let (rows, cols) = block.factor_a.dim();
2719 if rows != cols {
2720 return Err(ArrowSchurGpuFailure::Unavailable);
2721 }
2722 smooth_offsets_host.push(checked_i32(block.global_offset)?);
2723 smooth_m_host.push(checked_i32(rows)?);
2724 for r in 0..rows {
2725 for c in 0..cols {
2726 smooth_data_host.push(block.factor_a[[r, c]]);
2727 }
2728 }
2729 smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2730 }
2731
2732 let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2733 let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2734 let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2735 let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2736 let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2737 let mut g_data_host = Vec::<f64>::new();
2738 g_ptr_host.push(0_i32);
2739 for block in &data.sparse_g_blocks {
2740 let (rows, cols) = block.data.dim();
2741 g_row_off_host.push(checked_i32(block.row_off)?);
2742 g_col_off_host.push(checked_i32(block.col_off)?);
2743 g_rows_host.push(checked_i32(rows)?);
2744 g_cols_host.push(checked_i32(cols)?);
2745 for r in 0..rows {
2746 for c in 0..cols {
2747 g_data_host.push(block.data[[r, c]]);
2748 }
2749 }
2750 g_ptr_host.push(checked_i32(g_data_host.len())?);
2751 }
2752
2753 let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2754 for (row_idx, row) in sys.rows.iter().enumerate() {
2755 let q = data.local_jac[row_idx].len() / p;
2756 if row.htt.dim() != (q, q) {
2757 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2758 reason: format!(
2759 "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2760 row.htt.dim()
2761 ),
2762 });
2763 }
2764 let mut block = row.htt.clone();
2765 for d in 0..q {
2766 block[[d, d]] += ridge_t;
2767 }
2768 let factor = gam_linalg::triangular::cholesky_factor_in_place(
2769 block.view(),
2770 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2771 )
2772 .ok_or_else(|| {
2773 ArrowSchurGpuFailure::RidgeBumpRequired {
2776 row: row_idx,
2777 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2778 }
2779 })?;
2780 for col in 0..q {
2781 let mut e = Array1::<f64>::zeros(q);
2782 e[col] = 1.0;
2783 let solved =
2784 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2785 for r in 0..q {
2786 ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2787 }
2788 }
2789 }
2790
2791 Ok(DeviceSaePcgBuffers {
2792 row_ptr: stream
2793 .clone_htod(&row_ptr_host)
2794 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795 beta_base: stream
2796 .clone_htod(&beta_base_host)
2797 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798 phi: stream
2799 .clone_htod(&phi_host)
2800 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801 jac_ptr: stream
2802 .clone_htod(&jac_ptr_host)
2803 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804 jac: stream
2805 .clone_htod(&jac_host)
2806 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2807 smooth_offsets: stream
2808 .clone_htod(&smooth_offsets_host)
2809 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2810 smooth_m: stream
2811 .clone_htod(&smooth_m_host)
2812 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2813 smooth_ptr: stream
2814 .clone_htod(&smooth_ptr_host)
2815 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2816 smooth_data: stream
2817 .clone_htod(&smooth_data_host)
2818 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2819 g_row_off: stream
2820 .clone_htod(&g_row_off_host)
2821 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2822 g_col_off: stream
2823 .clone_htod(&g_col_off_host)
2824 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2825 g_rows: stream
2826 .clone_htod(&g_rows_host)
2827 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2828 g_cols: stream
2829 .clone_htod(&g_cols_host)
2830 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2831 g_ptr: stream
2832 .clone_htod(&g_ptr_host)
2833 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2834 g_data: stream
2835 .clone_htod(&g_data_host)
2836 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2837 ainv: stream
2838 .clone_htod(&ainv_host)
2839 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2840 u: stream
2841 .alloc_zeros::<f64>(n_rows * p)
2842 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2843 w: stream
2844 .alloc_zeros::<f64>(n_rows * max_q)
2845 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2846 v: stream
2847 .alloc_zeros::<f64>(n_rows * max_q)
2848 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2849 n_rows,
2850 p,
2851 k,
2852 max_q,
2853 smooth_blocks: data.smooth_blocks.len(),
2854 g_blocks: data.sparse_g_blocks.len(),
2855 })
2856 }
2857
2858 fn launch_sae_init(
2859 stream: &Arc<CudaStream>,
2860 module: &Arc<CudaModule>,
2861 out: &mut CudaSlice<f64>,
2862 x: &CudaSlice<f64>,
2863 ridge: f64,
2864 n: usize,
2865 ) -> Result<(), ArrowSchurGpuFailure> {
2866 let kernel = module
2867 .load_function("arrow_sae_init")
2868 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2869 let n_i32 = checked_i32(n)?;
2870 let mut builder = stream.launch_builder(&kernel);
2871 builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2872 unsafe { builder.launch(pcg_launch_config(n)?) }
2876 .map(drop)
2877 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2878 }
2879
2880 fn launch_sae_penalty_matvec(
2881 stream: &Arc<CudaStream>,
2882 module: &Arc<CudaModule>,
2883 buffers: &mut DeviceSaePcgBuffers,
2884 x: &CudaSlice<f64>,
2885 out: &mut CudaSlice<f64>,
2886 ridge_beta: f64,
2887 ) -> Result<(), ArrowSchurGpuFailure> {
2888 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2889 if buffers.smooth_blocks > 0 {
2890 let kernel = module
2891 .load_function("arrow_sae_smooth_matvec")
2892 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2893 let max_m = buffers.k;
2894 let p_i32 = checked_i32(buffers.p)?;
2895 let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2896 let cfg = LaunchConfig {
2897 grid_dim: (
2898 ((max_m as u32).saturating_add(255) / 256).max(1),
2899 checked_i32(buffers.smooth_blocks)? as u32,
2900 1,
2901 ),
2902 block_dim: (256, 1, 1),
2903 shared_mem_bytes: 0,
2904 };
2905 let mut builder = stream.launch_builder(&kernel);
2906 builder
2907 .arg(x)
2908 .arg(&mut *out)
2909 .arg(&buffers.smooth_offsets)
2910 .arg(&buffers.smooth_m)
2911 .arg(&buffers.smooth_ptr)
2912 .arg(&buffers.smooth_data)
2913 .arg(&p_i32)
2914 .arg(&blocks_i32);
2915 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2920 }
2921 if buffers.g_blocks > 0 {
2922 let kernel = module
2923 .load_function("arrow_sae_sparse_g_matvec")
2924 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2925 let max_work = buffers
2926 .k
2927 .checked_div(buffers.p)
2928 .unwrap_or(0)
2929 .saturating_mul(buffers.p);
2930 let p_i32 = checked_i32(buffers.p)?;
2931 let blocks_i32 = checked_i32(buffers.g_blocks)?;
2932 let cfg = LaunchConfig {
2933 grid_dim: (
2934 ((max_work as u32).saturating_add(255) / 256).max(1),
2935 checked_i32(buffers.g_blocks)? as u32,
2936 1,
2937 ),
2938 block_dim: (256, 1, 1),
2939 shared_mem_bytes: 0,
2940 };
2941 let mut builder = stream.launch_builder(&kernel);
2942 builder
2943 .arg(x)
2944 .arg(&mut *out)
2945 .arg(&buffers.g_row_off)
2946 .arg(&buffers.g_col_off)
2947 .arg(&buffers.g_rows)
2948 .arg(&buffers.g_cols)
2949 .arg(&buffers.g_ptr)
2950 .arg(&buffers.g_data)
2951 .arg(&p_i32)
2952 .arg(&blocks_i32);
2953 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2958 }
2959 Ok(())
2960 }
2961
2962 fn launch_sae_row_schur_sub(
2963 stream: &Arc<CudaStream>,
2964 module: &Arc<CudaModule>,
2965 buffers: &mut DeviceSaePcgBuffers,
2966 x: &CudaSlice<f64>,
2967 out: &mut CudaSlice<f64>,
2968 ) -> Result<(), ArrowSchurGpuFailure> {
2969 let p_i32 = checked_i32(buffers.p)?;
2970 let max_q_i32 = checked_i32(buffers.max_q)?;
2971 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2972 let cfg_p_rows = LaunchConfig {
2973 grid_dim: (
2974 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2975 checked_i32(buffers.n_rows)? as u32,
2976 1,
2977 ),
2978 block_dim: (256, 1, 1),
2979 shared_mem_bytes: 0,
2980 };
2981 let gather = module
2982 .load_function("arrow_sae_gather_u")
2983 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2984 {
2985 let mut builder = stream.launch_builder(&gather);
2986 builder
2987 .arg(x)
2988 .arg(&buffers.row_ptr)
2989 .arg(&buffers.beta_base)
2990 .arg(&buffers.phi)
2991 .arg(&mut buffers.u)
2992 .arg(&p_i32)
2993 .arg(&n_rows_i32);
2994 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2998 }
2999
3000 let cfg_q_rows = LaunchConfig {
3001 grid_dim: (
3002 ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
3003 checked_i32(buffers.n_rows)? as u32,
3004 1,
3005 ),
3006 block_dim: (256, 1, 1),
3007 shared_mem_bytes: 0,
3008 };
3009 let apply_l = module
3010 .load_function("arrow_sae_apply_l")
3011 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3012 {
3013 let mut builder = stream.launch_builder(&apply_l);
3014 builder
3015 .arg(&buffers.u)
3016 .arg(&buffers.jac_ptr)
3017 .arg(&buffers.jac)
3018 .arg(&mut buffers.w)
3019 .arg(&p_i32)
3020 .arg(&max_q_i32)
3021 .arg(&n_rows_i32);
3022 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3026 }
3027
3028 let apply_ainv = module
3029 .load_function("arrow_sae_apply_ainv")
3030 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3031 {
3032 let mut builder = stream.launch_builder(&apply_ainv);
3033 builder
3034 .arg(&buffers.ainv)
3035 .arg(&buffers.w)
3036 .arg(&mut buffers.v)
3037 .arg(&max_q_i32)
3038 .arg(&n_rows_i32);
3039 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3043 }
3044
3045 let scatter = module
3046 .load_function("arrow_sae_scatter_sub")
3047 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3048 {
3049 let mut builder = stream.launch_builder(&scatter);
3050 builder
3051 .arg(&buffers.v)
3052 .arg(&buffers.jac_ptr)
3053 .arg(&buffers.jac)
3054 .arg(&buffers.row_ptr)
3055 .arg(&buffers.beta_base)
3056 .arg(&buffers.phi)
3057 .arg(out)
3058 .arg(&p_i32)
3059 .arg(&max_q_i32)
3060 .arg(&n_rows_i32);
3061 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3065 }
3066 Ok(())
3067 }
3068
3069 fn launch_sae_diag_sub(
3070 stream: &Arc<CudaStream>,
3071 module: &Arc<CudaModule>,
3072 buffers: &DeviceSaePcgBuffers,
3073 diag: &mut CudaSlice<f64>,
3074 ) -> Result<(), ArrowSchurGpuFailure> {
3075 let kernel = module
3076 .load_function("arrow_sae_diag_sub")
3077 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3078 let p_i32 = checked_i32(buffers.p)?;
3079 let max_q_i32 = checked_i32(buffers.max_q)?;
3080 let n_rows_i32 = checked_i32(buffers.n_rows)?;
3081 let cfg = LaunchConfig {
3082 grid_dim: (
3083 ((buffers.p as u32).saturating_add(255) / 256).max(1),
3084 checked_i32(buffers.n_rows)? as u32,
3085 1,
3086 ),
3087 block_dim: (256, 1, 1),
3088 shared_mem_bytes: 0,
3089 };
3090 let mut builder = stream.launch_builder(&kernel);
3091 builder
3092 .arg(diag)
3093 .arg(&buffers.ainv)
3094 .arg(&buffers.jac_ptr)
3095 .arg(&buffers.jac)
3096 .arg(&buffers.row_ptr)
3097 .arg(&buffers.beta_base)
3098 .arg(&buffers.phi)
3099 .arg(&p_i32)
3100 .arg(&max_q_i32)
3101 .arg(&n_rows_i32);
3102 unsafe { builder.launch(cfg) }
3106 .map(drop)
3107 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3108 }
3109
3110 fn launch_sae_matvec(
3111 stream: &Arc<CudaStream>,
3112 module: &Arc<CudaModule>,
3113 buffers: &mut DeviceSaePcgBuffers,
3114 x: &CudaSlice<f64>,
3115 out: &mut CudaSlice<f64>,
3116 ridge_beta: f64,
3117 ) -> Result<(), ArrowSchurGpuFailure> {
3118 launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3119 launch_sae_row_schur_sub(stream, module, buffers, x, out)
3120 }
3121
3122 fn pack_fused_host(
3127 sys: &ArrowSchurSystem,
3128 ridge_t: f64,
3129 p_max: usize,
3130 r_template: usize,
3131 ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3132 let n = sys.rows.len();
3133 let d = sys.d;
3134 let k = sys.k;
3135 let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3136 let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3137 let mut g_buf = vec![0.0_f64; n * p_max];
3138 for (i, row) in sys.rows.iter().enumerate() {
3139 for col in 0..d {
3141 let base = (i * p_max + col) * p_max;
3142 for r in 0..d {
3143 let mut value = row.htt[[r, col]];
3144 if r == col {
3145 value += ridge_t;
3146 }
3147 d_buf[base + r] = value;
3148 }
3149 }
3150 for col in 0..k {
3158 let base = (i * r_template + col) * p_max;
3159 for r in 0..d {
3160 b_buf[base + r] = row.htbeta[[r, col]];
3161 }
3162 }
3163 let g_base = i * p_max;
3165 for r in 0..d {
3166 g_buf[g_base + r] = row.gt[r];
3167 }
3168 }
3169 (d_buf, b_buf, g_buf)
3170 }
3171
3172 pub(super) struct ResidentArrowFrame {
3199 n: usize,
3200 d: usize,
3201 k: usize,
3202 stream: Arc<CudaStream>,
3203 blas: CudaBlas,
3204 l_dev: CudaSlice<f64>,
3207 y_dev: CudaSlice<f64>,
3210 schur_dev: CudaSlice<f64>,
3213 log_det_hessian: f64,
3216 }
3217
3218 impl ResidentArrowFrame {
3219 pub(super) fn new(
3223 sys: &ArrowSchurSystem,
3224 ridge_t: f64,
3225 ridge_beta: f64,
3226 ) -> Result<Self, ArrowSchurGpuFailure> {
3227 if ridge_t.is_nan() || ridge_beta.is_nan() {
3228 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3229 reason: "ridge is NaN".to_string(),
3230 });
3231 }
3232 let n = sys.rows.len();
3233 let d = sys.d;
3234 let k = sys.k;
3235 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3236 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3237 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3238 .and_then(|ctx| ctx.new_stream().ok())
3239 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3240 let solver =
3241 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3242 let blas =
3243 CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3244
3245 let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3247 let mut l_dev = stream
3248 .clone_htod(&d_host)
3249 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3250 let mut y_dev = stream
3251 .clone_htod(&b_host)
3252 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3253
3254 let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3256 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3257 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3261 row: idx,
3262 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3263 });
3264 }
3265
3266 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3268
3269 let schur_init: Vec<f64> = {
3274 let mut tmp = Vec::with_capacity(k * k);
3275 for col in 0..k {
3276 for row in 0..k {
3277 let mut v = sys.hbb[[row, col]];
3278 if row == col {
3279 v += ridge_beta;
3280 }
3281 tmp.push(v);
3282 }
3283 }
3284 tmp
3285 };
3286 let mut schur_dev = stream
3287 .clone_htod(&schur_init)
3288 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3289 let zero_u = stream
3292 .clone_htod(&vec![0.0_f64; n * d])
3293 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3294 let mut throwaway_rhs = stream
3295 .clone_htod(&vec![0.0_f64; k])
3296 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3297 accumulate_schur(
3298 &blas,
3299 d,
3300 k,
3301 n,
3302 &y_dev,
3303 &zero_u,
3304 &mut schur_dev,
3305 &mut throwaway_rhs,
3306 )?;
3307 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3308 if info != 0 {
3309 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3310 reason: format!("Schur Cholesky failed at pivot {info}"),
3311 });
3312 }
3313
3314 let l_local_host = stream
3316 .clone_dtoh(&l_dev)
3317 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3318 let l_schur_host = stream
3319 .clone_dtoh(&schur_dev)
3320 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3321 let mut log_det = 0.0_f64;
3322 for i in 0..n {
3323 let base = i * d * d;
3324 for j in 0..d {
3325 log_det += l_local_host[base + j * d + j].ln();
3326 }
3327 }
3328 for j in 0..k {
3329 log_det += l_schur_host[j * k + j].ln();
3330 }
3331 log_det *= 2.0;
3332
3333 Ok(Self {
3334 n,
3335 d,
3336 k,
3337 stream,
3338 blas,
3339 l_dev,
3340 y_dev,
3341 schur_dev,
3342 log_det_hessian: log_det,
3343 })
3344 }
3345
3346 #[inline]
3347 pub(super) fn log_det_hessian(&self) -> f64 {
3348 self.log_det_hessian
3349 }
3350
3351 pub(super) fn solve_gradient(
3355 &self,
3356 g_t: &[f64],
3357 g_beta: &[f64],
3358 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3359 let n = self.n;
3360 let d = self.d;
3361 let k = self.k;
3362 if g_t.len() != n * d || g_beta.len() != k {
3363 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3364 reason: format!(
3365 "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3366 g_t.len(),
3367 n * d,
3368 g_beta.len(),
3369 k
3370 ),
3371 });
3372 }
3373 let mut u_dev = self
3375 .stream
3376 .clone_htod(&g_t.to_vec())
3377 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3378 trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3379
3380 let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3383 let mut rhs_dev = self
3384 .stream
3385 .clone_htod(&rhs_init)
3386 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3387 accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3388
3389 trsm_single(
3391 &self.blas,
3392 &self.stream,
3393 k,
3394 &self.schur_dev,
3395 &mut rhs_dev,
3396 false,
3397 false,
3398 )?;
3399 trsm_single(
3400 &self.blas,
3401 &self.stream,
3402 k,
3403 &self.schur_dev,
3404 &mut rhs_dev,
3405 false,
3406 true,
3407 )?;
3408 let delta_beta_host = self
3409 .stream
3410 .clone_dtoh(&rhs_dev)
3411 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3412 let delta_beta = Array1::from_vec(delta_beta_host);
3413
3414 accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3416 trsm_batched_lower_inplace_transposed(
3417 &self.blas,
3418 &self.stream,
3419 d,
3420 n,
3421 1,
3422 &self.l_dev,
3423 &mut u_dev,
3424 )?;
3425 let x_host = self
3426 .stream
3427 .clone_dtoh(&u_dev)
3428 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3429 let mut delta_t = Array1::<f64>::zeros(n * d);
3430 for (i, v) in x_host.iter().enumerate() {
3431 delta_t[i] = -*v;
3432 }
3433
3434 Ok(ArrowSchurGpuSolution {
3435 delta_t,
3436 delta_beta,
3437 log_det_hessian: self.log_det_hessian,
3438 })
3439 }
3440 }
3441
3442 pub(super) fn solve_fused(
3443 sys: &ArrowSchurSystem,
3444 ridge_t: f64,
3445 ridge_beta: f64,
3446 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3447 let n = sys.rows.len();
3448 let d = sys.d;
3449 let k = sys.k;
3450 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3451 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3452 let p_max = plan.p_max;
3453 let r_template = plan.r_template;
3454
3455 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3456 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3457 )
3458 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3459 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3460 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3461 let stream = ctx
3462 .new_stream()
3463 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3464 let cap = &runtime.device.capability;
3465 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3466 cc_major: cap.compute_major,
3467 cc_minor: cap.compute_minor,
3468 p_max: p_max as u32,
3469 r_template: r_template as u32,
3470 };
3471 let module = fused_module_for(&ctx, key)?;
3472 let forward = module
3473 .load_function("arrow_schur_forward_pgroup")
3474 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3475 let back_sub = module
3476 .load_function("arrow_schur_back_sub_pgroup")
3477 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3478
3479 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3481 let d_dev = stream
3482 .clone_htod(&d_host)
3483 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3484 let b_dev = stream
3485 .clone_htod(&b_host)
3486 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487 let g_dev = stream
3488 .clone_htod(&g_host)
3489 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3490 let mut l_out = stream
3491 .alloc_zeros::<f64>(n * p_max * p_max)
3492 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3493 let mut u_out = stream
3494 .alloc_zeros::<f64>(n * p_max)
3495 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3496 let mut y_out = stream
3497 .alloc_zeros::<f64>(n * p_max * r_template)
3498 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3499 let mut partial_s = stream
3500 .alloc_zeros::<f64>(plan.partial_s_doubles)
3501 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3502 let mut partial_r = stream
3503 .alloc_zeros::<f64>(plan.partial_r_doubles)
3504 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3505 let mut status_dev = stream
3506 .alloc_zeros::<i32>(n)
3507 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3508
3509 let cfg = LaunchConfig {
3511 grid_dim: (plan.blocks, 1, 1),
3512 block_dim: (plan.threads_per_block, 1, 1),
3513 shared_mem_bytes: 0,
3514 };
3515 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3516 let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3517 let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3518 let ridge_arg = ridge_t;
3519 {
3520 let mut builder = stream.launch_builder(&forward);
3521 builder
3522 .arg(&d_dev)
3523 .arg(&b_dev)
3524 .arg(&g_dev)
3525 .arg(&n_i32)
3526 .arg(&p_i32)
3527 .arg(&r_i32)
3528 .arg(&ridge_arg)
3529 .arg(&mut l_out)
3530 .arg(&mut u_out)
3531 .arg(&mut y_out)
3532 .arg(&mut partial_s)
3533 .arg(&mut partial_r)
3534 .arg(&mut status_dev);
3535 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3539 }
3540 stream
3541 .synchronize()
3542 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3543
3544 let status_host = stream
3546 .clone_dtoh(&status_dev)
3547 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3548 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3549 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3553 row,
3554 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3555 });
3556 }
3557
3558 let partial_s_host = stream
3560 .clone_dtoh(&partial_s)
3561 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3562 let partial_r_host = stream
3563 .clone_dtoh(&partial_r)
3564 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3565 let mut schur_host = vec![0.0_f64; k * k];
3566 for col in 0..k {
3567 for row in 0..k {
3568 let mut v = sys.hbb[[row, col]];
3569 if row == col {
3570 v += ridge_beta;
3571 }
3572 schur_host[col * k + row] = v;
3573 }
3574 }
3575 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3576 for i in 0..n {
3577 let s_base = i * r_template * r_template;
3580 for col in 0..k {
3581 let col_base = s_base + col * r_template;
3582 let dst_col_base = col * k;
3583 for row in 0..k {
3584 schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3585 }
3586 }
3587 let r_base = i * r_template;
3588 for a in 0..k {
3589 rhs_host[a] += partial_r_host[r_base + a];
3590 }
3591 }
3592
3593 let mut schur_dev = stream
3595 .clone_htod(&schur_host)
3596 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3597 let mut rhs_dev = stream
3598 .clone_htod(&rhs_host)
3599 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3600 let solver =
3601 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3602 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3603 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3604 if info != 0 {
3605 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3606 reason: format!("fused Schur Cholesky failed at pivot {info}"),
3607 });
3608 }
3609 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3610 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3611 let delta_beta_host = stream
3612 .clone_dtoh(&rhs_dev)
3613 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3614 let delta_beta = Array1::from_vec(delta_beta_host.clone());
3615
3616 let mut delta_t_dev = stream
3618 .alloc_zeros::<f64>(n * p_max)
3619 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3620 let back_cfg = LaunchConfig {
3621 grid_dim: (plan.blocks, 1, 1),
3622 block_dim: (plan.threads_per_block, 1, 1),
3623 shared_mem_bytes: 0,
3624 };
3625 {
3626 let mut builder = stream.launch_builder(&back_sub);
3627 builder
3628 .arg(&l_out)
3629 .arg(&u_out)
3630 .arg(&y_out)
3631 .arg(&rhs_dev)
3632 .arg(&n_i32)
3633 .arg(&p_i32)
3634 .arg(&r_i32)
3635 .arg(&mut delta_t_dev);
3636 unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3640 }
3641 stream
3642 .synchronize()
3643 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3644
3645 let delta_t_host = stream
3646 .clone_dtoh(&delta_t_dev)
3647 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3648 let mut delta_t = Array1::<f64>::zeros(n * d);
3649 for i in 0..n {
3650 let src_base = i * p_max;
3651 let dst_base = i * d;
3652 for r in 0..d {
3653 delta_t[dst_base + r] = delta_t_host[src_base + r];
3654 }
3655 }
3656
3657 let l_local_host = stream
3659 .clone_dtoh(&l_out)
3660 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3661 let l_schur_host = stream
3662 .clone_dtoh(&schur_dev)
3663 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3664 let mut log_det = 0.0_f64;
3665 for i in 0..n {
3666 let base = i * p_max * p_max;
3667 for j in 0..d {
3668 log_det += l_local_host[base + j * p_max + j].ln();
3669 }
3670 }
3671 for j in 0..k {
3672 log_det += l_schur_host[j * k + j].ln();
3673 }
3674 log_det *= 2.0;
3675
3676 Ok(ArrowSchurGpuSolution {
3677 delta_t,
3678 delta_beta,
3679 log_det_hessian: log_det,
3680 })
3681 }
3682
3683 pub(super) fn build_schur_matvec_backend(
3693 sys: &ArrowSchurSystem,
3694 ridge_t: f64,
3695 ridge_beta: f64,
3696 ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3697 let n = sys.rows.len();
3698 let d = sys.d;
3699 let k = sys.k;
3700 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3701 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3702 let p_max = plan.p_max;
3703 let r_template = plan.r_template;
3704
3705 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3706 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3707 )
3708 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3709 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3710 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3711 let stream = ctx
3712 .new_stream()
3713 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3714 let cap = &runtime.device.capability;
3715 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3716 cc_major: cap.compute_major,
3717 cc_minor: cap.compute_minor,
3718 p_max: p_max as u32,
3719 r_template: r_template as u32,
3720 };
3721 let module = fused_module_for(&ctx, key)?;
3722 let forward = module
3723 .load_function("arrow_schur_forward_pgroup")
3724 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3725
3726 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3727 let d_dev = stream
3728 .clone_htod(&d_host)
3729 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3730 let b_dev = stream
3731 .clone_htod(&b_host)
3732 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3733 let g_dev = stream
3734 .clone_htod(&g_host)
3735 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3736 let mut l_out = stream
3737 .alloc_zeros::<f64>(n * p_max * p_max)
3738 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3739 let mut u_out = stream
3740 .alloc_zeros::<f64>(n * p_max)
3741 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3742 let mut y_out = stream
3743 .alloc_zeros::<f64>(n * p_max * r_template)
3744 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3745 let mut partial_s = stream
3746 .alloc_zeros::<f64>(plan.partial_s_doubles)
3747 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3748 let mut partial_r = stream
3749 .alloc_zeros::<f64>(plan.partial_r_doubles)
3750 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3751 let mut status_dev = stream
3752 .alloc_zeros::<i32>(n)
3753 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3754
3755 let cfg = LaunchConfig {
3756 grid_dim: (plan.blocks, 1, 1),
3757 block_dim: (plan.threads_per_block, 1, 1),
3758 shared_mem_bytes: 0,
3759 };
3760 let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3761 let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3762 let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3763 let ridge_arg = ridge_t;
3764 {
3765 let mut builder = stream.launch_builder(&forward);
3766 builder
3767 .arg(&d_dev)
3768 .arg(&b_dev)
3769 .arg(&g_dev)
3770 .arg(&n_i32)
3771 .arg(&p_i32)
3772 .arg(&r_i32)
3773 .arg(&ridge_arg)
3774 .arg(&mut l_out)
3775 .arg(&mut u_out)
3776 .arg(&mut y_out)
3777 .arg(&mut partial_s)
3778 .arg(&mut partial_r)
3779 .arg(&mut status_dev);
3780 unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3783 }
3784 stream
3785 .synchronize()
3786 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3787
3788 let status_host = stream
3789 .clone_dtoh(&status_dev)
3790 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3791 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3792 return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3796 row,
3797 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3798 });
3799 }
3800
3801 let y_host = stream
3803 .clone_dtoh(&y_out)
3804 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3805
3806 let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3809 let hbb_is_kk = sys.hbb.dim() == (k, k);
3810 let hbb_matvec_opt = sys.hbb_matvec.clone();
3811
3812 let closure: crate::arrow_schur::GpuSchurMatvec =
3813 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3814 assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3815 assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3816
3817 if let Some(ref mv) = hbb_matvec_opt {
3819 mv(x.view(), out);
3820 for a in 0..k {
3821 out[a] += ridge_beta * x[a];
3822 }
3823 } else if hbb_is_kk {
3824 for a in 0..k {
3826 let mut acc = ridge_beta * x[a];
3827 for b in 0..k {
3828 acc += hbb_host[a * k + b] * x[b];
3829 }
3830 out[a] = acc;
3831 }
3832 } else {
3833 for a in 0..k {
3834 out[a] = ridge_beta * x[a];
3835 }
3836 }
3837
3838 let mut z = vec![0.0_f64; d];
3841 for i in 0..n {
3842 let y_base = i * p_max * r_template;
3843 for r in 0..d {
3844 let mut acc = 0.0;
3845 for c in 0..k {
3846 acc += y_host[y_base + c * p_max + r] * x[c];
3847 }
3848 z[r] = acc;
3849 }
3850 for c in 0..k {
3851 let mut acc = 0.0;
3852 for r in 0..d {
3853 acc += y_host[y_base + c * p_max + r] * z[r];
3854 }
3855 out[c] -= acc;
3856 }
3857 }
3858 });
3859
3860 Ok(closure)
3861 }
3862
3863 struct DeviceSaeFrameBuffers {
3866 s_off: CudaSlice<i32>,
3868 s_m: CudaSlice<i32>,
3869 s_r: CudaSlice<i32>,
3870 s_ptr: CudaSlice<i32>,
3871 s_data: CudaSlice<f64>,
3872 s_blocks: usize,
3873 g_off_i: CudaSlice<i32>,
3875 g_off_j: CudaSlice<i32>,
3876 g_ri: CudaSlice<i32>,
3877 g_rj: CudaSlice<i32>,
3878 g_mi: CudaSlice<i32>,
3879 g_mj: CudaSlice<i32>,
3880 g_ptr: CudaSlice<i32>,
3881 g_data: CudaSlice<f64>,
3882 w_ptr: CudaSlice<i32>,
3883 w_data: CudaSlice<f64>,
3884 g_blocks: usize,
3885 g_max_work: usize,
3886 htb_ptr: CudaSlice<i32>,
3888 htb: CudaSlice<f64>,
3889 q_of: CudaSlice<i32>,
3890 ainv: CudaSlice<f64>,
3891 hvec: CudaSlice<f64>,
3892 svec: CudaSlice<f64>,
3893 n_rows: usize,
3894 k: usize,
3895 max_q: usize,
3896 }
3897
3898 fn flatten_device_sae_frame_data(
3899 sys: &ArrowSchurSystem,
3900 data: &DeviceSaePcgData,
3901 frame: &DeviceSaeFrameData,
3902 ridge_t: f64,
3903 stream: &Arc<CudaStream>,
3904 ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3905 let n_rows = sys.rows.len();
3906 let k = data.beta_dim;
3907 if frame.row_htbeta.len() != n_rows
3908 || frame.ranks.len() != frame.basis_sizes.len()
3909 || frame.border_offsets.len() != frame.ranks.len()
3910 || data.smooth_blocks.len() != frame.smooth_ranks.len()
3911 {
3912 return Err(ArrowSchurGpuFailure::Unavailable);
3913 }
3914
3915 let mut s_off = Vec::new();
3917 let mut s_m = Vec::new();
3918 let mut s_r = Vec::new();
3919 let mut s_ptr = vec![0_i32];
3920 let mut s_data = Vec::<f64>::new();
3921 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3922 let (m, mc) = blk.factor_a.dim();
3923 if m != mc {
3924 return Err(ArrowSchurGpuFailure::Unavailable);
3925 }
3926 s_off.push(checked_i32(blk.global_offset)?);
3927 s_m.push(checked_i32(m)?);
3928 s_r.push(checked_i32(r)?);
3929 for ri in 0..m {
3930 for ci in 0..m {
3931 s_data.push(blk.factor_a[[ri, ci]]);
3932 }
3933 }
3934 s_ptr.push(checked_i32(s_data.len())?);
3935 }
3936
3937 let mut g_off_i = Vec::new();
3939 let mut g_off_j = Vec::new();
3940 let mut g_ri = Vec::new();
3941 let mut g_rj = Vec::new();
3942 let mut g_mi = Vec::new();
3943 let mut g_mj = Vec::new();
3944 let mut g_ptr = vec![0_i32];
3945 let mut g_data = Vec::<f64>::new();
3946 let mut w_ptr = vec![0_i32];
3947 let mut w_data = Vec::<f64>::new();
3948 let mut g_max_work = 0usize;
3949 for blk in &frame.frame_blocks {
3950 let ri = frame.ranks[blk.atom_i];
3951 let rj = frame.ranks[blk.atom_j];
3952 let (mi, mj) = blk.g.dim();
3953 if blk.w.dim() != (ri, rj) {
3954 return Err(ArrowSchurGpuFailure::Unavailable);
3955 }
3956 g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3957 g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3958 g_ri.push(checked_i32(ri)?);
3959 g_rj.push(checked_i32(rj)?);
3960 g_mi.push(checked_i32(mi)?);
3961 g_mj.push(checked_i32(mj)?);
3962 for r in 0..mi {
3963 for c in 0..mj {
3964 g_data.push(blk.g[[r, c]]);
3965 }
3966 }
3967 g_ptr.push(checked_i32(g_data.len())?);
3968 for a in 0..ri {
3969 for b in 0..rj {
3970 w_data.push(blk.w[[a, b]]);
3971 }
3972 }
3973 w_ptr.push(checked_i32(w_data.len())?);
3974 g_max_work = g_max_work.max(mi * ri);
3975 }
3976
3977 let mut htb_ptr = vec![0_i32];
3979 let mut htb = Vec::<f64>::new();
3980 let mut q_of = Vec::<i32>::with_capacity(n_rows);
3981 let mut max_q = 0usize;
3982 for (i, slab) in frame.row_htbeta.iter().enumerate() {
3983 let qi = sys.row_dims[i];
3984 let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3987 qi
3988 } else {
3989 0
3990 };
3991 q_of.push(checked_i32(q_eff)?);
3992 max_q = max_q.max(q_eff);
3993 if q_eff > 0 {
3994 htb.extend_from_slice(slab);
3995 }
3996 htb_ptr.push(checked_i32(htb.len())?);
3997 }
3998 if max_q == 0 {
3999 max_q = 1;
4002 }
4003
4004 let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
4005 for (i, row) in sys.rows.iter().enumerate() {
4006 let q = q_of[i] as usize;
4007 if q == 0 {
4008 continue;
4009 }
4010 if row.htt.dim() != (q, q) {
4011 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4012 reason: format!(
4013 "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
4014 row.htt.dim()
4015 ),
4016 });
4017 }
4018 let mut block = row.htt.clone();
4019 for d in 0..q {
4020 block[[d, d]] += ridge_t;
4021 }
4022 let factor = gam_linalg::triangular::cholesky_factor_in_place(
4023 block.view(),
4024 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
4025 )
4026 .ok_or_else(|| {
4027 ArrowSchurGpuFailure::RidgeBumpRequired {
4030 row: i,
4031 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
4032 }
4033 })?;
4034 for col in 0..q {
4035 let mut e = Array1::<f64>::zeros(q);
4036 e[col] = 1.0;
4037 let solved =
4038 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
4039 for r in 0..q {
4040 ainv[i * max_q * max_q + r * max_q + col] = solved[r];
4041 }
4042 }
4043 }
4044
4045 let htod_i = |v: &[i32]| {
4046 stream
4047 .clone_htod(v)
4048 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4049 };
4050 let htod_f = |v: &[f64]| {
4051 stream
4052 .clone_htod(v)
4053 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4054 };
4055 Ok(DeviceSaeFrameBuffers {
4056 s_off: htod_i(&s_off)?,
4057 s_m: htod_i(&s_m)?,
4058 s_r: htod_i(&s_r)?,
4059 s_ptr: htod_i(&s_ptr)?,
4060 s_data: htod_f(&s_data)?,
4061 s_blocks: data.smooth_blocks.len(),
4062 g_off_i: htod_i(&g_off_i)?,
4063 g_off_j: htod_i(&g_off_j)?,
4064 g_ri: htod_i(&g_ri)?,
4065 g_rj: htod_i(&g_rj)?,
4066 g_mi: htod_i(&g_mi)?,
4067 g_mj: htod_i(&g_mj)?,
4068 g_ptr: htod_i(&g_ptr)?,
4069 g_data: htod_f(&g_data)?,
4070 w_ptr: htod_i(&w_ptr)?,
4071 w_data: htod_f(&w_data)?,
4072 g_blocks: frame.frame_blocks.len(),
4073 g_max_work,
4074 htb_ptr: htod_i(&htb_ptr)?,
4075 htb: htod_f(&htb)?,
4076 q_of: htod_i(&q_of)?,
4077 ainv: htod_f(&ainv)?,
4078 hvec: stream
4079 .alloc_zeros::<f64>(n_rows * max_q)
4080 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4081 svec: stream
4082 .alloc_zeros::<f64>(n_rows * max_q)
4083 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4084 n_rows,
4085 k,
4086 max_q,
4087 })
4088 }
4089
4090 fn sae_frame_penalty_diag_host(
4091 data: &DeviceSaePcgData,
4092 frame: &DeviceSaeFrameData,
4093 ridge_beta: f64,
4094 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4095 let mut diag = vec![ridge_beta; data.beta_dim];
4096 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4098 let m = blk.factor_a.nrows();
4099 for ia in 0..m {
4100 let coeff = blk.factor_a[[ia, ia]];
4101 let base = blk.global_offset + ia * r;
4102 for ib in 0..r {
4103 if base + ib >= diag.len() {
4104 return Err(ArrowSchurGpuFailure::Unavailable);
4105 }
4106 diag[base + ib] += coeff;
4107 }
4108 }
4109 }
4110 for blk in &frame.frame_blocks {
4112 if blk.atom_i != blk.atom_j {
4113 continue;
4114 }
4115 let r = frame.ranks[blk.atom_i];
4116 let off = frame.border_offsets[blk.atom_i];
4117 let (mi, mj) = blk.g.dim();
4118 for li in 0..mi.min(mj) {
4119 let gii = blk.g[[li, li]];
4120 let base = off + li * r;
4121 for a in 0..r {
4122 if base + a >= diag.len() {
4123 return Err(ArrowSchurGpuFailure::Unavailable);
4124 }
4125 diag[base + a] += gii * blk.w[[a, a]];
4126 }
4127 }
4128 }
4129 Ok(diag)
4130 }
4131
4132 fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4133 Ok(LaunchConfig {
4134 grid_dim: (
4135 ((work as u32).saturating_add(255) / 256).max(1),
4136 checked_i32(n_rows)? as u32,
4137 1,
4138 ),
4139 block_dim: (256, 1, 1),
4140 shared_mem_bytes: 0,
4141 })
4142 }
4143
4144 fn launch_sae_frame_matvec(
4145 stream: &Arc<CudaStream>,
4146 module: &Arc<CudaModule>,
4147 buffers: &mut DeviceSaeFrameBuffers,
4148 x: &CudaSlice<f64>,
4149 out: &mut CudaSlice<f64>,
4150 ridge_beta: f64,
4151 ) -> Result<(), ArrowSchurGpuFailure> {
4152 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4153 if buffers.s_blocks > 0 {
4155 let kernel = module
4156 .load_function("arrow_sae_frame_smooth_matvec")
4157 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4158 let blocks_i32 = checked_i32(buffers.s_blocks)?;
4159 let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4160 let mut b = stream.launch_builder(&kernel);
4161 b.arg(x)
4162 .arg(&mut *out)
4163 .arg(&buffers.s_off)
4164 .arg(&buffers.s_m)
4165 .arg(&buffers.s_r)
4166 .arg(&buffers.s_ptr)
4167 .arg(&buffers.s_data)
4168 .arg(&blocks_i32);
4169 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4172 }
4173 if buffers.g_blocks > 0 {
4175 let kernel = module
4176 .load_function("arrow_sae_frame_g_matvec")
4177 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4178 let blocks_i32 = checked_i32(buffers.g_blocks)?;
4179 let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4180 let mut b = stream.launch_builder(&kernel);
4181 b.arg(x)
4182 .arg(&mut *out)
4183 .arg(&buffers.g_off_i)
4184 .arg(&buffers.g_off_j)
4185 .arg(&buffers.g_ri)
4186 .arg(&buffers.g_rj)
4187 .arg(&buffers.g_mi)
4188 .arg(&buffers.g_mj)
4189 .arg(&buffers.g_ptr)
4190 .arg(&buffers.g_data)
4191 .arg(&buffers.w_ptr)
4192 .arg(&buffers.w_data)
4193 .arg(&blocks_i32);
4194 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4197 }
4198 let k_i32 = checked_i32(buffers.k)?;
4200 let max_q_i32 = checked_i32(buffers.max_q)?;
4201 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4202 {
4203 let kernel = module
4204 .load_function("arrow_sae_frame_apply_h")
4205 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4206 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4207 let mut b = stream.launch_builder(&kernel);
4208 b.arg(x)
4209 .arg(&buffers.htb_ptr)
4210 .arg(&buffers.htb)
4211 .arg(&buffers.q_of)
4212 .arg(&mut buffers.hvec)
4213 .arg(&k_i32)
4214 .arg(&max_q_i32)
4215 .arg(&n_rows_i32);
4216 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4219 }
4220 {
4221 let kernel = module
4222 .load_function("arrow_sae_frame_apply_ainv")
4223 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4224 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4225 let mut b = stream.launch_builder(&kernel);
4226 b.arg(&buffers.ainv)
4227 .arg(&buffers.hvec)
4228 .arg(&buffers.q_of)
4229 .arg(&mut buffers.svec)
4230 .arg(&max_q_i32)
4231 .arg(&n_rows_i32);
4232 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4235 }
4236 {
4237 let kernel = module
4238 .load_function("arrow_sae_frame_scatter_h")
4239 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4240 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4241 let mut b = stream.launch_builder(&kernel);
4242 b.arg(&buffers.svec)
4243 .arg(&buffers.htb_ptr)
4244 .arg(&buffers.htb)
4245 .arg(&buffers.q_of)
4246 .arg(out)
4247 .arg(&k_i32)
4248 .arg(&max_q_i32)
4249 .arg(&n_rows_i32);
4250 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4253 }
4254 Ok(())
4255 }
4256
4257 fn launch_sae_frame_diag_sub(
4258 stream: &Arc<CudaStream>,
4259 module: &Arc<CudaModule>,
4260 buffers: &DeviceSaeFrameBuffers,
4261 diag: &mut CudaSlice<f64>,
4262 ) -> Result<(), ArrowSchurGpuFailure> {
4263 let kernel = module
4264 .load_function("arrow_sae_frame_diag_sub")
4265 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4266 let k_i32 = checked_i32(buffers.k)?;
4267 let max_q_i32 = checked_i32(buffers.max_q)?;
4268 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4269 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4270 let mut b = stream.launch_builder(&kernel);
4271 b.arg(diag)
4272 .arg(&buffers.ainv)
4273 .arg(&buffers.htb_ptr)
4274 .arg(&buffers.htb)
4275 .arg(&buffers.q_of)
4276 .arg(&k_i32)
4277 .arg(&max_q_i32)
4278 .arg(&n_rows_i32);
4279 unsafe { b.launch(cfg) }
4281 .map(drop)
4282 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4283 }
4284
4285 pub(super) fn framed_schur_matvec_once_on_device(
4296 sys: &ArrowSchurSystem,
4297 data: &DeviceSaePcgData,
4298 ridge_t: f64,
4299 ridge_beta: f64,
4300 x: &Array1<f64>,
4301 ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4302 let k = x.len();
4303 if k == 0 || data.beta_dim != k || sys.k != k {
4304 return Err(ArrowSchurGpuFailure::Unavailable);
4305 }
4306 let frame = data
4307 .frame
4308 .as_ref()
4309 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4310 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4313 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4314 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4315 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316 let stream = ctx
4317 .new_stream()
4318 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4319 let vector_module = pcg_vector_module(&ctx)?;
4320 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4321 let x_dev = stream
4322 .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4323 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4324 let mut out_dev = stream
4325 .alloc_zeros::<f64>(k)
4326 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4327 launch_sae_frame_matvec(
4328 &stream,
4329 vector_module,
4330 &mut buffers,
4331 &x_dev,
4332 &mut out_dev,
4333 ridge_beta,
4334 )?;
4335 let out = stream
4336 .clone_dtoh(&out_dev)
4337 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4338 Ok(Array1::from_vec(out))
4339 }
4340
4341 pub(super) fn solve_sae_matrix_free_pcg_framed(
4342 sys: &ArrowSchurSystem,
4343 data: &DeviceSaePcgData,
4344 ridge_t: f64,
4345 ridge_beta: f64,
4346 rhs_beta: &Array1<f64>,
4347 max_iterations: usize,
4348 relative_tolerance: f64,
4349 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4350 let k = rhs_beta.len();
4351 if k == 0 || data.beta_dim != k || sys.k != k {
4352 return Err(ArrowSchurGpuFailure::Unavailable);
4353 }
4354 let frame = data
4355 .frame
4356 .as_ref()
4357 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4358 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4359 .filter(|rt| {
4360 rt.policy().reduced_schur_matvec_should_offload(
4361 sys.rows.len(),
4362 sys.k,
4363 sys.d,
4364 max_iterations,
4365 )
4366 })
4367 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4368 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4369 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4370 let stream = ctx
4371 .new_stream()
4372 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4374 let vector_module = pcg_vector_module(&ctx)?;
4375 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4376
4377 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4378 if rhs_norm == 0.0 {
4379 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4380 }
4381 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4382 let rhs_dev = stream
4383 .clone_htod(
4384 rhs_beta
4385 .as_slice()
4386 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4387 )
4388 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4389 let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4390 let mut diag_dev = stream
4391 .clone_htod(&diag_host)
4392 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4393 launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4394 let diag_host = stream
4395 .clone_dtoh(&diag_dev)
4396 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4397 let mut inv_diag = Vec::with_capacity(k);
4398 for (idx, &d) in diag_host.iter().enumerate() {
4399 if !d.is_finite() || d <= 1.0e-18 {
4400 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4401 reason: format!(
4402 "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4403 ),
4404 });
4405 }
4406 inv_diag.push(1.0 / d);
4407 }
4408 let inv_diag_dev = stream
4409 .clone_htod(&inv_diag)
4410 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4411
4412 let mut x_dev = stream
4413 .alloc_zeros::<f64>(k)
4414 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4415 let mut r_dev = stream
4416 .alloc_zeros::<f64>(k)
4417 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4418 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4419 let mut z_dev = stream
4420 .alloc_zeros::<f64>(k)
4421 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4422 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4423 let mut p_dev = stream
4424 .alloc_zeros::<f64>(k)
4425 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4426 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4427 let mut ap_dev = stream
4428 .alloc_zeros::<f64>(k)
4429 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4430
4431 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4432 if rz <= 0.0 || !rz.is_finite() {
4433 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4434 reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4435 });
4436 }
4437 let mut diag = PcgDiagnostics {
4438 precond_apply_calls: 1,
4439 stopping_reason: PcgStopReason::MaxIter,
4440 ..PcgDiagnostics::default()
4441 };
4442 for _ in 0..max_iterations.max(1) {
4443 launch_sae_frame_matvec(
4444 &stream,
4445 vector_module,
4446 &mut buffers,
4447 &p_dev,
4448 &mut ap_dev,
4449 ridge_beta,
4450 )?;
4451 diag.matvec_calls += 1;
4452 diag.iterations += 1;
4453 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4454 if pap <= 0.0 || !pap.is_finite() {
4455 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4456 reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4457 });
4458 }
4459 let alpha = rz / pap;
4460 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4461 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4462 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4463 if r_norm <= tol {
4464 diag.final_relative_residual = r_norm / rhs_norm;
4465 diag.stopping_reason = PcgStopReason::Converged;
4466 break;
4467 }
4468 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4469 diag.precond_apply_calls += 1;
4470 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4471 if rz_new <= 0.0 || !rz_new.is_finite() {
4472 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4473 reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4474 });
4475 }
4476 let beta = rz_new / rz;
4477 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4478 rz = rz_new;
4479 }
4480 if diag.stopping_reason != PcgStopReason::Converged {
4481 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4482 diag.final_relative_residual = r_norm / rhs_norm;
4483 diag.stopping_reason = PcgStopReason::MaxIter;
4484 }
4485 let x = stream
4486 .clone_dtoh(&x_dev)
4487 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4488 Ok((Array1::from_vec(x), diag))
4489 }
4490
4491 pub(super) fn solve_sae_matrix_free_pcg(
4498 sys: &ArrowSchurSystem,
4499 data: &DeviceSaePcgData,
4500 ridge_t: f64,
4501 ridge_beta: f64,
4502 rhs_beta: &Array1<f64>,
4503 max_iterations: usize,
4504 relative_tolerance: f64,
4505 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4506 let k = rhs_beta.len();
4507 if k == 0 || data.beta_dim != k || sys.k != k {
4508 return Err(ArrowSchurGpuFailure::Unavailable);
4509 }
4510 if data.frame.is_some() {
4514 return Err(ArrowSchurGpuFailure::Unavailable);
4515 }
4516 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4530 .filter(|rt| {
4531 rt.policy().reduced_schur_matvec_should_offload(
4532 sys.rows.len(),
4533 sys.k,
4534 sys.d,
4535 max_iterations,
4536 )
4537 })
4538 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4539 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4540 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4541 let stream = ctx
4542 .new_stream()
4543 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4545 let vector_module = pcg_vector_module(&ctx)?;
4546 let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4547
4548 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4549 if rhs_norm == 0.0 {
4550 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4551 }
4552 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4553 let rhs_dev = stream
4554 .clone_htod(
4555 rhs_beta
4556 .as_slice()
4557 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4558 )
4559 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4560 let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4561 let mut diag_dev = stream
4562 .clone_htod(&diag_host)
4563 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4564 launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4565 let diag_host = stream
4566 .clone_dtoh(&diag_dev)
4567 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4568 let mut inv_diag = Vec::with_capacity(k);
4569 for (idx, &d) in diag_host.iter().enumerate() {
4570 if !d.is_finite() || d <= 1.0e-18 {
4571 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4572 reason: format!(
4573 "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4574 ),
4575 });
4576 }
4577 inv_diag.push(1.0 / d);
4578 }
4579 let inv_diag_dev = stream
4580 .clone_htod(&inv_diag)
4581 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4582
4583 let mut x_dev = stream
4584 .alloc_zeros::<f64>(k)
4585 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4586 let mut r_dev = stream
4587 .alloc_zeros::<f64>(k)
4588 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4589 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4590 let mut z_dev = stream
4591 .alloc_zeros::<f64>(k)
4592 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4593 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4594 let mut p_dev = stream
4595 .alloc_zeros::<f64>(k)
4596 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4597 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4598 let mut ap_dev = stream
4599 .alloc_zeros::<f64>(k)
4600 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4601
4602 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4603 if rz <= 0.0 || !rz.is_finite() {
4604 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4605 reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4606 });
4607 }
4608 let mut diag = PcgDiagnostics {
4609 precond_apply_calls: 1,
4610 stopping_reason: PcgStopReason::MaxIter,
4611 ..PcgDiagnostics::default()
4612 };
4613
4614 for _ in 0..max_iterations.max(1) {
4615 launch_sae_matvec(
4616 &stream,
4617 vector_module,
4618 &mut buffers,
4619 &p_dev,
4620 &mut ap_dev,
4621 ridge_beta,
4622 )?;
4623 diag.matvec_calls += 1;
4624 diag.iterations += 1;
4625 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4626 if pap <= 0.0 || !pap.is_finite() {
4627 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4628 reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4629 });
4630 }
4631 let alpha = rz / pap;
4632 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4633 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4634 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4635 if r_norm <= tol {
4636 diag.final_relative_residual = r_norm / rhs_norm;
4637 diag.stopping_reason = PcgStopReason::Converged;
4638 break;
4639 }
4640 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4641 diag.precond_apply_calls += 1;
4642 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4643 if rz_new <= 0.0 || !rz_new.is_finite() {
4644 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4645 reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4646 });
4647 }
4648 let beta = rz_new / rz;
4649 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4650 rz = rz_new;
4651 }
4652 if diag.stopping_reason != PcgStopReason::Converged {
4653 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4654 diag.final_relative_residual = r_norm / rhs_norm;
4655 diag.stopping_reason = PcgStopReason::MaxIter;
4656 }
4657 let x = stream
4658 .clone_dtoh(&x_dev)
4659 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4660 Ok((Array1::from_vec(x), diag))
4661 }
4662
4663 pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4664 s_acc: &ndarray::Array2<f64>,
4665 rhs_beta: &Array1<f64>,
4666 max_iterations: usize,
4667 relative_tolerance: f64,
4668 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4669 let k = rhs_beta.len();
4670 let cg_iters = max_iterations.max(1);
4682 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4683 gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4684 m: k,
4685 n: k,
4686 k: cg_iters,
4687 },
4688 )
4689 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4690 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4691 .and_then(|ctx| ctx.new_stream().ok())
4692 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4693 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4694 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4695 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4696 let vector_module = pcg_vector_module(&ctx)?;
4697
4698 let mut inv_diag = vec![0.0_f64; k];
4700 for j in 0..k {
4701 let djj = s_acc[[j, j]];
4702 if !(djj > 0.0) {
4703 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4704 reason: format!(
4705 "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4706 ),
4707 });
4708 }
4709 inv_diag[j] = 1.0 / djj;
4710 }
4711
4712 let mut s_host = vec![0.0_f64; k * k];
4714 for col in 0..k {
4715 for row in 0..k {
4716 s_host[col * k + row] = s_acc[[row, col]];
4717 }
4718 }
4719 let s_dev = stream
4720 .clone_htod(&s_host)
4721 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4722
4723 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4727 if rhs_norm == 0.0 {
4728 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4729 }
4730 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4731
4732 let mut x_dev = stream
4735 .alloc_zeros::<f64>(k)
4736 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4737 let mut r_dev = stream
4738 .clone_htod(
4739 rhs_beta
4740 .as_slice()
4741 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4742 )
4743 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4744 let inv_diag_dev = stream
4745 .clone_htod(&inv_diag)
4746 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4747 let mut z_dev = stream
4748 .alloc_zeros::<f64>(k)
4749 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4750 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4751 let mut p_dev = stream
4752 .alloc_zeros::<f64>(k)
4753 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4754 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4755 let mut sp_dev = stream
4756 .alloc_zeros::<f64>(k)
4757 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4758 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4759 let mut diag = PcgDiagnostics {
4760 precond_apply_calls: 1,
4761 stopping_reason: PcgStopReason::MaxIter,
4762 ..PcgDiagnostics::default()
4763 };
4764 if rz <= 0.0 || !rz.is_finite() {
4765 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4766 reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4767 });
4768 }
4769
4770 let max_iters = max_iterations.max(1);
4771 for _ in 0..max_iters {
4772 let gemv_cfg = GemvConfig::<f64> {
4774 trans: cublasOperation_t::CUBLAS_OP_N,
4775 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4776 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4777 alpha: 1.0,
4778 lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4779 incx: 1,
4780 beta: 0.0,
4781 incy: 1,
4782 };
4783 unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4785 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4786 diag.matvec_calls += 1;
4787 diag.iterations += 1;
4788
4789 let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4790 if !(p_sp > 0.0) {
4791 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4794 reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4795 });
4796 }
4797 let alpha = rz / p_sp;
4798 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4799 device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4800 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4801 if r_norm <= tol {
4802 diag.final_relative_residual = r_norm / rhs_norm;
4803 diag.stopping_reason = PcgStopReason::Converged;
4804 break;
4805 }
4806 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4807 diag.precond_apply_calls += 1;
4808 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4809 if rz_new <= 0.0 || !rz_new.is_finite() {
4810 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4811 reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4812 });
4813 }
4814 let beta = rz_new / rz;
4815 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4816 rz = rz_new;
4817 }
4818 if diag.stopping_reason != PcgStopReason::Converged {
4819 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4820 diag.final_relative_residual = r_norm / rhs_norm;
4821 diag.stopping_reason = PcgStopReason::MaxIter;
4822 }
4823
4824 let x = stream
4825 .clone_dtoh(&x_dev)
4826 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4827 Ok((Array1::from_vec(x), diag))
4828 }
4829
4830 fn device_copy(
4831 blas: &CudaBlas,
4832 stream: &Arc<CudaStream>,
4833 n: usize,
4834 src: &CudaSlice<f64>,
4835 dst: &mut CudaSlice<f64>,
4836 ) -> Result<(), ArrowSchurGpuFailure> {
4837 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4838 let (src_ptr, _src_rec) = src.device_ptr(stream);
4839 let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4840 let status = unsafe {
4843 cudarc::cublas::sys::cublasDcopy_v2(
4844 *blas.handle(),
4845 n_i,
4846 src_ptr as *const f64,
4847 1,
4848 dst_ptr as *mut f64,
4849 1,
4850 )
4851 };
4852 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4853 Ok(())
4854 } else {
4855 Err(ArrowSchurGpuFailure::Unavailable)
4856 }
4857 }
4858
4859 fn device_axpy(
4860 blas: &CudaBlas,
4861 stream: &Arc<CudaStream>,
4862 n: usize,
4863 alpha: f64,
4864 x: &CudaSlice<f64>,
4865 y: &mut CudaSlice<f64>,
4866 ) -> Result<(), ArrowSchurGpuFailure> {
4867 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4868 let (x_ptr, _x_rec) = x.device_ptr(stream);
4869 let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4870 let status = unsafe {
4873 cudarc::cublas::sys::cublasDaxpy_v2(
4874 *blas.handle(),
4875 n_i,
4876 &alpha,
4877 x_ptr as *const f64,
4878 1,
4879 y_ptr as *mut f64,
4880 1,
4881 )
4882 };
4883 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4884 Ok(())
4885 } else {
4886 Err(ArrowSchurGpuFailure::Unavailable)
4887 }
4888 }
4889
4890 fn device_dot(
4891 blas: &CudaBlas,
4892 stream: &Arc<CudaStream>,
4893 n: usize,
4894 x: &CudaSlice<f64>,
4895 y: &CudaSlice<f64>,
4896 ) -> Result<f64, ArrowSchurGpuFailure> {
4897 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4898 let (x_ptr, _x_rec) = x.device_ptr(stream);
4899 let (y_ptr, _y_rec) = y.device_ptr(stream);
4900 let mut result = 0.0_f64;
4901 let status = unsafe {
4905 cudarc::cublas::sys::cublasDdot_v2(
4906 *blas.handle(),
4907 n_i,
4908 x_ptr as *const f64,
4909 1,
4910 y_ptr as *const f64,
4911 1,
4912 &mut result,
4913 )
4914 };
4915 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4916 Ok(result)
4917 } else {
4918 Err(ArrowSchurGpuFailure::Unavailable)
4919 }
4920 }
4921
4922 fn device_nrm2(
4923 blas: &CudaBlas,
4924 stream: &Arc<CudaStream>,
4925 n: usize,
4926 x: &CudaSlice<f64>,
4927 ) -> Result<f64, ArrowSchurGpuFailure> {
4928 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4929 let (x_ptr, _x_rec) = x.device_ptr(stream);
4930 let mut result = 0.0_f64;
4931 let status = unsafe {
4935 cudarc::cublas::sys::cublasDnrm2_v2(
4936 *blas.handle(),
4937 n_i,
4938 x_ptr as *const f64,
4939 1,
4940 &mut result,
4941 )
4942 };
4943 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4944 Ok(result)
4945 } else {
4946 Err(ArrowSchurGpuFailure::Unavailable)
4947 }
4948 }
4949
4950 #[cfg(test)]
4951 mod tests {
4952 use super::*;
4957 use crate::arrow_schur::{
4958 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4959 FactoredFrameGBlock,
4960 };
4961 use ndarray::Array2;
4962
4963 fn device_matvec_once(
4966 sys: &ArrowSchurSystem,
4967 data: &DeviceSaePcgData,
4968 ridge_t: f64,
4969 ridge_beta: f64,
4970 x_host: &[f64],
4971 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4972 let k = x_host.len();
4973 let frame = data
4974 .frame
4975 .as_ref()
4976 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4977 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4978 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4979 let ctx =
4980 gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4981 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4982 let stream = ctx
4983 .new_stream()
4984 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4985 let vector_module = pcg_vector_module(&ctx)?;
4986 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4987 let x_dev = stream
4988 .clone_htod(x_host)
4989 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4990 let mut out_dev = stream
4991 .alloc_zeros::<f64>(k)
4992 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4993 launch_sae_frame_matvec(
4994 &stream,
4995 vector_module,
4996 &mut buffers,
4997 &x_dev,
4998 &mut out_dev,
4999 ridge_beta,
5000 )?;
5001 stream
5002 .clone_dtoh(&out_dev)
5003 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
5004 }
5005
5006 #[test]
5012 fn framed_sae_device_matvec_stage_diff_tiny_1551() {
5013 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
5014 return;
5015 }
5016 let p = 3usize;
5017 let ranks = vec![2usize, 3usize];
5018 let basis_sizes = vec![2usize, 2usize];
5019 let mut border_offsets = Vec::new();
5020 let mut acc = 0usize;
5021 for k in 0..2 {
5022 border_offsets.push(acc);
5023 acc += basis_sizes[k] * ranks[k];
5024 }
5025 let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
5027 Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
5028 0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
5029 })
5030 };
5031 let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
5032 let w_of = |i: usize, j: usize| -> Array2<f64> {
5033 let (ui, uj) = (&frames[i], &frames[j]);
5034 Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
5035 (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
5036 })
5037 };
5038 let mut frame_blocks = Vec::new();
5039 for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
5040 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5041 let mut g =
5042 Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
5043 if i == j {
5044 for r in 0..mi.min(mj) {
5045 g[[r, r]] += mi as f64 + 2.0;
5046 }
5047 }
5048 frame_blocks.push(FactoredFrameGBlock {
5049 atom_i: i,
5050 atom_j: j,
5051 g,
5052 w: w_of(i, j),
5053 });
5054 }
5055 let mut smooth_blocks = Vec::new();
5056 for k in 0..2 {
5057 let m = basis_sizes[k];
5058 let mut s =
5059 Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5060 for r in 0..m {
5061 s[[r, r]] += 1.0;
5062 }
5063 smooth_blocks.push(DeviceSaeSmoothBlock {
5064 global_offset: border_offsets[k],
5065 factor_a: s,
5066 });
5067 }
5068 let smooth_ranks = ranks.clone();
5069 let n = 2usize;
5070 let q = 2usize;
5071 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5072 let mut row_htbeta = Vec::new();
5073 for i in 0..n {
5074 let mut htt =
5075 Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5076 for r in 0..q {
5077 htt[[r, r]] += q as f64 + 2.0;
5078 }
5079 sys.rows[i].htt = htt;
5080 let mut slab = vec![0.0_f64; q * border_dim];
5081 for c in 0..q {
5082 for col in 0..border_dim {
5083 let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5084 slab[c * border_dim + col] = v;
5085 sys.rows[i].htbeta[[c, col]] = v;
5086 }
5087 }
5088 row_htbeta.push(slab);
5089 }
5090 let data = DeviceSaePcgData {
5091 p,
5092 beta_dim: border_dim,
5093 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5094 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5095 smooth_blocks,
5096 sparse_g_blocks: Vec::new(),
5097 frame: Some(DeviceSaeFrameData {
5098 ranks,
5099 basis_sizes,
5100 border_offsets,
5101 frame_blocks,
5102 smooth_ranks,
5103 row_htbeta,
5104 }),
5105 };
5106 let ridge_t = 1e-7;
5107 let ridge_beta = 1e-6;
5108 let mut first_bad: Option<usize> = None;
5109 let mut worst = 0.0_f64;
5110 let mut worst_at = 0usize;
5111 let mut worst_dev = 0.0_f64;
5112 let mut worst_cpu = 0.0_f64;
5113 for col in 0..border_dim {
5114 let mut x = vec![0.0_f64; border_dim];
5115 x[col] = 1.0;
5116 let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5117 Ok(v) => v,
5118 Err(_) => return,
5119 };
5120 let mut cpu = vec![0.0_f64; border_dim];
5121 super::super::sae_framed_schur_matvec_cpu(
5122 &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5123 )
5124 .expect("cpu matvec");
5125 for r in 0..border_dim {
5126 let d = (dev[r] - cpu[r]).abs();
5127 if d > 1e-9 && first_bad.is_none() {
5128 first_bad = Some(r * border_dim + col);
5129 }
5130 if d > worst {
5131 worst = d;
5132 worst_at = r * border_dim + col;
5133 worst_dev = dev[r];
5134 worst_cpu = cpu[r];
5135 }
5136 }
5137 }
5138 assert!(
5139 worst <= 1e-9,
5140 "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5141 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5142 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5143 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5144 G⊗W=cross, reduced-Schur=dense per-row)",
5145 );
5146 }
5147 }
5148}
5149
5150#[cfg(test)]
5151mod tests {
5152 use super::*;
5153 use crate::arrow_schur::ArrowSchurSystem;
5154 use ndarray::{Array2, ArrayView1};
5155
5156 fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5157 let mut sys = ArrowSchurSystem::new(n, d, k);
5158 let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5159 let mut sample = || -> f64 {
5160 state = state
5161 .wrapping_mul(6364136223846793005)
5162 .wrapping_add(1442695040888963407);
5163 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5164 };
5165 for row in &mut sys.rows {
5166 let mut a = Array2::<f64>::zeros((d, d));
5167 for r in 0..d {
5168 for c in 0..d {
5169 a[[r, c]] = sample();
5170 }
5171 }
5172 let mut htt = a.t().dot(&a);
5173 for r in 0..d {
5174 htt[[r, r]] += d as f64 + 1.0;
5175 }
5176 row.htt = htt;
5177 for r in 0..d {
5178 for c in 0..k {
5179 row.htbeta[[r, c]] = 0.1 * sample();
5180 }
5181 row.gt[r] = sample();
5182 }
5183 }
5184 let mut hbb_a = Array2::<f64>::zeros((k, k));
5185 for r in 0..k {
5186 for c in 0..k {
5187 hbb_a[[r, c]] = sample();
5188 }
5189 }
5190 let mut hbb = hbb_a.t().dot(&hbb_a);
5191 for r in 0..k {
5192 hbb[[r, r]] += k as f64 + 1.0;
5193 }
5194 sys.hbb = hbb;
5195 for r in 0..k {
5196 sys.gb[r] = sample();
5197 }
5198 sys
5199 }
5200
5201 #[test]
5206 fn ridge_bump_makes_known_indefinite_blocks_pd() {
5207 let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); let mut indef2 = Array2::<f64>::zeros((2, 2));
5215 indef2[[0, 0]] = 1.0;
5216 indef2[[1, 1]] = 1.0;
5217 indef2[[0, 1]] = 2.0;
5218 indef2[[1, 0]] = 2.0;
5219 let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5222
5223 for (label, block) in [
5224 ("-I (λ_min=-1)", neg_identity),
5225 ("-250·I (λ_min=-250)", scaled_neg),
5226 ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5227 ("5·I (PD)", pd),
5228 ] {
5229 let ridge_t = 0.0;
5230 let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5231 assert!(
5232 bump > 0.0 && bump.is_finite(),
5233 "[{label}] bump must be strictly positive and finite, got {bump:e}"
5234 );
5235 let d = block.nrows();
5236 let mut shifted = block.clone();
5237 for i in 0..d {
5238 shifted[[i, i]] += ridge_t + bump;
5239 }
5240 assert!(
5241 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5242 "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5243 Gershgorin bump, but the Cholesky still rejected it"
5244 );
5245 }
5246 }
5247
5248 #[cfg(target_os = "linux")]
5258 #[test]
5259 fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5260 let mut a = Array2::<f64>::zeros((3, 3));
5262 a[[0, 0]] = -2.0;
5263 a[[1, 1]] = 0.5;
5264 a[[2, 2]] = 1.0;
5265 a[[0, 1]] = 0.3;
5266 a[[1, 0]] = 0.3;
5267 a[[1, 2]] = -0.4;
5268 a[[2, 1]] = -0.4;
5269 a[[0, 2]] = 0.1;
5270 a[[2, 0]] = 0.1;
5271
5272 let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5273
5274 let d = 3;
5276 let mut col_major = vec![0.0_f64; d * d];
5277 for c in 0..d {
5278 for r in 0..d {
5279 col_major[c * d + r] = a[[r, c]];
5280 }
5281 }
5282 let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5283
5284 assert!(
5285 (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5286 "colmajor bump {col_major_bump:e} must match rowmajor bump \
5287 {row_major_bump:e} for a symmetric block"
5288 );
5289
5290 let mut shifted = a.clone();
5292 for i in 0..d {
5293 shifted[[i, i]] += col_major_bump;
5294 }
5295 assert!(
5296 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5297 "colmajor Gershgorin bump must make the symmetric block PD"
5298 );
5299 }
5300
5301 fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5302 let mut s = Array2::<f64>::zeros((k, k));
5303 for row in 0..k {
5304 s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5305 if row + 1 < k {
5306 s[[row, row + 1]] = -0.05;
5307 s[[row + 1, row]] = -0.05;
5308 }
5309 if row + 7 < k {
5310 s[[row, row + 7]] = 0.01;
5311 s[[row + 7, row]] = 0.01;
5312 }
5313 }
5314 let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5315 (s, rhs)
5316 }
5317
5318 fn dense_pcg_cpu_reference(
5319 s: &Array2<f64>,
5320 rhs: &Array1<f64>,
5321 max_iterations: usize,
5322 relative_tolerance: f64,
5323 ) -> Array1<f64> {
5324 let k = rhs.len();
5325 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5326 if rhs_norm == 0.0 {
5327 return Array1::<f64>::zeros(k);
5328 }
5329 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5330 let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5331 let mut x = Array1::<f64>::zeros(k);
5332 let mut r = rhs.clone();
5333 let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5334 let mut p = z.clone();
5335 let mut sp = Array1::<f64>::zeros(k);
5336 let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5337 for _ in 0..max_iterations.max(1) {
5338 for row in 0..k {
5339 let mut acc = 0.0;
5340 for col in 0..k {
5341 acc += s[[row, col]] * p[col];
5342 }
5343 sp[row] = acc;
5344 }
5345 let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5346 let alpha = rz / p_sp;
5347 for idx in 0..k {
5348 x[idx] += alpha * p[idx];
5349 r[idx] -= alpha * sp[idx];
5350 }
5351 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5352 if r_norm <= tol {
5353 break;
5354 }
5355 for idx in 0..k {
5356 z[idx] = inv_diag[idx] * r[idx];
5357 }
5358 let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5359 let beta = rz_next / rz;
5360 for idx in 0..k {
5361 p[idx] = z[idx] + beta * p[idx];
5362 }
5363 rz = rz_next;
5364 }
5365 x
5366 }
5367
5368 #[test]
5369 fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5370 let (s, rhs) = device_pcg_fixture(512);
5371 let max_iterations = 200usize;
5372 let relative_tolerance = 1.0e-12;
5373 let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5374 let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5375 &s,
5376 &rhs,
5377 max_iterations,
5378 relative_tolerance,
5379 ) {
5380 Ok(result) => result,
5381 Err(failure) => {
5388 assert!(
5389 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5390 "#1017: CUDA device present but the device reduced-beta PCG \
5391 declined/faulted instead of returning a result (tag: {failure:?}) — \
5392 the kernel does not run correctly on GPU"
5393 );
5394 return;
5395 }
5396 };
5397 let max_err = cpu
5398 .iter()
5399 .zip(device.iter())
5400 .map(|(a, b)| (a - b).abs())
5401 .fold(0.0_f64, f64::max);
5402 assert!(
5403 max_err <= 1.0e-10,
5404 "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5405 );
5406 assert!(diag.matvec_calls > 0);
5407 assert_eq!(diag.matvec_calls, diag.iterations);
5408 }
5409
5410 #[test]
5411 fn dense_reference_matches_independent_solve() {
5412 let sys = build_fixture(4, 5, 3, 7);
5413 let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5414 let n = sys.rows.len();
5418 let d = sys.d;
5419 let k = sys.k;
5420 let total = n * d + k;
5421 let mut h = Array2::<f64>::zeros((total, total));
5422 let mut g = ndarray::Array1::<f64>::zeros(total);
5423 for (i, row) in sys.rows.iter().enumerate() {
5424 let base = i * d;
5425 for c in 0..d {
5426 for r in 0..d {
5427 h[[base + r, base + c]] = row.htt[[r, c]];
5428 }
5429 }
5430 for c in 0..k {
5431 for r in 0..d {
5432 h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5433 h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5434 }
5435 }
5436 for r in 0..d {
5437 g[base + r] = row.gt[r];
5438 }
5439 }
5440 for c in 0..k {
5441 for r in 0..k {
5442 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5443 }
5444 g[n * d + c] = sys.gb[c];
5445 }
5446 let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5447 let rhs = g.mapv(|v| -v);
5448 let expected = cholesky_solve_vector(l.view(), rhs.view());
5449 for i in 0..n * d {
5450 assert!(
5451 (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5452 "delta_t[{i}] mismatch: got {} expected {}",
5453 solution.delta_t[i],
5454 expected[i]
5455 );
5456 }
5457 for a in 0..k {
5458 assert!(
5459 (solution.delta_beta[a] - expected[n * d + a]).abs()
5460 < 1e-10 * (1.0 + expected[n * d + a].abs()),
5461 "delta_beta[{a}] mismatch"
5462 );
5463 }
5464 }
5465
5466 #[test]
5480 fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5481 use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5482 let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
5484 let k = 24usize;
5485 let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5486 let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5491 let forward_slabs = slabs.clone();
5492 let transpose_slabs = slabs;
5493 sys.set_row_htbeta_operator(
5494 move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5495 let h = &forward_slabs[row];
5496 for r in 0..h.nrows() {
5497 let mut acc = 0.0_f64;
5498 for c in 0..h.ncols() {
5499 acc += h[[r, c]] * x[c];
5500 }
5501 out[r] = acc;
5502 }
5503 },
5504 move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5505 let h = &transpose_slabs[row];
5506 for r in 0..h.nrows() {
5507 for c in 0..h.ncols() {
5508 out[c] += h[[r, c]] * v[r];
5509 }
5510 }
5511 },
5512 );
5513
5514 let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5515 .expect("row-procedural matvec backend builds for matrix-free system");
5516 let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5517
5518 let mut out_parallel_a = Array1::<f64>::zeros(k);
5522 matvec(&x, &mut out_parallel_a);
5523 let mut out_parallel_b = Array1::<f64>::zeros(k);
5524 matvec(&x, &mut out_parallel_b);
5525 for a in 0..k {
5526 assert_eq!(
5527 out_parallel_a[a].to_bits(),
5528 out_parallel_b[a].to_bits(),
5529 "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5530 );
5531 }
5532
5533 let mut out_serial = Array1::<f64>::zeros(k);
5538 rayon::ThreadPoolBuilder::new()
5539 .num_threads(2)
5540 .build()
5541 .expect("build rayon pool")
5542 .install(|| matvec(&x, &mut out_serial));
5543
5544 let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5545 for a in 0..k {
5546 let diff = (out_parallel_a[a] - out_serial[a]).abs();
5547 assert!(
5548 diff <= 1e-12 * (1.0 + max_abs),
5549 "row-procedural matvec parallel vs serial diverged beyond reassociation \
5550 at index {a}: {} vs {} (diff={diff:e})",
5551 out_parallel_a[a],
5552 out_serial[a]
5553 );
5554 }
5555 }
5556
5557 #[test]
5564 fn framed_sae_schur_matvec_matches_dense_reference() {
5565 use crate::arrow_schur::{
5566 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5567 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5568 };
5569
5570 let p = 4usize;
5571 let ranks = vec![2usize, 4usize, 3usize];
5573 let basis_sizes = vec![2usize, 1usize, 2usize];
5574 let n_atoms = ranks.len();
5575 let mut border_offsets = Vec::with_capacity(n_atoms);
5576 let mut acc = 0usize;
5577 for k in 0..n_atoms {
5578 border_offsets.push(acc);
5579 acc += basis_sizes[k] * ranks[k];
5580 }
5581 let border_dim = acc; let mut state = 0x1234_5678_9abc_def0u64;
5584 let mut sample = || -> f64 {
5585 state = state
5586 .wrapping_mul(6364136223846793005)
5587 .wrapping_add(1442695040888963407);
5588 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5589 };
5590
5591 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5594 for k in 0..n_atoms {
5595 let r = ranks[k];
5596 let mut u = Array2::<f64>::zeros((p, r));
5597 for i in 0..p {
5598 for j in 0..r {
5599 u[[i, j]] = if r == p && i == j {
5600 1.0
5601 } else if r == p {
5602 0.0
5603 } else {
5604 sample()
5605 };
5606 }
5607 }
5608 frames.push(u);
5609 }
5610 let w_of = |i: usize, j: usize| -> Array2<f64> {
5611 let (ui, uj) = (&frames[i], &frames[j]);
5612 let (ri, rj) = (ranks[i], ranks[j]);
5613 let mut w = Array2::<f64>::zeros((ri, rj));
5614 for a in 0..ri {
5615 for b in 0..rj {
5616 let mut s = 0.0;
5617 for c in 0..p {
5618 s += ui[[c, a]] * uj[[c, b]];
5619 }
5620 w[[a, b]] = s;
5621 }
5622 }
5623 w
5624 };
5625
5626 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5628 let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5629 pairs.sort();
5630 for &(i, j) in &pairs {
5631 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5632 let mut g = Array2::<f64>::zeros((mi, mj));
5633 for r in 0..mi {
5634 for c in 0..mj {
5635 g[[r, c]] = 0.3 * sample();
5636 }
5637 }
5638 if i == j {
5640 for r in 0..mi.min(mj) {
5641 g[[r, r]] += mi as f64 + 2.0;
5642 }
5643 }
5644 frame_blocks.push(FactoredFrameGBlock {
5645 atom_i: i,
5646 atom_j: j,
5647 g,
5648 w: w_of(i, j),
5649 });
5650 }
5651
5652 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5654 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5655 for k in 0..n_atoms {
5656 let m = basis_sizes[k];
5657 let mut a = Array2::<f64>::zeros((m, m));
5658 for r in 0..m {
5659 for c in 0..m {
5660 a[[r, c]] = 0.2 * sample();
5661 }
5662 }
5663 let mut s = a.t().dot(&a);
5664 for r in 0..m {
5665 s[[r, r]] += 1.0;
5666 }
5667 smooth_blocks.push(DeviceSaeSmoothBlock {
5668 global_offset: border_offsets[k],
5669 factor_a: s,
5670 });
5671 smooth_ranks.push(ranks[k]);
5672 }
5673
5674 let n = 6usize;
5676 let q = 3usize;
5677 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5678 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5679 for i in 0..n {
5680 let mut a = Array2::<f64>::zeros((q, q));
5682 for r in 0..q {
5683 for c in 0..q {
5684 a[[r, c]] = sample();
5685 }
5686 }
5687 let mut htt = a.t().dot(&a);
5688 for r in 0..q {
5689 htt[[r, r]] += q as f64 + 1.0;
5690 }
5691 sys.rows[i].htt = htt;
5692 let mut slab = vec![0.0_f64; q * border_dim];
5693 for c in 0..q {
5694 for col in 0..border_dim {
5695 let v = 0.15 * sample();
5696 slab[c * border_dim + col] = v;
5697 sys.rows[i].htbeta[[c, col]] = v;
5698 }
5699 }
5700 row_htbeta.push(slab);
5701 }
5702
5703 let data_op =
5706 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5707 .expect("frame op");
5708 let mut hbb = data_op.to_dense();
5709 for k in 0..n_atoms {
5710 let op = IdentityRightKroneckerPenaltyOp {
5711 factor_a: smooth_blocks[k].factor_a.clone(),
5712 p: ranks[k],
5713 global_offset: border_offsets[k],
5714 k: border_dim,
5715 };
5716 let d = op.to_dense();
5717 for r in 0..border_dim {
5718 for c in 0..border_dim {
5719 hbb[[r, c]] += d[[r, c]];
5720 }
5721 }
5722 }
5723 sys.hbb = hbb;
5724
5725 let data = DeviceSaePcgData {
5726 p,
5727 beta_dim: border_dim,
5728 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5729 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5730 smooth_blocks,
5731 sparse_g_blocks: Vec::new(),
5732 frame: Some(DeviceSaeFrameData {
5733 ranks: ranks.clone(),
5734 basis_sizes: basis_sizes.clone(),
5735 border_offsets: border_offsets.clone(),
5736 frame_blocks,
5737 smooth_ranks,
5738 row_htbeta,
5739 }),
5740 };
5741
5742 let ridge_t = 1e-7;
5743 let ridge_beta = 1e-6;
5744
5745 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5749 for r in 0..border_dim {
5750 for c in 0..border_dim {
5751 s_dense[[r, c]] = sys.hbb[[r, c]];
5752 }
5753 s_dense[[r, r]] += ridge_beta;
5754 }
5755 for row in &sys.rows {
5756 let mut htt = row.htt.clone();
5757 for d in 0..q {
5758 htt[[d, d]] += ridge_t;
5759 }
5760 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5761 .expect("htt PD");
5762 let mut y = Array2::<f64>::zeros((q, border_dim));
5764 for col in 0..border_dim {
5765 let mut e = Array1::<f64>::zeros(q);
5766 for r in 0..q {
5767 e[r] = row.htbeta[[r, col]];
5768 }
5769 let solved = cholesky_solve_vector(factor.view(), e.view());
5770 for r in 0..q {
5771 y[[r, col]] = solved[r];
5772 }
5773 }
5774 for r in 0..border_dim {
5775 for c in 0..border_dim {
5776 let mut acc = 0.0;
5777 for d in 0..q {
5778 acc += row.htbeta[[d, r]] * y[[d, c]];
5779 }
5780 s_dense[[r, c]] -= acc;
5781 }
5782 }
5783 }
5784
5785 let mut max_rel = 0.0_f64;
5787 for trial in 0..4 {
5788 let x: Vec<f64> = (0..border_dim)
5789 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5790 .collect();
5791 let mut got = vec![0.0_f64; border_dim];
5792 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5793 .expect("framed matvec");
5794 let mut want = vec![0.0_f64; border_dim];
5795 for r in 0..border_dim {
5796 let mut acc = 0.0;
5797 for c in 0..border_dim {
5798 acc += s_dense[[r, c]] * x[c];
5799 }
5800 want[r] = acc;
5801 }
5802 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5803 for a in 0..border_dim {
5804 let rel = (got[a] - want[a]).abs() / scale;
5805 max_rel = max_rel.max(rel);
5806 }
5807 }
5808 assert!(
5809 max_rel <= 1e-10,
5810 "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5811 );
5812 }
5813
5814 #[test]
5824 fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5825 use crate::arrow_schur::{
5826 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5827 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5828 };
5829
5830 let p = 12usize;
5831 let n_atoms = 40usize;
5832 let ranks: Vec<usize> = (0..n_atoms)
5835 .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5836 .collect();
5837 let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5838 let mut border_offsets = Vec::with_capacity(n_atoms);
5839 let mut acc = 0usize;
5840 for k in 0..n_atoms {
5841 border_offsets.push(acc);
5842 acc += basis_sizes[k] * ranks[k];
5843 }
5844 let border_dim = acc;
5845
5846 let mut state = 0x0bad_c0de_dead_beefu64;
5847 let mut sample = || -> f64 {
5848 state = state
5849 .wrapping_mul(6364136223846793005)
5850 .wrapping_add(1442695040888963407);
5851 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5852 };
5853
5854 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5856 for k in 0..n_atoms {
5857 let r = ranks[k];
5858 let mut u = Array2::<f64>::zeros((p, r));
5859 for i in 0..p {
5860 for j in 0..r {
5861 u[[i, j]] = if r == p {
5862 if i == j { 1.0 } else { 0.0 }
5863 } else {
5864 sample()
5865 };
5866 }
5867 }
5868 frames.push(u);
5869 }
5870 let w_of = |i: usize, j: usize| -> Array2<f64> {
5871 let (ui, uj) = (&frames[i], &frames[j]);
5872 let (ri, rj) = (ranks[i], ranks[j]);
5873 let mut w = Array2::<f64>::zeros((ri, rj));
5874 for a in 0..ri {
5875 for b in 0..rj {
5876 let mut s = 0.0;
5877 for c in 0..p {
5878 s += ui[[c, a]] * uj[[c, b]];
5879 }
5880 w[[a, b]] = s;
5881 }
5882 }
5883 w
5884 };
5885
5886 let mut pairs: Vec<(usize, usize)> = Vec::new();
5889 for k in 0..n_atoms {
5890 pairs.push((k, k));
5891 }
5892 for k in 0..n_atoms - 1 {
5893 pairs.push((k, k + 1));
5894 pairs.push((k + 1, k));
5895 }
5896 pairs.sort_unstable();
5897 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5898 for &(i, j) in &pairs {
5899 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5900 let mut g = Array2::<f64>::zeros((mi, mj));
5901 for r in 0..mi {
5902 for c in 0..mj {
5903 g[[r, c]] = 0.3 * sample();
5904 }
5905 }
5906 if i == j {
5907 for r in 0..mi.min(mj) {
5908 g[[r, r]] += mi as f64 + 2.0;
5909 }
5910 }
5911 frame_blocks.push(FactoredFrameGBlock {
5912 atom_i: i,
5913 atom_j: j,
5914 g,
5915 w: w_of(i, j),
5916 });
5917 }
5918
5919 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5921 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5922 for k in 0..n_atoms {
5923 let m = basis_sizes[k];
5924 let mut a = Array2::<f64>::zeros((m, m));
5925 for r in 0..m {
5926 for c in 0..m {
5927 a[[r, c]] = 0.2 * sample();
5928 }
5929 }
5930 let mut s = a.t().dot(&a);
5931 for r in 0..m {
5932 s[[r, r]] += 1.0;
5933 }
5934 smooth_blocks.push(DeviceSaeSmoothBlock {
5935 global_offset: border_offsets[k],
5936 factor_a: s,
5937 });
5938 smooth_ranks.push(ranks[k]);
5939 }
5940
5941 let n = 8usize;
5943 let q = 3usize;
5944 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5945 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5946 for i in 0..n {
5947 let mut a = Array2::<f64>::zeros((q, q));
5948 for r in 0..q {
5949 for c in 0..q {
5950 a[[r, c]] = sample();
5951 }
5952 }
5953 let mut htt = a.t().dot(&a);
5954 for r in 0..q {
5955 htt[[r, r]] += q as f64 + 1.0;
5956 }
5957 sys.rows[i].htt = htt;
5958 let mut slab = vec![0.0_f64; q * border_dim];
5959 for c in 0..q {
5960 for col in 0..border_dim {
5961 let v = 0.15 * sample();
5962 slab[c * border_dim + col] = v;
5963 sys.rows[i].htbeta[[c, col]] = v;
5964 }
5965 }
5966 row_htbeta.push(slab);
5967 }
5968
5969 let data_op =
5972 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5973 .expect("frame op");
5974 let mut hbb = data_op.to_dense();
5975 for k in 0..n_atoms {
5976 let op = IdentityRightKroneckerPenaltyOp {
5977 factor_a: smooth_blocks[k].factor_a.clone(),
5978 p: ranks[k],
5979 global_offset: border_offsets[k],
5980 k: border_dim,
5981 };
5982 let d = op.to_dense();
5983 for r in 0..border_dim {
5984 for c in 0..border_dim {
5985 hbb[[r, c]] += d[[r, c]];
5986 }
5987 }
5988 }
5989 sys.hbb = hbb;
5990
5991 let data = DeviceSaePcgData {
5992 p,
5993 beta_dim: border_dim,
5994 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5995 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5996 smooth_blocks,
5997 sparse_g_blocks: Vec::new(),
5998 frame: Some(DeviceSaeFrameData {
5999 ranks: ranks.clone(),
6000 basis_sizes: basis_sizes.clone(),
6001 border_offsets: border_offsets.clone(),
6002 frame_blocks,
6003 smooth_ranks,
6004 row_htbeta,
6005 }),
6006 };
6007
6008 let ridge_t = 1e-7;
6009 let ridge_beta = 1e-6;
6010
6011 let mut s_dense = sys.hbb.clone();
6013 for r in 0..border_dim {
6014 s_dense[[r, r]] += ridge_beta;
6015 }
6016 for row in &sys.rows {
6017 let mut htt = row.htt.clone();
6018 for d in 0..q {
6019 htt[[d, d]] += ridge_t;
6020 }
6021 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
6022 .expect("htt PD");
6023 let mut y = Array2::<f64>::zeros((q, border_dim));
6024 for col in 0..border_dim {
6025 let mut e = Array1::<f64>::zeros(q);
6026 for r in 0..q {
6027 e[r] = row.htbeta[[r, col]];
6028 }
6029 let solved = cholesky_solve_vector(factor.view(), e.view());
6030 for r in 0..q {
6031 y[[r, col]] = solved[r];
6032 }
6033 }
6034 for r in 0..border_dim {
6035 for c in 0..border_dim {
6036 let mut acc = 0.0;
6037 for d in 0..q {
6038 acc += row.htbeta[[d, r]] * y[[d, c]];
6039 }
6040 s_dense[[r, c]] -= acc;
6041 }
6042 }
6043 }
6044
6045 let mut max_rel = 0.0_f64;
6046 for trial in 0..4 {
6047 let x: Vec<f64> = (0..border_dim)
6048 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6049 .collect();
6050 let mut got = vec![0.0_f64; border_dim];
6051 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6052 .expect("framed matvec");
6053 let mut want = vec![0.0_f64; border_dim];
6054 for r in 0..border_dim {
6055 let mut acc = 0.0;
6056 for c in 0..border_dim {
6057 acc += s_dense[[r, c]] * x[c];
6058 }
6059 want[r] = acc;
6060 }
6061 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6062 for a in 0..border_dim {
6063 let rel = (got[a] - want[a]).abs() / scale;
6064 max_rel = max_rel.max(rel);
6065 }
6066 }
6067 assert!(
6068 max_rel <= 1e-10,
6069 "large-K framed SAE Schur matvec vs dense reference diverged: \
6070 max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6071 );
6072 }
6073
6074 #[test]
6080 fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6081 use crate::arrow_schur::{
6082 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6083 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6084 };
6085
6086 let p = 6usize;
6090 let n_atoms = 8usize;
6091 let ranks: Vec<usize> = (0..n_atoms)
6092 .map(|k| if k % 2 == 0 { 3usize } else { p })
6093 .collect();
6094 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6095 let mut border_offsets = Vec::with_capacity(n_atoms);
6096 let mut acc = 0usize;
6097 for k in 0..n_atoms {
6098 border_offsets.push(acc);
6099 acc += basis_sizes[k] * ranks[k];
6100 }
6101 let border_dim = acc; let mut state = 0xfeed_face_dead_beefu64;
6104 let mut sample = || -> f64 {
6105 state = state
6106 .wrapping_mul(6364136223846793005)
6107 .wrapping_add(1442695040888963407);
6108 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6109 };
6110 let mut frames: Vec<Array2<f64>> = Vec::new();
6111 for k in 0..n_atoms {
6112 let r = ranks[k];
6113 let mut u = Array2::<f64>::zeros((p, r));
6114 for i in 0..p {
6115 for j in 0..r {
6116 u[[i, j]] = if r == p && i == j {
6117 1.0
6118 } else if r == p {
6119 0.0
6120 } else {
6121 sample()
6122 };
6123 }
6124 }
6125 frames.push(u);
6126 }
6127 let w_of = |i: usize, j: usize| {
6128 let (ui, uj) = (&frames[i], &frames[j]);
6129 let (ri, rj) = (ranks[i], ranks[j]);
6130 let mut w = Array2::<f64>::zeros((ri, rj));
6131 for a in 0..ri {
6132 for b in 0..rj {
6133 let mut s = 0.0;
6134 for c in 0..p {
6135 s += ui[[c, a]] * uj[[c, b]];
6136 }
6137 w[[a, b]] = s;
6138 }
6139 }
6140 w
6141 };
6142 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6143 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6145 pairs.push((i, j));
6146 pairs.push((j, i));
6147 }
6148 let mut frame_blocks = Vec::new();
6149 for &(i, j) in &pairs {
6150 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6151 let mut g = Array2::<f64>::zeros((mi, mj));
6152 for r in 0..mi {
6153 for c in 0..mj {
6154 g[[r, c]] = 0.25 * sample();
6155 }
6156 }
6157 if i == j {
6158 for r in 0..mi.min(mj) {
6159 g[[r, r]] += mi as f64 + 2.0;
6160 }
6161 }
6162 frame_blocks.push(FactoredFrameGBlock {
6163 atom_i: i,
6164 atom_j: j,
6165 g,
6166 w: w_of(i, j),
6167 });
6168 }
6169 let mut smooth_blocks = Vec::new();
6170 let mut smooth_ranks = Vec::new();
6171 for k in 0..n_atoms {
6172 let m = basis_sizes[k];
6173 let mut a = Array2::<f64>::zeros((m, m));
6174 for r in 0..m {
6175 for c in 0..m {
6176 a[[r, c]] = 0.2 * sample();
6177 }
6178 }
6179 let mut s = a.t().dot(&a);
6180 for r in 0..m {
6181 s[[r, r]] += 1.0;
6182 }
6183 smooth_blocks.push(DeviceSaeSmoothBlock {
6184 global_offset: border_offsets[k],
6185 factor_a: s,
6186 });
6187 smooth_ranks.push(ranks[k]);
6188 }
6189 let n = 400usize;
6190 let q = 4usize;
6191 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6192 let mut row_htbeta = Vec::new();
6193 for i in 0..n {
6194 let mut a = Array2::<f64>::zeros((q, q));
6195 for r in 0..q {
6196 for c in 0..q {
6197 a[[r, c]] = sample();
6198 }
6199 }
6200 let mut htt = a.t().dot(&a);
6201 for r in 0..q {
6202 htt[[r, r]] += q as f64 + 1.0;
6203 }
6204 sys.rows[i].htt = htt;
6205 let mut slab = vec![0.0_f64; q * border_dim];
6206 for c in 0..q {
6207 for col in 0..border_dim {
6208 let v = 0.02 * sample();
6211 slab[c * border_dim + col] = v;
6212 sys.rows[i].htbeta[[c, col]] = v;
6213 }
6214 }
6215 row_htbeta.push(slab);
6216 }
6217 let data_op =
6218 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6219 .expect("frame op");
6220 let mut hbb = data_op.to_dense();
6221 for k in 0..n_atoms {
6222 let op = IdentityRightKroneckerPenaltyOp {
6223 factor_a: smooth_blocks[k].factor_a.clone(),
6224 p: ranks[k],
6225 global_offset: border_offsets[k],
6226 k: border_dim,
6227 };
6228 let d = op.to_dense();
6229 for r in 0..border_dim {
6230 for c in 0..border_dim {
6231 hbb[[r, c]] += d[[r, c]];
6232 }
6233 }
6234 }
6235 sys.hbb = hbb;
6236 let data = DeviceSaePcgData {
6237 p,
6238 beta_dim: border_dim,
6239 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6240 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6241 smooth_blocks,
6242 sparse_g_blocks: Vec::new(),
6243 frame: Some(DeviceSaeFrameData {
6244 ranks: ranks.clone(),
6245 basis_sizes: basis_sizes.clone(),
6246 border_offsets: border_offsets.clone(),
6247 frame_blocks,
6248 smooth_ranks,
6249 row_htbeta,
6250 }),
6251 };
6252 let ridge_t = 1e-7;
6253 let ridge_beta = 1e-6;
6254 let rhs: Array1<f64> =
6255 Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6256
6257 let (device, diag) =
6258 match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6259 Ok(result) => result,
6260 Err(failure) => {
6266 assert!(
6267 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6268 "#1017: CUDA device present but the framed device SAE PCG \
6269 declined/faulted instead of returning a result (tag: {failure:?}) — \
6270 the kernel does not run correctly on GPU"
6271 );
6272 return;
6273 }
6274 };
6275
6276 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6294 let oracle_resid = |x: &Array1<f64>| -> f64 {
6295 let mut sx = vec![0.0_f64; border_dim];
6296 sae_framed_schur_matvec_cpu(
6297 &sys,
6298 &data,
6299 ridge_t,
6300 ridge_beta,
6301 x.as_slice().unwrap(),
6302 &mut sx,
6303 )
6304 .expect("cpu oracle matvec");
6305 let mut acc = 0.0_f64;
6306 for a in 0..border_dim {
6307 let e = sx[a] - rhs[a];
6308 acc += e * e;
6309 }
6310 acc.sqrt()
6311 };
6312 let s_dev_resid = oracle_resid(&device);
6313 let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6314
6315 let precond = {
6320 let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6321 let mut diag = d;
6324 for (i, row) in sys.rows.iter().enumerate() {
6325 let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6326 let qi = sys.row_dims[i];
6327 if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6328 continue;
6329 }
6330 let mut block = row.htt.clone();
6331 for dd in 0..qi {
6332 block[[dd, dd]] += ridge_t;
6333 }
6334 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6335 .expect("row htt PD");
6336 let mut ainv = Array2::<f64>::zeros((qi, qi));
6338 for col in 0..qi {
6339 let mut e = Array1::<f64>::zeros(qi);
6340 e[col] = 1.0;
6341 let s = cholesky_solve_vector(factor.view(), e.view());
6342 for r in 0..qi {
6343 ainv[[r, col]] = s[r];
6344 }
6345 }
6346 for a in 0..border_dim {
6347 let mut quad = 0.0_f64;
6348 for c in 0..qi {
6349 let hc = slab[c * border_dim + a];
6350 for dd in 0..qi {
6351 quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6352 }
6353 }
6354 diag[a] -= quad;
6355 }
6356 }
6357 Array1::from_vec(diag)
6358 };
6359 let mut cpu = Array1::<f64>::zeros(border_dim);
6360 let cpu_result = {
6361 let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6362 let mut tmp = vec![0.0_f64; border_dim];
6363 sae_framed_schur_matvec_cpu(
6364 &sys,
6365 &data,
6366 ridge_t,
6367 ridge_beta,
6368 v.as_slice().unwrap(),
6369 &mut tmp,
6370 )
6371 .expect("cpu oracle matvec");
6372 out.assign(&Array1::from_vec(tmp));
6373 };
6374 gam_linalg::pcg::pcg_core(
6375 &mut apply,
6376 &rhs.view(),
6377 &precond.view(),
6378 1e-12,
6379 800,
6380 32,
6381 false,
6382 gam_linalg::pcg::DotReduction::Serial,
6383 &mut cpu.view_mut(),
6384 )
6385 };
6386 let s_cpu_resid = oracle_resid(&cpu);
6387 let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6388
6389 assert!(
6392 dev_rel_resid <= 1e-7,
6393 "[#1551] device δβ does not solve the CPU-oracle system: \
6394 ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6395 device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6396 means the device matvec is a DIFFERENT operator (kernel bug)",
6397 diag.stopping_reason,
6398 diag.iterations,
6399 diag.final_relative_residual,
6400 );
6401 assert!(
6404 cpu_rel_resid <= 1e-6,
6405 "[#1551] CPU pcg_core failed to solve the oracle system: \
6406 ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6407 cpu_result.stop,
6408 cpu_result.iterations,
6409 );
6410 }
6411
6412 fn sae_frame_penalty_diag_host_for_test(
6416 data: &DeviceSaePcgData,
6417 ridge_beta: f64,
6418 ) -> Vec<f64> {
6419 let frame = data.frame.as_ref().expect("frame");
6420 let mut diag = vec![ridge_beta; data.beta_dim];
6421 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6422 let m = blk.factor_a.nrows();
6423 for ia in 0..m {
6424 let coeff = blk.factor_a[[ia, ia]];
6425 let base = blk.global_offset + ia * r;
6426 for ib in 0..r {
6427 diag[base + ib] += coeff;
6428 }
6429 }
6430 }
6431 for blk in &frame.frame_blocks {
6432 if blk.atom_i != blk.atom_j {
6433 continue;
6434 }
6435 let r = frame.ranks[blk.atom_i];
6436 let off = frame.border_offsets[blk.atom_i];
6437 let (mi, mj) = blk.g.dim();
6438 for li in 0..mi.min(mj) {
6439 let gii = blk.g[[li, li]];
6440 let base = off + li * r;
6441 for a in 0..r {
6442 diag[base + a] += gii * blk.w[[a, a]];
6443 }
6444 }
6445 }
6446 diag
6447 }
6448
6449 #[test]
6460 fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6461 use crate::arrow_schur::{
6462 DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6463 };
6464
6465 let p = 6usize;
6468 let n_atoms = 8usize;
6469 let ranks: Vec<usize> = (0..n_atoms)
6470 .map(|k| if k % 2 == 0 { 3usize } else { p })
6471 .collect();
6472 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6473 let mut border_offsets = Vec::with_capacity(n_atoms);
6474 let mut acc = 0usize;
6475 for k in 0..n_atoms {
6476 border_offsets.push(acc);
6477 acc += basis_sizes[k] * ranks[k];
6478 }
6479 let border_dim = acc;
6480
6481 let mut state = 0x1551_0017_1026_0922u64;
6482 let mut sample = || -> f64 {
6483 state = state
6484 .wrapping_mul(6364136223846793005)
6485 .wrapping_add(1442695040888963407);
6486 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6487 };
6488 let mut frames: Vec<Array2<f64>> = Vec::new();
6489 for k in 0..n_atoms {
6490 let r = ranks[k];
6491 let mut u = Array2::<f64>::zeros((p, r));
6492 for i in 0..p {
6493 for j in 0..r {
6494 u[[i, j]] = if r == p && i == j {
6495 1.0
6496 } else if r == p {
6497 0.0
6498 } else {
6499 sample()
6500 };
6501 }
6502 }
6503 frames.push(u);
6504 }
6505 let w_of = |i: usize, j: usize| {
6506 let (ui, uj) = (&frames[i], &frames[j]);
6507 let (ri, rj) = (ranks[i], ranks[j]);
6508 let mut w = Array2::<f64>::zeros((ri, rj));
6509 for a in 0..ri {
6510 for b in 0..rj {
6511 let mut s = 0.0;
6512 for c in 0..p {
6513 s += ui[[c, a]] * uj[[c, b]];
6514 }
6515 w[[a, b]] = s;
6516 }
6517 }
6518 w
6519 };
6520 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6521 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6522 pairs.push((i, j));
6523 pairs.push((j, i));
6524 }
6525 let mut frame_blocks = Vec::new();
6526 for &(i, j) in &pairs {
6527 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6528 let mut g = Array2::<f64>::zeros((mi, mj));
6529 for r in 0..mi {
6530 for c in 0..mj {
6531 g[[r, c]] = 0.25 * sample();
6532 }
6533 }
6534 if i == j {
6535 for r in 0..mi.min(mj) {
6536 g[[r, r]] += mi as f64 + 2.0;
6537 }
6538 }
6539 frame_blocks.push(FactoredFrameGBlock {
6540 atom_i: i,
6541 atom_j: j,
6542 g,
6543 w: w_of(i, j),
6544 });
6545 }
6546 let mut smooth_blocks = Vec::new();
6547 let mut smooth_ranks = Vec::new();
6548 for k in 0..n_atoms {
6549 let m = basis_sizes[k];
6550 let mut a = Array2::<f64>::zeros((m, m));
6551 for r in 0..m {
6552 for c in 0..m {
6553 a[[r, c]] = 0.2 * sample();
6554 }
6555 }
6556 let mut s = a.t().dot(&a);
6557 for r in 0..m {
6558 s[[r, r]] += 1.0;
6559 }
6560 smooth_blocks.push(DeviceSaeSmoothBlock {
6561 global_offset: border_offsets[k],
6562 factor_a: s,
6563 });
6564 smooth_ranks.push(ranks[k]);
6565 }
6566 let n = 32usize;
6570 let q = 4usize;
6571 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6572 let mut row_htbeta = Vec::new();
6573 for i in 0..n {
6574 let mut a = Array2::<f64>::zeros((q, q));
6575 for r in 0..q {
6576 for c in 0..q {
6577 a[[r, c]] = sample();
6578 }
6579 }
6580 let mut htt = a.t().dot(&a);
6581 for r in 0..q {
6582 htt[[r, r]] += q as f64 + 1.0;
6583 }
6584 sys.rows[i].htt = htt;
6585 let mut slab = vec![0.0_f64; q * border_dim];
6586 for c in 0..q {
6587 for col in 0..border_dim {
6588 let v = 0.3 * sample();
6589 slab[c * border_dim + col] = v;
6590 sys.rows[i].htbeta[[c, col]] = v;
6591 }
6592 }
6593 row_htbeta.push(slab);
6594 }
6595 let ridge_t = 1e-7;
6596 let ridge_beta = 1e-6;
6597 let data = DeviceSaePcgData {
6598 p,
6599 beta_dim: border_dim,
6600 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6601 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6602 smooth_blocks,
6603 sparse_g_blocks: Vec::new(),
6604 frame: Some(DeviceSaeFrameData {
6605 ranks: ranks.clone(),
6606 basis_sizes: basis_sizes.clone(),
6607 border_offsets: border_offsets.clone(),
6608 frame_blocks,
6609 smooth_ranks,
6610 row_htbeta,
6611 }),
6612 };
6613
6614 let mut probes: Vec<Array1<f64>> = Vec::new();
6618 probes.push(Array1::from_shape_fn(border_dim, |a| {
6619 ((a as f64 + 1.0) * 0.37).sin()
6620 }));
6621 probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6622 for axis in [0usize, border_dim / 3, border_dim - 1] {
6623 let mut e = Array1::<f64>::zeros(border_dim);
6624 e[axis] = 1.0;
6625 probes.push(e);
6626 }
6627
6628 let mut any_ran = false;
6629 let mut worst = 0.0_f64;
6630 for (pi, x) in probes.iter().enumerate() {
6631 let device = match super::framed_schur_matvec_once_on_device(
6632 &sys, &data, ridge_t, ridge_beta, x,
6633 ) {
6634 Ok(out) => out,
6635 Err(failure) => {
6636 assert!(
6640 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6641 "#1551: CUDA device present but the framed device matvec \
6642 declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6643 does not run on GPU"
6644 );
6645 return;
6646 }
6647 };
6648 any_ran = true;
6649 let mut cpu = vec![0.0_f64; border_dim];
6650 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6651 .expect("cpu oracle matvec");
6652 let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6653 for a in 0..border_dim {
6654 let rel = (device[a] - cpu[a]).abs() / scale;
6655 worst = worst.max(rel);
6656 assert!(
6657 rel <= 1e-9,
6658 "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6659 rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6660 device[a],
6661 cpu[a],
6662 );
6663 }
6664 }
6665 if any_ran {
6666 assert!(
6671 gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6672 "#1551: matvec ran but no GPU runtime — unexpected"
6673 );
6674 assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6675 }
6676 }
6677}