1use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41 Sobolev,
43 Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48 pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52 match self {
53 SphereSpectralKernelKind::Sobolev => {
54 crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55 }
56 SphereSpectralKernelKind::Pseudo => {
57 crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58 }
59 }
60 }
61
62 pub const fn tag(self) -> &'static str {
64 match self {
65 SphereSpectralKernelKind::Sobolev => "sobolev",
66 SphereSpectralKernelKind::Pseudo => "pseudo",
67 }
68 }
69}
70
71#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76 ColumnMajor,
77}
78
79pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86 if latlon.ncols() != 2 {
87 return Err(format!(
88 "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89 latlon.shape()
90 ));
91 }
92 let deg = if radians {
93 1.0
94 } else {
95 std::f64::consts::PI / 180.0
96 };
97 let n = latlon.nrows();
98 let mut out = Vec::with_capacity(3 * n);
99 for row in latlon.outer_iter() {
100 let lat = row[0] * deg;
101 let lon = row[1] * deg;
102 let (s_lat, c_lat) = lat.sin_cos();
103 let (s_lon, c_lon) = lon.sin_cos();
104 out.push(c_lat * c_lon);
106 out.push(c_lat * s_lon);
107 out.push(s_lat);
108 }
109 Ok(out)
110}
111
112#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120 pub rows: usize,
121 pub cols: usize,
122 pub ld: usize,
123 pub col_major_dev: CudaSlice<f64>,
124 pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129 pub rows: usize,
130 pub cols: usize,
131 pub ld: usize,
132 pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137 #[cfg(target_os = "linux")]
153 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
154 let needed = self.ld * self.cols;
155 let mut staging = PinnedLease::acquire(self.stream.context(), needed)?;
156 self.stream
157 .memcpy_dtoh(&self.col_major_dev, staging.as_mut_slice())
158 .gpu_ctx("DeviceS2KernelMatrix dtoh (pinned)")?;
159 self.stream
160 .synchronize()
161 .gpu_ctx("DeviceS2KernelMatrix synchronize (pinned)")?;
162 Ok(col_major_to_row_major_parallel(
163 staging.as_slice(),
164 self.rows,
165 self.cols,
166 self.ld,
167 ))
168 }
169
170 #[cfg(not(target_os = "linux"))]
171 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
172 let mut col_major = vec![0.0_f64; self.ld * self.cols];
176 self.copy_to_host_col_major(&mut col_major)?;
177 Ok(col_major_to_row_major_parallel(
178 &col_major, self.rows, self.cols, self.ld,
179 ))
180 }
181
182 #[cfg(target_os = "linux")]
187 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
188 let needed = self.ld * self.cols;
189 if dst.len() != needed {
190 gam_gpu::gpu_bail!(
191 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
192 dst.len(),
193 needed
194 );
195 }
196 self.stream
197 .memcpy_dtoh(&self.col_major_dev, dst)
198 .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
199 self.stream
200 .synchronize()
201 .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
202 Ok(())
203 }
204
205 #[cfg(not(target_os = "linux"))]
206 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
207 let needed = self.ld * self.cols;
208 if dst.len() != needed {
209 gam_gpu::gpu_bail!(
210 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
211 dst.len(),
212 needed
213 );
214 }
215 dst.copy_from_slice(&self.col_major_dev);
216 Ok(())
217 }
218}
219
220fn col_major_to_row_major_parallel(
236 col_major: &[f64],
237 rows: usize,
238 cols: usize,
239 ld: usize,
240) -> Array2<f64> {
241 use rayon::prelude::*;
242
243 assert!(ld >= rows, "ld {ld} must be >= rows {rows}");
244 assert!(
245 col_major.len() >= ld * cols,
246 "col_major len {} < ld*cols {}",
247 col_major.len(),
248 ld * cols
249 );
250
251 const BLOCK_ROWS: usize = 128;
254
255 let mut out_flat = vec![0.0_f64; rows * cols];
256 out_flat
257 .par_chunks_mut(BLOCK_ROWS * cols)
258 .enumerate()
259 .for_each(|(block_idx, out_block)| {
260 let r0 = block_idx * BLOCK_ROWS;
261 let block_rows = out_block.len() / cols;
262 for j in 0..cols {
263 let base = j * ld + r0;
264 let src_col = &col_major[base..base + block_rows];
265 for (local_i, &v) in src_col.iter().enumerate() {
267 out_block[local_i * cols + j] = v;
268 }
269 }
270 });
271
272 Array2::from_shape_vec((rows, cols), out_flat)
273 .expect("row-major buffer has rows*cols elements")
274}
275
276#[cfg(target_os = "linux")]
287struct PinnedF64 {
288 ptr: *mut f64,
289 len: usize,
290 freed: bool,
291}
292
293#[cfg(target_os = "linux")]
294impl PinnedF64 {
295 fn alloc(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
298 ctx.bind_to_thread()
299 .gpu_ctx("PinnedF64 bind_to_thread")?;
300 let bytes = len
301 .checked_mul(std::mem::size_of::<f64>())
302 .ok_or_else(|| gam_gpu::gpu_err!("PinnedF64: len={len} byte size overflows usize"))?;
303 let raw = unsafe { cudarc::driver::result::malloc_host(bytes, 0) }
308 .gpu_ctx("PinnedF64 cuMemHostAlloc")?;
309 let ptr = raw as *mut f64;
310 if ptr.is_null() {
311 gam_gpu::gpu_bail!("PinnedF64: cuMemHostAlloc returned null for {bytes} bytes");
312 }
313 Ok(Self {
314 ptr,
315 len,
316 freed: false,
317 })
318 }
319
320 fn as_mut_slice(&mut self) -> &mut [f64] {
321 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
324 }
325
326 fn as_slice(&self) -> &[f64] {
327 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
329 }
330}
331
332#[cfg(target_os = "linux")]
333impl Drop for PinnedF64 {
334 fn drop(&mut self) {
335 if self.freed {
336 return;
337 }
338 self.freed = true;
339 unsafe { cudarc::driver::result::free_host(self.ptr as *mut std::ffi::c_void) }.ok();
344 }
345}
346
347#[cfg(target_os = "linux")]
353unsafe impl Send for PinnedF64 {}
354
355#[cfg(target_os = "linux")]
364const PINNED_POOL_MAX_BUFFERS: usize = 4;
365
366#[cfg(target_os = "linux")]
367static PINNED_POOL: OnceLock<Mutex<Vec<PinnedF64>>> = OnceLock::new();
368
369#[cfg(target_os = "linux")]
373struct PinnedLease {
374 buf: Option<PinnedF64>,
375}
376
377#[cfg(target_os = "linux")]
378impl PinnedLease {
379 fn acquire(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
382 let pool = PINNED_POOL.get_or_init(|| Mutex::new(Vec::new()));
383 if let Ok(mut guard) = pool.lock() {
384 if let Some(pos) = guard.iter().position(|b| b.len == len) {
385 return Ok(Self {
386 buf: Some(guard.swap_remove(pos)),
387 });
388 }
389 }
390 Ok(Self {
391 buf: Some(PinnedF64::alloc(ctx, len)?),
392 })
393 }
394
395 fn as_mut_slice(&mut self) -> &mut [f64] {
396 self.buf
397 .as_mut()
398 .expect("PinnedLease buffer present until drop")
399 .as_mut_slice()
400 }
401
402 fn as_slice(&self) -> &[f64] {
403 self.buf
404 .as_ref()
405 .expect("PinnedLease buffer present until drop")
406 .as_slice()
407 }
408}
409
410#[cfg(target_os = "linux")]
411impl Drop for PinnedLease {
412 fn drop(&mut self) {
413 let Some(buf) = self.buf.take() else {
414 return;
415 };
416 if let Some(pool) = PINNED_POOL.get() {
417 if let Ok(mut guard) = pool.lock() {
418 if guard.len() < PINNED_POOL_MAX_BUFFERS {
419 guard.push(buf);
420 return;
421 }
422 guard.remove(0);
426 guard.push(buf);
427 return;
428 }
429 }
430 drop(buf);
432 }
433}
434
435#[derive(Clone, Debug)]
446pub struct S2KernelBuildInputs<'a> {
447 pub n: usize,
448 pub m: usize,
449 pub lmax: usize,
450 pub data_xyz: &'a [f64],
451 pub centers_xyz: &'a [f64],
452 pub coeffs: &'a [f64],
453 pub kind: SphereSpectralKernelKind,
454 pub layout: DeviceMatrixLayout,
455}
456
457impl<'a> S2KernelBuildInputs<'a> {
458 fn validate(&self) -> Result<(), GpuError> {
459 if self.lmax == 0 {
460 return Err(GpuError::DriverCallFailed {
461 reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
462 });
463 }
464 if self.data_xyz.len() != 3 * self.n {
465 gam_gpu::gpu_bail!(
466 "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
467 self.data_xyz.len(),
468 3 * self.n
469 );
470 }
471 if self.centers_xyz.len() != 3 * self.m {
472 gam_gpu::gpu_bail!(
473 "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
474 self.centers_xyz.len(),
475 3 * self.m
476 );
477 }
478 if self.coeffs.len() != self.lmax + 1 {
479 gam_gpu::gpu_bail!(
480 "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
481 self.coeffs.len(),
482 self.lmax + 1
483 );
484 }
485 if self.coeffs[0] != 0.0 {
486 return Err(GpuError::DriverCallFailed {
487 reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
488 });
489 }
490 Ok(())
491 }
492}
493
494#[cfg(target_os = "linux")]
504const KERNEL_TEMPLATE: &str = r#"
505// LMAX is supplied by the host via a `#define LMAX ...` prepended to
506// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
507extern "C" __global__
508__launch_bounds__(256)
509void s2_wahba_legendre_colmajor(
510 const double* __restrict__ data_xyz, // n × 3 (row-major flat)
511 const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
512 const double* __restrict__ coeffs, // length LMAX + 1, coeffs[0] = 0
513 int n,
514 int m,
515 long long ld,
516 double* __restrict__ out // ld × m column-major
517) {
518 const int i = blockIdx.y * blockDim.y + threadIdx.y;
519 const int j = blockIdx.x * blockDim.x + threadIdx.x;
520 if (i >= n || j >= m) return;
521
522 // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
523 const double xi = data_xyz[3 * i + 0];
524 const double yi = data_xyz[3 * i + 1];
525 const double zi = data_xyz[3 * i + 2];
526 const double cxj = centers_xyz[3 * j + 0];
527 const double cyj = centers_xyz[3 * j + 1];
528 const double czj = centers_xyz[3 * j + 2];
529
530 // t = clamp(x_i · z_j, -1, +1).
531 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
532 if (t > 1.0) t = 1.0;
533 if (t < -1.0) t = -1.0;
534
535 // Legendre 3-term recurrence in registers.
536 // P_0(t) = 1, P_1(t) = t.
537 double p_prev = 1.0;
538 double p_curr = t;
539 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
540
541 #pragma unroll 8
542 for (int ell = 1; ell < LMAX; ++ell) {
543 const double lf = (double) ell;
544 const double inv = 1.0 / (lf + 1.0);
545 // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
546 const double p_next =
547 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
548 acc = fma(coeffs[ell + 1], p_next, acc);
549 p_prev = p_curr;
550 p_curr = p_next;
551 }
552
553 out[(long long) j * ld + (long long) i] = acc;
554}
555
556// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
557// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
558// i.e. drop the first column after applying Z. Each thread computes one
559// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
560// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
561// for j_out in 0..m-1.
562//
563// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
564// over centers in an inner loop — register-bound by the per-row state
565// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
566extern "C" __global__
567__launch_bounds__(128)
568void s2_wahba_householder_constrained_colmajor(
569 const double* __restrict__ data_xyz, // n × 3
570 const double* __restrict__ centers_xyz, // m × 3
571 const double* __restrict__ coeffs, // length LMAX + 1
572 const double* __restrict__ v, // length m, Householder vector
573 double beta,
574 int n,
575 int m,
576 long long ld_out,
577 double* __restrict__ out // ld_out × (m-1) column-major
578) {
579 const int i = blockIdx.x * blockDim.x + threadIdx.x;
580 if (i >= n) return;
581
582 const double xi = data_xyz[3 * i + 0];
583 const double yi = data_xyz[3 * i + 1];
584 const double zi = data_xyz[3 * i + 2];
585
586 // Pass 1: compute d_i = sum_j v[j] * B[i, j].
587 double d_i = 0.0;
588 for (int j = 0; j < m; ++j) {
589 const double cxj = centers_xyz[3 * j + 0];
590 const double cyj = centers_xyz[3 * j + 1];
591 const double czj = centers_xyz[3 * j + 2];
592 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
593 if (t > 1.0) t = 1.0;
594 if (t < -1.0) t = -1.0;
595
596 double p_prev = 1.0;
597 double p_curr = t;
598 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
599 #pragma unroll 8
600 for (int ell = 1; ell < LMAX; ++ell) {
601 const double lf = (double) ell;
602 const double inv = 1.0 / (lf + 1.0);
603 const double p_next =
604 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
605 acc = fma(coeffs[ell + 1], p_next, acc);
606 p_prev = p_curr;
607 p_curr = p_next;
608 }
609 d_i = fma(v[j], acc, d_i);
610 }
611
612 // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
613 const double bd = beta * d_i;
614 for (int j_out = 0; j_out < m - 1; ++j_out) {
615 const int j = j_out + 1;
616 const double cxj = centers_xyz[3 * j + 0];
617 const double cyj = centers_xyz[3 * j + 1];
618 const double czj = centers_xyz[3 * j + 2];
619 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
620 if (t > 1.0) t = 1.0;
621 if (t < -1.0) t = -1.0;
622
623 double p_prev = 1.0;
624 double p_curr = t;
625 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
626 #pragma unroll 8
627 for (int ell = 1; ell < LMAX; ++ell) {
628 const double lf = (double) ell;
629 const double inv = 1.0 / (lf + 1.0);
630 const double p_next =
631 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
632 acc = fma(coeffs[ell + 1], p_next, acc);
633 p_prev = p_curr;
634 p_curr = p_next;
635 }
636 const double xs = acc - bd * v[j];
637 out[(long long) j_out * ld_out + (long long) i] = xs;
638 }
639}
640"#;
641
642#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
652pub struct S2ModuleCacheKey {
653 pub cc_major: i32,
654 pub cc_minor: i32,
655 pub lmax: u32,
656 pub kind: SphereSpectralKernelKind,
657 pub layout: DeviceMatrixLayout,
658}
659
660pub const fn sphere_gpu_compiled() -> bool {
663 cfg!(target_os = "linux")
664}
665
666#[must_use]
673pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
674 let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
675 let ld = ((n + 31) / 32) * 32;
676 let needed_bytes = ld
677 .saturating_mul(m)
678 .saturating_mul(std::mem::size_of::<f64>());
679 let budget = runtime.memory_budget_bytes;
680 n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
681 } else {
682 false
683 };
684 decide(
685 GpuKernel::SpatialKernelOperator,
686 gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
687 )
688}
689
690#[must_use]
696pub fn truncated_device_kind(
697 kernel: crate::basis::SphereWahbaKernel,
698) -> Option<(SphereSpectralKernelKind, u16)> {
699 use crate::basis::SphereWahbaKernel;
700 match kernel {
701 SphereWahbaKernel::SobolevTruncated { lmax } => {
702 Some((SphereSpectralKernelKind::Sobolev, lmax))
703 }
704 SphereWahbaKernel::PseudoTruncated { lmax } => {
705 Some((SphereSpectralKernelKind::Pseudo, lmax))
706 }
707 SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
708 }
709}
710
711pub fn try_build_truncated_kernel_matrix_gpu(
733 data: ArrayView2<'_, f64>,
734 centers: ArrayView2<'_, f64>,
735 penalty_order: usize,
736 radians: bool,
737 kernel: crate::basis::SphereWahbaKernel,
738) -> Option<Result<Array2<f64>, GpuError>> {
739 let (kind, lmax) = truncated_device_kind(kernel)?;
740 let n = data.nrows();
741 let m = centers.nrows();
742 if n == 0 || m == 0 || lmax == 0 {
743 return None;
744 }
745 let decision = sphere_kernel_decision(n, m, lmax as usize);
746 if !decision.use_gpu {
747 return None;
750 }
751 Some(build_truncated_kernel_matrix_gpu_admitted(
753 data,
754 centers,
755 penalty_order,
756 radians,
757 kind,
758 lmax,
759 ))
760}
761
762fn build_truncated_kernel_matrix_gpu_admitted(
766 data: ArrayView2<'_, f64>,
767 centers: ArrayView2<'_, f64>,
768 penalty_order: usize,
769 radians: bool,
770 kind: SphereSpectralKernelKind,
771 lmax: u16,
772) -> Result<Array2<f64>, GpuError> {
773 let n = data.nrows();
774 let m = centers.nrows();
775 let data_xyz = latlon_to_xyz_host(data, radians)
776 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
777 let centers_xyz = latlon_to_xyz_host(centers, radians)
778 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
779 let coeffs = kind.coefficients(lmax as usize, penalty_order);
783 let inputs = S2KernelBuildInputs {
784 n,
785 m,
786 lmax: lmax as usize,
787 data_xyz: &data_xyz,
788 centers_xyz: ¢ers_xyz,
789 coeffs: &coeffs,
790 kind,
791 layout: DeviceMatrixLayout::ColumnMajor,
792 };
793 let device_matrix = build_kernel_matrix_device(inputs)?;
794 let out = device_matrix.to_host_array()?;
795 if !out.sum().is_finite() {
806 return Err(GpuError::DriverCallFailed {
807 reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
808 });
809 }
810 Ok(out)
811}
812
813#[cfg(target_os = "linux")]
814struct SphereGpuContext {
815 ctx: Arc<CudaContext>,
816 stream: Arc<CudaStream>,
817 modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
818 cc_major: i32,
819 cc_minor: i32,
820}
821
822pub struct SphereGpuBackend {
825 #[cfg(target_os = "linux")]
826 inner: SphereGpuContext,
827}
828
829impl SphereGpuBackend {
830 pub fn probe() -> Result<&'static Self, GpuError> {
832 static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
833 BACKEND
834 .get_or_init(|| {
835 #[cfg(target_os = "linux")]
836 {
837 Self::probe_linux()
838 }
839 #[cfg(not(target_os = "linux"))]
840 {
841 Err(GpuError::DriverLibraryUnavailable {
842 reason: "sphere GPU backend is Linux-only".to_string(),
843 })
844 }
845 })
846 .as_ref()
847 .map_err(GpuError::clone)
848 }
849
850 #[cfg(target_os = "linux")]
851 fn probe_linux() -> Result<Self, GpuError> {
852 let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
853 Ok(SphereGpuBackend {
854 inner: SphereGpuContext {
855 ctx: parts.ctx,
856 stream: parts.stream,
857 modules: Mutex::new(HashMap::new()),
858 cc_major: parts.capability.compute_major,
859 cc_minor: parts.capability.compute_minor,
860 },
861 })
862 }
863
864 #[cfg(target_os = "linux")]
867 fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
868 if let Ok(guard) = self.inner.modules.lock() {
869 if let Some(existing) = guard.get(&key) {
870 return Ok(existing.clone());
871 }
872 }
873 let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
883 let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).gpu_ctx_with(|err| {
884 format!(
885 "sphere NVRTC compile (kind={}, lmax={}): {err}",
886 key.kind.tag(),
887 key.lmax
888 )
889 })?;
890 let module = self
891 .inner
892 .ctx
893 .load_module(ptx)
894 .gpu_ctx("sphere module load")?;
895 if let Ok(mut guard) = self.inner.modules.lock() {
896 guard.entry(key).or_insert_with(|| module.clone());
897 }
898 Ok(module)
899 }
900
901 #[cfg(target_os = "linux")]
902 fn cc(&self) -> (i32, i32) {
903 (self.inner.cc_major, self.inner.cc_minor)
904 }
905}
906
907pub fn build_kernel_matrix_device(
914 inputs: S2KernelBuildInputs<'_>,
915) -> Result<DeviceS2KernelMatrix, GpuError> {
916 inputs.validate()?;
917
918 #[cfg(target_os = "linux")]
919 {
920 use cudarc::driver::{LaunchConfig, PushKernelArg};
921 let backend = SphereGpuBackend::probe()?;
922 let (cc_major, cc_minor) = backend.cc();
923 let key = S2ModuleCacheKey {
924 cc_major,
925 cc_minor,
926 lmax: inputs.lmax as u32,
927 kind: inputs.kind,
928 layout: inputs.layout,
929 };
930 let module = backend.module_for(key)?;
931 let func = module
932 .load_function("s2_wahba_legendre_colmajor")
933 .gpu_ctx("sphere load_function raw")?;
934 let stream = backend.inner.stream.clone();
935
936 let data_dev = stream
937 .clone_htod(inputs.data_xyz)
938 .gpu_ctx("sphere htod data_xyz")?;
939 let centers_dev = stream
940 .clone_htod(inputs.centers_xyz)
941 .gpu_ctx("sphere htod centers_xyz")?;
942 let coeffs_dev = stream
943 .clone_htod(inputs.coeffs)
944 .gpu_ctx("sphere htod coeffs")?;
945
946 let n = inputs.n;
947 let m = inputs.m;
948 let ld = ((n + 31) / 32) * 32;
949 let mut out_dev = stream
950 .alloc_zeros::<f64>(ld * m)
951 .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
952
953 let block_x: u32 = 32;
955 let block_y: u32 = 8;
956 let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
957 let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
958 let cfg = LaunchConfig {
959 grid_dim: (grid_x, grid_y, 1),
960 block_dim: (block_x, block_y, 1),
961 shared_mem_bytes: 0,
962 };
963 let n_i32: i32 =
964 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
965 let m_i32: i32 =
966 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
967 let ld_i64: i64 = ld as i64;
968
969 let mut builder = stream.launch_builder(&func);
970 builder
971 .arg(&data_dev)
972 .arg(¢ers_dev)
973 .arg(&coeffs_dev)
974 .arg(&n_i32)
975 .arg(&m_i32)
976 .arg(&ld_i64)
977 .arg(&mut out_dev);
978 unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
983 stream
984 .synchronize()
985 .gpu_ctx("sphere raw kernel synchronize")?;
986
987 Ok(DeviceS2KernelMatrix {
988 rows: n,
989 cols: m,
990 ld,
991 col_major_dev: out_dev,
992 stream,
993 })
994 }
995
996 #[cfg(not(target_os = "linux"))]
997 {
998 Err(GpuError::DriverLibraryUnavailable {
999 reason: "sphere GPU backend is Linux-only".to_string(),
1000 })
1001 }
1002}
1003
1004pub fn build_householder_constrained_design_device(
1008 inputs: S2KernelBuildInputs<'_>,
1009 v: &[f64],
1010 beta: f64,
1011) -> Result<DeviceS2KernelMatrix, GpuError> {
1012 inputs.validate()?;
1013 if v.len() != inputs.m {
1014 gam_gpu::gpu_bail!(
1015 "build_householder_constrained_design_device: v.len()={} != m={}",
1016 v.len(),
1017 inputs.m
1018 );
1019 }
1020 if inputs.m < 2 {
1021 gam_gpu::gpu_bail!(
1022 "build_householder_constrained_design_device: m must be >= 2 (got {})",
1023 inputs.m
1024 );
1025 }
1026 if !beta.is_finite() {
1027 gam_gpu::gpu_bail!(
1028 "build_householder_constrained_design_device: beta must be finite (got {beta})"
1029 );
1030 }
1031
1032 #[cfg(target_os = "linux")]
1033 {
1034 use cudarc::driver::{LaunchConfig, PushKernelArg};
1035 let backend = SphereGpuBackend::probe()?;
1036 let (cc_major, cc_minor) = backend.cc();
1037 let key = S2ModuleCacheKey {
1038 cc_major,
1039 cc_minor,
1040 lmax: inputs.lmax as u32,
1041 kind: inputs.kind,
1042 layout: inputs.layout,
1043 };
1044 let module = backend.module_for(key)?;
1045 let func = module
1046 .load_function("s2_wahba_householder_constrained_colmajor")
1047 .gpu_ctx("sphere load_function householder")?;
1048 let stream = backend.inner.stream.clone();
1049
1050 let data_dev = stream
1051 .clone_htod(inputs.data_xyz)
1052 .gpu_ctx("sphere-hh htod data_xyz")?;
1053 let centers_dev = stream
1054 .clone_htod(inputs.centers_xyz)
1055 .gpu_ctx("sphere-hh htod centers_xyz")?;
1056 let coeffs_dev = stream
1057 .clone_htod(inputs.coeffs)
1058 .gpu_ctx("sphere-hh htod coeffs")?;
1059 let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
1060
1061 let n = inputs.n;
1062 let m = inputs.m;
1063 let cols_out = m - 1;
1064 let ld_out = ((n + 31) / 32) * 32;
1065 let mut out_dev = stream
1066 .alloc_zeros::<f64>(ld_out * cols_out)
1067 .gpu_ctx_with(|err| {
1068 format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
1069 })?;
1070
1071 let block_x: u32 = 128;
1072 let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
1073 let cfg = LaunchConfig {
1074 grid_dim: (grid_x, 1, 1),
1075 block_dim: (block_x, 1, 1),
1076 shared_mem_bytes: 0,
1077 };
1078 let n_i32: i32 =
1079 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
1080 let m_i32: i32 =
1081 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
1082 let ld_out_i64: i64 = ld_out as i64;
1083
1084 let mut builder = stream.launch_builder(&func);
1085 builder
1086 .arg(&data_dev)
1087 .arg(¢ers_dev)
1088 .arg(&coeffs_dev)
1089 .arg(&v_dev)
1090 .arg(&beta)
1091 .arg(&n_i32)
1092 .arg(&m_i32)
1093 .arg(&ld_out_i64)
1094 .arg(&mut out_dev);
1095 unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
1098 stream
1099 .synchronize()
1100 .gpu_ctx("sphere-hh kernel synchronize")?;
1101
1102 Ok(DeviceS2KernelMatrix {
1103 rows: n,
1104 cols: cols_out,
1105 ld: ld_out,
1106 col_major_dev: out_dev,
1107 stream,
1108 })
1109 }
1110
1111 #[cfg(not(target_os = "linux"))]
1112 {
1113 Err(GpuError::DriverLibraryUnavailable {
1114 reason: "sphere GPU backend is Linux-only".to_string(),
1115 })
1116 }
1117}
1118
1119pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
1132 let m = w.len();
1133 if m == 0 {
1134 return (Vec::new(), 0.0);
1135 }
1136 let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1137 if norm == 0.0 {
1138 return (vec![0.0; m], 0.0);
1139 }
1140 let sigma = if w[0] >= 0.0 { norm } else { -norm };
1141 let mut v = w.to_vec();
1142 v[0] += sigma;
1143 let v0 = v[0];
1144 if v0 == 0.0 {
1145 return (vec![0.0; m], 0.0);
1146 }
1147 for entry in v.iter_mut() {
1149 *entry /= v0;
1150 }
1151 let vv: f64 = v.iter().map(|x| x * x).sum();
1153 let beta = 2.0 / vv;
1154 (v, beta)
1155}
1156
1157pub fn build_center_kernel_device(
1176 centers_xyz: &[f64],
1177 lmax: usize,
1178 coeffs: &[f64],
1179 kind: SphereSpectralKernelKind,
1180) -> Result<DeviceS2KernelMatrix, GpuError> {
1181 let m = centers_xyz.len() / 3;
1182 if centers_xyz.len() != 3 * m {
1183 return Err(GpuError::DriverCallFailed {
1184 reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
1185 });
1186 }
1187 let inputs = S2KernelBuildInputs {
1188 n: m,
1189 m,
1190 lmax,
1191 data_xyz: centers_xyz,
1192 centers_xyz,
1193 coeffs,
1194 kind,
1195 layout: DeviceMatrixLayout::ColumnMajor,
1196 };
1197 build_kernel_matrix_device(inputs)
1198}
1199
1200pub fn constrained_penalty_host(
1205 c: ArrayView2<'_, f64>,
1206 w: &[f64],
1207) -> Result<Array2<f64>, GpuError> {
1208 let (m1, m2) = c.dim();
1209 if m1 != m2 {
1210 gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
1211 }
1212 let m = m1;
1213 if w.len() != m {
1214 gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
1215 }
1216 if m < 2 {
1217 gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
1218 }
1219 let (v, beta) = householder_reflector_from_weights(w);
1220
1221 let mut u = vec![0.0_f64; m];
1224 for i in 0..m {
1225 let mut acc = 0.0_f64;
1226 for j in 0..m {
1227 acc += c[(i, j)] * v[j];
1228 }
1229 u[i] = acc;
1230 }
1231 let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
1232 let mut hch = Array2::<f64>::zeros((m, m));
1233 for i in 0..m {
1234 for j in 0..m {
1235 hch[(i, j)] =
1236 c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
1237 }
1238 }
1239 let mut s = Array2::<f64>::zeros((m - 1, m - 1));
1241 for i in 0..(m - 1) {
1242 for j in 0..(m - 1) {
1243 s[(i, j)] = hch[(i + 1, j + 1)];
1244 }
1245 }
1246 Ok(s)
1247}
1248
1249#[derive(Clone, Debug)]
1280pub struct PenalisedLsSolution {
1281 pub beta: Vec<f64>,
1283 pub weighted_residual_ssq: f64,
1285 pub log_det_hessian: f64,
1287}
1288
1289#[cfg(target_os = "linux")]
1299pub fn solve_penalised_ls_device(
1300 x_s_device: &DeviceS2KernelMatrix,
1301 wy: &[f64],
1302 r_s: ArrayView2<'_, f64>,
1303) -> Result<PenalisedLsSolution, GpuError> {
1304 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1305 use cudarc::driver::DevicePtrMut;
1306
1307 let n = x_s_device.rows;
1308 let p = x_s_device.cols;
1309 if wy.len() != n {
1310 gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1311 }
1312 if r_s.dim() != (p, p) {
1313 gam_gpu::gpu_bail!(
1314 "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1315 r_s.dim()
1316 );
1317 }
1318 if p == 0 {
1319 return Ok(PenalisedLsSolution {
1320 beta: Vec::new(),
1321 weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1322 log_det_hessian: 0.0,
1323 });
1324 }
1325
1326 let stream = x_s_device.stream.clone();
1327 let n_aug = n + p;
1328
1329 let mut a_aug_host = vec![0.0_f64; n_aug * p];
1334 let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1336 x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1337 for j in 0..p {
1338 let src_off = j * x_s_device.ld;
1339 let dst_off = j * n_aug;
1340 a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1341 for i in 0..p {
1342 a_aug_host[dst_off + n + i] = r_s[(i, j)];
1345 }
1346 }
1347 let mut a_dev = stream
1348 .clone_htod(&a_aug_host)
1349 .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1350
1351 let mut b_host = vec![0.0_f64; n_aug];
1353 b_host[..n].copy_from_slice(wy);
1354 let mut b_dev = stream
1355 .clone_htod(&b_host)
1356 .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1357
1358 let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1359 let n_aug_i: i32 = i32::try_from(n_aug)
1360 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1361 let p_i: i32 = i32::try_from(p)
1362 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1363
1364 let mut lwork: i32 = 0;
1366 {
1367 let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1368 let status = unsafe {
1371 cusolver_sys::cusolverDnDgeqrf_bufferSize(
1372 solver.cu(),
1373 n_aug_i,
1374 p_i,
1375 a_ptr as *mut f64,
1376 n_aug_i,
1377 &mut lwork,
1378 )
1379 };
1380 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1381 gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1382 }
1383 }
1384 let lwork_us = usize::try_from(lwork)
1385 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1386 let mut workspace = stream
1387 .alloc_zeros::<f64>(lwork_us.max(1))
1388 .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1389 let mut tau = stream
1390 .alloc_zeros::<f64>(p)
1391 .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1392 let mut info = stream
1393 .alloc_zeros::<i32>(1)
1394 .gpu_ctx("solve_penalised_ls_device alloc info")?;
1395
1396 {
1398 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1399 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1400 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1401 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1402 let status = unsafe {
1405 cusolver_sys::cusolverDnDgeqrf(
1406 solver.cu(),
1407 n_aug_i,
1408 p_i,
1409 a_ptr as *mut f64,
1410 n_aug_i,
1411 tau_ptr as *mut f64,
1412 work_ptr as *mut f64,
1413 lwork,
1414 info_ptr as *mut i32,
1415 )
1416 };
1417 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1418 gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1419 }
1420 }
1421
1422 let mut ormqr_lwork: i32 = 0;
1424 {
1425 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1426 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1427 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1428 let status = unsafe {
1431 cusolver_sys::cusolverDnDormqr_bufferSize(
1432 solver.cu(),
1433 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1434 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1435 n_aug_i,
1436 1,
1437 p_i,
1438 a_ptr as *const f64,
1439 n_aug_i,
1440 tau_ptr as *const f64,
1441 b_ptr as *mut f64,
1442 n_aug_i,
1443 &mut ormqr_lwork,
1444 )
1445 };
1446 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1447 gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1448 }
1449 }
1450 if ormqr_lwork > lwork {
1451 workspace = stream
1452 .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1453 .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1454 }
1455 {
1456 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1457 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1458 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1459 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1460 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1461 let status = unsafe {
1465 cusolver_sys::cusolverDnDormqr(
1466 solver.cu(),
1467 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1468 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1469 n_aug_i,
1470 1,
1471 p_i,
1472 a_ptr as *const f64,
1473 n_aug_i,
1474 tau_ptr as *const f64,
1475 b_ptr as *mut f64,
1476 n_aug_i,
1477 work_ptr as *mut f64,
1478 ormqr_lwork.max(lwork),
1479 info_ptr as *mut i32,
1480 )
1481 };
1482 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1483 gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1484 }
1485 }
1486
1487 {
1490 use cudarc::cublas::CudaBlas;
1491 let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1492 let alpha = 1.0_f64;
1493 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1494 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1495 let handle = *blas.handle();
1500 let status = unsafe {
1501 cudarc::cublas::sys::cublasDtrsm_v2(
1502 handle,
1503 cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1504 cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1505 cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1506 cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1507 p_i,
1508 1,
1509 &alpha,
1510 a_ptr as *const f64,
1511 n_aug_i,
1512 b_ptr as *mut f64,
1513 n_aug_i,
1514 )
1515 };
1516 if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1517 gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1518 }
1519 }
1520
1521 let mut b_out = vec![0.0_f64; n_aug];
1523 stream
1524 .memcpy_dtoh(&b_dev, &mut b_out)
1525 .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1526 let mut a_back = vec![0.0_f64; n_aug * p];
1527 stream
1528 .memcpy_dtoh(&a_dev, &mut a_back)
1529 .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1530 stream
1531 .synchronize()
1532 .gpu_ctx("solve_penalised_ls_device synchronize")?;
1533
1534 let beta: Vec<f64> = b_out[..p].to_vec();
1535 let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1544
1545 let mut log_abs_r = 0.0_f64;
1547 for k in 0..p {
1548 let r_kk = a_back[k * n_aug + k];
1549 log_abs_r += r_kk.abs().ln();
1550 }
1551 let log_det_hessian = 2.0 * log_abs_r;
1552
1553 Ok(PenalisedLsSolution {
1554 beta,
1555 weighted_residual_ssq: augmented_residual_ssq,
1556 log_det_hessian,
1557 })
1558}
1559
1560#[cfg(not(target_os = "linux"))]
1561pub fn solve_penalised_ls_device(
1562 x_s_device: &DeviceS2KernelMatrix,
1563 wy: &[f64],
1564 r_s: ArrayView2<'_, f64>,
1565) -> Result<PenalisedLsSolution, GpuError> {
1566 Err(GpuError::DriverLibraryUnavailable {
1567 reason: format!(
1568 "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1569 x_s_device.rows,
1570 x_s_device.cols,
1571 wy.len(),
1572 r_s.dim()
1573 ),
1574 })
1575}
1576
1577#[cfg(test)]
1582mod sphere_gpu_tests {
1583 use super::*;
1584 use crate::basis::{
1585 SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1586 spherical_wahba_kernel_matrix_with_kind,
1587 };
1588 use ndarray::Array2;
1589
1590 fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1591 let mut rows = Vec::with_capacity(n_lat * n_lon);
1593 for i in 0..n_lat {
1594 let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1595 for j in 0..n_lon {
1596 let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1597 rows.push(lat);
1598 rows.push(lon);
1599 }
1600 }
1601 Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1602 }
1603
1604 #[test]
1605 fn sum_finite_guard_accepts_finite_rejects_nonfinite() {
1606 let finite = Array2::<f64>::from_shape_fn((5, 7), |(i, j)| (i as f64 - 2.0) * (j as f64));
1611 assert!(finite.sum().is_finite());
1612
1613 let mut with_nan = finite.clone();
1614 with_nan[[3, 4]] = f64::NAN;
1615 assert!(!with_nan.sum().is_finite());
1616
1617 let mut with_pos_inf = finite.clone();
1618 with_pos_inf[[0, 0]] = f64::INFINITY;
1619 assert!(!with_pos_inf.sum().is_finite());
1620
1621 let mut with_neg_inf = finite.clone();
1622 with_neg_inf[[4, 6]] = f64::NEG_INFINITY;
1623 assert!(!with_neg_inf.sum().is_finite());
1624 }
1625
1626 #[test]
1627 fn xyz_preprocessing_matches_unit_sphere() {
1628 let latlon = ndarray::array![
1629 [0.0, 0.0],
1630 [90.0, 0.0],
1631 [0.0, 90.0],
1632 [-90.0, 17.5],
1633 [45.0, -120.0],
1634 ];
1635 let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1636 assert_eq!(xyz.len(), 3 * 5);
1637 for i in 0..5 {
1638 let nrm2 = xyz[3 * i] * xyz[3 * i]
1639 + xyz[3 * i + 1] * xyz[3 * i + 1]
1640 + xyz[3 * i + 2] * xyz[3 * i + 2];
1641 assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1642 }
1643 assert!((xyz[0] - 1.0).abs() < 1e-15);
1645 assert!(xyz[1].abs() < 1e-15);
1646 assert!(xyz[2].abs() < 1e-15);
1647 assert!(xyz[3].abs() < 1e-15);
1649 assert!(xyz[4].abs() < 1e-15);
1650 assert!((xyz[5] - 1.0).abs() < 1e-15);
1651 assert!(xyz[6].abs() < 1e-15);
1653 assert!((xyz[7] - 1.0).abs() < 1e-15);
1654 assert!(xyz[8].abs() < 1e-15);
1655 }
1656
1657 #[test]
1658 fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1659 for m_penalty in 1..=4 {
1663 for &lmax in &[5_usize, 20, 50] {
1664 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1665 let expected: f64 = coeffs.iter().sum();
1666 let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1667 assert!(
1668 (got - expected).abs() < 1e-13,
1669 "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1670 );
1671 }
1672 }
1673 }
1674
1675 #[test]
1676 fn truncated_spectral_at_antipode_matches_alternating_sum() {
1677 for m_penalty in 1..=4 {
1680 for &lmax in &[5_usize, 20, 50] {
1681 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1682 let expected: f64 = coeffs
1683 .iter()
1684 .enumerate()
1685 .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1686 .sum();
1687 let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1688 assert!(
1689 (got - expected).abs() < 1e-13,
1690 "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1691 );
1692 }
1693 }
1694 }
1695
1696 #[test]
1697 fn truncated_spectral_matrix_is_symmetric() {
1698 let centers = ndarray::array![
1702 [10.0_f64, 20.0],
1703 [-30.0, 100.0],
1704 [45.0, -60.0],
1705 [-89.0, 0.0],
1706 [0.0, 180.0],
1707 [60.0, -179.9],
1708 ];
1709 for m_penalty in [1usize, 2, 4] {
1710 for &lmax in &[10_usize, 30] {
1711 let mat = spherical_wahba_kernel_matrix_with_kind(
1712 centers.view(),
1713 centers.view(),
1714 m_penalty,
1715 false,
1716 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1717 )
1718 .expect("kernel matrix");
1719 let n = centers.nrows();
1720 let mut max_asym = 0.0_f64;
1721 for i in 0..n {
1722 for j in 0..n {
1723 let d = (mat[(i, j)] - mat[(j, i)]).abs();
1724 if d > max_asym {
1725 max_asym = d;
1726 }
1727 }
1728 }
1729 assert!(
1730 max_asym < 1e-13,
1731 "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1732 );
1733 }
1734 }
1735 }
1736
1737 #[test]
1738 fn truncated_coefficients_have_zero_constant_mode() {
1739 for m in 1..=4 {
1740 let c = sobolev_s2_truncated_coefficients(50, m);
1741 assert_eq!(c.len(), 51);
1742 assert_eq!(c[0], 0.0);
1743 assert!(c[1] > 0.0);
1744 for ell in 2..=50 {
1746 assert!(
1747 c[ell] < c[ell - 1] + 1e-15,
1748 "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1749 c[ell],
1750 c[ell - 1]
1751 );
1752 }
1753 }
1754 }
1755
1756 #[test]
1757 fn truncated_spectral_matches_matrix_helper() {
1758 let m_penalty = 2;
1762 let lmax = 20;
1763 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1764 let data = ndarray::array![[12.5, -34.0]];
1765 let centers = ndarray::array![[40.0, 10.0]];
1766 let mat = spherical_wahba_kernel_matrix_with_kind(
1767 data.view(),
1768 centers.view(),
1769 m_penalty,
1770 false,
1771 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1772 )
1773 .expect("kernel matrix");
1774 let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1776 let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1777 let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1778 let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1779 assert!(
1780 (mat[(0, 0)] - expected).abs() < 1e-13,
1781 "matrix helper differs from scalar evaluator: {} vs {}",
1782 mat[(0, 0)],
1783 expected
1784 );
1785 }
1786
1787 #[test]
1788 fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1789 let m = 6;
1794 let mut c = Array2::<f64>::zeros((m, m));
1795 for i in 0..m {
1796 for j in 0..m {
1797 let d = (i as f64 - j as f64).abs();
1798 c[(i, j)] = (-0.5 * d).exp();
1799 }
1800 }
1801 let w = vec![1.0_f64; m];
1802 let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1803 assert_eq!(s.dim(), (m - 1, m - 1));
1804 let mut max_asym = 0.0_f64;
1806 for i in 0..(m - 1) {
1807 for j in 0..(m - 1) {
1808 let d = (s[(i, j)] - s[(j, i)]).abs();
1809 if d > max_asym {
1810 max_asym = d;
1811 }
1812 }
1813 }
1814 assert!(
1815 max_asym < 1e-13,
1816 "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1817 );
1818
1819 let ones = ndarray::Array1::<f64>::ones(m - 1);
1827 let sx = s.dot(&ones);
1828 assert!(sx.iter().all(|v| v.is_finite()));
1829 }
1830
1831 #[test]
1832 fn householder_reflector_zeroes_target_vector() {
1833 let w = vec![3.0, 4.0, 0.0, -1.0];
1834 let (v, beta) = householder_reflector_from_weights(&w);
1835 let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1838 let hw: Vec<f64> = w
1839 .iter()
1840 .zip(&v)
1841 .map(|(wj, vj)| wj - beta * dot * vj)
1842 .collect();
1843 for entry in hw.iter().skip(1) {
1844 assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1845 }
1846 assert!(hw[0].abs() > 0.0);
1847 }
1848
1849 #[test]
1852 fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1853 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1854 eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1855 return;
1856 };
1857 SphereGpuBackend::probe()
1860 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1861
1862 let data_ll = small_latlon_grid(7, 9);
1863 let centers_ll = small_latlon_grid(5, 7);
1864 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1865 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1866 let n = data_ll.nrows();
1867 let m = centers_ll.nrows();
1868 let penalty = 2usize;
1869 let lmax = 20usize;
1870 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1871
1872 let inputs = S2KernelBuildInputs {
1873 n,
1874 m,
1875 lmax,
1876 data_xyz: &data_xyz,
1877 centers_xyz: ¢ers_xyz,
1878 coeffs: &coeffs,
1879 kind: SphereSpectralKernelKind::Sobolev,
1880 layout: DeviceMatrixLayout::ColumnMajor,
1881 };
1882 let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1883 let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1884
1885 let cpu = spherical_wahba_kernel_matrix_with_kind(
1886 data_ll.view(),
1887 centers_ll.view(),
1888 penalty,
1889 false,
1890 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1891 )
1892 .expect("cpu kernel matrix");
1893
1894 let mut max_abs = 0.0_f64;
1895 for i in 0..n {
1896 for j in 0..m {
1897 let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1898 if d > max_abs {
1899 max_abs = d;
1900 }
1901 }
1902 }
1903 assert!(
1904 max_abs < 1e-11,
1905 "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1906 );
1907 }
1908
1909 #[test]
1923 fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1924 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1925 eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1926 return;
1927 };
1928 SphereGpuBackend::probe()
1932 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1933 use crate::basis::{
1934 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1935 build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1936 spherical_wahba_kernel_matrix_with_kind,
1937 };
1938
1939 let data = small_latlon_grid(100, 100);
1941 let lmax: u16 = 30;
1942 let penalty_order = 2usize;
1943 let centers =
1944 crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1945 .expect("centers");
1946 let n = data.nrows();
1947 let m = centers.nrows();
1948
1949 let decision = sphere_kernel_decision(n, m, lmax as usize);
1953 assert!(
1954 decision.use_gpu,
1955 "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1956 (reason={}); the engagement gate regressed",
1957 decision.reason
1958 );
1959
1960 let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1962 data.view(),
1963 centers.view(),
1964 penalty_order,
1965 false,
1966 SphereWahbaKernel::SobolevTruncated { lmax },
1967 )
1968 .expect("GPU-eligible production kernel build succeeds");
1969
1970 let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1972 data.view(),
1973 centers.view(),
1974 penalty_order,
1975 false,
1976 SphereWahbaKernel::SobolevTruncated { lmax },
1977 )
1978 .expect("cpu oracle kernel build succeeds");
1979
1980 assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1981 let mut max_abs = 0.0_f64;
1982 let mut max_rel = 0.0_f64;
1983 for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1984 let d = (g - c).abs();
1985 if d > max_abs {
1986 max_abs = d;
1987 }
1988 let denom = g.abs().max(c.abs()).max(1e-300);
1989 let r = d / denom;
1990 if r > max_rel {
1991 max_rel = r;
1992 }
1993 }
1994 assert!(
1995 max_rel < 1e-9,
1996 "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1997 >= 1e-9 (abs {max_abs:.3e})"
1998 );
1999
2000 let spec_gpu = SphericalSplineBasisSpec {
2004 center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
2005 penalty_order,
2006 double_penalty: false,
2007 radians: false,
2008 method: SphereMethod::Wahba,
2009 max_degree: None,
2010 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2011 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2012 };
2013 let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
2014 .expect("GPU-eligible build_spherical_spline_basis succeeds");
2015 let design = result_gpu.design.as_dense().expect("dense design");
2016 assert_eq!(design.nrows(), n, "design row count must match data rows");
2017 assert!(
2018 design.iter().all(|v| v.is_finite()),
2019 "engaged-device spherical design must be finite"
2020 );
2021 }
2022
2023 #[test]
2026 fn sphere_gpu_householder_parity_vs_raw_dot_z() {
2027 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2028 eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
2029 return;
2030 };
2031 SphereGpuBackend::probe()
2034 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
2035 let data_ll = small_latlon_grid(6, 8);
2036 let centers_ll = small_latlon_grid(4, 5);
2037 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2038 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2039 let n = data_ll.nrows();
2040 let m = centers_ll.nrows();
2041 let penalty = 2usize;
2042 let lmax = 15usize;
2043 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
2044
2045 let inputs_raw = S2KernelBuildInputs {
2047 n,
2048 m,
2049 lmax,
2050 data_xyz: &data_xyz,
2051 centers_xyz: ¢ers_xyz,
2052 coeffs: &coeffs,
2053 kind: SphereSpectralKernelKind::Sobolev,
2054 layout: DeviceMatrixLayout::ColumnMajor,
2055 };
2056 let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
2057 let b = b_dev.to_host_array().expect("dtoh raw");
2058
2059 let w = vec![1.0_f64; m];
2062 let (v, beta) = householder_reflector_from_weights(&w);
2063
2064 let mut xs_host = Array2::<f64>::zeros((n, m - 1));
2066 for i in 0..n {
2067 let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
2068 for j_out in 0..(m - 1) {
2069 xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
2070 }
2071 }
2072
2073 let xs_dev =
2074 build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
2075 let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
2076
2077 let mut max_abs = 0.0_f64;
2078 for i in 0..n {
2079 for j in 0..(m - 1) {
2080 let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
2081 if d > max_abs {
2082 max_abs = d;
2083 }
2084 }
2085 }
2086 assert!(
2087 max_abs < 1e-12,
2088 "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
2089 );
2090 }
2091
2092 #[test]
2096 fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
2097 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2098 eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
2099 return;
2100 };
2101 if SphereGpuBackend::probe().is_err() {
2102 eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
2103 return;
2104 }
2105
2106 let n_lat = 500usize;
2109 let n_lon = 400usize;
2110 assert_eq!(n_lat * n_lon, 200_000);
2111 let data_ll = small_latlon_grid(n_lat, n_lon);
2112 let m = 200usize;
2113 let centers_ll =
2114 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2115 .expect("centers");
2116 let n = data_ll.nrows();
2117 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2118 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2119 let penalty_order = 2usize;
2120 let lmax = 50usize;
2121 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
2122
2123 let inputs_warm = S2KernelBuildInputs {
2125 n,
2126 m,
2127 lmax,
2128 data_xyz: &data_xyz,
2129 centers_xyz: ¢ers_xyz,
2130 coeffs: &coeffs,
2131 kind: SphereSpectralKernelKind::Sobolev,
2132 layout: DeviceMatrixLayout::ColumnMajor,
2133 };
2134 {
2139 let warm = build_kernel_matrix_device(inputs_warm.clone()).expect("warmup");
2140 drop(warm.to_host_array().expect("warmup to_host"));
2141 }
2142
2143 let t0 = std::time::Instant::now();
2145 let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
2146 let _host_gpu = dev.to_host_array().expect("dtoh");
2147 let gpu_secs = t0.elapsed().as_secs_f64();
2148
2149 let t1 = std::time::Instant::now();
2156 let _cpu = crate::basis::spherical_wahba_kernel_matrix_cpu(
2157 data_ll.view(),
2158 centers_ll.view(),
2159 penalty_order,
2160 false,
2161 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
2162 )
2163 .expect("cpu kernel matrix");
2164 let cpu_secs = t1.elapsed().as_secs_f64();
2165
2166 let ratio = cpu_secs / gpu_secs.max(1e-9);
2167 eprintln!(
2168 "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
2169 );
2170 assert!(
2171 ratio >= 20.0,
2172 "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
2173 n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2174 );
2175 }
2176
2177 #[test]
2182 fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
2183 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2184 eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
2185 return;
2186 };
2187 if SphereGpuBackend::probe().is_err() {
2188 eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
2189 return;
2190 }
2191 use crate::basis::{
2192 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
2193 build_spherical_spline_basis,
2194 };
2195
2196 let n_lat = 500usize;
2197 let n_lon = 400usize;
2198 let data_ll = small_latlon_grid(n_lat, n_lon);
2199 let m: usize = 200;
2200 let lmax: u16 = 50;
2201 let spec_gpu = SphericalSplineBasisSpec {
2202 center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
2203 penalty_order: 2,
2204 double_penalty: false,
2205 radians: false,
2206 method: SphereMethod::Wahba,
2207 max_degree: None,
2208 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2209 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2210 };
2211
2212 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
2214
2215 let t0 = std::time::Instant::now();
2216 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
2217 let gpu_secs = t0.elapsed().as_secs_f64();
2218
2219 let centers =
2226 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2227 .expect("centers");
2228 let z = Array2::<f64>::eye(centers.nrows());
2229 let t1 = std::time::Instant::now();
2230 let raw_cpu = crate::basis::spherical_wahba_kernel_matrix_cpu(
2235 data_ll.view(),
2236 centers.view(),
2237 2,
2238 false,
2239 SphereWahbaKernel::SobolevTruncated { lmax },
2240 )
2241 .expect("cpu raw");
2242 let _design_cpu = raw_cpu.dot(&z);
2243 let cpu_secs = t1.elapsed().as_secs_f64();
2244
2245 let ratio = cpu_secs / gpu_secs.max(1e-9);
2246 eprintln!(
2247 "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
2248 data_ll.nrows()
2249 );
2250 assert!(
2251 ratio >= 10.0,
2252 "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
2253 cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2254 );
2255 }
2256
2257 #[test]
2279 fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
2280 use crate::basis::{
2281 select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
2282 };
2283 use faer::Side;
2284 use gam_linalg::faer_ndarray::FaerCholesky;
2285
2286 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2287 eprintln!(
2288 "[sphere gpu parity] no CUDA runtime — skipping device parity \
2289 (CPU oracle exercised by sibling tests)"
2290 );
2291 return;
2292 };
2293 SphereGpuBackend::probe()
2296 .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2297
2298 let data_ll = small_latlon_grid(25, 40);
2300 assert_eq!(data_ll.nrows(), 1000);
2301 let n = data_ll.nrows();
2302 let m: usize = 80;
2303 let lmax_u16: u16 = 15;
2304 let lmax: usize = lmax_u16 as usize;
2305 let penalty_order: usize = 2;
2306 let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2307 let lambda: f64 = 1.0e-3;
2308
2309 let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2311 .expect("farthest-point centers");
2312 assert_eq!(centers_ll.nrows(), m);
2313
2314 let z = Array2::<f64>::eye(centers_ll.nrows());
2317 let p = z.ncols();
2318 assert_eq!(p, m);
2319
2320 let k_cc = spherical_wahba_kernel_matrix_with_kind(
2325 centers_ll.view(),
2326 centers_ll.view(),
2327 penalty_order,
2328 false,
2329 kernel,
2330 )
2331 .expect("centers×centers kernel");
2332 let s_full = z.t().dot(&k_cc).dot(&z);
2333
2334 let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2336 data_ll.view(),
2337 centers_ll.view(),
2338 penalty_order,
2339 false,
2340 kernel,
2341 )
2342 .expect("CPU raw design");
2343 let x_s_cpu = raw_design_cpu.dot(&z);
2344
2345 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2347 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2348 let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2349 let inputs = S2KernelBuildInputs {
2350 n,
2351 m,
2352 lmax,
2353 data_xyz: &data_xyz,
2354 centers_xyz: ¢ers_xyz,
2355 coeffs: &coeffs,
2356 kind: SphereSpectralKernelKind::Sobolev,
2357 layout: DeviceMatrixLayout::ColumnMajor,
2358 };
2359 let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2360 let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2361 let x_s_gpu = raw_design_gpu.dot(&z);
2362
2363 assert_eq!(x_s_cpu.dim(), (n, p));
2364 assert_eq!(x_s_gpu.dim(), (n, p));
2365
2366 let mut raw_xs_delta = 0.0_f64;
2376 let mut xs_scale = 0.0_f64;
2377 for (a, b) in x_s_cpu.iter().zip(x_s_gpu.iter()) {
2378 raw_xs_delta = raw_xs_delta.max((a - b).abs());
2379 xs_scale = xs_scale.max(a.abs());
2380 }
2381 let cond = {
2384 use gam_linalg::faer_ndarray::FaerEigh;
2385 let xtx = x_s_cpu.t().dot(&x_s_cpu);
2386 let mut a = xtx;
2387 for i in 0..p {
2388 for j in 0..p {
2389 a[(i, j)] += lambda * s_full[(i, j)];
2390 }
2391 }
2392 let (mut lo, mut hi) = (f64::INFINITY, 0.0_f64);
2393 if let Ok((vals, _)) = a.eigh(faer::Side::Lower) {
2394 for &v in vals.iter() {
2395 lo = lo.min(v);
2396 hi = hi.max(v);
2397 }
2398 }
2399 hi / lo.max(1e-300)
2400 };
2401 assert!(
2406 raw_xs_delta <= 1e-12 * xs_scale.max(1.0),
2407 "GPU vs CPU sphere design matrix max |Δ| = {raw_xs_delta:.3e} > {:.3e} \
2408 (scale {xs_scale:.3e}) — the kernel itself drifted (this is the genuine \
2409 GPU output, NOT a conditioning artifact)",
2410 1e-12 * xs_scale.max(1.0)
2411 );
2412
2413 let mut y = ndarray::Array1::<f64>::zeros(n);
2419 for i in 0..n {
2420 let lat_rad = data_ll[(i, 0)].to_radians();
2421 let lon_rad = data_ll[(i, 1)].to_radians();
2422 y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2424 + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2425 }
2426
2427 let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2432 let xtx = x_s.t().dot(x_s);
2433 let mut a = xtx;
2434 for i in 0..p {
2435 for j in 0..p {
2436 a[(i, j)] += lambda * s_full[(i, j)];
2437 }
2438 }
2439 let rhs = x_s.t().dot(&y);
2440 let factor = a
2441 .cholesky(Side::Lower)
2442 .expect("penalised normal equations are SPD under λ > 0");
2443 factor.solvevec(&rhs)
2444 };
2445
2446 let beta_cpu = solve_penalised(&x_s_cpu);
2447 let beta_gpu = solve_penalised(&x_s_gpu);
2448 assert_eq!(beta_cpu.len(), p);
2449 assert_eq!(beta_gpu.len(), p);
2450
2451 let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2455 let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2456
2457 let mut max_beta_delta = 0.0_f64;
2458 for k in 0..p {
2459 let d = (beta_cpu[k] - beta_gpu[k]).abs();
2460 if d > max_beta_delta {
2461 max_beta_delta = d;
2462 }
2463 }
2464 let mut max_fit_delta = 0.0_f64;
2465 for i in 0..n {
2466 let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2467 if d > max_fit_delta {
2468 max_fit_delta = d;
2469 }
2470 }
2471
2472 eprintln!(
2473 "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2474 raw_xs|Δ|={raw_xs_delta:.3e} cond={cond:.3e} \
2475 max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2476 );
2477
2478 assert!(
2485 max_fit_delta <= 1.0e-9,
2486 "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2487 );
2488
2489 let beta_tol = (1e-15 * cond * (1.0 + xs_scale)).max(1e-9) * 16.0;
2501 assert!(
2502 max_beta_delta <= beta_tol,
2503 "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > \
2504 condition-aware tol {beta_tol:.3e} (cond={cond:.3e}). Raw design parity is \
2505 {raw_xs_delta:.3e}; a drift THIS much larger than cond·ULP is a real solve/kernel \
2506 mismatch, not conditioning."
2507 );
2508 }
2509}