1use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24pub struct ArrowSchurGpuSolution {
26 pub delta_t: Array1<f64>,
27 pub delta_beta: Array1<f64>,
28 pub log_det_hessian: f64,
31}
32
33#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38 Unavailable,
40 RidgeBumpRequired { row: usize, bump: f64 },
43 SchurFactorFailed { reason: String },
46 GpuRequiresDenseSystem {
52 had_hbb_matvec: bool,
53 had_htbeta_matvec: bool,
54 },
55}
56
57const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106 let d = htt.nrows();
107 let mut scale = 1.0_f64;
110 let mut min_gershgorin_edge = f64::INFINITY;
111 for i in 0..d {
112 let diag = htt[[i, i]];
113 scale = scale.max(diag.abs());
114 let mut off_sum = 0.0_f64;
115 for j in 0..d {
116 if j != i {
117 off_sum += htt[[i, j]].abs();
118 }
119 }
120 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121 }
122 if !min_gershgorin_edge.is_finite() {
123 return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126 }
127 let deficit = -(min_gershgorin_edge + ridge_t);
130 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131 deficit.max(0.0) + margin
134}
135
136#[must_use]
143fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
144 if d == 0 || block.len() < d * d {
145 return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
146 }
147 let mut scale = 1.0_f64;
150 let mut min_gershgorin_edge = f64::INFINITY;
151 for i in 0..d {
152 let diag = block[i * d + i];
153 scale = scale.max(diag.abs());
154 let mut off_sum = 0.0_f64;
155 for j in 0..d {
156 if j != i {
157 off_sum += block[j * d + i].abs();
158 }
159 }
160 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
161 }
162 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
163 if !min_gershgorin_edge.is_finite() {
164 return margin;
165 }
166 (-min_gershgorin_edge).max(0.0) + margin
167}
168
169pub fn solve_arrow_newton_step(
173 sys: &ArrowSchurSystem,
174 ridge_t: f64,
175 ridge_beta: f64,
176) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
177 let n = sys.rows.len();
178 let d = sys.d;
179 let k = sys.k;
180
181 let had_hbb_matvec = sys.hbb_matvec.is_some();
186 let had_htbeta_matvec = sys.htbeta_matvec.is_some();
187 if had_hbb_matvec || had_htbeta_matvec {
188 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
189 had_hbb_matvec,
190 had_htbeta_matvec,
191 });
192 }
193
194 if sys.hbb.dim() != (k, k) {
195 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
196 reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
197 });
198 }
199 if n == 0 || d == 0 {
200 return Err(ArrowSchurGpuFailure::Unavailable);
201 }
202 if sys
203 .rows
204 .iter()
205 .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
206 {
207 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
208 reason: "row block dimension mismatch".to_string(),
209 });
210 }
211
212 #[cfg(not(target_os = "linux"))]
213 {
214 if ridge_t.is_nan() || ridge_beta.is_nan() {
215 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
216 reason: "ridge is NaN".to_string(),
217 });
218 }
219 Err(ArrowSchurGpuFailure::Unavailable)
220 }
221
222 #[cfg(target_os = "linux")]
223 {
224 if gam_gpu::device_runtime::GpuRuntime::global()
233 .map(gam_gpu::device_runtime::GpuRuntime::device_count)
234 .unwrap_or(0)
235 > 1
236 {
237 match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
238 Ok(sol) => return Ok(sol),
239 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
240 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
241 }
242 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
243 return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
244 }
245 Err(_) => {}
248 }
249 }
250 if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
256 match cuda::solve_fused(sys, ridge_t, ridge_beta) {
257 Ok(sol) => return Ok(sol),
258 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
262 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
263 }
264 Err(_) => {}
268 }
269 }
270 cuda::solve(sys, ridge_t, ridge_beta)
271 }
272}
273
274#[cfg(target_os = "linux")]
280fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
281 let n = sys.rows.len();
282 let d = sys.d;
283 let k = sys.k;
284 let mut d_buf = Vec::with_capacity(n * d * d);
285 let mut b_buf = Vec::with_capacity(n * d * k);
286 let mut g_buf = Vec::with_capacity(n * d);
287 for row in &sys.rows {
288 pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
289 }
290 (d_buf, b_buf, g_buf)
291}
292
293#[cfg(target_os = "linux")]
294#[inline]
295fn pack_block(
296 row: &crate::arrow_schur::ArrowRowBlock,
297 ridge_t: f64,
298 d: usize,
299 k: usize,
300 d_buf: &mut Vec<f64>,
301 b_buf: &mut Vec<f64>,
302 g_buf: &mut Vec<f64>,
303) {
304 for col in 0..d {
305 for r in 0..d {
306 let mut value = row.htt[[r, col]];
307 if r == col {
308 value += ridge_t;
309 }
310 d_buf.push(value);
311 }
312 }
313 for col in 0..k {
314 for r in 0..d {
315 b_buf.push(row.htbeta[[r, col]]);
316 }
317 }
318 for r in 0..d {
319 g_buf.push(row.gt[r]);
320 }
321}
322
323#[doc(hidden)]
328#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
330 sys: &ArrowSchurSystem,
331 ridge_t: f64,
332 ridge_beta: f64,
333) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
334 if ridge_t.is_nan() || ridge_beta.is_nan() {
335 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
336 reason: "ridge is NaN".to_string(),
337 });
338 }
339 #[cfg(not(target_os = "linux"))]
340 {
341 Err(ArrowSchurGpuFailure::Unavailable)
346 }
347 #[cfg(target_os = "linux")]
348 {
349 if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
350 .is_none()
351 {
352 return Err(ArrowSchurGpuFailure::Unavailable);
353 }
354 cuda::solve_fused(sys, ridge_t, ridge_beta)
355 }
356}
357
358pub struct ResidentArrowFrameHandle {
368 #[cfg(target_os = "linux")]
369 inner: cuda::ResidentArrowFrame,
370 #[cfg(not(target_os = "linux"))]
371 _never: std::convert::Infallible,
372}
373
374impl ResidentArrowFrameHandle {
375 pub fn new(
377 sys: &ArrowSchurSystem,
378 ridge_t: f64,
379 ridge_beta: f64,
380 ) -> Result<Self, ArrowSchurGpuFailure> {
381 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
384 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
385 had_hbb_matvec: sys.hbb_matvec.is_some(),
386 had_htbeta_matvec: sys.htbeta_matvec.is_some(),
387 });
388 }
389 #[cfg(not(target_os = "linux"))]
390 {
391 if ridge_t.is_nan() || ridge_beta.is_nan() {
392 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
393 reason: "ridge is NaN".to_string(),
394 });
395 }
396 Err(ArrowSchurGpuFailure::Unavailable)
397 }
398 #[cfg(target_os = "linux")]
399 {
400 Ok(Self {
401 inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
402 })
403 }
404 }
405
406 pub fn solve_gradient(
408 &self,
409 g_t: &[f64],
410 g_beta: &[f64],
411 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
412 #[cfg(not(target_os = "linux"))]
413 {
414 if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
415 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
416 reason: "non-finite gradient entry".to_string(),
417 });
418 }
419 Err(ArrowSchurGpuFailure::Unavailable)
420 }
421 #[cfg(target_os = "linux")]
422 {
423 self.inner.solve_gradient(g_t, g_beta)
424 }
425 }
426
427 #[must_use]
429 pub fn log_det_hessian(&self) -> f64 {
430 #[cfg(not(target_os = "linux"))]
431 {
432 panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
438 }
439 #[cfg(target_os = "linux")]
440 {
441 self.inner.log_det_hessian()
442 }
443 }
444}
445
446pub fn gpu_schur_matvec_backend(
489 sys: &ArrowSchurSystem,
490 ridge_t: f64,
491 ridge_beta: f64,
492) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
493 if sys.htbeta_matvec.is_some() {
496 return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
497 }
498
499 #[cfg(not(target_os = "linux"))]
500 {
501 if ridge_t.is_nan() || ridge_beta.is_nan() {
504 return Err(ArrowSchurGpuFailure::Unavailable);
505 }
506 Err(ArrowSchurGpuFailure::Unavailable)
507 }
508
509 #[cfg(target_os = "linux")]
510 {
511 cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
512 }
513}
514
515fn build_row_procedural_matvec(
532 sys: &ArrowSchurSystem,
533 ridge_t: f64,
534 ridge_beta: f64,
535) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
536 use std::sync::Arc;
537 let n = sys.rows.len();
538 let k = sys.k;
539 let forward = sys
540 .htbeta_matvec
541 .clone()
542 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
543 let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
544 ArrowSchurGpuFailure::SchurFactorFailed {
549 reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
550 forward operator installed without its sparse adjoint"
551 .to_string(),
552 }
553 })?;
554
555 let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
560 for (i, row) in sys.rows.iter().enumerate() {
561 let di = row.htt.nrows();
562 if row.htt.ncols() != di || row.gt.len() != di {
563 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
564 reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
565 });
566 }
567 let mut block = row.htt.clone();
568 for r in 0..di {
569 block[[r, r]] += ridge_t;
570 }
571 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
572 .ok_or_else(|| {
573 ArrowSchurGpuFailure::RidgeBumpRequired {
577 row: i,
578 bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
579 }
580 })?;
581 factors.push(factor);
582 }
583
584 let penalty_op = sys.effective_penalty_op();
591 let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
592
593 let closure: crate::arrow_schur::GpuSchurMatvec =
594 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
595 assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
596 assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
597
598 {
601 let x_slice = x.as_slice().expect("x must be contiguous");
602 let out_slice = out.as_slice_mut().expect("out must be contiguous");
603 for a in 0..k {
604 out_slice[a] = ridge_beta * x_slice[a];
605 }
606 penalty_op.matvec(x_slice, out_slice);
607 }
608
609 let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
636 && rayon::current_thread_index().is_none();
637 if parallel {
638 use rayon::prelude::*;
639 const CHUNK: usize = 64;
640 let partials: Vec<Array1<f64>> = (0..n)
641 .into_par_iter()
642 .chunks(CHUNK)
643 .map(|idxs| {
644 let mut neg = Array1::<f64>::zeros(k);
648 for i in idxs {
649 let di = row_dims[i];
650 let mut v_i = Array1::<f64>::zeros(di);
652 forward(i, x.view(), &mut v_i);
653 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
655 transpose(i, w_i.view(), &mut neg);
657 }
658 neg
659 })
660 .collect();
661 let mut neg = Array1::<f64>::zeros(k);
670 for part in &partials {
671 for a in 0..k {
672 neg[a] += part[a];
673 }
674 }
675 for a in 0..k {
676 out[a] -= neg[a];
677 }
678 } else {
679 let mut neg = Array1::<f64>::zeros(k);
681 for i in 0..n {
682 let di = row_dims[i];
683 let mut v_i = Array1::<f64>::zeros(di);
685 forward(i, x.view(), &mut v_i);
686 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
688 transpose(i, w_i.view(), &mut neg);
690 }
691 for a in 0..k {
692 out[a] -= neg[a];
693 }
694 }
695 });
696
697 Ok(closure)
698}
699
700pub fn solve_reduced_beta_pcg(
722 s_acc: &Array2<f64>,
723 rhs_beta: &Array1<f64>,
724 max_iterations: usize,
725 relative_tolerance: f64,
726) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
727 solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
728 .map(|(x, _)| x)
729}
730
731#[doc(hidden)]
732pub fn solve_reduced_beta_pcg_with_diagnostics(
733 s_acc: &Array2<f64>,
734 rhs_beta: &Array1<f64>,
735 max_iterations: usize,
736 relative_tolerance: f64,
737) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
738 let k = rhs_beta.len();
739 if s_acc.dim() != (k, k) {
740 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
741 reason: format!(
742 "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
743 s_acc.dim()
744 ),
745 });
746 }
747 if k == 0 {
748 return Err(ArrowSchurGpuFailure::Unavailable);
749 }
750
751 #[cfg(not(target_os = "linux"))]
752 {
753 if relative_tolerance.is_nan() || max_iterations == 0 {
754 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
755 reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
756 });
757 }
758 Err(ArrowSchurGpuFailure::Unavailable)
759 }
760
761 #[cfg(target_os = "linux")]
762 {
763 cuda::solve_reduced_beta_pcg_with_diagnostics(
764 s_acc,
765 rhs_beta,
766 max_iterations,
767 relative_tolerance,
768 )
769 }
770}
771
772pub fn solve_sae_matrix_free_pcg(
773 sys: &ArrowSchurSystem,
774 data: &DeviceSaePcgData,
775 ridge_t: f64,
776 ridge_beta: f64,
777 rhs_beta: &Array1<f64>,
778 max_iterations: usize,
779 relative_tolerance: f64,
780) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
781 if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
782 return Err(ArrowSchurGpuFailure::Unavailable);
783 }
784 #[cfg(not(target_os = "linux"))]
785 {
786 if ridge_t.is_nan()
787 || ridge_beta.is_nan()
788 || relative_tolerance.is_nan()
789 || max_iterations == 0
790 {
791 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
792 reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
793 });
794 }
795 Err(ArrowSchurGpuFailure::Unavailable)
796 }
797 #[cfg(target_os = "linux")]
798 {
799 if data.frame.is_some() {
806 cuda::solve_sae_matrix_free_pcg_framed(
807 sys,
808 data,
809 ridge_t,
810 ridge_beta,
811 rhs_beta,
812 max_iterations,
813 relative_tolerance,
814 )
815 } else {
816 cuda::solve_sae_matrix_free_pcg(
817 sys,
818 data,
819 ridge_t,
820 ridge_beta,
821 rhs_beta,
822 max_iterations,
823 relative_tolerance,
824 )
825 }
826 }
827}
828
829#[doc(hidden)]
838pub fn framed_schur_matvec_once_on_device(
839 sys: &ArrowSchurSystem,
840 data: &DeviceSaePcgData,
841 ridge_t: f64,
842 ridge_beta: f64,
843 x: &Array1<f64>,
844) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
845 if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
846 return Err(ArrowSchurGpuFailure::Unavailable);
847 }
848 if data.frame.is_none() {
849 return Err(ArrowSchurGpuFailure::Unavailable);
850 }
851 #[cfg(not(target_os = "linux"))]
852 {
853 if ridge_t.is_finite() && ridge_beta.is_finite() {
858 return Err(ArrowSchurGpuFailure::Unavailable);
859 }
860 Err(ArrowSchurGpuFailure::Unavailable)
861 }
862 #[cfg(target_os = "linux")]
863 {
864 cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
865 }
866}
867
868#[doc(hidden)]
872pub fn solve_arrow_newton_step_dense_reference(
873 sys: &ArrowSchurSystem,
874 ridge_t: f64,
875 ridge_beta: f64,
876) -> Result<ArrowSchurGpuSolution, String> {
877 let n = sys.rows.len();
878 let d = sys.d;
879 let k = sys.k;
880 let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
881 let mut h = Array2::<f64>::zeros((total, total));
882 let mut rhs = Array1::<f64>::zeros(total);
883 for (i, row) in sys.rows.iter().enumerate() {
884 let base = i * d;
885 for c in 0..d {
886 for r in 0..d {
887 h[[base + r, base + c]] = row.htt[[r, c]];
888 }
889 h[[base + c, base + c]] += ridge_t;
890 }
891 for c in 0..k {
892 for r in 0..d {
893 let value = row.htbeta[[r, c]];
894 h[[base + r, n * d + c]] = value;
895 h[[n * d + c, base + r]] = value;
896 }
897 }
898 for r in 0..d {
899 rhs[base + r] = -row.gt[r];
900 }
901 }
902 for c in 0..k {
903 for r in 0..k {
904 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
905 }
906 h[[n * d + c, n * d + c]] += ridge_beta;
907 rhs[n * d + c] = -sys.gb[c];
908 }
909 let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
910 .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
911 let mut log_det = 0.0_f64;
912 for i in 0..total {
913 log_det += factor[[i, i]].ln();
914 }
915 log_det *= 2.0;
916 let solved = cholesky_solve_vector(factor.view(), rhs.view());
917 let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
918 let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
919 Ok(ArrowSchurGpuSolution {
920 delta_t,
921 delta_beta,
922 log_det_hessian: log_det,
923 })
924}
925
926#[doc(hidden)]
937pub fn sae_framed_penalty_matvec_cpu(
938 data: &DeviceSaePcgData,
939 ridge_beta: f64,
940 x: &[f64],
941 out: &mut [f64],
942) {
943 let frame = data
944 .frame
945 .as_ref()
946 .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
947 let k = data.beta_dim;
948 for a in 0..k {
949 out[a] = ridge_beta * x[a];
950 }
951 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
953 let off = blk.global_offset;
954 let m = blk.factor_a.nrows();
955 for i_a in 0..m {
956 for i_b in 0..r {
957 let mut acc = 0.0_f64;
958 for j_a in 0..m {
959 let s = blk.factor_a[[i_a, j_a]];
960 if s == 0.0 {
961 continue;
962 }
963 acc += s * x[off + j_a * r + i_b];
964 }
965 out[off + i_a * r + i_b] += acc;
966 }
967 }
968 }
969 for blk in &frame.frame_blocks {
971 let r_i = frame.ranks[blk.atom_i];
972 let r_j = frame.ranks[blk.atom_j];
973 let off_i = frame.border_offsets[blk.atom_i];
974 let off_j = frame.border_offsets[blk.atom_j];
975 let (m_i, m_j) = blk.g.dim();
976 for li in 0..m_i {
977 let yi_base = off_i + li * r_i;
978 for lj in 0..m_j {
979 let g = blk.g[[li, lj]];
980 if g == 0.0 {
981 continue;
982 }
983 let xj_base = off_j + lj * r_j;
984 for a in 0..r_i {
985 let mut acc = 0.0_f64;
986 for b in 0..r_j {
987 acc += blk.w[[a, b]] * x[xj_base + b];
988 }
989 out[yi_base + a] += g * acc;
990 }
991 }
992 }
993 }
994}
995
996#[doc(hidden)]
1005pub fn sae_framed_schur_matvec_cpu(
1006 sys: &ArrowSchurSystem,
1007 data: &DeviceSaePcgData,
1008 ridge_t: f64,
1009 ridge_beta: f64,
1010 x: &[f64],
1011 out: &mut [f64],
1012) -> Result<(), String> {
1013 let frame = data
1014 .frame
1015 .as_ref()
1016 .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1017 let k = data.beta_dim;
1018 sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1019 if frame.row_htbeta.len() != sys.rows.len() {
1020 return Err(format!(
1021 "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1022 frame.row_htbeta.len(),
1023 sys.rows.len()
1024 ));
1025 }
1026 for (i, row) in sys.rows.iter().enumerate() {
1027 let slab = &frame.row_htbeta[i];
1028 if slab.is_empty() {
1029 continue;
1030 }
1031 let qi = sys.row_dims[i];
1032 if qi == 0 || slab.len() != qi * k {
1033 continue;
1034 }
1035 let mut h = vec![0.0_f64; qi];
1037 for c in 0..qi {
1038 let base = c * k;
1039 let mut acc = 0.0_f64;
1040 for a in 0..k {
1041 acc += slab[base + a] * x[a];
1042 }
1043 h[c] = acc;
1044 }
1045 let mut block = row.htt.clone();
1047 for d in 0..qi {
1048 block[[d, d]] += ridge_t;
1049 }
1050 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1051 .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1052 let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1053 for c in 0..qi {
1055 let sc = s[c];
1056 if sc == 0.0 {
1057 continue;
1058 }
1059 let base = c * k;
1060 for a in 0..k {
1061 out[a] -= slab[base + a] * sc;
1062 }
1063 }
1064 }
1065 Ok(())
1066}
1067
1068#[cfg(target_os = "linux")]
1069mod cuda {
1070 use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1071 use gam_gpu::driver::to_i32;
1072 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1073 use crate::arrow_schur::{
1074 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1075 };
1076 use cudarc::cublas::sys::{
1077 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1078 };
1079 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1080 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1081 use cudarc::driver::{
1082 CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1083 PushKernelArg,
1084 };
1085 use ndarray::Array1;
1086 use std::sync::{Arc, OnceLock};
1087
1088 struct RowSlot {
1093 d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
1102 bump: Option<f64>,
1105 tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
1111
1112 pub(super) fn solve_multi_gpu(
1133 sys: &ArrowSchurSystem,
1134 ridge_t: f64,
1135 ridge_beta: f64,
1136 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1137 let n = sys.rows.len();
1138 let d = sys.d;
1139 let k = sys.k;
1140 if n == 0 || d == 0 || k == 0 {
1141 return Err(ArrowSchurGpuFailure::Unavailable);
1142 }
1143 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1147 return Err(ArrowSchurGpuFailure::Unavailable);
1148 }
1149
1150 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1151 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1152 if runtime.device_count() < 2 {
1153 return Err(ArrowSchurGpuFailure::Unavailable);
1154 }
1155
1156 let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1158 for row in &sys.rows {
1159 if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1160 return Err(ArrowSchurGpuFailure::Unavailable);
1161 }
1162 let mut d_block = Vec::with_capacity(d * d);
1163 let mut b_block = Vec::with_capacity(d * k);
1164 let mut g_vec = Vec::with_capacity(d);
1165 pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1166 slots.push(RowSlot {
1167 d_block,
1168 b_block,
1169 g_vec,
1170 l_block: Vec::new(),
1171 u_vec: Vec::new(),
1172 y_block: Vec::new(),
1173 log_det_local: 0.0,
1174 bump: None,
1175 tile_partial_schur: None,
1176 tile_partial_rhs: None,
1177 delta_t_block: vec![0.0; d],
1178 });
1179 }
1180
1181 let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1183 forward_tile(ordinal, d, k, tile)
1184 });
1185 if forward_ok.is_none() {
1186 return Err(ArrowSchurGpuFailure::Unavailable);
1187 }
1188
1189 let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1191 if let Some((row, bump)) = slots
1192 .iter()
1193 .enumerate()
1194 .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1195 {
1196 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1197 }
1198
1199 let mut schur_host = vec![0.0_f64; k * k];
1204 for col in 0..k {
1205 for row in 0..k {
1206 let mut v = sys.hbb[[row, col]];
1207 if row == col {
1208 v += ridge_beta;
1209 }
1210 schur_host[col * k + row] = v;
1211 }
1212 }
1213 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1214 let mut log_det = 0.0_f64;
1215 for start in tile_starts(&row_base_of_tile) {
1216 let slot = &slots[start];
1217 let partial_schur = slot
1218 .tile_partial_schur
1219 .as_ref()
1220 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1221 let partial_rhs = slot
1222 .tile_partial_rhs
1223 .as_ref()
1224 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1225 for idx in 0..k * k {
1230 schur_host[idx] += partial_schur[idx];
1231 }
1232 for a in 0..k {
1233 rhs_host[a] += partial_rhs[a];
1234 }
1235 }
1236 for slot in &slots {
1237 log_det += slot.log_det_local;
1238 }
1239
1240 let primary = runtime.selected_device().ordinal;
1244 let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1245 .and_then(|ctx| ctx.new_stream().ok())
1246 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1247 let solver =
1248 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1249 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1250 let mut schur_dev = stream
1251 .clone_htod(&schur_host)
1252 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1253 let mut rhs_dev = stream
1254 .clone_htod(&rhs_host)
1255 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1256 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1257 if info != 0 {
1258 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1259 reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1260 });
1261 }
1262 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1263 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1264 let delta_beta_host = stream
1265 .clone_dtoh(&rhs_dev)
1266 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1267 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1268 let l_schur_host = stream
1269 .clone_dtoh(&schur_dev)
1270 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1271 for j in 0..k {
1272 log_det += l_schur_host[j * k + j].ln();
1273 }
1274 log_det *= 2.0;
1275
1276 let delta_beta_ref = &delta_beta_host;
1278 let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1279 back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1280 });
1281 if back_ok.is_none() {
1282 return Err(ArrowSchurGpuFailure::Unavailable);
1283 }
1284
1285 let mut delta_t = Array1::<f64>::zeros(n * d);
1287 for (i, slot) in slots.iter().enumerate() {
1288 let base = i * d;
1289 for r in 0..d {
1290 delta_t[base + r] = slot.delta_t_block[r];
1291 }
1292 }
1293
1294 Ok(ArrowSchurGpuSolution {
1295 delta_t,
1296 delta_beta,
1297 log_det_hessian: log_det,
1298 })
1299 }
1300
1301 fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1304 tiles.iter().map(|(_, range)| range.start)
1305 }
1306
1307 fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1315 if tile.is_empty() {
1316 return Some(());
1317 }
1318 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1321 .and_then(|ctx| ctx.new_stream().ok())?;
1322 let solver = DnHandle::new(stream.clone()).ok()?;
1323 let blas = CudaBlas::new(stream.clone()).ok()?;
1324 let m = tile.len();
1325
1326 let mut d_host = Vec::with_capacity(m * d * d);
1329 let mut b_host = Vec::with_capacity(m * d * k);
1330 let mut g_host = Vec::with_capacity(m * d);
1331 for slot in tile.iter() {
1332 d_host.extend_from_slice(&slot.d_block);
1333 b_host.extend_from_slice(&slot.b_block);
1334 g_host.extend_from_slice(&slot.g_vec);
1335 }
1336 let mut d_dev = stream.clone_htod(&d_host).ok()?;
1337 let mut b_dev = stream.clone_htod(&b_host).ok()?;
1338 let mut g_dev = stream.clone_htod(&g_host).ok()?;
1339
1340 let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1346 if let Some(local) = info_host.iter().position(|info| *info != 0) {
1347 tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1348 &tile[local].d_block,
1349 d,
1350 ));
1351 return Some(());
1352 }
1353
1354 trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1356 trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1357
1358 let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1360 let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1361 accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1362
1363 let l_host = stream.clone_dtoh(&d_dev).ok()?;
1365 let u_host = stream.clone_dtoh(&g_dev).ok()?;
1366 let y_host = stream.clone_dtoh(&b_dev).ok()?;
1367 let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1368 let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1369
1370 for (local, slot) in tile.iter_mut().enumerate() {
1371 let l_base = local * d * d;
1372 let u_base = local * d;
1373 let y_base = local * d * k;
1374 slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1375 slot.u_vec = u_host[u_base..u_base + d].to_vec();
1376 slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1377 let mut log_det_local = 0.0_f64;
1378 for j in 0..d {
1379 log_det_local += l_host[l_base + j * d + j].ln();
1380 }
1381 slot.log_det_local = log_det_local;
1382 }
1383 tile[0].tile_partial_schur = Some(partial_schur);
1384 tile[0].tile_partial_rhs = Some(partial_rhs);
1385 Some(())
1386 }
1387
1388 fn back_sub_tile(
1392 ordinal: usize,
1393 d: usize,
1394 k: usize,
1395 delta_beta: &[f64],
1396 tile: &mut [RowSlot],
1397 ) -> Option<()> {
1398 if tile.is_empty() {
1399 return Some(());
1400 }
1401 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1404 .and_then(|ctx| ctx.new_stream().ok())?;
1405 let blas = CudaBlas::new(stream.clone()).ok()?;
1406 let m = tile.len();
1407
1408 let mut l_host = Vec::with_capacity(m * d * d);
1409 let mut u_host = Vec::with_capacity(m * d);
1410 let mut y_host = Vec::with_capacity(m * d * k);
1411 for slot in tile.iter() {
1412 l_host.extend_from_slice(&slot.l_block);
1413 u_host.extend_from_slice(&slot.u_vec);
1414 y_host.extend_from_slice(&slot.y_block);
1415 }
1416 let d_dev = stream.clone_htod(&l_host).ok()?;
1417 let mut g_dev = stream.clone_htod(&u_host).ok()?;
1418 let b_dev = stream.clone_htod(&y_host).ok()?;
1419 let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1420
1421 accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1423 trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1424 let x_host = stream.clone_dtoh(&g_dev).ok()?;
1425 for (local, slot) in tile.iter_mut().enumerate() {
1426 let base = local * d;
1427 for r in 0..d {
1428 slot.delta_t_block[r] = -x_host[base + r];
1429 }
1430 }
1431 Some(())
1432 }
1433
1434 pub(super) fn solve(
1435 sys: &ArrowSchurSystem,
1436 ridge_t: f64,
1437 ridge_beta: f64,
1438 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1439 let n = sys.rows.len();
1440 let d = sys.d;
1441 let k = sys.k;
1442 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1443 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1444
1445 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1446 .and_then(|ctx| ctx.new_stream().ok())
1447 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1448 let solver =
1449 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1450 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1451
1452 let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1454 let mut d_dev = stream
1455 .clone_htod(&d_host)
1456 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1457 let mut b_dev = stream
1458 .clone_htod(&b_host)
1459 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1460 let mut g_dev = stream
1461 .clone_htod(&g_host)
1462 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1463
1464 let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1472 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1473 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1477 row: idx,
1478 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1479 });
1480 }
1481
1482 trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1485 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1488
1489 let schur_init: Vec<f64> = {
1508 let mut tmp = Vec::with_capacity(k * k);
1509 for col in 0..k {
1510 for row in 0..k {
1511 let mut v = sys.hbb[[row, col]];
1512 if row == col {
1513 v += ridge_beta;
1514 }
1515 tmp.push(v);
1516 }
1517 }
1518 tmp
1519 };
1520 let mut schur_dev = stream
1521 .clone_htod(&schur_init)
1522 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1523 let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1524 let mut rhs_dev = stream
1525 .clone_htod(&rhs_init)
1526 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1527
1528 accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1529
1530 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1532 if info != 0 {
1533 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1534 reason: format!("Schur Cholesky failed at pivot {info}"),
1535 });
1536 }
1537 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1539 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1540 let delta_beta_host = stream
1541 .clone_dtoh(&rhs_dev)
1542 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1543 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1544
1545 accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1553 trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1554
1555 let x_host = stream
1556 .clone_dtoh(&g_dev)
1557 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1558 let mut delta_t = Array1::<f64>::zeros(n * d);
1559 for (i, v) in x_host.iter().enumerate() {
1560 delta_t[i] = -*v;
1561 }
1562
1563 let l_local_host = stream
1565 .clone_dtoh(&d_dev)
1566 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1567 let l_schur_host = stream
1568 .clone_dtoh(&schur_dev)
1569 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1570 let mut log_det = 0.0_f64;
1571 for i in 0..n {
1572 let base = i * d * d;
1573 for j in 0..d {
1574 log_det += l_local_host[base + j * d + j].ln();
1575 }
1576 }
1577 for j in 0..k {
1578 log_det += l_schur_host[j * k + j].ln();
1579 }
1580 log_det *= 2.0;
1581
1582 Ok(ArrowSchurGpuSolution {
1583 delta_t,
1584 delta_beta,
1585 log_det_hessian: log_det,
1586 })
1587 }
1588
1589 fn potrf_batched(
1590 solver: &DnHandle,
1591 stream: &Arc<CudaStream>,
1592 p: usize,
1593 batch: usize,
1594 matrices: &mut CudaSlice<f64>,
1595 ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1596 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1597 let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1598 let matrix_len = p * p;
1599 let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1600 let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1601 let mut ptrs = Vec::with_capacity(batch);
1602 for idx in 0..batch {
1603 ptrs.push(base_ptr + (idx as u64) * bytes_per);
1604 }
1605 let mut ptrs_dev = stream
1606 .clone_htod(&ptrs)
1607 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1608 let mut info_dev = stream
1609 .alloc_zeros::<i32>(batch)
1610 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1611 let status = {
1612 let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1613 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1614 unsafe {
1617 cusolver_sys::cusolverDnDpotrfBatched(
1618 solver.cu(),
1619 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1620 p_i,
1621 ptrs_ptr as *mut *mut f64,
1622 p_i,
1623 info_ptr as *mut i32,
1624 batch_i,
1625 )
1626 }
1627 };
1628 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1629 return Err(ArrowSchurGpuFailure::Unavailable);
1630 }
1631 stream
1632 .clone_dtoh(&info_dev)
1633 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1634 }
1635
1636 fn potrf_single(
1637 solver: &DnHandle,
1638 stream: &Arc<CudaStream>,
1639 p: usize,
1640 matrix: &mut CudaSlice<f64>,
1641 ) -> Result<i32, ArrowSchurGpuFailure> {
1642 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1643 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1644 let mut lwork = 0_i32;
1645 {
1646 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1647 let status = unsafe {
1649 cusolver_sys::cusolverDnDpotrf_bufferSize(
1650 solver.cu(),
1651 uplo,
1652 p_i,
1653 mat_ptr as *mut f64,
1654 p_i,
1655 &mut lwork,
1656 )
1657 };
1658 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1659 return Err(ArrowSchurGpuFailure::Unavailable);
1660 }
1661 }
1662 let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1663 let mut workspace = stream
1664 .alloc_zeros::<f64>(lwork_usize.max(1))
1665 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1666 let mut info_dev = stream
1667 .alloc_zeros::<i32>(1)
1668 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1669 {
1670 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1671 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1672 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1673 let status = unsafe {
1675 cusolver_sys::cusolverDnDpotrf(
1676 solver.cu(),
1677 uplo,
1678 p_i,
1679 mat_ptr as *mut f64,
1680 p_i,
1681 work_ptr as *mut f64,
1682 lwork,
1683 info_ptr as *mut i32,
1684 )
1685 };
1686 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1687 return Err(ArrowSchurGpuFailure::Unavailable);
1688 }
1689 }
1690 let info_host = stream
1691 .clone_dtoh(&info_dev)
1692 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1693 Ok(info_host[0])
1694 }
1695
1696 fn trsm_batched_lower_inplace(
1700 blas: &CudaBlas,
1701 stream: &Arc<CudaStream>,
1702 d: usize,
1703 n: usize,
1704 nrhs: usize,
1705 l_stack: &CudaSlice<f64>,
1706 rhs_stack: &mut CudaSlice<f64>,
1707 ) -> Result<(), ArrowSchurGpuFailure> {
1708 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1709 }
1710
1711 fn trsm_batched_lower_inplace_transposed(
1713 blas: &CudaBlas,
1714 stream: &Arc<CudaStream>,
1715 d: usize,
1716 n: usize,
1717 nrhs: usize,
1718 l_stack: &CudaSlice<f64>,
1719 rhs_stack: &mut CudaSlice<f64>,
1720 ) -> Result<(), ArrowSchurGpuFailure> {
1721 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1722 }
1723
1724 fn trsm_batched_inplace_inner(
1725 blas: &CudaBlas,
1726 stream: &Arc<CudaStream>,
1727 d: usize,
1728 n: usize,
1729 nrhs: usize,
1730 l_stack: &CudaSlice<f64>,
1731 rhs_stack: &mut CudaSlice<f64>,
1732 transposed: bool,
1733 ) -> Result<(), ArrowSchurGpuFailure> {
1734 let alpha = 1.0_f64;
1735 let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1736 let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1737 let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1738 let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1739 let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1740 let (l_base, _l_record) = l_stack.device_ptr(stream);
1741 let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1742 let mut l_ptrs = Vec::with_capacity(n);
1743 let mut rhs_ptrs = Vec::with_capacity(n);
1744 for i in 0..n {
1745 l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1746 rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1747 }
1748 let mut l_ptrs_dev = stream
1749 .clone_htod(&l_ptrs)
1750 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1751 let mut rhs_ptrs_dev = stream
1752 .clone_htod(&rhs_ptrs)
1753 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1754 let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1755 let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1756 let op = if transposed {
1757 cublasOperation_t::CUBLAS_OP_T
1758 } else {
1759 cublasOperation_t::CUBLAS_OP_N
1760 };
1761 let handle = *blas.handle();
1762 let status = unsafe {
1765 cudarc::cublas::sys::cublasDtrsmBatched(
1766 handle,
1767 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1768 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1769 op,
1770 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1771 d_i,
1772 nrhs_i,
1773 &alpha,
1774 l_ptrs_ptr as *const *const f64,
1775 d_i,
1776 rhs_ptrs_ptr as *const *mut f64,
1777 d_i,
1778 batch_i,
1779 )
1780 };
1781 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1782 return Err(ArrowSchurGpuFailure::Unavailable);
1783 }
1784 Ok(())
1785 }
1786
1787 fn trsm_single(
1790 blas: &CudaBlas,
1791 stream: &Arc<CudaStream>,
1792 n: usize,
1793 l: &CudaSlice<f64>,
1794 rhs: &mut CudaSlice<f64>,
1795 upper: bool,
1796 transposed: bool,
1797 ) -> Result<(), ArrowSchurGpuFailure> {
1798 let alpha = 1.0_f64;
1799 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1800 let handle = *blas.handle();
1801 let (l_ptr, _l_rec) = l.device_ptr(stream);
1802 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1803 let status = unsafe {
1805 cudarc::cublas::sys::cublasDtrsm_v2(
1806 handle,
1807 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1808 if upper {
1809 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1810 } else {
1811 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1812 },
1813 if transposed {
1814 cublasOperation_t::CUBLAS_OP_T
1815 } else {
1816 cublasOperation_t::CUBLAS_OP_N
1817 },
1818 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1819 n_i,
1820 1,
1821 &alpha,
1822 l_ptr as *const f64,
1823 n_i,
1824 rhs_ptr as *mut f64,
1825 n_i,
1826 )
1827 };
1828 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1829 return Err(ArrowSchurGpuFailure::Unavailable);
1830 }
1831 Ok(())
1832 }
1833
1834 fn accumulate_schur(
1838 blas: &CudaBlas,
1839 d: usize,
1840 k: usize,
1841 n: usize,
1842 y_stack: &CudaSlice<f64>,
1843 u_stack: &CudaSlice<f64>,
1844 schur: &mut CudaSlice<f64>,
1845 rhs: &mut CudaSlice<f64>,
1846 ) -> Result<(), ArrowSchurGpuFailure> {
1847 let y_block_elems = d * k;
1848 let u_block_elems = d;
1849 for i in 0..n {
1850 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1851 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1852 let gemm_cfg = GemmConfig::<f64> {
1854 transa: cublasOperation_t::CUBLAS_OP_T,
1855 transb: cublasOperation_t::CUBLAS_OP_N,
1856 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1857 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1858 k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1859 alpha: -1.0,
1860 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1861 ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1862 beta: 1.0,
1863 ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1864 };
1865 unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1867 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1868 let gemv_cfg = GemvConfig::<f64> {
1870 trans: cublasOperation_t::CUBLAS_OP_T,
1871 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1872 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1873 alpha: 1.0,
1874 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1875 incx: 1,
1876 beta: 1.0,
1877 incy: 1,
1878 };
1879 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1882 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1883 }
1884 Ok(())
1885 }
1886
1887 fn accumulate_schur_rhs_only(
1895 blas: &CudaBlas,
1896 d: usize,
1897 k: usize,
1898 n: usize,
1899 y_stack: &CudaSlice<f64>,
1900 u_stack: &CudaSlice<f64>,
1901 rhs: &mut CudaSlice<f64>,
1902 ) -> Result<(), ArrowSchurGpuFailure> {
1903 let y_block_elems = d * k;
1904 let u_block_elems = d;
1905 for i in 0..n {
1906 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1907 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1908 let gemv_cfg = GemvConfig::<f64> {
1909 trans: cublasOperation_t::CUBLAS_OP_T,
1910 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1911 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1912 alpha: 1.0,
1913 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1914 incx: 1,
1915 beta: 1.0,
1916 incy: 1,
1917 };
1918 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1921 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1922 }
1923 Ok(())
1924 }
1925
1926 fn accumulate_back_sub_rhs(
1929 blas: &CudaBlas,
1930 d: usize,
1931 k: usize,
1932 n: usize,
1933 y_stack: &CudaSlice<f64>,
1934 delta_beta: &CudaSlice<f64>,
1935 u_stack: &mut CudaSlice<f64>,
1936 ) -> Result<(), ArrowSchurGpuFailure> {
1937 let y_block_elems = d * k;
1938 let u_block_elems = d;
1939 for i in 0..n {
1940 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1941 let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1942 let gemv_cfg = GemvConfig::<f64> {
1943 trans: cublasOperation_t::CUBLAS_OP_N,
1944 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1945 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1946 alpha: 1.0,
1947 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1948 incx: 1,
1949 beta: 1.0,
1950 incy: 1,
1951 };
1952 unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1955 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1956 }
1957 Ok(())
1958 }
1959
1960 use std::collections::HashMap;
1976 use std::sync::Mutex;
1977
1978 struct FusedModuleCache {
1983 modules: Mutex<
1984 HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1985 >,
1986 }
1987
1988 fn fused_module_cache() -> &'static FusedModuleCache {
1989 static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1990 CACHE.get_or_init(|| FusedModuleCache {
1991 modules: Mutex::new(HashMap::new()),
1992 })
1993 }
1994
1995 fn fused_module_for(
1996 ctx: &Arc<CudaContext>,
1997 key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
1998 ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
1999 let cache = fused_module_cache();
2000 if let Ok(guard) = cache.modules.lock() {
2001 if let Some(existing) = guard.get(&key) {
2002 return Ok(existing.clone());
2003 }
2004 }
2005 let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2006 key.p_max as usize,
2007 key.r_template as usize,
2008 );
2009 let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
2010 ArrowSchurGpuFailure::SchurFactorFailed {
2011 reason: format!(
2012 "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2013 key.p_max, key.r_template
2014 ),
2015 }
2016 })?;
2017 let module = ctx
2018 .load_module(ptx)
2019 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2020 if let Ok(mut guard) = cache.modules.lock() {
2021 guard.entry(key).or_insert_with(|| module.clone());
2022 }
2023 Ok(module)
2024 }
2025
2026 const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2027extern "C" __global__ void arrow_pcg_jacobi_mul(
2028 const double* __restrict__ inv_diag,
2029 const double* __restrict__ r,
2030 double* __restrict__ z,
2031 int n
2032) {
2033 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2034 if (idx < n) {
2035 z[idx] = inv_diag[idx] * r[idx];
2036 }
2037}
2038
2039extern "C" __global__ void arrow_pcg_update_p(
2040 const double* __restrict__ z,
2041 double beta,
2042 double* __restrict__ p,
2043 int n
2044) {
2045 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2046 if (idx < n) {
2047 p[idx] = z[idx] + beta * p[idx];
2048 }
2049}
2050
2051extern "C" __global__ void arrow_sae_init(
2052 double* __restrict__ out,
2053 const double* __restrict__ x,
2054 double ridge,
2055 int n
2056) {
2057 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2058 if (idx < n) {
2059 out[idx] = ridge * x[idx];
2060 }
2061}
2062
2063extern "C" __global__ void arrow_sae_smooth_matvec(
2064 const double* __restrict__ x,
2065 double* __restrict__ out,
2066 const int* __restrict__ block_offsets,
2067 const int* __restrict__ block_m,
2068 const int* __restrict__ factor_ptr,
2069 const double* __restrict__ factors,
2070 int p,
2071 int n_blocks
2072) {
2073 int block_id = blockIdx.y;
2074 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2075 if (block_id >= n_blocks) {
2076 return;
2077 }
2078 int m = block_m[block_id];
2079 int total = m * p;
2080 if (linear >= total) {
2081 return;
2082 }
2083 int li = linear / p;
2084 int oc = linear - li * p;
2085 int off = block_offsets[block_id];
2086 int fbase = factor_ptr[block_id];
2087 double acc = 0.0;
2088 for (int lj = 0; lj < m; ++lj) {
2089 double a = factors[fbase + li * m + lj];
2090 acc += a * x[off + lj * p + oc];
2091 }
2092 out[off + li * p + oc] += acc;
2093}
2094
2095extern "C" __global__ void arrow_sae_sparse_g_matvec(
2096 const double* __restrict__ x,
2097 double* __restrict__ out,
2098 const int* __restrict__ row_off,
2099 const int* __restrict__ col_off,
2100 const int* __restrict__ rows,
2101 const int* __restrict__ cols,
2102 const int* __restrict__ data_ptr,
2103 const double* __restrict__ data,
2104 int p,
2105 int n_blocks
2106) {
2107 int block_id = blockIdx.y;
2108 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2109 if (block_id >= n_blocks) {
2110 return;
2111 }
2112 int m_i = rows[block_id];
2113 int m_j = cols[block_id];
2114 int total = m_i * p;
2115 if (linear >= total) {
2116 return;
2117 }
2118 int li = linear / p;
2119 int oc = linear - li * p;
2120 int rbase = row_off[block_id];
2121 int cbase = col_off[block_id];
2122 int dbase = data_ptr[block_id];
2123 double acc = 0.0;
2124 for (int lj = 0; lj < m_j; ++lj) {
2125 acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2126 }
2127 // #1017 — a row atom co-occurs with multiple column atoms, so several
2128 // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2129 // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2130 // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2131 // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2132 atomicAdd(&out[(rbase + li) * p + oc], acc);
2133}
2134
2135extern "C" __global__ void arrow_sae_gather_u(
2136 const double* __restrict__ x,
2137 const int* __restrict__ row_ptr,
2138 const int* __restrict__ beta_base,
2139 const double* __restrict__ phi,
2140 double* __restrict__ u,
2141 int p,
2142 int n_rows
2143) {
2144 int row = blockIdx.y;
2145 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2146 if (row >= n_rows || oc >= p) {
2147 return;
2148 }
2149 double acc = 0.0;
2150 int start = row_ptr[row];
2151 int end = row_ptr[row + 1];
2152 for (int e = start; e < end; ++e) {
2153 acc += phi[e] * x[beta_base[e] + oc];
2154 }
2155 u[row * p + oc] = acc;
2156}
2157
2158extern "C" __global__ void arrow_sae_apply_l(
2159 const double* __restrict__ u,
2160 const int* __restrict__ jac_ptr,
2161 const double* __restrict__ jac,
2162 double* __restrict__ w,
2163 int p,
2164 int max_q,
2165 int n_rows
2166) {
2167 int row = blockIdx.y;
2168 int c = blockIdx.x * blockDim.x + threadIdx.x;
2169 if (row >= n_rows) {
2170 return;
2171 }
2172 int jstart = jac_ptr[row];
2173 int q = (jac_ptr[row + 1] - jstart) / p;
2174 if (c >= q) {
2175 return;
2176 }
2177 double acc = 0.0;
2178 for (int oc = 0; oc < p; ++oc) {
2179 acc += jac[jstart + c * p + oc] * u[row * p + oc];
2180 }
2181 w[row * max_q + c] = acc;
2182}
2183
2184extern "C" __global__ void arrow_sae_apply_ainv(
2185 const double* __restrict__ ainv,
2186 const double* __restrict__ w,
2187 double* __restrict__ v,
2188 int max_q,
2189 int n_rows
2190) {
2191 int row = blockIdx.y;
2192 int c = blockIdx.x * blockDim.x + threadIdx.x;
2193 if (row >= n_rows || c >= max_q) {
2194 return;
2195 }
2196 double acc = 0.0;
2197 int base = row * max_q * max_q;
2198 for (int j = 0; j < max_q; ++j) {
2199 acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2200 }
2201 v[row * max_q + c] = acc;
2202}
2203
2204extern "C" __global__ void arrow_sae_scatter_sub(
2205 const double* __restrict__ v,
2206 const int* __restrict__ jac_ptr,
2207 const double* __restrict__ jac,
2208 const int* __restrict__ row_ptr,
2209 const int* __restrict__ beta_base,
2210 const double* __restrict__ phi,
2211 double* __restrict__ out,
2212 int p,
2213 int max_q,
2214 int n_rows
2215) {
2216 int row = blockIdx.y;
2217 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2218 if (row >= n_rows || oc >= p) {
2219 return;
2220 }
2221 int jstart = jac_ptr[row];
2222 int q = (jac_ptr[row + 1] - jstart) / p;
2223 double lt_v = 0.0;
2224 for (int c = 0; c < q; ++c) {
2225 lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2226 }
2227 int start = row_ptr[row];
2228 int end = row_ptr[row + 1];
2229 for (int e = start; e < end; ++e) {
2230 atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2231 }
2232}
2233
2234extern "C" __global__ void arrow_sae_diag_sub(
2235 double* __restrict__ diag,
2236 const double* __restrict__ ainv,
2237 const int* __restrict__ jac_ptr,
2238 const double* __restrict__ jac,
2239 const int* __restrict__ row_ptr,
2240 const int* __restrict__ beta_base,
2241 const double* __restrict__ phi,
2242 int p,
2243 int max_q,
2244 int n_rows
2245) {
2246 int row = blockIdx.y;
2247 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2248 if (row >= n_rows || oc >= p) {
2249 return;
2250 }
2251 int jstart = jac_ptr[row];
2252 int q = (jac_ptr[row + 1] - jstart) / p;
2253 int abase = row * max_q * max_q;
2254 double quad = 0.0;
2255 for (int c = 0; c < q; ++c) {
2256 double lc = jac[jstart + c * p + oc];
2257 for (int d = 0; d < q; ++d) {
2258 quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2259 }
2260 }
2261 int start = row_ptr[row];
2262 int end = row_ptr[row + 1];
2263 for (int e = start; e < end; ++e) {
2264 double pe = phi[e];
2265 atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2266 }
2267}
2268
2269/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2270 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2271 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2272 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2273 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2274
2275extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2276 const double* __restrict__ x,
2277 double* __restrict__ out,
2278 const int* __restrict__ block_offsets,
2279 const int* __restrict__ block_m,
2280 const int* __restrict__ block_r,
2281 const int* __restrict__ factor_ptr,
2282 const double* __restrict__ factors,
2283 int n_blocks
2284) {
2285 int block_id = blockIdx.y;
2286 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2287 if (block_id >= n_blocks) {
2288 return;
2289 }
2290 int m = block_m[block_id];
2291 int r = block_r[block_id];
2292 int total = m * r;
2293 if (linear >= total) {
2294 return;
2295 }
2296 int li = linear / r;
2297 int ib = linear - li * r;
2298 int off = block_offsets[block_id];
2299 int fbase = factor_ptr[block_id];
2300 double acc = 0.0;
2301 for (int lj = 0; lj < m; ++lj) {
2302 double a = factors[fbase + li * m + lj];
2303 acc += a * x[off + lj * r + ib];
2304 }
2305 out[off + li * r + ib] += acc;
2306}
2307
2308extern "C" __global__ void arrow_sae_frame_g_matvec(
2309 const double* __restrict__ x,
2310 double* __restrict__ out,
2311 const int* __restrict__ off_i,
2312 const int* __restrict__ off_j,
2313 const int* __restrict__ r_i,
2314 const int* __restrict__ r_j,
2315 const int* __restrict__ m_i,
2316 const int* __restrict__ m_j,
2317 const int* __restrict__ g_ptr,
2318 const double* __restrict__ g_data,
2319 const int* __restrict__ w_ptr,
2320 const double* __restrict__ w_data,
2321 int n_blocks
2322) {
2323 int block_id = blockIdx.y;
2324 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2325 if (block_id >= n_blocks) {
2326 return;
2327 }
2328 int ri = r_i[block_id];
2329 int rj = r_j[block_id];
2330 int mi = m_i[block_id];
2331 int mj = m_j[block_id];
2332 int total = mi * ri;
2333 if (linear >= total) {
2334 return;
2335 }
2336 int li = linear / ri; // basis row in atom i
2337 int a = linear - li * ri; // frame coord in atom i
2338 int oi = off_i[block_id];
2339 int oj = off_j[block_id];
2340 int gbase = g_ptr[block_id];
2341 int wbase = w_ptr[block_id];
2342 double acc = 0.0;
2343 for (int lj = 0; lj < mj; ++lj) {
2344 double g = g_data[gbase + li * mj + lj];
2345 if (g == 0.0) { continue; }
2346 int xj_base = oj + lj * rj;
2347 double inner = 0.0;
2348 for (int b = 0; b < rj; ++b) {
2349 inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2350 }
2351 acc += g * inner;
2352 }
2353 // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2354 // multiple co-occurring (i,j) frame blocks running concurrently on
2355 // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2356 // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2357 atomicAdd(&out[oi + li * ri + a], acc);
2358}
2359
2360/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2361 * h_i = H_tβ^(i) · x (length q_i)
2362 * s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
2363 * out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
2364 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2365 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2366extern "C" __global__ void arrow_sae_frame_apply_h(
2367 const double* __restrict__ x,
2368 const int* __restrict__ htb_ptr,
2369 const double* __restrict__ htb,
2370 const int* __restrict__ q_of,
2371 double* __restrict__ hvec,
2372 int k,
2373 int max_q,
2374 int n_rows
2375) {
2376 int row = blockIdx.y;
2377 int c = blockIdx.x * blockDim.x + threadIdx.x;
2378 if (row >= n_rows) { return; }
2379 int q = q_of[row];
2380 if (c >= q) { return; }
2381 int base = htb_ptr[row] + c * k;
2382 double acc = 0.0;
2383 for (int a = 0; a < k; ++a) {
2384 acc += htb[base + a] * x[a];
2385 }
2386 hvec[row * max_q + c] = acc;
2387}
2388
2389extern "C" __global__ void arrow_sae_frame_apply_ainv(
2390 const double* __restrict__ ainv,
2391 const double* __restrict__ hvec,
2392 const int* __restrict__ q_of,
2393 double* __restrict__ svec,
2394 int max_q,
2395 int n_rows
2396) {
2397 int row = blockIdx.y;
2398 int c = blockIdx.x * blockDim.x + threadIdx.x;
2399 if (row >= n_rows || c >= max_q) { return; }
2400 int q = q_of[row];
2401 double acc = 0.0;
2402 int abase = row * max_q * max_q;
2403 for (int j = 0; j < q; ++j) {
2404 acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2405 }
2406 svec[row * max_q + c] = acc;
2407}
2408
2409extern "C" __global__ void arrow_sae_frame_scatter_h(
2410 const double* __restrict__ svec,
2411 const int* __restrict__ htb_ptr,
2412 const double* __restrict__ htb,
2413 const int* __restrict__ q_of,
2414 double* __restrict__ out,
2415 int k,
2416 int max_q,
2417 int n_rows
2418) {
2419 int row = blockIdx.y;
2420 int a = blockIdx.x * blockDim.x + threadIdx.x;
2421 if (row >= n_rows || a >= k) { return; }
2422 int q = q_of[row];
2423 int hbase = htb_ptr[row];
2424 double acc = 0.0;
2425 for (int c = 0; c < q; ++c) {
2426 acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2427 }
2428 atomicAdd(&out[a], -acc);
2429}
2430
2431/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2432extern "C" __global__ void arrow_sae_frame_diag_sub(
2433 double* __restrict__ diag,
2434 const double* __restrict__ ainv,
2435 const int* __restrict__ htb_ptr,
2436 const double* __restrict__ htb,
2437 const int* __restrict__ q_of,
2438 int k,
2439 int max_q,
2440 int n_rows
2441) {
2442 int row = blockIdx.y;
2443 int a = blockIdx.x * blockDim.x + threadIdx.x;
2444 if (row >= n_rows || a >= k) { return; }
2445 int q = q_of[row];
2446 int hbase = htb_ptr[row];
2447 int abase = row * max_q * max_q;
2448 double quad = 0.0;
2449 for (int c = 0; c < q; ++c) {
2450 double hc = htb[hbase + c * k + a];
2451 for (int d = 0; d < q; ++d) {
2452 quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2453 }
2454 }
2455 atomicAdd(&diag[a], -quad);
2456}
2457"#;
2458
2459 fn pcg_vector_module(
2460 ctx: &Arc<CudaContext>,
2461 ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2462 static CACHE: gam_gpu::device_cache::PtxModuleCache =
2463 gam_gpu::device_cache::PtxModuleCache::new();
2464 CACHE
2465 .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2466 .map_err(|err| {
2467 log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2473 ArrowSchurGpuFailure::Unavailable
2474 })
2475 }
2476
2477 fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2478 let threads = 256u32;
2479 let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2480 Ok(LaunchConfig {
2481 grid_dim: (blocks, 1, 1),
2482 block_dim: (threads, 1, 1),
2483 shared_mem_bytes: 0,
2484 })
2485 }
2486
2487 fn launch_jacobi_mul(
2488 stream: &Arc<CudaStream>,
2489 module: &Arc<CudaModule>,
2490 inv_diag: &CudaSlice<f64>,
2491 r: &CudaSlice<f64>,
2492 z: &mut CudaSlice<f64>,
2493 n: usize,
2494 ) -> Result<(), ArrowSchurGpuFailure> {
2495 let kernel = module
2496 .load_function("arrow_pcg_jacobi_mul")
2497 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2498 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2499 let mut builder = stream.launch_builder(&kernel);
2500 builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2501 unsafe { builder.launch(pcg_launch_config(n)?) }
2504 .map(drop)
2505 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2506 }
2507
2508 fn launch_update_p(
2509 stream: &Arc<CudaStream>,
2510 module: &Arc<CudaModule>,
2511 z: &CudaSlice<f64>,
2512 beta: f64,
2513 p: &mut CudaSlice<f64>,
2514 n: usize,
2515 ) -> Result<(), ArrowSchurGpuFailure> {
2516 let kernel = module
2517 .load_function("arrow_pcg_update_p")
2518 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2519 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2520 let mut builder = stream.launch_builder(&kernel);
2521 builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2522 unsafe { builder.launch(pcg_launch_config(n)?) }
2525 .map(drop)
2526 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2527 }
2528
2529 struct DeviceSaePcgBuffers {
2530 row_ptr: CudaSlice<i32>,
2531 beta_base: CudaSlice<i32>,
2532 phi: CudaSlice<f64>,
2533 jac_ptr: CudaSlice<i32>,
2534 jac: CudaSlice<f64>,
2535 smooth_offsets: CudaSlice<i32>,
2536 smooth_m: CudaSlice<i32>,
2537 smooth_ptr: CudaSlice<i32>,
2538 smooth_data: CudaSlice<f64>,
2539 g_row_off: CudaSlice<i32>,
2540 g_col_off: CudaSlice<i32>,
2541 g_rows: CudaSlice<i32>,
2542 g_cols: CudaSlice<i32>,
2543 g_ptr: CudaSlice<i32>,
2544 g_data: CudaSlice<f64>,
2545 ainv: CudaSlice<f64>,
2546 u: CudaSlice<f64>,
2547 w: CudaSlice<f64>,
2548 v: CudaSlice<f64>,
2549 n_rows: usize,
2550 p: usize,
2551 k: usize,
2552 max_q: usize,
2553 smooth_blocks: usize,
2554 g_blocks: usize,
2555 }
2556
2557 fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2558 to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2559 }
2560
2561 fn sae_penalty_diag_host(
2562 data: &DeviceSaePcgData,
2563 ridge_beta: f64,
2564 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2565 let mut diag = vec![ridge_beta; data.beta_dim];
2566 for block in &data.smooth_blocks {
2567 let (rows, cols) = block.factor_a.dim();
2568 if rows != cols {
2569 return Err(ArrowSchurGpuFailure::Unavailable);
2570 }
2571 for row in 0..rows {
2572 let coeff = block.factor_a[[row, row]];
2573 let base = block
2574 .global_offset
2575 .checked_add(
2576 row.checked_mul(data.p)
2577 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2578 )
2579 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2580 let end = base
2581 .checked_add(data.p)
2582 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2583 if end > diag.len() {
2584 return Err(ArrowSchurGpuFailure::Unavailable);
2585 }
2586 for channel in 0..data.p {
2587 diag[base + channel] += coeff;
2588 }
2589 }
2590 }
2591 for block in &data.sparse_g_blocks {
2592 if block.row_off != block.col_off {
2593 continue;
2594 }
2595 let (rows, cols) = block.data.dim();
2596 for row in 0..rows.min(cols) {
2597 let coeff = block.data[[row, row]];
2598 let beta_row = block
2599 .row_off
2600 .checked_add(row)
2601 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2602 let base = beta_row
2603 .checked_mul(data.p)
2604 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2605 let end = base
2606 .checked_add(data.p)
2607 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2608 if end > diag.len() {
2609 return Err(ArrowSchurGpuFailure::Unavailable);
2610 }
2611 for channel in 0..data.p {
2612 diag[base + channel] += coeff;
2613 }
2614 }
2615 }
2616 Ok(diag)
2617 }
2618
2619 fn flatten_device_sae_data(
2620 sys: &ArrowSchurSystem,
2621 data: &DeviceSaePcgData,
2622 ridge_t: f64,
2623 stream: &Arc<CudaStream>,
2624 ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2625 let n_rows = sys.rows.len();
2626 let p = data.p;
2627 let k = data.beta_dim;
2628 if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2629 return Err(ArrowSchurGpuFailure::Unavailable);
2630 }
2631
2632 let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2633 let mut beta_base_host = Vec::<i32>::new();
2634 let mut phi_host = Vec::<f64>::new();
2635 row_ptr_host.push(0_i32);
2636 for row in data.a_phi.iter() {
2637 for &(base, phi) in row {
2638 beta_base_host.push(checked_i32(base)?);
2639 phi_host.push(phi);
2640 }
2641 row_ptr_host.push(checked_i32(beta_base_host.len())?);
2642 }
2643
2644 let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2645 let mut jac_host = Vec::<f64>::new();
2646 let mut max_q = 0usize;
2647 jac_ptr_host.push(0_i32);
2648 for row_jac in data.local_jac.iter() {
2649 if row_jac.len() % p != 0 {
2650 return Err(ArrowSchurGpuFailure::Unavailable);
2651 }
2652 max_q = max_q.max(row_jac.len() / p);
2653 jac_host.extend_from_slice(row_jac);
2654 jac_ptr_host.push(checked_i32(jac_host.len())?);
2655 }
2656 if max_q == 0 {
2657 return Err(ArrowSchurGpuFailure::Unavailable);
2658 }
2659
2660 let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2661 let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2662 let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2663 let mut smooth_data_host = Vec::<f64>::new();
2664 smooth_ptr_host.push(0_i32);
2665 for block in &data.smooth_blocks {
2666 let (rows, cols) = block.factor_a.dim();
2667 if rows != cols {
2668 return Err(ArrowSchurGpuFailure::Unavailable);
2669 }
2670 smooth_offsets_host.push(checked_i32(block.global_offset)?);
2671 smooth_m_host.push(checked_i32(rows)?);
2672 for r in 0..rows {
2673 for c in 0..cols {
2674 smooth_data_host.push(block.factor_a[[r, c]]);
2675 }
2676 }
2677 smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2678 }
2679
2680 let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2681 let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2682 let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2683 let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2684 let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2685 let mut g_data_host = Vec::<f64>::new();
2686 g_ptr_host.push(0_i32);
2687 for block in &data.sparse_g_blocks {
2688 let (rows, cols) = block.data.dim();
2689 g_row_off_host.push(checked_i32(block.row_off)?);
2690 g_col_off_host.push(checked_i32(block.col_off)?);
2691 g_rows_host.push(checked_i32(rows)?);
2692 g_cols_host.push(checked_i32(cols)?);
2693 for r in 0..rows {
2694 for c in 0..cols {
2695 g_data_host.push(block.data[[r, c]]);
2696 }
2697 }
2698 g_ptr_host.push(checked_i32(g_data_host.len())?);
2699 }
2700
2701 let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2702 for (row_idx, row) in sys.rows.iter().enumerate() {
2703 let q = data.local_jac[row_idx].len() / p;
2704 if row.htt.dim() != (q, q) {
2705 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2706 reason: format!(
2707 "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2708 row.htt.dim()
2709 ),
2710 });
2711 }
2712 let mut block = row.htt.clone();
2713 for d in 0..q {
2714 block[[d, d]] += ridge_t;
2715 }
2716 let factor = gam_linalg::triangular::cholesky_factor_in_place(
2717 block.view(),
2718 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2719 )
2720 .ok_or_else(|| {
2721 ArrowSchurGpuFailure::RidgeBumpRequired {
2724 row: row_idx,
2725 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2726 }
2727 })?;
2728 for col in 0..q {
2729 let mut e = Array1::<f64>::zeros(q);
2730 e[col] = 1.0;
2731 let solved =
2732 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2733 for r in 0..q {
2734 ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2735 }
2736 }
2737 }
2738
2739 Ok(DeviceSaePcgBuffers {
2740 row_ptr: stream
2741 .clone_htod(&row_ptr_host)
2742 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2743 beta_base: stream
2744 .clone_htod(&beta_base_host)
2745 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2746 phi: stream
2747 .clone_htod(&phi_host)
2748 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2749 jac_ptr: stream
2750 .clone_htod(&jac_ptr_host)
2751 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2752 jac: stream
2753 .clone_htod(&jac_host)
2754 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2755 smooth_offsets: stream
2756 .clone_htod(&smooth_offsets_host)
2757 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2758 smooth_m: stream
2759 .clone_htod(&smooth_m_host)
2760 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2761 smooth_ptr: stream
2762 .clone_htod(&smooth_ptr_host)
2763 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2764 smooth_data: stream
2765 .clone_htod(&smooth_data_host)
2766 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2767 g_row_off: stream
2768 .clone_htod(&g_row_off_host)
2769 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2770 g_col_off: stream
2771 .clone_htod(&g_col_off_host)
2772 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2773 g_rows: stream
2774 .clone_htod(&g_rows_host)
2775 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2776 g_cols: stream
2777 .clone_htod(&g_cols_host)
2778 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2779 g_ptr: stream
2780 .clone_htod(&g_ptr_host)
2781 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2782 g_data: stream
2783 .clone_htod(&g_data_host)
2784 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2785 ainv: stream
2786 .clone_htod(&ainv_host)
2787 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2788 u: stream
2789 .alloc_zeros::<f64>(n_rows * p)
2790 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2791 w: stream
2792 .alloc_zeros::<f64>(n_rows * max_q)
2793 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2794 v: stream
2795 .alloc_zeros::<f64>(n_rows * max_q)
2796 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2797 n_rows,
2798 p,
2799 k,
2800 max_q,
2801 smooth_blocks: data.smooth_blocks.len(),
2802 g_blocks: data.sparse_g_blocks.len(),
2803 })
2804 }
2805
2806 fn launch_sae_init(
2807 stream: &Arc<CudaStream>,
2808 module: &Arc<CudaModule>,
2809 out: &mut CudaSlice<f64>,
2810 x: &CudaSlice<f64>,
2811 ridge: f64,
2812 n: usize,
2813 ) -> Result<(), ArrowSchurGpuFailure> {
2814 let kernel = module
2815 .load_function("arrow_sae_init")
2816 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2817 let n_i32 = checked_i32(n)?;
2818 let mut builder = stream.launch_builder(&kernel);
2819 builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2820 unsafe { builder.launch(pcg_launch_config(n)?) }
2824 .map(drop)
2825 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2826 }
2827
2828 fn launch_sae_penalty_matvec(
2829 stream: &Arc<CudaStream>,
2830 module: &Arc<CudaModule>,
2831 buffers: &mut DeviceSaePcgBuffers,
2832 x: &CudaSlice<f64>,
2833 out: &mut CudaSlice<f64>,
2834 ridge_beta: f64,
2835 ) -> Result<(), ArrowSchurGpuFailure> {
2836 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2837 if buffers.smooth_blocks > 0 {
2838 let kernel = module
2839 .load_function("arrow_sae_smooth_matvec")
2840 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2841 let max_m = buffers.k;
2842 let p_i32 = checked_i32(buffers.p)?;
2843 let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2844 let cfg = LaunchConfig {
2845 grid_dim: (
2846 ((max_m as u32).saturating_add(255) / 256).max(1),
2847 checked_i32(buffers.smooth_blocks)? as u32,
2848 1,
2849 ),
2850 block_dim: (256, 1, 1),
2851 shared_mem_bytes: 0,
2852 };
2853 let mut builder = stream.launch_builder(&kernel);
2854 builder
2855 .arg(x)
2856 .arg(&mut *out)
2857 .arg(&buffers.smooth_offsets)
2858 .arg(&buffers.smooth_m)
2859 .arg(&buffers.smooth_ptr)
2860 .arg(&buffers.smooth_data)
2861 .arg(&p_i32)
2862 .arg(&blocks_i32);
2863 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2868 }
2869 if buffers.g_blocks > 0 {
2870 let kernel = module
2871 .load_function("arrow_sae_sparse_g_matvec")
2872 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2873 let max_work = buffers
2874 .k
2875 .checked_div(buffers.p)
2876 .unwrap_or(0)
2877 .saturating_mul(buffers.p);
2878 let p_i32 = checked_i32(buffers.p)?;
2879 let blocks_i32 = checked_i32(buffers.g_blocks)?;
2880 let cfg = LaunchConfig {
2881 grid_dim: (
2882 ((max_work as u32).saturating_add(255) / 256).max(1),
2883 checked_i32(buffers.g_blocks)? as u32,
2884 1,
2885 ),
2886 block_dim: (256, 1, 1),
2887 shared_mem_bytes: 0,
2888 };
2889 let mut builder = stream.launch_builder(&kernel);
2890 builder
2891 .arg(x)
2892 .arg(&mut *out)
2893 .arg(&buffers.g_row_off)
2894 .arg(&buffers.g_col_off)
2895 .arg(&buffers.g_rows)
2896 .arg(&buffers.g_cols)
2897 .arg(&buffers.g_ptr)
2898 .arg(&buffers.g_data)
2899 .arg(&p_i32)
2900 .arg(&blocks_i32);
2901 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2906 }
2907 Ok(())
2908 }
2909
2910 fn launch_sae_row_schur_sub(
2911 stream: &Arc<CudaStream>,
2912 module: &Arc<CudaModule>,
2913 buffers: &mut DeviceSaePcgBuffers,
2914 x: &CudaSlice<f64>,
2915 out: &mut CudaSlice<f64>,
2916 ) -> Result<(), ArrowSchurGpuFailure> {
2917 let p_i32 = checked_i32(buffers.p)?;
2918 let max_q_i32 = checked_i32(buffers.max_q)?;
2919 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2920 let cfg_p_rows = LaunchConfig {
2921 grid_dim: (
2922 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2923 checked_i32(buffers.n_rows)? as u32,
2924 1,
2925 ),
2926 block_dim: (256, 1, 1),
2927 shared_mem_bytes: 0,
2928 };
2929 let gather = module
2930 .load_function("arrow_sae_gather_u")
2931 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2932 {
2933 let mut builder = stream.launch_builder(&gather);
2934 builder
2935 .arg(x)
2936 .arg(&buffers.row_ptr)
2937 .arg(&buffers.beta_base)
2938 .arg(&buffers.phi)
2939 .arg(&mut buffers.u)
2940 .arg(&p_i32)
2941 .arg(&n_rows_i32);
2942 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2946 }
2947
2948 let cfg_q_rows = LaunchConfig {
2949 grid_dim: (
2950 ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2951 checked_i32(buffers.n_rows)? as u32,
2952 1,
2953 ),
2954 block_dim: (256, 1, 1),
2955 shared_mem_bytes: 0,
2956 };
2957 let apply_l = module
2958 .load_function("arrow_sae_apply_l")
2959 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2960 {
2961 let mut builder = stream.launch_builder(&apply_l);
2962 builder
2963 .arg(&buffers.u)
2964 .arg(&buffers.jac_ptr)
2965 .arg(&buffers.jac)
2966 .arg(&mut buffers.w)
2967 .arg(&p_i32)
2968 .arg(&max_q_i32)
2969 .arg(&n_rows_i32);
2970 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2974 }
2975
2976 let apply_ainv = module
2977 .load_function("arrow_sae_apply_ainv")
2978 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2979 {
2980 let mut builder = stream.launch_builder(&apply_ainv);
2981 builder
2982 .arg(&buffers.ainv)
2983 .arg(&buffers.w)
2984 .arg(&mut buffers.v)
2985 .arg(&max_q_i32)
2986 .arg(&n_rows_i32);
2987 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2991 }
2992
2993 let scatter = module
2994 .load_function("arrow_sae_scatter_sub")
2995 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2996 {
2997 let mut builder = stream.launch_builder(&scatter);
2998 builder
2999 .arg(&buffers.v)
3000 .arg(&buffers.jac_ptr)
3001 .arg(&buffers.jac)
3002 .arg(&buffers.row_ptr)
3003 .arg(&buffers.beta_base)
3004 .arg(&buffers.phi)
3005 .arg(out)
3006 .arg(&p_i32)
3007 .arg(&max_q_i32)
3008 .arg(&n_rows_i32);
3009 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3013 }
3014 Ok(())
3015 }
3016
3017 fn launch_sae_diag_sub(
3018 stream: &Arc<CudaStream>,
3019 module: &Arc<CudaModule>,
3020 buffers: &DeviceSaePcgBuffers,
3021 diag: &mut CudaSlice<f64>,
3022 ) -> Result<(), ArrowSchurGpuFailure> {
3023 let kernel = module
3024 .load_function("arrow_sae_diag_sub")
3025 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3026 let p_i32 = checked_i32(buffers.p)?;
3027 let max_q_i32 = checked_i32(buffers.max_q)?;
3028 let n_rows_i32 = checked_i32(buffers.n_rows)?;
3029 let cfg = LaunchConfig {
3030 grid_dim: (
3031 ((buffers.p as u32).saturating_add(255) / 256).max(1),
3032 checked_i32(buffers.n_rows)? as u32,
3033 1,
3034 ),
3035 block_dim: (256, 1, 1),
3036 shared_mem_bytes: 0,
3037 };
3038 let mut builder = stream.launch_builder(&kernel);
3039 builder
3040 .arg(diag)
3041 .arg(&buffers.ainv)
3042 .arg(&buffers.jac_ptr)
3043 .arg(&buffers.jac)
3044 .arg(&buffers.row_ptr)
3045 .arg(&buffers.beta_base)
3046 .arg(&buffers.phi)
3047 .arg(&p_i32)
3048 .arg(&max_q_i32)
3049 .arg(&n_rows_i32);
3050 unsafe { builder.launch(cfg) }
3054 .map(drop)
3055 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3056 }
3057
3058 fn launch_sae_matvec(
3059 stream: &Arc<CudaStream>,
3060 module: &Arc<CudaModule>,
3061 buffers: &mut DeviceSaePcgBuffers,
3062 x: &CudaSlice<f64>,
3063 out: &mut CudaSlice<f64>,
3064 ridge_beta: f64,
3065 ) -> Result<(), ArrowSchurGpuFailure> {
3066 launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3067 launch_sae_row_schur_sub(stream, module, buffers, x, out)
3068 }
3069
3070 fn pack_fused_host(
3075 sys: &ArrowSchurSystem,
3076 ridge_t: f64,
3077 p_max: usize,
3078 r_template: usize,
3079 ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3080 let n = sys.rows.len();
3081 let d = sys.d;
3082 let k = sys.k;
3083 let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3084 let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3085 let mut g_buf = vec![0.0_f64; n * p_max];
3086 for (i, row) in sys.rows.iter().enumerate() {
3087 for col in 0..d {
3089 let base = (i * p_max + col) * p_max;
3090 for r in 0..d {
3091 let mut value = row.htt[[r, col]];
3092 if r == col {
3093 value += ridge_t;
3094 }
3095 d_buf[base + r] = value;
3096 }
3097 }
3098 for col in 0..k {
3106 let base = (i * r_template + col) * p_max;
3107 for r in 0..d {
3108 b_buf[base + r] = row.htbeta[[r, col]];
3109 }
3110 }
3111 let g_base = i * p_max;
3113 for r in 0..d {
3114 g_buf[g_base + r] = row.gt[r];
3115 }
3116 }
3117 (d_buf, b_buf, g_buf)
3118 }
3119
3120 pub(super) struct ResidentArrowFrame {
3147 n: usize,
3148 d: usize,
3149 k: usize,
3150 stream: Arc<CudaStream>,
3151 blas: CudaBlas,
3152 l_dev: CudaSlice<f64>,
3155 y_dev: CudaSlice<f64>,
3158 schur_dev: CudaSlice<f64>,
3161 log_det_hessian: f64,
3164 }
3165
3166 impl ResidentArrowFrame {
3167 pub(super) fn new(
3171 sys: &ArrowSchurSystem,
3172 ridge_t: f64,
3173 ridge_beta: f64,
3174 ) -> Result<Self, ArrowSchurGpuFailure> {
3175 if ridge_t.is_nan() || ridge_beta.is_nan() {
3176 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3177 reason: "ridge is NaN".to_string(),
3178 });
3179 }
3180 let n = sys.rows.len();
3181 let d = sys.d;
3182 let k = sys.k;
3183 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3184 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3185 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3186 .and_then(|ctx| ctx.new_stream().ok())
3187 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3188 let solver =
3189 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3190 let blas =
3191 CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3192
3193 let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3195 let mut l_dev = stream
3196 .clone_htod(&d_host)
3197 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3198 let mut y_dev = stream
3199 .clone_htod(&b_host)
3200 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3201
3202 let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3204 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3205 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3209 row: idx,
3210 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3211 });
3212 }
3213
3214 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3216
3217 let schur_init: Vec<f64> = {
3222 let mut tmp = Vec::with_capacity(k * k);
3223 for col in 0..k {
3224 for row in 0..k {
3225 let mut v = sys.hbb[[row, col]];
3226 if row == col {
3227 v += ridge_beta;
3228 }
3229 tmp.push(v);
3230 }
3231 }
3232 tmp
3233 };
3234 let mut schur_dev = stream
3235 .clone_htod(&schur_init)
3236 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3237 let zero_u = stream
3240 .clone_htod(&vec![0.0_f64; n * d])
3241 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3242 let mut throwaway_rhs = stream
3243 .clone_htod(&vec![0.0_f64; k])
3244 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3245 accumulate_schur(
3246 &blas,
3247 d,
3248 k,
3249 n,
3250 &y_dev,
3251 &zero_u,
3252 &mut schur_dev,
3253 &mut throwaway_rhs,
3254 )?;
3255 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3256 if info != 0 {
3257 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3258 reason: format!("Schur Cholesky failed at pivot {info}"),
3259 });
3260 }
3261
3262 let l_local_host = stream
3264 .clone_dtoh(&l_dev)
3265 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3266 let l_schur_host = stream
3267 .clone_dtoh(&schur_dev)
3268 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3269 let mut log_det = 0.0_f64;
3270 for i in 0..n {
3271 let base = i * d * d;
3272 for j in 0..d {
3273 log_det += l_local_host[base + j * d + j].ln();
3274 }
3275 }
3276 for j in 0..k {
3277 log_det += l_schur_host[j * k + j].ln();
3278 }
3279 log_det *= 2.0;
3280
3281 Ok(Self {
3282 n,
3283 d,
3284 k,
3285 stream,
3286 blas,
3287 l_dev,
3288 y_dev,
3289 schur_dev,
3290 log_det_hessian: log_det,
3291 })
3292 }
3293
3294 #[inline]
3295 pub(super) fn log_det_hessian(&self) -> f64 {
3296 self.log_det_hessian
3297 }
3298
3299 pub(super) fn solve_gradient(
3303 &self,
3304 g_t: &[f64],
3305 g_beta: &[f64],
3306 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3307 let n = self.n;
3308 let d = self.d;
3309 let k = self.k;
3310 if g_t.len() != n * d || g_beta.len() != k {
3311 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3312 reason: format!(
3313 "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3314 g_t.len(),
3315 n * d,
3316 g_beta.len(),
3317 k
3318 ),
3319 });
3320 }
3321 let mut u_dev = self
3323 .stream
3324 .clone_htod(&g_t.to_vec())
3325 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3326 trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3327
3328 let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3331 let mut rhs_dev = self
3332 .stream
3333 .clone_htod(&rhs_init)
3334 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3335 accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3336
3337 trsm_single(
3339 &self.blas,
3340 &self.stream,
3341 k,
3342 &self.schur_dev,
3343 &mut rhs_dev,
3344 false,
3345 false,
3346 )?;
3347 trsm_single(
3348 &self.blas,
3349 &self.stream,
3350 k,
3351 &self.schur_dev,
3352 &mut rhs_dev,
3353 false,
3354 true,
3355 )?;
3356 let delta_beta_host = self
3357 .stream
3358 .clone_dtoh(&rhs_dev)
3359 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3360 let delta_beta = Array1::from_vec(delta_beta_host);
3361
3362 accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3364 trsm_batched_lower_inplace_transposed(
3365 &self.blas,
3366 &self.stream,
3367 d,
3368 n,
3369 1,
3370 &self.l_dev,
3371 &mut u_dev,
3372 )?;
3373 let x_host = self
3374 .stream
3375 .clone_dtoh(&u_dev)
3376 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3377 let mut delta_t = Array1::<f64>::zeros(n * d);
3378 for (i, v) in x_host.iter().enumerate() {
3379 delta_t[i] = -*v;
3380 }
3381
3382 Ok(ArrowSchurGpuSolution {
3383 delta_t,
3384 delta_beta,
3385 log_det_hessian: self.log_det_hessian,
3386 })
3387 }
3388 }
3389
3390 pub(super) fn solve_fused(
3391 sys: &ArrowSchurSystem,
3392 ridge_t: f64,
3393 ridge_beta: f64,
3394 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3395 let n = sys.rows.len();
3396 let d = sys.d;
3397 let k = sys.k;
3398 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3399 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3400 let p_max = plan.p_max;
3401 let r_template = plan.r_template;
3402
3403 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3404 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3405 )
3406 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3407 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3408 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3409 let stream = ctx
3410 .new_stream()
3411 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3412 let cap = &runtime.device.capability;
3413 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3414 cc_major: cap.compute_major,
3415 cc_minor: cap.compute_minor,
3416 p_max: p_max as u32,
3417 r_template: r_template as u32,
3418 };
3419 let module = fused_module_for(&ctx, key)?;
3420 let forward = module
3421 .load_function("arrow_schur_forward_pgroup")
3422 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3423 let back_sub = module
3424 .load_function("arrow_schur_back_sub_pgroup")
3425 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3426
3427 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3429 let d_dev = stream
3430 .clone_htod(&d_host)
3431 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3432 let b_dev = stream
3433 .clone_htod(&b_host)
3434 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3435 let g_dev = stream
3436 .clone_htod(&g_host)
3437 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3438 let mut l_out = stream
3439 .alloc_zeros::<f64>(n * p_max * p_max)
3440 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3441 let mut u_out = stream
3442 .alloc_zeros::<f64>(n * p_max)
3443 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3444 let mut y_out = stream
3445 .alloc_zeros::<f64>(n * p_max * r_template)
3446 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3447 let mut partial_s = stream
3448 .alloc_zeros::<f64>(plan.partial_s_doubles)
3449 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3450 let mut partial_r = stream
3451 .alloc_zeros::<f64>(plan.partial_r_doubles)
3452 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3453 let mut status_dev = stream
3454 .alloc_zeros::<i32>(n)
3455 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3456
3457 let cfg = LaunchConfig {
3459 grid_dim: (plan.blocks, 1, 1),
3460 block_dim: (plan.threads_per_block, 1, 1),
3461 shared_mem_bytes: 0,
3462 };
3463 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3464 let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3465 let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3466 let ridge_arg = ridge_t;
3467 {
3468 let mut builder = stream.launch_builder(&forward);
3469 builder
3470 .arg(&d_dev)
3471 .arg(&b_dev)
3472 .arg(&g_dev)
3473 .arg(&n_i32)
3474 .arg(&p_i32)
3475 .arg(&r_i32)
3476 .arg(&ridge_arg)
3477 .arg(&mut l_out)
3478 .arg(&mut u_out)
3479 .arg(&mut y_out)
3480 .arg(&mut partial_s)
3481 .arg(&mut partial_r)
3482 .arg(&mut status_dev);
3483 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3487 }
3488 stream
3489 .synchronize()
3490 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3491
3492 let status_host = stream
3494 .clone_dtoh(&status_dev)
3495 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3496 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3497 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3501 row,
3502 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3503 });
3504 }
3505
3506 let partial_s_host = stream
3508 .clone_dtoh(&partial_s)
3509 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3510 let partial_r_host = stream
3511 .clone_dtoh(&partial_r)
3512 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3513 let mut schur_host = vec![0.0_f64; k * k];
3514 for col in 0..k {
3515 for row in 0..k {
3516 let mut v = sys.hbb[[row, col]];
3517 if row == col {
3518 v += ridge_beta;
3519 }
3520 schur_host[col * k + row] = v;
3521 }
3522 }
3523 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3524 for i in 0..n {
3525 let s_base = i * r_template * r_template;
3528 for col in 0..k {
3529 let col_base = s_base + col * r_template;
3530 let dst_col_base = col * k;
3531 for row in 0..k {
3532 schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3533 }
3534 }
3535 let r_base = i * r_template;
3536 for a in 0..k {
3537 rhs_host[a] += partial_r_host[r_base + a];
3538 }
3539 }
3540
3541 let mut schur_dev = stream
3543 .clone_htod(&schur_host)
3544 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3545 let mut rhs_dev = stream
3546 .clone_htod(&rhs_host)
3547 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3548 let solver =
3549 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3550 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3551 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3552 if info != 0 {
3553 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3554 reason: format!("fused Schur Cholesky failed at pivot {info}"),
3555 });
3556 }
3557 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3558 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3559 let delta_beta_host = stream
3560 .clone_dtoh(&rhs_dev)
3561 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3562 let delta_beta = Array1::from_vec(delta_beta_host.clone());
3563
3564 let mut delta_t_dev = stream
3566 .alloc_zeros::<f64>(n * p_max)
3567 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3568 let back_cfg = LaunchConfig {
3569 grid_dim: (plan.blocks, 1, 1),
3570 block_dim: (plan.threads_per_block, 1, 1),
3571 shared_mem_bytes: 0,
3572 };
3573 {
3574 let mut builder = stream.launch_builder(&back_sub);
3575 builder
3576 .arg(&l_out)
3577 .arg(&u_out)
3578 .arg(&y_out)
3579 .arg(&rhs_dev)
3580 .arg(&n_i32)
3581 .arg(&p_i32)
3582 .arg(&r_i32)
3583 .arg(&mut delta_t_dev);
3584 unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3588 }
3589 stream
3590 .synchronize()
3591 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3592
3593 let delta_t_host = stream
3594 .clone_dtoh(&delta_t_dev)
3595 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3596 let mut delta_t = Array1::<f64>::zeros(n * d);
3597 for i in 0..n {
3598 let src_base = i * p_max;
3599 let dst_base = i * d;
3600 for r in 0..d {
3601 delta_t[dst_base + r] = delta_t_host[src_base + r];
3602 }
3603 }
3604
3605 let l_local_host = stream
3607 .clone_dtoh(&l_out)
3608 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3609 let l_schur_host = stream
3610 .clone_dtoh(&schur_dev)
3611 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3612 let mut log_det = 0.0_f64;
3613 for i in 0..n {
3614 let base = i * p_max * p_max;
3615 for j in 0..d {
3616 log_det += l_local_host[base + j * p_max + j].ln();
3617 }
3618 }
3619 for j in 0..k {
3620 log_det += l_schur_host[j * k + j].ln();
3621 }
3622 log_det *= 2.0;
3623
3624 Ok(ArrowSchurGpuSolution {
3625 delta_t,
3626 delta_beta,
3627 log_det_hessian: log_det,
3628 })
3629 }
3630
3631 pub(super) fn build_schur_matvec_backend(
3641 sys: &ArrowSchurSystem,
3642 ridge_t: f64,
3643 ridge_beta: f64,
3644 ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3645 let n = sys.rows.len();
3646 let d = sys.d;
3647 let k = sys.k;
3648 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3649 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3650 let p_max = plan.p_max;
3651 let r_template = plan.r_template;
3652
3653 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3654 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3655 )
3656 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3657 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3658 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3659 let stream = ctx
3660 .new_stream()
3661 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3662 let cap = &runtime.device.capability;
3663 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3664 cc_major: cap.compute_major,
3665 cc_minor: cap.compute_minor,
3666 p_max: p_max as u32,
3667 r_template: r_template as u32,
3668 };
3669 let module = fused_module_for(&ctx, key)?;
3670 let forward = module
3671 .load_function("arrow_schur_forward_pgroup")
3672 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3673
3674 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3675 let d_dev = stream
3676 .clone_htod(&d_host)
3677 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3678 let b_dev = stream
3679 .clone_htod(&b_host)
3680 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3681 let g_dev = stream
3682 .clone_htod(&g_host)
3683 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3684 let mut l_out = stream
3685 .alloc_zeros::<f64>(n * p_max * p_max)
3686 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3687 let mut u_out = stream
3688 .alloc_zeros::<f64>(n * p_max)
3689 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3690 let mut y_out = stream
3691 .alloc_zeros::<f64>(n * p_max * r_template)
3692 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3693 let mut partial_s = stream
3694 .alloc_zeros::<f64>(plan.partial_s_doubles)
3695 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3696 let mut partial_r = stream
3697 .alloc_zeros::<f64>(plan.partial_r_doubles)
3698 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3699 let mut status_dev = stream
3700 .alloc_zeros::<i32>(n)
3701 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3702
3703 let cfg = LaunchConfig {
3704 grid_dim: (plan.blocks, 1, 1),
3705 block_dim: (plan.threads_per_block, 1, 1),
3706 shared_mem_bytes: 0,
3707 };
3708 let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3709 let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3710 let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3711 let ridge_arg = ridge_t;
3712 {
3713 let mut builder = stream.launch_builder(&forward);
3714 builder
3715 .arg(&d_dev)
3716 .arg(&b_dev)
3717 .arg(&g_dev)
3718 .arg(&n_i32)
3719 .arg(&p_i32)
3720 .arg(&r_i32)
3721 .arg(&ridge_arg)
3722 .arg(&mut l_out)
3723 .arg(&mut u_out)
3724 .arg(&mut y_out)
3725 .arg(&mut partial_s)
3726 .arg(&mut partial_r)
3727 .arg(&mut status_dev);
3728 unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3731 }
3732 stream
3733 .synchronize()
3734 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3735
3736 let status_host = stream
3737 .clone_dtoh(&status_dev)
3738 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3739 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3740 return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3744 row,
3745 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3746 });
3747 }
3748
3749 let y_host = stream
3751 .clone_dtoh(&y_out)
3752 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3753
3754 let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3757 let hbb_is_kk = sys.hbb.dim() == (k, k);
3758 let hbb_matvec_opt = sys.hbb_matvec.clone();
3759
3760 let closure: crate::arrow_schur::GpuSchurMatvec =
3761 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3762 assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3763 assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3764
3765 if let Some(ref mv) = hbb_matvec_opt {
3767 mv(x.view(), out);
3768 for a in 0..k {
3769 out[a] += ridge_beta * x[a];
3770 }
3771 } else if hbb_is_kk {
3772 for a in 0..k {
3774 let mut acc = ridge_beta * x[a];
3775 for b in 0..k {
3776 acc += hbb_host[a * k + b] * x[b];
3777 }
3778 out[a] = acc;
3779 }
3780 } else {
3781 for a in 0..k {
3782 out[a] = ridge_beta * x[a];
3783 }
3784 }
3785
3786 let mut z = vec![0.0_f64; d];
3789 for i in 0..n {
3790 let y_base = i * p_max * r_template;
3791 for r in 0..d {
3792 let mut acc = 0.0;
3793 for c in 0..k {
3794 acc += y_host[y_base + c * p_max + r] * x[c];
3795 }
3796 z[r] = acc;
3797 }
3798 for c in 0..k {
3799 let mut acc = 0.0;
3800 for r in 0..d {
3801 acc += y_host[y_base + c * p_max + r] * z[r];
3802 }
3803 out[c] -= acc;
3804 }
3805 }
3806 });
3807
3808 Ok(closure)
3809 }
3810
3811 struct DeviceSaeFrameBuffers {
3814 s_off: CudaSlice<i32>,
3816 s_m: CudaSlice<i32>,
3817 s_r: CudaSlice<i32>,
3818 s_ptr: CudaSlice<i32>,
3819 s_data: CudaSlice<f64>,
3820 s_blocks: usize,
3821 g_off_i: CudaSlice<i32>,
3823 g_off_j: CudaSlice<i32>,
3824 g_ri: CudaSlice<i32>,
3825 g_rj: CudaSlice<i32>,
3826 g_mi: CudaSlice<i32>,
3827 g_mj: CudaSlice<i32>,
3828 g_ptr: CudaSlice<i32>,
3829 g_data: CudaSlice<f64>,
3830 w_ptr: CudaSlice<i32>,
3831 w_data: CudaSlice<f64>,
3832 g_blocks: usize,
3833 g_max_work: usize,
3834 htb_ptr: CudaSlice<i32>,
3836 htb: CudaSlice<f64>,
3837 q_of: CudaSlice<i32>,
3838 ainv: CudaSlice<f64>,
3839 hvec: CudaSlice<f64>,
3840 svec: CudaSlice<f64>,
3841 n_rows: usize,
3842 k: usize,
3843 max_q: usize,
3844 }
3845
3846 fn flatten_device_sae_frame_data(
3847 sys: &ArrowSchurSystem,
3848 data: &DeviceSaePcgData,
3849 frame: &DeviceSaeFrameData,
3850 ridge_t: f64,
3851 stream: &Arc<CudaStream>,
3852 ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3853 let n_rows = sys.rows.len();
3854 let k = data.beta_dim;
3855 if frame.row_htbeta.len() != n_rows
3856 || frame.ranks.len() != frame.basis_sizes.len()
3857 || frame.border_offsets.len() != frame.ranks.len()
3858 || data.smooth_blocks.len() != frame.smooth_ranks.len()
3859 {
3860 return Err(ArrowSchurGpuFailure::Unavailable);
3861 }
3862
3863 let mut s_off = Vec::new();
3865 let mut s_m = Vec::new();
3866 let mut s_r = Vec::new();
3867 let mut s_ptr = vec![0_i32];
3868 let mut s_data = Vec::<f64>::new();
3869 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3870 let (m, mc) = blk.factor_a.dim();
3871 if m != mc {
3872 return Err(ArrowSchurGpuFailure::Unavailable);
3873 }
3874 s_off.push(checked_i32(blk.global_offset)?);
3875 s_m.push(checked_i32(m)?);
3876 s_r.push(checked_i32(r)?);
3877 for ri in 0..m {
3878 for ci in 0..m {
3879 s_data.push(blk.factor_a[[ri, ci]]);
3880 }
3881 }
3882 s_ptr.push(checked_i32(s_data.len())?);
3883 }
3884
3885 let mut g_off_i = Vec::new();
3887 let mut g_off_j = Vec::new();
3888 let mut g_ri = Vec::new();
3889 let mut g_rj = Vec::new();
3890 let mut g_mi = Vec::new();
3891 let mut g_mj = Vec::new();
3892 let mut g_ptr = vec![0_i32];
3893 let mut g_data = Vec::<f64>::new();
3894 let mut w_ptr = vec![0_i32];
3895 let mut w_data = Vec::<f64>::new();
3896 let mut g_max_work = 0usize;
3897 for blk in &frame.frame_blocks {
3898 let ri = frame.ranks[blk.atom_i];
3899 let rj = frame.ranks[blk.atom_j];
3900 let (mi, mj) = blk.g.dim();
3901 if blk.w.dim() != (ri, rj) {
3902 return Err(ArrowSchurGpuFailure::Unavailable);
3903 }
3904 g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3905 g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3906 g_ri.push(checked_i32(ri)?);
3907 g_rj.push(checked_i32(rj)?);
3908 g_mi.push(checked_i32(mi)?);
3909 g_mj.push(checked_i32(mj)?);
3910 for r in 0..mi {
3911 for c in 0..mj {
3912 g_data.push(blk.g[[r, c]]);
3913 }
3914 }
3915 g_ptr.push(checked_i32(g_data.len())?);
3916 for a in 0..ri {
3917 for b in 0..rj {
3918 w_data.push(blk.w[[a, b]]);
3919 }
3920 }
3921 w_ptr.push(checked_i32(w_data.len())?);
3922 g_max_work = g_max_work.max(mi * ri);
3923 }
3924
3925 let mut htb_ptr = vec![0_i32];
3927 let mut htb = Vec::<f64>::new();
3928 let mut q_of = Vec::<i32>::with_capacity(n_rows);
3929 let mut max_q = 0usize;
3930 for (i, slab) in frame.row_htbeta.iter().enumerate() {
3931 let qi = sys.row_dims[i];
3932 let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3935 qi
3936 } else {
3937 0
3938 };
3939 q_of.push(checked_i32(q_eff)?);
3940 max_q = max_q.max(q_eff);
3941 if q_eff > 0 {
3942 htb.extend_from_slice(slab);
3943 }
3944 htb_ptr.push(checked_i32(htb.len())?);
3945 }
3946 if max_q == 0 {
3947 max_q = 1;
3950 }
3951
3952 let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3953 for (i, row) in sys.rows.iter().enumerate() {
3954 let q = q_of[i] as usize;
3955 if q == 0 {
3956 continue;
3957 }
3958 if row.htt.dim() != (q, q) {
3959 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3960 reason: format!(
3961 "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3962 row.htt.dim()
3963 ),
3964 });
3965 }
3966 let mut block = row.htt.clone();
3967 for d in 0..q {
3968 block[[d, d]] += ridge_t;
3969 }
3970 let factor = gam_linalg::triangular::cholesky_factor_in_place(
3971 block.view(),
3972 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3973 )
3974 .ok_or_else(|| {
3975 ArrowSchurGpuFailure::RidgeBumpRequired {
3978 row: i,
3979 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3980 }
3981 })?;
3982 for col in 0..q {
3983 let mut e = Array1::<f64>::zeros(q);
3984 e[col] = 1.0;
3985 let solved =
3986 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3987 for r in 0..q {
3988 ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3989 }
3990 }
3991 }
3992
3993 let htod_i = |v: &[i32]| {
3994 stream
3995 .clone_htod(v)
3996 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3997 };
3998 let htod_f = |v: &[f64]| {
3999 stream
4000 .clone_htod(v)
4001 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4002 };
4003 Ok(DeviceSaeFrameBuffers {
4004 s_off: htod_i(&s_off)?,
4005 s_m: htod_i(&s_m)?,
4006 s_r: htod_i(&s_r)?,
4007 s_ptr: htod_i(&s_ptr)?,
4008 s_data: htod_f(&s_data)?,
4009 s_blocks: data.smooth_blocks.len(),
4010 g_off_i: htod_i(&g_off_i)?,
4011 g_off_j: htod_i(&g_off_j)?,
4012 g_ri: htod_i(&g_ri)?,
4013 g_rj: htod_i(&g_rj)?,
4014 g_mi: htod_i(&g_mi)?,
4015 g_mj: htod_i(&g_mj)?,
4016 g_ptr: htod_i(&g_ptr)?,
4017 g_data: htod_f(&g_data)?,
4018 w_ptr: htod_i(&w_ptr)?,
4019 w_data: htod_f(&w_data)?,
4020 g_blocks: frame.frame_blocks.len(),
4021 g_max_work,
4022 htb_ptr: htod_i(&htb_ptr)?,
4023 htb: htod_f(&htb)?,
4024 q_of: htod_i(&q_of)?,
4025 ainv: htod_f(&ainv)?,
4026 hvec: stream
4027 .alloc_zeros::<f64>(n_rows * max_q)
4028 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4029 svec: stream
4030 .alloc_zeros::<f64>(n_rows * max_q)
4031 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4032 n_rows,
4033 k,
4034 max_q,
4035 })
4036 }
4037
4038 fn sae_frame_penalty_diag_host(
4039 data: &DeviceSaePcgData,
4040 frame: &DeviceSaeFrameData,
4041 ridge_beta: f64,
4042 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4043 let mut diag = vec![ridge_beta; data.beta_dim];
4044 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4046 let m = blk.factor_a.nrows();
4047 for ia in 0..m {
4048 let coeff = blk.factor_a[[ia, ia]];
4049 let base = blk.global_offset + ia * r;
4050 for ib in 0..r {
4051 if base + ib >= diag.len() {
4052 return Err(ArrowSchurGpuFailure::Unavailable);
4053 }
4054 diag[base + ib] += coeff;
4055 }
4056 }
4057 }
4058 for blk in &frame.frame_blocks {
4060 if blk.atom_i != blk.atom_j {
4061 continue;
4062 }
4063 let r = frame.ranks[blk.atom_i];
4064 let off = frame.border_offsets[blk.atom_i];
4065 let (mi, mj) = blk.g.dim();
4066 for li in 0..mi.min(mj) {
4067 let gii = blk.g[[li, li]];
4068 let base = off + li * r;
4069 for a in 0..r {
4070 if base + a >= diag.len() {
4071 return Err(ArrowSchurGpuFailure::Unavailable);
4072 }
4073 diag[base + a] += gii * blk.w[[a, a]];
4074 }
4075 }
4076 }
4077 Ok(diag)
4078 }
4079
4080 fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4081 Ok(LaunchConfig {
4082 grid_dim: (
4083 ((work as u32).saturating_add(255) / 256).max(1),
4084 checked_i32(n_rows)? as u32,
4085 1,
4086 ),
4087 block_dim: (256, 1, 1),
4088 shared_mem_bytes: 0,
4089 })
4090 }
4091
4092 fn launch_sae_frame_matvec(
4093 stream: &Arc<CudaStream>,
4094 module: &Arc<CudaModule>,
4095 buffers: &mut DeviceSaeFrameBuffers,
4096 x: &CudaSlice<f64>,
4097 out: &mut CudaSlice<f64>,
4098 ridge_beta: f64,
4099 ) -> Result<(), ArrowSchurGpuFailure> {
4100 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4101 if buffers.s_blocks > 0 {
4103 let kernel = module
4104 .load_function("arrow_sae_frame_smooth_matvec")
4105 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4106 let blocks_i32 = checked_i32(buffers.s_blocks)?;
4107 let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4108 let mut b = stream.launch_builder(&kernel);
4109 b.arg(x)
4110 .arg(&mut *out)
4111 .arg(&buffers.s_off)
4112 .arg(&buffers.s_m)
4113 .arg(&buffers.s_r)
4114 .arg(&buffers.s_ptr)
4115 .arg(&buffers.s_data)
4116 .arg(&blocks_i32);
4117 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4120 }
4121 if buffers.g_blocks > 0 {
4123 let kernel = module
4124 .load_function("arrow_sae_frame_g_matvec")
4125 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4126 let blocks_i32 = checked_i32(buffers.g_blocks)?;
4127 let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4128 let mut b = stream.launch_builder(&kernel);
4129 b.arg(x)
4130 .arg(&mut *out)
4131 .arg(&buffers.g_off_i)
4132 .arg(&buffers.g_off_j)
4133 .arg(&buffers.g_ri)
4134 .arg(&buffers.g_rj)
4135 .arg(&buffers.g_mi)
4136 .arg(&buffers.g_mj)
4137 .arg(&buffers.g_ptr)
4138 .arg(&buffers.g_data)
4139 .arg(&buffers.w_ptr)
4140 .arg(&buffers.w_data)
4141 .arg(&blocks_i32);
4142 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4145 }
4146 let k_i32 = checked_i32(buffers.k)?;
4148 let max_q_i32 = checked_i32(buffers.max_q)?;
4149 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4150 {
4151 let kernel = module
4152 .load_function("arrow_sae_frame_apply_h")
4153 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4154 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4155 let mut b = stream.launch_builder(&kernel);
4156 b.arg(x)
4157 .arg(&buffers.htb_ptr)
4158 .arg(&buffers.htb)
4159 .arg(&buffers.q_of)
4160 .arg(&mut buffers.hvec)
4161 .arg(&k_i32)
4162 .arg(&max_q_i32)
4163 .arg(&n_rows_i32);
4164 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4167 }
4168 {
4169 let kernel = module
4170 .load_function("arrow_sae_frame_apply_ainv")
4171 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4172 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4173 let mut b = stream.launch_builder(&kernel);
4174 b.arg(&buffers.ainv)
4175 .arg(&buffers.hvec)
4176 .arg(&buffers.q_of)
4177 .arg(&mut buffers.svec)
4178 .arg(&max_q_i32)
4179 .arg(&n_rows_i32);
4180 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4183 }
4184 {
4185 let kernel = module
4186 .load_function("arrow_sae_frame_scatter_h")
4187 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4188 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4189 let mut b = stream.launch_builder(&kernel);
4190 b.arg(&buffers.svec)
4191 .arg(&buffers.htb_ptr)
4192 .arg(&buffers.htb)
4193 .arg(&buffers.q_of)
4194 .arg(out)
4195 .arg(&k_i32)
4196 .arg(&max_q_i32)
4197 .arg(&n_rows_i32);
4198 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4201 }
4202 Ok(())
4203 }
4204
4205 fn launch_sae_frame_diag_sub(
4206 stream: &Arc<CudaStream>,
4207 module: &Arc<CudaModule>,
4208 buffers: &DeviceSaeFrameBuffers,
4209 diag: &mut CudaSlice<f64>,
4210 ) -> Result<(), ArrowSchurGpuFailure> {
4211 let kernel = module
4212 .load_function("arrow_sae_frame_diag_sub")
4213 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4214 let k_i32 = checked_i32(buffers.k)?;
4215 let max_q_i32 = checked_i32(buffers.max_q)?;
4216 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4217 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4218 let mut b = stream.launch_builder(&kernel);
4219 b.arg(diag)
4220 .arg(&buffers.ainv)
4221 .arg(&buffers.htb_ptr)
4222 .arg(&buffers.htb)
4223 .arg(&buffers.q_of)
4224 .arg(&k_i32)
4225 .arg(&max_q_i32)
4226 .arg(&n_rows_i32);
4227 unsafe { b.launch(cfg) }
4229 .map(drop)
4230 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4231 }
4232
4233 pub(super) fn framed_schur_matvec_once_on_device(
4244 sys: &ArrowSchurSystem,
4245 data: &DeviceSaePcgData,
4246 ridge_t: f64,
4247 ridge_beta: f64,
4248 x: &Array1<f64>,
4249 ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4250 let k = x.len();
4251 if k == 0 || data.beta_dim != k || sys.k != k {
4252 return Err(ArrowSchurGpuFailure::Unavailable);
4253 }
4254 let frame = data
4255 .frame
4256 .as_ref()
4257 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4258 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4261 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4262 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4263 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4264 let stream = ctx
4265 .new_stream()
4266 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4267 let vector_module = pcg_vector_module(&ctx)?;
4268 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4269 let x_dev = stream
4270 .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4271 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4272 let mut out_dev = stream
4273 .alloc_zeros::<f64>(k)
4274 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4275 launch_sae_frame_matvec(
4276 &stream,
4277 vector_module,
4278 &mut buffers,
4279 &x_dev,
4280 &mut out_dev,
4281 ridge_beta,
4282 )?;
4283 let out = stream
4284 .clone_dtoh(&out_dev)
4285 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4286 Ok(Array1::from_vec(out))
4287 }
4288
4289 pub(super) fn solve_sae_matrix_free_pcg_framed(
4290 sys: &ArrowSchurSystem,
4291 data: &DeviceSaePcgData,
4292 ridge_t: f64,
4293 ridge_beta: f64,
4294 rhs_beta: &Array1<f64>,
4295 max_iterations: usize,
4296 relative_tolerance: f64,
4297 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4298 let k = rhs_beta.len();
4299 if k == 0 || data.beta_dim != k || sys.k != k {
4300 return Err(ArrowSchurGpuFailure::Unavailable);
4301 }
4302 let frame = data
4303 .frame
4304 .as_ref()
4305 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4306 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4307 .filter(|rt| {
4308 rt.policy().reduced_schur_matvec_should_offload(
4309 sys.rows.len(),
4310 sys.k,
4311 sys.d,
4312 max_iterations,
4313 )
4314 })
4315 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4317 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4318 let stream = ctx
4319 .new_stream()
4320 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4321 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4322 let vector_module = pcg_vector_module(&ctx)?;
4323 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4324
4325 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4326 if rhs_norm == 0.0 {
4327 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4328 }
4329 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4330 let rhs_dev = stream
4331 .clone_htod(
4332 rhs_beta
4333 .as_slice()
4334 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4335 )
4336 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4337 let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4338 let mut diag_dev = stream
4339 .clone_htod(&diag_host)
4340 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4341 launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4342 let diag_host = stream
4343 .clone_dtoh(&diag_dev)
4344 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4345 let mut inv_diag = Vec::with_capacity(k);
4346 for (idx, &d) in diag_host.iter().enumerate() {
4347 if !d.is_finite() || d <= 1.0e-18 {
4348 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4349 reason: format!(
4350 "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4351 ),
4352 });
4353 }
4354 inv_diag.push(1.0 / d);
4355 }
4356 let inv_diag_dev = stream
4357 .clone_htod(&inv_diag)
4358 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4359
4360 let mut x_dev = stream
4361 .alloc_zeros::<f64>(k)
4362 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4363 let mut r_dev = stream
4364 .alloc_zeros::<f64>(k)
4365 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4366 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4367 let mut z_dev = stream
4368 .alloc_zeros::<f64>(k)
4369 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4370 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4371 let mut p_dev = stream
4372 .alloc_zeros::<f64>(k)
4373 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4374 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4375 let mut ap_dev = stream
4376 .alloc_zeros::<f64>(k)
4377 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4378
4379 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4380 if rz <= 0.0 || !rz.is_finite() {
4381 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4382 reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4383 });
4384 }
4385 let mut diag = PcgDiagnostics {
4386 precond_apply_calls: 1,
4387 stopping_reason: PcgStopReason::MaxIter,
4388 ..PcgDiagnostics::default()
4389 };
4390 for _ in 0..max_iterations.max(1) {
4391 launch_sae_frame_matvec(
4392 &stream,
4393 vector_module,
4394 &mut buffers,
4395 &p_dev,
4396 &mut ap_dev,
4397 ridge_beta,
4398 )?;
4399 diag.matvec_calls += 1;
4400 diag.iterations += 1;
4401 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4402 if pap <= 0.0 || !pap.is_finite() {
4403 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4404 reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4405 });
4406 }
4407 let alpha = rz / pap;
4408 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4409 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4410 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4411 if r_norm <= tol {
4412 diag.final_relative_residual = r_norm / rhs_norm;
4413 diag.stopping_reason = PcgStopReason::Converged;
4414 break;
4415 }
4416 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4417 diag.precond_apply_calls += 1;
4418 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4419 if rz_new <= 0.0 || !rz_new.is_finite() {
4420 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4421 reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4422 });
4423 }
4424 let beta = rz_new / rz;
4425 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4426 rz = rz_new;
4427 }
4428 if diag.stopping_reason != PcgStopReason::Converged {
4429 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4430 diag.final_relative_residual = r_norm / rhs_norm;
4431 diag.stopping_reason = PcgStopReason::MaxIter;
4432 }
4433 let x = stream
4434 .clone_dtoh(&x_dev)
4435 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4436 Ok((Array1::from_vec(x), diag))
4437 }
4438
4439 pub(super) fn solve_sae_matrix_free_pcg(
4446 sys: &ArrowSchurSystem,
4447 data: &DeviceSaePcgData,
4448 ridge_t: f64,
4449 ridge_beta: f64,
4450 rhs_beta: &Array1<f64>,
4451 max_iterations: usize,
4452 relative_tolerance: f64,
4453 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4454 let k = rhs_beta.len();
4455 if k == 0 || data.beta_dim != k || sys.k != k {
4456 return Err(ArrowSchurGpuFailure::Unavailable);
4457 }
4458 if data.frame.is_some() {
4462 return Err(ArrowSchurGpuFailure::Unavailable);
4463 }
4464 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4478 .filter(|rt| {
4479 rt.policy().reduced_schur_matvec_should_offload(
4480 sys.rows.len(),
4481 sys.k,
4482 sys.d,
4483 max_iterations,
4484 )
4485 })
4486 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4487 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4488 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4489 let stream = ctx
4490 .new_stream()
4491 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4492 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4493 let vector_module = pcg_vector_module(&ctx)?;
4494 let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4495
4496 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4497 if rhs_norm == 0.0 {
4498 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4499 }
4500 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4501 let rhs_dev = stream
4502 .clone_htod(
4503 rhs_beta
4504 .as_slice()
4505 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4506 )
4507 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4508 let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4509 let mut diag_dev = stream
4510 .clone_htod(&diag_host)
4511 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4512 launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4513 let diag_host = stream
4514 .clone_dtoh(&diag_dev)
4515 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4516 let mut inv_diag = Vec::with_capacity(k);
4517 for (idx, &d) in diag_host.iter().enumerate() {
4518 if !d.is_finite() || d <= 1.0e-18 {
4519 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4520 reason: format!(
4521 "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4522 ),
4523 });
4524 }
4525 inv_diag.push(1.0 / d);
4526 }
4527 let inv_diag_dev = stream
4528 .clone_htod(&inv_diag)
4529 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4530
4531 let mut x_dev = stream
4532 .alloc_zeros::<f64>(k)
4533 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4534 let mut r_dev = stream
4535 .alloc_zeros::<f64>(k)
4536 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4537 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4538 let mut z_dev = stream
4539 .alloc_zeros::<f64>(k)
4540 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4542 let mut p_dev = stream
4543 .alloc_zeros::<f64>(k)
4544 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4545 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4546 let mut ap_dev = stream
4547 .alloc_zeros::<f64>(k)
4548 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4549
4550 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4551 if rz <= 0.0 || !rz.is_finite() {
4552 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4553 reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4554 });
4555 }
4556 let mut diag = PcgDiagnostics {
4557 precond_apply_calls: 1,
4558 stopping_reason: PcgStopReason::MaxIter,
4559 ..PcgDiagnostics::default()
4560 };
4561
4562 for _ in 0..max_iterations.max(1) {
4563 launch_sae_matvec(
4564 &stream,
4565 vector_module,
4566 &mut buffers,
4567 &p_dev,
4568 &mut ap_dev,
4569 ridge_beta,
4570 )?;
4571 diag.matvec_calls += 1;
4572 diag.iterations += 1;
4573 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4574 if pap <= 0.0 || !pap.is_finite() {
4575 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4576 reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4577 });
4578 }
4579 let alpha = rz / pap;
4580 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4581 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4582 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4583 if r_norm <= tol {
4584 diag.final_relative_residual = r_norm / rhs_norm;
4585 diag.stopping_reason = PcgStopReason::Converged;
4586 break;
4587 }
4588 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4589 diag.precond_apply_calls += 1;
4590 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4591 if rz_new <= 0.0 || !rz_new.is_finite() {
4592 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4593 reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4594 });
4595 }
4596 let beta = rz_new / rz;
4597 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4598 rz = rz_new;
4599 }
4600 if diag.stopping_reason != PcgStopReason::Converged {
4601 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4602 diag.final_relative_residual = r_norm / rhs_norm;
4603 diag.stopping_reason = PcgStopReason::MaxIter;
4604 }
4605 let x = stream
4606 .clone_dtoh(&x_dev)
4607 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4608 Ok((Array1::from_vec(x), diag))
4609 }
4610
4611 pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4612 s_acc: &ndarray::Array2<f64>,
4613 rhs_beta: &Array1<f64>,
4614 max_iterations: usize,
4615 relative_tolerance: f64,
4616 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4617 let k = rhs_beta.len();
4618 let cg_iters = max_iterations.max(1);
4630 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4631 gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4632 m: k,
4633 n: k,
4634 k: cg_iters,
4635 },
4636 )
4637 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4638 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4639 .and_then(|ctx| ctx.new_stream().ok())
4640 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4641 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4642 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4643 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4644 let vector_module = pcg_vector_module(&ctx)?;
4645
4646 let mut inv_diag = vec![0.0_f64; k];
4648 for j in 0..k {
4649 let djj = s_acc[[j, j]];
4650 if !(djj > 0.0) {
4651 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4652 reason: format!(
4653 "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4654 ),
4655 });
4656 }
4657 inv_diag[j] = 1.0 / djj;
4658 }
4659
4660 let mut s_host = vec![0.0_f64; k * k];
4662 for col in 0..k {
4663 for row in 0..k {
4664 s_host[col * k + row] = s_acc[[row, col]];
4665 }
4666 }
4667 let s_dev = stream
4668 .clone_htod(&s_host)
4669 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4670
4671 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4675 if rhs_norm == 0.0 {
4676 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4677 }
4678 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4679
4680 let mut x_dev = stream
4683 .alloc_zeros::<f64>(k)
4684 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4685 let mut r_dev = stream
4686 .clone_htod(
4687 rhs_beta
4688 .as_slice()
4689 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4690 )
4691 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4692 let inv_diag_dev = stream
4693 .clone_htod(&inv_diag)
4694 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4695 let mut z_dev = stream
4696 .alloc_zeros::<f64>(k)
4697 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4698 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4699 let mut p_dev = stream
4700 .alloc_zeros::<f64>(k)
4701 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4703 let mut sp_dev = stream
4704 .alloc_zeros::<f64>(k)
4705 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4706 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4707 let mut diag = PcgDiagnostics {
4708 precond_apply_calls: 1,
4709 stopping_reason: PcgStopReason::MaxIter,
4710 ..PcgDiagnostics::default()
4711 };
4712 if rz <= 0.0 || !rz.is_finite() {
4713 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4714 reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4715 });
4716 }
4717
4718 let max_iters = max_iterations.max(1);
4719 for _ in 0..max_iters {
4720 let gemv_cfg = GemvConfig::<f64> {
4722 trans: cublasOperation_t::CUBLAS_OP_N,
4723 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4724 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4725 alpha: 1.0,
4726 lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4727 incx: 1,
4728 beta: 0.0,
4729 incy: 1,
4730 };
4731 unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4733 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4734 diag.matvec_calls += 1;
4735 diag.iterations += 1;
4736
4737 let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4738 if !(p_sp > 0.0) {
4739 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4742 reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4743 });
4744 }
4745 let alpha = rz / p_sp;
4746 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4747 device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4748 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4749 if r_norm <= tol {
4750 diag.final_relative_residual = r_norm / rhs_norm;
4751 diag.stopping_reason = PcgStopReason::Converged;
4752 break;
4753 }
4754 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4755 diag.precond_apply_calls += 1;
4756 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4757 if rz_new <= 0.0 || !rz_new.is_finite() {
4758 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4759 reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4760 });
4761 }
4762 let beta = rz_new / rz;
4763 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4764 rz = rz_new;
4765 }
4766 if diag.stopping_reason != PcgStopReason::Converged {
4767 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4768 diag.final_relative_residual = r_norm / rhs_norm;
4769 diag.stopping_reason = PcgStopReason::MaxIter;
4770 }
4771
4772 let x = stream
4773 .clone_dtoh(&x_dev)
4774 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4775 Ok((Array1::from_vec(x), diag))
4776 }
4777
4778 fn device_copy(
4779 blas: &CudaBlas,
4780 stream: &Arc<CudaStream>,
4781 n: usize,
4782 src: &CudaSlice<f64>,
4783 dst: &mut CudaSlice<f64>,
4784 ) -> Result<(), ArrowSchurGpuFailure> {
4785 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4786 let (src_ptr, _src_rec) = src.device_ptr(stream);
4787 let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4788 let status = unsafe {
4791 cudarc::cublas::sys::cublasDcopy_v2(
4792 *blas.handle(),
4793 n_i,
4794 src_ptr as *const f64,
4795 1,
4796 dst_ptr as *mut f64,
4797 1,
4798 )
4799 };
4800 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4801 Ok(())
4802 } else {
4803 Err(ArrowSchurGpuFailure::Unavailable)
4804 }
4805 }
4806
4807 fn device_axpy(
4808 blas: &CudaBlas,
4809 stream: &Arc<CudaStream>,
4810 n: usize,
4811 alpha: f64,
4812 x: &CudaSlice<f64>,
4813 y: &mut CudaSlice<f64>,
4814 ) -> Result<(), ArrowSchurGpuFailure> {
4815 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4816 let (x_ptr, _x_rec) = x.device_ptr(stream);
4817 let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4818 let status = unsafe {
4821 cudarc::cublas::sys::cublasDaxpy_v2(
4822 *blas.handle(),
4823 n_i,
4824 &alpha,
4825 x_ptr as *const f64,
4826 1,
4827 y_ptr as *mut f64,
4828 1,
4829 )
4830 };
4831 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4832 Ok(())
4833 } else {
4834 Err(ArrowSchurGpuFailure::Unavailable)
4835 }
4836 }
4837
4838 fn device_dot(
4839 blas: &CudaBlas,
4840 stream: &Arc<CudaStream>,
4841 n: usize,
4842 x: &CudaSlice<f64>,
4843 y: &CudaSlice<f64>,
4844 ) -> Result<f64, ArrowSchurGpuFailure> {
4845 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4846 let (x_ptr, _x_rec) = x.device_ptr(stream);
4847 let (y_ptr, _y_rec) = y.device_ptr(stream);
4848 let mut result = 0.0_f64;
4849 let status = unsafe {
4853 cudarc::cublas::sys::cublasDdot_v2(
4854 *blas.handle(),
4855 n_i,
4856 x_ptr as *const f64,
4857 1,
4858 y_ptr as *const f64,
4859 1,
4860 &mut result,
4861 )
4862 };
4863 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4864 Ok(result)
4865 } else {
4866 Err(ArrowSchurGpuFailure::Unavailable)
4867 }
4868 }
4869
4870 fn device_nrm2(
4871 blas: &CudaBlas,
4872 stream: &Arc<CudaStream>,
4873 n: usize,
4874 x: &CudaSlice<f64>,
4875 ) -> Result<f64, ArrowSchurGpuFailure> {
4876 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4877 let (x_ptr, _x_rec) = x.device_ptr(stream);
4878 let mut result = 0.0_f64;
4879 let status = unsafe {
4883 cudarc::cublas::sys::cublasDnrm2_v2(
4884 *blas.handle(),
4885 n_i,
4886 x_ptr as *const f64,
4887 1,
4888 &mut result,
4889 )
4890 };
4891 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4892 Ok(result)
4893 } else {
4894 Err(ArrowSchurGpuFailure::Unavailable)
4895 }
4896 }
4897
4898 #[cfg(test)]
4899 mod tests {
4900 use super::*;
4905 use crate::arrow_schur::{
4906 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4907 FactoredFrameGBlock,
4908 };
4909 use ndarray::Array2;
4910
4911 fn device_matvec_once(
4914 sys: &ArrowSchurSystem,
4915 data: &DeviceSaePcgData,
4916 ridge_t: f64,
4917 ridge_beta: f64,
4918 x_host: &[f64],
4919 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4920 let k = x_host.len();
4921 let frame = data
4922 .frame
4923 .as_ref()
4924 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4925 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4926 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4927 let ctx =
4928 gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4929 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4930 let stream = ctx
4931 .new_stream()
4932 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4933 let vector_module = pcg_vector_module(&ctx)?;
4934 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4935 let x_dev = stream
4936 .clone_htod(x_host)
4937 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4938 let mut out_dev = stream
4939 .alloc_zeros::<f64>(k)
4940 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4941 launch_sae_frame_matvec(
4942 &stream,
4943 vector_module,
4944 &mut buffers,
4945 &x_dev,
4946 &mut out_dev,
4947 ridge_beta,
4948 )?;
4949 stream
4950 .clone_dtoh(&out_dev)
4951 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4952 }
4953
4954 #[test]
4960 fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4961 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4962 return;
4963 }
4964 let p = 3usize;
4965 let ranks = vec![2usize, 3usize];
4966 let basis_sizes = vec![2usize, 2usize];
4967 let mut border_offsets = Vec::new();
4968 let mut acc = 0usize;
4969 for k in 0..2 {
4970 border_offsets.push(acc);
4971 acc += basis_sizes[k] * ranks[k];
4972 }
4973 let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
4975 Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4976 0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4977 })
4978 };
4979 let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4980 let w_of = |i: usize, j: usize| -> Array2<f64> {
4981 let (ui, uj) = (&frames[i], &frames[j]);
4982 Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4983 (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4984 })
4985 };
4986 let mut frame_blocks = Vec::new();
4987 for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4988 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4989 let mut g =
4990 Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4991 if i == j {
4992 for r in 0..mi.min(mj) {
4993 g[[r, r]] += mi as f64 + 2.0;
4994 }
4995 }
4996 frame_blocks.push(FactoredFrameGBlock {
4997 atom_i: i,
4998 atom_j: j,
4999 g,
5000 w: w_of(i, j),
5001 });
5002 }
5003 let mut smooth_blocks = Vec::new();
5004 for k in 0..2 {
5005 let m = basis_sizes[k];
5006 let mut s =
5007 Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5008 for r in 0..m {
5009 s[[r, r]] += 1.0;
5010 }
5011 smooth_blocks.push(DeviceSaeSmoothBlock {
5012 global_offset: border_offsets[k],
5013 factor_a: s,
5014 });
5015 }
5016 let smooth_ranks = ranks.clone();
5017 let n = 2usize;
5018 let q = 2usize;
5019 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5020 let mut row_htbeta = Vec::new();
5021 for i in 0..n {
5022 let mut htt =
5023 Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5024 for r in 0..q {
5025 htt[[r, r]] += q as f64 + 2.0;
5026 }
5027 sys.rows[i].htt = htt;
5028 let mut slab = vec![0.0_f64; q * border_dim];
5029 for c in 0..q {
5030 for col in 0..border_dim {
5031 let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5032 slab[c * border_dim + col] = v;
5033 sys.rows[i].htbeta[[c, col]] = v;
5034 }
5035 }
5036 row_htbeta.push(slab);
5037 }
5038 let data = DeviceSaePcgData {
5039 p,
5040 beta_dim: border_dim,
5041 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5042 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5043 smooth_blocks,
5044 sparse_g_blocks: Vec::new(),
5045 frame: Some(DeviceSaeFrameData {
5046 ranks,
5047 basis_sizes,
5048 border_offsets,
5049 frame_blocks,
5050 smooth_ranks,
5051 row_htbeta,
5052 }),
5053 };
5054 let ridge_t = 1e-7;
5055 let ridge_beta = 1e-6;
5056 let mut first_bad: Option<usize> = None;
5057 let mut worst = 0.0_f64;
5058 let mut worst_at = 0usize;
5059 let mut worst_dev = 0.0_f64;
5060 let mut worst_cpu = 0.0_f64;
5061 for col in 0..border_dim {
5062 let mut x = vec![0.0_f64; border_dim];
5063 x[col] = 1.0;
5064 let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5065 Ok(v) => v,
5066 Err(_) => return,
5067 };
5068 let mut cpu = vec![0.0_f64; border_dim];
5069 super::super::sae_framed_schur_matvec_cpu(
5070 &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5071 )
5072 .expect("cpu matvec");
5073 for r in 0..border_dim {
5074 let d = (dev[r] - cpu[r]).abs();
5075 if d > 1e-9 && first_bad.is_none() {
5076 first_bad = Some(r * border_dim + col);
5077 }
5078 if d > worst {
5079 worst = d;
5080 worst_at = r * border_dim + col;
5081 worst_dev = dev[r];
5082 worst_cpu = cpu[r];
5083 }
5084 }
5085 }
5086 assert!(
5087 worst <= 1e-9,
5088 "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5089 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5090 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5091 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5092 G⊗W=cross, reduced-Schur=dense per-row)",
5093 );
5094 }
5095 }
5096}
5097
5098#[cfg(test)]
5099mod tests {
5100 use super::*;
5101 use crate::arrow_schur::ArrowSchurSystem;
5102 use ndarray::{Array2, ArrayView1};
5103
5104 fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5105 let mut sys = ArrowSchurSystem::new(n, d, k);
5106 let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5107 let mut sample = || -> f64 {
5108 state = state
5109 .wrapping_mul(6364136223846793005)
5110 .wrapping_add(1442695040888963407);
5111 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5112 };
5113 for row in &mut sys.rows {
5114 let mut a = Array2::<f64>::zeros((d, d));
5115 for r in 0..d {
5116 for c in 0..d {
5117 a[[r, c]] = sample();
5118 }
5119 }
5120 let mut htt = a.t().dot(&a);
5121 for r in 0..d {
5122 htt[[r, r]] += d as f64 + 1.0;
5123 }
5124 row.htt = htt;
5125 for r in 0..d {
5126 for c in 0..k {
5127 row.htbeta[[r, c]] = 0.1 * sample();
5128 }
5129 row.gt[r] = sample();
5130 }
5131 }
5132 let mut hbb_a = Array2::<f64>::zeros((k, k));
5133 for r in 0..k {
5134 for c in 0..k {
5135 hbb_a[[r, c]] = sample();
5136 }
5137 }
5138 let mut hbb = hbb_a.t().dot(&hbb_a);
5139 for r in 0..k {
5140 hbb[[r, r]] += k as f64 + 1.0;
5141 }
5142 sys.hbb = hbb;
5143 for r in 0..k {
5144 sys.gb[r] = sample();
5145 }
5146 sys
5147 }
5148
5149 #[test]
5154 fn ridge_bump_makes_known_indefinite_blocks_pd() {
5155 let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); let mut indef2 = Array2::<f64>::zeros((2, 2));
5163 indef2[[0, 0]] = 1.0;
5164 indef2[[1, 1]] = 1.0;
5165 indef2[[0, 1]] = 2.0;
5166 indef2[[1, 0]] = 2.0;
5167 let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5170
5171 for (label, block) in [
5172 ("-I (λ_min=-1)", neg_identity),
5173 ("-250·I (λ_min=-250)", scaled_neg),
5174 ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5175 ("5·I (PD)", pd),
5176 ] {
5177 let ridge_t = 0.0;
5178 let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5179 assert!(
5180 bump > 0.0 && bump.is_finite(),
5181 "[{label}] bump must be strictly positive and finite, got {bump:e}"
5182 );
5183 let d = block.nrows();
5184 let mut shifted = block.clone();
5185 for i in 0..d {
5186 shifted[[i, i]] += ridge_t + bump;
5187 }
5188 assert!(
5189 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5190 "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5191 Gershgorin bump, but the Cholesky still rejected it"
5192 );
5193 }
5194 }
5195
5196 #[test]
5203 fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5204 let mut a = Array2::<f64>::zeros((3, 3));
5206 a[[0, 0]] = -2.0;
5207 a[[1, 1]] = 0.5;
5208 a[[2, 2]] = 1.0;
5209 a[[0, 1]] = 0.3;
5210 a[[1, 0]] = 0.3;
5211 a[[1, 2]] = -0.4;
5212 a[[2, 1]] = -0.4;
5213 a[[0, 2]] = 0.1;
5214 a[[2, 0]] = 0.1;
5215
5216 let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5217
5218 let d = 3;
5220 let mut col_major = vec![0.0_f64; d * d];
5221 for c in 0..d {
5222 for r in 0..d {
5223 col_major[c * d + r] = a[[r, c]];
5224 }
5225 }
5226 let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5227
5228 assert!(
5229 (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5230 "colmajor bump {col_major_bump:e} must match rowmajor bump \
5231 {row_major_bump:e} for a symmetric block"
5232 );
5233
5234 let mut shifted = a.clone();
5236 for i in 0..d {
5237 shifted[[i, i]] += col_major_bump;
5238 }
5239 assert!(
5240 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5241 "colmajor Gershgorin bump must make the symmetric block PD"
5242 );
5243 }
5244
5245 fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5246 let mut s = Array2::<f64>::zeros((k, k));
5247 for row in 0..k {
5248 s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5249 if row + 1 < k {
5250 s[[row, row + 1]] = -0.05;
5251 s[[row + 1, row]] = -0.05;
5252 }
5253 if row + 7 < k {
5254 s[[row, row + 7]] = 0.01;
5255 s[[row + 7, row]] = 0.01;
5256 }
5257 }
5258 let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5259 (s, rhs)
5260 }
5261
5262 fn dense_pcg_cpu_reference(
5263 s: &Array2<f64>,
5264 rhs: &Array1<f64>,
5265 max_iterations: usize,
5266 relative_tolerance: f64,
5267 ) -> Array1<f64> {
5268 let k = rhs.len();
5269 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5270 if rhs_norm == 0.0 {
5271 return Array1::<f64>::zeros(k);
5272 }
5273 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5274 let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5275 let mut x = Array1::<f64>::zeros(k);
5276 let mut r = rhs.clone();
5277 let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5278 let mut p = z.clone();
5279 let mut sp = Array1::<f64>::zeros(k);
5280 let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5281 for _ in 0..max_iterations.max(1) {
5282 for row in 0..k {
5283 let mut acc = 0.0;
5284 for col in 0..k {
5285 acc += s[[row, col]] * p[col];
5286 }
5287 sp[row] = acc;
5288 }
5289 let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5290 let alpha = rz / p_sp;
5291 for idx in 0..k {
5292 x[idx] += alpha * p[idx];
5293 r[idx] -= alpha * sp[idx];
5294 }
5295 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5296 if r_norm <= tol {
5297 break;
5298 }
5299 for idx in 0..k {
5300 z[idx] = inv_diag[idx] * r[idx];
5301 }
5302 let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5303 let beta = rz_next / rz;
5304 for idx in 0..k {
5305 p[idx] = z[idx] + beta * p[idx];
5306 }
5307 rz = rz_next;
5308 }
5309 x
5310 }
5311
5312 #[test]
5313 fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5314 let (s, rhs) = device_pcg_fixture(512);
5315 let max_iterations = 200usize;
5316 let relative_tolerance = 1.0e-12;
5317 let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5318 let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5319 &s,
5320 &rhs,
5321 max_iterations,
5322 relative_tolerance,
5323 ) {
5324 Ok(result) => result,
5325 Err(failure) => {
5332 assert!(
5333 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5334 "#1017: CUDA device present but the device reduced-beta PCG \
5335 declined/faulted instead of returning a result (tag: {failure:?}) — \
5336 the kernel does not run correctly on GPU"
5337 );
5338 return;
5339 }
5340 };
5341 let max_err = cpu
5342 .iter()
5343 .zip(device.iter())
5344 .map(|(a, b)| (a - b).abs())
5345 .fold(0.0_f64, f64::max);
5346 assert!(
5347 max_err <= 1.0e-10,
5348 "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5349 );
5350 assert!(diag.matvec_calls > 0);
5351 assert_eq!(diag.matvec_calls, diag.iterations);
5352 }
5353
5354 #[test]
5355 fn dense_reference_matches_independent_solve() {
5356 let sys = build_fixture(4, 5, 3, 7);
5357 let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5358 let n = sys.rows.len();
5362 let d = sys.d;
5363 let k = sys.k;
5364 let total = n * d + k;
5365 let mut h = Array2::<f64>::zeros((total, total));
5366 let mut g = ndarray::Array1::<f64>::zeros(total);
5367 for (i, row) in sys.rows.iter().enumerate() {
5368 let base = i * d;
5369 for c in 0..d {
5370 for r in 0..d {
5371 h[[base + r, base + c]] = row.htt[[r, c]];
5372 }
5373 }
5374 for c in 0..k {
5375 for r in 0..d {
5376 h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5377 h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5378 }
5379 }
5380 for r in 0..d {
5381 g[base + r] = row.gt[r];
5382 }
5383 }
5384 for c in 0..k {
5385 for r in 0..k {
5386 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5387 }
5388 g[n * d + c] = sys.gb[c];
5389 }
5390 let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5391 let rhs = g.mapv(|v| -v);
5392 let expected = cholesky_solve_vector(l.view(), rhs.view());
5393 for i in 0..n * d {
5394 assert!(
5395 (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5396 "delta_t[{i}] mismatch: got {} expected {}",
5397 solution.delta_t[i],
5398 expected[i]
5399 );
5400 }
5401 for a in 0..k {
5402 assert!(
5403 (solution.delta_beta[a] - expected[n * d + a]).abs()
5404 < 1e-10 * (1.0 + expected[n * d + a].abs()),
5405 "delta_beta[{a}] mismatch"
5406 );
5407 }
5408 }
5409
5410 #[test]
5424 fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5425 use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5426 let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
5428 let k = 24usize;
5429 let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5430 let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5435 let forward_slabs = slabs.clone();
5436 let transpose_slabs = slabs;
5437 sys.set_row_htbeta_operator(
5438 move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5439 let h = &forward_slabs[row];
5440 for r in 0..h.nrows() {
5441 let mut acc = 0.0_f64;
5442 for c in 0..h.ncols() {
5443 acc += h[[r, c]] * x[c];
5444 }
5445 out[r] = acc;
5446 }
5447 },
5448 move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5449 let h = &transpose_slabs[row];
5450 for r in 0..h.nrows() {
5451 for c in 0..h.ncols() {
5452 out[c] += h[[r, c]] * v[r];
5453 }
5454 }
5455 },
5456 );
5457
5458 let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5459 .expect("row-procedural matvec backend builds for matrix-free system");
5460 let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5461
5462 let mut out_parallel_a = Array1::<f64>::zeros(k);
5466 matvec(&x, &mut out_parallel_a);
5467 let mut out_parallel_b = Array1::<f64>::zeros(k);
5468 matvec(&x, &mut out_parallel_b);
5469 for a in 0..k {
5470 assert_eq!(
5471 out_parallel_a[a].to_bits(),
5472 out_parallel_b[a].to_bits(),
5473 "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5474 );
5475 }
5476
5477 let mut out_serial = Array1::<f64>::zeros(k);
5482 rayon::ThreadPoolBuilder::new()
5483 .num_threads(2)
5484 .build()
5485 .expect("build rayon pool")
5486 .install(|| matvec(&x, &mut out_serial));
5487
5488 let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5489 for a in 0..k {
5490 let diff = (out_parallel_a[a] - out_serial[a]).abs();
5491 assert!(
5492 diff <= 1e-12 * (1.0 + max_abs),
5493 "row-procedural matvec parallel vs serial diverged beyond reassociation \
5494 at index {a}: {} vs {} (diff={diff:e})",
5495 out_parallel_a[a],
5496 out_serial[a]
5497 );
5498 }
5499 }
5500
5501 #[test]
5508 fn framed_sae_schur_matvec_matches_dense_reference() {
5509 use crate::arrow_schur::{
5510 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5511 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5512 };
5513
5514 let p = 4usize;
5515 let ranks = vec![2usize, 4usize, 3usize];
5517 let basis_sizes = vec![2usize, 1usize, 2usize];
5518 let n_atoms = ranks.len();
5519 let mut border_offsets = Vec::with_capacity(n_atoms);
5520 let mut acc = 0usize;
5521 for k in 0..n_atoms {
5522 border_offsets.push(acc);
5523 acc += basis_sizes[k] * ranks[k];
5524 }
5525 let border_dim = acc; let mut state = 0x1234_5678_9abc_def0u64;
5528 let mut sample = || -> f64 {
5529 state = state
5530 .wrapping_mul(6364136223846793005)
5531 .wrapping_add(1442695040888963407);
5532 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5533 };
5534
5535 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5538 for k in 0..n_atoms {
5539 let r = ranks[k];
5540 let mut u = Array2::<f64>::zeros((p, r));
5541 for i in 0..p {
5542 for j in 0..r {
5543 u[[i, j]] = if r == p && i == j {
5544 1.0
5545 } else if r == p {
5546 0.0
5547 } else {
5548 sample()
5549 };
5550 }
5551 }
5552 frames.push(u);
5553 }
5554 let w_of = |i: usize, j: usize| -> Array2<f64> {
5555 let (ui, uj) = (&frames[i], &frames[j]);
5556 let (ri, rj) = (ranks[i], ranks[j]);
5557 let mut w = Array2::<f64>::zeros((ri, rj));
5558 for a in 0..ri {
5559 for b in 0..rj {
5560 let mut s = 0.0;
5561 for c in 0..p {
5562 s += ui[[c, a]] * uj[[c, b]];
5563 }
5564 w[[a, b]] = s;
5565 }
5566 }
5567 w
5568 };
5569
5570 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5572 let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5573 pairs.sort();
5574 for &(i, j) in &pairs {
5575 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5576 let mut g = Array2::<f64>::zeros((mi, mj));
5577 for r in 0..mi {
5578 for c in 0..mj {
5579 g[[r, c]] = 0.3 * sample();
5580 }
5581 }
5582 if i == j {
5584 for r in 0..mi.min(mj) {
5585 g[[r, r]] += mi as f64 + 2.0;
5586 }
5587 }
5588 frame_blocks.push(FactoredFrameGBlock {
5589 atom_i: i,
5590 atom_j: j,
5591 g,
5592 w: w_of(i, j),
5593 });
5594 }
5595
5596 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5598 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5599 for k in 0..n_atoms {
5600 let m = basis_sizes[k];
5601 let mut a = Array2::<f64>::zeros((m, m));
5602 for r in 0..m {
5603 for c in 0..m {
5604 a[[r, c]] = 0.2 * sample();
5605 }
5606 }
5607 let mut s = a.t().dot(&a);
5608 for r in 0..m {
5609 s[[r, r]] += 1.0;
5610 }
5611 smooth_blocks.push(DeviceSaeSmoothBlock {
5612 global_offset: border_offsets[k],
5613 factor_a: s,
5614 });
5615 smooth_ranks.push(ranks[k]);
5616 }
5617
5618 let n = 6usize;
5620 let q = 3usize;
5621 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5622 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5623 for i in 0..n {
5624 let mut a = Array2::<f64>::zeros((q, q));
5626 for r in 0..q {
5627 for c in 0..q {
5628 a[[r, c]] = sample();
5629 }
5630 }
5631 let mut htt = a.t().dot(&a);
5632 for r in 0..q {
5633 htt[[r, r]] += q as f64 + 1.0;
5634 }
5635 sys.rows[i].htt = htt;
5636 let mut slab = vec![0.0_f64; q * border_dim];
5637 for c in 0..q {
5638 for col in 0..border_dim {
5639 let v = 0.15 * sample();
5640 slab[c * border_dim + col] = v;
5641 sys.rows[i].htbeta[[c, col]] = v;
5642 }
5643 }
5644 row_htbeta.push(slab);
5645 }
5646
5647 let data_op =
5650 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5651 .expect("frame op");
5652 let mut hbb = data_op.to_dense();
5653 for k in 0..n_atoms {
5654 let op = IdentityRightKroneckerPenaltyOp {
5655 factor_a: smooth_blocks[k].factor_a.clone(),
5656 p: ranks[k],
5657 global_offset: border_offsets[k],
5658 k: border_dim,
5659 };
5660 let d = op.to_dense();
5661 for r in 0..border_dim {
5662 for c in 0..border_dim {
5663 hbb[[r, c]] += d[[r, c]];
5664 }
5665 }
5666 }
5667 sys.hbb = hbb;
5668
5669 let data = DeviceSaePcgData {
5670 p,
5671 beta_dim: border_dim,
5672 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5673 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5674 smooth_blocks,
5675 sparse_g_blocks: Vec::new(),
5676 frame: Some(DeviceSaeFrameData {
5677 ranks: ranks.clone(),
5678 basis_sizes: basis_sizes.clone(),
5679 border_offsets: border_offsets.clone(),
5680 frame_blocks,
5681 smooth_ranks,
5682 row_htbeta,
5683 }),
5684 };
5685
5686 let ridge_t = 1e-7;
5687 let ridge_beta = 1e-6;
5688
5689 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5693 for r in 0..border_dim {
5694 for c in 0..border_dim {
5695 s_dense[[r, c]] = sys.hbb[[r, c]];
5696 }
5697 s_dense[[r, r]] += ridge_beta;
5698 }
5699 for row in &sys.rows {
5700 let mut htt = row.htt.clone();
5701 for d in 0..q {
5702 htt[[d, d]] += ridge_t;
5703 }
5704 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5705 .expect("htt PD");
5706 let mut y = Array2::<f64>::zeros((q, border_dim));
5708 for col in 0..border_dim {
5709 let mut e = Array1::<f64>::zeros(q);
5710 for r in 0..q {
5711 e[r] = row.htbeta[[r, col]];
5712 }
5713 let solved = cholesky_solve_vector(factor.view(), e.view());
5714 for r in 0..q {
5715 y[[r, col]] = solved[r];
5716 }
5717 }
5718 for r in 0..border_dim {
5719 for c in 0..border_dim {
5720 let mut acc = 0.0;
5721 for d in 0..q {
5722 acc += row.htbeta[[d, r]] * y[[d, c]];
5723 }
5724 s_dense[[r, c]] -= acc;
5725 }
5726 }
5727 }
5728
5729 let mut max_rel = 0.0_f64;
5731 for trial in 0..4 {
5732 let x: Vec<f64> = (0..border_dim)
5733 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5734 .collect();
5735 let mut got = vec![0.0_f64; border_dim];
5736 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5737 .expect("framed matvec");
5738 let mut want = vec![0.0_f64; border_dim];
5739 for r in 0..border_dim {
5740 let mut acc = 0.0;
5741 for c in 0..border_dim {
5742 acc += s_dense[[r, c]] * x[c];
5743 }
5744 want[r] = acc;
5745 }
5746 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5747 for a in 0..border_dim {
5748 let rel = (got[a] - want[a]).abs() / scale;
5749 max_rel = max_rel.max(rel);
5750 }
5751 }
5752 assert!(
5753 max_rel <= 1e-10,
5754 "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5755 );
5756 }
5757
5758 #[test]
5764 fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
5765 use crate::arrow_schur::{
5766 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5767 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5768 };
5769
5770 let p = 6usize;
5774 let n_atoms = 8usize;
5775 let ranks: Vec<usize> = (0..n_atoms)
5776 .map(|k| if k % 2 == 0 { 3usize } else { p })
5777 .collect();
5778 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
5779 let mut border_offsets = Vec::with_capacity(n_atoms);
5780 let mut acc = 0usize;
5781 for k in 0..n_atoms {
5782 border_offsets.push(acc);
5783 acc += basis_sizes[k] * ranks[k];
5784 }
5785 let border_dim = acc; let mut state = 0xfeed_face_dead_beefu64;
5788 let mut sample = || -> f64 {
5789 state = state
5790 .wrapping_mul(6364136223846793005)
5791 .wrapping_add(1442695040888963407);
5792 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5793 };
5794 let mut frames: Vec<Array2<f64>> = Vec::new();
5795 for k in 0..n_atoms {
5796 let r = ranks[k];
5797 let mut u = Array2::<f64>::zeros((p, r));
5798 for i in 0..p {
5799 for j in 0..r {
5800 u[[i, j]] = if r == p && i == j {
5801 1.0
5802 } else if r == p {
5803 0.0
5804 } else {
5805 sample()
5806 };
5807 }
5808 }
5809 frames.push(u);
5810 }
5811 let w_of = |i: usize, j: usize| {
5812 let (ui, uj) = (&frames[i], &frames[j]);
5813 let (ri, rj) = (ranks[i], ranks[j]);
5814 let mut w = Array2::<f64>::zeros((ri, rj));
5815 for a in 0..ri {
5816 for b in 0..rj {
5817 let mut s = 0.0;
5818 for c in 0..p {
5819 s += ui[[c, a]] * uj[[c, b]];
5820 }
5821 w[[a, b]] = s;
5822 }
5823 }
5824 w
5825 };
5826 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
5827 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
5829 pairs.push((i, j));
5830 pairs.push((j, i));
5831 }
5832 let mut frame_blocks = Vec::new();
5833 for &(i, j) in &pairs {
5834 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5835 let mut g = Array2::<f64>::zeros((mi, mj));
5836 for r in 0..mi {
5837 for c in 0..mj {
5838 g[[r, c]] = 0.25 * sample();
5839 }
5840 }
5841 if i == j {
5842 for r in 0..mi.min(mj) {
5843 g[[r, r]] += mi as f64 + 2.0;
5844 }
5845 }
5846 frame_blocks.push(FactoredFrameGBlock {
5847 atom_i: i,
5848 atom_j: j,
5849 g,
5850 w: w_of(i, j),
5851 });
5852 }
5853 let mut smooth_blocks = Vec::new();
5854 let mut smooth_ranks = Vec::new();
5855 for k in 0..n_atoms {
5856 let m = basis_sizes[k];
5857 let mut a = Array2::<f64>::zeros((m, m));
5858 for r in 0..m {
5859 for c in 0..m {
5860 a[[r, c]] = 0.2 * sample();
5861 }
5862 }
5863 let mut s = a.t().dot(&a);
5864 for r in 0..m {
5865 s[[r, r]] += 1.0;
5866 }
5867 smooth_blocks.push(DeviceSaeSmoothBlock {
5868 global_offset: border_offsets[k],
5869 factor_a: s,
5870 });
5871 smooth_ranks.push(ranks[k]);
5872 }
5873 let n = 400usize;
5874 let q = 4usize;
5875 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5876 let mut row_htbeta = Vec::new();
5877 for i in 0..n {
5878 let mut a = Array2::<f64>::zeros((q, q));
5879 for r in 0..q {
5880 for c in 0..q {
5881 a[[r, c]] = sample();
5882 }
5883 }
5884 let mut htt = a.t().dot(&a);
5885 for r in 0..q {
5886 htt[[r, r]] += q as f64 + 1.0;
5887 }
5888 sys.rows[i].htt = htt;
5889 let mut slab = vec![0.0_f64; q * border_dim];
5890 for c in 0..q {
5891 for col in 0..border_dim {
5892 let v = 0.02 * sample();
5895 slab[c * border_dim + col] = v;
5896 sys.rows[i].htbeta[[c, col]] = v;
5897 }
5898 }
5899 row_htbeta.push(slab);
5900 }
5901 let data_op =
5902 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5903 .expect("frame op");
5904 let mut hbb = data_op.to_dense();
5905 for k in 0..n_atoms {
5906 let op = IdentityRightKroneckerPenaltyOp {
5907 factor_a: smooth_blocks[k].factor_a.clone(),
5908 p: ranks[k],
5909 global_offset: border_offsets[k],
5910 k: border_dim,
5911 };
5912 let d = op.to_dense();
5913 for r in 0..border_dim {
5914 for c in 0..border_dim {
5915 hbb[[r, c]] += d[[r, c]];
5916 }
5917 }
5918 }
5919 sys.hbb = hbb;
5920 let data = DeviceSaePcgData {
5921 p,
5922 beta_dim: border_dim,
5923 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5924 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5925 smooth_blocks,
5926 sparse_g_blocks: Vec::new(),
5927 frame: Some(DeviceSaeFrameData {
5928 ranks: ranks.clone(),
5929 basis_sizes: basis_sizes.clone(),
5930 border_offsets: border_offsets.clone(),
5931 frame_blocks,
5932 smooth_ranks,
5933 row_htbeta,
5934 }),
5935 };
5936 let ridge_t = 1e-7;
5937 let ridge_beta = 1e-6;
5938 let rhs: Array1<f64> =
5939 Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
5940
5941 let (device, diag) =
5942 match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
5943 Ok(result) => result,
5944 Err(failure) => {
5950 assert!(
5951 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5952 "#1017: CUDA device present but the framed device SAE PCG \
5953 declined/faulted instead of returning a result (tag: {failure:?}) — \
5954 the kernel does not run correctly on GPU"
5955 );
5956 return;
5957 }
5958 };
5959
5960 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5978 let oracle_resid = |x: &Array1<f64>| -> f64 {
5979 let mut sx = vec![0.0_f64; border_dim];
5980 sae_framed_schur_matvec_cpu(
5981 &sys,
5982 &data,
5983 ridge_t,
5984 ridge_beta,
5985 x.as_slice().unwrap(),
5986 &mut sx,
5987 )
5988 .expect("cpu oracle matvec");
5989 let mut acc = 0.0_f64;
5990 for a in 0..border_dim {
5991 let e = sx[a] - rhs[a];
5992 acc += e * e;
5993 }
5994 acc.sqrt()
5995 };
5996 let s_dev_resid = oracle_resid(&device);
5997 let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
5998
5999 let precond = {
6004 let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6005 let mut diag = d;
6008 for (i, row) in sys.rows.iter().enumerate() {
6009 let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6010 let qi = sys.row_dims[i];
6011 if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6012 continue;
6013 }
6014 let mut block = row.htt.clone();
6015 for dd in 0..qi {
6016 block[[dd, dd]] += ridge_t;
6017 }
6018 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6019 .expect("row htt PD");
6020 let mut ainv = Array2::<f64>::zeros((qi, qi));
6022 for col in 0..qi {
6023 let mut e = Array1::<f64>::zeros(qi);
6024 e[col] = 1.0;
6025 let s = cholesky_solve_vector(factor.view(), e.view());
6026 for r in 0..qi {
6027 ainv[[r, col]] = s[r];
6028 }
6029 }
6030 for a in 0..border_dim {
6031 let mut quad = 0.0_f64;
6032 for c in 0..qi {
6033 let hc = slab[c * border_dim + a];
6034 for dd in 0..qi {
6035 quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6036 }
6037 }
6038 diag[a] -= quad;
6039 }
6040 }
6041 Array1::from_vec(diag)
6042 };
6043 let mut cpu = Array1::<f64>::zeros(border_dim);
6044 let cpu_result = {
6045 let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6046 let mut tmp = vec![0.0_f64; border_dim];
6047 sae_framed_schur_matvec_cpu(
6048 &sys,
6049 &data,
6050 ridge_t,
6051 ridge_beta,
6052 v.as_slice().unwrap(),
6053 &mut tmp,
6054 )
6055 .expect("cpu oracle matvec");
6056 out.assign(&Array1::from_vec(tmp));
6057 };
6058 gam_linalg::pcg::pcg_core(
6059 &mut apply,
6060 &rhs.view(),
6061 &precond.view(),
6062 1e-12,
6063 800,
6064 32,
6065 false,
6066 gam_linalg::pcg::DotReduction::Serial,
6067 &mut cpu.view_mut(),
6068 )
6069 };
6070 let s_cpu_resid = oracle_resid(&cpu);
6071 let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6072
6073 assert!(
6076 dev_rel_resid <= 1e-7,
6077 "[#1551] device δβ does not solve the CPU-oracle system: \
6078 ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6079 device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6080 means the device matvec is a DIFFERENT operator (kernel bug)",
6081 diag.stopping_reason,
6082 diag.iterations,
6083 diag.final_relative_residual,
6084 );
6085 assert!(
6088 cpu_rel_resid <= 1e-6,
6089 "[#1551] CPU pcg_core failed to solve the oracle system: \
6090 ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6091 cpu_result.stop,
6092 cpu_result.iterations,
6093 );
6094 }
6095
6096 fn sae_frame_penalty_diag_host_for_test(
6100 data: &DeviceSaePcgData,
6101 ridge_beta: f64,
6102 ) -> Vec<f64> {
6103 let frame = data.frame.as_ref().expect("frame");
6104 let mut diag = vec![ridge_beta; data.beta_dim];
6105 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6106 let m = blk.factor_a.nrows();
6107 for ia in 0..m {
6108 let coeff = blk.factor_a[[ia, ia]];
6109 let base = blk.global_offset + ia * r;
6110 for ib in 0..r {
6111 diag[base + ib] += coeff;
6112 }
6113 }
6114 }
6115 for blk in &frame.frame_blocks {
6116 if blk.atom_i != blk.atom_j {
6117 continue;
6118 }
6119 let r = frame.ranks[blk.atom_i];
6120 let off = frame.border_offsets[blk.atom_i];
6121 let (mi, mj) = blk.g.dim();
6122 for li in 0..mi.min(mj) {
6123 let gii = blk.g[[li, li]];
6124 let base = off + li * r;
6125 for a in 0..r {
6126 diag[base + a] += gii * blk.w[[a, a]];
6127 }
6128 }
6129 }
6130 diag
6131 }
6132
6133 #[test]
6144 fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6145 use crate::arrow_schur::{
6146 DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6147 };
6148
6149 let p = 6usize;
6152 let n_atoms = 8usize;
6153 let ranks: Vec<usize> = (0..n_atoms)
6154 .map(|k| if k % 2 == 0 { 3usize } else { p })
6155 .collect();
6156 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6157 let mut border_offsets = Vec::with_capacity(n_atoms);
6158 let mut acc = 0usize;
6159 for k in 0..n_atoms {
6160 border_offsets.push(acc);
6161 acc += basis_sizes[k] * ranks[k];
6162 }
6163 let border_dim = acc;
6164
6165 let mut state = 0x1551_0017_1026_0922u64;
6166 let mut sample = || -> f64 {
6167 state = state
6168 .wrapping_mul(6364136223846793005)
6169 .wrapping_add(1442695040888963407);
6170 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6171 };
6172 let mut frames: Vec<Array2<f64>> = Vec::new();
6173 for k in 0..n_atoms {
6174 let r = ranks[k];
6175 let mut u = Array2::<f64>::zeros((p, r));
6176 for i in 0..p {
6177 for j in 0..r {
6178 u[[i, j]] = if r == p && i == j {
6179 1.0
6180 } else if r == p {
6181 0.0
6182 } else {
6183 sample()
6184 };
6185 }
6186 }
6187 frames.push(u);
6188 }
6189 let w_of = |i: usize, j: usize| {
6190 let (ui, uj) = (&frames[i], &frames[j]);
6191 let (ri, rj) = (ranks[i], ranks[j]);
6192 let mut w = Array2::<f64>::zeros((ri, rj));
6193 for a in 0..ri {
6194 for b in 0..rj {
6195 let mut s = 0.0;
6196 for c in 0..p {
6197 s += ui[[c, a]] * uj[[c, b]];
6198 }
6199 w[[a, b]] = s;
6200 }
6201 }
6202 w
6203 };
6204 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6205 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6206 pairs.push((i, j));
6207 pairs.push((j, i));
6208 }
6209 let mut frame_blocks = Vec::new();
6210 for &(i, j) in &pairs {
6211 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6212 let mut g = Array2::<f64>::zeros((mi, mj));
6213 for r in 0..mi {
6214 for c in 0..mj {
6215 g[[r, c]] = 0.25 * sample();
6216 }
6217 }
6218 if i == j {
6219 for r in 0..mi.min(mj) {
6220 g[[r, r]] += mi as f64 + 2.0;
6221 }
6222 }
6223 frame_blocks.push(FactoredFrameGBlock {
6224 atom_i: i,
6225 atom_j: j,
6226 g,
6227 w: w_of(i, j),
6228 });
6229 }
6230 let mut smooth_blocks = Vec::new();
6231 let mut smooth_ranks = Vec::new();
6232 for k in 0..n_atoms {
6233 let m = basis_sizes[k];
6234 let mut a = Array2::<f64>::zeros((m, m));
6235 for r in 0..m {
6236 for c in 0..m {
6237 a[[r, c]] = 0.2 * sample();
6238 }
6239 }
6240 let mut s = a.t().dot(&a);
6241 for r in 0..m {
6242 s[[r, r]] += 1.0;
6243 }
6244 smooth_blocks.push(DeviceSaeSmoothBlock {
6245 global_offset: border_offsets[k],
6246 factor_a: s,
6247 });
6248 smooth_ranks.push(ranks[k]);
6249 }
6250 let n = 32usize;
6254 let q = 4usize;
6255 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6256 let mut row_htbeta = Vec::new();
6257 for i in 0..n {
6258 let mut a = Array2::<f64>::zeros((q, q));
6259 for r in 0..q {
6260 for c in 0..q {
6261 a[[r, c]] = sample();
6262 }
6263 }
6264 let mut htt = a.t().dot(&a);
6265 for r in 0..q {
6266 htt[[r, r]] += q as f64 + 1.0;
6267 }
6268 sys.rows[i].htt = htt;
6269 let mut slab = vec![0.0_f64; q * border_dim];
6270 for c in 0..q {
6271 for col in 0..border_dim {
6272 let v = 0.3 * sample();
6273 slab[c * border_dim + col] = v;
6274 sys.rows[i].htbeta[[c, col]] = v;
6275 }
6276 }
6277 row_htbeta.push(slab);
6278 }
6279 let ridge_t = 1e-7;
6280 let ridge_beta = 1e-6;
6281 let data = DeviceSaePcgData {
6282 p,
6283 beta_dim: border_dim,
6284 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6285 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6286 smooth_blocks,
6287 sparse_g_blocks: Vec::new(),
6288 frame: Some(DeviceSaeFrameData {
6289 ranks: ranks.clone(),
6290 basis_sizes: basis_sizes.clone(),
6291 border_offsets: border_offsets.clone(),
6292 frame_blocks,
6293 smooth_ranks,
6294 row_htbeta,
6295 }),
6296 };
6297
6298 let mut probes: Vec<Array1<f64>> = Vec::new();
6302 probes.push(Array1::from_shape_fn(border_dim, |a| {
6303 ((a as f64 + 1.0) * 0.37).sin()
6304 }));
6305 probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6306 for axis in [0usize, border_dim / 3, border_dim - 1] {
6307 let mut e = Array1::<f64>::zeros(border_dim);
6308 e[axis] = 1.0;
6309 probes.push(e);
6310 }
6311
6312 let mut any_ran = false;
6313 let mut worst = 0.0_f64;
6314 for (pi, x) in probes.iter().enumerate() {
6315 let device = match super::framed_schur_matvec_once_on_device(
6316 &sys, &data, ridge_t, ridge_beta, x,
6317 ) {
6318 Ok(out) => out,
6319 Err(failure) => {
6320 assert!(
6324 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6325 "#1551: CUDA device present but the framed device matvec \
6326 declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6327 does not run on GPU"
6328 );
6329 return;
6330 }
6331 };
6332 any_ran = true;
6333 let mut cpu = vec![0.0_f64; border_dim];
6334 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6335 .expect("cpu oracle matvec");
6336 let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6337 for a in 0..border_dim {
6338 let rel = (device[a] - cpu[a]).abs() / scale;
6339 worst = worst.max(rel);
6340 assert!(
6341 rel <= 1e-9,
6342 "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6343 rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6344 device[a],
6345 cpu[a],
6346 );
6347 }
6348 }
6349 if any_ran {
6350 assert!(
6355 gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6356 "#1551: matvec ran but no GPU runtime — unexpected"
6357 );
6358 assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6359 }
6360 }
6361}