use std::sync::OnceLock;
use ndarray::{Array2, ArrayView2};
use super::error::GpuError;
use super::{GpuDecision, GpuKernel, decide};
#[cfg(target_os = "linux")]
use crate::gpu::error::GpuResultExt;
#[cfg(target_os = "linux")]
use std::collections::HashMap;
#[cfg(target_os = "linux")]
use std::sync::{Arc, Mutex};
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SphereSpectralKernelKind {
Sobolev,
Pseudo,
}
impl SphereSpectralKernelKind {
pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
match self {
SphereSpectralKernelKind::Sobolev => {
crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
}
SphereSpectralKernelKind::Pseudo => {
crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
}
}
}
pub const fn tag(self) -> &'static str {
match self {
SphereSpectralKernelKind::Sobolev => "sobolev",
SphereSpectralKernelKind::Pseudo => "pseudo",
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum DeviceMatrixLayout {
ColumnMajor,
}
pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
if latlon.ncols() != 2 {
return Err(format!(
"latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
latlon.shape()
));
}
let deg = if radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
let n = latlon.nrows();
let mut out = Vec::with_capacity(3 * n);
for row in latlon.outer_iter() {
let lat = row[0] * deg;
let lon = row[1] * deg;
let (s_lat, c_lat) = lat.sin_cos();
let (s_lon, c_lon) = lon.sin_cos();
out.push(c_lat * c_lon);
out.push(c_lat * s_lon);
out.push(s_lat);
}
Ok(out)
}
#[cfg(target_os = "linux")]
pub struct DeviceS2KernelMatrix {
pub rows: usize,
pub cols: usize,
pub ld: usize,
pub col_major_dev: CudaSlice<f64>,
pub stream: Arc<CudaStream>,
}
#[cfg(not(target_os = "linux"))]
pub struct DeviceS2KernelMatrix {
pub rows: usize,
pub cols: usize,
pub ld: usize,
pub col_major_dev: Vec<f64>,
}
impl DeviceS2KernelMatrix {
pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
let mut col_major = vec![0.0_f64; self.ld * self.cols];
self.copy_to_host_col_major(&mut col_major)?;
let mut out = Array2::<f64>::zeros((self.rows, self.cols));
for j in 0..self.cols {
for i in 0..self.rows {
out[(i, j)] = col_major[j * self.ld + i];
}
}
Ok(out)
}
#[cfg(target_os = "linux")]
pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
let needed = self.ld * self.cols;
if dst.len() != needed {
crate::gpu_bail!(
"DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
dst.len(),
needed
);
}
self.stream
.memcpy_dtoh(&self.col_major_dev, dst)
.gpu_ctx("DeviceS2KernelMatrix dtoh")?;
self.stream
.synchronize()
.gpu_ctx("DeviceS2KernelMatrix synchronize")?;
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
let needed = self.ld * self.cols;
if dst.len() != needed {
crate::gpu_bail!(
"DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
dst.len(),
needed
);
}
dst.copy_from_slice(&self.col_major_dev);
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct S2KernelBuildInputs<'a> {
pub n: usize,
pub m: usize,
pub lmax: usize,
pub data_xyz: &'a [f64],
pub centers_xyz: &'a [f64],
pub coeffs: &'a [f64],
pub kind: SphereSpectralKernelKind,
pub layout: DeviceMatrixLayout,
}
impl<'a> S2KernelBuildInputs<'a> {
fn validate(&self) -> Result<(), GpuError> {
if self.lmax == 0 {
return Err(GpuError::DriverCallFailed {
reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
});
}
if self.data_xyz.len() != 3 * self.n {
crate::gpu_bail!(
"S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
self.data_xyz.len(),
3 * self.n
);
}
if self.centers_xyz.len() != 3 * self.m {
crate::gpu_bail!(
"S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
self.centers_xyz.len(),
3 * self.m
);
}
if self.coeffs.len() != self.lmax + 1 {
crate::gpu_bail!(
"S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
self.coeffs.len(),
self.lmax + 1
);
}
if self.coeffs[0] != 0.0 {
return Err(GpuError::DriverCallFailed {
reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
});
}
Ok(())
}
}
#[cfg(target_os = "linux")]
const KERNEL_TEMPLATE: &str = r#"
// LMAX is supplied by the host via a `#define LMAX ...` prepended to
// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
extern "C" __global__
__launch_bounds__(256)
void s2_wahba_legendre_colmajor(
const double* __restrict__ data_xyz, // n × 3 (row-major flat)
const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
const double* __restrict__ coeffs, // length LMAX + 1, coeffs[0] = 0
int n,
int m,
long long ld,
double* __restrict__ out // ld × m column-major
) {
const int i = blockIdx.y * blockDim.y + threadIdx.y;
const int j = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n || j >= m) return;
// Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
const double xi = data_xyz[3 * i + 0];
const double yi = data_xyz[3 * i + 1];
const double zi = data_xyz[3 * i + 2];
const double cxj = centers_xyz[3 * j + 0];
const double cyj = centers_xyz[3 * j + 1];
const double czj = centers_xyz[3 * j + 2];
// t = clamp(x_i · z_j, -1, +1).
double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
if (t > 1.0) t = 1.0;
if (t < -1.0) t = -1.0;
// Legendre 3-term recurrence in registers.
// P_0(t) = 1, P_1(t) = t.
double p_prev = 1.0;
double p_curr = t;
double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
#pragma unroll 8
for (int ell = 1; ell < LMAX; ++ell) {
const double lf = (double) ell;
const double inv = 1.0 / (lf + 1.0);
// p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
const double p_next =
fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
acc = fma(coeffs[ell + 1], p_next, acc);
p_prev = p_curr;
p_curr = p_next;
}
out[(long long) j * ld + (long long) i] = acc;
}
// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
// i.e. drop the first column after applying Z. Each thread computes one
// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
// for j_out in 0..m-1.
//
// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
// over centers in an inner loop — register-bound by the per-row state
// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
extern "C" __global__
__launch_bounds__(128)
void s2_wahba_householder_constrained_colmajor(
const double* __restrict__ data_xyz, // n × 3
const double* __restrict__ centers_xyz, // m × 3
const double* __restrict__ coeffs, // length LMAX + 1
const double* __restrict__ v, // length m, Householder vector
double beta,
int n,
int m,
long long ld_out,
double* __restrict__ out // ld_out × (m-1) column-major
) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n) return;
const double xi = data_xyz[3 * i + 0];
const double yi = data_xyz[3 * i + 1];
const double zi = data_xyz[3 * i + 2];
// Pass 1: compute d_i = sum_j v[j] * B[i, j].
double d_i = 0.0;
for (int j = 0; j < m; ++j) {
const double cxj = centers_xyz[3 * j + 0];
const double cyj = centers_xyz[3 * j + 1];
const double czj = centers_xyz[3 * j + 2];
double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
if (t > 1.0) t = 1.0;
if (t < -1.0) t = -1.0;
double p_prev = 1.0;
double p_curr = t;
double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
#pragma unroll 8
for (int ell = 1; ell < LMAX; ++ell) {
const double lf = (double) ell;
const double inv = 1.0 / (lf + 1.0);
const double p_next =
fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
acc = fma(coeffs[ell + 1], p_next, acc);
p_prev = p_curr;
p_curr = p_next;
}
d_i = fma(v[j], acc, d_i);
}
// Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
const double bd = beta * d_i;
for (int j_out = 0; j_out < m - 1; ++j_out) {
const int j = j_out + 1;
const double cxj = centers_xyz[3 * j + 0];
const double cyj = centers_xyz[3 * j + 1];
const double czj = centers_xyz[3 * j + 2];
double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
if (t > 1.0) t = 1.0;
if (t < -1.0) t = -1.0;
double p_prev = 1.0;
double p_curr = t;
double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
#pragma unroll 8
for (int ell = 1; ell < LMAX; ++ell) {
const double lf = (double) ell;
const double inv = 1.0 / (lf + 1.0);
const double p_next =
fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
acc = fma(coeffs[ell + 1], p_next, acc);
p_prev = p_curr;
p_curr = p_next;
}
const double xs = acc - bd * v[j];
out[(long long) j_out * ld_out + (long long) i] = xs;
}
}
"#;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct S2ModuleCacheKey {
pub cc_major: i32,
pub cc_minor: i32,
pub lmax: u32,
pub kind: SphereSpectralKernelKind,
pub layout: DeviceMatrixLayout,
}
pub const fn sphere_gpu_compiled() -> bool {
cfg!(target_os = "linux")
}
#[must_use]
pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
let large_enough = if let Some(runtime) = super::runtime::GpuRuntime::global() {
let ld = ((n + 31) / 32) * 32;
let needed_bytes = ld
.saturating_mul(m)
.saturating_mul(std::mem::size_of::<f64>());
let budget = runtime.memory_budget_bytes;
n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
} else {
false
};
decide(
GpuKernel::SpatialKernelOperator,
super::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
)
}
#[cfg(target_os = "linux")]
struct SphereGpuContext {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
cc_major: i32,
cc_minor: i32,
}
pub struct SphereGpuBackend {
#[cfg(target_os = "linux")]
inner: SphereGpuContext,
}
impl SphereGpuBackend {
pub fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
#[cfg(target_os = "linux")]
{
Self::probe_linux()
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "sphere GPU backend is Linux-only".to_string(),
})
}
})
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
fn probe_linux() -> Result<Self, GpuError> {
let runtime = super::runtime::GpuRuntime::global().ok_or_else(|| {
GpuError::DriverLibraryUnavailable {
reason: "sphere backend: no CUDA runtime available".to_string(),
}
})?;
let ordinal = runtime.selected_device().ordinal;
let ctx = super::runtime::cuda_context_for(ordinal).ok_or_else(|| {
gpu_err!("sphere backend: failed to create CUDA context for device {ordinal}")
})?;
let stream = ctx.default_stream();
let cap = &runtime.selected_device().capability;
let cc_major = cap.compute_major;
let cc_minor = cap.compute_minor;
Ok(SphereGpuBackend {
inner: SphereGpuContext {
ctx,
stream,
modules: Mutex::new(HashMap::new()),
cc_major,
cc_minor,
},
})
}
#[cfg(target_os = "linux")]
fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
if let Ok(guard) = self.inner.modules.lock() {
if let Some(existing) = guard.get(&key) {
return Ok(existing.clone());
}
}
let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
format!(
"sphere NVRTC compile (kind={}, lmax={}): {err}",
key.kind.tag(),
key.lmax
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("sphere module load")?;
if let Ok(mut guard) = self.inner.modules.lock() {
guard.entry(key).or_insert_with(|| module.clone());
}
Ok(module)
}
#[cfg(target_os = "linux")]
fn cc(&self) -> (i32, i32) {
(self.inner.cc_major, self.inner.cc_minor)
}
}
pub fn build_kernel_matrix_device(
inputs: S2KernelBuildInputs<'_>,
) -> Result<DeviceS2KernelMatrix, GpuError> {
inputs.validate()?;
#[cfg(target_os = "linux")]
{
use cudarc::driver::{LaunchConfig, PushKernelArg};
let backend = SphereGpuBackend::probe()?;
let (cc_major, cc_minor) = backend.cc();
let key = S2ModuleCacheKey {
cc_major,
cc_minor,
lmax: inputs.lmax as u32,
kind: inputs.kind,
layout: inputs.layout,
};
let module = backend.module_for(key)?;
let func = module
.load_function("s2_wahba_legendre_colmajor")
.gpu_ctx("sphere load_function raw")?;
let stream = backend.inner.stream.clone();
let data_dev = stream
.clone_htod(inputs.data_xyz)
.gpu_ctx("sphere htod data_xyz")?;
let centers_dev = stream
.clone_htod(inputs.centers_xyz)
.gpu_ctx("sphere htod centers_xyz")?;
let coeffs_dev = stream
.clone_htod(inputs.coeffs)
.gpu_ctx("sphere htod coeffs")?;
let n = inputs.n;
let m = inputs.m;
let ld = ((n + 31) / 32) * 32;
let mut out_dev = stream
.alloc_zeros::<f64>(ld * m)
.gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
let block_x: u32 = 32;
let block_y: u32 = 8;
let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (block_x, block_y, 1),
shared_mem_bytes: 0,
};
let n_i32: i32 = i32::try_from(n).map_err(|_| gpu_err!("sphere n={n} overflows i32"))?;
let m_i32: i32 = i32::try_from(m).map_err(|_| gpu_err!("sphere m={m} overflows i32"))?;
let ld_i64: i64 = ld as i64;
let mut builder = stream.launch_builder(&func);
builder
.arg(&data_dev)
.arg(¢ers_dev)
.arg(&coeffs_dev)
.arg(&n_i32)
.arg(&m_i32)
.arg(&ld_i64)
.arg(&mut out_dev);
unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
stream
.synchronize()
.gpu_ctx("sphere raw kernel synchronize")?;
Ok(DeviceS2KernelMatrix {
rows: n,
cols: m,
ld,
col_major_dev: out_dev,
stream,
})
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "sphere GPU backend is Linux-only".to_string(),
})
}
}
pub fn build_householder_constrained_design_device(
inputs: S2KernelBuildInputs<'_>,
v: &[f64],
beta: f64,
) -> Result<DeviceS2KernelMatrix, GpuError> {
inputs.validate()?;
if v.len() != inputs.m {
crate::gpu_bail!(
"build_householder_constrained_design_device: v.len()={} != m={}",
v.len(),
inputs.m
);
}
if inputs.m < 2 {
crate::gpu_bail!(
"build_householder_constrained_design_device: m must be >= 2 (got {})",
inputs.m
);
}
if !beta.is_finite() {
crate::gpu_bail!(
"build_householder_constrained_design_device: beta must be finite (got {beta})"
);
}
#[cfg(target_os = "linux")]
{
use cudarc::driver::{LaunchConfig, PushKernelArg};
let backend = SphereGpuBackend::probe()?;
let (cc_major, cc_minor) = backend.cc();
let key = S2ModuleCacheKey {
cc_major,
cc_minor,
lmax: inputs.lmax as u32,
kind: inputs.kind,
layout: inputs.layout,
};
let module = backend.module_for(key)?;
let func = module
.load_function("s2_wahba_householder_constrained_colmajor")
.gpu_ctx("sphere load_function householder")?;
let stream = backend.inner.stream.clone();
let data_dev = stream
.clone_htod(inputs.data_xyz)
.gpu_ctx("sphere-hh htod data_xyz")?;
let centers_dev = stream
.clone_htod(inputs.centers_xyz)
.gpu_ctx("sphere-hh htod centers_xyz")?;
let coeffs_dev = stream
.clone_htod(inputs.coeffs)
.gpu_ctx("sphere-hh htod coeffs")?;
let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
let n = inputs.n;
let m = inputs.m;
let cols_out = m - 1;
let ld_out = ((n + 31) / 32) * 32;
let mut out_dev = stream
.alloc_zeros::<f64>(ld_out * cols_out)
.gpu_ctx_with(|err| {
format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
})?;
let block_x: u32 = 128;
let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (block_x, 1, 1),
shared_mem_bytes: 0,
};
let n_i32: i32 = i32::try_from(n).map_err(|_| gpu_err!("sphere-hh n={n} overflows i32"))?;
let m_i32: i32 = i32::try_from(m).map_err(|_| gpu_err!("sphere-hh m={m} overflows i32"))?;
let ld_out_i64: i64 = ld_out as i64;
let mut builder = stream.launch_builder(&func);
builder
.arg(&data_dev)
.arg(¢ers_dev)
.arg(&coeffs_dev)
.arg(&v_dev)
.arg(&beta)
.arg(&n_i32)
.arg(&m_i32)
.arg(&ld_out_i64)
.arg(&mut out_dev);
unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
stream
.synchronize()
.gpu_ctx("sphere-hh kernel synchronize")?;
Ok(DeviceS2KernelMatrix {
rows: n,
cols: cols_out,
ld: ld_out,
col_major_dev: out_dev,
stream,
})
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "sphere GPU backend is Linux-only".to_string(),
})
}
}
pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
let m = w.len();
if m == 0 {
return (Vec::new(), 0.0);
}
let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm == 0.0 {
return (vec![0.0; m], 0.0);
}
let sigma = if w[0] >= 0.0 { norm } else { -norm };
let mut v = w.to_vec();
v[0] += sigma;
let v0 = v[0];
if v0 == 0.0 {
return (vec![0.0; m], 0.0);
}
for entry in v.iter_mut() {
*entry /= v0;
}
let vv: f64 = v.iter().map(|x| x * x).sum();
let beta = 2.0 / vv;
(v, beta)
}
pub fn build_center_kernel_device(
centers_xyz: &[f64],
lmax: usize,
coeffs: &[f64],
kind: SphereSpectralKernelKind,
) -> Result<DeviceS2KernelMatrix, GpuError> {
let m = centers_xyz.len() / 3;
if centers_xyz.len() != 3 * m {
return Err(GpuError::DriverCallFailed {
reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
});
}
let inputs = S2KernelBuildInputs {
n: m,
m,
lmax,
data_xyz: centers_xyz,
centers_xyz,
coeffs,
kind,
layout: DeviceMatrixLayout::ColumnMajor,
};
build_kernel_matrix_device(inputs)
}
pub fn constrained_penalty_host(
c: ArrayView2<'_, f64>,
w: &[f64],
) -> Result<Array2<f64>, GpuError> {
let (m1, m2) = c.dim();
if m1 != m2 {
crate::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
}
let m = m1;
if w.len() != m {
crate::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
}
if m < 2 {
crate::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
}
let (v, beta) = householder_reflector_from_weights(w);
let mut u = vec![0.0_f64; m];
for i in 0..m {
let mut acc = 0.0_f64;
for j in 0..m {
acc += c[(i, j)] * v[j];
}
u[i] = acc;
}
let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
let mut hch = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
hch[(i, j)] =
c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
}
}
let mut s = Array2::<f64>::zeros((m - 1, m - 1));
for i in 0..(m - 1) {
for j in 0..(m - 1) {
s[(i, j)] = hch[(i + 1, j + 1)];
}
}
Ok(s)
}
#[derive(Clone, Debug)]
pub struct PenalisedLsSolution {
pub beta: Vec<f64>,
pub weighted_residual_ssq: f64,
pub log_det_hessian: f64,
}
#[cfg(target_os = "linux")]
pub fn solve_penalised_ls_device(
x_s_device: &DeviceS2KernelMatrix,
wy: &[f64],
r_s: ArrayView2<'_, f64>,
) -> Result<PenalisedLsSolution, GpuError> {
use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
use cudarc::driver::DevicePtrMut;
let n = x_s_device.rows;
let p = x_s_device.cols;
if wy.len() != n {
crate::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
}
if r_s.dim() != (p, p) {
crate::gpu_bail!(
"solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
r_s.dim()
);
}
if p == 0 {
return Ok(PenalisedLsSolution {
beta: Vec::new(),
weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
log_det_hessian: 0.0,
});
}
let stream = x_s_device.stream.clone();
let n_aug = n + p;
let mut a_aug_host = vec![0.0_f64; n_aug * p];
let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
for j in 0..p {
let src_off = j * x_s_device.ld;
let dst_off = j * n_aug;
a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
for i in 0..p {
a_aug_host[dst_off + n + i] = r_s[(i, j)];
}
}
let mut a_dev = stream
.clone_htod(&a_aug_host)
.gpu_ctx("solve_penalised_ls_device htod A_aug")?;
let mut b_host = vec![0.0_f64; n_aug];
b_host[..n].copy_from_slice(wy);
let mut b_dev = stream
.clone_htod(&b_host)
.gpu_ctx("solve_penalised_ls_device htod b_aug")?;
let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
let n_aug_i: i32 = i32::try_from(n_aug)
.map_err(|_| gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
let p_i: i32 =
i32::try_from(p).map_err(|_| gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
let mut lwork: i32 = 0;
{
let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
let status = unsafe {
cusolver_sys::cusolverDnDgeqrf_bufferSize(
solver.cu(),
n_aug_i,
p_i,
a_ptr as *mut f64,
n_aug_i,
&mut lwork,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
crate::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
}
}
let lwork_us = usize::try_from(lwork)
.map_err(|_| gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
let mut workspace = stream
.alloc_zeros::<f64>(lwork_us.max(1))
.gpu_ctx("solve_penalised_ls_device alloc workspace")?;
let mut tau = stream
.alloc_zeros::<f64>(p)
.gpu_ctx("solve_penalised_ls_device alloc tau")?;
let mut info = stream
.alloc_zeros::<i32>(1)
.gpu_ctx("solve_penalised_ls_device alloc info")?;
{
let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
let status = unsafe {
cusolver_sys::cusolverDnDgeqrf(
solver.cu(),
n_aug_i,
p_i,
a_ptr as *mut f64,
n_aug_i,
tau_ptr as *mut f64,
work_ptr as *mut f64,
lwork,
info_ptr as *mut i32,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
crate::gpu_bail!("cusolverDnDgeqrf status={status:?}");
}
}
let mut ormqr_lwork: i32 = 0;
{
let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
let status = unsafe {
cusolver_sys::cusolverDnDormqr_bufferSize(
solver.cu(),
cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
n_aug_i,
1,
p_i,
a_ptr as *const f64,
n_aug_i,
tau_ptr as *const f64,
b_ptr as *mut f64,
n_aug_i,
&mut ormqr_lwork,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
crate::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
}
}
if ormqr_lwork > lwork {
workspace = stream
.alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
.gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
}
{
let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
let status = unsafe {
cusolver_sys::cusolverDnDormqr(
solver.cu(),
cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
n_aug_i,
1,
p_i,
a_ptr as *const f64,
n_aug_i,
tau_ptr as *const f64,
b_ptr as *mut f64,
n_aug_i,
work_ptr as *mut f64,
ormqr_lwork.max(lwork),
info_ptr as *mut i32,
)
};
if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
crate::gpu_bail!("cusolverDnDormqr status={status:?}");
}
}
{
use cudarc::cublas::CudaBlas;
let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
let alpha = 1.0_f64;
let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
let handle = *blas.handle();
let status = unsafe {
cudarc::cublas::sys::cublasDtrsm_v2(
handle,
cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
p_i,
1,
&alpha,
a_ptr as *const f64,
n_aug_i,
b_ptr as *mut f64,
n_aug_i,
)
};
if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
crate::gpu_bail!("cublasDtrsm_v2 status={status:?}");
}
}
let mut b_out = vec![0.0_f64; n_aug];
stream
.memcpy_dtoh(&b_dev, &mut b_out)
.gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
let mut a_back = vec![0.0_f64; n_aug * p];
stream
.memcpy_dtoh(&a_dev, &mut a_back)
.gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
stream
.synchronize()
.gpu_ctx("solve_penalised_ls_device synchronize")?;
let beta: Vec<f64> = b_out[..p].to_vec();
let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
let mut log_abs_r = 0.0_f64;
for k in 0..p {
let r_kk = a_back[k * n_aug + k];
log_abs_r += r_kk.abs().ln();
}
let log_det_hessian = 2.0 * log_abs_r;
Ok(PenalisedLsSolution {
beta,
weighted_residual_ssq: augmented_residual_ssq,
log_det_hessian,
})
}
#[cfg(not(target_os = "linux"))]
pub fn solve_penalised_ls_device(
x_s_device: &DeviceS2KernelMatrix,
wy: &[f64],
r_s: ArrayView2<'_, f64>,
) -> Result<PenalisedLsSolution, GpuError> {
Err(GpuError::DriverLibraryUnavailable {
reason: format!(
"sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
x_s_device.rows,
x_s_device.cols,
wy.len(),
r_s.dim()
),
})
}
#[cfg(test)]
mod sphere_gpu_tests {
use super::*;
use crate::basis::{
SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
spherical_wahba_kernel_matrix_with_kind,
};
use ndarray::Array2;
fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
let mut rows = Vec::with_capacity(n_lat * n_lon);
for i in 0..n_lat {
let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
for j in 0..n_lon {
let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
rows.push(lat);
rows.push(lon);
}
}
Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
}
#[test]
fn xyz_preprocessing_matches_unit_sphere() {
let latlon = ndarray::array![
[0.0, 0.0],
[90.0, 0.0],
[0.0, 90.0],
[-90.0, 17.5],
[45.0, -120.0],
];
let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
assert_eq!(xyz.len(), 3 * 5);
for i in 0..5 {
let nrm2 = xyz[3 * i] * xyz[3 * i]
+ xyz[3 * i + 1] * xyz[3 * i + 1]
+ xyz[3 * i + 2] * xyz[3 * i + 2];
assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
}
assert!((xyz[0] - 1.0).abs() < 1e-15);
assert!((xyz[5] - 1.0).abs() > 0.5);
assert!((xyz[5]).abs() < 1e-15);
assert!((xyz[7] - 1.0).abs() < 1e-15);
}
#[test]
fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
for m_penalty in 1..=4 {
for &lmax in &[5_usize, 20, 50] {
let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
let expected: f64 = coeffs.iter().sum();
let got = sphere_truncated_spectral_eval(1.0, &coeffs);
assert!(
(got - expected).abs() < 1e-13,
"K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
);
}
}
}
#[test]
fn truncated_spectral_at_antipode_matches_alternating_sum() {
for m_penalty in 1..=4 {
for &lmax in &[5_usize, 20, 50] {
let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
let expected: f64 = coeffs
.iter()
.enumerate()
.map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
.sum();
let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
assert!(
(got - expected).abs() < 1e-13,
"K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
);
}
}
}
#[test]
fn truncated_spectral_matrix_is_symmetric() {
let centers = ndarray::array![
[10.0_f64, 20.0],
[-30.0, 100.0],
[45.0, -60.0],
[-89.0, 0.0],
[0.0, 180.0],
[60.0, -179.9],
];
for m_penalty in [1usize, 2, 4] {
for &lmax in &[10_usize, 30] {
let mat = spherical_wahba_kernel_matrix_with_kind(
centers.view(),
centers.view(),
m_penalty,
false,
SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
)
.expect("kernel matrix");
let n = centers.nrows();
let mut max_asym = 0.0_f64;
for i in 0..n {
for j in 0..n {
let d = (mat[(i, j)] - mat[(j, i)]).abs();
if d > max_asym {
max_asym = d;
}
}
}
assert!(
max_asym < 1e-13,
"K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
);
}
}
}
#[test]
fn truncated_coefficients_have_zero_constant_mode() {
for m in 1..=4 {
let c = sobolev_s2_truncated_coefficients(50, m);
assert_eq!(c.len(), 51);
assert_eq!(c[0], 0.0);
assert!(c[1] > 0.0);
for ell in 2..=50 {
assert!(
c[ell] < c[ell - 1] + 1e-15,
"Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
c[ell],
c[ell - 1]
);
}
}
}
#[test]
fn truncated_spectral_matches_matrix_helper() {
let m_penalty = 2;
let lmax = 20;
let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
let data = ndarray::array![[12.5, -34.0]];
let centers = ndarray::array![[40.0, 10.0]];
let mat = spherical_wahba_kernel_matrix_with_kind(
data.view(),
centers.view(),
m_penalty,
false,
SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
)
.expect("kernel matrix");
let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
assert!(
(mat[(0, 0)] - expected).abs() < 1e-13,
"matrix helper differs from scalar evaluator: {} vs {}",
mat[(0, 0)],
expected
);
}
#[test]
fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
let m = 6;
let mut c = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let d = (i as f64 - j as f64).abs();
c[(i, j)] = (-0.5 * d).exp();
}
}
let w = vec![1.0_f64; m];
let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
assert_eq!(s.dim(), (m - 1, m - 1));
let mut max_asym = 0.0_f64;
for i in 0..(m - 1) {
for j in 0..(m - 1) {
let d = (s[(i, j)] - s[(j, i)]).abs();
if d > max_asym {
max_asym = d;
}
}
}
assert!(
max_asym < 1e-13,
"S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
);
let ones = ndarray::Array1::<f64>::ones(m - 1);
let sx = s.dot(&ones);
assert!(sx.iter().all(|v| v.is_finite()));
}
#[test]
fn householder_reflector_zeroes_target_vector() {
let w = vec![3.0, 4.0, 0.0, -1.0];
let (v, beta) = householder_reflector_from_weights(&w);
let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
let hw: Vec<f64> = w
.iter()
.zip(&v)
.map(|(wj, vj)| wj - beta * dot * vj)
.collect();
for entry in hw.iter().skip(1) {
assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
}
assert!(hw[0].abs() > 0.0);
}
#[test]
fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
let Some(_runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
return;
};
let backend = match SphereGpuBackend::probe() {
Ok(b) => b,
Err(err) => {
eprintln!("[sphere_gpu test] backend probe failed: {err}");
return;
}
};
let _ = backend;
let data_ll = small_latlon_grid(7, 9);
let centers_ll = small_latlon_grid(5, 7);
let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
let n = data_ll.nrows();
let m = centers_ll.nrows();
let penalty = 2usize;
let lmax = 20usize;
let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
let inputs = S2KernelBuildInputs {
n,
m,
lmax,
data_xyz: &data_xyz,
centers_xyz: ¢ers_xyz,
coeffs: &coeffs,
kind: SphereSpectralKernelKind::Sobolev,
layout: DeviceMatrixLayout::ColumnMajor,
};
let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
let cpu = spherical_wahba_kernel_matrix_with_kind(
data_ll.view(),
centers_ll.view(),
penalty,
false,
SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
)
.expect("cpu kernel matrix");
let mut max_abs = 0.0_f64;
for i in 0..n {
for j in 0..m {
let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
if d > max_abs {
max_abs = d;
}
}
}
assert!(
max_abs < 1e-11,
"GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
);
}
#[test]
fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
let Some(_runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
return;
};
if SphereGpuBackend::probe().is_err() {
eprintln!("[sphere_gpu test] backend probe failed — skipping");
return;
}
use crate::basis::{
CenterStrategy, SphereMethod, SphericalSplineBasisSpec, build_spherical_spline_basis,
sobolev_s2_truncated_coefficients,
};
let _ = sobolev_s2_truncated_coefficients(1, 1);
let data = small_latlon_grid(100, 100);
let lmax: u16 = 30;
let penalty_order = 2usize;
let spec_gpu = SphericalSplineBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
penalty_order,
double_penalty: false,
radians: false,
method: SphereMethod::Wahba,
max_degree: None,
wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
};
let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
.expect("GPU-eligible build_spherical_spline_basis succeeds");
let centers =
crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
.expect("centers");
let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
data.view(),
centers.view(),
penalty_order,
false,
SphereWahbaKernel::SobolevTruncated { lmax },
)
.expect("cpu raw design");
let weights = crate::basis::sphere_area_weights(centers.view(), false);
let z =
crate::basis::weighted_coefficient_sum_to_zero_transform(weights.view()).expect("z");
let cpu_design = raw_cpu.dot(&z);
let gpu_design = result_gpu.design.as_dense().expect("dense design").clone();
assert_eq!(gpu_design.dim(), cpu_design.dim());
let mut max_abs = 0.0_f64;
let mut max_rel = 0.0_f64;
for ((g, c), _) in gpu_design.iter().zip(cpu_design.iter()).zip(0..) {
let d = (g - c).abs();
if d > max_abs {
max_abs = d;
}
let denom = g.abs().max(c.abs()).max(1e-300);
let r = d / denom;
if r > max_rel {
max_rel = r;
}
}
assert!(
max_rel < 1e-9,
"end-to-end design parity max relative |Δ| = {max_rel:.3e} >= 1e-9 (abs {max_abs:.3e})"
);
}
#[test]
fn sphere_gpu_householder_parity_vs_raw_dot_z() {
let Some(_runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
return;
};
if SphereGpuBackend::probe().is_err() {
eprintln!("[sphere_gpu test] backend probe failed — skipping");
return;
}
let data_ll = small_latlon_grid(6, 8);
let centers_ll = small_latlon_grid(4, 5);
let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
let n = data_ll.nrows();
let m = centers_ll.nrows();
let penalty = 2usize;
let lmax = 15usize;
let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
let inputs_raw = S2KernelBuildInputs {
n,
m,
lmax,
data_xyz: &data_xyz,
centers_xyz: ¢ers_xyz,
coeffs: &coeffs,
kind: SphereSpectralKernelKind::Sobolev,
layout: DeviceMatrixLayout::ColumnMajor,
};
let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
let b = b_dev.to_host_array().expect("dtoh raw");
let w = vec![1.0_f64; m];
let (v, beta) = householder_reflector_from_weights(&w);
let mut xs_host = Array2::<f64>::zeros((n, m - 1));
for i in 0..n {
let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
for j_out in 0..(m - 1) {
xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
}
}
let xs_dev =
build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
let mut max_abs = 0.0_f64;
for i in 0..n {
for j in 0..(m - 1) {
let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
if d > max_abs {
max_abs = d;
}
}
}
assert!(
max_abs < 1e-12,
"Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
);
}
#[test]
fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
let Some(_runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
return;
};
if SphereGpuBackend::probe().is_err() {
eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
return;
}
let n_lat = 500usize;
let n_lon = 400usize;
assert_eq!(n_lat * n_lon, 200_000);
let data_ll = small_latlon_grid(n_lat, n_lon);
let m = 200usize;
let centers_ll =
crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
.expect("centers");
let n = data_ll.nrows();
let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
let penalty_order = 2usize;
let lmax = 50usize;
let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
let inputs_warm = S2KernelBuildInputs {
n,
m,
lmax,
data_xyz: &data_xyz,
centers_xyz: ¢ers_xyz,
coeffs: &coeffs,
kind: SphereSpectralKernelKind::Sobolev,
layout: DeviceMatrixLayout::ColumnMajor,
};
let _ = build_kernel_matrix_device(inputs_warm.clone()).expect("warmup");
let t0 = std::time::Instant::now();
let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
let _host_gpu = dev.to_host_array().expect("dtoh");
let gpu_secs = t0.elapsed().as_secs_f64();
let t1 = std::time::Instant::now();
let _cpu = spherical_wahba_kernel_matrix_with_kind(
data_ll.view(),
centers_ll.view(),
penalty_order,
false,
SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
)
.expect("cpu kernel matrix");
let cpu_secs = t1.elapsed().as_secs_f64();
let ratio = cpu_secs / gpu_secs.max(1e-9);
eprintln!(
"[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
);
assert!(
ratio >= 20.0,
"GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
);
}
#[test]
fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
let Some(_runtime) = super::super::runtime::GpuRuntime::global() else {
eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
return;
};
if SphereGpuBackend::probe().is_err() {
eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
return;
}
use crate::basis::{
CenterStrategy, SphereMethod, SphericalSplineBasisSpec, build_spherical_spline_basis,
};
let n_lat = 500usize;
let n_lon = 400usize;
let data_ll = small_latlon_grid(n_lat, n_lon);
let m: usize = 200;
let lmax: u16 = 50;
let spec_gpu = SphericalSplineBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
penalty_order: 2,
double_penalty: false,
radians: false,
method: SphereMethod::Wahba,
max_degree: None,
wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
};
let _ = build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build");
let t0 = std::time::Instant::now();
let _ = build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build");
let gpu_secs = t0.elapsed().as_secs_f64();
let centers =
crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
.expect("centers");
let weights = crate::basis::sphere_area_weights(centers.view(), false);
let z =
crate::basis::weighted_coefficient_sum_to_zero_transform(weights.view()).expect("z");
let t1 = std::time::Instant::now();
let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
data_ll.view(),
centers.view(),
2,
false,
SphereWahbaKernel::SobolevTruncated { lmax },
)
.expect("cpu raw");
let _design_cpu = raw_cpu.dot(&z);
let cpu_secs = t1.elapsed().as_secs_f64();
let ratio = cpu_secs / gpu_secs.max(1e-9);
eprintln!(
"[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
data_ll.nrows()
);
assert!(
ratio >= 10.0,
"End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
);
}
#[test]
fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
use crate::basis::{
select_spherical_farthest_point_centers, sphere_area_weights,
spherical_wahba_kernel_matrix_with_kind, weighted_coefficient_sum_to_zero_transform,
};
use crate::linalg::faer_ndarray::FaerCholesky;
use faer::Side;
let _runtime = super::super::runtime::GpuRuntime::global()
.expect("task #25 parity requires CUDA runtime (test is #[ignore]d off CUDA)");
SphereGpuBackend::probe()
.expect("task #25 parity requires sphere GPU backend probe to succeed");
let data_ll = small_latlon_grid(25, 40);
assert_eq!(data_ll.nrows(), 1000);
let n = data_ll.nrows();
let m: usize = 80;
let lmax_u16: u16 = 15;
let lmax: usize = lmax_u16 as usize;
let penalty_order: usize = 2;
let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
let lambda: f64 = 1.0e-3;
let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
.expect("farthest-point centers");
assert_eq!(centers_ll.nrows(), m);
let weights = sphere_area_weights(centers_ll.view(), false);
let z = weighted_coefficient_sum_to_zero_transform(weights.view())
.expect("weighted sum-to-zero transform");
let p = z.ncols();
assert_eq!(p, m - 1);
let k_cc = spherical_wahba_kernel_matrix_with_kind(
centers_ll.view(),
centers_ll.view(),
penalty_order,
false,
kernel,
)
.expect("centers×centers kernel");
let s_full = z.t().dot(&k_cc).dot(&z);
let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
data_ll.view(),
centers_ll.view(),
penalty_order,
false,
kernel,
)
.expect("CPU raw design");
let x_s_cpu = raw_design_cpu.dot(&z);
let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
let inputs = S2KernelBuildInputs {
n,
m,
lmax,
data_xyz: &data_xyz,
centers_xyz: ¢ers_xyz,
coeffs: &coeffs,
kind: SphereSpectralKernelKind::Sobolev,
layout: DeviceMatrixLayout::ColumnMajor,
};
let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
let x_s_gpu = raw_design_gpu.dot(&z);
assert_eq!(x_s_cpu.dim(), (n, p));
assert_eq!(x_s_gpu.dim(), (n, p));
let mut y = ndarray::Array1::<f64>::zeros(n);
for i in 0..n {
let lat_rad = data_ll[(i, 0)].to_radians();
let lon_rad = data_ll[(i, 1)].to_radians();
y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
+ 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
}
let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
let xtx = x_s.t().dot(x_s);
let mut a = xtx;
for i in 0..p {
for j in 0..p {
a[(i, j)] += lambda * s_full[(i, j)];
}
}
let rhs = x_s.t().dot(&y);
let factor = a
.cholesky(Side::Lower)
.expect("penalised normal equations are SPD under λ > 0");
factor.solvevec(&rhs)
};
let beta_cpu = solve_penalised(&x_s_cpu);
let beta_gpu = solve_penalised(&x_s_gpu);
assert_eq!(beta_cpu.len(), p);
assert_eq!(beta_gpu.len(), p);
let yhat_cpu = x_s_cpu.dot(&beta_cpu);
let yhat_gpu = x_s_gpu.dot(&beta_gpu);
let mut max_beta_delta = 0.0_f64;
for k in 0..p {
let d = (beta_cpu[k] - beta_gpu[k]).abs();
if d > max_beta_delta {
max_beta_delta = d;
}
}
let mut max_fit_delta = 0.0_f64;
for i in 0..n {
let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
if d > max_fit_delta {
max_fit_delta = d;
}
}
eprintln!(
"[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
);
assert!(
max_beta_delta <= 1.0e-9,
"GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > 1e-9"
);
assert!(
max_fit_delta <= 1.0e-9,
"GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
);
}
}