gam 0.3.115

Generalized penalized likelihood engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
//! Constant-curvature (`M_κ`) smooth term: basis + penalty over the
//! κ-stereographic chart (#944, stage 3 step 1).
//!
//! The term is the κ-generic sibling of the intrinsic-S² Wahba smooth
//! (`sphere_spec.rs` / `build_spherical_spline_basis`): a reproducing-kernel
//! basis on a center set, with the kernel Gram on the centers as the RKHS
//! roughness penalty and a coefficient-space sum-to-zero constraint for
//! identifiability. Where the Wahba smooth hard-codes S² (lat/lon chart,
//! Legendre kernels), this term takes the geometry from
//! [`crate::geometry::constant_curvature::ConstantCurvature`] at an explicit
//! curvature κ, so one construction covers the whole interpolation
//! `S^d(1/√κ) → ℝ^d → H^d(1/√−κ)` through κ = 0.
//!
//! # Kernel
//!
//! `K_κ(x, y) = exp(−d_κ(x, y) / ℓ)` — the geodesic-exponential kernel, where
//! `d_κ` is the exact constant-curvature geodesic distance in the
//! κ-stereographic chart. The geodesic distance is a kernel of conditionally
//! negative type on all three constant-curvature space forms (Schoenberg 1942
//! for `S^d`; classical CND of `‖·‖` on `ℝ^d`; Faraut–Harzallah 1974 for
//! `H^d`), so `exp(−c·d_κ)` is positive definite for every `c > 0` and every
//! κ — the Gram on distinct centers is strictly PD, which is exactly what the
//! RKHS penalty construction needs. At κ = 0 the chart carries the doubled
//! gauge (`metric 4δ`, `d_0(x, y) = 2‖x − y‖`), so the κ = 0 term is the
//! Euclidean exponential (Matérn-½) kernel smooth with effective Euclidean
//! range `ℓ/2`.
//!
//! # κ-differentiability contract (what the ψ-channel stage consumes)
//!
//! Every κ-moving piece of this construction is differentiable in κ via the
//! exact κ-jets landed in stage 2, and every κ-FIXED piece is documented as
//! such so the later ψ-channel wiring (`∂X/∂κ`, `∂S/∂κ` into the LAML outer
//! gradient, Matérn iso-κ optimizer as the template) needs no new calculus:
//!
//! - **Centers are κ-fixed.** Center selection runs in chart coordinates
//!   (farthest-point / k-means / user-provided) and deliberately does NOT
//!   consult κ, so `∂(centers)/∂κ ≡ 0` and the design moves with κ only
//!   through the kernel. A κ-dependent center rule would add an
//!   uncontrolled, non-smooth term to the design drift.
//! - **The length scale ℓ is κ-fixed.** The auto-initialized ℓ is derived
//!   from chart-coordinate (κ = 0 gauge) center spacing only, and an
//!   explicit user ℓ is a constant. `∂ℓ/∂κ ≡ 0`.
//! - **The constraint transform `z` is κ-fixed.** Uniform coefficient
//!   weights; at fit time the global identifiability pipeline composes the
//!   parametric orthogonalization onto it and the result is FROZEN
//!   (mirroring `SphericalSplineIdentifiability::FrozenTransform`, #532), so
//!   the predict/ψ-trial rebuild replays the same `z` verbatim.
//! - **The kernel has exact κ-jets.** `∂K/∂κ` and `∂²K/∂κ²` follow from
//!   `distance_kappa_jet` (Tower4-exact, FD-gated) by the chain rule — see
//!   [`constant_curvature_kernel_kappa_jets`]. Therefore:
//!   `∂X_raw/∂κ = ∂K(data, centers)/∂κ`, realized design drift
//!   `∂X/∂κ = (∂K/∂κ)·z`, and penalty drift `∂S_raw/∂κ = zᵀ(∂K(centers,
//!   centers)/∂κ)z` are all available in closed form from this module today.
//!   (The penalty handed to the optimizer is Frobenius-normalized; the
//!   ψ-channel must route its κ-derivative through the same normalization
//!   rule — `normalize_penaltywith_psi_derivatives` is the existing seam.)
//! - **Available but not yet consumed:** `log_map_kappa_jet` /
//!   `exp_map_kappa_jet` cover future geodesic/normal-coordinate basis
//!   variants (e.g. tangent-space designs); the distance jet is the only one
//!   this kernel construction needs.

use ndarray::{Array1, Array2, ArrayView2, Axis};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

use crate::geometry::constant_curvature::{ConstantCurvature, distance_kappa_jet};

