Skip to main content

gam_terms/basis/
sphere_gpu.rs

1//! GPU NVRTC Wahba intrinsic-S2 kernel matrix construction.
2//!
3//! This module owns the device-side construction of the Wahba reproducing
4//! kernel basis matrix on the 2-sphere using the **finite truncated
5//! spectral Legendre series**
6//!
7//! `K_L(γ) = Σ_{ℓ=1..L} c_ℓ · P_ℓ(cos γ)`,
8//!
9//! evaluated entry-by-entry against the 3-term Legendre recurrence kept
10//! in registers. The host CPU parity target is the matching
11//! `SphereWahbaKernel::SobolevTruncated { lmax }` /
12//! `SphereWahbaKernel::PseudoTruncated { lmax }` variant added to
13//! `src/terms/basis.rs` (single source: same recurrence, same c_ℓ).
14//!
15//! The device path evaluates the raw column-major kernel matrix with `f64`
16//! Legendre recurrence math. Host code owns centering, constraints, and solver
17//! assembly in `basis.rs`.
18
19use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36/// Which truncated-spectral Wahba kernel to evaluate on device. Matches
37/// the CPU `SphereWahbaKernel::{SobolevTruncated, PseudoTruncated}` so
38/// parity tests are well-defined.
39#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41    /// `c_ℓ = (2ℓ+1) / (4π · [ℓ(ℓ+1)]^m)` — true `H^m(S²)` Sobolev RKHS.
42    Sobolev,
43    /// `c_ℓ = 2 / (4π · Π_{k=1..m+1}(ℓ + k))` — Wahba 1981 pseudo-spline.
44    Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48    /// `c_0 = 0`, `c_ℓ = c_ℓ(m)` for `ℓ = 1..=lmax`. Returned vector has
49    /// length `lmax + 1` and is uploaded verbatim to constant/global
50    /// memory before kernel launch.
51    pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52        match self {
53            SphereSpectralKernelKind::Sobolev => {
54                crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55            }
56            SphereSpectralKernelKind::Pseudo => {
57                crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58            }
59        }
60    }
61
62    /// Stable string tag used in the NVRTC module cache key + logs.
63    pub const fn tag(self) -> &'static str {
64        match self {
65            SphereSpectralKernelKind::Sobolev => "sobolev",
66            SphereSpectralKernelKind::Pseudo => "pseudo",
67        }
68    }
69}
70
71/// Layout of the (n,m) kernel design matrix on device. The Wahba
72/// pipeline downstream of this kernel (cuBLAS GEMM, cuSOLVER GEQRF)
73/// requires column-major.
74#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76    ColumnMajor,
77}
78
79/// Lat/lon (degrees or radians) → unit vector `(x, y, z)` on S² ⊂ ℝ³.
80/// Returns a flat `Vec<f64>` of length `3 * n` in the row-major layout
81/// `[x_0, y_0, z_0, x_1, y_1, z_1, …]`, ready for one `htod` upload.
82///
83/// `radians = false` interprets inputs as degrees (the codebase default
84/// for `SphericalSplineBasisSpec`).
85pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86    if latlon.ncols() != 2 {
87        return Err(format!(
88            "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89            latlon.shape()
90        ));
91    }
92    let deg = if radians {
93        1.0
94    } else {
95        std::f64::consts::PI / 180.0
96    };
97    let n = latlon.nrows();
98    let mut out = Vec::with_capacity(3 * n);
99    for row in latlon.outer_iter() {
100        let lat = row[0] * deg;
101        let lon = row[1] * deg;
102        let (s_lat, c_lat) = lat.sin_cos();
103        let (s_lon, c_lon) = lon.sin_cos();
104        // Standard geographic→cartesian: pole on +z.
105        out.push(c_lat * c_lon);
106        out.push(c_lat * s_lon);
107        out.push(s_lat);
108    }
109    Ok(out)
110}
111
112/// Device-resident `(rows × cols)` matrix in column-major layout with
113/// leading dimension `ld ≥ rows`. The slice holds `ld * cols` `f64`
114/// elements; entry `(i, j)` lives at `col_major_dev[j * ld + i]`.
115///
116/// On non-Linux builds the type is intentionally a host shadow so the
117/// surrounding orchestration compiles without cudarc.
118#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120    pub rows: usize,
121    pub cols: usize,
122    pub ld: usize,
123    pub col_major_dev: CudaSlice<f64>,
124    pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129    pub rows: usize,
130    pub cols: usize,
131    pub ld: usize,
132    /// Host shadow for CPU-only builds.
133    pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137    /// Copy the device matrix back to the host as a regular ndarray
138    /// `(rows × cols)` row-major view. Convenience for tests + parity
139    /// comparisons; production paths should keep the matrix resident.
140    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
141        let mut col_major = vec![0.0_f64; self.ld * self.cols];
142        self.copy_to_host_col_major(&mut col_major)?;
143        let mut out = Array2::<f64>::zeros((self.rows, self.cols));
144        for j in 0..self.cols {
145            for i in 0..self.rows {
146                out[(i, j)] = col_major[j * self.ld + i];
147            }
148        }
149        Ok(out)
150    }
151
152    /// Copy the underlying `(ld × cols)` column-major payload to a
153    /// caller-provided buffer. Used by `to_host_array` and by the
154    /// device-resident cuSOLVER consumer when it needs to extract the
155    /// coefficient vector.
156    #[cfg(target_os = "linux")]
157    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
158        let needed = self.ld * self.cols;
159        if dst.len() != needed {
160            gam_gpu::gpu_bail!(
161                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
162                dst.len(),
163                needed
164            );
165        }
166        self.stream
167            .memcpy_dtoh(&self.col_major_dev, dst)
168            .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
169        self.stream
170            .synchronize()
171            .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
172        Ok(())
173    }
174
175    #[cfg(not(target_os = "linux"))]
176    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
177        let needed = self.ld * self.cols;
178        if dst.len() != needed {
179            gam_gpu::gpu_bail!(
180                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
181                dst.len(),
182                needed
183            );
184        }
185        dst.copy_from_slice(&self.col_major_dev);
186        Ok(())
187    }
188}
189
190// ────────────────────────────────────────────────────────────────────────
191// Inputs
192// ────────────────────────────────────────────────────────────────────────
193
194/// Host-side inputs needed to launch `s2_wahba_legendre_colmajor`.
195///
196/// `data_xyz` and `centers_xyz` are flat row-major
197/// `[x_0, y_0, z_0, …]` length `3 * n` and `3 * m` respectively, pre-
198/// computed via [`latlon_to_xyz_host`]. `coeffs` has length `lmax + 1`,
199/// indexed as `coeffs[ℓ] = c_ℓ` with `c_0 = 0`.
200#[derive(Clone, Debug)]
201pub struct S2KernelBuildInputs<'a> {
202    pub n: usize,
203    pub m: usize,
204    pub lmax: usize,
205    pub data_xyz: &'a [f64],
206    pub centers_xyz: &'a [f64],
207    pub coeffs: &'a [f64],
208    pub kind: SphereSpectralKernelKind,
209    pub layout: DeviceMatrixLayout,
210}
211
212impl<'a> S2KernelBuildInputs<'a> {
213    fn validate(&self) -> Result<(), GpuError> {
214        if self.lmax == 0 {
215            return Err(GpuError::DriverCallFailed {
216                reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
217            });
218        }
219        if self.data_xyz.len() != 3 * self.n {
220            gam_gpu::gpu_bail!(
221                "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
222                self.data_xyz.len(),
223                3 * self.n
224            );
225        }
226        if self.centers_xyz.len() != 3 * self.m {
227            gam_gpu::gpu_bail!(
228                "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
229                self.centers_xyz.len(),
230                3 * self.m
231            );
232        }
233        if self.coeffs.len() != self.lmax + 1 {
234            gam_gpu::gpu_bail!(
235                "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
236                self.coeffs.len(),
237                self.lmax + 1
238            );
239        }
240        if self.coeffs[0] != 0.0 {
241            return Err(GpuError::DriverCallFailed {
242                reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
243            });
244        }
245        Ok(())
246    }
247}
248
249// ────────────────────────────────────────────────────────────────────────
250// NVRTC kernel source — raw and Householder-fused variants.
251//
252// Both compile with `--std=c++17 --gpu-architecture=compute_${cc}` and
253// take LMAX as a compile-time `#define`. Block (32, 8, 1), shared-mem
254// tiles for one data row × 3 doubles per warp and one center × 3
255// doubles per warp.
256// ────────────────────────────────────────────────────────────────────────
257
258#[cfg(target_os = "linux")]
259const KERNEL_TEMPLATE: &str = r#"
260// LMAX is supplied by the host via a `#define LMAX ...` prepended to
261// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
262extern "C" __global__
263__launch_bounds__(256)
264void s2_wahba_legendre_colmajor(
265    const double* __restrict__ data_xyz,    // n × 3 (row-major flat)
266    const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
267    const double* __restrict__ coeffs,      // length LMAX + 1, coeffs[0] = 0
268    int n,
269    int m,
270    long long ld,
271    double* __restrict__ out                // ld × m column-major
272) {
273    const int i = blockIdx.y * blockDim.y + threadIdx.y;
274    const int j = blockIdx.x * blockDim.x + threadIdx.x;
275    if (i >= n || j >= m) return;
276
277    // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
278    const double xi = data_xyz[3 * i + 0];
279    const double yi = data_xyz[3 * i + 1];
280    const double zi = data_xyz[3 * i + 2];
281    const double cxj = centers_xyz[3 * j + 0];
282    const double cyj = centers_xyz[3 * j + 1];
283    const double czj = centers_xyz[3 * j + 2];
284
285    // t = clamp(x_i · z_j, -1, +1).
286    double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
287    if (t >  1.0) t =  1.0;
288    if (t < -1.0) t = -1.0;
289
290    // Legendre 3-term recurrence in registers.
291    // P_0(t) = 1, P_1(t) = t.
292    double p_prev = 1.0;
293    double p_curr = t;
294    double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
295
296    #pragma unroll 8
297    for (int ell = 1; ell < LMAX; ++ell) {
298        const double lf  = (double) ell;
299        const double inv = 1.0 / (lf + 1.0);
300        // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
301        const double p_next =
302            fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
303        acc = fma(coeffs[ell + 1], p_next, acc);
304        p_prev = p_curr;
305        p_curr = p_next;
306    }
307
308    out[(long long) j * ld + (long long) i] = acc;
309}
310
311// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
312// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
313// i.e. drop the first column after applying Z. Each thread computes one
314// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
315// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
316// for j_out in 0..m-1.
317//
318// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
319// over centers in an inner loop — register-bound by the per-row state
320// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
321extern "C" __global__
322__launch_bounds__(128)
323void s2_wahba_householder_constrained_colmajor(
324    const double* __restrict__ data_xyz,    // n × 3
325    const double* __restrict__ centers_xyz, // m × 3
326    const double* __restrict__ coeffs,      // length LMAX + 1
327    const double* __restrict__ v,           // length m, Householder vector
328    double beta,
329    int n,
330    int m,
331    long long ld_out,
332    double* __restrict__ out                // ld_out × (m-1) column-major
333) {
334    const int i = blockIdx.x * blockDim.x + threadIdx.x;
335    if (i >= n) return;
336
337    const double xi = data_xyz[3 * i + 0];
338    const double yi = data_xyz[3 * i + 1];
339    const double zi = data_xyz[3 * i + 2];
340
341    // Pass 1: compute d_i = sum_j v[j] * B[i, j].
342    double d_i = 0.0;
343    for (int j = 0; j < m; ++j) {
344        const double cxj = centers_xyz[3 * j + 0];
345        const double cyj = centers_xyz[3 * j + 1];
346        const double czj = centers_xyz[3 * j + 2];
347        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
348        if (t >  1.0) t =  1.0;
349        if (t < -1.0) t = -1.0;
350
351        double p_prev = 1.0;
352        double p_curr = t;
353        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
354        #pragma unroll 8
355        for (int ell = 1; ell < LMAX; ++ell) {
356            const double lf  = (double) ell;
357            const double inv = 1.0 / (lf + 1.0);
358            const double p_next =
359                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
360            acc = fma(coeffs[ell + 1], p_next, acc);
361            p_prev = p_curr;
362            p_curr = p_next;
363        }
364        d_i = fma(v[j], acc, d_i);
365    }
366
367    // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
368    const double bd = beta * d_i;
369    for (int j_out = 0; j_out < m - 1; ++j_out) {
370        const int j = j_out + 1;
371        const double cxj = centers_xyz[3 * j + 0];
372        const double cyj = centers_xyz[3 * j + 1];
373        const double czj = centers_xyz[3 * j + 2];
374        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
375        if (t >  1.0) t =  1.0;
376        if (t < -1.0) t = -1.0;
377
378        double p_prev = 1.0;
379        double p_curr = t;
380        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
381        #pragma unroll 8
382        for (int ell = 1; ell < LMAX; ++ell) {
383            const double lf  = (double) ell;
384            const double inv = 1.0 / (lf + 1.0);
385            const double p_next =
386                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
387            acc = fma(coeffs[ell + 1], p_next, acc);
388            p_prev = p_curr;
389            p_curr = p_next;
390        }
391        const double xs = acc - bd * v[j];
392        out[(long long) j_out * ld_out + (long long) i] = xs;
393    }
394}
395"#;
396
397// ────────────────────────────────────────────────────────────────────────
398// Module cache key + per-process backend.
399// ────────────────────────────────────────────────────────────────────────
400
401/// Module cache key: every distinct `(CC, LMAX, kind, layout, kernel
402/// flavor)` compiles to a different PTX. `precision = f64` and the
403/// (32, 8, 1) raw-kernel block / (128, 1, 1) Householder-kernel block
404/// shapes are baked into the kernel source so they are implicit in the
405/// flavor tag and don't appear here.
406#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
407pub struct S2ModuleCacheKey {
408    pub cc_major: i32,
409    pub cc_minor: i32,
410    pub lmax: u32,
411    pub kind: SphereSpectralKernelKind,
412    pub layout: DeviceMatrixLayout,
413}
414
415/// Returns `true` if this build was compiled with the Linux + cudarc GPU
416/// backend that runs the S² Wahba kernels.
417pub const fn sphere_gpu_compiled() -> bool {
418    cfg!(target_os = "linux")
419}
420
421/// Decide whether the GPU sphere kernel matrix path is eligible for
422/// `(n, m, lmax)`. Heuristic per the math spec:
423///   * `n * m >= 1_000_000`
424///   * `lmax <= 200`
425///   * device memory budget admits at least one `(ld × m)` design at
426///     `ld = ((n + 31) / 32) * 32`.
427#[must_use]
428pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
429    let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
430        let ld = ((n + 31) / 32) * 32;
431        let needed_bytes = ld
432            .saturating_mul(m)
433            .saturating_mul(std::mem::size_of::<f64>());
434        let budget = runtime.memory_budget_bytes;
435        n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
436    } else {
437        false
438    };
439    decide(
440        GpuKernel::SpatialKernelOperator,
441        gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
442    )
443}
444
445/// Map a truncated `SphereWahbaKernel` variant onto the device kernel kind +
446/// truncation degree. Only the two *truncated* spectral variants have an exact
447/// device counterpart (the closed-form `Sobolev`/`Pseudo` variants use
448/// polylogarithms / deep-`L` series the device kernel does not evaluate), so
449/// `Sobolev`/`Pseudo` return `None` and stay on the CPU closed-form path.
450#[must_use]
451pub fn truncated_device_kind(
452    kernel: crate::basis::SphereWahbaKernel,
453) -> Option<(SphereSpectralKernelKind, u16)> {
454    use crate::basis::SphereWahbaKernel;
455    match kernel {
456        SphereWahbaKernel::SobolevTruncated { lmax } => {
457            Some((SphereSpectralKernelKind::Sobolev, lmax))
458        }
459        SphereWahbaKernel::PseudoTruncated { lmax } => {
460            Some((SphereSpectralKernelKind::Pseudo, lmax))
461        }
462        SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => None,
463    }
464}
465
466/// Production entry: build the raw `(n × m)` truncated-spectral Wahba kernel
467/// design matrix on the GPU when [`sphere_kernel_decision`] admits the device,
468/// returning `None` to signal the caller to use its CPU oracle.
469///
470/// Contract:
471///   * Returns `None` when the kernel is a non-truncated closed-form variant
472///     (no exact device counterpart), or when the dispatch decision keeps the
473///     work on the CPU (`!use_gpu`). The caller then runs the bit-defining CPU
474///     path. This is the **only** quiet-CPU route and it is taken *before* any
475///     device call — never as a silent fallback after a device failure.
476///   * Returns `Some(Ok(matrix))` with the device-computed host array when the
477///     device path ran and matches the CPU truncated recurrence to roundoff
478///     (proven by the parity tests). `gam_gpu::policy` keeps the same `c_ℓ`
479///     array and the same Legendre 3-term recurrence on both sides.
480///   * Returns `Some(Err(_))` when the device was *admitted* but the launch /
481///     NVRTC compile / copy-back failed — a hard error the caller must surface,
482///     NOT degrade to CPU. Fail-loud once admitted (the recurring silent-CPU
483///     fallback is the bug this path exists to kill).
484///
485/// `data` / `centers` are `(_, 2)` lat/lon matrices (degrees unless
486/// `radians`), matching `spherical_wahba_kernel_matrix_with_kind`.
487pub fn try_build_truncated_kernel_matrix_gpu(
488    data: ArrayView2<'_, f64>,
489    centers: ArrayView2<'_, f64>,
490    penalty_order: usize,
491    radians: bool,
492    kernel: crate::basis::SphereWahbaKernel,
493) -> Option<Result<Array2<f64>, GpuError>> {
494    let (kind, lmax) = truncated_device_kind(kernel)?;
495    let n = data.nrows();
496    let m = centers.nrows();
497    if n == 0 || m == 0 || lmax == 0 {
498        return None;
499    }
500    let decision = sphere_kernel_decision(n, m, lmax as usize);
501    if !decision.use_gpu {
502        // Either backend-not-compiled, runtime-unavailable, or below the
503        // device-work threshold. Quiet CPU route, taken before any device call.
504        return None;
505    }
506    // Admitted: from here a failure is a hard error, never a silent CPU degrade.
507    Some(build_truncated_kernel_matrix_gpu_admitted(
508        data,
509        centers,
510        penalty_order,
511        radians,
512        kind,
513        lmax,
514    ))
515}
516
517/// Run the admitted device build for `try_build_truncated_kernel_matrix_gpu`.
518/// Separated so the admission decision (which returns `None` for the CPU route)
519/// stays distinct from the fail-loud device execution (which returns `Err`).
520fn build_truncated_kernel_matrix_gpu_admitted(
521    data: ArrayView2<'_, f64>,
522    centers: ArrayView2<'_, f64>,
523    penalty_order: usize,
524    radians: bool,
525    kind: SphereSpectralKernelKind,
526    lmax: u16,
527) -> Result<Array2<f64>, GpuError> {
528    let n = data.nrows();
529    let m = centers.nrows();
530    let data_xyz = latlon_to_xyz_host(data, radians)
531        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
532    let centers_xyz = latlon_to_xyz_host(centers, radians)
533        .map_err(|reason| GpuError::DriverCallFailed { reason })?;
534    // Single-source the coefficients: the same `c_ℓ` array the CPU truncated
535    // recurrence consumes (`wahba_sphere_kernel_from_cos_kind`) is uploaded to
536    // the device, so CPU and GPU evaluate an identical zonal series.
537    let coeffs = kind.coefficients(lmax as usize, penalty_order);
538    let inputs = S2KernelBuildInputs {
539        n,
540        m,
541        lmax: lmax as usize,
542        data_xyz: &data_xyz,
543        centers_xyz: &centers_xyz,
544        coeffs: &coeffs,
545        kind,
546        layout: DeviceMatrixLayout::ColumnMajor,
547    };
548    let device_matrix = build_kernel_matrix_device(inputs)?;
549    let out = device_matrix.to_host_array()?;
550    if out.iter().any(|v| !v.is_finite()) {
551        return Err(GpuError::DriverCallFailed {
552            reason: "sphere GPU truncated kernel produced a non-finite value".to_string(),
553        });
554    }
555    Ok(out)
556}
557
558#[cfg(target_os = "linux")]
559struct SphereGpuContext {
560    ctx: Arc<CudaContext>,
561    stream: Arc<CudaStream>,
562    modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
563    cc_major: i32,
564    cc_minor: i32,
565}
566
567/// Process-wide sphere GPU backend. Lazy-initialised on first call to
568/// [`SphereGpuBackend::probe`].
569pub struct SphereGpuBackend {
570    #[cfg(target_os = "linux")]
571    inner: SphereGpuContext,
572}
573
574impl SphereGpuBackend {
575    /// Lazily initialise the process-wide sphere backend.
576    pub fn probe() -> Result<&'static Self, GpuError> {
577        static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
578        BACKEND
579            .get_or_init(|| {
580                #[cfg(target_os = "linux")]
581                {
582                    Self::probe_linux()
583                }
584                #[cfg(not(target_os = "linux"))]
585                {
586                    Err(GpuError::DriverLibraryUnavailable {
587                        reason: "sphere GPU backend is Linux-only".to_string(),
588                    })
589                }
590            })
591            .as_ref()
592            .map_err(GpuError::clone)
593    }
594
595    #[cfg(target_os = "linux")]
596    fn probe_linux() -> Result<Self, GpuError> {
597        let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
598        Ok(SphereGpuBackend {
599            inner: SphereGpuContext {
600                ctx: parts.ctx,
601                stream: parts.stream,
602                modules: Mutex::new(HashMap::new()),
603                cc_major: parts.capability.compute_major,
604                cc_minor: parts.capability.compute_minor,
605            },
606        })
607    }
608
609    /// NVRTC-compile (or fetch from cache) the module for `key`. The
610    /// returned module exposes both raw and Householder-fused kernels.
611    #[cfg(target_os = "linux")]
612    fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
613        if let Ok(guard) = self.inner.modules.lock() {
614            if let Some(existing) = guard.get(&key) {
615                return Ok(existing.clone());
616            }
617        }
618        // CompileOptions in cudarc 0.19 takes `arch: Option<&'static str>`
619        // which we cannot satisfy with a runtime-built string. Prepend the
620        // `LMAX` macro directly to the source so the NVRTC compile is a
621        // pure `compile_ptx`, matching the sibling kernels' invocation
622        // pattern. The kernel itself targets the device the driver
623        // reports (Volta+).
624        let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
625        let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
626            format!(
627                "sphere NVRTC compile (kind={}, lmax={}): {err}",
628                key.kind.tag(),
629                key.lmax
630            )
631        })?;
632        let module = self
633            .inner
634            .ctx
635            .load_module(ptx)
636            .gpu_ctx("sphere module load")?;
637        if let Ok(mut guard) = self.inner.modules.lock() {
638            guard.entry(key).or_insert_with(|| module.clone());
639        }
640        Ok(module)
641    }
642
643    #[cfg(target_os = "linux")]
644    fn cc(&self) -> (i32, i32) {
645        (self.inner.cc_major, self.inner.cc_minor)
646    }
647}
648
649// ────────────────────────────────────────────────────────────────────────
650// Entry points
651// ────────────────────────────────────────────────────────────────────────
652
653/// Build the raw `(n × m)` Wahba kernel matrix on device using
654/// `s2_wahba_legendre_colmajor`. Phase 1 entry point.
655pub fn build_kernel_matrix_device(
656    inputs: S2KernelBuildInputs<'_>,
657) -> Result<DeviceS2KernelMatrix, GpuError> {
658    inputs.validate()?;
659
660    #[cfg(target_os = "linux")]
661    {
662        use cudarc::driver::{LaunchConfig, PushKernelArg};
663        let backend = SphereGpuBackend::probe()?;
664        let (cc_major, cc_minor) = backend.cc();
665        let key = S2ModuleCacheKey {
666            cc_major,
667            cc_minor,
668            lmax: inputs.lmax as u32,
669            kind: inputs.kind,
670            layout: inputs.layout,
671        };
672        let module = backend.module_for(key)?;
673        let func = module
674            .load_function("s2_wahba_legendre_colmajor")
675            .gpu_ctx("sphere load_function raw")?;
676        let stream = backend.inner.stream.clone();
677
678        let data_dev = stream
679            .clone_htod(inputs.data_xyz)
680            .gpu_ctx("sphere htod data_xyz")?;
681        let centers_dev = stream
682            .clone_htod(inputs.centers_xyz)
683            .gpu_ctx("sphere htod centers_xyz")?;
684        let coeffs_dev = stream
685            .clone_htod(inputs.coeffs)
686            .gpu_ctx("sphere htod coeffs")?;
687
688        let n = inputs.n;
689        let m = inputs.m;
690        let ld = ((n + 31) / 32) * 32;
691        let mut out_dev = stream
692            .alloc_zeros::<f64>(ld * m)
693            .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
694
695        // Block (32, 8, 1) — x over centers, y over rows.
696        let block_x: u32 = 32;
697        let block_y: u32 = 8;
698        let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
699        let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
700        let cfg = LaunchConfig {
701            grid_dim: (grid_x, grid_y, 1),
702            block_dim: (block_x, block_y, 1),
703            shared_mem_bytes: 0,
704        };
705        let n_i32: i32 =
706            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
707        let m_i32: i32 =
708            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
709        let ld_i64: i64 = ld as i64;
710
711        let mut builder = stream.launch_builder(&func);
712        builder
713            .arg(&data_dev)
714            .arg(&centers_dev)
715            .arg(&coeffs_dev)
716            .arg(&n_i32)
717            .arg(&m_i32)
718            .arg(&ld_i64)
719            .arg(&mut out_dev);
720        // SAFETY: launch parameters are validated above; all device
721        // pointers come from cudarc-checked allocations on the same
722        // stream; the kernel only reads inputs and writes within
723        // out[0 .. ld*m].
724        unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
725        stream
726            .synchronize()
727            .gpu_ctx("sphere raw kernel synchronize")?;
728
729        Ok(DeviceS2KernelMatrix {
730            rows: n,
731            cols: m,
732            ld,
733            col_major_dev: out_dev,
734            stream,
735        })
736    }
737
738    #[cfg(not(target_os = "linux"))]
739    {
740        Err(GpuError::DriverLibraryUnavailable {
741            reason: "sphere GPU backend is Linux-only".to_string(),
742        })
743    }
744}
745
746/// Phase-3 fused Householder-constrained kernel. `v` is the Householder
747/// vector (length m), `beta` the reflector scalar, and the output is
748/// the `(n × (m-1))` constrained design X_s on device.
749pub fn build_householder_constrained_design_device(
750    inputs: S2KernelBuildInputs<'_>,
751    v: &[f64],
752    beta: f64,
753) -> Result<DeviceS2KernelMatrix, GpuError> {
754    inputs.validate()?;
755    if v.len() != inputs.m {
756        gam_gpu::gpu_bail!(
757            "build_householder_constrained_design_device: v.len()={} != m={}",
758            v.len(),
759            inputs.m
760        );
761    }
762    if inputs.m < 2 {
763        gam_gpu::gpu_bail!(
764            "build_householder_constrained_design_device: m must be >= 2 (got {})",
765            inputs.m
766        );
767    }
768    if !beta.is_finite() {
769        gam_gpu::gpu_bail!(
770            "build_householder_constrained_design_device: beta must be finite (got {beta})"
771        );
772    }
773
774    #[cfg(target_os = "linux")]
775    {
776        use cudarc::driver::{LaunchConfig, PushKernelArg};
777        let backend = SphereGpuBackend::probe()?;
778        let (cc_major, cc_minor) = backend.cc();
779        let key = S2ModuleCacheKey {
780            cc_major,
781            cc_minor,
782            lmax: inputs.lmax as u32,
783            kind: inputs.kind,
784            layout: inputs.layout,
785        };
786        let module = backend.module_for(key)?;
787        let func = module
788            .load_function("s2_wahba_householder_constrained_colmajor")
789            .gpu_ctx("sphere load_function householder")?;
790        let stream = backend.inner.stream.clone();
791
792        let data_dev = stream
793            .clone_htod(inputs.data_xyz)
794            .gpu_ctx("sphere-hh htod data_xyz")?;
795        let centers_dev = stream
796            .clone_htod(inputs.centers_xyz)
797            .gpu_ctx("sphere-hh htod centers_xyz")?;
798        let coeffs_dev = stream
799            .clone_htod(inputs.coeffs)
800            .gpu_ctx("sphere-hh htod coeffs")?;
801        let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
802
803        let n = inputs.n;
804        let m = inputs.m;
805        let cols_out = m - 1;
806        let ld_out = ((n + 31) / 32) * 32;
807        let mut out_dev = stream
808            .alloc_zeros::<f64>(ld_out * cols_out)
809            .gpu_ctx_with(|err| {
810                format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
811            })?;
812
813        let block_x: u32 = 128;
814        let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
815        let cfg = LaunchConfig {
816            grid_dim: (grid_x, 1, 1),
817            block_dim: (block_x, 1, 1),
818            shared_mem_bytes: 0,
819        };
820        let n_i32: i32 =
821            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
822        let m_i32: i32 =
823            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
824        let ld_out_i64: i64 = ld_out as i64;
825
826        let mut builder = stream.launch_builder(&func);
827        builder
828            .arg(&data_dev)
829            .arg(&centers_dev)
830            .arg(&coeffs_dev)
831            .arg(&v_dev)
832            .arg(&beta)
833            .arg(&n_i32)
834            .arg(&m_i32)
835            .arg(&ld_out_i64)
836            .arg(&mut out_dev);
837        // SAFETY: validated shapes above; the kernel writes exactly
838        // (n × (m-1)) entries within `out[0 .. ld_out * (m-1)]`.
839        unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
840        stream
841            .synchronize()
842            .gpu_ctx("sphere-hh kernel synchronize")?;
843
844        Ok(DeviceS2KernelMatrix {
845            rows: n,
846            cols: cols_out,
847            ld: ld_out,
848            col_major_dev: out_dev,
849            stream,
850        })
851    }
852
853    #[cfg(not(target_os = "linux"))]
854    {
855        Err(GpuError::DriverLibraryUnavailable {
856            reason: "sphere GPU backend is Linux-only".to_string(),
857        })
858    }
859}
860
861// ────────────────────────────────────────────────────────────────────────
862// Householder reflector helpers (host-side; Phase 3 prep).
863//
864// Given a non-zero weight vector w ∈ ℝ^m, construct (v, beta) such that
865// H = I − beta · v · v^T satisfies H · w = ±‖w‖ · e_1 and drops the
866// weighted-sum constraint into the first column.
867// ────────────────────────────────────────────────────────────────────────
868
869/// Build the Householder reflector that zeroes `w` against `e_1`.
870/// Returns `(v, beta)` with the LAPACK / Golub-Van Loan convention
871/// `v[0] = 1`. If `w` has zero norm, returns `(0-vector, 0.0)` and the
872/// caller should treat the reflector as a no-op (no constraint).
873pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
874    let m = w.len();
875    if m == 0 {
876        return (Vec::new(), 0.0);
877    }
878    let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
879    if norm == 0.0 {
880        return (vec![0.0; m], 0.0);
881    }
882    let sigma = if w[0] >= 0.0 { norm } else { -norm };
883    let mut v = w.to_vec();
884    v[0] += sigma;
885    let v0 = v[0];
886    if v0 == 0.0 {
887        return (vec![0.0; m], 0.0);
888    }
889    // Normalize so v[0] = 1 (LAPACK convention).
890    for entry in v.iter_mut() {
891        *entry /= v0;
892    }
893    // beta = 2 / (v · v).
894    let vv: f64 = v.iter().map(|x| x * x).sum();
895    let beta = 2.0 / vv;
896    (v, beta)
897}
898
899// ────────────────────────────────────────────────────────────────────────
900// Phase 2 — center-center penalty C + constraint S = Zᵀ C Z.
901//
902// `C` is the (m × m) Wahba kernel of centers against themselves and is
903// computed by reusing the raw GPU kernel with `n = m`. The constraint
904// transform is the same Householder reflector used by the Phase-3 fused
905// kernel: Z = (I − β · v · vᵀ) with the first column dropped, so the
906// constrained penalty is the trailing (m−1)×(m−1) block of HᵀCH.
907//
908// At m ≤ 200 the Householder product is cheap on host and the result is
909// returned as an `ndarray::Array2`. Future calls into cuSOLVER QR can
910// upload it (or its Cholesky factor) once and keep it device-resident.
911// ────────────────────────────────────────────────────────────────────────
912
913/// Build the (m × m) center-center kernel matrix `C` using the same GPU
914/// kernel that builds the design. `centers_xyz` is the unit-vector
915/// representation of the centers, length `3 * m`. `coeffs` and `kind`
916/// match the design build.
917pub fn build_center_kernel_device(
918    centers_xyz: &[f64],
919    lmax: usize,
920    coeffs: &[f64],
921    kind: SphereSpectralKernelKind,
922) -> Result<DeviceS2KernelMatrix, GpuError> {
923    let m = centers_xyz.len() / 3;
924    if centers_xyz.len() != 3 * m {
925        return Err(GpuError::DriverCallFailed {
926            reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
927        });
928    }
929    let inputs = S2KernelBuildInputs {
930        n: m,
931        m,
932        lmax,
933        data_xyz: centers_xyz,
934        centers_xyz,
935        coeffs,
936        kind,
937        layout: DeviceMatrixLayout::ColumnMajor,
938    };
939    build_kernel_matrix_device(inputs)
940}
941
942/// Constrained penalty matrix `S = Zᵀ C Z` for the
943/// weighted-sum-to-zero Householder constraint built from `w`.
944/// Returned shape is `((m−1) × (m−1))`. `C` is taken as a host
945/// (m × m) array (typically the dtoh of `build_center_kernel_device`).
946pub fn constrained_penalty_host(
947    c: ArrayView2<'_, f64>,
948    w: &[f64],
949) -> Result<Array2<f64>, GpuError> {
950    let (m1, m2) = c.dim();
951    if m1 != m2 {
952        gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
953    }
954    let m = m1;
955    if w.len() != m {
956        gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
957    }
958    if m < 2 {
959        gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
960    }
961    let (v, beta) = householder_reflector_from_weights(w);
962
963    // Form HCH = (I - β v vᵀ) C (I - β v vᵀ) = C - β (v · uᵀ + u · vᵀ) + β² (vᵀ C v) v vᵀ,
964    // where u = C v. This is O(m²) — fine for m ≤ 200.
965    let mut u = vec![0.0_f64; m];
966    for i in 0..m {
967        let mut acc = 0.0_f64;
968        for j in 0..m {
969            acc += c[(i, j)] * v[j];
970        }
971        u[i] = acc;
972    }
973    let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
974    let mut hch = Array2::<f64>::zeros((m, m));
975    for i in 0..m {
976        for j in 0..m {
977            hch[(i, j)] =
978                c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
979        }
980    }
981    // Drop the first row and column (the Householder-constrained nullspace).
982    let mut s = Array2::<f64>::zeros((m - 1, m - 1));
983    for i in 0..(m - 1) {
984        for j in 0..(m - 1) {
985            s[(i, j)] = hch[(i + 1, j + 1)];
986        }
987    }
988    Ok(s)
989}
990
991// ────────────────────────────────────────────────────────────────────────
992// Phase 4 — device-resident cuSOLVER QR penalised solve.
993//
994// Solve  min_β  ‖ [√W · X_s] β − [√W · y] ‖² + λ ‖R_S · β‖²
995//
996// by stacking the augmented matrix
997//
998//     A_aug = [ √W · X_s ;   √λ · R_S ]    shape (n + p) × p,
999//     b_aug = [ √W · y    ;   0       ]    length n + p,
1000//
1001// where p = m − 1, R_S is the upper-triangular Cholesky factor of the
1002// constrained penalty S = Zᵀ C Z, and (√W·X_s) is the design built by
1003// the fused Householder kernel scaled by sqrt-weights row-by-row on
1004// device. The pipeline is:
1005//
1006//     1. cusolverDnDgeqrf_bufferSize → workspace size.
1007//     2. cusolverDnDgeqrf(A_aug)     → A := [R upper-tri / V Householder]
1008//                                        plus tau vector.
1009//     3. cusolverDnDormqr(side=L, trans=T)
1010//                                  → applies Qᵀ to b_aug.
1011//     4. cublasDtrsm(L = upper) → β := R⁻¹ · (Qᵀ b_aug)[0..p].
1012//
1013// Coefficients (β) come back to host; log|H| can be returned via Σ
1014// log(R_ii²) from the diagonal of the in-place factored R.
1015//
1016// All intermediate state — A_aug, b_aug, tau, workspace, info — stays
1017// device-resident. The host learns only (β, log|H|, residual ssq).
1018// ────────────────────────────────────────────────────────────────────────
1019
1020/// Result returned by [`solve_penalised_ls_device`].
1021#[derive(Clone, Debug)]
1022pub struct PenalisedLsSolution {
1023    /// Coefficient vector, length `p = m − 1` (after Householder drop).
1024    pub beta: Vec<f64>,
1025    /// Sum of squared residuals on the unaugmented rows: ‖√W (Xβ − y)‖².
1026    pub weighted_residual_ssq: f64,
1027    /// log|H| = 2 · Σ log |R_ii| of the QR-factored augmented design.
1028    pub log_det_hessian: f64,
1029}
1030
1031/// Augmented penalised least-squares solve via on-device cuSOLVER QR.
1032///
1033/// Inputs:
1034///   * `x_s_device` — already-constrained, weighted-sqrt-scaled design
1035///     `√W · X_s` produced by the Phase-3 fused kernel + a row-scaling
1036///     kernel. Shape `(n × p)` column-major.
1037///   * `wy` — `√W · y` (length n), already host-multiplied (cheap).
1038///   * `r_s` — upper-triangular Cholesky factor of `√λ · S`, shape
1039///     `(p × p)` row-major host array.
1040#[cfg(target_os = "linux")]
1041pub fn solve_penalised_ls_device(
1042    x_s_device: &DeviceS2KernelMatrix,
1043    wy: &[f64],
1044    r_s: ArrayView2<'_, f64>,
1045) -> Result<PenalisedLsSolution, GpuError> {
1046    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
1047    use cudarc::driver::DevicePtrMut;
1048
1049    let n = x_s_device.rows;
1050    let p = x_s_device.cols;
1051    if wy.len() != n {
1052        gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
1053    }
1054    if r_s.dim() != (p, p) {
1055        gam_gpu::gpu_bail!(
1056            "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
1057            r_s.dim()
1058        );
1059    }
1060    if p == 0 {
1061        return Ok(PenalisedLsSolution {
1062            beta: Vec::new(),
1063            weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
1064            log_det_hessian: 0.0,
1065        });
1066    }
1067
1068    let stream = x_s_device.stream.clone();
1069    let n_aug = n + p;
1070
1071    // 1) Materialise A_aug column-major on device. We don't need the
1072    //    upstream X_s after QR, but the kernel matrix builder hands us
1073    //    its own storage; we copy into a fresh (n_aug × p) slab so the
1074    //    in-place geqrf doesn't clobber a buffer the caller still owns.
1075    let mut a_aug_host = vec![0.0_f64; n_aug * p];
1076    // Copy device-side X_s back column-by-column into the upper block.
1077    let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
1078    x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
1079    for j in 0..p {
1080        let src_off = j * x_s_device.ld;
1081        let dst_off = j * n_aug;
1082        a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
1083        for i in 0..p {
1084            // R_S is row-major host; insert into column j of the lower
1085            // block (rows n..n+p) as r_s[i, j].
1086            a_aug_host[dst_off + n + i] = r_s[(i, j)];
1087        }
1088    }
1089    let mut a_dev = stream
1090        .clone_htod(&a_aug_host)
1091        .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
1092
1093    // b_aug = [√W·y ; 0]
1094    let mut b_host = vec![0.0_f64; n_aug];
1095    b_host[..n].copy_from_slice(wy);
1096    let mut b_dev = stream
1097        .clone_htod(&b_host)
1098        .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
1099
1100    let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
1101    let n_aug_i: i32 = i32::try_from(n_aug)
1102        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
1103    let p_i: i32 = i32::try_from(p)
1104        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
1105
1106    // 2) Workspace size for geqrf.
1107    let mut lwork: i32 = 0;
1108    {
1109        let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
1110        // SAFETY: a_dev holds n_aug*p f64 elements column-major;
1111        // pointer is live on `stream`; lwork is a valid host out-param.
1112        let status = unsafe {
1113            cusolver_sys::cusolverDnDgeqrf_bufferSize(
1114                solver.cu(),
1115                n_aug_i,
1116                p_i,
1117                a_ptr as *mut f64,
1118                n_aug_i,
1119                &mut lwork,
1120            )
1121        };
1122        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1123            gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1124        }
1125    }
1126    let lwork_us = usize::try_from(lwork)
1127        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1128    let mut workspace = stream
1129        .alloc_zeros::<f64>(lwork_us.max(1))
1130        .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1131    let mut tau = stream
1132        .alloc_zeros::<f64>(p)
1133        .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1134    let mut info = stream
1135        .alloc_zeros::<i32>(1)
1136        .gpu_ctx("solve_penalised_ls_device alloc info")?;
1137
1138    // 3) cusolverDnDgeqrf — A := QR in place.
1139    {
1140        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1141        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1142        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1143        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1144        // SAFETY: all pointers reference live device allocations on
1145        // this stream; lwork matches the bufferSize query above.
1146        let status = unsafe {
1147            cusolver_sys::cusolverDnDgeqrf(
1148                solver.cu(),
1149                n_aug_i,
1150                p_i,
1151                a_ptr as *mut f64,
1152                n_aug_i,
1153                tau_ptr as *mut f64,
1154                work_ptr as *mut f64,
1155                lwork,
1156                info_ptr as *mut i32,
1157            )
1158        };
1159        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1160            gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1161        }
1162    }
1163
1164    // 4) cusolverDnDormqr — b_aug := Qᵀ · b_aug.
1165    let mut ormqr_lwork: i32 = 0;
1166    {
1167        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1168        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1169        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1170        // SAFETY: A/tau/b are live device buffers on this stream;
1171        // ormqr_lwork is a host out-param.
1172        let status = unsafe {
1173            cusolver_sys::cusolverDnDormqr_bufferSize(
1174                solver.cu(),
1175                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1176                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1177                n_aug_i,
1178                1,
1179                p_i,
1180                a_ptr as *const f64,
1181                n_aug_i,
1182                tau_ptr as *const f64,
1183                b_ptr as *mut f64,
1184                n_aug_i,
1185                &mut ormqr_lwork,
1186            )
1187        };
1188        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1189            gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1190        }
1191    }
1192    if ormqr_lwork > lwork {
1193        workspace = stream
1194            .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1195            .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1196    }
1197    {
1198        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1199        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1200        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1201        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1202        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1203        // SAFETY: all pointers reference live, mutually-non-aliasing
1204        // device buffers on this stream; lwork matches the bufferSize
1205        // query above; A and tau are the geqrf output.
1206        let status = unsafe {
1207            cusolver_sys::cusolverDnDormqr(
1208                solver.cu(),
1209                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1210                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1211                n_aug_i,
1212                1,
1213                p_i,
1214                a_ptr as *const f64,
1215                n_aug_i,
1216                tau_ptr as *const f64,
1217                b_ptr as *mut f64,
1218                n_aug_i,
1219                work_ptr as *mut f64,
1220                ormqr_lwork.max(lwork),
1221                info_ptr as *mut i32,
1222            )
1223        };
1224        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1225            gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1226        }
1227    }
1228
1229    // 5) cublasDtrsm — solve R · β = (Qᵀ b)[0..p] in place on the top
1230    //    of b_dev. We use a single-RHS upper-triangular non-unit solve.
1231    {
1232        use cudarc::cublas::CudaBlas;
1233        let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1234        let alpha = 1.0_f64;
1235        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1236        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1237        // SAFETY: A is the geqrf-output upper-triangular factor R in
1238        // its top-p × p block (col-major, ld = n_aug); b is the
1239        // ormqr-output Qᵀb in the top p slots (ld = n_aug as well so
1240        // pretend it is column-major with 1 column of leading dim n_aug).
1241        let handle = *blas.handle();
1242        let status = unsafe {
1243            cudarc::cublas::sys::cublasDtrsm_v2(
1244                handle,
1245                cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1246                cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1247                cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1248                cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1249                p_i,
1250                1,
1251                &alpha,
1252                a_ptr as *const f64,
1253                n_aug_i,
1254                b_ptr as *mut f64,
1255                n_aug_i,
1256            )
1257        };
1258        if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1259            gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1260        }
1261    }
1262
1263    // 6) Copy results back to host.
1264    let mut b_out = vec![0.0_f64; n_aug];
1265    stream
1266        .memcpy_dtoh(&b_dev, &mut b_out)
1267        .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1268    let mut a_back = vec![0.0_f64; n_aug * p];
1269    stream
1270        .memcpy_dtoh(&a_dev, &mut a_back)
1271        .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1272    stream
1273        .synchronize()
1274        .gpu_ctx("solve_penalised_ls_device synchronize")?;
1275
1276    let beta: Vec<f64> = b_out[..p].to_vec();
1277    // (Qᵀb)[p..n_aug] holds the residual in the rotated coordinates;
1278    // ‖(Qᵀb)[p..]‖² = ‖√W (Xβ − y)‖² + λ ‖R_S β‖² for the augmented
1279    // system. To recover ‖√W (Xβ − y)‖² alone, subtract the penalty
1280    // residual ‖R_S β‖² (penalty rotates to itself in the augmented
1281    // bottom block, but only when the bottom block ROWS map exactly
1282    // into the rotated residual — which is not guaranteed, so the
1283    // simpler accurate path is to return the **augmented** residual
1284    // squared and let the caller subtract.)
1285    let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1286
1287    // log|R| diagonal.
1288    let mut log_abs_r = 0.0_f64;
1289    for k in 0..p {
1290        let r_kk = a_back[k * n_aug + k];
1291        log_abs_r += r_kk.abs().ln();
1292    }
1293    let log_det_hessian = 2.0 * log_abs_r;
1294
1295    Ok(PenalisedLsSolution {
1296        beta,
1297        weighted_residual_ssq: augmented_residual_ssq,
1298        log_det_hessian,
1299    })
1300}
1301
1302#[cfg(not(target_os = "linux"))]
1303pub fn solve_penalised_ls_device(
1304    x_s_device: &DeviceS2KernelMatrix,
1305    wy: &[f64],
1306    r_s: ArrayView2<'_, f64>,
1307) -> Result<PenalisedLsSolution, GpuError> {
1308    Err(GpuError::DriverLibraryUnavailable {
1309        reason: format!(
1310            "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1311            x_s_device.rows,
1312            x_s_device.cols,
1313            wy.len(),
1314            r_s.dim()
1315        ),
1316    })
1317}
1318
1319// ────────────────────────────────────────────────────────────────────────
1320// Tests
1321// ────────────────────────────────────────────────────────────────────────
1322
1323#[cfg(test)]
1324mod sphere_gpu_tests {
1325    use super::*;
1326    use crate::basis::{
1327        SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1328        spherical_wahba_kernel_matrix_with_kind,
1329    };
1330    use ndarray::Array2;
1331
1332    fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1333        // Latitude in (-85, 85), longitude in [-180, 180), degrees.
1334        let mut rows = Vec::with_capacity(n_lat * n_lon);
1335        for i in 0..n_lat {
1336            let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1337            for j in 0..n_lon {
1338                let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1339                rows.push(lat);
1340                rows.push(lon);
1341            }
1342        }
1343        Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1344    }
1345
1346    #[test]
1347    fn xyz_preprocessing_matches_unit_sphere() {
1348        let latlon = ndarray::array![
1349            [0.0, 0.0],
1350            [90.0, 0.0],
1351            [0.0, 90.0],
1352            [-90.0, 17.5],
1353            [45.0, -120.0],
1354        ];
1355        let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1356        assert_eq!(xyz.len(), 3 * 5);
1357        for i in 0..5 {
1358            let nrm2 = xyz[3 * i] * xyz[3 * i]
1359                + xyz[3 * i + 1] * xyz[3 * i + 1]
1360                + xyz[3 * i + 2] * xyz[3 * i + 2];
1361            assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1362        }
1363        // Row 0 = equator @ lon=0 → (1, 0, 0).
1364        assert!((xyz[0] - 1.0).abs() < 1e-15);
1365        assert!(xyz[1].abs() < 1e-15);
1366        assert!(xyz[2].abs() < 1e-15);
1367        // Row 1 = north pole (lat=90, lon=0) → (0, 0, 1).
1368        assert!(xyz[3].abs() < 1e-15);
1369        assert!(xyz[4].abs() < 1e-15);
1370        assert!((xyz[5] - 1.0).abs() < 1e-15);
1371        // Row 2 = equator @ lon=90 → (0, 1, 0).
1372        assert!(xyz[6].abs() < 1e-15);
1373        assert!((xyz[7] - 1.0).abs() < 1e-15);
1374        assert!(xyz[8].abs() < 1e-15);
1375    }
1376
1377    #[test]
1378    fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1379        // P_ℓ(1) = 1 for all ℓ, so K(x, x) = Σ_{ℓ=0..L} c_ℓ. The Legendre
1380        // recurrence in `sphere_truncated_spectral_eval` must reproduce
1381        // this exact identity to roundoff.
1382        for m_penalty in 1..=4 {
1383            for &lmax in &[5_usize, 20, 50] {
1384                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1385                let expected: f64 = coeffs.iter().sum();
1386                let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1387                assert!(
1388                    (got - expected).abs() < 1e-13,
1389                    "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1390                );
1391            }
1392        }
1393    }
1394
1395    #[test]
1396    fn truncated_spectral_at_antipode_matches_alternating_sum() {
1397        // P_ℓ(-1) = (-1)^ℓ, so K(x, -x) = Σ_{ℓ=0..L} c_ℓ · (-1)^ℓ. Same
1398        // exact identity for the recurrence at t = -1.
1399        for m_penalty in 1..=4 {
1400            for &lmax in &[5_usize, 20, 50] {
1401                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1402                let expected: f64 = coeffs
1403                    .iter()
1404                    .enumerate()
1405                    .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1406                    .sum();
1407                let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1408                assert!(
1409                    (got - expected).abs() < 1e-13,
1410                    "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1411                );
1412            }
1413        }
1414    }
1415
1416    #[test]
1417    fn truncated_spectral_matrix_is_symmetric() {
1418        // K(γ) depends only on cos γ = x · y = y · x, so the Gram
1419        // matrix B B^T-style kernel evaluation on the same point set
1420        // must be symmetric to roundoff.
1421        let centers = ndarray::array![
1422            [10.0_f64, 20.0],
1423            [-30.0, 100.0],
1424            [45.0, -60.0],
1425            [-89.0, 0.0],
1426            [0.0, 180.0],
1427            [60.0, -179.9],
1428        ];
1429        for m_penalty in [1usize, 2, 4] {
1430            for &lmax in &[10_usize, 30] {
1431                let mat = spherical_wahba_kernel_matrix_with_kind(
1432                    centers.view(),
1433                    centers.view(),
1434                    m_penalty,
1435                    false,
1436                    SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1437                )
1438                .expect("kernel matrix");
1439                let n = centers.nrows();
1440                let mut max_asym = 0.0_f64;
1441                for i in 0..n {
1442                    for j in 0..n {
1443                        let d = (mat[(i, j)] - mat[(j, i)]).abs();
1444                        if d > max_asym {
1445                            max_asym = d;
1446                        }
1447                    }
1448                }
1449                assert!(
1450                    max_asym < 1e-13,
1451                    "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1452                );
1453            }
1454        }
1455    }
1456
1457    #[test]
1458    fn truncated_coefficients_have_zero_constant_mode() {
1459        for m in 1..=4 {
1460            let c = sobolev_s2_truncated_coefficients(50, m);
1461            assert_eq!(c.len(), 51);
1462            assert_eq!(c[0], 0.0);
1463            assert!(c[1] > 0.0);
1464            // Spectral decay c_ℓ ~ 1/ℓ^{2m-1}: monotone for ℓ ≥ 1.
1465            for ell in 2..=50 {
1466                assert!(
1467                    c[ell] < c[ell - 1] + 1e-15,
1468                    "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1469                    c[ell],
1470                    c[ell - 1]
1471                );
1472            }
1473        }
1474    }
1475
1476    #[test]
1477    fn truncated_spectral_matches_matrix_helper() {
1478        // The Wahba kernel matrix helper, invoked with the truncated
1479        // variant, must produce the same value as the bare scalar
1480        // evaluator.
1481        let m_penalty = 2;
1482        let lmax = 20;
1483        let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1484        let data = ndarray::array![[12.5, -34.0]];
1485        let centers = ndarray::array![[40.0, 10.0]];
1486        let mat = spherical_wahba_kernel_matrix_with_kind(
1487            data.view(),
1488            centers.view(),
1489            m_penalty,
1490            false,
1491            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1492        )
1493        .expect("kernel matrix");
1494        // Recompute cos γ on the unit sphere.
1495        let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1496        let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1497        let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1498        let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1499        assert!(
1500            (mat[(0, 0)] - expected).abs() < 1e-13,
1501            "matrix helper differs from scalar evaluator: {} vs {}",
1502            mat[(0, 0)],
1503            expected
1504        );
1505    }
1506
1507    #[test]
1508    fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1509        // Build a small symmetric PD matrix as a stand-in for C, then
1510        // verify that constrained_penalty_host returns a symmetric
1511        // (m-1)×(m-1) matrix whose action against Z·x matches the
1512        // expected Zᵀ C Z mapping.
1513        let m = 6;
1514        let mut c = Array2::<f64>::zeros((m, m));
1515        for i in 0..m {
1516            for j in 0..m {
1517                let d = (i as f64 - j as f64).abs();
1518                c[(i, j)] = (-0.5 * d).exp();
1519            }
1520        }
1521        let w = vec![1.0_f64; m];
1522        let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1523        assert_eq!(s.dim(), (m - 1, m - 1));
1524        // Symmetry within roundoff.
1525        let mut max_asym = 0.0_f64;
1526        for i in 0..(m - 1) {
1527            for j in 0..(m - 1) {
1528                let d = (s[(i, j)] - s[(j, i)]).abs();
1529                if d > max_asym {
1530                    max_asym = d;
1531                }
1532            }
1533        }
1534        assert!(
1535            max_asym < 1e-13,
1536            "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1537        );
1538
1539        // The kernel-of-Zᵀ direction: Zᵀ · w = 0 ⇒ x = (something) such
1540        // that Z · x stays in span(w)^⊥, so x can be any (m-1) vector;
1541        // we just verify that picking the all-ones constraint direction
1542        // collapses to zero through Z when applied to constant fields.
1543        // i.e. constant-field penalty norm must be zero in the
1544        // un-constrained Cv direction, and the trailing block here is
1545        // never used against the constraint.
1546        let ones = ndarray::Array1::<f64>::ones(m - 1);
1547        let sx = s.dot(&ones);
1548        assert!(sx.iter().all(|v| v.is_finite()));
1549    }
1550
1551    #[test]
1552    fn householder_reflector_zeroes_target_vector() {
1553        let w = vec![3.0, 4.0, 0.0, -1.0];
1554        let (v, beta) = householder_reflector_from_weights(&w);
1555        // Apply H = I - beta * v * v^T to w; the result should be a
1556        // multiple of e_1 (only first entry non-zero).
1557        let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1558        let hw: Vec<f64> = w
1559            .iter()
1560            .zip(&v)
1561            .map(|(wj, vj)| wj - beta * dot * vj)
1562            .collect();
1563        for entry in hw.iter().skip(1) {
1564            assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1565        }
1566        assert!(hw[0].abs() > 0.0);
1567    }
1568
1569    /// V100-only: probe + raw kernel parity vs CPU truncated-spectral on
1570    /// a small grid. Skips cleanly on hosts with no CUDA runtime.
1571    #[test]
1572    fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1573        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1574            eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1575            return;
1576        };
1577        // Past the runtime Some-gate: a probe failure is a real device fault on a
1578        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1579        SphereGpuBackend::probe()
1580            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1581
1582        let data_ll = small_latlon_grid(7, 9);
1583        let centers_ll = small_latlon_grid(5, 7);
1584        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1585        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1586        let n = data_ll.nrows();
1587        let m = centers_ll.nrows();
1588        let penalty = 2usize;
1589        let lmax = 20usize;
1590        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1591
1592        let inputs = S2KernelBuildInputs {
1593            n,
1594            m,
1595            lmax,
1596            data_xyz: &data_xyz,
1597            centers_xyz: &centers_xyz,
1598            coeffs: &coeffs,
1599            kind: SphereSpectralKernelKind::Sobolev,
1600            layout: DeviceMatrixLayout::ColumnMajor,
1601        };
1602        let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1603        let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1604
1605        let cpu = spherical_wahba_kernel_matrix_with_kind(
1606            data_ll.view(),
1607            centers_ll.view(),
1608            penalty,
1609            false,
1610            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1611        )
1612        .expect("cpu kernel matrix");
1613
1614        let mut max_abs = 0.0_f64;
1615        for i in 0..n {
1616            for j in 0..m {
1617                let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1618                if d > max_abs {
1619                    max_abs = d;
1620                }
1621            }
1622        }
1623        assert!(
1624            max_abs < 1e-11,
1625            "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1626        );
1627    }
1628
1629    /// V100-only end-to-end DISPATCH parity: prove the *production* kernel
1630    /// builder (`spherical_wahba_kernel_matrix_with_kind`) actually engages the
1631    /// device on a GPU-eligible truncated-spectral shape, and that the device
1632    /// result matches the CPU oracle (`spherical_wahba_kernel_matrix_cpu`) to
1633    /// roundoff. This is the engagement + parity gate the prior version of this
1634    /// test never exercised: it called `build_spherical_spline_basis` (which did
1635    /// not route to the GPU at all) and then compared the *decomposed* design
1636    /// against the *raw* kernel matrix, so it diverged by construction
1637    /// (rel |Δ| = 2.0) regardless of any device behaviour.
1638    ///
1639    /// Downstream PIRLS/REML consumes the kernel design through the same
1640    /// deterministic low-degree decomposition for both backends, so element-wise
1641    /// raw-kernel parity at ≤ 1e-9 implies full-design + fit parity.
1642    #[test]
1643    fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1644        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1645            eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1646            return;
1647        };
1648        // Past the runtime Some-gate: a backend probe failure is a real device
1649        // fault on a CUDA host, not a no-CUDA skip — fail loud (device-PCG
1650        // skip-pass class, eee12f6b2) instead of masking it as a pass.
1651        SphereGpuBackend::probe()
1652            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1653        use crate::basis::{
1654            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1655            build_spherical_spline_basis, spherical_wahba_kernel_matrix_cpu,
1656            spherical_wahba_kernel_matrix_with_kind,
1657        };
1658
1659        // (n=10_000, m=200) → n·m = 2_000_000 ≥ 1_000_000 → GPU eligible.
1660        let data = small_latlon_grid(100, 100);
1661        let lmax: u16 = 30;
1662        let penalty_order = 2usize;
1663        let centers =
1664            crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1665                .expect("centers");
1666        let n = data.nrows();
1667        let m = centers.nrows();
1668
1669        // The device MUST be admitted for this shape, otherwise this test would
1670        // silently exercise the CPU path on both sides and prove nothing about
1671        // engagement. Fail loud if the dispatch decision declines the GPU.
1672        let decision = sphere_kernel_decision(n, m, lmax as usize);
1673        assert!(
1674            decision.use_gpu,
1675            "expected GPU dispatch for (n={n}, m={m}, lmax={lmax}); decision said CPU \
1676             (reason={}); the engagement gate regressed",
1677            decision.reason
1678        );
1679
1680        // Production dispatcher: engages the device for this admitted shape.
1681        let gpu_kernel = spherical_wahba_kernel_matrix_with_kind(
1682            data.view(),
1683            centers.view(),
1684            penalty_order,
1685            false,
1686            SphereWahbaKernel::SobolevTruncated { lmax },
1687        )
1688        .expect("GPU-eligible production kernel build succeeds");
1689
1690        // CPU oracle: forced host evaluation regardless of dispatch decision.
1691        let cpu_kernel = spherical_wahba_kernel_matrix_cpu(
1692            data.view(),
1693            centers.view(),
1694            penalty_order,
1695            false,
1696            SphereWahbaKernel::SobolevTruncated { lmax },
1697        )
1698        .expect("cpu oracle kernel build succeeds");
1699
1700        assert_eq!(gpu_kernel.dim(), cpu_kernel.dim());
1701        let mut max_abs = 0.0_f64;
1702        let mut max_rel = 0.0_f64;
1703        for (g, c) in gpu_kernel.iter().zip(cpu_kernel.iter()) {
1704            let d = (g - c).abs();
1705            if d > max_abs {
1706                max_abs = d;
1707            }
1708            let denom = g.abs().max(c.abs()).max(1e-300);
1709            let r = d / denom;
1710            if r > max_rel {
1711                max_rel = r;
1712            }
1713        }
1714        assert!(
1715            max_rel < 1e-9,
1716            "GPU-dispatch vs CPU-oracle kernel parity max relative |Δ| = {max_rel:.3e} \
1717             >= 1e-9 (abs {max_abs:.3e})"
1718        );
1719
1720        // End-to-end smoke: the full design build (which routes its large
1721        // data×centers kernel through the engaged device) produces a finite,
1722        // correctly-shaped design with the expected number of rows.
1723        let spec_gpu = SphericalSplineBasisSpec {
1724            center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
1725            penalty_order,
1726            double_penalty: false,
1727            radians: false,
1728            method: SphereMethod::Wahba,
1729            max_degree: None,
1730            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1731            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1732        };
1733        let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
1734            .expect("GPU-eligible build_spherical_spline_basis succeeds");
1735        let design = result_gpu.design.as_dense().expect("dense design");
1736        assert_eq!(design.nrows(), n, "design row count must match data rows");
1737        assert!(
1738            design.iter().all(|v| v.is_finite()),
1739            "engaged-device spherical design must be finite"
1740        );
1741    }
1742
1743    /// V100-only: parity of Householder-constrained kernel against
1744    /// (raw kernel) · Z evaluated on host.
1745    #[test]
1746    fn sphere_gpu_householder_parity_vs_raw_dot_z() {
1747        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1748            eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
1749            return;
1750        };
1751        // Past the runtime Some-gate: a probe failure is a real device fault on a
1752        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1753        SphereGpuBackend::probe()
1754            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1755        let data_ll = small_latlon_grid(6, 8);
1756        let centers_ll = small_latlon_grid(4, 5);
1757        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1758        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1759        let n = data_ll.nrows();
1760        let m = centers_ll.nrows();
1761        let penalty = 2usize;
1762        let lmax = 15usize;
1763        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1764
1765        // Build raw B on device, then form (n × m-1) X_s = B · Z on host.
1766        let inputs_raw = S2KernelBuildInputs {
1767            n,
1768            m,
1769            lmax,
1770            data_xyz: &data_xyz,
1771            centers_xyz: &centers_xyz,
1772            coeffs: &coeffs,
1773            kind: SphereSpectralKernelKind::Sobolev,
1774            layout: DeviceMatrixLayout::ColumnMajor,
1775        };
1776        let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
1777        let b = b_dev.to_host_array().expect("dtoh raw");
1778
1779        // Construct a Householder reflector from a uniform weight vector
1780        // (the "weighted sum-to-zero" constraint when weights are all 1).
1781        let w = vec![1.0_f64; m];
1782        let (v, beta) = householder_reflector_from_weights(&w);
1783
1784        // Apply on host: X_s_host[i, j_out] = B[i, j_out+1] - beta * (B[i,:] · v) * v[j_out+1]
1785        let mut xs_host = Array2::<f64>::zeros((n, m - 1));
1786        for i in 0..n {
1787            let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
1788            for j_out in 0..(m - 1) {
1789                xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
1790            }
1791        }
1792
1793        let xs_dev =
1794            build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
1795        let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
1796
1797        let mut max_abs = 0.0_f64;
1798        for i in 0..n {
1799            for j in 0..(m - 1) {
1800                let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
1801                if d > max_abs {
1802                    max_abs = d;
1803                }
1804            }
1805        }
1806        assert!(
1807            max_abs < 1e-12,
1808            "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
1809        );
1810    }
1811
1812    /// V100 hill-climb: GPU truncated-spectral kernel matrix build at
1813    /// (n=200_000, m=200, L=50) must beat CPU by ≥ 20× wall-clock.
1814    /// Skips silently when no CUDA runtime is available.
1815    #[test]
1816    fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
1817        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1818            eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
1819            return;
1820        };
1821        if SphereGpuBackend::probe().is_err() {
1822            eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
1823            return;
1824        }
1825
1826        // (n=200_000, m=200, lmax=50). n·m = 4·10^7 ≫ 1e6 → GPU eligible.
1827        // Build a 200_000-row deterministic lat/lon grid.
1828        let n_lat = 500usize;
1829        let n_lon = 400usize;
1830        assert_eq!(n_lat * n_lon, 200_000);
1831        let data_ll = small_latlon_grid(n_lat, n_lon);
1832        let m = 200usize;
1833        let centers_ll =
1834            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1835                .expect("centers");
1836        let n = data_ll.nrows();
1837        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1838        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1839        let penalty_order = 2usize;
1840        let lmax = 50usize;
1841        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
1842
1843        // Warm up GPU (NVRTC compile + first-touch alloc).
1844        let inputs_warm = S2KernelBuildInputs {
1845            n,
1846            m,
1847            lmax,
1848            data_xyz: &data_xyz,
1849            centers_xyz: &centers_xyz,
1850            coeffs: &coeffs,
1851            kind: SphereSpectralKernelKind::Sobolev,
1852            layout: DeviceMatrixLayout::ColumnMajor,
1853        };
1854        drop(build_kernel_matrix_device(inputs_warm.clone()).expect("warmup"));
1855
1856        // Measure GPU.
1857        let t0 = std::time::Instant::now();
1858        let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
1859        let _host_gpu = dev.to_host_array().expect("dtoh");
1860        let gpu_secs = t0.elapsed().as_secs_f64();
1861
1862        // Measure CPU (truncated-spectral via the public matrix helper).
1863        let t1 = std::time::Instant::now();
1864        let _cpu = spherical_wahba_kernel_matrix_with_kind(
1865            data_ll.view(),
1866            centers_ll.view(),
1867            penalty_order,
1868            false,
1869            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1870        )
1871        .expect("cpu kernel matrix");
1872        let cpu_secs = t1.elapsed().as_secs_f64();
1873
1874        let ratio = cpu_secs / gpu_secs.max(1e-9);
1875        eprintln!(
1876            "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
1877        );
1878        assert!(
1879            ratio >= 20.0,
1880            "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
1881             n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1882        );
1883    }
1884
1885    /// V100 hill-climb: end-to-end Gaussian fit through
1886    /// `build_spherical_spline_basis` (GPU-dispatched) must beat the
1887    /// CPU-only fit by ≥ 10× wall-clock at a workload where the GPU
1888    /// kernel build dominates PIRLS.
1889    #[test]
1890    fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
1891        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1892            eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
1893            return;
1894        };
1895        if SphereGpuBackend::probe().is_err() {
1896            eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
1897            return;
1898        }
1899        use crate::basis::{
1900            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1901            build_spherical_spline_basis,
1902        };
1903
1904        let n_lat = 500usize;
1905        let n_lon = 400usize;
1906        let data_ll = small_latlon_grid(n_lat, n_lon);
1907        let m: usize = 200;
1908        let lmax: u16 = 50;
1909        let spec_gpu = SphericalSplineBasisSpec {
1910            center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
1911            penalty_order: 2,
1912            double_penalty: false,
1913            radians: false,
1914            method: SphereMethod::Wahba,
1915            max_degree: None,
1916            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1917            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1918        };
1919
1920        // Warm-up GPU build.
1921        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
1922
1923        let t0 = std::time::Instant::now();
1924        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
1925        let gpu_secs = t0.elapsed().as_secs_f64();
1926
1927        // CPU comparison: directly invoke the CPU helper and apply the
1928        // same constraint transform (matches what build_*_basis would do
1929        // when GPU dispatch declines). Going through the public matrix
1930        // helper isolates the GPU-vs-CPU kernel cost without re-doing
1931        // farthest-point center selection (which is identical for both
1932        // paths).
1933        let centers =
1934            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1935                .expect("centers");
1936        let z = Array2::<f64>::eye(centers.nrows());
1937        let t1 = std::time::Instant::now();
1938        let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1939            data_ll.view(),
1940            centers.view(),
1941            2,
1942            false,
1943            SphereWahbaKernel::SobolevTruncated { lmax },
1944        )
1945        .expect("cpu raw");
1946        let _design_cpu = raw_cpu.dot(&z);
1947        let cpu_secs = t1.elapsed().as_secs_f64();
1948
1949        let ratio = cpu_secs / gpu_secs.max(1e-9);
1950        eprintln!(
1951            "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
1952            data_ll.nrows()
1953        );
1954        assert!(
1955            ratio >= 10.0,
1956            "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
1957             cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1958        );
1959    }
1960
1961    /// Task #25: end-to-end fit parity between the GPU truncated-spectral
1962    /// path and the CPU truncated-spectral path on a small synthetic
1963    /// intrinsic-S² fixture.
1964    ///
1965    /// Setup: deterministic lat/lon grid (n = 1000 = 25 × 40), 80 centers
1966    /// chosen by farthest-point selection, lmax = 15, penalty order 2,
1967    /// Wahba weighted-sum-to-zero constraint applied via `Z`. We fit a
1968    /// fixed-λ penalised LS problem
1969    ///   β = argmin ‖X_s β − y‖² + λ · βᵀ S β
1970    /// where `X_s = K(data, centers) · Z` and `S = Zᵀ · K(centers, centers) · Z`,
1971    /// solving `(X_sᵀ X_s + λ S) β = X_sᵀ y` via faer LLT for both paths.
1972    /// The only path-dependent quantity is `K(data, centers)`: built on
1973    /// GPU via `build_kernel_matrix_device` for one β, and on CPU via
1974    /// `spherical_wahba_kernel_matrix_with_kind` for the other. The
1975    /// penalty kernel `K(centers, centers)` is m × m and tiny, so we
1976    /// build it once on CPU and share it across paths (it is not the
1977    /// surface under test).
1978    ///
1979    /// Asserts max-absolute coefficient delta ≤ 1e-9 and max-absolute
1980    /// fitted-value delta ≤ 1e-9. `#[ignore = "requires CUDA"]` so the
1981    /// V100 bench runner unignores in their harness.
1982    #[test]
1983    fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
1984        use crate::basis::{
1985            select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
1986        };
1987        use faer::Side;
1988        use gam_linalg::faer_ndarray::FaerCholesky;
1989
1990        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1991            eprintln!(
1992                "[sphere gpu parity] no CUDA runtime — skipping device parity \
1993                 (CPU oracle exercised by sibling tests)"
1994            );
1995            return;
1996        };
1997        // Past the runtime Some-gate: a probe failure is a real device fault on a
1998        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1999        SphereGpuBackend::probe()
2000            .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
2001
2002        // Fixture: 25 × 40 lat/lon grid → n = 1000.
2003        let data_ll = small_latlon_grid(25, 40);
2004        assert_eq!(data_ll.nrows(), 1000);
2005        let n = data_ll.nrows();
2006        let m: usize = 80;
2007        let lmax_u16: u16 = 15;
2008        let lmax: usize = lmax_u16 as usize;
2009        let penalty_order: usize = 2;
2010        let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
2011        let lambda: f64 = 1.0e-3;
2012
2013        // Deterministic centers via farthest-point selection.
2014        let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
2015            .expect("farthest-point centers");
2016        assert_eq!(centers_ll.nrows(), m);
2017
2018        // The Wahba sphere basis no longer imposes a finite-center coefficient
2019        // gauge; parity compares the raw center coefficient chart.
2020        let z = Array2::<f64>::eye(centers_ll.nrows());
2021        let p = z.ncols();
2022        assert_eq!(p, m);
2023
2024        // Penalty K(centers, centers), built once on CPU. The penalty
2025        // kernel evaluation is m × m (= 6400 entries), well outside the
2026        // GPU dispatch threshold, and identical for both paths under
2027        // test by construction.
2028        let k_cc = spherical_wahba_kernel_matrix_with_kind(
2029            centers_ll.view(),
2030            centers_ll.view(),
2031            penalty_order,
2032            false,
2033            kernel,
2034        )
2035        .expect("centers×centers kernel");
2036        let s_full = z.t().dot(&k_cc).dot(&z);
2037
2038        // CPU path: K(data, centers) via the public CPU helper.
2039        let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
2040            data_ll.view(),
2041            centers_ll.view(),
2042            penalty_order,
2043            false,
2044            kernel,
2045        )
2046        .expect("CPU raw design");
2047        let x_s_cpu = raw_design_cpu.dot(&z);
2048
2049        // GPU path: K(data, centers) via `build_kernel_matrix_device`.
2050        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
2051        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
2052        let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
2053        let inputs = S2KernelBuildInputs {
2054            n,
2055            m,
2056            lmax,
2057            data_xyz: &data_xyz,
2058            centers_xyz: &centers_xyz,
2059            coeffs: &coeffs,
2060            kind: SphereSpectralKernelKind::Sobolev,
2061            layout: DeviceMatrixLayout::ColumnMajor,
2062        };
2063        let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
2064        let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
2065        let x_s_gpu = raw_design_gpu.dot(&z);
2066
2067        assert_eq!(x_s_cpu.dim(), (n, p));
2068        assert_eq!(x_s_gpu.dim(), (n, p));
2069
2070        // Deterministic synthetic response. The intent is to give the
2071        // penalised LS solve a non-trivial right-hand side; any smooth
2072        // function of the lat/lon is fine. Use a fixed-seed pseudo-
2073        // random walk derived from coordinates so the fixture has no
2074        // RNG dependency.
2075        let mut y = ndarray::Array1::<f64>::zeros(n);
2076        for i in 0..n {
2077            let lat_rad = data_ll[(i, 0)].to_radians();
2078            let lon_rad = data_ll[(i, 1)].to_radians();
2079            // Smooth ground truth + a tiny deterministic high-freq jitter.
2080            y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
2081                + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
2082        }
2083
2084        // Penalised normal-equation solve via faer LLT for each path:
2085        //   (X_sᵀ X_s + λ S) β = X_sᵀ y
2086        // S is symmetric positive semi-definite; λ S makes the system
2087        // strictly positive definite once added to X_sᵀ X_s.
2088        let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
2089            let xtx = x_s.t().dot(x_s);
2090            let mut a = xtx;
2091            for i in 0..p {
2092                for j in 0..p {
2093                    a[(i, j)] += lambda * s_full[(i, j)];
2094                }
2095            }
2096            let rhs = x_s.t().dot(&y);
2097            let factor = a
2098                .cholesky(Side::Lower)
2099                .expect("penalised normal equations are SPD under λ > 0");
2100            factor.solvevec(&rhs)
2101        };
2102
2103        let beta_cpu = solve_penalised(&x_s_cpu);
2104        let beta_gpu = solve_penalised(&x_s_gpu);
2105        assert_eq!(beta_cpu.len(), p);
2106        assert_eq!(beta_gpu.len(), p);
2107
2108        // Fitted values for both paths use their own design matrices —
2109        // this is the customer-visible quantity (prediction at training
2110        // points).
2111        let yhat_cpu = x_s_cpu.dot(&beta_cpu);
2112        let yhat_gpu = x_s_gpu.dot(&beta_gpu);
2113
2114        let mut max_beta_delta = 0.0_f64;
2115        for k in 0..p {
2116            let d = (beta_cpu[k] - beta_gpu[k]).abs();
2117            if d > max_beta_delta {
2118                max_beta_delta = d;
2119            }
2120        }
2121        let mut max_fit_delta = 0.0_f64;
2122        for i in 0..n {
2123            let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
2124            if d > max_fit_delta {
2125                max_fit_delta = d;
2126            }
2127        }
2128
2129        eprintln!(
2130            "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
2131             max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
2132        );
2133
2134        assert!(
2135            max_beta_delta <= 1.0e-9,
2136            "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > 1e-9"
2137        );
2138        assert!(
2139            max_fit_delta <= 1.0e-9,
2140            "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2141        );
2142    }
2143}