1use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_linalg::triangular::{CholeskyGuard, cholesky_factor_in_place, cholesky_solve_vector};
22use crate::arrow_schur::{ArrowSchurSystem, DeviceSaePcgData, PcgDiagnostics};
23
24pub struct ArrowSchurGpuSolution {
26 pub delta_t: Array1<f64>,
27 pub delta_beta: Array1<f64>,
28 pub log_det_hessian: f64,
31}
32
33#[derive(Debug, Clone)]
37pub enum ArrowSchurGpuFailure {
38 Unavailable,
40 RidgeBumpRequired { row: usize, bump: f64 },
43 SchurFactorFailed { reason: String },
46 GpuRequiresDenseSystem {
52 had_hbb_matvec: bool,
53 had_htbeta_matvec: bool,
54 },
55}
56
57const RIDGE_BUMP_EPS_MARGIN: f64 = 1024.0;
68
69#[must_use]
105fn ridge_bump_to_make_pd(htt: ArrayView2<'_, f64>, ridge_t: f64) -> f64 {
106 let d = htt.nrows();
107 let mut scale = 1.0_f64;
110 let mut min_gershgorin_edge = f64::INFINITY;
111 for i in 0..d {
112 let diag = htt[[i, i]];
113 scale = scale.max(diag.abs());
114 let mut off_sum = 0.0_f64;
115 for j in 0..d {
116 if j != i {
117 off_sum += htt[[i, j]].abs();
118 }
119 }
120 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
121 }
122 if !min_gershgorin_edge.is_finite() {
123 return scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
126 }
127 let deficit = -(min_gershgorin_edge + ridge_t);
130 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
131 deficit.max(0.0) + margin
134}
135
136#[cfg(target_os = "linux")]
149#[must_use]
150fn ridge_bump_to_make_pd_colmajor(block: &[f64], d: usize) -> f64 {
151 if d == 0 || block.len() < d * d {
152 return f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
153 }
154 let mut scale = 1.0_f64;
157 let mut min_gershgorin_edge = f64::INFINITY;
158 for i in 0..d {
159 let diag = block[i * d + i];
160 scale = scale.max(diag.abs());
161 let mut off_sum = 0.0_f64;
162 for j in 0..d {
163 if j != i {
164 off_sum += block[j * d + i].abs();
165 }
166 }
167 min_gershgorin_edge = min_gershgorin_edge.min(diag - off_sum);
168 }
169 let margin = scale * f64::EPSILON.sqrt() * RIDGE_BUMP_EPS_MARGIN;
170 if !min_gershgorin_edge.is_finite() {
171 return margin;
172 }
173 (-min_gershgorin_edge).max(0.0) + margin
174}
175
176pub fn solve_arrow_newton_step(
180 sys: &ArrowSchurSystem,
181 ridge_t: f64,
182 ridge_beta: f64,
183) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
184 let n = sys.rows.len();
185 let d = sys.d;
186 let k = sys.k;
187
188 let had_hbb_matvec = sys.hbb_matvec.is_some();
193 let had_htbeta_matvec = sys.htbeta_matvec.is_some();
194 if had_hbb_matvec || had_htbeta_matvec {
195 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
196 had_hbb_matvec,
197 had_htbeta_matvec,
198 });
199 }
200
201 if sys.hbb.dim() != (k, k) {
202 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
203 reason: "CUDA arrow-Schur requires a dense shared beta block".to_string(),
204 });
205 }
206 if n == 0 || d == 0 {
207 return Err(ArrowSchurGpuFailure::Unavailable);
208 }
209 if sys
210 .rows
211 .iter()
212 .any(|row| row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d)
213 {
214 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
215 reason: "row block dimension mismatch".to_string(),
216 });
217 }
218
219 #[cfg(not(target_os = "linux"))]
220 {
221 if ridge_t.is_nan() || ridge_beta.is_nan() {
222 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
223 reason: "ridge is NaN".to_string(),
224 });
225 }
226 Err(ArrowSchurGpuFailure::Unavailable)
227 }
228
229 #[cfg(target_os = "linux")]
230 {
231 if gam_gpu::device_runtime::GpuRuntime::global()
240 .map(gam_gpu::device_runtime::GpuRuntime::device_count)
241 .unwrap_or(0)
242 > 1
243 {
244 match cuda::solve_multi_gpu(sys, ridge_t, ridge_beta) {
245 Ok(sol) => return Ok(sol),
246 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
247 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
248 }
249 Err(ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
250 return Err(ArrowSchurGpuFailure::SchurFactorFailed { reason });
251 }
252 Err(_) => {}
255 }
256 }
257 if crate::gpu_kernels::arrow_schur_nvrtc::system_admits_fused_path(sys) {
263 match cuda::solve_fused(sys, ridge_t, ridge_beta) {
264 Ok(sol) => return Ok(sol),
265 Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
269 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
270 }
271 Err(_) => {}
275 }
276 }
277 cuda::solve(sys, ridge_t, ridge_beta)
278 }
279}
280
281#[cfg(target_os = "linux")]
287fn pack_host(sys: &ArrowSchurSystem, ridge_t: f64) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
288 let n = sys.rows.len();
289 let d = sys.d;
290 let k = sys.k;
291 let mut d_buf = Vec::with_capacity(n * d * d);
292 let mut b_buf = Vec::with_capacity(n * d * k);
293 let mut g_buf = Vec::with_capacity(n * d);
294 for row in &sys.rows {
295 pack_block(row, ridge_t, d, k, &mut d_buf, &mut b_buf, &mut g_buf);
296 }
297 (d_buf, b_buf, g_buf)
298}
299
300#[cfg(target_os = "linux")]
301#[inline]
302fn pack_block(
303 row: &crate::arrow_schur::ArrowRowBlock,
304 ridge_t: f64,
305 d: usize,
306 k: usize,
307 d_buf: &mut Vec<f64>,
308 b_buf: &mut Vec<f64>,
309 g_buf: &mut Vec<f64>,
310) {
311 for col in 0..d {
312 for r in 0..d {
313 let mut value = row.htt[[r, col]];
314 if r == col {
315 value += ridge_t;
316 }
317 d_buf.push(value);
318 }
319 }
320 for col in 0..k {
321 for r in 0..d {
322 b_buf.push(row.htbeta[[r, col]]);
323 }
324 }
325 for r in 0..d {
326 g_buf.push(row.gt[r]);
327 }
328}
329
330#[doc(hidden)]
335#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] pub fn solve_arrow_newton_step_fused_force(
337 sys: &ArrowSchurSystem,
338 ridge_t: f64,
339 ridge_beta: f64,
340) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
341 if ridge_t.is_nan() || ridge_beta.is_nan() {
342 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
343 reason: "ridge is NaN".to_string(),
344 });
345 }
346 #[cfg(not(target_os = "linux"))]
347 {
348 Err(ArrowSchurGpuFailure::Unavailable)
353 }
354 #[cfg(target_os = "linux")]
355 {
356 if crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(sys.rows.len(), sys.d, sys.k)
357 .is_none()
358 {
359 return Err(ArrowSchurGpuFailure::Unavailable);
360 }
361 cuda::solve_fused(sys, ridge_t, ridge_beta)
362 }
363}
364
365pub struct ResidentArrowFrameHandle {
375 #[cfg(target_os = "linux")]
376 inner: cuda::ResidentArrowFrame,
377 #[cfg(not(target_os = "linux"))]
378 _never: std::convert::Infallible,
379}
380
381impl ResidentArrowFrameHandle {
382 pub fn new(
384 sys: &ArrowSchurSystem,
385 ridge_t: f64,
386 ridge_beta: f64,
387 ) -> Result<Self, ArrowSchurGpuFailure> {
388 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() {
391 return Err(ArrowSchurGpuFailure::GpuRequiresDenseSystem {
392 had_hbb_matvec: sys.hbb_matvec.is_some(),
393 had_htbeta_matvec: sys.htbeta_matvec.is_some(),
394 });
395 }
396 #[cfg(not(target_os = "linux"))]
397 {
398 if ridge_t.is_nan() || ridge_beta.is_nan() {
399 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
400 reason: "ridge is NaN".to_string(),
401 });
402 }
403 Err(ArrowSchurGpuFailure::Unavailable)
404 }
405 #[cfg(target_os = "linux")]
406 {
407 Ok(Self {
408 inner: cuda::ResidentArrowFrame::new(sys, ridge_t, ridge_beta)?,
409 })
410 }
411 }
412
413 pub fn solve_gradient(
415 &self,
416 g_t: &[f64],
417 g_beta: &[f64],
418 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
419 #[cfg(not(target_os = "linux"))]
420 {
421 if g_t.iter().chain(g_beta).any(|v| !v.is_finite()) {
422 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
423 reason: "non-finite gradient entry".to_string(),
424 });
425 }
426 Err(ArrowSchurGpuFailure::Unavailable)
427 }
428 #[cfg(target_os = "linux")]
429 {
430 self.inner.solve_gradient(g_t, g_beta)
431 }
432 }
433
434 #[must_use]
436 pub fn log_det_hessian(&self) -> f64 {
437 #[cfg(not(target_os = "linux"))]
438 {
439 panic!("ResidentArrowFrameHandle cannot be constructed off CUDA")
445 }
446 #[cfg(target_os = "linux")]
447 {
448 self.inner.log_det_hessian()
449 }
450 }
451}
452
453pub fn gpu_schur_matvec_backend(
496 sys: &ArrowSchurSystem,
497 ridge_t: f64,
498 ridge_beta: f64,
499) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
500 if sys.htbeta_matvec.is_some() {
503 return build_row_procedural_matvec(sys, ridge_t, ridge_beta);
504 }
505
506 #[cfg(not(target_os = "linux"))]
507 {
508 if ridge_t.is_nan() || ridge_beta.is_nan() {
511 return Err(ArrowSchurGpuFailure::Unavailable);
512 }
513 Err(ArrowSchurGpuFailure::Unavailable)
514 }
515
516 #[cfg(target_os = "linux")]
517 {
518 cuda::build_schur_matvec_backend(sys, ridge_t, ridge_beta)
519 }
520}
521
522fn build_row_procedural_matvec(
539 sys: &ArrowSchurSystem,
540 ridge_t: f64,
541 ridge_beta: f64,
542) -> Result<crate::arrow_schur::GpuSchurMatvec, ArrowSchurGpuFailure> {
543 use std::sync::Arc;
544 let n = sys.rows.len();
545 let k = sys.k;
546 let forward = sys
547 .htbeta_matvec
548 .clone()
549 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
550 let transpose = sys.htbeta_transpose_matvec.clone().ok_or_else(|| {
551 ArrowSchurGpuFailure::SchurFactorFailed {
556 reason: "row-procedural Schur matvec requires htbeta_transpose_matvec; \
557 forward operator installed without its sparse adjoint"
558 .to_string(),
559 }
560 })?;
561
562 let mut factors: Vec<Array2<f64>> = Vec::with_capacity(n);
567 for (i, row) in sys.rows.iter().enumerate() {
568 let di = row.htt.nrows();
569 if row.htt.ncols() != di || row.gt.len() != di {
570 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
571 reason: format!("row {i}: malformed H_tt block {:?}", row.htt.dim()),
572 });
573 }
574 let mut block = row.htt.clone();
575 for r in 0..di {
576 block[[r, r]] += ridge_t;
577 }
578 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
579 .ok_or_else(|| {
580 ArrowSchurGpuFailure::RidgeBumpRequired {
584 row: i,
585 bump: ridge_bump_to_make_pd(row.htt.view(), ridge_t),
586 }
587 })?;
588 factors.push(factor);
589 }
590
591 let penalty_op = sys.effective_penalty_op();
598 let row_dims: Vec<usize> = sys.rows.iter().map(|row| row.htt.nrows()).collect();
599
600 let closure: crate::arrow_schur::GpuSchurMatvec =
601 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
602 assert_eq!(x.len(), k, "row-procedural matvec: x.len() != k");
603 assert_eq!(out.len(), k, "row-procedural matvec: out.len() != k");
604
605 {
608 let x_slice = x.as_slice().expect("x must be contiguous");
609 let out_slice = out.as_slice_mut().expect("out must be contiguous");
610 for a in 0..k {
611 out_slice[a] = ridge_beta * x_slice[a];
612 }
613 penalty_op.matvec(x_slice, out_slice);
614 }
615
616 let parallel = n >= crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN
643 && rayon::current_thread_index().is_none();
644 if parallel {
645 use rayon::prelude::*;
646 const CHUNK: usize = 64;
647 let partials: Vec<Array1<f64>> = (0..n)
648 .into_par_iter()
649 .chunks(CHUNK)
650 .map(|idxs| {
651 let mut neg = Array1::<f64>::zeros(k);
655 for i in idxs {
656 let di = row_dims[i];
657 let mut v_i = Array1::<f64>::zeros(di);
659 forward(i, x.view(), &mut v_i);
660 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
662 transpose(i, w_i.view(), &mut neg);
664 }
665 neg
666 })
667 .collect();
668 let mut neg = Array1::<f64>::zeros(k);
680 for part in &partials {
681 for a in 0..k {
682 neg[a] += part[a];
683 }
684 }
685 for a in 0..k {
686 out[a] -= neg[a];
687 }
688 } else {
689 let mut neg = Array1::<f64>::zeros(k);
691 for i in 0..n {
692 let di = row_dims[i];
693 let mut v_i = Array1::<f64>::zeros(di);
695 forward(i, x.view(), &mut v_i);
696 let w_i = cholesky_solve_vector(factors[i].view(), v_i.view());
698 transpose(i, w_i.view(), &mut neg);
700 }
701 for a in 0..k {
702 out[a] -= neg[a];
703 }
704 }
705 });
706
707 Ok(closure)
708}
709
710pub fn solve_reduced_beta_pcg(
732 s_acc: &Array2<f64>,
733 rhs_beta: &Array1<f64>,
734 max_iterations: usize,
735 relative_tolerance: f64,
736) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
737 solve_reduced_beta_pcg_with_diagnostics(s_acc, rhs_beta, max_iterations, relative_tolerance)
738 .map(|(x, _)| x)
739}
740
741#[doc(hidden)]
742pub fn solve_reduced_beta_pcg_with_diagnostics(
743 s_acc: &Array2<f64>,
744 rhs_beta: &Array1<f64>,
745 max_iterations: usize,
746 relative_tolerance: f64,
747) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
748 let k = rhs_beta.len();
749 if s_acc.dim() != (k, k) {
750 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
751 reason: format!(
752 "reduced-β GPU PCG requires a square (k×k) Schur block; got {:?} for k={k}",
753 s_acc.dim()
754 ),
755 });
756 }
757 if k == 0 {
758 return Err(ArrowSchurGpuFailure::Unavailable);
759 }
760
761 #[cfg(not(target_os = "linux"))]
762 {
763 if relative_tolerance.is_nan() || max_iterations == 0 {
764 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
765 reason: "reduced-β GPU PCG: invalid CG controls".to_string(),
766 });
767 }
768 Err(ArrowSchurGpuFailure::Unavailable)
769 }
770
771 #[cfg(target_os = "linux")]
772 {
773 cuda::solve_reduced_beta_pcg_with_diagnostics(
774 s_acc,
775 rhs_beta,
776 max_iterations,
777 relative_tolerance,
778 )
779 }
780}
781
782pub fn solve_sae_matrix_free_pcg(
783 sys: &ArrowSchurSystem,
784 data: &DeviceSaePcgData,
785 ridge_t: f64,
786 ridge_beta: f64,
787 rhs_beta: &Array1<f64>,
788 max_iterations: usize,
789 relative_tolerance: f64,
790) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
791 if sys.k != data.beta_dim || rhs_beta.len() != data.beta_dim || data.p == 0 {
792 return Err(ArrowSchurGpuFailure::Unavailable);
793 }
794 #[cfg(not(target_os = "linux"))]
795 {
796 if ridge_t.is_nan()
797 || ridge_beta.is_nan()
798 || relative_tolerance.is_nan()
799 || max_iterations == 0
800 {
801 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
802 reason: "SAE matrix-free GPU PCG: invalid controls".to_string(),
803 });
804 }
805 Err(ArrowSchurGpuFailure::Unavailable)
806 }
807 #[cfg(target_os = "linux")]
808 {
809 if data.frame.is_some() {
816 cuda::solve_sae_matrix_free_pcg_framed(
817 sys,
818 data,
819 ridge_t,
820 ridge_beta,
821 rhs_beta,
822 max_iterations,
823 relative_tolerance,
824 )
825 } else {
826 cuda::solve_sae_matrix_free_pcg(
827 sys,
828 data,
829 ridge_t,
830 ridge_beta,
831 rhs_beta,
832 max_iterations,
833 relative_tolerance,
834 )
835 }
836 }
837}
838
839#[doc(hidden)]
848pub fn framed_schur_matvec_once_on_device(
849 sys: &ArrowSchurSystem,
850 data: &DeviceSaePcgData,
851 ridge_t: f64,
852 ridge_beta: f64,
853 x: &Array1<f64>,
854) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
855 if sys.k != data.beta_dim || x.len() != data.beta_dim || data.p == 0 {
856 return Err(ArrowSchurGpuFailure::Unavailable);
857 }
858 if data.frame.is_none() {
859 return Err(ArrowSchurGpuFailure::Unavailable);
860 }
861 #[cfg(not(target_os = "linux"))]
862 {
863 if ridge_t.is_finite() && ridge_beta.is_finite() {
868 return Err(ArrowSchurGpuFailure::Unavailable);
869 }
870 Err(ArrowSchurGpuFailure::Unavailable)
871 }
872 #[cfg(target_os = "linux")]
873 {
874 cuda::framed_schur_matvec_once_on_device(sys, data, ridge_t, ridge_beta, x)
875 }
876}
877
878#[doc(hidden)]
882pub fn solve_arrow_newton_step_dense_reference(
883 sys: &ArrowSchurSystem,
884 ridge_t: f64,
885 ridge_beta: f64,
886) -> Result<ArrowSchurGpuSolution, String> {
887 let n = sys.rows.len();
888 let d = sys.d;
889 let k = sys.k;
890 let total = n.checked_mul(d).ok_or("dimension overflow")? + k;
891 let mut h = Array2::<f64>::zeros((total, total));
892 let mut rhs = Array1::<f64>::zeros(total);
893 for (i, row) in sys.rows.iter().enumerate() {
894 let base = i * d;
895 for c in 0..d {
896 for r in 0..d {
897 h[[base + r, base + c]] = row.htt[[r, c]];
898 }
899 h[[base + c, base + c]] += ridge_t;
900 }
901 for c in 0..k {
902 for r in 0..d {
903 let value = row.htbeta[[r, c]];
904 h[[base + r, n * d + c]] = value;
905 h[[n * d + c, base + r]] = value;
906 }
907 }
908 for r in 0..d {
909 rhs[base + r] = -row.gt[r];
910 }
911 }
912 for c in 0..k {
913 for r in 0..k {
914 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
915 }
916 h[[n * d + c, n * d + c]] += ridge_beta;
917 rhs[n * d + c] = -sys.gb[c];
918 }
919 let factor = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot)
920 .ok_or_else(|| "dense reference Cholesky failed".to_string())?;
921 let mut log_det = 0.0_f64;
922 for i in 0..total {
923 log_det += factor[[i, i]].ln();
924 }
925 log_det *= 2.0;
926 let solved = cholesky_solve_vector(factor.view(), rhs.view());
927 let delta_t = solved.slice(ndarray::s![..n * d]).to_owned();
928 let delta_beta = solved.slice(ndarray::s![n * d..]).to_owned();
929 Ok(ArrowSchurGpuSolution {
930 delta_t,
931 delta_beta,
932 log_det_hessian: log_det,
933 })
934}
935
936#[doc(hidden)]
947pub fn sae_framed_penalty_matvec_cpu(
948 data: &DeviceSaePcgData,
949 ridge_beta: f64,
950 x: &[f64],
951 out: &mut [f64],
952) {
953 let frame = data
954 .frame
955 .as_ref()
956 .expect("sae_framed_penalty_matvec_cpu requires frame metadata");
957 let k = data.beta_dim;
958 for a in 0..k {
959 out[a] = ridge_beta * x[a];
960 }
961 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
963 let off = blk.global_offset;
964 let m = blk.factor_a.nrows();
965 for i_a in 0..m {
966 for i_b in 0..r {
967 let mut acc = 0.0_f64;
968 for j_a in 0..m {
969 let s = blk.factor_a[[i_a, j_a]];
970 if s == 0.0 {
971 continue;
972 }
973 acc += s * x[off + j_a * r + i_b];
974 }
975 out[off + i_a * r + i_b] += acc;
976 }
977 }
978 }
979 for blk in &frame.frame_blocks {
981 let r_i = frame.ranks[blk.atom_i];
982 let r_j = frame.ranks[blk.atom_j];
983 let off_i = frame.border_offsets[blk.atom_i];
984 let off_j = frame.border_offsets[blk.atom_j];
985 let (m_i, m_j) = blk.g.dim();
986 for li in 0..m_i {
987 let yi_base = off_i + li * r_i;
988 for lj in 0..m_j {
989 let g = blk.g[[li, lj]];
990 if g == 0.0 {
991 continue;
992 }
993 let xj_base = off_j + lj * r_j;
994 for a in 0..r_i {
995 let mut acc = 0.0_f64;
996 for b in 0..r_j {
997 acc += blk.w[[a, b]] * x[xj_base + b];
998 }
999 out[yi_base + a] += g * acc;
1000 }
1001 }
1002 }
1003 }
1004}
1005
1006#[doc(hidden)]
1015pub fn sae_framed_schur_matvec_cpu(
1016 sys: &ArrowSchurSystem,
1017 data: &DeviceSaePcgData,
1018 ridge_t: f64,
1019 ridge_beta: f64,
1020 x: &[f64],
1021 out: &mut [f64],
1022) -> Result<(), String> {
1023 let frame = data
1024 .frame
1025 .as_ref()
1026 .ok_or("sae_framed_schur_matvec_cpu requires frame metadata")?;
1027 let k = data.beta_dim;
1028 sae_framed_penalty_matvec_cpu(data, ridge_beta, x, out);
1029 if frame.row_htbeta.len() != sys.rows.len() {
1030 return Err(format!(
1031 "sae_framed_schur_matvec_cpu: {} row_htbeta slabs but {} rows",
1032 frame.row_htbeta.len(),
1033 sys.rows.len()
1034 ));
1035 }
1036 for (i, row) in sys.rows.iter().enumerate() {
1037 let slab = &frame.row_htbeta[i];
1038 if slab.is_empty() {
1039 continue;
1040 }
1041 let qi = sys.row_dims[i];
1042 if qi == 0 || slab.len() != qi * k {
1043 continue;
1044 }
1045 let mut h = vec![0.0_f64; qi];
1047 for c in 0..qi {
1048 let base = c * k;
1049 let mut acc = 0.0_f64;
1050 for a in 0..k {
1051 acc += slab[base + a] * x[a];
1052 }
1053 h[c] = acc;
1054 }
1055 let mut block = row.htt.clone();
1057 for d in 0..qi {
1058 block[[d, d]] += ridge_t;
1059 }
1060 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
1061 .ok_or_else(|| format!("sae_framed_schur_matvec_cpu: row {i} H_tt not PD"))?;
1062 let s = cholesky_solve_vector(factor.view(), Array1::from_vec(h).view());
1063 for c in 0..qi {
1065 let sc = s[c];
1066 if sc == 0.0 {
1067 continue;
1068 }
1069 let base = c * k;
1070 for a in 0..k {
1071 out[a] -= slab[base + a] * sc;
1072 }
1073 }
1074 }
1075 Ok(())
1076}
1077
1078#[cfg(target_os = "linux")]
1079mod cuda {
1080 use super::{ArrowSchurGpuFailure, ArrowSchurGpuSolution, pack_block, pack_host};
1081 use gam_gpu::driver::to_i32;
1082 use gam_gpu::linalg_dispatch::{DispatchOp, route_through_gpu};
1083 use crate::arrow_schur::{
1084 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, PcgDiagnostics, PcgStopReason,
1085 };
1086 use cudarc::cublas::sys::{
1087 cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
1088 };
1089 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
1090 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1091 use cudarc::driver::{
1092 CudaContext, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig,
1093 PushKernelArg,
1094 };
1095 use ndarray::Array1;
1096 use std::sync::{Arc, OnceLock};
1097
1098 struct RowSlot {
1103 d_block: Vec<f64>, b_block: Vec<f64>, g_vec: Vec<f64>, l_block: Vec<f64>, u_vec: Vec<f64>, y_block: Vec<f64>, log_det_local: f64,
1112 bump: Option<f64>,
1115 tile_partial_schur: Option<Vec<f64>>, tile_partial_rhs: Option<Vec<f64>>, delta_t_block: Vec<f64>, }
1121
1122 pub(super) fn solve_multi_gpu(
1143 sys: &ArrowSchurSystem,
1144 ridge_t: f64,
1145 ridge_beta: f64,
1146 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1147 let n = sys.rows.len();
1148 let d = sys.d;
1149 let k = sys.k;
1150 if n == 0 || d == 0 || k == 0 {
1151 return Err(ArrowSchurGpuFailure::Unavailable);
1152 }
1153 if sys.hbb_matvec.is_some() || sys.htbeta_matvec.is_some() || sys.hbb.dim() != (k, k) {
1157 return Err(ArrowSchurGpuFailure::Unavailable);
1158 }
1159
1160 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
1161 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1162 if runtime.device_count() < 2 {
1163 return Err(ArrowSchurGpuFailure::Unavailable);
1164 }
1165
1166 let mut slots: Vec<RowSlot> = Vec::with_capacity(n);
1168 for row in &sys.rows {
1169 if row.htt.dim() != (d, d) || row.htbeta.dim() != (d, k) || row.gt.len() != d {
1170 return Err(ArrowSchurGpuFailure::Unavailable);
1171 }
1172 let mut d_block = Vec::with_capacity(d * d);
1173 let mut b_block = Vec::with_capacity(d * k);
1174 let mut g_vec = Vec::with_capacity(d);
1175 pack_block(row, ridge_t, d, k, &mut d_block, &mut b_block, &mut g_vec);
1176 slots.push(RowSlot {
1177 d_block,
1178 b_block,
1179 g_vec,
1180 l_block: Vec::new(),
1181 u_vec: Vec::new(),
1182 y_block: Vec::new(),
1183 log_det_local: 0.0,
1184 bump: None,
1185 tile_partial_schur: None,
1186 tile_partial_rhs: None,
1187 delta_t_block: vec![0.0; d],
1188 });
1189 }
1190
1191 let forward_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1193 forward_tile(ordinal, d, k, tile)
1194 });
1195 if forward_ok.is_none() {
1196 return Err(ArrowSchurGpuFailure::Unavailable);
1197 }
1198
1199 let row_base_of_tile = gam_gpu::pool::balanced_partition(runtime, n);
1201 if let Some((row, bump)) = slots
1202 .iter()
1203 .enumerate()
1204 .find_map(|(i, slot)| slot.bump.map(|b| (i, b)))
1205 {
1206 return Err(ArrowSchurGpuFailure::RidgeBumpRequired { row, bump });
1207 }
1208
1209 let mut schur_host = vec![0.0_f64; k * k];
1214 for col in 0..k {
1215 for row in 0..k {
1216 let mut v = sys.hbb[[row, col]];
1217 if row == col {
1218 v += ridge_beta;
1219 }
1220 schur_host[col * k + row] = v;
1221 }
1222 }
1223 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1224 let mut log_det = 0.0_f64;
1225 for start in tile_starts(&row_base_of_tile) {
1226 let slot = &slots[start];
1227 let partial_schur = slot
1228 .tile_partial_schur
1229 .as_ref()
1230 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1231 let partial_rhs = slot
1232 .tile_partial_rhs
1233 .as_ref()
1234 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1235 for idx in 0..k * k {
1240 schur_host[idx] += partial_schur[idx];
1241 }
1242 for a in 0..k {
1243 rhs_host[a] += partial_rhs[a];
1244 }
1245 }
1246 for slot in &slots {
1247 log_det += slot.log_det_local;
1248 }
1249
1250 let primary = runtime.selected_device().ordinal;
1254 let stream = gam_gpu::device_runtime::cuda_context_for(primary)
1255 .and_then(|ctx| ctx.new_stream().ok())
1256 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1257 let solver =
1258 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1259 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1260 let mut schur_dev = stream
1261 .clone_htod(&schur_host)
1262 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1263 let mut rhs_dev = stream
1264 .clone_htod(&rhs_host)
1265 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1266 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1267 if info != 0 {
1268 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1269 reason: format!("multi-GPU Schur Cholesky failed at pivot {info}"),
1270 });
1271 }
1272 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1273 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1274 let delta_beta_host = stream
1275 .clone_dtoh(&rhs_dev)
1276 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1277 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1278 let l_schur_host = stream
1279 .clone_dtoh(&schur_dev)
1280 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1281 for j in 0..k {
1282 log_det += l_schur_host[j * k + j].ln();
1283 }
1284 log_det *= 2.0;
1285
1286 let delta_beta_ref = &delta_beta_host;
1288 let back_ok = gam_gpu::pool::scatter_batched(runtime, &mut slots, |ordinal, tile| {
1289 back_sub_tile(ordinal, d, k, delta_beta_ref, tile)
1290 });
1291 if back_ok.is_none() {
1292 return Err(ArrowSchurGpuFailure::Unavailable);
1293 }
1294
1295 let mut delta_t = Array1::<f64>::zeros(n * d);
1297 for (i, slot) in slots.iter().enumerate() {
1298 let base = i * d;
1299 for r in 0..d {
1300 delta_t[base + r] = slot.delta_t_block[r];
1301 }
1302 }
1303
1304 Ok(ArrowSchurGpuSolution {
1305 delta_t,
1306 delta_beta,
1307 log_det_hessian: log_det,
1308 })
1309 }
1310
1311 fn tile_starts(tiles: &[(usize, std::ops::Range<usize>)]) -> impl Iterator<Item = usize> + '_ {
1314 tiles.iter().map(|(_, range)| range.start)
1315 }
1316
1317 fn forward_tile(ordinal: usize, d: usize, k: usize, tile: &mut [RowSlot]) -> Option<()> {
1325 if tile.is_empty() {
1326 return Some(());
1327 }
1328 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1331 .and_then(|ctx| ctx.new_stream().ok())?;
1332 let solver = DnHandle::new(stream.clone()).ok()?;
1333 let blas = CudaBlas::new(stream.clone()).ok()?;
1334 let m = tile.len();
1335
1336 let mut d_host = Vec::with_capacity(m * d * d);
1339 let mut b_host = Vec::with_capacity(m * d * k);
1340 let mut g_host = Vec::with_capacity(m * d);
1341 for slot in tile.iter() {
1342 d_host.extend_from_slice(&slot.d_block);
1343 b_host.extend_from_slice(&slot.b_block);
1344 g_host.extend_from_slice(&slot.g_vec);
1345 }
1346 let mut d_dev = stream.clone_htod(&d_host).ok()?;
1347 let mut b_dev = stream.clone_htod(&b_host).ok()?;
1348 let mut g_dev = stream.clone_htod(&g_host).ok()?;
1349
1350 let info_host = potrf_batched(&solver, &stream, d, m, &mut d_dev).ok()?;
1356 if let Some(local) = info_host.iter().position(|info| *info != 0) {
1357 tile[local].bump = Some(super::ridge_bump_to_make_pd_colmajor(
1358 &tile[local].d_block,
1359 d,
1360 ));
1361 return Some(());
1362 }
1363
1364 trsm_batched_lower_inplace(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1366 trsm_batched_lower_inplace(&blas, &stream, d, m, k, &d_dev, &mut b_dev).ok()?;
1367
1368 let mut schur_dev = stream.alloc_zeros::<f64>(k * k).ok()?;
1370 let mut rhs_dev = stream.alloc_zeros::<f64>(k).ok()?;
1371 accumulate_schur(&blas, d, k, m, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev).ok()?;
1372
1373 let l_host = stream.clone_dtoh(&d_dev).ok()?;
1375 let u_host = stream.clone_dtoh(&g_dev).ok()?;
1376 let y_host = stream.clone_dtoh(&b_dev).ok()?;
1377 let partial_schur = stream.clone_dtoh(&schur_dev).ok()?;
1378 let partial_rhs = stream.clone_dtoh(&rhs_dev).ok()?;
1379
1380 for (local, slot) in tile.iter_mut().enumerate() {
1381 let l_base = local * d * d;
1382 let u_base = local * d;
1383 let y_base = local * d * k;
1384 slot.l_block = l_host[l_base..l_base + d * d].to_vec();
1385 slot.u_vec = u_host[u_base..u_base + d].to_vec();
1386 slot.y_block = y_host[y_base..y_base + d * k].to_vec();
1387 let mut log_det_local = 0.0_f64;
1388 for j in 0..d {
1389 log_det_local += l_host[l_base + j * d + j].ln();
1390 }
1391 slot.log_det_local = log_det_local;
1392 }
1393 tile[0].tile_partial_schur = Some(partial_schur);
1394 tile[0].tile_partial_rhs = Some(partial_rhs);
1395 Some(())
1396 }
1397
1398 fn back_sub_tile(
1402 ordinal: usize,
1403 d: usize,
1404 k: usize,
1405 delta_beta: &[f64],
1406 tile: &mut [RowSlot],
1407 ) -> Option<()> {
1408 if tile.is_empty() {
1409 return Some(());
1410 }
1411 let stream = gam_gpu::device_runtime::cuda_context_for(ordinal)
1414 .and_then(|ctx| ctx.new_stream().ok())?;
1415 let blas = CudaBlas::new(stream.clone()).ok()?;
1416 let m = tile.len();
1417
1418 let mut l_host = Vec::with_capacity(m * d * d);
1419 let mut u_host = Vec::with_capacity(m * d);
1420 let mut y_host = Vec::with_capacity(m * d * k);
1421 for slot in tile.iter() {
1422 l_host.extend_from_slice(&slot.l_block);
1423 u_host.extend_from_slice(&slot.u_vec);
1424 y_host.extend_from_slice(&slot.y_block);
1425 }
1426 let d_dev = stream.clone_htod(&l_host).ok()?;
1427 let mut g_dev = stream.clone_htod(&u_host).ok()?;
1428 let b_dev = stream.clone_htod(&y_host).ok()?;
1429 let rhs_dev = stream.clone_htod(&delta_beta.to_vec()).ok()?;
1430
1431 accumulate_back_sub_rhs(&blas, d, k, m, &b_dev, &rhs_dev, &mut g_dev).ok()?;
1433 trsm_batched_lower_inplace_transposed(&blas, &stream, d, m, 1, &d_dev, &mut g_dev).ok()?;
1434 let x_host = stream.clone_dtoh(&g_dev).ok()?;
1435 for (local, slot) in tile.iter_mut().enumerate() {
1436 let base = local * d;
1437 for r in 0..d {
1438 slot.delta_t_block[r] = -x_host[base + r];
1439 }
1440 }
1441 Some(())
1442 }
1443
1444 pub(super) fn solve(
1445 sys: &ArrowSchurSystem,
1446 ridge_t: f64,
1447 ridge_beta: f64,
1448 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
1449 let n = sys.rows.len();
1450 let d = sys.d;
1451 let k = sys.k;
1452 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
1453 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1454
1455 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
1456 .and_then(|ctx| ctx.new_stream().ok())
1457 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
1458 let solver =
1459 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1460 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1461
1462 let (d_host, b_host, g_host) = pack_host(sys, ridge_t);
1464 let mut d_dev = stream
1465 .clone_htod(&d_host)
1466 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1467 let mut b_dev = stream
1468 .clone_htod(&b_host)
1469 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1470 let mut g_dev = stream
1471 .clone_htod(&g_host)
1472 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1473
1474 let info_host = potrf_batched(&solver, &stream, d, n, &mut d_dev)?;
1482 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
1483 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
1487 row: idx,
1488 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
1489 });
1490 }
1491
1492 trsm_batched_lower_inplace(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1495 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &d_dev, &mut b_dev)?;
1498
1499 let schur_init: Vec<f64> = {
1518 let mut tmp = Vec::with_capacity(k * k);
1519 for col in 0..k {
1520 for row in 0..k {
1521 let mut v = sys.hbb[[row, col]];
1522 if row == col {
1523 v += ridge_beta;
1524 }
1525 tmp.push(v);
1526 }
1527 }
1528 tmp
1529 };
1530 let mut schur_dev = stream
1531 .clone_htod(&schur_init)
1532 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1533 let rhs_init: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
1534 let mut rhs_dev = stream
1535 .clone_htod(&rhs_init)
1536 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1537
1538 accumulate_schur(&blas, d, k, n, &b_dev, &g_dev, &mut schur_dev, &mut rhs_dev)?;
1539
1540 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
1542 if info != 0 {
1543 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
1544 reason: format!("Schur Cholesky failed at pivot {info}"),
1545 });
1546 }
1547 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
1549 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
1550 let delta_beta_host = stream
1551 .clone_dtoh(&rhs_dev)
1552 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1553 let delta_beta = Array1::from_vec(delta_beta_host.clone());
1554
1555 accumulate_back_sub_rhs(&blas, d, k, n, &b_dev, &rhs_dev, &mut g_dev)?;
1563 trsm_batched_lower_inplace_transposed(&blas, &stream, d, n, 1, &d_dev, &mut g_dev)?;
1564
1565 let x_host = stream
1566 .clone_dtoh(&g_dev)
1567 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1568 let mut delta_t = Array1::<f64>::zeros(n * d);
1569 for (i, v) in x_host.iter().enumerate() {
1570 delta_t[i] = -*v;
1571 }
1572
1573 let l_local_host = stream
1575 .clone_dtoh(&d_dev)
1576 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1577 let l_schur_host = stream
1578 .clone_dtoh(&schur_dev)
1579 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1580 let mut log_det = 0.0_f64;
1581 for i in 0..n {
1582 let base = i * d * d;
1583 for j in 0..d {
1584 log_det += l_local_host[base + j * d + j].ln();
1585 }
1586 }
1587 for j in 0..k {
1588 log_det += l_schur_host[j * k + j].ln();
1589 }
1590 log_det *= 2.0;
1591
1592 Ok(ArrowSchurGpuSolution {
1593 delta_t,
1594 delta_beta,
1595 log_det_hessian: log_det,
1596 })
1597 }
1598
1599 fn potrf_batched(
1600 solver: &DnHandle,
1601 stream: &Arc<CudaStream>,
1602 p: usize,
1603 batch: usize,
1604 matrices: &mut CudaSlice<f64>,
1605 ) -> Result<Vec<i32>, ArrowSchurGpuFailure> {
1606 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1607 let batch_i = to_i32(batch).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1608 let matrix_len = p * p;
1609 let bytes_per = (matrix_len * std::mem::size_of::<f64>()) as u64;
1610 let (base_ptr, _record) = matrices.device_ptr_mut(stream);
1611 let mut ptrs = Vec::with_capacity(batch);
1612 for idx in 0..batch {
1613 ptrs.push(base_ptr + (idx as u64) * bytes_per);
1614 }
1615 let mut ptrs_dev = stream
1616 .clone_htod(&ptrs)
1617 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1618 let mut info_dev = stream
1619 .alloc_zeros::<i32>(batch)
1620 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1621 let status = {
1622 let (ptrs_ptr, _ptrs_record) = ptrs_dev.device_ptr_mut(stream);
1623 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
1624 unsafe {
1627 cusolver_sys::cusolverDnDpotrfBatched(
1628 solver.cu(),
1629 cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1630 p_i,
1631 ptrs_ptr as *mut *mut f64,
1632 p_i,
1633 info_ptr as *mut i32,
1634 batch_i,
1635 )
1636 }
1637 };
1638 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1639 return Err(ArrowSchurGpuFailure::Unavailable);
1640 }
1641 stream
1642 .clone_dtoh(&info_dev)
1643 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
1644 }
1645
1646 fn potrf_single(
1647 solver: &DnHandle,
1648 stream: &Arc<CudaStream>,
1649 p: usize,
1650 matrix: &mut CudaSlice<f64>,
1651 ) -> Result<i32, ArrowSchurGpuFailure> {
1652 let p_i = to_i32(p).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1653 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
1654 let mut lwork = 0_i32;
1655 {
1656 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1657 let status = unsafe {
1659 cusolver_sys::cusolverDnDpotrf_bufferSize(
1660 solver.cu(),
1661 uplo,
1662 p_i,
1663 mat_ptr as *mut f64,
1664 p_i,
1665 &mut lwork,
1666 )
1667 };
1668 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1669 return Err(ArrowSchurGpuFailure::Unavailable);
1670 }
1671 }
1672 let lwork_usize = usize::try_from(lwork).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1673 let mut workspace = stream
1674 .alloc_zeros::<f64>(lwork_usize.max(1))
1675 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1676 let mut info_dev = stream
1677 .alloc_zeros::<i32>(1)
1678 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1679 {
1680 let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
1681 let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
1682 let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
1683 let status = unsafe {
1685 cusolver_sys::cusolverDnDpotrf(
1686 solver.cu(),
1687 uplo,
1688 p_i,
1689 mat_ptr as *mut f64,
1690 p_i,
1691 work_ptr as *mut f64,
1692 lwork,
1693 info_ptr as *mut i32,
1694 )
1695 };
1696 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1697 return Err(ArrowSchurGpuFailure::Unavailable);
1698 }
1699 }
1700 let info_host = stream
1701 .clone_dtoh(&info_dev)
1702 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1703 Ok(info_host[0])
1704 }
1705
1706 fn trsm_batched_lower_inplace(
1710 blas: &CudaBlas,
1711 stream: &Arc<CudaStream>,
1712 d: usize,
1713 n: usize,
1714 nrhs: usize,
1715 l_stack: &CudaSlice<f64>,
1716 rhs_stack: &mut CudaSlice<f64>,
1717 ) -> Result<(), ArrowSchurGpuFailure> {
1718 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, false)
1719 }
1720
1721 fn trsm_batched_lower_inplace_transposed(
1723 blas: &CudaBlas,
1724 stream: &Arc<CudaStream>,
1725 d: usize,
1726 n: usize,
1727 nrhs: usize,
1728 l_stack: &CudaSlice<f64>,
1729 rhs_stack: &mut CudaSlice<f64>,
1730 ) -> Result<(), ArrowSchurGpuFailure> {
1731 trsm_batched_inplace_inner(blas, stream, d, n, nrhs, l_stack, rhs_stack, true)
1732 }
1733
1734 fn trsm_batched_inplace_inner(
1735 blas: &CudaBlas,
1736 stream: &Arc<CudaStream>,
1737 d: usize,
1738 n: usize,
1739 nrhs: usize,
1740 l_stack: &CudaSlice<f64>,
1741 rhs_stack: &mut CudaSlice<f64>,
1742 transposed: bool,
1743 ) -> Result<(), ArrowSchurGpuFailure> {
1744 let alpha = 1.0_f64;
1745 let d_i = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1746 let nrhs_i = to_i32(nrhs).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1747 let batch_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1748 let l_bytes_per = (d * d * std::mem::size_of::<f64>()) as u64;
1749 let rhs_bytes_per = (d * nrhs * std::mem::size_of::<f64>()) as u64;
1750 let (l_base, _l_record) = l_stack.device_ptr(stream);
1751 let (rhs_base, _rhs_record) = rhs_stack.device_ptr_mut(stream);
1752 let mut l_ptrs = Vec::with_capacity(n);
1753 let mut rhs_ptrs = Vec::with_capacity(n);
1754 for i in 0..n {
1755 l_ptrs.push(l_base + (i as u64) * l_bytes_per);
1756 rhs_ptrs.push(rhs_base + (i as u64) * rhs_bytes_per);
1757 }
1758 let mut l_ptrs_dev = stream
1759 .clone_htod(&l_ptrs)
1760 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1761 let mut rhs_ptrs_dev = stream
1762 .clone_htod(&rhs_ptrs)
1763 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1764 let (l_ptrs_ptr, _l_ptrs_rec) = l_ptrs_dev.device_ptr_mut(stream);
1765 let (rhs_ptrs_ptr, _rhs_ptrs_rec) = rhs_ptrs_dev.device_ptr_mut(stream);
1766 let op = if transposed {
1767 cublasOperation_t::CUBLAS_OP_T
1768 } else {
1769 cublasOperation_t::CUBLAS_OP_N
1770 };
1771 let handle = *blas.handle();
1772 let status = unsafe {
1775 cudarc::cublas::sys::cublasDtrsmBatched(
1776 handle,
1777 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1778 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
1779 op,
1780 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1781 d_i,
1782 nrhs_i,
1783 &alpha,
1784 l_ptrs_ptr as *const *const f64,
1785 d_i,
1786 rhs_ptrs_ptr as *const *mut f64,
1787 d_i,
1788 batch_i,
1789 )
1790 };
1791 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1792 return Err(ArrowSchurGpuFailure::Unavailable);
1793 }
1794 Ok(())
1795 }
1796
1797 fn trsm_single(
1800 blas: &CudaBlas,
1801 stream: &Arc<CudaStream>,
1802 n: usize,
1803 l: &CudaSlice<f64>,
1804 rhs: &mut CudaSlice<f64>,
1805 upper: bool,
1806 transposed: bool,
1807 ) -> Result<(), ArrowSchurGpuFailure> {
1808 let alpha = 1.0_f64;
1809 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
1810 let handle = *blas.handle();
1811 let (l_ptr, _l_rec) = l.device_ptr(stream);
1812 let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
1813 let status = unsafe {
1815 cudarc::cublas::sys::cublasDtrsm_v2(
1816 handle,
1817 cublasSideMode_t::CUBLAS_SIDE_LEFT,
1818 if upper {
1819 cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
1820 } else {
1821 cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
1822 },
1823 if transposed {
1824 cublasOperation_t::CUBLAS_OP_T
1825 } else {
1826 cublasOperation_t::CUBLAS_OP_N
1827 },
1828 cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1829 n_i,
1830 1,
1831 &alpha,
1832 l_ptr as *const f64,
1833 n_i,
1834 rhs_ptr as *mut f64,
1835 n_i,
1836 )
1837 };
1838 if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1839 return Err(ArrowSchurGpuFailure::Unavailable);
1840 }
1841 Ok(())
1842 }
1843
1844 fn accumulate_schur(
1848 blas: &CudaBlas,
1849 d: usize,
1850 k: usize,
1851 n: usize,
1852 y_stack: &CudaSlice<f64>,
1853 u_stack: &CudaSlice<f64>,
1854 schur: &mut CudaSlice<f64>,
1855 rhs: &mut CudaSlice<f64>,
1856 ) -> Result<(), ArrowSchurGpuFailure> {
1857 let y_block_elems = d * k;
1858 let u_block_elems = d;
1859 for i in 0..n {
1860 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1861 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1862 let gemm_cfg = GemmConfig::<f64> {
1864 transa: cublasOperation_t::CUBLAS_OP_T,
1865 transb: cublasOperation_t::CUBLAS_OP_N,
1866 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1867 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1868 k: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1869 alpha: -1.0,
1870 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1871 ldb: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1872 beta: 1.0,
1873 ldc: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1874 };
1875 unsafe { blas.gemm(gemm_cfg, &y_slice, &y_slice, schur) }
1877 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1878 let gemv_cfg = GemvConfig::<f64> {
1880 trans: cublasOperation_t::CUBLAS_OP_T,
1881 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1882 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1883 alpha: 1.0,
1884 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1885 incx: 1,
1886 beta: 1.0,
1887 incy: 1,
1888 };
1889 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1892 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1893 }
1894 Ok(())
1895 }
1896
1897 fn accumulate_schur_rhs_only(
1905 blas: &CudaBlas,
1906 d: usize,
1907 k: usize,
1908 n: usize,
1909 y_stack: &CudaSlice<f64>,
1910 u_stack: &CudaSlice<f64>,
1911 rhs: &mut CudaSlice<f64>,
1912 ) -> Result<(), ArrowSchurGpuFailure> {
1913 let y_block_elems = d * k;
1914 let u_block_elems = d;
1915 for i in 0..n {
1916 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1917 let u_slice = u_stack.slice(i * u_block_elems..(i + 1) * u_block_elems);
1918 let gemv_cfg = GemvConfig::<f64> {
1919 trans: cublasOperation_t::CUBLAS_OP_T,
1920 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1921 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1922 alpha: 1.0,
1923 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1924 incx: 1,
1925 beta: 1.0,
1926 incy: 1,
1927 };
1928 unsafe { blas.gemv(gemv_cfg, &y_slice, &u_slice, rhs) }
1931 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1932 }
1933 Ok(())
1934 }
1935
1936 fn accumulate_back_sub_rhs(
1939 blas: &CudaBlas,
1940 d: usize,
1941 k: usize,
1942 n: usize,
1943 y_stack: &CudaSlice<f64>,
1944 delta_beta: &CudaSlice<f64>,
1945 u_stack: &mut CudaSlice<f64>,
1946 ) -> Result<(), ArrowSchurGpuFailure> {
1947 let y_block_elems = d * k;
1948 let u_block_elems = d;
1949 for i in 0..n {
1950 let y_slice = y_stack.slice(i * y_block_elems..(i + 1) * y_block_elems);
1951 let mut u_slice = u_stack.slice_mut(i * u_block_elems..(i + 1) * u_block_elems);
1952 let gemv_cfg = GemvConfig::<f64> {
1953 trans: cublasOperation_t::CUBLAS_OP_N,
1954 m: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1955 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1956 alpha: 1.0,
1957 lda: to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?,
1958 incx: 1,
1959 beta: 1.0,
1960 incy: 1,
1961 };
1962 unsafe { blas.gemv(gemv_cfg, &y_slice, delta_beta, &mut u_slice) }
1965 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
1966 }
1967 Ok(())
1968 }
1969
1970 use std::collections::HashMap;
1986 use std::sync::Mutex;
1987
1988 struct FusedModuleCache {
1993 modules: Mutex<
1994 HashMap<crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey, Arc<CudaModule>>,
1995 >,
1996 }
1997
1998 fn fused_module_cache() -> &'static FusedModuleCache {
1999 static CACHE: OnceLock<FusedModuleCache> = OnceLock::new();
2000 CACHE.get_or_init(|| FusedModuleCache {
2001 modules: Mutex::new(HashMap::new()),
2002 })
2003 }
2004
2005 fn fused_module_for(
2006 ctx: &Arc<CudaContext>,
2007 key: crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey,
2008 ) -> Result<Arc<CudaModule>, ArrowSchurGpuFailure> {
2009 let cache = fused_module_cache();
2010 if let Ok(guard) = cache.modules.lock() {
2011 if let Some(existing) = guard.get(&key) {
2012 return Ok(existing.clone());
2013 }
2014 }
2015 let src = crate::gpu_kernels::arrow_schur_nvrtc::forward_kernel_source(
2016 key.p_max as usize,
2017 key.r_template as usize,
2018 );
2019 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).map_err(|err| {
2020 ArrowSchurGpuFailure::SchurFactorFailed {
2021 reason: format!(
2022 "arrow-schur fused NVRTC compile (p_max={}, r={}): {err}",
2023 key.p_max, key.r_template
2024 ),
2025 }
2026 })?;
2027 let module = ctx
2028 .load_module(ptx)
2029 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2030 if let Ok(mut guard) = cache.modules.lock() {
2031 guard.entry(key).or_insert_with(|| module.clone());
2032 }
2033 Ok(module)
2034 }
2035
2036 const PCG_VECTOR_KERNEL_SOURCE: &str = r#"
2037extern "C" __global__ void arrow_pcg_jacobi_mul(
2038 const double* __restrict__ inv_diag,
2039 const double* __restrict__ r,
2040 double* __restrict__ z,
2041 int n
2042) {
2043 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2044 if (idx < n) {
2045 z[idx] = inv_diag[idx] * r[idx];
2046 }
2047}
2048
2049extern "C" __global__ void arrow_pcg_update_p(
2050 const double* __restrict__ z,
2051 double beta,
2052 double* __restrict__ p,
2053 int n
2054) {
2055 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2056 if (idx < n) {
2057 p[idx] = z[idx] + beta * p[idx];
2058 }
2059}
2060
2061extern "C" __global__ void arrow_sae_init(
2062 double* __restrict__ out,
2063 const double* __restrict__ x,
2064 double ridge,
2065 int n
2066) {
2067 int idx = blockIdx.x * blockDim.x + threadIdx.x;
2068 if (idx < n) {
2069 out[idx] = ridge * x[idx];
2070 }
2071}
2072
2073extern "C" __global__ void arrow_sae_smooth_matvec(
2074 const double* __restrict__ x,
2075 double* __restrict__ out,
2076 const int* __restrict__ block_offsets,
2077 const int* __restrict__ block_m,
2078 const int* __restrict__ factor_ptr,
2079 const double* __restrict__ factors,
2080 int p,
2081 int n_blocks
2082) {
2083 int block_id = blockIdx.y;
2084 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2085 if (block_id >= n_blocks) {
2086 return;
2087 }
2088 int m = block_m[block_id];
2089 int total = m * p;
2090 if (linear >= total) {
2091 return;
2092 }
2093 int li = linear / p;
2094 int oc = linear - li * p;
2095 int off = block_offsets[block_id];
2096 int fbase = factor_ptr[block_id];
2097 double acc = 0.0;
2098 for (int lj = 0; lj < m; ++lj) {
2099 double a = factors[fbase + li * m + lj];
2100 acc += a * x[off + lj * p + oc];
2101 }
2102 out[off + li * p + oc] += acc;
2103}
2104
2105extern "C" __global__ void arrow_sae_sparse_g_matvec(
2106 const double* __restrict__ x,
2107 double* __restrict__ out,
2108 const int* __restrict__ row_off,
2109 const int* __restrict__ col_off,
2110 const int* __restrict__ rows,
2111 const int* __restrict__ cols,
2112 const int* __restrict__ data_ptr,
2113 const double* __restrict__ data,
2114 int p,
2115 int n_blocks
2116) {
2117 int block_id = blockIdx.y;
2118 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2119 if (block_id >= n_blocks) {
2120 return;
2121 }
2122 int m_i = rows[block_id];
2123 int m_j = cols[block_id];
2124 int total = m_i * p;
2125 if (linear >= total) {
2126 return;
2127 }
2128 int li = linear / p;
2129 int oc = linear - li * p;
2130 int rbase = row_off[block_id];
2131 int cbase = col_off[block_id];
2132 int dbase = data_ptr[block_id];
2133 double acc = 0.0;
2134 for (int lj = 0; lj < m_j; ++lj) {
2135 acc += data[dbase + li * m_j + lj] * x[(cbase + lj) * p + oc];
2136 }
2137 // #1017 — a row atom co-occurs with multiple column atoms, so several
2138 // concurrent (atom_i, atom_j) blocks (blockIdx.y) write the SAME output
2139 // element `out[(rbase+li)*p+oc]`. A plain `+=` races and loses updates
2140 // (silently-wrong Schur matvec); accumulate atomically. `double` atomicAdd
2141 // needs sm_60+, guaranteed by the NVRTC arch pin (#1551).
2142 atomicAdd(&out[(rbase + li) * p + oc], acc);
2143}
2144
2145extern "C" __global__ void arrow_sae_gather_u(
2146 const double* __restrict__ x,
2147 const int* __restrict__ row_ptr,
2148 const int* __restrict__ beta_base,
2149 const double* __restrict__ phi,
2150 double* __restrict__ u,
2151 int p,
2152 int n_rows
2153) {
2154 int row = blockIdx.y;
2155 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2156 if (row >= n_rows || oc >= p) {
2157 return;
2158 }
2159 double acc = 0.0;
2160 int start = row_ptr[row];
2161 int end = row_ptr[row + 1];
2162 for (int e = start; e < end; ++e) {
2163 acc += phi[e] * x[beta_base[e] + oc];
2164 }
2165 u[row * p + oc] = acc;
2166}
2167
2168extern "C" __global__ void arrow_sae_apply_l(
2169 const double* __restrict__ u,
2170 const int* __restrict__ jac_ptr,
2171 const double* __restrict__ jac,
2172 double* __restrict__ w,
2173 int p,
2174 int max_q,
2175 int n_rows
2176) {
2177 int row = blockIdx.y;
2178 int c = blockIdx.x * blockDim.x + threadIdx.x;
2179 if (row >= n_rows) {
2180 return;
2181 }
2182 int jstart = jac_ptr[row];
2183 int q = (jac_ptr[row + 1] - jstart) / p;
2184 if (c >= q) {
2185 return;
2186 }
2187 double acc = 0.0;
2188 for (int oc = 0; oc < p; ++oc) {
2189 acc += jac[jstart + c * p + oc] * u[row * p + oc];
2190 }
2191 w[row * max_q + c] = acc;
2192}
2193
2194extern "C" __global__ void arrow_sae_apply_ainv(
2195 const double* __restrict__ ainv,
2196 const double* __restrict__ w,
2197 double* __restrict__ v,
2198 int max_q,
2199 int n_rows
2200) {
2201 int row = blockIdx.y;
2202 int c = blockIdx.x * blockDim.x + threadIdx.x;
2203 if (row >= n_rows || c >= max_q) {
2204 return;
2205 }
2206 double acc = 0.0;
2207 int base = row * max_q * max_q;
2208 for (int j = 0; j < max_q; ++j) {
2209 acc += ainv[base + c * max_q + j] * w[row * max_q + j];
2210 }
2211 v[row * max_q + c] = acc;
2212}
2213
2214extern "C" __global__ void arrow_sae_scatter_sub(
2215 const double* __restrict__ v,
2216 const int* __restrict__ jac_ptr,
2217 const double* __restrict__ jac,
2218 const int* __restrict__ row_ptr,
2219 const int* __restrict__ beta_base,
2220 const double* __restrict__ phi,
2221 double* __restrict__ out,
2222 int p,
2223 int max_q,
2224 int n_rows
2225) {
2226 int row = blockIdx.y;
2227 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2228 if (row >= n_rows || oc >= p) {
2229 return;
2230 }
2231 int jstart = jac_ptr[row];
2232 int q = (jac_ptr[row + 1] - jstart) / p;
2233 double lt_v = 0.0;
2234 for (int c = 0; c < q; ++c) {
2235 lt_v += jac[jstart + c * p + oc] * v[row * max_q + c];
2236 }
2237 int start = row_ptr[row];
2238 int end = row_ptr[row + 1];
2239 for (int e = start; e < end; ++e) {
2240 atomicAdd(&out[beta_base[e] + oc], -phi[e] * lt_v);
2241 }
2242}
2243
2244extern "C" __global__ void arrow_sae_diag_sub(
2245 double* __restrict__ diag,
2246 const double* __restrict__ ainv,
2247 const int* __restrict__ jac_ptr,
2248 const double* __restrict__ jac,
2249 const int* __restrict__ row_ptr,
2250 const int* __restrict__ beta_base,
2251 const double* __restrict__ phi,
2252 int p,
2253 int max_q,
2254 int n_rows
2255) {
2256 int row = blockIdx.y;
2257 int oc = blockIdx.x * blockDim.x + threadIdx.x;
2258 if (row >= n_rows || oc >= p) {
2259 return;
2260 }
2261 int jstart = jac_ptr[row];
2262 int q = (jac_ptr[row + 1] - jstart) / p;
2263 int abase = row * max_q * max_q;
2264 double quad = 0.0;
2265 for (int c = 0; c < q; ++c) {
2266 double lc = jac[jstart + c * p + oc];
2267 for (int d = 0; d < q; ++d) {
2268 quad += lc * ainv[abase + c * max_q + d] * jac[jstart + d * p + oc];
2269 }
2270 }
2271 int start = row_ptr[row];
2272 int end = row_ptr[row + 1];
2273 for (int e = start; e < end; ++e) {
2274 double pe = phi[e];
2275 atomicAdd(&diag[beta_base[e] + oc], -(pe * pe) * quad);
2276 }
2277}
2278
2279/* ── #1017/#1026 frames-engaged device kernels ─────────────────────────────
2280 * The factored β border is C-space (width Σ M_k·r_k). The penalty side is the
2281 * smooth `λ S_k ⊗ I_{r_k}` (per-block right-width r_k) plus the data-fit
2282 * `G_{ij} ⊗ W_{ij}` (W = U_iᵀU_j, dense r_i×r_j). The reduced-Schur term uses
2283 * the per-row DENSE cross-block H_tβ^(i) (q_i × border_dim, row-major). */
2284
2285extern "C" __global__ void arrow_sae_frame_smooth_matvec(
2286 const double* __restrict__ x,
2287 double* __restrict__ out,
2288 const int* __restrict__ block_offsets,
2289 const int* __restrict__ block_m,
2290 const int* __restrict__ block_r,
2291 const int* __restrict__ factor_ptr,
2292 const double* __restrict__ factors,
2293 int n_blocks
2294) {
2295 int block_id = blockIdx.y;
2296 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2297 if (block_id >= n_blocks) {
2298 return;
2299 }
2300 int m = block_m[block_id];
2301 int r = block_r[block_id];
2302 int total = m * r;
2303 if (linear >= total) {
2304 return;
2305 }
2306 int li = linear / r;
2307 int ib = linear - li * r;
2308 int off = block_offsets[block_id];
2309 int fbase = factor_ptr[block_id];
2310 double acc = 0.0;
2311 for (int lj = 0; lj < m; ++lj) {
2312 double a = factors[fbase + li * m + lj];
2313 acc += a * x[off + lj * r + ib];
2314 }
2315 out[off + li * r + ib] += acc;
2316}
2317
2318extern "C" __global__ void arrow_sae_frame_g_matvec(
2319 const double* __restrict__ x,
2320 double* __restrict__ out,
2321 const int* __restrict__ off_i,
2322 const int* __restrict__ off_j,
2323 const int* __restrict__ r_i,
2324 const int* __restrict__ r_j,
2325 const int* __restrict__ m_i,
2326 const int* __restrict__ m_j,
2327 const int* __restrict__ g_ptr,
2328 const double* __restrict__ g_data,
2329 const int* __restrict__ w_ptr,
2330 const double* __restrict__ w_data,
2331 int n_blocks
2332) {
2333 int block_id = blockIdx.y;
2334 int linear = blockIdx.x * blockDim.x + threadIdx.x;
2335 if (block_id >= n_blocks) {
2336 return;
2337 }
2338 int ri = r_i[block_id];
2339 int rj = r_j[block_id];
2340 int mi = m_i[block_id];
2341 int mj = m_j[block_id];
2342 int total = mi * ri;
2343 if (linear >= total) {
2344 return;
2345 }
2346 int li = linear / ri; // basis row in atom i
2347 int a = linear - li * ri; // frame coord in atom i
2348 int oi = off_i[block_id];
2349 int oj = off_j[block_id];
2350 int gbase = g_ptr[block_id];
2351 int wbase = w_ptr[block_id];
2352 double acc = 0.0;
2353 for (int lj = 0; lj < mj; ++lj) {
2354 double g = g_data[gbase + li * mj + lj];
2355 if (g == 0.0) { continue; }
2356 int xj_base = oj + lj * rj;
2357 double inner = 0.0;
2358 for (int b = 0; b < rj; ++b) {
2359 inner += w_data[wbase + a * rj + b] * x[xj_base + b];
2360 }
2361 acc += g * inner;
2362 }
2363 // #1017 — same race as `arrow_sae_sparse_g_matvec`: atom i is the row atom of
2364 // multiple co-occurring (i,j) frame blocks running concurrently on
2365 // blockIdx.y, all targeting `out[oi+li*ri+a]`. Accumulate atomically so the
2366 // framed G⊗W matvec is correct (the CPU oracle sums these sequentially).
2367 atomicAdd(&out[oi + li * ri + a], acc);
2368}
2369
2370/* Per-row reduced-Schur subtraction with a DENSE cross-block H_tβ^(i).
2371 * h_i = H_tβ^(i) · x (length q_i)
2372 * s_i = (H_tt^(i)+ρ_t I)⁻¹ h_i (apply cached ainv, length q_i)
2373 * out -= (H_tβ^(i))ᵀ · s_i (scatter into border_dim)
2374 * `htb` is row-major (q_i × k) flattened, `htb_ptr` gives each row's base and
2375 * (htb_ptr[row+1]-htb_ptr[row])/k == q_i. `q_of` carries q_i directly. */
2376extern "C" __global__ void arrow_sae_frame_apply_h(
2377 const double* __restrict__ x,
2378 const int* __restrict__ htb_ptr,
2379 const double* __restrict__ htb,
2380 const int* __restrict__ q_of,
2381 double* __restrict__ hvec,
2382 int k,
2383 int max_q,
2384 int n_rows
2385) {
2386 int row = blockIdx.y;
2387 int c = blockIdx.x * blockDim.x + threadIdx.x;
2388 if (row >= n_rows) { return; }
2389 int q = q_of[row];
2390 if (c >= q) { return; }
2391 int base = htb_ptr[row] + c * k;
2392 double acc = 0.0;
2393 for (int a = 0; a < k; ++a) {
2394 acc += htb[base + a] * x[a];
2395 }
2396 hvec[row * max_q + c] = acc;
2397}
2398
2399extern "C" __global__ void arrow_sae_frame_apply_ainv(
2400 const double* __restrict__ ainv,
2401 const double* __restrict__ hvec,
2402 const int* __restrict__ q_of,
2403 double* __restrict__ svec,
2404 int max_q,
2405 int n_rows
2406) {
2407 int row = blockIdx.y;
2408 int c = blockIdx.x * blockDim.x + threadIdx.x;
2409 if (row >= n_rows || c >= max_q) { return; }
2410 int q = q_of[row];
2411 double acc = 0.0;
2412 int abase = row * max_q * max_q;
2413 for (int j = 0; j < q; ++j) {
2414 acc += ainv[abase + c * max_q + j] * hvec[row * max_q + j];
2415 }
2416 svec[row * max_q + c] = acc;
2417}
2418
2419extern "C" __global__ void arrow_sae_frame_scatter_h(
2420 const double* __restrict__ svec,
2421 const int* __restrict__ htb_ptr,
2422 const double* __restrict__ htb,
2423 const int* __restrict__ q_of,
2424 double* __restrict__ out,
2425 int k,
2426 int max_q,
2427 int n_rows
2428) {
2429 int row = blockIdx.y;
2430 int a = blockIdx.x * blockDim.x + threadIdx.x;
2431 if (row >= n_rows || a >= k) { return; }
2432 int q = q_of[row];
2433 int hbase = htb_ptr[row];
2434 double acc = 0.0;
2435 for (int c = 0; c < q; ++c) {
2436 acc += htb[hbase + c * k + a] * svec[row * max_q + c];
2437 }
2438 atomicAdd(&out[a], -acc);
2439}
2440
2441/* Frame Jacobi diagonal subtraction: diag[a] -= Σ_c Σ_d H_tβ[c,a]·ainv[c,d]·H_tβ[d,a]. */
2442extern "C" __global__ void arrow_sae_frame_diag_sub(
2443 double* __restrict__ diag,
2444 const double* __restrict__ ainv,
2445 const int* __restrict__ htb_ptr,
2446 const double* __restrict__ htb,
2447 const int* __restrict__ q_of,
2448 int k,
2449 int max_q,
2450 int n_rows
2451) {
2452 int row = blockIdx.y;
2453 int a = blockIdx.x * blockDim.x + threadIdx.x;
2454 if (row >= n_rows || a >= k) { return; }
2455 int q = q_of[row];
2456 int hbase = htb_ptr[row];
2457 int abase = row * max_q * max_q;
2458 double quad = 0.0;
2459 for (int c = 0; c < q; ++c) {
2460 double hc = htb[hbase + c * k + a];
2461 for (int d = 0; d < q; ++d) {
2462 quad += hc * ainv[abase + c * max_q + d] * htb[hbase + d * k + a];
2463 }
2464 }
2465 atomicAdd(&diag[a], -quad);
2466}
2467"#;
2468
2469 fn pcg_vector_module(
2470 ctx: &Arc<CudaContext>,
2471 ) -> Result<&'static Arc<CudaModule>, ArrowSchurGpuFailure> {
2472 static CACHE: gam_gpu::device_cache::PtxModuleCache =
2473 gam_gpu::device_cache::PtxModuleCache::new();
2474 CACHE
2475 .get_or_compile(ctx, "arrow_pcg_vector", PCG_VECTOR_KERNEL_SOURCE)
2476 .map_err(|err| {
2477 log::warn!("[#1551] pcg_vector_module get_or_compile failed: {err}");
2483 ArrowSchurGpuFailure::Unavailable
2484 })
2485 }
2486
2487 fn pcg_launch_config(n: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
2488 let threads = 256u32;
2489 let blocks = ((n as u32).saturating_add(threads - 1) / threads).max(1);
2490 Ok(LaunchConfig {
2491 grid_dim: (blocks, 1, 1),
2492 block_dim: (threads, 1, 1),
2493 shared_mem_bytes: 0,
2494 })
2495 }
2496
2497 fn launch_jacobi_mul(
2498 stream: &Arc<CudaStream>,
2499 module: &Arc<CudaModule>,
2500 inv_diag: &CudaSlice<f64>,
2501 r: &CudaSlice<f64>,
2502 z: &mut CudaSlice<f64>,
2503 n: usize,
2504 ) -> Result<(), ArrowSchurGpuFailure> {
2505 let kernel = module
2506 .load_function("arrow_pcg_jacobi_mul")
2507 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2508 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2509 let mut builder = stream.launch_builder(&kernel);
2510 builder.arg(inv_diag).arg(r).arg(z).arg(&n_i32);
2511 unsafe { builder.launch(pcg_launch_config(n)?) }
2514 .map(drop)
2515 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2516 }
2517
2518 fn launch_update_p(
2519 stream: &Arc<CudaStream>,
2520 module: &Arc<CudaModule>,
2521 z: &CudaSlice<f64>,
2522 beta: f64,
2523 p: &mut CudaSlice<f64>,
2524 n: usize,
2525 ) -> Result<(), ArrowSchurGpuFailure> {
2526 let kernel = module
2527 .load_function("arrow_pcg_update_p")
2528 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2529 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
2530 let mut builder = stream.launch_builder(&kernel);
2531 builder.arg(z).arg(&beta).arg(p).arg(&n_i32);
2532 unsafe { builder.launch(pcg_launch_config(n)?) }
2535 .map(drop)
2536 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2537 }
2538
2539 struct DeviceSaePcgBuffers {
2540 row_ptr: CudaSlice<i32>,
2541 beta_base: CudaSlice<i32>,
2542 phi: CudaSlice<f64>,
2543 jac_ptr: CudaSlice<i32>,
2544 jac: CudaSlice<f64>,
2545 smooth_offsets: CudaSlice<i32>,
2546 smooth_m: CudaSlice<i32>,
2547 smooth_ptr: CudaSlice<i32>,
2548 smooth_data: CudaSlice<f64>,
2549 g_row_off: CudaSlice<i32>,
2550 g_col_off: CudaSlice<i32>,
2551 g_rows: CudaSlice<i32>,
2552 g_cols: CudaSlice<i32>,
2553 g_ptr: CudaSlice<i32>,
2554 g_data: CudaSlice<f64>,
2555 ainv: CudaSlice<f64>,
2556 u: CudaSlice<f64>,
2557 w: CudaSlice<f64>,
2558 v: CudaSlice<f64>,
2559 n_rows: usize,
2560 p: usize,
2561 k: usize,
2562 max_q: usize,
2563 smooth_blocks: usize,
2564 g_blocks: usize,
2565 }
2566
2567 fn checked_i32(value: usize) -> Result<i32, ArrowSchurGpuFailure> {
2568 to_i32(value).ok_or(ArrowSchurGpuFailure::Unavailable)
2569 }
2570
2571 fn sae_penalty_diag_host(
2572 data: &DeviceSaePcgData,
2573 ridge_beta: f64,
2574 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
2575 let mut diag = vec![ridge_beta; data.beta_dim];
2576 for block in &data.smooth_blocks {
2577 let (rows, cols) = block.factor_a.dim();
2578 if rows != cols {
2579 return Err(ArrowSchurGpuFailure::Unavailable);
2580 }
2581 for row in 0..rows {
2582 let coeff = block.factor_a[[row, row]];
2583 let base = block
2584 .global_offset
2585 .checked_add(
2586 row.checked_mul(data.p)
2587 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
2588 )
2589 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2590 let end = base
2591 .checked_add(data.p)
2592 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2593 if end > diag.len() {
2594 return Err(ArrowSchurGpuFailure::Unavailable);
2595 }
2596 for channel in 0..data.p {
2597 diag[base + channel] += coeff;
2598 }
2599 }
2600 }
2601 for block in &data.sparse_g_blocks {
2602 if block.row_off != block.col_off {
2603 continue;
2604 }
2605 let (rows, cols) = block.data.dim();
2606 for row in 0..rows.min(cols) {
2607 let coeff = block.data[[row, row]];
2608 let beta_row = block
2609 .row_off
2610 .checked_add(row)
2611 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2612 let base = beta_row
2613 .checked_mul(data.p)
2614 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2615 let end = base
2616 .checked_add(data.p)
2617 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
2618 if end > diag.len() {
2619 return Err(ArrowSchurGpuFailure::Unavailable);
2620 }
2621 for channel in 0..data.p {
2622 diag[base + channel] += coeff;
2623 }
2624 }
2625 }
2626 Ok(diag)
2627 }
2628
2629 fn flatten_device_sae_data(
2630 sys: &ArrowSchurSystem,
2631 data: &DeviceSaePcgData,
2632 ridge_t: f64,
2633 stream: &Arc<CudaStream>,
2634 ) -> Result<DeviceSaePcgBuffers, ArrowSchurGpuFailure> {
2635 let n_rows = sys.rows.len();
2636 let p = data.p;
2637 let k = data.beta_dim;
2638 if data.a_phi.len() != n_rows || data.local_jac.len() != n_rows {
2639 return Err(ArrowSchurGpuFailure::Unavailable);
2640 }
2641
2642 let mut row_ptr_host = Vec::with_capacity(n_rows + 1);
2643 let mut beta_base_host = Vec::<i32>::new();
2644 let mut phi_host = Vec::<f64>::new();
2645 row_ptr_host.push(0_i32);
2646 for row in data.a_phi.iter() {
2647 for &(base, phi) in row {
2648 beta_base_host.push(checked_i32(base)?);
2649 phi_host.push(phi);
2650 }
2651 row_ptr_host.push(checked_i32(beta_base_host.len())?);
2652 }
2653
2654 let mut jac_ptr_host = Vec::with_capacity(n_rows + 1);
2655 let mut jac_host = Vec::<f64>::new();
2656 let mut max_q = 0usize;
2657 jac_ptr_host.push(0_i32);
2658 for row_jac in data.local_jac.iter() {
2659 if row_jac.len() % p != 0 {
2660 return Err(ArrowSchurGpuFailure::Unavailable);
2661 }
2662 max_q = max_q.max(row_jac.len() / p);
2663 jac_host.extend_from_slice(row_jac);
2664 jac_ptr_host.push(checked_i32(jac_host.len())?);
2665 }
2666 if max_q == 0 {
2667 return Err(ArrowSchurGpuFailure::Unavailable);
2668 }
2669
2670 let mut smooth_offsets_host = Vec::with_capacity(data.smooth_blocks.len());
2671 let mut smooth_m_host = Vec::with_capacity(data.smooth_blocks.len());
2672 let mut smooth_ptr_host = Vec::with_capacity(data.smooth_blocks.len() + 1);
2673 let mut smooth_data_host = Vec::<f64>::new();
2674 smooth_ptr_host.push(0_i32);
2675 for block in &data.smooth_blocks {
2676 let (rows, cols) = block.factor_a.dim();
2677 if rows != cols {
2678 return Err(ArrowSchurGpuFailure::Unavailable);
2679 }
2680 smooth_offsets_host.push(checked_i32(block.global_offset)?);
2681 smooth_m_host.push(checked_i32(rows)?);
2682 for r in 0..rows {
2683 for c in 0..cols {
2684 smooth_data_host.push(block.factor_a[[r, c]]);
2685 }
2686 }
2687 smooth_ptr_host.push(checked_i32(smooth_data_host.len())?);
2688 }
2689
2690 let mut g_row_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2691 let mut g_col_off_host = Vec::with_capacity(data.sparse_g_blocks.len());
2692 let mut g_rows_host = Vec::with_capacity(data.sparse_g_blocks.len());
2693 let mut g_cols_host = Vec::with_capacity(data.sparse_g_blocks.len());
2694 let mut g_ptr_host = Vec::with_capacity(data.sparse_g_blocks.len() + 1);
2695 let mut g_data_host = Vec::<f64>::new();
2696 g_ptr_host.push(0_i32);
2697 for block in &data.sparse_g_blocks {
2698 let (rows, cols) = block.data.dim();
2699 g_row_off_host.push(checked_i32(block.row_off)?);
2700 g_col_off_host.push(checked_i32(block.col_off)?);
2701 g_rows_host.push(checked_i32(rows)?);
2702 g_cols_host.push(checked_i32(cols)?);
2703 for r in 0..rows {
2704 for c in 0..cols {
2705 g_data_host.push(block.data[[r, c]]);
2706 }
2707 }
2708 g_ptr_host.push(checked_i32(g_data_host.len())?);
2709 }
2710
2711 let mut ainv_host = vec![0.0_f64; n_rows * max_q * max_q];
2712 for (row_idx, row) in sys.rows.iter().enumerate() {
2713 let q = data.local_jac[row_idx].len() / p;
2714 if row.htt.dim() != (q, q) {
2715 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
2716 reason: format!(
2717 "SAE device PCG row {row_idx}: H_tt shape {:?} != ({q}, {q})",
2718 row.htt.dim()
2719 ),
2720 });
2721 }
2722 let mut block = row.htt.clone();
2723 for d in 0..q {
2724 block[[d, d]] += ridge_t;
2725 }
2726 let factor = gam_linalg::triangular::cholesky_factor_in_place(
2727 block.view(),
2728 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
2729 )
2730 .ok_or_else(|| {
2731 ArrowSchurGpuFailure::RidgeBumpRequired {
2734 row: row_idx,
2735 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
2736 }
2737 })?;
2738 for col in 0..q {
2739 let mut e = Array1::<f64>::zeros(q);
2740 e[col] = 1.0;
2741 let solved =
2742 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
2743 for r in 0..q {
2744 ainv_host[row_idx * max_q * max_q + r * max_q + col] = solved[r];
2745 }
2746 }
2747 }
2748
2749 Ok(DeviceSaePcgBuffers {
2750 row_ptr: stream
2751 .clone_htod(&row_ptr_host)
2752 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2753 beta_base: stream
2754 .clone_htod(&beta_base_host)
2755 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2756 phi: stream
2757 .clone_htod(&phi_host)
2758 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2759 jac_ptr: stream
2760 .clone_htod(&jac_ptr_host)
2761 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2762 jac: stream
2763 .clone_htod(&jac_host)
2764 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2765 smooth_offsets: stream
2766 .clone_htod(&smooth_offsets_host)
2767 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2768 smooth_m: stream
2769 .clone_htod(&smooth_m_host)
2770 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2771 smooth_ptr: stream
2772 .clone_htod(&smooth_ptr_host)
2773 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2774 smooth_data: stream
2775 .clone_htod(&smooth_data_host)
2776 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2777 g_row_off: stream
2778 .clone_htod(&g_row_off_host)
2779 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2780 g_col_off: stream
2781 .clone_htod(&g_col_off_host)
2782 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2783 g_rows: stream
2784 .clone_htod(&g_rows_host)
2785 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2786 g_cols: stream
2787 .clone_htod(&g_cols_host)
2788 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2789 g_ptr: stream
2790 .clone_htod(&g_ptr_host)
2791 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2792 g_data: stream
2793 .clone_htod(&g_data_host)
2794 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2795 ainv: stream
2796 .clone_htod(&ainv_host)
2797 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2798 u: stream
2799 .alloc_zeros::<f64>(n_rows * p)
2800 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2801 w: stream
2802 .alloc_zeros::<f64>(n_rows * max_q)
2803 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2804 v: stream
2805 .alloc_zeros::<f64>(n_rows * max_q)
2806 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
2807 n_rows,
2808 p,
2809 k,
2810 max_q,
2811 smooth_blocks: data.smooth_blocks.len(),
2812 g_blocks: data.sparse_g_blocks.len(),
2813 })
2814 }
2815
2816 fn launch_sae_init(
2817 stream: &Arc<CudaStream>,
2818 module: &Arc<CudaModule>,
2819 out: &mut CudaSlice<f64>,
2820 x: &CudaSlice<f64>,
2821 ridge: f64,
2822 n: usize,
2823 ) -> Result<(), ArrowSchurGpuFailure> {
2824 let kernel = module
2825 .load_function("arrow_sae_init")
2826 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2827 let n_i32 = checked_i32(n)?;
2828 let mut builder = stream.launch_builder(&kernel);
2829 builder.arg(out).arg(x).arg(&ridge).arg(&n_i32);
2830 unsafe { builder.launch(pcg_launch_config(n)?) }
2834 .map(drop)
2835 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
2836 }
2837
2838 fn launch_sae_penalty_matvec(
2839 stream: &Arc<CudaStream>,
2840 module: &Arc<CudaModule>,
2841 buffers: &mut DeviceSaePcgBuffers,
2842 x: &CudaSlice<f64>,
2843 out: &mut CudaSlice<f64>,
2844 ridge_beta: f64,
2845 ) -> Result<(), ArrowSchurGpuFailure> {
2846 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
2847 if buffers.smooth_blocks > 0 {
2848 let kernel = module
2849 .load_function("arrow_sae_smooth_matvec")
2850 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2851 let max_m = buffers.k;
2852 let p_i32 = checked_i32(buffers.p)?;
2853 let blocks_i32 = checked_i32(buffers.smooth_blocks)?;
2854 let cfg = LaunchConfig {
2855 grid_dim: (
2856 ((max_m as u32).saturating_add(255) / 256).max(1),
2857 checked_i32(buffers.smooth_blocks)? as u32,
2858 1,
2859 ),
2860 block_dim: (256, 1, 1),
2861 shared_mem_bytes: 0,
2862 };
2863 let mut builder = stream.launch_builder(&kernel);
2864 builder
2865 .arg(x)
2866 .arg(&mut *out)
2867 .arg(&buffers.smooth_offsets)
2868 .arg(&buffers.smooth_m)
2869 .arg(&buffers.smooth_ptr)
2870 .arg(&buffers.smooth_data)
2871 .arg(&p_i32)
2872 .arg(&blocks_i32);
2873 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2878 }
2879 if buffers.g_blocks > 0 {
2880 let kernel = module
2881 .load_function("arrow_sae_sparse_g_matvec")
2882 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2883 let max_work = buffers
2884 .k
2885 .checked_div(buffers.p)
2886 .unwrap_or(0)
2887 .saturating_mul(buffers.p);
2888 let p_i32 = checked_i32(buffers.p)?;
2889 let blocks_i32 = checked_i32(buffers.g_blocks)?;
2890 let cfg = LaunchConfig {
2891 grid_dim: (
2892 ((max_work as u32).saturating_add(255) / 256).max(1),
2893 checked_i32(buffers.g_blocks)? as u32,
2894 1,
2895 ),
2896 block_dim: (256, 1, 1),
2897 shared_mem_bytes: 0,
2898 };
2899 let mut builder = stream.launch_builder(&kernel);
2900 builder
2901 .arg(x)
2902 .arg(&mut *out)
2903 .arg(&buffers.g_row_off)
2904 .arg(&buffers.g_col_off)
2905 .arg(&buffers.g_rows)
2906 .arg(&buffers.g_cols)
2907 .arg(&buffers.g_ptr)
2908 .arg(&buffers.g_data)
2909 .arg(&p_i32)
2910 .arg(&blocks_i32);
2911 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2916 }
2917 Ok(())
2918 }
2919
2920 fn launch_sae_row_schur_sub(
2921 stream: &Arc<CudaStream>,
2922 module: &Arc<CudaModule>,
2923 buffers: &mut DeviceSaePcgBuffers,
2924 x: &CudaSlice<f64>,
2925 out: &mut CudaSlice<f64>,
2926 ) -> Result<(), ArrowSchurGpuFailure> {
2927 let p_i32 = checked_i32(buffers.p)?;
2928 let max_q_i32 = checked_i32(buffers.max_q)?;
2929 let n_rows_i32 = checked_i32(buffers.n_rows)?;
2930 let cfg_p_rows = LaunchConfig {
2931 grid_dim: (
2932 ((buffers.p as u32).saturating_add(255) / 256).max(1),
2933 checked_i32(buffers.n_rows)? as u32,
2934 1,
2935 ),
2936 block_dim: (256, 1, 1),
2937 shared_mem_bytes: 0,
2938 };
2939 let gather = module
2940 .load_function("arrow_sae_gather_u")
2941 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2942 {
2943 let mut builder = stream.launch_builder(&gather);
2944 builder
2945 .arg(x)
2946 .arg(&buffers.row_ptr)
2947 .arg(&buffers.beta_base)
2948 .arg(&buffers.phi)
2949 .arg(&mut buffers.u)
2950 .arg(&p_i32)
2951 .arg(&n_rows_i32);
2952 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2956 }
2957
2958 let cfg_q_rows = LaunchConfig {
2959 grid_dim: (
2960 ((buffers.max_q as u32).saturating_add(255) / 256).max(1),
2961 checked_i32(buffers.n_rows)? as u32,
2962 1,
2963 ),
2964 block_dim: (256, 1, 1),
2965 shared_mem_bytes: 0,
2966 };
2967 let apply_l = module
2968 .load_function("arrow_sae_apply_l")
2969 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2970 {
2971 let mut builder = stream.launch_builder(&apply_l);
2972 builder
2973 .arg(&buffers.u)
2974 .arg(&buffers.jac_ptr)
2975 .arg(&buffers.jac)
2976 .arg(&mut buffers.w)
2977 .arg(&p_i32)
2978 .arg(&max_q_i32)
2979 .arg(&n_rows_i32);
2980 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2984 }
2985
2986 let apply_ainv = module
2987 .load_function("arrow_sae_apply_ainv")
2988 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
2989 {
2990 let mut builder = stream.launch_builder(&apply_ainv);
2991 builder
2992 .arg(&buffers.ainv)
2993 .arg(&buffers.w)
2994 .arg(&mut buffers.v)
2995 .arg(&max_q_i32)
2996 .arg(&n_rows_i32);
2997 unsafe { builder.launch(cfg_q_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3001 }
3002
3003 let scatter = module
3004 .load_function("arrow_sae_scatter_sub")
3005 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3006 {
3007 let mut builder = stream.launch_builder(&scatter);
3008 builder
3009 .arg(&buffers.v)
3010 .arg(&buffers.jac_ptr)
3011 .arg(&buffers.jac)
3012 .arg(&buffers.row_ptr)
3013 .arg(&buffers.beta_base)
3014 .arg(&buffers.phi)
3015 .arg(out)
3016 .arg(&p_i32)
3017 .arg(&max_q_i32)
3018 .arg(&n_rows_i32);
3019 unsafe { builder.launch(cfg_p_rows) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3023 }
3024 Ok(())
3025 }
3026
3027 fn launch_sae_diag_sub(
3028 stream: &Arc<CudaStream>,
3029 module: &Arc<CudaModule>,
3030 buffers: &DeviceSaePcgBuffers,
3031 diag: &mut CudaSlice<f64>,
3032 ) -> Result<(), ArrowSchurGpuFailure> {
3033 let kernel = module
3034 .load_function("arrow_sae_diag_sub")
3035 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3036 let p_i32 = checked_i32(buffers.p)?;
3037 let max_q_i32 = checked_i32(buffers.max_q)?;
3038 let n_rows_i32 = checked_i32(buffers.n_rows)?;
3039 let cfg = LaunchConfig {
3040 grid_dim: (
3041 ((buffers.p as u32).saturating_add(255) / 256).max(1),
3042 checked_i32(buffers.n_rows)? as u32,
3043 1,
3044 ),
3045 block_dim: (256, 1, 1),
3046 shared_mem_bytes: 0,
3047 };
3048 let mut builder = stream.launch_builder(&kernel);
3049 builder
3050 .arg(diag)
3051 .arg(&buffers.ainv)
3052 .arg(&buffers.jac_ptr)
3053 .arg(&buffers.jac)
3054 .arg(&buffers.row_ptr)
3055 .arg(&buffers.beta_base)
3056 .arg(&buffers.phi)
3057 .arg(&p_i32)
3058 .arg(&max_q_i32)
3059 .arg(&n_rows_i32);
3060 unsafe { builder.launch(cfg) }
3064 .map(drop)
3065 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
3066 }
3067
3068 fn launch_sae_matvec(
3069 stream: &Arc<CudaStream>,
3070 module: &Arc<CudaModule>,
3071 buffers: &mut DeviceSaePcgBuffers,
3072 x: &CudaSlice<f64>,
3073 out: &mut CudaSlice<f64>,
3074 ridge_beta: f64,
3075 ) -> Result<(), ArrowSchurGpuFailure> {
3076 launch_sae_penalty_matvec(stream, module, buffers, x, out, ridge_beta)?;
3077 launch_sae_row_schur_sub(stream, module, buffers, x, out)
3078 }
3079
3080 fn pack_fused_host(
3085 sys: &ArrowSchurSystem,
3086 ridge_t: f64,
3087 p_max: usize,
3088 r_template: usize,
3089 ) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
3090 let n = sys.rows.len();
3091 let d = sys.d;
3092 let k = sys.k;
3093 let mut d_buf = vec![0.0_f64; n * p_max * p_max];
3094 let mut b_buf = vec![0.0_f64; n * p_max * r_template];
3095 let mut g_buf = vec![0.0_f64; n * p_max];
3096 for (i, row) in sys.rows.iter().enumerate() {
3097 for col in 0..d {
3099 let base = (i * p_max + col) * p_max;
3100 for r in 0..d {
3101 let mut value = row.htt[[r, col]];
3102 if r == col {
3103 value += ridge_t;
3104 }
3105 d_buf[base + r] = value;
3106 }
3107 }
3108 for col in 0..k {
3116 let base = (i * r_template + col) * p_max;
3117 for r in 0..d {
3118 b_buf[base + r] = row.htbeta[[r, col]];
3119 }
3120 }
3121 let g_base = i * p_max;
3123 for r in 0..d {
3124 g_buf[g_base + r] = row.gt[r];
3125 }
3126 }
3127 (d_buf, b_buf, g_buf)
3128 }
3129
3130 pub(super) struct ResidentArrowFrame {
3157 n: usize,
3158 d: usize,
3159 k: usize,
3160 stream: Arc<CudaStream>,
3161 blas: CudaBlas,
3162 l_dev: CudaSlice<f64>,
3165 y_dev: CudaSlice<f64>,
3168 schur_dev: CudaSlice<f64>,
3171 log_det_hessian: f64,
3174 }
3175
3176 impl ResidentArrowFrame {
3177 pub(super) fn new(
3181 sys: &ArrowSchurSystem,
3182 ridge_t: f64,
3183 ridge_beta: f64,
3184 ) -> Result<Self, ArrowSchurGpuFailure> {
3185 if ridge_t.is_nan() || ridge_beta.is_nan() {
3186 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3187 reason: "ridge is NaN".to_string(),
3188 });
3189 }
3190 let n = sys.rows.len();
3191 let d = sys.d;
3192 let k = sys.k;
3193 let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n })
3194 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3195 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3196 .and_then(|ctx| ctx.new_stream().ok())
3197 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3198 let solver =
3199 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3200 let blas =
3201 CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3202
3203 let (d_host, b_host, _g_host) = pack_host(sys, ridge_t);
3205 let mut l_dev = stream
3206 .clone_htod(&d_host)
3207 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3208 let mut y_dev = stream
3209 .clone_htod(&b_host)
3210 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3211
3212 let info_host = potrf_batched(&solver, &stream, d, n, &mut l_dev)?;
3214 if let Some(idx) = info_host.iter().position(|info| *info != 0) {
3215 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3219 row: idx,
3220 bump: super::ridge_bump_to_make_pd(sys.rows[idx].htt.view(), ridge_t),
3221 });
3222 }
3223
3224 trsm_batched_lower_inplace(&blas, &stream, d, n, k, &l_dev, &mut y_dev)?;
3226
3227 let schur_init: Vec<f64> = {
3232 let mut tmp = Vec::with_capacity(k * k);
3233 for col in 0..k {
3234 for row in 0..k {
3235 let mut v = sys.hbb[[row, col]];
3236 if row == col {
3237 v += ridge_beta;
3238 }
3239 tmp.push(v);
3240 }
3241 }
3242 tmp
3243 };
3244 let mut schur_dev = stream
3245 .clone_htod(&schur_init)
3246 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3247 let zero_u = stream
3250 .clone_htod(&vec![0.0_f64; n * d])
3251 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3252 let mut throwaway_rhs = stream
3253 .clone_htod(&vec![0.0_f64; k])
3254 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3255 accumulate_schur(
3256 &blas,
3257 d,
3258 k,
3259 n,
3260 &y_dev,
3261 &zero_u,
3262 &mut schur_dev,
3263 &mut throwaway_rhs,
3264 )?;
3265 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3266 if info != 0 {
3267 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3268 reason: format!("Schur Cholesky failed at pivot {info}"),
3269 });
3270 }
3271
3272 let l_local_host = stream
3274 .clone_dtoh(&l_dev)
3275 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3276 let l_schur_host = stream
3277 .clone_dtoh(&schur_dev)
3278 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3279 let mut log_det = 0.0_f64;
3280 for i in 0..n {
3281 let base = i * d * d;
3282 for j in 0..d {
3283 log_det += l_local_host[base + j * d + j].ln();
3284 }
3285 }
3286 for j in 0..k {
3287 log_det += l_schur_host[j * k + j].ln();
3288 }
3289 log_det *= 2.0;
3290
3291 Ok(Self {
3292 n,
3293 d,
3294 k,
3295 stream,
3296 blas,
3297 l_dev,
3298 y_dev,
3299 schur_dev,
3300 log_det_hessian: log_det,
3301 })
3302 }
3303
3304 #[inline]
3305 pub(super) fn log_det_hessian(&self) -> f64 {
3306 self.log_det_hessian
3307 }
3308
3309 pub(super) fn solve_gradient(
3313 &self,
3314 g_t: &[f64],
3315 g_beta: &[f64],
3316 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3317 let n = self.n;
3318 let d = self.d;
3319 let k = self.k;
3320 if g_t.len() != n * d || g_beta.len() != k {
3321 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3322 reason: format!(
3323 "resident gradient shape mismatch: g_t={} (want {}), g_beta={} (want {})",
3324 g_t.len(),
3325 n * d,
3326 g_beta.len(),
3327 k
3328 ),
3329 });
3330 }
3331 let mut u_dev = self
3333 .stream
3334 .clone_htod(&g_t.to_vec())
3335 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3336 trsm_batched_lower_inplace(&self.blas, &self.stream, d, n, 1, &self.l_dev, &mut u_dev)?;
3337
3338 let rhs_init: Vec<f64> = g_beta.iter().map(|v| -v).collect();
3341 let mut rhs_dev = self
3342 .stream
3343 .clone_htod(&rhs_init)
3344 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3345 accumulate_schur_rhs_only(&self.blas, d, k, n, &self.y_dev, &u_dev, &mut rhs_dev)?;
3346
3347 trsm_single(
3349 &self.blas,
3350 &self.stream,
3351 k,
3352 &self.schur_dev,
3353 &mut rhs_dev,
3354 false,
3355 false,
3356 )?;
3357 trsm_single(
3358 &self.blas,
3359 &self.stream,
3360 k,
3361 &self.schur_dev,
3362 &mut rhs_dev,
3363 false,
3364 true,
3365 )?;
3366 let delta_beta_host = self
3367 .stream
3368 .clone_dtoh(&rhs_dev)
3369 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3370 let delta_beta = Array1::from_vec(delta_beta_host);
3371
3372 accumulate_back_sub_rhs(&self.blas, d, k, n, &self.y_dev, &rhs_dev, &mut u_dev)?;
3374 trsm_batched_lower_inplace_transposed(
3375 &self.blas,
3376 &self.stream,
3377 d,
3378 n,
3379 1,
3380 &self.l_dev,
3381 &mut u_dev,
3382 )?;
3383 let x_host = self
3384 .stream
3385 .clone_dtoh(&u_dev)
3386 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3387 let mut delta_t = Array1::<f64>::zeros(n * d);
3388 for (i, v) in x_host.iter().enumerate() {
3389 delta_t[i] = -*v;
3390 }
3391
3392 Ok(ArrowSchurGpuSolution {
3393 delta_t,
3394 delta_beta,
3395 log_det_hessian: self.log_det_hessian,
3396 })
3397 }
3398 }
3399
3400 pub(super) fn solve_fused(
3401 sys: &ArrowSchurSystem,
3402 ridge_t: f64,
3403 ridge_beta: f64,
3404 ) -> Result<ArrowSchurGpuSolution, ArrowSchurGpuFailure> {
3405 let n = sys.rows.len();
3406 let d = sys.d;
3407 let k = sys.k;
3408 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3409 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3410 let p_max = plan.p_max;
3411 let r_template = plan.r_template;
3412
3413 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3414 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3415 )
3416 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3417 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3418 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
3419 let stream = ctx
3420 .new_stream()
3421 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3422 let cap = &runtime.device.capability;
3423 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3424 cc_major: cap.compute_major,
3425 cc_minor: cap.compute_minor,
3426 p_max: p_max as u32,
3427 r_template: r_template as u32,
3428 };
3429 let module = fused_module_for(&ctx, key)?;
3430 let forward = module
3431 .load_function("arrow_schur_forward_pgroup")
3432 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3433 let back_sub = module
3434 .load_function("arrow_schur_back_sub_pgroup")
3435 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3436
3437 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3439 let d_dev = stream
3440 .clone_htod(&d_host)
3441 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3442 let b_dev = stream
3443 .clone_htod(&b_host)
3444 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3445 let g_dev = stream
3446 .clone_htod(&g_host)
3447 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3448 let mut l_out = stream
3449 .alloc_zeros::<f64>(n * p_max * p_max)
3450 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3451 let mut u_out = stream
3452 .alloc_zeros::<f64>(n * p_max)
3453 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3454 let mut y_out = stream
3455 .alloc_zeros::<f64>(n * p_max * r_template)
3456 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3457 let mut partial_s = stream
3458 .alloc_zeros::<f64>(plan.partial_s_doubles)
3459 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3460 let mut partial_r = stream
3461 .alloc_zeros::<f64>(plan.partial_r_doubles)
3462 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3463 let mut status_dev = stream
3464 .alloc_zeros::<i32>(n)
3465 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3466
3467 let cfg = LaunchConfig {
3469 grid_dim: (plan.blocks, 1, 1),
3470 block_dim: (plan.threads_per_block, 1, 1),
3471 shared_mem_bytes: 0,
3472 };
3473 let n_i32 = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3474 let p_i32 = to_i32(d).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3475 let r_i32 = to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?;
3476 let ridge_arg = ridge_t;
3477 {
3478 let mut builder = stream.launch_builder(&forward);
3479 builder
3480 .arg(&d_dev)
3481 .arg(&b_dev)
3482 .arg(&g_dev)
3483 .arg(&n_i32)
3484 .arg(&p_i32)
3485 .arg(&r_i32)
3486 .arg(&ridge_arg)
3487 .arg(&mut l_out)
3488 .arg(&mut u_out)
3489 .arg(&mut y_out)
3490 .arg(&mut partial_s)
3491 .arg(&mut partial_r)
3492 .arg(&mut status_dev);
3493 unsafe { builder.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3497 }
3498 stream
3499 .synchronize()
3500 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3501
3502 let status_host = stream
3504 .clone_dtoh(&status_dev)
3505 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3506 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3507 return Err(ArrowSchurGpuFailure::RidgeBumpRequired {
3511 row,
3512 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3513 });
3514 }
3515
3516 let partial_s_host = stream
3518 .clone_dtoh(&partial_s)
3519 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3520 let partial_r_host = stream
3521 .clone_dtoh(&partial_r)
3522 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3523 let mut schur_host = vec![0.0_f64; k * k];
3524 for col in 0..k {
3525 for row in 0..k {
3526 let mut v = sys.hbb[[row, col]];
3527 if row == col {
3528 v += ridge_beta;
3529 }
3530 schur_host[col * k + row] = v;
3531 }
3532 }
3533 let mut rhs_host: Vec<f64> = sys.gb.iter().map(|v| -v).collect();
3534 for i in 0..n {
3535 let s_base = i * r_template * r_template;
3538 for col in 0..k {
3539 let col_base = s_base + col * r_template;
3540 let dst_col_base = col * k;
3541 for row in 0..k {
3542 schur_host[dst_col_base + row] -= partial_s_host[col_base + row];
3543 }
3544 }
3545 let r_base = i * r_template;
3546 for a in 0..k {
3547 rhs_host[a] += partial_r_host[r_base + a];
3548 }
3549 }
3550
3551 let mut schur_dev = stream
3553 .clone_htod(&schur_host)
3554 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3555 let mut rhs_dev = stream
3556 .clone_htod(&rhs_host)
3557 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3558 let solver =
3559 DnHandle::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3560 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3561 let info = potrf_single(&solver, &stream, k, &mut schur_dev)?;
3562 if info != 0 {
3563 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3564 reason: format!("fused Schur Cholesky failed at pivot {info}"),
3565 });
3566 }
3567 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, false)?;
3568 trsm_single(&blas, &stream, k, &schur_dev, &mut rhs_dev, false, true)?;
3569 let delta_beta_host = stream
3570 .clone_dtoh(&rhs_dev)
3571 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3572 let delta_beta = Array1::from_vec(delta_beta_host.clone());
3573
3574 let mut delta_t_dev = stream
3576 .alloc_zeros::<f64>(n * p_max)
3577 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3578 let back_cfg = LaunchConfig {
3579 grid_dim: (plan.blocks, 1, 1),
3580 block_dim: (plan.threads_per_block, 1, 1),
3581 shared_mem_bytes: 0,
3582 };
3583 {
3584 let mut builder = stream.launch_builder(&back_sub);
3585 builder
3586 .arg(&l_out)
3587 .arg(&u_out)
3588 .arg(&y_out)
3589 .arg(&rhs_dev)
3590 .arg(&n_i32)
3591 .arg(&p_i32)
3592 .arg(&r_i32)
3593 .arg(&mut delta_t_dev);
3594 unsafe { builder.launch(back_cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3598 }
3599 stream
3600 .synchronize()
3601 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3602
3603 let delta_t_host = stream
3604 .clone_dtoh(&delta_t_dev)
3605 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3606 let mut delta_t = Array1::<f64>::zeros(n * d);
3607 for i in 0..n {
3608 let src_base = i * p_max;
3609 let dst_base = i * d;
3610 for r in 0..d {
3611 delta_t[dst_base + r] = delta_t_host[src_base + r];
3612 }
3613 }
3614
3615 let l_local_host = stream
3617 .clone_dtoh(&l_out)
3618 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3619 let l_schur_host = stream
3620 .clone_dtoh(&schur_dev)
3621 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
3622 let mut log_det = 0.0_f64;
3623 for i in 0..n {
3624 let base = i * p_max * p_max;
3625 for j in 0..d {
3626 log_det += l_local_host[base + j * p_max + j].ln();
3627 }
3628 }
3629 for j in 0..k {
3630 log_det += l_schur_host[j * k + j].ln();
3631 }
3632 log_det *= 2.0;
3633
3634 Ok(ArrowSchurGpuSolution {
3635 delta_t,
3636 delta_beta,
3637 log_det_hessian: log_det,
3638 })
3639 }
3640
3641 pub(super) fn build_schur_matvec_backend(
3651 sys: &ArrowSchurSystem,
3652 ridge_t: f64,
3653 ridge_beta: f64,
3654 ) -> Result<crate::arrow_schur::GpuSchurMatvec, super::ArrowSchurGpuFailure> {
3655 let n = sys.rows.len();
3656 let d = sys.d;
3657 let k = sys.k;
3658 let plan = crate::gpu_kernels::arrow_schur_nvrtc::plan_fused_launch(n, d, k)
3659 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3660 let p_max = plan.p_max;
3661 let r_template = plan.r_template;
3662
3663 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
3664 gam_gpu::linalg_dispatch::DispatchOp::SmallDenseBatchedPotrf { p: d, batch: n },
3665 )
3666 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3667 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
3668 .ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3669 let stream = ctx
3670 .new_stream()
3671 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3672 let cap = &runtime.device.capability;
3673 let key = crate::gpu_kernels::arrow_schur_nvrtc::FusedModuleCacheKey {
3674 cc_major: cap.compute_major,
3675 cc_minor: cap.compute_minor,
3676 p_max: p_max as u32,
3677 r_template: r_template as u32,
3678 };
3679 let module = fused_module_for(&ctx, key)?;
3680 let forward = module
3681 .load_function("arrow_schur_forward_pgroup")
3682 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3683
3684 let (d_host, b_host, g_host) = pack_fused_host(sys, ridge_t, p_max, r_template);
3685 let d_dev = stream
3686 .clone_htod(&d_host)
3687 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3688 let b_dev = stream
3689 .clone_htod(&b_host)
3690 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3691 let g_dev = stream
3692 .clone_htod(&g_host)
3693 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3694 let mut l_out = stream
3695 .alloc_zeros::<f64>(n * p_max * p_max)
3696 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3697 let mut u_out = stream
3698 .alloc_zeros::<f64>(n * p_max)
3699 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3700 let mut y_out = stream
3701 .alloc_zeros::<f64>(n * p_max * r_template)
3702 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3703 let mut partial_s = stream
3704 .alloc_zeros::<f64>(plan.partial_s_doubles)
3705 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3706 let mut partial_r = stream
3707 .alloc_zeros::<f64>(plan.partial_r_doubles)
3708 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3709 let mut status_dev = stream
3710 .alloc_zeros::<i32>(n)
3711 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3712
3713 let cfg = LaunchConfig {
3714 grid_dim: (plan.blocks, 1, 1),
3715 block_dim: (plan.threads_per_block, 1, 1),
3716 shared_mem_bytes: 0,
3717 };
3718 let n_i32 = to_i32(n).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3719 let p_i32 = to_i32(d).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3720 let r_i32 = to_i32(k).ok_or(super::ArrowSchurGpuFailure::Unavailable)?;
3721 let ridge_arg = ridge_t;
3722 {
3723 let mut builder = stream.launch_builder(&forward);
3724 builder
3725 .arg(&d_dev)
3726 .arg(&b_dev)
3727 .arg(&g_dev)
3728 .arg(&n_i32)
3729 .arg(&p_i32)
3730 .arg(&r_i32)
3731 .arg(&ridge_arg)
3732 .arg(&mut l_out)
3733 .arg(&mut u_out)
3734 .arg(&mut y_out)
3735 .arg(&mut partial_s)
3736 .arg(&mut partial_r)
3737 .arg(&mut status_dev);
3738 unsafe { builder.launch(cfg) }.map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3741 }
3742 stream
3743 .synchronize()
3744 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3745
3746 let status_host = stream
3747 .clone_dtoh(&status_dev)
3748 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3749 if let Some(row) = status_host.iter().position(|s| *s != 0) {
3750 return Err(super::ArrowSchurGpuFailure::RidgeBumpRequired {
3754 row,
3755 bump: super::ridge_bump_to_make_pd(sys.rows[row].htt.view(), ridge_t),
3756 });
3757 }
3758
3759 let y_host = stream
3761 .clone_dtoh(&y_out)
3762 .map_err(|_| super::ArrowSchurGpuFailure::Unavailable)?;
3763
3764 let hbb_host: Vec<f64> = sys.hbb.iter().copied().collect();
3767 let hbb_is_kk = sys.hbb.dim() == (k, k);
3768 let hbb_matvec_opt = sys.hbb_matvec.clone();
3769
3770 let closure: crate::arrow_schur::GpuSchurMatvec =
3771 Arc::new(move |x: &Array1<f64>, out: &mut Array1<f64>| {
3772 assert_eq!(x.len(), k, "gpu_schur_matvec: x.len() != k");
3773 assert_eq!(out.len(), k, "gpu_schur_matvec: out.len() != k");
3774
3775 if let Some(ref mv) = hbb_matvec_opt {
3777 mv(x.view(), out);
3778 for a in 0..k {
3779 out[a] += ridge_beta * x[a];
3780 }
3781 } else if hbb_is_kk {
3782 for a in 0..k {
3784 let mut acc = ridge_beta * x[a];
3785 for b in 0..k {
3786 acc += hbb_host[a * k + b] * x[b];
3787 }
3788 out[a] = acc;
3789 }
3790 } else {
3791 for a in 0..k {
3792 out[a] = ridge_beta * x[a];
3793 }
3794 }
3795
3796 let mut z = vec![0.0_f64; d];
3799 for i in 0..n {
3800 let y_base = i * p_max * r_template;
3801 for r in 0..d {
3802 let mut acc = 0.0;
3803 for c in 0..k {
3804 acc += y_host[y_base + c * p_max + r] * x[c];
3805 }
3806 z[r] = acc;
3807 }
3808 for c in 0..k {
3809 let mut acc = 0.0;
3810 for r in 0..d {
3811 acc += y_host[y_base + c * p_max + r] * z[r];
3812 }
3813 out[c] -= acc;
3814 }
3815 }
3816 });
3817
3818 Ok(closure)
3819 }
3820
3821 struct DeviceSaeFrameBuffers {
3824 s_off: CudaSlice<i32>,
3826 s_m: CudaSlice<i32>,
3827 s_r: CudaSlice<i32>,
3828 s_ptr: CudaSlice<i32>,
3829 s_data: CudaSlice<f64>,
3830 s_blocks: usize,
3831 g_off_i: CudaSlice<i32>,
3833 g_off_j: CudaSlice<i32>,
3834 g_ri: CudaSlice<i32>,
3835 g_rj: CudaSlice<i32>,
3836 g_mi: CudaSlice<i32>,
3837 g_mj: CudaSlice<i32>,
3838 g_ptr: CudaSlice<i32>,
3839 g_data: CudaSlice<f64>,
3840 w_ptr: CudaSlice<i32>,
3841 w_data: CudaSlice<f64>,
3842 g_blocks: usize,
3843 g_max_work: usize,
3844 htb_ptr: CudaSlice<i32>,
3846 htb: CudaSlice<f64>,
3847 q_of: CudaSlice<i32>,
3848 ainv: CudaSlice<f64>,
3849 hvec: CudaSlice<f64>,
3850 svec: CudaSlice<f64>,
3851 n_rows: usize,
3852 k: usize,
3853 max_q: usize,
3854 }
3855
3856 fn flatten_device_sae_frame_data(
3857 sys: &ArrowSchurSystem,
3858 data: &DeviceSaePcgData,
3859 frame: &DeviceSaeFrameData,
3860 ridge_t: f64,
3861 stream: &Arc<CudaStream>,
3862 ) -> Result<DeviceSaeFrameBuffers, ArrowSchurGpuFailure> {
3863 let n_rows = sys.rows.len();
3864 let k = data.beta_dim;
3865 if frame.row_htbeta.len() != n_rows
3866 || frame.ranks.len() != frame.basis_sizes.len()
3867 || frame.border_offsets.len() != frame.ranks.len()
3868 || data.smooth_blocks.len() != frame.smooth_ranks.len()
3869 {
3870 return Err(ArrowSchurGpuFailure::Unavailable);
3871 }
3872
3873 let mut s_off = Vec::new();
3875 let mut s_m = Vec::new();
3876 let mut s_r = Vec::new();
3877 let mut s_ptr = vec![0_i32];
3878 let mut s_data = Vec::<f64>::new();
3879 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
3880 let (m, mc) = blk.factor_a.dim();
3881 if m != mc {
3882 return Err(ArrowSchurGpuFailure::Unavailable);
3883 }
3884 s_off.push(checked_i32(blk.global_offset)?);
3885 s_m.push(checked_i32(m)?);
3886 s_r.push(checked_i32(r)?);
3887 for ri in 0..m {
3888 for ci in 0..m {
3889 s_data.push(blk.factor_a[[ri, ci]]);
3890 }
3891 }
3892 s_ptr.push(checked_i32(s_data.len())?);
3893 }
3894
3895 let mut g_off_i = Vec::new();
3897 let mut g_off_j = Vec::new();
3898 let mut g_ri = Vec::new();
3899 let mut g_rj = Vec::new();
3900 let mut g_mi = Vec::new();
3901 let mut g_mj = Vec::new();
3902 let mut g_ptr = vec![0_i32];
3903 let mut g_data = Vec::<f64>::new();
3904 let mut w_ptr = vec![0_i32];
3905 let mut w_data = Vec::<f64>::new();
3906 let mut g_max_work = 0usize;
3907 for blk in &frame.frame_blocks {
3908 let ri = frame.ranks[blk.atom_i];
3909 let rj = frame.ranks[blk.atom_j];
3910 let (mi, mj) = blk.g.dim();
3911 if blk.w.dim() != (ri, rj) {
3912 return Err(ArrowSchurGpuFailure::Unavailable);
3913 }
3914 g_off_i.push(checked_i32(frame.border_offsets[blk.atom_i])?);
3915 g_off_j.push(checked_i32(frame.border_offsets[blk.atom_j])?);
3916 g_ri.push(checked_i32(ri)?);
3917 g_rj.push(checked_i32(rj)?);
3918 g_mi.push(checked_i32(mi)?);
3919 g_mj.push(checked_i32(mj)?);
3920 for r in 0..mi {
3921 for c in 0..mj {
3922 g_data.push(blk.g[[r, c]]);
3923 }
3924 }
3925 g_ptr.push(checked_i32(g_data.len())?);
3926 for a in 0..ri {
3927 for b in 0..rj {
3928 w_data.push(blk.w[[a, b]]);
3929 }
3930 }
3931 w_ptr.push(checked_i32(w_data.len())?);
3932 g_max_work = g_max_work.max(mi * ri);
3933 }
3934
3935 let mut htb_ptr = vec![0_i32];
3937 let mut htb = Vec::<f64>::new();
3938 let mut q_of = Vec::<i32>::with_capacity(n_rows);
3939 let mut max_q = 0usize;
3940 for (i, slab) in frame.row_htbeta.iter().enumerate() {
3941 let qi = sys.row_dims[i];
3942 let q_eff = if !slab.is_empty() && slab.len() == qi * k {
3945 qi
3946 } else {
3947 0
3948 };
3949 q_of.push(checked_i32(q_eff)?);
3950 max_q = max_q.max(q_eff);
3951 if q_eff > 0 {
3952 htb.extend_from_slice(slab);
3953 }
3954 htb_ptr.push(checked_i32(htb.len())?);
3955 }
3956 if max_q == 0 {
3957 max_q = 1;
3960 }
3961
3962 let mut ainv = vec![0.0_f64; n_rows * max_q * max_q];
3963 for (i, row) in sys.rows.iter().enumerate() {
3964 let q = q_of[i] as usize;
3965 if q == 0 {
3966 continue;
3967 }
3968 if row.htt.dim() != (q, q) {
3969 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
3970 reason: format!(
3971 "framed SAE device PCG row {i}: H_tt shape {:?} != ({q}, {q})",
3972 row.htt.dim()
3973 ),
3974 });
3975 }
3976 let mut block = row.htt.clone();
3977 for d in 0..q {
3978 block[[d, d]] += ridge_t;
3979 }
3980 let factor = gam_linalg::triangular::cholesky_factor_in_place(
3981 block.view(),
3982 gam_linalg::triangular::CholeskyGuard::NonnegativePivot,
3983 )
3984 .ok_or_else(|| {
3985 ArrowSchurGpuFailure::RidgeBumpRequired {
3988 row: i,
3989 bump: super::ridge_bump_to_make_pd(row.htt.view(), ridge_t),
3990 }
3991 })?;
3992 for col in 0..q {
3993 let mut e = Array1::<f64>::zeros(q);
3994 e[col] = 1.0;
3995 let solved =
3996 gam_linalg::triangular::cholesky_solve_vector(factor.view(), e.view());
3997 for r in 0..q {
3998 ainv[i * max_q * max_q + r * max_q + col] = solved[r];
3999 }
4000 }
4001 }
4002
4003 let htod_i = |v: &[i32]| {
4004 stream
4005 .clone_htod(v)
4006 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4007 };
4008 let htod_f = |v: &[f64]| {
4009 stream
4010 .clone_htod(v)
4011 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4012 };
4013 Ok(DeviceSaeFrameBuffers {
4014 s_off: htod_i(&s_off)?,
4015 s_m: htod_i(&s_m)?,
4016 s_r: htod_i(&s_r)?,
4017 s_ptr: htod_i(&s_ptr)?,
4018 s_data: htod_f(&s_data)?,
4019 s_blocks: data.smooth_blocks.len(),
4020 g_off_i: htod_i(&g_off_i)?,
4021 g_off_j: htod_i(&g_off_j)?,
4022 g_ri: htod_i(&g_ri)?,
4023 g_rj: htod_i(&g_rj)?,
4024 g_mi: htod_i(&g_mi)?,
4025 g_mj: htod_i(&g_mj)?,
4026 g_ptr: htod_i(&g_ptr)?,
4027 g_data: htod_f(&g_data)?,
4028 w_ptr: htod_i(&w_ptr)?,
4029 w_data: htod_f(&w_data)?,
4030 g_blocks: frame.frame_blocks.len(),
4031 g_max_work,
4032 htb_ptr: htod_i(&htb_ptr)?,
4033 htb: htod_f(&htb)?,
4034 q_of: htod_i(&q_of)?,
4035 ainv: htod_f(&ainv)?,
4036 hvec: stream
4037 .alloc_zeros::<f64>(n_rows * max_q)
4038 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4039 svec: stream
4040 .alloc_zeros::<f64>(n_rows * max_q)
4041 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?,
4042 n_rows,
4043 k,
4044 max_q,
4045 })
4046 }
4047
4048 fn sae_frame_penalty_diag_host(
4049 data: &DeviceSaePcgData,
4050 frame: &DeviceSaeFrameData,
4051 ridge_beta: f64,
4052 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4053 let mut diag = vec![ridge_beta; data.beta_dim];
4054 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
4056 let m = blk.factor_a.nrows();
4057 for ia in 0..m {
4058 let coeff = blk.factor_a[[ia, ia]];
4059 let base = blk.global_offset + ia * r;
4060 for ib in 0..r {
4061 if base + ib >= diag.len() {
4062 return Err(ArrowSchurGpuFailure::Unavailable);
4063 }
4064 diag[base + ib] += coeff;
4065 }
4066 }
4067 }
4068 for blk in &frame.frame_blocks {
4070 if blk.atom_i != blk.atom_j {
4071 continue;
4072 }
4073 let r = frame.ranks[blk.atom_i];
4074 let off = frame.border_offsets[blk.atom_i];
4075 let (mi, mj) = blk.g.dim();
4076 for li in 0..mi.min(mj) {
4077 let gii = blk.g[[li, li]];
4078 let base = off + li * r;
4079 for a in 0..r {
4080 if base + a >= diag.len() {
4081 return Err(ArrowSchurGpuFailure::Unavailable);
4082 }
4083 diag[base + a] += gii * blk.w[[a, a]];
4084 }
4085 }
4086 }
4087 Ok(diag)
4088 }
4089
4090 fn frame_grid(work: usize, n_rows: usize) -> Result<LaunchConfig, ArrowSchurGpuFailure> {
4091 Ok(LaunchConfig {
4092 grid_dim: (
4093 ((work as u32).saturating_add(255) / 256).max(1),
4094 checked_i32(n_rows)? as u32,
4095 1,
4096 ),
4097 block_dim: (256, 1, 1),
4098 shared_mem_bytes: 0,
4099 })
4100 }
4101
4102 fn launch_sae_frame_matvec(
4103 stream: &Arc<CudaStream>,
4104 module: &Arc<CudaModule>,
4105 buffers: &mut DeviceSaeFrameBuffers,
4106 x: &CudaSlice<f64>,
4107 out: &mut CudaSlice<f64>,
4108 ridge_beta: f64,
4109 ) -> Result<(), ArrowSchurGpuFailure> {
4110 launch_sae_init(stream, module, out, x, ridge_beta, buffers.k)?;
4111 if buffers.s_blocks > 0 {
4113 let kernel = module
4114 .load_function("arrow_sae_frame_smooth_matvec")
4115 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4116 let blocks_i32 = checked_i32(buffers.s_blocks)?;
4117 let cfg = frame_grid(buffers.k, buffers.s_blocks)?;
4118 let mut b = stream.launch_builder(&kernel);
4119 b.arg(x)
4120 .arg(&mut *out)
4121 .arg(&buffers.s_off)
4122 .arg(&buffers.s_m)
4123 .arg(&buffers.s_r)
4124 .arg(&buffers.s_ptr)
4125 .arg(&buffers.s_data)
4126 .arg(&blocks_i32);
4127 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4130 }
4131 if buffers.g_blocks > 0 {
4133 let kernel = module
4134 .load_function("arrow_sae_frame_g_matvec")
4135 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4136 let blocks_i32 = checked_i32(buffers.g_blocks)?;
4137 let cfg = frame_grid(buffers.g_max_work.max(1), buffers.g_blocks)?;
4138 let mut b = stream.launch_builder(&kernel);
4139 b.arg(x)
4140 .arg(&mut *out)
4141 .arg(&buffers.g_off_i)
4142 .arg(&buffers.g_off_j)
4143 .arg(&buffers.g_ri)
4144 .arg(&buffers.g_rj)
4145 .arg(&buffers.g_mi)
4146 .arg(&buffers.g_mj)
4147 .arg(&buffers.g_ptr)
4148 .arg(&buffers.g_data)
4149 .arg(&buffers.w_ptr)
4150 .arg(&buffers.w_data)
4151 .arg(&blocks_i32);
4152 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4155 }
4156 let k_i32 = checked_i32(buffers.k)?;
4158 let max_q_i32 = checked_i32(buffers.max_q)?;
4159 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4160 {
4161 let kernel = module
4162 .load_function("arrow_sae_frame_apply_h")
4163 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4164 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4165 let mut b = stream.launch_builder(&kernel);
4166 b.arg(x)
4167 .arg(&buffers.htb_ptr)
4168 .arg(&buffers.htb)
4169 .arg(&buffers.q_of)
4170 .arg(&mut buffers.hvec)
4171 .arg(&k_i32)
4172 .arg(&max_q_i32)
4173 .arg(&n_rows_i32);
4174 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4177 }
4178 {
4179 let kernel = module
4180 .load_function("arrow_sae_frame_apply_ainv")
4181 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4182 let cfg = frame_grid(buffers.max_q, buffers.n_rows)?;
4183 let mut b = stream.launch_builder(&kernel);
4184 b.arg(&buffers.ainv)
4185 .arg(&buffers.hvec)
4186 .arg(&buffers.q_of)
4187 .arg(&mut buffers.svec)
4188 .arg(&max_q_i32)
4189 .arg(&n_rows_i32);
4190 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4193 }
4194 {
4195 let kernel = module
4196 .load_function("arrow_sae_frame_scatter_h")
4197 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4198 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4199 let mut b = stream.launch_builder(&kernel);
4200 b.arg(&buffers.svec)
4201 .arg(&buffers.htb_ptr)
4202 .arg(&buffers.htb)
4203 .arg(&buffers.q_of)
4204 .arg(out)
4205 .arg(&k_i32)
4206 .arg(&max_q_i32)
4207 .arg(&n_rows_i32);
4208 unsafe { b.launch(cfg) }.map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4211 }
4212 Ok(())
4213 }
4214
4215 fn launch_sae_frame_diag_sub(
4216 stream: &Arc<CudaStream>,
4217 module: &Arc<CudaModule>,
4218 buffers: &DeviceSaeFrameBuffers,
4219 diag: &mut CudaSlice<f64>,
4220 ) -> Result<(), ArrowSchurGpuFailure> {
4221 let kernel = module
4222 .load_function("arrow_sae_frame_diag_sub")
4223 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4224 let k_i32 = checked_i32(buffers.k)?;
4225 let max_q_i32 = checked_i32(buffers.max_q)?;
4226 let n_rows_i32 = checked_i32(buffers.n_rows)?;
4227 let cfg = frame_grid(buffers.k, buffers.n_rows)?;
4228 let mut b = stream.launch_builder(&kernel);
4229 b.arg(diag)
4230 .arg(&buffers.ainv)
4231 .arg(&buffers.htb_ptr)
4232 .arg(&buffers.htb)
4233 .arg(&buffers.q_of)
4234 .arg(&k_i32)
4235 .arg(&max_q_i32)
4236 .arg(&n_rows_i32);
4237 unsafe { b.launch(cfg) }
4239 .map(drop)
4240 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4241 }
4242
4243 pub(super) fn framed_schur_matvec_once_on_device(
4254 sys: &ArrowSchurSystem,
4255 data: &DeviceSaePcgData,
4256 ridge_t: f64,
4257 ridge_beta: f64,
4258 x: &Array1<f64>,
4259 ) -> Result<Array1<f64>, ArrowSchurGpuFailure> {
4260 let k = x.len();
4261 if k == 0 || data.beta_dim != k || sys.k != k {
4262 return Err(ArrowSchurGpuFailure::Unavailable);
4263 }
4264 let frame = data
4265 .frame
4266 .as_ref()
4267 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4268 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4271 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4272 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4273 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4274 let stream = ctx
4275 .new_stream()
4276 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4277 let vector_module = pcg_vector_module(&ctx)?;
4278 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4279 let x_dev = stream
4280 .clone_htod(x.as_slice().ok_or(ArrowSchurGpuFailure::Unavailable)?)
4281 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4282 let mut out_dev = stream
4283 .alloc_zeros::<f64>(k)
4284 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4285 launch_sae_frame_matvec(
4286 &stream,
4287 vector_module,
4288 &mut buffers,
4289 &x_dev,
4290 &mut out_dev,
4291 ridge_beta,
4292 )?;
4293 let out = stream
4294 .clone_dtoh(&out_dev)
4295 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4296 Ok(Array1::from_vec(out))
4297 }
4298
4299 pub(super) fn solve_sae_matrix_free_pcg_framed(
4300 sys: &ArrowSchurSystem,
4301 data: &DeviceSaePcgData,
4302 ridge_t: f64,
4303 ridge_beta: f64,
4304 rhs_beta: &Array1<f64>,
4305 max_iterations: usize,
4306 relative_tolerance: f64,
4307 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4308 let k = rhs_beta.len();
4309 if k == 0 || data.beta_dim != k || sys.k != k {
4310 return Err(ArrowSchurGpuFailure::Unavailable);
4311 }
4312 let frame = data
4313 .frame
4314 .as_ref()
4315 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4316 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4317 .filter(|rt| {
4318 rt.policy().reduced_schur_matvec_should_offload(
4319 sys.rows.len(),
4320 sys.k,
4321 sys.d,
4322 max_iterations,
4323 )
4324 })
4325 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4326 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4327 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4328 let stream = ctx
4329 .new_stream()
4330 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4331 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4332 let vector_module = pcg_vector_module(&ctx)?;
4333 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4334
4335 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4336 if rhs_norm == 0.0 {
4337 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4338 }
4339 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4340 let rhs_dev = stream
4341 .clone_htod(
4342 rhs_beta
4343 .as_slice()
4344 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4345 )
4346 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4347 let diag_host = sae_frame_penalty_diag_host(data, frame, ridge_beta)?;
4348 let mut diag_dev = stream
4349 .clone_htod(&diag_host)
4350 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4351 launch_sae_frame_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4352 let diag_host = stream
4353 .clone_dtoh(&diag_dev)
4354 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4355 let mut inv_diag = Vec::with_capacity(k);
4356 for (idx, &d) in diag_host.iter().enumerate() {
4357 if !d.is_finite() || d <= 1.0e-18 {
4358 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4359 reason: format!(
4360 "framed SAE GPU PCG: non-positive Jacobi diagonal at {idx}: {d:e}"
4361 ),
4362 });
4363 }
4364 inv_diag.push(1.0 / d);
4365 }
4366 let inv_diag_dev = stream
4367 .clone_htod(&inv_diag)
4368 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4369
4370 let mut x_dev = stream
4371 .alloc_zeros::<f64>(k)
4372 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4373 let mut r_dev = stream
4374 .alloc_zeros::<f64>(k)
4375 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4376 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4377 let mut z_dev = stream
4378 .alloc_zeros::<f64>(k)
4379 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4380 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4381 let mut p_dev = stream
4382 .alloc_zeros::<f64>(k)
4383 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4384 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4385 let mut ap_dev = stream
4386 .alloc_zeros::<f64>(k)
4387 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4388
4389 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4390 if rz <= 0.0 || !rz.is_finite() {
4391 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4392 reason: format!("framed SAE GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4393 });
4394 }
4395 let mut diag = PcgDiagnostics {
4396 precond_apply_calls: 1,
4397 stopping_reason: PcgStopReason::MaxIter,
4398 ..PcgDiagnostics::default()
4399 };
4400 for _ in 0..max_iterations.max(1) {
4401 launch_sae_frame_matvec(
4402 &stream,
4403 vector_module,
4404 &mut buffers,
4405 &p_dev,
4406 &mut ap_dev,
4407 ridge_beta,
4408 )?;
4409 diag.matvec_calls += 1;
4410 diag.iterations += 1;
4411 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4412 if pap <= 0.0 || !pap.is_finite() {
4413 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4414 reason: format!("framed SAE GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4415 });
4416 }
4417 let alpha = rz / pap;
4418 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4419 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4420 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4421 if r_norm <= tol {
4422 diag.final_relative_residual = r_norm / rhs_norm;
4423 diag.stopping_reason = PcgStopReason::Converged;
4424 break;
4425 }
4426 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4427 diag.precond_apply_calls += 1;
4428 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4429 if rz_new <= 0.0 || !rz_new.is_finite() {
4430 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4431 reason: format!("framed SAE GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4432 });
4433 }
4434 let beta = rz_new / rz;
4435 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4436 rz = rz_new;
4437 }
4438 if diag.stopping_reason != PcgStopReason::Converged {
4439 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4440 diag.final_relative_residual = r_norm / rhs_norm;
4441 diag.stopping_reason = PcgStopReason::MaxIter;
4442 }
4443 let x = stream
4444 .clone_dtoh(&x_dev)
4445 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4446 Ok((Array1::from_vec(x), diag))
4447 }
4448
4449 pub(super) fn solve_sae_matrix_free_pcg(
4456 sys: &ArrowSchurSystem,
4457 data: &DeviceSaePcgData,
4458 ridge_t: f64,
4459 ridge_beta: f64,
4460 rhs_beta: &Array1<f64>,
4461 max_iterations: usize,
4462 relative_tolerance: f64,
4463 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4464 let k = rhs_beta.len();
4465 if k == 0 || data.beta_dim != k || sys.k != k {
4466 return Err(ArrowSchurGpuFailure::Unavailable);
4467 }
4468 if data.frame.is_some() {
4472 return Err(ArrowSchurGpuFailure::Unavailable);
4473 }
4474 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4488 .filter(|rt| {
4489 rt.policy().reduced_schur_matvec_should_offload(
4490 sys.rows.len(),
4491 sys.k,
4492 sys.d,
4493 max_iterations,
4494 )
4495 })
4496 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4497 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4498 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4499 let stream = ctx
4500 .new_stream()
4501 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4502 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4503 let vector_module = pcg_vector_module(&ctx)?;
4504 let mut buffers = flatten_device_sae_data(sys, data, ridge_t, &stream)?;
4505
4506 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4507 if rhs_norm == 0.0 {
4508 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4509 }
4510 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4511 let rhs_dev = stream
4512 .clone_htod(
4513 rhs_beta
4514 .as_slice()
4515 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4516 )
4517 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4518 let diag_host = sae_penalty_diag_host(data, ridge_beta)?;
4519 let mut diag_dev = stream
4520 .clone_htod(&diag_host)
4521 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4522 launch_sae_diag_sub(&stream, vector_module, &buffers, &mut diag_dev)?;
4523 let diag_host = stream
4524 .clone_dtoh(&diag_dev)
4525 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4526 let mut inv_diag = Vec::with_capacity(k);
4527 for (idx, &d) in diag_host.iter().enumerate() {
4528 if !d.is_finite() || d <= 1.0e-18 {
4529 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4530 reason: format!(
4531 "SAE matrix-free GPU PCG: non-positive Schur Jacobi diagonal at {idx}: {d:e}"
4532 ),
4533 });
4534 }
4535 inv_diag.push(1.0 / d);
4536 }
4537 let inv_diag_dev = stream
4538 .clone_htod(&inv_diag)
4539 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4540
4541 let mut x_dev = stream
4542 .alloc_zeros::<f64>(k)
4543 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4544 let mut r_dev = stream
4545 .alloc_zeros::<f64>(k)
4546 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4547 device_copy(&blas, &stream, k, &rhs_dev, &mut r_dev)?;
4548 let mut z_dev = stream
4549 .alloc_zeros::<f64>(k)
4550 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4551 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4552 let mut p_dev = stream
4553 .alloc_zeros::<f64>(k)
4554 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4555 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4556 let mut ap_dev = stream
4557 .alloc_zeros::<f64>(k)
4558 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4559
4560 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4561 if rz <= 0.0 || !rz.is_finite() {
4562 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4563 reason: format!("SAE matrix-free GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4564 });
4565 }
4566 let mut diag = PcgDiagnostics {
4567 precond_apply_calls: 1,
4568 stopping_reason: PcgStopReason::MaxIter,
4569 ..PcgDiagnostics::default()
4570 };
4571
4572 for _ in 0..max_iterations.max(1) {
4573 launch_sae_matvec(
4574 &stream,
4575 vector_module,
4576 &mut buffers,
4577 &p_dev,
4578 &mut ap_dev,
4579 ridge_beta,
4580 )?;
4581 diag.matvec_calls += 1;
4582 diag.iterations += 1;
4583 let pap = device_dot(&blas, &stream, k, &p_dev, &ap_dev)?;
4584 if pap <= 0.0 || !pap.is_finite() {
4585 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4586 reason: format!("SAE matrix-free GPU PCG: non-positive curvature pᵀAp={pap:e}"),
4587 });
4588 }
4589 let alpha = rz / pap;
4590 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4591 device_axpy(&blas, &stream, k, -alpha, &ap_dev, &mut r_dev)?;
4592 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4593 if r_norm <= tol {
4594 diag.final_relative_residual = r_norm / rhs_norm;
4595 diag.stopping_reason = PcgStopReason::Converged;
4596 break;
4597 }
4598 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4599 diag.precond_apply_calls += 1;
4600 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4601 if rz_new <= 0.0 || !rz_new.is_finite() {
4602 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4603 reason: format!("SAE matrix-free GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4604 });
4605 }
4606 let beta = rz_new / rz;
4607 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4608 rz = rz_new;
4609 }
4610 if diag.stopping_reason != PcgStopReason::Converged {
4611 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4612 diag.final_relative_residual = r_norm / rhs_norm;
4613 diag.stopping_reason = PcgStopReason::MaxIter;
4614 }
4615 let x = stream
4616 .clone_dtoh(&x_dev)
4617 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4618 Ok((Array1::from_vec(x), diag))
4619 }
4620
4621 pub(super) fn solve_reduced_beta_pcg_with_diagnostics(
4622 s_acc: &ndarray::Array2<f64>,
4623 rhs_beta: &Array1<f64>,
4624 max_iterations: usize,
4625 relative_tolerance: f64,
4626 ) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurGpuFailure> {
4627 let k = rhs_beta.len();
4628 let cg_iters = max_iterations.max(1);
4640 let runtime = gam_gpu::linalg_dispatch::route_through_gpu(
4641 gam_gpu::linalg_dispatch::DispatchOp::Gemm {
4642 m: k,
4643 n: k,
4644 k: cg_iters,
4645 },
4646 )
4647 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4648 let stream = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4649 .and_then(|ctx| ctx.new_stream().ok())
4650 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4651 let blas = CudaBlas::new(stream.clone()).map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4652 let ctx = gam_gpu::device_runtime::cuda_context_for(runtime.device.ordinal)
4653 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4654 let vector_module = pcg_vector_module(&ctx)?;
4655
4656 let mut inv_diag = vec![0.0_f64; k];
4658 for j in 0..k {
4659 let djj = s_acc[[j, j]];
4660 if !(djj > 0.0) {
4661 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4662 reason: format!(
4663 "reduced-β GPU PCG: Jacobi diagonal S[{j},{j}]={djj:e} not positive"
4664 ),
4665 });
4666 }
4667 inv_diag[j] = 1.0 / djj;
4668 }
4669
4670 let mut s_host = vec![0.0_f64; k * k];
4672 for col in 0..k {
4673 for row in 0..k {
4674 s_host[col * k + row] = s_acc[[row, col]];
4675 }
4676 }
4677 let s_dev = stream
4678 .clone_htod(&s_host)
4679 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4680
4681 let rhs_norm = rhs_beta.iter().map(|v| v * v).sum::<f64>().sqrt();
4685 if rhs_norm == 0.0 {
4686 return Ok((Array1::<f64>::zeros(k), PcgDiagnostics::default()));
4687 }
4688 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
4689
4690 let mut x_dev = stream
4693 .alloc_zeros::<f64>(k)
4694 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4695 let mut r_dev = stream
4696 .clone_htod(
4697 rhs_beta
4698 .as_slice()
4699 .ok_or(ArrowSchurGpuFailure::Unavailable)?,
4700 )
4701 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4702 let inv_diag_dev = stream
4703 .clone_htod(&inv_diag)
4704 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4705 let mut z_dev = stream
4706 .alloc_zeros::<f64>(k)
4707 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4708 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4709 let mut p_dev = stream
4710 .alloc_zeros::<f64>(k)
4711 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4712 device_copy(&blas, &stream, k, &z_dev, &mut p_dev)?;
4713 let mut sp_dev = stream
4714 .alloc_zeros::<f64>(k)
4715 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4716 let mut rz = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4717 let mut diag = PcgDiagnostics {
4718 precond_apply_calls: 1,
4719 stopping_reason: PcgStopReason::MaxIter,
4720 ..PcgDiagnostics::default()
4721 };
4722 if rz <= 0.0 || !rz.is_finite() {
4723 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4724 reason: format!("reduced-β GPU PCG: non-positive initial rᵀM⁻¹r={rz:e}"),
4725 });
4726 }
4727
4728 let max_iters = max_iterations.max(1);
4729 for _ in 0..max_iters {
4730 let gemv_cfg = GemvConfig::<f64> {
4732 trans: cublasOperation_t::CUBLAS_OP_N,
4733 m: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4734 n: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4735 alpha: 1.0,
4736 lda: to_i32(k).ok_or(ArrowSchurGpuFailure::Unavailable)?,
4737 incx: 1,
4738 beta: 0.0,
4739 incy: 1,
4740 };
4741 unsafe { blas.gemv(gemv_cfg, &s_dev, &p_dev, &mut sp_dev) }
4743 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4744 diag.matvec_calls += 1;
4745 diag.iterations += 1;
4746
4747 let p_sp = device_dot(&blas, &stream, k, &p_dev, &sp_dev)?;
4748 if !(p_sp > 0.0) {
4749 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4752 reason: format!("reduced-β GPU PCG: non-positive curvature pᵀSp={p_sp:e}"),
4753 });
4754 }
4755 let alpha = rz / p_sp;
4756 device_axpy(&blas, &stream, k, alpha, &p_dev, &mut x_dev)?;
4757 device_axpy(&blas, &stream, k, -alpha, &sp_dev, &mut r_dev)?;
4758 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4759 if r_norm <= tol {
4760 diag.final_relative_residual = r_norm / rhs_norm;
4761 diag.stopping_reason = PcgStopReason::Converged;
4762 break;
4763 }
4764 launch_jacobi_mul(&stream, vector_module, &inv_diag_dev, &r_dev, &mut z_dev, k)?;
4765 diag.precond_apply_calls += 1;
4766 let rz_new = device_dot(&blas, &stream, k, &r_dev, &z_dev)?;
4767 if rz_new <= 0.0 || !rz_new.is_finite() {
4768 return Err(ArrowSchurGpuFailure::SchurFactorFailed {
4769 reason: format!("reduced-β GPU PCG: non-positive rᵀM⁻¹r={rz_new:e}"),
4770 });
4771 }
4772 let beta = rz_new / rz;
4773 launch_update_p(&stream, vector_module, &z_dev, beta, &mut p_dev, k)?;
4774 rz = rz_new;
4775 }
4776 if diag.stopping_reason != PcgStopReason::Converged {
4777 let r_norm = device_nrm2(&blas, &stream, k, &r_dev)?;
4778 diag.final_relative_residual = r_norm / rhs_norm;
4779 diag.stopping_reason = PcgStopReason::MaxIter;
4780 }
4781
4782 let x = stream
4783 .clone_dtoh(&x_dev)
4784 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4785 Ok((Array1::from_vec(x), diag))
4786 }
4787
4788 fn device_copy(
4789 blas: &CudaBlas,
4790 stream: &Arc<CudaStream>,
4791 n: usize,
4792 src: &CudaSlice<f64>,
4793 dst: &mut CudaSlice<f64>,
4794 ) -> Result<(), ArrowSchurGpuFailure> {
4795 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4796 let (src_ptr, _src_rec) = src.device_ptr(stream);
4797 let (dst_ptr, _dst_rec) = dst.device_ptr_mut(stream);
4798 let status = unsafe {
4801 cudarc::cublas::sys::cublasDcopy_v2(
4802 *blas.handle(),
4803 n_i,
4804 src_ptr as *const f64,
4805 1,
4806 dst_ptr as *mut f64,
4807 1,
4808 )
4809 };
4810 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4811 Ok(())
4812 } else {
4813 Err(ArrowSchurGpuFailure::Unavailable)
4814 }
4815 }
4816
4817 fn device_axpy(
4818 blas: &CudaBlas,
4819 stream: &Arc<CudaStream>,
4820 n: usize,
4821 alpha: f64,
4822 x: &CudaSlice<f64>,
4823 y: &mut CudaSlice<f64>,
4824 ) -> Result<(), ArrowSchurGpuFailure> {
4825 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4826 let (x_ptr, _x_rec) = x.device_ptr(stream);
4827 let (y_ptr, _y_rec) = y.device_ptr_mut(stream);
4828 let status = unsafe {
4831 cudarc::cublas::sys::cublasDaxpy_v2(
4832 *blas.handle(),
4833 n_i,
4834 &alpha,
4835 x_ptr as *const f64,
4836 1,
4837 y_ptr as *mut f64,
4838 1,
4839 )
4840 };
4841 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4842 Ok(())
4843 } else {
4844 Err(ArrowSchurGpuFailure::Unavailable)
4845 }
4846 }
4847
4848 fn device_dot(
4849 blas: &CudaBlas,
4850 stream: &Arc<CudaStream>,
4851 n: usize,
4852 x: &CudaSlice<f64>,
4853 y: &CudaSlice<f64>,
4854 ) -> Result<f64, ArrowSchurGpuFailure> {
4855 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4856 let (x_ptr, _x_rec) = x.device_ptr(stream);
4857 let (y_ptr, _y_rec) = y.device_ptr(stream);
4858 let mut result = 0.0_f64;
4859 let status = unsafe {
4863 cudarc::cublas::sys::cublasDdot_v2(
4864 *blas.handle(),
4865 n_i,
4866 x_ptr as *const f64,
4867 1,
4868 y_ptr as *const f64,
4869 1,
4870 &mut result,
4871 )
4872 };
4873 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4874 Ok(result)
4875 } else {
4876 Err(ArrowSchurGpuFailure::Unavailable)
4877 }
4878 }
4879
4880 fn device_nrm2(
4881 blas: &CudaBlas,
4882 stream: &Arc<CudaStream>,
4883 n: usize,
4884 x: &CudaSlice<f64>,
4885 ) -> Result<f64, ArrowSchurGpuFailure> {
4886 let n_i = to_i32(n).ok_or(ArrowSchurGpuFailure::Unavailable)?;
4887 let (x_ptr, _x_rec) = x.device_ptr(stream);
4888 let mut result = 0.0_f64;
4889 let status = unsafe {
4893 cudarc::cublas::sys::cublasDnrm2_v2(
4894 *blas.handle(),
4895 n_i,
4896 x_ptr as *const f64,
4897 1,
4898 &mut result,
4899 )
4900 };
4901 if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
4902 Ok(result)
4903 } else {
4904 Err(ArrowSchurGpuFailure::Unavailable)
4905 }
4906 }
4907
4908 #[cfg(test)]
4909 mod tests {
4910 use super::*;
4915 use crate::arrow_schur::{
4916 ArrowSchurSystem, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
4917 FactoredFrameGBlock,
4918 };
4919 use ndarray::Array2;
4920
4921 fn device_matvec_once(
4924 sys: &ArrowSchurSystem,
4925 data: &DeviceSaePcgData,
4926 ridge_t: f64,
4927 ridge_beta: f64,
4928 x_host: &[f64],
4929 ) -> Result<Vec<f64>, ArrowSchurGpuFailure> {
4930 let k = x_host.len();
4931 let frame = data
4932 .frame
4933 .as_ref()
4934 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4935 let runtime = gam_gpu::device_runtime::GpuRuntime::global()
4936 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4937 let ctx =
4938 gam_gpu::device_runtime::cuda_context_for(runtime.selected_device().ordinal)
4939 .ok_or(ArrowSchurGpuFailure::Unavailable)?;
4940 let stream = ctx
4941 .new_stream()
4942 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4943 let vector_module = pcg_vector_module(&ctx)?;
4944 let mut buffers = flatten_device_sae_frame_data(sys, data, frame, ridge_t, &stream)?;
4945 let x_dev = stream
4946 .clone_htod(x_host)
4947 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4948 let mut out_dev = stream
4949 .alloc_zeros::<f64>(k)
4950 .map_err(|_| ArrowSchurGpuFailure::Unavailable)?;
4951 launch_sae_frame_matvec(
4952 &stream,
4953 vector_module,
4954 &mut buffers,
4955 &x_dev,
4956 &mut out_dev,
4957 ridge_beta,
4958 )?;
4959 stream
4960 .clone_dtoh(&out_dev)
4961 .map_err(|_| ArrowSchurGpuFailure::Unavailable)
4962 }
4963
4964 #[test]
4970 fn framed_sae_device_matvec_stage_diff_tiny_1551() {
4971 if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
4972 return;
4973 }
4974 let p = 3usize;
4975 let ranks = vec![2usize, 3usize];
4976 let basis_sizes = vec![2usize, 2usize];
4977 let mut border_offsets = Vec::new();
4978 let mut acc = 0usize;
4979 for k in 0..2 {
4980 border_offsets.push(acc);
4981 acc += basis_sizes[k] * ranks[k];
4982 }
4983 let border_dim = acc; let frame_of = |k: usize| -> Array2<f64> {
4985 Array2::from_shape_fn((p, ranks[k]), |(i, j)| {
4986 0.1 + 0.2 * ((i + 1) as f64) * ((j + 1 + 2 * k) as f64)
4987 })
4988 };
4989 let frames: Vec<Array2<f64>> = (0..2).map(frame_of).collect();
4990 let w_of = |i: usize, j: usize| -> Array2<f64> {
4991 let (ui, uj) = (&frames[i], &frames[j]);
4992 Array2::from_shape_fn((ranks[i], ranks[j]), |(a, b)| {
4993 (0..p).map(|c| ui[[c, a]] * uj[[c, b]]).sum()
4994 })
4995 };
4996 let mut frame_blocks = Vec::new();
4997 for &(i, j) in &[(0usize, 0usize), (1usize, 1usize), (0, 1), (1, 0)] {
4998 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
4999 let mut g =
5000 Array2::<f64>::from_shape_fn((mi, mj), |(r, c)| 0.1 * (r + 2 * c + 1) as f64);
5001 if i == j {
5002 for r in 0..mi.min(mj) {
5003 g[[r, r]] += mi as f64 + 2.0;
5004 }
5005 }
5006 frame_blocks.push(FactoredFrameGBlock {
5007 atom_i: i,
5008 atom_j: j,
5009 g,
5010 w: w_of(i, j),
5011 });
5012 }
5013 let mut smooth_blocks = Vec::new();
5014 for k in 0..2 {
5015 let m = basis_sizes[k];
5016 let mut s =
5017 Array2::<f64>::from_shape_fn((m, m), |(r, c)| 0.05 * (r + c + 1) as f64);
5018 for r in 0..m {
5019 s[[r, r]] += 1.0;
5020 }
5021 smooth_blocks.push(DeviceSaeSmoothBlock {
5022 global_offset: border_offsets[k],
5023 factor_a: s,
5024 });
5025 }
5026 let smooth_ranks = ranks.clone();
5027 let n = 2usize;
5028 let q = 2usize;
5029 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5030 let mut row_htbeta = Vec::new();
5031 for i in 0..n {
5032 let mut htt =
5033 Array2::<f64>::from_shape_fn((q, q), |(r, c)| 0.3 * (r + c + 1) as f64);
5034 for r in 0..q {
5035 htt[[r, r]] += q as f64 + 2.0;
5036 }
5037 sys.rows[i].htt = htt;
5038 let mut slab = vec![0.0_f64; q * border_dim];
5039 for c in 0..q {
5040 for col in 0..border_dim {
5041 let v = 0.01 * ((c + 1) * (col + 1) + i) as f64;
5042 slab[c * border_dim + col] = v;
5043 sys.rows[i].htbeta[[c, col]] = v;
5044 }
5045 }
5046 row_htbeta.push(slab);
5047 }
5048 let data = DeviceSaePcgData {
5049 p,
5050 beta_dim: border_dim,
5051 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5052 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5053 smooth_blocks,
5054 sparse_g_blocks: Vec::new(),
5055 frame: Some(DeviceSaeFrameData {
5056 ranks,
5057 basis_sizes,
5058 border_offsets,
5059 frame_blocks,
5060 smooth_ranks,
5061 row_htbeta,
5062 }),
5063 };
5064 let ridge_t = 1e-7;
5065 let ridge_beta = 1e-6;
5066 let mut first_bad: Option<usize> = None;
5067 let mut worst = 0.0_f64;
5068 let mut worst_at = 0usize;
5069 let mut worst_dev = 0.0_f64;
5070 let mut worst_cpu = 0.0_f64;
5071 for col in 0..border_dim {
5072 let mut x = vec![0.0_f64; border_dim];
5073 x[col] = 1.0;
5074 let dev = match device_matvec_once(&sys, &data, ridge_t, ridge_beta, &x) {
5075 Ok(v) => v,
5076 Err(_) => return,
5077 };
5078 let mut cpu = vec![0.0_f64; border_dim];
5079 super::super::sae_framed_schur_matvec_cpu(
5080 &sys, &data, ridge_t, ridge_beta, &x, &mut cpu,
5081 )
5082 .expect("cpu matvec");
5083 for r in 0..border_dim {
5084 let d = (dev[r] - cpu[r]).abs();
5085 if d > 1e-9 && first_bad.is_none() {
5086 first_bad = Some(r * border_dim + col);
5087 }
5088 if d > worst {
5089 worst = d;
5090 worst_at = r * border_dim + col;
5091 worst_dev = dev[r];
5092 worst_cpu = cpu[r];
5093 }
5094 }
5095 }
5096 assert!(
5097 worst <= 1e-9,
5098 "[#1551 stage-diff] device framed matvec != CPU oracle: worst abs={worst:e} at \
5099 (row*K+col)={worst_at} (dev={worst_dev:e} cpu={worst_cpu:e}), \
5100 first_bad_idx={first_bad:?}; border layout: atom0 [0..4) rank2, atom1 [4..10) \
5101 rank3 — which atom-range the bad row/col falls in pins the stage (smooth=diag, \
5102 G⊗W=cross, reduced-Schur=dense per-row)",
5103 );
5104 }
5105 }
5106}
5107
5108#[cfg(test)]
5109mod tests {
5110 use super::*;
5111 use crate::arrow_schur::ArrowSchurSystem;
5112 use ndarray::{Array2, ArrayView1};
5113
5114 fn build_fixture(n: usize, d: usize, k: usize, seed: u64) -> ArrowSchurSystem {
5115 let mut sys = ArrowSchurSystem::new(n, d, k);
5116 let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15);
5117 let mut sample = || -> f64 {
5118 state = state
5119 .wrapping_mul(6364136223846793005)
5120 .wrapping_add(1442695040888963407);
5121 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5122 };
5123 for row in &mut sys.rows {
5124 let mut a = Array2::<f64>::zeros((d, d));
5125 for r in 0..d {
5126 for c in 0..d {
5127 a[[r, c]] = sample();
5128 }
5129 }
5130 let mut htt = a.t().dot(&a);
5131 for r in 0..d {
5132 htt[[r, r]] += d as f64 + 1.0;
5133 }
5134 row.htt = htt;
5135 for r in 0..d {
5136 for c in 0..k {
5137 row.htbeta[[r, c]] = 0.1 * sample();
5138 }
5139 row.gt[r] = sample();
5140 }
5141 }
5142 let mut hbb_a = Array2::<f64>::zeros((k, k));
5143 for r in 0..k {
5144 for c in 0..k {
5145 hbb_a[[r, c]] = sample();
5146 }
5147 }
5148 let mut hbb = hbb_a.t().dot(&hbb_a);
5149 for r in 0..k {
5150 hbb[[r, r]] += k as f64 + 1.0;
5151 }
5152 sys.hbb = hbb;
5153 for r in 0..k {
5154 sys.gb[r] = sample();
5155 }
5156 sys
5157 }
5158
5159 #[test]
5164 fn ridge_bump_makes_known_indefinite_blocks_pd() {
5165 let neg_identity = Array2::<f64>::from_diag(&Array1::from_elem(8, -1.0)); let scaled_neg = Array2::<f64>::from_diag(&Array1::from_elem(4, -250.0)); let mut indef2 = Array2::<f64>::zeros((2, 2));
5173 indef2[[0, 0]] = 1.0;
5174 indef2[[1, 1]] = 1.0;
5175 indef2[[0, 1]] = 2.0;
5176 indef2[[1, 0]] = 2.0;
5177 let pd = Array2::<f64>::from_diag(&Array1::from_elem(3, 5.0));
5180
5181 for (label, block) in [
5182 ("-I (λ_min=-1)", neg_identity),
5183 ("-250·I (λ_min=-250)", scaled_neg),
5184 ("[[1,2],[2,1]] (λ_min=-1)", indef2),
5185 ("5·I (PD)", pd),
5186 ] {
5187 let ridge_t = 0.0;
5188 let bump = ridge_bump_to_make_pd(block.view(), ridge_t);
5189 assert!(
5190 bump > 0.0 && bump.is_finite(),
5191 "[{label}] bump must be strictly positive and finite, got {bump:e}"
5192 );
5193 let d = block.nrows();
5194 let mut shifted = block.clone();
5195 for i in 0..d {
5196 shifted[[i, i]] += ridge_t + bump;
5197 }
5198 assert!(
5199 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5200 "[{label}] H_tt + (ridge_t + bump={bump:e})·I must be PD after the \
5201 Gershgorin bump, but the Cholesky still rejected it"
5202 );
5203 }
5204 }
5205
5206 #[cfg(target_os = "linux")]
5216 #[test]
5217 fn ridge_bump_colmajor_matches_rowmajor_for_symmetric_block() {
5218 let mut a = Array2::<f64>::zeros((3, 3));
5220 a[[0, 0]] = -2.0;
5221 a[[1, 1]] = 0.5;
5222 a[[2, 2]] = 1.0;
5223 a[[0, 1]] = 0.3;
5224 a[[1, 0]] = 0.3;
5225 a[[1, 2]] = -0.4;
5226 a[[2, 1]] = -0.4;
5227 a[[0, 2]] = 0.1;
5228 a[[2, 0]] = 0.1;
5229
5230 let row_major_bump = ridge_bump_to_make_pd(a.view(), 0.0);
5231
5232 let d = 3;
5234 let mut col_major = vec![0.0_f64; d * d];
5235 for c in 0..d {
5236 for r in 0..d {
5237 col_major[c * d + r] = a[[r, c]];
5238 }
5239 }
5240 let col_major_bump = ridge_bump_to_make_pd_colmajor(&col_major, d);
5241
5242 assert!(
5243 (row_major_bump - col_major_bump).abs() <= 1e-12 * row_major_bump.max(1.0),
5244 "colmajor bump {col_major_bump:e} must match rowmajor bump \
5245 {row_major_bump:e} for a symmetric block"
5246 );
5247
5248 let mut shifted = a.clone();
5250 for i in 0..d {
5251 shifted[[i, i]] += col_major_bump;
5252 }
5253 assert!(
5254 cholesky_factor_in_place(shifted.view(), CholeskyGuard::NonnegativePivot).is_some(),
5255 "colmajor Gershgorin bump must make the symmetric block PD"
5256 );
5257 }
5258
5259 fn device_pcg_fixture(k: usize) -> (Array2<f64>, Array1<f64>) {
5260 let mut s = Array2::<f64>::zeros((k, k));
5261 for row in 0..k {
5262 s[[row, row]] = 2.5 + 0.001 * ((row % 17) as f64);
5263 if row + 1 < k {
5264 s[[row, row + 1]] = -0.05;
5265 s[[row + 1, row]] = -0.05;
5266 }
5267 if row + 7 < k {
5268 s[[row, row + 7]] = 0.01;
5269 s[[row + 7, row]] = 0.01;
5270 }
5271 }
5272 let rhs = Array1::from_shape_fn(k, |idx| ((idx as f64 + 1.0) * 0.013).sin());
5273 (s, rhs)
5274 }
5275
5276 fn dense_pcg_cpu_reference(
5277 s: &Array2<f64>,
5278 rhs: &Array1<f64>,
5279 max_iterations: usize,
5280 relative_tolerance: f64,
5281 ) -> Array1<f64> {
5282 let k = rhs.len();
5283 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
5284 if rhs_norm == 0.0 {
5285 return Array1::<f64>::zeros(k);
5286 }
5287 let tol = (relative_tolerance.max(0.0) * rhs_norm).max(1e-12);
5288 let inv_diag: Vec<f64> = (0..k).map(|idx| 1.0 / s[[idx, idx]]).collect();
5289 let mut x = Array1::<f64>::zeros(k);
5290 let mut r = rhs.clone();
5291 let mut z = Array1::from_shape_fn(k, |idx| inv_diag[idx] * r[idx]);
5292 let mut p = z.clone();
5293 let mut sp = Array1::<f64>::zeros(k);
5294 let mut rz = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5295 for _ in 0..max_iterations.max(1) {
5296 for row in 0..k {
5297 let mut acc = 0.0;
5298 for col in 0..k {
5299 acc += s[[row, col]] * p[col];
5300 }
5301 sp[row] = acc;
5302 }
5303 let p_sp = p.iter().zip(sp.iter()).map(|(a, b)| a * b).sum::<f64>();
5304 let alpha = rz / p_sp;
5305 for idx in 0..k {
5306 x[idx] += alpha * p[idx];
5307 r[idx] -= alpha * sp[idx];
5308 }
5309 let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
5310 if r_norm <= tol {
5311 break;
5312 }
5313 for idx in 0..k {
5314 z[idx] = inv_diag[idx] * r[idx];
5315 }
5316 let rz_next = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum::<f64>();
5317 let beta = rz_next / rz;
5318 for idx in 0..k {
5319 p[idx] = z[idx] + beta * p[idx];
5320 }
5321 rz = rz_next;
5322 }
5323 x
5324 }
5325
5326 #[test]
5327 fn device_resident_pcg_matches_cpu_reference_when_cuda_admits() {
5328 let (s, rhs) = device_pcg_fixture(512);
5329 let max_iterations = 200usize;
5330 let relative_tolerance = 1.0e-12;
5331 let cpu = dense_pcg_cpu_reference(&s, &rhs, max_iterations, relative_tolerance);
5332 let (device, diag) = match solve_reduced_beta_pcg_with_diagnostics(
5333 &s,
5334 &rhs,
5335 max_iterations,
5336 relative_tolerance,
5337 ) {
5338 Ok(result) => result,
5339 Err(failure) => {
5346 assert!(
5347 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
5348 "#1017: CUDA device present but the device reduced-beta PCG \
5349 declined/faulted instead of returning a result (tag: {failure:?}) — \
5350 the kernel does not run correctly on GPU"
5351 );
5352 return;
5353 }
5354 };
5355 let max_err = cpu
5356 .iter()
5357 .zip(device.iter())
5358 .map(|(a, b)| (a - b).abs())
5359 .fold(0.0_f64, f64::max);
5360 assert!(
5361 max_err <= 1.0e-10,
5362 "device resident PCG parity failed: max_err={max_err:e}, diag={diag:?}"
5363 );
5364 assert!(diag.matvec_calls > 0);
5365 assert_eq!(diag.matvec_calls, diag.iterations);
5366 }
5367
5368 #[test]
5369 fn dense_reference_matches_independent_solve() {
5370 let sys = build_fixture(4, 5, 3, 7);
5371 let solution = solve_arrow_newton_step_dense_reference(&sys, 0.0, 0.0).unwrap();
5372 let n = sys.rows.len();
5376 let d = sys.d;
5377 let k = sys.k;
5378 let total = n * d + k;
5379 let mut h = Array2::<f64>::zeros((total, total));
5380 let mut g = ndarray::Array1::<f64>::zeros(total);
5381 for (i, row) in sys.rows.iter().enumerate() {
5382 let base = i * d;
5383 for c in 0..d {
5384 for r in 0..d {
5385 h[[base + r, base + c]] = row.htt[[r, c]];
5386 }
5387 }
5388 for c in 0..k {
5389 for r in 0..d {
5390 h[[base + r, n * d + c]] = row.htbeta[[r, c]];
5391 h[[n * d + c, base + r]] = row.htbeta[[r, c]];
5392 }
5393 }
5394 for r in 0..d {
5395 g[base + r] = row.gt[r];
5396 }
5397 }
5398 for c in 0..k {
5399 for r in 0..k {
5400 h[[n * d + r, n * d + c]] += sys.hbb[[r, c]];
5401 }
5402 g[n * d + c] = sys.gb[c];
5403 }
5404 let l = cholesky_factor_in_place(h.view(), CholeskyGuard::NonnegativePivot).unwrap();
5405 let rhs = g.mapv(|v| -v);
5406 let expected = cholesky_solve_vector(l.view(), rhs.view());
5407 for i in 0..n * d {
5408 assert!(
5409 (solution.delta_t[i] - expected[i]).abs() < 1e-10 * (1.0 + expected[i].abs()),
5410 "delta_t[{i}] mismatch: got {} expected {}",
5411 solution.delta_t[i],
5412 expected[i]
5413 );
5414 }
5415 for a in 0..k {
5416 assert!(
5417 (solution.delta_beta[a] - expected[n * d + a]).abs()
5418 < 1e-10 * (1.0 + expected[n * d + a].abs()),
5419 "delta_beta[{a}] mismatch"
5420 );
5421 }
5422 }
5423
5424 #[test]
5438 fn row_procedural_matvec_parallel_deterministic_and_matches_serial() {
5439 use crate::arrow_schur::SCHUR_MATVEC_PARALLEL_ROW_MIN;
5440 let n = SCHUR_MATVEC_PARALLEL_ROW_MIN + 96; let d = 3usize;
5442 let k = 24usize;
5443 let mut sys = build_fixture(n, d, k, 0xA17C_0FFE);
5444 let slabs: Vec<Array2<f64>> = sys.rows.iter().map(|row| row.htbeta.clone()).collect();
5449 let forward_slabs = slabs.clone();
5450 let transpose_slabs = slabs;
5451 sys.set_row_htbeta_operator(
5452 move |row: usize, x: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5453 let h = &forward_slabs[row];
5454 for r in 0..h.nrows() {
5455 let mut acc = 0.0_f64;
5456 for c in 0..h.ncols() {
5457 acc += h[[r, c]] * x[c];
5458 }
5459 out[r] = acc;
5460 }
5461 },
5462 move |row: usize, v: ArrayView1<'_, f64>, out: &mut Array1<f64>| {
5463 let h = &transpose_slabs[row];
5464 for r in 0..h.nrows() {
5465 for c in 0..h.ncols() {
5466 out[c] += h[[r, c]] * v[r];
5467 }
5468 }
5469 },
5470 );
5471
5472 let matvec = gpu_schur_matvec_backend(&sys, 0.0, 0.0)
5473 .expect("row-procedural matvec backend builds for matrix-free system");
5474 let x = Array1::from_shape_fn(k, |i| ((i as f64 + 1.0) * 0.37).sin());
5475
5476 let mut out_parallel_a = Array1::<f64>::zeros(k);
5480 matvec(&x, &mut out_parallel_a);
5481 let mut out_parallel_b = Array1::<f64>::zeros(k);
5482 matvec(&x, &mut out_parallel_b);
5483 for a in 0..k {
5484 assert_eq!(
5485 out_parallel_a[a].to_bits(),
5486 out_parallel_b[a].to_bits(),
5487 "row-procedural matvec parallel reduction is non-deterministic at index {a}"
5488 );
5489 }
5490
5491 let mut out_serial = Array1::<f64>::zeros(k);
5496 rayon::ThreadPoolBuilder::new()
5497 .num_threads(2)
5498 .build()
5499 .expect("build rayon pool")
5500 .install(|| matvec(&x, &mut out_serial));
5501
5502 let max_abs = out_serial.iter().fold(0.0_f64, |m, v| m.max(v.abs()));
5503 for a in 0..k {
5504 let diff = (out_parallel_a[a] - out_serial[a]).abs();
5505 assert!(
5506 diff <= 1e-12 * (1.0 + max_abs),
5507 "row-procedural matvec parallel vs serial diverged beyond reassociation \
5508 at index {a}: {} vs {} (diff={diff:e})",
5509 out_parallel_a[a],
5510 out_serial[a]
5511 );
5512 }
5513 }
5514
5515 #[test]
5522 fn framed_sae_schur_matvec_matches_dense_reference() {
5523 use crate::arrow_schur::{
5524 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5525 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5526 };
5527
5528 let p = 4usize;
5529 let ranks = vec![2usize, 4usize, 3usize];
5531 let basis_sizes = vec![2usize, 1usize, 2usize];
5532 let n_atoms = ranks.len();
5533 let mut border_offsets = Vec::with_capacity(n_atoms);
5534 let mut acc = 0usize;
5535 for k in 0..n_atoms {
5536 border_offsets.push(acc);
5537 acc += basis_sizes[k] * ranks[k];
5538 }
5539 let border_dim = acc; let mut state = 0x1234_5678_9abc_def0u64;
5542 let mut sample = || -> f64 {
5543 state = state
5544 .wrapping_mul(6364136223846793005)
5545 .wrapping_add(1442695040888963407);
5546 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5547 };
5548
5549 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5552 for k in 0..n_atoms {
5553 let r = ranks[k];
5554 let mut u = Array2::<f64>::zeros((p, r));
5555 for i in 0..p {
5556 for j in 0..r {
5557 u[[i, j]] = if r == p && i == j {
5558 1.0
5559 } else if r == p {
5560 0.0
5561 } else {
5562 sample()
5563 };
5564 }
5565 }
5566 frames.push(u);
5567 }
5568 let w_of = |i: usize, j: usize| -> Array2<f64> {
5569 let (ui, uj) = (&frames[i], &frames[j]);
5570 let (ri, rj) = (ranks[i], ranks[j]);
5571 let mut w = Array2::<f64>::zeros((ri, rj));
5572 for a in 0..ri {
5573 for b in 0..rj {
5574 let mut s = 0.0;
5575 for c in 0..p {
5576 s += ui[[c, a]] * uj[[c, b]];
5577 }
5578 w[[a, b]] = s;
5579 }
5580 }
5581 w
5582 };
5583
5584 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5586 let mut pairs = vec![(0usize, 0usize), (1, 1), (2, 2), (0, 2), (2, 0)];
5587 pairs.sort();
5588 for &(i, j) in &pairs {
5589 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5590 let mut g = Array2::<f64>::zeros((mi, mj));
5591 for r in 0..mi {
5592 for c in 0..mj {
5593 g[[r, c]] = 0.3 * sample();
5594 }
5595 }
5596 if i == j {
5598 for r in 0..mi.min(mj) {
5599 g[[r, r]] += mi as f64 + 2.0;
5600 }
5601 }
5602 frame_blocks.push(FactoredFrameGBlock {
5603 atom_i: i,
5604 atom_j: j,
5605 g,
5606 w: w_of(i, j),
5607 });
5608 }
5609
5610 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5612 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5613 for k in 0..n_atoms {
5614 let m = basis_sizes[k];
5615 let mut a = Array2::<f64>::zeros((m, m));
5616 for r in 0..m {
5617 for c in 0..m {
5618 a[[r, c]] = 0.2 * sample();
5619 }
5620 }
5621 let mut s = a.t().dot(&a);
5622 for r in 0..m {
5623 s[[r, r]] += 1.0;
5624 }
5625 smooth_blocks.push(DeviceSaeSmoothBlock {
5626 global_offset: border_offsets[k],
5627 factor_a: s,
5628 });
5629 smooth_ranks.push(ranks[k]);
5630 }
5631
5632 let n = 6usize;
5634 let q = 3usize;
5635 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5636 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5637 for i in 0..n {
5638 let mut a = Array2::<f64>::zeros((q, q));
5640 for r in 0..q {
5641 for c in 0..q {
5642 a[[r, c]] = sample();
5643 }
5644 }
5645 let mut htt = a.t().dot(&a);
5646 for r in 0..q {
5647 htt[[r, r]] += q as f64 + 1.0;
5648 }
5649 sys.rows[i].htt = htt;
5650 let mut slab = vec![0.0_f64; q * border_dim];
5651 for c in 0..q {
5652 for col in 0..border_dim {
5653 let v = 0.15 * sample();
5654 slab[c * border_dim + col] = v;
5655 sys.rows[i].htbeta[[c, col]] = v;
5656 }
5657 }
5658 row_htbeta.push(slab);
5659 }
5660
5661 let data_op =
5664 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5665 .expect("frame op");
5666 let mut hbb = data_op.to_dense();
5667 for k in 0..n_atoms {
5668 let op = IdentityRightKroneckerPenaltyOp {
5669 factor_a: smooth_blocks[k].factor_a.clone(),
5670 p: ranks[k],
5671 global_offset: border_offsets[k],
5672 k: border_dim,
5673 };
5674 let d = op.to_dense();
5675 for r in 0..border_dim {
5676 for c in 0..border_dim {
5677 hbb[[r, c]] += d[[r, c]];
5678 }
5679 }
5680 }
5681 sys.hbb = hbb;
5682
5683 let data = DeviceSaePcgData {
5684 p,
5685 beta_dim: border_dim,
5686 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5687 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5688 smooth_blocks,
5689 sparse_g_blocks: Vec::new(),
5690 frame: Some(DeviceSaeFrameData {
5691 ranks: ranks.clone(),
5692 basis_sizes: basis_sizes.clone(),
5693 border_offsets: border_offsets.clone(),
5694 frame_blocks,
5695 smooth_ranks,
5696 row_htbeta,
5697 }),
5698 };
5699
5700 let ridge_t = 1e-7;
5701 let ridge_beta = 1e-6;
5702
5703 let mut s_dense = Array2::<f64>::zeros((border_dim, border_dim));
5707 for r in 0..border_dim {
5708 for c in 0..border_dim {
5709 s_dense[[r, c]] = sys.hbb[[r, c]];
5710 }
5711 s_dense[[r, r]] += ridge_beta;
5712 }
5713 for row in &sys.rows {
5714 let mut htt = row.htt.clone();
5715 for d in 0..q {
5716 htt[[d, d]] += ridge_t;
5717 }
5718 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5719 .expect("htt PD");
5720 let mut y = Array2::<f64>::zeros((q, border_dim));
5722 for col in 0..border_dim {
5723 let mut e = Array1::<f64>::zeros(q);
5724 for r in 0..q {
5725 e[r] = row.htbeta[[r, col]];
5726 }
5727 let solved = cholesky_solve_vector(factor.view(), e.view());
5728 for r in 0..q {
5729 y[[r, col]] = solved[r];
5730 }
5731 }
5732 for r in 0..border_dim {
5733 for c in 0..border_dim {
5734 let mut acc = 0.0;
5735 for d in 0..q {
5736 acc += row.htbeta[[d, r]] * y[[d, c]];
5737 }
5738 s_dense[[r, c]] -= acc;
5739 }
5740 }
5741 }
5742
5743 let mut max_rel = 0.0_f64;
5745 for trial in 0..4 {
5746 let x: Vec<f64> = (0..border_dim)
5747 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
5748 .collect();
5749 let mut got = vec![0.0_f64; border_dim];
5750 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
5751 .expect("framed matvec");
5752 let mut want = vec![0.0_f64; border_dim];
5753 for r in 0..border_dim {
5754 let mut acc = 0.0;
5755 for c in 0..border_dim {
5756 acc += s_dense[[r, c]] * x[c];
5757 }
5758 want[r] = acc;
5759 }
5760 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
5761 for a in 0..border_dim {
5762 let rel = (got[a] - want[a]).abs() / scale;
5763 max_rel = max_rel.max(rel);
5764 }
5765 }
5766 assert!(
5767 max_rel <= 1e-10,
5768 "framed SAE Schur matvec vs dense reference diverged: max_rel={max_rel:e}"
5769 );
5770 }
5771
5772 #[test]
5782 fn framed_sae_schur_matvec_matches_dense_reference_large_k_1026() {
5783 use crate::arrow_schur::{
5784 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
5785 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
5786 };
5787
5788 let p = 12usize;
5789 let n_atoms = 40usize;
5790 let ranks: Vec<usize> = (0..n_atoms)
5793 .map(|k| if k % 5 == 0 { p } else { 2 + (k % 3) })
5794 .collect();
5795 let basis_sizes: Vec<usize> = (0..n_atoms).map(|k| 1 + (k % 3)).collect();
5796 let mut border_offsets = Vec::with_capacity(n_atoms);
5797 let mut acc = 0usize;
5798 for k in 0..n_atoms {
5799 border_offsets.push(acc);
5800 acc += basis_sizes[k] * ranks[k];
5801 }
5802 let border_dim = acc;
5803
5804 let mut state = 0x0bad_c0de_dead_beefu64;
5805 let mut sample = || -> f64 {
5806 state = state
5807 .wrapping_mul(6364136223846793005)
5808 .wrapping_add(1442695040888963407);
5809 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
5810 };
5811
5812 let mut frames: Vec<Array2<f64>> = Vec::with_capacity(n_atoms);
5814 for k in 0..n_atoms {
5815 let r = ranks[k];
5816 let mut u = Array2::<f64>::zeros((p, r));
5817 for i in 0..p {
5818 for j in 0..r {
5819 u[[i, j]] = if r == p {
5820 if i == j { 1.0 } else { 0.0 }
5821 } else {
5822 sample()
5823 };
5824 }
5825 }
5826 frames.push(u);
5827 }
5828 let w_of = |i: usize, j: usize| -> Array2<f64> {
5829 let (ui, uj) = (&frames[i], &frames[j]);
5830 let (ri, rj) = (ranks[i], ranks[j]);
5831 let mut w = Array2::<f64>::zeros((ri, rj));
5832 for a in 0..ri {
5833 for b in 0..rj {
5834 let mut s = 0.0;
5835 for c in 0..p {
5836 s += ui[[c, a]] * uj[[c, b]];
5837 }
5838 w[[a, b]] = s;
5839 }
5840 }
5841 w
5842 };
5843
5844 let mut pairs: Vec<(usize, usize)> = Vec::new();
5847 for k in 0..n_atoms {
5848 pairs.push((k, k));
5849 }
5850 for k in 0..n_atoms - 1 {
5851 pairs.push((k, k + 1));
5852 pairs.push((k + 1, k));
5853 }
5854 pairs.sort_unstable();
5855 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::new();
5856 for &(i, j) in &pairs {
5857 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
5858 let mut g = Array2::<f64>::zeros((mi, mj));
5859 for r in 0..mi {
5860 for c in 0..mj {
5861 g[[r, c]] = 0.3 * sample();
5862 }
5863 }
5864 if i == j {
5865 for r in 0..mi.min(mj) {
5866 g[[r, r]] += mi as f64 + 2.0;
5867 }
5868 }
5869 frame_blocks.push(FactoredFrameGBlock {
5870 atom_i: i,
5871 atom_j: j,
5872 g,
5873 w: w_of(i, j),
5874 });
5875 }
5876
5877 let mut smooth_blocks: Vec<DeviceSaeSmoothBlock> = Vec::with_capacity(n_atoms);
5879 let mut smooth_ranks: Vec<usize> = Vec::with_capacity(n_atoms);
5880 for k in 0..n_atoms {
5881 let m = basis_sizes[k];
5882 let mut a = Array2::<f64>::zeros((m, m));
5883 for r in 0..m {
5884 for c in 0..m {
5885 a[[r, c]] = 0.2 * sample();
5886 }
5887 }
5888 let mut s = a.t().dot(&a);
5889 for r in 0..m {
5890 s[[r, r]] += 1.0;
5891 }
5892 smooth_blocks.push(DeviceSaeSmoothBlock {
5893 global_offset: border_offsets[k],
5894 factor_a: s,
5895 });
5896 smooth_ranks.push(ranks[k]);
5897 }
5898
5899 let n = 8usize;
5901 let q = 3usize;
5902 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
5903 let mut row_htbeta: Vec<Vec<f64>> = Vec::with_capacity(n);
5904 for i in 0..n {
5905 let mut a = Array2::<f64>::zeros((q, q));
5906 for r in 0..q {
5907 for c in 0..q {
5908 a[[r, c]] = sample();
5909 }
5910 }
5911 let mut htt = a.t().dot(&a);
5912 for r in 0..q {
5913 htt[[r, r]] += q as f64 + 1.0;
5914 }
5915 sys.rows[i].htt = htt;
5916 let mut slab = vec![0.0_f64; q * border_dim];
5917 for c in 0..q {
5918 for col in 0..border_dim {
5919 let v = 0.15 * sample();
5920 slab[c * border_dim + col] = v;
5921 sys.rows[i].htbeta[[c, col]] = v;
5922 }
5923 }
5924 row_htbeta.push(slab);
5925 }
5926
5927 let data_op =
5930 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
5931 .expect("frame op");
5932 let mut hbb = data_op.to_dense();
5933 for k in 0..n_atoms {
5934 let op = IdentityRightKroneckerPenaltyOp {
5935 factor_a: smooth_blocks[k].factor_a.clone(),
5936 p: ranks[k],
5937 global_offset: border_offsets[k],
5938 k: border_dim,
5939 };
5940 let d = op.to_dense();
5941 for r in 0..border_dim {
5942 for c in 0..border_dim {
5943 hbb[[r, c]] += d[[r, c]];
5944 }
5945 }
5946 }
5947 sys.hbb = hbb;
5948
5949 let data = DeviceSaePcgData {
5950 p,
5951 beta_dim: border_dim,
5952 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5953 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
5954 smooth_blocks,
5955 sparse_g_blocks: Vec::new(),
5956 frame: Some(DeviceSaeFrameData {
5957 ranks: ranks.clone(),
5958 basis_sizes: basis_sizes.clone(),
5959 border_offsets: border_offsets.clone(),
5960 frame_blocks,
5961 smooth_ranks,
5962 row_htbeta,
5963 }),
5964 };
5965
5966 let ridge_t = 1e-7;
5967 let ridge_beta = 1e-6;
5968
5969 let mut s_dense = sys.hbb.clone();
5971 for r in 0..border_dim {
5972 s_dense[[r, r]] += ridge_beta;
5973 }
5974 for row in &sys.rows {
5975 let mut htt = row.htt.clone();
5976 for d in 0..q {
5977 htt[[d, d]] += ridge_t;
5978 }
5979 let factor = cholesky_factor_in_place(htt.view(), CholeskyGuard::NonnegativePivot)
5980 .expect("htt PD");
5981 let mut y = Array2::<f64>::zeros((q, border_dim));
5982 for col in 0..border_dim {
5983 let mut e = Array1::<f64>::zeros(q);
5984 for r in 0..q {
5985 e[r] = row.htbeta[[r, col]];
5986 }
5987 let solved = cholesky_solve_vector(factor.view(), e.view());
5988 for r in 0..q {
5989 y[[r, col]] = solved[r];
5990 }
5991 }
5992 for r in 0..border_dim {
5993 for c in 0..border_dim {
5994 let mut acc = 0.0;
5995 for d in 0..q {
5996 acc += row.htbeta[[d, r]] * y[[d, c]];
5997 }
5998 s_dense[[r, c]] -= acc;
5999 }
6000 }
6001 }
6002
6003 let mut max_rel = 0.0_f64;
6004 for trial in 0..4 {
6005 let x: Vec<f64> = (0..border_dim)
6006 .map(|a| 0.3 * ((a as f64 + trial as f64) * 0.21).cos() - 0.1)
6007 .collect();
6008 let mut got = vec![0.0_f64; border_dim];
6009 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, &x, &mut got)
6010 .expect("framed matvec");
6011 let mut want = vec![0.0_f64; border_dim];
6012 for r in 0..border_dim {
6013 let mut acc = 0.0;
6014 for c in 0..border_dim {
6015 acc += s_dense[[r, c]] * x[c];
6016 }
6017 want[r] = acc;
6018 }
6019 let scale = want.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6020 for a in 0..border_dim {
6021 let rel = (got[a] - want[a]).abs() / scale;
6022 max_rel = max_rel.max(rel);
6023 }
6024 }
6025 assert!(
6026 max_rel <= 1e-10,
6027 "large-K framed SAE Schur matvec vs dense reference diverged: \
6028 max_rel={max_rel:e} (n_atoms={n_atoms}, border_dim={border_dim})"
6029 );
6030 }
6031
6032 #[test]
6038 fn framed_sae_device_pcg_matches_cpu_when_cuda_admits() {
6039 use crate::arrow_schur::{
6040 BetaPenaltyOp, DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock,
6041 FactoredFrameGBlock, FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp,
6042 };
6043
6044 let p = 6usize;
6048 let n_atoms = 8usize;
6049 let ranks: Vec<usize> = (0..n_atoms)
6050 .map(|k| if k % 2 == 0 { 3usize } else { p })
6051 .collect();
6052 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6053 let mut border_offsets = Vec::with_capacity(n_atoms);
6054 let mut acc = 0usize;
6055 for k in 0..n_atoms {
6056 border_offsets.push(acc);
6057 acc += basis_sizes[k] * ranks[k];
6058 }
6059 let border_dim = acc; let mut state = 0xfeed_face_dead_beefu64;
6062 let mut sample = || -> f64 {
6063 state = state
6064 .wrapping_mul(6364136223846793005)
6065 .wrapping_add(1442695040888963407);
6066 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6067 };
6068 let mut frames: Vec<Array2<f64>> = Vec::new();
6069 for k in 0..n_atoms {
6070 let r = ranks[k];
6071 let mut u = Array2::<f64>::zeros((p, r));
6072 for i in 0..p {
6073 for j in 0..r {
6074 u[[i, j]] = if r == p && i == j {
6075 1.0
6076 } else if r == p {
6077 0.0
6078 } else {
6079 sample()
6080 };
6081 }
6082 }
6083 frames.push(u);
6084 }
6085 let w_of = |i: usize, j: usize| {
6086 let (ui, uj) = (&frames[i], &frames[j]);
6087 let (ri, rj) = (ranks[i], ranks[j]);
6088 let mut w = Array2::<f64>::zeros((ri, rj));
6089 for a in 0..ri {
6090 for b in 0..rj {
6091 let mut s = 0.0;
6092 for c in 0..p {
6093 s += ui[[c, a]] * uj[[c, b]];
6094 }
6095 w[[a, b]] = s;
6096 }
6097 }
6098 w
6099 };
6100 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6101 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6103 pairs.push((i, j));
6104 pairs.push((j, i));
6105 }
6106 let mut frame_blocks = Vec::new();
6107 for &(i, j) in &pairs {
6108 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6109 let mut g = Array2::<f64>::zeros((mi, mj));
6110 for r in 0..mi {
6111 for c in 0..mj {
6112 g[[r, c]] = 0.25 * sample();
6113 }
6114 }
6115 if i == j {
6116 for r in 0..mi.min(mj) {
6117 g[[r, r]] += mi as f64 + 2.0;
6118 }
6119 }
6120 frame_blocks.push(FactoredFrameGBlock {
6121 atom_i: i,
6122 atom_j: j,
6123 g,
6124 w: w_of(i, j),
6125 });
6126 }
6127 let mut smooth_blocks = Vec::new();
6128 let mut smooth_ranks = Vec::new();
6129 for k in 0..n_atoms {
6130 let m = basis_sizes[k];
6131 let mut a = Array2::<f64>::zeros((m, m));
6132 for r in 0..m {
6133 for c in 0..m {
6134 a[[r, c]] = 0.2 * sample();
6135 }
6136 }
6137 let mut s = a.t().dot(&a);
6138 for r in 0..m {
6139 s[[r, r]] += 1.0;
6140 }
6141 smooth_blocks.push(DeviceSaeSmoothBlock {
6142 global_offset: border_offsets[k],
6143 factor_a: s,
6144 });
6145 smooth_ranks.push(ranks[k]);
6146 }
6147 let n = 400usize;
6148 let q = 4usize;
6149 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6150 let mut row_htbeta = Vec::new();
6151 for i in 0..n {
6152 let mut a = Array2::<f64>::zeros((q, q));
6153 for r in 0..q {
6154 for c in 0..q {
6155 a[[r, c]] = sample();
6156 }
6157 }
6158 let mut htt = a.t().dot(&a);
6159 for r in 0..q {
6160 htt[[r, r]] += q as f64 + 1.0;
6161 }
6162 sys.rows[i].htt = htt;
6163 let mut slab = vec![0.0_f64; q * border_dim];
6164 for c in 0..q {
6165 for col in 0..border_dim {
6166 let v = 0.02 * sample();
6169 slab[c * border_dim + col] = v;
6170 sys.rows[i].htbeta[[c, col]] = v;
6171 }
6172 }
6173 row_htbeta.push(slab);
6174 }
6175 let data_op =
6176 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks.clone())
6177 .expect("frame op");
6178 let mut hbb = data_op.to_dense();
6179 for k in 0..n_atoms {
6180 let op = IdentityRightKroneckerPenaltyOp {
6181 factor_a: smooth_blocks[k].factor_a.clone(),
6182 p: ranks[k],
6183 global_offset: border_offsets[k],
6184 k: border_dim,
6185 };
6186 let d = op.to_dense();
6187 for r in 0..border_dim {
6188 for c in 0..border_dim {
6189 hbb[[r, c]] += d[[r, c]];
6190 }
6191 }
6192 }
6193 sys.hbb = hbb;
6194 let data = DeviceSaePcgData {
6195 p,
6196 beta_dim: border_dim,
6197 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6198 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6199 smooth_blocks,
6200 sparse_g_blocks: Vec::new(),
6201 frame: Some(DeviceSaeFrameData {
6202 ranks: ranks.clone(),
6203 basis_sizes: basis_sizes.clone(),
6204 border_offsets: border_offsets.clone(),
6205 frame_blocks,
6206 smooth_ranks,
6207 row_htbeta,
6208 }),
6209 };
6210 let ridge_t = 1e-7;
6211 let ridge_beta = 1e-6;
6212 let rhs: Array1<f64> =
6213 Array1::from_shape_fn(border_dim, |a| ((a as f64 + 1.0) * 0.17).sin());
6214
6215 let (device, diag) =
6216 match solve_sae_matrix_free_pcg(&sys, &data, ridge_t, ridge_beta, &rhs, 400, 1e-12) {
6217 Ok(result) => result,
6218 Err(failure) => {
6224 assert!(
6225 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6226 "#1017: CUDA device present but the framed device SAE PCG \
6227 declined/faulted instead of returning a result (tag: {failure:?}) — \
6228 the kernel does not run correctly on GPU"
6229 );
6230 return;
6231 }
6232 };
6233
6234 let rhs_norm = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
6252 let oracle_resid = |x: &Array1<f64>| -> f64 {
6253 let mut sx = vec![0.0_f64; border_dim];
6254 sae_framed_schur_matvec_cpu(
6255 &sys,
6256 &data,
6257 ridge_t,
6258 ridge_beta,
6259 x.as_slice().unwrap(),
6260 &mut sx,
6261 )
6262 .expect("cpu oracle matvec");
6263 let mut acc = 0.0_f64;
6264 for a in 0..border_dim {
6265 let e = sx[a] - rhs[a];
6266 acc += e * e;
6267 }
6268 acc.sqrt()
6269 };
6270 let s_dev_resid = oracle_resid(&device);
6271 let dev_rel_resid = s_dev_resid / rhs_norm.max(1e-300);
6272
6273 let precond = {
6278 let d = sae_frame_penalty_diag_host_for_test(&data, ridge_beta);
6279 let mut diag = d;
6282 for (i, row) in sys.rows.iter().enumerate() {
6283 let slab = &data.frame.as_ref().unwrap().row_htbeta[i];
6284 let qi = sys.row_dims[i];
6285 if slab.is_empty() || qi == 0 || slab.len() != qi * border_dim {
6286 continue;
6287 }
6288 let mut block = row.htt.clone();
6289 for dd in 0..qi {
6290 block[[dd, dd]] += ridge_t;
6291 }
6292 let factor = cholesky_factor_in_place(block.view(), CholeskyGuard::NonnegativePivot)
6293 .expect("row htt PD");
6294 let mut ainv = Array2::<f64>::zeros((qi, qi));
6296 for col in 0..qi {
6297 let mut e = Array1::<f64>::zeros(qi);
6298 e[col] = 1.0;
6299 let s = cholesky_solve_vector(factor.view(), e.view());
6300 for r in 0..qi {
6301 ainv[[r, col]] = s[r];
6302 }
6303 }
6304 for a in 0..border_dim {
6305 let mut quad = 0.0_f64;
6306 for c in 0..qi {
6307 let hc = slab[c * border_dim + a];
6308 for dd in 0..qi {
6309 quad += hc * ainv[[c, dd]] * slab[dd * border_dim + a];
6310 }
6311 }
6312 diag[a] -= quad;
6313 }
6314 }
6315 Array1::from_vec(diag)
6316 };
6317 let mut cpu = Array1::<f64>::zeros(border_dim);
6318 let cpu_result = {
6319 let mut apply = |v: &Array1<f64>, out: &mut Array1<f64>| {
6320 let mut tmp = vec![0.0_f64; border_dim];
6321 sae_framed_schur_matvec_cpu(
6322 &sys,
6323 &data,
6324 ridge_t,
6325 ridge_beta,
6326 v.as_slice().unwrap(),
6327 &mut tmp,
6328 )
6329 .expect("cpu oracle matvec");
6330 out.assign(&Array1::from_vec(tmp));
6331 };
6332 gam_linalg::pcg::pcg_core(
6333 &mut apply,
6334 &rhs.view(),
6335 &precond.view(),
6336 1e-12,
6337 800,
6338 32,
6339 false,
6340 gam_linalg::pcg::DotReduction::Serial,
6341 &mut cpu.view_mut(),
6342 )
6343 };
6344 let s_cpu_resid = oracle_resid(&cpu);
6345 let cpu_rel_resid = s_cpu_resid / rhs_norm.max(1e-300);
6346
6347 assert!(
6350 dev_rel_resid <= 1e-7,
6351 "[#1551] device δβ does not solve the CPU-oracle system: \
6352 ‖S_cpu·device−rhs‖/‖rhs‖={dev_rel_resid:e} (>1e-7) | abs={s_dev_resid:e} | \
6353 device PCG stop={:?} iters={} final_rel_resid={:e} — a large operator residual \
6354 means the device matvec is a DIFFERENT operator (kernel bug)",
6355 diag.stopping_reason,
6356 diag.iterations,
6357 diag.final_relative_residual,
6358 );
6359 assert!(
6362 cpu_rel_resid <= 1e-6,
6363 "[#1551] CPU pcg_core failed to solve the oracle system: \
6364 ‖S_cpu·cpu−rhs‖/‖rhs‖={cpu_rel_resid:e} (stop={:?}, iters={}) — fixture/oracle issue",
6365 cpu_result.stop,
6366 cpu_result.iterations,
6367 );
6368 }
6369
6370 fn sae_frame_penalty_diag_host_for_test(
6374 data: &DeviceSaePcgData,
6375 ridge_beta: f64,
6376 ) -> Vec<f64> {
6377 let frame = data.frame.as_ref().expect("frame");
6378 let mut diag = vec![ridge_beta; data.beta_dim];
6379 for (blk, &r) in data.smooth_blocks.iter().zip(frame.smooth_ranks.iter()) {
6380 let m = blk.factor_a.nrows();
6381 for ia in 0..m {
6382 let coeff = blk.factor_a[[ia, ia]];
6383 let base = blk.global_offset + ia * r;
6384 for ib in 0..r {
6385 diag[base + ib] += coeff;
6386 }
6387 }
6388 }
6389 for blk in &frame.frame_blocks {
6390 if blk.atom_i != blk.atom_j {
6391 continue;
6392 }
6393 let r = frame.ranks[blk.atom_i];
6394 let off = frame.border_offsets[blk.atom_i];
6395 let (mi, mj) = blk.g.dim();
6396 for li in 0..mi.min(mj) {
6397 let gii = blk.g[[li, li]];
6398 let base = off + li * r;
6399 for a in 0..r {
6400 diag[base + a] += gii * blk.w[[a, a]];
6401 }
6402 }
6403 }
6404 diag
6405 }
6406
6407 #[test]
6418 fn framed_sae_device_matvec_matches_cpu_oracle_when_cuda_admits() {
6419 use crate::arrow_schur::{
6420 DeviceSaeFrameData, DeviceSaePcgData, DeviceSaeSmoothBlock, FactoredFrameGBlock,
6421 };
6422
6423 let p = 6usize;
6426 let n_atoms = 8usize;
6427 let ranks: Vec<usize> = (0..n_atoms)
6428 .map(|k| if k % 2 == 0 { 3usize } else { p })
6429 .collect();
6430 let basis_sizes: Vec<usize> = (0..n_atoms).map(|_| 3usize).collect();
6431 let mut border_offsets = Vec::with_capacity(n_atoms);
6432 let mut acc = 0usize;
6433 for k in 0..n_atoms {
6434 border_offsets.push(acc);
6435 acc += basis_sizes[k] * ranks[k];
6436 }
6437 let border_dim = acc;
6438
6439 let mut state = 0x1551_0017_1026_0922u64;
6440 let mut sample = || -> f64 {
6441 state = state
6442 .wrapping_mul(6364136223846793005)
6443 .wrapping_add(1442695040888963407);
6444 ((state >> 33) as f64) / ((1u64 << 31) as f64) - 1.0
6445 };
6446 let mut frames: Vec<Array2<f64>> = Vec::new();
6447 for k in 0..n_atoms {
6448 let r = ranks[k];
6449 let mut u = Array2::<f64>::zeros((p, r));
6450 for i in 0..p {
6451 for j in 0..r {
6452 u[[i, j]] = if r == p && i == j {
6453 1.0
6454 } else if r == p {
6455 0.0
6456 } else {
6457 sample()
6458 };
6459 }
6460 }
6461 frames.push(u);
6462 }
6463 let w_of = |i: usize, j: usize| {
6464 let (ui, uj) = (&frames[i], &frames[j]);
6465 let (ri, rj) = (ranks[i], ranks[j]);
6466 let mut w = Array2::<f64>::zeros((ri, rj));
6467 for a in 0..ri {
6468 for b in 0..rj {
6469 let mut s = 0.0;
6470 for c in 0..p {
6471 s += ui[[c, a]] * uj[[c, b]];
6472 }
6473 w[[a, b]] = s;
6474 }
6475 }
6476 w
6477 };
6478 let mut pairs: Vec<(usize, usize)> = (0..n_atoms).map(|k| (k, k)).collect();
6479 for &(i, j) in &[(0usize, 1usize), (2, 4), (3, 6)] {
6480 pairs.push((i, j));
6481 pairs.push((j, i));
6482 }
6483 let mut frame_blocks = Vec::new();
6484 for &(i, j) in &pairs {
6485 let (mi, mj) = (basis_sizes[i], basis_sizes[j]);
6486 let mut g = Array2::<f64>::zeros((mi, mj));
6487 for r in 0..mi {
6488 for c in 0..mj {
6489 g[[r, c]] = 0.25 * sample();
6490 }
6491 }
6492 if i == j {
6493 for r in 0..mi.min(mj) {
6494 g[[r, r]] += mi as f64 + 2.0;
6495 }
6496 }
6497 frame_blocks.push(FactoredFrameGBlock {
6498 atom_i: i,
6499 atom_j: j,
6500 g,
6501 w: w_of(i, j),
6502 });
6503 }
6504 let mut smooth_blocks = Vec::new();
6505 let mut smooth_ranks = Vec::new();
6506 for k in 0..n_atoms {
6507 let m = basis_sizes[k];
6508 let mut a = Array2::<f64>::zeros((m, m));
6509 for r in 0..m {
6510 for c in 0..m {
6511 a[[r, c]] = 0.2 * sample();
6512 }
6513 }
6514 let mut s = a.t().dot(&a);
6515 for r in 0..m {
6516 s[[r, r]] += 1.0;
6517 }
6518 smooth_blocks.push(DeviceSaeSmoothBlock {
6519 global_offset: border_offsets[k],
6520 factor_a: s,
6521 });
6522 smooth_ranks.push(ranks[k]);
6523 }
6524 let n = 32usize;
6528 let q = 4usize;
6529 let mut sys = ArrowSchurSystem::new(n, q, border_dim);
6530 let mut row_htbeta = Vec::new();
6531 for i in 0..n {
6532 let mut a = Array2::<f64>::zeros((q, q));
6533 for r in 0..q {
6534 for c in 0..q {
6535 a[[r, c]] = sample();
6536 }
6537 }
6538 let mut htt = a.t().dot(&a);
6539 for r in 0..q {
6540 htt[[r, r]] += q as f64 + 1.0;
6541 }
6542 sys.rows[i].htt = htt;
6543 let mut slab = vec![0.0_f64; q * border_dim];
6544 for c in 0..q {
6545 for col in 0..border_dim {
6546 let v = 0.3 * sample();
6547 slab[c * border_dim + col] = v;
6548 sys.rows[i].htbeta[[c, col]] = v;
6549 }
6550 }
6551 row_htbeta.push(slab);
6552 }
6553 let ridge_t = 1e-7;
6554 let ridge_beta = 1e-6;
6555 let data = DeviceSaePcgData {
6556 p,
6557 beta_dim: border_dim,
6558 a_phi: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6559 local_jac: std::sync::Arc::from(Vec::new().into_boxed_slice()),
6560 smooth_blocks,
6561 sparse_g_blocks: Vec::new(),
6562 frame: Some(DeviceSaeFrameData {
6563 ranks: ranks.clone(),
6564 basis_sizes: basis_sizes.clone(),
6565 border_offsets: border_offsets.clone(),
6566 frame_blocks,
6567 smooth_ranks,
6568 row_htbeta,
6569 }),
6570 };
6571
6572 let mut probes: Vec<Array1<f64>> = Vec::new();
6576 probes.push(Array1::from_shape_fn(border_dim, |a| {
6577 ((a as f64 + 1.0) * 0.37).sin()
6578 }));
6579 probes.push(Array1::from_shape_fn(border_dim, |_| sample()));
6580 for axis in [0usize, border_dim / 3, border_dim - 1] {
6581 let mut e = Array1::<f64>::zeros(border_dim);
6582 e[axis] = 1.0;
6583 probes.push(e);
6584 }
6585
6586 let mut any_ran = false;
6587 let mut worst = 0.0_f64;
6588 for (pi, x) in probes.iter().enumerate() {
6589 let device = match super::framed_schur_matvec_once_on_device(
6590 &sys, &data, ridge_t, ridge_beta, x,
6591 ) {
6592 Ok(out) => out,
6593 Err(failure) => {
6594 assert!(
6598 gam_gpu::device_runtime::GpuRuntime::global().is_none(),
6599 "#1551: CUDA device present but the framed device matvec \
6600 declined/faulted (probe {pi}, tag: {failure:?}) — the kernel \
6601 does not run on GPU"
6602 );
6603 return;
6604 }
6605 };
6606 any_ran = true;
6607 let mut cpu = vec![0.0_f64; border_dim];
6608 sae_framed_schur_matvec_cpu(&sys, &data, ridge_t, ridge_beta, x.as_slice().unwrap(), &mut cpu)
6609 .expect("cpu oracle matvec");
6610 let scale = cpu.iter().fold(0.0_f64, |m, v| m.max(v.abs())).max(1.0);
6611 for a in 0..border_dim {
6612 let rel = (device[a] - cpu[a]).abs() / scale;
6613 worst = worst.max(rel);
6614 assert!(
6615 rel <= 1e-9,
6616 "[#1551 matvec-parity] probe {pi} component {a}: device={:e} cpu={:e} \
6617 rel={rel:e} (>1e-9) — framed S·x kernel diverges from the CPU oracle",
6618 device[a],
6619 cpu[a],
6620 );
6621 }
6622 }
6623 if any_ran {
6624 assert!(
6629 gam_gpu::device_runtime::GpuRuntime::global().is_some(),
6630 "#1551: matvec ran but no GPU runtime — unexpected"
6631 );
6632 assert!(worst <= 1e-9, "framed matvec parity worst rel = {worst:e}");
6633 }
6634 }
6635}