1use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24pub struct ArrowSchurGpuSolution {
26 pub delta_t: Array1<f64>,
27 pub delta_beta: Array1<f64>,
28 pub log_det_hessian: f64,
31}
32
33#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38 Unavailable,
40 RidgeBumpRequired { row: usize, bump: f64 },
43 SchurFactorFailed { reason: String },
46 GpuRequiresDenseSystem {
52 had_hbb_matvec: bool,
53 had_htbeta_matvec: bool,
54 },
55}
56
57const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106 let d = htt.nrows();
107 let mut scale = 1.0_f64;
110 let mut min_gershgorin_edge = f64::INFINITY;
111 for i in 0..d {
112 let diag = htt[[i, i]];
113 scale = scale.max(diag.abs());
114 let mut off_sum = 0.0_f64;
115 for j in 0..d {
116 if j != i {
117 off_sum += htt[[i, j]].abs();
118 }
119 }
120 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121 }
122 if !min_gershgorin_edge.is_finite() {
123 return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126 }
127 let deficit = -(min_gershgorin_edge + ridge_t);
130 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131 deficit.max(0.0) + margin
134}
135
136#[cfg(target_os = "linux")]
149#[must_use]
150fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
151 if d == 0 || block.len() < d * d {
152 return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
153 }
154 let mut scale = 1.0_f64;
157 let mut min_gershgorin_edge = f64::INFINITY;
158 for i in 0..d {
159 let diag = block[i * d + i];
160 scale = scale.max(diag.abs());
161 let mut off_sum = 0.0_f64;
162 for j in 0..d {
163 if j != i {
164 off_sum += block[j * d + i].abs();
165 }
166 }
167 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
168 }
169 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
170 if !min_gershgorin_edge.is_finite() {
171 return margin;
172 }
173 (-min_gershgorin_edge).max(0.0) + margin
174}
175
176pub fn solve_arrow_newton_step(
180 sys: &ArrowSchurSystem,
181 ridge_t: f64,
182 ridge_beta: f64,
183) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
184 let n = sys.rows.len();
185 let d = sys.d;
186 let k = sys.k;
187
188 let had_hbb_matvec = sys.hbb_matvec.is_some();
193 let had_htbeta_matvec = sys.htbeta_matvec.is_some();
194 if had_hbb_matvec || had_htbeta_matvec {
195 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
196 had_hbb_matvec,
197 had_htbeta_matvec,
198 });
199 }
200
201 if sys.hbb.dim() != (k, k) {
202 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
203 reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
204 });
205 }
206 if n == 0 || d == 0 {
207 return Err(ArrowSchurGpuFailure::Unavailable);
208 }
209 if sys
210 .rows
211 .iter()
212 .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
213 {
214 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
215 reason: "row block dimension mismatch".to_string(),
216 });
217 }
218
219 #[cfg(not(target_os = "linux"))]
220 {
221 if ridge_t.is_nan() || ridge_beta.is_nan() {
222 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
223 reason: "ridge is NaN".to_string(),
224 });
225 }
226 Err(ArrowSchurGpuFailure::Unavailable)
227 }
228
229 #[cfg(target_os = "linux")]
230 {
231 if gam_gpu::device_runtime::GpuRuntime::global()
240 .map(gam_gpu::device_runtime::GpuRuntime::device_count)
241 .unwrap_or(0)
242 > 1
243 {
244 match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
245 Ok(sol) => return Ok(sol),
246 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
247 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
248 }
249 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
250 return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
251 }
252 Err(_) => {}
255 }
256 }
257 if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
263 match cuda::solve_fused(sys, ridge_t, ridge_beta) {
264 Ok(sol) => return Ok(sol),
265 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
269 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
270 }
271 Err(_) => {}
275 }
276 }
277 cuda::solve(sys, ridge_t, ridge_beta)
278 }
279}
280
281#[cfg(target_os = "linux")]
287fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
288 let n = sys.rows.len();
289 let d = sys.d;
290 let k = sys.k;
291 let mut d_buf = Vec::with_capacity(n * d * d);
292 let mut b_buf = Vec::with_capacity(n * d * k);
293 let mut g_buf = Vec::with_capacity(n * d);
294 for row in &sys.rows {
295 pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
296 }
297 (d_buf, b_buf, g_buf)
298}
299
300#[cfg(target_os = "linux")]
301#[inline]
302fn pack_block(
303 row: &crate::arrow_schur::ArrowRowBlock,
304 ridge_t: f64,
305 d: usize,
306 k: usize,
307 d_buf: &mut Vec<f64>,
308 b_buf: &mut Vec<f64>,
309 g_buf: &mut Vec<f64>,
310) {
311 for col in 0..d {
312 for r in 0..d {
313 let mut value = row.htt[[r, col]];
314 if r == col {
315 value += ridge_t;
316 }
317 d_buf.push(value);
318 }
319 }
320 for col in 0..k {
321 for r in 0..d {
322 b_buf.push(row.htbeta[[r, col]]);
323 }
324 }
325 for r in 0..d {
326 g_buf.push(row.gt[r]);
327 }
328}
329
330#[doc(hidden)]
335#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
337 sys: &ArrowSchurSystem,
338 ridge_t: f64,
339 ridge_beta: f64,
340) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
341 if ridge_t.is_nan() || ridge_beta.is_nan() {
342 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
343 reason: "ridge is NaN".to_string(),
344 });
345 }
346 #[cfg(not(target_os = "linux"))]
347 {
348 Err(ArrowSchurGpuFailure::Unavailable)
353 }
354 #[cfg(target_os = "linux")]
355 {
356 if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
357 .is_none()
358 {
359 return Err(ArrowSchurGpuFailure::Unavailable);
360 }
361 cuda::solve_fused(sys, ridge_t, ridge_beta)
362 }
363}
364
365pub struct ResidentArrowFrameHandle {
375 #[cfg(target_os = "linux")]
376 inner: cuda::ResidentArrowFrame,
377 #[cfg(not(target_os = "linux"))]
378 _never: std::convert::Infallible,
379}
380
381impl ResidentArrowFrameHandle {
382 pub fn new(
384 sys: &ArrowSchurSystem,
385 ridge_t: f64,
386 ridge_beta: f64,
387 ) -> Result<Self, ArrowSchurGpuFailure> {
388 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
391 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
392 had_hbb_matvec: sys.hbb_matvec.is_some(),
393 had_htbeta_matvec: sys.htbeta_matvec.is_some(),
394 });
395 }
396 #[cfg(not(target_os = "linux"))]
397 {
398 if ridge_t.is_nan() || ridge_beta.is_nan() {
399 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
400 reason: "ridge is NaN".to_string(),
401 });
402 }
403 Err(ArrowSchurGpuFailure::Unavailable)
404 }
405 #[cfg(target_os = "linux")]
406 {
407 Ok(Self {
408 inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
409 })
410 }
411 }
412
413 pub fn solve_gradient(
415 &self,
416 g_t: &[f64],
417 g_beta: &[f64],
418 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
419 #[cfg(not(target_os = "linux"))]
420 {
421 if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
422 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
423 reason: "non-finite gradient entry".to_string(),
424 });
425 }
426 Err(ArrowSchurGpuFailure::Unavailable)
427 }
428 #[cfg(target_os = "linux")]
429 {
430 self.inner.solve_gradient(g_t, g_beta)
431 }
432 }
433
434 #[must_use]
436 pub fn log_det_hessian(&self) -> f64 {
437 #[cfg(not(target_os = "linux"))]
438 {
439 panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
445 }
446 #[cfg(target_os = "linux")]
447 {
448 self.inner.log_det_hessian()
449 }
450 }
451}
452
453pub fn gpu_schur_matvec_backend(
496 sys: &ArrowSchurSystem,
497 ridge_t: f64,
498 ridge_beta: f64,
499) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
500 if sys.htbeta_matvec.is_some() {
503 return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
504 }
505
506 #[cfg(not(target_os = "linux"))]
507 {
508 if ridge_t.is_nan() || ridge_beta.is_nan() {
511 return Err(ArrowSchurGpuFailure::Unavailable);
512 }
513 Err(ArrowSchurGpuFailure::Unavailable)
514 }
515
516 #[cfg(target_os = "linux")]
517 {
518 cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
519 }
520}
521
522fn build_row_procedural_matvec(
539 sys: &ArrowSchurSystem,
540 ridge_t: f64,
541 ridge_beta: f64,
542) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
543 use std::sync::Arc;
544 let n = sys.rows.len();
545 let k = sys.k;
546 let forward = sys
547 .htbeta_matvec
548 .clone()
549 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
550 let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
551 ArrowSchurGpuFailure::SchurFactorFailed {
556 reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
557 forward operator installed without its sparse adjoint"
558 .to_string(),
559 }
560 })?;
561
562 let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
567 for (i, row) in sys.rows.iter().enumerate() {
568 let di = row.htt.nrows();
569 if row.htt.ncols() != di || row.gt.len() != di {
570 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
571 reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
572 });
573 }
574 let mut block = row.htt.clone();
575 for r in 0..di {
576 block[[r, r]] += ridge_t;
577 }
578 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
579 .ok_or_else(|| {
580 ArrowSchurGpuFailure::RidgeBumpRequired {
584 row: i,
585 bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
586 }
587 })?;
588 factors.push(factor);
589 }
590
591 let penalty_op = sys.effective_penalty_op();
598 let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
599
600 let closure: crate::arrow_schur::GpuSchurMatvec =
601 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
602 assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
603 assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
604
605 {
608 let x_slice = x.as_slice().expect("x must be contiguous");
609 let out_slice = out.as_slice_mut().expect("out must be contiguous");
610 for a in 0..k {
611 out_slice[a] = ridge_beta * x_slice[a];
612 }
613 penalty_op.matvec(x_slice, out_slice);
614 }
615
616 let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
643 && rayon::current_thread_index().is_none();
644 if parallel {
645 use rayon::prelude::*;
646 const CHUNK: usize = 64;
647 let partials: Vec<Array1<f64>> = (0..n)
648 .into_par_iter()
649 .chunks(CHUNK)
650 .map(|idxs| {
651 let mut neg = Array1::<f64>::zeros(k);
655 for i in idxs {
656 let di = row_dims[i];
657 let mut v_i = Array1::<f64>::zeros(di);
659 forward(i, x.view(), &mut v_i);
660 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
662 transpose(i, w_i.view(), &mut neg);
664 }
665 neg
666 })
667 .collect();
668 let mut neg = Array1::<f64>::zeros(k);
677 for part in &partials {
678 for a in 0..k {
679 neg[a] += part[a];
680 }
681 }
682 for a in 0..k {
683 out[a] -= neg[a];
684 }
685 } else {
686 let mut neg = Array1::<f64>::zeros(k);
688 for i in 0..n {
689 let di = row_dims[i];
690 let mut v_i = Array1::<f64>::zeros(di);
692 forward(i, x.view(), &mut v_i);
693 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
695 transpose(i, w_i.view(), &mut neg);
697 }
698 for a in 0..k {
699 out[a] -= neg[a];
700 }
701 }
702 });
703
704 Ok(closure)
705}
706
707pub fn solve_reduced_beta_pcg(
729 s_acc: &Array2<f64>,
730 rhs_beta: &Array1<f64>,
731 max_iterations: usize,
732 relative_tolerance: f64,
733) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
734 solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
735 .map(|(x, _)| x)
736}
737
738#[doc(hidden)]
739pub fn solve_reduced_beta_pcg_with_diagnostics(
740 s_acc: &Array2<f64>,
741 rhs_beta: &Array1<f64>,
742 max_iterations: usize,
743 relative_tolerance: f64,
744) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
745 let k = rhs_beta.len();
746 if s_acc.dim() != (k, k) {
747 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
748 reason: format!(
749 "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
750 s_acc.dim()
751 ),
752 });
753 }
754 if k == 0 {
755 return Err(ArrowSchurGpuFailure::Unavailable);
756 }
757
758 #[cfg(not(target_os = "linux"))]
759 {
760 if relative_tolerance.is_nan() || max_iterations == 0 {
761 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
762 reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
763 });
764 }
765 Err(ArrowSchurGpuFailure::Unavailable)
766 }
767
768 #[cfg(target_os = "linux")]
769 {
770 cuda::solve_reduced_beta_pcg_with_diagnostics(
771 s_acc,
772 rhs_beta,
773 max_iterations,
774 relative_tolerance,
775 )
776 }
777}
778
779pub fn solve_sae_matrix_free_pcg(
780 sys: &ArrowSchurSystem,
781 data: &DeviceSaePcgData,
782 ridge_t: f64,
783 ridge_beta: f64,
784 rhs_beta: &Array1<f64>,
785 max_iterations: usize,
786 relative_tolerance: f64,
787) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
788 if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
789 return Err(ArrowSchurGpuFailure::Unavailable);
790 }
791 #[cfg(not(target_os = "linux"))]
792 {
793 if ridge_t.is_nan()
794 || ridge_beta.is_nan()
795 || relative_tolerance.is_nan()
796 || max_iterations == 0
797 {
798 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
799 reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
800 });
801 }
802 Err(ArrowSchurGpuFailure::Unavailable)
803 }
804 #[cfg(target_os = "linux")]
805 {
806 if data.frame.is_some() {
813 cuda::solve_sae_matrix_free_pcg_framed(
814 sys,
815 data,
816 ridge_t,
817 ridge_beta,
818 rhs_beta,
819 max_iterations,
820 relative_tolerance,
821 )
822 } else {
823 cuda::solve_sae_matrix_free_pcg(
824 sys,
825 data,
826 ridge_t,
827 ridge_beta,
828 rhs_beta,
829 max_iterations,
830 relative_tolerance,
831 )
832 }
833 }
834}
835
836#[doc(hidden)]
845pub fn framed_schur_matvec_once_on_device(
846 sys: &ArrowSchurSystem,
847 data: &DeviceSaePcgData,
848 ridge_t: f64,
849 ridge_beta: f64,
850 x: &Array1<f64>,
851) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
852 if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
853 return Err(ArrowSchurGpuFailure::Unavailable);
854 }
855 if data.frame.is_none() {
856 return Err(ArrowSchurGpuFailure::Unavailable);
857 }
858 #[cfg(not(target_os = "linux"))]
859 {
860 if ridge_t.is_finite() && ridge_beta.is_finite() {
865 return Err(ArrowSchurGpuFailure::Unavailable);
866 }
867 Err(ArrowSchurGpuFailure::Unavailable)
868 }
869 #[cfg(target_os = "linux")]
870 {
871 cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
872 }
873}
874
875#[doc(hidden)]
879pub fn solve_arrow_newton_step_dense_reference(
880 sys: &ArrowSchurSystem,
881 ridge_t: f64,
882 ridge_beta: f64,
883) -> Result<ArrowSchurGpuSolution, String> {
884 let n = sys.rows.len();
885 let d = sys.d;
886 let k = sys.k;
887 let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
888 let mut h = Array2::<f64>::zeros((total, total));
889 let mut rhs = Array1::<f64>::zeros(total);
890 for (i, row) in sys.rows.iter().enumerate() {
891 let base = i * d;
892 for c in 0..d {
893 for r in 0..d {
894 h[[base + r, base + c]] = row.htt[[r, c]];
895 }
896 h[[base + c, base + c]] += ridge_t;
897 }
898 for c in 0..k {
899 for r in 0..d {
900 let value = row.htbeta[[r, c]];
901 h[[base + r, n * d + c]] = value;
902 h[[n * d + c, base + r]] = value;
903 }
904 }
905 for r in 0..d {
906 rhs[base + r] = -row.gt[r];
907 }
908 }
909 for c in 0..k {
910 for r in 0..k {
911 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
912 }
913 h[[n * d + c, n * d + c]] += ridge_beta;
914 rhs[n * d + c] = -sys.gb[c];
915 }
916 let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
917 .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
918 let mut log_det = 0.0_f64;
919 for i in 0..total {
920 log_det += factor[[i, i]].ln();
921 }
922 log_det *= 2.0;
923 let solved = cholesky_solve_vector(factor.view(), rhs.view());
924 let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
925 let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
926 Ok(ArrowSchurGpuSolution {
927 delta_t,
928 delta_beta,
929 log_det_hessian: log_det,
930 })
931}
932
933#[doc(hidden)]
944pub fn sae_framed_penalty_matvec_cpu(
945 data: &DeviceSaePcgData,
946 ridge_beta: f64,
947 x: &[f64],
948 out: &mut [f64],
949) {
950 let frame = data
951 .frame
952 .as_ref()
953 .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
954 let k = data.beta_dim;
955 for a in 0..k {
956 out[a] = ridge_beta * x[a];
957 }
958 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
960 let off = blk.global_offset;
961 let m = blk.factor_a.nrows();
962 for i_a in 0..m {
963 for i_b in 0..r {
964 let mut acc = 0.0_f64;
965 for j_a in 0..m {
966 let s = blk.factor_a[[i_a, j_a]];
967 if s == 0.0 {
968 continue;
969 }
970 acc += s * x[off + j_a * r + i_b];
971 }
972 out[off + i_a * r + i_b] += acc;
973 }
974 }
975 }
976 for blk in &frame.frame_blocks {
978 let r_i = frame.ranks[blk.atom_i];
979 let r_j = frame.ranks[blk.atom_j];
980 let off_i = frame.border_offsets[blk.atom_i];
981 let off_j = frame.border_offsets[blk.atom_j];
982 let (m_i, m_j) = blk.g.dim();
983 for li in 0..m_i {
984 let yi_base = off_i + li * r_i;
985 for lj in 0..m_j {
986 let g = blk.g[[li, lj]];
987 if g == 0.0 {
988 continue;
989 }
990 let xj_base = off_j + lj * r_j;
991 for a in 0..r_i {
992 let mut acc = 0.0_f64;
993 for b in 0..r_j {
994 acc += blk.w[[a, b]] * x[xj_base + b];
995 }
996 out[yi_base + a] += g * acc;
997 }
998 }
999 }
1000 }
1001}
1002
1003#[doc(hidden)]
1012pub fn sae_framed_schur_matvec_cpu(
1013 sys: &ArrowSchurSystem,
1014 data: &DeviceSaePcgData,
1015 ridge_t: f64,
1016 ridge_beta: f64,
1017 x: &[f64],
1018 out: &mut [f64],
1019) -> Result<(), String> {
1020 let frame = data
1021 .frame
1022 .as_ref()
1023 .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1024 let k = data.beta_dim;
1025 sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1026 if frame.row_htbeta.len() != sys.rows.len() {
1027 return Err(format!(
1028 "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1029 frame.row_htbeta.len(),
1030 sys.rows.len()
1031 ));
1032 }
1033 for (i, row) in sys.rows.iter().enumerate() {
1034 let slab = &frame.row_htbeta[i];
1035 if slab.is_empty() {
1036 continue;
1037 }
1038 let qi = sys.row_dims[i];
1039 if qi == 0 || slab.len() != qi * k {
1040 continue;
1041 }
1042 let mut h = vec![0.0_f64; qi];
1044 for c in 0..qi {
1045 let base = c * k;
1046 let mut acc = 0.0_f64;
1047 for a in 0..k {
1048 acc += slab[base + a] * x[a];
1049 }
1050 h[c] = acc;
1051 }
1052 let mut block = row.htt.clone();
1054 for d in 0..qi {
1055 block[[d, d]] += ridge_t;
1056 }
1057 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1058 .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1059 let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1060 for c in 0..qi {
1062 let sc = s[c];
1063 if sc == 0.0 {
1064 continue;
1065 }
1066 let base = c * k;
1067 for a in 0..k {
1068 out[a] -= slab[base + a] * sc;
1069 }
1070 }
1071 }
1072 Ok(())
1073}
1074
1075#[cfg(target_os = "linux")]
1076mod cuda {
1077 use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1078 use gam_gpu::driver::to_i32;
1079 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1080 use crate::arrow_schur::{
1081 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1082 };
1083 use cudarc::cublas::sys::{
1084 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1085 };
1086 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1087 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1088 use cudarc::driver::{
1089 CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1090 PushKernelArg,
1091 };
1092 use ndarray::Array1;
1093 use std::sync::{Arc, OnceLock};
1094
1095 struct RowSlot {
1100 d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
1109 bump: Option<f64>,
1112 tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
1118
1119 pub(super) fn solve_multi_gpu(
1140 sys: &ArrowSchurSystem,
1141 ridge_t: f64,
1142 ridge_beta: f64,
1143 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1144 let n = sys.rows.len();
1145 let d = sys.d;
1146 let k = sys.k;
1147 if n == 0 || d == 0 || k == 0 {
1148 return Err(ArrowSchurGpuFailure::Unavailable);
1149 }
1150 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1154 return Err(ArrowSchurGpuFailure::Unavailable);
1155 }
1156
1157 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1158 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1159 if runtime.device_count() < 2 {
1160 return Err(ArrowSchurGpuFailure::Unavailable);
1161 }
1162
1163 let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1165 for row in &sys.rows {
1166 if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1167 return Err(ArrowSchurGpuFailure::Unavailable);
1168 }
1169 let mut d_block = Vec::with_capacity(d * d);
1170 let mut b_block = Vec::with_capacity(d * k);
1171 let mut g_vec = Vec::with_capacity(d);
1172 pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1173 slots.push(RowSlot {
1174 d_block,
1175 b_block,
1176 g_vec,
1177 l_block: Vec::new(),
1178 u_vec: Vec::new(),
1179 y_block: Vec::new(),
1180 log_det_local: 0.0,
1181 bump: None,
1182 tile_partial_schur: None,
1183 tile_partial_rhs: None,
1184 delta_t_block: vec![0.0; d],
1185 });
1186 }
1187
1188 let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1190 forward_tile(ordinal, d, k, tile)
1191 });
1192 if forward_ok.is_none() {
1193 return Err(ArrowSchurGpuFailure::Unavailable);
1194 }
1195
1196 let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1198 if let Some((row, bump)) = slots
1199 .iter()
1200 .enumerate()
1201 .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1202 {
1203 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1204 }
1205
1206 let mut schur_host = vec![0.0_f64; k * k];
1211 for col in 0..k {
1212 for row in 0..k {
1213 let mut v = sys.hbb[[row, col]];
1214 if row == col {
1215 v += ridge_beta;
1216 }
1217 schur_host[col * k + row] = v;
1218 }
1219 }
1220 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1221 let mut log_det = 0.0_f64;
1222 for start in tile_starts(&row_base_of_tile) {
1223 let slot = &slots[start];
1224 let partial_schur = slot
1225 .tile_partial_schur
1226 .as_ref()
1227 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1228 let partial_rhs = slot
1229 .tile_partial_rhs
1230 .as_ref()
1231 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1232 for idx in 0..k * k {
1237 schur_host[idx] += partial_schur[idx];
1238 }
1239 for a in 0..k {
1240 rhs_host[a] += partial_rhs[a];
1241 }
1242 }
1243 for slot in &slots {
1244 log_det += slot.log_det_local;
1245 }
1246
1247 let primary = runtime.selected_device().ordinal;
1251 let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1252 .and_then(|ctx| ctx.new_stream().ok())
1253 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1254 let solver =
1255 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1256 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1257 let mut schur_dev = stream
1258 .clone_htod(&schur_host)
1259 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1260 let mut rhs_dev = stream
1261 .clone_htod(&rhs_host)
1262 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1263 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1264 if info != 0 {
1265 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1266 reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1267 });
1268 }
1269 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1270 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1271 let delta_beta_host = stream
1272 .clone_dtoh(&rhs_dev)
1273 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1274 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1275 let l_schur_host = stream
1276 .clone_dtoh(&schur_dev)
1277 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1278 for j in 0..k {
1279 log_det += l_schur_host[j * k + j].ln();
1280 }
1281 log_det *= 2.0;
1282
1283 let delta_beta_ref = &delta_beta_host;
1285 let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1286 back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1287 });
1288 if back_ok.is_none() {
1289 return Err(ArrowSchurGpuFailure::Unavailable);
1290 }
1291
1292 let mut delta_t = Array1::<f64>::zeros(n * d);
1294 for (i, slot) in slots.iter().enumerate() {
1295 let base = i * d;
1296 for r in 0..d {
1297 delta_t[base + r] = slot.delta_t_block[r];
1298 }
1299 }
1300
1301 Ok(ArrowSchurGpuSolution {
1302 delta_t,
1303 delta_beta,
1304 log_det_hessian: log_det,
1305 })
1306 }
1307
1308 fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1311 tiles.iter().map(|(_, range)| range.start)
1312 }
1313
1314 fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1322 if tile.is_empty() {
1323 return Some(());
1324 }
1325 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1328 .and_then(|ctx| ctx.new_stream().ok())?;
1329 let solver = DnHandle::new(stream.clone()).ok()?;
1330 let blas = CudaBlas::new(stream.clone()).ok()?;
1331 let m = tile.len();
1332
1333 let mut d_host = Vec::with_capacity(m * d * d);
1336 let mut b_host = Vec::with_capacity(m * d * k);
1337 let mut g_host = Vec::with_capacity(m * d);
1338 for slot in tile.iter() {
1339 d_host.extend_from_slice(&slot.d_block);
1340 b_host.extend_from_slice(&slot.b_block);
1341 g_host.extend_from_slice(&slot.g_vec);
1342 }
1343 let mut d_dev = stream.clone_htod(&d_host).ok()?;
1344 let mut b_dev = stream.clone_htod(&b_host).ok()?;
1345 let mut g_dev = stream.clone_htod(&g_host).ok()?;
1346
1347 let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1353 if let Some(local) = info_host.iter().position(|info| *info != 0) {
1354 tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1355 &tile[local].d_block,
1356 d,
1357 ));
1358 return Some(());
1359 }
1360
1361 trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1363 trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1364
1365 let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1367 let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1368 accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1369
1370 let l_host = stream.clone_dtoh(&d_dev).ok()?;
1372 let u_host = stream.clone_dtoh(&g_dev).ok()?;
1373 let y_host = stream.clone_dtoh(&b_dev).ok()?;
1374 let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1375 let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1376
1377 for (local, slot) in tile.iter_mut().enumerate() {
1378 let l_base = local * d * d;
1379 let u_base = local * d;
1380 let y_base = local * d * k;
1381 slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1382 slot.u_vec = u_host[u_base..u_base + d].to_vec();
1383 slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1384 let mut log_det_local = 0.0_f64;
1385 for j in 0..d {
1386 log_det_local += l_host[l_base + j * d + j].ln();
1387 }
1388 slot.log_det_local = log_det_local;
1389 }
1390 tile[0].tile_partial_schur = Some(partial_schur);
1391 tile[0].tile_partial_rhs = Some(partial_rhs);
1392 Some(())
1393 }
1394
1395 fn back_sub_tile(
1399 ordinal: usize,
1400 d: usize,
1401 k: usize,
1402 delta_beta: &[f64],
1403 tile: &mut [RowSlot],
1404 ) -> Option<()> {
1405 if tile.is_empty() {
1406 return Some(());
1407 }
1408 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1411 .and_then(|ctx| ctx.new_stream().ok())?;
1412 let blas = CudaBlas::new(stream.clone()).ok()?;
1413 let m = tile.len();
1414
1415 let mut l_host = Vec::with_capacity(m * d * d);
1416 let mut u_host = Vec::with_capacity(m * d);
1417 let mut y_host = Vec::with_capacity(m * d * k);
1418 for slot in tile.iter() {
1419 l_host.extend_from_slice(&slot.l_block);
1420 u_host.extend_from_slice(&slot.u_vec);
1421 y_host.extend_from_slice(&slot.y_block);
1422 }
1423 let d_dev = stream.clone_htod(&l_host).ok()?;
1424 let mut g_dev = stream.clone_htod(&u_host).ok()?;
1425 let b_dev = stream.clone_htod(&y_host).ok()?;
1426 let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1427
1428 accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1430 trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1431 let x_host = stream.clone_dtoh(&g_dev).ok()?;
1432 for (local, slot) in tile.iter_mut().enumerate() {
1433 let base = local * d;
1434 for r in 0..d {
1435 slot.delta_t_block[r] = -x_host[base + r];
1436 }
1437 }
1438 Some(())
1439 }
1440
1441 pub(super) fn solve(
1442 sys: &ArrowSchurSystem,
1443 ridge_t: f64,
1444 ridge_beta: f64,
1445 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1446 let n = sys.rows.len();
1447 let d = sys.d;
1448 let k = sys.k;
1449 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1450 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1451
1452 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1453 .and_then(|ctx| ctx.new_stream().ok())
1454 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1455 let solver =
1456 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1457 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1458
1459 let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1461 let mut d_dev = stream
1462 .clone_htod(&d_host)
1463 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1464 let mut b_dev = stream
1465 .clone_htod(&b_host)
1466 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1467 let mut g_dev = stream
1468 .clone_htod(&g_host)
1469 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1470
1471 let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1479 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1480 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1484 row: idx,
1485 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1486 });
1487 }
1488
1489 trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1492 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1495
1496 let schur_init: Vec<f64> = {
1515 let mut tmp = Vec::with_capacity(k * k);
1516 for col in 0..k {
1517 for row in 0..k {
1518 let mut v = sys.hbb[[row, col]];
1519 if row == col {
1520 v += ridge_beta;
1521 }
1522 tmp.push(v);
1523 }
1524 }
1525 tmp
1526 };
1527 let mut schur_dev = stream
1528 .clone_htod(&schur_init)
1529 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1530 let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1531 let mut rhs_dev = stream
1532 .clone_htod(&rhs_init)
1533 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1534
1535 accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1536
1537 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1539 if info != 0 {
1540 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1541 reason: format!("Schur Cholesky failed at pivot {info}"),
1542 });
1543 }
1544 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1546 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1547 let delta_beta_host = stream
1548 .clone_dtoh(&rhs_dev)
1549 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1550 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1551
1552 accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1560 trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1561
1562 let x_host = stream
1563 .clone_dtoh(&g_dev)
1564 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1565 let mut delta_t = Array1::<f64>::zeros(n * d);
1566 for (i, v) in x_host.iter().enumerate() {
1567 delta_t[i] = -*v;
1568 }
1569
1570 let l_local_host = stream
1572 .clone_dtoh(&d_dev)
1573 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1574 let l_schur_host = stream
1575 .clone_dtoh(&schur_dev)
1576 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1577 let mut log_det = 0.0_f64;
1578 for i in 0..n {
1579 let base = i * d * d;
1580 for j in 0..d {
1581 log_det += l_local_host[base + j * d + j].ln();
1582 }
1583 }
1584 for j in 0..k {
1585 log_det += l_schur_host[j * k + j].ln();
1586 }
1587 log_det *= 2.0;
1588
1589 Ok(ArrowSchurGpuSolution {
1590 delta_t,
1591 delta_beta,
1592 log_det_hessian: log_det,
1593 })
1594 }
1595
1596 fn potrf_batched(
1597 solver: &DnHandle,
1598 stream: &Arc<CudaStream>,
1599 p: usize,
1600 batch: usize,
1601 matrices: &mut CudaSlice<f64>,
1602 ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1603 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1604 let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1605 let matrix_len = p * p;
1606 let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1607 let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1608 let mut ptrs = Vec::with_capacity(batch);
1609 for idx in 0..batch {
1610 ptrs.push(base_ptr + (idx as u64) * bytes_per);
1611 }
1612 let mut ptrs_dev = stream
1613 .clone_htod(&ptrs)
1614 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1615 let mut info_dev = stream
1616 .alloc_zeros::<i32>(batch)
1617 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1618 let status = {
1619 let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1620 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1621 unsafe {
1624 cusolver_sys::cusolverDnDpotrfBatched(
1625 solver.cu(),
1626 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1627 p_i,
1628 ptrs_ptr as *mut *mut f64,
1629 p_i,
1630 info_ptr as *mut i32,
1631 batch_i,
1632 )
1633 }
1634 };
1635 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1636 return Err(ArrowSchurGpuFailure::Unavailable);
1637 }
1638 stream
1639 .clone_dtoh(&info_dev)
1640 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1641 }
1642
1643 fn potrf_single(
1644 solver: &DnHandle,
1645 stream: &Arc<CudaStream>,
1646 p: usize,
1647 matrix: &mut CudaSlice<f64>,
1648 ) -> Result<i32, ArrowSchurGpuFailure> {
1649 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1650 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1651 let mut lwork = 0_i32;
1652 {
1653 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1654 let status = unsafe {
1656 cusolver_sys::cusolverDnDpotrf_bufferSize(
1657 solver.cu(),
1658 uplo,
1659 p_i,
1660 mat_ptr as *mut f64,
1661 p_i,
1662 &mut lwork,
1663 )
1664 };
1665 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1666 return Err(ArrowSchurGpuFailure::Unavailable);
1667 }
1668 }
1669 let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1670 let mut workspace = stream
1671 .alloc_zeros::<f64>(lwork_usize.max(1))
1672 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1673 let mut info_dev = stream
1674 .alloc_zeros::<i32>(1)
1675 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1676 {
1677 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1678 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1679 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1680 let status = unsafe {
1682 cusolver_sys::cusolverDnDpotrf(
1683 solver.cu(),
1684 uplo,
1685 p_i,
1686 mat_ptr as *mut f64,
1687 p_i,
1688 work_ptr as *mut f64,
1689 lwork,
1690 info_ptr as *mut i32,
1691 )
1692 };
1693 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1694 return Err(ArrowSchurGpuFailure::Unavailable);
1695 }
1696 }
1697 let info_host = stream
1698 .clone_dtoh(&info_dev)
1699 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1700 Ok(info_host[0])
1701 }
1702
1703 fn trsm_batched_lower_inplace(
1707 blas: &CudaBlas,
1708 stream: &Arc<CudaStream>,
1709 d: usize,
1710 n: usize,
1711 nrhs: usize,
1712 l_stack: &CudaSlice<f64>,
1713 rhs_stack: &mut CudaSlice<f64>,
1714 ) -> Result<(), ArrowSchurGpuFailure> {
1715 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1716 }
1717
1718 fn trsm_batched_lower_inplace_transposed(
1720 blas: &CudaBlas,
1721 stream: &Arc<CudaStream>,
1722 d: usize,
1723 n: usize,
1724 nrhs: usize,
1725 l_stack: &CudaSlice<f64>,
1726 rhs_stack: &mut CudaSlice<f64>,
1727 ) -> Result<(), ArrowSchurGpuFailure> {
1728 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1729 }
1730
1731 fn trsm_batched_inplace_inner(
1732 blas: &CudaBlas,
1733 stream: &Arc<CudaStream>,
1734 d: usize,
1735 n: usize,
1736 nrhs: usize,
1737 l_stack: &CudaSlice<f64>,
1738 rhs_stack: &mut CudaSlice<f64>,
1739 transposed: bool,
1740 ) -> Result<(), ArrowSchurGpuFailure> {
1741 let alpha = 1.0_f64;
1742 let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1743 let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1744 let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1745 let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1746 let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1747 let (l_base, _l_record) = l_stack.device_ptr(stream);
1748 let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1749 let mut l_ptrs = Vec::with_capacity(n);
1750 let mut rhs_ptrs = Vec::with_capacity(n);
1751 for i in 0..n {
1752 l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1753 rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1754 }
1755 let mut l_ptrs_dev = stream
1756 .clone_htod(&l_ptrs)
1757 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1758 let mut rhs_ptrs_dev = stream
1759 .clone_htod(&rhs_ptrs)
1760 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1761 let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1762 let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1763 let op = if transposed {
1764 cublasOperation_t::CUBLAS_OP_T
1765 } else {
1766 cublasOperation_t::CUBLAS_OP_N
1767 };
1768 let handle = *blas.handle();
1769 let status = unsafe {
1772 cudarc::cublas::sys::cublasDtrsmBatched(
1773 handle,
1774 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1775 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1776 op,
1777 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1778 d_i,
1779 nrhs_i,
1780 &alpha,
1781 l_ptrs_ptr as *const *const f64,
1782 d_i,
1783 rhs_ptrs_ptr as *const *mut f64,
1784 d_i,
1785 batch_i,
1786 )
1787 };
1788 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1789 return Err(ArrowSchurGpuFailure::Unavailable);
1790 }
1791 Ok(())
1792 }
1793
1794 fn trsm_single(
1797 blas: &CudaBlas,
1798 stream: &Arc<CudaStream>,
1799 n: usize,
1800 l: &CudaSlice<f64>,
1801 rhs: &mut CudaSlice<f64>,
1802 upper: bool,
1803 transposed: bool,
1804 ) -> Result<(), ArrowSchurGpuFailure> {
1805 let alpha = 1.0_f64;
1806 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1807 let handle = *blas.handle();
1808 let (l_ptr, _l_rec) = l.device_ptr(stream);
1809 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1810 let status = unsafe {
1812 cudarc::cublas::sys::cublasDtrsm_v2(
1813 handle,
1814 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1815 if upper {
1816 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1817 } else {
1818 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1819 },
1820 if transposed {
1821 cublasOperation_t::CUBLAS_OP_T
1822 } else {
1823 cublasOperation_t::CUBLAS_OP_N
1824 },
1825 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1826 n_i,
1827 1,
1828 &alpha,
1829 l_ptr as *const f64,
1830 n_i,
1831 rhs_ptr as *mut f64,
1832 n_i,
1833 )
1834 };
1835 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1836 return Err(ArrowSchurGpuFailure::Unavailable);
1837 }
1838 Ok(())
1839 }
1840
1841 fn accumulate_schur(
1845 blas: &CudaBlas,
1846 d: usize,
1847 k: usize,
1848 n: usize,
1849 y_stack: &CudaSlice<f64>,
1850 u_stack: &CudaSlice<f64>,
1851 schur: &mut CudaSlice<f64>,
1852 rhs: &mut CudaSlice<f64>,
1853 ) -> Result<(), ArrowSchurGpuFailure> {
1854 let y_block_elems = d * k;
1855 let u_block_elems = d;
1856 for i in 0..n {
1857 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1858 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1859 let gemm_cfg = GemmConfig::<f64> {
1861 transa: cublasOperation_t::CUBLAS_OP_T,
1862 transb: cublasOperation_t::CUBLAS_OP_N,
1863 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1864 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1865 k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1866 alpha: -1.0,
1867 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1868 ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1869 beta: 1.0,
1870 ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1871 };
1872 unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1874 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1875 let gemv_cfg = GemvConfig::<f64> {
1877 trans: cublasOperation_t::CUBLAS_OP_T,
1878 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1879 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1880 alpha: 1.0,
1881 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1882 incx: 1,
1883 beta: 1.0,
1884 incy: 1,
1885 };
1886 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1889 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1890 }
1891 Ok(())
1892 }
1893
1894 fn accumulate_schur_rhs_only(
1902 blas: &CudaBlas,
1903 d: usize,
1904 k: usize,
1905 n: usize,
1906 y_stack: &CudaSlice<f64>,
1907 u_stack: &CudaSlice<f64>,
1908 rhs: &mut CudaSlice<f64>,
1909 ) -> Result<(), ArrowSchurGpuFailure> {
1910 let y_block_elems = d * k;
1911 let u_block_elems = d;
1912 for i in 0..n {
1913 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1914 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1915 let gemv_cfg = GemvConfig::<f64> {
1916 trans: cublasOperation_t::CUBLAS_OP_T,
1917 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1918 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1919 alpha: 1.0,
1920 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1921 incx: 1,
1922 beta: 1.0,
1923 incy: 1,
1924 };
1925 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1928 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1929 }
1930 Ok(())
1931 }
1932
1933 fn accumulate_back_sub_rhs(
1936 blas: &CudaBlas,
1937 d: usize,
1938 k: usize,
1939 n: usize,
1940 y_stack: &CudaSlice<f64>,
1941 delta_beta: &CudaSlice<f64>,
1942 u_stack: &mut CudaSlice<f64>,
1943 ) -> Result<(), ArrowSchurGpuFailure> {
1944 let y_block_elems = d * k;
1945 let u_block_elems = d;
1946 for i in 0..n {
1947 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1948 let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1949 let gemv_cfg = GemvConfig::<f64> {
1950 trans: cublasOperation_t::CUBLAS_OP_N,
1951 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1952 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1953 alpha: 1.0,
1954 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1955 incx: 1,
1956 beta: 1.0,
1957 incy: 1,
1958 };
1959 unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1962 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1963 }
1964 Ok(())
1965 }
1966
1967 use std::collections::HashMap;
1983 use std::sync::Mutex;
1984
1985 struct FusedModuleCache {
1990 modules: Mutex<
1991 HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1992 >,
1993 }
1994
1995 fn fused_module_cache() -> &'static FusedModuleCache {
1996 static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
1997 CACHE.get_or_init(|| FusedModuleCache {
1998 modules: Mutex::new(HashMap::new()),
1999 })
2000 }
2001
2002 fn fused_module_for(
2003 ctx: &Arc<CudaContext>,
2004 key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2005 ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2006 let cache = fused_module_cache();
2007 if let Ok(guard) = cache.modules.lock() {
2008 if let Some(existing) = guard.get(&key) {
2009 return Ok(existing.clone());
2010 }
2011 }
2012 let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2013 key.p_max as usize,
2014 key.r_template as usize,
2015 );
2016 let ptx = cudarc::nvrtc::compile_ptx(&src).map_err(|err| {
2017 ArrowSchurGpuFailure::SchurFactorFailed {
2018 reason: format!(
2019 "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2020 key.p_max, key.r_template
2021 ),
2022 }
2023 })?;
2024 let module = ctx
2025 .load_module(ptx)
2026 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2027 if let Ok(mut guard) = cache.modules.lock() {
2028 guard.entry(key).or_insert_with(|| module.clone());
2029 }
2030 Ok(module)
2031 }
2032
2033 const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2034extern "C" __global__ void arrow_pcg_jacobi_mul(
2035 const double* __restrict__ inv_diag,
2036 const double* __restrict__ r,
2037 double* __restrict__ z,
2038 int n
2039) {
2040 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2041 if (idx < n) {
2042 z[idx] = inv_diag[idx] * r[idx];
2043 }
2044}
2045
2046extern "C" __global__ void arrow_pcg_update_p(
2047 const double* __restrict__ z,
2048 double beta,
2049 double* __restrict__ p,
2050 int n
2051) {
2052 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2053 if (idx < n) {
2054 p[idx] = z[idx] + beta * p[idx];
2055 }
2056}
2057
2058extern "C" __global__ void arrow_sae_init(
2059 double* __restrict__ out,
2060 const double* __restrict__ x,
2061 double ridge,
2062 int n
2063) {
2064 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2065 if (idx < n) {
2066 out[idx] = ridge * x[idx];
2067 }
2068}
2069
2070extern "C" __global__ void arrow_sae_smooth_matvec(
2071 const double* __restrict__ x,
2072 double* __restrict__ out,
2073 const int* __restrict__ block_offsets,
2074 const int* __restrict__ block_m,
2075 const int* __restrict__ factor_ptr,
2076 const double* __restrict__ factors,
2077 int p,
2078 int n_blocks
2079) {
2080 int block_id = blockIdx.y;
2081 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2082 if (block_id >= n_blocks) {
2083 return;
2084 }
2085 int m = block_m[block_id];
2086 int total = m * p;
2087 if (linear >= total) {
2088 return;
2089 }
2090 int li = linear / p;
2091 int oc = linear - li * p;
2092 int off = block_offsets[block_id];
2093 int fbase = factor_ptr[block_id];
2094 double acc = 0.0;
2095 for (int lj = 0; lj < m; ++lj) {
2096 double a = factors[fbase + li * m + lj];
2097 acc += a * x[off + lj * p + oc];
2098 }
2099 out[off + li * p + oc] += acc;
2100}
2101
2102extern "C" __global__ void arrow_sae_sparse_g_matvec(
2103 const double* __restrict__ x,
2104 double* __restrict__ out,
2105 const int* __restrict__ row_off,
2106 const int* __restrict__ col_off,
2107 const int* __restrict__ rows,
2108 const int* __restrict__ cols,
2109 const int* __restrict__ data_ptr,
2110 const double* __restrict__ data,
2111 int p,
2112 int n_blocks
2113) {
2114 int block_id = blockIdx.y;
2115 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2116 if (block_id >= n_blocks) {
2117 return;
2118 }
2119 int m_i = rows[block_id];
2120 int m_j = cols[block_id];
2121 int total = m_i * p;
2122 if (linear >= total) {
2123 return;
2124 }
2125 int li = linear / p;
2126 int oc = linear - li * p;
2127 int rbase = row_off[block_id];
2128 int cbase = col_off[block_id];
2129 int dbase = data_ptr[block_id];
2130 double acc = 0.0;
2131 for (int lj = 0; lj < m_j; ++lj) {
2132 acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2133 }
2134 // #1017 — a row atom co-occurs with multiple column atoms, so several
2135 // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2136 // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2137 // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2138 // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2139 atomicAdd(&out[(rbase + li) * p + oc], acc);
2140}
2141
2142extern "C" __global__ void arrow_sae_gather_u(
2143 const double* __restrict__ x,
2144 const int* __restrict__ row_ptr,
2145 const int* __restrict__ beta_base,
2146 const double* __restrict__ phi,
2147 double* __restrict__ u,
2148 int p,
2149 int n_rows
2150) {
2151 int row = blockIdx.y;
2152 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2153 if (row >= n_rows || oc >= p) {
2154 return;
2155 }
2156 double acc = 0.0;
2157 int start = row_ptr[row];
2158 int end = row_ptr[row + 1];
2159 for (int e = start; e < end; ++e) {
2160 acc += phi[e] * x[beta_base[e] + oc];
2161 }
2162 u[row * p + oc] = acc;
2163}
2164
2165extern "C" __global__ void arrow_sae_apply_l(
2166 const double* __restrict__ u,
2167 const int* __restrict__ jac_ptr,
2168 const double* __restrict__ jac,
2169 double* __restrict__ w,
2170 int p,
2171 int max_q,
2172 int n_rows
2173) {
2174 int row = blockIdx.y;
2175 int c = blockIdx.x * blockDim.x + threadIdx.x;
2176 if (row >= n_rows) {
2177 return;
2178 }
2179 int jstart = jac_ptr[row];
2180 int q = (jac_ptr[row + 1] - jstart) / p;
2181 if (c >= q) {
2182 return;
2183 }
2184 double acc = 0.0;
2185 for (int oc = 0; oc < p; ++oc) {
2186 acc += jac[jstart + c * p + oc] * u[row * p + oc];
2187 }
2188 w[row * max_q + c] = acc;
2189}
2190
2191extern "C" __global__ void arrow_sae_apply_ainv(
2192 const double* __restrict__ ainv,
2193 const double* __restrict__ w,
2194 double* __restrict__ v,
2195 int max_q,
2196 int n_rows
2197) {
2198 int row = blockIdx.y;
2199 int c = blockIdx.x * blockDim.x + threadIdx.x;
2200 if (row >= n_rows || c >= max_q) {
2201 return;
2202 }
2203 double acc = 0.0;
2204 int base = row * max_q * max_q;
2205 for (int j = 0; j < max_q; ++j) {
2206 acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2207 }
2208 v[row * max_q + c] = acc;
2209}
2210
2211extern "C" __global__ void arrow_sae_scatter_sub(
2212 const double* __restrict__ v,
2213 const int* __restrict__ jac_ptr,
2214 const double* __restrict__ jac,
2215 const int* __restrict__ row_ptr,
2216 const int* __restrict__ beta_base,
2217 const double* __restrict__ phi,
2218 double* __restrict__ out,
2219 int p,
2220 int max_q,
2221 int n_rows
2222) {
2223 int row = blockIdx.y;
2224 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2225 if (row >= n_rows || oc >= p) {
2226 return;
2227 }
2228 int jstart = jac_ptr[row];
2229 int q = (jac_ptr[row + 1] - jstart) / p;
2230 double lt_v = 0.0;
2231 for (int c = 0; c < q; ++c) {
2232 lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2233 }
2234 int start = row_ptr[row];
2235 int end = row_ptr[row + 1];
2236 for (int e = start; e < end; ++e) {
2237 atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2238 }
2239}
2240
2241extern "C" __global__ void arrow_sae_diag_sub(
2242 double* __restrict__ diag,
2243 const double* __restrict__ ainv,
2244 const int* __restrict__ jac_ptr,
2245 const double* __restrict__ jac,
2246 const int* __restrict__ row_ptr,
2247 const int* __restrict__ beta_base,
2248 const double* __restrict__ phi,
2249 int p,
2250 int max_q,
2251 int n_rows
2252) {
2253 int row = blockIdx.y;
2254 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2255 if (row >= n_rows || oc >= p) {
2256 return;
2257 }
2258 int jstart = jac_ptr[row];
2259 int q = (jac_ptr[row + 1] - jstart) / p;
2260 int abase = row * max_q * max_q;
2261 double quad = 0.0;
2262 for (int c = 0; c < q; ++c) {
2263 double lc = jac[jstart + c * p + oc];
2264 for (int d = 0; d < q; ++d) {
2265 quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2266 }
2267 }
2268 int start = row_ptr[row];
2269 int end = row_ptr[row + 1];
2270 for (int e = start; e < end; ++e) {
2271 double pe = phi[e];
2272 atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2273 }
2274}
2275
2276/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2277 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2278 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2279 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2280 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2281
2282extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2283 const double* __restrict__ x,
2284 double* __restrict__ out,
2285 const int* __restrict__ block_offsets,
2286 const int* __restrict__ block_m,
2287 const int* __restrict__ block_r,
2288 const int* __restrict__ factor_ptr,
2289 const double* __restrict__ factors,
2290 int n_blocks
2291) {
2292 int block_id = blockIdx.y;
2293 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2294 if (block_id >= n_blocks) {
2295 return;
2296 }
2297 int m = block_m[block_id];
2298 int r = block_r[block_id];
2299 int total = m * r;
2300 if (linear >= total) {
2301 return;
2302 }
2303 int li = linear / r;
2304 int ib = linear - li * r;
2305 int off = block_offsets[block_id];
2306 int fbase = factor_ptr[block_id];
2307 double acc = 0.0;
2308 for (int lj = 0; lj < m; ++lj) {
2309 double a = factors[fbase + li * m + lj];
2310 acc += a * x[off + lj * r + ib];
2311 }
2312 out[off + li * r + ib] += acc;
2313}
2314
2315extern "C" __global__ void arrow_sae_frame_g_matvec(
2316 const double* __restrict__ x,
2317 double* __restrict__ out,
2318 const int* __restrict__ off_i,
2319 const int* __restrict__ off_j,
2320 const int* __restrict__ r_i,
2321 const int* __restrict__ r_j,
2322 const int* __restrict__ m_i,
2323 const int* __restrict__ m_j,
2324 const int* __restrict__ g_ptr,
2325 const double* __restrict__ g_data,
2326 const int* __restrict__ w_ptr,
2327 const double* __restrict__ w_data,
2328 int n_blocks
2329) {
2330 int block_id = blockIdx.y;
2331 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2332 if (block_id >= n_blocks) {
2333 return;
2334 }
2335 int ri = r_i[block_id];
2336 int rj = r_j[block_id];
2337 int mi = m_i[block_id];
2338 int mj = m_j[block_id];
2339 int total = mi * ri;
2340 if (linear >= total) {
2341 return;
2342 }
2343 int li = linear / ri; // basis row in atom i
2344 int a = linear - li * ri; // frame coord in atom i
2345 int oi = off_i[block_id];
2346 int oj = off_j[block_id];
2347 int gbase = g_ptr[block_id];
2348 int wbase = w_ptr[block_id];
2349 double acc = 0.0;
2350 for (int lj = 0; lj < mj; ++lj) {
2351 double g = g_data[gbase + li * mj + lj];
2352 if (g == 0.0) { continue; }
2353 int xj_base = oj + lj * rj;
2354 double inner = 0.0;
2355 for (int b = 0; b < rj; ++b) {
2356 inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2357 }
2358 acc += g * inner;
2359 }
2360 // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2361 // multiple co-occurring (i,j) frame blocks running concurrently on
2362 // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2363 // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2364 atomicAdd(&out[oi + li * ri + a], acc);
2365}
2366
2367/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2368 * h_i = H_tβ^(i) · x (length q_i)
2369 * s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
2370 * out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
2371 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2372 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2373extern "C" __global__ void arrow_sae_frame_apply_h(
2374 const double* __restrict__ x,
2375 const int* __restrict__ htb_ptr,
2376 const double* __restrict__ htb,
2377 const int* __restrict__ q_of,
2378 double* __restrict__ hvec,
2379 int k,
2380 int max_q,
2381 int n_rows
2382) {
2383 int row = blockIdx.y;
2384 int c = blockIdx.x * blockDim.x + threadIdx.x;
2385 if (row >= n_rows) { return; }
2386 int q = q_of[row];
2387 if (c >= q) { return; }
2388 int base = htb_ptr[row] + c * k;
2389 double acc = 0.0;
2390 for (int a = 0; a < k; ++a) {
2391 acc += htb[base + a] * x[a];
2392 }
2393 hvec[row * max_q + c] = acc;
2394}
2395
2396extern "C" __global__ void arrow_sae_frame_apply_ainv(
2397 const double* __restrict__ ainv,
2398 const double* __restrict__ hvec,
2399 const int* __restrict__ q_of,
2400 double* __restrict__ svec,
2401 int max_q,
2402 int n_rows
2403) {
2404 int row = blockIdx.y;
2405 int c = blockIdx.x * blockDim.x + threadIdx.x;
2406 if (row >= n_rows || c >= max_q) { return; }
2407 int q = q_of[row];
2408 double acc = 0.0;
2409 int abase = row * max_q * max_q;
2410 for (int j = 0; j < q; ++j) {
2411 acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2412 }
2413 svec[row * max_q + c] = acc;
2414}
2415
2416extern "C" __global__ void arrow_sae_frame_scatter_h(
2417 const double* __restrict__ svec,
2418 const int* __restrict__ htb_ptr,
2419 const double* __restrict__ htb,
2420 const int* __restrict__ q_of,
2421 double* __restrict__ out,
2422 int k,
2423 int max_q,
2424 int n_rows
2425) {
2426 int row = blockIdx.y;
2427 int a = blockIdx.x * blockDim.x + threadIdx.x;
2428 if (row >= n_rows || a >= k) { return; }
2429 int q = q_of[row];
2430 int hbase = htb_ptr[row];
2431 double acc = 0.0;
2432 for (int c = 0; c < q; ++c) {
2433 acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2434 }
2435 atomicAdd(&out[a], -acc);
2436}
2437
2438/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2439extern "C" __global__ void arrow_sae_frame_diag_sub(
2440 double* __restrict__ diag,
2441 const double* __restrict__ ainv,
2442 const int* __restrict__ htb_ptr,
2443 const double* __restrict__ htb,
2444 const int* __restrict__ q_of,
2445 int k,
2446 int max_q,
2447 int n_rows
2448) {
2449 int row = blockIdx.y;
2450 int a = blockIdx.x * blockDim.x + threadIdx.x;
2451 if (row >= n_rows || a >= k) { return; }
2452 int q = q_of[row];
2453 int hbase = htb_ptr[row];
2454 int abase = row * max_q * max_q;
2455 double quad = 0.0;
2456 for (int c = 0; c < q; ++c) {
2457 double hc = htb[hbase + c * k + a];
2458 for (int d = 0; d < q; ++d) {
2459 quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2460 }
2461 }
2462 atomicAdd(&diag[a], -quad);
2463}
2464"#;
2465
2466 fn pcg_vector_module(
2467 ctx: &Arc<CudaContext>,
2468 ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2469 static CACHE: gam_gpu::device_cache::PtxModuleCache =
2470 gam_gpu::device_cache::PtxModuleCache::new();
2471 CACHE
2472 .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2473 .map_err(|err| {
2474 log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2480 ArrowSchurGpuFailure::Unavailable
2481 })
2482 }
2483
2484 fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2485 let threads = 256u32;
2486 let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2487 Ok(LaunchConfig {
2488 grid_dim: (blocks, 1, 1),
2489 block_dim: (threads, 1, 1),
2490 shared_mem_bytes: 0,
2491 })
2492 }
2493
2494 fn launch_jacobi_mul(
2495 stream: &Arc<CudaStream>,
2496 module: &Arc<CudaModule>,
2497 inv_diag: &CudaSlice<f64>,
2498 r: &CudaSlice<f64>,
2499 z: &mut CudaSlice<f64>,
2500 n: usize,
2501 ) -> Result<(), ArrowSchurGpuFailure> {
2502 let kernel = module
2503 .load_function("arrow_pcg_jacobi_mul")
2504 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2505 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2506 let mut builder = stream.launch_builder(&kernel);
2507 builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2508 unsafe { builder.launch(pcg_launch_config(n)?) }
2511 .map(drop)
2512 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2513 }
2514
2515 fn launch_update_p(
2516 stream: &Arc<CudaStream>,
2517 module: &Arc<CudaModule>,
2518 z: &CudaSlice<f64>,
2519 beta: f64,
2520 p: &mut CudaSlice<f64>,
2521 n: usize,
2522 ) -> Result<(), ArrowSchurGpuFailure> {
2523 let kernel = module
2524 .load_function("arrow_pcg_update_p")
2525 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2526 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2527 let mut builder = stream.launch_builder(&kernel);
2528 builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2529 unsafe { builder.launch(pcg_launch_config(n)?) }
2532 .map(drop)
2533 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2534 }
2535
2536 struct DeviceSaePcgBuffers {
2537 row_ptr: CudaSlice<i32>,
2538 beta_base: CudaSlice<i32>,
2539 phi: CudaSlice<f64>,
2540 jac_ptr: CudaSlice<i32>,
2541 jac: CudaSlice<f64>,
2542 smooth_offsets: CudaSlice<i32>,
2543 smooth_m: CudaSlice<i32>,
2544 smooth_ptr: CudaSlice<i32>,
2545 smooth_data: CudaSlice<f64>,
2546 g_row_off: CudaSlice<i32>,
2547 g_col_off: CudaSlice<i32>,
2548 g_rows: CudaSlice<i32>,
2549 g_cols: CudaSlice<i32>,
2550 g_ptr: CudaSlice<i32>,
2551 g_data: CudaSlice<f64>,
2552 ainv: CudaSlice<f64>,
2553 u: CudaSlice<f64>,
2554 w: CudaSlice<f64>,
2555 v: CudaSlice<f64>,
2556 n_rows: usize,
2557 p: usize,
2558 k: usize,
2559 max_q: usize,
2560 smooth_blocks: usize,
2561 g_blocks: usize,
2562 }
2563
2564 fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2565 to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2566 }
2567
2568 fn sae_penalty_diag_host(
2569 data: &DeviceSaePcgData,
2570 ridge_beta: f64,
2571 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2572 let mut diag = vec![ridge_beta; data.beta_dim];
2573 for block in &data.smooth_blocks {
2574 let (rows, cols) = block.factor_a.dim();
2575 if rows != cols {
2576 return Err(ArrowSchurGpuFailure::Unavailable);
2577 }
2578 for row in 0..rows {
2579 let coeff = block.factor_a[[row, row]];
2580 let base = block
2581 .global_offset
2582 .checked_add(
2583 row.checked_mul(data.p)
2584 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2585 )
2586 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2587 let end = base
2588 .checked_add(data.p)
2589 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2590 if end > diag.len() {
2591 return Err(ArrowSchurGpuFailure::Unavailable);
2592 }
2593 for channel in 0..data.p {
2594 diag[base + channel] += coeff;
2595 }
2596 }
2597 }
2598 for block in &data.sparse_g_blocks {
2599 if block.row_off != block.col_off {
2600 continue;
2601 }
2602 let (rows, cols) = block.data.dim();
2603 for row in 0..rows.min(cols) {
2604 let coeff = block.data[[row, row]];
2605 let beta_row = block
2606 .row_off
2607 .checked_add(row)
2608 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2609 let base = beta_row
2610 .checked_mul(data.p)
2611 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2612 let end = base
2613 .checked_add(data.p)
2614 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2615 if end > diag.len() {
2616 return Err(ArrowSchurGpuFailure::Unavailable);
2617 }
2618 for channel in 0..data.p {
2619 diag[base + channel] += coeff;
2620 }
2621 }
2622 }
2623 Ok(diag)
2624 }
2625
2626 fn flatten_device_sae_data(
2627 sys: &ArrowSchurSystem,
2628 data: &DeviceSaePcgData,
2629 ridge_t: f64,
2630 stream: &Arc<CudaStream>,
2631 ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2632 let n_rows = sys.rows.len();
2633 let p = data.p;
2634 let k = data.beta_dim;
2635 if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2636 return Err(ArrowSchurGpuFailure::Unavailable);
2637 }
2638
2639 let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2640 let mut beta_base_host = Vec::<i32>::new();
2641 let mut phi_host = Vec::<f64>::new();
2642 row_ptr_host.push(0_i32);
2643 for row in data.a_phi.iter() {
2644 for &(base, phi) in row {
2645 beta_base_host.push(checked_i32(base)?);
2646 phi_host.push(phi);
2647 }
2648 row_ptr_host.push(checked_i32(beta_base_host.len())?);
2649 }
2650
2651 let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2652 let mut jac_host = Vec::<f64>::new();
2653 let mut max_q = 0usize;
2654 jac_ptr_host.push(0_i32);
2655 for row_jac in data.local_jac.iter() {
2656 if row_jac.len() % p != 0 {
2657 return Err(ArrowSchurGpuFailure::Unavailable);
2658 }
2659 max_q = max_q.max(row_jac.len() / p);
2660 jac_host.extend_from_slice(row_jac);
2661 jac_ptr_host.push(checked_i32(jac_host.len())?);
2662 }
2663 if max_q == 0 {
2664 return Err(ArrowSchurGpuFailure::Unavailable);
2665 }
2666
2667 let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2668 let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2669 let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2670 let mut smooth_data_host = Vec::<f64>::new();
2671 smooth_ptr_host.push(0_i32);
2672 for block in &data.smooth_blocks {
2673 let (rows, cols) = block.factor_a.dim();
2674 if rows != cols {
2675 return Err(ArrowSchurGpuFailure::Unavailable);
2676 }
2677 smooth_offsets_host.push(checked_i32(block.global_offset)?);
2678 smooth_m_host.push(checked_i32(rows)?);
2679 for r in 0..rows {
2680 for c in 0..cols {
2681 smooth_data_host.push(block.factor_a[[r, c]]);
2682 }
2683 }
2684 smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2685 }
2686
2687 let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2688 let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2689 let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2690 let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2691 let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2692 let mut g_data_host = Vec::<f64>::new();
2693 g_ptr_host.push(0_i32);
2694 for block in &data.sparse_g_blocks {
2695 let (rows, cols) = block.data.dim();
2696 g_row_off_host.push(checked_i32(block.row_off)?);
2697 g_col_off_host.push(checked_i32(block.col_off)?);
2698 g_rows_host.push(checked_i32(rows)?);
2699 g_cols_host.push(checked_i32(cols)?);
2700 for r in 0..rows {
2701 for c in 0..cols {
2702 g_data_host.push(block.data[[r, c]]);
2703 }
2704 }
2705 g_ptr_host.push(checked_i32(g_data_host.len())?);
2706 }
2707
2708 let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2709 for (row_idx, row) in sys.rows.iter().enumerate() {
2710 let q = data.local_jac[row_idx].len() / p;
2711 if row.htt.dim() != (q, q) {
2712 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2713 reason: format!(
2714 "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2715 row.htt.dim()
2716 ),
2717 });
2718 }
2719 let mut block = row.htt.clone();
2720 for d in 0..q {
2721 block[[d, d]] += ridge_t;
2722 }
2723 let factor = gam_linalg::triangular::cholesky_factor_in_place(
2724 block.view(),
2725 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2726 )
2727 .ok_or_else(|| {
2728 ArrowSchurGpuFailure::RidgeBumpRequired {
2731 row: row_idx,
2732 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2733 }
2734 })?;
2735 for col in 0..q {
2736 let mut e = Array1::<f64>::zeros(q);
2737 e[col] = 1.0;
2738 let solved =
2739 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2740 for r in 0..q {
2741 ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2742 }
2743 }
2744 }
2745
2746 Ok(DeviceSaePcgBuffers {
2747 row_ptr: stream
2748 .clone_htod(&row_ptr_host)
2749 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2750 beta_base: stream
2751 .clone_htod(&beta_base_host)
2752 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2753 phi: stream
2754 .clone_htod(&phi_host)
2755 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2756 jac_ptr: stream
2757 .clone_htod(&jac_ptr_host)
2758 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2759 jac: stream
2760 .clone_htod(&jac_host)
2761 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2762 smooth_offsets: stream
2763 .clone_htod(&smooth_offsets_host)
2764 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2765 smooth_m: stream
2766 .clone_htod(&smooth_m_host)
2767 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2768 smooth_ptr: stream
2769 .clone_htod(&smooth_ptr_host)
2770 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2771 smooth_data: stream
2772 .clone_htod(&smooth_data_host)
2773 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2774 g_row_off: stream
2775 .clone_htod(&g_row_off_host)
2776 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2777 g_col_off: stream
2778 .clone_htod(&g_col_off_host)
2779 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2780 g_rows: stream
2781 .clone_htod(&g_rows_host)
2782 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2783 g_cols: stream
2784 .clone_htod(&g_cols_host)
2785 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2786 g_ptr: stream
2787 .clone_htod(&g_ptr_host)
2788 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2789 g_data: stream
2790 .clone_htod(&g_data_host)
2791 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2792 ainv: stream
2793 .clone_htod(&ainv_host)
2794 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795 u: stream
2796 .alloc_zeros::<f64>(n_rows * p)
2797 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798 w: stream
2799 .alloc_zeros::<f64>(n_rows * max_q)
2800 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801 v: stream
2802 .alloc_zeros::<f64>(n_rows * max_q)
2803 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804 n_rows,
2805 p,
2806 k,
2807 max_q,
2808 smooth_blocks: data.smooth_blocks.len(),
2809 g_blocks: data.sparse_g_blocks.len(),
2810 })
2811 }
2812
2813 fn launch_sae_init(
2814 stream: &Arc<CudaStream>,
2815 module: &Arc<CudaModule>,
2816 out: &mut CudaSlice<f64>,
2817 x: &CudaSlice<f64>,
2818 ridge: f64,
2819 n: usize,
2820 ) -> Result<(), ArrowSchurGpuFailure> {
2821 let kernel = module
2822 .load_function("arrow_sae_init")
2823 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2824 let n_i32 = checked_i32(n)?;
2825 let mut builder = stream.launch_builder(&kernel);
2826 builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2827 unsafe { builder.launch(pcg_launch_config(n)?) }
2831 .map(drop)
2832 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2833 }
2834
2835 fn launch_sae_penalty_matvec(
2836 stream: &Arc<CudaStream>,
2837 module: &Arc<CudaModule>,
2838 buffers: &mut DeviceSaePcgBuffers,
2839 x: &CudaSlice<f64>,
2840 out: &mut CudaSlice<f64>,
2841 ridge_beta: f64,
2842 ) -> Result<(), ArrowSchurGpuFailure> {
2843 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2844 if buffers.smooth_blocks > 0 {
2845 let kernel = module
2846 .load_function("arrow_sae_smooth_matvec")
2847 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2848 let max_m = buffers.k;
2849 let p_i32 = checked_i32(buffers.p)?;
2850 let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2851 let cfg = LaunchConfig {
2852 grid_dim: (
2853 ((max_m as u32).saturating_add(255) / 256).max(1),
2854 checked_i32(buffers.smooth_blocks)? as u32,
2855 1,
2856 ),
2857 block_dim: (256, 1, 1),
2858 shared_mem_bytes: 0,
2859 };
2860 let mut builder = stream.launch_builder(&kernel);
2861 builder
2862 .arg(x)
2863 .arg(&mut *out)
2864 .arg(&buffers.smooth_offsets)
2865 .arg(&buffers.smooth_m)
2866 .arg(&buffers.smooth_ptr)
2867 .arg(&buffers.smooth_data)
2868 .arg(&p_i32)
2869 .arg(&blocks_i32);
2870 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2875 }
2876 if buffers.g_blocks > 0 {
2877 let kernel = module
2878 .load_function("arrow_sae_sparse_g_matvec")
2879 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2880 let max_work = buffers
2881 .k
2882 .checked_div(buffers.p)
2883 .unwrap_or(0)
2884 .saturating_mul(buffers.p);
2885 let p_i32 = checked_i32(buffers.p)?;
2886 let blocks_i32 = checked_i32(buffers.g_blocks)?;
2887 let cfg = LaunchConfig {
2888 grid_dim: (
2889 ((max_work as u32).saturating_add(255) / 256).max(1),
2890 checked_i32(buffers.g_blocks)? as u32,
2891 1,
2892 ),
2893 block_dim: (256, 1, 1),
2894 shared_mem_bytes: 0,
2895 };
2896 let mut builder = stream.launch_builder(&kernel);
2897 builder
2898 .arg(x)
2899 .arg(&mut *out)
2900 .arg(&buffers.g_row_off)
2901 .arg(&buffers.g_col_off)
2902 .arg(&buffers.g_rows)
2903 .arg(&buffers.g_cols)
2904 .arg(&buffers.g_ptr)
2905 .arg(&buffers.g_data)
2906 .arg(&p_i32)
2907 .arg(&blocks_i32);
2908 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2913 }
2914 Ok(())
2915 }
2916
2917 fn launch_sae_row_schur_sub(
2918 stream: &Arc<CudaStream>,
2919 module: &Arc<CudaModule>,
2920 buffers: &mut DeviceSaePcgBuffers,
2921 x: &CudaSlice<f64>,
2922 out: &mut CudaSlice<f64>,
2923 ) -> Result<(), ArrowSchurGpuFailure> {
2924 let p_i32 = checked_i32(buffers.p)?;
2925 let max_q_i32 = checked_i32(buffers.max_q)?;
2926 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2927 let cfg_p_rows = LaunchConfig {
2928 grid_dim: (
2929 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2930 checked_i32(buffers.n_rows)? as u32,
2931 1,
2932 ),
2933 block_dim: (256, 1, 1),
2934 shared_mem_bytes: 0,
2935 };
2936 let gather = module
2937 .load_function("arrow_sae_gather_u")
2938 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2939 {
2940 let mut builder = stream.launch_builder(&gather);
2941 builder
2942 .arg(x)
2943 .arg(&buffers.row_ptr)
2944 .arg(&buffers.beta_base)
2945 .arg(&buffers.phi)
2946 .arg(&mut buffers.u)
2947 .arg(&p_i32)
2948 .arg(&n_rows_i32);
2949 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2953 }
2954
2955 let cfg_q_rows = LaunchConfig {
2956 grid_dim: (
2957 ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2958 checked_i32(buffers.n_rows)? as u32,
2959 1,
2960 ),
2961 block_dim: (256, 1, 1),
2962 shared_mem_bytes: 0,
2963 };
2964 let apply_l = module
2965 .load_function("arrow_sae_apply_l")
2966 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2967 {
2968 let mut builder = stream.launch_builder(&apply_l);
2969 builder
2970 .arg(&buffers.u)
2971 .arg(&buffers.jac_ptr)
2972 .arg(&buffers.jac)
2973 .arg(&mut buffers.w)
2974 .arg(&p_i32)
2975 .arg(&max_q_i32)
2976 .arg(&n_rows_i32);
2977 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2981 }
2982
2983 let apply_ainv = module
2984 .load_function("arrow_sae_apply_ainv")
2985 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2986 {
2987 let mut builder = stream.launch_builder(&apply_ainv);
2988 builder
2989 .arg(&buffers.ainv)
2990 .arg(&buffers.w)
2991 .arg(&mut buffers.v)
2992 .arg(&max_q_i32)
2993 .arg(&n_rows_i32);
2994 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2998 }
2999
3000 let scatter = module
3001 .load_function("arrow_sae_scatter_sub")
3002 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3003 {
3004 let mut builder = stream.launch_builder(&scatter);
3005 builder
3006 .arg(&buffers.v)
3007 .arg(&buffers.jac_ptr)
3008 .arg(&buffers.jac)
3009 .arg(&buffers.row_ptr)
3010 .arg(&buffers.beta_base)
3011 .arg(&buffers.phi)
3012 .arg(out)
3013 .arg(&p_i32)
3014 .arg(&max_q_i32)
3015 .arg(&n_rows_i32);
3016 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3020 }
3021 Ok(())
3022 }
3023
3024 fn launch_sae_diag_sub(
3025 stream: &Arc<CudaStream>,
3026 module: &Arc<CudaModule>,
3027 buffers: &DeviceSaePcgBuffers,
3028 diag: &mut CudaSlice<f64>,
3029 ) -> Result<(), ArrowSchurGpuFailure> {
3030 let kernel = module
3031 .load_function("arrow_sae_diag_sub")
3032 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3033 let p_i32 = checked_i32(buffers.p)?;
3034 let max_q_i32 = checked_i32(buffers.max_q)?;
3035 let n_rows_i32 = checked_i32(buffers.n_rows)?;
3036 let cfg = LaunchConfig {
3037 grid_dim: (
3038 ((buffers.p as u32).saturating_add(255) / 256).max(1),
3039 checked_i32(buffers.n_rows)? as u32,
3040 1,
3041 ),
3042 block_dim: (256, 1, 1),
3043 shared_mem_bytes: 0,
3044 };
3045 let mut builder = stream.launch_builder(&kernel);
3046 builder
3047 .arg(diag)
3048 .arg(&buffers.ainv)
3049 .arg(&buffers.jac_ptr)
3050 .arg(&buffers.jac)
3051 .arg(&buffers.row_ptr)
3052 .arg(&buffers.beta_base)
3053 .arg(&buffers.phi)
3054 .arg(&p_i32)
3055 .arg(&max_q_i32)
3056 .arg(&n_rows_i32);
3057 unsafe { builder.launch(cfg) }
3061 .map(drop)
3062 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3063 }
3064
3065 fn launch_sae_matvec(
3066 stream: &Arc<CudaStream>,
3067 module: &Arc<CudaModule>,
3068 buffers: &mut DeviceSaePcgBuffers,
3069 x: &CudaSlice<f64>,
3070 out: &mut CudaSlice<f64>,
3071 ridge_beta: f64,
3072 ) -> Result<(), ArrowSchurGpuFailure> {
3073 launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3074 launch_sae_row_schur_sub(stream, module, buffers, x, out)
3075 }
3076
3077 fn pack_fused_host(
3082 sys: &ArrowSchurSystem,
3083 ridge_t: f64,
3084 p_max: usize,
3085 r_template: usize,
3086 ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3087 let n = sys.rows.len();
3088 let d = sys.d;
3089 let k = sys.k;
3090 let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3091 let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3092 let mut g_buf = vec![0.0_f64; n * p_max];
3093 for (i, row) in sys.rows.iter().enumerate() {
3094 for col in 0..d {
3096 let base = (i * p_max + col) * p_max;
3097 for r in 0..d {
3098 let mut value = row.htt[[r, col]];
3099 if r == col {
3100 value += ridge_t;
3101 }
3102 d_buf[base + r] = value;
3103 }
3104 }
3105 for col in 0..k {
3113 let base = (i * r_template + col) * p_max;
3114 for r in 0..d {
3115 b_buf[base + r] = row.htbeta[[r, col]];
3116 }
3117 }
3118 let g_base = i * p_max;
3120 for r in 0..d {
3121 g_buf[g_base + r] = row.gt[r];
3122 }
3123 }
3124 (d_buf, b_buf, g_buf)
3125 }
3126
3127 pub(super) struct ResidentArrowFrame {
3154 n: usize,
3155 d: usize,
3156 k: usize,
3157 stream: Arc<CudaStream>,
3158 blas: CudaBlas,
3159 l_dev: CudaSlice<f64>,
3162 y_dev: CudaSlice<f64>,
3165 schur_dev: CudaSlice<f64>,
3168 log_det_hessian: f64,
3171 }
3172
3173 impl ResidentArrowFrame {
3174 pub(super) fn new(
3178 sys: &ArrowSchurSystem,
3179 ridge_t: f64,
3180 ridge_beta: f64,
3181 ) -> Result<Self, ArrowSchurGpuFailure> {
3182 if ridge_t.is_nan() || ridge_beta.is_nan() {
3183 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3184 reason: "ridge is NaN".to_string(),
3185 });
3186 }
3187 let n = sys.rows.len();
3188 let d = sys.d;
3189 let k = sys.k;
3190 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3191 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3192 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3193 .and_then(|ctx| ctx.new_stream().ok())
3194 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3195 let solver =
3196 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3197 let blas =
3198 CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3199
3200 let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3202 let mut l_dev = stream
3203 .clone_htod(&d_host)
3204 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3205 let mut y_dev = stream
3206 .clone_htod(&b_host)
3207 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3208
3209 let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3211 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3212 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3216 row: idx,
3217 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3218 });
3219 }
3220
3221 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3223
3224 let schur_init: Vec<f64> = {
3229 let mut tmp = Vec::with_capacity(k * k);
3230 for col in 0..k {
3231 for row in 0..k {
3232 let mut v = sys.hbb[[row, col]];
3233 if row == col {
3234 v += ridge_beta;
3235 }
3236 tmp.push(v);
3237 }
3238 }
3239 tmp
3240 };
3241 let mut schur_dev = stream
3242 .clone_htod(&schur_init)
3243 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3244 let zero_u = stream
3247 .clone_htod(&vec![0.0_f64; n * d])
3248 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3249 let mut throwaway_rhs = stream
3250 .clone_htod(&vec![0.0_f64; k])
3251 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3252 accumulate_schur(
3253 &blas,
3254 d,
3255 k,
3256 n,
3257 &y_dev,
3258 &zero_u,
3259 &mut schur_dev,
3260 &mut throwaway_rhs,
3261 )?;
3262 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3263 if info != 0 {
3264 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3265 reason: format!("Schur Cholesky failed at pivot {info}"),
3266 });
3267 }
3268
3269 let l_local_host = stream
3271 .clone_dtoh(&l_dev)
3272 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3273 let l_schur_host = stream
3274 .clone_dtoh(&schur_dev)
3275 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3276 let mut log_det = 0.0_f64;
3277 for i in 0..n {
3278 let base = i * d * d;
3279 for j in 0..d {
3280 log_det += l_local_host[base + j * d + j].ln();
3281 }
3282 }
3283 for j in 0..k {
3284 log_det += l_schur_host[j * k + j].ln();
3285 }
3286 log_det *= 2.0;
3287
3288 Ok(Self {
3289 n,
3290 d,
3291 k,
3292 stream,
3293 blas,
3294 l_dev,
3295 y_dev,
3296 schur_dev,
3297 log_det_hessian: log_det,
3298 })
3299 }
3300
3301 #[inline]
3302 pub(super) fn log_det_hessian(&self) -> f64 {
3303 self.log_det_hessian
3304 }
3305
3306 pub(super) fn solve_gradient(
3310 &self,
3311 g_t: &[f64],
3312 g_beta: &[f64],
3313 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3314 let n = self.n;
3315 let d = self.d;
3316 let k = self.k;
3317 if g_t.len() != n * d || g_beta.len() != k {
3318 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3319 reason: format!(
3320 "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3321 g_t.len(),
3322 n * d,
3323 g_beta.len(),
3324 k
3325 ),
3326 });
3327 }
3328 let mut u_dev = self
3330 .stream
3331 .clone_htod(&g_t.to_vec())
3332 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3333 trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3334
3335 let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3338 let mut rhs_dev = self
3339 .stream
3340 .clone_htod(&rhs_init)
3341 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3342 accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3343
3344 trsm_single(
3346 &self.blas,
3347 &self.stream,
3348 k,
3349 &self.schur_dev,
3350 &mut rhs_dev,
3351 false,
3352 false,
3353 )?;
3354 trsm_single(
3355 &self.blas,
3356 &self.stream,
3357 k,
3358 &self.schur_dev,
3359 &mut rhs_dev,
3360 false,
3361 true,
3362 )?;
3363 let delta_beta_host = self
3364 .stream
3365 .clone_dtoh(&rhs_dev)
3366 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3367 let delta_beta = Array1::from_vec(delta_beta_host);
3368
3369 accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3371 trsm_batched_lower_inplace_transposed(
3372 &self.blas,
3373 &self.stream,
3374 d,
3375 n,
3376 1,
3377 &self.l_dev,
3378 &mut u_dev,
3379 )?;
3380 let x_host = self
3381 .stream
3382 .clone_dtoh(&u_dev)
3383 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3384 let mut delta_t = Array1::<f64>::zeros(n * d);
3385 for (i, v) in x_host.iter().enumerate() {
3386 delta_t[i] = -*v;
3387 }
3388
3389 Ok(ArrowSchurGpuSolution {
3390 delta_t,
3391 delta_beta,
3392 log_det_hessian: self.log_det_hessian,
3393 })
3394 }
3395 }
3396
3397 pub(super) fn solve_fused(
3398 sys: &ArrowSchurSystem,
3399 ridge_t: f64,
3400 ridge_beta: f64,
3401 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3402 let n = sys.rows.len();
3403 let d = sys.d;
3404 let k = sys.k;
3405 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3406 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3407 let p_max = plan.p_max;
3408 let r_template = plan.r_template;
3409
3410 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3411 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3412 )
3413 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3414 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3415 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3416 let stream = ctx
3417 .new_stream()
3418 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3419 let cap = &runtime.device.capability;
3420 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3421 cc_major: cap.compute_major,
3422 cc_minor: cap.compute_minor,
3423 p_max: p_max as u32,
3424 r_template: r_template as u32,
3425 };
3426 let module = fused_module_for(&ctx, key)?;
3427 let forward = module
3428 .load_function("arrow_schur_forward_pgroup")
3429 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3430 let back_sub = module
3431 .load_function("arrow_schur_back_sub_pgroup")
3432 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3433
3434 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3436 let d_dev = stream
3437 .clone_htod(&d_host)
3438 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3439 let b_dev = stream
3440 .clone_htod(&b_host)
3441 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3442 let g_dev = stream
3443 .clone_htod(&g_host)
3444 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445 let mut l_out = stream
3446 .alloc_zeros::<f64>(n * p_max * p_max)
3447 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3448 let mut u_out = stream
3449 .alloc_zeros::<f64>(n * p_max)
3450 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3451 let mut y_out = stream
3452 .alloc_zeros::<f64>(n * p_max * r_template)
3453 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3454 let mut partial_s = stream
3455 .alloc_zeros::<f64>(plan.partial_s_doubles)
3456 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457 let mut partial_r = stream
3458 .alloc_zeros::<f64>(plan.partial_r_doubles)
3459 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3460 let mut status_dev = stream
3461 .alloc_zeros::<i32>(n)
3462 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463
3464 let cfg = LaunchConfig {
3466 grid_dim: (plan.blocks, 1, 1),
3467 block_dim: (plan.threads_per_block, 1, 1),
3468 shared_mem_bytes: 0,
3469 };
3470 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3471 let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3472 let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3473 let ridge_arg = ridge_t;
3474 {
3475 let mut builder = stream.launch_builder(&forward);
3476 builder
3477 .arg(&d_dev)
3478 .arg(&b_dev)
3479 .arg(&g_dev)
3480 .arg(&n_i32)
3481 .arg(&p_i32)
3482 .arg(&r_i32)
3483 .arg(&ridge_arg)
3484 .arg(&mut l_out)
3485 .arg(&mut u_out)
3486 .arg(&mut y_out)
3487 .arg(&mut partial_s)
3488 .arg(&mut partial_r)
3489 .arg(&mut status_dev);
3490 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3494 }
3495 stream
3496 .synchronize()
3497 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3498
3499 let status_host = stream
3501 .clone_dtoh(&status_dev)
3502 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3503 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3504 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3508 row,
3509 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3510 });
3511 }
3512
3513 let partial_s_host = stream
3515 .clone_dtoh(&partial_s)
3516 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3517 let partial_r_host = stream
3518 .clone_dtoh(&partial_r)
3519 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3520 let mut schur_host = vec![0.0_f64; k * k];
3521 for col in 0..k {
3522 for row in 0..k {
3523 let mut v = sys.hbb[[row, col]];
3524 if row == col {
3525 v += ridge_beta;
3526 }
3527 schur_host[col * k + row] = v;
3528 }
3529 }
3530 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3531 for i in 0..n {
3532 let s_base = i * r_template * r_template;
3535 for col in 0..k {
3536 let col_base = s_base + col * r_template;
3537 let dst_col_base = col * k;
3538 for row in 0..k {
3539 schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3540 }
3541 }
3542 let r_base = i * r_template;
3543 for a in 0..k {
3544 rhs_host[a] += partial_r_host[r_base + a];
3545 }
3546 }
3547
3548 let mut schur_dev = stream
3550 .clone_htod(&schur_host)
3551 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3552 let mut rhs_dev = stream
3553 .clone_htod(&rhs_host)
3554 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3555 let solver =
3556 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3557 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3558 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3559 if info != 0 {
3560 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3561 reason: format!("fused Schur Cholesky failed at pivot {info}"),
3562 });
3563 }
3564 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3565 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3566 let delta_beta_host = stream
3567 .clone_dtoh(&rhs_dev)
3568 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3569 let delta_beta = Array1::from_vec(delta_beta_host.clone());
3570
3571 let mut delta_t_dev = stream
3573 .alloc_zeros::<f64>(n * p_max)
3574 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3575 let back_cfg = LaunchConfig {
3576 grid_dim: (plan.blocks, 1, 1),
3577 block_dim: (plan.threads_per_block, 1, 1),
3578 shared_mem_bytes: 0,
3579 };
3580 {
3581 let mut builder = stream.launch_builder(&back_sub);
3582 builder
3583 .arg(&l_out)
3584 .arg(&u_out)
3585 .arg(&y_out)
3586 .arg(&rhs_dev)
3587 .arg(&n_i32)
3588 .arg(&p_i32)
3589 .arg(&r_i32)
3590 .arg(&mut delta_t_dev);
3591 unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3595 }
3596 stream
3597 .synchronize()
3598 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3599
3600 let delta_t_host = stream
3601 .clone_dtoh(&delta_t_dev)
3602 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3603 let mut delta_t = Array1::<f64>::zeros(n * d);
3604 for i in 0..n {
3605 let src_base = i * p_max;
3606 let dst_base = i * d;
3607 for r in 0..d {
3608 delta_t[dst_base + r] = delta_t_host[src_base + r];
3609 }
3610 }
3611
3612 let l_local_host = stream
3614 .clone_dtoh(&l_out)
3615 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3616 let l_schur_host = stream
3617 .clone_dtoh(&schur_dev)
3618 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3619 let mut log_det = 0.0_f64;
3620 for i in 0..n {
3621 let base = i * p_max * p_max;
3622 for j in 0..d {
3623 log_det += l_local_host[base + j * p_max + j].ln();
3624 }
3625 }
3626 for j in 0..k {
3627 log_det += l_schur_host[j * k + j].ln();
3628 }
3629 log_det *= 2.0;
3630
3631 Ok(ArrowSchurGpuSolution {
3632 delta_t,
3633 delta_beta,
3634 log_det_hessian: log_det,
3635 })
3636 }
3637
3638 pub(super) fn build_schur_matvec_backend(
3648 sys: &ArrowSchurSystem,
3649 ridge_t: f64,
3650 ridge_beta: f64,
3651 ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3652 let n = sys.rows.len();
3653 let d = sys.d;
3654 let k = sys.k;
3655 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3656 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3657 let p_max = plan.p_max;
3658 let r_template = plan.r_template;
3659
3660 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3661 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3662 )
3663 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3664 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3665 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3666 let stream = ctx
3667 .new_stream()
3668 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3669 let cap = &runtime.device.capability;
3670 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3671 cc_major: cap.compute_major,
3672 cc_minor: cap.compute_minor,
3673 p_max: p_max as u32,
3674 r_template: r_template as u32,
3675 };
3676 let module = fused_module_for(&ctx, key)?;
3677 let forward = module
3678 .load_function("arrow_schur_forward_pgroup")
3679 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3680
3681 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3682 let d_dev = stream
3683 .clone_htod(&d_host)
3684 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3685 let b_dev = stream
3686 .clone_htod(&b_host)
3687 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3688 let g_dev = stream
3689 .clone_htod(&g_host)
3690 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3691 let mut l_out = stream
3692 .alloc_zeros::<f64>(n * p_max * p_max)
3693 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3694 let mut u_out = stream
3695 .alloc_zeros::<f64>(n * p_max)
3696 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3697 let mut y_out = stream
3698 .alloc_zeros::<f64>(n * p_max * r_template)
3699 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3700 let mut partial_s = stream
3701 .alloc_zeros::<f64>(plan.partial_s_doubles)
3702 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3703 let mut partial_r = stream
3704 .alloc_zeros::<f64>(plan.partial_r_doubles)
3705 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3706 let mut status_dev = stream
3707 .alloc_zeros::<i32>(n)
3708 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3709
3710 let cfg = LaunchConfig {
3711 grid_dim: (plan.blocks, 1, 1),
3712 block_dim: (plan.threads_per_block, 1, 1),
3713 shared_mem_bytes: 0,
3714 };
3715 let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3716 let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3717 let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3718 let ridge_arg = ridge_t;
3719 {
3720 let mut builder = stream.launch_builder(&forward);
3721 builder
3722 .arg(&d_dev)
3723 .arg(&b_dev)
3724 .arg(&g_dev)
3725 .arg(&n_i32)
3726 .arg(&p_i32)
3727 .arg(&r_i32)
3728 .arg(&ridge_arg)
3729 .arg(&mut l_out)
3730 .arg(&mut u_out)
3731 .arg(&mut y_out)
3732 .arg(&mut partial_s)
3733 .arg(&mut partial_r)
3734 .arg(&mut status_dev);
3735 unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3738 }
3739 stream
3740 .synchronize()
3741 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3742
3743 let status_host = stream
3744 .clone_dtoh(&status_dev)
3745 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3746 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3747 return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3751 row,
3752 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3753 });
3754 }
3755
3756 let y_host = stream
3758 .clone_dtoh(&y_out)
3759 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3760
3761 let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3764 let hbb_is_kk = sys.hbb.dim() == (k, k);
3765 let hbb_matvec_opt = sys.hbb_matvec.clone();
3766
3767 let closure: crate::arrow_schur::GpuSchurMatvec =
3768 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3769 assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3770 assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3771
3772 if let Some(ref mv) = hbb_matvec_opt {
3774 mv(x.view(), out);
3775 for a in 0..k {
3776 out[a] += ridge_beta * x[a];
3777 }
3778 } else if hbb_is_kk {
3779 for a in 0..k {
3781 let mut acc = ridge_beta * x[a];
3782 for b in 0..k {
3783 acc += hbb_host[a * k + b] * x[b];
3784 }
3785 out[a] = acc;
3786 }
3787 } else {
3788 for a in 0..k {
3789 out[a] = ridge_beta * x[a];
3790 }
3791 }
3792
3793 let mut z = vec![0.0_f64; d];
3796 for i in 0..n {
3797 let y_base = i * p_max * r_template;
3798 for r in 0..d {
3799 let mut acc = 0.0;
3800 for c in 0..k {
3801 acc += y_host[y_base + c * p_max + r] * x[c];
3802 }
3803 z[r] = acc;
3804 }
3805 for c in 0..k {
3806 let mut acc = 0.0;
3807 for r in 0..d {
3808 acc += y_host[y_base + c * p_max + r] * z[r];
3809 }
3810 out[c] -= acc;
3811 }
3812 }
3813 });
3814
3815 Ok(closure)
3816 }
3817
3818 struct DeviceSaeFrameBuffers {
3821 s_off: CudaSlice<i32>,
3823 s_m: CudaSlice<i32>,
3824 s_r: CudaSlice<i32>,
3825 s_ptr: CudaSlice<i32>,
3826 s_data: CudaSlice<f64>,
3827 s_blocks: usize,
3828 g_off_i: CudaSlice<i32>,
3830 g_off_j: CudaSlice<i32>,
3831 g_ri: CudaSlice<i32>,
3832 g_rj: CudaSlice<i32>,
3833 g_mi: CudaSlice<i32>,
3834 g_mj: CudaSlice<i32>,
3835 g_ptr: CudaSlice<i32>,
3836 g_data: CudaSlice<f64>,
3837 w_ptr: CudaSlice<i32>,
3838 w_data: CudaSlice<f64>,
3839 g_blocks: usize,
3840 g_max_work: usize,
3841 htb_ptr: CudaSlice<i32>,
3843 htb: CudaSlice<f64>,
3844 q_of: CudaSlice<i32>,
3845 ainv: CudaSlice<f64>,
3846 hvec: CudaSlice<f64>,
3847 svec: CudaSlice<f64>,
3848 n_rows: usize,
3849 k: usize,
3850 max_q: usize,
3851 }
3852
3853 fn flatten_device_sae_frame_data(
3854 sys: &ArrowSchurSystem,
3855 data: &DeviceSaePcgData,
3856 frame: &DeviceSaeFrameData,
3857 ridge_t: f64,
3858 stream: &Arc<CudaStream>,
3859 ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3860 let n_rows = sys.rows.len();
3861 let k = data.beta_dim;
3862 if frame.row_htbeta.len() != n_rows
3863 || frame.ranks.len() != frame.basis_sizes.len()
3864 || frame.border_offsets.len() != frame.ranks.len()
3865 || data.smooth_blocks.len() != frame.smooth_ranks.len()
3866 {
3867 return Err(ArrowSchurGpuFailure::Unavailable);
3868 }
3869
3870 let mut s_off = Vec::new();
3872 let mut s_m = Vec::new();
3873 let mut s_r = Vec::new();
3874 let mut s_ptr = vec![0_i32];
3875 let mut s_data = Vec::<f64>::new();
3876 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3877 let (m, mc) = blk.factor_a.dim();
3878 if m != mc {
3879 return Err(ArrowSchurGpuFailure::Unavailable);
3880 }
3881 s_off.push(checked_i32(blk.global_offset)?);
3882 s_m.push(checked_i32(m)?);
3883 s_r.push(checked_i32(r)?);
3884 for ri in 0..m {
3885 for ci in 0..m {
3886 s_data.push(blk.factor_a[[ri, ci]]);
3887 }
3888 }
3889 s_ptr.push(checked_i32(s_data.len())?);
3890 }
3891
3892 let mut g_off_i = Vec::new();
3894 let mut g_off_j = Vec::new();
3895 let mut g_ri = Vec::new();
3896 let mut g_rj = Vec::new();
3897 let mut g_mi = Vec::new();
3898 let mut g_mj = Vec::new();
3899 let mut g_ptr = vec![0_i32];
3900 let mut g_data = Vec::<f64>::new();
3901 let mut w_ptr = vec![0_i32];
3902 let mut w_data = Vec::<f64>::new();
3903 let mut g_max_work = 0usize;
3904 for blk in &frame.frame_blocks {
3905 let ri = frame.ranks[blk.atom_i];
3906 let rj = frame.ranks[blk.atom_j];
3907 let (mi, mj) = blk.g.dim();
3908 if blk.w.dim() != (ri, rj) {
3909 return Err(ArrowSchurGpuFailure::Unavailable);
3910 }
3911 g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3912 g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3913 g_ri.push(checked_i32(ri)?);
3914 g_rj.push(checked_i32(rj)?);
3915 g_mi.push(checked_i32(mi)?);
3916 g_mj.push(checked_i32(mj)?);
3917 for r in 0..mi {
3918 for c in 0..mj {
3919 g_data.push(blk.g[[r, c]]);
3920 }
3921 }
3922 g_ptr.push(checked_i32(g_data.len())?);
3923 for a in 0..ri {
3924 for b in 0..rj {
3925 w_data.push(blk.w[[a, b]]);
3926 }
3927 }
3928 w_ptr.push(checked_i32(w_data.len())?);
3929 g_max_work = g_max_work.max(mi * ri);
3930 }
3931
3932 let mut htb_ptr = vec![0_i32];
3934 let mut htb = Vec::<f64>::new();
3935 let mut q_of = Vec::<i32>::with_capacity(n_rows);
3936 let mut max_q = 0usize;
3937 for (i, slab) in frame.row_htbeta.iter().enumerate() {
3938 let qi = sys.row_dims[i];
3939 let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3942 qi
3943 } else {
3944 0
3945 };
3946 q_of.push(checked_i32(q_eff)?);
3947 max_q = max_q.max(q_eff);
3948 if q_eff > 0 {
3949 htb.extend_from_slice(slab);
3950 }
3951 htb_ptr.push(checked_i32(htb.len())?);
3952 }
3953 if max_q == 0 {
3954 max_q = 1;
3957 }
3958
3959 let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3960 for (i, row) in sys.rows.iter().enumerate() {
3961 let q = q_of[i] as usize;
3962 if q == 0 {
3963 continue;
3964 }
3965 if row.htt.dim() != (q, q) {
3966 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3967 reason: format!(
3968 "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3969 row.htt.dim()
3970 ),
3971 });
3972 }
3973 let mut block = row.htt.clone();
3974 for d in 0..q {
3975 block[[d, d]] += ridge_t;
3976 }
3977 let factor = gam_linalg::triangular::cholesky_factor_in_place(
3978 block.view(),
3979 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3980 )
3981 .ok_or_else(|| {
3982 ArrowSchurGpuFailure::RidgeBumpRequired {
3985 row: i,
3986 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3987 }
3988 })?;
3989 for col in 0..q {
3990 let mut e = Array1::<f64>::zeros(q);
3991 e[col] = 1.0;
3992 let solved =
3993 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3994 for r in 0..q {
3995 ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3996 }
3997 }
3998 }
3999
4000 let htod_i = |v: &[i32]| {
4001 stream
4002 .clone_htod(v)
4003 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4004 };
4005 let htod_f = |v: &[f64]| {
4006 stream
4007 .clone_htod(v)
4008 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4009 };
4010 Ok(DeviceSaeFrameBuffers {
4011 s_off: htod_i(&s_off)?,
4012 s_m: htod_i(&s_m)?,
4013 s_r: htod_i(&s_r)?,
4014 s_ptr: htod_i(&s_ptr)?,
4015 s_data: htod_f(&s_data)?,
4016 s_blocks: data.smooth_blocks.len(),
4017 g_off_i: htod_i(&g_off_i)?,
4018 g_off_j: htod_i(&g_off_j)?,
4019 g_ri: htod_i(&g_ri)?,
4020 g_rj: htod_i(&g_rj)?,
4021 g_mi: htod_i(&g_mi)?,
4022 g_mj: htod_i(&g_mj)?,
4023 g_ptr: htod_i(&g_ptr)?,
4024 g_data: htod_f(&g_data)?,
4025 w_ptr: htod_i(&w_ptr)?,
4026 w_data: htod_f(&w_data)?,
4027 g_blocks: frame.frame_blocks.len(),
4028 g_max_work,
4029 htb_ptr: htod_i(&htb_ptr)?,
4030 htb: htod_f(&htb)?,
4031 q_of: htod_i(&q_of)?,
4032 ainv: htod_f(&ainv)?,
4033 hvec: stream
4034 .alloc_zeros::<f64>(n_rows * max_q)
4035 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4036 svec: stream
4037 .alloc_zeros::<f64>(n_rows * max_q)
4038 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4039 n_rows,
4040 k,
4041 max_q,
4042 })
4043 }
4044
4045 fn sae_frame_penalty_diag_host(
4046 data: &DeviceSaePcgData,
4047 frame: &DeviceSaeFrameData,
4048 ridge_beta: f64,
4049 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4050 let mut diag = vec![ridge_beta; data.beta_dim];
4051 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4053 let m = blk.factor_a.nrows();
4054 for ia in 0..m {
4055 let coeff = blk.factor_a[[ia, ia]];
4056 let base = blk.global_offset + ia * r;
4057 for ib in 0..r {
4058 if base + ib >= diag.len() {
4059 return Err(ArrowSchurGpuFailure::Unavailable);
4060 }
4061 diag[base + ib] += coeff;
4062 }
4063 }
4064 }
4065 for blk in &frame.frame_blocks {
4067 if blk.atom_i != blk.atom_j {
4068 continue;
4069 }
4070 let r = frame.ranks[blk.atom_i];
4071 let off = frame.border_offsets[blk.atom_i];
4072 let (mi, mj) = blk.g.dim();
4073 for li in 0..mi.min(mj) {
4074 let gii = blk.g[[li, li]];
4075 let base = off + li * r;
4076 for a in 0..r {
4077 if base + a >= diag.len() {
4078 return Err(ArrowSchurGpuFailure::Unavailable);
4079 }
4080 diag[base + a] += gii * blk.w[[a, a]];
4081 }
4082 }
4083 }
4084 Ok(diag)
4085 }
4086
4087 fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4088 Ok(LaunchConfig {
4089 grid_dim: (
4090 ((work as u32).saturating_add(255) / 256).max(1),
4091 checked_i32(n_rows)? as u32,
4092 1,
4093 ),
4094 block_dim: (256, 1, 1),
4095 shared_mem_bytes: 0,
4096 })
4097 }
4098
4099 fn launch_sae_frame_matvec(
4100 stream: &Arc<CudaStream>,
4101 module: &Arc<CudaModule>,
4102 buffers: &mut DeviceSaeFrameBuffers,
4103 x: &CudaSlice<f64>,
4104 out: &mut CudaSlice<f64>,
4105 ridge_beta: f64,
4106 ) -> Result<(), ArrowSchurGpuFailure> {
4107 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4108 if buffers.s_blocks > 0 {
4110 let kernel = module
4111 .load_function("arrow_sae_frame_smooth_matvec")
4112 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4113 let blocks_i32 = checked_i32(buffers.s_blocks)?;
4114 let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4115 let mut b = stream.launch_builder(&kernel);
4116 b.arg(x)
4117 .arg(&mut *out)
4118 .arg(&buffers.s_off)
4119 .arg(&buffers.s_m)
4120 .arg(&buffers.s_r)
4121 .arg(&buffers.s_ptr)
4122 .arg(&buffers.s_data)
4123 .arg(&blocks_i32);
4124 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4127 }
4128 if buffers.g_blocks > 0 {
4130 let kernel = module
4131 .load_function("arrow_sae_frame_g_matvec")
4132 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4133 let blocks_i32 = checked_i32(buffers.g_blocks)?;
4134 let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4135 let mut b = stream.launch_builder(&kernel);
4136 b.arg(x)
4137 .arg(&mut *out)
4138 .arg(&buffers.g_off_i)
4139 .arg(&buffers.g_off_j)
4140 .arg(&buffers.g_ri)
4141 .arg(&buffers.g_rj)
4142 .arg(&buffers.g_mi)
4143 .arg(&buffers.g_mj)
4144 .arg(&buffers.g_ptr)
4145 .arg(&buffers.g_data)
4146 .arg(&buffers.w_ptr)
4147 .arg(&buffers.w_data)
4148 .arg(&blocks_i32);
4149 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4152 }
4153 let k_i32 = checked_i32(buffers.k)?;
4155 let max_q_i32 = checked_i32(buffers.max_q)?;
4156 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4157 {
4158 let kernel = module
4159 .load_function("arrow_sae_frame_apply_h")
4160 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4161 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4162 let mut b = stream.launch_builder(&kernel);
4163 b.arg(x)
4164 .arg(&buffers.htb_ptr)
4165 .arg(&buffers.htb)
4166 .arg(&buffers.q_of)
4167 .arg(&mut buffers.hvec)
4168 .arg(&k_i32)
4169 .arg(&max_q_i32)
4170 .arg(&n_rows_i32);
4171 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4174 }
4175 {
4176 let kernel = module
4177 .load_function("arrow_sae_frame_apply_ainv")
4178 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4179 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4180 let mut b = stream.launch_builder(&kernel);
4181 b.arg(&buffers.ainv)
4182 .arg(&buffers.hvec)
4183 .arg(&buffers.q_of)
4184 .arg(&mut buffers.svec)
4185 .arg(&max_q_i32)
4186 .arg(&n_rows_i32);
4187 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4190 }
4191 {
4192 let kernel = module
4193 .load_function("arrow_sae_frame_scatter_h")
4194 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4195 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4196 let mut b = stream.launch_builder(&kernel);
4197 b.arg(&buffers.svec)
4198 .arg(&buffers.htb_ptr)
4199 .arg(&buffers.htb)
4200 .arg(&buffers.q_of)
4201 .arg(out)
4202 .arg(&k_i32)
4203 .arg(&max_q_i32)
4204 .arg(&n_rows_i32);
4205 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4208 }
4209 Ok(())
4210 }
4211
4212 fn launch_sae_frame_diag_sub(
4213 stream: &Arc<CudaStream>,
4214 module: &Arc<CudaModule>,
4215 buffers: &DeviceSaeFrameBuffers,
4216 diag: &mut CudaSlice<f64>,
4217 ) -> Result<(), ArrowSchurGpuFailure> {
4218 let kernel = module
4219 .load_function("arrow_sae_frame_diag_sub")
4220 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4221 let k_i32 = checked_i32(buffers.k)?;
4222 let max_q_i32 = checked_i32(buffers.max_q)?;
4223 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4224 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4225 let mut b = stream.launch_builder(&kernel);
4226 b.arg(diag)
4227 .arg(&buffers.ainv)
4228 .arg(&buffers.htb_ptr)
4229 .arg(&buffers.htb)
4230 .arg(&buffers.q_of)
4231 .arg(&k_i32)
4232 .arg(&max_q_i32)
4233 .arg(&n_rows_i32);
4234 unsafe { b.launch(cfg) }
4236 .map(drop)
4237 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4238 }
4239
4240 pub(super) fn framed_schur_matvec_once_on_device(
4251 sys: &ArrowSchurSystem,
4252 data: &DeviceSaePcgData,
4253 ridge_t: f64,
4254 ridge_beta: f64,
4255 x: &Array1<f64>,
4256 ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4257 let k = x.len();
4258 if k == 0 || data.beta_dim != k || sys.k != k {
4259 return Err(ArrowSchurGpuFailure::Unavailable);
4260 }
4261 let frame = data
4262 .frame
4263 .as_ref()
4264 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4265 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4268 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4269 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4270 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4271 let stream = ctx
4272 .new_stream()
4273 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4274 let vector_module = pcg_vector_module(&ctx)?;
4275 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4276 let x_dev = stream
4277 .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4278 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4279 let mut out_dev = stream
4280 .alloc_zeros::<f64>(k)
4281 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4282 launch_sae_frame_matvec(
4283 &stream,
4284 vector_module,
4285 &mut buffers,
4286 &x_dev,
4287 &mut out_dev,
4288 ridge_beta,
4289 )?;
4290 let out = stream
4291 .clone_dtoh(&out_dev)
4292 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4293 Ok(Array1::from_vec(out))
4294 }
4295
4296 pub(super) fn solve_sae_matrix_free_pcg_framed(
4297 sys: &ArrowSchurSystem,
4298 data: &DeviceSaePcgData,
4299 ridge_t: f64,
4300 ridge_beta: f64,
4301 rhs_beta: &Array1<f64>,
4302 max_iterations: usize,
4303 relative_tolerance: f64,
4304 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4305 let k = rhs_beta.len();
4306 if k == 0 || data.beta_dim != k || sys.k != k {
4307 return Err(ArrowSchurGpuFailure::Unavailable);
4308 }
4309 let frame = data
4310 .frame
4311 .as_ref()
4312 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4313 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4314 .filter(|rt| {
4315 rt.policy().reduced_schur_matvec_should_offload(
4316 sys.rows.len(),
4317 sys.k,
4318 sys.d,
4319 max_iterations,
4320 )
4321 })
4322 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4323 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4324 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4325 let stream = ctx
4326 .new_stream()
4327 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4328 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4329 let vector_module = pcg_vector_module(&ctx)?;
4330 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4331
4332 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4333 if rhs_norm == 0.0 {
4334 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4335 }
4336 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4337 let rhs_dev = stream
4338 .clone_htod(
4339 rhs_beta
4340 .as_slice()
4341 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4342 )
4343 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4344 let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4345 let mut diag_dev = stream
4346 .clone_htod(&diag_host)
4347 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4348 launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4349 let diag_host = stream
4350 .clone_dtoh(&diag_dev)
4351 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4352 let mut inv_diag = Vec::with_capacity(k);
4353 for (idx, &d) in diag_host.iter().enumerate() {
4354 if !d.is_finite() || d <= 1.0e-18 {
4355 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4356 reason: format!(
4357 "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4358 ),
4359 });
4360 }
4361 inv_diag.push(1.0 / d);
4362 }
4363 let inv_diag_dev = stream
4364 .clone_htod(&inv_diag)
4365 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4366
4367 let mut x_dev = stream
4368 .alloc_zeros::<f64>(k)
4369 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4370 let mut r_dev = stream
4371 .alloc_zeros::<f64>(k)
4372 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4374 let mut z_dev = stream
4375 .alloc_zeros::<f64>(k)
4376 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4377 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4378 let mut p_dev = stream
4379 .alloc_zeros::<f64>(k)
4380 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4381 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4382 let mut ap_dev = stream
4383 .alloc_zeros::<f64>(k)
4384 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4385
4386 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4387 if rz <= 0.0 || !rz.is_finite() {
4388 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4389 reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4390 });
4391 }
4392 let mut diag = PcgDiagnostics {
4393 precond_apply_calls: 1,
4394 stopping_reason: PcgStopReason::MaxIter,
4395 ..PcgDiagnostics::default()
4396 };
4397 for _ in 0..max_iterations.max(1) {
4398 launch_sae_frame_matvec(
4399 &stream,
4400 vector_module,
4401 &mut buffers,
4402 &p_dev,
4403 &mut ap_dev,
4404 ridge_beta,
4405 )?;
4406 diag.matvec_calls += 1;
4407 diag.iterations += 1;
4408 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4409 if pap <= 0.0 || !pap.is_finite() {
4410 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4411 reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4412 });
4413 }
4414 let alpha = rz / pap;
4415 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4416 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4417 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4418 if r_norm <= tol {
4419 diag.final_relative_residual = r_norm / rhs_norm;
4420 diag.stopping_reason = PcgStopReason::Converged;
4421 break;
4422 }
4423 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4424 diag.precond_apply_calls += 1;
4425 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4426 if rz_new <= 0.0 || !rz_new.is_finite() {
4427 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4428 reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4429 });
4430 }
4431 let beta = rz_new / rz;
4432 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4433 rz = rz_new;
4434 }
4435 if diag.stopping_reason != PcgStopReason::Converged {
4436 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4437 diag.final_relative_residual = r_norm / rhs_norm;
4438 diag.stopping_reason = PcgStopReason::MaxIter;
4439 }
4440 let x = stream
4441 .clone_dtoh(&x_dev)
4442 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4443 Ok((Array1::from_vec(x), diag))
4444 }
4445
4446 pub(super) fn solve_sae_matrix_free_pcg(
4453 sys: &ArrowSchurSystem,
4454 data: &DeviceSaePcgData,
4455 ridge_t: f64,
4456 ridge_beta: f64,
4457 rhs_beta: &Array1<f64>,
4458 max_iterations: usize,
4459 relative_tolerance: f64,
4460 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4461 let k = rhs_beta.len();
4462 if k == 0 || data.beta_dim != k || sys.k != k {
4463 return Err(ArrowSchurGpuFailure::Unavailable);
4464 }
4465 if data.frame.is_some() {
4469 return Err(ArrowSchurGpuFailure::Unavailable);
4470 }
4471 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4485 .filter(|rt| {
4486 rt.policy().reduced_schur_matvec_should_offload(
4487 sys.rows.len(),
4488 sys.k,
4489 sys.d,
4490 max_iterations,
4491 )
4492 })
4493 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4494 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4495 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4496 let stream = ctx
4497 .new_stream()
4498 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4499 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4500 let vector_module = pcg_vector_module(&ctx)?;
4501 let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4502
4503 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4504 if rhs_norm == 0.0 {
4505 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4506 }
4507 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4508 let rhs_dev = stream
4509 .clone_htod(
4510 rhs_beta
4511 .as_slice()
4512 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4513 )
4514 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4515 let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4516 let mut diag_dev = stream
4517 .clone_htod(&diag_host)
4518 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4519 launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4520 let diag_host = stream
4521 .clone_dtoh(&diag_dev)
4522 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4523 let mut inv_diag = Vec::with_capacity(k);
4524 for (idx, &d) in diag_host.iter().enumerate() {
4525 if !d.is_finite() || d <= 1.0e-18 {
4526 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4527 reason: format!(
4528 "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4529 ),
4530 });
4531 }
4532 inv_diag.push(1.0 / d);
4533 }
4534 let inv_diag_dev = stream
4535 .clone_htod(&inv_diag)
4536 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4537
4538 let mut x_dev = stream
4539 .alloc_zeros::<f64>(k)
4540 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4541 let mut r_dev = stream
4542 .alloc_zeros::<f64>(k)
4543 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4545 let mut z_dev = stream
4546 .alloc_zeros::<f64>(k)
4547 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4548 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4549 let mut p_dev = stream
4550 .alloc_zeros::<f64>(k)
4551 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4552 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4553 let mut ap_dev = stream
4554 .alloc_zeros::<f64>(k)
4555 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4556
4557 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4558 if rz <= 0.0 || !rz.is_finite() {
4559 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4560 reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4561 });
4562 }
4563 let mut diag = PcgDiagnostics {
4564 precond_apply_calls: 1,
4565 stopping_reason: PcgStopReason::MaxIter,
4566 ..PcgDiagnostics::default()
4567 };
4568
4569 for _ in 0..max_iterations.max(1) {
4570 launch_sae_matvec(
4571 &stream,
4572 vector_module,
4573 &mut buffers,
4574 &p_dev,
4575 &mut ap_dev,
4576 ridge_beta,
4577 )?;
4578 diag.matvec_calls += 1;
4579 diag.iterations += 1;
4580 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4581 if pap <= 0.0 || !pap.is_finite() {
4582 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4583 reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4584 });
4585 }
4586 let alpha = rz / pap;
4587 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4588 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4589 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4590 if r_norm <= tol {
4591 diag.final_relative_residual = r_norm / rhs_norm;
4592 diag.stopping_reason = PcgStopReason::Converged;
4593 break;
4594 }
4595 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4596 diag.precond_apply_calls += 1;
4597 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4598 if rz_new <= 0.0 || !rz_new.is_finite() {
4599 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4600 reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4601 });
4602 }
4603 let beta = rz_new / rz;
4604 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4605 rz = rz_new;
4606 }
4607 if diag.stopping_reason != PcgStopReason::Converged {
4608 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4609 diag.final_relative_residual = r_norm / rhs_norm;
4610 diag.stopping_reason = PcgStopReason::MaxIter;
4611 }
4612 let x = stream
4613 .clone_dtoh(&x_dev)
4614 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4615 Ok((Array1::from_vec(x), diag))
4616 }
4617
4618 pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4619 s_acc: &ndarray::Array2<f64>,
4620 rhs_beta: &Array1<f64>,
4621 max_iterations: usize,
4622 relative_tolerance: f64,
4623 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4624 let k = rhs_beta.len();
4625 let cg_iters = max_iterations.max(1);
4637 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4638 gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4639 m: k,
4640 n: k,
4641 k: cg_iters,
4642 },
4643 )
4644 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4645 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4646 .and_then(|ctx| ctx.new_stream().ok())
4647 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4648 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4649 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4650 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4651 let vector_module = pcg_vector_module(&ctx)?;
4652
4653 let mut inv_diag = vec![0.0_f64; k];
4655 for j in 0..k {
4656 let djj = s_acc[[j, j]];
4657 if !(djj > 0.0) {
4658 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4659 reason: format!(
4660 "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4661 ),
4662 });
4663 }
4664 inv_diag[j] = 1.0 / djj;
4665 }
4666
4667 let mut s_host = vec![0.0_f64; k * k];
4669 for col in 0..k {
4670 for row in 0..k {
4671 s_host[col * k + row] = s_acc[[row, col]];
4672 }
4673 }
4674 let s_dev = stream
4675 .clone_htod(&s_host)
4676 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4677
4678 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4682 if rhs_norm == 0.0 {
4683 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4684 }
4685 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4686
4687 let mut x_dev = stream
4690 .alloc_zeros::<f64>(k)
4691 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4692 let mut r_dev = stream
4693 .clone_htod(
4694 rhs_beta
4695 .as_slice()
4696 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4697 )
4698 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4699 let inv_diag_dev = stream
4700 .clone_htod(&inv_diag)
4701 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702 let mut z_dev = stream
4703 .alloc_zeros::<f64>(k)
4704 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4705 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4706 let mut p_dev = stream
4707 .alloc_zeros::<f64>(k)
4708 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4709 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4710 let mut sp_dev = stream
4711 .alloc_zeros::<f64>(k)
4712 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4713 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4714 let mut diag = PcgDiagnostics {
4715 precond_apply_calls: 1,
4716 stopping_reason: PcgStopReason::MaxIter,
4717 ..PcgDiagnostics::default()
4718 };
4719 if rz <= 0.0 || !rz.is_finite() {
4720 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4721 reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4722 });
4723 }
4724
4725 let max_iters = max_iterations.max(1);
4726 for _ in 0..max_iters {
4727 let gemv_cfg = GemvConfig::<f64> {
4729 trans: cublasOperation_t::CUBLAS_OP_N,
4730 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4731 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4732 alpha: 1.0,
4733 lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4734 incx: 1,
4735 beta: 0.0,
4736 incy: 1,
4737 };
4738 unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4740 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4741 diag.matvec_calls += 1;
4742 diag.iterations += 1;
4743
4744 let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4745 if !(p_sp > 0.0) {
4746 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4749 reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4750 });
4751 }
4752 let alpha = rz / p_sp;
4753 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4754 device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4755 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4756 if r_norm <= tol {
4757 diag.final_relative_residual = r_norm / rhs_norm;
4758 diag.stopping_reason = PcgStopReason::Converged;
4759 break;
4760 }
4761 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4762 diag.precond_apply_calls += 1;
4763 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4764 if rz_new <= 0.0 || !rz_new.is_finite() {
4765 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4766 reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4767 });
4768 }
4769 let beta = rz_new / rz;
4770 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4771 rz = rz_new;
4772 }
4773 if diag.stopping_reason != PcgStopReason::Converged {
4774 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4775 diag.final_relative_residual = r_norm / rhs_norm;
4776 diag.stopping_reason = PcgStopReason::MaxIter;
4777 }
4778
4779 let x = stream
4780 .clone_dtoh(&x_dev)
4781 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4782 Ok((Array1::from_vec(x), diag))
4783 }
4784
4785 fn device_copy(
4786 blas: &CudaBlas,
4787 stream: &Arc<CudaStream>,
4788 n: usize,
4789 src: &CudaSlice<f64>,
4790 dst: &mut CudaSlice<f64>,
4791 ) -> Result<(), ArrowSchurGpuFailure> {
4792 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4793 let (src_ptr, _src_rec) = src.device_ptr(stream);
4794 let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4795 let status = unsafe {
4798 cudarc::cublas::sys::cublasDcopy_v2(
4799 *blas.handle(),
4800 n_i,
4801 src_ptr as *const f64,
4802 1,
4803 dst_ptr as *mut f64,
4804 1,
4805 )
4806 };
4807 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4808 Ok(())
4809 } else {
4810 Err(ArrowSchurGpuFailure::Unavailable)
4811 }
4812 }
4813
4814 fn device_axpy(
4815 blas: &CudaBlas,
4816 stream: &Arc<CudaStream>,
4817 n: usize,
4818 alpha: f64,
4819 x: &CudaSlice<f64>,
4820 y: &mut CudaSlice<f64>,
4821 ) -> Result<(), ArrowSchurGpuFailure> {
4822 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4823 let (x_ptr, _x_rec) = x.device_ptr(stream);
4824 let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4825 let status = unsafe {
4828 cudarc::cublas::sys::cublasDaxpy_v2(
4829 *blas.handle(),
4830 n_i,
4831 &alpha,
4832 x_ptr as *const f64,
4833 1,
4834 y_ptr as *mut f64,
4835 1,
4836 )
4837 };
4838 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4839 Ok(())
4840 } else {
4841 Err(ArrowSchurGpuFailure::Unavailable)
4842 }
4843 }
4844
4845 fn device_dot(
4846 blas: &CudaBlas,
4847 stream: &Arc<CudaStream>,
4848 n: usize,
4849 x: &CudaSlice<f64>,
4850 y: &CudaSlice<f64>,
4851 ) -> Result<f64, ArrowSchurGpuFailure> {
4852 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4853 let (x_ptr, _x_rec) = x.device_ptr(stream);
4854 let (y_ptr, _y_rec) = y.device_ptr(stream);
4855 let mut result = 0.0_f64;
4856 let status = unsafe {
4860 cudarc::cublas::sys::cublasDdot_v2(
4861 *blas.handle(),
4862 n_i,
4863 x_ptr as *const f64,
4864 1,
4865 y_ptr as *const f64,
4866 1,
4867 &mut result,
4868 )
4869 };
4870 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4871 Ok(result)
4872 } else {
4873 Err(ArrowSchurGpuFailure::Unavailable)
4874 }
4875 }
4876
4877 fn device_nrm2(
4878 blas: &CudaBlas,
4879 stream: &Arc<CudaStream>,
4880 n: usize,
4881 x: &CudaSlice<f64>,
4882 ) -> Result<f64, ArrowSchurGpuFailure> {
4883 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4884 let (x_ptr, _x_rec) = x.device_ptr(stream);
4885 let mut result = 0.0_f64;
4886 let status = unsafe {
4890 cudarc::cublas::sys::cublasDnrm2_v2(
4891 *blas.handle(),
4892 n_i,
4893 x_ptr as *const f64,
4894 1,
4895 &mut result,
4896 )
4897 };
4898 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4899 Ok(result)
4900 } else {
4901 Err(ArrowSchurGpuFailure::Unavailable)
4902 }
4903 }
4904
4905 #[cfg(test)]
4906 mod tests {
4907 use super::*;
4912 use crate::arrow_schur::{
4913 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4914 FactoredFrameGBlock,
4915 };
4916 use ndarray::Array2;
4917
4918 fn device_matvec_once(
4921 sys: &ArrowSchurSystem,
4922 data: &DeviceSaePcgData,
4923 ridge_t: f64,
4924 ridge_beta: f64,
4925 x_host: &[f64],
4926 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4927 let k = x_host.len();
4928 let frame = data
4929 .frame
4930 .as_ref()
4931 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4932 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4933 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4934 let ctx =
4935 gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4936 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4937 let stream = ctx
4938 .new_stream()
4939 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4940 let vector_module = pcg_vector_module(&ctx)?;
4941 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4942 let x_dev = stream
4943 .clone_htod(x_host)
4944 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4945 let mut out_dev = stream
4946 .alloc_zeros::<f64>(k)
4947 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4948 launch_sae_frame_matvec(
4949 &stream,
4950 vector_module,
4951 &mut buffers,
4952 &x_dev,
4953 &mut out_dev,
4954 ridge_beta,
4955 )?;
4956 stream
4957 .clone_dtoh(&out_dev)
4958 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4959 }
4960
4961 #[test]
4967 fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4968 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4969 return;
4970 }
4971 let p = 3usize;
4972 let ranks = vec![2usize, 3usize];
4973 let basis_sizes = vec![2usize, 2usize];
4974 let mut border_offsets = Vec::new();
4975 let mut acc = 0usize;
4976 for k in 0..2 {
4977 border_offsets.push(acc);
4978 acc += basis_sizes[k] * ranks[k];
4979 }
4980 let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
4982 Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4983 0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4984 })
4985 };
4986 let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4987 let w_of = |i: usize, j: usize| -> Array2<f64> {
4988 let (ui, uj) = (&frames[i], &frames[j]);
4989 Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4990 (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4991 })
4992 };
4993 let mut frame_blocks = Vec::new();
4994 for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4995 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4996 let mut g =
4997 Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
4998 if i == j {
4999 for r in 0..mi.min(mj) {
5000 g[[r, r]] += mi as f64 + 2.0;
5001 }
5002 }
5003 frame_blocks.push(FactoredFrameGBlock {
5004 atom_i: i,
5005 atom_j: j,
5006 g,
5007 w: w_of(i, j),
5008 });
5009 }
5010 let mut smooth_blocks = Vec::new();
5011 for k in 0..2 {
5012 let m = basis_sizes[k];
5013 let mut s =
5014 Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5015 for r in 0..m {
5016 s[[r, r]] += 1.0;
5017 }
5018 smooth_blocks.push(DeviceSaeSmoothBlock {
5019 global_offset: border_offsets[k],
5020 factor_a: s,
5021 });
5022 }
5023 let smooth_ranks = ranks.clone();
5024 let n = 2usize;
5025 let q = 2usize;
5026 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5027 let mut row_htbeta = Vec::new();
5028 for i in 0..n {
5029 let mut htt =
5030 Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5031 for r in 0..q {
5032 htt[[r, r]] += q as f64 + 2.0;
5033 }
5034 sys.rows[i].htt = htt;
5035 let mut slab = vec![0.0_f64; q * border_dim];
5036 for c in 0..q {
5037 for col in 0..border_dim {
5038 let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5039 slab[c * border_dim + col] = v;
5040 sys.rows[i].htbeta[[c, col]] = v;
5041 }
5042 }
5043 row_htbeta.push(slab);
5044 }
5045 let data = DeviceSaePcgData {
5046 p,
5047 beta_dim: border_dim,
5048 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5049 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5050 smooth_blocks,
5051 sparse_g_blocks: Vec::new(),
5052 frame: Some(DeviceSaeFrameData {
5053 ranks,
5054 basis_sizes,
5055 border_offsets,
5056 frame_blocks,
5057 smooth_ranks,
5058 row_htbeta,
5059 }),
5060 };
5061 let ridge_t = 1e-7;
5062 let ridge_beta = 1e-6;
5063 let mut first_bad: Option<usize> = None;
5064 let mut worst = 0.0_f64;
5065 let mut worst_at = 0usize;
5066 let mut worst_dev = 0.0_f64;
5067 let mut worst_cpu = 0.0_f64;
5068 for col in 0..border_dim {
5069 let mut x = vec![0.0_f64; border_dim];
5070 x[col] = 1.0;
5071 let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5072 Ok(v) => v,
5073 Err(_) => return,
5074 };
5075 let mut cpu = vec![0.0_f64; border_dim];
5076 super::super::sae_framed_schur_matvec_cpu(
5077 &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5078 )
5079 .expect("cpu matvec");
5080 for r in 0..border_dim {
5081 let d = (dev[r] - cpu[r]).abs();
5082 if d > 1e-9 && first_bad.is_none() {
5083 first_bad = Some(r * border_dim + col);
5084 }
5085 if d > worst {
5086 worst = d;
5087 worst_at = r * border_dim + col;
5088 worst_dev = dev[r];
5089 worst_cpu = cpu[r];
5090 }
5091 }
5092 }
5093 assert!(
5094 worst <= 1e-9,
5095 "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5096 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5097 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5098 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5099 G⊗W=cross, reduced-Schur=dense per-row)",
5100 );
5101 }
5102 }
5103}
5104
5105#[cfg(test)]
5106mod tests {
5107 use super::*;
5108 use crate::arrow_schur::ArrowSchurSystem;
5109 use ndarray::{Array2, ArrayView1};
5110
5111 fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5112 let mut sys = ArrowSchurSystem::new(n, d, k);
5113 let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5114 let mut sample = || -> f64 {
5115 state = state
5116 .wrapping_mul(6364136223846793005)
5117 .wrapping_add(1442695040888963407);
5118 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5119 };
5120 for row in &mut sys.rows {
5121 let mut a = Array2::<f64>::zeros((d, d));
5122 for r in 0..d {
5123 for c in 0..d {
5124 a[[r, c]] = sample();
5125 }
5126 }
5127 let mut htt = a.t().dot(&a);
5128 for r in 0..d {
5129 htt[[r, r]] += d as f64 + 1.0;
5130 }
5131 row.htt = htt;
5132 for r in 0..d {
5133 for c in 0..k {
5134 row.htbeta[[r, c]] = 0.1 * sample();
5135 }
5136 row.gt[r] = sample();
5137 }
5138 }
5139 let mut hbb_a = Array2::<f64>::zeros((k, k));
5140 for r in 0..k {
5141 for c in 0..k {
5142 hbb_a[[r, c]] = sample();
5143 }
5144 }
5145 let mut hbb = hbb_a.t().dot(&hbb_a);
5146 for r in 0..k {
5147 hbb[[r, r]] += k as f64 + 1.0;
5148 }
5149 sys.hbb = hbb;
5150 for r in 0..k {
5151 sys.gb[r] = sample();
5152 }
5153 sys
5154 }
5155
5156 #[test]
5161 fn ridge_bump_makes_known_indefinite_blocks_pd() {
5162 let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); let mut indef2 = Array2::<f64>::zeros((2, 2));
5170 indef2[[0, 0]] = 1.0;
5171 indef2[[1, 1]] = 1.0;
5172 indef2[[0, 1]] = 2.0;
5173 indef2[[1, 0]] = 2.0;
5174 let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5177
5178 for (label, block) in [
5179 ("-I (λ_min=-1)", neg_identity),
5180 ("-250·I (λ_min=-250)", scaled_neg),
5181 ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5182 ("5·I (PD)", pd),
5183 ] {
5184 let ridge_t = 0.0;
5185 let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5186 assert!(
5187 bump > 0.0 && bump.is_finite(),
5188 "[{label}] bump must be strictly positive and finite, got {bump:e}"
5189 );
5190 let d = block.nrows();
5191 let mut shifted = block.clone();
5192 for i in 0..d {
5193 shifted[[i, i]] += ridge_t + bump;
5194 }
5195 assert!(
5196 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5197 "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5198 Gershgorin bump, but the Cholesky still rejected it"
5199 );
5200 }
5201 }
5202
5203 #[cfg(target_os = "linux")]
5213 #[test]
5214 fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5215 let mut a = Array2::<f64>::zeros((3, 3));
5217 a[[0, 0]] = -2.0;
5218 a[[1, 1]] = 0.5;
5219 a[[2, 2]] = 1.0;
5220 a[[0, 1]] = 0.3;
5221 a[[1, 0]] = 0.3;
5222 a[[1, 2]] = -0.4;
5223 a[[2, 1]] = -0.4;
5224 a[[0, 2]] = 0.1;
5225 a[[2, 0]] = 0.1;
5226
5227 let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5228
5229 let d = 3;
5231 let mut col_major = vec![0.0_f64; d * d];
5232 for c in 0..d {
5233 for r in 0..d {
5234 col_major[c * d + r] = a[[r, c]];
5235 }
5236 }
5237 let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5238
5239 assert!(
5240 (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5241 "colmajor bump {col_major_bump:e} must match rowmajor bump \
5242 {row_major_bump:e} for a symmetric block"
5243 );
5244
5245 let mut shifted = a.clone();
5247 for i in 0..d {
5248 shifted[[i, i]] += col_major_bump;
5249 }
5250 assert!(
5251 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5252 "colmajor Gershgorin bump must make the symmetric block PD"
5253 );
5254 }
5255
5256 fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5257 let mut s = Array2::<f64>::zeros((k, k));
5258 for row in 0..k {
5259 s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5260 if row + 1 < k {
5261 s[[row, row + 1]] = -0.05;
5262 s[[row + 1, row]] = -0.05;
5263 }
5264 if row + 7 < k {
5265 s[[row, row + 7]] = 0.01;
5266 s[[row + 7, row]] = 0.01;
5267 }
5268 }
5269 let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5270 (s, rhs)
5271 }
5272
5273 fn dense_pcg_cpu_reference(
5274 s: &Array2<f64>,
5275 rhs: &Array1<f64>,
5276 max_iterations: usize,
5277 relative_tolerance: f64,
5278 ) -> Array1<f64> {
5279 let k = rhs.len();
5280 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5281 if rhs_norm == 0.0 {
5282 return Array1::<f64>::zeros(k);
5283 }
5284 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5285 let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5286 let mut x = Array1::<f64>::zeros(k);
5287 let mut r = rhs.clone();
5288 let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5289 let mut p = z.clone();
5290 let mut sp = Array1::<f64>::zeros(k);
5291 let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5292 for _ in 0..max_iterations.max(1) {
5293 for row in 0..k {
5294 let mut acc = 0.0;
5295 for col in 0..k {
5296 acc += s[[row, col]] * p[col];
5297 }
5298 sp[row] = acc;
5299 }
5300 let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5301 let alpha = rz / p_sp;
5302 for idx in 0..k {
5303 x[idx] += alpha * p[idx];
5304 r[idx] -= alpha * sp[idx];
5305 }
5306 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5307 if r_norm <= tol {
5308 break;
5309 }
5310 for idx in 0..k {
5311 z[idx] = inv_diag[idx] * r[idx];
5312 }
5313 let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5314 let beta = rz_next / rz;
5315 for idx in 0..k {
5316 p[idx] = z[idx] + beta * p[idx];
5317 }
5318 rz = rz_next;
5319 }
5320 x
5321 }
5322
5323 #[test]
5324 fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5325 let (s, rhs) = device_pcg_fixture(512);
5326 let max_iterations = 200usize;
5327 let relative_tolerance = 1.0e-12;
5328 let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5329 let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5330 &s,
5331 &rhs,
5332 max_iterations,
5333 relative_tolerance,
5334 ) {
5335 Ok(result) => result,
5336 Err(failure) => {
5343 assert!(
5344 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5345 "#1017: CUDA device present but the device reduced-beta PCG \
5346 declined/faulted instead of returning a result (tag: {failure:?}) — \
5347 the kernel does not run correctly on GPU"
5348 );
5349 return;
5350 }
5351 };
5352 let max_err = cpu
5353 .iter()
5354 .zip(device.iter())
5355 .map(|(a, b)| (a - b).abs())
5356 .fold(0.0_f64, f64::max);
5357 assert!(
5358 max_err <= 1.0e-10,
5359 "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5360 );
5361 assert!(diag.matvec_calls > 0);
5362 assert_eq!(diag.matvec_calls, diag.iterations);
5363 }
5364
5365 #[test]
5366 fn dense_reference_matches_independent_solve() {
5367 let sys = build_fixture(4, 5, 3, 7);
5368 let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5369 let n = sys.rows.len();
5373 let d = sys.d;
5374 let k = sys.k;
5375 let total = n * d + k;
5376 let mut h = Array2::<f64>::zeros((total, total));
5377 let mut g = ndarray::Array1::<f64>::zeros(total);
5378 for (i, row) in sys.rows.iter().enumerate() {
5379 let base = i * d;
5380 for c in 0..d {
5381 for r in 0..d {
5382 h[[base + r, base + c]] = row.htt[[r, c]];
5383 }
5384 }
5385 for c in 0..k {
5386 for r in 0..d {
5387 h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5388 h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5389 }
5390 }
5391 for r in 0..d {
5392 g[base + r] = row.gt[r];
5393 }
5394 }
5395 for c in 0..k {
5396 for r in 0..k {
5397 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5398 }
5399 g[n * d + c] = sys.gb[c];
5400 }
5401 let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5402 let rhs = g.mapv(|v| -v);
5403 let expected = cholesky_solve_vector(l.view(), rhs.view());
5404 for i in 0..n * d {
5405 assert!(
5406 (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5407 "delta_t[{i}] mismatch: got {} expected {}",
5408 solution.delta_t[i],
5409 expected[i]
5410 );
5411 }
5412 for a in 0..k {
5413 assert!(
5414 (solution.delta_beta[a] - expected[n * d + a]).abs()
5415 < 1e-10 * (1.0 + expected[n * d + a].abs()),
5416 "delta_beta[{a}] mismatch"
5417 );
5418 }
5419 }
5420
5421 #[test]
5435 fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5436 use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5437 let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
5439 let k = 24usize;
5440 let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5441 let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5446 let forward_slabs = slabs.clone();
5447 let transpose_slabs = slabs;
5448 sys.set_row_htbeta_operator(
5449 move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5450 let h = &forward_slabs[row];
5451 for r in 0..h.nrows() {
5452 let mut acc = 0.0_f64;
5453 for c in 0..h.ncols() {
5454 acc += h[[r, c]] * x[c];
5455 }
5456 out[r] = acc;
5457 }
5458 },
5459 move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5460 let h = &transpose_slabs[row];
5461 for r in 0..h.nrows() {
5462 for c in 0..h.ncols() {
5463 out[c] += h[[r, c]] * v[r];
5464 }
5465 }
5466 },
5467 );
5468
5469 let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5470 .expect("row-procedural matvec backend builds for matrix-free system");
5471 let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5472
5473 let mut out_parallel_a = Array1::<f64>::zeros(k);
5477 matvec(&x, &mut out_parallel_a);
5478 let mut out_parallel_b = Array1::<f64>::zeros(k);
5479 matvec(&x, &mut out_parallel_b);
5480 for a in 0..k {
5481 assert_eq!(
5482 out_parallel_a[a].to_bits(),
5483 out_parallel_b[a].to_bits(),
5484 "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5485 );
5486 }
5487
5488 let mut out_serial = Array1::<f64>::zeros(k);
5493 rayon::ThreadPoolBuilder::new()
5494 .num_threads(2)
5495 .build()
5496 .expect("build rayon pool")
5497 .install(|| matvec(&x, &mut out_serial));
5498
5499 let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5500 for a in 0..k {
5501 let diff = (out_parallel_a[a] - out_serial[a]).abs();
5502 assert!(
5503 diff <= 1e-12 * (1.0 + max_abs),
5504 "row-procedural matvec parallel vs serial diverged beyond reassociation \
5505 at index {a}: {} vs {} (diff={diff:e})",
5506 out_parallel_a[a],
5507 out_serial[a]
5508 );
5509 }
5510 }
5511
5512 #[test]
5519 fn framed_sae_schur_matvec_matches_dense_reference() {
5520 use crate::arrow_schur::{
5521 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5522 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5523 };
5524
5525 let p = 4usize;
5526 let ranks = vec![2usize, 4usize, 3usize];
5528 let basis_sizes = vec![2usize, 1usize, 2usize];
5529 let n_atoms = ranks.len();
5530 let mut border_offsets = Vec::with_capacity(n_atoms);
5531 let mut acc = 0usize;
5532 for k in 0..n_atoms {
5533 border_offsets.push(acc);
5534 acc += basis_sizes[k] * ranks[k];
5535 }
5536 let border_dim = acc; let mut state = 0x1234_5678_9abc_def0u64;
5539 let mut sample = || -> f64 {
5540 state = state
5541 .wrapping_mul(6364136223846793005)
5542 .wrapping_add(1442695040888963407);
5543 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5544 };
5545
5546 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5549 for k in 0..n_atoms {
5550 let r = ranks[k];
5551 let mut u = Array2::<f64>::zeros((p, r));
5552 for i in 0..p {
5553 for j in 0..r {
5554 u[[i, j]] = if r == p && i == j {
5555 1.0
5556 } else if r == p {
5557 0.0
5558 } else {
5559 sample()
5560 };
5561 }
5562 }
5563 frames.push(u);
5564 }
5565 let w_of = |i: usize, j: usize| -> Array2<f64> {
5566 let (ui, uj) = (&frames[i], &frames[j]);
5567 let (ri, rj) = (ranks[i], ranks[j]);
5568 let mut w = Array2::<f64>::zeros((ri, rj));
5569 for a in 0..ri {
5570 for b in 0..rj {
5571 let mut s = 0.0;
5572 for c in 0..p {
5573 s += ui[[c, a]] * uj[[c, b]];
5574 }
5575 w[[a, b]] = s;
5576 }
5577 }
5578 w
5579 };
5580
5581 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5583 let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5584 pairs.sort();
5585 for &(i, j) in &pairs {
5586 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5587 let mut g = Array2::<f64>::zeros((mi, mj));
5588 for r in 0..mi {
5589 for c in 0..mj {
5590 g[[r, c]] = 0.3 * sample();
5591 }
5592 }
5593 if i == j {
5595 for r in 0..mi.min(mj) {
5596 g[[r, r]] += mi as f64 + 2.0;
5597 }
5598 }
5599 frame_blocks.push(FactoredFrameGBlock {
5600 atom_i: i,
5601 atom_j: j,
5602 g,
5603 w: w_of(i, j),
5604 });
5605 }
5606
5607 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5609 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5610 for k in 0..n_atoms {
5611 let m = basis_sizes[k];
5612 let mut a = Array2::<f64>::zeros((m, m));
5613 for r in 0..m {
5614 for c in 0..m {
5615 a[[r, c]] = 0.2 * sample();
5616 }
5617 }
5618 let mut s = a.t().dot(&a);
5619 for r in 0..m {
5620 s[[r, r]] += 1.0;
5621 }
5622 smooth_blocks.push(DeviceSaeSmoothBlock {
5623 global_offset: border_offsets[k],
5624 factor_a: s,
5625 });
5626 smooth_ranks.push(ranks[k]);
5627 }
5628
5629 let n = 6usize;
5631 let q = 3usize;
5632 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5633 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5634 for i in 0..n {
5635 let mut a = Array2::<f64>::zeros((q, q));
5637 for r in 0..q {
5638 for c in 0..q {
5639 a[[r, c]] = sample();
5640 }
5641 }
5642 let mut htt = a.t().dot(&a);
5643 for r in 0..q {
5644 htt[[r, r]] += q as f64 + 1.0;
5645 }
5646 sys.rows[i].htt = htt;
5647 let mut slab = vec![0.0_f64; q * border_dim];
5648 for c in 0..q {
5649 for col in 0..border_dim {
5650 let v = 0.15 * sample();
5651 slab[c * border_dim + col] = v;
5652 sys.rows[i].htbeta[[c, col]] = v;
5653 }
5654 }
5655 row_htbeta.push(slab);
5656 }
5657
5658 let data_op =
5661 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5662 .expect("frame op");
5663 let mut hbb = data_op.to_dense();
5664 for k in 0..n_atoms {
5665 let op = IdentityRightKroneckerPenaltyOp {
5666 factor_a: smooth_blocks[k].factor_a.clone(),
5667 p: ranks[k],
5668 global_offset: border_offsets[k],
5669 k: border_dim,
5670 };
5671 let d = op.to_dense();
5672 for r in 0..border_dim {
5673 for c in 0..border_dim {
5674 hbb[[r, c]] += d[[r, c]];
5675 }
5676 }
5677 }
5678 sys.hbb = hbb;
5679
5680 let data = DeviceSaePcgData {
5681 p,
5682 beta_dim: border_dim,
5683 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5684 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5685 smooth_blocks,
5686 sparse_g_blocks: Vec::new(),
5687 frame: Some(DeviceSaeFrameData {
5688 ranks: ranks.clone(),
5689 basis_sizes: basis_sizes.clone(),
5690 border_offsets: border_offsets.clone(),
5691 frame_blocks,
5692 smooth_ranks,
5693 row_htbeta,
5694 }),
5695 };
5696
5697 let ridge_t = 1e-7;
5698 let ridge_beta = 1e-6;
5699
5700 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5704 for r in 0..border_dim {
5705 for c in 0..border_dim {
5706 s_dense[[r, c]] = sys.hbb[[r, c]];
5707 }
5708 s_dense[[r, r]] += ridge_beta;
5709 }
5710 for row in &sys.rows {
5711 let mut htt = row.htt.clone();
5712 for d in 0..q {
5713 htt[[d, d]] += ridge_t;
5714 }
5715 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5716 .expect("htt PD");
5717 let mut y = Array2::<f64>::zeros((q, border_dim));
5719 for col in 0..border_dim {
5720 let mut e = Array1::<f64>::zeros(q);
5721 for r in 0..q {
5722 e[r] = row.htbeta[[r, col]];
5723 }
5724 let solved = cholesky_solve_vector(factor.view(), e.view());
5725 for r in 0..q {
5726 y[[r, col]] = solved[r];
5727 }
5728 }
5729 for r in 0..border_dim {
5730 for c in 0..border_dim {
5731 let mut acc = 0.0;
5732 for d in 0..q {
5733 acc += row.htbeta[[d, r]] * y[[d, c]];
5734 }
5735 s_dense[[r, c]] -= acc;
5736 }
5737 }
5738 }
5739
5740 let mut max_rel = 0.0_f64;
5742 for trial in 0..4 {
5743 let x: Vec<f64> = (0..border_dim)
5744 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5745 .collect();
5746 let mut got = vec![0.0_f64; border_dim];
5747 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5748 .expect("framed matvec");
5749 let mut want = vec![0.0_f64; border_dim];
5750 for r in 0..border_dim {
5751 let mut acc = 0.0;
5752 for c in 0..border_dim {
5753 acc += s_dense[[r, c]] * x[c];
5754 }
5755 want[r] = acc;
5756 }
5757 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5758 for a in 0..border_dim {
5759 let rel = (got[a] - want[a]).abs() / scale;
5760 max_rel = max_rel.max(rel);
5761 }
5762 }
5763 assert!(
5764 max_rel <= 1e-10,
5765 "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5766 );
5767 }
5768
5769 #[test]
5779 fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5780 use crate::arrow_schur::{
5781 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5782 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5783 };
5784
5785 let p = 12usize;
5786 let n_atoms = 40usize;
5787 let ranks: Vec<usize> = (0..n_atoms)
5790 .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5791 .collect();
5792 let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5793 let mut border_offsets = Vec::with_capacity(n_atoms);
5794 let mut acc = 0usize;
5795 for k in 0..n_atoms {
5796 border_offsets.push(acc);
5797 acc += basis_sizes[k] * ranks[k];
5798 }
5799 let border_dim = acc;
5800
5801 let mut state = 0x0bad_c0de_dead_beefu64;
5802 let mut sample = || -> f64 {
5803 state = state
5804 .wrapping_mul(6364136223846793005)
5805 .wrapping_add(1442695040888963407);
5806 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5807 };
5808
5809 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5811 for k in 0..n_atoms {
5812 let r = ranks[k];
5813 let mut u = Array2::<f64>::zeros((p, r));
5814 for i in 0..p {
5815 for j in 0..r {
5816 u[[i, j]] = if r == p {
5817 if i == j { 1.0 } else { 0.0 }
5818 } else {
5819 sample()
5820 };
5821 }
5822 }
5823 frames.push(u);
5824 }
5825 let w_of = |i: usize, j: usize| -> Array2<f64> {
5826 let (ui, uj) = (&frames[i], &frames[j]);
5827 let (ri, rj) = (ranks[i], ranks[j]);
5828 let mut w = Array2::<f64>::zeros((ri, rj));
5829 for a in 0..ri {
5830 for b in 0..rj {
5831 let mut s = 0.0;
5832 for c in 0..p {
5833 s += ui[[c, a]] * uj[[c, b]];
5834 }
5835 w[[a, b]] = s;
5836 }
5837 }
5838 w
5839 };
5840
5841 let mut pairs: Vec<(usize, usize)> = Vec::new();
5844 for k in 0..n_atoms {
5845 pairs.push((k, k));
5846 }
5847 for k in 0..n_atoms - 1 {
5848 pairs.push((k, k + 1));
5849 pairs.push((k + 1, k));
5850 }
5851 pairs.sort_unstable();
5852 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5853 for &(i, j) in &pairs {
5854 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5855 let mut g = Array2::<f64>::zeros((mi, mj));
5856 for r in 0..mi {
5857 for c in 0..mj {
5858 g[[r, c]] = 0.3 * sample();
5859 }
5860 }
5861 if i == j {
5862 for r in 0..mi.min(mj) {
5863 g[[r, r]] += mi as f64 + 2.0;
5864 }
5865 }
5866 frame_blocks.push(FactoredFrameGBlock {
5867 atom_i: i,
5868 atom_j: j,
5869 g,
5870 w: w_of(i, j),
5871 });
5872 }
5873
5874 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5876 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5877 for k in 0..n_atoms {
5878 let m = basis_sizes[k];
5879 let mut a = Array2::<f64>::zeros((m, m));
5880 for r in 0..m {
5881 for c in 0..m {
5882 a[[r, c]] = 0.2 * sample();
5883 }
5884 }
5885 let mut s = a.t().dot(&a);
5886 for r in 0..m {
5887 s[[r, r]] += 1.0;
5888 }
5889 smooth_blocks.push(DeviceSaeSmoothBlock {
5890 global_offset: border_offsets[k],
5891 factor_a: s,
5892 });
5893 smooth_ranks.push(ranks[k]);
5894 }
5895
5896 let n = 8usize;
5898 let q = 3usize;
5899 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5900 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5901 for i in 0..n {
5902 let mut a = Array2::<f64>::zeros((q, q));
5903 for r in 0..q {
5904 for c in 0..q {
5905 a[[r, c]] = sample();
5906 }
5907 }
5908 let mut htt = a.t().dot(&a);
5909 for r in 0..q {
5910 htt[[r, r]] += q as f64 + 1.0;
5911 }
5912 sys.rows[i].htt = htt;
5913 let mut slab = vec![0.0_f64; q * border_dim];
5914 for c in 0..q {
5915 for col in 0..border_dim {
5916 let v = 0.15 * sample();
5917 slab[c * border_dim + col] = v;
5918 sys.rows[i].htbeta[[c, col]] = v;
5919 }
5920 }
5921 row_htbeta.push(slab);
5922 }
5923
5924 let data_op =
5927 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5928 .expect("frame op");
5929 let mut hbb = data_op.to_dense();
5930 for k in 0..n_atoms {
5931 let op = IdentityRightKroneckerPenaltyOp {
5932 factor_a: smooth_blocks[k].factor_a.clone(),
5933 p: ranks[k],
5934 global_offset: border_offsets[k],
5935 k: border_dim,
5936 };
5937 let d = op.to_dense();
5938 for r in 0..border_dim {
5939 for c in 0..border_dim {
5940 hbb[[r, c]] += d[[r, c]];
5941 }
5942 }
5943 }
5944 sys.hbb = hbb;
5945
5946 let data = DeviceSaePcgData {
5947 p,
5948 beta_dim: border_dim,
5949 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5950 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5951 smooth_blocks,
5952 sparse_g_blocks: Vec::new(),
5953 frame: Some(DeviceSaeFrameData {
5954 ranks: ranks.clone(),
5955 basis_sizes: basis_sizes.clone(),
5956 border_offsets: border_offsets.clone(),
5957 frame_blocks,
5958 smooth_ranks,
5959 row_htbeta,
5960 }),
5961 };
5962
5963 let ridge_t = 1e-7;
5964 let ridge_beta = 1e-6;
5965
5966 let mut s_dense = sys.hbb.clone();
5968 for r in 0..border_dim {
5969 s_dense[[r, r]] += ridge_beta;
5970 }
5971 for row in &sys.rows {
5972 let mut htt = row.htt.clone();
5973 for d in 0..q {
5974 htt[[d, d]] += ridge_t;
5975 }
5976 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5977 .expect("htt PD");
5978 let mut y = Array2::<f64>::zeros((q, border_dim));
5979 for col in 0..border_dim {
5980 let mut e = Array1::<f64>::zeros(q);
5981 for r in 0..q {
5982 e[r] = row.htbeta[[r, col]];
5983 }
5984 let solved = cholesky_solve_vector(factor.view(), e.view());
5985 for r in 0..q {
5986 y[[r, col]] = solved[r];
5987 }
5988 }
5989 for r in 0..border_dim {
5990 for c in 0..border_dim {
5991 let mut acc = 0.0;
5992 for d in 0..q {
5993 acc += row.htbeta[[d, r]] * y[[d, c]];
5994 }
5995 s_dense[[r, c]] -= acc;
5996 }
5997 }
5998 }
5999
6000 let mut max_rel = 0.0_f64;
6001 for trial in 0..4 {
6002 let x: Vec<f64> = (0..border_dim)
6003 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6004 .collect();
6005 let mut got = vec![0.0_f64; border_dim];
6006 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6007 .expect("framed matvec");
6008 let mut want = vec![0.0_f64; border_dim];
6009 for r in 0..border_dim {
6010 let mut acc = 0.0;
6011 for c in 0..border_dim {
6012 acc += s_dense[[r, c]] * x[c];
6013 }
6014 want[r] = acc;
6015 }
6016 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6017 for a in 0..border_dim {
6018 let rel = (got[a] - want[a]).abs() / scale;
6019 max_rel = max_rel.max(rel);
6020 }
6021 }
6022 assert!(
6023 max_rel <= 1e-10,
6024 "large-K framed SAE Schur matvec vs dense reference diverged: \
6025 max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6026 );
6027 }
6028
6029 #[test]
6035 fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6036 use crate::arrow_schur::{
6037 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6038 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6039 };
6040
6041 let p = 6usize;
6045 let n_atoms = 8usize;
6046 let ranks: Vec<usize> = (0..n_atoms)
6047 .map(|k| if k % 2 == 0 { 3usize } else { p })
6048 .collect();
6049 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6050 let mut border_offsets = Vec::with_capacity(n_atoms);
6051 let mut acc = 0usize;
6052 for k in 0..n_atoms {
6053 border_offsets.push(acc);
6054 acc += basis_sizes[k] * ranks[k];
6055 }
6056 let border_dim = acc; let mut state = 0xfeed_face_dead_beefu64;
6059 let mut sample = || -> f64 {
6060 state = state
6061 .wrapping_mul(6364136223846793005)
6062 .wrapping_add(1442695040888963407);
6063 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6064 };
6065 let mut frames: Vec<Array2<f64>> = Vec::new();
6066 for k in 0..n_atoms {
6067 let r = ranks[k];
6068 let mut u = Array2::<f64>::zeros((p, r));
6069 for i in 0..p {
6070 for j in 0..r {
6071 u[[i, j]] = if r == p && i == j {
6072 1.0
6073 } else if r == p {
6074 0.0
6075 } else {
6076 sample()
6077 };
6078 }
6079 }
6080 frames.push(u);
6081 }
6082 let w_of = |i: usize, j: usize| {
6083 let (ui, uj) = (&frames[i], &frames[j]);
6084 let (ri, rj) = (ranks[i], ranks[j]);
6085 let mut w = Array2::<f64>::zeros((ri, rj));
6086 for a in 0..ri {
6087 for b in 0..rj {
6088 let mut s = 0.0;
6089 for c in 0..p {
6090 s += ui[[c, a]] * uj[[c, b]];
6091 }
6092 w[[a, b]] = s;
6093 }
6094 }
6095 w
6096 };
6097 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6098 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6100 pairs.push((i, j));
6101 pairs.push((j, i));
6102 }
6103 let mut frame_blocks = Vec::new();
6104 for &(i, j) in &pairs {
6105 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6106 let mut g = Array2::<f64>::zeros((mi, mj));
6107 for r in 0..mi {
6108 for c in 0..mj {
6109 g[[r, c]] = 0.25 * sample();
6110 }
6111 }
6112 if i == j {
6113 for r in 0..mi.min(mj) {
6114 g[[r, r]] += mi as f64 + 2.0;
6115 }
6116 }
6117 frame_blocks.push(FactoredFrameGBlock {
6118 atom_i: i,
6119 atom_j: j,
6120 g,
6121 w: w_of(i, j),
6122 });
6123 }
6124 let mut smooth_blocks = Vec::new();
6125 let mut smooth_ranks = Vec::new();
6126 for k in 0..n_atoms {
6127 let m = basis_sizes[k];
6128 let mut a = Array2::<f64>::zeros((m, m));
6129 for r in 0..m {
6130 for c in 0..m {
6131 a[[r, c]] = 0.2 * sample();
6132 }
6133 }
6134 let mut s = a.t().dot(&a);
6135 for r in 0..m {
6136 s[[r, r]] += 1.0;
6137 }
6138 smooth_blocks.push(DeviceSaeSmoothBlock {
6139 global_offset: border_offsets[k],
6140 factor_a: s,
6141 });
6142 smooth_ranks.push(ranks[k]);
6143 }
6144 let n = 400usize;
6145 let q = 4usize;
6146 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6147 let mut row_htbeta = Vec::new();
6148 for i in 0..n {
6149 let mut a = Array2::<f64>::zeros((q, q));
6150 for r in 0..q {
6151 for c in 0..q {
6152 a[[r, c]] = sample();
6153 }
6154 }
6155 let mut htt = a.t().dot(&a);
6156 for r in 0..q {
6157 htt[[r, r]] += q as f64 + 1.0;
6158 }
6159 sys.rows[i].htt = htt;
6160 let mut slab = vec![0.0_f64; q * border_dim];
6161 for c in 0..q {
6162 for col in 0..border_dim {
6163 let v = 0.02 * sample();
6166 slab[c * border_dim + col] = v;
6167 sys.rows[i].htbeta[[c, col]] = v;
6168 }
6169 }
6170 row_htbeta.push(slab);
6171 }
6172 let data_op =
6173 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6174 .expect("frame op");
6175 let mut hbb = data_op.to_dense();
6176 for k in 0..n_atoms {
6177 let op = IdentityRightKroneckerPenaltyOp {
6178 factor_a: smooth_blocks[k].factor_a.clone(),
6179 p: ranks[k],
6180 global_offset: border_offsets[k],
6181 k: border_dim,
6182 };
6183 let d = op.to_dense();
6184 for r in 0..border_dim {
6185 for c in 0..border_dim {
6186 hbb[[r, c]] += d[[r, c]];
6187 }
6188 }
6189 }
6190 sys.hbb = hbb;
6191 let data = DeviceSaePcgData {
6192 p,
6193 beta_dim: border_dim,
6194 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6195 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6196 smooth_blocks,
6197 sparse_g_blocks: Vec::new(),
6198 frame: Some(DeviceSaeFrameData {
6199 ranks: ranks.clone(),
6200 basis_sizes: basis_sizes.clone(),
6201 border_offsets: border_offsets.clone(),
6202 frame_blocks,
6203 smooth_ranks,
6204 row_htbeta,
6205 }),
6206 };
6207 let ridge_t = 1e-7;
6208 let ridge_beta = 1e-6;
6209 let rhs: Array1<f64> =
6210 Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6211
6212 let (device, diag) =
6213 match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6214 Ok(result) => result,
6215 Err(failure) => {
6221 assert!(
6222 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6223 "#1017: CUDA device present but the framed device SAE PCG \
6224 declined/faulted instead of returning a result (tag: {failure:?}) — \
6225 the kernel does not run correctly on GPU"
6226 );
6227 return;
6228 }
6229 };
6230
6231 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6249 let oracle_resid = |x: &Array1<f64>| -> f64 {
6250 let mut sx = vec![0.0_f64; border_dim];
6251 sae_framed_schur_matvec_cpu(
6252 &sys,
6253 &data,
6254 ridge_t,
6255 ridge_beta,
6256 x.as_slice().unwrap(),
6257 &mut sx,
6258 )
6259 .expect("cpu oracle matvec");
6260 let mut acc = 0.0_f64;
6261 for a in 0..border_dim {
6262 let e = sx[a] - rhs[a];
6263 acc += e * e;
6264 }
6265 acc.sqrt()
6266 };
6267 let s_dev_resid = oracle_resid(&device);
6268 let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6269
6270 let precond = {
6275 let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6276 let mut diag = d;
6279 for (i, row) in sys.rows.iter().enumerate() {
6280 let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6281 let qi = sys.row_dims[i];
6282 if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6283 continue;
6284 }
6285 let mut block = row.htt.clone();
6286 for dd in 0..qi {
6287 block[[dd, dd]] += ridge_t;
6288 }
6289 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6290 .expect("row htt PD");
6291 let mut ainv = Array2::<f64>::zeros((qi, qi));
6293 for col in 0..qi {
6294 let mut e = Array1::<f64>::zeros(qi);
6295 e[col] = 1.0;
6296 let s = cholesky_solve_vector(factor.view(), e.view());
6297 for r in 0..qi {
6298 ainv[[r, col]] = s[r];
6299 }
6300 }
6301 for a in 0..border_dim {
6302 let mut quad = 0.0_f64;
6303 for c in 0..qi {
6304 let hc = slab[c * border_dim + a];
6305 for dd in 0..qi {
6306 quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6307 }
6308 }
6309 diag[a] -= quad;
6310 }
6311 }
6312 Array1::from_vec(diag)
6313 };
6314 let mut cpu = Array1::<f64>::zeros(border_dim);
6315 let cpu_result = {
6316 let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6317 let mut tmp = vec![0.0_f64; border_dim];
6318 sae_framed_schur_matvec_cpu(
6319 &sys,
6320 &data,
6321 ridge_t,
6322 ridge_beta,
6323 v.as_slice().unwrap(),
6324 &mut tmp,
6325 )
6326 .expect("cpu oracle matvec");
6327 out.assign(&Array1::from_vec(tmp));
6328 };
6329 gam_linalg::pcg::pcg_core(
6330 &mut apply,
6331 &rhs.view(),
6332 &precond.view(),
6333 1e-12,
6334 800,
6335 32,
6336 false,
6337 gam_linalg::pcg::DotReduction::Serial,
6338 &mut cpu.view_mut(),
6339 )
6340 };
6341 let s_cpu_resid = oracle_resid(&cpu);
6342 let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6343
6344 assert!(
6347 dev_rel_resid <= 1e-7,
6348 "[#1551] device δβ does not solve the CPU-oracle system: \
6349 ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6350 device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6351 means the device matvec is a DIFFERENT operator (kernel bug)",
6352 diag.stopping_reason,
6353 diag.iterations,
6354 diag.final_relative_residual,
6355 );
6356 assert!(
6359 cpu_rel_resid <= 1e-6,
6360 "[#1551] CPU pcg_core failed to solve the oracle system: \
6361 ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6362 cpu_result.stop,
6363 cpu_result.iterations,
6364 );
6365 }
6366
6367 fn sae_frame_penalty_diag_host_for_test(
6371 data: &DeviceSaePcgData,
6372 ridge_beta: f64,
6373 ) -> Vec<f64> {
6374 let frame = data.frame.as_ref().expect("frame");
6375 let mut diag = vec![ridge_beta; data.beta_dim];
6376 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6377 let m = blk.factor_a.nrows();
6378 for ia in 0..m {
6379 let coeff = blk.factor_a[[ia, ia]];
6380 let base = blk.global_offset + ia * r;
6381 for ib in 0..r {
6382 diag[base + ib] += coeff;
6383 }
6384 }
6385 }
6386 for blk in &frame.frame_blocks {
6387 if blk.atom_i != blk.atom_j {
6388 continue;
6389 }
6390 let r = frame.ranks[blk.atom_i];
6391 let off = frame.border_offsets[blk.atom_i];
6392 let (mi, mj) = blk.g.dim();
6393 for li in 0..mi.min(mj) {
6394 let gii = blk.g[[li, li]];
6395 let base = off + li * r;
6396 for a in 0..r {
6397 diag[base + a] += gii * blk.w[[a, a]];
6398 }
6399 }
6400 }
6401 diag
6402 }
6403
6404 #[test]
6415 fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6416 use crate::arrow_schur::{
6417 DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6418 };
6419
6420 let p = 6usize;
6423 let n_atoms = 8usize;
6424 let ranks: Vec<usize> = (0..n_atoms)
6425 .map(|k| if k % 2 == 0 { 3usize } else { p })
6426 .collect();
6427 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6428 let mut border_offsets = Vec::with_capacity(n_atoms);
6429 let mut acc = 0usize;
6430 for k in 0..n_atoms {
6431 border_offsets.push(acc);
6432 acc += basis_sizes[k] * ranks[k];
6433 }
6434 let border_dim = acc;
6435
6436 let mut state = 0x1551_0017_1026_0922u64;
6437 let mut sample = || -> f64 {
6438 state = state
6439 .wrapping_mul(6364136223846793005)
6440 .wrapping_add(1442695040888963407);
6441 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6442 };
6443 let mut frames: Vec<Array2<f64>> = Vec::new();
6444 for k in 0..n_atoms {
6445 let r = ranks[k];
6446 let mut u = Array2::<f64>::zeros((p, r));
6447 for i in 0..p {
6448 for j in 0..r {
6449 u[[i, j]] = if r == p && i == j {
6450 1.0
6451 } else if r == p {
6452 0.0
6453 } else {
6454 sample()
6455 };
6456 }
6457 }
6458 frames.push(u);
6459 }
6460 let w_of = |i: usize, j: usize| {
6461 let (ui, uj) = (&frames[i], &frames[j]);
6462 let (ri, rj) = (ranks[i], ranks[j]);
6463 let mut w = Array2::<f64>::zeros((ri, rj));
6464 for a in 0..ri {
6465 for b in 0..rj {
6466 let mut s = 0.0;
6467 for c in 0..p {
6468 s += ui[[c, a]] * uj[[c, b]];
6469 }
6470 w[[a, b]] = s;
6471 }
6472 }
6473 w
6474 };
6475 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6476 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6477 pairs.push((i, j));
6478 pairs.push((j, i));
6479 }
6480 let mut frame_blocks = Vec::new();
6481 for &(i, j) in &pairs {
6482 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6483 let mut g = Array2::<f64>::zeros((mi, mj));
6484 for r in 0..mi {
6485 for c in 0..mj {
6486 g[[r, c]] = 0.25 * sample();
6487 }
6488 }
6489 if i == j {
6490 for r in 0..mi.min(mj) {
6491 g[[r, r]] += mi as f64 + 2.0;
6492 }
6493 }
6494 frame_blocks.push(FactoredFrameGBlock {
6495 atom_i: i,
6496 atom_j: j,
6497 g,
6498 w: w_of(i, j),
6499 });
6500 }
6501 let mut smooth_blocks = Vec::new();
6502 let mut smooth_ranks = Vec::new();
6503 for k in 0..n_atoms {
6504 let m = basis_sizes[k];
6505 let mut a = Array2::<f64>::zeros((m, m));
6506 for r in 0..m {
6507 for c in 0..m {
6508 a[[r, c]] = 0.2 * sample();
6509 }
6510 }
6511 let mut s = a.t().dot(&a);
6512 for r in 0..m {
6513 s[[r, r]] += 1.0;
6514 }
6515 smooth_blocks.push(DeviceSaeSmoothBlock {
6516 global_offset: border_offsets[k],
6517 factor_a: s,
6518 });
6519 smooth_ranks.push(ranks[k]);
6520 }
6521 let n = 32usize;
6525 let q = 4usize;
6526 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6527 let mut row_htbeta = Vec::new();
6528 for i in 0..n {
6529 let mut a = Array2::<f64>::zeros((q, q));
6530 for r in 0..q {
6531 for c in 0..q {
6532 a[[r, c]] = sample();
6533 }
6534 }
6535 let mut htt = a.t().dot(&a);
6536 for r in 0..q {
6537 htt[[r, r]] += q as f64 + 1.0;
6538 }
6539 sys.rows[i].htt = htt;
6540 let mut slab = vec![0.0_f64; q * border_dim];
6541 for c in 0..q {
6542 for col in 0..border_dim {
6543 let v = 0.3 * sample();
6544 slab[c * border_dim + col] = v;
6545 sys.rows[i].htbeta[[c, col]] = v;
6546 }
6547 }
6548 row_htbeta.push(slab);
6549 }
6550 let ridge_t = 1e-7;
6551 let ridge_beta = 1e-6;
6552 let data = DeviceSaePcgData {
6553 p,
6554 beta_dim: border_dim,
6555 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6556 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6557 smooth_blocks,
6558 sparse_g_blocks: Vec::new(),
6559 frame: Some(DeviceSaeFrameData {
6560 ranks: ranks.clone(),
6561 basis_sizes: basis_sizes.clone(),
6562 border_offsets: border_offsets.clone(),
6563 frame_blocks,
6564 smooth_ranks,
6565 row_htbeta,
6566 }),
6567 };
6568
6569 let mut probes: Vec<Array1<f64>> = Vec::new();
6573 probes.push(Array1::from_shape_fn(border_dim, |a| {
6574 ((a as f64 + 1.0) * 0.37).sin()
6575 }));
6576 probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6577 for axis in [0usize, border_dim / 3, border_dim - 1] {
6578 let mut e = Array1::<f64>::zeros(border_dim);
6579 e[axis] = 1.0;
6580 probes.push(e);
6581 }
6582
6583 let mut any_ran = false;
6584 let mut worst = 0.0_f64;
6585 for (pi, x) in probes.iter().enumerate() {
6586 let device = match super::framed_schur_matvec_once_on_device(
6587 &sys, &data, ridge_t, ridge_beta, x,
6588 ) {
6589 Ok(out) => out,
6590 Err(failure) => {
6591 assert!(
6595 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6596 "#1551: CUDA device present but the framed device matvec \
6597 declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6598 does not run on GPU"
6599 );
6600 return;
6601 }
6602 };
6603 any_ran = true;
6604 let mut cpu = vec![0.0_f64; border_dim];
6605 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6606 .expect("cpu oracle matvec");
6607 let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6608 for a in 0..border_dim {
6609 let rel = (device[a] - cpu[a]).abs() / scale;
6610 worst = worst.max(rel);
6611 assert!(
6612 rel <= 1e-9,
6613 "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6614 rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6615 device[a],
6616 cpu[a],
6617 );
6618 }
6619 }
6620 if any_ran {
6621 assert!(
6626 gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6627 "#1551: matvec ran but no GPU runtime — unexpected"
6628 );
6629 assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6630 }
6631 }
6632}