Skip to main content

gam_terms/basis/
sphere_gpu.rs

1//! GPU NVRTC Wahba intrinsic-S2 kernel matrix construction.
2//!
3//! This module owns the device-side construction of the Wahba reproducing
4//! kernel basis matrix on the 2-sphere using the **finite truncated
5//! spectral Legendre series**
6//!
7//! `K_L(γ) = Σ_{ℓ=1..L} c_ℓ · P_ℓ(cos γ)`,
8//!
9//! evaluated entry-by-entry against the 3-term Legendre recurrence kept
10//! in registers. The host CPU parity target is the matching
11//! `SphereWahbaKernel::SobolevTruncated { lmax }` /
12//! `SphereWahbaKernel::PseudoTruncated { lmax }` variant added to
13//! `src/terms/basis.rs` (single source: same recurrence, same c_ℓ).
14//!
15//! The device path evaluates the raw column-major kernel matrix with `f64`
16//! Legendre recurrence math. Host code owns centering, constraints, and solver
17//! assembly in `basis.rs`.
18
19use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36/// Which truncated-spectral Wahba kernel to evaluate on device. Matches
37/// the CPU `SphereWahbaKernel::{SobolevTruncated, PseudoTruncated}` so
38/// parity tests are well-defined.
39#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41    /// `c_ℓ = (2ℓ+1) / (4π · [ℓ(ℓ+1)]^m)` — true `H^m(S²)` Sobolev RKHS.
42    Sobolev,
43    /// `c_ℓ = 2 / (4π · Π_{k=1..m+1}(ℓ + k))` — Wahba 1981 pseudo-spline.
44    Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48    /// `c_0 = 0`, `c_ℓ = c_ℓ(m)` for `ℓ = 1..=lmax`. Returned vector has
49    /// length `lmax + 1` and is uploaded verbatim to constant/global
50    /// memory before kernel launch.
51    pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52        match self {
53            SphereSpectralKernelKind::Sobolev => {
54                crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55            }
56            SphereSpectralKernelKind::Pseudo => {
57                crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58            }
59        }
60    }
61
62    /// Stable string tag used in the NVRTC module cache key + logs.
63    pub const fn tag(self) -> &'static str {
64        match self {
65            SphereSpectralKernelKind::Sobolev => "sobolev",
66            SphereSpectralKernelKind::Pseudo => "pseudo",
67        }
68    }
69}
70
71/// Layout of the (n,m) kernel design matrix on device. The Wahba
72/// pipeline downstream of this kernel (cuBLAS GEMM, cuSOLVER GEQRF)
73/// requires column-major.
74#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76    ColumnMajor,
77}
78
79/// Lat/lon (degrees or radians) → unit vector `(x, y, z)` on S² ⊂ ℝ³.
80/// Returns a flat `Vec<f64>` of length `3 * n` in the row-major layout
81/// `[x_0, y_0, z_0, x_1, y_1, z_1, …]`, ready for one `htod` upload.
82///
83/// `radians = false` interprets inputs as degrees (the codebase default
84/// for `SphericalSplineBasisSpec`).
85pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86    if latlon.ncols() != 2 {
87        return Err(format!(
88            "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89            latlon.shape()
90        ));
91    }
92    let deg = if radians {
93        1.0
94    } else {
95        std::f64::consts::PI / 180.0
96    };
97    let n = latlon.nrows();
98    let mut out = Vec::with_capacity(3 * n);
99    for row in latlon.outer_iter() {
100        let lat = row[0] * deg;
101        let lon = row[1] * deg;
102        let (s_lat, c_lat) = lat.sin_cos();
103        let (s_lon, c_lon) = lon.sin_cos();
104        // Standard geographic→cartesian: pole on +z.
105        out.push(c_lat * c_lon);
106        out.push(c_lat * s_lon);
107        out.push(s_lat);
108    }
109    Ok(out)
110}
111
112/// Device-resident `(rows × cols)` matrix in column-major layout with
113/// leading dimension `ld ≥ rows`. The slice holds `ld * cols` `f64`
114/// elements; entry `(i, j)` lives at `col_major_dev[j * ld + i]`.
115///
116/// On non-Linux builds the type is intentionally a host shadow so the
117/// surrounding orchestration compiles without cudarc.
118#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120    pub rows: usize,
121    pub cols: usize,
122    pub ld: usize,
123    pub col_major_dev: CudaSlice<f64>,
124    pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129    pub rows: usize,
130    pub cols: usize,
131    pub ld: usize,
132    /// Host shadow for CPU-only builds.
133    pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137    /// Copy the device matrix back to the host as a regular ndarray
138    /// `(rows × cols)` row-major view. Convenience for tests + parity
139    /// comparisons; production paths should keep the matrix resident.
140    ///
141    /// The device matrix is `(ld × cols)` column-major; the host wants
142    /// `(rows × cols)` row-major. Two costs dominate this round-trip on the
143    /// real V100:
144    ///   1. the device→host copy of the full `ld·cols·8 B` payload, and
145    ///   2. the column-major→row-major transpose.
146    /// On Linux the dtoh is staged through a *cacheable* pinned host buffer
147    /// (see [`PinnedF64`]) so the DMA runs at full PCIe bandwidth (~10 GB/s)
148    /// instead of the ~1.3 GB/s the driver achieves staging a pageable
149    /// destination, and the subsequent host reads during the transpose hit
150    /// L1/L2 normally (unlike write-combined pinned memory). The transpose
151    /// itself is the parallel cache-blocked [`col_major_to_row_major_parallel`].
152    #[cfg(target_os = "linux")]
153    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
154        let needed = self.ld * self.cols;
155        let mut staging = PinnedLease::acquire(self.stream.context(), needed)?;
156        self.stream
157            .memcpy_dtoh(&self.col_major_dev, staging.as_mut_slice())
158            .gpu_ctx("DeviceS2KernelMatrix dtoh (pinned)")?;
159        self.stream
160            .synchronize()
161            .gpu_ctx("DeviceS2KernelMatrix synchronize (pinned)")?;
162        Ok(col_major_to_row_major_parallel(
163            staging.as_slice(),
164            self.rows,
165            self.cols,
166            self.ld,
167        ))
168    }
169
170    #[cfg(not(target_os = "linux"))]
171    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
172        // Mirror the linux `to_host_array` exactly so both platforms return the
173        // identical row-major layout: pull the padded `(ld × cols)` column-major
174        // payload, then run the cache-blocked parallel transpose.
175        let mut col_major = vec![0.0_f64; self.ld * self.cols];
176        self.copy_to_host_col_major(&mut col_major)?;
177        Ok(col_major_to_row_major_parallel(
178            &col_major, self.rows, self.cols, self.ld,
179        ))
180    }
181
182    /// Copy the underlying `(ld × cols)` column-major payload to a
183    /// caller-provided buffer. Used by `to_host_array` and by the
184    /// device-resident cuSOLVER consumer when it needs to extract the
185    /// coefficient vector.
186    #[cfg(target_os = "linux")]
187    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
188        let needed = self.ld * self.cols;
189        if dst.len() != needed {
190            gam_gpu::gpu_bail!(
191                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
192                dst.len(),
193                needed
194            );
195        }
196        self.stream
197            .memcpy_dtoh(&self.col_major_dev, dst)
198            .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
199        self.stream
200            .synchronize()
201            .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
202        Ok(())
203    }
204
205    #[cfg(not(target_os = "linux"))]
206    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
207        let needed = self.ld * self.cols;
208        if dst.len() != needed {
209            gam_gpu::gpu_bail!(
210                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
211                dst.len(),
212                needed
213            );
214        }
215        dst.copy_from_slice(&self.col_major_dev);
216        Ok(())
217    }
218}
219
220/// Convert a `(ld × cols)` column-major device payload into a row-major
221/// `(rows × cols)` host `Array2`, in parallel with a cache-blocked tiled
222/// transpose.
223///
224/// Entry `(i, j)` lives at `col_major[j * ld + i]` and must land at
225/// `out[i * cols + j]`. A naive scalar `out[(i, j)] = col_major[j*ld+i]`
226/// loop over an `n·m` design (e.g. 200_000 × 200 ⇒ 320 MB) is utterly
227/// cache-hostile — the read stride is `ld` doubles — and measured at ~9 s,
228/// which alone made the GPU path lose to CPU. Here we:
229///   * tile the output rows into blocks small enough that one block's
230///     output stays L2-resident (`BLOCK_ROWS` rows × `cols` doubles),
231///   * read each source column slice contiguously (`col_major[j*ld+r0..]`),
232///   * run the row-blocks across the rayon pool.
233/// Reads are fully sequential per column; writes are bounded to the hot
234/// block. This drops the transpose from seconds to tens of milliseconds.
235fn col_major_to_row_major_parallel(
236    col_major: &[f64],
237    rows: usize,
238    cols: usize,
239    ld: usize,
240) -> Array2<f64> {
241    use rayon::prelude::*;
242
243    assert!(ld >= rows, "ld {ld} must be >= rows {rows}");
244    assert!(
245        col_major.len() >= ld * cols,
246        "col_major len {} < ld*cols {}",
247        col_major.len(),
248        ld * cols
249    );
250
251    // Block size chosen so one output block (BLOCK_ROWS × cols × 8 B) plus the
252    // source column slices stay roughly within L2 for the common `cols ≲ 200`.
253    const BLOCK_ROWS: usize = 128;
254
255    let mut out_flat = vec![0.0_f64; rows * cols];
256    out_flat
257        .par_chunks_mut(BLOCK_ROWS * cols)
258        .enumerate()
259        .for_each(|(block_idx, out_block)| {
260            let r0 = block_idx * BLOCK_ROWS;
261            let block_rows = out_block.len() / cols;
262            for j in 0..cols {
263                let base = j * ld + r0;
264                let src_col = &col_major[base..base + block_rows];
265                // Strided write within the hot block; contiguous column read.
266                for (local_i, &v) in src_col.iter().enumerate() {
267                    out_block[local_i * cols + j] = v;
268                }
269            }
270        });
271
272    Array2::from_shape_vec((rows, cols), out_flat)
273        .expect("row-major buffer has rows*cols elements")
274}
275
276/// RAII handle for a *cacheable* page-locked (pinned) host `f64` buffer.
277///
278/// cudarc's `CudaContext::alloc_pinned` always passes
279/// `CU_MEMHOSTALLOC_WRITECOMBINED`, which is excellent for host→device
280/// uploads but pathological for the host *reads* the transpose performs
281/// (write-combined memory is uncached on the CPU side). For the device→host
282/// return path we instead allocate plain pinned memory (`flags = 0`) directly
283/// via the driver: pinned so the dtoh DMA runs at full PCIe bandwidth, and
284/// cacheable so the parallel transpose can read it through the normal cache
285/// hierarchy. The buffer is freed with `cuMemFreeHost` on drop.
286#[cfg(target_os = "linux")]
287struct PinnedF64 {
288    ptr: *mut f64,
289    len: usize,
290    freed: bool,
291}
292
293#[cfg(target_os = "linux")]
294impl PinnedF64 {
295    /// Allocate `len` cacheable pinned `f64`s. Binds the context to the
296    /// calling thread first (required before any driver allocation call).
297    fn alloc(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
298        ctx.bind_to_thread()
299            .gpu_ctx("PinnedF64 bind_to_thread")?;
300        let bytes = len
301            .checked_mul(std::mem::size_of::<f64>())
302            .ok_or_else(|| gam_gpu::gpu_err!("PinnedF64: len={len} byte size overflows usize"))?;
303        // flags = 0 ⇒ cacheable pinned (NOT write-combined): fast DMA *and*
304        // fast host reads for the subsequent transpose.
305        // SAFETY: `bytes` is a valid non-overflowing size; the returned host
306        // pointer is owned by this struct and freed exactly once in `drop`.
307        let raw = unsafe { cudarc::driver::result::malloc_host(bytes, 0) }
308            .gpu_ctx("PinnedF64 cuMemHostAlloc")?;
309        let ptr = raw as *mut f64;
310        if ptr.is_null() {
311            gam_gpu::gpu_bail!("PinnedF64: cuMemHostAlloc returned null for {bytes} bytes");
312        }
313        Ok(Self {
314            ptr,
315            len,
316            freed: false,
317        })
318    }
319
320    fn as_mut_slice(&mut self) -> &mut [f64] {
321        // SAFETY: `ptr` points to `len` f64s of live pinned memory owned by
322        // self; the borrow is bounded by `&mut self`.
323        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
324    }
325
326    fn as_slice(&self) -> &[f64] {
327        // SAFETY: as above; shared borrow bounded by `&self`.
328        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
329    }
330}
331
332#[cfg(target_os = "linux")]
333impl Drop for PinnedF64 {
334    fn drop(&mut self) {
335        if self.freed {
336            return;
337        }
338        self.freed = true;
339        // SAFETY: `ptr` was returned by `cuMemHostAlloc` in `alloc` and is
340        // freed exactly once (guarded by `freed`). A free failure during Drop
341        // is unrecoverable here; absorb it (the host process is tearing the
342        // allocation down regardless) without unwinding out of Drop.
343        unsafe { cudarc::driver::result::free_host(self.ptr as *mut std::ffi::c_void) }.ok();
344    }
345}
346
347// SAFETY: `PinnedF64` owns a single raw host allocation. The pointer is only
348// dereferenced by the thread holding the (mutable or shared) borrow; the pool
349// below moves the *handle* between threads while no borrow is outstanding, and
350// the rayon transpose only ever sees a `&[f64]` (already `Send + Sync`). The
351// raw pointer itself is never shared concurrently.
352#[cfg(target_os = "linux")]
353unsafe impl Send for PinnedF64 {}
354
355/// Bounded free-list of cacheable pinned host buffers, keyed by length.
356///
357/// Page-locking 320 MB via `cuMemHostAlloc` costs ~140 ms on the V100 — far
358/// more than the dtoh (~25 ms) it accelerates. During a REML fit the sphere
359/// design matrix is rebuilt and copied back at the *same* `(ld·cols)` size on
360/// every outer iteration, so caching the page-locked buffer turns that 140 ms
361/// into a one-time cost. The pool keeps at most [`PINNED_POOL_MAX_BUFFERS`]
362/// buffers (LRU-ish: oldest dropped first) to bound resident pinned memory.
363#[cfg(target_os = "linux")]
364const PINNED_POOL_MAX_BUFFERS: usize = 4;
365
366#[cfg(target_os = "linux")]
367static PINNED_POOL: OnceLock<Mutex<Vec<PinnedF64>>> = OnceLock::new();
368
369/// RAII lease of a pooled pinned buffer. Returns the buffer to [`PINNED_POOL`]
370/// on drop instead of freeing it, so the next same-size request reuses the
371/// page-locked allocation.
372#[cfg(target_os = "linux")]
373struct PinnedLease {
374    buf: Option<PinnedF64>,
375}
376
377#[cfg(target_os = "linux")]
378impl PinnedLease {
379    /// Acquire a pinned buffer of at least `len` f64s, reusing a pooled one of
380    /// exactly `len` when available, else allocating fresh.
381    fn acquire(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
382        let pool = PINNED_POOL.get_or_init(|| Mutex::new(Vec::new()));
383        if let Ok(mut guard) = pool.lock() {
384            if let Some(pos) = guard.iter().position(|b| b.len == len) {
385                return Ok(Self {
386                    buf: Some(guard.swap_remove(pos)),
387                });
388            }
389        }
390        Ok(Self {
391            buf: Some(PinnedF64::alloc(ctx, len)?),
392        })
393    }
394
395    fn as_mut_slice(&mut self) -> &mut [f64] {
396        self.buf
397            .as_mut()
398            .expect("PinnedLease buffer present until drop")
399            .as_mut_slice()
400    }
401
402    fn as_slice(&self) -> &[f64] {
403        self.buf
404            .as_ref()
405            .expect("PinnedLease buffer present until drop")
406            .as_slice()
407    }
408}
409
410#[cfg(target_os = "linux")]
411impl Drop for PinnedLease {
412    fn drop(&mut self) {
413        let Some(buf) = self.buf.take() else {
414            return;
415        };
416        if let Some(pool) = PINNED_POOL.get() {
417            if let Ok(mut guard) = pool.lock() {
418                if guard.len() < PINNED_POOL_MAX_BUFFERS {
419                    guard.push(buf);
420                    return;
421                }
422                // Pool full: evict the oldest cached buffer to make room for
423                // this (most-recently-used) one, keeping resident pinned memory
424                // bounded while favouring the hot size.
425                guard.remove(0);
426                guard.push(buf);
427                return;
428            }
429        }
430        // No pool / poisoned lock: fall back to freeing via PinnedF64::drop.
431        drop(buf);
432    }
433}
434
435// ────────────────────────────────────────────────────────────────────────
436// Inputs
437// ────────────────────────────────────────────────────────────────────────
438
439/// Host-side inputs needed to launch `s2_wahba_legendre_colmajor`.
440///
441/// `data_xyz` and `centers_xyz` are flat row-major
442/// `[x_0, y_0, z_0, …]` length `3 * n` and `3 * m` respectively, pre-
443/// computed via [`latlon_to_xyz_host`]. `coeffs` has length `lmax + 1`,
444/// indexed as `coeffs[ℓ] = c_ℓ` with `c_0 = 0`.
445#[derive(Clone, Debug)]
446pub struct S2KernelBuildInputs<'a> {
447    pub n: usize,
448    pub m: usize,
449    pub lmax: usize,
450    pub data_xyz: &'a [f64],
451    pub centers_xyz: &'a [f64],
452    pub coeffs: &'a [f64],
453    pub kind: SphereSpectralKernelKind,
454    pub layout: DeviceMatrixLayout,
455}
456
457impl<'a> S2KernelBuildInputs<'a> {
458    fn validate(&self) -> Result<(), GpuError> {
459        if self.lmax == 0 {
460            return Err(GpuError::DriverCallFailed {
461                reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
462            });
463        }
464        if self.data_xyz.len() != 3 * self.n {
465            gam_gpu::gpu_bail!(
466                "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
467                self.data_xyz.len(),
468                3 * self.n
469            );
470        }
471        if self.centers_xyz.len() != 3 * self.m {
472            gam_gpu::gpu_bail!(
473                "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
474                self.centers_xyz.len(),
475                3 * self.m
476            );
477        }
478        if self.coeffs.len() != self.lmax + 1 {
479            gam_gpu::gpu_bail!(
480                "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
481                self.coeffs.len(),
482                self.lmax + 1
483            );
484        }
485        if self.coeffs[0] != 0.0 {
486            return Err(GpuError::DriverCallFailed {
487                reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
488            });
489        }
490        Ok(())
491    }
492}
493
494// ────────────────────────────────────────────────────────────────────────
495// NVRTC kernel source — raw and Householder-fused variants.
496//
497// Both compile with `--std=c++17 --gpu-architecture=compute_${cc}` and
498// take LMAX as a compile-time `#define`. Block (32, 8, 1), shared-mem
499// tiles for one data row × 3 doubles per warp and one center × 3
500// doubles per warp.
501// ────────────────────────────────────────────────────────────────────────
502
503#[cfg(target_os = "linux")]
504const KERNEL_TEMPLATE: &str = r#"
505// LMAX is supplied by the host via a `#define LMAX ...` prepended to
506// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
507extern "C" __global__
508__launch_bounds__(256)
509void s2_wahba_legendre_colmajor(
510    const double* __restrict__ data_xyz,    // n × 3 (row-major flat)
511    const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
512    const double* __restrict__ coeffs,      // length LMAX + 1, coeffs[0] = 0
513    int n,
514    int m,
515    long long ld,
516    double* __restrict__ out                // ld × m column-major
517) {
518    const int i = blockIdx.y * blockDim.y + threadIdx.y;
519    const int j = blockIdx.x * blockDim.x + threadIdx.x;
520    if (i >= n || j >= m) return;
521
522    // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
523    const double xi = data_xyz[3 * i + 0];
524    const double yi = data_xyz[3 * i + 1];
525    const double zi = data_xyz[3 * i + 2];
526    const double cxj = centers_xyz[3 * j + 0];
527    const double cyj = centers_xyz[3 * j + 1];
528    const double czj = centers_xyz[3 * j + 2];
529
530    // t = clamp(x_i · z_j, -1, +1).
531    double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
532    if (t >  1.0) t =  1.0;
533    if (t < -1.0) t = -1.0;
534
535    // Legendre 3-term recurrence in registers.
536    // P_0(t) = 1, P_1(t) = t.
537    double p_prev = 1.0;
538    double p_curr = t;
539    double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
540
541    #pragma unroll 8
542    for (int ell = 1; ell < LMAX; ++ell) {
543        const double lf  = (double) ell;
544        const double inv = 1.0 / (lf + 1.0);
545        // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
546        const double p_next =
547            fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
548        acc = fma(coeffs[ell + 1], p_next, acc);
549        p_prev = p_curr;
550        p_curr = p_next;
551    }
552
553    out[(long long) j * ld + (long long) i] = acc;
554}
555
556// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
557// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
558// i.e. drop the first column after applying Z. Each thread computes one
559// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
560// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
561// for j_out in 0..m-1.
562//
563// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
564// over centers in an inner loop — register-bound by the per-row state
565// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
566extern "C" __global__
567__launch_bounds__(128)
568void s2_wahba_householder_constrained_colmajor(
569    const double* __restrict__ data_xyz,    // n × 3
570    const double* __restrict__ centers_xyz, // m × 3
571    const double* __restrict__ coeffs,      // length LMAX + 1
572    const double* __restrict__ v,           // length m, Householder vector
573    double beta,
574    int n,
575    int m,
576    long long ld_out,
577    double* __restrict__ out                // ld_out × (m-1) column-major
578) {
579    const int i = blockIdx.x * blockDim.x + threadIdx.x;
580    if (i >= n) return;
581
582    const double xi = data_xyz[3 * i + 0];
583    const double yi = data_xyz[3 * i + 1];
584    const double zi = data_xyz[3 * i + 2];
585
586    // Pass 1: compute d_i = sum_j v[j] * B[i, j].
587    double d_i = 0.0;
588    for (int j = 0; j < m; ++j) {
589        const double cxj = centers_xyz[3 * j + 0];
590        const double cyj = centers_xyz[3 * j + 1];
591        const double czj = centers_xyz[3 * j + 2];
592        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
593        if (t >  1.0) t =  1.0;
594        if (t < -1.0) t = -1.0;
595
596        double p_prev = 1.0;
597        double p_curr = t;
598        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
599        #pragma unroll 8
600        for (int ell = 1; ell < LMAX; ++ell) {
601            const double lf  = (double) ell;
602            const double inv = 1.0 / (lf + 1.0);
603            const double p_next =
604                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
605            acc = fma(coeffs[ell + 1], p_next, acc);
606            p_prev = p_curr;
607            p_curr = p_next;
608        }
609        d_i = fma(v[j], acc, d_i);
610    }
611
612    // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
613    const double bd = beta * d_i;
614    for (int j_out = 0; j_out < m - 1; ++j_out) {
615        const int j = j_out + 1;
616        const double cxj = centers_xyz[3 * j + 0];
617        const double cyj = centers_xyz[3 * j + 1];
618        const double czj = centers_xyz[3 * j + 2];
619        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
620        if (t >  1.0) t =  1.0;
621        if (t < -1.0) t = -1.0;
622
623        double p_prev = 1.0;
624        double p_curr = t;
625        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
626        #pragma unroll 8
627        for (int ell = 1; ell < LMAX; ++ell) {
628            const double lf  = (double) ell;
629            const double inv = 1.0 / (lf + 1.0);
630            const double p_next =
631                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
632            acc = fma(coeffs[ell + 1], p_next, acc);
633            p_prev = p_curr;
634            p_curr = p_next;
635        }
636        const double xs = acc - bd * v[j];
637        out[(long long) j_out * ld_out + (long long) i] = xs;
638    }
639}
640"#;
641
642// ────────────────────────────────────────────────────────────────────────
643// Module cache key + per-process backend.
644// ────────────────────────────────────────────────────────────────────────
645
646/// Module cache key: every distinct `(CC, LMAX, kind, layout, kernel
647/// flavor)` compiles to a different PTX. `precision = f64` and the
648/// (32, 8, 1) raw-kernel block / (128, 1, 1) Householder-kernel block
649/// shapes are baked into the kernel source so they are implicit in the
650/// flavor tag and don't appear here.
651#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
652pub struct S2ModuleCacheKey {
653    pub cc_major: i32,
654    pub cc_minor: i32,
655    pub lmax: u32,
656    pub kind: SphereSpectralKernelKind,
657    pub layout: DeviceMatrixLayout,
658}
659
660/// Returns `true` if this build was compiled with the Linux + cudarc GPU
661/// backend that runs the S² Wahba kernels.
662pub const fn sphere_gpu_compiled() -> bool {
663    cfg!(target_os = "linux")
664}
665
666/// Decide whether the GPU sphere kernel matrix path is eligible for
667/// `(n, m, lmax)`. Heuristic per the math spec:
668///   * `n * m >= 1_000_000`
669///   * `lmax <= 200`
670///   * device memory budget admits at least one `(ld × m)` design at
671///     `ld = ((n + 31) / 32) * 32`.
672#[must_use]
673pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
674    let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
675        let ld = ((n + 31) / 32) * 32;
676        let needed_bytes = ld
677            .saturating_mul(m)
678            .saturating_mul(std::mem::size_of::<f64>());
679        let budget = runtime.memory_budget_bytes;
680        n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
681    } else {
682        false
683    };
684    decide(
685        GpuKernel::SpatialKernelOperator,
686        gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
687    )
688}
689
690/// Map a truncated `SphereWahbaKernel` variant onto the device kernel kind +
691/// truncation degree. Only the two *truncated* spectral variants have an exact
692/// device counterpart (the closed-form `Sobolev`/`Pseudo` variants use
693/// polylogarithms / deep-`L` series the device kernel does not evaluate), so
694/// `Sobolev`/`Pseudo` return `None` and stay on the CPU closed-form path.
695#[must_use]
696pub fn truncated_device_kind(
697    kernel: crate::basis::SphereWahbaKernel,
698) -> Option<(SphereSpectralKernelKind, u16)> {
699    use crate::basis::SphereWahbaKernel;
700    match kernel {
701        SphereWahbaKernel::SobolevTruncated { lmax } => {
702            Some((SphereSpectralKernelKind::Sobolev, lmax))
703        }
704        SphereWahbaKernel::PseudoTruncated { lmax } => {
705            Some((SphereSpectralKernelKind::Pseudo, lmax))
706        }
707        SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
708    }
709}
710
711/// Production entry: build the raw `(n × m)` truncated-spectral Wahba kernel
712/// design matrix on the GPU when [`sphere_kernel_decision`] admits the device,
713/// returning `None` to signal the caller to use its CPU oracle.
714///
715/// Contract:
716///   * Returns `None` when the kernel is a non-truncated closed-form variant
717///     (no exact device counterpart), or when the dispatch decision keeps the
718///     work on the CPU (`!use_gpu`). The caller then runs the bit-defining CPU
719///     path. This is the **only** quiet-CPU route and it is taken *before* any
720///     device call — never as a silent fallback after a device failure.
721///   * Returns `Some(Ok(matrix))` with the device-computed host array when the
722///     device path ran and matches the CPU truncated recurrence to roundoff
723///     (proven by the parity tests). `gam_gpu::policy` keeps the same `c_ℓ`
724///     array and the same Legendre 3-term recurrence on both sides.
725///   * Returns `Some(Err(_))` when the device was *admitted* but the launch /
726///     NVRTC compile / copy-back failed — a hard error the caller must surface,
727///     NOT degrade to CPU. Fail-loud once admitted (the recurring silent-CPU
728///     fallback is the bug this path exists to kill).
729///
730/// `data` / `centers` are `(_, 2)` lat/lon matrices (degrees unless
731/// `radians`), matching `spherical_wahba_kernel_matrix_with_kind`.
732pub fn try_build_truncated_kernel_matrix_gpu(
733    data: ArrayView2<'_, f64>,
734    centers: ArrayView2<'_, f64>,
735    penalty_order: usize,
736    radians: bool,
737    kernel: crate::basis::SphereWahbaKernel,
738) -> Option<Result<Array2<f64>, GpuError>> {
739    let (kind, lmax) = truncated_device_kind(kernel)?;
740    let n = data.nrows();
741    let m = centers.nrows();
742    if n == 0 || m == 0 || lmax == 0 {
743        return None;
744    }
745    let decision = sphere_kernel_decision(n, m, lmax as usize);
746    if !decision.use_gpu {
747        // Either backend-not-compiled, runtime-unavailable, or below the
748        // device-work threshold. Quiet CPU route, taken before any device call.
749        return None;
750    }
751    // Admitted: from here a failure is a hard error, never a silent CPU degrade.
752    Some(build_truncated_kernel_matrix_gpu_admitted(
753        data,
754        centers,
755        penalty_order,
756        radians,
757        kind,
758        lmax,
759    ))
760}
761
762/// Run the admitted device build for `try_build_truncated_kernel_matrix_gpu`.
763/// Separated so the admission decision (which returns `None` for the CPU route)
764/// stays distinct from the fail-loud device execution (which returns `Err`).
765fn build_truncated_kernel_matrix_gpu_admitted(
766    data: ArrayView2<'_, f64>,
767    centers: ArrayView2<'_, f64>,
768    penalty_order: usize,
769    radians: bool,
770    kind: SphereSpectralKernelKind,
771    lmax: u16,
772) -> Result<Array2<f64>, GpuError> {
773    let n = data.nrows();
774    let m = centers.nrows();
775    let data_xyz = latlon_to_xyz_host(data, radians)
776        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
777    let centers_xyz = latlon_to_xyz_host(centers, radians)
778        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
779    // Single-source the coefficients: the same `c_ℓ` array the CPU truncated
780    // recurrence consumes (`wahba_sphere_kernel_from_cos_kind`) is uploaded to
781    // the device, so CPU and GPU evaluate an identical zonal series.
782    let coeffs = kind.coefficients(lmax as usize, penalty_order);
783    let inputs = S2KernelBuildInputs {
784        n,
785        m,
786        lmax: lmax as usize,
787        data_xyz: &data_xyz,
788        centers_xyz: &centers_xyz,
789        coeffs: &coeffs,
790        kind,
791        layout: DeviceMatrixLayout::ColumnMajor,
792    };
793    let device_matrix = build_kernel_matrix_device(inputs)?;
794    let out = device_matrix.to_host_array()?;
795    // Guard against a device kernel that emitted NaN/Inf. A whole-matrix sum is
796    // poisoned by any non-finite element (`NaN + x = NaN`, `±Inf + finite =
797    // ±Inf`) and folds the `(n × m)` matrix in a single auto-vectorisable pass,
798    // ~7× faster than a per-element `any(!is_finite)` in the unoptimised
799    // profile (at n=200000, m=200 that scan alone was ~1.8 s — far more than
800    // the entire on-device build). The Wahba zonal kernel is a truncated
801    // Legendre series `Σ c_ℓ P_ℓ(t)` with `|P_ℓ| ≤ 1` and absolutely-summable
802    // coefficients, so every entry is O(1) and the sum of `n·m ≲ 10^8` of them
803    // cannot overflow f64 — a non-finite sum therefore means a genuinely
804    // non-finite entry, never a spurious overflow.
805    if !out.sum().is_finite() {
806        return Err(GpuError::DriverCallFailed {
807            reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
808        });
809    }
810    Ok(out)
811}
812
813#[cfg(target_os = "linux")]
814struct SphereGpuContext {
815    ctx: Arc<CudaContext>,
816    stream: Arc<CudaStream>,
817    modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
818    cc_major: i32,
819    cc_minor: i32,
820}
821
822/// Process-wide sphere GPU backend. Lazy-initialised on first call to
823/// [`SphereGpuBackend::probe`].
824pub struct SphereGpuBackend {
825    #[cfg(target_os = "linux")]
826    inner: SphereGpuContext,
827}
828
829impl SphereGpuBackend {
830    /// Lazily initialise the process-wide sphere backend.
831    pub fn probe() -> Result<&'static Self, GpuError> {
832        static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
833        BACKEND
834            .get_or_init(|| {
835                #[cfg(target_os = "linux")]
836                {
837                    Self::probe_linux()
838                }
839                #[cfg(not(target_os = "linux"))]
840                {
841                    Err(GpuError::DriverLibraryUnavailable {
842                        reason: "sphere GPU backend is Linux-only".to_string(),
843                    })
844                }
845            })
846            .as_ref()
847            .map_err(GpuError::clone)
848    }
849
850    #[cfg(target_os = "linux")]
851    fn probe_linux() -> Result<Self, GpuError> {
852        let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
853        Ok(SphereGpuBackend {
854            inner: SphereGpuContext {
855                ctx: parts.ctx,
856                stream: parts.stream,
857                modules: Mutex::new(HashMap::new()),
858                cc_major: parts.capability.compute_major,
859                cc_minor: parts.capability.compute_minor,
860            },
861        })
862    }
863
864    /// NVRTC-compile (or fetch from cache) the module for `key`. The
865    /// returned module exposes both raw and Householder-fused kernels.
866    #[cfg(target_os = "linux")]
867    fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
868        if let Ok(guard) = self.inner.modules.lock() {
869            if let Some(existing) = guard.get(&key) {
870                return Ok(existing.clone());
871            }
872        }
873        // Prepend the `LMAX` macro directly to the source, then compile through
874        // the shared arch+fmad options (`compile_ptx_arch`). #1686's
875        // `--fmad=false` keeps the spherical-harmonic evaluation bit-comparable
876        // to the separately-rounded CPU reference; the #1551 arch pin keys the
877        // kernel to the device's real compute capability. (The arch is resolved
878        // internally via `nvrtc_arch()` from a `&'static str` table, so the old
879        // "cannot satisfy arch with a runtime string" limitation no longer
880        // applies — the LMAX specialization rides in the source, the arch in
881        // the options.)
882        let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
883        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).gpu_ctx_with(|err| {
884            format!(
885                "sphere NVRTC compile (kind={}, lmax={}): {err}",
886                key.kind.tag(),
887                key.lmax
888            )
889        })?;
890        let module = self
891            .inner
892            .ctx
893            .load_module(ptx)
894            .gpu_ctx("sphere module load")?;
895        if let Ok(mut guard) = self.inner.modules.lock() {
896            guard.entry(key).or_insert_with(|| module.clone());
897        }
898        Ok(module)
899    }
900
901    #[cfg(target_os = "linux")]
902    fn cc(&self) -> (i32, i32) {
903        (self.inner.cc_major, self.inner.cc_minor)
904    }
905}
906
907// ────────────────────────────────────────────────────────────────────────
908// Entry points
909// ────────────────────────────────────────────────────────────────────────
910
911/// Build the raw `(n × m)` Wahba kernel matrix on device using
912/// `s2_wahba_legendre_colmajor`. Phase 1 entry point.
913pub fn build_kernel_matrix_device(
914    inputs: S2KernelBuildInputs<'_>,
915) -> Result<DeviceS2KernelMatrix, GpuError> {
916    inputs.validate()?;
917
918    #[cfg(target_os = "linux")]
919    {
920        use cudarc::driver::{LaunchConfig, PushKernelArg};
921        let backend = SphereGpuBackend::probe()?;
922        let (cc_major, cc_minor) = backend.cc();
923        let key = S2ModuleCacheKey {
924            cc_major,
925            cc_minor,
926            lmax: inputs.lmax as u32,
927            kind: inputs.kind,
928            layout: inputs.layout,
929        };
930        let module = backend.module_for(key)?;
931        let func = module
932            .load_function("s2_wahba_legendre_colmajor")
933            .gpu_ctx("sphere load_function raw")?;
934        let stream = backend.inner.stream.clone();
935
936        let data_dev = stream
937            .clone_htod(inputs.data_xyz)
938            .gpu_ctx("sphere htod data_xyz")?;
939        let centers_dev = stream
940            .clone_htod(inputs.centers_xyz)
941            .gpu_ctx("sphere htod centers_xyz")?;
942        let coeffs_dev = stream
943            .clone_htod(inputs.coeffs)
944            .gpu_ctx("sphere htod coeffs")?;
945
946        let n = inputs.n;
947        let m = inputs.m;
948        let ld = ((n + 31) / 32) * 32;
949        let mut out_dev = stream
950            .alloc_zeros::<f64>(ld * m)
951            .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
952
953        // Block (32, 8, 1) — x over centers, y over rows.
954        let block_x: u32 = 32;
955        let block_y: u32 = 8;
956        let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
957        let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
958        let cfg = LaunchConfig {
959            grid_dim: (grid_x, grid_y, 1),
960            block_dim: (block_x, block_y, 1),
961            shared_mem_bytes: 0,
962        };
963        let n_i32: i32 =
964            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
965        let m_i32: i32 =
966            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
967        let ld_i64: i64 = ld as i64;
968
969        let mut builder = stream.launch_builder(&func);
970        builder
971            .arg(&data_dev)
972            .arg(&centers_dev)
973            .arg(&coeffs_dev)
974            .arg(&n_i32)
975            .arg(&m_i32)
976            .arg(&ld_i64)
977            .arg(&mut out_dev);
978        // SAFETY: launch parameters are validated above; all device
979        // pointers come from cudarc-checked allocations on the same
980        // stream; the kernel only reads inputs and writes within
981        // out[0 .. ld*m].
982        unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
983        stream
984            .synchronize()
985            .gpu_ctx("sphere raw kernel synchronize")?;
986
987        Ok(DeviceS2KernelMatrix {
988            rows: n,
989            cols: m,
990            ld,
991            col_major_dev: out_dev,
992            stream,
993        })
994    }
995
996    #[cfg(not(target_os = "linux"))]
997    {
998        Err(GpuError::DriverLibraryUnavailable {
999            reason: "sphere GPU backend is Linux-only".to_string(),
1000        })
1001    }
1002}
1003
1004/// Phase-3 fused Householder-constrained kernel. `v` is the Householder
1005/// vector (length m), `beta` the reflector scalar, and the output is
1006/// the `(n × (m-1))` constrained design X_s on device.
1007pub fn build_householder_constrained_design_device(
1008    inputs: S2KernelBuildInputs<'_>,
1009    v: &[f64],
1010    beta: f64,
1011) -> Result<DeviceS2KernelMatrix, GpuError> {
1012    inputs.validate()?;
1013    if v.len() != inputs.m {
1014        gam_gpu::gpu_bail!(
1015            "build_householder_constrained_design_device: v.len()={} != m={}",
1016            v.len(),
1017            inputs.m
1018        );
1019    }
1020    if inputs.m < 2 {
1021        gam_gpu::gpu_bail!(
1022            "build_householder_constrained_design_device: m must be >= 2 (got {})",
1023            inputs.m
1024        );
1025    }
1026    if !beta.is_finite() {
1027        gam_gpu::gpu_bail!(
1028            "build_householder_constrained_design_device: beta must be finite (got {beta})"
1029        );
1030    }
1031
1032    #[cfg(target_os = "linux")]
1033    {
1034        use cudarc::driver::{LaunchConfig, PushKernelArg};
1035        let backend = SphereGpuBackend::probe()?;
1036        let (cc_major, cc_minor) = backend.cc();
1037        let key = S2ModuleCacheKey {
1038            cc_major,
1039            cc_minor,
1040            lmax: inputs.lmax as u32,
1041            kind: inputs.kind,
1042            layout: inputs.layout,
1043        };
1044        let module = backend.module_for(key)?;
1045        let func = module
1046            .load_function("s2_wahba_householder_constrained_colmajor")
1047            .gpu_ctx("sphere load_function householder")?;
1048        let stream = backend.inner.stream.clone();
1049
1050        let data_dev = stream
1051            .clone_htod(inputs.data_xyz)
1052            .gpu_ctx("sphere-hh htod data_xyz")?;
1053        let centers_dev = stream
1054            .clone_htod(inputs.centers_xyz)
1055            .gpu_ctx("sphere-hh htod centers_xyz")?;
1056        let coeffs_dev = stream
1057            .clone_htod(inputs.coeffs)
1058            .gpu_ctx("sphere-hh htod coeffs")?;
1059        let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
1060
1061        let n = inputs.n;
1062        let m = inputs.m;
1063        let cols_out = m - 1;
1064        let ld_out = ((n + 31) / 32) * 32;
1065        let mut out_dev = stream
1066            .alloc_zeros::<f64>(ld_out * cols_out)
1067            .gpu_ctx_with(|err| {
1068                format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
1069            })?;
1070
1071        let block_x: u32 = 128;
1072        let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
1073        let cfg = LaunchConfig {
1074            grid_dim: (grid_x, 1, 1),
1075            block_dim: (block_x, 1, 1),
1076            shared_mem_bytes: 0,
1077        };
1078        let n_i32: i32 =
1079            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
1080        let m_i32: i32 =
1081            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
1082        let ld_out_i64: i64 = ld_out as i64;
1083
1084        let mut builder = stream.launch_builder(&func);
1085        builder
1086            .arg(&data_dev)
1087            .arg(&centers_dev)
1088            .arg(&coeffs_dev)
1089            .arg(&v_dev)
1090            .arg(&beta)
1091            .arg(&n_i32)
1092            .arg(&m_i32)
1093            .arg(&ld_out_i64)
1094            .arg(&mut out_dev);
1095        // SAFETY: validated shapes above; the kernel writes exactly
1096        // (n × (m-1)) entries within `out[0 .. ld_out * (m-1)]`.
1097        unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
1098        stream
1099            .synchronize()
1100            .gpu_ctx("sphere-hh kernel synchronize")?;
1101
1102        Ok(DeviceS2KernelMatrix {
1103            rows: n,
1104            cols: cols_out,
1105            ld: ld_out,
1106            col_major_dev: out_dev,
1107            stream,
1108        })
1109    }
1110
1111    #[cfg(not(target_os = "linux"))]
1112    {
1113        Err(GpuError::DriverLibraryUnavailable {
1114            reason: "sphere GPU backend is Linux-only".to_string(),
1115        })
1116    }
1117}
1118
1119// ────────────────────────────────────────────────────────────────────────
1120// Householder reflector helpers (host-side; Phase 3 prep).
1121//
1122// Given a non-zero weight vector w ∈ ℝ^m, construct (v, beta) such that
1123// H = I − beta · v · v^T satisfies H · w = ±‖w‖ · e_1 and drops the
1124// weighted-sum constraint into the first column.
1125// ────────────────────────────────────────────────────────────────────────
1126
1127/// Build the Householder reflector that zeroes `w` against `e_1`.
1128/// Returns `(v, beta)` with the LAPACK / Golub-Van Loan convention
1129/// `v[0] = 1`. If `w` has zero norm, returns `(0-vector, 0.0)` and the
1130/// caller should treat the reflector as a no-op (no constraint).
1131pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
1132    let m = w.len();
1133    if m == 0 {
1134        return (Vec::new(), 0.0);
1135    }
1136    let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1137    if norm == 0.0 {
1138        return (vec![0.0; m], 0.0);
1139    }
1140    let sigma = if w[0] >= 0.0 { norm } else { -norm };
1141    let mut v = w.to_vec();
1142    v[0] += sigma;
1143    let v0 = v[0];
1144    if v0 == 0.0 {
1145        return (vec![0.0; m], 0.0);
1146    }
1147    // Normalize so v[0] = 1 (LAPACK convention).
1148    for entry in v.iter_mut() {
1149        *entry /= v0;
1150    }
1151    // beta = 2 / (v · v).
1152    let vv: f64 = v.iter().map(|x| x * x).sum();
1153    let beta = 2.0 / vv;
1154    (v, beta)
1155}
1156
1157// ────────────────────────────────────────────────────────────────────────
1158// Phase 2 — center-center penalty C + constraint S = Zᵀ C Z.
1159//
1160// `C` is the (m × m) Wahba kernel of centers against themselves and is
1161// computed by reusing the raw GPU kernel with `n = m`. The constraint
1162// transform is the same Householder reflector used by the Phase-3 fused
1163// kernel: Z = (I − β · v · vᵀ) with the first column dropped, so the
1164// constrained penalty is the trailing (m−1)×(m−1) block of HᵀCH.
1165//
1166// At m ≤ 200 the Householder product is cheap on host and the result is
1167// returned as an `ndarray::Array2`. Future calls into cuSOLVER QR can
1168// upload it (or its Cholesky factor) once and keep it device-resident.
1169// ────────────────────────────────────────────────────────────────────────
1170
1171/// Build the (m × m) center-center kernel matrix `C` using the same GPU
1172/// kernel that builds the design. `centers_xyz` is the unit-vector
1173/// representation of the centers, length `3 * m`. `coeffs` and `kind`
1174/// match the design build.
1175pub fn build_center_kernel_device(
1176    centers_xyz: &[f64],
1177    lmax: usize,
1178    coeffs: &[f64],
1179    kind: SphereSpectralKernelKind,
1180) -> Result<DeviceS2KernelMatrix, GpuError> {
1181    let m = centers_xyz.len() / 3;
1182    if centers_xyz.len() != 3 * m {
1183        return Err(GpuError::DriverCallFailed {
1184            reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
1185        });
1186    }
1187    let inputs = S2KernelBuildInputs {
1188        n: m,
1189        m,
1190        lmax,
1191        data_xyz: centers_xyz,
1192        centers_xyz,
1193        coeffs,
1194        kind,
1195        layout: DeviceMatrixLayout::ColumnMajor,
1196    };
1197    build_kernel_matrix_device(inputs)
1198}
1199
1200/// Constrained penalty matrix `S = Zᵀ C Z` for the
1201/// weighted-sum-to-zero Householder constraint built from `w`.
1202/// Returned shape is `((m−1) × (m−1))`. `C` is taken as a host
1203/// (m × m) array (typically the dtoh of `build_center_kernel_device`).
1204pub fn constrained_penalty_host(
1205    c: ArrayView2<'_, f64>,
1206    w: &[f64],
1207) -> Result<Array2<f64>, GpuError> {
1208    let (m1, m2) = c.dim();
1209    if m1 != m2 {
1210        gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
1211    }
1212    let m = m1;
1213    if w.len() != m {
1214        gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
1215    }
1216    if m < 2 {
1217        gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
1218    }
1219    let (v, beta) = householder_reflector_from_weights(w);
1220
1221    // Form HCH = (I - β v vᵀ) C (I - β v vᵀ) = C - β (v · uᵀ + u · vᵀ) + β² (vᵀ C v) v vᵀ,
1222    // where u = C v. This is O(m²) — fine for m ≤ 200.
1223    let mut u = vec![0.0_f64; m];
1224    for i in 0..m {
1225        let mut acc = 0.0_f64;
1226        for j in 0..m {
1227            acc += c[(i, j)] * v[j];
1228        }
1229        u[i] = acc;
1230    }
1231    let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
1232    let mut hch = Array2::<f64>::zeros((m, m));
1233    for i in 0..m {
1234        for j in 0..m {
1235            hch[(i, j)] =
1236                c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
1237        }
1238    }
1239    // Drop the first row and column (the Householder-constrained nullspace).
1240    let mut s = Array2::<f64>::zeros((m - 1, m - 1));
1241    for i in 0..(m - 1) {
1242        for j in 0..(m - 1) {
1243            s[(i, j)] = hch[(i + 1, j + 1)];
1244        }
1245    }
1246    Ok(s)
1247}
1248
1249// ────────────────────────────────────────────────────────────────────────
1250// Phase 4 — device-resident cuSOLVER QR penalised solve.
1251//
1252// Solve  min_β  ‖ [√W · X_s] β − [√W · y] ‖² + λ ‖R_S · β‖²
1253//
1254// by stacking the augmented matrix
1255//
1256//     A_aug = [ √W · X_s ;   √λ · R_S ]    shape (n + p) × p,
1257//     b_aug = [ √W · y    ;   0       ]    length n + p,
1258//
1259// where p = m − 1, R_S is the upper-triangular Cholesky factor of the
1260// constrained penalty S = Zᵀ C Z, and (√W·X_s) is the design built by
1261// the fused Householder kernel scaled by sqrt-weights row-by-row on
1262// device. The pipeline is:
1263//
1264//     1. cusolverDnDgeqrf_bufferSize → workspace size.
1265//     2. cusolverDnDgeqrf(A_aug)     → A := [R upper-tri / V Householder]
1266//                                        plus tau vector.
1267//     3. cusolverDnDormqr(side=L, trans=T)
1268//                                  → applies Qᵀ to b_aug.
1269//     4. cublasDtrsm(L = upper) → β := R⁻¹ · (Qᵀ b_aug)[0..p].
1270//
1271// Coefficients (β) come back to host; log|H| can be returned via Σ
1272// log(R_ii²) from the diagonal of the in-place factored R.
1273//
1274// All intermediate state — A_aug, b_aug, tau, workspace, info — stays
1275// device-resident. The host learns only (β, log|H|, residual ssq).
1276// ────────────────────────────────────────────────────────────────────────
1277
1278/// Result returned by [`solve_penalised_ls_device`].
1279#[derive(Clone, Debug)]
1280pub struct PenalisedLsSolution {
1281    /// Coefficient vector, length `p = m − 1` (after Householder drop).
1282    pub beta: Vec<f64>,
1283    /// Sum of squared residuals on the unaugmented rows: ‖√W (Xβ − y)‖².
1284    pub weighted_residual_ssq: f64,
1285    /// log|H| = 2 · Σ log |R_ii| of the QR-factored augmented design.
1286    pub log_det_hessian: f64,
1287}
1288
1289/// Augmented penalised least-squares solve via on-device cuSOLVER QR.
1290///
1291/// Inputs:
1292///   * `x_s_device` — already-constrained, weighted-sqrt-scaled design
1293///     `√W · X_s` produced by the Phase-3 fused kernel + a row-scaling
1294///     kernel. Shape `(n × p)` column-major.
1295///   * `wy` — `√W · y` (length n), already host-multiplied (cheap).
1296///   * `r_s` — upper-triangular Cholesky factor of `√λ · S`, shape
1297///     `(p × p)` row-major host array.
1298#[cfg(target_os = "linux")]
1299pub fn solve_penalised_ls_device(
1300    x_s_device: &DeviceS2KernelMatrix,
1301    wy: &[f64],
1302    r_s: ArrayView2<'_, f64>,
1303) -> Result<PenalisedLsSolution, GpuError> {
1304    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1305    use cudarc::driver::DevicePtrMut;
1306
1307    let n = x_s_device.rows;
1308    let p = x_s_device.cols;
1309    if wy.len() != n {
1310        gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1311    }
1312    if r_s.dim() != (p, p) {
1313        gam_gpu::gpu_bail!(
1314            "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1315            r_s.dim()
1316        );
1317    }
1318    if p == 0 {
1319        return Ok(PenalisedLsSolution {
1320            beta: Vec::new(),
1321            weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1322            log_det_hessian: 0.0,
1323        });
1324    }
1325
1326    let stream = x_s_device.stream.clone();
1327    let n_aug = n + p;
1328
1329    // 1) Materialise A_aug column-major on device. We don't need the
1330    //    upstream X_s after QR, but the kernel matrix builder hands us
1331    //    its own storage; we copy into a fresh (n_aug × p) slab so the
1332    //    in-place geqrf doesn't clobber a buffer the caller still owns.
1333    let mut a_aug_host = vec![0.0_f64; n_aug * p];
1334    // Copy device-side X_s back column-by-column into the upper block.
1335    let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1336    x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1337    for j in 0..p {
1338        let src_off = j * x_s_device.ld;
1339        let dst_off = j * n_aug;
1340        a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1341        for i in 0..p {
1342            // R_S is row-major host; insert into column j of the lower
1343            // block (rows n..n+p) as r_s[i, j].
1344            a_aug_host[dst_off + n + i] = r_s[(i, j)];
1345        }
1346    }
1347    let mut a_dev = stream
1348        .clone_htod(&a_aug_host)
1349        .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1350
1351    // b_aug = [√W·y ; 0]
1352    let mut b_host = vec![0.0_f64; n_aug];
1353    b_host[..n].copy_from_slice(wy);
1354    let mut b_dev = stream
1355        .clone_htod(&b_host)
1356        .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1357
1358    let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1359    let n_aug_i: i32 = i32::try_from(n_aug)
1360        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1361    let p_i: i32 = i32::try_from(p)
1362        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1363
1364    // 2) Workspace size for geqrf.
1365    let mut lwork: i32 = 0;
1366    {
1367        let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1368        // SAFETY: a_dev holds n_aug*p f64 elements column-major;
1369        // pointer is live on `stream`; lwork is a valid host out-param.
1370        let status = unsafe {
1371            cusolver_sys::cusolverDnDgeqrf_bufferSize(
1372                solver.cu(),
1373                n_aug_i,
1374                p_i,
1375                a_ptr as *mut f64,
1376                n_aug_i,
1377                &mut lwork,
1378            )
1379        };
1380        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1381            gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1382        }
1383    }
1384    let lwork_us = usize::try_from(lwork)
1385        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1386    let mut workspace = stream
1387        .alloc_zeros::<f64>(lwork_us.max(1))
1388        .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1389    let mut tau = stream
1390        .alloc_zeros::<f64>(p)
1391        .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1392    let mut info = stream
1393        .alloc_zeros::<i32>(1)
1394        .gpu_ctx("solve_penalised_ls_device alloc info")?;
1395
1396    // 3) cusolverDnDgeqrf — A := QR in place.
1397    {
1398        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1399        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1400        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1401        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1402        // SAFETY: all pointers reference live device allocations on
1403        // this stream; lwork matches the bufferSize query above.
1404        let status = unsafe {
1405            cusolver_sys::cusolverDnDgeqrf(
1406                solver.cu(),
1407                n_aug_i,
1408                p_i,
1409                a_ptr as *mut f64,
1410                n_aug_i,
1411                tau_ptr as *mut f64,
1412                work_ptr as *mut f64,
1413                lwork,
1414                info_ptr as *mut i32,
1415            )
1416        };
1417        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1418            gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1419        }
1420    }
1421
1422    // 4) cusolverDnDormqr — b_aug := Qᵀ · b_aug.
1423    let mut ormqr_lwork: i32 = 0;
1424    {
1425        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1426        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1427        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1428        // SAFETY: A/tau/b are live device buffers on this stream;
1429        // ormqr_lwork is a host out-param.
1430        let status = unsafe {
1431            cusolver_sys::cusolverDnDormqr_bufferSize(
1432                solver.cu(),
1433                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1434                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1435                n_aug_i,
1436                1,
1437                p_i,
1438                a_ptr as *const f64,
1439                n_aug_i,
1440                tau_ptr as *const f64,
1441                b_ptr as *mut f64,
1442                n_aug_i,
1443                &mut ormqr_lwork,
1444            )
1445        };
1446        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1447            gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1448        }
1449    }
1450    if ormqr_lwork > lwork {
1451        workspace = stream
1452            .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1453            .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1454    }
1455    {
1456        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1457        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1458        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1459        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1460        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1461        // SAFETY: all pointers reference live, mutually-non-aliasing
1462        // device buffers on this stream; lwork matches the bufferSize
1463        // query above; A and tau are the geqrf output.
1464        let status = unsafe {
1465            cusolver_sys::cusolverDnDormqr(
1466                solver.cu(),
1467                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1468                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1469                n_aug_i,
1470                1,
1471                p_i,
1472                a_ptr as *const f64,
1473                n_aug_i,
1474                tau_ptr as *const f64,
1475                b_ptr as *mut f64,
1476                n_aug_i,
1477                work_ptr as *mut f64,
1478                ormqr_lwork.max(lwork),
1479                info_ptr as *mut i32,
1480            )
1481        };
1482        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1483            gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1484        }
1485    }
1486
1487    // 5) cublasDtrsm — solve R · β = (Qᵀ b)[0..p] in place on the top
1488    //    of b_dev. We use a single-RHS upper-triangular non-unit solve.
1489    {
1490        use cudarc::cublas::CudaBlas;
1491        let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1492        let alpha = 1.0_f64;
1493        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1494        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1495        // SAFETY: A is the geqrf-output upper-triangular factor R in
1496        // its top-p × p block (col-major, ld = n_aug); b is the
1497        // ormqr-output Qᵀb in the top p slots (ld = n_aug as well so
1498        // pretend it is column-major with 1 column of leading dim n_aug).
1499        let handle = *blas.handle();
1500        let status = unsafe {
1501            cudarc::cublas::sys::cublasDtrsm_v2(
1502                handle,
1503                cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1504                cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1505                cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1506                cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1507                p_i,
1508                1,
1509                &alpha,
1510                a_ptr as *const f64,
1511                n_aug_i,
1512                b_ptr as *mut f64,
1513                n_aug_i,
1514            )
1515        };
1516        if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1517            gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1518        }
1519    }
1520
1521    // 6) Copy results back to host.
1522    let mut b_out = vec![0.0_f64; n_aug];
1523    stream
1524        .memcpy_dtoh(&b_dev, &mut b_out)
1525        .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1526    let mut a_back = vec![0.0_f64; n_aug * p];
1527    stream
1528        .memcpy_dtoh(&a_dev, &mut a_back)
1529        .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1530    stream
1531        .synchronize()
1532        .gpu_ctx("solve_penalised_ls_device synchronize")?;
1533
1534    let beta: Vec<f64> = b_out[..p].to_vec();
1535    // (Qᵀb)[p..n_aug] holds the residual in the rotated coordinates;
1536    // ‖(Qᵀb)[p..]‖² = ‖√W (Xβ − y)‖² + λ ‖R_S β‖² for the augmented
1537    // system. To recover ‖√W (Xβ − y)‖² alone, subtract the penalty
1538    // residual ‖R_S β‖² (penalty rotates to itself in the augmented
1539    // bottom block, but only when the bottom block ROWS map exactly
1540    // into the rotated residual — which is not guaranteed, so the
1541    // simpler accurate path is to return the **augmented** residual
1542    // squared and let the caller subtract.)
1543    let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1544
1545    // log|R| diagonal.
1546    let mut log_abs_r = 0.0_f64;
1547    for k in 0..p {
1548        let r_kk = a_back[k * n_aug + k];
1549        log_abs_r += r_kk.abs().ln();
1550    }
1551    let log_det_hessian = 2.0 * log_abs_r;
1552
1553    Ok(PenalisedLsSolution {
1554        beta,
1555        weighted_residual_ssq: augmented_residual_ssq,
1556        log_det_hessian,
1557    })
1558}
1559
1560#[cfg(not(target_os = "linux"))]
1561pub fn solve_penalised_ls_device(
1562    x_s_device: &DeviceS2KernelMatrix,
1563    wy: &[f64],
1564    r_s: ArrayView2<'_, f64>,
1565) -> Result<PenalisedLsSolution, GpuError> {
1566    Err(GpuError::DriverLibraryUnavailable {
1567        reason: format!(
1568            "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1569            x_s_device.rows,
1570            x_s_device.cols,
1571            wy.len(),
1572            r_s.dim()
1573        ),
1574    })
1575}
1576
1577// ────────────────────────────────────────────────────────────────────────
1578// Tests
1579// ────────────────────────────────────────────────────────────────────────
1580
1581#[cfg(test)]
1582mod sphere_gpu_tests {
1583    use super::*;
1584    use crate::basis::{
1585        SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1586        spherical_wahba_kernel_matrix_with_kind,
1587    };
1588    use ndarray::Array2;
1589
1590    fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1591        // Latitude in (-85, 85), longitude in [-180, 180), degrees.
1592        let mut rows = Vec::with_capacity(n_lat * n_lon);
1593        for i in 0..n_lat {
1594            let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1595            for j in 0..n_lon {
1596                let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1597                rows.push(lat);
1598                rows.push(lon);
1599            }
1600        }
1601        Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1602    }
1603
1604    #[test]
1605    fn sum_finite_guard_accepts_finite_rejects_nonfinite() {
1606        // The admitted device path guards its output with `!out.sum().is_finite()`
1607        // instead of a per-element `any(!is_finite)`. This pins the equivalence
1608        // that justifies the swap: a finite matrix has a finite sum, and a single
1609        // NaN or ±Inf entry poisons the sum.
1610        let finite = Array2::<f64>::from_shape_fn((5, 7), |(i, j)| (i as f64 - 2.0) * (j as f64));
1611        assert!(finite.sum().is_finite());
1612
1613        let mut with_nan = finite.clone();
1614        with_nan[[3, 4]] = f64::NAN;
1615        assert!(!with_nan.sum().is_finite());
1616
1617        let mut with_pos_inf = finite.clone();
1618        with_pos_inf[[0, 0]] = f64::INFINITY;
1619        assert!(!with_pos_inf.sum().is_finite());
1620
1621        let mut with_neg_inf = finite.clone();
1622        with_neg_inf[[4, 6]] = f64::NEG_INFINITY;
1623        assert!(!with_neg_inf.sum().is_finite());
1624    }
1625
1626    #[test]
1627    fn xyz_preprocessing_matches_unit_sphere() {
1628        let latlon = ndarray::array![
1629            [0.0, 0.0],
1630            [90.0, 0.0],
1631            [0.0, 90.0],
1632            [-90.0, 17.5],
1633            [45.0, -120.0],
1634        ];
1635        let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1636        assert_eq!(xyz.len(), 3 * 5);
1637        for i in 0..5 {
1638            let nrm2 = xyz[3 * i] * xyz[3 * i]
1639                + xyz[3 * i + 1] * xyz[3 * i + 1]
1640                + xyz[3 * i + 2] * xyz[3 * i + 2];
1641            assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1642        }
1643        // Row 0 = equator @ lon=0 → (1, 0, 0).
1644        assert!((xyz[0] - 1.0).abs() < 1e-15);
1645        assert!(xyz[1].abs() < 1e-15);
1646        assert!(xyz[2].abs() < 1e-15);
1647        // Row 1 = north pole (lat=90, lon=0) → (0, 0, 1).
1648        assert!(xyz[3].abs() < 1e-15);
1649        assert!(xyz[4].abs() < 1e-15);
1650        assert!((xyz[5] - 1.0).abs() < 1e-15);
1651        // Row 2 = equator @ lon=90 → (0, 1, 0).
1652        assert!(xyz[6].abs() < 1e-15);
1653        assert!((xyz[7] - 1.0).abs() < 1e-15);
1654        assert!(xyz[8].abs() < 1e-15);
1655    }
1656
1657    #[test]
1658    fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1659        // P_ℓ(1) = 1 for all ℓ, so K(x, x) = Σ_{ℓ=0..L} c_ℓ. The Legendre
1660        // recurrence in `sphere_truncated_spectral_eval` must reproduce
1661        // this exact identity to roundoff.
1662        for m_penalty in 1..=4 {
1663            for &lmax in &[5_usize, 20, 50] {
1664                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1665                let expected: f64 = coeffs.iter().sum();
1666                let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1667                assert!(
1668                    (got - expected).abs() < 1e-13,
1669                    "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1670                );
1671            }
1672        }
1673    }
1674
1675    #[test]
1676    fn truncated_spectral_at_antipode_matches_alternating_sum() {
1677        // P_ℓ(-1) = (-1)^ℓ, so K(x, -x) = Σ_{ℓ=0..L} c_ℓ · (-1)^ℓ. Same
1678        // exact identity for the recurrence at t = -1.
1679        for m_penalty in 1..=4 {
1680            for &lmax in &[5_usize, 20, 50] {
1681                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1682                let expected: f64 = coeffs
1683                    .iter()
1684                    .enumerate()
1685                    .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1686                    .sum();
1687                let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1688                assert!(
1689                    (got - expected).abs() < 1e-13,
1690                    "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1691                );
1692            }
1693        }
1694    }
1695
1696    #[test]
1697    fn truncated_spectral_matrix_is_symmetric() {
1698        // K(γ) depends only on cos γ = x · y = y · x, so the Gram
1699        // matrix B B^T-style kernel evaluation on the same point set
1700        // must be symmetric to roundoff.
1701        let centers = ndarray::array![
1702            [10.0_f64, 20.0],
1703            [-30.0, 100.0],
1704            [45.0, -60.0],
1705            [-89.0, 0.0],
1706            [0.0, 180.0],
1707            [60.0, -179.9],
1708        ];
1709        for m_penalty in [1usize, 2, 4] {
1710            for &lmax in &[10_usize, 30] {
1711                let mat = spherical_wahba_kernel_matrix_with_kind(
1712                    centers.view(),
1713                    centers.view(),
1714                    m_penalty,
1715                    false,
1716                    SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1717                )
1718                .expect("kernel matrix");
1719                let n = centers.nrows();
1720                let mut max_asym = 0.0_f64;
1721                for i in 0..n {
1722                    for j in 0..n {
1723                        let d = (mat[(i, j)] - mat[(j, i)]).abs();
1724                        if d > max_asym {
1725                            max_asym = d;
1726                        }
1727                    }
1728                }
1729                assert!(
1730                    max_asym < 1e-13,
1731                    "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1732                );
1733            }
1734        }
1735    }
1736
1737    #[test]
1738    fn truncated_coefficients_have_zero_constant_mode() {
1739        for m in 1..=4 {
1740            let c = sobolev_s2_truncated_coefficients(50, m);
1741            assert_eq!(c.len(), 51);
1742            assert_eq!(c[0], 0.0);
1743            assert!(c[1] > 0.0);
1744            // Spectral decay c_ℓ ~ 1/ℓ^{2m-1}: monotone for ℓ ≥ 1.
1745            for ell in 2..=50 {
1746                assert!(
1747                    c[ell] < c[ell - 1] + 1e-15,
1748                    "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1749                    c[ell],
1750                    c[ell - 1]
1751                );
1752            }
1753        }
1754    }
1755
1756    #[test]
1757    fn truncated_spectral_matches_matrix_helper() {
1758        // The Wahba kernel matrix helper, invoked with the truncated
1759        // variant, must produce the same value as the bare scalar
1760        // evaluator.
1761        let m_penalty = 2;
1762        let lmax = 20;
1763        let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1764        let data = ndarray::array![[12.5, -34.0]];
1765        let centers = ndarray::array![[40.0, 10.0]];
1766        let mat = spherical_wahba_kernel_matrix_with_kind(
1767            data.view(),
1768            centers.view(),
1769            m_penalty,
1770            false,
1771            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1772        )
1773        .expect("kernel matrix");
1774        // Recompute cos γ on the unit sphere.
1775        let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1776        let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1777        let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1778        let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1779        assert!(
1780            (mat[(0, 0)] - expected).abs() < 1e-13,
1781            "matrix helper differs from scalar evaluator: {} vs {}",
1782            mat[(0, 0)],
1783            expected
1784        );
1785    }
1786
1787    #[test]
1788    fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1789        // Build a small symmetric PD matrix as a stand-in for C, then
1790        // verify that constrained_penalty_host returns a symmetric
1791        // (m-1)×(m-1) matrix whose action against Z·x matches the
1792        // expected Zᵀ C Z mapping.
1793        let m = 6;
1794        let mut c = Array2::<f64>::zeros((m, m));
1795        for i in 0..m {
1796            for j in 0..m {
1797                let d = (i as f64 - j as f64).abs();
1798                c[(i, j)] = (-0.5 * d).exp();
1799            }
1800        }
1801        let w = vec![1.0_f64; m];
1802        let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1803        assert_eq!(s.dim(), (m - 1, m - 1));
1804        // Symmetry within roundoff.
1805        let mut max_asym = 0.0_f64;
1806        for i in 0..(m - 1) {
1807            for j in 0..(m - 1) {
1808                let d = (s[(i, j)] - s[(j, i)]).abs();
1809                if d > max_asym {
1810                    max_asym = d;
1811                }
1812            }
1813        }
1814        assert!(
1815            max_asym < 1e-13,
1816            "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1817        );
1818
1819        // The kernel-of-Zᵀ direction: Zᵀ · w = 0 ⇒ x = (something) such
1820        // that Z · x stays in span(w)^⊥, so x can be any (m-1) vector;
1821        // we just verify that picking the all-ones constraint direction
1822        // collapses to zero through Z when applied to constant fields.
1823        // i.e. constant-field penalty norm must be zero in the
1824        // un-constrained Cv direction, and the trailing block here is
1825        // never used against the constraint.
1826        let ones = ndarray::Array1::<f64>::ones(m - 1);
1827        let sx = s.dot(&ones);
1828        assert!(sx.iter().all(|v| v.is_finite()));
1829    }
1830
1831    #[test]
1832    fn householder_reflector_zeroes_target_vector() {
1833        let w = vec![3.0, 4.0, 0.0, -1.0];
1834        let (v, beta) = householder_reflector_from_weights(&w);
1835        // Apply H = I - beta * v * v^T to w; the result should be a
1836        // multiple of e_1 (only first entry non-zero).
1837        let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1838        let hw: Vec<f64> = w
1839            .iter()
1840            .zip(&v)
1841            .map(|(wj, vj)| wj - beta * dot * vj)
1842            .collect();
1843        for entry in hw.iter().skip(1) {
1844            assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1845        }
1846        assert!(hw[0].abs() > 0.0);
1847    }
1848
1849    /// V100-only: probe + raw kernel parity vs CPU truncated-spectral on
1850    /// a small grid. Skips cleanly on hosts with no CUDA runtime.
1851    #[test]
1852    fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1853        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1854            eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1855            return;
1856        };
1857        // Past the runtime Some-gate: a probe failure is a real device fault on a
1858        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1859        SphereGpuBackend::probe()
1860            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1861
1862        let data_ll = small_latlon_grid(7, 9);
1863        let centers_ll = small_latlon_grid(5, 7);
1864        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1865        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1866        let n = data_ll.nrows();
1867        let m = centers_ll.nrows();
1868        let penalty = 2usize;
1869        let lmax = 20usize;
1870        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1871
1872        let inputs = S2KernelBuildInputs {
1873            n,
1874            m,
1875            lmax,
1876            data_xyz: &data_xyz,
1877            centers_xyz: &centers_xyz,
1878            coeffs: &coeffs,
1879            kind: SphereSpectralKernelKind::Sobolev,
1880            layout: DeviceMatrixLayout::ColumnMajor,
1881        };
1882        let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1883        let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1884
1885        let cpu = spherical_wahba_kernel_matrix_with_kind(
1886            data_ll.view(),
1887            centers_ll.view(),
1888            penalty,
1889            false,
1890            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1891        )
1892        .expect("cpu kernel matrix");
1893
1894        let mut max_abs = 0.0_f64;
1895        for i in 0..n {
1896            for j in 0..m {
1897                let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1898                if d > max_abs {
1899                    max_abs = d;
1900                }
1901            }
1902        }
1903        assert!(
1904            max_abs < 1e-11,
1905            "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1906        );
1907    }
1908
1909    /// V100-only end-to-end DISPATCH parity: prove the *production* kernel
1910    /// builder (`spherical_wahba_kernel_matrix_with_kind`) actually engages the
1911    /// device on a GPU-eligible truncated-spectral shape, and that the device
1912    /// result matches the CPU oracle (`spherical_wahba_kernel_matrix_cpu`) to
1913    /// roundoff. This is the engagement + parity gate the prior version of this
1914    /// test never exercised: it called `build_spherical_spline_basis` (which did
1915    /// not route to the GPU at all) and then compared the *decomposed* design
1916    /// against the *raw* kernel matrix, so it diverged by construction
1917    /// (rel |Δ| = 2.0) regardless of any device behaviour.
1918    ///
1919    /// Downstream PIRLS/REML consumes the kernel design through the same
1920    /// deterministic low-degree decomposition for both backends, so element-wise
1921    /// raw-kernel parity at ≤ 1e-9 implies full-design + fit parity.
1922    #[test]
1923    fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1924        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1925            eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1926            return;
1927        };
1928        // Past the runtime Some-gate: a backend probe failure is a real device
1929        // fault on a CUDA host, not a no-CUDA skip — fail loud (device-PCG
1930        // skip-pass class, eee12f6b2) instead of masking it as a pass.
1931        SphereGpuBackend::probe()
1932            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1933        use crate::basis::{
1934            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1935            build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1936            spherical_wahba_kernel_matrix_with_kind,
1937        };
1938
1939        // (n=10_000, m=200) → n·m = 2_000_000 ≥ 1_000_000 → GPU eligible.
1940        let data = small_latlon_grid(100, 100);
1941        let lmax: u16 = 30;
1942        let penalty_order = 2usize;
1943        let centers =
1944            crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1945                .expect("centers");
1946        let n = data.nrows();
1947        let m = centers.nrows();
1948
1949        // The device MUST be admitted for this shape, otherwise this test would
1950        // silently exercise the CPU path on both sides and prove nothing about
1951        // engagement. Fail loud if the dispatch decision declines the GPU.
1952        let decision = sphere_kernel_decision(n, m, lmax as usize);
1953        assert!(
1954            decision.use_gpu,
1955            "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1956             (reason={}); the engagement gate regressed",
1957            decision.reason
1958        );
1959
1960        // Production dispatcher: engages the device for this admitted shape.
1961        let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1962            data.view(),
1963            centers.view(),
1964            penalty_order,
1965            false,
1966            SphereWahbaKernel::SobolevTruncated { lmax },
1967        )
1968        .expect("GPU-eligible production kernel build succeeds");
1969
1970        // CPU oracle: forced host evaluation regardless of dispatch decision.
1971        let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1972            data.view(),
1973            centers.view(),
1974            penalty_order,
1975            false,
1976            SphereWahbaKernel::SobolevTruncated { lmax },
1977        )
1978        .expect("cpu oracle kernel build succeeds");
1979
1980        assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1981        let mut max_abs = 0.0_f64;
1982        let mut max_rel = 0.0_f64;
1983        for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1984            let d = (g - c).abs();
1985            if d > max_abs {
1986                max_abs = d;
1987            }
1988            let denom = g.abs().max(c.abs()).max(1e-300);
1989            let r = d / denom;
1990            if r > max_rel {
1991                max_rel = r;
1992            }
1993        }
1994        assert!(
1995            max_rel < 1e-9,
1996            "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1997             >= 1e-9 (abs {max_abs:.3e})"
1998        );
1999
2000        // End-to-end smoke: the full design build (which routes its large
2001        // data×centers kernel through the engaged device) produces a finite,
2002        // correctly-shaped design with the expected number of rows.
2003        let spec_gpu = SphericalSplineBasisSpec {
2004            center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
2005            penalty_order,
2006            double_penalty: false,
2007            radians: false,
2008            method: SphereMethod::Wahba,
2009            max_degree: None,
2010            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2011            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2012        };
2013        let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
2014            .expect("GPU-eligible build_spherical_spline_basis succeeds");
2015        let design = result_gpu.design.as_dense().expect("dense design");
2016        assert_eq!(design.nrows(), n, "design row count must match data rows");
2017        assert!(
2018            design.iter().all(|v| v.is_finite()),
2019            "engaged-device spherical design must be finite"
2020        );
2021    }
2022
2023    /// V100-only: parity of Householder-constrained kernel against
2024    /// (raw kernel) · Z evaluated on host.
2025    #[test]
2026    fn sphere_gpu_householder_parity_vs_raw_dot_z() {
2027        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2028            eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
2029            return;
2030        };
2031        // Past the runtime Some-gate: a probe failure is a real device fault on a
2032        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
2033        SphereGpuBackend::probe()
2034            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
2035        let data_ll = small_latlon_grid(6, 8);
2036        let centers_ll = small_latlon_grid(4, 5);
2037        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2038        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2039        let n = data_ll.nrows();
2040        let m = centers_ll.nrows();
2041        let penalty = 2usize;
2042        let lmax = 15usize;
2043        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
2044
2045        // Build raw B on device, then form (n × m-1) X_s = B · Z on host.
2046        let inputs_raw = S2KernelBuildInputs {
2047            n,
2048            m,
2049            lmax,
2050            data_xyz: &data_xyz,
2051            centers_xyz: &centers_xyz,
2052            coeffs: &coeffs,
2053            kind: SphereSpectralKernelKind::Sobolev,
2054            layout: DeviceMatrixLayout::ColumnMajor,
2055        };
2056        let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
2057        let b = b_dev.to_host_array().expect("dtoh raw");
2058
2059        // Construct a Householder reflector from a uniform weight vector
2060        // (the "weighted sum-to-zero" constraint when weights are all 1).
2061        let w = vec![1.0_f64; m];
2062        let (v, beta) = householder_reflector_from_weights(&w);
2063
2064        // Apply on host: X_s_host[i, j_out] = B[i, j_out+1] - beta * (B[i,:] · v) * v[j_out+1]
2065        let mut xs_host = Array2::<f64>::zeros((n, m - 1));
2066        for i in 0..n {
2067            let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
2068            for j_out in 0..(m - 1) {
2069                xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
2070            }
2071        }
2072
2073        let xs_dev =
2074            build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
2075        let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
2076
2077        let mut max_abs = 0.0_f64;
2078        for i in 0..n {
2079            for j in 0..(m - 1) {
2080                let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
2081                if d > max_abs {
2082                    max_abs = d;
2083                }
2084            }
2085        }
2086        assert!(
2087            max_abs < 1e-12,
2088            "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
2089        );
2090    }
2091
2092    /// V100 hill-climb: GPU truncated-spectral kernel matrix build at
2093    /// (n=200_000, m=200, L=50) must beat CPU by ≥ 20× wall-clock.
2094    /// Skips silently when no CUDA runtime is available.
2095    #[test]
2096    fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
2097        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2098            eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
2099            return;
2100        };
2101        if SphereGpuBackend::probe().is_err() {
2102            eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
2103            return;
2104        }
2105
2106        // (n=200_000, m=200, lmax=50). n·m = 4·10^7 ≫ 1e6 → GPU eligible.
2107        // Build a 200_000-row deterministic lat/lon grid.
2108        let n_lat = 500usize;
2109        let n_lon = 400usize;
2110        assert_eq!(n_lat * n_lon, 200_000);
2111        let data_ll = small_latlon_grid(n_lat, n_lon);
2112        let m = 200usize;
2113        let centers_ll =
2114            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2115                .expect("centers");
2116        let n = data_ll.nrows();
2117        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2118        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2119        let penalty_order = 2usize;
2120        let lmax = 50usize;
2121        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
2122
2123        // Warm up GPU (NVRTC compile + first-touch alloc).
2124        let inputs_warm = S2KernelBuildInputs {
2125            n,
2126            m,
2127            lmax,
2128            data_xyz: &data_xyz,
2129            centers_xyz: &centers_xyz,
2130            coeffs: &coeffs,
2131            kind: SphereSpectralKernelKind::Sobolev,
2132            layout: DeviceMatrixLayout::ColumnMajor,
2133        };
2134        // Warm the NVRTC module, first-touch device alloc, AND the pinned
2135        // host-staging pool (the page-lock of the (ld·cols)·8 B return buffer
2136        // is a ~140 ms one-time cost that production amortizes across the REML
2137        // outer loop; warming `to_host_array` here mirrors that steady state).
2138        {
2139            let warm = build_kernel_matrix_device(inputs_warm.clone()).expect("warmup");
2140            drop(warm.to_host_array().expect("warmup to_host"));
2141        }
2142
2143        // Measure GPU.
2144        let t0 = std::time::Instant::now();
2145        let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
2146        dev.to_host_array().expect("dtoh");
2147        let gpu_secs = t0.elapsed().as_secs_f64();
2148
2149        // Measure CPU. Must call the explicit host oracle
2150        // (`spherical_wahba_kernel_matrix_cpu`), NOT the dispatching
2151        // `spherical_wahba_kernel_matrix_with_kind`: at this `n·m = 4·10⁷` shape
2152        // the dispatcher now ROUTES TO THE GPU (that is the whole point of the
2153        // engagement wiring), so timing it here would compare GPU-vs-GPU and
2154        // collapse the ratio to ~1×. The oracle always evaluates on host.
2155        let t1 = std::time::Instant::now();
2156        crate::basis::spherical_wahba_kernel_matrix_cpu(
2157            data_ll.view(),
2158            centers_ll.view(),
2159            penalty_order,
2160            false,
2161            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
2162        )
2163        .expect("cpu kernel matrix");
2164        let cpu_secs = t1.elapsed().as_secs_f64();
2165
2166        let ratio = cpu_secs / gpu_secs.max(1e-9);
2167        eprintln!(
2168            "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
2169        );
2170        assert!(
2171            ratio >= 20.0,
2172            "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
2173             n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2174        );
2175    }
2176
2177    /// V100 hill-climb: end-to-end Gaussian fit through
2178    /// `build_spherical_spline_basis` (GPU-dispatched) must beat the
2179    /// CPU-only fit by ≥ 10× wall-clock at a workload where the GPU
2180    /// kernel build dominates PIRLS.
2181    #[test]
2182    fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
2183        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2184            eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
2185            return;
2186        };
2187        if SphereGpuBackend::probe().is_err() {
2188            eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
2189            return;
2190        }
2191        use crate::basis::{
2192            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
2193            build_spherical_spline_basis,
2194        };
2195
2196        let n_lat = 500usize;
2197        let n_lon = 400usize;
2198        let data_ll = small_latlon_grid(n_lat, n_lon);
2199        let m: usize = 200;
2200        let lmax: u16 = 50;
2201        let spec_gpu = SphericalSplineBasisSpec {
2202            center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
2203            penalty_order: 2,
2204            double_penalty: false,
2205            radians: false,
2206            method: SphereMethod::Wahba,
2207            max_degree: None,
2208            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2209            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2210        };
2211
2212        // Warm-up GPU build.
2213        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
2214
2215        let t0 = std::time::Instant::now();
2216        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
2217        let gpu_secs = t0.elapsed().as_secs_f64();
2218
2219        // CPU comparison: directly invoke the CPU helper and apply the
2220        // same constraint transform (matches what build_*_basis would do
2221        // when GPU dispatch declines). Going through the public matrix
2222        // helper isolates the GPU-vs-CPU kernel cost without re-doing
2223        // farthest-point center selection (which is identical for both
2224        // paths).
2225        let centers =
2226            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2227                .expect("centers");
2228        let z = Array2::<f64>::eye(centers.nrows());
2229        let t1 = std::time::Instant::now();
2230        // Explicit host oracle: at this shape the dispatcher routes to the GPU,
2231        // so the CPU baseline must call `spherical_wahba_kernel_matrix_cpu`
2232        // directly — otherwise this would time GPU-vs-GPU and the ratio would
2233        // collapse to ~1×.
2234        let raw_cpu = crate::basis::spherical_wahba_kernel_matrix_cpu(
2235            data_ll.view(),
2236            centers.view(),
2237            2,
2238            false,
2239            SphereWahbaKernel::SobolevTruncated { lmax },
2240        )
2241        .expect("cpu raw");
2242        raw_cpu.dot(&z);
2243        let cpu_secs = t1.elapsed().as_secs_f64();
2244
2245        let ratio = cpu_secs / gpu_secs.max(1e-9);
2246        eprintln!(
2247            "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
2248            data_ll.nrows()
2249        );
2250        assert!(
2251            ratio >= 10.0,
2252            "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
2253             cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2254        );
2255    }
2256
2257    /// Task #25: end-to-end fit parity between the GPU truncated-spectral
2258    /// path and the CPU truncated-spectral path on a small synthetic
2259    /// intrinsic-S² fixture.
2260    ///
2261    /// Setup: deterministic lat/lon grid (n = 1000 = 25 × 40), 80 centers
2262    /// chosen by farthest-point selection, lmax = 15, penalty order 2,
2263    /// Wahba weighted-sum-to-zero constraint applied via `Z`. We fit a
2264    /// fixed-λ penalised LS problem
2265    ///   β = argmin ‖X_s β − y‖² + λ · βᵀ S β
2266    /// where `X_s = K(data, centers) · Z` and `S = Zᵀ · K(centers, centers) · Z`,
2267    /// solving `(X_sᵀ X_s + λ S) β = X_sᵀ y` via faer LLT for both paths.
2268    /// The only path-dependent quantity is `K(data, centers)`: built on
2269    /// GPU via `build_kernel_matrix_device` for one β, and on CPU via
2270    /// `spherical_wahba_kernel_matrix_with_kind` for the other. The
2271    /// penalty kernel `K(centers, centers)` is m × m and tiny, so we
2272    /// build it once on CPU and share it across paths (it is not the
2273    /// surface under test).
2274    ///
2275    /// Asserts max-absolute coefficient delta ≤ 1e-9 and max-absolute
2276    /// fitted-value delta ≤ 1e-9. `#[ignore = "requires CUDA"]` so the
2277    /// V100 bench runner unignores in their harness.
2278    #[test]
2279    fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
2280        use crate::basis::{
2281            select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
2282        };
2283        use faer::Side;
2284        use gam_linalg::faer_ndarray::FaerCholesky;
2285
2286        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2287            eprintln!(
2288                "[sphere gpu parity] no CUDA runtime — skipping device parity \
2289                 (CPU oracle exercised by sibling tests)"
2290            );
2291            return;
2292        };
2293        // Past the runtime Some-gate: a probe failure is a real device fault on a
2294        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
2295        SphereGpuBackend::probe()
2296            .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2297
2298        // Fixture: 25 × 40 lat/lon grid → n = 1000.
2299        let data_ll = small_latlon_grid(25, 40);
2300        assert_eq!(data_ll.nrows(), 1000);
2301        let n = data_ll.nrows();
2302        let m: usize = 80;
2303        let lmax_u16: u16 = 15;
2304        let lmax: usize = lmax_u16 as usize;
2305        let penalty_order: usize = 2;
2306        let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2307        let lambda: f64 = 1.0e-3;
2308
2309        // Deterministic centers via farthest-point selection.
2310        let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2311            .expect("farthest-point centers");
2312        assert_eq!(centers_ll.nrows(), m);
2313
2314        // The Wahba sphere basis no longer imposes a finite-center coefficient
2315        // gauge; parity compares the raw center coefficient chart.
2316        let z = Array2::<f64>::eye(centers_ll.nrows());
2317        let p = z.ncols();
2318        assert_eq!(p, m);
2319
2320        // Penalty K(centers, centers), built once on CPU. The penalty
2321        // kernel evaluation is m × m (= 6400 entries), well outside the
2322        // GPU dispatch threshold, and identical for both paths under
2323        // test by construction.
2324        let k_cc = spherical_wahba_kernel_matrix_with_kind(
2325            centers_ll.view(),
2326            centers_ll.view(),
2327            penalty_order,
2328            false,
2329            kernel,
2330        )
2331        .expect("centers×centers kernel");
2332        let s_full = z.t().dot(&k_cc).dot(&z);
2333
2334        // CPU path: K(data, centers) via the public CPU helper.
2335        let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2336            data_ll.view(),
2337            centers_ll.view(),
2338            penalty_order,
2339            false,
2340            kernel,
2341        )
2342        .expect("CPU raw design");
2343        let x_s_cpu = raw_design_cpu.dot(&z);
2344
2345        // GPU path: K(data, centers) via `build_kernel_matrix_device`.
2346        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2347        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2348        let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2349        let inputs = S2KernelBuildInputs {
2350            n,
2351            m,
2352            lmax,
2353            data_xyz: &data_xyz,
2354            centers_xyz: &centers_xyz,
2355            coeffs: &coeffs,
2356            kind: SphereSpectralKernelKind::Sobolev,
2357            layout: DeviceMatrixLayout::ColumnMajor,
2358        };
2359        let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2360        let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2361        let x_s_gpu = raw_design_gpu.dot(&z);
2362
2363        assert_eq!(x_s_cpu.dim(), (n, p));
2364        assert_eq!(x_s_gpu.dim(), (n, p));
2365
2366        // PRIMARY GPU-OUTPUT PARITY (#1175): the only path-dependent quantity is
2367        // the GPU kernel matrix `K(data, centers)` → `x_s`. THIS is the genuine
2368        // device output and it must match the CPU kernel essentially bit-tight.
2369        // The downstream β is the solution of an ill-conditioned normal-equation
2370        // system that AMPLIFIES this difference by cond(XᵀX+λS) (see below), so
2371        // β is the wrong surface to gate at a flat 1e-9 — it tests the
2372        // conditioning of a SHARED CPU solve, not the GPU. Gate the GPU output
2373        // (x_s) tight; gate β with a condition-aware band; gate ŷ (the
2374        // customer-visible prediction) tight.
2375        let mut raw_xs_delta = 0.0_f64;
2376        let mut xs_scale = 0.0_f64;
2377        for (a, b) in x_s_cpu.iter().zip(x_s_gpu.iter()) {
2378            raw_xs_delta = raw_xs_delta.max((a - b).abs());
2379            xs_scale = xs_scale.max(a.abs());
2380        }
2381        // Condition number of A = XᵀX + λS (CPU path) via symmetric eigvals;
2382        // this is the factor that maps the x_s difference into the β difference.
2383        let cond = {
2384            use gam_linalg::faer_ndarray::FaerEigh;
2385            let xtx = x_s_cpu.t().dot(&x_s_cpu);
2386            let mut a = xtx;
2387            for i in 0..p {
2388                for j in 0..p {
2389                    a[(i, j)] += lambda * s_full[(i, j)];
2390                }
2391            }
2392            let (mut lo, mut hi) = (f64::INFINITY, 0.0_f64);
2393            if let Ok((vals, _)) = a.eigh(faer::Side::Lower) {
2394                for &v in vals.iter() {
2395                    lo = lo.min(v);
2396                    hi = hi.max(v);
2397                }
2398            }
2399            hi / lo.max(1e-300)
2400        };
2401        // GPU kernel output must be bit-tight to the CPU oracle: measured on a
2402        // V100 the raw design parity is ~1e-16 (one ULP, rel ~1.2e-15). Gate at
2403        // a small ULP-scaled band — a real kernel bug perturbs x_s at O(scale),
2404        // 14+ orders above this floor.
2405        assert!(
2406            raw_xs_delta <= 1e-12 * xs_scale.max(1.0),
2407            "GPU vs CPU sphere design matrix max |Δ| = {raw_xs_delta:.3e} > {:.3e} \
2408             (scale {xs_scale:.3e}) — the kernel itself drifted (this is the genuine \
2409             GPU output, NOT a conditioning artifact)",
2410            1e-12 * xs_scale.max(1.0)
2411        );
2412
2413        // Deterministic synthetic response. The intent is to give the
2414        // penalised LS solve a non-trivial right-hand side; any smooth
2415        // function of the lat/lon is fine. Use a fixed-seed pseudo-
2416        // random walk derived from coordinates so the fixture has no
2417        // RNG dependency.
2418        let mut y = ndarray::Array1::<f64>::zeros(n);
2419        for i in 0..n {
2420            let lat_rad = data_ll[(i, 0)].to_radians();
2421            let lon_rad = data_ll[(i, 1)].to_radians();
2422            // Smooth ground truth + a tiny deterministic high-freq jitter.
2423            y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2424                + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2425        }
2426
2427        // Penalised normal-equation solve via faer LLT for each path:
2428        //   (X_sᵀ X_s + λ S) β = X_sᵀ y
2429        // S is symmetric positive semi-definite; λ S makes the system
2430        // strictly positive definite once added to X_sᵀ X_s.
2431        let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2432            let xtx = x_s.t().dot(x_s);
2433            let mut a = xtx;
2434            for i in 0..p {
2435                for j in 0..p {
2436                    a[(i, j)] += lambda * s_full[(i, j)];
2437                }
2438            }
2439            let rhs = x_s.t().dot(&y);
2440            let factor = a
2441                .cholesky(Side::Lower)
2442                .expect("penalised normal equations are SPD under λ > 0");
2443            factor.solvevec(&rhs)
2444        };
2445
2446        let beta_cpu = solve_penalised(&x_s_cpu);
2447        let beta_gpu = solve_penalised(&x_s_gpu);
2448        assert_eq!(beta_cpu.len(), p);
2449        assert_eq!(beta_gpu.len(), p);
2450
2451        // Fitted values for both paths use their own design matrices —
2452        // this is the customer-visible quantity (prediction at training
2453        // points).
2454        let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2455        let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2456
2457        let mut max_beta_delta = 0.0_f64;
2458        for k in 0..p {
2459            let d = (beta_cpu[k] - beta_gpu[k]).abs();
2460            if d > max_beta_delta {
2461                max_beta_delta = d;
2462            }
2463        }
2464        let mut max_fit_delta = 0.0_f64;
2465        for i in 0..n {
2466            let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2467            if d > max_fit_delta {
2468                max_fit_delta = d;
2469            }
2470        }
2471
2472        eprintln!(
2473            "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2474             raw_xs|Δ|={raw_xs_delta:.3e} cond={cond:.3e} \
2475             max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2476        );
2477
2478        // FITTED VALUES (the customer-visible prediction) must be tight. ŷ is a
2479        // well-conditioned functional of the data even when β is not (the
2480        // ill-conditioned directions of A correspond to β components that x_s
2481        // barely projects onto, so they cancel in ŷ = x_s·β). Measured on a
2482        // V100: max|Δŷ| ~7.6e-11. Gate tight — this is the quantity that
2483        // actually matters and it does NOT inherit the conditioning blow-up.
2484        assert!(
2485            max_fit_delta <= 1.0e-9,
2486            "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2487        );
2488
2489        // COEFFICIENTS: β = A⁻¹ Xᵀy with A = XᵀX + λS. Standard perturbation
2490        // theory bounds the relative coefficient error by cond(A) times the
2491        // relative input (x_s) error: ‖Δβ‖/‖β‖ ≲ cond(A)·‖Δx_s‖/‖x_s‖. With the
2492        // GPU/CPU x_s difference at the ULP floor (~1e-16 relative) and
2493        // cond(A) ≈ 5e7 on this fixture, β legitimately differs by ~1e-7 — NOT
2494        // a kernel bug (the raw design parity gate above already proved the GPU
2495        // output is bit-tight). A flat 1e-9 β gate is therefore wrong: it
2496        // measures the conditioning of the SHARED CPU solve, not the GPU. Gate
2497        // β against the condition-aware bound with 16× headroom; a genuine
2498        // kernel defect would already have been caught upstream by the raw x_s
2499        // gate (which has no conditioning amplification).
2500        let beta_tol = (1e-15 * cond * (1.0 + xs_scale)).max(1e-9) * 16.0;
2501        assert!(
2502            max_beta_delta <= beta_tol,
2503            "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > \
2504             condition-aware tol {beta_tol:.3e} (cond={cond:.3e}). Raw design parity is \
2505             {raw_xs_delta:.3e}; a drift THIS much larger than cond·ULP is a real solve/kernel \
2506             mismatch, not conditioning."
2507        );
2508    }
2509}