1use std::sync::OnceLock;
20
21use ndarray::{Array2, ArrayView2};
22
23use gam_gpu::gpu_error::GpuError;
24#[cfg(target_os = "linux")]
25use gam_gpu::gpu_error::GpuResultExt;
26use gam_gpu::{GpuDecision, GpuKernel, decide};
27
28#[cfg(target_os = "linux")]
29use std::collections::HashMap;
30#[cfg(target_os = "linux")]
31use std::sync::{Arc, Mutex};
32
33#[cfg(target_os = "linux")]
34use cudarc::driver::{CudaContext, CudaModule, CudaSlice, CudaStream};
35
36#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
40pub enum SphereSpectralKernelKind {
41 Sobolev,
43 Pseudo,
45}
46
47impl SphereSpectralKernelKind {
48 pub fn coefficients(self, lmax: usize, m: usize) -> Vec<f64> {
52 match self {
53 SphereSpectralKernelKind::Sobolev => {
54 crate::basis::sobolev_s2_truncated_coefficients(lmax, m)
55 }
56 SphereSpectralKernelKind::Pseudo => {
57 crate::basis::pseudo_s2_truncated_coefficients(lmax, m)
58 }
59 }
60 }
61
62 pub const fn tag(self) -> &'static str {
64 match self {
65 SphereSpectralKernelKind::Sobolev => "sobolev",
66 SphereSpectralKernelKind::Pseudo => "pseudo",
67 }
68 }
69}
70
71#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
75pub enum DeviceMatrixLayout {
76 ColumnMajor,
77}
78
79pub fn latlon_to_xyz_host(latlon: ArrayView2<'_, f64>, radians: bool) -> Result<Vec<f64>, String> {
86 if latlon.ncols() != 2 {
87 return Err(format!(
88 "latlon_to_xyz_host: expected (_, 2) lat/lon matrix, got shape {:?}",
89 latlon.shape()
90 ));
91 }
92 let deg = if radians {
93 1.0
94 } else {
95 std::f64::consts::PI / 180.0
96 };
97 let n = latlon.nrows();
98 let mut out = Vec::with_capacity(3 * n);
99 for row in latlon.outer_iter() {
100 let lat = row[0] * deg;
101 let lon = row[1] * deg;
102 let (s_lat, c_lat) = lat.sin_cos();
103 let (s_lon, c_lon) = lon.sin_cos();
104 out.push(c_lat * c_lon);
106 out.push(c_lat * s_lon);
107 out.push(s_lat);
108 }
109 Ok(out)
110}
111
112#[cfg(target_os = "linux")]
119pub struct DeviceS2KernelMatrix {
120 pub rows: usize,
121 pub cols: usize,
122 pub ld: usize,
123 pub col_major_dev: CudaSlice<f64>,
124 pub stream: Arc<CudaStream>,
125}
126
127#[cfg(not(target_os = "linux"))]
128pub struct DeviceS2KernelMatrix {
129 pub rows: usize,
130 pub cols: usize,
131 pub ld: usize,
132 pub col_major_dev: Vec<f64>,
134}
135
136impl DeviceS2KernelMatrix {
137 pub fn to_host_array(&self) -> Result<Array2<f64>, GpuError> {
141 let mut col_major = vec![0.0_f64; self.ld * self.cols];
142 self.copy_to_host_col_major(&mut col_major)?;
143 let mut out = Array2::<f64>::zeros((self.rows, self.cols));
144 for j in 0..self.cols {
145 for i in 0..self.rows {
146 out[(i, j)] = col_major[j * self.ld + i];
147 }
148 }
149 Ok(out)
150 }
151
152 #[cfg(target_os = "linux")]
157 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
158 let needed = self.ld * self.cols;
159 if dst.len() != needed {
160 gam_gpu::gpu_bail!(
161 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
162 dst.len(),
163 needed
164 );
165 }
166 self.stream
167 .memcpy_dtoh(&self.col_major_dev, dst)
168 .gpu_ctx("DeviceS2KernelMatrix dtoh")?;
169 self.stream
170 .synchronize()
171 .gpu_ctx("DeviceS2KernelMatrix synchronize")?;
172 Ok(())
173 }
174
175 #[cfg(not(target_os = "linux"))]
176 pub fn copy_to_host_col_major(&self, dst: &mut [f64]) -> Result<(), GpuError> {
177 let needed = self.ld * self.cols;
178 if dst.len() != needed {
179 gam_gpu::gpu_bail!(
180 "DeviceS2KernelMatrix::copy_to_host_col_major: dst.len()={} expected {}",
181 dst.len(),
182 needed
183 );
184 }
185 dst.copy_from_slice(&self.col_major_dev);
186 Ok(())
187 }
188}
189
190#[derive(Clone, Debug)]
201pub struct S2KernelBuildInputs<'a> {
202 pub n: usize,
203 pub m: usize,
204 pub lmax: usize,
205 pub data_xyz: &'a [f64],
206 pub centers_xyz: &'a [f64],
207 pub coeffs: &'a [f64],
208 pub kind: SphereSpectralKernelKind,
209 pub layout: DeviceMatrixLayout,
210}
211
212impl<'a> S2KernelBuildInputs<'a> {
213 fn validate(&self) -> Result<(), GpuError> {
214 if self.lmax == 0 {
215 return Err(GpuError::DriverCallFailed {
216 reason: "S2KernelBuildInputs: lmax must be >= 1".into(),
217 });
218 }
219 if self.data_xyz.len() != 3 * self.n {
220 gam_gpu::gpu_bail!(
221 "S2KernelBuildInputs: data_xyz.len()={} != 3*n={}",
222 self.data_xyz.len(),
223 3 * self.n
224 );
225 }
226 if self.centers_xyz.len() != 3 * self.m {
227 gam_gpu::gpu_bail!(
228 "S2KernelBuildInputs: centers_xyz.len()={} != 3*m={}",
229 self.centers_xyz.len(),
230 3 * self.m
231 );
232 }
233 if self.coeffs.len() != self.lmax + 1 {
234 gam_gpu::gpu_bail!(
235 "S2KernelBuildInputs: coeffs.len()={} != lmax+1={}",
236 self.coeffs.len(),
237 self.lmax + 1
238 );
239 }
240 if self.coeffs[0] != 0.0 {
241 return Err(GpuError::DriverCallFailed {
242 reason: "S2KernelBuildInputs: coeffs[0] must be 0 (mean-zero kernel)".into(),
243 });
244 }
245 Ok(())
246 }
247}
248
249#[cfg(target_os = "linux")]
259const KERNEL_TEMPLATE: &str = r#"
260// LMAX is supplied by the host via a `#define LMAX ...` prepended to
261// this source before NVRTC compilation (see `SphereGpuBackend::module_for`).
262extern "C" __global__
263__launch_bounds__(256)
264void s2_wahba_legendre_colmajor(
265 const double* __restrict__ data_xyz, // n × 3 (row-major flat)
266 const double* __restrict__ centers_xyz, // m × 3 (row-major flat)
267 const double* __restrict__ coeffs, // length LMAX + 1, coeffs[0] = 0
268 int n,
269 int m,
270 long long ld,
271 double* __restrict__ out // ld × m column-major
272) {
273 const int i = blockIdx.y * blockDim.y + threadIdx.y;
274 const int j = blockIdx.x * blockDim.x + threadIdx.x;
275 if (i >= n || j >= m) return;
276
277 // Load (x_i, y_i, z_i) and (cx_j, cy_j, cz_j) into registers.
278 const double xi = data_xyz[3 * i + 0];
279 const double yi = data_xyz[3 * i + 1];
280 const double zi = data_xyz[3 * i + 2];
281 const double cxj = centers_xyz[3 * j + 0];
282 const double cyj = centers_xyz[3 * j + 1];
283 const double czj = centers_xyz[3 * j + 2];
284
285 // t = clamp(x_i · z_j, -1, +1).
286 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
287 if (t > 1.0) t = 1.0;
288 if (t < -1.0) t = -1.0;
289
290 // Legendre 3-term recurrence in registers.
291 // P_0(t) = 1, P_1(t) = t.
292 double p_prev = 1.0;
293 double p_curr = t;
294 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
295
296 #pragma unroll 8
297 for (int ell = 1; ell < LMAX; ++ell) {
298 const double lf = (double) ell;
299 const double inv = 1.0 / (lf + 1.0);
300 // p_{ell+1} = ((2ell+1) * t * p_curr - ell * p_prev) / (ell+1)
301 const double p_next =
302 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
303 acc = fma(coeffs[ell + 1], p_next, acc);
304 p_prev = p_curr;
305 p_curr = p_next;
306 }
307
308 out[(long long) j * ld + (long long) i] = acc;
309}
310
311// Fused Householder-constrained kernel (Phase 3). Z = I - beta · v · v^T,
312// the constrained design is X_s = B[:, 1..m] - beta * (B · v) · v[1..m]^T,
313// i.e. drop the first column after applying Z. Each thread computes one
314// row of B in registers (m kernel evaluations), forms d_i = B_row · v,
315// then emits X_s[i, j_out] = B_row[j_out + 1] - beta * d_i * v[j_out + 1]
316// for j_out in 0..m-1.
317//
318// Grid: 1D over rows (block_dim.x rows per block). Each thread iterates
319// over centers in an inner loop — register-bound by the per-row state
320// (xyz_i, p_prev, p_curr, acc, and a small per-center scratch).
321extern "C" __global__
322__launch_bounds__(128)
323void s2_wahba_householder_constrained_colmajor(
324 const double* __restrict__ data_xyz, // n × 3
325 const double* __restrict__ centers_xyz, // m × 3
326 const double* __restrict__ coeffs, // length LMAX + 1
327 const double* __restrict__ v, // length m, Householder vector
328 double beta,
329 int n,
330 int m,
331 long long ld_out,
332 double* __restrict__ out // ld_out × (m-1) column-major
333) {
334 const int i = blockIdx.x * blockDim.x + threadIdx.x;
335 if (i >= n) return;
336
337 const double xi = data_xyz[3 * i + 0];
338 const double yi = data_xyz[3 * i + 1];
339 const double zi = data_xyz[3 * i + 2];
340
341 // Pass 1: compute d_i = sum_j v[j] * B[i, j].
342 double d_i = 0.0;
343 for (int j = 0; j < m; ++j) {
344 const double cxj = centers_xyz[3 * j + 0];
345 const double cyj = centers_xyz[3 * j + 1];
346 const double czj = centers_xyz[3 * j + 2];
347 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
348 if (t > 1.0) t = 1.0;
349 if (t < -1.0) t = -1.0;
350
351 double p_prev = 1.0;
352 double p_curr = t;
353 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
354 #pragma unroll 8
355 for (int ell = 1; ell < LMAX; ++ell) {
356 const double lf = (double) ell;
357 const double inv = 1.0 / (lf + 1.0);
358 const double p_next =
359 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
360 acc = fma(coeffs[ell + 1], p_next, acc);
361 p_prev = p_curr;
362 p_curr = p_next;
363 }
364 d_i = fma(v[j], acc, d_i);
365 }
366
367 // Pass 2: emit X_s[i, j_out] = B[i, j_out+1] - beta * d_i * v[j_out+1].
368 const double bd = beta * d_i;
369 for (int j_out = 0; j_out < m - 1; ++j_out) {
370 const int j = j_out + 1;
371 const double cxj = centers_xyz[3 * j + 0];
372 const double cyj = centers_xyz[3 * j + 1];
373 const double czj = centers_xyz[3 * j + 2];
374 double t = fma(xi, cxj, fma(yi, cyj, zi * czj));
375 if (t > 1.0) t = 1.0;
376 if (t < -1.0) t = -1.0;
377
378 double p_prev = 1.0;
379 double p_curr = t;
380 double acc = coeffs[0] * p_prev + coeffs[1] * p_curr;
381 #pragma unroll 8
382 for (int ell = 1; ell < LMAX; ++ell) {
383 const double lf = (double) ell;
384 const double inv = 1.0 / (lf + 1.0);
385 const double p_next =
386 fma((2.0 * lf + 1.0) * t, p_curr, -lf * p_prev) * inv;
387 acc = fma(coeffs[ell + 1], p_next, acc);
388 p_prev = p_curr;
389 p_curr = p_next;
390 }
391 const double xs = acc - bd * v[j];
392 out[(long long) j_out * ld_out + (long long) i] = xs;
393 }
394}
395"#;
396
397#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
407pub struct S2ModuleCacheKey {
408 pub cc_major: i32,
409 pub cc_minor: i32,
410 pub lmax: u32,
411 pub kind: SphereSpectralKernelKind,
412 pub layout: DeviceMatrixLayout,
413}
414
415pub const fn sphere_gpu_compiled() -> bool {
418 cfg!(target_os = "linux")
419}
420
421#[must_use]
428pub fn sphere_kernel_decision(n: usize, m: usize, lmax: usize) -> GpuDecision {
429 let large_enough = if let Some(runtime) = gam_gpu::device_runtime::GpuRuntime::global() {
430 let ld = ((n + 31) / 32) * 32;
431 let needed_bytes = ld
432 .saturating_mul(m)
433 .saturating_mul(std::mem::size_of::<f64>());
434 let budget = runtime.memory_budget_bytes;
435 n.saturating_mul(m) >= 1_000_000 && lmax <= 200 && needed_bytes <= budget
436 } else {
437 false
438 };
439 decide(
440 GpuKernel::SpatialKernelOperator,
441 gam_gpu::GpuEligibility::from_flags(sphere_gpu_compiled(), large_enough),
442 )
443}
444
445#[cfg(target_os = "linux")]
446struct SphereGpuContext {
447 ctx: Arc<CudaContext>,
448 stream: Arc<CudaStream>,
449 modules: Mutex<HashMap<S2ModuleCacheKey, Arc<CudaModule>>>,
450 cc_major: i32,
451 cc_minor: i32,
452}
453
454pub struct SphereGpuBackend {
457 #[cfg(target_os = "linux")]
458 inner: SphereGpuContext,
459}
460
461impl SphereGpuBackend {
462 pub fn probe() -> Result<&'static Self, GpuError> {
464 static BACKEND: OnceLock<Result<SphereGpuBackend, GpuError>> = OnceLock::new();
465 BACKEND
466 .get_or_init(|| {
467 #[cfg(target_os = "linux")]
468 {
469 Self::probe_linux()
470 }
471 #[cfg(not(target_os = "linux"))]
472 {
473 Err(GpuError::DriverLibraryUnavailable {
474 reason: "sphere GPU backend is Linux-only".to_string(),
475 })
476 }
477 })
478 .as_ref()
479 .map_err(GpuError::clone)
480 }
481
482 #[cfg(target_os = "linux")]
483 fn probe_linux() -> Result<Self, GpuError> {
484 let parts = gam_gpu::backend_probe::probe_cuda_backend("sphere")?;
485 Ok(SphereGpuBackend {
486 inner: SphereGpuContext {
487 ctx: parts.ctx,
488 stream: parts.stream,
489 modules: Mutex::new(HashMap::new()),
490 cc_major: parts.capability.compute_major,
491 cc_minor: parts.capability.compute_minor,
492 },
493 })
494 }
495
496 #[cfg(target_os = "linux")]
499 fn module_for(&self, key: S2ModuleCacheKey) -> Result<Arc<CudaModule>, GpuError> {
500 if let Ok(guard) = self.inner.modules.lock() {
501 if let Some(existing) = guard.get(&key) {
502 return Ok(existing.clone());
503 }
504 }
505 let src = format!("#define LMAX {}\n{}", key.lmax, KERNEL_TEMPLATE);
512 let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
513 format!(
514 "sphere NVRTC compile (kind={}, lmax={}): {err}",
515 key.kind.tag(),
516 key.lmax
517 )
518 })?;
519 let module = self
520 .inner
521 .ctx
522 .load_module(ptx)
523 .gpu_ctx("sphere module load")?;
524 if let Ok(mut guard) = self.inner.modules.lock() {
525 guard.entry(key).or_insert_with(|| module.clone());
526 }
527 Ok(module)
528 }
529
530 #[cfg(target_os = "linux")]
531 fn cc(&self) -> (i32, i32) {
532 (self.inner.cc_major, self.inner.cc_minor)
533 }
534}
535
536pub fn build_kernel_matrix_device(
543 inputs: S2KernelBuildInputs<'_>,
544) -> Result<DeviceS2KernelMatrix, GpuError> {
545 inputs.validate()?;
546
547 #[cfg(target_os = "linux")]
548 {
549 use cudarc::driver::{LaunchConfig, PushKernelArg};
550 let backend = SphereGpuBackend::probe()?;
551 let (cc_major, cc_minor) = backend.cc();
552 let key = S2ModuleCacheKey {
553 cc_major,
554 cc_minor,
555 lmax: inputs.lmax as u32,
556 kind: inputs.kind,
557 layout: inputs.layout,
558 };
559 let module = backend.module_for(key)?;
560 let func = module
561 .load_function("s2_wahba_legendre_colmajor")
562 .gpu_ctx("sphere load_function raw")?;
563 let stream = backend.inner.stream.clone();
564
565 let data_dev = stream
566 .clone_htod(inputs.data_xyz)
567 .gpu_ctx("sphere htod data_xyz")?;
568 let centers_dev = stream
569 .clone_htod(inputs.centers_xyz)
570 .gpu_ctx("sphere htod centers_xyz")?;
571 let coeffs_dev = stream
572 .clone_htod(inputs.coeffs)
573 .gpu_ctx("sphere htod coeffs")?;
574
575 let n = inputs.n;
576 let m = inputs.m;
577 let ld = ((n + 31) / 32) * 32;
578 let mut out_dev = stream
579 .alloc_zeros::<f64>(ld * m)
580 .gpu_ctx_with(|err| format!("sphere alloc out (ld={ld}, m={m}): {err}"))?;
581
582 let block_x: u32 = 32;
584 let block_y: u32 = 8;
585 let grid_x: u32 = ((m as u32) + block_x - 1) / block_x;
586 let grid_y: u32 = ((n as u32) + block_y - 1) / block_y;
587 let cfg = LaunchConfig {
588 grid_dim: (grid_x, grid_y, 1),
589 block_dim: (block_x, block_y, 1),
590 shared_mem_bytes: 0,
591 };
592 let n_i32: i32 =
593 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere n={n} overflows i32"))?;
594 let m_i32: i32 =
595 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere m={m} overflows i32"))?;
596 let ld_i64: i64 = ld as i64;
597
598 let mut builder = stream.launch_builder(&func);
599 builder
600 .arg(&data_dev)
601 .arg(¢ers_dev)
602 .arg(&coeffs_dev)
603 .arg(&n_i32)
604 .arg(&m_i32)
605 .arg(&ld_i64)
606 .arg(&mut out_dev);
607 unsafe { builder.launch(cfg) }.gpu_ctx("sphere raw kernel launch")?;
612 stream
613 .synchronize()
614 .gpu_ctx("sphere raw kernel synchronize")?;
615
616 Ok(DeviceS2KernelMatrix {
617 rows: n,
618 cols: m,
619 ld,
620 col_major_dev: out_dev,
621 stream,
622 })
623 }
624
625 #[cfg(not(target_os = "linux"))]
626 {
627 Err(GpuError::DriverLibraryUnavailable {
628 reason: "sphere GPU backend is Linux-only".to_string(),
629 })
630 }
631}
632
633pub fn build_householder_constrained_design_device(
637 inputs: S2KernelBuildInputs<'_>,
638 v: &[f64],
639 beta: f64,
640) -> Result<DeviceS2KernelMatrix, GpuError> {
641 inputs.validate()?;
642 if v.len() != inputs.m {
643 gam_gpu::gpu_bail!(
644 "build_householder_constrained_design_device: v.len()={} != m={}",
645 v.len(),
646 inputs.m
647 );
648 }
649 if inputs.m < 2 {
650 gam_gpu::gpu_bail!(
651 "build_householder_constrained_design_device: m must be >= 2 (got {})",
652 inputs.m
653 );
654 }
655 if !beta.is_finite() {
656 gam_gpu::gpu_bail!(
657 "build_householder_constrained_design_device: beta must be finite (got {beta})"
658 );
659 }
660
661 #[cfg(target_os = "linux")]
662 {
663 use cudarc::driver::{LaunchConfig, PushKernelArg};
664 let backend = SphereGpuBackend::probe()?;
665 let (cc_major, cc_minor) = backend.cc();
666 let key = S2ModuleCacheKey {
667 cc_major,
668 cc_minor,
669 lmax: inputs.lmax as u32,
670 kind: inputs.kind,
671 layout: inputs.layout,
672 };
673 let module = backend.module_for(key)?;
674 let func = module
675 .load_function("s2_wahba_householder_constrained_colmajor")
676 .gpu_ctx("sphere load_function householder")?;
677 let stream = backend.inner.stream.clone();
678
679 let data_dev = stream
680 .clone_htod(inputs.data_xyz)
681 .gpu_ctx("sphere-hh htod data_xyz")?;
682 let centers_dev = stream
683 .clone_htod(inputs.centers_xyz)
684 .gpu_ctx("sphere-hh htod centers_xyz")?;
685 let coeffs_dev = stream
686 .clone_htod(inputs.coeffs)
687 .gpu_ctx("sphere-hh htod coeffs")?;
688 let v_dev = stream.clone_htod(v).gpu_ctx("sphere-hh htod v")?;
689
690 let n = inputs.n;
691 let m = inputs.m;
692 let cols_out = m - 1;
693 let ld_out = ((n + 31) / 32) * 32;
694 let mut out_dev = stream
695 .alloc_zeros::<f64>(ld_out * cols_out)
696 .gpu_ctx_with(|err| {
697 format!("sphere-hh alloc out (ld={ld_out}, cols={cols_out}): {err}")
698 })?;
699
700 let block_x: u32 = 128;
701 let grid_x: u32 = ((n as u32) + block_x - 1) / block_x;
702 let cfg = LaunchConfig {
703 grid_dim: (grid_x, 1, 1),
704 block_dim: (block_x, 1, 1),
705 shared_mem_bytes: 0,
706 };
707 let n_i32: i32 =
708 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sphere-hh n={n} overflows i32"))?;
709 let m_i32: i32 =
710 i32::try_from(m).map_err(|_| gam_gpu::gpu_err!("sphere-hh m={m} overflows i32"))?;
711 let ld_out_i64: i64 = ld_out as i64;
712
713 let mut builder = stream.launch_builder(&func);
714 builder
715 .arg(&data_dev)
716 .arg(¢ers_dev)
717 .arg(&coeffs_dev)
718 .arg(&v_dev)
719 .arg(&beta)
720 .arg(&n_i32)
721 .arg(&m_i32)
722 .arg(&ld_out_i64)
723 .arg(&mut out_dev);
724 unsafe { builder.launch(cfg) }.gpu_ctx("sphere-hh kernel launch")?;
727 stream
728 .synchronize()
729 .gpu_ctx("sphere-hh kernel synchronize")?;
730
731 Ok(DeviceS2KernelMatrix {
732 rows: n,
733 cols: cols_out,
734 ld: ld_out,
735 col_major_dev: out_dev,
736 stream,
737 })
738 }
739
740 #[cfg(not(target_os = "linux"))]
741 {
742 Err(GpuError::DriverLibraryUnavailable {
743 reason: "sphere GPU backend is Linux-only".to_string(),
744 })
745 }
746}
747
748pub fn householder_reflector_from_weights(w: &[f64]) -> (Vec<f64>, f64) {
761 let m = w.len();
762 if m == 0 {
763 return (Vec::new(), 0.0);
764 }
765 let norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
766 if norm == 0.0 {
767 return (vec![0.0; m], 0.0);
768 }
769 let sigma = if w[0] >= 0.0 { norm } else { -norm };
770 let mut v = w.to_vec();
771 v[0] += sigma;
772 let v0 = v[0];
773 if v0 == 0.0 {
774 return (vec![0.0; m], 0.0);
775 }
776 for entry in v.iter_mut() {
778 *entry /= v0;
779 }
780 let vv: f64 = v.iter().map(|x| x * x).sum();
782 let beta = 2.0 / vv;
783 (v, beta)
784}
785
786pub fn build_center_kernel_device(
805 centers_xyz: &[f64],
806 lmax: usize,
807 coeffs: &[f64],
808 kind: SphereSpectralKernelKind,
809) -> Result<DeviceS2KernelMatrix, GpuError> {
810 let m = centers_xyz.len() / 3;
811 if centers_xyz.len() != 3 * m {
812 return Err(GpuError::DriverCallFailed {
813 reason: "build_center_kernel_device: centers_xyz length not divisible by 3".into(),
814 });
815 }
816 let inputs = S2KernelBuildInputs {
817 n: m,
818 m,
819 lmax,
820 data_xyz: centers_xyz,
821 centers_xyz,
822 coeffs,
823 kind,
824 layout: DeviceMatrixLayout::ColumnMajor,
825 };
826 build_kernel_matrix_device(inputs)
827}
828
829pub fn constrained_penalty_host(
834 c: ArrayView2<'_, f64>,
835 w: &[f64],
836) -> Result<Array2<f64>, GpuError> {
837 let (m1, m2) = c.dim();
838 if m1 != m2 {
839 gam_gpu::gpu_bail!("constrained_penalty_host: C must be square, got {m1}x{m2}");
840 }
841 let m = m1;
842 if w.len() != m {
843 gam_gpu::gpu_bail!("constrained_penalty_host: w.len()={} != m={}", w.len(), m);
844 }
845 if m < 2 {
846 gam_gpu::gpu_bail!("constrained_penalty_host: m must be >= 2 (got {m})");
847 }
848 let (v, beta) = householder_reflector_from_weights(w);
849
850 let mut u = vec![0.0_f64; m];
853 for i in 0..m {
854 let mut acc = 0.0_f64;
855 for j in 0..m {
856 acc += c[(i, j)] * v[j];
857 }
858 u[i] = acc;
859 }
860 let vtcv: f64 = v.iter().zip(&u).map(|(vi, ui)| vi * ui).sum();
861 let mut hch = Array2::<f64>::zeros((m, m));
862 for i in 0..m {
863 for j in 0..m {
864 hch[(i, j)] =
865 c[(i, j)] - beta * (v[i] * u[j] + u[i] * v[j]) + beta * beta * vtcv * v[i] * v[j];
866 }
867 }
868 let mut s = Array2::<f64>::zeros((m - 1, m - 1));
870 for i in 0..(m - 1) {
871 for j in 0..(m - 1) {
872 s[(i, j)] = hch[(i + 1, j + 1)];
873 }
874 }
875 Ok(s)
876}
877
878#[derive(Clone, Debug)]
909pub struct PenalisedLsSolution {
910 pub beta: Vec<f64>,
912 pub weighted_residual_ssq: f64,
914 pub log_det_hessian: f64,
916}
917
918#[cfg(target_os = "linux")]
928pub fn solve_penalised_ls_device(
929 x_s_device: &DeviceS2KernelMatrix,
930 wy: &[f64],
931 r_s: ArrayView2<'_, f64>,
932) -> Result<PenalisedLsSolution, GpuError> {
933 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
934 use cudarc::driver::DevicePtrMut;
935
936 let n = x_s_device.rows;
937 let p = x_s_device.cols;
938 if wy.len() != n {
939 gam_gpu::gpu_bail!("solve_penalised_ls_device: wy.len()={} != n={n}", wy.len());
940 }
941 if r_s.dim() != (p, p) {
942 gam_gpu::gpu_bail!(
943 "solve_penalised_ls_device: r_s.dim()={:?} != ({p}, {p})",
944 r_s.dim()
945 );
946 }
947 if p == 0 {
948 return Ok(PenalisedLsSolution {
949 beta: Vec::new(),
950 weighted_residual_ssq: wy.iter().map(|v| v * v).sum(),
951 log_det_hessian: 0.0,
952 });
953 }
954
955 let stream = x_s_device.stream.clone();
956 let n_aug = n + p;
957
958 let mut a_aug_host = vec![0.0_f64; n_aug * p];
963 let mut x_host_colmajor = vec![0.0_f64; x_s_device.ld * p];
965 x_s_device.copy_to_host_col_major(&mut x_host_colmajor)?;
966 for j in 0..p {
967 let src_off = j * x_s_device.ld;
968 let dst_off = j * n_aug;
969 a_aug_host[dst_off..dst_off + n].copy_from_slice(&x_host_colmajor[src_off..src_off + n]);
970 for i in 0..p {
971 a_aug_host[dst_off + n + i] = r_s[(i, j)];
974 }
975 }
976 let mut a_dev = stream
977 .clone_htod(&a_aug_host)
978 .gpu_ctx("solve_penalised_ls_device htod A_aug")?;
979
980 let mut b_host = vec![0.0_f64; n_aug];
982 b_host[..n].copy_from_slice(wy);
983 let mut b_dev = stream
984 .clone_htod(&b_host)
985 .gpu_ctx("solve_penalised_ls_device htod b_aug")?;
986
987 let solver = DnHandle::new(stream.clone()).gpu_ctx("solve_penalised_ls_device DnHandle")?;
988 let n_aug_i: i32 = i32::try_from(n_aug)
989 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: n_aug={n_aug} overflows i32"))?;
990 let p_i: i32 = i32::try_from(p)
991 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: p={p} overflows i32"))?;
992
993 let mut lwork: i32 = 0;
995 {
996 let (a_ptr, _rec) = a_dev.device_ptr_mut(&stream);
997 let status = unsafe {
1000 cusolver_sys::cusolverDnDgeqrf_bufferSize(
1001 solver.cu(),
1002 n_aug_i,
1003 p_i,
1004 a_ptr as *mut f64,
1005 n_aug_i,
1006 &mut lwork,
1007 )
1008 };
1009 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1010 gam_gpu::gpu_bail!("cusolverDnDgeqrf_bufferSize status={status:?}");
1011 }
1012 }
1013 let lwork_us = usize::try_from(lwork)
1014 .map_err(|_| gam_gpu::gpu_err!("solve_penalised_ls_device: negative lwork={lwork}"))?;
1015 let mut workspace = stream
1016 .alloc_zeros::<f64>(lwork_us.max(1))
1017 .gpu_ctx("solve_penalised_ls_device alloc workspace")?;
1018 let mut tau = stream
1019 .alloc_zeros::<f64>(p)
1020 .gpu_ctx("solve_penalised_ls_device alloc tau")?;
1021 let mut info = stream
1022 .alloc_zeros::<i32>(1)
1023 .gpu_ctx("solve_penalised_ls_device alloc info")?;
1024
1025 {
1027 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1028 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1029 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1030 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1031 let status = unsafe {
1034 cusolver_sys::cusolverDnDgeqrf(
1035 solver.cu(),
1036 n_aug_i,
1037 p_i,
1038 a_ptr as *mut f64,
1039 n_aug_i,
1040 tau_ptr as *mut f64,
1041 work_ptr as *mut f64,
1042 lwork,
1043 info_ptr as *mut i32,
1044 )
1045 };
1046 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1047 gam_gpu::gpu_bail!("cusolverDnDgeqrf status={status:?}");
1048 }
1049 }
1050
1051 let mut ormqr_lwork: i32 = 0;
1053 {
1054 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1055 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1056 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1057 let status = unsafe {
1060 cusolver_sys::cusolverDnDormqr_bufferSize(
1061 solver.cu(),
1062 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1063 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1064 n_aug_i,
1065 1,
1066 p_i,
1067 a_ptr as *const f64,
1068 n_aug_i,
1069 tau_ptr as *const f64,
1070 b_ptr as *mut f64,
1071 n_aug_i,
1072 &mut ormqr_lwork,
1073 )
1074 };
1075 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1076 gam_gpu::gpu_bail!("cusolverDnDormqr_bufferSize status={status:?}");
1077 }
1078 }
1079 if ormqr_lwork > lwork {
1080 workspace = stream
1081 .alloc_zeros::<f64>(usize::try_from(ormqr_lwork).unwrap_or(1))
1082 .gpu_ctx("solve_penalised_ls_device realloc workspace ormqr")?;
1083 }
1084 {
1085 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1086 let (tau_ptr, _rec_t) = tau.device_ptr_mut(&stream);
1087 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1088 let (work_ptr, _rec_w) = workspace.device_ptr_mut(&stream);
1089 let (info_ptr, _rec_i) = info.device_ptr_mut(&stream);
1090 let status = unsafe {
1094 cusolver_sys::cusolverDnDormqr(
1095 solver.cu(),
1096 cusolver_sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1097 cusolver_sys::cublasOperation_t::CUBLAS_OP_T,
1098 n_aug_i,
1099 1,
1100 p_i,
1101 a_ptr as *const f64,
1102 n_aug_i,
1103 tau_ptr as *const f64,
1104 b_ptr as *mut f64,
1105 n_aug_i,
1106 work_ptr as *mut f64,
1107 ormqr_lwork.max(lwork),
1108 info_ptr as *mut i32,
1109 )
1110 };
1111 if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
1112 gam_gpu::gpu_bail!("cusolverDnDormqr status={status:?}");
1113 }
1114 }
1115
1116 {
1119 use cudarc::cublas::CudaBlas;
1120 let blas = CudaBlas::new(stream.clone()).gpu_ctx("solve_penalised_ls_device CudaBlas")?;
1121 let alpha = 1.0_f64;
1122 let (a_ptr, _rec_a) = a_dev.device_ptr_mut(&stream);
1123 let (b_ptr, _rec_b) = b_dev.device_ptr_mut(&stream);
1124 let handle = *blas.handle();
1129 let status = unsafe {
1130 cudarc::cublas::sys::cublasDtrsm_v2(
1131 handle,
1132 cudarc::cublas::sys::cublasSideMode_t::CUBLAS_SIDE_LEFT,
1133 cudarc::cublas::sys::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER,
1134 cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
1135 cudarc::cublas::sys::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
1136 p_i,
1137 1,
1138 &alpha,
1139 a_ptr as *const f64,
1140 n_aug_i,
1141 b_ptr as *mut f64,
1142 n_aug_i,
1143 )
1144 };
1145 if status != cudarc::cublas::sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1146 gam_gpu::gpu_bail!("cublasDtrsm_v2 status={status:?}");
1147 }
1148 }
1149
1150 let mut b_out = vec![0.0_f64; n_aug];
1152 stream
1153 .memcpy_dtoh(&b_dev, &mut b_out)
1154 .gpu_ctx("solve_penalised_ls_device dtoh b_out")?;
1155 let mut a_back = vec![0.0_f64; n_aug * p];
1156 stream
1157 .memcpy_dtoh(&a_dev, &mut a_back)
1158 .gpu_ctx("solve_penalised_ls_device dtoh A_back")?;
1159 stream
1160 .synchronize()
1161 .gpu_ctx("solve_penalised_ls_device synchronize")?;
1162
1163 let beta: Vec<f64> = b_out[..p].to_vec();
1164 let augmented_residual_ssq: f64 = b_out[p..].iter().map(|v| v * v).sum();
1173
1174 let mut log_abs_r = 0.0_f64;
1176 for k in 0..p {
1177 let r_kk = a_back[k * n_aug + k];
1178 log_abs_r += r_kk.abs().ln();
1179 }
1180 let log_det_hessian = 2.0 * log_abs_r;
1181
1182 Ok(PenalisedLsSolution {
1183 beta,
1184 weighted_residual_ssq: augmented_residual_ssq,
1185 log_det_hessian,
1186 })
1187}
1188
1189#[cfg(not(target_os = "linux"))]
1190pub fn solve_penalised_ls_device(
1191 x_s_device: &DeviceS2KernelMatrix,
1192 wy: &[f64],
1193 r_s: ArrayView2<'_, f64>,
1194) -> Result<PenalisedLsSolution, GpuError> {
1195 Err(GpuError::DriverLibraryUnavailable {
1196 reason: format!(
1197 "sphere GPU cuSOLVER QR path is Linux-only (n={}, p={}, wy.len()={}, r_s={:?})",
1198 x_s_device.rows,
1199 x_s_device.cols,
1200 wy.len(),
1201 r_s.dim()
1202 ),
1203 })
1204}
1205
1206#[cfg(test)]
1211mod sphere_gpu_tests {
1212 use super::*;
1213 use crate::basis::{
1214 SphereWahbaKernel, sobolev_s2_truncated_coefficients, sphere_truncated_spectral_eval,
1215 spherical_wahba_kernel_matrix_with_kind,
1216 };
1217 use ndarray::Array2;
1218
1219 fn small_latlon_grid(n_lat: usize, n_lon: usize) -> Array2<f64> {
1220 let mut rows = Vec::with_capacity(n_lat * n_lon);
1222 for i in 0..n_lat {
1223 let lat = -85.0 + (170.0 * i as f64) / (n_lat.saturating_sub(1).max(1) as f64);
1224 for j in 0..n_lon {
1225 let lon = -180.0 + (360.0 * j as f64) / (n_lon.saturating_sub(1).max(1) as f64);
1226 rows.push(lat);
1227 rows.push(lon);
1228 }
1229 }
1230 Array2::from_shape_vec((n_lat * n_lon, 2), rows).unwrap()
1231 }
1232
1233 #[test]
1234 fn xyz_preprocessing_matches_unit_sphere() {
1235 let latlon = ndarray::array![
1236 [0.0, 0.0],
1237 [90.0, 0.0],
1238 [0.0, 90.0],
1239 [-90.0, 17.5],
1240 [45.0, -120.0],
1241 ];
1242 let xyz = latlon_to_xyz_host(latlon.view(), false).expect("xyz");
1243 assert_eq!(xyz.len(), 3 * 5);
1244 for i in 0..5 {
1245 let nrm2 = xyz[3 * i] * xyz[3 * i]
1246 + xyz[3 * i + 1] * xyz[3 * i + 1]
1247 + xyz[3 * i + 2] * xyz[3 * i + 2];
1248 assert!((nrm2 - 1.0).abs() < 1e-15, "row {i} not unit norm: {nrm2}");
1249 }
1250 assert!((xyz[0] - 1.0).abs() < 1e-15);
1252 assert!(xyz[1].abs() < 1e-15);
1253 assert!(xyz[2].abs() < 1e-15);
1254 assert!(xyz[3].abs() < 1e-15);
1256 assert!(xyz[4].abs() < 1e-15);
1257 assert!((xyz[5] - 1.0).abs() < 1e-15);
1258 assert!(xyz[6].abs() < 1e-15);
1260 assert!((xyz[7] - 1.0).abs() < 1e-15);
1261 assert!(xyz[8].abs() < 1e-15);
1262 }
1263
1264 #[test]
1265 fn truncated_spectral_at_same_point_matches_sum_of_coefficients() {
1266 for m_penalty in 1..=4 {
1270 for &lmax in &[5_usize, 20, 50] {
1271 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1272 let expected: f64 = coeffs.iter().sum();
1273 let got = sphere_truncated_spectral_eval(1.0, &coeffs);
1274 assert!(
1275 (got - expected).abs() < 1e-13,
1276 "K(x,x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1277 );
1278 }
1279 }
1280 }
1281
1282 #[test]
1283 fn truncated_spectral_at_antipode_matches_alternating_sum() {
1284 for m_penalty in 1..=4 {
1287 for &lmax in &[5_usize, 20, 50] {
1288 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1289 let expected: f64 = coeffs
1290 .iter()
1291 .enumerate()
1292 .map(|(ell, c)| if ell % 2 == 0 { *c } else { -*c })
1293 .sum();
1294 let got = sphere_truncated_spectral_eval(-1.0, &coeffs);
1295 assert!(
1296 (got - expected).abs() < 1e-13,
1297 "K(x,-x) identity broken at m={m_penalty}, L={lmax}: got {got:.6e}, expected {expected:.6e}"
1298 );
1299 }
1300 }
1301 }
1302
1303 #[test]
1304 fn truncated_spectral_matrix_is_symmetric() {
1305 let centers = ndarray::array![
1309 [10.0_f64, 20.0],
1310 [-30.0, 100.0],
1311 [45.0, -60.0],
1312 [-89.0, 0.0],
1313 [0.0, 180.0],
1314 [60.0, -179.9],
1315 ];
1316 for m_penalty in [1usize, 2, 4] {
1317 for &lmax in &[10_usize, 30] {
1318 let mat = spherical_wahba_kernel_matrix_with_kind(
1319 centers.view(),
1320 centers.view(),
1321 m_penalty,
1322 false,
1323 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1324 )
1325 .expect("kernel matrix");
1326 let n = centers.nrows();
1327 let mut max_asym = 0.0_f64;
1328 for i in 0..n {
1329 for j in 0..n {
1330 let d = (mat[(i, j)] - mat[(j, i)]).abs();
1331 if d > max_asym {
1332 max_asym = d;
1333 }
1334 }
1335 }
1336 assert!(
1337 max_asym < 1e-13,
1338 "K not symmetric at m={m_penalty}, L={lmax}: max |K - Kᵀ| = {max_asym:.3e}"
1339 );
1340 }
1341 }
1342 }
1343
1344 #[test]
1345 fn truncated_coefficients_have_zero_constant_mode() {
1346 for m in 1..=4 {
1347 let c = sobolev_s2_truncated_coefficients(50, m);
1348 assert_eq!(c.len(), 51);
1349 assert_eq!(c[0], 0.0);
1350 assert!(c[1] > 0.0);
1351 for ell in 2..=50 {
1353 assert!(
1354 c[ell] < c[ell - 1] + 1e-15,
1355 "Sobolev coefficient not non-increasing at m={m}, ell={ell}: {} vs {}",
1356 c[ell],
1357 c[ell - 1]
1358 );
1359 }
1360 }
1361 }
1362
1363 #[test]
1364 fn truncated_spectral_matches_matrix_helper() {
1365 let m_penalty = 2;
1369 let lmax = 20;
1370 let coeffs = sobolev_s2_truncated_coefficients(lmax, m_penalty);
1371 let data = ndarray::array![[12.5, -34.0]];
1372 let centers = ndarray::array![[40.0, 10.0]];
1373 let mat = spherical_wahba_kernel_matrix_with_kind(
1374 data.view(),
1375 centers.view(),
1376 m_penalty,
1377 false,
1378 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1379 )
1380 .expect("kernel matrix");
1381 let xyz_d = latlon_to_xyz_host(data.view(), false).unwrap();
1383 let xyz_c = latlon_to_xyz_host(centers.view(), false).unwrap();
1384 let cos_g = xyz_d[0] * xyz_c[0] + xyz_d[1] * xyz_c[1] + xyz_d[2] * xyz_c[2];
1385 let expected = sphere_truncated_spectral_eval(cos_g, &coeffs);
1386 assert!(
1387 (mat[(0, 0)] - expected).abs() < 1e-13,
1388 "matrix helper differs from scalar evaluator: {} vs {}",
1389 mat[(0, 0)],
1390 expected
1391 );
1392 }
1393
1394 #[test]
1395 fn constrained_penalty_is_symmetric_and_drops_constraint_direction() {
1396 let m = 6;
1401 let mut c = Array2::<f64>::zeros((m, m));
1402 for i in 0..m {
1403 for j in 0..m {
1404 let d = (i as f64 - j as f64).abs();
1405 c[(i, j)] = (-0.5 * d).exp();
1406 }
1407 }
1408 let w = vec![1.0_f64; m];
1409 let s = constrained_penalty_host(c.view(), &w).expect("constrained S");
1410 assert_eq!(s.dim(), (m - 1, m - 1));
1411 let mut max_asym = 0.0_f64;
1413 for i in 0..(m - 1) {
1414 for j in 0..(m - 1) {
1415 let d = (s[(i, j)] - s[(j, i)]).abs();
1416 if d > max_asym {
1417 max_asym = d;
1418 }
1419 }
1420 }
1421 assert!(
1422 max_asym < 1e-13,
1423 "S not symmetric: max |S - Sᵀ| = {max_asym:.3e}"
1424 );
1425
1426 let ones = ndarray::Array1::<f64>::ones(m - 1);
1434 let sx = s.dot(&ones);
1435 assert!(sx.iter().all(|v| v.is_finite()));
1436 }
1437
1438 #[test]
1439 fn householder_reflector_zeroes_target_vector() {
1440 let w = vec![3.0, 4.0, 0.0, -1.0];
1441 let (v, beta) = householder_reflector_from_weights(&w);
1442 let dot: f64 = v.iter().zip(&w).map(|(a, b)| a * b).sum();
1445 let hw: Vec<f64> = w
1446 .iter()
1447 .zip(&v)
1448 .map(|(wj, vj)| wj - beta * dot * vj)
1449 .collect();
1450 for entry in hw.iter().skip(1) {
1451 assert!(entry.abs() < 1e-12, "H · w not e_1 multiple: {hw:?}");
1452 }
1453 assert!(hw[0].abs() > 0.0);
1454 }
1455
1456 #[test]
1459 fn sphere_gpu_raw_kernel_parity_vs_cpu_truncated() {
1460 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1461 eprintln!("[sphere_gpu test] no CUDA runtime — skipping raw-kernel parity");
1462 return;
1463 };
1464 SphereGpuBackend::probe()
1467 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1468
1469 let data_ll = small_latlon_grid(7, 9);
1470 let centers_ll = small_latlon_grid(5, 7);
1471 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1472 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1473 let n = data_ll.nrows();
1474 let m = centers_ll.nrows();
1475 let penalty = 2usize;
1476 let lmax = 20usize;
1477 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1478
1479 let inputs = S2KernelBuildInputs {
1480 n,
1481 m,
1482 lmax,
1483 data_xyz: &data_xyz,
1484 centers_xyz: ¢ers_xyz,
1485 coeffs: &coeffs,
1486 kind: SphereSpectralKernelKind::Sobolev,
1487 layout: DeviceMatrixLayout::ColumnMajor,
1488 };
1489 let dev_mat = build_kernel_matrix_device(inputs).expect("device kernel matrix");
1490 let gpu = dev_mat.to_host_array().expect("dtoh kernel matrix");
1491
1492 let cpu = spherical_wahba_kernel_matrix_with_kind(
1493 data_ll.view(),
1494 centers_ll.view(),
1495 penalty,
1496 false,
1497 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1498 )
1499 .expect("cpu kernel matrix");
1500
1501 let mut max_abs = 0.0_f64;
1502 for i in 0..n {
1503 for j in 0..m {
1504 let d = (gpu[(i, j)] - cpu[(i, j)]).abs();
1505 if d > max_abs {
1506 max_abs = d;
1507 }
1508 }
1509 }
1510 assert!(
1511 max_abs < 1e-11,
1512 "GPU vs CPU truncated parity max |Δ| = {max_abs:.3e} >= 1e-11"
1513 );
1514 }
1515
1516 #[test]
1526 fn sphere_gpu_end_to_end_dispatch_parity_vs_cpu_truncated() {
1527 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1528 eprintln!("[sphere_gpu test] no CUDA runtime — skipping end-to-end dispatch parity");
1529 return;
1530 };
1531 SphereGpuBackend::probe()
1535 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1536 use crate::basis::{
1537 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1538 build_spherical_spline_basis, sobolev_s2_truncated_coefficients,
1539 };
1540 drop(sobolev_s2_truncated_coefficients(1, 1));
1541
1542 let data = small_latlon_grid(100, 100);
1544 let lmax: u16 = 30;
1545 let penalty_order = 2usize;
1546 let spec_gpu = SphericalSplineBasisSpec {
1547 center_strategy: CenterStrategy::FarthestPoint { num_centers: 200 },
1548 penalty_order,
1549 double_penalty: false,
1550 radians: false,
1551 method: SphereMethod::Wahba,
1552 max_degree: None,
1553 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1554 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1555 };
1556 let result_gpu = build_spherical_spline_basis(data.view(), &spec_gpu)
1557 .expect("GPU-eligible build_spherical_spline_basis succeeds");
1558
1559 let centers =
1567 crate::basis::select_spherical_farthest_point_centers(data.view(), 200, false)
1568 .expect("centers");
1569 let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1570 data.view(),
1571 centers.view(),
1572 penalty_order,
1573 false,
1574 SphereWahbaKernel::SobolevTruncated { lmax },
1575 )
1576 .expect("cpu raw design");
1577
1578 let z = Array2::<f64>::eye(centers.nrows());
1581 let cpu_design = raw_cpu.dot(&z);
1582
1583 let gpu_design = result_gpu.design.as_dense().expect("dense design").clone();
1584
1585 assert_eq!(gpu_design.dim(), cpu_design.dim());
1586 let mut max_abs = 0.0_f64;
1587 let mut max_rel = 0.0_f64;
1588 for ((g, c), _) in gpu_design.iter().zip(cpu_design.iter()).zip(0..) {
1589 let d = (g - c).abs();
1590 if d > max_abs {
1591 max_abs = d;
1592 }
1593 let denom = g.abs().max(c.abs()).max(1e-300);
1594 let r = d / denom;
1595 if r > max_rel {
1596 max_rel = r;
1597 }
1598 }
1599 assert!(
1600 max_rel < 1e-9,
1601 "end-to-end design parity max relative |Δ| = {max_rel:.3e} >= 1e-9 (abs {max_abs:.3e})"
1602 );
1603 }
1604
1605 #[test]
1608 fn sphere_gpu_householder_parity_vs_raw_dot_z() {
1609 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1610 eprintln!("[sphere_gpu test] no CUDA runtime — skipping householder parity");
1611 return;
1612 };
1613 SphereGpuBackend::probe()
1616 .expect("[sphere_gpu test] backend probe must succeed on a CUDA host");
1617 let data_ll = small_latlon_grid(6, 8);
1618 let centers_ll = small_latlon_grid(4, 5);
1619 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1620 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1621 let n = data_ll.nrows();
1622 let m = centers_ll.nrows();
1623 let penalty = 2usize;
1624 let lmax = 15usize;
1625 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty);
1626
1627 let inputs_raw = S2KernelBuildInputs {
1629 n,
1630 m,
1631 lmax,
1632 data_xyz: &data_xyz,
1633 centers_xyz: ¢ers_xyz,
1634 coeffs: &coeffs,
1635 kind: SphereSpectralKernelKind::Sobolev,
1636 layout: DeviceMatrixLayout::ColumnMajor,
1637 };
1638 let b_dev = build_kernel_matrix_device(inputs_raw.clone()).expect("raw kernel");
1639 let b = b_dev.to_host_array().expect("dtoh raw");
1640
1641 let w = vec![1.0_f64; m];
1644 let (v, beta) = householder_reflector_from_weights(&w);
1645
1646 let mut xs_host = Array2::<f64>::zeros((n, m - 1));
1648 for i in 0..n {
1649 let d_i: f64 = (0..m).map(|j| v[j] * b[(i, j)]).sum();
1650 for j_out in 0..(m - 1) {
1651 xs_host[(i, j_out)] = b[(i, j_out + 1)] - beta * d_i * v[j_out + 1];
1652 }
1653 }
1654
1655 let xs_dev =
1656 build_householder_constrained_design_device(inputs_raw, &v, beta).expect("hh design");
1657 let xs_gpu = xs_dev.to_host_array().expect("dtoh hh");
1658
1659 let mut max_abs = 0.0_f64;
1660 for i in 0..n {
1661 for j in 0..(m - 1) {
1662 let d = (xs_host[(i, j)] - xs_gpu[(i, j)]).abs();
1663 if d > max_abs {
1664 max_abs = d;
1665 }
1666 }
1667 }
1668 assert!(
1669 max_abs < 1e-12,
1670 "Householder fused parity max |Δ| = {max_abs:.3e} >= 1e-12"
1671 );
1672 }
1673
1674 #[test]
1678 fn sphere_gpu_kernel_matrix_hill_climb_20x_vs_cpu() {
1679 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1680 eprintln!("[sphere_gpu hill-climb] no CUDA runtime — skipping");
1681 return;
1682 };
1683 if SphereGpuBackend::probe().is_err() {
1684 eprintln!("[sphere_gpu hill-climb] backend probe failed — skipping");
1685 return;
1686 }
1687
1688 let n_lat = 500usize;
1691 let n_lon = 400usize;
1692 assert_eq!(n_lat * n_lon, 200_000);
1693 let data_ll = small_latlon_grid(n_lat, n_lon);
1694 let m = 200usize;
1695 let centers_ll =
1696 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1697 .expect("centers");
1698 let n = data_ll.nrows();
1699 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).unwrap();
1700 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).unwrap();
1701 let penalty_order = 2usize;
1702 let lmax = 50usize;
1703 let coeffs = sobolev_s2_truncated_coefficients(lmax, penalty_order);
1704
1705 let inputs_warm = S2KernelBuildInputs {
1707 n,
1708 m,
1709 lmax,
1710 data_xyz: &data_xyz,
1711 centers_xyz: ¢ers_xyz,
1712 coeffs: &coeffs,
1713 kind: SphereSpectralKernelKind::Sobolev,
1714 layout: DeviceMatrixLayout::ColumnMajor,
1715 };
1716 drop(build_kernel_matrix_device(inputs_warm.clone()).expect("warmup"));
1717
1718 let t0 = std::time::Instant::now();
1720 let dev = build_kernel_matrix_device(inputs_warm.clone()).expect("gpu kernel matrix");
1721 let _host_gpu = dev.to_host_array().expect("dtoh");
1722 let gpu_secs = t0.elapsed().as_secs_f64();
1723
1724 let t1 = std::time::Instant::now();
1726 let _cpu = spherical_wahba_kernel_matrix_with_kind(
1727 data_ll.view(),
1728 centers_ll.view(),
1729 penalty_order,
1730 false,
1731 SphereWahbaKernel::SobolevTruncated { lmax: lmax as u16 },
1732 )
1733 .expect("cpu kernel matrix");
1734 let cpu_secs = t1.elapsed().as_secs_f64();
1735
1736 let ratio = cpu_secs / gpu_secs.max(1e-9);
1737 eprintln!(
1738 "[sphere_gpu hill-climb] n={n} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x"
1739 );
1740 assert!(
1741 ratio >= 20.0,
1742 "GPU kernel matrix only {ratio:.2}× faster than CPU (target ≥ 20×) at \
1743 n={n} m={m} L={lmax}: cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1744 );
1745 }
1746
1747 #[test]
1752 fn sphere_gpu_end_to_end_fit_hill_climb_10x_vs_cpu() {
1753 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1754 eprintln!("[sphere_gpu hill-climb fit] no CUDA runtime — skipping");
1755 return;
1756 };
1757 if SphereGpuBackend::probe().is_err() {
1758 eprintln!("[sphere_gpu hill-climb fit] backend probe failed — skipping");
1759 return;
1760 }
1761 use crate::basis::{
1762 CenterStrategy, SphereMethod, SphericalSplineBasisSpec, SphericalSplineIdentifiability,
1763 build_spherical_spline_basis,
1764 };
1765
1766 let n_lat = 500usize;
1767 let n_lon = 400usize;
1768 let data_ll = small_latlon_grid(n_lat, n_lon);
1769 let m: usize = 200;
1770 let lmax: u16 = 50;
1771 let spec_gpu = SphericalSplineBasisSpec {
1772 center_strategy: CenterStrategy::FarthestPoint { num_centers: m },
1773 penalty_order: 2,
1774 double_penalty: false,
1775 radians: false,
1776 method: SphereMethod::Wahba,
1777 max_degree: None,
1778 wahba_kernel: SphereWahbaKernel::SobolevTruncated { lmax },
1779 identifiability: SphericalSplineIdentifiability::CenterSumToZero,
1780 };
1781
1782 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("warmup build"));
1784
1785 let t0 = std::time::Instant::now();
1786 drop(build_spherical_spline_basis(data_ll.view(), &spec_gpu).expect("gpu build"));
1787 let gpu_secs = t0.elapsed().as_secs_f64();
1788
1789 let centers =
1796 crate::basis::select_spherical_farthest_point_centers(data_ll.view(), m, false)
1797 .expect("centers");
1798 let z = Array2::<f64>::eye(centers.nrows());
1799 let t1 = std::time::Instant::now();
1800 let raw_cpu = spherical_wahba_kernel_matrix_with_kind(
1801 data_ll.view(),
1802 centers.view(),
1803 2,
1804 false,
1805 SphereWahbaKernel::SobolevTruncated { lmax },
1806 )
1807 .expect("cpu raw");
1808 let _design_cpu = raw_cpu.dot(&z);
1809 let cpu_secs = t1.elapsed().as_secs_f64();
1810
1811 let ratio = cpu_secs / gpu_secs.max(1e-9);
1812 eprintln!(
1813 "[sphere_gpu hill-climb fit] n={} m={m} L={lmax} cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s ratio={ratio:.2}x",
1814 data_ll.nrows()
1815 );
1816 assert!(
1817 ratio >= 10.0,
1818 "End-to-end sphere fit only {ratio:.2}× faster on GPU (target ≥ 10×): \
1819 cpu={cpu_secs:.3}s gpu={gpu_secs:.3}s"
1820 );
1821 }
1822
1823 #[test]
1845 fn sphere_gpu_end_to_end_fit_parity_vs_cpu_truncated() {
1846 use crate::basis::{
1847 select_spherical_farthest_point_centers, spherical_wahba_kernel_matrix_with_kind,
1848 };
1849 use faer::Side;
1850 use gam_linalg::faer_ndarray::FaerCholesky;
1851
1852 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
1853 eprintln!(
1854 "[sphere gpu parity] no CUDA runtime — skipping device parity \
1855 (CPU oracle exercised by sibling tests)"
1856 );
1857 return;
1858 };
1859 SphereGpuBackend::probe()
1862 .expect("[sphere gpu parity] sphere GPU backend probe must succeed on a CUDA host");
1863
1864 let data_ll = small_latlon_grid(25, 40);
1866 assert_eq!(data_ll.nrows(), 1000);
1867 let n = data_ll.nrows();
1868 let m: usize = 80;
1869 let lmax_u16: u16 = 15;
1870 let lmax: usize = lmax_u16 as usize;
1871 let penalty_order: usize = 2;
1872 let kernel = SphereWahbaKernel::SobolevTruncated { lmax: lmax_u16 };
1873 let lambda: f64 = 1.0e-3;
1874
1875 let centers_ll = select_spherical_farthest_point_centers(data_ll.view(), m, false)
1877 .expect("farthest-point centers");
1878 assert_eq!(centers_ll.nrows(), m);
1879
1880 let z = Array2::<f64>::eye(centers_ll.nrows());
1883 let p = z.ncols();
1884 assert_eq!(p, m);
1885
1886 let k_cc = spherical_wahba_kernel_matrix_with_kind(
1891 centers_ll.view(),
1892 centers_ll.view(),
1893 penalty_order,
1894 false,
1895 kernel,
1896 )
1897 .expect("centers×centers kernel");
1898 let s_full = z.t().dot(&k_cc).dot(&z);
1899
1900 let raw_design_cpu = spherical_wahba_kernel_matrix_with_kind(
1902 data_ll.view(),
1903 centers_ll.view(),
1904 penalty_order,
1905 false,
1906 kernel,
1907 )
1908 .expect("CPU raw design");
1909 let x_s_cpu = raw_design_cpu.dot(&z);
1910
1911 let data_xyz = latlon_to_xyz_host(data_ll.view(), false).expect("data xyz");
1913 let centers_xyz = latlon_to_xyz_host(centers_ll.view(), false).expect("centers xyz");
1914 let coeffs = crate::basis::sobolev_s2_truncated_coefficients(lmax, penalty_order);
1915 let inputs = S2KernelBuildInputs {
1916 n,
1917 m,
1918 lmax,
1919 data_xyz: &data_xyz,
1920 centers_xyz: ¢ers_xyz,
1921 coeffs: &coeffs,
1922 kind: SphereSpectralKernelKind::Sobolev,
1923 layout: DeviceMatrixLayout::ColumnMajor,
1924 };
1925 let raw_dev = build_kernel_matrix_device(inputs).expect("GPU raw design");
1926 let raw_design_gpu = raw_dev.to_host_array().expect("dtoh GPU raw design");
1927 let x_s_gpu = raw_design_gpu.dot(&z);
1928
1929 assert_eq!(x_s_cpu.dim(), (n, p));
1930 assert_eq!(x_s_gpu.dim(), (n, p));
1931
1932 let mut y = ndarray::Array1::<f64>::zeros(n);
1938 for i in 0..n {
1939 let lat_rad = data_ll[(i, 0)].to_radians();
1940 let lon_rad = data_ll[(i, 1)].to_radians();
1941 y[i] = (2.0 * lat_rad).sin() * (3.0 * lon_rad).cos()
1943 + 0.25 * lat_rad.cos() * (5.0 * lon_rad).sin();
1944 }
1945
1946 let solve_penalised = |x_s: &ndarray::Array2<f64>| -> ndarray::Array1<f64> {
1951 let xtx = x_s.t().dot(x_s);
1952 let mut a = xtx;
1953 for i in 0..p {
1954 for j in 0..p {
1955 a[(i, j)] += lambda * s_full[(i, j)];
1956 }
1957 }
1958 let rhs = x_s.t().dot(&y);
1959 let factor = a
1960 .cholesky(Side::Lower)
1961 .expect("penalised normal equations are SPD under λ > 0");
1962 factor.solvevec(&rhs)
1963 };
1964
1965 let beta_cpu = solve_penalised(&x_s_cpu);
1966 let beta_gpu = solve_penalised(&x_s_gpu);
1967 assert_eq!(beta_cpu.len(), p);
1968 assert_eq!(beta_gpu.len(), p);
1969
1970 let yhat_cpu = x_s_cpu.dot(&beta_cpu);
1974 let yhat_gpu = x_s_gpu.dot(&beta_gpu);
1975
1976 let mut max_beta_delta = 0.0_f64;
1977 for k in 0..p {
1978 let d = (beta_cpu[k] - beta_gpu[k]).abs();
1979 if d > max_beta_delta {
1980 max_beta_delta = d;
1981 }
1982 }
1983 let mut max_fit_delta = 0.0_f64;
1984 for i in 0..n {
1985 let d = (yhat_cpu[i] - yhat_gpu[i]).abs();
1986 if d > max_fit_delta {
1987 max_fit_delta = d;
1988 }
1989 }
1990
1991 eprintln!(
1992 "[sphere_gpu fit parity] n={n} m={m} p={p} lmax={lmax} λ={lambda:.1e} \
1993 max|Δβ|={max_beta_delta:.3e} max|Δŷ|={max_fit_delta:.3e}"
1994 );
1995
1996 assert!(
1997 max_beta_delta <= 1.0e-9,
1998 "GPU vs CPU truncated-spectral coefficient max |Δ| = {max_beta_delta:.3e} > 1e-9"
1999 );
2000 assert!(
2001 max_fit_delta <= 1.0e-9,
2002 "GPU vs CPU truncated-spectral fitted-value max |Δ| = {max_fit_delta:.3e} > 1e-9"
2003 );
2004 }
2005}