1use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41 Sobolev,
43 Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48 pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52 match self {
53 SphereSpectralKernelKind::Sobolev => {
54 crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55 }
56 SphereSpectralKernelKind::Pseudo => {
57 crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58 }
59 }
60 }
61
62 pub const fn tag(self) -> &'static str {
64 match self {
65 SphereSpectralKernelKind::Sobolev => "sobolev",
66 SphereSpectralKernelKind::Pseudo => "pseudo",
67 }
68 }
69}
70
71#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76 ColumnMajor,
77}
78
79pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86 if latlon.ncols() != 2 {
87 return Err(format!(
88 "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89 latlon.shape()
90 ));
91 }
92 let deg = if radians {
93 1.0
94 } else {
95 std::f64::consts::PI / 180.0
96 };
97 let n = latlon.nrows();
98 let mut out = Vec::with_capacity(3 * n);
99 for row in latlon.outer_iter() {
100 let lat = row[0] * deg;
101 let lon = row[1] * deg;
102 let (s_lat, c_lat) = lat.sin_cos();
103 let (s_lon, c_lon) = lon.sin_cos();
104 out.push(c_lat * c_lon);
106 out.push(c_lat * s_lon);
107 out.push(s_lat);
108 }
109 Ok(out)
110}
111
112#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120 pub rows: usize,
121 pub cols: usize,
122 pub ld: usize,
123 pub col_major_dev: CudaSlice<f64>,
124 pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129 pub rows: usize,
130 pub cols: usize,
131 pub ld: usize,
132 pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137 #[cfg(target_os = "linux")]
153 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
154 let needed = self.ld * self.cols;
155 let mut staging = PinnedLease::acquire(self.stream.context(), needed)?;
156 self.stream
157 .memcpy_dtoh(&self.col_major_dev, staging.as_mut_slice())
158 .gpu_ctx("DeviceS2KernelMatrix dtoh (pinned)")?;
159 self.stream
160 .synchronize()
161 .gpu_ctx("DeviceS2KernelMatrix synchronize (pinned)")?;
162 Ok(col_major_to_row_major_parallel(
163 staging.as_slice(),
164 self.rows,
165 self.cols,
166 self.ld,
167 ))
168 }
169
170 #[cfg(not(target_os = "linux"))]
171 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
172 let mut col_major = vec![0.0_f64; self.ld * self.cols];
176 self.copy_to_host_col_major(&mut col_major)?;
177 Ok(col_major_to_row_major_parallel(
178 &col_major, self.rows, self.cols, self.ld,
179 ))
180 }
181
182 #[cfg(target_os = "linux")]
187 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
188 let needed = self.ld * self.cols;
189 if dst.len() != needed {
190 gam_gpu::gpu_bail!(
191 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
192 dst.len(),
193 needed
194 );
195 }
196 self.stream
197 .memcpy_dtoh(&self.col_major_dev, dst)
198 .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
199 self.stream
200 .synchronize()
201 .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
202 Ok(())
203 }
204
205 #[cfg(not(target_os = "linux"))]
206 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
207 let needed = self.ld * self.cols;
208 if dst.len() != needed {
209 gam_gpu::gpu_bail!(
210 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
211 dst.len(),
212 needed
213 );
214 }
215 dst.copy_from_slice(&self.col_major_dev);
216 Ok(())
217 }
218}
219
220fn col_major_to_row_major_parallel(
236 col_major: &[f64],
237 rows: usize,
238 cols: usize,
239 ld: usize,
240) -> Array2<f64> {
241 use rayon::prelude::*;
242
243 assert!(ld >= rows, "ld {ld} must be >= rows {rows}");
244 assert!(
245 col_major.len() >= ld * cols,
246 "col_major len {} < ld*cols {}",
247 col_major.len(),
248 ld * cols
249 );
250
251 const BLOCK_ROWS: usize = 128;
254
255 let mut out_flat = vec![0.0_f64; rows * cols];
256 out_flat
257 .par_chunks_mut(BLOCK_ROWS * cols)
258 .enumerate()
259 .for_each(|(block_idx, out_block)| {
260 let r0 = block_idx * BLOCK_ROWS;
261 let block_rows = out_block.len() / cols;
262 for j in 0..cols {
263 let base = j * ld + r0;
264 let src_col = &col_major[base..base + block_rows];
265 for (local_i, &v) in src_col.iter().enumerate() {
267 out_block[local_i * cols + j] = v;
268 }
269 }
270 });
271
272 Array2::from_shape_vec((rows, cols), out_flat).expect("row-major buffer has rows*cols elements")
273}
274
275#[cfg(target_os = "linux")]
286struct PinnedF64 {
287 ptr: *mut f64,
288 len: usize,
289 freed: bool,
290}
291
292#[cfg(target_os = "linux")]
293impl PinnedF64 {
294 fn alloc(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
297 ctx.bind_to_thread().gpu_ctx("PinnedF64 bind_to_thread")?;
298 let bytes = len
299 .checked_mul(std::mem::size_of::<f64>())
300 .ok_or_else(|| gam_gpu::gpu_err!("PinnedF64: len={len} byte size overflows usize"))?;
301 let raw = unsafe { cudarc::driver::result::malloc_host(bytes, 0) }
306 .gpu_ctx("PinnedF64 cuMemHostAlloc")?;
307 let ptr = raw as *mut f64;
308 if ptr.is_null() {
309 gam_gpu::gpu_bail!("PinnedF64: cuMemHostAlloc returned null for {bytes} bytes");
310 }
311 Ok(Self {
312 ptr,
313 len,
314 freed: false,
315 })
316 }
317
318 fn as_mut_slice(&mut self) -> &mut [f64] {
319 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
322 }
323
324 fn as_slice(&self) -> &[f64] {
325 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
327 }
328}
329
330#[cfg(target_os = "linux")]
331impl Drop for PinnedF64 {
332 fn drop(&mut self) {
333 if self.freed {
334 return;
335 }
336 self.freed = true;
337 unsafe { cudarc::driver::result::free_host(self.ptr as *mut std::ffi::c_void) }.ok();
342 }
343}
344
345#[cfg(target_os = "linux")]
351unsafe impl Send for PinnedF64 {}
352
353#[cfg(target_os = "linux")]
362const PINNED_POOL_MAX_BUFFERS: usize = 4;
363
364#[cfg(target_os = "linux")]
365static PINNED_POOL: OnceLock<Mutex<Vec<PinnedF64>>> = OnceLock::new();
366
367#[cfg(target_os = "linux")]
371struct PinnedLease {
372 buf: Option<PinnedF64>,
373}
374
375#[cfg(target_os = "linux")]
376impl PinnedLease {
377 fn acquire(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
380 let pool = PINNED_POOL.get_or_init(|| Mutex::new(Vec::new()));
381 if let Ok(mut guard) = pool.lock() {
382 if let Some(pos) = guard.iter().position(|b| b.len == len) {
383 return Ok(Self {
384 buf: Some(guard.swap_remove(pos)),
385 });
386 }
387 }
388 Ok(Self {
389 buf: Some(PinnedF64::alloc(ctx, len)?),
390 })
391 }
392
393 fn as_mut_slice(&mut self) -> &mut [f64] {
394 self.buf
395 .as_mut()
396 .expect("PinnedLease buffer present until drop")
397 .as_mut_slice()
398 }
399
400 fn as_slice(&self) -> &[f64] {
401 self.buf
402 .as_ref()
403 .expect("PinnedLease buffer present until drop")
404 .as_slice()
405 }
406}
407
408#[cfg(target_os = "linux")]
409impl Drop for PinnedLease {
410 fn drop(&mut self) {
411 let Some(buf) = self.buf.take() else {
412 return;
413 };
414 if let Some(pool) = PINNED_POOL.get() {
415 if let Ok(mut guard) = pool.lock() {
416 if guard.len() < PINNED_POOL_MAX_BUFFERS {
417 guard.push(buf);
418 return;
419 }
420 guard.remove(0);
424 guard.push(buf);
425 return;
426 }
427 }
428 drop(buf);
430 }
431}
432
433#[derive(Clone, Debug)]
444pub struct S2KernelBuildInputs<'a> {
445 pub n: usize,
446 pub m: usize,
447 pub lmax: usize,
448 pub data_xyz: &'a [f64],
449 pub centers_xyz: &'a [f64],
450 pub coeffs: &'a [f64],
451 pub kind: SphereSpectralKernelKind,
452 pub layout: DeviceMatrixLayout,
453}
454
455impl<'a> S2KernelBuildInputs<'a> {
456 fn validate(&self) -> Result<(), GpuError> {
457 if self.lmax == 0 {
458 return Err(GpuError::DriverCallFailed {
459 reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
460 });
461 }
462 if self.data_xyz.len() != 3 * self.n {
463 gam_gpu::gpu_bail!(
464 "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
465 self.data_xyz.len(),
466 3 * self.n
467 );
468 }
469 if self.centers_xyz.len() != 3 * self.m {
470 gam_gpu::gpu_bail!(
471 "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
472 self.centers_xyz.len(),
473 3 * self.m
474 );
475 }
476 if self.coeffs.len() != self.lmax + 1 {
477 gam_gpu::gpu_bail!(
478 "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
479 self.coeffs.len(),
480 self.lmax + 1
481 );
482 }
483 if self.coeffs[0] != 0.0 {
484 return Err(GpuError::DriverCallFailed {
485 reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
486 });
487 }
488 Ok(())
489 }
490}
491
492#[cfg(target_os = "linux")]
502const KERNEL_TEMPLATE: &str = r#"
503// LMAX is supplied by the host via a `#define LMAX ...` prepended to
504// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
505extern "C" __global__
506__launch_bounds__(256)
507void s2_wahba_legendre_colmajor(
508 const double* __restrict__ data_xyz, // n × 3 (row-major flat)
509 const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
510 const double* __restrict__ coeffs, // length LMAX + 1, coeffs[0] = 0
511 int n,
512 int m,
513 long long ld,
514 double* __restrict__ out // ld × m column-major
515) {
516 const int i = blockIdx.y * blockDim.y + threadIdx.y;
517 const int j = blockIdx.x * blockDim.x + threadIdx.x;
518 if (i >= n || j >= m) return;
519
520 // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
521 const double xi = data_xyz[3 * i + 0];
522 const double yi = data_xyz[3 * i + 1];
523 const double zi = data_xyz[3 * i + 2];
524 const double cxj = centers_xyz[3 * j + 0];
525 const double cyj = centers_xyz[3 * j + 1];
526 const double czj = centers_xyz[3 * j + 2];
527
528 // t = clamp(x_i · z_j, -1, +1).
529 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
530 if (t > 1.0) t = 1.0;
531 if (t < -1.0) t = -1.0;
532
533 // Legendre 3-term recurrence in registers.
534 // P_0(t) = 1, P_1(t) = t.
535 double p_prev = 1.0;
536 double p_curr = t;
537 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
538
539 #pragma unroll 8
540 for (int ell = 1; ell < LMAX; ++ell) {
541 const double lf = (double) ell;
542 const double inv = 1.0 / (lf + 1.0);
543 // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
544 const double p_next =
545 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
546 acc = fma(coeffs[ell + 1], p_next, acc);
547 p_prev = p_curr;
548 p_curr = p_next;
549 }
550
551 out[(long long) j * ld + (long long) i] = acc;
552}
553
554// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
555// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
556// i.e. drop the first column after applying Z. Each thread computes one
557// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
558// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
559// for j_out in 0..m-1.
560//
561// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
562// over centers in an inner loop — register-bound by the per-row state
563// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
564extern "C" __global__
565__launch_bounds__(128)
566void s2_wahba_householder_constrained_colmajor(
567 const double* __restrict__ data_xyz, // n × 3
568 const double* __restrict__ centers_xyz, // m × 3
569 const double* __restrict__ coeffs, // length LMAX + 1
570 const double* __restrict__ v, // length m, Householder vector
571 double beta,
572 int n,
573 int m,
574 long long ld_out,
575 double* __restrict__ out // ld_out × (m-1) column-major
576) {
577 const int i = blockIdx.x * blockDim.x + threadIdx.x;
578 if (i >= n) return;
579
580 const double xi = data_xyz[3 * i + 0];
581 const double yi = data_xyz[3 * i + 1];
582 const double zi = data_xyz[3 * i + 2];
583
584 // Pass 1: compute d_i = sum_j v[j] * B[i, j].
585 double d_i = 0.0;
586 for (int j = 0; j < m; ++j) {
587 const double cxj = centers_xyz[3 * j + 0];
588 const double cyj = centers_xyz[3 * j + 1];
589 const double czj = centers_xyz[3 * j + 2];
590 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
591 if (t > 1.0) t = 1.0;
592 if (t < -1.0) t = -1.0;
593
594 double p_prev = 1.0;
595 double p_curr = t;
596 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
597 #pragma unroll 8
598 for (int ell = 1; ell < LMAX; ++ell) {
599 const double lf = (double) ell;
600 const double inv = 1.0 / (lf + 1.0);
601 const double p_next =
602 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
603 acc = fma(coeffs[ell + 1], p_next, acc);
604 p_prev = p_curr;
605 p_curr = p_next;
606 }
607 d_i = fma(v[j], acc, d_i);
608 }
609
610 // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
611 const double bd = beta * d_i;
612 for (int j_out = 0; j_out < m - 1; ++j_out) {
613 const int j = j_out + 1;
614 const double cxj = centers_xyz[3 * j + 0];
615 const double cyj = centers_xyz[3 * j + 1];
616 const double czj = centers_xyz[3 * j + 2];
617 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
618 if (t > 1.0) t = 1.0;
619 if (t < -1.0) t = -1.0;
620
621 double p_prev = 1.0;
622 double p_curr = t;
623 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
624 #pragma unroll 8
625 for (int ell = 1; ell < LMAX; ++ell) {
626 const double lf = (double) ell;
627 const double inv = 1.0 / (lf + 1.0);
628 const double p_next =
629 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
630 acc = fma(coeffs[ell + 1], p_next, acc);
631 p_prev = p_curr;
632 p_curr = p_next;
633 }
634 const double xs = acc - bd * v[j];
635 out[(long long) j_out * ld_out + (long long) i] = xs;
636 }
637}
638"#;
639
640#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
650pub struct S2ModuleCacheKey {
651 pub cc_major: i32,
652 pub cc_minor: i32,
653 pub lmax: u32,
654 pub kind: SphereSpectralKernelKind,
655 pub layout: DeviceMatrixLayout,
656}
657
658pub const fn sphere_gpu_compiled() -> bool {
661 cfg!(target_os = "linux")
662}
663
664#[must_use]
671pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
672 let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
673 let ld = ((n + 31) / 32) * 32;
674 let needed_bytes = ld
675 .saturating_mul(m)
676 .saturating_mul(std::mem::size_of::<f64>());
677 let budget = runtime.memory_budget_bytes;
678 n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
679 } else {
680 false
681 };
682 decide(
683 GpuKernel::SpatialKernelOperator,
684 gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
685 )
686}
687
688#[must_use]
694pub fn truncated_device_kind(
695 kernel: crate::basis::SphereWahbaKernel,
696) -> Option<(SphereSpectralKernelKind, u16)> {
697 use crate::basis::SphereWahbaKernel;
698 match kernel {
699 SphereWahbaKernel::SobolevTruncated { lmax } => {
700 Some((SphereSpectralKernelKind::Sobolev, lmax))
701 }
702 SphereWahbaKernel::PseudoTruncated { lmax } => {
703 Some((SphereSpectralKernelKind::Pseudo, lmax))
704 }
705 SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
706 }
707}
708
709pub fn try_build_truncated_kernel_matrix_gpu(
731 data: ArrayView2<'_, f64>,
732 centers: ArrayView2<'_, f64>,
733 penalty_order: usize,
734 radians: bool,
735 kernel: crate::basis::SphereWahbaKernel,
736) -> Option<Result<Array2<f64>, GpuError>> {
737 let (kind, lmax) = truncated_device_kind(kernel)?;
738 let n = data.nrows();
739 let m = centers.nrows();
740 if n == 0 || m == 0 || lmax == 0 {
741 return None;
742 }
743 let decision = sphere_kernel_decision(n, m, lmax as usize);
744 if !decision.use_gpu {
745 return None;
748 }
749 Some(build_truncated_kernel_matrix_gpu_admitted(
751 data,
752 centers,
753 penalty_order,
754 radians,
755 kind,
756 lmax,
757 ))
758}
759
760fn build_truncated_kernel_matrix_gpu_admitted(
764 data: ArrayView2<'_, f64>,
765 centers: ArrayView2<'_, f64>,
766 penalty_order: usize,
767 radians: bool,
768 kind: SphereSpectralKernelKind,
769 lmax: u16,
770) -> Result<Array2<f64>, GpuError> {
771 let n = data.nrows();
772 let m = centers.nrows();
773 let data_xyz = latlon_to_xyz_host(data, radians)
774 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
775 let centers_xyz = latlon_to_xyz_host(centers, radians)
776 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
777 let coeffs = kind.coefficients(lmax as usize, penalty_order);
781 let inputs = S2KernelBuildInputs {
782 n,
783 m,
784 lmax: lmax as usize,
785 data_xyz: &data_xyz,
786 centers_xyz: ¢ers_xyz,
787 coeffs: &coeffs,
788 kind,
789 layout: DeviceMatrixLayout::ColumnMajor,
790 };
791 let device_matrix = build_kernel_matrix_device(inputs)?;
792 let out = device_matrix.to_host_array()?;
793 if !out.sum().is_finite() {
804 return Err(GpuError::DriverCallFailed {
805 reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
806 });
807 }
808 Ok(out)
809}
810
811#[cfg(target_os = "linux")]
812struct SphereGpuContext {
813 ctx: Arc<CudaContext>,
814 stream: Arc<CudaStream>,
815 modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
816 cc_major: i32,
817 cc_minor: i32,
818}
819
820pub struct SphereGpuBackend {
823 #[cfg(target_os = "linux")]
824 inner: SphereGpuContext,
825}
826
827impl SphereGpuBackend {
828 pub fn probe() -> Result<&'static Self, GpuError> {
830 static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
831 BACKEND
832 .get_or_init(|| {
833 #[cfg(target_os = "linux")]
834 {
835 Self::probe_linux()
836 }
837 #[cfg(not(target_os = "linux"))]
838 {
839 Err(GpuError::DriverLibraryUnavailable {
840 reason: "sphere GPU backend is Linux-only".to_string(),
841 })
842 }
843 })
844 .as_ref()
845 .map_err(GpuError::clone)
846 }
847
848 #[cfg(target_os = "linux")]
849 fn probe_linux() -> Result<Self, GpuError> {
850 let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
851 Ok(SphereGpuBackend {
852 inner: SphereGpuContext {
853 ctx: parts.ctx,
854 stream: parts.stream,
855 modules: Mutex::new(HashMap::new()),
856 cc_major: parts.capability.compute_major,
857 cc_minor: parts.capability.compute_minor,
858 },
859 })
860 }
861
862 #[cfg(target_os = "linux")]
865 fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
866 if let Ok(guard) = self.inner.modules.lock() {
867 if let Some(existing) = guard.get(&key) {
868 return Ok(existing.clone());
869 }
870 }
871 let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
881 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).gpu_ctx_with(|err| {
882 format!(
883 "sphere NVRTC compile (kind={}, lmax={}): {err}",
884 key.kind.tag(),
885 key.lmax
886 )
887 })?;
888 let module = self
889 .inner
890 .ctx
891 .load_module(ptx)
892 .gpu_ctx("sphere module load")?;
893 if let Ok(mut guard) = self.inner.modules.lock() {
894 guard.entry(key).or_insert_with(|| module.clone());
895 }
896 Ok(module)
897 }
898
899 #[cfg(target_os = "linux")]
900 fn cc(&self) -> (i32, i32) {
901 (self.inner.cc_major, self.inner.cc_minor)
902 }
903}
904
905pub fn build_kernel_matrix_device(
912 inputs: S2KernelBuildInputs<'_>,
913) -> Result<DeviceS2KernelMatrix, GpuError> {
914 inputs.validate()?;
915
916 #[cfg(target_os = "linux")]
917 {
918 use cudarc::driver::{LaunchConfig, PushKernelArg};
919 let backend = SphereGpuBackend::probe()?;
920 let (cc_major, cc_minor) = backend.cc();
921 let key = S2ModuleCacheKey {
922 cc_major,
923 cc_minor,
924 lmax: inputs.lmax as u32,
925 kind: inputs.kind,
926 layout: inputs.layout,
927 };
928 let module = backend.module_for(key)?;
929 let func = module
930 .load_function("s2_wahba_legendre_colmajor")
931 .gpu_ctx("sphere load_function raw")?;
932 let stream = backend.inner.stream.clone();
933
934 let data_dev = stream
935 .clone_htod(inputs.data_xyz)
936 .gpu_ctx("sphere htod data_xyz")?;
937 let centers_dev = stream
938 .clone_htod(inputs.centers_xyz)
939 .gpu_ctx("sphere htod centers_xyz")?;
940 let coeffs_dev = stream
941 .clone_htod(inputs.coeffs)
942 .gpu_ctx("sphere htod coeffs")?;
943
944 let n = inputs.n;
945 let m = inputs.m;
946 let ld = ((n + 31) / 32) * 32;
947 let mut out_dev = stream
948 .alloc_zeros::<f64>(ld * m)
949 .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
950
951 let block_x: u32 = 32;
953 let block_y: u32 = 8;
954 let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
955 let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
956 let cfg = LaunchConfig {
957 grid_dim: (grid_x, grid_y, 1),
958 block_dim: (block_x, block_y, 1),
959 shared_mem_bytes: 0,
960 };
961 let n_i32: i32 =
962 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
963 let m_i32: i32 =
964 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
965 let ld_i64: i64 = ld as i64;
966
967 let mut builder = stream.launch_builder(&func);
968 builder
969 .arg(&data_dev)
970 .arg(¢ers_dev)
971 .arg(&coeffs_dev)
972 .arg(&n_i32)
973 .arg(&m_i32)
974 .arg(&ld_i64)
975 .arg(&mut out_dev);
976 unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
981 stream
982 .synchronize()
983 .gpu_ctx("sphere raw kernel synchronize")?;
984
985 Ok(DeviceS2KernelMatrix {
986 rows: n,
987 cols: m,
988 ld,
989 col_major_dev: out_dev,
990 stream,
991 })
992 }
993
994 #[cfg(not(target_os = "linux"))]
995 {
996 Err(GpuError::DriverLibraryUnavailable {
997 reason: "sphere GPU backend is Linux-only".to_string(),
998 })
999 }
1000}
1001
1002pub fn build_householder_constrained_design_device(
1006 inputs: S2KernelBuildInputs<'_>,
1007 v: &[f64],
1008 beta: f64,
1009) -> Result<DeviceS2KernelMatrix, GpuError> {
1010 inputs.validate()?;
1011 if v.len() != inputs.m {
1012 gam_gpu::gpu_bail!(
1013 "build_householder_constrained_design_device: v.len()={} != m={}",
1014 v.len(),
1015 inputs.m
1016 );
1017 }
1018 if inputs.m < 2 {
1019 gam_gpu::gpu_bail!(
1020 "build_householder_constrained_design_device: m must be >= 2 (got {})",
1021 inputs.m
1022 );
1023 }
1024 if !beta.is_finite() {
1025 gam_gpu::gpu_bail!(
1026 "build_householder_constrained_design_device: beta must be finite (got {beta})"
1027 );
1028 }
1029
1030 #[cfg(target_os = "linux")]
1031 {
1032 use cudarc::driver::{LaunchConfig, PushKernelArg};
1033 let backend = SphereGpuBackend::probe()?;
1034 let (cc_major, cc_minor) = backend.cc();
1035 let key = S2ModuleCacheKey {
1036 cc_major,
1037 cc_minor,
1038 lmax: inputs.lmax as u32,
1039 kind: inputs.kind,
1040 layout: inputs.layout,
1041 };
1042 let module = backend.module_for(key)?;
1043 let func = module
1044 .load_function("s2_wahba_householder_constrained_colmajor")
1045 .gpu_ctx("sphere load_function householder")?;
1046 let stream = backend.inner.stream.clone();
1047
1048 let data_dev = stream
1049 .clone_htod(inputs.data_xyz)
1050 .gpu_ctx("sphere-hh htod data_xyz")?;
1051 let centers_dev = stream
1052 .clone_htod(inputs.centers_xyz)
1053 .gpu_ctx("sphere-hh htod centers_xyz")?;
1054 let coeffs_dev = stream
1055 .clone_htod(inputs.coeffs)
1056 .gpu_ctx("sphere-hh htod coeffs")?;
1057 let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
1058
1059 let n = inputs.n;
1060 let m = inputs.m;
1061 let cols_out = m - 1;
1062 let ld_out = ((n + 31) / 32) * 32;
1063 let mut out_dev = stream
1064 .alloc_zeros::<f64>(ld_out * cols_out)
1065 .gpu_ctx_with(|err| {
1066 format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
1067 })?;
1068
1069 let block_x: u32 = 128;
1070 let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
1071 let cfg = LaunchConfig {
1072 grid_dim: (grid_x, 1, 1),
1073 block_dim: (block_x, 1, 1),
1074 shared_mem_bytes: 0,
1075 };
1076 let n_i32: i32 =
1077 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
1078 let m_i32: i32 =
1079 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
1080 let ld_out_i64: i64 = ld_out as i64;
1081
1082 let mut builder = stream.launch_builder(&func);
1083 builder
1084 .arg(&data_dev)
1085 .arg(¢ers_dev)
1086 .arg(&coeffs_dev)
1087 .arg(&v_dev)
1088 .arg(&beta)
1089 .arg(&n_i32)
1090 .arg(&m_i32)
1091 .arg(&ld_out_i64)
1092 .arg(&mut out_dev);
1093 unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
1096 stream
1097 .synchronize()
1098 .gpu_ctx("sphere-hh kernel synchronize")?;
1099
1100 Ok(DeviceS2KernelMatrix {
1101 rows: n,
1102 cols: cols_out,
1103 ld: ld_out,
1104 col_major_dev: out_dev,
1105 stream,
1106 })
1107 }
1108
1109 #[cfg(not(target_os = "linux"))]
1110 {
1111 Err(GpuError::DriverLibraryUnavailable {
1112 reason: "sphere GPU backend is Linux-only".to_string(),
1113 })
1114 }
1115}
1116
1117pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
1130 let m = w.len();
1131 if m == 0 {
1132 return (Vec::new(), 0.0);
1133 }
1134 let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1135 if norm == 0.0 {
1136 return (vec![0.0; m], 0.0);
1137 }
1138 let sigma = if w[0] >= 0.0 { norm } else { -norm };
1139 let mut v = w.to_vec();
1140 v[0] += sigma;
1141 let v0 = v[0];
1142 if v0 == 0.0 {
1143 return (vec![0.0; m], 0.0);
1144 }
1145 for entry in v.iter_mut() {
1147 *entry /= v0;
1148 }
1149 let vv: f64 = v.iter().map(|x| x * x).sum();
1151 let beta = 2.0 / vv;
1152 (v, beta)
1153}
1154
1155pub fn build_center_kernel_device(
1174 centers_xyz: &[f64],
1175 lmax: usize,
1176 coeffs: &[f64],
1177 kind: SphereSpectralKernelKind,
1178) -> Result<DeviceS2KernelMatrix, GpuError> {
1179 let m = centers_xyz.len() / 3;
1180 if centers_xyz.len() != 3 * m {
1181 return Err(GpuError::DriverCallFailed {
1182 reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
1183 });
1184 }
1185 let inputs = S2KernelBuildInputs {
1186 n: m,
1187 m,
1188 lmax,
1189 data_xyz: centers_xyz,
1190 centers_xyz,
1191 coeffs,
1192 kind,
1193 layout: DeviceMatrixLayout::ColumnMajor,
1194 };
1195 build_kernel_matrix_device(inputs)
1196}
1197
1198pub fn constrained_penalty_host(
1203 c: ArrayView2<'_, f64>,
1204 w: &[f64],
1205) -> Result<Array2<f64>, GpuError> {
1206 let (m1, m2) = c.dim();
1207 if m1 != m2 {
1208 gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
1209 }
1210 let m = m1;
1211 if w.len() != m {
1212 gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
1213 }
1214 if m < 2 {
1215 gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
1216 }
1217 let (v, beta) = householder_reflector_from_weights(w);
1218
1219 let mut u = vec![0.0_f64; m];
1222 for i in 0..m {
1223 let mut acc = 0.0_f64;
1224 for j in 0..m {
1225 acc += c[(i, j)] * v[j];
1226 }
1227 u[i] = acc;
1228 }
1229 let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
1230 let mut hch = Array2::<f64>::zeros((m, m));
1231 for i in 0..m {
1232 for j in 0..m {
1233 hch[(i, j)] =
1234 c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
1235 }
1236 }
1237 let mut s = Array2::<f64>::zeros((m - 1, m - 1));
1239 for i in 0..(m - 1) {
1240 for j in 0..(m - 1) {
1241 s[(i, j)] = hch[(i + 1, j + 1)];
1242 }
1243 }
1244 Ok(s)
1245}
1246
1247#[derive(Clone, Debug)]
1278pub struct PenalisedLsSolution {
1279 pub beta: Vec<f64>,
1281 pub weighted_residual_ssq: f64,
1283 pub log_det_hessian: f64,
1285}
1286
1287#[cfg(target_os = "linux")]
1297pub fn solve_penalised_ls_device(
1298 x_s_device: &DeviceS2KernelMatrix,
1299 wy: &[f64],
1300 r_s: ArrayView2<'_, f64>,
1301) -> Result<PenalisedLsSolution, GpuError> {
1302 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1303 use cudarc::driver::DevicePtrMut;
1304
1305 let n = x_s_device.rows;
1306 let p = x_s_device.cols;
1307 if wy.len() != n {
1308 gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1309 }
1310 if r_s.dim() != (p, p) {
1311 gam_gpu::gpu_bail!(
1312 "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1313 r_s.dim()
1314 );
1315 }
1316 if p == 0 {
1317 return Ok(PenalisedLsSolution {
1318 beta: Vec::new(),
1319 weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1320 log_det_hessian: 0.0,
1321 });
1322 }
1323
1324 let stream = x_s_device.stream.clone();
1325 let n_aug = n + p;
1326
1327 let mut a_aug_host = vec![0.0_f64; n_aug * p];
1332 let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1334 x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1335 for j in 0..p {
1336 let src_off = j * x_s_device.ld;
1337 let dst_off = j * n_aug;
1338 a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1339 for i in 0..p {
1340 a_aug_host[dst_off + n + i] = r_s[(i, j)];
1343 }
1344 }
1345 let mut a_dev = stream
1346 .clone_htod(&a_aug_host)
1347 .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1348
1349 let mut b_host = vec![0.0_f64; n_aug];
1351 b_host[..n].copy_from_slice(wy);
1352 let mut b_dev = stream
1353 .clone_htod(&b_host)
1354 .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1355
1356 let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1357 let n_aug_i: i32 = i32::try_from(n_aug)
1358 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1359 let p_i: i32 = i32::try_from(p)
1360 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1361
1362 let mut lwork: i32 = 0;
1364 {
1365 let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1366 let status = unsafe {
1369 cusolver_sys::cusolverDnDgeqrf_bufferSize(
1370 solver.cu(),
1371 n_aug_i,
1372 p_i,
1373 a_ptr as *mut f64,
1374 n_aug_i,
1375 &mut lwork,
1376 )
1377 };
1378 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1379 gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1380 }
1381 }
1382 let lwork_us = usize::try_from(lwork)
1383 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1384 let mut workspace = stream
1385 .alloc_zeros::<f64>(lwork_us.max(1))
1386 .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1387 let mut tau = stream
1388 .alloc_zeros::<f64>(p)
1389 .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1390 let mut info = stream
1391 .alloc_zeros::<i32>(1)
1392 .gpu_ctx("solve_penalised_ls_device alloc info")?;
1393
1394 {
1396 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1397 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1398 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1399 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1400 let status = unsafe {
1403 cusolver_sys::cusolverDnDgeqrf(
1404 solver.cu(),
1405 n_aug_i,
1406 p_i,
1407 a_ptr as *mut f64,
1408 n_aug_i,
1409 tau_ptr as *mut f64,
1410 work_ptr as *mut f64,
1411 lwork,
1412 info_ptr as *mut i32,
1413 )
1414 };
1415 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1416 gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1417 }
1418 }
1419
1420 let mut ormqr_lwork: i32 = 0;
1422 {
1423 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1424 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1425 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1426 let status = unsafe {
1429 cusolver_sys::cusolverDnDormqr_bufferSize(
1430 solver.cu(),
1431 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1432 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1433 n_aug_i,
1434 1,
1435 p_i,
1436 a_ptr as *const f64,
1437 n_aug_i,
1438 tau_ptr as *const f64,
1439 b_ptr as *mut f64,
1440 n_aug_i,
1441 &mut ormqr_lwork,
1442 )
1443 };
1444 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1445 gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1446 }
1447 }
1448 if ormqr_lwork > lwork {
1449 workspace = stream
1450 .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1451 .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1452 }
1453 {
1454 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1455 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1456 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1457 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1458 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1459 let status = unsafe {
1463 cusolver_sys::cusolverDnDormqr(
1464 solver.cu(),
1465 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1466 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1467 n_aug_i,
1468 1,
1469 p_i,
1470 a_ptr as *const f64,
1471 n_aug_i,
1472 tau_ptr as *const f64,
1473 b_ptr as *mut f64,
1474 n_aug_i,
1475 work_ptr as *mut f64,
1476 ormqr_lwork.max(lwork),
1477 info_ptr as *mut i32,
1478 )
1479 };
1480 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1481 gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1482 }
1483 }
1484
1485 {
1488 use cudarc::cublas::CudaBlas;
1489 let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1490 let alpha = 1.0_f64;
1491 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1492 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1493 let handle = *blas.handle();
1498 let status = unsafe {
1499 cudarc::cublas::sys::cublasDtrsm_v2(
1500 handle,
1501 cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1502 cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1503 cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1504 cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1505 p_i,
1506 1,
1507 &alpha,
1508 a_ptr as *const f64,
1509 n_aug_i,
1510 b_ptr as *mut f64,
1511 n_aug_i,
1512 )
1513 };
1514 if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1515 gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1516 }
1517 }
1518
1519 let mut b_out = vec![0.0_f64; n_aug];
1521 stream
1522 .memcpy_dtoh(&b_dev, &mut b_out)
1523 .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1524 let mut a_back = vec![0.0_f64; n_aug * p];
1525 stream
1526 .memcpy_dtoh(&a_dev, &mut a_back)
1527 .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1528 stream
1529 .synchronize()
1530 .gpu_ctx("solve_penalised_ls_device synchronize")?;
1531
1532 let beta: Vec<f64> = b_out[..p].to_vec();
1533 let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1542
1543 let mut log_abs_r = 0.0_f64;
1545 for k in 0..p {
1546 let r_kk = a_back[k * n_aug + k];
1547 log_abs_r += r_kk.abs().ln();
1548 }
1549 let log_det_hessian = 2.0 * log_abs_r;
1550
1551 Ok(PenalisedLsSolution {
1552 beta,
1553 weighted_residual_ssq: augmented_residual_ssq,
1554 log_det_hessian,
1555 })
1556}
1557
1558#[cfg(not(target_os = "linux"))]
1559pub fn solve_penalised_ls_device(
1560 x_s_device: &DeviceS2KernelMatrix,
1561 wy: &[f64],
1562 r_s: ArrayView2<'_, f64>,
1563) -> Result<PenalisedLsSolution, GpuError> {
1564 Err(GpuError::DriverLibraryUnavailable {
1565 reason: format!(
1566 "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1567 x_s_device.rows,
1568 x_s_device.cols,
1569 wy.len(),
1570 r_s.dim()
1571 ),
1572 })
1573}
1574
1575#[cfg(test)]
1580mod sphere_gpu_tests {
1581 use super::*;
1582 use crate::basis::{
1583 SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1584 spherical_wahba_kernel_matrix_with_kind,
1585 };
1586 use ndarray::Array2;
1587
1588 fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1589 let mut rows = Vec::with_capacity(n_lat * n_lon);
1591 for i in 0..n_lat {
1592 let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1593 for j in 0..n_lon {
1594 let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1595 rows.push(lat);
1596 rows.push(lon);
1597 }
1598 }
1599 Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1600 }
1601
1602 #[test]
1603 fn sum_finite_guard_accepts_finite_rejects_nonfinite() {
1604 let finite = Array2::<f64>::from_shape_fn((5, 7), |(i, j)| (i as f64 - 2.0) * (j as f64));
1609 assert!(finite.sum().is_finite());
1610
1611 let mut with_nan = finite.clone();
1612 with_nan[[3, 4]] = f64::NAN;
1613 assert!(!with_nan.sum().is_finite());
1614
1615 let mut with_pos_inf = finite.clone();
1616 with_pos_inf[[0, 0]] = f64::INFINITY;
1617 assert!(!with_pos_inf.sum().is_finite());
1618
1619 let mut with_neg_inf = finite.clone();
1620 with_neg_inf[[4, 6]] = f64::NEG_INFINITY;
1621 assert!(!with_neg_inf.sum().is_finite());
1622 }
1623
1624 #[test]
1625 fn xyz_preprocessing_matches_unit_sphere() {
1626 let latlon = ndarray::array![
1627 [0.0, 0.0],
1628 [90.0, 0.0],
1629 [0.0, 90.0],
1630 [-90.0, 17.5],
1631 [45.0, -120.0],
1632 ];
1633 let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1634 assert_eq!(xyz.len(), 3 * 5);
1635 for i in 0..5 {
1636 let nrm2 = xyz[3 * i] * xyz[3 * i]
1637 + xyz[3 * i + 1] * xyz[3 * i + 1]
1638 + xyz[3 * i + 2] * xyz[3 * i + 2];
1639 assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1640 }
1641 assert!((xyz[0] - 1.0).abs() < 1e-15);
1643 assert!(xyz[1].abs() < 1e-15);
1644 assert!(xyz[2].abs() < 1e-15);
1645 assert!(xyz[3].abs() < 1e-15);
1647 assert!(xyz[4].abs() < 1e-15);
1648 assert!((xyz[5] - 1.0).abs() < 1e-15);
1649 assert!(xyz[6].abs() < 1e-15);
1651 assert!((xyz[7] - 1.0).abs() < 1e-15);
1652 assert!(xyz[8].abs() < 1e-15);
1653 }
1654
1655 #[test]
1656 fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1657 for m_penalty in 1..=4 {
1661 for &lmax in &[5_usize, 20, 50] {
1662 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1663 let expected: f64 = coeffs.iter().sum();
1664 let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1665 assert!(
1666 (got - expected).abs() < 1e-13,
1667 "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1668 );
1669 }
1670 }
1671 }
1672
1673 #[test]
1674 fn truncated_spectral_at_antipode_matches_alternating_sum() {
1675 for m_penalty in 1..=4 {
1678 for &lmax in &[5_usize, 20, 50] {
1679 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1680 let expected: f64 = coeffs
1681 .iter()
1682 .enumerate()
1683 .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1684 .sum();
1685 let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1686 assert!(
1687 (got - expected).abs() < 1e-13,
1688 "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1689 );
1690 }
1691 }
1692 }
1693
1694 #[test]
1695 fn truncated_spectral_matrix_is_symmetric() {
1696 let centers = ndarray::array![
1700 [10.0_f64, 20.0],
1701 [-30.0, 100.0],
1702 [45.0, -60.0],
1703 [-89.0, 0.0],
1704 [0.0, 180.0],
1705 [60.0, -179.9],
1706 ];
1707 for m_penalty in [1usize, 2, 4] {
1708 for &lmax in &[10_usize, 30] {
1709 let mat = spherical_wahba_kernel_matrix_with_kind(
1710 centers.view(),
1711 centers.view(),
1712 m_penalty,
1713 false,
1714 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1715 )
1716 .expect("kernel matrix");
1717 let n = centers.nrows();
1718 let mut max_asym = 0.0_f64;
1719 for i in 0..n {
1720 for j in 0..n {
1721 let d = (mat[(i, j)] - mat[(j, i)]).abs();
1722 if d > max_asym {
1723 max_asym = d;
1724 }
1725 }
1726 }
1727 assert!(
1728 max_asym < 1e-13,
1729 "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1730 );
1731 }
1732 }
1733 }
1734
1735 #[test]
1736 fn truncated_coefficients_have_zero_constant_mode() {
1737 for m in 1..=4 {
1738 let c = sobolev_s2_truncated_coefficients(50, m);
1739 assert_eq!(c.len(), 51);
1740 assert_eq!(c[0], 0.0);
1741 assert!(c[1] > 0.0);
1742 for ell in 2..=50 {
1744 assert!(
1745 c[ell] < c[ell - 1] + 1e-15,
1746 "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1747 c[ell],
1748 c[ell - 1]
1749 );
1750 }
1751 }
1752 }
1753
1754 #[test]
1755 fn truncated_spectral_matches_matrix_helper() {
1756 let m_penalty = 2;
1760 let lmax = 20;
1761 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1762 let data = ndarray::array![[12.5, -34.0]];
1763 let centers = ndarray::array![[40.0, 10.0]];
1764 let mat = spherical_wahba_kernel_matrix_with_kind(
1765 data.view(),
1766 centers.view(),
1767 m_penalty,
1768 false,
1769 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1770 )
1771 .expect("kernel matrix");
1772 let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1774 let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1775 let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1776 let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1777 assert!(
1778 (mat[(0, 0)] - expected).abs() < 1e-13,
1779 "matrix helper differs from scalar evaluator: {} vs {}",
1780 mat[(0, 0)],
1781 expected
1782 );
1783 }
1784
1785 #[test]
1786 fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1787 let m = 6;
1792 let mut c = Array2::<f64>::zeros((m, m));
1793 for i in 0..m {
1794 for j in 0..m {
1795 let d = (i as f64 - j as f64).abs();
1796 c[(i, j)] = (-0.5 * d).exp();
1797 }
1798 }
1799 let w = vec![1.0_f64; m];
1800 let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1801 assert_eq!(s.dim(), (m - 1, m - 1));
1802 let mut max_asym = 0.0_f64;
1804 for i in 0..(m - 1) {
1805 for j in 0..(m - 1) {
1806 let d = (s[(i, j)] - s[(j, i)]).abs();
1807 if d > max_asym {
1808 max_asym = d;
1809 }
1810 }
1811 }
1812 assert!(
1813 max_asym < 1e-13,
1814 "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1815 );
1816
1817 let ones = ndarray::Array1::<f64>::ones(m - 1);
1825 let sx = s.dot(&ones);
1826 assert!(sx.iter().all(|v| v.is_finite()));
1827 }
1828
1829 #[test]
1830 fn householder_reflector_zeroes_target_vector() {
1831 let w = vec![3.0, 4.0, 0.0, -1.0];
1832 let (v, beta) = householder_reflector_from_weights(&w);
1833 let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1836 let hw: Vec<f64> = w
1837 .iter()
1838 .zip(&v)
1839 .map(|(wj, vj)| wj - beta * dot * vj)
1840 .collect();
1841 for entry in hw.iter().skip(1) {
1842 assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1843 }
1844 assert!(hw[0].abs() > 0.0);
1845 }
1846
1847 #[test]
1850 fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1851 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1852 eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1853 return;
1854 };
1855 SphereGpuBackend::probe()
1858 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1859
1860 let data_ll = small_latlon_grid(7, 9);
1861 let centers_ll = small_latlon_grid(5, 7);
1862 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1863 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1864 let n = data_ll.nrows();
1865 let m = centers_ll.nrows();
1866 let penalty = 2usize;
1867 let lmax = 20usize;
1868 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1869
1870 let inputs = S2KernelBuildInputs {
1871 n,
1872 m,
1873 lmax,
1874 data_xyz: &data_xyz,
1875 centers_xyz: ¢ers_xyz,
1876 coeffs: &coeffs,
1877 kind: SphereSpectralKernelKind::Sobolev,
1878 layout: DeviceMatrixLayout::ColumnMajor,
1879 };
1880 let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1881 let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1882
1883 let cpu = spherical_wahba_kernel_matrix_with_kind(
1884 data_ll.view(),
1885 centers_ll.view(),
1886 penalty,
1887 false,
1888 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1889 )
1890 .expect("cpu kernel matrix");
1891
1892 let mut max_abs = 0.0_f64;
1893 for i in 0..n {
1894 for j in 0..m {
1895 let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1896 if d > max_abs {
1897 max_abs = d;
1898 }
1899 }
1900 }
1901 assert!(
1902 max_abs < 1e-11,
1903 "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1904 );
1905 }
1906
1907 #[test]
1921 fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1922 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1923 eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1924 return;
1925 };
1926 SphereGpuBackend::probe()
1930 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1931 use crate::basis::{
1932 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1933 build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1934 spherical_wahba_kernel_matrix_with_kind,
1935 };
1936
1937 let data = small_latlon_grid(100, 100);
1939 let lmax: u16 = 30;
1940 let penalty_order = 2usize;
1941 let centers =
1942 crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1943 .expect("centers");
1944 let n = data.nrows();
1945 let m = centers.nrows();
1946
1947 let decision = sphere_kernel_decision(n, m, lmax as usize);
1951 assert!(
1952 decision.use_gpu,
1953 "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1954 (reason={}); the engagement gate regressed",
1955 decision.reason
1956 );
1957
1958 let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1960 data.view(),
1961 centers.view(),
1962 penalty_order,
1963 false,
1964 SphereWahbaKernel::SobolevTruncated { lmax },
1965 )
1966 .expect("GPU-eligible production kernel build succeeds");
1967
1968 let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1970 data.view(),
1971 centers.view(),
1972 penalty_order,
1973 false,
1974 SphereWahbaKernel::SobolevTruncated { lmax },
1975 )
1976 .expect("cpu oracle kernel build succeeds");
1977
1978 assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1979 let mut max_abs = 0.0_f64;
1980 let mut max_rel = 0.0_f64;
1981 for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1982 let d = (g - c).abs();
1983 if d > max_abs {
1984 max_abs = d;
1985 }
1986 let denom = g.abs().max(c.abs()).max(1e-300);
1987 let r = d / denom;
1988 if r > max_rel {
1989 max_rel = r;
1990 }
1991 }
1992 assert!(
1993 max_rel < 1e-9,
1994 "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1995 >= 1e-9 (abs {max_abs:.3e})"
1996 );
1997
1998 let spec_gpu = SphericalSplineBasisSpec {
2002 center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
2003 penalty_order,
2004 double_penalty: false,
2005 radians: false,
2006 method: SphereMethod::Wahba,
2007 max_degree: None,
2008 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2009 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2010 };
2011 let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
2012 .expect("GPU-eligible build_spherical_spline_basis succeeds");
2013 let design = result_gpu.design.as_dense().expect("dense design");
2014 assert_eq!(design.nrows(), n, "design row count must match data rows");
2015 assert!(
2016 design.iter().all(|v| v.is_finite()),
2017 "engaged-device spherical design must be finite"
2018 );
2019 }
2020
2021 #[test]
2024 fn sphere_gpu_householder_parity_vs_raw_dot_z() {
2025 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2026 eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
2027 return;
2028 };
2029 SphereGpuBackend::probe()
2032 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
2033 let data_ll = small_latlon_grid(6, 8);
2034 let centers_ll = small_latlon_grid(4, 5);
2035 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2036 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2037 let n = data_ll.nrows();
2038 let m = centers_ll.nrows();
2039 let penalty = 2usize;
2040 let lmax = 15usize;
2041 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
2042
2043 let inputs_raw = S2KernelBuildInputs {
2045 n,
2046 m,
2047 lmax,
2048 data_xyz: &data_xyz,
2049 centers_xyz: ¢ers_xyz,
2050 coeffs: &coeffs,
2051 kind: SphereSpectralKernelKind::Sobolev,
2052 layout: DeviceMatrixLayout::ColumnMajor,
2053 };
2054 let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
2055 let b = b_dev.to_host_array().expect("dtoh raw");
2056
2057 let w = vec![1.0_f64; m];
2060 let (v, beta) = householder_reflector_from_weights(&w);
2061
2062 let mut xs_host = Array2::<f64>::zeros((n, m - 1));
2064 for i in 0..n {
2065 let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
2066 for j_out in 0..(m - 1) {
2067 xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
2068 }
2069 }
2070
2071 let xs_dev =
2072 build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
2073 let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
2074
2075 let mut max_abs = 0.0_f64;
2076 for i in 0..n {
2077 for j in 0..(m - 1) {
2078 let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
2079 if d > max_abs {
2080 max_abs = d;
2081 }
2082 }
2083 }
2084 assert!(
2085 max_abs < 1e-12,
2086 "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
2087 );
2088 }
2089
2090 #[test]
2094 fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
2095 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2096 eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
2097 return;
2098 };
2099 if SphereGpuBackend::probe().is_err() {
2100 eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
2101 return;
2102 }
2103
2104 let n_lat = 500usize;
2107 let n_lon = 400usize;
2108 assert_eq!(n_lat * n_lon, 200_000);
2109 let data_ll = small_latlon_grid(n_lat, n_lon);
2110 let m = 200usize;
2111 let centers_ll =
2112 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2113 .expect("centers");
2114 let n = data_ll.nrows();
2115 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2116 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2117 let penalty_order = 2usize;
2118 let lmax = 50usize;
2119 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
2120
2121 let inputs_warm = S2KernelBuildInputs {
2123 n,
2124 m,
2125 lmax,
2126 data_xyz: &data_xyz,
2127 centers_xyz: ¢ers_xyz,
2128 coeffs: &coeffs,
2129 kind: SphereSpectralKernelKind::Sobolev,
2130 layout: DeviceMatrixLayout::ColumnMajor,
2131 };
2132 {
2137 let warm = build_kernel_matrix_device(inputs_warm.clone()).expect("warmup");
2138 drop(warm.to_host_array().expect("warmup to_host"));
2139 }
2140
2141 let t0 = std::time::Instant::now();
2143 let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
2144 dev.to_host_array().expect("dtoh");
2145 let gpu_secs = t0.elapsed().as_secs_f64();
2146
2147 let t1 = std::time::Instant::now();
2154 crate::basis::spherical_wahba_kernel_matrix_cpu(
2155 data_ll.view(),
2156 centers_ll.view(),
2157 penalty_order,
2158 false,
2159 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
2160 )
2161 .expect("cpu kernel matrix");
2162 let cpu_secs = t1.elapsed().as_secs_f64();
2163
2164 let ratio = cpu_secs / gpu_secs.max(1e-9);
2165 eprintln!(
2166 "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
2167 );
2168 assert!(
2169 ratio >= 20.0,
2170 "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
2171 n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2172 );
2173 }
2174
2175 #[test]
2180 fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
2181 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2182 eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
2183 return;
2184 };
2185 if SphereGpuBackend::probe().is_err() {
2186 eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
2187 return;
2188 }
2189 use crate::basis::{
2190 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
2191 build_spherical_spline_basis,
2192 };
2193
2194 let n_lat = 500usize;
2195 let n_lon = 400usize;
2196 let data_ll = small_latlon_grid(n_lat, n_lon);
2197 let m: usize = 200;
2198 let lmax: u16 = 50;
2199 let spec_gpu = SphericalSplineBasisSpec {
2200 center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
2201 penalty_order: 2,
2202 double_penalty: false,
2203 radians: false,
2204 method: SphereMethod::Wahba,
2205 max_degree: None,
2206 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2207 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2208 };
2209
2210 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
2212
2213 let t0 = std::time::Instant::now();
2214 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
2215 let gpu_secs = t0.elapsed().as_secs_f64();
2216
2217 let centers =
2224 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2225 .expect("centers");
2226 let z = Array2::<f64>::eye(centers.nrows());
2227 let t1 = std::time::Instant::now();
2228 let raw_cpu = crate::basis::spherical_wahba_kernel_matrix_cpu(
2233 data_ll.view(),
2234 centers.view(),
2235 2,
2236 false,
2237 SphereWahbaKernel::SobolevTruncated { lmax },
2238 )
2239 .expect("cpu raw");
2240 raw_cpu.dot(&z);
2241 let cpu_secs = t1.elapsed().as_secs_f64();
2242
2243 let ratio = cpu_secs / gpu_secs.max(1e-9);
2244 eprintln!(
2245 "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
2246 data_ll.nrows()
2247 );
2248 assert!(
2249 ratio >= 10.0,
2250 "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
2251 cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2252 );
2253 }
2254
2255 #[test]
2277 fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
2278 use crate::basis::{
2279 select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
2280 };
2281 use faer::Side;
2282 use gam_linalg::faer_ndarray::FaerCholesky;
2283
2284 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2285 eprintln!(
2286 "[sphere gpu parity] no CUDA runtime — skipping device parity \
2287 (CPU oracle exercised by sibling tests)"
2288 );
2289 return;
2290 };
2291 SphereGpuBackend::probe()
2294 .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2295
2296 let data_ll = small_latlon_grid(25, 40);
2298 assert_eq!(data_ll.nrows(), 1000);
2299 let n = data_ll.nrows();
2300 let m: usize = 80;
2301 let lmax_u16: u16 = 15;
2302 let lmax: usize = lmax_u16 as usize;
2303 let penalty_order: usize = 2;
2304 let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2305 let lambda: f64 = 1.0e-3;
2306
2307 let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2309 .expect("farthest-point centers");
2310 assert_eq!(centers_ll.nrows(), m);
2311
2312 let z = Array2::<f64>::eye(centers_ll.nrows());
2315 let p = z.ncols();
2316 assert_eq!(p, m);
2317
2318 let k_cc = spherical_wahba_kernel_matrix_with_kind(
2323 centers_ll.view(),
2324 centers_ll.view(),
2325 penalty_order,
2326 false,
2327 kernel,
2328 )
2329 .expect("centers×centers kernel");
2330 let s_full = z.t().dot(&k_cc).dot(&z);
2331
2332 let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2334 data_ll.view(),
2335 centers_ll.view(),
2336 penalty_order,
2337 false,
2338 kernel,
2339 )
2340 .expect("CPU raw design");
2341 let x_s_cpu = raw_design_cpu.dot(&z);
2342
2343 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2345 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2346 let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2347 let inputs = S2KernelBuildInputs {
2348 n,
2349 m,
2350 lmax,
2351 data_xyz: &data_xyz,
2352 centers_xyz: ¢ers_xyz,
2353 coeffs: &coeffs,
2354 kind: SphereSpectralKernelKind::Sobolev,
2355 layout: DeviceMatrixLayout::ColumnMajor,
2356 };
2357 let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2358 let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2359 let x_s_gpu = raw_design_gpu.dot(&z);
2360
2361 assert_eq!(x_s_cpu.dim(), (n, p));
2362 assert_eq!(x_s_gpu.dim(), (n, p));
2363
2364 let mut raw_xs_delta = 0.0_f64;
2374 let mut xs_scale = 0.0_f64;
2375 for (a, b) in x_s_cpu.iter().zip(x_s_gpu.iter()) {
2376 raw_xs_delta = raw_xs_delta.max((a - b).abs());
2377 xs_scale = xs_scale.max(a.abs());
2378 }
2379 let cond = {
2382 use gam_linalg::faer_ndarray::FaerEigh;
2383 let xtx = x_s_cpu.t().dot(&x_s_cpu);
2384 let mut a = xtx;
2385 for i in 0..p {
2386 for j in 0..p {
2387 a[(i, j)] += lambda * s_full[(i, j)];
2388 }
2389 }
2390 let (mut lo, mut hi) = (f64::INFINITY, 0.0_f64);
2391 if let Ok((vals, _)) = a.eigh(faer::Side::Lower) {
2392 for &v in vals.iter() {
2393 lo = lo.min(v);
2394 hi = hi.max(v);
2395 }
2396 }
2397 hi / lo.max(1e-300)
2398 };
2399 assert!(
2404 raw_xs_delta <= 1e-12 * xs_scale.max(1.0),
2405 "GPU vs CPU sphere design matrix max |Δ| = {raw_xs_delta:.3e} > {:.3e} \
2406 (scale {xs_scale:.3e}) — the kernel itself drifted (this is the genuine \
2407 GPU output, NOT a conditioning artifact)",
2408 1e-12 * xs_scale.max(1.0)
2409 );
2410
2411 let mut y = ndarray::Array1::<f64>::zeros(n);
2417 for i in 0..n {
2418 let lat_rad = data_ll[(i, 0)].to_radians();
2419 let lon_rad = data_ll[(i, 1)].to_radians();
2420 y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2422 + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2423 }
2424
2425 let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2430 let xtx = x_s.t().dot(x_s);
2431 let mut a = xtx;
2432 for i in 0..p {
2433 for j in 0..p {
2434 a[(i, j)] += lambda * s_full[(i, j)];
2435 }
2436 }
2437 let rhs = x_s.t().dot(&y);
2438 let factor = a
2439 .cholesky(Side::Lower)
2440 .expect("penalised normal equations are SPD under λ > 0");
2441 factor.solvevec(&rhs)
2442 };
2443
2444 let beta_cpu = solve_penalised(&x_s_cpu);
2445 let beta_gpu = solve_penalised(&x_s_gpu);
2446 assert_eq!(beta_cpu.len(), p);
2447 assert_eq!(beta_gpu.len(), p);
2448
2449 let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2453 let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2454
2455 let mut max_beta_delta = 0.0_f64;
2456 for k in 0..p {
2457 let d = (beta_cpu[k] - beta_gpu[k]).abs();
2458 if d > max_beta_delta {
2459 max_beta_delta = d;
2460 }
2461 }
2462 let mut max_fit_delta = 0.0_f64;
2463 for i in 0..n {
2464 let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2465 if d > max_fit_delta {
2466 max_fit_delta = d;
2467 }
2468 }
2469
2470 eprintln!(
2471 "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2472 raw_xs|Δ|={raw_xs_delta:.3e} cond={cond:.3e} \
2473 max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2474 );
2475
2476 assert!(
2483 max_fit_delta <= 1.0e-9,
2484 "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2485 );
2486
2487 let beta_tol = (1e-15 * cond * (1.0 + xs_scale)).max(1e-9) * 16.0;
2499 assert!(
2500 max_beta_delta <= beta_tol,
2501 "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > \
2502 condition-aware tol {beta_tol:.3e} (cond={cond:.3e}). Raw design parity is \
2503 {raw_xs_delta:.3e}; a drift THIS much larger than cond·ULP is a real solve/kernel \
2504 mismatch, not conditioning."
2505 );
2506 }
2507}