use super::{
    BasisBuildResult, BasisError, BasisMetadata, CenterStrategy, PenaltyCandidate, PenaltySource,
    filter_active_penalty_candidates_with_ops, normalize_penalty, select_centers_by_strategy,
    weighted_coefficient_sum_to_zero_transform,
};

/// Realized-design identifiability policy for the constant-curvature smooth.
/// Mirrors [`super::SphericalSplineIdentifiability`] (#532): the fit-time
/// center-space sum-to-zero `z` gets the parametric orthogonalization composed
/// onto it by the global identifiability pipeline, and the composed transform
/// is frozen here so predict-time (and future per-ψ-trial) rebuilds replay it
/// verbatim instead of recomputing `z` from the centers.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum ConstantCurvatureIdentifiability {
    /// Fit-time default: uniform-weight coefficient sum-to-zero over the
    /// centers (`Σ_j α_j = 0`), then global parametric residualization.
    #[default]
    CenterSumToZero,
    /// Predict-time replay: the frozen composed transform captured at fit
    /// time. `transform.nrows()` equals the number of centers.
    FrozenTransform { transform: Array2<f64> },
}

/// Constant-curvature smooth configuration (`curv(x, z, kappa = …)`).
///
/// The chart inputs are the raw feature columns interpreted as
/// κ-stereographic chart coordinates: any finite point for κ ≥ 0, the open
/// ball `‖x‖ < 1/√(−κ)` for κ < 0. The default κ = 0 reproduces a Euclidean
/// exponential-kernel smooth (in the doubled κ = 0 chart gauge), so the term
/// is safe to use as a drop-in flat smooth until κ becomes a fitted
/// ψ-coordinate.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantCurvatureBasisSpec {
    /// Center/knot selection strategy in chart coordinates. Deliberately
    /// κ-independent (see the module-level κ-contract).
    pub center_strategy: CenterStrategy,
    /// Sectional curvature κ of the latent/feature geometry. Fixed at build
    /// time; the later ψ-channel stage promotes it to a fitted outer
    /// coordinate consuming this module's exact κ-jets.
    pub kappa: f64,
    /// Geodesic kernel range ℓ in `K_κ = exp(−d_κ/ℓ)`. The `0.0` sentinel
    /// requests the κ-independent auto initialization
    /// ([`realized_constant_curvature_length_scale`]); the realized value is
    /// persisted in [`BasisMetadata::ConstantCurvature`] and frozen back into
    /// the spec for predict-time replay.
    pub length_scale: f64,
    /// Add the ridge-like shrinkage penalty alongside the RKHS Gram penalty.
    pub double_penalty: bool,
    /// Realized-design identifiability policy (see type docs).
    #[serde(default)]
    pub identifiability: ConstantCurvatureIdentifiability,
}

impl Default for ConstantCurvatureBasisSpec {
    fn default() -> Self {
        Self {
            center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
            kappa: 0.0,
            length_scale: 0.0,
            double_penalty: true,
            identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
        }
    }
}

/// Validate that every row of `points` is finite and inside the
/// κ-stereographic chart (`1 + κ‖x‖² > 0`; automatic for κ ≥ 0, the open-ball
/// constraint for κ < 0).
fn validate_chart_points(
    points: ArrayView2<'_, f64>,
    kappa: f64,
    what: &str,
) -> Result<(), BasisError> {
    for (i, row) in points.outer_iter().enumerate() {
        let mut nx2 = 0.0_f64;
        for &v in row.iter() {
            if !v.is_finite() {
                crate::bail_invalid_basis!(
                    "constant-curvature {what} row {i} has a non-finite coordinate"
                );
            }
            nx2 += v * v;
        }
        if 1.0 + kappa * nx2 <= 0.0 {
            crate::bail_invalid_basis!(
                "constant-curvature {what} row {i} lies outside the κ-stereographic chart \
                 (need 1 + κ·‖x‖² > 0; got κ = {kappa}, ‖x‖² = {nx2}); for κ < 0 the chart is \
                 the open ball ‖x‖ < 1/√(−κ)"
            );
        }
    }
    Ok(())
}

