Skip to main content

gam_terms/basis/
sphere_gpu.rs

1//! GPU NVRTC Wahba intrinsic-S2 kernel matrix construction.
2//!
3//! This module owns the device-side construction of the Wahba reproducing
4//! kernel basis matrix on the 2-sphere using the **finite truncated
5//! spectral Legendre series**
6//!
7//! `K_L(γ) = Σ_{ℓ=1..L} c_ℓ · P_ℓ(cos γ)`,
8//!
9//! evaluated entry-by-entry against the 3-term Legendre recurrence kept
10//! in registers. The host CPU parity target is the matching
11//! `SphereWahbaKernel::SobolevTruncated { lmax }` /
12//! `SphereWahbaKernel::PseudoTruncated { lmax }` variant added to
13//! `src/terms/basis.rs` (single source: same recurrence, same c_ℓ).
14//!
15//! The device path evaluates the raw column-major kernel matrix with `f64`
16//! Legendre recurrence math. Host code owns centering, constraints, and solver
17//! assembly in `basis.rs`.
18
19use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36/// Which truncated-spectral Wahba kernel to evaluate on device. Matches
37/// the CPU `SphereWahbaKernel::{SobolevTruncated, PseudoTruncated}` so
38/// parity tests are well-defined.
39#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41    /// `c_ℓ = (2ℓ+1) / (4π · [ℓ(ℓ+1)]^m)` — true `H^m(S²)` Sobolev RKHS.
42    Sobolev,
43    /// `c_ℓ = 2 / (4π · Π_{k=1..m+1}(ℓ + k))` — Wahba 1981 pseudo-spline.
44    Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48    /// `c_0 = 0`, `c_ℓ = c_ℓ(m)` for `ℓ = 1..=lmax`. Returned vector has
49    /// length `lmax + 1` and is uploaded verbatim to constant/global
50    /// memory before kernel launch.
51    pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52        match self {
53            SphereSpectralKernelKind::Sobolev => {
54                crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55            }
56            SphereSpectralKernelKind::Pseudo => {
57                crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58            }
59        }
60    }
61
62    /// Stable string tag used in the NVRTC module cache key + logs.
63    pub const fn tag(self) -> &'static str {
64        match self {
65            SphereSpectralKernelKind::Sobolev => "sobolev",
66            SphereSpectralKernelKind::Pseudo => "pseudo",
67        }
68    }
69}
70
71/// Layout of the (n,m) kernel design matrix on device. The Wahba
72/// pipeline downstream of this kernel (cuBLAS GEMM, cuSOLVER GEQRF)
73/// requires column-major.
74#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76    ColumnMajor,
77}
78
79/// Lat/lon (degrees or radians) → unit vector `(x, y, z)` on S² ⊂ ℝ³.
80/// Returns a flat `Vec<f64>` of length `3 * n` in the row-major layout
81/// `[x_0, y_0, z_0, x_1, y_1, z_1, …]`, ready for one `htod` upload.
82///
83/// `radians = false` interprets inputs as degrees (the codebase default
84/// for `SphericalSplineBasisSpec`).
85pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86    if latlon.ncols() != 2 {
87        return Err(format!(
88            "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89            latlon.shape()
90        ));
91    }
92    let deg = if radians {
93        1.0
94    } else {
95        std::f64::consts::PI / 180.0
96    };
97    let n = latlon.nrows();
98    let mut out = Vec::with_capacity(3 * n);
99    for row in latlon.outer_iter() {
100        let lat = row[0] * deg;
101        let lon = row[1] * deg;
102        let (s_lat, c_lat) = lat.sin_cos();
103        let (s_lon, c_lon) = lon.sin_cos();
104        // Standard geographic→cartesian: pole on +z.
105        out.push(c_lat * c_lon);
106        out.push(c_lat * s_lon);
107        out.push(s_lat);
108    }
109    Ok(out)
110}
111
112/// Device-resident `(rows × cols)` matrix in column-major layout with
113/// leading dimension `ld ≥ rows`. The slice holds `ld * cols` `f64`
114/// elements; entry `(i, j)` lives at `col_major_dev[j * ld + i]`.
115///
116/// On non-Linux builds the type is intentionally a host shadow so the
117/// surrounding orchestration compiles without cudarc.
118#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120    pub rows: usize,
121    pub cols: usize,
122    pub ld: usize,
123    pub col_major_dev: CudaSlice<f64>,
124    pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129    pub rows: usize,
130    pub cols: usize,
131    pub ld: usize,
132    /// Host shadow for CPU-only builds.
133    pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137    /// Copy the device matrix back to the host as a regular ndarray
138    /// `(rows × cols)` row-major view. Convenience for tests + parity
139    /// comparisons; production paths should keep the matrix resident.
140    pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
141        let mut col_major = vec![0.0_f64; self.ld * self.cols];
142        self.copy_to_host_col_major(&mut col_major)?;
143        let mut out = Array2::<f64>::zeros((self.rows, self.cols));
144        for j in 0..self.cols {
145            for i in 0..self.rows {
146                out[(i, j)] = col_major[j * self.ld + i];
147            }
148        }
149        Ok(out)
150    }
151
152    /// Copy the underlying `(ld × cols)` column-major payload to a
153    /// caller-provided buffer. Used by `to_host_array` and by the
154    /// device-resident cuSOLVER consumer when it needs to extract the
155    /// coefficient vector.
156    #[cfg(target_os = "linux")]
157    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
158        let needed = self.ld * self.cols;
159        if dst.len() != needed {
160            gam_gpu::gpu_bail!(
161                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
162                dst.len(),
163                needed
164            );
165        }
166        self.stream
167            .memcpy_dtoh(&self.col_major_dev, dst)
168            .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
169        self.stream
170            .synchronize()
171            .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
172        Ok(())
173    }
174
175    #[cfg(not(target_os = "linux"))]
176    pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
177        let needed = self.ld * self.cols;
178        if dst.len() != needed {
179            gam_gpu::gpu_bail!(
180                "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
181                dst.len(),
182                needed
183            );
184        }
185        dst.copy_from_slice(&self.col_major_dev);
186        Ok(())
187    }
188}
189
190// ────────────────────────────────────────────────────────────────────────
191// Inputs
192// ────────────────────────────────────────────────────────────────────────
193
194/// Host-side inputs needed to launch `s2_wahba_legendre_colmajor`.
195///
196/// `data_xyz` and `centers_xyz` are flat row-major
197/// `[x_0, y_0, z_0, …]` length `3 * n` and `3 * m` respectively, pre-
198/// computed via [`latlon_to_xyz_host`]. `coeffs` has length `lmax + 1`,
199/// indexed as `coeffs[ℓ] = c_ℓ` with `c_0 = 0`.
200#[derive(Clone, Debug)]
201pub struct S2KernelBuildInputs<'a> {
202    pub n: usize,
203    pub m: usize,
204    pub lmax: usize,
205    pub data_xyz: &'a [f64],
206    pub centers_xyz: &'a [f64],
207    pub coeffs: &'a [f64],
208    pub kind: SphereSpectralKernelKind,
209    pub layout: DeviceMatrixLayout,
210}
211
212impl<'a> S2KernelBuildInputs<'a> {
213    fn validate(&self) -> Result<(), GpuError> {
214        if self.lmax == 0 {
215            return Err(GpuError::DriverCallFailed {
216                reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
217            });
218        }
219        if self.data_xyz.len() != 3 * self.n {
220            gam_gpu::gpu_bail!(
221                "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
222                self.data_xyz.len(),
223                3 * self.n
224            );
225        }
226        if self.centers_xyz.len() != 3 * self.m {
227            gam_gpu::gpu_bail!(
228                "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
229                self.centers_xyz.len(),
230                3 * self.m
231            );
232        }
233        if self.coeffs.len() != self.lmax + 1 {
234            gam_gpu::gpu_bail!(
235                "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
236                self.coeffs.len(),
237                self.lmax + 1
238            );
239        }
240        if self.coeffs[0] != 0.0 {
241            return Err(GpuError::DriverCallFailed {
242                reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
243            });
244        }
245        Ok(())
246    }
247}
248
249// ────────────────────────────────────────────────────────────────────────
250// NVRTC kernel source — raw and Householder-fused variants.
251//
252// Both compile with `--std=c++17 --gpu-architecture=compute_${cc}` and
253// take LMAX as a compile-time `#define`. Block (32, 8, 1), shared-mem
254// tiles for one data row × 3 doubles per warp and one center × 3
255// doubles per warp.
256// ────────────────────────────────────────────────────────────────────────
257
258#[cfg(target_os = "linux")]
259const KERNEL_TEMPLATE: &str = r#"
260// LMAX is supplied by the host via a `#define LMAX ...` prepended to
261// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
262extern "C" __global__
263__launch_bounds__(256)
264void s2_wahba_legendre_colmajor(
265    const double* __restrict__ data_xyz,    // n × 3 (row-major flat)
266    const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
267    const double* __restrict__ coeffs,      // length LMAX + 1, coeffs[0] = 0
268    int n,
269    int m,
270    long long ld,
271    double* __restrict__ out                // ld × m column-major
272) {
273    const int i = blockIdx.y * blockDim.y + threadIdx.y;
274    const int j = blockIdx.x * blockDim.x + threadIdx.x;
275    if (i >= n || j >= m) return;
276
277    // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
278    const double xi = data_xyz[3 * i + 0];
279    const double yi = data_xyz[3 * i + 1];
280    const double zi = data_xyz[3 * i + 2];
281    const double cxj = centers_xyz[3 * j + 0];
282    const double cyj = centers_xyz[3 * j + 1];
283    const double czj = centers_xyz[3 * j + 2];
284
285    // t = clamp(x_i · z_j, -1, +1).
286    double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
287    if (t >  1.0) t =  1.0;
288    if (t < -1.0) t = -1.0;
289
290    // Legendre 3-term recurrence in registers.
291    // P_0(t) = 1, P_1(t) = t.
292    double p_prev = 1.0;
293    double p_curr = t;
294    double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
295
296    #pragma unroll 8
297    for (int ell = 1; ell < LMAX; ++ell) {
298        const double lf  = (double) ell;
299        const double inv = 1.0 / (lf + 1.0);
300        // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
301        const double p_next =
302            fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
303        acc = fma(coeffs[ell + 1], p_next, acc);
304        p_prev = p_curr;
305        p_curr = p_next;
306    }
307
308    out[(long long) j * ld + (long long) i] = acc;
309}
310
311// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
312// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
313// i.e. drop the first column after applying Z. Each thread computes one
314// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
315// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
316// for j_out in 0..m-1.
317//
318// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
319// over centers in an inner loop — register-bound by the per-row state
320// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
321extern "C" __global__
322__launch_bounds__(128)
323void s2_wahba_householder_constrained_colmajor(
324    const double* __restrict__ data_xyz,    // n × 3
325    const double* __restrict__ centers_xyz, // m × 3
326    const double* __restrict__ coeffs,      // length LMAX + 1
327    const double* __restrict__ v,           // length m, Householder vector
328    double beta,
329    int n,
330    int m,
331    long long ld_out,
332    double* __restrict__ out                // ld_out × (m-1) column-major
333) {
334    const int i = blockIdx.x * blockDim.x + threadIdx.x;
335    if (i >= n) return;
336
337    const double xi = data_xyz[3 * i + 0];
338    const double yi = data_xyz[3 * i + 1];
339    const double zi = data_xyz[3 * i + 2];
340
341    // Pass 1: compute d_i = sum_j v[j] * B[i, j].
342    double d_i = 0.0;
343    for (int j = 0; j < m; ++j) {
344        const double cxj = centers_xyz[3 * j + 0];
345        const double cyj = centers_xyz[3 * j + 1];
346        const double czj = centers_xyz[3 * j + 2];
347        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
348        if (t >  1.0) t =  1.0;
349        if (t < -1.0) t = -1.0;
350
351        double p_prev = 1.0;
352        double p_curr = t;
353        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
354        #pragma unroll 8
355        for (int ell = 1; ell < LMAX; ++ell) {
356            const double lf  = (double) ell;
357            const double inv = 1.0 / (lf + 1.0);
358            const double p_next =
359                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
360            acc = fma(coeffs[ell + 1], p_next, acc);
361            p_prev = p_curr;
362            p_curr = p_next;
363        }
364        d_i = fma(v[j], acc, d_i);
365    }
366
367    // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
368    const double bd = beta * d_i;
369    for (int j_out = 0; j_out < m - 1; ++j_out) {
370        const int j = j_out + 1;
371        const double cxj = centers_xyz[3 * j + 0];
372        const double cyj = centers_xyz[3 * j + 1];
373        const double czj = centers_xyz[3 * j + 2];
374        double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
375        if (t >  1.0) t =  1.0;
376        if (t < -1.0) t = -1.0;
377
378        double p_prev = 1.0;
379        double p_curr = t;
380        double acc    = coeffs[0] * p_prev + coeffs[1] * p_curr;
381        #pragma unroll 8
382        for (int ell = 1; ell < LMAX; ++ell) {
383            const double lf  = (double) ell;
384            const double inv = 1.0 / (lf + 1.0);
385            const double p_next =
386                fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
387            acc = fma(coeffs[ell + 1], p_next, acc);
388            p_prev = p_curr;
389            p_curr = p_next;
390        }
391        const double xs = acc - bd * v[j];
392        out[(long long) j_out * ld_out + (long long) i] = xs;
393    }
394}
395"#;
396
397// ────────────────────────────────────────────────────────────────────────
398// Module cache key + per-process backend.
399// ────────────────────────────────────────────────────────────────────────
400
401/// Module cache key: every distinct `(CC, LMAX, kind, layout, kernel
402/// flavor)` compiles to a different PTX. `precision = f64` and the
403/// (32, 8, 1) raw-kernel block / (128, 1, 1) Householder-kernel block
404/// shapes are baked into the kernel source so they are implicit in the
405/// flavor tag and don't appear here.
406#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
407pub struct S2ModuleCacheKey {
408    pub cc_major: i32,
409    pub cc_minor: i32,
410    pub lmax: u32,
411    pub kind: SphereSpectralKernelKind,
412    pub layout: DeviceMatrixLayout,
413}
414
415/// Returns `true` if this build was compiled with the Linux + cudarc GPU
416/// backend that runs the S² Wahba kernels.
417pub const fn sphere_gpu_compiled() -> bool {
418    cfg!(target_os = "linux")
419}
420
421/// Decide whether the GPU sphere kernel matrix path is eligible for
422/// `(n, m, lmax)`. Heuristic per the math spec:
423///   * `n * m >= 1_000_000`
424///   * `lmax <= 200`
425///   * device memory budget admits at least one `(ld × m)` design at
426///     `ld = ((n + 31) / 32) * 32`.
427#[must_use]
428pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
429    let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
430        let ld = ((n + 31) / 32) * 32;
431        let needed_bytes = ld
432            .saturating_mul(m)
433            .saturating_mul(std::mem::size_of::<f64>());
434        let budget = runtime.memory_budget_bytes;
435        n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
436    } else {
437        false
438    };
439    decide(
440        GpuKernel::SpatialKernelOperator,
441        gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
442    )
443}
444
445#[cfg(target_os = "linux")]
446struct SphereGpuContext {
447    ctx: Arc<CudaContext>,
448    stream: Arc<CudaStream>,
449    modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
450    cc_major: i32,
451    cc_minor: i32,
452}
453
454/// Process-wide sphere GPU backend. Lazy-initialised on first call to
455/// [`SphereGpuBackend::probe`].
456pub struct SphereGpuBackend {
457    #[cfg(target_os = "linux")]
458    inner: SphereGpuContext,
459}
460
461impl SphereGpuBackend {
462    /// Lazily initialise the process-wide sphere backend.
463    pub fn probe() -> Result<&'static Self, GpuError> {
464        static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
465        BACKEND
466            .get_or_init(|| {
467                #[cfg(target_os = "linux")]
468                {
469                    Self::probe_linux()
470                }
471                #[cfg(not(target_os = "linux"))]
472                {
473                    Err(GpuError::DriverLibraryUnavailable {
474                        reason: "sphere GPU backend is Linux-only".to_string(),
475                    })
476                }
477            })
478            .as_ref()
479            .map_err(GpuError::clone)
480    }
481
482    #[cfg(target_os = "linux")]
483    fn probe_linux() -> Result<Self, GpuError> {
484        let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
485        Ok(SphereGpuBackend {
486            inner: SphereGpuContext {
487                ctx: parts.ctx,
488                stream: parts.stream,
489                modules: Mutex::new(HashMap::new()),
490                cc_major: parts.capability.compute_major,
491                cc_minor: parts.capability.compute_minor,
492            },
493        })
494    }
495
496    /// NVRTC-compile (or fetch from cache) the module for `key`. The
497    /// returned module exposes both raw and Householder-fused kernels.
498    #[cfg(target_os = "linux")]
499    fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
500        if let Ok(guard) = self.inner.modules.lock() {
501            if let Some(existing) = guard.get(&key) {
502                return Ok(existing.clone());
503            }
504        }
505        // CompileOptions in cudarc 0.19 takes `arch: Option<&'static str>`
506        // which we cannot satisfy with a runtime-built string. Prepend the
507        // `LMAX` macro directly to the source so the NVRTC compile is a
508        // pure `compile_ptx`, matching the sibling kernels' invocation
509        // pattern. The kernel itself targets the device the driver
510        // reports (Volta+).
511        let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
512        let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
513            format!(
514                "sphere NVRTC compile (kind={}, lmax={}): {err}",
515                key.kind.tag(),
516                key.lmax
517            )
518        })?;
519        let module = self
520            .inner
521            .ctx
522            .load_module(ptx)
523            .gpu_ctx("sphere module load")?;
524        if let Ok(mut guard) = self.inner.modules.lock() {
525            guard.entry(key).or_insert_with(|| module.clone());
526        }
527        Ok(module)
528    }
529
530    #[cfg(target_os = "linux")]
531    fn cc(&self) -> (i32, i32) {
532        (self.inner.cc_major, self.inner.cc_minor)
533    }
534}
535
536// ────────────────────────────────────────────────────────────────────────
537// Entry points
538// ────────────────────────────────────────────────────────────────────────
539
540/// Build the raw `(n × m)` Wahba kernel matrix on device using
541/// `s2_wahba_legendre_colmajor`. Phase 1 entry point.
542pub fn build_kernel_matrix_device(
543    inputs: S2KernelBuildInputs<'_>,
544) -> Result<DeviceS2KernelMatrix, GpuError> {
545    inputs.validate()?;
546
547    #[cfg(target_os = "linux")]
548    {
549        use cudarc::driver::{LaunchConfig, PushKernelArg};
550        let backend = SphereGpuBackend::probe()?;
551        let (cc_major, cc_minor) = backend.cc();
552        let key = S2ModuleCacheKey {
553            cc_major,
554            cc_minor,
555            lmax: inputs.lmax as u32,
556            kind: inputs.kind,
557            layout: inputs.layout,
558        };
559        let module = backend.module_for(key)?;
560        let func = module
561            .load_function("s2_wahba_legendre_colmajor")
562            .gpu_ctx("sphere load_function raw")?;
563        let stream = backend.inner.stream.clone();
564
565        let data_dev = stream
566            .clone_htod(inputs.data_xyz)
567            .gpu_ctx("sphere htod data_xyz")?;
568        let centers_dev = stream
569            .clone_htod(inputs.centers_xyz)
570            .gpu_ctx("sphere htod centers_xyz")?;
571        let coeffs_dev = stream
572            .clone_htod(inputs.coeffs)
573            .gpu_ctx("sphere htod coeffs")?;
574
575        let n = inputs.n;
576        let m = inputs.m;
577        let ld = ((n + 31) / 32) * 32;
578        let mut out_dev = stream
579            .alloc_zeros::<f64>(ld * m)
580            .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
581
582        // Block (32, 8, 1) — x over centers, y over rows.
583        let block_x: u32 = 32;
584        let block_y: u32 = 8;
585        let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
586        let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
587        let cfg = LaunchConfig {
588            grid_dim: (grid_x, grid_y, 1),
589            block_dim: (block_x, block_y, 1),
590            shared_mem_bytes: 0,
591        };
592        let n_i32: i32 =
593            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
594        let m_i32: i32 =
595            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
596        let ld_i64: i64 = ld as i64;
597
598        let mut builder = stream.launch_builder(&func);
599        builder
600            .arg(&data_dev)
601            .arg(&centers_dev)
602            .arg(&coeffs_dev)
603            .arg(&n_i32)
604            .arg(&m_i32)
605            .arg(&ld_i64)
606            .arg(&mut out_dev);
607        // SAFETY: launch parameters are validated above; all device
608        // pointers come from cudarc-checked allocations on the same
609        // stream; the kernel only reads inputs and writes within
610        // out[0 .. ld*m].
611        unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
612        stream
613            .synchronize()
614            .gpu_ctx("sphere raw kernel synchronize")?;
615
616        Ok(DeviceS2KernelMatrix {
617            rows: n,
618            cols: m,
619            ld,
620            col_major_dev: out_dev,
621            stream,
622        })
623    }
624
625    #[cfg(not(target_os = "linux"))]
626    {
627        Err(GpuError::DriverLibraryUnavailable {
628            reason: "sphere GPU backend is Linux-only".to_string(),
629        })
630    }
631}
632
633/// Phase-3 fused Householder-constrained kernel. `v` is the Householder
634/// vector (length m), `beta` the reflector scalar, and the output is
635/// the `(n × (m-1))` constrained design X_s on device.
636pub fn build_householder_constrained_design_device(
637    inputs: S2KernelBuildInputs<'_>,
638    v: &[f64],
639    beta: f64,
640) -> Result<DeviceS2KernelMatrix, GpuError> {
641    inputs.validate()?;
642    if v.len() != inputs.m {
643        gam_gpu::gpu_bail!(
644            "build_householder_constrained_design_device: v.len()={} != m={}",
645            v.len(),
646            inputs.m
647        );
648    }
649    if inputs.m < 2 {
650        gam_gpu::gpu_bail!(
651            "build_householder_constrained_design_device: m must be >= 2 (got {})",
652            inputs.m
653        );
654    }
655    if !beta.is_finite() {
656        gam_gpu::gpu_bail!(
657            "build_householder_constrained_design_device: beta must be finite (got {beta})"
658        );
659    }
660
661    #[cfg(target_os = "linux")]
662    {
663        use cudarc::driver::{LaunchConfig, PushKernelArg};
664        let backend = SphereGpuBackend::probe()?;
665        let (cc_major, cc_minor) = backend.cc();
666        let key = S2ModuleCacheKey {
667            cc_major,
668            cc_minor,
669            lmax: inputs.lmax as u32,
670            kind: inputs.kind,
671            layout: inputs.layout,
672        };
673        let module = backend.module_for(key)?;
674        let func = module
675            .load_function("s2_wahba_householder_constrained_colmajor")
676            .gpu_ctx("sphere load_function householder")?;
677        let stream = backend.inner.stream.clone();
678
679        let data_dev = stream
680            .clone_htod(inputs.data_xyz)
681            .gpu_ctx("sphere-hh htod data_xyz")?;
682        let centers_dev = stream
683            .clone_htod(inputs.centers_xyz)
684            .gpu_ctx("sphere-hh htod centers_xyz")?;
685        let coeffs_dev = stream
686            .clone_htod(inputs.coeffs)
687            .gpu_ctx("sphere-hh htod coeffs")?;
688        let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
689
690        let n = inputs.n;
691        let m = inputs.m;
692        let cols_out = m - 1;
693        let ld_out = ((n + 31) / 32) * 32;
694        let mut out_dev = stream
695            .alloc_zeros::<f64>(ld_out * cols_out)
696            .gpu_ctx_with(|err| {
697                format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
698            })?;
699
700        let block_x: u32 = 128;
701        let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
702        let cfg = LaunchConfig {
703            grid_dim: (grid_x, 1, 1),
704            block_dim: (block_x, 1, 1),
705            shared_mem_bytes: 0,
706        };
707        let n_i32: i32 =
708            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
709        let m_i32: i32 =
710            i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
711        let ld_out_i64: i64 = ld_out as i64;
712
713        let mut builder = stream.launch_builder(&func);
714        builder
715            .arg(&data_dev)
716            .arg(&centers_dev)
717            .arg(&coeffs_dev)
718            .arg(&v_dev)
719            .arg(&beta)
720            .arg(&n_i32)
721            .arg(&m_i32)
722            .arg(&ld_out_i64)
723            .arg(&mut out_dev);
724        // SAFETY: validated shapes above; the kernel writes exactly
725        // (n × (m-1)) entries within `out[0 .. ld_out * (m-1)]`.
726        unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
727        stream
728            .synchronize()
729            .gpu_ctx("sphere-hh kernel synchronize")?;
730
731        Ok(DeviceS2KernelMatrix {
732            rows: n,
733            cols: cols_out,
734            ld: ld_out,
735            col_major_dev: out_dev,
736            stream,
737        })
738    }
739
740    #[cfg(not(target_os = "linux"))]
741    {
742        Err(GpuError::DriverLibraryUnavailable {
743            reason: "sphere GPU backend is Linux-only".to_string(),
744        })
745    }
746}
747
748// ────────────────────────────────────────────────────────────────────────
749// Householder reflector helpers (host-side; Phase 3 prep).
750//
751// Given a non-zero weight vector w ∈ ℝ^m, construct (v, beta) such that
752// H = I − beta · v · v^T satisfies H · w = ±‖w‖ · e_1 and drops the
753// weighted-sum constraint into the first column.
754// ────────────────────────────────────────────────────────────────────────
755
756/// Build the Householder reflector that zeroes `w` against `e_1`.
757/// Returns `(v, beta)` with the LAPACK / Golub-Van Loan convention
758/// `v[0] = 1`. If `w` has zero norm, returns `(0-vector, 0.0)` and the
759/// caller should treat the reflector as a no-op (no constraint).
760pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
761    let m = w.len();
762    if m == 0 {
763        return (Vec::new(), 0.0);
764    }
765    let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
766    if norm == 0.0 {
767        return (vec![0.0; m], 0.0);
768    }
769    let sigma = if w[0] >= 0.0 { norm } else { -norm };
770    let mut v = w.to_vec();
771    v[0] += sigma;
772    let v0 = v[0];
773    if v0 == 0.0 {
774        return (vec![0.0; m], 0.0);
775    }
776    // Normalize so v[0] = 1 (LAPACK convention).
777    for entry in v.iter_mut() {
778        *entry /= v0;
779    }
780    // beta = 2 / (v · v).
781    let vv: f64 = v.iter().map(|x| x * x).sum();
782    let beta = 2.0 / vv;
783    (v, beta)
784}
785
786// ────────────────────────────────────────────────────────────────────────
787// Phase 2 — center-center penalty C + constraint S = Zᵀ C Z.
788//
789// `C` is the (m × m) Wahba kernel of centers against themselves and is
790// computed by reusing the raw GPU kernel with `n = m`. The constraint
791// transform is the same Householder reflector used by the Phase-3 fused
792// kernel: Z = (I − β · v · vᵀ) with the first column dropped, so the
793// constrained penalty is the trailing (m−1)×(m−1) block of HᵀCH.
794//
795// At m ≤ 200 the Householder product is cheap on host and the result is
796// returned as an `ndarray::Array2`. Future calls into cuSOLVER QR can
797// upload it (or its Cholesky factor) once and keep it device-resident.
798// ────────────────────────────────────────────────────────────────────────
799
800/// Build the (m × m) center-center kernel matrix `C` using the same GPU
801/// kernel that builds the design. `centers_xyz` is the unit-vector
802/// representation of the centers, length `3 * m`. `coeffs` and `kind`
803/// match the design build.
804pub fn build_center_kernel_device(
805    centers_xyz: &[f64],
806    lmax: usize,
807    coeffs: &[f64],
808    kind: SphereSpectralKernelKind,
809) -> Result<DeviceS2KernelMatrix, GpuError> {
810    let m = centers_xyz.len() / 3;
811    if centers_xyz.len() != 3 * m {
812        return Err(GpuError::DriverCallFailed {
813            reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
814        });
815    }
816    let inputs = S2KernelBuildInputs {
817        n: m,
818        m,
819        lmax,
820        data_xyz: centers_xyz,
821        centers_xyz,
822        coeffs,
823        kind,
824        layout: DeviceMatrixLayout::ColumnMajor,
825    };
826    build_kernel_matrix_device(inputs)
827}
828
829/// Constrained penalty matrix `S = Zᵀ C Z` for the
830/// weighted-sum-to-zero Householder constraint built from `w`.
831/// Returned shape is `((m−1) × (m−1))`. `C` is taken as a host
832/// (m × m) array (typically the dtoh of `build_center_kernel_device`).
833pub fn constrained_penalty_host(
834    c: ArrayView2<'_, f64>,
835    w: &[f64],
836) -> Result<Array2<f64>, GpuError> {
837    let (m1, m2) = c.dim();
838    if m1 != m2 {
839        gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
840    }
841    let m = m1;
842    if w.len() != m {
843        gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
844    }
845    if m < 2 {
846        gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
847    }
848    let (v, beta) = householder_reflector_from_weights(w);
849
850    // Form HCH = (I - β v vᵀ) C (I - β v vᵀ) = C - β (v · uᵀ + u · vᵀ) + β² (vᵀ C v) v vᵀ,
851    // where u = C v. This is O(m²) — fine for m ≤ 200.
852    let mut u = vec![0.0_f64; m];
853    for i in 0..m {
854        let mut acc = 0.0_f64;
855        for j in 0..m {
856            acc += c[(i, j)] * v[j];
857        }
858        u[i] = acc;
859    }
860    let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
861    let mut hch = Array2::<f64>::zeros((m, m));
862    for i in 0..m {
863        for j in 0..m {
864            hch[(i, j)] =
865                c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
866        }
867    }
868    // Drop the first row and column (the Householder-constrained nullspace).
869    let mut s = Array2::<f64>::zeros((m - 1, m - 1));
870    for i in 0..(m - 1) {
871        for j in 0..(m - 1) {
872            s[(i, j)] = hch[(i + 1, j + 1)];
873        }
874    }
875    Ok(s)
876}
877
878// ────────────────────────────────────────────────────────────────────────
879// Phase 4 — device-resident cuSOLVER QR penalised solve.
880//
881// Solve  min_β  ‖ [√W · X_s] β − [√W · y] ‖² + λ ‖R_S · β‖²
882//
883// by stacking the augmented matrix
884//
885//     A_aug = [ √W · X_s ;   √λ · R_S ]    shape (n + p) × p,
886//     b_aug = [ √W · y    ;   0       ]    length n + p,
887//
888// where p = m − 1, R_S is the upper-triangular Cholesky factor of the
889// constrained penalty S = Zᵀ C Z, and (√W·X_s) is the design built by
890// the fused Householder kernel scaled by sqrt-weights row-by-row on
891// device. The pipeline is:
892//
893//     1. cusolverDnDgeqrf_bufferSize → workspace size.
894//     2. cusolverDnDgeqrf(A_aug)     → A := [R upper-tri / V Householder]
895//                                        plus tau vector.
896//     3. cusolverDnDormqr(side=L, trans=T)
897//                                  → applies Qᵀ to b_aug.
898//     4. cublasDtrsm(L = upper) → β := R⁻¹ · (Qᵀ b_aug)[0..p].
899//
900// Coefficients (β) come back to host; log|H| can be returned via Σ
901// log(R_ii²) from the diagonal of the in-place factored R.
902//
903// All intermediate state — A_aug, b_aug, tau, workspace, info — stays
904// device-resident. The host learns only (β, log|H|, residual ssq).
905// ────────────────────────────────────────────────────────────────────────
906
907/// Result returned by [`solve_penalised_ls_device`].
908#[derive(Clone, Debug)]
909pub struct PenalisedLsSolution {
910    /// Coefficient vector, length `p = m − 1` (after Householder drop).
911    pub beta: Vec<f64>,
912    /// Sum of squared residuals on the unaugmented rows: ‖√W (Xβ − y)‖².
913    pub weighted_residual_ssq: f64,
914    /// log|H| = 2 · Σ log |R_ii| of the QR-factored augmented design.
915    pub log_det_hessian: f64,
916}
917
918/// Augmented penalised least-squares solve via on-device cuSOLVER QR.
919///
920/// Inputs:
921///   * `x_s_device` — already-constrained, weighted-sqrt-scaled design
922///     `√W · X_s` produced by the Phase-3 fused kernel + a row-scaling
923///     kernel. Shape `(n × p)` column-major.
924///   * `wy` — `√W · y` (length n), already host-multiplied (cheap).
925///   * `r_s` — upper-triangular Cholesky factor of `√λ · S`, shape
926///     `(p × p)` row-major host array.
927#[cfg(target_os = "linux")]
928pub fn solve_penalised_ls_device(
929    x_s_device: &DeviceS2KernelMatrix,
930    wy: &[f64],
931    r_s: ArrayView2<'_, f64>,
932) -> Result<PenalisedLsSolution, GpuError> {
933    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
934    use cudarc::driver::DevicePtrMut;
935
936    let n = x_s_device.rows;
937    let p = x_s_device.cols;
938    if wy.len() != n {
939        gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
940    }
941    if r_s.dim() != (p, p) {
942        gam_gpu::gpu_bail!(
943            "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
944            r_s.dim()
945        );
946    }
947    if p == 0 {
948        return Ok(PenalisedLsSolution {
949            beta: Vec::new(),
950            weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
951            log_det_hessian: 0.0,
952        });
953    }
954
955    let stream = x_s_device.stream.clone();
956    let n_aug = n + p;
957
958    // 1) Materialise A_aug column-major on device. We don't need the
959    //    upstream X_s after QR, but the kernel matrix builder hands us
960    //    its own storage; we copy into a fresh (n_aug × p) slab so the
961    //    in-place geqrf doesn't clobber a buffer the caller still owns.
962    let mut a_aug_host = vec![0.0_f64; n_aug * p];
963    // Copy device-side X_s back column-by-column into the upper block.
964    let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
965    x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
966    for j in 0..p {
967        let src_off = j * x_s_device.ld;
968        let dst_off = j * n_aug;
969        a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
970        for i in 0..p {
971            // R_S is row-major host; insert into column j of the lower
972            // block (rows n..n+p) as r_s[i, j].
973            a_aug_host[dst_off + n + i] = r_s[(i, j)];
974        }
975    }
976    let mut a_dev = stream
977        .clone_htod(&a_aug_host)
978        .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
979
980    // b_aug = [√W·y ; 0]
981    let mut b_host = vec![0.0_f64; n_aug];
982    b_host[..n].copy_from_slice(wy);
983    let mut b_dev = stream
984        .clone_htod(&b_host)
985        .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
986
987    let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
988    let n_aug_i: i32 = i32::try_from(n_aug)
989        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
990    let p_i: i32 = i32::try_from(p)
991        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
992
993    // 2) Workspace size for geqrf.
994    let mut lwork: i32 = 0;
995    {
996        let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
997        // SAFETY: a_dev holds n_aug*p f64 elements column-major;
998        // pointer is live on `stream`; lwork is a valid host out-param.
999        let status = unsafe {
1000            cusolver_sys::cusolverDnDgeqrf_bufferSize(
1001                solver.cu(),
1002                n_aug_i,
1003                p_i,
1004                a_ptr as *mut f64,
1005                n_aug_i,
1006                &mut lwork,
1007            )
1008        };
1009        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1010            gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1011        }
1012    }
1013    let lwork_us = usize::try_from(lwork)
1014        .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1015    let mut workspace = stream
1016        .alloc_zeros::<f64>(lwork_us.max(1))
1017        .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1018    let mut tau = stream
1019        .alloc_zeros::<f64>(p)
1020        .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1021    let mut info = stream
1022        .alloc_zeros::<i32>(1)
1023        .gpu_ctx("solve_penalised_ls_device alloc info")?;
1024
1025    // 3) cusolverDnDgeqrf — A := QR in place.
1026    {
1027        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1028        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1029        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1030        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1031        // SAFETY: all pointers reference live device allocations on
1032        // this stream; lwork matches the bufferSize query above.
1033        let status = unsafe {
1034            cusolver_sys::cusolverDnDgeqrf(
1035                solver.cu(),
1036                n_aug_i,
1037                p_i,
1038                a_ptr as *mut f64,
1039                n_aug_i,
1040                tau_ptr as *mut f64,
1041                work_ptr as *mut f64,
1042                lwork,
1043                info_ptr as *mut i32,
1044            )
1045        };
1046        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1047            gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1048        }
1049    }
1050
1051    // 4) cusolverDnDormqr — b_aug := Qᵀ · b_aug.
1052    let mut ormqr_lwork: i32 = 0;
1053    {
1054        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1055        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1056        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1057        // SAFETY: A/tau/b are live device buffers on this stream;
1058        // ormqr_lwork is a host out-param.
1059        let status = unsafe {
1060            cusolver_sys::cusolverDnDormqr_bufferSize(
1061                solver.cu(),
1062                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1063                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1064                n_aug_i,
1065                1,
1066                p_i,
1067                a_ptr as *const f64,
1068                n_aug_i,
1069                tau_ptr as *const f64,
1070                b_ptr as *mut f64,
1071                n_aug_i,
1072                &mut ormqr_lwork,
1073            )
1074        };
1075        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1076            gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1077        }
1078    }
1079    if ormqr_lwork > lwork {
1080        workspace = stream
1081            .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1082            .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1083    }
1084    {
1085        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1086        let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1087        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1088        let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1089        let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1090        // SAFETY: all pointers reference live, mutually-non-aliasing
1091        // device buffers on this stream; lwork matches the bufferSize
1092        // query above; A and tau are the geqrf output.
1093        let status = unsafe {
1094            cusolver_sys::cusolverDnDormqr(
1095                solver.cu(),
1096                cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1097                cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1098                n_aug_i,
1099                1,
1100                p_i,
1101                a_ptr as *const f64,
1102                n_aug_i,
1103                tau_ptr as *const f64,
1104                b_ptr as *mut f64,
1105                n_aug_i,
1106                work_ptr as *mut f64,
1107                ormqr_lwork.max(lwork),
1108                info_ptr as *mut i32,
1109            )
1110        };
1111        if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1112            gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1113        }
1114    }
1115
1116    // 5) cublasDtrsm — solve R · β = (Qᵀ b)[0..p] in place on the top
1117    //    of b_dev. We use a single-RHS upper-triangular non-unit solve.
1118    {
1119        use cudarc::cublas::CudaBlas;
1120        let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1121        let alpha = 1.0_f64;
1122        let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1123        let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1124        // SAFETY: A is the geqrf-output upper-triangular factor R in
1125        // its top-p × p block (col-major, ld = n_aug); b is the
1126        // ormqr-output Qᵀb in the top p slots (ld = n_aug as well so
1127        // pretend it is column-major with 1 column of leading dim n_aug).
1128        let handle = *blas.handle();
1129        let status = unsafe {
1130            cudarc::cublas::sys::cublasDtrsm_v2(
1131                handle,
1132                cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1133                cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1134                cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1135                cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1136                p_i,
1137                1,
1138                &alpha,
1139                a_ptr as *const f64,
1140                n_aug_i,
1141                b_ptr as *mut f64,
1142                n_aug_i,
1143            )
1144        };
1145        if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1146            gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1147        }
1148    }
1149
1150    // 6) Copy results back to host.
1151    let mut b_out = vec![0.0_f64; n_aug];
1152    stream
1153        .memcpy_dtoh(&b_dev, &mut b_out)
1154        .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1155    let mut a_back = vec![0.0_f64; n_aug * p];
1156    stream
1157        .memcpy_dtoh(&a_dev, &mut a_back)
1158        .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1159    stream
1160        .synchronize()
1161        .gpu_ctx("solve_penalised_ls_device synchronize")?;
1162
1163    let beta: Vec<f64> = b_out[..p].to_vec();
1164    // (Qᵀb)[p..n_aug] holds the residual in the rotated coordinates;
1165    // ‖(Qᵀb)[p..]‖² = ‖√W (Xβ − y)‖² + λ ‖R_S β‖² for the augmented
1166    // system. To recover ‖√W (Xβ − y)‖² alone, subtract the penalty
1167    // residual ‖R_S β‖² (penalty rotates to itself in the augmented
1168    // bottom block, but only when the bottom block ROWS map exactly
1169    // into the rotated residual — which is not guaranteed, so the
1170    // simpler accurate path is to return the **augmented** residual
1171    // squared and let the caller subtract.)
1172    let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1173
1174    // log|R| diagonal.
1175    let mut log_abs_r = 0.0_f64;
1176    for k in 0..p {
1177        let r_kk = a_back[k * n_aug + k];
1178        log_abs_r += r_kk.abs().ln();
1179    }
1180    let log_det_hessian = 2.0 * log_abs_r;
1181
1182    Ok(PenalisedLsSolution {
1183        beta,
1184        weighted_residual_ssq: augmented_residual_ssq,
1185        log_det_hessian,
1186    })
1187}
1188
1189#[cfg(not(target_os = "linux"))]
1190pub fn solve_penalised_ls_device(
1191    x_s_device: &DeviceS2KernelMatrix,
1192    wy: &[f64],
1193    r_s: ArrayView2<'_, f64>,
1194) -> Result<PenalisedLsSolution, GpuError> {
1195    Err(GpuError::DriverLibraryUnavailable {
1196        reason: format!(
1197            "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1198            x_s_device.rows,
1199            x_s_device.cols,
1200            wy.len(),
1201            r_s.dim()
1202        ),
1203    })
1204}
1205
1206// ────────────────────────────────────────────────────────────────────────
1207// Tests
1208// ────────────────────────────────────────────────────────────────────────
1209
1210#[cfg(test)]
1211mod sphere_gpu_tests {
1212    use super::*;
1213    use crate::basis::{
1214        SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1215        spherical_wahba_kernel_matrix_with_kind,
1216    };
1217    use ndarray::Array2;
1218
1219    fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1220        // Latitude in (-85, 85), longitude in [-180, 180), degrees.
1221        let mut rows = Vec::with_capacity(n_lat * n_lon);
1222        for i in 0..n_lat {
1223            let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1224            for j in 0..n_lon {
1225                let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1226                rows.push(lat);
1227                rows.push(lon);
1228            }
1229        }
1230        Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1231    }
1232
1233    #[test]
1234    fn xyz_preprocessing_matches_unit_sphere() {
1235        let latlon = ndarray::array![
1236            [0.0, 0.0],
1237            [90.0, 0.0],
1238            [0.0, 90.0],
1239            [-90.0, 17.5],
1240            [45.0, -120.0],
1241        ];
1242        let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1243        assert_eq!(xyz.len(), 3 * 5);
1244        for i in 0..5 {
1245            let nrm2 = xyz[3 * i] * xyz[3 * i]
1246                + xyz[3 * i + 1] * xyz[3 * i + 1]
1247                + xyz[3 * i + 2] * xyz[3 * i + 2];
1248            assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1249        }
1250        // Row 0 = equator @ lon=0 → (1, 0, 0).
1251        assert!((xyz[0] - 1.0).abs() < 1e-15);
1252        assert!(xyz[1].abs() < 1e-15);
1253        assert!(xyz[2].abs() < 1e-15);
1254        // Row 1 = north pole (lat=90, lon=0) → (0, 0, 1).
1255        assert!(xyz[3].abs() < 1e-15);
1256        assert!(xyz[4].abs() < 1e-15);
1257        assert!((xyz[5] - 1.0).abs() < 1e-15);
1258        // Row 2 = equator @ lon=90 → (0, 1, 0).
1259        assert!(xyz[6].abs() < 1e-15);
1260        assert!((xyz[7] - 1.0).abs() < 1e-15);
1261        assert!(xyz[8].abs() < 1e-15);
1262    }
1263
1264    #[test]
1265    fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1266        // P_ℓ(1) = 1 for all ℓ, so K(x, x) = Σ_{ℓ=0..L} c_ℓ. The Legendre
1267        // recurrence in `sphere_truncated_spectral_eval` must reproduce
1268        // this exact identity to roundoff.
1269        for m_penalty in 1..=4 {
1270            for &lmax in &[5_usize, 20, 50] {
1271                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1272                let expected: f64 = coeffs.iter().sum();
1273                let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1274                assert!(
1275                    (got - expected).abs() < 1e-13,
1276                    "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1277                );
1278            }
1279        }
1280    }
1281
1282    #[test]
1283    fn truncated_spectral_at_antipode_matches_alternating_sum() {
1284        // P_ℓ(-1) = (-1)^ℓ, so K(x, -x) = Σ_{ℓ=0..L} c_ℓ · (-1)^ℓ. Same
1285        // exact identity for the recurrence at t = -1.
1286        for m_penalty in 1..=4 {
1287            for &lmax in &[5_usize, 20, 50] {
1288                let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1289                let expected: f64 = coeffs
1290                    .iter()
1291                    .enumerate()
1292                    .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1293                    .sum();
1294                let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1295                assert!(
1296                    (got - expected).abs() < 1e-13,
1297                    "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1298                );
1299            }
1300        }
1301    }
1302
1303    #[test]
1304    fn truncated_spectral_matrix_is_symmetric() {
1305        // K(γ) depends only on cos γ = x · y = y · x, so the Gram
1306        // matrix B B^T-style kernel evaluation on the same point set
1307        // must be symmetric to roundoff.
1308        let centers = ndarray::array![
1309            [10.0_f64, 20.0],
1310            [-30.0, 100.0],
1311            [45.0, -60.0],
1312            [-89.0, 0.0],
1313            [0.0, 180.0],
1314            [60.0, -179.9],
1315        ];
1316        for m_penalty in [1usize, 2, 4] {
1317            for &lmax in &[10_usize, 30] {
1318                let mat = spherical_wahba_kernel_matrix_with_kind(
1319                    centers.view(),
1320                    centers.view(),
1321                    m_penalty,
1322                    false,
1323                    SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1324                )
1325                .expect("kernel matrix");
1326                let n = centers.nrows();
1327                let mut max_asym = 0.0_f64;
1328                for i in 0..n {
1329                    for j in 0..n {
1330                        let d = (mat[(i, j)] - mat[(j, i)]).abs();
1331                        if d > max_asym {
1332                            max_asym = d;
1333                        }
1334                    }
1335                }
1336                assert!(
1337                    max_asym < 1e-13,
1338                    "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1339                );
1340            }
1341        }
1342    }
1343
1344    #[test]
1345    fn truncated_coefficients_have_zero_constant_mode() {
1346        for m in 1..=4 {
1347            let c = sobolev_s2_truncated_coefficients(50, m);
1348            assert_eq!(c.len(), 51);
1349            assert_eq!(c[0], 0.0);
1350            assert!(c[1] > 0.0);
1351            // Spectral decay c_ℓ ~ 1/ℓ^{2m-1}: monotone for ℓ ≥ 1.
1352            for ell in 2..=50 {
1353                assert!(
1354                    c[ell] < c[ell - 1] + 1e-15,
1355                    "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1356                    c[ell],
1357                    c[ell - 1]
1358                );
1359            }
1360        }
1361    }
1362
1363    #[test]
1364    fn truncated_spectral_matches_matrix_helper() {
1365        // The Wahba kernel matrix helper, invoked with the truncated
1366        // variant, must produce the same value as the bare scalar
1367        // evaluator.
1368        let m_penalty = 2;
1369        let lmax = 20;
1370        let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1371        let data = ndarray::array![[12.5, -34.0]];
1372        let centers = ndarray::array![[40.0, 10.0]];
1373        let mat = spherical_wahba_kernel_matrix_with_kind(
1374            data.view(),
1375            centers.view(),
1376            m_penalty,
1377            false,
1378            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1379        )
1380        .expect("kernel matrix");
1381        // Recompute cos γ on the unit sphere.
1382        let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1383        let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1384        let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1385        let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1386        assert!(
1387            (mat[(0, 0)] - expected).abs() < 1e-13,
1388            "matrix helper differs from scalar evaluator: {} vs {}",
1389            mat[(0, 0)],
1390            expected
1391        );
1392    }
1393
1394    #[test]
1395    fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1396        // Build a small symmetric PD matrix as a stand-in for C, then
1397        // verify that constrained_penalty_host returns a symmetric
1398        // (m-1)×(m-1) matrix whose action against Z·x matches the
1399        // expected Zᵀ C Z mapping.
1400        let m = 6;
1401        let mut c = Array2::<f64>::zeros((m, m));
1402        for i in 0..m {
1403            for j in 0..m {
1404                let d = (i as f64 - j as f64).abs();
1405                c[(i, j)] = (-0.5 * d).exp();
1406            }
1407        }
1408        let w = vec![1.0_f64; m];
1409        let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1410        assert_eq!(s.dim(), (m - 1, m - 1));
1411        // Symmetry within roundoff.
1412        let mut max_asym = 0.0_f64;
1413        for i in 0..(m - 1) {
1414            for j in 0..(m - 1) {
1415                let d = (s[(i, j)] - s[(j, i)]).abs();
1416                if d > max_asym {
1417                    max_asym = d;
1418                }
1419            }
1420        }
1421        assert!(
1422            max_asym < 1e-13,
1423            "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1424        );
1425
1426        // The kernel-of-Zᵀ direction: Zᵀ · w = 0 ⇒ x = (something) such
1427        // that Z · x stays in span(w)^⊥, so x can be any (m-1) vector;
1428        // we just verify that picking the all-ones constraint direction
1429        // collapses to zero through Z when applied to constant fields.
1430        // i.e. constant-field penalty norm must be zero in the
1431        // un-constrained Cv direction, and the trailing block here is
1432        // never used against the constraint.
1433        let ones = ndarray::Array1::<f64>::ones(m - 1);
1434        let sx = s.dot(&ones);
1435        assert!(sx.iter().all(|v| v.is_finite()));
1436    }
1437
1438    #[test]
1439    fn householder_reflector_zeroes_target_vector() {
1440        let w = vec![3.0, 4.0, 0.0, -1.0];
1441        let (v, beta) = householder_reflector_from_weights(&w);
1442        // Apply H = I - beta * v * v^T to w; the result should be a
1443        // multiple of e_1 (only first entry non-zero).
1444        let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1445        let hw: Vec<f64> = w
1446            .iter()
1447            .zip(&v)
1448            .map(|(wj, vj)| wj - beta * dot * vj)
1449            .collect();
1450        for entry in hw.iter().skip(1) {
1451            assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1452        }
1453        assert!(hw[0].abs() > 0.0);
1454    }
1455
1456    /// V100-only: probe + raw kernel parity vs CPU truncated-spectral on
1457    /// a small grid. Skips cleanly on hosts with no CUDA runtime.
1458    #[test]
1459    fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1460        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1461            eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1462            return;
1463        };
1464        // Past the runtime Some-gate: a probe failure is a real device fault on a
1465        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1466        SphereGpuBackend::probe()
1467            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1468
1469        let data_ll = small_latlon_grid(7, 9);
1470        let centers_ll = small_latlon_grid(5, 7);
1471        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1472        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1473        let n = data_ll.nrows();
1474        let m = centers_ll.nrows();
1475        let penalty = 2usize;
1476        let lmax = 20usize;
1477        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1478
1479        let inputs = S2KernelBuildInputs {
1480            n,
1481            m,
1482            lmax,
1483            data_xyz: &data_xyz,
1484            centers_xyz: &centers_xyz,
1485            coeffs: &coeffs,
1486            kind: SphereSpectralKernelKind::Sobolev,
1487            layout: DeviceMatrixLayout::ColumnMajor,
1488        };
1489        let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1490        let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1491
1492        let cpu = spherical_wahba_kernel_matrix_with_kind(
1493            data_ll.view(),
1494            centers_ll.view(),
1495            penalty,
1496            false,
1497            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1498        )
1499        .expect("cpu kernel matrix");
1500
1501        let mut max_abs = 0.0_f64;
1502        for i in 0..n {
1503            for j in 0..m {
1504                let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1505                if d > max_abs {
1506                    max_abs = d;
1507                }
1508            }
1509        }
1510        assert!(
1511            max_abs < 1e-11,
1512            "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1513        );
1514    }
1515
1516    /// V100-only end-to-end dispatch parity: drive
1517    /// `build_spherical_spline_basis` with a `SobolevTruncated` spec on
1518    /// a workload large enough to trigger `sphere_kernel_decision().use_gpu`,
1519    /// then re-build with the CPU-only `Sobolev` (deep-spectral) kernel
1520    /// is **not** what we compare to — the GPU exactly matches the
1521    /// `SobolevTruncated` CPU path, so the comparison is
1522    /// truncated-on-GPU vs truncated-on-CPU. Down-stream PIRLS/REML
1523    /// consumes the design verbatim so element-wise design parity at
1524    /// ≤ 1e-9 implies fit parity at the same tolerance.
1525    #[test]
1526    fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1527        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1528            eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1529            return;
1530        };
1531        // Past the runtime Some-gate: a backend probe failure is a real device
1532        // fault on a CUDA host, not a no-CUDA skip — fail loud (device-PCG
1533        // skip-pass class, eee12f6b2) instead of masking it as a pass.
1534        SphereGpuBackend::probe()
1535            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1536        use crate::basis::{
1537            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1538            build_spherical_spline_basis, sobolev_s2_truncated_coefficients,
1539        };
1540        drop(sobolev_s2_truncated_coefficients(1, 1));
1541
1542        // (n=10_000, m=200) → n·m = 2_000_000 ≥ 1_000_000 → GPU eligible.
1543        let data = small_latlon_grid(100, 100);
1544        let lmax: u16 = 30;
1545        let penalty_order = 2usize;
1546        let spec_gpu = SphericalSplineBasisSpec {
1547            center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
1548            penalty_order,
1549            double_penalty: false,
1550            radians: false,
1551            method: SphereMethod::Wahba,
1552            max_degree: None,
1553            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1554            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1555        };
1556        let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
1557            .expect("GPU-eligible build_spherical_spline_basis succeeds");
1558
1559        // Re-run with the same spec but bypass the GPU by shrinking
1560        // `n·m` below the 1e6 gate is not possible without changing data
1561        // shape, so instead we materialise the CPU-truncated reference
1562        // design by calling the public CPU helper directly with the same
1563        // centers that the GPU build chose. The centers are deterministic
1564        // (farthest-point with the same seed = leftmost-lowest lat/lon),
1565        // so we can rebuild them.
1566        let centers =
1567            crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1568                .expect("centers");
1569        let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1570            data.view(),
1571            centers.view(),
1572            penalty_order,
1573            false,
1574            SphereWahbaKernel::SobolevTruncated { lmax },
1575        )
1576        .expect("cpu raw design");
1577
1578        // The build keeps raw Wahba center coefficients unless a frozen
1579        // realized-design transform is supplied.
1580        let z = Array2::<f64>::eye(centers.nrows());
1581        let cpu_design = raw_cpu.dot(&z);
1582
1583        let gpu_design = result_gpu.design.as_dense().expect("dense design").clone();
1584
1585        assert_eq!(gpu_design.dim(), cpu_design.dim());
1586        let mut max_abs = 0.0_f64;
1587        let mut max_rel = 0.0_f64;
1588        for ((g, c), _) in gpu_design.iter().zip(cpu_design.iter()).zip(0..) {
1589            let d = (g - c).abs();
1590            if d > max_abs {
1591                max_abs = d;
1592            }
1593            let denom = g.abs().max(c.abs()).max(1e-300);
1594            let r = d / denom;
1595            if r > max_rel {
1596                max_rel = r;
1597            }
1598        }
1599        assert!(
1600            max_rel < 1e-9,
1601            "end-to-end design parity max relative |Δ| = {max_rel:.3e} >= 1e-9 (abs {max_abs:.3e})"
1602        );
1603    }
1604
1605    /// V100-only: parity of Householder-constrained kernel against
1606    /// (raw kernel) · Z evaluated on host.
1607    #[test]
1608    fn sphere_gpu_householder_parity_vs_raw_dot_z() {
1609        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1610            eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
1611            return;
1612        };
1613        // Past the runtime Some-gate: a probe failure is a real device fault on a
1614        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1615        SphereGpuBackend::probe()
1616            .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1617        let data_ll = small_latlon_grid(6, 8);
1618        let centers_ll = small_latlon_grid(4, 5);
1619        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1620        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1621        let n = data_ll.nrows();
1622        let m = centers_ll.nrows();
1623        let penalty = 2usize;
1624        let lmax = 15usize;
1625        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1626
1627        // Build raw B on device, then form (n × m-1) X_s = B · Z on host.
1628        let inputs_raw = S2KernelBuildInputs {
1629            n,
1630            m,
1631            lmax,
1632            data_xyz: &data_xyz,
1633            centers_xyz: &centers_xyz,
1634            coeffs: &coeffs,
1635            kind: SphereSpectralKernelKind::Sobolev,
1636            layout: DeviceMatrixLayout::ColumnMajor,
1637        };
1638        let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
1639        let b = b_dev.to_host_array().expect("dtoh raw");
1640
1641        // Construct a Householder reflector from a uniform weight vector
1642        // (the "weighted sum-to-zero" constraint when weights are all 1).
1643        let w = vec![1.0_f64; m];
1644        let (v, beta) = householder_reflector_from_weights(&w);
1645
1646        // Apply on host: X_s_host[i, j_out] = B[i, j_out+1] - beta * (B[i,:] · v) * v[j_out+1]
1647        let mut xs_host = Array2::<f64>::zeros((n, m - 1));
1648        for i in 0..n {
1649            let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
1650            for j_out in 0..(m - 1) {
1651                xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
1652            }
1653        }
1654
1655        let xs_dev =
1656            build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
1657        let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
1658
1659        let mut max_abs = 0.0_f64;
1660        for i in 0..n {
1661            for j in 0..(m - 1) {
1662                let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
1663                if d > max_abs {
1664                    max_abs = d;
1665                }
1666            }
1667        }
1668        assert!(
1669            max_abs < 1e-12,
1670            "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
1671        );
1672    }
1673
1674    /// V100 hill-climb: GPU truncated-spectral kernel matrix build at
1675    /// (n=200_000, m=200, L=50) must beat CPU by ≥ 20× wall-clock.
1676    /// Skips silently when no CUDA runtime is available.
1677    #[test]
1678    fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
1679        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1680            eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
1681            return;
1682        };
1683        if SphereGpuBackend::probe().is_err() {
1684            eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
1685            return;
1686        }
1687
1688        // (n=200_000, m=200, lmax=50). n·m = 4·10^7 ≫ 1e6 → GPU eligible.
1689        // Build a 200_000-row deterministic lat/lon grid.
1690        let n_lat = 500usize;
1691        let n_lon = 400usize;
1692        assert_eq!(n_lat * n_lon, 200_000);
1693        let data_ll = small_latlon_grid(n_lat, n_lon);
1694        let m = 200usize;
1695        let centers_ll =
1696            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1697                .expect("centers");
1698        let n = data_ll.nrows();
1699        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1700        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1701        let penalty_order = 2usize;
1702        let lmax = 50usize;
1703        let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
1704
1705        // Warm up GPU (NVRTC compile + first-touch alloc).
1706        let inputs_warm = S2KernelBuildInputs {
1707            n,
1708            m,
1709            lmax,
1710            data_xyz: &data_xyz,
1711            centers_xyz: &centers_xyz,
1712            coeffs: &coeffs,
1713            kind: SphereSpectralKernelKind::Sobolev,
1714            layout: DeviceMatrixLayout::ColumnMajor,
1715        };
1716        drop(build_kernel_matrix_device(inputs_warm.clone()).expect("warmup"));
1717
1718        // Measure GPU.
1719        let t0 = std::time::Instant::now();
1720        let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
1721        let _host_gpu = dev.to_host_array().expect("dtoh");
1722        let gpu_secs = t0.elapsed().as_secs_f64();
1723
1724        // Measure CPU (truncated-spectral via the public matrix helper).
1725        let t1 = std::time::Instant::now();
1726        let _cpu = spherical_wahba_kernel_matrix_with_kind(
1727            data_ll.view(),
1728            centers_ll.view(),
1729            penalty_order,
1730            false,
1731            SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1732        )
1733        .expect("cpu kernel matrix");
1734        let cpu_secs = t1.elapsed().as_secs_f64();
1735
1736        let ratio = cpu_secs / gpu_secs.max(1e-9);
1737        eprintln!(
1738            "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
1739        );
1740        assert!(
1741            ratio >= 20.0,
1742            "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
1743             n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1744        );
1745    }
1746
1747    /// V100 hill-climb: end-to-end Gaussian fit through
1748    /// `build_spherical_spline_basis` (GPU-dispatched) must beat the
1749    /// CPU-only fit by ≥ 10× wall-clock at a workload where the GPU
1750    /// kernel build dominates PIRLS.
1751    #[test]
1752    fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
1753        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1754            eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
1755            return;
1756        };
1757        if SphereGpuBackend::probe().is_err() {
1758            eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
1759            return;
1760        }
1761        use crate::basis::{
1762            CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1763            build_spherical_spline_basis,
1764        };
1765
1766        let n_lat = 500usize;
1767        let n_lon = 400usize;
1768        let data_ll = small_latlon_grid(n_lat, n_lon);
1769        let m: usize = 200;
1770        let lmax: u16 = 50;
1771        let spec_gpu = SphericalSplineBasisSpec {
1772            center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
1773            penalty_order: 2,
1774            double_penalty: false,
1775            radians: false,
1776            method: SphereMethod::Wahba,
1777            max_degree: None,
1778            wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1779            identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1780        };
1781
1782        // Warm-up GPU build.
1783        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
1784
1785        let t0 = std::time::Instant::now();
1786        drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
1787        let gpu_secs = t0.elapsed().as_secs_f64();
1788
1789        // CPU comparison: directly invoke the CPU helper and apply the
1790        // same constraint transform (matches what build_*_basis would do
1791        // when GPU dispatch declines). Going through the public matrix
1792        // helper isolates the GPU-vs-CPU kernel cost without re-doing
1793        // farthest-point center selection (which is identical for both
1794        // paths).
1795        let centers =
1796            crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1797                .expect("centers");
1798        let z = Array2::<f64>::eye(centers.nrows());
1799        let t1 = std::time::Instant::now();
1800        let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1801            data_ll.view(),
1802            centers.view(),
1803            2,
1804            false,
1805            SphereWahbaKernel::SobolevTruncated { lmax },
1806        )
1807        .expect("cpu raw");
1808        let _design_cpu = raw_cpu.dot(&z);
1809        let cpu_secs = t1.elapsed().as_secs_f64();
1810
1811        let ratio = cpu_secs / gpu_secs.max(1e-9);
1812        eprintln!(
1813            "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
1814            data_ll.nrows()
1815        );
1816        assert!(
1817            ratio >= 10.0,
1818            "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
1819             cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1820        );
1821    }
1822
1823    /// Task #25: end-to-end fit parity between the GPU truncated-spectral
1824    /// path and the CPU truncated-spectral path on a small synthetic
1825    /// intrinsic-S² fixture.
1826    ///
1827    /// Setup: deterministic lat/lon grid (n = 1000 = 25 × 40), 80 centers
1828    /// chosen by farthest-point selection, lmax = 15, penalty order 2,
1829    /// Wahba weighted-sum-to-zero constraint applied via `Z`. We fit a
1830    /// fixed-λ penalised LS problem
1831    ///   β = argmin ‖X_s β − y‖² + λ · βᵀ S β
1832    /// where `X_s = K(data, centers) · Z` and `S = Zᵀ · K(centers, centers) · Z`,
1833    /// solving `(X_sᵀ X_s + λ S) β = X_sᵀ y` via faer LLT for both paths.
1834    /// The only path-dependent quantity is `K(data, centers)`: built on
1835    /// GPU via `build_kernel_matrix_device` for one β, and on CPU via
1836    /// `spherical_wahba_kernel_matrix_with_kind` for the other. The
1837    /// penalty kernel `K(centers, centers)` is m × m and tiny, so we
1838    /// build it once on CPU and share it across paths (it is not the
1839    /// surface under test).
1840    ///
1841    /// Asserts max-absolute coefficient delta ≤ 1e-9 and max-absolute
1842    /// fitted-value delta ≤ 1e-9. `#[ignore = "requires CUDA"]` so the
1843    /// V100 bench runner unignores in their harness.
1844    #[test]
1845    fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
1846        use crate::basis::{
1847            select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
1848        };
1849        use faer::Side;
1850        use gam_linalg::faer_ndarray::FaerCholesky;
1851
1852        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1853            eprintln!(
1854                "[sphere gpu parity] no CUDA runtime — skipping device parity \
1855                 (CPU oracle exercised by sibling tests)"
1856            );
1857            return;
1858        };
1859        // Past the runtime Some-gate: a probe failure is a real device fault on a
1860        // CUDA host — fail loud (device-PCG skip-pass class, eee12f6b2).
1861        SphereGpuBackend::probe()
1862            .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
1863
1864        // Fixture: 25 × 40 lat/lon grid → n = 1000.
1865        let data_ll = small_latlon_grid(25, 40);
1866        assert_eq!(data_ll.nrows(), 1000);
1867        let n = data_ll.nrows();
1868        let m: usize = 80;
1869        let lmax_u16: u16 = 15;
1870        let lmax: usize = lmax_u16 as usize;
1871        let penalty_order: usize = 2;
1872        let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
1873        let lambda: f64 = 1.0e-3;
1874
1875        // Deterministic centers via farthest-point selection.
1876        let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
1877            .expect("farthest-point centers");
1878        assert_eq!(centers_ll.nrows(), m);
1879
1880        // The Wahba sphere basis no longer imposes a finite-center coefficient
1881        // gauge; parity compares the raw center coefficient chart.
1882        let z = Array2::<f64>::eye(centers_ll.nrows());
1883        let p = z.ncols();
1884        assert_eq!(p, m);
1885
1886        // Penalty K(centers, centers), built once on CPU. The penalty
1887        // kernel evaluation is m × m (= 6400 entries), well outside the
1888        // GPU dispatch threshold, and identical for both paths under
1889        // test by construction.
1890        let k_cc = spherical_wahba_kernel_matrix_with_kind(
1891            centers_ll.view(),
1892            centers_ll.view(),
1893            penalty_order,
1894            false,
1895            kernel,
1896        )
1897        .expect("centers×centers kernel");
1898        let s_full = z.t().dot(&k_cc).dot(&z);
1899
1900        // CPU path: K(data, centers) via the public CPU helper.
1901        let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
1902            data_ll.view(),
1903            centers_ll.view(),
1904            penalty_order,
1905            false,
1906            kernel,
1907        )
1908        .expect("CPU raw design");
1909        let x_s_cpu = raw_design_cpu.dot(&z);
1910
1911        // GPU path: K(data, centers) via `build_kernel_matrix_device`.
1912        let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
1913        let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
1914        let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
1915        let inputs = S2KernelBuildInputs {
1916            n,
1917            m,
1918            lmax,
1919            data_xyz: &data_xyz,
1920            centers_xyz: &centers_xyz,
1921            coeffs: &coeffs,
1922            kind: SphereSpectralKernelKind::Sobolev,
1923            layout: DeviceMatrixLayout::ColumnMajor,
1924        };
1925        let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
1926        let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
1927        let x_s_gpu = raw_design_gpu.dot(&z);
1928
1929        assert_eq!(x_s_cpu.dim(), (n, p));
1930        assert_eq!(x_s_gpu.dim(), (n, p));
1931
1932        // Deterministic synthetic response. The intent is to give the
1933        // penalised LS solve a non-trivial right-hand side; any smooth
1934        // function of the lat/lon is fine. Use a fixed-seed pseudo-
1935        // random walk derived from coordinates so the fixture has no
1936        // RNG dependency.
1937        let mut y = ndarray::Array1::<f64>::zeros(n);
1938        for i in 0..n {
1939            let lat_rad = data_ll[(i, 0)].to_radians();
1940            let lon_rad = data_ll[(i, 1)].to_radians();
1941            // Smooth ground truth + a tiny deterministic high-freq jitter.
1942            y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
1943                + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
1944        }
1945
1946        // Penalised normal-equation solve via faer LLT for each path:
1947        //   (X_sᵀ X_s + λ S) β = X_sᵀ y
1948        // S is symmetric positive semi-definite; λ S makes the system
1949        // strictly positive definite once added to X_sᵀ X_s.
1950        let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
1951            let xtx = x_s.t().dot(x_s);
1952            let mut a = xtx;
1953            for i in 0..p {
1954                for j in 0..p {
1955                    a[(i, j)] += lambda * s_full[(i, j)];
1956                }
1957            }
1958            let rhs = x_s.t().dot(&y);
1959            let factor = a
1960                .cholesky(Side::Lower)
1961                .expect("penalised normal equations are SPD under λ > 0");
1962            factor.solvevec(&rhs)
1963        };
1964
1965        let beta_cpu = solve_penalised(&x_s_cpu);
1966        let beta_gpu = solve_penalised(&x_s_gpu);
1967        assert_eq!(beta_cpu.len(), p);
1968        assert_eq!(beta_gpu.len(), p);
1969
1970        // Fitted values for both paths use their own design matrices —
1971        // this is the customer-visible quantity (prediction at training
1972        // points).
1973        let yhat_cpu = x_s_cpu.dot(&beta_cpu);
1974        let yhat_gpu = x_s_gpu.dot(&beta_gpu);
1975
1976        let mut max_beta_delta = 0.0_f64;
1977        for k in 0..p {
1978            let d = (beta_cpu[k] - beta_gpu[k]).abs();
1979            if d > max_beta_delta {
1980                max_beta_delta = d;
1981            }
1982        }
1983        let mut max_fit_delta = 0.0_f64;
1984        for i in 0..n {
1985            let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
1986            if d > max_fit_delta {
1987                max_fit_delta = d;
1988            }
1989        }
1990
1991        eprintln!(
1992            "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
1993             max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
1994        );
1995
1996        assert!(
1997            max_beta_delta <= 1.0e-9,
1998            "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > 1e-9"
1999        );
2000        assert!(
2001            max_fit_delta <= 1.0e-9,
2002            "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2003        );
2004    }
2005}