1use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41 Sobolev,
43 Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48 pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52 match self {
53 SphereSpectralKernelKind::Sobolev => {
54 crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55 }
56 SphereSpectralKernelKind::Pseudo => {
57 crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58 }
59 }
60 }
61
62 pub const fn tag(self) -> &'static str {
64 match self {
65 SphereSpectralKernelKind::Sobolev => "sobolev",
66 SphereSpectralKernelKind::Pseudo => "pseudo",
67 }
68 }
69}
70
71#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76 ColumnMajor,
77}
78
79pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86 if latlon.ncols() != 2 {
87 return Err(format!(
88 "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89 latlon.shape()
90 ));
91 }
92 let deg = if radians {
93 1.0
94 } else {
95 std::f64::consts::PI / 180.0
96 };
97 let n = latlon.nrows();
98 let mut out = Vec::with_capacity(3 * n);
99 for row in latlon.outer_iter() {
100 let lat = row[0] * deg;
101 let lon = row[1] * deg;
102 let (s_lat, c_lat) = lat.sin_cos();
103 let (s_lon, c_lon) = lon.sin_cos();
104 out.push(c_lat * c_lon);
106 out.push(c_lat * s_lon);
107 out.push(s_lat);
108 }
109 Ok(out)
110}
111
112#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120 pub rows: usize,
121 pub cols: usize,
122 pub ld: usize,
123 pub col_major_dev: CudaSlice<f64>,
124 pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129 pub rows: usize,
130 pub cols: usize,
131 pub ld: usize,
132 pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
141 let mut col_major = vec![0.0_f64; self.ld * self.cols];
142 self.copy_to_host_col_major(&mut col_major)?;
143 let mut out = Array2::<f64>::zeros((self.rows, self.cols));
144 for j in 0..self.cols {
145 for i in 0..self.rows {
146 out[(i, j)] = col_major[j * self.ld + i];
147 }
148 }
149 Ok(out)
150 }
151
152 #[cfg(target_os = "linux")]
157 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
158 let needed = self.ld * self.cols;
159 if dst.len() != needed {
160 gam_gpu::gpu_bail!(
161 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
162 dst.len(),
163 needed
164 );
165 }
166 self.stream
167 .memcpy_dtoh(&self.col_major_dev, dst)
168 .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
169 self.stream
170 .synchronize()
171 .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
172 Ok(())
173 }
174
175 #[cfg(not(target_os = "linux"))]
176 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
177 let needed = self.ld * self.cols;
178 if dst.len() != needed {
179 gam_gpu::gpu_bail!(
180 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
181 dst.len(),
182 needed
183 );
184 }
185 dst.copy_from_slice(&self.col_major_dev);
186 Ok(())
187 }
188}
189
190#[derive(Clone, Debug)]
201pub struct S2KernelBuildInputs<'a> {
202 pub n: usize,
203 pub m: usize,
204 pub lmax: usize,
205 pub data_xyz: &'a [f64],
206 pub centers_xyz: &'a [f64],
207 pub coeffs: &'a [f64],
208 pub kind: SphereSpectralKernelKind,
209 pub layout: DeviceMatrixLayout,
210}
211
212impl<'a> S2KernelBuildInputs<'a> {
213 fn validate(&self) -> Result<(), GpuError> {
214 if self.lmax == 0 {
215 return Err(GpuError::DriverCallFailed {
216 reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
217 });
218 }
219 if self.data_xyz.len() != 3 * self.n {
220 gam_gpu::gpu_bail!(
221 "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
222 self.data_xyz.len(),
223 3 * self.n
224 );
225 }
226 if self.centers_xyz.len() != 3 * self.m {
227 gam_gpu::gpu_bail!(
228 "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
229 self.centers_xyz.len(),
230 3 * self.m
231 );
232 }
233 if self.coeffs.len() != self.lmax + 1 {
234 gam_gpu::gpu_bail!(
235 "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
236 self.coeffs.len(),
237 self.lmax + 1
238 );
239 }
240 if self.coeffs[0] != 0.0 {
241 return Err(GpuError::DriverCallFailed {
242 reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
243 });
244 }
245 Ok(())
246 }
247}
248
249#[cfg(target_os = "linux")]
259const KERNEL_TEMPLATE: &str = r#"
260// LMAX is supplied by the host via a `#define LMAX ...` prepended to
261// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
262extern "C" __global__
263__launch_bounds__(256)
264void s2_wahba_legendre_colmajor(
265 const double* __restrict__ data_xyz, // n × 3 (row-major flat)
266 const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
267 const double* __restrict__ coeffs, // length LMAX + 1, coeffs[0] = 0
268 int n,
269 int m,
270 long long ld,
271 double* __restrict__ out // ld × m column-major
272) {
273 const int i = blockIdx.y * blockDim.y + threadIdx.y;
274 const int j = blockIdx.x * blockDim.x + threadIdx.x;
275 if (i >= n || j >= m) return;
276
277 // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
278 const double xi = data_xyz[3 * i + 0];
279 const double yi = data_xyz[3 * i + 1];
280 const double zi = data_xyz[3 * i + 2];
281 const double cxj = centers_xyz[3 * j + 0];
282 const double cyj = centers_xyz[3 * j + 1];
283 const double czj = centers_xyz[3 * j + 2];
284
285 // t = clamp(x_i · z_j, -1, +1).
286 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
287 if (t > 1.0) t = 1.0;
288 if (t < -1.0) t = -1.0;
289
290 // Legendre 3-term recurrence in registers.
291 // P_0(t) = 1, P_1(t) = t.
292 double p_prev = 1.0;
293 double p_curr = t;
294 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
295
296 #pragma unroll 8
297 for (int ell = 1; ell < LMAX; ++ell) {
298 const double lf = (double) ell;
299 const double inv = 1.0 / (lf + 1.0);
300 // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
301 const double p_next =
302 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
303 acc = fma(coeffs[ell + 1], p_next, acc);
304 p_prev = p_curr;
305 p_curr = p_next;
306 }
307
308 out[(long long) j * ld + (long long) i] = acc;
309}
310
311// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
312// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
313// i.e. drop the first column after applying Z. Each thread computes one
314// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
315// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
316// for j_out in 0..m-1.
317//
318// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
319// over centers in an inner loop — register-bound by the per-row state
320// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
321extern "C" __global__
322__launch_bounds__(128)
323void s2_wahba_householder_constrained_colmajor(
324 const double* __restrict__ data_xyz, // n × 3
325 const double* __restrict__ centers_xyz, // m × 3
326 const double* __restrict__ coeffs, // length LMAX + 1
327 const double* __restrict__ v, // length m, Householder vector
328 double beta,
329 int n,
330 int m,
331 long long ld_out,
332 double* __restrict__ out // ld_out × (m-1) column-major
333) {
334 const int i = blockIdx.x * blockDim.x + threadIdx.x;
335 if (i >= n) return;
336
337 const double xi = data_xyz[3 * i + 0];
338 const double yi = data_xyz[3 * i + 1];
339 const double zi = data_xyz[3 * i + 2];
340
341 // Pass 1: compute d_i = sum_j v[j] * B[i, j].
342 double d_i = 0.0;
343 for (int j = 0; j < m; ++j) {
344 const double cxj = centers_xyz[3 * j + 0];
345 const double cyj = centers_xyz[3 * j + 1];
346 const double czj = centers_xyz[3 * j + 2];
347 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
348 if (t > 1.0) t = 1.0;
349 if (t < -1.0) t = -1.0;
350
351 double p_prev = 1.0;
352 double p_curr = t;
353 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
354 #pragma unroll 8
355 for (int ell = 1; ell < LMAX; ++ell) {
356 const double lf = (double) ell;
357 const double inv = 1.0 / (lf + 1.0);
358 const double p_next =
359 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
360 acc = fma(coeffs[ell + 1], p_next, acc);
361 p_prev = p_curr;
362 p_curr = p_next;
363 }
364 d_i = fma(v[j], acc, d_i);
365 }
366
367 // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
368 const double bd = beta * d_i;
369 for (int j_out = 0; j_out < m - 1; ++j_out) {
370 const int j = j_out + 1;
371 const double cxj = centers_xyz[3 * j + 0];
372 const double cyj = centers_xyz[3 * j + 1];
373 const double czj = centers_xyz[3 * j + 2];
374 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
375 if (t > 1.0) t = 1.0;
376 if (t < -1.0) t = -1.0;
377
378 double p_prev = 1.0;
379 double p_curr = t;
380 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
381 #pragma unroll 8
382 for (int ell = 1; ell < LMAX; ++ell) {
383 const double lf = (double) ell;
384 const double inv = 1.0 / (lf + 1.0);
385 const double p_next =
386 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
387 acc = fma(coeffs[ell + 1], p_next, acc);
388 p_prev = p_curr;
389 p_curr = p_next;
390 }
391 const double xs = acc - bd * v[j];
392 out[(long long) j_out * ld_out + (long long) i] = xs;
393 }
394}
395"#;
396
397#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
407pub struct S2ModuleCacheKey {
408 pub cc_major: i32,
409 pub cc_minor: i32,
410 pub lmax: u32,
411 pub kind: SphereSpectralKernelKind,
412 pub layout: DeviceMatrixLayout,
413}
414
415pub const fn sphere_gpu_compiled() -> bool {
418 cfg!(target_os = "linux")
419}
420
421#[must_use]
428pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
429 let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
430 let ld = ((n + 31) / 32) * 32;
431 let needed_bytes = ld
432 .saturating_mul(m)
433 .saturating_mul(std::mem::size_of::<f64>());
434 let budget = runtime.memory_budget_bytes;
435 n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
436 } else {
437 false
438 };
439 decide(
440 GpuKernel::SpatialKernelOperator,
441 gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
442 )
443}
444
445#[must_use]
451pub fn truncated_device_kind(
452 kernel: crate::basis::SphereWahbaKernel,
453) -> Option<(SphereSpectralKernelKind, u16)> {
454 use crate::basis::SphereWahbaKernel;
455 match kernel {
456 SphereWahbaKernel::SobolevTruncated { lmax } => {
457 Some((SphereSpectralKernelKind::Sobolev, lmax))
458 }
459 SphereWahbaKernel::PseudoTruncated { lmax } => {
460 Some((SphereSpectralKernelKind::Pseudo, lmax))
461 }
462 SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
463 }
464}
465
466pub fn try_build_truncated_kernel_matrix_gpu(
488 data: ArrayView2<'_, f64>,
489 centers: ArrayView2<'_, f64>,
490 penalty_order: usize,
491 radians: bool,
492 kernel: crate::basis::SphereWahbaKernel,
493) -> Option<Result<Array2<f64>, GpuError>> {
494 let (kind, lmax) = truncated_device_kind(kernel)?;
495 let n = data.nrows();
496 let m = centers.nrows();
497 if n == 0 || m == 0 || lmax == 0 {
498 return None;
499 }
500 let decision = sphere_kernel_decision(n, m, lmax as usize);
501 if !decision.use_gpu {
502 return None;
505 }
506 Some(build_truncated_kernel_matrix_gpu_admitted(
508 data,
509 centers,
510 penalty_order,
511 radians,
512 kind,
513 lmax,
514 ))
515}
516
517fn build_truncated_kernel_matrix_gpu_admitted(
521 data: ArrayView2<'_, f64>,
522 centers: ArrayView2<'_, f64>,
523 penalty_order: usize,
524 radians: bool,
525 kind: SphereSpectralKernelKind,
526 lmax: u16,
527) -> Result<Array2<f64>, GpuError> {
528 let n = data.nrows();
529 let m = centers.nrows();
530 let data_xyz = latlon_to_xyz_host(data, radians)
531 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
532 let centers_xyz = latlon_to_xyz_host(centers, radians)
533 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
534 let coeffs = kind.coefficients(lmax as usize, penalty_order);
538 let inputs = S2KernelBuildInputs {
539 n,
540 m,
541 lmax: lmax as usize,
542 data_xyz: &data_xyz,
543 centers_xyz: ¢ers_xyz,
544 coeffs: &coeffs,
545 kind,
546 layout: DeviceMatrixLayout::ColumnMajor,
547 };
548 let device_matrix = build_kernel_matrix_device(inputs)?;
549 let out = device_matrix.to_host_array()?;
550 if out.iter().any(|v| !v.is_finite()) {
551 return Err(GpuError::DriverCallFailed {
552 reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
553 });
554 }
555 Ok(out)
556}
557
558#[cfg(target_os = "linux")]
559struct SphereGpuContext {
560 ctx: Arc<CudaContext>,
561 stream: Arc<CudaStream>,
562 modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
563 cc_major: i32,
564 cc_minor: i32,
565}
566
567pub struct SphereGpuBackend {
570 #[cfg(target_os = "linux")]
571 inner: SphereGpuContext,
572}
573
574impl SphereGpuBackend {
575 pub fn probe() -> Result<&'static Self, GpuError> {
577 static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
578 BACKEND
579 .get_or_init(|| {
580 #[cfg(target_os = "linux")]
581 {
582 Self::probe_linux()
583 }
584 #[cfg(not(target_os = "linux"))]
585 {
586 Err(GpuError::DriverLibraryUnavailable {
587 reason: "sphere GPU backend is Linux-only".to_string(),
588 })
589 }
590 })
591 .as_ref()
592 .map_err(GpuError::clone)
593 }
594
595 #[cfg(target_os = "linux")]
596 fn probe_linux() -> Result<Self, GpuError> {
597 let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
598 Ok(SphereGpuBackend {
599 inner: SphereGpuContext {
600 ctx: parts.ctx,
601 stream: parts.stream,
602 modules: Mutex::new(HashMap::new()),
603 cc_major: parts.capability.compute_major,
604 cc_minor: parts.capability.compute_minor,
605 },
606 })
607 }
608
609 #[cfg(target_os = "linux")]
612 fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
613 if let Ok(guard) = self.inner.modules.lock() {
614 if let Some(existing) = guard.get(&key) {
615 return Ok(existing.clone());
616 }
617 }
618 let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
625 let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
626 format!(
627 "sphere NVRTC compile (kind={}, lmax={}): {err}",
628 key.kind.tag(),
629 key.lmax
630 )
631 })?;
632 let module = self
633 .inner
634 .ctx
635 .load_module(ptx)
636 .gpu_ctx("sphere module load")?;
637 if let Ok(mut guard) = self.inner.modules.lock() {
638 guard.entry(key).or_insert_with(|| module.clone());
639 }
640 Ok(module)
641 }
642
643 #[cfg(target_os = "linux")]
644 fn cc(&self) -> (i32, i32) {
645 (self.inner.cc_major, self.inner.cc_minor)
646 }
647}
648
649pub fn build_kernel_matrix_device(
656 inputs: S2KernelBuildInputs<'_>,
657) -> Result<DeviceS2KernelMatrix, GpuError> {
658 inputs.validate()?;
659
660 #[cfg(target_os = "linux")]
661 {
662 use cudarc::driver::{LaunchConfig, PushKernelArg};
663 let backend = SphereGpuBackend::probe()?;
664 let (cc_major, cc_minor) = backend.cc();
665 let key = S2ModuleCacheKey {
666 cc_major,
667 cc_minor,
668 lmax: inputs.lmax as u32,
669 kind: inputs.kind,
670 layout: inputs.layout,
671 };
672 let module = backend.module_for(key)?;
673 let func = module
674 .load_function("s2_wahba_legendre_colmajor")
675 .gpu_ctx("sphere load_function raw")?;
676 let stream = backend.inner.stream.clone();
677
678 let data_dev = stream
679 .clone_htod(inputs.data_xyz)
680 .gpu_ctx("sphere htod data_xyz")?;
681 let centers_dev = stream
682 .clone_htod(inputs.centers_xyz)
683 .gpu_ctx("sphere htod centers_xyz")?;
684 let coeffs_dev = stream
685 .clone_htod(inputs.coeffs)
686 .gpu_ctx("sphere htod coeffs")?;
687
688 let n = inputs.n;
689 let m = inputs.m;
690 let ld = ((n + 31) / 32) * 32;
691 let mut out_dev = stream
692 .alloc_zeros::<f64>(ld * m)
693 .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
694
695 let block_x: u32 = 32;
697 let block_y: u32 = 8;
698 let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
699 let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
700 let cfg = LaunchConfig {
701 grid_dim: (grid_x, grid_y, 1),
702 block_dim: (block_x, block_y, 1),
703 shared_mem_bytes: 0,
704 };
705 let n_i32: i32 =
706 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
707 let m_i32: i32 =
708 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
709 let ld_i64: i64 = ld as i64;
710
711 let mut builder = stream.launch_builder(&func);
712 builder
713 .arg(&data_dev)
714 .arg(¢ers_dev)
715 .arg(&coeffs_dev)
716 .arg(&n_i32)
717 .arg(&m_i32)
718 .arg(&ld_i64)
719 .arg(&mut out_dev);
720 unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
725 stream
726 .synchronize()
727 .gpu_ctx("sphere raw kernel synchronize")?;
728
729 Ok(DeviceS2KernelMatrix {
730 rows: n,
731 cols: m,
732 ld,
733 col_major_dev: out_dev,
734 stream,
735 })
736 }
737
738 #[cfg(not(target_os = "linux"))]
739 {
740 Err(GpuError::DriverLibraryUnavailable {
741 reason: "sphere GPU backend is Linux-only".to_string(),
742 })
743 }
744}
745
746pub fn build_householder_constrained_design_device(
750 inputs: S2KernelBuildInputs<'_>,
751 v: &[f64],
752 beta: f64,
753) -> Result<DeviceS2KernelMatrix, GpuError> {
754 inputs.validate()?;
755 if v.len() != inputs.m {
756 gam_gpu::gpu_bail!(
757 "build_householder_constrained_design_device: v.len()={} != m={}",
758 v.len(),
759 inputs.m
760 );
761 }
762 if inputs.m < 2 {
763 gam_gpu::gpu_bail!(
764 "build_householder_constrained_design_device: m must be >= 2 (got {})",
765 inputs.m
766 );
767 }
768 if !beta.is_finite() {
769 gam_gpu::gpu_bail!(
770 "build_householder_constrained_design_device: beta must be finite (got {beta})"
771 );
772 }
773
774 #[cfg(target_os = "linux")]
775 {
776 use cudarc::driver::{LaunchConfig, PushKernelArg};
777 let backend = SphereGpuBackend::probe()?;
778 let (cc_major, cc_minor) = backend.cc();
779 let key = S2ModuleCacheKey {
780 cc_major,
781 cc_minor,
782 lmax: inputs.lmax as u32,
783 kind: inputs.kind,
784 layout: inputs.layout,
785 };
786 let module = backend.module_for(key)?;
787 let func = module
788 .load_function("s2_wahba_householder_constrained_colmajor")
789 .gpu_ctx("sphere load_function householder")?;
790 let stream = backend.inner.stream.clone();
791
792 let data_dev = stream
793 .clone_htod(inputs.data_xyz)
794 .gpu_ctx("sphere-hh htod data_xyz")?;
795 let centers_dev = stream
796 .clone_htod(inputs.centers_xyz)
797 .gpu_ctx("sphere-hh htod centers_xyz")?;
798 let coeffs_dev = stream
799 .clone_htod(inputs.coeffs)
800 .gpu_ctx("sphere-hh htod coeffs")?;
801 let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
802
803 let n = inputs.n;
804 let m = inputs.m;
805 let cols_out = m - 1;
806 let ld_out = ((n + 31) / 32) * 32;
807 let mut out_dev = stream
808 .alloc_zeros::<f64>(ld_out * cols_out)
809 .gpu_ctx_with(|err| {
810 format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
811 })?;
812
813 let block_x: u32 = 128;
814 let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
815 let cfg = LaunchConfig {
816 grid_dim: (grid_x, 1, 1),
817 block_dim: (block_x, 1, 1),
818 shared_mem_bytes: 0,
819 };
820 let n_i32: i32 =
821 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
822 let m_i32: i32 =
823 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
824 let ld_out_i64: i64 = ld_out as i64;
825
826 let mut builder = stream.launch_builder(&func);
827 builder
828 .arg(&data_dev)
829 .arg(¢ers_dev)
830 .arg(&coeffs_dev)
831 .arg(&v_dev)
832 .arg(&beta)
833 .arg(&n_i32)
834 .arg(&m_i32)
835 .arg(&ld_out_i64)
836 .arg(&mut out_dev);
837 unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
840 stream
841 .synchronize()
842 .gpu_ctx("sphere-hh kernel synchronize")?;
843
844 Ok(DeviceS2KernelMatrix {
845 rows: n,
846 cols: cols_out,
847 ld: ld_out,
848 col_major_dev: out_dev,
849 stream,
850 })
851 }
852
853 #[cfg(not(target_os = "linux"))]
854 {
855 Err(GpuError::DriverLibraryUnavailable {
856 reason: "sphere GPU backend is Linux-only".to_string(),
857 })
858 }
859}
860
861pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
874 let m = w.len();
875 if m == 0 {
876 return (Vec::new(), 0.0);
877 }
878 let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
879 if norm == 0.0 {
880 return (vec![0.0; m], 0.0);
881 }
882 let sigma = if w[0] >= 0.0 { norm } else { -norm };
883 let mut v = w.to_vec();
884 v[0] += sigma;
885 let v0 = v[0];
886 if v0 == 0.0 {
887 return (vec![0.0; m], 0.0);
888 }
889 for entry in v.iter_mut() {
891 *entry /= v0;
892 }
893 let vv: f64 = v.iter().map(|x| x * x).sum();
895 let beta = 2.0 / vv;
896 (v, beta)
897}
898
899pub fn build_center_kernel_device(
918 centers_xyz: &[f64],
919 lmax: usize,
920 coeffs: &[f64],
921 kind: SphereSpectralKernelKind,
922) -> Result<DeviceS2KernelMatrix, GpuError> {
923 let m = centers_xyz.len() / 3;
924 if centers_xyz.len() != 3 * m {
925 return Err(GpuError::DriverCallFailed {
926 reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
927 });
928 }
929 let inputs = S2KernelBuildInputs {
930 n: m,
931 m,
932 lmax,
933 data_xyz: centers_xyz,
934 centers_xyz,
935 coeffs,
936 kind,
937 layout: DeviceMatrixLayout::ColumnMajor,
938 };
939 build_kernel_matrix_device(inputs)
940}
941
942pub fn constrained_penalty_host(
947 c: ArrayView2<'_, f64>,
948 w: &[f64],
949) -> Result<Array2<f64>, GpuError> {
950 let (m1, m2) = c.dim();
951 if m1 != m2 {
952 gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
953 }
954 let m = m1;
955 if w.len() != m {
956 gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
957 }
958 if m < 2 {
959 gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
960 }
961 let (v, beta) = householder_reflector_from_weights(w);
962
963 let mut u = vec![0.0_f64; m];
966 for i in 0..m {
967 let mut acc = 0.0_f64;
968 for j in 0..m {
969 acc += c[(i, j)] * v[j];
970 }
971 u[i] = acc;
972 }
973 let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
974 let mut hch = Array2::<f64>::zeros((m, m));
975 for i in 0..m {
976 for j in 0..m {
977 hch[(i, j)] =
978 c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
979 }
980 }
981 let mut s = Array2::<f64>::zeros((m - 1, m - 1));
983 for i in 0..(m - 1) {
984 for j in 0..(m - 1) {
985 s[(i, j)] = hch[(i + 1, j + 1)];
986 }
987 }
988 Ok(s)
989}
990
991#[derive(Clone, Debug)]
1022pub struct PenalisedLsSolution {
1023 pub beta: Vec<f64>,
1025 pub weighted_residual_ssq: f64,
1027 pub log_det_hessian: f64,
1029}
1030
1031#[cfg(target_os = "linux")]
1041pub fn solve_penalised_ls_device(
1042 x_s_device: &DeviceS2KernelMatrix,
1043 wy: &[f64],
1044 r_s: ArrayView2<'_, f64>,
1045) -> Result<PenalisedLsSolution, GpuError> {
1046 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1047 use cudarc::driver::DevicePtrMut;
1048
1049 let n = x_s_device.rows;
1050 let p = x_s_device.cols;
1051 if wy.len() != n {
1052 gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1053 }
1054 if r_s.dim() != (p, p) {
1055 gam_gpu::gpu_bail!(
1056 "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1057 r_s.dim()
1058 );
1059 }
1060 if p == 0 {
1061 return Ok(PenalisedLsSolution {
1062 beta: Vec::new(),
1063 weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1064 log_det_hessian: 0.0,
1065 });
1066 }
1067
1068 let stream = x_s_device.stream.clone();
1069 let n_aug = n + p;
1070
1071 let mut a_aug_host = vec![0.0_f64; n_aug * p];
1076 let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1078 x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1079 for j in 0..p {
1080 let src_off = j * x_s_device.ld;
1081 let dst_off = j * n_aug;
1082 a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1083 for i in 0..p {
1084 a_aug_host[dst_off + n + i] = r_s[(i, j)];
1087 }
1088 }
1089 let mut a_dev = stream
1090 .clone_htod(&a_aug_host)
1091 .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1092
1093 let mut b_host = vec![0.0_f64; n_aug];
1095 b_host[..n].copy_from_slice(wy);
1096 let mut b_dev = stream
1097 .clone_htod(&b_host)
1098 .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1099
1100 let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1101 let n_aug_i: i32 = i32::try_from(n_aug)
1102 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1103 let p_i: i32 = i32::try_from(p)
1104 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1105
1106 let mut lwork: i32 = 0;
1108 {
1109 let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1110 let status = unsafe {
1113 cusolver_sys::cusolverDnDgeqrf_bufferSize(
1114 solver.cu(),
1115 n_aug_i,
1116 p_i,
1117 a_ptr as *mut f64,
1118 n_aug_i,
1119 &mut lwork,
1120 )
1121 };
1122 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1123 gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1124 }
1125 }
1126 let lwork_us = usize::try_from(lwork)
1127 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1128 let mut workspace = stream
1129 .alloc_zeros::<f64>(lwork_us.max(1))
1130 .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1131 let mut tau = stream
1132 .alloc_zeros::<f64>(p)
1133 .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1134 let mut info = stream
1135 .alloc_zeros::<i32>(1)
1136 .gpu_ctx("solve_penalised_ls_device alloc info")?;
1137
1138 {
1140 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1141 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1142 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1143 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1144 let status = unsafe {
1147 cusolver_sys::cusolverDnDgeqrf(
1148 solver.cu(),
1149 n_aug_i,
1150 p_i,
1151 a_ptr as *mut f64,
1152 n_aug_i,
1153 tau_ptr as *mut f64,
1154 work_ptr as *mut f64,
1155 lwork,
1156 info_ptr as *mut i32,
1157 )
1158 };
1159 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1160 gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1161 }
1162 }
1163
1164 let mut ormqr_lwork: i32 = 0;
1166 {
1167 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1168 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1169 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1170 let status = unsafe {
1173 cusolver_sys::cusolverDnDormqr_bufferSize(
1174 solver.cu(),
1175 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1176 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1177 n_aug_i,
1178 1,
1179 p_i,
1180 a_ptr as *const f64,
1181 n_aug_i,
1182 tau_ptr as *const f64,
1183 b_ptr as *mut f64,
1184 n_aug_i,
1185 &mut ormqr_lwork,
1186 )
1187 };
1188 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1189 gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1190 }
1191 }
1192 if ormqr_lwork > lwork {
1193 workspace = stream
1194 .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1195 .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1196 }
1197 {
1198 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1199 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1200 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1201 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1202 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1203 let status = unsafe {
1207 cusolver_sys::cusolverDnDormqr(
1208 solver.cu(),
1209 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1210 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1211 n_aug_i,
1212 1,
1213 p_i,
1214 a_ptr as *const f64,
1215 n_aug_i,
1216 tau_ptr as *const f64,
1217 b_ptr as *mut f64,
1218 n_aug_i,
1219 work_ptr as *mut f64,
1220 ormqr_lwork.max(lwork),
1221 info_ptr as *mut i32,
1222 )
1223 };
1224 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1225 gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1226 }
1227 }
1228
1229 {
1232 use cudarc::cublas::CudaBlas;
1233 let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1234 let alpha = 1.0_f64;
1235 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1236 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1237 let handle = *blas.handle();
1242 let status = unsafe {
1243 cudarc::cublas::sys::cublasDtrsm_v2(
1244 handle,
1245 cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1246 cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1247 cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1248 cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1249 p_i,
1250 1,
1251 &alpha,
1252 a_ptr as *const f64,
1253 n_aug_i,
1254 b_ptr as *mut f64,
1255 n_aug_i,
1256 )
1257 };
1258 if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1259 gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1260 }
1261 }
1262
1263 let mut b_out = vec![0.0_f64; n_aug];
1265 stream
1266 .memcpy_dtoh(&b_dev, &mut b_out)
1267 .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1268 let mut a_back = vec![0.0_f64; n_aug * p];
1269 stream
1270 .memcpy_dtoh(&a_dev, &mut a_back)
1271 .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1272 stream
1273 .synchronize()
1274 .gpu_ctx("solve_penalised_ls_device synchronize")?;
1275
1276 let beta: Vec<f64> = b_out[..p].to_vec();
1277 let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1286
1287 let mut log_abs_r = 0.0_f64;
1289 for k in 0..p {
1290 let r_kk = a_back[k * n_aug + k];
1291 log_abs_r += r_kk.abs().ln();
1292 }
1293 let log_det_hessian = 2.0 * log_abs_r;
1294
1295 Ok(PenalisedLsSolution {
1296 beta,
1297 weighted_residual_ssq: augmented_residual_ssq,
1298 log_det_hessian,
1299 })
1300}
1301
1302#[cfg(not(target_os = "linux"))]
1303pub fn solve_penalised_ls_device(
1304 x_s_device: &DeviceS2KernelMatrix,
1305 wy: &[f64],
1306 r_s: ArrayView2<'_, f64>,
1307) -> Result<PenalisedLsSolution, GpuError> {
1308 Err(GpuError::DriverLibraryUnavailable {
1309 reason: format!(
1310 "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1311 x_s_device.rows,
1312 x_s_device.cols,
1313 wy.len(),
1314 r_s.dim()
1315 ),
1316 })
1317}
1318
1319#[cfg(test)]
1324mod sphere_gpu_tests {
1325 use super::*;
1326 use crate::basis::{
1327 SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1328 spherical_wahba_kernel_matrix_with_kind,
1329 };
1330 use ndarray::Array2;
1331
1332 fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1333 let mut rows = Vec::with_capacity(n_lat * n_lon);
1335 for i in 0..n_lat {
1336 let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1337 for j in 0..n_lon {
1338 let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1339 rows.push(lat);
1340 rows.push(lon);
1341 }
1342 }
1343 Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1344 }
1345
1346 #[test]
1347 fn xyz_preprocessing_matches_unit_sphere() {
1348 let latlon = ndarray::array![
1349 [0.0, 0.0],
1350 [90.0, 0.0],
1351 [0.0, 90.0],
1352 [-90.0, 17.5],
1353 [45.0, -120.0],
1354 ];
1355 let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1356 assert_eq!(xyz.len(), 3 * 5);
1357 for i in 0..5 {
1358 let nrm2 = xyz[3 * i] * xyz[3 * i]
1359 + xyz[3 * i + 1] * xyz[3 * i + 1]
1360 + xyz[3 * i + 2] * xyz[3 * i + 2];
1361 assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1362 }
1363 assert!((xyz[0] - 1.0).abs() < 1e-15);
1365 assert!(xyz[1].abs() < 1e-15);
1366 assert!(xyz[2].abs() < 1e-15);
1367 assert!(xyz[3].abs() < 1e-15);
1369 assert!(xyz[4].abs() < 1e-15);
1370 assert!((xyz[5] - 1.0).abs() < 1e-15);
1371 assert!(xyz[6].abs() < 1e-15);
1373 assert!((xyz[7] - 1.0).abs() < 1e-15);
1374 assert!(xyz[8].abs() < 1e-15);
1375 }
1376
1377 #[test]
1378 fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1379 for m_penalty in 1..=4 {
1383 for &lmax in &[5_usize, 20, 50] {
1384 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1385 let expected: f64 = coeffs.iter().sum();
1386 let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1387 assert!(
1388 (got - expected).abs() < 1e-13,
1389 "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1390 );
1391 }
1392 }
1393 }
1394
1395 #[test]
1396 fn truncated_spectral_at_antipode_matches_alternating_sum() {
1397 for m_penalty in 1..=4 {
1400 for &lmax in &[5_usize, 20, 50] {
1401 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1402 let expected: f64 = coeffs
1403 .iter()
1404 .enumerate()
1405 .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1406 .sum();
1407 let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1408 assert!(
1409 (got - expected).abs() < 1e-13,
1410 "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1411 );
1412 }
1413 }
1414 }
1415
1416 #[test]
1417 fn truncated_spectral_matrix_is_symmetric() {
1418 let centers = ndarray::array![
1422 [10.0_f64, 20.0],
1423 [-30.0, 100.0],
1424 [45.0, -60.0],
1425 [-89.0, 0.0],
1426 [0.0, 180.0],
1427 [60.0, -179.9],
1428 ];
1429 for m_penalty in [1usize, 2, 4] {
1430 for &lmax in &[10_usize, 30] {
1431 let mat = spherical_wahba_kernel_matrix_with_kind(
1432 centers.view(),
1433 centers.view(),
1434 m_penalty,
1435 false,
1436 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1437 )
1438 .expect("kernel matrix");
1439 let n = centers.nrows();
1440 let mut max_asym = 0.0_f64;
1441 for i in 0..n {
1442 for j in 0..n {
1443 let d = (mat[(i, j)] - mat[(j, i)]).abs();
1444 if d > max_asym {
1445 max_asym = d;
1446 }
1447 }
1448 }
1449 assert!(
1450 max_asym < 1e-13,
1451 "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1452 );
1453 }
1454 }
1455 }
1456
1457 #[test]
1458 fn truncated_coefficients_have_zero_constant_mode() {
1459 for m in 1..=4 {
1460 let c = sobolev_s2_truncated_coefficients(50, m);
1461 assert_eq!(c.len(), 51);
1462 assert_eq!(c[0], 0.0);
1463 assert!(c[1] > 0.0);
1464 for ell in 2..=50 {
1466 assert!(
1467 c[ell] < c[ell - 1] + 1e-15,
1468 "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1469 c[ell],
1470 c[ell - 1]
1471 );
1472 }
1473 }
1474 }
1475
1476 #[test]
1477 fn truncated_spectral_matches_matrix_helper() {
1478 let m_penalty = 2;
1482 let lmax = 20;
1483 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1484 let data = ndarray::array![[12.5, -34.0]];
1485 let centers = ndarray::array![[40.0, 10.0]];
1486 let mat = spherical_wahba_kernel_matrix_with_kind(
1487 data.view(),
1488 centers.view(),
1489 m_penalty,
1490 false,
1491 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1492 )
1493 .expect("kernel matrix");
1494 let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1496 let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1497 let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1498 let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1499 assert!(
1500 (mat[(0, 0)] - expected).abs() < 1e-13,
1501 "matrix helper differs from scalar evaluator: {} vs {}",
1502 mat[(0, 0)],
1503 expected
1504 );
1505 }
1506
1507 #[test]
1508 fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1509 let m = 6;
1514 let mut c = Array2::<f64>::zeros((m, m));
1515 for i in 0..m {
1516 for j in 0..m {
1517 let d = (i as f64 - j as f64).abs();
1518 c[(i, j)] = (-0.5 * d).exp();
1519 }
1520 }
1521 let w = vec![1.0_f64; m];
1522 let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1523 assert_eq!(s.dim(), (m - 1, m - 1));
1524 let mut max_asym = 0.0_f64;
1526 for i in 0..(m - 1) {
1527 for j in 0..(m - 1) {
1528 let d = (s[(i, j)] - s[(j, i)]).abs();
1529 if d > max_asym {
1530 max_asym = d;
1531 }
1532 }
1533 }
1534 assert!(
1535 max_asym < 1e-13,
1536 "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1537 );
1538
1539 let ones = ndarray::Array1::<f64>::ones(m - 1);
1547 let sx = s.dot(&ones);
1548 assert!(sx.iter().all(|v| v.is_finite()));
1549 }
1550
1551 #[test]
1552 fn householder_reflector_zeroes_target_vector() {
1553 let w = vec![3.0, 4.0, 0.0, -1.0];
1554 let (v, beta) = householder_reflector_from_weights(&w);
1555 let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1558 let hw: Vec<f64> = w
1559 .iter()
1560 .zip(&v)
1561 .map(|(wj, vj)| wj - beta * dot * vj)
1562 .collect();
1563 for entry in hw.iter().skip(1) {
1564 assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1565 }
1566 assert!(hw[0].abs() > 0.0);
1567 }
1568
1569 #[test]
1572 fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1573 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1574 eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1575 return;
1576 };
1577 SphereGpuBackend::probe()
1580 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1581
1582 let data_ll = small_latlon_grid(7, 9);
1583 let centers_ll = small_latlon_grid(5, 7);
1584 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1585 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1586 let n = data_ll.nrows();
1587 let m = centers_ll.nrows();
1588 let penalty = 2usize;
1589 let lmax = 20usize;
1590 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1591
1592 let inputs = S2KernelBuildInputs {
1593 n,
1594 m,
1595 lmax,
1596 data_xyz: &data_xyz,
1597 centers_xyz: ¢ers_xyz,
1598 coeffs: &coeffs,
1599 kind: SphereSpectralKernelKind::Sobolev,
1600 layout: DeviceMatrixLayout::ColumnMajor,
1601 };
1602 let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1603 let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1604
1605 let cpu = spherical_wahba_kernel_matrix_with_kind(
1606 data_ll.view(),
1607 centers_ll.view(),
1608 penalty,
1609 false,
1610 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1611 )
1612 .expect("cpu kernel matrix");
1613
1614 let mut max_abs = 0.0_f64;
1615 for i in 0..n {
1616 for j in 0..m {
1617 let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1618 if d > max_abs {
1619 max_abs = d;
1620 }
1621 }
1622 }
1623 assert!(
1624 max_abs < 1e-11,
1625 "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1626 );
1627 }
1628
1629 #[test]
1643 fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1644 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1645 eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1646 return;
1647 };
1648 SphereGpuBackend::probe()
1652 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1653 use crate::basis::{
1654 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1655 build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1656 spherical_wahba_kernel_matrix_with_kind,
1657 };
1658
1659 let data = small_latlon_grid(100, 100);
1661 let lmax: u16 = 30;
1662 let penalty_order = 2usize;
1663 let centers =
1664 crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1665 .expect("centers");
1666 let n = data.nrows();
1667 let m = centers.nrows();
1668
1669 let decision = sphere_kernel_decision(n, m, lmax as usize);
1673 assert!(
1674 decision.use_gpu,
1675 "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1676 (reason={}); the engagement gate regressed",
1677 decision.reason
1678 );
1679
1680 let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1682 data.view(),
1683 centers.view(),
1684 penalty_order,
1685 false,
1686 SphereWahbaKernel::SobolevTruncated { lmax },
1687 )
1688 .expect("GPU-eligible production kernel build succeeds");
1689
1690 let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1692 data.view(),
1693 centers.view(),
1694 penalty_order,
1695 false,
1696 SphereWahbaKernel::SobolevTruncated { lmax },
1697 )
1698 .expect("cpu oracle kernel build succeeds");
1699
1700 assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1701 let mut max_abs = 0.0_f64;
1702 let mut max_rel = 0.0_f64;
1703 for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1704 let d = (g - c).abs();
1705 if d > max_abs {
1706 max_abs = d;
1707 }
1708 let denom = g.abs().max(c.abs()).max(1e-300);
1709 let r = d / denom;
1710 if r > max_rel {
1711 max_rel = r;
1712 }
1713 }
1714 assert!(
1715 max_rel < 1e-9,
1716 "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1717 >= 1e-9 (abs {max_abs:.3e})"
1718 );
1719
1720 let spec_gpu = SphericalSplineBasisSpec {
1724 center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
1725 penalty_order,
1726 double_penalty: false,
1727 radians: false,
1728 method: SphereMethod::Wahba,
1729 max_degree: None,
1730 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1731 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1732 };
1733 let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
1734 .expect("GPU-eligible build_spherical_spline_basis succeeds");
1735 let design = result_gpu.design.as_dense().expect("dense design");
1736 assert_eq!(design.nrows(), n, "design row count must match data rows");
1737 assert!(
1738 design.iter().all(|v| v.is_finite()),
1739 "engaged-device spherical design must be finite"
1740 );
1741 }
1742
1743 #[test]
1746 fn sphere_gpu_householder_parity_vs_raw_dot_z() {
1747 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1748 eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
1749 return;
1750 };
1751 SphereGpuBackend::probe()
1754 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1755 let data_ll = small_latlon_grid(6, 8);
1756 let centers_ll = small_latlon_grid(4, 5);
1757 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1758 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1759 let n = data_ll.nrows();
1760 let m = centers_ll.nrows();
1761 let penalty = 2usize;
1762 let lmax = 15usize;
1763 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1764
1765 let inputs_raw = S2KernelBuildInputs {
1767 n,
1768 m,
1769 lmax,
1770 data_xyz: &data_xyz,
1771 centers_xyz: ¢ers_xyz,
1772 coeffs: &coeffs,
1773 kind: SphereSpectralKernelKind::Sobolev,
1774 layout: DeviceMatrixLayout::ColumnMajor,
1775 };
1776 let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
1777 let b = b_dev.to_host_array().expect("dtoh raw");
1778
1779 let w = vec![1.0_f64; m];
1782 let (v, beta) = householder_reflector_from_weights(&w);
1783
1784 let mut xs_host = Array2::<f64>::zeros((n, m - 1));
1786 for i in 0..n {
1787 let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
1788 for j_out in 0..(m - 1) {
1789 xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
1790 }
1791 }
1792
1793 let xs_dev =
1794 build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
1795 let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
1796
1797 let mut max_abs = 0.0_f64;
1798 for i in 0..n {
1799 for j in 0..(m - 1) {
1800 let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
1801 if d > max_abs {
1802 max_abs = d;
1803 }
1804 }
1805 }
1806 assert!(
1807 max_abs < 1e-12,
1808 "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
1809 );
1810 }
1811
1812 #[test]
1816 fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
1817 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1818 eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
1819 return;
1820 };
1821 if SphereGpuBackend::probe().is_err() {
1822 eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
1823 return;
1824 }
1825
1826 let n_lat = 500usize;
1829 let n_lon = 400usize;
1830 assert_eq!(n_lat * n_lon, 200_000);
1831 let data_ll = small_latlon_grid(n_lat, n_lon);
1832 let m = 200usize;
1833 let centers_ll =
1834 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1835 .expect("centers");
1836 let n = data_ll.nrows();
1837 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1838 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1839 let penalty_order = 2usize;
1840 let lmax = 50usize;
1841 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
1842
1843 let inputs_warm = S2KernelBuildInputs {
1845 n,
1846 m,
1847 lmax,
1848 data_xyz: &data_xyz,
1849 centers_xyz: ¢ers_xyz,
1850 coeffs: &coeffs,
1851 kind: SphereSpectralKernelKind::Sobolev,
1852 layout: DeviceMatrixLayout::ColumnMajor,
1853 };
1854 drop(build_kernel_matrix_device(inputs_warm.clone()).expect("warmup"));
1855
1856 let t0 = std::time::Instant::now();
1858 let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
1859 let _host_gpu = dev.to_host_array().expect("dtoh");
1860 let gpu_secs = t0.elapsed().as_secs_f64();
1861
1862 let t1 = std::time::Instant::now();
1864 let _cpu = spherical_wahba_kernel_matrix_with_kind(
1865 data_ll.view(),
1866 centers_ll.view(),
1867 penalty_order,
1868 false,
1869 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1870 )
1871 .expect("cpu kernel matrix");
1872 let cpu_secs = t1.elapsed().as_secs_f64();
1873
1874 let ratio = cpu_secs / gpu_secs.max(1e-9);
1875 eprintln!(
1876 "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
1877 );
1878 assert!(
1879 ratio >= 20.0,
1880 "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
1881 n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1882 );
1883 }
1884
1885 #[test]
1890 fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
1891 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1892 eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
1893 return;
1894 };
1895 if SphereGpuBackend::probe().is_err() {
1896 eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
1897 return;
1898 }
1899 use crate::basis::{
1900 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1901 build_spherical_spline_basis,
1902 };
1903
1904 let n_lat = 500usize;
1905 let n_lon = 400usize;
1906 let data_ll = small_latlon_grid(n_lat, n_lon);
1907 let m: usize = 200;
1908 let lmax: u16 = 50;
1909 let spec_gpu = SphericalSplineBasisSpec {
1910 center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
1911 penalty_order: 2,
1912 double_penalty: false,
1913 radians: false,
1914 method: SphereMethod::Wahba,
1915 max_degree: None,
1916 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1917 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1918 };
1919
1920 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
1922
1923 let t0 = std::time::Instant::now();
1924 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
1925 let gpu_secs = t0.elapsed().as_secs_f64();
1926
1927 let centers =
1934 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1935 .expect("centers");
1936 let z = Array2::<f64>::eye(centers.nrows());
1937 let t1 = std::time::Instant::now();
1938 let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1939 data_ll.view(),
1940 centers.view(),
1941 2,
1942 false,
1943 SphereWahbaKernel::SobolevTruncated { lmax },
1944 )
1945 .expect("cpu raw");
1946 let _design_cpu = raw_cpu.dot(&z);
1947 let cpu_secs = t1.elapsed().as_secs_f64();
1948
1949 let ratio = cpu_secs / gpu_secs.max(1e-9);
1950 eprintln!(
1951 "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
1952 data_ll.nrows()
1953 );
1954 assert!(
1955 ratio >= 10.0,
1956 "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
1957 cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1958 );
1959 }
1960
1961 #[test]
1983 fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
1984 use crate::basis::{
1985 select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
1986 };
1987 use faer::Side;
1988 use gam_linalg::faer_ndarray::FaerCholesky;
1989
1990 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1991 eprintln!(
1992 "[sphere gpu parity] no CUDA runtime — skipping device parity \
1993 (CPU oracle exercised by sibling tests)"
1994 );
1995 return;
1996 };
1997 SphereGpuBackend::probe()
2000 .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2001
2002 let data_ll = small_latlon_grid(25, 40);
2004 assert_eq!(data_ll.nrows(), 1000);
2005 let n = data_ll.nrows();
2006 let m: usize = 80;
2007 let lmax_u16: u16 = 15;
2008 let lmax: usize = lmax_u16 as usize;
2009 let penalty_order: usize = 2;
2010 let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2011 let lambda: f64 = 1.0e-3;
2012
2013 let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2015 .expect("farthest-point centers");
2016 assert_eq!(centers_ll.nrows(), m);
2017
2018 let z = Array2::<f64>::eye(centers_ll.nrows());
2021 let p = z.ncols();
2022 assert_eq!(p, m);
2023
2024 let k_cc = spherical_wahba_kernel_matrix_with_kind(
2029 centers_ll.view(),
2030 centers_ll.view(),
2031 penalty_order,
2032 false,
2033 kernel,
2034 )
2035 .expect("centers×centers kernel");
2036 let s_full = z.t().dot(&k_cc).dot(&z);
2037
2038 let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2040 data_ll.view(),
2041 centers_ll.view(),
2042 penalty_order,
2043 false,
2044 kernel,
2045 )
2046 .expect("CPU raw design");
2047 let x_s_cpu = raw_design_cpu.dot(&z);
2048
2049 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2051 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2052 let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2053 let inputs = S2KernelBuildInputs {
2054 n,
2055 m,
2056 lmax,
2057 data_xyz: &data_xyz,
2058 centers_xyz: ¢ers_xyz,
2059 coeffs: &coeffs,
2060 kind: SphereSpectralKernelKind::Sobolev,
2061 layout: DeviceMatrixLayout::ColumnMajor,
2062 };
2063 let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2064 let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2065 let x_s_gpu = raw_design_gpu.dot(&z);
2066
2067 assert_eq!(x_s_cpu.dim(), (n, p));
2068 assert_eq!(x_s_gpu.dim(), (n, p));
2069
2070 let mut y = ndarray::Array1::<f64>::zeros(n);
2076 for i in 0..n {
2077 let lat_rad = data_ll[(i, 0)].to_radians();
2078 let lon_rad = data_ll[(i, 1)].to_radians();
2079 y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2081 + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2082 }
2083
2084 let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2089 let xtx = x_s.t().dot(x_s);
2090 let mut a = xtx;
2091 for i in 0..p {
2092 for j in 0..p {
2093 a[(i, j)] += lambda * s_full[(i, j)];
2094 }
2095 }
2096 let rhs = x_s.t().dot(&y);
2097 let factor = a
2098 .cholesky(Side::Lower)
2099 .expect("penalised normal equations are SPD under λ > 0");
2100 factor.solvevec(&rhs)
2101 };
2102
2103 let beta_cpu = solve_penalised(&x_s_cpu);
2104 let beta_gpu = solve_penalised(&x_s_gpu);
2105 assert_eq!(beta_cpu.len(), p);
2106 assert_eq!(beta_gpu.len(), p);
2107
2108 let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2112 let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2113
2114 let mut max_beta_delta = 0.0_f64;
2115 for k in 0..p {
2116 let d = (beta_cpu[k] - beta_gpu[k]).abs();
2117 if d > max_beta_delta {
2118 max_beta_delta = d;
2119 }
2120 }
2121 let mut max_fit_delta = 0.0_f64;
2122 for i in 0..n {
2123 let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2124 if d > max_fit_delta {
2125 max_fit_delta = d;
2126 }
2127 }
2128
2129 eprintln!(
2130 "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2131 max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2132 );
2133
2134 assert!(
2135 max_beta_delta <= 1.0e-9,
2136 "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > 1e-9"
2137 );
2138 assert!(
2139 max_fit_delta <= 1.0e-9,
2140 "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2141 );
2142 }
2143}