/// `K_κ(data, centers)` — the geodesic-exponential kernel matrix
/// `exp(−d_κ(x_i, c_j)/ℓ)`.
pub fn constant_curvature_kernel_matrix(
    data: ArrayView2<'_, f64>,
    centers: ArrayView2<'_, f64>,
    kappa: f64,
    length_scale: f64,
) -> Result<Array2<f64>, BasisError> {
    if data.ncols() != centers.ncols() {
        crate::bail_dim_basis!(
            "constant-curvature kernel dimension mismatch: data d={} centers d={}",
            data.ncols(),
            centers.ncols()
        );
    }
    if !(length_scale.is_finite() && length_scale > 0.0) {
        crate::bail_invalid_basis!(
            "constant-curvature kernel needs a positive finite length_scale; got {length_scale}"
        );
    }
    validate_chart_points(data, kappa, "data")?;
    validate_chart_points(centers, kappa, "centers")?;
    let manifold = ConstantCurvature::new(data.ncols(), kappa);
    let mut out = Array2::<f64>::zeros((data.nrows(), centers.nrows()));
    out.axis_iter_mut(Axis(0))
        .into_par_iter()
        .enumerate()
        .try_for_each(|(i, mut row)| -> Result<(), BasisError> {
            for (j, c) in centers.outer_iter().enumerate() {
                let d = manifold.distance(data.row(i), c).map_err(|e| {
                    BasisError::InvalidInput(format!(
                        "constant-curvature distance failed at (row {i}, center {j}): {e}"
                    ))
                })?;
                row[j] = (-d / length_scale).exp();
            }
            Ok(())
        })?;
    Ok(out)
}

/// `(K, ∂K/∂κ, ∂²K/∂κ²)` of the raw (pre-constraint) kernel matrix — the
/// ψ-channel hook. Exact: rides `distance_kappa_jet` (Tower4, FD-gated in
/// `geometry::constant_curvature`) through the chain rule for
/// `K = exp(−d/ℓ)` at κ-FIXED ℓ and centers (see the module κ-contract):
///
/// ```text
///   ∂K/∂κ  = −(d′/ℓ) · K
///   ∂²K/∂κ² = ((d′/ℓ)² − d″/ℓ) · K
/// ```
///
/// The realized design/penalty drifts follow by the κ-fixed transforms:
/// `∂X/∂κ = (∂K/∂κ)·z` and `∂S_raw/∂κ = zᵀ(∂K/∂κ)z` (centers×centers), with
/// the Frobenius penalty normalization differentiated by the existing
/// `normalize_penaltywith_psi_derivatives` seam.
pub fn constant_curvature_kernel_kappa_jets(
    data: ArrayView2<'_, f64>,
    centers: ArrayView2<'_, f64>,
    kappa: f64,
    length_scale: f64,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
    if data.ncols() != centers.ncols() {
        crate::bail_dim_basis!(
            "constant-curvature kernel-jet dimension mismatch: data d={} centers d={}",
            data.ncols(),
            centers.ncols()
        );
    }
    if !(length_scale.is_finite() && length_scale > 0.0) {
        crate::bail_invalid_basis!(
            "constant-curvature kernel jets need a positive finite length_scale; got {length_scale}"
        );
    }
    validate_chart_points(data, kappa, "data")?;
    validate_chart_points(centers, kappa, "centers")?;
    let manifold = ConstantCurvature::new(data.ncols(), kappa);
    let n = data.nrows();
    let m = centers.nrows();
    let mut value = Array2::<f64>::zeros((n, m));
    let mut dk = Array2::<f64>::zeros((n, m));
    let mut dkk = Array2::<f64>::zeros((n, m));
    let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
        .into_par_iter()
        .map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
            let mut row = Vec::with_capacity(m);
            for (j, c) in centers.outer_iter().enumerate() {
                let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
                    BasisError::InvalidInput(format!(
                        "constant-curvature distance κ-jet failed at (row {i}, center {j}): {e}"
                    ))
                })?;
                let k = (-d / length_scale).exp();
                let g = d1 / length_scale;
                row.push((k, -g * k, (g * g - d2 / length_scale) * k));
            }
            Ok((i, row))
        })
        .collect::<Result<Vec<_>, BasisError>>()?;
    for (i, row) in rows {
        for (j, (k, k1, k2)) in row.into_iter().enumerate() {
            value[(i, j)] = k;
            dk[(i, j)] = k1;
            dkk[(i, j)] = k2;
        }
    }
    Ok((value, dk, dkk))
}

