Skip to main content

gam_terms/basis/
workspace_cache.rs

1use super::*;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4pub(crate) struct ConstraintNullspaceCacheKey {
5    pub(crate) centersrows: usize,
6    pub(crate) centers_cols: usize,
7    pub(crate) centers_hash: u64,
8    pub(crate) order: ConstraintNullspaceOrderKey,
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub(crate) enum ConstraintNullspaceOrderKey {
13    Duchon(DuchonNullspaceOrder),
14    ThinPlate,
15}
16
17#[derive(Default, Clone, Debug)]
18pub(crate) struct ConstraintNullspaceCache {
19    pub(crate) map: HashMap<ConstraintNullspaceCacheKey, Arc<Array2<f64>>>,
20    pub(crate) order: Vec<ConstraintNullspaceCacheKey>,
21}
22
23pub(crate) const CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES: usize = 32;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub(crate) struct OwnedDataCacheKey {
27    pub(crate) rows: usize,
28    pub(crate) cols: usize,
29    pub(crate) ptr: usize,
30    pub(crate) stride0: isize,
31    pub(crate) stride1: isize,
32}
33
34#[derive(Debug)]
35pub(crate) struct BasisCacheContext {
36    pub(crate) constraint_nullspace: ConstraintNullspaceCache,
37    pub(crate) owned_data: gam_runtime::resource::ByteLruCache<OwnedDataCacheKey, Arc<Array2<f64>>>,
38}
39
40impl BasisCacheContext {
41    pub(crate) fn with_policy(policy: &gam_runtime::resource::ResourcePolicy) -> Self {
42        Self {
43            constraint_nullspace: ConstraintNullspaceCache::default(),
44            owned_data: gam_runtime::resource::ByteLruCache::with_max_entries(
45                policy.max_owned_data_cache_bytes,
46                gam_runtime::resource::OWNED_DATA_CACHE_MAX_ENTRIES,
47            ),
48        }
49    }
50}
51
52impl Default for BasisCacheContext {
53    fn default() -> Self {
54        Self::with_policy(&gam_runtime::resource::ResourcePolicy::default_library())
55    }
56}
57
58/// Explicit per-run workspace for reusable basis-construction caches.
59///
60/// Pass one workspace through repeated basis builds to avoid global mutable state
61/// and to keep caching scoped to a caller-controlled lifecycle.
62///
63/// Owned-data cache entries are byte-limited via the
64/// [`gam_runtime::resource::ResourcePolicy`] provided at construction; use
65/// [`BasisWorkspace::with_policy`] for large-scale workloads where a single
66/// entry can be multiple gigabytes.
67#[derive(Debug)]
68pub struct BasisWorkspace {
69    pub(crate) cache: BasisCacheContext,
70    pub(crate) policy: gam_runtime::resource::ResourcePolicy,
71}
72
73impl BasisWorkspace {
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    pub fn with_policy(policy: gam_runtime::resource::ResourcePolicy) -> Self {
79        Self {
80            cache: BasisCacheContext::with_policy(&policy),
81            policy,
82        }
83    }
84
85    pub fn default_library() -> Self {
86        Self::with_policy(gam_runtime::resource::ResourcePolicy::default_library())
87    }
88
89    /// Returns the resource policy this workspace was configured with.
90    pub fn policy(&self) -> &gam_runtime::resource::ResourcePolicy {
91        &self.policy
92    }
93}
94
95impl Default for BasisWorkspace {
96    fn default() -> Self {
97        Self::default_library()
98    }
99}
100
101pub(crate) fn hash_arrayview2(values: ArrayView2<'_, f64>) -> u64 {
102    let mut hasher = DefaultHasher::new();
103    values.nrows().hash(&mut hasher);
104    values.ncols().hash(&mut hasher);
105    for v in values {
106        v.to_bits().hash(&mut hasher);
107    }
108    hasher.finish()
109}
110
111pub(crate) fn shared_owned_data_matrix(
112    data: ArrayView2<'_, f64>,
113    cache: &BasisCacheContext,
114) -> Arc<Array2<f64>> {
115    let key = OwnedDataCacheKey {
116        rows: data.nrows(),
117        cols: data.ncols(),
118        ptr: data.as_ptr() as usize,
119        stride0: data.strides()[0],
120        stride1: data.strides()[1],
121    };
122    if let Some(hit) = cache.owned_data.get(&key) {
123        return hit;
124    }
125
126    let owned = Arc::new(data.to_owned());
127    if let Some(hit) = cache.owned_data.get(&key) {
128        return hit;
129    }
130
131    cache.owned_data.insert(key, owned.clone());
132    owned
133}
134
135/// Minimal cache-less intern: wraps an `ArrayView2` into an `Arc<Array2<f64>>`.
136///
137/// Used by derivative-operator builders that don't have a `BasisCacheContext`
138/// in scope (e.g. `build_aniso_design_psi_derivatives_shared`). The goal is the
139/// same as `shared_owned_data_matrix`: move the owned payload into an `Arc`
140/// once so that downstream `StreamingRadialState` copies share it via
141/// `Arc::clone` instead of materializing a fresh n×d `Array2<f64>` per axis.
142#[inline]
143pub(crate) fn shared_owned_data_matrix_from_view(data: ArrayView2<'_, f64>) -> Arc<Array2<f64>> {
144    Arc::new(data.to_owned())
145}
146
147/// Minimal cache-less intern for knot centers; mirrors
148/// `shared_owned_data_matrix_from_view`. Centers are typically k×d with k
149/// much smaller than n, but the `Arc::clone` pattern still avoids a k×d
150/// copy per axis when the same operator feeds multiple derivative paths.
151#[inline]
152pub(crate) fn shared_owned_centers_matrix_from_view(
153    centers: ArrayView2<'_, f64>,
154) -> Arc<Array2<f64>> {
155    Arc::new(centers.to_owned())
156}
157
158/// Compute the kernel reparameterisation transform `Z = null(P_centers^T)`.
159///
160/// `Z` is a `(k, k − C(d+r, r))` orthonormal matrix whose columns span the
161/// null space of the polynomial side-condition system.  Reparameterising the
162/// radial kernel coefficients as `α = Z γ` enforces `P_centers^T α = 0` and
163/// reduces the kernel column count from `k` to `k − C(d+r, r)`.
164///
165/// After this projection the polynomial block `P_data` is appended as separate
166/// explicit unpenalized columns (see `build_duchon_basis_designwithworkspace`),
167/// so the pre-identifiability total width is always `k` (equal to the center
168/// count), regardless of the polynomial null-space dimension.
169///
170/// This is the step that absorbs the full `C(d+r, r)`-dimensional polynomial
171/// null space.  The subsequent `spatial_parametric_constraint_block` step only
172/// removes the intercept.
173pub(crate) fn kernel_constraint_nullspace(
174    centers: ArrayView2<'_, f64>,
175    order: DuchonNullspaceOrder,
176    cache: &mut BasisCacheContext,
177) -> Result<Array2<f64>, BasisError> {
178    let effective_order = duchon_effective_nullspace_order(centers, order);
179    let degraded = effective_order != order;
180    // Translation-invariant side-condition frame (#1375, mirroring the #1269 tp
181    // fix). `Z = null(P(centers)ᵀ)` is mathematically invariant to subtracting a
182    // per-axis constant from `centers` (the polynomial columns `{1, x, …}` and
183    // `{1, x − x̄, …}` span the same space, so `P` has the same column space and
184    // `P^T` the same null space), but the RRQR pivoting that materialises `Z`
185    // drifts under a large coordinate mean — landing on a different orthonormal
186    // basis of the SAME null space, which would desync the design `K·Z` from the
187    // penalty `ZᵀK_CC Z` across a covariate translation. Subtract the center-cloud
188    // per-axis mean so the factorisation is location-standardized; both a raw and
189    // an already-centered caller then produce bit-identical `Z`. The mean is a
190    // fixed property of the (frozen `UserProvided`) centers, replayed identically
191    // at predict.
192    let k = centers.nrows();
193    let d = centers.ncols();
194    let center_mean: Vec<f64> = (0..d)
195        .map(|c| centers.column(c).sum() / (k.max(1) as f64))
196        .collect();
197    let mut centers_centered = centers.to_owned();
198    for c in 0..d {
199        let mu = center_mean[c];
200        centers_centered.column_mut(c).mapv_inplace(|v| v - mu);
201    }
202    let centers = centers_centered.view();
203    let key = ConstraintNullspaceCacheKey {
204        centersrows: centers.nrows(),
205        centers_cols: centers.ncols(),
206        centers_hash: hash_arrayview2(centers),
207        order: ConstraintNullspaceOrderKey::Duchon(effective_order),
208    };
209
210    if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
211        return Ok((**hit).clone());
212    }
213
214    let p_k = polynomial_block_from_order(centers, effective_order);
215    let z = Arc::new(kernel_constraint_nullspace_from_matrix(p_k.view()).map_err(|err| {
216        if degraded {
217            BasisError::InvalidInput(format!(
218                "Duchon degraded from order={:?} to order={:?} due to insufficient centers ({} in dim={}); order={:?} construction then failed: {err}",
219                order,
220                effective_order,
221                centers.nrows(),
222                centers.ncols(),
223                effective_order,
224            ))
225        } else {
226            err
227        }
228    })?);
229
230    if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
231        return Ok((**hit).clone());
232    }
233    cache.constraint_nullspace.map.insert(key, z.clone());
234    cache.constraint_nullspace.order.push(key);
235    while cache.constraint_nullspace.map.len() > CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES {
236        if cache.constraint_nullspace.order.is_empty() {
237            break;
238        }
239        let oldkey = cache.constraint_nullspace.order.remove(0);
240        cache.constraint_nullspace.map.remove(&oldkey);
241    }
242
243    Ok((*z).clone())
244}
245
246pub(crate) fn thin_plate_kernel_constraint_nullspace(
247    centers: ArrayView2<'_, f64>,
248    cache: &mut BasisCacheContext,
249) -> Result<Array2<f64>, BasisError> {
250    let key = ConstraintNullspaceCacheKey {
251        centersrows: centers.nrows(),
252        centers_cols: centers.ncols(),
253        centers_hash: hash_arrayview2(centers),
254        order: ConstraintNullspaceOrderKey::ThinPlate,
255    };
256
257    if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
258        return Ok((**hit).clone());
259    }
260
261    let p_k = thin_plate_polynomial_block(centers);
262    if centers.nrows() < p_k.ncols() {
263        crate::bail_invalid_basis!(
264            "thin-plate spline requires at least {} centers to span the degree-{} polynomial null space in dimension {}; got {}",
265            p_k.ncols(),
266            thin_plate_polynomial_degree(centers.ncols()),
267            centers.ncols(),
268            centers.nrows()
269        );
270    }
271    let (z, rank) =
272        rrqr_nullspace_basis(&p_k, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
273    if rank != p_k.ncols() {
274        crate::bail_invalid_basis!(
275            "thin-plate spline polynomial block is rank deficient at the selected centers: expected rank {}, got {}; choose geometrically independent centers for dimension {}",
276            p_k.ncols(),
277            rank,
278            centers.ncols()
279        );
280    }
281    let z = Arc::new(z);
282
283    if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
284        return Ok((**hit).clone());
285    }
286    cache.constraint_nullspace.map.insert(key, z.clone());
287    cache.constraint_nullspace.order.push(key);
288    while cache.constraint_nullspace.map.len() > CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES {
289        if cache.constraint_nullspace.order.is_empty() {
290            break;
291        }
292        let oldkey = cache.constraint_nullspace.order.remove(0);
293        cache.constraint_nullspace.map.remove(&oldkey);
294    }
295
296    Ok((*z).clone())
297}
298
299pub(crate) fn matern_identifiability_transform(
300    centers: ArrayView2<'_, f64>,
301    identifiability: &MaternIdentifiability,
302) -> Result<Option<Array2<f64>>, BasisError> {
303    let k = centers.nrows();
304    match identifiability {
305        MaternIdentifiability::None => Ok(None),
306        MaternIdentifiability::CenterSumToZero => {
307            let q = Array2::<f64>::ones((k, 1));
308            Ok(Some(kernel_constraint_nullspace_from_matrix(q.view())?))
309        }
310        MaternIdentifiability::CenterLinearOrthogonal => {
311            // Mirror the Duchon path: auto-degrade to Zero (constant-only) when
312            // there aren't enough centers to affinely span [1, x_1, ..., x_d].
313            // kernel_constraint_nullspace_from_matrix would otherwise hard-error
314            // via rrqr_nullspace_basis when centers.nrows() < d + 1.
315            let effective_order =
316                duchon_effective_nullspace_order(centers, DuchonNullspaceOrder::Linear);
317            let q = polynomial_block_from_order(centers, effective_order);
318            Ok(Some(kernel_constraint_nullspace_from_matrix(q.view())?))
319        }
320        MaternIdentifiability::FrozenTransform { transform, .. } => {
321            if transform.nrows() != k {
322                crate::bail_dim_basis!(
323                    "frozen Matérn identifiability transform mismatch: centers={k}, transform rows={}",
324                    transform.nrows()
325                );
326            }
327            Ok(Some(transform.clone()))
328        }
329    }
330}
331
332pub(crate) fn build_matern_operator_penalty_candidates(
333    centers: ArrayView2<'_, f64>,
334    length_scale: f64,
335    nu: MaternNu,
336    include_intercept: bool,
337    z_opt: Option<&Array2<f64>>,
338    aniso_log_scales: Option<&[f64]>,
339) -> Result<Vec<PenaltyCandidate>, BasisError> {
340    let ops = build_matern_collocation_operator_matrices(
341        centers,
342        None,
343        length_scale,
344        nu,
345        include_intercept,
346        z_opt.map(|z| z.view()),
347        aniso_log_scales,
348    )?;
349    // Gate the operator dials on the Matérn-ν RKHS smoothness so a rough kernel
350    // (e.g. ν=1/2) is not over-smoothed by a higher-order roughness penalty its
351    // own RKHS norm does not control (#707).
352    let matern_spec = DuchonOperatorPenaltySpec::matern_for_smoothness(nu, centers.ncols());
353    Ok(operator_penalty_candidates_from_collocation(
354        &ops.d0,
355        &ops.d1,
356        &ops.d2,
357        &matern_spec,
358    ))
359}
360
361/// Decide whether the matern double-penalty path emits the
362/// `DoublePenaltyNullspace` shrinkage candidate, honoring a FROZEN bootstrap-κ
363/// decision when one is present (gam#787/#860). `frozen` is
364/// `MaternIdentifiability::FrozenTransform`'s `nullspace_shrinkage_survived`:
365/// `Some(b)` forces the answer (so the learned-penalty count stays invariant as
366/// the κ-optimizer rebuilds the design), `None` falls back to the κ-dependent
367/// spectral test (the cold-build / non-frozen behavior). Returns the emitted
368/// candidate list together with the realized decision so the caller can record
369/// it into the basis metadata for the freeze step.
370/// True when every entry of `m` is finite. A non-finite projected kernel Gram
371/// or shrinkage projector must never be turned into a penalty: its root feeds
372/// the λ-weighted range block whose eigensolve hard-rejects non-finite input
373/// ("range penalty block contains non-finite entries", gam#1379). On certain
374/// 1-D `matern(x)` / `bs="gp"` data geometries the projected kernel Gram is
375/// numerically degenerate enough that the eigensolver returns non-finite
376/// near-null eigenvectors, so the spectral-projector shrinkage block comes back
377/// non-finite; we drop that block rather than poison the whole penalty.
378fn matrix_all_finite(m: &Array2<f64>) -> bool {
379    m.iter().all(|v| v.is_finite())
380}
381
382pub(crate) fn matern_double_penalty_candidates_with_decision(
383    primary: &Array2<f64>,
384    frozen: Option<bool>,
385) -> Result<(Vec<PenaltyCandidate>, bool), BasisError> {
386    // gam#1379 — guard the Primary projected kernel Gram itself. It is `Zᵀ K Z`
387    // with a finite Matérn kernel `K`, so it is finite in exact arithmetic; if a
388    // degenerate trial geometry made it non-finite we cannot ship it as a
389    // penalty (the range-block eigensolve would abort the fit). Surface a clear
390    // basis error instead of an opaque downstream "non-finite range penalty".
391    if !matrix_all_finite(primary) {
392        crate::bail_invalid_basis!(
393            "Matérn double-penalty primary kernel Gram is non-finite; the projected \
394             kernel `Zᵀ K Z` could not be formed at this length scale (degenerate \
395             geometry). Widen the data spread, change the length scale, or drop the term."
396        );
397    }
398    let mut candidates = vec![normalize_penalty_candidate(
399        primary.clone(),
400        0,
401        PenaltySource::Primary,
402    )];
403    let survived = match frozen {
404        Some(forced) => {
405            if forced
406                && let Some(shrinkage) = build_nullspace_shrinkage_penalty(primary)?
407                && matrix_all_finite(&shrinkage.sym_penalty)
408            {
409                candidates.push(normalize_penalty_candidate(
410                    shrinkage.sym_penalty,
411                    0,
412                    PenaltySource::DoublePenaltyNullspace,
413                ));
414                true
415            } else {
416                // Forced ON but the projected kernel has no near-zero direction
417                // at this κ (so there is literally no shrinkage subspace to
418                // build), OR forced OFF: emit only the primary kernel penalty.
419                // Forced-ON-without-a-subspace cannot manufacture a 7th penalty,
420                // but the frozen path only sets `Some(true)` when the bootstrap κ
421                // DID find a subspace, and the projected-kernel null space is a
422                // geometric property of the centers/transform (κ rescales every
423                // eigenvalue together), so the subspace persists across rebuilds.
424                false
425            }
426        }
427        None => {
428            if let Some(shrinkage) = build_nullspace_shrinkage_penalty(primary)?
429                && matrix_all_finite(&shrinkage.sym_penalty)
430            {
431                candidates.push(normalize_penalty_candidate(
432                    shrinkage.sym_penalty,
433                    0,
434                    PenaltySource::DoublePenaltyNullspace,
435                ));
436                true
437            } else {
438                false
439            }
440        }
441    };
442    Ok((candidates, survived))
443}
444
445pub(crate) fn build_matern_double_penalty_candidates(
446    spline: &MaternSplineBasis,
447    full_transform: Option<&Array2<f64>>,
448    frozen_nullspace_shrinkage_survived: Option<bool>,
449) -> Result<(Vec<PenaltyCandidate>, bool), BasisError> {
450    let primary = project_penalty_matrix(&spline.penalty_kernel, full_transform);
451    matern_double_penalty_candidates_with_decision(&primary, frozen_nullspace_shrinkage_survived)
452}
453
454/// Creates a Matérn spline basis from data and centers.
455///
456/// The design is `[K | 1]` when `include_intercept=true` and `[K]` otherwise, where:
457/// - `K_ij = k(||x_i - c_j||; length_scale, nu)` is the Matérn kernel block.
458///
459/// The default kernel penalty is `alpha' S alpha` with `S_jl = k(||c_j - c_l||)`, embedded
460/// in the full coefficient space. With intercept included, that column is unpenalized by
461/// `penalty_kernel`; optional `penalty_ridge` is a nullspace projector used for
462/// double-penalty shrinkage of previously unpenalized directions.
463///
464/// NOTE: This follows the RKHS Gram construction S = K_CC (not K_CC^{-1}) in
465/// coefficient space, with global scaling absorbed by the smoothing parameter λ.
466pub fn create_matern_spline_basiswithworkspace(
467    data: ArrayView2<'_, f64>,
468    centers: ArrayView2<'_, f64>,
469    length_scale: f64,
470    nu: MaternNu,
471    include_intercept: bool,
472    aniso_log_scales: Option<&[f64]>,
473    workspace: &mut BasisWorkspace,
474) -> Result<MaternSplineBasis, BasisError> {
475    let n = data.nrows();
476    let d = data.ncols();
477    let k = centers.nrows();
478    let total_cols = k + usize::from(include_intercept);
479    let dense_bytes = dense_design_bytes(n, total_cols);
480    if dense_bytes > workspace.policy().max_single_materialization_bytes {
481        crate::bail_invalid_basis!(
482            "Matérn basis dense design exceeds resource policy: n={n}, p={total_cols}, dense={:.1} MiB, cap={:.1} MiB",
483            dense_bytes as f64 / (1024.0 * 1024.0),
484            workspace.policy().max_single_materialization_bytes as f64 / (1024.0 * 1024.0),
485        );
486    }
487
488    if d == 0 {
489        crate::bail_invalid_basis!("Matérn basis requires at least one covariate dimension");
490    }
491    if k == 0 {
492        crate::bail_invalid_basis!("Matérn basis requires at least one center");
493    }
494    if centers.ncols() != d {
495        crate::bail_dim_basis!(
496            "Matérn basis dimension mismatch: data has {d} columns, centers have {}",
497            centers.ncols()
498        );
499    }
500    if data.iter().any(|v| !v.is_finite()) || centers.iter().any(|v| !v.is_finite()) {
501        crate::bail_invalid_basis!("Matérn basis requires finite data and center values");
502    }
503    validate_matern_length_scale(length_scale)?;
504    if let Some(eta) = aniso_log_scales {
505        if eta.len() != d {
506            crate::bail_dim_basis!(
507                "aniso_log_scales length {} does not match data dimension {d}",
508                eta.len()
509            );
510        }
511        if eta.iter().any(|v| !v.is_finite()) {
512            crate::bail_invalid_basis!("aniso_log_scales must contain finite values");
513        }
514    }
515
516    // Practical safe operating range for κ from center geometry (document Eq. D.2):
517    //   κ in [1e-2 / r_max, 1e2 / r_min], with κ = 1/length_scale.
518    // Warn rather than silently clamp so callers keep explicit control.
519    // Under anisotropy the kernel metric is y-space (y_a = exp(η_a) x_a), so
520    // the relevant r_min/r_max are y-space pairwise distances, not raw.
521    let warn_bounds = if let Some(eta) = aniso_log_scales {
522        let y_centers = points_in_aniso_y_space(centers, eta);
523        pairwise_distance_bounds(y_centers.view())
524    } else {
525        pairwise_distance_bounds(centers)
526    };
527    if let Some((r_min, r_max)) = warn_bounds {
528        let kappa = 1.0 / length_scale.max(1e-300);
529        let kappa_lo = 1e-2 / r_max;
530        let kappa_hi = 1e2 / r_min;
531        if kappa < kappa_lo || kappa > kappa_hi {
532            log::debug!(
533                "Matérn κ={} is outside recommended range [{}, {}] derived from centers (r_min={}, r_max={}); kernel conditioning may degrade",
534                kappa,
535                kappa_lo,
536                kappa_hi,
537                r_min,
538                r_max
539            );
540        }
541    }
542
543    // Distance computation: anisotropic when eta is present, isotropic otherwise.
544    // Under anisotropy we work in y-space (y = Ax), so r = |Ah| replaces |h|.
545    let mut kernel_block = Array2::<f64>::zeros((n, k));
546    let mut center_kernel = Array2::<f64>::zeros((k, k));
547    let axis_scales = aniso_log_scales.map(aniso_axis_scales);
548    let kernel_result: Result<(), BasisError> = kernel_block
549        .axis_iter_mut(Axis(0))
550        .into_par_iter()
551        .enumerate()
552        .try_for_each(|(i, mut row)| {
553            for j in 0..k {
554                let r = if let Some(scales) = axis_scales.as_deref() {
555                    aniso_distance_rows_with_scales(data, i, centers, j, scales)
556                } else {
557                    euclidean_distance_rows(data, i, centers, j)
558                };
559                row[j] = matern_kernel_from_distance(r, length_scale, nu)?;
560            }
561            Ok(())
562        });
563    kernel_result?;
564    // Center-center Gram matrix K_CC. In RKHS form, the kernel penalty on
565    // radial coefficients is alpha^T K_CC alpha.
566    fill_symmetric_from_row_kernel(&mut center_kernel, |i, j| {
567        let r = if let Some(scales) = axis_scales.as_deref() {
568            aniso_distance_rows_with_scales(centers, i, centers, j, scales)
569        } else {
570            euclidean_distance_rows(centers, i, centers, j)
571        };
572        matern_kernel_from_distance(r, length_scale, nu)
573    })?;
574
575    let mut basis = Array2::<f64>::zeros((n, total_cols));
576    basis.slice_mut(s![.., 0..k]).assign(&kernel_block);
577    if include_intercept {
578        basis.column_mut(k).fill(1.0);
579    }
580
581    let mut penalty_kernel = Array2::<f64>::zeros((total_cols, total_cols));
582    // RKHS coefficient penalty uses the center Gram matrix directly:
583    //   S = K_CC  (not K_CC^{-1}).
584    // This matches Duchon/Matérn spline theory where alpha^T K_CC alpha is the
585    // native-space quadratic form up to a global scaling absorbed by lambda.
586    penalty_kernel
587        .slice_mut(s![0..k, 0..k])
588        .assign(&center_kernel);
589    let penalty_ridge = build_nullspace_shrinkage_penalty(&penalty_kernel)?
590        .map(|block| block.sym_penalty)
591        .unwrap_or_else(|| Array2::<f64>::zeros((total_cols, total_cols)));
592
593    Ok(MaternSplineBasis {
594        basis,
595        penalty_kernel,
596        penalty_ridge,
597        num_kernel_basis: k,
598        num_polynomial_basis: usize::from(include_intercept),
599        dimension: d,
600    })
601}
602
603#[inline]
604pub(crate) fn validate_lat_lon_matrix(
605    data: ArrayView2<'_, f64>,
606    context: &str,
607    radians: bool,
608) -> Result<(), BasisError> {
609    if data.ncols() != 2 {
610        crate::bail_dim_basis!(
611            "{context} requires exactly two columns: latitude and longitude; got {}",
612            data.ncols()
613        );
614    }
615    if data.nrows() == 0 {
616        crate::bail_invalid_basis!("{context} requires at least one row");
617    }
618    let (lat_lo, lat_hi, unit) = if radians {
619        (
620            -std::f64::consts::FRAC_PI_2,
621            std::f64::consts::FRAC_PI_2,
622            "radians",
623        )
624    } else {
625        (-90.0, 90.0, "degrees")
626    };
627    for (i, row) in data.outer_iter().enumerate() {
628        let lat = row[0];
629        let lon = row[1];
630        if !lat.is_finite() || !lon.is_finite() {
631            crate::bail_invalid_basis!(
632                "{context} requires finite latitude/longitude; row {i} has ({lat}, {lon})"
633            );
634        }
635        if !(lat_lo..=lat_hi).contains(&lat) {
636            crate::bail_invalid_basis!(
637                "{context} latitude must be in [{lat_lo}, {lat_hi}] {unit}; row {i} has {lat}"
638            );
639        }
640    }
641    Ok(())
642}
643
644pub fn spherical_wahba_kernel_matrix(
645    data: ArrayView2<'_, f64>,
646    centers: ArrayView2<'_, f64>,
647    penalty_order: usize,
648    radians: bool,
649) -> Result<Array2<f64>, BasisError> {
650    spherical_wahba_kernel_matrix_with_kind(
651        data,
652        centers,
653        penalty_order,
654        radians,
655        SphereWahbaKernel::Sobolev,
656    )
657}
658
659pub fn spherical_wahba_kernel_matrix_with_kind(
660    data: ArrayView2<'_, f64>,
661    centers: ArrayView2<'_, f64>,
662    penalty_order: usize,
663    radians: bool,
664    kernel: SphereWahbaKernel,
665) -> Result<Array2<f64>, BasisError> {
666    validate_lat_lon_matrix(data, "spherical spline data", radians)?;
667    validate_lat_lon_matrix(centers, "spherical spline centers", radians)?;
668    let n = data.nrows();
669    let k = centers.nrows();
670    let deg = if radians {
671        1.0
672    } else {
673        std::f64::consts::PI / 180.0
674    };
675    // Precompute (sin_lat, cos_lat, sin_lon, cos_lon) for each center once.
676    // Using cos(lon - lon_c) = cos(lon)·cos(lon_c) + sin(lon)·sin(lon_c)
677    // collapses the inner-loop trig from one `.cos()` per (i, j) down to
678    // four multiplies and an add — a ~10x speedup on the inner body at
679    // large-scale N·K.
680    let mut sin_lat_c = Vec::<f64>::with_capacity(k);
681    let mut cos_lat_c = Vec::<f64>::with_capacity(k);
682    let mut sin_lon_c = Vec::<f64>::with_capacity(k);
683    let mut cos_lon_c = Vec::<f64>::with_capacity(k);
684    for c in centers.outer_iter() {
685        let lat = c[0] * deg;
686        let lon = c[1] * deg;
687        let (s_lat, c_lat) = lat.sin_cos();
688        let (s_lon, c_lon) = lon.sin_cos();
689        sin_lat_c.push(s_lat);
690        cos_lat_c.push(c_lat);
691        sin_lon_c.push(s_lon);
692        cos_lon_c.push(c_lon);
693    }
694    let mut out = Array2::<f64>::zeros((n, k));
695    let err_flag = std::sync::atomic::AtomicBool::new(false);
696    out.axis_chunks_iter_mut(ndarray::Axis(0), 256)
697        .into_par_iter()
698        .enumerate()
699        .for_each(|(chunk_idx, mut block)| {
700            use wide::f64x4;
701            let row_offset = chunk_idx * 256;
702            let chunks = k / 4;
703            let tail = k % 4;
704            for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
705                let i = row_offset + local_i;
706                let lat = data[(i, 0)] * deg;
707                let lon = data[(i, 1)] * deg;
708                let (sin_lat, cos_lat) = lat.sin_cos();
709                let (sin_lon, cos_lon) = lon.sin_cos();
710                let sin_lat_v = f64x4::from(sin_lat);
711                let cos_lat_v = f64x4::from(cos_lat);
712                let sin_lon_v = f64x4::from(sin_lon);
713                let cos_lon_v = f64x4::from(cos_lon);
714                // SIMD over 4 centers at a time.
715                for cidx in 0..chunks {
716                    let base = cidx * 4;
717                    let sl_c = f64x4::from([
718                        sin_lat_c[base],
719                        sin_lat_c[base + 1],
720                        sin_lat_c[base + 2],
721                        sin_lat_c[base + 3],
722                    ]);
723                    let cl_c = f64x4::from([
724                        cos_lat_c[base],
725                        cos_lat_c[base + 1],
726                        cos_lat_c[base + 2],
727                        cos_lat_c[base + 3],
728                    ]);
729                    let sn_c = f64x4::from([
730                        sin_lon_c[base],
731                        sin_lon_c[base + 1],
732                        sin_lon_c[base + 2],
733                        sin_lon_c[base + 3],
734                    ]);
735                    let cn_c = f64x4::from([
736                        cos_lon_c[base],
737                        cos_lon_c[base + 1],
738                        cos_lon_c[base + 2],
739                        cos_lon_c[base + 3],
740                    ]);
741                    let dlon_cos = cos_lon_v * cn_c + sin_lon_v * sn_c;
742                    let cos_gamma = sin_lat_v * sl_c + cos_lat_v * cl_c * dlon_cos;
743                    let vals =
744                        wahba_sphere_kernel_from_cos_simd_kind(cos_gamma, penalty_order, kernel);
745                    let arr = vals.to_array();
746                    for lane in 0..4 {
747                        if !arr[lane].is_finite() {
748                            err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
749                            return;
750                        }
751                        out_row[base + lane] = arr[lane];
752                    }
753                }
754                // Scalar tail (0..3 elements).
755                let tail_start = chunks * 4;
756                for t in 0..tail {
757                    let j = tail_start + t;
758                    let dlon_cos = cos_lon * cos_lon_c[j] + sin_lon * sin_lon_c[j];
759                    let cos_gamma = sin_lat * sin_lat_c[j] + cos_lat * cos_lat_c[j] * dlon_cos;
760                    match wahba_sphere_kernel_from_cos_kind(cos_gamma, penalty_order, kernel) {
761                        Ok(v) => out_row[j] = v,
762                        Err(_) => {
763                            err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
764                            return;
765                        }
766                    }
767                }
768            }
769        });
770    if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
771        crate::bail_invalid_basis!("spherical spline kernel produced a non-finite value");
772    }
773    Ok(out)
774}
775
776pub(crate) fn weighted_coefficient_sum_to_zero_transform(
777    weights: ArrayView1<'_, f64>,
778) -> Result<Array2<f64>, BasisError> {
779    let k = weights.len();
780    if k < 2 {
781        return Err(BasisError::InsufficientColumnsForConstraint { found: k });
782    }
783    if weights.iter().any(|w| !w.is_finite() || *w < 0.0) {
784        crate::bail_invalid_basis!(
785            "sphere coefficient constraint weights must be finite and non-negative"
786        );
787    }
788    let norm = weights.iter().map(|w| w * w).sum::<f64>().sqrt();
789    if norm <= 0.0 {
790        crate::bail_invalid_basis!("sphere coefficient constraint weights cannot all be zero");
791    }
792    let c = Array2::from_shape_vec((k, 1), weights.iter().map(|w| *w / norm).collect())
793        .map_err(|e| BasisError::InvalidInput(format!("invalid sphere constraint weights: {e}")))?;
794    let (z, rank) =
795        rrqr_nullspace_basis(&c, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
796    if rank >= k {
797        return Err(BasisError::ConstraintNullspaceCollapsed {
798            site: "weighted_coefficient_sum_to_zero_transform",
799            cross_rank: rank,
800            coeff_dim: k,
801            cross_frobenius: 1.0,
802            gram_spectrum: "not computed (structural rank collapse before Gram eigendecomposition)"
803                .to_string(),
804        });
805    }
806    Ok(z)
807}
808
809pub fn select_spherical_farthest_point_centers(
810    data: ArrayView2<'_, f64>,
811    num_centers: usize,
812    radians: bool,
813) -> Result<Array2<f64>, BasisError> {
814    validate_lat_lon_matrix(data, "spherical farthest-point centers", radians)?;
815    if num_centers == 0 {
816        crate::bail_invalid_basis!("spherical farthest-point center count must be positive");
817    }
818
819    let to_units = if radians {
820        1.0
821    } else {
822        180.0 / std::f64::consts::PI
823    };
824    let golden_angle = std::f64::consts::PI * (3.0 - 5.0_f64.sqrt());
825    let mut centers = Array2::<f64>::zeros((num_centers, 2));
826    for i in 0..num_centers {
827        let z = (2.0 * i as f64 + 1.0) / num_centers as f64 - 1.0;
828        let lon = ((i as f64) * golden_angle + std::f64::consts::PI)
829            .rem_euclid(std::f64::consts::TAU)
830            - std::f64::consts::PI;
831        centers[[i, 0]] = z.asin() * to_units;
832        centers[[i, 1]] = lon * to_units;
833    }
834    Ok(centers)
835}
836
837/// Auto-derive a streaming row chunk size for dense basis evaluation.
838///
839/// The opt-in `streaming_chunk_size` knob has been removed from public specs:
840/// streaming activates automatically when the would-be dense buffer
841/// `n_rows * n_basis_cols * 8 bytes` exceeds 1 GiB. When streaming is
842/// active, the chunk size is sized so each resident chunk holds ~256 MiB
843/// of `f64` (`chunk = (256 MiB) / (n_basis_cols * 8)`), clamped to
844/// `[1024, n_rows]`. Returning `None` means "do not stream, materialize
845/// densely".
846pub fn auto_streaming_chunk_size_for_dense(n_rows: usize, n_basis_cols: usize) -> Option<usize> {
847    if n_rows == 0 || n_basis_cols == 0 {
848        return None;
849    }
850    const DENSE_THRESHOLD_BYTES: usize = 1024 * 1024 * 1024;
851    const TARGET_CHUNK_BYTES: usize = 256 * 1024 * 1024;
852    const MIN_CHUNK_ROWS: usize = 1024;
853    let dense_bytes = n_rows.saturating_mul(n_basis_cols).saturating_mul(8);
854    if dense_bytes <= DENSE_THRESHOLD_BYTES {
855        return None;
856    }
857    let row_bytes = n_basis_cols.saturating_mul(8).max(1);
858    let raw_chunk = TARGET_CHUNK_BYTES / row_bytes;
859    let clamped = raw_chunk.max(MIN_CHUNK_ROWS).min(n_rows);
860    Some(clamped)
861}