Skip to main content

gam_terms/basis/
sphere_gpu.rs

1//! GPU NVRTC Wahba intrinsic-S2 kernel matrix construction.
2//!
3//! This module owns the device-side construction of the Wahba reproducing
4//! kernel basis matrix on the 2-sphere using the **finite truncated
5//! spectral Legendre series**
6//!
7//! `K_L(γ) = Σ_{ℓ=1..L} c_ℓ · P_ℓ(cos γ)`,
8//!
9//! evaluated entry-by-entry against the 3-term Legendre recurrence kept
10//! in registers. The host CPU parity target is the matching
11//! `SphereWahbaKernel::SobolevTruncated { lmax }` /
12//! `SphereWahbaKernel::PseudoTruncated { lmax }` variant added to
13//! `src/terms/basis.rs` (single source: same recurrence, same c_ℓ).
14//!
15//! The device path evaluates the raw column-major kernel matrix with `f64`
16//! Legendre recurrence math. Host code owns centering, constraints, and solver
17//! assembly in `basis.rs`.
18
19use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36/// Which truncated-spectral Wahba kernel to evaluate on device. Matches
37/// the CPU `SphereWahbaKernel::{SobolevTruncated, PseudoTruncated}` so
38/// parity tests are well-defined.
39#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41    /// `c_ℓ = (2ℓ+1) / (4π · [ℓ(ℓ+1)]^m)` — true `H^m(S²)` Sobolev RKHS.
42    Sobolev,
43    /// `c_ℓ = 2 / (4π · Π_{k=1..m+1}(ℓ + k))` — Wahba 1981 pseudo-spline.
44    Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48    /// `c_0 = 0`, `c_ℓ = c_ℓ(m)` for `ℓ = 1..=lmax`. Returned vector has
49    /// length `lmax + 1` and is uploaded verbatim to constant/global
50    /// memory before kernel launch.
51    pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52        match self {
53            SphereSpectralKernelKind::Sobolev => {
54                crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55            }
56            SphereSpectralKernelKind::Pseudo => {
57                crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58            }
59        }
60    }
61
62    /// Stable string tag used in the NVRTC module cache key + logs.
63    pub const fn tag(self) -> &'static str {
64        match self {
65            SphereSpectralKernelKind::Sobolev => "sobolev",
66            SphereSpectralKernelKind::Pseudo => "pseudo",
67        }
68    }
69}
70
71/// Layout of the (n,m) kernel design matrix on device. The Wahba
72/// pipeline downstream of this kernel (cuBLAS GEMM, cuSOLVER GEQRF)
73/// requires column-major.
74#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76    ColumnMajor,
77}
78
79/// Lat/lon (degrees or radians) → unit vector `(x, y, z)` on S² ⊂ ℝ³.
80/// Returns a flat `Vec<f64>` of length `3 * n` in the row-major layout
81/// `[x_0, y_0, z_0, x_1, y_1, z_1, …]`, ready for one `htod` upload.
82///
83/// `radians = false` interprets inputs as degrees (the codebase default
84/// for `SphericalSplineBasisSpec`).
85pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86    if latlon.ncols() != 2 {
87        return Err(format!(
88            "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89            latlon.shape()
90        ));
91    }
92    let deg = if radians {
93        1.0
94    } else {
95        std::f64::consts::PI / 180.0
96    };
97    let n = latlon.nrows();
98    let mut out = Vec::with_capacity(3 * n);
99    for row in latlon.outer_iter() {
100        let lat = row[0] * deg;
101        let lon = row[1] * deg;
102        let (s_lat, c_lat) = lat.sin_cos();
103        let (s_lon, c_lon) = lon.sin_cos();
104        // Standard geographic→cartesian: pole on +z.
105        out.push(c_lat * c_lon);
106        out.push(c_lat * s_lon);
107        out.push(s_lat);
108    }
109    Ok(out)
110}
111
112/// Device-resident `(rows × cols)` matrix in column-major layout with
113/// leading dimension `ld ≥ rows`. The slice holds `ld * cols` `f64`
114/// elements; entry `(i, j)` lives at `col_major_dev[j * ld + i]`.
115///
116/// On non-Linux builds the type is intentionally a host shadow so the
117/// surrounding orchestration compiles without cudarc.
118#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120    pub rows: usize,
121    pub cols: usize,
122    pub ld: usize,
123    pub col_major_dev: CudaSlice<f64>,
124    pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129    pub rows: usize,
130    pub cols: usize,
131    pub ld: usize,
132    /// Host shadow for CPU-only builds.
133    pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137    /// Copy the device matrix back to the host as a regular ndarray
138    /// `(rows × cols)` row-major view. Convenience for tests + parity
139    /// comparisons; production paths should keep the matrix resident.
140    ///
141    /// The device matrix is `(ld × cols)` column-major; the host wants
142    /// `(rows × cols)` row-major. Two costs dominate this round-trip on the
143    /// real V100:
144    ///   1. the device→host copy of the full `ld·cols·8 B` payload, and
145    ///   2. the column-major→row-major transpose.
146    /// On Linux the dtoh is staged through a *cacheable* pinned host buffer
147    /// (see [`PinnedF64`]) so the DMA runs at full PCIe bandwidth (~10 GB/s)
148    /// instead of the ~1.3 GB/s the driver achieves staging a pageable
149    /// destination, and the subsequent host reads during the transpose hit
150    /// L1/L2 normally (unlike write-combined pinned memory). The transpose
151    /// itself is the parallel cache-blocked [`col_major_to_row_major_parallel`].
152    #[cfg(target_os = "linux")]
153    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
154        let needed = self.ld * self.cols;
155        let mut staging = PinnedLease::acquire(self.stream.context(), needed)?;
156        self.stream
157            .memcpy_dtoh(&self.col_major_dev, staging.as_mut_slice())
158            .gpu_ctx("DeviceS2KernelMatrix dtoh (pinned)")?;
159        self.stream
160            .synchronize()
161            .gpu_ctx("DeviceS2KernelMatrix synchronize (pinned)")?;
162        Ok(col_major_to_row_major_parallel(
163            staging.as_slice(),
164            self.rows,
165            self.cols,
166            self.ld,
167        ))
168    }
169
170    #[cfg(not(target_os = "linux"))]
171    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
172        // Mirror the linux `to_host_array` exactly so both platforms return the
173        // identical row-major layout: pull the padded `(ld × cols)` column-major
174        // payload, then run the cache-blocked parallel transpose.
175        let mut col_major = vec![0.0_f64; self.ld * self.cols];
176        self.copy_to_host_col_major(&mut col_major)?;
177        Ok(col_major_to_row_major_parallel(
178            &col_major, self.rows, self.cols, self.ld,
179        ))
180    }
181
182    /// Copy the underlying `(ld × cols)` column-major payload to a
183    /// caller-provided buffer. Used by `to_host_array` and by the
184    /// device-resident cuSOLVER consumer when it needs to extract the
185    /// coefficient vector.
186    #[cfg(target_os = "linux")]
187    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
188        let needed = self.ld * self.cols;
189        if dst.len() != needed {
190            gam_gpu::gpu_bail!(
191                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
192                dst.len(),
193                needed
194            );
195        }
196        self.stream
197            .memcpy_dtoh(&self.col_major_dev, dst)
198            .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
199        self.stream
200            .synchronize()
201            .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
202        Ok(())
203    }
204
205    #[cfg(not(target_os = "linux"))]
206    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
207        let needed = self.ld * self.cols;
208        if dst.len() != needed {
209            gam_gpu::gpu_bail!(
210                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
211                dst.len(),
212                needed
213            );
214        }
215        dst.copy_from_slice(&self.col_major_dev);
216        Ok(())
217    }
218}
219
220/// Convert a `(ld × cols)` column-major device payload into a row-major
221/// `(rows × cols)` host `Array2`, in parallel with a cache-blocked tiled
222/// transpose.
223///
224/// Entry `(i, j)` lives at `col_major[j * ld + i]` and must land at
225/// `out[i * cols + j]`. A naive scalar `out[(i, j)] = col_major[j*ld+i]`
226/// loop over an `n·m` design (e.g. 200_000 × 200 ⇒ 320 MB) is utterly
227/// cache-hostile — the read stride is `ld` doubles — and measured at ~9 s,
228/// which alone made the GPU path lose to CPU. Here we:
229///   * tile the output rows into blocks small enough that one block's
230///     output stays L2-resident (`BLOCK_ROWS` rows × `cols` doubles),
231///   * read each source column slice contiguously (`col_major[j*ld+r0..]`),
232///   * run the row-blocks across the rayon pool.
233/// Reads are fully sequential per column; writes are bounded to the hot
234/// block. This drops the transpose from seconds to tens of milliseconds.
235fn col_major_to_row_major_parallel(
236    col_major: &[f64],
237    rows: usize,
238    cols: usize,
239    ld: usize,
240) -> Array2<f64> {
241    use rayon::prelude::*;
242
243    assert!(ld >= rows, "ld {ld} must be >= rows {rows}");
244    assert!(
245        col_major.len() >= ld * cols,
246        "col_major len {} < ld*cols {}",
247        col_major.len(),
248        ld * cols
249    );
250
251    // Block size chosen so one output block (BLOCK_ROWS × cols × 8 B) plus the
252    // source column slices stay roughly within L2 for the common `cols ≲ 200`.
253    const BLOCK_ROWS: usize = 128;
254
255    let mut out_flat = vec![0.0_f64; rows * cols];
256    out_flat
257        .par_chunks_mut(BLOCK_ROWS * cols)
258        .enumerate()
259        .for_each(|(block_idx, out_block)| {
260            let r0 = block_idx * BLOCK_ROWS;
261            let block_rows = out_block.len() / cols;
262            for j in 0..cols {
263                let base = j * ld + r0;
264                let src_col = &col_major[base..base + block_rows];
265                // Strided write within the hot block; contiguous column read.
266                for (local_i, &v) in src_col.iter().enumerate() {
267                    out_block[local_i * cols + j] = v;
268                }
269            }
270        });
271
272    Array2::from_shape_vec((rows, cols), out_flat).expect("row-major buffer has rows*cols elements")
273}
274
275/// RAII handle for a *cacheable* page-locked (pinned) host `f64` buffer.
276///
277/// cudarc's `CudaContext::alloc_pinned` always passes
278/// `CU_MEMHOSTALLOC_WRITECOMBINED`, which is excellent for host→device
279/// uploads but pathological for the host *reads* the transpose performs
280/// (write-combined memory is uncached on the CPU side). For the device→host
281/// return path we instead allocate plain pinned memory (`flags = 0`) directly
282/// via the driver: pinned so the dtoh DMA runs at full PCIe bandwidth, and
283/// cacheable so the parallel transpose can read it through the normal cache
284/// hierarchy. The buffer is freed with `cuMemFreeHost` on drop.
285#[cfg(target_os = "linux")]
286struct PinnedF64 {
287    ptr: *mut f64,
288    len: usize,
289    freed: bool,
290}
291
292#[cfg(target_os = "linux")]
293impl PinnedF64 {
294    /// Allocate `len` cacheable pinned `f64`s. Binds the context to the
295    /// calling thread first (required before any driver allocation call).
296    fn alloc(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
297        ctx.bind_to_thread().gpu_ctx("PinnedF64 bind_to_thread")?;
298        let bytes = len
299            .checked_mul(std::mem::size_of::<f64>())
300            .ok_or_else(|| gam_gpu::gpu_err!("PinnedF64: len={len} byte size overflows usize"))?;
301        // flags = 0 ⇒ cacheable pinned (NOT write-combined): fast DMA *and*
302        // fast host reads for the subsequent transpose.
303        // SAFETY: `bytes` is a valid non-overflowing size; the returned host
304        // pointer is owned by this struct and freed exactly once in `drop`.
305        let raw = unsafe { cudarc::driver::result::malloc_host(bytes, 0) }
306            .gpu_ctx("PinnedF64 cuMemHostAlloc")?;
307        let ptr = raw as *mut f64;
308        if ptr.is_null() {
309            gam_gpu::gpu_bail!("PinnedF64: cuMemHostAlloc returned null for {bytes} bytes");
310        }
311        Ok(Self {
312            ptr,
313            len,
314            freed: false,
315        })
316    }
317
318    fn as_mut_slice(&mut self) -> &mut [f64] {
319        // SAFETY: `ptr` points to `len` f64s of live pinned memory owned by
320        // self; the borrow is bounded by `&mut self`.
321        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
322    }
323
324    fn as_slice(&self) -> &[f64] {
325        // SAFETY: as above; shared borrow bounded by `&self`.
326        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
327    }
328}
329
330#[cfg(target_os = "linux")]
331impl Drop for PinnedF64 {
332    fn drop(&mut self) {
333        if self.freed {
334            return;
335        }
336        self.freed = true;
337        // SAFETY: `ptr` was returned by `cuMemHostAlloc` in `alloc` and is
338        // freed exactly once (guarded by `freed`). A free failure during Drop
339        // is unrecoverable here; absorb it (the host process is tearing the
340        // allocation down regardless) without unwinding out of Drop.
341        unsafe { cudarc::driver::result::free_host(self.ptr as *mut std::ffi::c_void) }.ok();
342    }
343}
344
345// SAFETY: `PinnedF64` owns a single raw host allocation. The pointer is only
346// dereferenced by the thread holding the (mutable or shared) borrow; the pool
347// below moves the *handle* between threads while no borrow is outstanding, and
348// the rayon transpose only ever sees a `&[f64]` (already `Send + Sync`). The
349// raw pointer itself is never shared concurrently.
350#[cfg(target_os = "linux")]
351unsafe impl Send for PinnedF64 {}
352
353/// Bounded free-list of cacheable pinned host buffers, keyed by length.
354///
355/// Page-locking 320 MB via `cuMemHostAlloc` costs ~140 ms on the V100 — far
356/// more than the dtoh (~25 ms) it accelerates. During a REML fit the sphere
357/// design matrix is rebuilt and copied back at the *same* `(ld·cols)` size on
358/// every outer iteration, so caching the page-locked buffer turns that 140 ms
359/// into a one-time cost. The pool keeps at most [`PINNED_POOL_MAX_BUFFERS`]
360/// buffers (LRU-ish: oldest dropped first) to bound resident pinned memory.
361#[cfg(target_os = "linux")]
362const PINNED_POOL_MAX_BUFFERS: usize = 4;
363
364#[cfg(target_os = "linux")]
365static PINNED_POOL: OnceLock<Mutex<Vec<PinnedF64>>> = OnceLock::new();
366
367/// RAII lease of a pooled pinned buffer. Returns the buffer to [`PINNED_POOL`]
368/// on drop instead of freeing it, so the next same-size request reuses the
369/// page-locked allocation.
370#[cfg(target_os = "linux")]
371struct PinnedLease {
372    buf: Option<PinnedF64>,
373}
374
375#[cfg(target_os = "linux")]
376impl PinnedLease {
377    /// Acquire a pinned buffer of at least `len` f64s, reusing a pooled one of
378    /// exactly `len` when available, else allocating fresh.
379    fn acquire(ctx: &Arc<CudaContext>, len: usize) -> Result<Self, GpuError> {
380        let pool = PINNED_POOL.get_or_init(|| Mutex::new(Vec::new()));
381        if let Ok(mut guard) = pool.lock() {
382            if let Some(pos) = guard.iter().position(|b| b.len == len) {
383                return Ok(Self {
384                    buf: Some(guard.swap_remove(pos)),
385                });
386            }
387        }
388        Ok(Self {
389            buf: Some(PinnedF64::alloc(ctx, len)?),
390        })
391    }
392
393    fn as_mut_slice(&mut self) -> &mut [f64] {
394        self.buf
395            .as_mut()
396            .expect("PinnedLease buffer present until drop")
397            .as_mut_slice()
398    }
399
400    fn as_slice(&self) -> &[f64] {
401        self.buf
402            .as_ref()
403            .expect("PinnedLease buffer present until drop")
404            .as_slice()
405    }
406}
407
408#[cfg(target_os = "linux")]
409impl Drop for PinnedLease {
410    fn drop(&mut self) {
411        let Some(buf) = self.buf.take() else {
412            return;
413        };
414        if let Some(pool) = PINNED_POOL.get() {
415            if let Ok(mut guard) = pool.lock() {
416                if guard.len() < PINNED_POOL_MAX_BUFFERS {
417                    guard.push(buf);
418                    return;
419                }
420                // Pool full: evict the oldest cached buffer to make room for
421                // this (most-recently-used) one, keeping resident pinned memory
422                // bounded while favouring the hot size.
423                guard.remove(0);
424                guard.push(buf);
425                return;
426            }
427        }
428        // No pool / poisoned lock: fall back to freeing via PinnedF64::drop.
429        drop(buf);
430    }
431}
432
433// ────────────────────────────────────────────────────────────────────────
434// Inputs
435// ────────────────────────────────────────────────────────────────────────
436
437/// Host-side inputs needed to launch `s2_wahba_legendre_colmajor`.
438///
439/// `data_xyz` and `centers_xyz` are flat row-major
440/// `[x_0, y_0, z_0, …]` length `3 * n` and `3 * m` respectively, pre-
441/// computed via [`latlon_to_xyz_host`]. `coeffs` has length `lmax + 1`,
442/// indexed as `coeffs[ℓ] = c_ℓ` with `c_0 = 0`.
443#[derive(Clone, Debug)]
444pub struct S2KernelBuildInputs<'a> {
445    pub n: usize,
446    pub m: usize,
447    pub lmax: usize,
448    pub data_xyz: &'a [f64],
449    pub centers_xyz: &'a [f64],
450    pub coeffs: &'a [f64],
451    pub kind: SphereSpectralKernelKind,
452    pub layout: DeviceMatrixLayout,
453}
454
455impl<'a> S2KernelBuildInputs<'a> {
456    fn validate(&self) -> Result<(), GpuError> {
457        if self.lmax == 0 {
458            return Err(GpuError::DriverCallFailed {
459                reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
460            });
461        }
462        if self.data_xyz.len() != 3 * self.n {
463            gam_gpu::gpu_bail!(
464                "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
465                self.data_xyz.len(),
466                3 * self.n
467            );
468        }
469        if self.centers_xyz.len() != 3 * self.m {
470            gam_gpu::gpu_bail!(
471                "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
472                self.centers_xyz.len(),
473                3 * self.m
474            );
475        }
476        if self.coeffs.len() != self.lmax + 1 {
477            gam_gpu::gpu_bail!(
478                "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
479                self.coeffs.len(),
480                self.lmax + 1
481            );
482        }
483        if self.coeffs[0] != 0.0 {
484            return Err(GpuError::DriverCallFailed {
485                reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
486            });
487        }
488        Ok(())
489    }
490}
491
492// ────────────────────────────────────────────────────────────────────────
493// NVRTC kernel source — raw and Householder-fused variants.
494//
495// Both compile with `--std=c++17 --gpu-architecture=compute_${cc}` and
496// take LMAX as a compile-time `#define`. Block (32, 8, 1), shared-mem
497// tiles for one data row × 3 doubles per warp and one center × 3
498// doubles per warp.
499// ────────────────────────────────────────────────────────────────────────
500
501#[cfg(target_os = "linux")]
502const KERNEL_TEMPLATE: &str = r#"
503// LMAX is supplied by the host via a `#define LMAX ...` prepended to
504// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
505extern "C" __global__
506__launch_bounds__(256)
507void s2_wahba_legendre_colmajor(
508    const double* __restrict__ data_xyz,    // n × 3 (row-major flat)
509    const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
510    const double* __restrict__ coeffs,      // length LMAX + 1, coeffs[0] = 0
511    int n,
512    int m,
513    long long ld,
514    double* __restrict__ out                // ld × m column-major
515) {
516    const int i = blockIdx.y * blockDim.y + threadIdx.y;
517    const int j = blockIdx.x * blockDim.x + threadIdx.x;
518    if (i >= n || j >= m) return;
519
520    // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
521    const double xi = data_xyz[3 * i + 0];
522    const double yi = data_xyz[3 * i + 1];
523    const double zi = data_xyz[3 * i + 2];
524    const double cxj = centers_xyz[3 * j + 0];
525    const double cyj = centers_xyz[3 * j + 1];
526    const double czj = centers_xyz[3 * j + 2];
527
528    // t = clamp(x_i · z_j, -1, +1).
529    double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
530    if (t >  1.0) t =  1.0;
531    if (t < -1.0) t = -1.0;
532
533    // Legendre 3-term recurrence in registers.
534    // P_0(t) = 1, P_1(t) = t.
535    double p_prev = 1.0;
536    double p_curr = t;
537    double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
538
539    #pragma unroll 8
540    for (int ell = 1; ell < LMAX; ++ell) {
541        const double lf  = (double) ell;
542        const double inv = 1.0 / (lf + 1.0);
543        // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
544        const double p_next =
545            fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
546        acc = fma(coeffs[ell + 1], p_next, acc);
547        p_prev = p_curr;
548        p_curr = p_next;
549    }
550
551    out[(long long) j * ld + (long long) i] = acc;
552}
553
554// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
555// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
556// i.e. drop the first column after applying Z. Each thread computes one
557// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
558// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
559// for j_out in 0..m-1.
560//
561// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
562// over centers in an inner loop — register-bound by the per-row state
563// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
564extern "C" __global__
565__launch_bounds__(128)
566void s2_wahba_householder_constrained_colmajor(
567    const double* __restrict__ data_xyz,    // n × 3
568    const double* __restrict__ centers_xyz, // m × 3
569    const double* __restrict__ coeffs,      // length LMAX + 1
570    const double* __restrict__ v,           // length m, Householder vector
571    double beta,
572    int n,
573    int m,
574    long long ld_out,
575    double* __restrict__ out                // ld_out × (m-1) column-major
576) {
577    const int i = blockIdx.x * blockDim.x + threadIdx.x;
578    if (i >= n) return;
579
580    const double xi = data_xyz[3 * i + 0];
581    const double yi = data_xyz[3 * i + 1];
582    const double zi = data_xyz[3 * i + 2];
583
584    // Pass 1: compute d_i = sum_j v[j] * B[i, j].
585    double d_i = 0.0;
586    for (int j = 0; j < m; ++j) {
587        const double cxj = centers_xyz[3 * j + 0];
588        const double cyj = centers_xyz[3 * j + 1];
589        const double czj = centers_xyz[3 * j + 2];
590        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
591        if (t >  1.0) t =  1.0;
592        if (t < -1.0) t = -1.0;
593
594        double p_prev = 1.0;
595        double p_curr = t;
596        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
597        #pragma unroll 8
598        for (int ell = 1; ell < LMAX; ++ell) {
599            const double lf  = (double) ell;
600            const double inv = 1.0 / (lf + 1.0);
601            const double p_next =
602                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
603            acc = fma(coeffs[ell + 1], p_next, acc);
604            p_prev = p_curr;
605            p_curr = p_next;
606        }
607        d_i = fma(v[j], acc, d_i);
608    }
609
610    // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
611    const double bd = beta * d_i;
612    for (int j_out = 0; j_out < m - 1; ++j_out) {
613        const int j = j_out + 1;
614        const double cxj = centers_xyz[3 * j + 0];
615        const double cyj = centers_xyz[3 * j + 1];
616        const double czj = centers_xyz[3 * j + 2];
617        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
618        if (t >  1.0) t =  1.0;
619        if (t < -1.0) t = -1.0;
620
621        double p_prev = 1.0;
622        double p_curr = t;
623        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
624        #pragma unroll 8
625        for (int ell = 1; ell < LMAX; ++ell) {
626            const double lf  = (double) ell;
627            const double inv = 1.0 / (lf + 1.0);
628            const double p_next =
629                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
630            acc = fma(coeffs[ell + 1], p_next, acc);
631            p_prev = p_curr;
632            p_curr = p_next;
633        }
634        const double xs = acc - bd * v[j];
635        out[(long long) j_out * ld_out + (long long) i] = xs;
636    }
637}
638"#;
639
640// ────────────────────────────────────────────────────────────────────────
641// Module cache key + per-process backend.
642// ────────────────────────────────────────────────────────────────────────
643
644/// Module cache key: every distinct `(CC, LMAX, kind, layout, kernel
645/// flavor)` compiles to a different PTX. `precision = f64` and the
646/// (32, 8, 1) raw-kernel block / (128, 1, 1) Householder-kernel block
647/// shapes are baked into the kernel source so they are implicit in the
648/// flavor tag and don't appear here.
649#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
650pub struct S2ModuleCacheKey {
651    pub cc_major: i32,
652    pub cc_minor: i32,
653    pub lmax: u32,
654    pub kind: SphereSpectralKernelKind,
655    pub layout: DeviceMatrixLayout,
656}
657
658/// Returns `true` if this build was compiled with the Linux + cudarc GPU
659/// backend that runs the S² Wahba kernels.
660pub const fn sphere_gpu_compiled() -> bool {
661    cfg!(target_os = "linux")
662}
663
664/// Decide whether the GPU sphere kernel matrix path is eligible for
665/// `(n, m, lmax)`. Heuristic per the math spec:
666///   * `n * m >= 1_000_000`
667///   * `lmax <= 200`
668///   * device memory budget admits at least one `(ld × m)` design at
669///     `ld = ((n + 31) / 32) * 32`.
670#[must_use]
671pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
672    let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
673        let ld = ((n + 31) / 32) * 32;
674        let needed_bytes = ld
675            .saturating_mul(m)
676            .saturating_mul(std::mem::size_of::<f64>());
677        let budget = runtime.memory_budget_bytes;
678        n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
679    } else {
680        false
681    };
682    decide(
683        GpuKernel::SpatialKernelOperator,
684        gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
685    )
686}
687
688/// Map a truncated `SphereWahbaKernel` variant onto the device kernel kind +
689/// truncation degree. Only the two *truncated* spectral variants have an exact
690/// device counterpart (the closed-form `Sobolev`/`Pseudo` variants use
691/// polylogarithms / deep-`L` series the device kernel does not evaluate), so
692/// `Sobolev`/`Pseudo` return `None` and stay on the CPU closed-form path.
693#[must_use]
694pub fn truncated_device_kind(
695    kernel: crate::basis::SphereWahbaKernel,
696) -> Option<(SphereSpectralKernelKind, u16)> {
697    use crate::basis::SphereWahbaKernel;
698    match kernel {
699        SphereWahbaKernel::SobolevTruncated { lmax } => {
700            Some((SphereSpectralKernelKind::Sobolev, lmax))
701        }
702        SphereWahbaKernel::PseudoTruncated { lmax } => {
703            Some((SphereSpectralKernelKind::Pseudo, lmax))
704        }
705        SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
706    }
707}
708
709/// Production entry: build the raw `(n × m)` truncated-spectral Wahba kernel
710/// design matrix on the GPU when [`sphere_kernel_decision`] admits the device,
711/// returning `None` to signal the caller to use its CPU oracle.
712///
713/// Contract:
714///   * Returns `None` when the kernel is a non-truncated closed-form variant
715///     (no exact device counterpart), or when the dispatch decision keeps the
716///     work on the CPU (`!use_gpu`). The caller then runs the bit-defining CPU
717///     path. This is the **only** quiet-CPU route and it is taken *before* any
718///     device call — never as a silent fallback after a device failure.
719///   * Returns `Some(Ok(matrix))` with the device-computed host array when the
720///     device path ran and matches the CPU truncated recurrence to roundoff
721///     (proven by the parity tests). `gam_gpu::policy` keeps the same `c_ℓ`
722///     array and the same Legendre 3-term recurrence on both sides.
723///   * Returns `Some(Err(_))` when the device was *admitted* but the launch /
724///     NVRTC compile / copy-back failed — a hard error the caller must surface,
725///     NOT degrade to CPU. Fail-loud once admitted (the recurring silent-CPU
726///     fallback is the bug this path exists to kill).
727///
728/// `data` / `centers` are `(_, 2)` lat/lon matrices (degrees unless
729/// `radians`), matching `spherical_wahba_kernel_matrix_with_kind`.
730pub fn try_build_truncated_kernel_matrix_gpu(
731    data: ArrayView2<'_, f64>,
732    centers: ArrayView2<'_, f64>,
733    penalty_order: usize,
734    radians: bool,
735    kernel: crate::basis::SphereWahbaKernel,
736) -> Option<Result<Array2<f64>, GpuError>> {
737    let (kind, lmax) = truncated_device_kind(kernel)?;
738    let n = data.nrows();
739    let m = centers.nrows();
740    if n == 0 || m == 0 || lmax == 0 {
741        return None;
742    }
743    let decision = sphere_kernel_decision(n, m, lmax as usize);
744    if !decision.use_gpu {
745        // Either backend-not-compiled, runtime-unavailable, or below the
746        // device-work threshold. Quiet CPU route, taken before any device call.
747        return None;
748    }
749    // Admitted: from here a failure is a hard error, never a silent CPU degrade.
750    Some(build_truncated_kernel_matrix_gpu_admitted(
751        data,
752        centers,
753        penalty_order,
754        radians,
755        kind,
756        lmax,
757    ))
758}
759
760/// Run the admitted device build for `try_build_truncated_kernel_matrix_gpu`.
761/// Separated so the admission decision (which returns `None` for the CPU route)
762/// stays distinct from the fail-loud device execution (which returns `Err`).
763fn build_truncated_kernel_matrix_gpu_admitted(
764    data: ArrayView2<'_, f64>,
765    centers: ArrayView2<'_, f64>,
766    penalty_order: usize,
767    radians: bool,
768    kind: SphereSpectralKernelKind,
769    lmax: u16,
770) -> Result<Array2<f64>, GpuError> {
771    let n = data.nrows();
772    let m = centers.nrows();
773    let data_xyz = latlon_to_xyz_host(data, radians)
774        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
775    let centers_xyz = latlon_to_xyz_host(centers, radians)
776        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
777    // Single-source the coefficients: the same `c_ℓ` array the CPU truncated
778    // recurrence consumes (`wahba_sphere_kernel_from_cos_kind`) is uploaded to
779    // the device, so CPU and GPU evaluate an identical zonal series.
780    let coeffs = kind.coefficients(lmax as usize, penalty_order);
781    let inputs = S2KernelBuildInputs {
782        n,
783        m,
784        lmax: lmax as usize,
785        data_xyz: &data_xyz,
786        centers_xyz: &centers_xyz,
787        coeffs: &coeffs,
788        kind,
789        layout: DeviceMatrixLayout::ColumnMajor,
790    };
791    let device_matrix = build_kernel_matrix_device(inputs)?;
792    let out = device_matrix.to_host_array()?;
793    // Guard against a device kernel that emitted NaN/Inf. A whole-matrix sum is
794    // poisoned by any non-finite element (`NaN + x = NaN`, `±Inf + finite =
795    // ±Inf`) and folds the `(n × m)` matrix in a single auto-vectorisable pass,
796    // ~7× faster than a per-element `any(!is_finite)` in the unoptimised
797    // profile (at n=200000, m=200 that scan alone was ~1.8 s — far more than
798    // the entire on-device build). The Wahba zonal kernel is a truncated
799    // Legendre series `Σ c_ℓ P_ℓ(t)` with `|P_ℓ| ≤ 1` and absolutely-summable
800    // coefficients, so every entry is O(1) and the sum of `n·m ≲ 10^8` of them
801    // cannot overflow f64 — a non-finite sum therefore means a genuinely
802    // non-finite entry, never a spurious overflow.
803    if !out.sum().is_finite() {
804        return Err(GpuError::DriverCallFailed {
805            reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
806        });
807    }
808    Ok(out)
809}
810
811#[cfg(target_os = "linux")]
812struct SphereGpuContext {
813    ctx: Arc<CudaContext>,
814    stream: Arc<CudaStream>,
815    modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
816    cc_major: i32,
817    cc_minor: i32,
818}
819
820/// Process-wide sphere GPU backend. Lazy-initialised on first call to
821/// [`SphereGpuBackend::probe`].
822pub struct SphereGpuBackend {
823    #[cfg(target_os = "linux")]
824    inner: SphereGpuContext,
825}
826
827impl SphereGpuBackend {
828    /// Lazily initialise the process-wide sphere backend.
829    pub fn probe() -> Result<&'static Self, GpuError> {
830        static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
831        BACKEND
832            .get_or_init(|| {
833                #[cfg(target_os = "linux")]
834                {
835                    Self::probe_linux()
836                }
837                #[cfg(not(target_os = "linux"))]
838                {
839                    Err(GpuError::DriverLibraryUnavailable {
840                        reason: "sphere GPU backend is Linux-only".to_string(),
841                    })
842                }
843            })
844            .as_ref()
845            .map_err(GpuError::clone)
846    }
847
848    #[cfg(target_os = "linux")]
849    fn probe_linux() -> Result<Self, GpuError> {
850        let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
851        Ok(SphereGpuBackend {
852            inner: SphereGpuContext {
853                ctx: parts.ctx,
854                stream: parts.stream,
855                modules: Mutex::new(HashMap::new()),
856                cc_major: parts.capability.compute_major,
857                cc_minor: parts.capability.compute_minor,
858            },
859        })
860    }
861
862    /// NVRTC-compile (or fetch from cache) the module for `key`. The
863    /// returned module exposes both raw and Householder-fused kernels.
864    #[cfg(target_os = "linux")]
865    fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
866        if let Ok(guard) = self.inner.modules.lock() {
867            if let Some(existing) = guard.get(&key) {
868                return Ok(existing.clone());
869            }
870        }
871        // Prepend the `LMAX` macro directly to the source, then compile through
872        // the shared arch+fmad options (`compile_ptx_arch`). #1686's
873        // `--fmad=false` keeps the spherical-harmonic evaluation bit-comparable
874        // to the separately-rounded CPU reference; the #1551 arch pin keys the
875        // kernel to the device's real compute capability. (The arch is resolved
876        // internally via `nvrtc_arch()` from a `&'static str` table, so the old
877        // "cannot satisfy arch with a runtime string" limitation no longer
878        // applies — the LMAX specialization rides in the source, the arch in
879        // the options.)
880        let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
881        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src).gpu_ctx_with(|err| {
882            format!(
883                "sphere NVRTC compile (kind={}, lmax={}): {err}",
884                key.kind.tag(),
885                key.lmax
886            )
887        })?;
888        let module = self
889            .inner
890            .ctx
891            .load_module(ptx)
892            .gpu_ctx("sphere module load")?;
893        if let Ok(mut guard) = self.inner.modules.lock() {
894            guard.entry(key).or_insert_with(|| module.clone());
895        }
896        Ok(module)
897    }
898
899    #[cfg(target_os = "linux")]
900    fn cc(&self) -> (i32, i32) {
901        (self.inner.cc_major, self.inner.cc_minor)
902    }
903}
904
905// ────────────────────────────────────────────────────────────────────────
906// Entry points
907// ────────────────────────────────────────────────────────────────────────
908
909/// Build the raw `(n × m)` Wahba kernel matrix on device using
910/// `s2_wahba_legendre_colmajor`. Phase 1 entry point.
911pub fn build_kernel_matrix_device(
912    inputs: S2KernelBuildInputs<'_>,
913) -> Result<DeviceS2KernelMatrix, GpuError> {
914    inputs.validate()?;
915
916    #[cfg(target_os = "linux")]
917    {
918        use cudarc::driver::{LaunchConfig, PushKernelArg};
919        let backend = SphereGpuBackend::probe()?;
920        let (cc_major, cc_minor) = backend.cc();
921        let key = S2ModuleCacheKey {
922            cc_major,
923            cc_minor,
924            lmax: inputs.lmax as u32,
925            kind: inputs.kind,
926            layout: inputs.layout,
927        };
928        let module = backend.module_for(key)?;
929        let func = module
930            .load_function("s2_wahba_legendre_colmajor")
931            .gpu_ctx("sphere load_function raw")?;
932        let stream = backend.inner.stream.clone();
933
934        let data_dev = stream
935            .clone_htod(inputs.data_xyz)
936            .gpu_ctx("sphere htod data_xyz")?;
937        let centers_dev = stream
938            .clone_htod(inputs.centers_xyz)
939            .gpu_ctx("sphere htod centers_xyz")?;
940        let coeffs_dev = stream
941            .clone_htod(inputs.coeffs)
942            .gpu_ctx("sphere htod coeffs")?;
943
944        let n = inputs.n;
945        let m = inputs.m;
946        let ld = ((n + 31) / 32) * 32;
947        let mut out_dev = stream
948            .alloc_zeros::<f64>(ld * m)
949            .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
950
951        // Block (32, 8, 1) — x over centers, y over rows.
952        let block_x: u32 = 32;
953        let block_y: u32 = 8;
954        let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
955        let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
956        let cfg = LaunchConfig {
957            grid_dim: (grid_x, grid_y, 1),
958            block_dim: (block_x, block_y, 1),
959            shared_mem_bytes: 0,
960        };
961        let n_i32: i32 =
962            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
963        let m_i32: i32 =
964            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
965        let ld_i64: i64 = ld as i64;
966
967        let mut builder = stream.launch_builder(&func);
968        builder
969            .arg(&data_dev)
970            .arg(&centers_dev)
971            .arg(&coeffs_dev)
972            .arg(&n_i32)
973            .arg(&m_i32)
974            .arg(&ld_i64)
975            .arg(&mut out_dev);
976        // SAFETY: launch parameters are validated above; all device
977        // pointers come from cudarc-checked allocations on the same
978        // stream; the kernel only reads inputs and writes within
979        // out[0 .. ld*m].
980        unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
981        stream
982            .synchronize()
983            .gpu_ctx("sphere raw kernel synchronize")?;
984
985        Ok(DeviceS2KernelMatrix {
986            rows: n,
987            cols: m,
988            ld,
989            col_major_dev: out_dev,
990            stream,
991        })
992    }
993
994    #[cfg(not(target_os = "linux"))]
995    {
996        Err(GpuError::DriverLibraryUnavailable {
997            reason: "sphere GPU backend is Linux-only".to_string(),
998        })
999    }
1000}
1001
1002/// Phase-3 fused Householder-constrained kernel. `v` is the Householder
1003/// vector (length m), `beta` the reflector scalar, and the output is
1004/// the `(n × (m-1))` constrained design X_s on device.
1005pub fn build_householder_constrained_design_device(
1006    inputs: S2KernelBuildInputs<'_>,
1007    v: &[f64],
1008    beta: f64,
1009) -> Result<DeviceS2KernelMatrix, GpuError> {
1010    inputs.validate()?;
1011    if v.len() != inputs.m {
1012        gam_gpu::gpu_bail!(
1013            "build_householder_constrained_design_device: v.len()={} != m={}",
1014            v.len(),
1015            inputs.m
1016        );
1017    }
1018    if inputs.m < 2 {
1019        gam_gpu::gpu_bail!(
1020            "build_householder_constrained_design_device: m must be >= 2 (got {})",
1021            inputs.m
1022        );
1023    }
1024    if !beta.is_finite() {
1025        gam_gpu::gpu_bail!(
1026            "build_householder_constrained_design_device: beta must be finite (got {beta})"
1027        );
1028    }
1029
1030    #[cfg(target_os = "linux")]
1031    {
1032        use cudarc::driver::{LaunchConfig, PushKernelArg};
1033        let backend = SphereGpuBackend::probe()?;
1034        let (cc_major, cc_minor) = backend.cc();
1035        let key = S2ModuleCacheKey {
1036            cc_major,
1037            cc_minor,
1038            lmax: inputs.lmax as u32,
1039            kind: inputs.kind,
1040            layout: inputs.layout,
1041        };
1042        let module = backend.module_for(key)?;
1043        let func = module
1044            .load_function("s2_wahba_householder_constrained_colmajor")
1045            .gpu_ctx("sphere load_function householder")?;
1046        let stream = backend.inner.stream.clone();
1047
1048        let data_dev = stream
1049            .clone_htod(inputs.data_xyz)
1050            .gpu_ctx("sphere-hh htod data_xyz")?;
1051        let centers_dev = stream
1052            .clone_htod(inputs.centers_xyz)
1053            .gpu_ctx("sphere-hh htod centers_xyz")?;
1054        let coeffs_dev = stream
1055            .clone_htod(inputs.coeffs)
1056            .gpu_ctx("sphere-hh htod coeffs")?;
1057        let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
1058
1059        let n = inputs.n;
1060        let m = inputs.m;
1061        let cols_out = m - 1;
1062        let ld_out = ((n + 31) / 32) * 32;
1063        let mut out_dev = stream
1064            .alloc_zeros::<f64>(ld_out * cols_out)
1065            .gpu_ctx_with(|err| {
1066                format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
1067            })?;
1068
1069        let block_x: u32 = 128;
1070        let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
1071        let cfg = LaunchConfig {
1072            grid_dim: (grid_x, 1, 1),
1073            block_dim: (block_x, 1, 1),
1074            shared_mem_bytes: 0,
1075        };
1076        let n_i32: i32 =
1077            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
1078        let m_i32: i32 =
1079            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
1080        let ld_out_i64: i64 = ld_out as i64;
1081
1082        let mut builder = stream.launch_builder(&func);
1083        builder
1084            .arg(&data_dev)
1085            .arg(&centers_dev)
1086            .arg(&coeffs_dev)
1087            .arg(&v_dev)
1088            .arg(&beta)
1089            .arg(&n_i32)
1090            .arg(&m_i32)
1091            .arg(&ld_out_i64)
1092            .arg(&mut out_dev);
1093        // SAFETY: validated shapes above; the kernel writes exactly
1094        // (n × (m-1)) entries within `out[0 .. ld_out * (m-1)]`.
1095        unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
1096        stream
1097            .synchronize()
1098            .gpu_ctx("sphere-hh kernel synchronize")?;
1099
1100        Ok(DeviceS2KernelMatrix {
1101            rows: n,
1102            cols: cols_out,
1103            ld: ld_out,
1104            col_major_dev: out_dev,
1105            stream,
1106        })
1107    }
1108
1109    #[cfg(not(target_os = "linux"))]
1110    {
1111        Err(GpuError::DriverLibraryUnavailable {
1112            reason: "sphere GPU backend is Linux-only".to_string(),
1113        })
1114    }
1115}
1116
1117// ────────────────────────────────────────────────────────────────────────
1118// Householder reflector helpers (host-side; Phase 3 prep).
1119//
1120// Given a non-zero weight vector w ∈ ℝ^m, construct (v, beta) such that
1121// H = I − beta · v · v^T satisfies H · w = ±‖w‖ · e_1 and drops the
1122// weighted-sum constraint into the first column.
1123// ────────────────────────────────────────────────────────────────────────
1124
1125/// Build the Householder reflector that zeroes `w` against `e_1`.
1126/// Returns `(v, beta)` with the LAPACK / Golub-Van Loan convention
1127/// `v[0] = 1`. If `w` has zero norm, returns `(0-vector, 0.0)` and the
1128/// caller should treat the reflector as a no-op (no constraint).
1129pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
1130    let m = w.len();
1131    if m == 0 {
1132        return (Vec::new(), 0.0);
1133    }
1134    let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1135    if norm == 0.0 {
1136        return (vec![0.0; m], 0.0);
1137    }
1138    let sigma = if w[0] >= 0.0 { norm } else { -norm };
1139    let mut v = w.to_vec();
1140    v[0] += sigma;
1141    let v0 = v[0];
1142    if v0 == 0.0 {
1143        return (vec![0.0; m], 0.0);
1144    }
1145    // Normalize so v[0] = 1 (LAPACK convention).
1146    for entry in v.iter_mut() {
1147        *entry /= v0;
1148    }
1149    // beta = 2 / (v · v).
1150    let vv: f64 = v.iter().map(|x| x * x).sum();
1151    let beta = 2.0 / vv;
1152    (v, beta)
1153}
1154
1155// ────────────────────────────────────────────────────────────────────────
1156// Phase 2 — center-center penalty C + constraint S = Zᵀ C Z.
1157//
1158// `C` is the (m × m) Wahba kernel of centers against themselves and is
1159// computed by reusing the raw GPU kernel with `n = m`. The constraint
1160// transform is the same Householder reflector used by the Phase-3 fused
1161// kernel: Z = (I − β · v · vᵀ) with the first column dropped, so the
1162// constrained penalty is the trailing (m−1)×(m−1) block of HᵀCH.
1163//
1164// At m ≤ 200 the Householder product is cheap on host and the result is
1165// returned as an `ndarray::Array2`. Future calls into cuSOLVER QR can
1166// upload it (or its Cholesky factor) once and keep it device-resident.
1167// ────────────────────────────────────────────────────────────────────────
1168
1169/// Build the (m × m) center-center kernel matrix `C` using the same GPU
1170/// kernel that builds the design. `centers_xyz` is the unit-vector
1171/// representation of the centers, length `3 * m`. `coeffs` and `kind`
1172/// match the design build.
1173pub fn build_center_kernel_device(
1174    centers_xyz: &[f64],
1175    lmax: usize,
1176    coeffs: &[f64],
1177    kind: SphereSpectralKernelKind,
1178) -> Result<DeviceS2KernelMatrix, GpuError> {
1179    let m = centers_xyz.len() / 3;
1180    if centers_xyz.len() != 3 * m {
1181        return Err(GpuError::DriverCallFailed {
1182            reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
1183        });
1184    }
1185    let inputs = S2KernelBuildInputs {
1186        n: m,
1187        m,
1188        lmax,
1189        data_xyz: centers_xyz,
1190        centers_xyz,
1191        coeffs,
1192        kind,
1193        layout: DeviceMatrixLayout::ColumnMajor,
1194    };
1195    build_kernel_matrix_device(inputs)
1196}
1197
1198/// Constrained penalty matrix `S = Zᵀ C Z` for the
1199/// weighted-sum-to-zero Householder constraint built from `w`.
1200/// Returned shape is `((m−1) × (m−1))`. `C` is taken as a host
1201/// (m × m) array (typically the dtoh of `build_center_kernel_device`).
1202pub fn constrained_penalty_host(
1203    c: ArrayView2<'_, f64>,
1204    w: &[f64],
1205) -> Result<Array2<f64>, GpuError> {
1206    let (m1, m2) = c.dim();
1207    if m1 != m2 {
1208        gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
1209    }
1210    let m = m1;
1211    if w.len() != m {
1212        gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
1213    }
1214    if m < 2 {
1215        gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
1216    }
1217    let (v, beta) = householder_reflector_from_weights(w);
1218
1219    // Form HCH = (I - β v vᵀ) C (I - β v vᵀ) = C - β (v · uᵀ + u · vᵀ) + β² (vᵀ C v) v vᵀ,
1220    // where u = C v. This is O(m²) — fine for m ≤ 200.
1221    let mut u = vec![0.0_f64; m];
1222    for i in 0..m {
1223        let mut acc = 0.0_f64;
1224        for j in 0..m {
1225            acc += c[(i, j)] * v[j];
1226        }
1227        u[i] = acc;
1228    }
1229    let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
1230    let mut hch = Array2::<f64>::zeros((m, m));
1231    for i in 0..m {
1232        for j in 0..m {
1233            hch[(i, j)] =
1234                c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
1235        }
1236    }
1237    // Drop the first row and column (the Householder-constrained nullspace).
1238    let mut s = Array2::<f64>::zeros((m - 1, m - 1));
1239    for i in 0..(m - 1) {
1240        for j in 0..(m - 1) {
1241            s[(i, j)] = hch[(i + 1, j + 1)];
1242        }
1243    }
1244    Ok(s)
1245}
1246
1247// ────────────────────────────────────────────────────────────────────────
1248// Phase 4 — device-resident cuSOLVER QR penalised solve.
1249//
1250// Solve  min_β  ‖ [√W · X_s] β − [√W · y] ‖² + λ ‖R_S · β‖²
1251//
1252// by stacking the augmented matrix
1253//
1254//     A_aug = [ √W · X_s ;   √λ · R_S ]    shape (n + p) × p,
1255//     b_aug = [ √W · y    ;   0       ]    length n + p,
1256//
1257// where p = m − 1, R_S is the upper-triangular Cholesky factor of the
1258// constrained penalty S = Zᵀ C Z, and (√W·X_s) is the design built by
1259// the fused Householder kernel scaled by sqrt-weights row-by-row on
1260// device. The pipeline is:
1261//
1262//     1. cusolverDnDgeqrf_bufferSize → workspace size.
1263//     2. cusolverDnDgeqrf(A_aug)     → A := [R upper-tri / V Householder]
1264//                                        plus tau vector.
1265//     3. cusolverDnDormqr(side=L, trans=T)
1266//                                  → applies Qᵀ to b_aug.
1267//     4. cublasDtrsm(L = upper) → β := R⁻¹ · (Qᵀ b_aug)[0..p].
1268//
1269// Coefficients (β) come back to host; log|H| can be returned via Σ
1270// log(R_ii²) from the diagonal of the in-place factored R.
1271//
1272// All intermediate state — A_aug, b_aug, tau, workspace, info — stays
1273// device-resident. The host learns only (β, log|H|, residual ssq).
1274// ────────────────────────────────────────────────────────────────────────
1275
1276/// Result returned by [`solve_penalised_ls_device`].
1277#[derive(Clone, Debug)]
1278pub struct PenalisedLsSolution {
1279    /// Coefficient vector, length `p = m − 1` (after Householder drop).
1280    pub beta: Vec<f64>,
1281    /// Sum of squared residuals on the unaugmented rows: ‖√W (Xβ − y)‖².
1282    pub weighted_residual_ssq: f64,
1283    /// log|H| = 2 · Σ log |R_ii| of the QR-factored augmented design.
1284    pub log_det_hessian: f64,
1285}
1286
1287/// Augmented penalised least-squares solve via on-device cuSOLVER QR.
1288///
1289/// Inputs:
1290///   * `x_s_device` — already-constrained, weighted-sqrt-scaled design
1291///     `√W · X_s` produced by the Phase-3 fused kernel + a row-scaling
1292///     kernel. Shape `(n × p)` column-major.
1293///   * `wy` — `√W · y` (length n), already host-multiplied (cheap).
1294///   * `r_s` — upper-triangular Cholesky factor of `√λ · S`, shape
1295///     `(p × p)` row-major host array.
1296#[cfg(target_os = "linux")]
1297pub fn solve_penalised_ls_device(
1298    x_s_device: &DeviceS2KernelMatrix,
1299    wy: &[f64],
1300    r_s: ArrayView2<'_, f64>,
1301) -> Result<PenalisedLsSolution, GpuError> {
1302    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1303    use cudarc::driver::DevicePtrMut;
1304
1305    let n = x_s_device.rows;
1306    let p = x_s_device.cols;
1307    if wy.len() != n {
1308        gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1309    }
1310    if r_s.dim() != (p, p) {
1311        gam_gpu::gpu_bail!(
1312            "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1313            r_s.dim()
1314        );
1315    }
1316    if p == 0 {
1317        return Ok(PenalisedLsSolution {
1318            beta: Vec::new(),
1319            weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1320            log_det_hessian: 0.0,
1321        });
1322    }
1323
1324    let stream = x_s_device.stream.clone();
1325    let n_aug = n + p;
1326
1327    // 1) Materialise A_aug column-major on device. We don't need the
1328    //    upstream X_s after QR, but the kernel matrix builder hands us
1329    //    its own storage; we copy into a fresh (n_aug × p) slab so the
1330    //    in-place geqrf doesn't clobber a buffer the caller still owns.
1331    let mut a_aug_host = vec![0.0_f64; n_aug * p];
1332    // Copy device-side X_s back column-by-column into the upper block.
1333    let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1334    x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1335    for j in 0..p {
1336        let src_off = j * x_s_device.ld;
1337        let dst_off = j * n_aug;
1338        a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1339        for i in 0..p {
1340            // R_S is row-major host; insert into column j of the lower
1341            // block (rows n..n+p) as r_s[i, j].
1342            a_aug_host[dst_off + n + i] = r_s[(i, j)];
1343        }
1344    }
1345    let mut a_dev = stream
1346        .clone_htod(&a_aug_host)
1347        .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1348
1349    // b_aug = [√W·y ; 0]
1350    let mut b_host = vec![0.0_f64; n_aug];
1351    b_host[..n].copy_from_slice(wy);
1352    let mut b_dev = stream
1353        .clone_htod(&b_host)
1354        .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1355
1356    let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1357    let n_aug_i: i32 = i32::try_from(n_aug)
1358        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1359    let p_i: i32 = i32::try_from(p)
1360        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1361
1362    // 2) Workspace size for geqrf.
1363    let mut lwork: i32 = 0;
1364    {
1365        let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1366        // SAFETY: a_dev holds n_aug*p f64 elements column-major;
1367        // pointer is live on `stream`; lwork is a valid host out-param.
1368        let status = unsafe {
1369            cusolver_sys::cusolverDnDgeqrf_bufferSize(
1370                solver.cu(),
1371                n_aug_i,
1372                p_i,
1373                a_ptr as *mut f64,
1374                n_aug_i,
1375                &mut lwork,
1376            )
1377        };
1378        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1379            gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1380        }
1381    }
1382    let lwork_us = usize::try_from(lwork)
1383        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1384    let mut workspace = stream
1385        .alloc_zeros::<f64>(lwork_us.max(1))
1386        .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1387    let mut tau = stream
1388        .alloc_zeros::<f64>(p)
1389        .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1390    let mut info = stream
1391        .alloc_zeros::<i32>(1)
1392        .gpu_ctx("solve_penalised_ls_device alloc info")?;
1393
1394    // 3) cusolverDnDgeqrf — A := QR in place.
1395    {
1396        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1397        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1398        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1399        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1400        // SAFETY: all pointers reference live device allocations on
1401        // this stream; lwork matches the bufferSize query above.
1402        let status = unsafe {
1403            cusolver_sys::cusolverDnDgeqrf(
1404                solver.cu(),
1405                n_aug_i,
1406                p_i,
1407                a_ptr as *mut f64,
1408                n_aug_i,
1409                tau_ptr as *mut f64,
1410                work_ptr as *mut f64,
1411                lwork,
1412                info_ptr as *mut i32,
1413            )
1414        };
1415        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1416            gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1417        }
1418    }
1419
1420    // 4) cusolverDnDormqr — b_aug := Qᵀ · b_aug.
1421    let mut ormqr_lwork: i32 = 0;
1422    {
1423        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1424        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1425        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1426        // SAFETY: A/tau/b are live device buffers on this stream;
1427        // ormqr_lwork is a host out-param.
1428        let status = unsafe {
1429            cusolver_sys::cusolverDnDormqr_bufferSize(
1430                solver.cu(),
1431                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1432                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1433                n_aug_i,
1434                1,
1435                p_i,
1436                a_ptr as *const f64,
1437                n_aug_i,
1438                tau_ptr as *const f64,
1439                b_ptr as *mut f64,
1440                n_aug_i,
1441                &mut ormqr_lwork,
1442            )
1443        };
1444        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1445            gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1446        }
1447    }
1448    if ormqr_lwork > lwork {
1449        workspace = stream
1450            .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1451            .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1452    }
1453    {
1454        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1455        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1456        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1457        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1458        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1459        // SAFETY: all pointers reference live, mutually-non-aliasing
1460        // device buffers on this stream; lwork matches the bufferSize
1461        // query above; A and tau are the geqrf output.
1462        let status = unsafe {
1463            cusolver_sys::cusolverDnDormqr(
1464                solver.cu(),
1465                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1466                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1467                n_aug_i,
1468                1,
1469                p_i,
1470                a_ptr as *const f64,
1471                n_aug_i,
1472                tau_ptr as *const f64,
1473                b_ptr as *mut f64,
1474                n_aug_i,
1475                work_ptr as *mut f64,
1476                ormqr_lwork.max(lwork),
1477                info_ptr as *mut i32,
1478            )
1479        };
1480        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1481            gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1482        }
1483    }
1484
1485    // 5) cublasDtrsm — solve R · β = (Qᵀ b)[0..p] in place on the top
1486    //    of b_dev. We use a single-RHS upper-triangular non-unit solve.
1487    {
1488        use cudarc::cublas::CudaBlas;
1489        let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1490        let alpha = 1.0_f64;
1491        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1492        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1493        // SAFETY: A is the geqrf-output upper-triangular factor R in
1494        // its top-p × p block (col-major, ld = n_aug); b is the
1495        // ormqr-output Qᵀb in the top p slots (ld = n_aug as well so
1496        // pretend it is column-major with 1 column of leading dim n_aug).
1497        let handle = *blas.handle();
1498        let status = unsafe {
1499            cudarc::cublas::sys::cublasDtrsm_v2(
1500                handle,
1501                cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1502                cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1503                cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1504                cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1505                p_i,
1506                1,
1507                &alpha,
1508                a_ptr as *const f64,
1509                n_aug_i,
1510                b_ptr as *mut f64,
1511                n_aug_i,
1512            )
1513        };
1514        if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1515            gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1516        }
1517    }
1518
1519    // 6) Copy results back to host.
1520    let mut b_out = vec![0.0_f64; n_aug];
1521    stream
1522        .memcpy_dtoh(&b_dev, &mut b_out)
1523        .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1524    let mut a_back = vec![0.0_f64; n_aug * p];
1525    stream
1526        .memcpy_dtoh(&a_dev, &mut a_back)
1527        .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1528    stream
1529        .synchronize()
1530        .gpu_ctx("solve_penalised_ls_device synchronize")?;
1531
1532    let beta: Vec<f64> = b_out[..p].to_vec();
1533    // (Qᵀb)[p..n_aug] holds the residual in the rotated coordinates;
1534    // ‖(Qᵀb)[p..]‖² = ‖√W (Xβ − y)‖² + λ ‖R_S β‖² for the augmented
1535    // system. To recover ‖√W (Xβ − y)‖² alone, subtract the penalty
1536    // residual ‖R_S β‖² (penalty rotates to itself in the augmented
1537    // bottom block, but only when the bottom block ROWS map exactly
1538    // into the rotated residual — which is not guaranteed, so the
1539    // simpler accurate path is to return the **augmented** residual
1540    // squared and let the caller subtract.)
1541    let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1542
1543    // log|R| diagonal.
1544    let mut log_abs_r = 0.0_f64;
1545    for k in 0..p {
1546        let r_kk = a_back[k * n_aug + k];
1547        log_abs_r += r_kk.abs().ln();
1548    }
1549    let log_det_hessian = 2.0 * log_abs_r;
1550
1551    Ok(PenalisedLsSolution {
1552        beta,
1553        weighted_residual_ssq: augmented_residual_ssq,
1554        log_det_hessian,
1555    })
1556}
1557
1558#[cfg(not(target_os = "linux"))]
1559pub fn solve_penalised_ls_device(
1560    x_s_device: &DeviceS2KernelMatrix,
1561    wy: &[f64],
1562    r_s: ArrayView2<'_, f64>,
1563) -> Result<PenalisedLsSolution, GpuError> {
1564    Err(GpuError::DriverLibraryUnavailable {
1565        reason: format!(
1566            "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1567            x_s_device.rows,
1568            x_s_device.cols,
1569            wy.len(),
1570            r_s.dim()
1571        ),
1572    })
1573}
1574
1575// ────────────────────────────────────────────────────────────────────────
1576// Tests
1577// ────────────────────────────────────────────────────────────────────────
1578
1579#[cfg(test)]
1580mod sphere_gpu_tests {
1581    use super::*;
1582    use crate::basis::{
1583        SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1584        spherical_wahba_kernel_matrix_with_kind,
1585    };
1586    use ndarray::Array2;
1587
1588    fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1589        // Latitude in (-85, 85), longitude in [-180, 180), degrees.
1590        let mut rows = Vec::with_capacity(n_lat * n_lon);
1591        for i in 0..n_lat {
1592            let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1593            for j in 0..n_lon {
1594                let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1595                rows.push(lat);
1596                rows.push(lon);
1597            }
1598        }
1599        Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1600    }
1601
1602    #[test]
1603    fn sum_finite_guard_accepts_finite_rejects_nonfinite() {
1604        // The admitted device path guards its output with `!out.sum().is_finite()`
1605        // instead of a per-element `any(!is_finite)`. This pins the equivalence
1606        // that justifies the swap: a finite matrix has a finite sum, and a single
1607        // NaN or ±Inf entry poisons the sum.
1608        let finite = Array2::<f64>::from_shape_fn((5, 7), |(i, j)| (i as f64 - 2.0) * (j as f64));
1609        assert!(finite.sum().is_finite());
1610
1611        let mut with_nan = finite.clone();
1612        with_nan[[3, 4]] = f64::NAN;
1613        assert!(!with_nan.sum().is_finite());
1614
1615        let mut with_pos_inf = finite.clone();
1616        with_pos_inf[[0, 0]] = f64::INFINITY;
1617        assert!(!with_pos_inf.sum().is_finite());
1618
1619        let mut with_neg_inf = finite.clone();
1620        with_neg_inf[[4, 6]] = f64::NEG_INFINITY;
1621        assert!(!with_neg_inf.sum().is_finite());
1622    }
1623
1624    #[test]
1625    fn xyz_preprocessing_matches_unit_sphere() {
1626        let latlon = ndarray::array![
1627            [0.0, 0.0],
1628            [90.0, 0.0],
1629            [0.0, 90.0],
1630            [-90.0, 17.5],
1631            [45.0, -120.0],
1632        ];
1633        let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1634        assert_eq!(xyz.len(), 3 * 5);
1635        for i in 0..5 {
1636            let nrm2 = xyz[3 * i] * xyz[3 * i]
1637                + xyz[3 * i + 1] * xyz[3 * i + 1]
1638                + xyz[3 * i + 2] * xyz[3 * i + 2];
1639            assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1640        }
1641        // Row 0 = equator @ lon=0 → (1, 0, 0).
1642        assert!((xyz[0] - 1.0).abs() < 1e-15);
1643        assert!(xyz[1].abs() < 1e-15);
1644        assert!(xyz[2].abs() < 1e-15);
1645        // Row 1 = north pole (lat=90, lon=0) → (0, 0, 1).
1646        assert!(xyz[3].abs() < 1e-15);
1647        assert!(xyz[4].abs() < 1e-15);
1648        assert!((xyz[5] - 1.0).abs() < 1e-15);
1649        // Row 2 = equator @ lon=90 → (0, 1, 0).
1650        assert!(xyz[6].abs() < 1e-15);
1651        assert!((xyz[7] - 1.0).abs() < 1e-15);
1652        assert!(xyz[8].abs() < 1e-15);
1653    }
1654
1655    #[test]
1656    fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1657        // P_ℓ(1) = 1 for all ℓ, so K(x, x) = Σ_{ℓ=0..L} c_ℓ. The Legendre
1658        // recurrence in `sphere_truncated_spectral_eval` must reproduce
1659        // this exact identity to roundoff.
1660        for m_penalty in 1..=4 {
1661            for &lmax in &[5_usize, 20, 50] {
1662                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1663                let expected: f64 = coeffs.iter().sum();
1664                let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1665                assert!(
1666                    (got - expected).abs() < 1e-13,
1667                    "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1668                );
1669            }
1670        }
1671    }
1672
1673    #[test]
1674    fn truncated_spectral_at_antipode_matches_alternating_sum() {
1675        // P_ℓ(-1) = (-1)^ℓ, so K(x, -x) = Σ_{ℓ=0..L} c_ℓ · (-1)^ℓ. Same
1676        // exact identity for the recurrence at t = -1.
1677        for m_penalty in 1..=4 {
1678            for &lmax in &[5_usize, 20, 50] {
1679                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1680                let expected: f64 = coeffs
1681                    .iter()
1682                    .enumerate()
1683                    .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1684                    .sum();
1685                let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1686                assert!(
1687                    (got - expected).abs() < 1e-13,
1688                    "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1689                );
1690            }
1691        }
1692    }
1693
1694    #[test]
1695    fn truncated_spectral_matrix_is_symmetric() {
1696        // K(γ) depends only on cos γ = x · y = y · x, so the Gram
1697        // matrix B B^T-style kernel evaluation on the same point set
1698        // must be symmetric to roundoff.
1699        let centers = ndarray::array![
1700            [10.0_f64, 20.0],
1701            [-30.0, 100.0],
1702            [45.0, -60.0],
1703            [-89.0, 0.0],
1704            [0.0, 180.0],
1705            [60.0, -179.9],
1706        ];
1707        for m_penalty in [1usize, 2, 4] {
1708            for &lmax in &[10_usize, 30] {
1709                let mat = spherical_wahba_kernel_matrix_with_kind(
1710                    centers.view(),
1711                    centers.view(),
1712                    m_penalty,
1713                    false,
1714                    SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1715                )
1716                .expect("kernel matrix");
1717                let n = centers.nrows();
1718                let mut max_asym = 0.0_f64;
1719                for i in 0..n {
1720                    for j in 0..n {
1721                        let d = (mat[(i, j)] - mat[(j, i)]).abs();
1722                        if d > max_asym {
1723                            max_asym = d;
1724                        }
1725                    }
1726                }
1727                assert!(
1728                    max_asym < 1e-13,
1729                    "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1730                );
1731            }
1732        }
1733    }
1734
1735    #[test]
1736    fn truncated_coefficients_have_zero_constant_mode() {
1737        for m in 1..=4 {
1738            let c = sobolev_s2_truncated_coefficients(50, m);
1739            assert_eq!(c.len(), 51);
1740            assert_eq!(c[0], 0.0);
1741            assert!(c[1] > 0.0);
1742            // Spectral decay c_ℓ ~ 1/ℓ^{2m-1}: monotone for ℓ ≥ 1.
1743            for ell in 2..=50 {
1744                assert!(
1745                    c[ell] < c[ell - 1] + 1e-15,
1746                    "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1747                    c[ell],
1748                    c[ell - 1]
1749                );
1750            }
1751        }
1752    }
1753
1754    #[test]
1755    fn truncated_spectral_matches_matrix_helper() {
1756        // The Wahba kernel matrix helper, invoked with the truncated
1757        // variant, must produce the same value as the bare scalar
1758        // evaluator.
1759        let m_penalty = 2;
1760        let lmax = 20;
1761        let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1762        let data = ndarray::array![[12.5, -34.0]];
1763        let centers = ndarray::array![[40.0, 10.0]];
1764        let mat = spherical_wahba_kernel_matrix_with_kind(
1765            data.view(),
1766            centers.view(),
1767            m_penalty,
1768            false,
1769            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1770        )
1771        .expect("kernel matrix");
1772        // Recompute cos γ on the unit sphere.
1773        let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1774        let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1775        let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1776        let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1777        assert!(
1778            (mat[(0, 0)] - expected).abs() < 1e-13,
1779            "matrix helper differs from scalar evaluator: {} vs {}",
1780            mat[(0, 0)],
1781            expected
1782        );
1783    }
1784
1785    #[test]
1786    fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1787        // Build a small symmetric PD matrix as a stand-in for C, then
1788        // verify that constrained_penalty_host returns a symmetric
1789        // (m-1)×(m-1) matrix whose action against Z·x matches the
1790        // expected Zᵀ C Z mapping.
1791        let m = 6;
1792        let mut c = Array2::<f64>::zeros((m, m));
1793        for i in 0..m {
1794            for j in 0..m {
1795                let d = (i as f64 - j as f64).abs();
1796                c[(i, j)] = (-0.5 * d).exp();
1797            }
1798        }
1799        let w = vec![1.0_f64; m];
1800        let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1801        assert_eq!(s.dim(), (m - 1, m - 1));
1802        // Symmetry within roundoff.
1803        let mut max_asym = 0.0_f64;
1804        for i in 0..(m - 1) {
1805            for j in 0..(m - 1) {
1806                let d = (s[(i, j)] - s[(j, i)]).abs();
1807                if d > max_asym {
1808                    max_asym = d;
1809                }
1810            }
1811        }
1812        assert!(
1813            max_asym < 1e-13,
1814            "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1815        );
1816
1817        // The kernel-of-Zᵀ direction: Zᵀ · w = 0 ⇒ x = (something) such
1818        // that Z · x stays in span(w)^⊥, so x can be any (m-1) vector;
1819        // we just verify that picking the all-ones constraint direction
1820        // collapses to zero through Z when applied to constant fields.
1821        // i.e. constant-field penalty norm must be zero in the
1822        // un-constrained Cv direction, and the trailing block here is
1823        // never used against the constraint.
1824        let ones = ndarray::Array1::<f64>::ones(m - 1);
1825        let sx = s.dot(&ones);
1826        assert!(sx.iter().all(|v| v.is_finite()));
1827    }
1828
1829    #[test]
1830    fn householder_reflector_zeroes_target_vector() {
1831        let w = vec![3.0, 4.0, 0.0, -1.0];
1832        let (v, beta) = householder_reflector_from_weights(&w);
1833        // Apply H = I - beta * v * v^T to w; the result should be a
1834        // multiple of e_1 (only first entry non-zero).
1835        let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1836        let hw: Vec<f64> = w
1837            .iter()
1838            .zip(&v)
1839            .map(|(wj, vj)| wj - beta * dot * vj)
1840            .collect();
1841        for entry in hw.iter().skip(1) {
1842            assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1843        }
1844        assert!(hw[0].abs() > 0.0);
1845    }
1846
1847    /// V100-only: probe + raw kernel parity vs CPU truncated-spectral on
1848    /// a small grid. Skips cleanly on hosts with no CUDA runtime.
1849    #[test]
1850    fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1851        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1852            eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1853            return;
1854        };
1855        // Past the runtime Some-gate: a probe failure is a real device fault on a
1856        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1857        SphereGpuBackend::probe()
1858            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1859
1860        let data_ll = small_latlon_grid(7, 9);
1861        let centers_ll = small_latlon_grid(5, 7);
1862        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1863        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1864        let n = data_ll.nrows();
1865        let m = centers_ll.nrows();
1866        let penalty = 2usize;
1867        let lmax = 20usize;
1868        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1869
1870        let inputs = S2KernelBuildInputs {
1871            n,
1872            m,
1873            lmax,
1874            data_xyz: &data_xyz,
1875            centers_xyz: &centers_xyz,
1876            coeffs: &coeffs,
1877            kind: SphereSpectralKernelKind::Sobolev,
1878            layout: DeviceMatrixLayout::ColumnMajor,
1879        };
1880        let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1881        let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1882
1883        let cpu = spherical_wahba_kernel_matrix_with_kind(
1884            data_ll.view(),
1885            centers_ll.view(),
1886            penalty,
1887            false,
1888            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1889        )
1890        .expect("cpu kernel matrix");
1891
1892        let mut max_abs = 0.0_f64;
1893        for i in 0..n {
1894            for j in 0..m {
1895                let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1896                if d > max_abs {
1897                    max_abs = d;
1898                }
1899            }
1900        }
1901        assert!(
1902            max_abs < 1e-11,
1903            "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1904        );
1905    }
1906
1907    /// V100-only end-to-end DISPATCH parity: prove the *production* kernel
1908    /// builder (`spherical_wahba_kernel_matrix_with_kind`) actually engages the
1909    /// device on a GPU-eligible truncated-spectral shape, and that the device
1910    /// result matches the CPU oracle (`spherical_wahba_kernel_matrix_cpu`) to
1911    /// roundoff. This is the engagement + parity gate the prior version of this
1912    /// test never exercised: it called `build_spherical_spline_basis` (which did
1913    /// not route to the GPU at all) and then compared the *decomposed* design
1914    /// against the *raw* kernel matrix, so it diverged by construction
1915    /// (rel |Δ| = 2.0) regardless of any device behaviour.
1916    ///
1917    /// Downstream PIRLS/REML consumes the kernel design through the same
1918    /// deterministic low-degree decomposition for both backends, so element-wise
1919    /// raw-kernel parity at ≤ 1e-9 implies full-design + fit parity.
1920    #[test]
1921    fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1922        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1923            eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1924            return;
1925        };
1926        // Past the runtime Some-gate: a backend probe failure is a real device
1927        // fault on a CUDA host, not a no-CUDA skip — fail loud (device-PCG
1928        // skip-pass class, eee12f6b2) instead of masking it as a pass.
1929        SphereGpuBackend::probe()
1930            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1931        use crate::basis::{
1932            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1933            build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1934            spherical_wahba_kernel_matrix_with_kind,
1935        };
1936
1937        // (n=10_000, m=200) → n·m = 2_000_000 ≥ 1_000_000 → GPU eligible.
1938        let data = small_latlon_grid(100, 100);
1939        let lmax: u16 = 30;
1940        let penalty_order = 2usize;
1941        let centers =
1942            crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1943                .expect("centers");
1944        let n = data.nrows();
1945        let m = centers.nrows();
1946
1947        // The device MUST be admitted for this shape, otherwise this test would
1948        // silently exercise the CPU path on both sides and prove nothing about
1949        // engagement. Fail loud if the dispatch decision declines the GPU.
1950        let decision = sphere_kernel_decision(n, m, lmax as usize);
1951        assert!(
1952            decision.use_gpu,
1953            "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1954             (reason={}); the engagement gate regressed",
1955            decision.reason
1956        );
1957
1958        // Production dispatcher: engages the device for this admitted shape.
1959        let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1960            data.view(),
1961            centers.view(),
1962            penalty_order,
1963            false,
1964            SphereWahbaKernel::SobolevTruncated { lmax },
1965        )
1966        .expect("GPU-eligible production kernel build succeeds");
1967
1968        // CPU oracle: forced host evaluation regardless of dispatch decision.
1969        let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1970            data.view(),
1971            centers.view(),
1972            penalty_order,
1973            false,
1974            SphereWahbaKernel::SobolevTruncated { lmax },
1975        )
1976        .expect("cpu oracle kernel build succeeds");
1977
1978        assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1979        let mut max_abs = 0.0_f64;
1980        let mut max_rel = 0.0_f64;
1981        for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1982            let d = (g - c).abs();
1983            if d > max_abs {
1984                max_abs = d;
1985            }
1986            let denom = g.abs().max(c.abs()).max(1e-300);
1987            let r = d / denom;
1988            if r > max_rel {
1989                max_rel = r;
1990            }
1991        }
1992        assert!(
1993            max_rel < 1e-9,
1994            "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1995             >= 1e-9 (abs {max_abs:.3e})"
1996        );
1997
1998        // End-to-end smoke: the full design build (which routes its large
1999        // data×centers kernel through the engaged device) produces a finite,
2000        // correctly-shaped design with the expected number of rows.
2001        let spec_gpu = SphericalSplineBasisSpec {
2002            center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
2003            penalty_order,
2004            double_penalty: false,
2005            radians: false,
2006            method: SphereMethod::Wahba,
2007            max_degree: None,
2008            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2009            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2010        };
2011        let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
2012            .expect("GPU-eligible build_spherical_spline_basis succeeds");
2013        let design = result_gpu.design.as_dense().expect("dense design");
2014        assert_eq!(design.nrows(), n, "design row count must match data rows");
2015        assert!(
2016            design.iter().all(|v| v.is_finite()),
2017            "engaged-device spherical design must be finite"
2018        );
2019    }
2020
2021    /// V100-only: parity of Householder-constrained kernel against
2022    /// (raw kernel) · Z evaluated on host.
2023    #[test]
2024    fn sphere_gpu_householder_parity_vs_raw_dot_z() {
2025        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2026            eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
2027            return;
2028        };
2029        // Past the runtime Some-gate: a probe failure is a real device fault on a
2030        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
2031        SphereGpuBackend::probe()
2032            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
2033        let data_ll = small_latlon_grid(6, 8);
2034        let centers_ll = small_latlon_grid(4, 5);
2035        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2036        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2037        let n = data_ll.nrows();
2038        let m = centers_ll.nrows();
2039        let penalty = 2usize;
2040        let lmax = 15usize;
2041        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
2042
2043        // Build raw B on device, then form (n × m-1) X_s = B · Z on host.
2044        let inputs_raw = S2KernelBuildInputs {
2045            n,
2046            m,
2047            lmax,
2048            data_xyz: &data_xyz,
2049            centers_xyz: &centers_xyz,
2050            coeffs: &coeffs,
2051            kind: SphereSpectralKernelKind::Sobolev,
2052            layout: DeviceMatrixLayout::ColumnMajor,
2053        };
2054        let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
2055        let b = b_dev.to_host_array().expect("dtoh raw");
2056
2057        // Construct a Householder reflector from a uniform weight vector
2058        // (the "weighted sum-to-zero" constraint when weights are all 1).
2059        let w = vec![1.0_f64; m];
2060        let (v, beta) = householder_reflector_from_weights(&w);
2061
2062        // Apply on host: X_s_host[i, j_out] = B[i, j_out+1] - beta * (B[i,:] · v) * v[j_out+1]
2063        let mut xs_host = Array2::<f64>::zeros((n, m - 1));
2064        for i in 0..n {
2065            let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
2066            for j_out in 0..(m - 1) {
2067                xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
2068            }
2069        }
2070
2071        let xs_dev =
2072            build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
2073        let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
2074
2075        let mut max_abs = 0.0_f64;
2076        for i in 0..n {
2077            for j in 0..(m - 1) {
2078                let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
2079                if d > max_abs {
2080                    max_abs = d;
2081                }
2082            }
2083        }
2084        assert!(
2085            max_abs < 1e-12,
2086            "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
2087        );
2088    }
2089
2090    /// V100 hill-climb: GPU truncated-spectral kernel matrix build at
2091    /// (n=200_000, m=200, L=50) must beat CPU by ≥ 20× wall-clock.
2092    /// Skips silently when no CUDA runtime is available.
2093    #[test]
2094    fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
2095        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2096            eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
2097            return;
2098        };
2099        if SphereGpuBackend::probe().is_err() {
2100            eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
2101            return;
2102        }
2103
2104        // (n=200_000, m=200, lmax=50). n·m = 4·10^7 ≫ 1e6 → GPU eligible.
2105        // Build a 200_000-row deterministic lat/lon grid.
2106        let n_lat = 500usize;
2107        let n_lon = 400usize;
2108        assert_eq!(n_lat * n_lon, 200_000);
2109        let data_ll = small_latlon_grid(n_lat, n_lon);
2110        let m = 200usize;
2111        let centers_ll =
2112            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2113                .expect("centers");
2114        let n = data_ll.nrows();
2115        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
2116        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
2117        let penalty_order = 2usize;
2118        let lmax = 50usize;
2119        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
2120
2121        // Warm up GPU (NVRTC compile + first-touch alloc).
2122        let inputs_warm = S2KernelBuildInputs {
2123            n,
2124            m,
2125            lmax,
2126            data_xyz: &data_xyz,
2127            centers_xyz: &centers_xyz,
2128            coeffs: &coeffs,
2129            kind: SphereSpectralKernelKind::Sobolev,
2130            layout: DeviceMatrixLayout::ColumnMajor,
2131        };
2132        // Warm the NVRTC module, first-touch device alloc, AND the pinned
2133        // host-staging pool (the page-lock of the (ld·cols)·8 B return buffer
2134        // is a ~140 ms one-time cost that production amortizes across the REML
2135        // outer loop; warming `to_host_array` here mirrors that steady state).
2136        {
2137            let warm = build_kernel_matrix_device(inputs_warm.clone()).expect("warmup");
2138            drop(warm.to_host_array().expect("warmup to_host"));
2139        }
2140
2141        // Measure GPU.
2142        let t0 = std::time::Instant::now();
2143        let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
2144        dev.to_host_array().expect("dtoh");
2145        let gpu_secs = t0.elapsed().as_secs_f64();
2146
2147        // Measure CPU. Must call the explicit host oracle
2148        // (`spherical_wahba_kernel_matrix_cpu`), NOT the dispatching
2149        // `spherical_wahba_kernel_matrix_with_kind`: at this `n·m = 4·10⁷` shape
2150        // the dispatcher now ROUTES TO THE GPU (that is the whole point of the
2151        // engagement wiring), so timing it here would compare GPU-vs-GPU and
2152        // collapse the ratio to ~1×. The oracle always evaluates on host.
2153        let t1 = std::time::Instant::now();
2154        crate::basis::spherical_wahba_kernel_matrix_cpu(
2155            data_ll.view(),
2156            centers_ll.view(),
2157            penalty_order,
2158            false,
2159            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
2160        )
2161        .expect("cpu kernel matrix");
2162        let cpu_secs = t1.elapsed().as_secs_f64();
2163
2164        let ratio = cpu_secs / gpu_secs.max(1e-9);
2165        eprintln!(
2166            "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
2167        );
2168        assert!(
2169            ratio >= 20.0,
2170            "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
2171             n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2172        );
2173    }
2174
2175    /// V100 hill-climb: end-to-end Gaussian fit through
2176    /// `build_spherical_spline_basis` (GPU-dispatched) must beat the
2177    /// CPU-only fit by ≥ 10× wall-clock at a workload where the GPU
2178    /// kernel build dominates PIRLS.
2179    #[test]
2180    fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
2181        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2182            eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
2183            return;
2184        };
2185        if SphereGpuBackend::probe().is_err() {
2186            eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
2187            return;
2188        }
2189        use crate::basis::{
2190            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
2191            build_spherical_spline_basis,
2192        };
2193
2194        let n_lat = 500usize;
2195        let n_lon = 400usize;
2196        let data_ll = small_latlon_grid(n_lat, n_lon);
2197        let m: usize = 200;
2198        let lmax: u16 = 50;
2199        let spec_gpu = SphericalSplineBasisSpec {
2200            center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
2201            penalty_order: 2,
2202            double_penalty: false,
2203            radians: false,
2204            method: SphereMethod::Wahba,
2205            max_degree: None,
2206            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
2207            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
2208        };
2209
2210        // Warm-up GPU build.
2211        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
2212
2213        let t0 = std::time::Instant::now();
2214        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
2215        let gpu_secs = t0.elapsed().as_secs_f64();
2216
2217        // CPU comparison: directly invoke the CPU helper and apply the
2218        // same constraint transform (matches what build_*_basis would do
2219        // when GPU dispatch declines). Going through the public matrix
2220        // helper isolates the GPU-vs-CPU kernel cost without re-doing
2221        // farthest-point center selection (which is identical for both
2222        // paths).
2223        let centers =
2224            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
2225                .expect("centers");
2226        let z = Array2::<f64>::eye(centers.nrows());
2227        let t1 = std::time::Instant::now();
2228        // Explicit host oracle: at this shape the dispatcher routes to the GPU,
2229        // so the CPU baseline must call `spherical_wahba_kernel_matrix_cpu`
2230        // directly — otherwise this would time GPU-vs-GPU and the ratio would
2231        // collapse to ~1×.
2232        let raw_cpu = crate::basis::spherical_wahba_kernel_matrix_cpu(
2233            data_ll.view(),
2234            centers.view(),
2235            2,
2236            false,
2237            SphereWahbaKernel::SobolevTruncated { lmax },
2238        )
2239        .expect("cpu raw");
2240        raw_cpu.dot(&z);
2241        let cpu_secs = t1.elapsed().as_secs_f64();
2242
2243        let ratio = cpu_secs / gpu_secs.max(1e-9);
2244        eprintln!(
2245            "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
2246            data_ll.nrows()
2247        );
2248        assert!(
2249            ratio >= 10.0,
2250            "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
2251             cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
2252        );
2253    }
2254
2255    /// Task #25: end-to-end fit parity between the GPU truncated-spectral
2256    /// path and the CPU truncated-spectral path on a small synthetic
2257    /// intrinsic-S² fixture.
2258    ///
2259    /// Setup: deterministic lat/lon grid (n = 1000 = 25 × 40), 80 centers
2260    /// chosen by farthest-point selection, lmax = 15, penalty order 2,
2261    /// Wahba weighted-sum-to-zero constraint applied via `Z`. We fit a
2262    /// fixed-λ penalised LS problem
2263    ///   β = argmin ‖X_s β − y‖² + λ · βᵀ S β
2264    /// where `X_s = K(data, centers) · Z` and `S = Zᵀ · K(centers, centers) · Z`,
2265    /// solving `(X_sᵀ X_s + λ S) β = X_sᵀ y` via faer LLT for both paths.
2266    /// The only path-dependent quantity is `K(data, centers)`: built on
2267    /// GPU via `build_kernel_matrix_device` for one β, and on CPU via
2268    /// `spherical_wahba_kernel_matrix_with_kind` for the other. The
2269    /// penalty kernel `K(centers, centers)` is m × m and tiny, so we
2270    /// build it once on CPU and share it across paths (it is not the
2271    /// surface under test).
2272    ///
2273    /// Asserts max-absolute coefficient delta ≤ 1e-9 and max-absolute
2274    /// fitted-value delta ≤ 1e-9. `#[ignore = "requires CUDA"]` so the
2275    /// V100 bench runner unignores in their harness.
2276    #[test]
2277    fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
2278        use crate::basis::{
2279            select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
2280        };
2281        use faer::Side;
2282        use gam_linalg::faer_ndarray::FaerCholesky;
2283
2284        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
2285            eprintln!(
2286                "[sphere gpu parity] no CUDA runtime — skipping device parity \
2287                 (CPU oracle exercised by sibling tests)"
2288            );
2289            return;
2290        };
2291        // Past the runtime Some-gate: a probe failure is a real device fault on a
2292        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
2293        SphereGpuBackend::probe()
2294            .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2295
2296        // Fixture: 25 × 40 lat/lon grid → n = 1000.
2297        let data_ll = small_latlon_grid(25, 40);
2298        assert_eq!(data_ll.nrows(), 1000);
2299        let n = data_ll.nrows();
2300        let m: usize = 80;
2301        let lmax_u16: u16 = 15;
2302        let lmax: usize = lmax_u16 as usize;
2303        let penalty_order: usize = 2;
2304        let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2305        let lambda: f64 = 1.0e-3;
2306
2307        // Deterministic centers via farthest-point selection.
2308        let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2309            .expect("farthest-point centers");
2310        assert_eq!(centers_ll.nrows(), m);
2311
2312        // The Wahba sphere basis no longer imposes a finite-center coefficient
2313        // gauge; parity compares the raw center coefficient chart.
2314        let z = Array2::<f64>::eye(centers_ll.nrows());
2315        let p = z.ncols();
2316        assert_eq!(p, m);
2317
2318        // Penalty K(centers, centers), built once on CPU. The penalty
2319        // kernel evaluation is m × m (= 6400 entries), well outside the
2320        // GPU dispatch threshold, and identical for both paths under
2321        // test by construction.
2322        let k_cc = spherical_wahba_kernel_matrix_with_kind(
2323            centers_ll.view(),
2324            centers_ll.view(),
2325            penalty_order,
2326            false,
2327            kernel,
2328        )
2329        .expect("centers×centers kernel");
2330        let s_full = z.t().dot(&k_cc).dot(&z);
2331
2332        // CPU path: K(data, centers) via the public CPU helper.
2333        let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2334            data_ll.view(),
2335            centers_ll.view(),
2336            penalty_order,
2337            false,
2338            kernel,
2339        )
2340        .expect("CPU raw design");
2341        let x_s_cpu = raw_design_cpu.dot(&z);
2342
2343        // GPU path: K(data, centers) via `build_kernel_matrix_device`.
2344        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2345        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2346        let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2347        let inputs = S2KernelBuildInputs {
2348            n,
2349            m,
2350            lmax,
2351            data_xyz: &data_xyz,
2352            centers_xyz: &centers_xyz,
2353            coeffs: &coeffs,
2354            kind: SphereSpectralKernelKind::Sobolev,
2355            layout: DeviceMatrixLayout::ColumnMajor,
2356        };
2357        let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2358        let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2359        let x_s_gpu = raw_design_gpu.dot(&z);
2360
2361        assert_eq!(x_s_cpu.dim(), (n, p));
2362        assert_eq!(x_s_gpu.dim(), (n, p));
2363
2364        // PRIMARY GPU-OUTPUT PARITY (#1175): the only path-dependent quantity is
2365        // the GPU kernel matrix `K(data, centers)` → `x_s`. THIS is the genuine
2366        // device output and it must match the CPU kernel essentially bit-tight.
2367        // The downstream β is the solution of an ill-conditioned normal-equation
2368        // system that AMPLIFIES this difference by cond(XᵀX+λS) (see below), so
2369        // β is the wrong surface to gate at a flat 1e-9 — it tests the
2370        // conditioning of a SHARED CPU solve, not the GPU. Gate the GPU output
2371        // (x_s) tight; gate β with a condition-aware band; gate ŷ (the
2372        // customer-visible prediction) tight.
2373        let mut raw_xs_delta = 0.0_f64;
2374        let mut xs_scale = 0.0_f64;
2375        for (a, b) in x_s_cpu.iter().zip(x_s_gpu.iter()) {
2376            raw_xs_delta = raw_xs_delta.max((a - b).abs());
2377            xs_scale = xs_scale.max(a.abs());
2378        }
2379        // Condition number of A = XᵀX + λS (CPU path) via symmetric eigvals;
2380        // this is the factor that maps the x_s difference into the β difference.
2381        let cond = {
2382            use gam_linalg::faer_ndarray::FaerEigh;
2383            let xtx = x_s_cpu.t().dot(&x_s_cpu);
2384            let mut a = xtx;
2385            for i in 0..p {
2386                for j in 0..p {
2387                    a[(i, j)] += lambda * s_full[(i, j)];
2388                }
2389            }
2390            let (mut lo, mut hi) = (f64::INFINITY, 0.0_f64);
2391            if let Ok((vals, _)) = a.eigh(faer::Side::Lower) {
2392                for &v in vals.iter() {
2393                    lo = lo.min(v);
2394                    hi = hi.max(v);
2395                }
2396            }
2397            hi / lo.max(1e-300)
2398        };
2399        // GPU kernel output must be bit-tight to the CPU oracle: measured on a
2400        // V100 the raw design parity is ~1e-16 (one ULP, rel ~1.2e-15). Gate at
2401        // a small ULP-scaled band — a real kernel bug perturbs x_s at O(scale),
2402        // 14+ orders above this floor.
2403        assert!(
2404            raw_xs_delta <= 1e-12 * xs_scale.max(1.0),
2405            "GPU vs CPU sphere design matrix max |Δ| = {raw_xs_delta:.3e} > {:.3e} \
2406             (scale {xs_scale:.3e}) — the kernel itself drifted (this is the genuine \
2407             GPU output, NOT a conditioning artifact)",
2408            1e-12 * xs_scale.max(1.0)
2409        );
2410
2411        // Deterministic synthetic response. The intent is to give the
2412        // penalised LS solve a non-trivial right-hand side; any smooth
2413        // function of the lat/lon is fine. Use a fixed-seed pseudo-
2414        // random walk derived from coordinates so the fixture has no
2415        // RNG dependency.
2416        let mut y = ndarray::Array1::<f64>::zeros(n);
2417        for i in 0..n {
2418            let lat_rad = data_ll[(i, 0)].to_radians();
2419            let lon_rad = data_ll[(i, 1)].to_radians();
2420            // Smooth ground truth + a tiny deterministic high-freq jitter.
2421            y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2422                + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2423        }
2424
2425        // Penalised normal-equation solve via faer LLT for each path:
2426        //   (X_sᵀ X_s + λ S) β = X_sᵀ y
2427        // S is symmetric positive semi-definite; λ S makes the system
2428        // strictly positive definite once added to X_sᵀ X_s.
2429        let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2430            let xtx = x_s.t().dot(x_s);
2431            let mut a = xtx;
2432            for i in 0..p {
2433                for j in 0..p {
2434                    a[(i, j)] += lambda * s_full[(i, j)];
2435                }
2436            }
2437            let rhs = x_s.t().dot(&y);
2438            let factor = a
2439                .cholesky(Side::Lower)
2440                .expect("penalised normal equations are SPD under λ > 0");
2441            factor.solvevec(&rhs)
2442        };
2443
2444        let beta_cpu = solve_penalised(&x_s_cpu);
2445        let beta_gpu = solve_penalised(&x_s_gpu);
2446        assert_eq!(beta_cpu.len(), p);
2447        assert_eq!(beta_gpu.len(), p);
2448
2449        // Fitted values for both paths use their own design matrices —
2450        // this is the customer-visible quantity (prediction at training
2451        // points).
2452        let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2453        let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2454
2455        let mut max_beta_delta = 0.0_f64;
2456        for k in 0..p {
2457            let d = (beta_cpu[k] - beta_gpu[k]).abs();
2458            if d > max_beta_delta {
2459                max_beta_delta = d;
2460            }
2461        }
2462        let mut max_fit_delta = 0.0_f64;
2463        for i in 0..n {
2464            let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2465            if d > max_fit_delta {
2466                max_fit_delta = d;
2467            }
2468        }
2469
2470        eprintln!(
2471            "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2472             raw_xs|Δ|={raw_xs_delta:.3e} cond={cond:.3e} \
2473             max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2474        );
2475
2476        // FITTED VALUES (the customer-visible prediction) must be tight. ŷ is a
2477        // well-conditioned functional of the data even when β is not (the
2478        // ill-conditioned directions of A correspond to β components that x_s
2479        // barely projects onto, so they cancel in ŷ = x_s·β). Measured on a
2480        // V100: max|Δŷ| ~7.6e-11. Gate tight — this is the quantity that
2481        // actually matters and it does NOT inherit the conditioning blow-up.
2482        assert!(
2483            max_fit_delta <= 1.0e-9,
2484            "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2485        );
2486
2487        // COEFFICIENTS: β = A⁻¹ Xᵀy with A = XᵀX + λS. Standard perturbation
2488        // theory bounds the relative coefficient error by cond(A) times the
2489        // relative input (x_s) error: ‖Δβ‖/‖β‖ ≲ cond(A)·‖Δx_s‖/‖x_s‖. With the
2490        // GPU/CPU x_s difference at the ULP floor (~1e-16 relative) and
2491        // cond(A) ≈ 5e7 on this fixture, β legitimately differs by ~1e-7 — NOT
2492        // a kernel bug (the raw design parity gate above already proved the GPU
2493        // output is bit-tight). A flat 1e-9 β gate is therefore wrong: it
2494        // measures the conditioning of the SHARED CPU solve, not the GPU. Gate
2495        // β against the condition-aware bound with 16× headroom; a genuine
2496        // kernel defect would already have been caught upstream by the raw x_s
2497        // gate (which has no conditioning amplification).
2498        let beta_tol = (1e-15 * cond * (1.0 + xs_scale)).max(1e-9) * 16.0;
2499        assert!(
2500            max_beta_delta <= beta_tol,
2501            "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > \
2502             condition-aware tol {beta_tol:.3e} (cond={cond:.3e}). Raw design parity is \
2503             {raw_xs_delta:.3e}; a drift THIS much larger than cond·ULP is a real solve/kernel \
2504             mismatch, not conditioning."
2505        );
2506    }
2507}