/// Resolve the realized kernel range ℓ. An explicit positive `spec_length_scale`
/// is used verbatim; the `0.0` sentinel auto-initializes from the median
/// pairwise CHART distance among the centers, doubled to match the κ = 0
/// chart gauge (`d_0 = 2‖Δ‖`).
///
/// κ-contract: the auto rule reads chart coordinates only — it never consults
/// κ — so the realized ℓ is a κ-CONSTANT and contributes no `∂ℓ/∂κ` term to
/// the design drift.
pub fn realized_constant_curvature_length_scale(
    centers: ArrayView2<'_, f64>,
    spec_length_scale: f64,
) -> Result<f64, BasisError> {
    if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
        return Ok(spec_length_scale);
    }
    if spec_length_scale != 0.0 {
        crate::bail_invalid_basis!(
            "constant-curvature length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
        );
    }
    let m = centers.nrows();
    if m < 2 {
        return Err(BasisError::InsufficientColumnsForConstraint { found: m });
    }
    let mut dists: Vec<f64> = Vec::with_capacity(m * (m - 1) / 2);
    for i in 0..m {
        for j in (i + 1)..m {
            let mut s = 0.0_f64;
            for k in 0..centers.ncols() {
                let dlt = centers[(i, k)] - centers[(j, k)];
                s += dlt * dlt;
            }
            dists.push(2.0 * s.sqrt());
        }
    }
    dists.sort_by(|a, b| a.partial_cmp(b).expect("finite chart distances"));
    let median = dists[dists.len() / 2];
    if !(median.is_finite() && median > 0.0) {
        crate::bail_invalid_basis!(
            "constant-curvature auto length_scale failed: centers are degenerate \
             (median pairwise chart distance = {median})"
        );
    }
    Ok(median)
}

/// Build the constant-curvature reproducing-kernel smooth: realized design
/// `K_κ(data, centers)·z`, RKHS penalty `zᵀ K_κ(centers, centers) z`, and the
/// replayable [`BasisMetadata::ConstantCurvature`]. Structure mirrors the
/// Wahba S² builder (`build_spherical_spline_basis`); geometry comes from
/// `ConstantCurvature` at the spec's fixed κ.
pub fn build_constant_curvature_basis(
    data: ArrayView2<'_, f64>,
    spec: &ConstantCurvatureBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
    if data.ncols() == 0 {
        crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
    }
    if !spec.kappa.is_finite() {
        crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
    }
    validate_chart_points(data, spec.kappa, "data")?;
    let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
    if centers.nrows() < 2 {
        return Err(BasisError::InsufficientColumnsForConstraint {
            found: centers.nrows(),
        });
    }
    validate_chart_points(centers.view(), spec.kappa, "centers")?;
    let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
    let raw_penalty =
        constant_curvature_kernel_matrix(centers.view(), centers.view(), spec.kappa, length_scale)?;
    // Realized-design constraint transform: uniform coefficient sum-to-zero at
    // fit time; the frozen composed `z · z_parametric` at predict time (#532
    // pattern — see ConstantCurvatureIdentifiability).
    let z = match &spec.identifiability {
        ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
            if transform.nrows() != centers.nrows() {
                crate::bail_dim_basis!(
                    "frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
                    centers.nrows(),
                    transform.nrows()
                );
            }
            transform.clone()
        }
        ConstantCurvatureIdentifiability::CenterSumToZero => {
            let weights = Array1::<f64>::ones(centers.nrows());
            weighted_coefficient_sum_to_zero_transform(weights.view())?
        }
    };
    let penalty = z.t().dot(&raw_penalty).dot(&z);
    let raw_design =
        constant_curvature_kernel_matrix(data, centers.view(), spec.kappa, length_scale)?;
    let design = crate::matrix::DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
        raw_design.dot(&z),
    ));
    let (penalty_norm, c_primary) = normalize_penalty(&((&penalty + &penalty.t()) * 0.5));
    let mut candidates = vec![PenaltyCandidate {
        matrix: penalty_norm,
        nullspace_dim_hint: 0,
        source: PenaltySource::Primary,
        normalization_scale: c_primary,
        kronecker_factors: None,
        op: None,
    }];
    if spec.double_penalty {
        let ridge = Array2::<f64>::eye(design.ncols());
        let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
        candidates.push(PenaltyCandidate {
            matrix: ridge_norm,
            nullspace_dim_hint: 0,
            source: PenaltySource::DoublePenaltyNullspace,
            normalization_scale: c_ridge,
            kronecker_factors: None,
            op: None,
        });
    }
    let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
        filter_active_penalty_candidates_with_ops(candidates)?;
    Ok(BasisBuildResult {
        design,
        penalties,
        nullspace_dims,
        penaltyinfo,
        metadata: BasisMetadata::ConstantCurvature {
            centers,
            kappa: spec.kappa,
            length_scale,
            constraint_transform: Some(z),
        },
        kronecker_factored: None,
        ops,
        null_eigenvectors,
        joint_null_rotation: None,
    })
}