gam_terms/basis/constant_curvature_smooth.rs
1//! Constant-curvature (`M_κ`) smooth term: basis + penalty over the
2//! κ-stereographic chart (#944, stage 3 step 1).
3//!
4//! The term is the κ-generic sibling of the intrinsic-S² Wahba smooth
5//! (`sphere_spec.rs` / `build_spherical_spline_basis`): a reproducing-kernel
6//! basis on a center set, with the kernel Gram on the centers as the RKHS
7//! roughness penalty and a coefficient-space sum-to-zero constraint for
8//! identifiability. Where the Wahba smooth hard-codes S² (lat/lon chart,
9//! Legendre kernels), this term takes the geometry from
10//! [`gam_geometry::constant_curvature::ConstantCurvature`] at an explicit
11//! curvature κ, so one construction covers the whole interpolation
12//! `S^d(1/√κ) → ℝ^d → H^d(1/√−κ)` through κ = 0.
13//!
14//! # Kernel
15//!
16//! `K_κ(x, y) = exp(−d_κ(x, y) / ℓ)` — the geodesic-exponential kernel, where
17//! `d_κ` is the exact constant-curvature geodesic distance in the
18//! κ-stereographic chart. The geodesic distance is a kernel of conditionally
19//! negative type on all three constant-curvature space forms (Schoenberg 1942
20//! for `S^d`; classical CND of `‖·‖` on `ℝ^d`; Faraut–Harzallah 1974 for
21//! `H^d`), so `exp(−c·d_κ)` is positive definite for every `c > 0` and every
22//! κ — the Gram on distinct centers is strictly PD, which is exactly what the
23//! RKHS penalty construction needs. At κ = 0 the chart carries the doubled
24//! gauge (`metric 4δ`, `d_0(x, y) = 2‖x − y‖`), so the κ = 0 term is the
25//! Euclidean exponential (Matérn-½) kernel smooth with effective Euclidean
26//! range `ℓ/2`.
27//!
28//! # κ-differentiability contract (what the ψ-channel stage consumes)
29//!
30//! Every κ-moving piece of this construction is differentiable in κ via the
31//! exact κ-jets landed in stage 2, and every κ-FIXED piece is documented as
32//! such so the later ψ-channel wiring (`∂X/∂κ`, `∂S/∂κ` into the LAML outer
33//! gradient, Matérn iso-κ optimizer as the template) needs no new calculus:
34//!
35//! - **Centers are κ-fixed.** Center selection runs in chart coordinates
36//! (farthest-point / k-means / user-provided) and deliberately does NOT
37//! consult κ, so `∂(centers)/∂κ ≡ 0` and the design moves with κ only
38//! through the kernel. A κ-dependent center rule would add an
39//! uncontrolled, non-smooth term to the design drift.
40//! - **The length scale ℓ is κ-fixed.** The auto-initialized ℓ is derived
41//! from chart-coordinate (κ = 0 gauge) center spacing only, and an
42//! explicit user ℓ is a constant. `∂ℓ/∂κ ≡ 0`.
43//! - **The constraint transform `z` is κ-fixed.** Uniform coefficient
44//! weights; at fit time the global identifiability pipeline composes the
45//! parametric orthogonalization onto it and the result is FROZEN
46//! (mirroring `SphericalSplineIdentifiability::FrozenTransform`, #532), so
47//! the predict/ψ-trial rebuild replays the same `z` verbatim.
48//! - **The kernel has exact κ-jets.** `∂K/∂κ` and `∂²K/∂κ²` follow from
49//! `distance_kappa_jet` (Tower4-exact, FD-gated) by the chain rule — see
50//! [`constant_curvature_kernel_kappa_jets`]. Therefore:
51//! `∂X_raw/∂κ = ∂K(data, centers)/∂κ`, realized design drift
52//! `∂X/∂κ = (∂K/∂κ)·z`, and penalty drift `∂S_raw/∂κ = zᵀ(∂K(centers,
53//! centers)/∂κ)z` are all available in closed form from this module today.
54//! (The penalty handed to the optimizer is Frobenius-normalized; the
55//! ψ-channel must route its κ-derivative through the same normalization
56//! rule — `normalize_penaltywith_psi_derivatives` is the existing seam.)
57//! - **Available but not yet consumed:** `log_map_kappa_jet` /
58//! `exp_map_kappa_jet` cover future geodesic/normal-coordinate basis
59//! variants (e.g. tangent-space designs); the distance jet is the only one
60//! this kernel construction needs.
61
62use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
63use rayon::prelude::*;
64use serde::{Deserialize, Serialize};
65
66use gam_geometry::constant_curvature::{ConstantCurvature, distance_kappa_jet};
67
68use super::{
69 BasisBuildResult, BasisError, BasisMetadata, BasisPsiDerivativeBundle,
70 BasisPsiDerivativeResult, BasisPsiSecondDerivativeResult, CenterStrategy, PenaltyCandidate,
71 PenaltyInfo, PenaltySource, filter_active_penalty_candidates_with_ops, normalize_penalty,
72 select_centers_by_strategy, weighted_coefficient_sum_to_zero_transform,
73};
74
75/// Realized-design identifiability policy for the constant-curvature smooth.
76/// Mirrors [`super::SphericalSplineIdentifiability`] (#532): the fit-time
77/// center-space sum-to-zero `z` gets the parametric orthogonalization composed
78/// onto it by the global identifiability pipeline, and the composed transform
79/// is frozen here so predict-time (and future per-ψ-trial) rebuilds replay it
80/// verbatim instead of recomputing `z` from the centers.
81#[derive(Debug, Clone, Serialize, Deserialize, Default)]
82pub enum ConstantCurvatureIdentifiability {
83 /// Fit-time default: uniform-weight coefficient sum-to-zero over the
84 /// centers (`Σ_j α_j = 0`), then global parametric residualization.
85 #[default]
86 CenterSumToZero,
87 /// Predict-time replay: the frozen composed transform captured at fit
88 /// time. `transform.nrows()` equals the number of centers.
89 FrozenTransform { transform: Array2<f64> },
90}
91
92/// Constant-curvature smooth configuration (`curv(x, z, kappa = …)`).
93///
94/// The chart inputs are the raw feature columns interpreted as
95/// κ-stereographic chart coordinates: any finite point for κ ≥ 0, the open
96/// ball `‖x‖ < 1/√(−κ)` for κ < 0. The default κ = 0 reproduces a Euclidean
97/// exponential-kernel smooth (in the doubled κ = 0 chart gauge), so the term
98/// is safe to use as a drop-in flat smooth until κ becomes a fitted
99/// ψ-coordinate.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ConstantCurvatureBasisSpec {
102 /// Center/knot selection strategy in chart coordinates. Deliberately
103 /// κ-independent (see the module-level κ-contract).
104 pub center_strategy: CenterStrategy,
105 /// Sectional curvature κ of the latent/feature geometry. Fixed at build
106 /// time; the later ψ-channel stage promotes it to a fitted outer
107 /// coordinate consuming this module's exact κ-jets.
108 pub kappa: f64,
109 /// Geodesic kernel range ℓ in `K_κ = exp(−d_κ/ℓ)`. The `0.0` sentinel
110 /// requests the κ-independent auto initialization
111 /// ([`realized_constant_curvature_length_scale`]); the realized value is
112 /// persisted in [`BasisMetadata::ConstantCurvature`] and frozen back into
113 /// the spec for predict-time replay.
114 pub length_scale: f64,
115 /// Add the ridge-like shrinkage penalty alongside the RKHS Gram penalty.
116 pub double_penalty: bool,
117 /// Realized-design identifiability policy (see type docs).
118 #[serde(default)]
119 pub identifiability: ConstantCurvatureIdentifiability,
120}
121
122impl Default for ConstantCurvatureBasisSpec {
123 fn default() -> Self {
124 Self {
125 center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
126 kappa: 0.0,
127 length_scale: 0.0,
128 // No double-penalty ridge by default (#1464). The RKHS Gram penalty
129 // zᵀKz is strictly PD/full-rank on distinct centers, so it already
130 // regularizes every coefficient direction — the ridge `I` adds no
131 // stability. Worse, `I` is curvature-BLIND: with its own λ it absorbs
132 // the data fit independently of κ, so the κ outer coordinate sees only
133 // the monotone Occam term (positive κ compresses geodesic distances →
134 // kernel log-det shrinks) and rails to the +chart bound for any curved
135 // data, recovering hyperbolic truth as spherical. Dropping the ridge
136 // matches the single-penalty profiled-REML oracle
137 // (`profiled_reml_identifies_curvature_sign_with_effective_length`),
138 // which identifies the curvature SIGN.
139 double_penalty: false,
140 identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
141 }
142 }
143}
144
145/// Validate that every row of `points` is finite and inside the
146/// κ-stereographic chart (`1 + κ‖x‖² > 0`; automatic for κ ≥ 0, the open-ball
147/// constraint for κ < 0).
148pub(crate) fn validate_chart_points(
149 points: ArrayView2<'_, f64>,
150 kappa: f64,
151 what: &str,
152) -> Result<(), BasisError> {
153 for (i, row) in points.outer_iter().enumerate() {
154 let mut nx2 = 0.0_f64;
155 for &v in row.iter() {
156 if !v.is_finite() {
157 crate::bail_invalid_basis!(
158 "constant-curvature {what} row {i} has a non-finite coordinate"
159 );
160 }
161 nx2 += v * v;
162 }
163 if 1.0 + kappa * nx2 <= 0.0 {
164 crate::bail_invalid_basis!(
165 "constant-curvature {what} row {i} lies outside the κ-stereographic chart \
166 (need 1 + κ·‖x‖² > 0; got κ = {kappa}, ‖x‖² = {nx2}); for κ < 0 the chart is \
167 the open ball ‖x‖ < 1/√(−κ)"
168 );
169 }
170 }
171 Ok(())
172}
173
174/// `K_κ(data, centers)` — the geodesic-exponential kernel matrix
175/// `exp(−d_κ(x_i, c_j)/ℓ)`.
176pub fn constant_curvature_kernel_matrix(
177 data: ArrayView2<'_, f64>,
178 centers: ArrayView2<'_, f64>,
179 kappa: f64,
180 length_scale: f64,
181) -> Result<Array2<f64>, BasisError> {
182 if data.ncols() != centers.ncols() {
183 crate::bail_dim_basis!(
184 "constant-curvature kernel dimension mismatch: data d={} centers d={}",
185 data.ncols(),
186 centers.ncols()
187 );
188 }
189 if !(length_scale.is_finite() && length_scale > 0.0) {
190 crate::bail_invalid_basis!(
191 "constant-curvature kernel needs a positive finite length_scale; got {length_scale}"
192 );
193 }
194 validate_chart_points(data, kappa, "data")?;
195 validate_chart_points(centers, kappa, "centers")?;
196 let manifold = ConstantCurvature::new(data.ncols(), kappa);
197 let mut out = Array2::<f64>::zeros((data.nrows(), centers.nrows()));
198 out.axis_iter_mut(Axis(0))
199 .into_par_iter()
200 .enumerate()
201 .try_for_each(|(i, mut row)| -> Result<(), BasisError> {
202 for (j, c) in centers.outer_iter().enumerate() {
203 let d = manifold.distance(data.row(i), c).map_err(|e| {
204 BasisError::InvalidInput(format!(
205 "constant-curvature distance failed at (row {i}, center {j}): {e}"
206 ))
207 })?;
208 row[j] = (-d / length_scale).exp();
209 }
210 Ok(())
211 })?;
212 Ok(out)
213}
214
215/// `(K, ∂K/∂κ, ∂²K/∂κ²)` of the raw (pre-constraint) kernel matrix — the
216/// ψ-channel hook. Exact: rides `distance_kappa_jet` (Tower4, FD-gated in
217/// `geometry::constant_curvature`) through the chain rule for
218/// `K = exp(−d/ℓ)` at κ-FIXED ℓ and centers (see the module κ-contract):
219///
220/// ```text
221/// ∂K/∂κ = −(d′/ℓ) · K
222/// ∂²K/∂κ² = ((d′/ℓ)² − d″/ℓ) · K
223/// ```
224///
225/// The realized design/penalty drifts follow by the κ-fixed transforms:
226/// `∂X/∂κ = (∂K/∂κ)·z` and `∂S_raw/∂κ = zᵀ(∂K/∂κ)z` (centers×centers), with
227/// the Frobenius penalty normalization differentiated by the existing
228/// `normalize_penaltywith_psi_derivatives` seam.
229pub fn constant_curvature_kernel_kappa_jets(
230 data: ArrayView2<'_, f64>,
231 centers: ArrayView2<'_, f64>,
232 kappa: f64,
233 length_scale: f64,
234) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
235 if data.ncols() != centers.ncols() {
236 crate::bail_dim_basis!(
237 "constant-curvature kernel-jet dimension mismatch: data d={} centers d={}",
238 data.ncols(),
239 centers.ncols()
240 );
241 }
242 if !(length_scale.is_finite() && length_scale > 0.0) {
243 crate::bail_invalid_basis!(
244 "constant-curvature kernel jets need a positive finite length_scale; got {length_scale}"
245 );
246 }
247 validate_chart_points(data, kappa, "data")?;
248 validate_chart_points(centers, kappa, "centers")?;
249 let manifold = ConstantCurvature::new(data.ncols(), kappa);
250 let n = data.nrows();
251 let m = centers.nrows();
252 let mut value = Array2::<f64>::zeros((n, m));
253 let mut dk = Array2::<f64>::zeros((n, m));
254 let mut dkk = Array2::<f64>::zeros((n, m));
255 let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
256 .into_par_iter()
257 .map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
258 let mut row = Vec::with_capacity(m);
259 for (j, c) in centers.outer_iter().enumerate() {
260 let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
261 BasisError::InvalidInput(format!(
262 "constant-curvature distance κ-jet failed at (row {i}, center {j}): {e}"
263 ))
264 })?;
265 let k = (-d / length_scale).exp();
266 let g = d1 / length_scale;
267 row.push((k, -g * k, (g * g - d2 / length_scale) * k));
268 }
269 Ok((i, row))
270 })
271 .collect::<Result<Vec<_>, BasisError>>()?;
272 for (i, row) in rows {
273 for (j, (k, k1, k2)) in row.into_iter().enumerate() {
274 value[(i, j)] = k;
275 dk[(i, j)] = k1;
276 dkk[(i, j)] = k2;
277 }
278 }
279 Ok((value, dk, dkk))
280}
281
282/// `(K, ∂K/∂κ, ∂²K/∂κ²)` of the raw kernel matrix when the kernel uses the
283/// fill-invariant effective length `L(κ)` (the #944 fix: `L` solves the fill
284/// target `g(L,κ)=fill⋆`, holding the kernel's effective DoF κ-invariant). Both
285/// the geodesic distance `d_κ` and the length `L(κ)` move with κ, so the exponent
286/// is the quotient `q = d/L` and the chain rule carries both jets:
287///
288/// ```text
289/// q = d / L
290/// q′ = d′/L − d·L′/L²
291/// q″ = d″/L − 2 d′ L′/L² − d L″/L² + 2 d (L′)²/L³
292/// K = e^{−q}, K′ = −q′K, K″ = ((q′)² − q″) K
293/// ```
294///
295/// `l_jet = (L, L′, L″)` is the effective-length κ-jet from
296/// [`constant_curvature_effective_length_jet`]; at κ = 0 it reduces to the
297/// fixed-ℓ jets (`L′ = L″` terms vanish only if the geometry is flat, but the
298/// formula is exact for all κ).
299pub(crate) fn constant_curvature_kernel_kappa_jets_scaled(
300 data: ArrayView2<'_, f64>,
301 centers: ArrayView2<'_, f64>,
302 kappa: f64,
303 l_jet: (f64, f64, f64),
304) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
305 if data.ncols() != centers.ncols() {
306 crate::bail_dim_basis!(
307 "constant-curvature scaled kernel-jet dimension mismatch: data d={} centers d={}",
308 data.ncols(),
309 centers.ncols()
310 );
311 }
312 let (l, l1, l2) = l_jet;
313 if !(l.is_finite() && l > 0.0) {
314 crate::bail_invalid_basis!(
315 "constant-curvature scaled kernel jets need a positive finite effective length; got {l}"
316 );
317 }
318 validate_chart_points(data, kappa, "data")?;
319 validate_chart_points(centers, kappa, "centers")?;
320 let manifold = ConstantCurvature::new(data.ncols(), kappa);
321 let n = data.nrows();
322 let m = centers.nrows();
323 let mut value = Array2::<f64>::zeros((n, m));
324 let mut dk = Array2::<f64>::zeros((n, m));
325 let mut dkk = Array2::<f64>::zeros((n, m));
326 let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
327 .into_par_iter()
328 .map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
329 let mut row = Vec::with_capacity(m);
330 for (j, c) in centers.outer_iter().enumerate() {
331 let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
332 BasisError::InvalidInput(format!(
333 "constant-curvature scaled distance κ-jet failed at (row {i}, center {j}): {e}"
334 ))
335 })?;
336 let q = d / l;
337 let q1 = d1 / l - d * l1 / (l * l);
338 let q2 = d2 / l - 2.0 * d1 * l1 / (l * l) - d * l2 / (l * l)
339 + 2.0 * d * l1 * l1 / (l * l * l);
340 let k = (-q).exp();
341 row.push((k, -q1 * k, (q1 * q1 - q2) * k));
342 }
343 Ok((i, row))
344 })
345 .collect::<Result<Vec<_>, BasisError>>()?;
346 for (i, row) in rows {
347 for (j, (k, k1, k2)) in row.into_iter().enumerate() {
348 value[(i, j)] = k;
349 dk[(i, j)] = k1;
350 dkk[(i, j)] = k2;
351 }
352 }
353 Ok((value, dk, dkk))
354}
355
356/// Resolve the realized kernel range ℓ. An explicit positive `spec_length_scale`
357/// is used verbatim; the `0.0` sentinel auto-initializes from the median
358/// pairwise CHART distance among the centers, doubled to match the κ = 0
359/// chart gauge (`d_0 = 2‖Δ‖`).
360///
361/// κ-contract: the auto rule reads chart coordinates only — it never consults
362/// κ — so the realized ℓ is a κ-CONSTANT and contributes no `∂ℓ/∂κ` term to
363/// the design drift.
364pub fn realized_constant_curvature_length_scale(
365 centers: ArrayView2<'_, f64>,
366 spec_length_scale: f64,
367) -> Result<f64, BasisError> {
368 if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
369 return Ok(spec_length_scale);
370 }
371 if spec_length_scale != 0.0 {
372 crate::bail_invalid_basis!(
373 "constant-curvature length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
374 );
375 }
376 let m = centers.nrows();
377 if m < 2 {
378 return Err(BasisError::InsufficientColumnsForConstraint { found: m });
379 }
380 let mut dists: Vec<f64> = Vec::with_capacity(m * (m - 1) / 2);
381 for i in 0..m {
382 for j in (i + 1)..m {
383 let mut s = 0.0_f64;
384 for k in 0..centers.ncols() {
385 let dlt = centers[(i, k)] - centers[(j, k)];
386 s += dlt * dlt;
387 }
388 dists.push(2.0 * s.sqrt());
389 }
390 }
391 dists.sort_by(|a, b| a.partial_cmp(b).expect("finite chart distances"));
392 let median = dists[dists.len() / 2];
393 if !(median.is_finite() && median > 0.0) {
394 crate::bail_invalid_basis!(
395 "constant-curvature auto length_scale failed: centers are degenerate \
396 (median pairwise chart distance = {median})"
397 );
398 }
399 Ok(median)
400}
401
402/// Reference kernel "fill" `fill⋆` — the κ = 0 mean data→center kernel entry
403/// `(1/N) Σᵢⱼ exp(−d₀(xᵢ,cⱼ)/ℓ_ref)` with `d₀ = 2‖Δ‖` the κ = 0 chart gauge.
404///
405/// The fill is the scalar that measures the kernel's *effective resolution* (how
406/// much each data row "sees" the centers): it is monotone in `ℓ/scale`, so
407/// pinning it across κ pins the realized design's flexibility (its effective
408/// degrees of freedom). [`constant_curvature_effective_length_jet`] solves
409/// `g(L,κ) = fill⋆` for `L(κ)` so the fill — hence the basis flexibility — stays
410/// κ-invariant and only the distance-matrix SHAPE (the genuine curvature signal)
411/// moves with κ. At κ = 0 the solution is `L = ℓ_ref` by construction.
412pub(crate) fn data_center_reference_fill(
413 data: ArrayView2<'_, f64>,
414 centers: ArrayView2<'_, f64>,
415 ell_ref: f64,
416) -> Result<f64, BasisError> {
417 if !(ell_ref.is_finite() && ell_ref > 0.0) {
418 crate::bail_invalid_basis!(
419 "constant-curvature reference fill needs a positive finite ℓ_ref; got {ell_ref}"
420 );
421 }
422 let mut sum = 0.0_f64;
423 let mut cnt = 0.0_f64;
424 for xi in data.outer_iter() {
425 for cj in centers.outer_iter() {
426 let mut s = 0.0_f64;
427 for k in 0..centers.ncols() {
428 let dlt = xi[k] - cj[k];
429 s += dlt * dlt;
430 }
431 let d0 = 2.0 * s.sqrt(); // κ = 0 chart gauge d₀ = 2‖Δ‖
432 sum += (-d0 / ell_ref).exp();
433 cnt += 1.0;
434 }
435 }
436 if cnt <= 0.0 {
437 crate::bail_invalid_basis!(
438 "constant-curvature reference fill needs at least one data row and one center"
439 );
440 }
441 Ok(sum / cnt)
442}
443
444/// The mean-kernel-entry "fill" `g(L,κ) = (1/N) Σᵢⱼ exp(−d_κ(xᵢ,cⱼ)/L)` together
445/// with the five partials needed by the implicit-function jet:
446/// `(g, g_L, g_κ, g_LL, g_κκ, g_Lκ)`.
447///
448/// With `k = exp(−d/L)` and the per-pair geodesic jet `(d, d', d'')` (exact via
449/// [`distance_kappa_jet`]):
450///
451/// ```text
452/// ∂k/∂L = k·d/L², ∂k/∂κ = −k·d'/L
453/// g_LL = (1/N)Σ k·d·(d − 2L)/L⁴
454/// g_κκ = (1/N)Σ k·((d')²/L − d'')/L
455/// g_Lκ = (1/N)Σ k·d'·(L − d)/L³
456/// ```
457///
458/// (each obtained by differentiating `∂k/∂L` / `∂k/∂κ` once more). `g` and every
459/// partial are smooth through κ = 0 because the distance jet is entire there.
460pub(crate) fn data_center_fill_partials(
461 data: ArrayView2<'_, f64>,
462 centers: ArrayView2<'_, f64>,
463 kappa: f64,
464 l: f64,
465) -> Result<(f64, f64, f64, f64, f64, f64), BasisError> {
466 if !(l.is_finite() && l > 0.0) {
467 crate::bail_invalid_basis!(
468 "constant-curvature fill partials need a positive finite length; got {l}"
469 );
470 }
471 let manifold = ConstantCurvature::new(centers.ncols(), kappa);
472 let l2 = l * l;
473 let l3 = l2 * l;
474 let l4 = l2 * l2;
475 let mut g = 0.0_f64;
476 let mut g_l = 0.0_f64;
477 let mut g_k = 0.0_f64;
478 let mut g_ll = 0.0_f64;
479 let mut g_kk = 0.0_f64;
480 let mut g_lk = 0.0_f64;
481 let mut cnt = 0.0_f64;
482 for xi in data.outer_iter() {
483 for cj in centers.outer_iter() {
484 let (d, d1, d2) = distance_kappa_jet(&manifold, xi, cj).map_err(|e| {
485 BasisError::InvalidInput(format!(
486 "constant-curvature data→center fill κ-jet failed: {e}"
487 ))
488 })?;
489 let k = (-d / l).exp();
490 g += k;
491 g_l += k * d / l2;
492 g_k += -k * d1 / l;
493 g_ll += k * d * (d - 2.0 * l) / l4;
494 g_kk += k * ((d1 * d1) / l - d2) / l;
495 g_lk += k * d1 * (l - d) / l3;
496 cnt += 1.0;
497 }
498 }
499 if cnt <= 0.0 {
500 crate::bail_invalid_basis!(
501 "constant-curvature fill partials need at least one data row and one center"
502 );
503 }
504 Ok((
505 g / cnt,
506 g_l / cnt,
507 g_k / cnt,
508 g_ll / cnt,
509 g_kk / cnt,
510 g_lk / cnt,
511 ))
512}
513
514/// Effective kernel length `L(κ)` and its EXACT κ-jet `(L, L′, L″)`.
515///
516/// THE κ-IDENTIFICATION FIX (#944). A κ-FROZEN length makes the geodesic-
517/// exponential kernel's *resolution* drift with κ: spherical (κ>0) geometries
518/// compress geodesic distances, narrowing the kernel relative to the data and
519/// inflating the basis's effective flexibility, so REML buys a lower deviance by
520/// cranking κ up — κ rails to the chart bound for every truth (the #944/#1059
521/// symptom). The earlier #1059 fix normalized by the mean data→center geodesic
522/// distance `s_dc(κ)`; but holding the mean DISTANCE fixed does NOT hold the
523/// kernel's flexibility fixed — the effective degrees of freedom still drift
524/// ~30% across the bracket (verified), so the deviance stayed monotone in κ.
525///
526/// We instead hold the kernel's "fill" — the mean realized kernel entry
527/// `g(L,κ) = (1/N) Σᵢⱼ exp(−d_κ(xᵢ,cⱼ)/L)` — κ-INVARIANT, which pins the
528/// realized design's effective degrees of freedom (the EDF is flat to <0.5% in κ
529/// under this rule, verified numerically). `L(κ)` is the implicit solution of
530///
531/// ```text
532/// g(L(κ), κ) = fill⋆, fill⋆ = g(ℓ_ref, 0) (the κ=0 reference fill)
533/// ```
534///
535/// so changing κ moves ONLY the distance-matrix SHAPE (the genuine curvature
536/// signal), giving `V_p(κ)` an interior minimum at the data-generating κ for
537/// curved truth. At κ = 0 the solution is `L = ℓ_ref` exactly.
538///
539/// The jet is EXACT via the implicit-function theorem. Differentiating
540/// `g(L(κ),κ) ≡ fill⋆` once gives `g_L·L′ + g_κ = 0`, and once more gives
541/// `g_LL·(L′)² + 2 g_Lκ·L′ + g_κκ + g_L·L″ = 0`:
542///
543/// ```text
544/// L′ = −g_κ / g_L
545/// L″ = −( g_LL·(L′)² + 2 g_Lκ·L′ + g_κκ ) / g_L .
546/// ```
547///
548/// The partials come from [`data_center_fill_partials`] (exact, riding
549/// `distance_kappa_jet`); the returned jet feeds `constant_curvature_kernel_
550/// kappa_jets_scaled` through the quotient `q = d/L` chain rule.
551///
552/// Public scalar view of the κ-invariant effective kernel length `L(κ)` that the
553/// realized constant-curvature design/penalty are built at (the #944 fill-
554/// invariance fix). The forward build evaluates the geodesic-exponential kernel
555/// at this `L(κ)`, NOT at the κ = 0 reference length `ell_ref`, so any external
556/// consumer reconstructing `K(·)` to compare against the realized design must
557/// use this length. Equals `ell_ref` exactly at κ = 0.
558pub fn constant_curvature_effective_length(
559 data: ArrayView2<'_, f64>,
560 centers: ArrayView2<'_, f64>,
561 ell_ref: f64,
562 kappa: f64,
563) -> Result<f64, BasisError> {
564 Ok(constant_curvature_effective_length_jet(data, centers, ell_ref, kappa)?.0)
565}
566
567pub(crate) fn constant_curvature_effective_length_jet(
568 data: ArrayView2<'_, f64>,
569 centers: ArrayView2<'_, f64>,
570 ell_ref: f64,
571 kappa: f64,
572) -> Result<(f64, f64, f64), BasisError> {
573 let fill_star = data_center_reference_fill(data, centers, ell_ref)?;
574 // Newton solve g(L, κ) = fill⋆ for L, warm-started at ℓ_ref (the exact root
575 // at κ = 0). g is strictly increasing in L (g_L > 0: larger L ⇒ each entry
576 // closer to 1), so Newton from ℓ_ref converges monotonically.
577 let mut l = ell_ref;
578 const NEWTON_MAX_ITER: usize = 100;
579 const NEWTON_REL_TOL: f64 = 1.0e-13;
580 let mut converged = false;
581 for _ in 0..NEWTON_MAX_ITER {
582 let (g, g_l, ..) = data_center_fill_partials(data, centers, kappa, l)?;
583 if !(g_l.is_finite() && g_l > 0.0) {
584 crate::bail_invalid_basis!(
585 "constant-curvature effective length: non-positive fill slope g_L = {g_l} \
586 (degenerate data/centers at κ = {kappa})"
587 );
588 }
589 let step = (g - fill_star) / g_l;
590 l -= step;
591 if !(l.is_finite() && l > 0.0) {
592 crate::bail_invalid_basis!(
593 "constant-curvature effective length: Newton left the positive axis (L = {l}) \
594 solving the fill target at κ = {kappa}"
595 );
596 }
597 if step.abs() <= NEWTON_REL_TOL * l {
598 converged = true;
599 break;
600 }
601 }
602 if !converged {
603 crate::bail_invalid_basis!(
604 "constant-curvature effective length: fill-target Newton did not converge at κ = {kappa}"
605 );
606 }
607 // Exact implicit-function-theorem jet at the converged root.
608 let (_, g_l, g_k, g_ll, g_kk, g_lk) = data_center_fill_partials(data, centers, kappa, l)?;
609 let l1 = -g_k / g_l;
610 let l2 = -(g_ll * l1 * l1 + 2.0 * g_lk * l1 + g_kk) / g_l;
611 Ok((l, l1, l2))
612}
613
614/// Build the constant-curvature reproducing-kernel smooth: realized design
615/// `K_κ(data, centers)·z`, RKHS penalty `zᵀ K_κ(centers, centers) z`, and the
616/// replayable [`BasisMetadata::ConstantCurvature`]. Structure mirrors the
617/// Wahba S² builder (`build_spherical_spline_basis`); geometry comes from
618/// `ConstantCurvature` at the spec's fixed κ.
619pub fn build_constant_curvature_basis(
620 data: ArrayView2<'_, f64>,
621 spec: &ConstantCurvatureBasisSpec,
622) -> Result<BasisBuildResult, BasisError> {
623 if data.ncols() == 0 {
624 crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
625 }
626 if !spec.kappa.is_finite() {
627 crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
628 }
629 validate_chart_points(data, spec.kappa, "data")?;
630 let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
631 if centers.nrows() < 2 {
632 return Err(BasisError::InsufficientColumnsForConstraint {
633 found: centers.nrows(),
634 });
635 }
636 validate_chart_points(centers.view(), spec.kappa, "centers")?;
637 // ℓ_ref is the κ = 0 reference length (auto = mean chart spacing, or the
638 // user/frozen value); the kernel uses the κ-invariant effective length
639 // L(κ) = ℓ_ref·s(κ)/s₀ so changing κ moves the geometry, not the kernel
640 // resolution (the #1059 curvature-identification fix). At κ = 0, L = ℓ_ref.
641 let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
642 // DESIGN effective length L(κ): solved against the DATA→center fill so the
643 // realized design's effective DOF stays κ-invariant (#944/#1059). The design
644 // X = K(data, centers)·z is built at this L.
645 let (ell_eff, _, _) =
646 constant_curvature_effective_length_jet(data, centers.view(), length_scale, spec.kappa)?;
647 // PENALTY effective length L_S(κ): solved against the CENTER→center fill so
648 // the penalty Gram S = zᵀK(centers,centers)z has a κ-INVARIANT resolution
649 // (#1464). The data→center fill that pins L(κ) does NOT pin the center→center
650 // penalty spectrum, so with the single shared L the penalty pseudo-determinant
651 // logdet|S|₊ drifts freely with κ: as κ grows positive the geodesic kernel
652 // collapses toward the constant, the center→center Gram eigenvalues bunch /
653 // drop below the rank tolerance, logdet|S|₊ falls, and the REML Occam term
654 // −½·logdet|S|₊ DECREASES — rewarding the +κ collapsed-kernel corner and
655 // railing κ̂ to the +chart bound for any curved data (the headline #1464
656 // sign-blindness: hyperbolic truth recovered as spherical, V_p(κ) monotone in
657 // κ with no interior optimum). Building the penalty at L_S(κ) holds the
658 // penalty eigenvalue SHAPE (hence logdet|S|₊ and its rank) κ-comparable, so
659 // the Occam term stops rewarding the collapse and V_p regains an interior
660 // minimum near the data-generating κ. At κ = 0, L_S = ℓ_ref = L, so the κ = 0
661 // build is byte-identical.
662 let (ell_eff_penalty, _, _) = constant_curvature_effective_length_jet(
663 centers.view(),
664 centers.view(),
665 length_scale,
666 spec.kappa,
667 )?;
668 let raw_penalty = constant_curvature_kernel_matrix(
669 centers.view(),
670 centers.view(),
671 spec.kappa,
672 ell_eff_penalty,
673 )?;
674 // Realized-design constraint transform: uniform coefficient sum-to-zero at
675 // fit time; the frozen composed `z · z_parametric` at predict time (#532
676 // pattern — see ConstantCurvatureIdentifiability).
677 let z = match &spec.identifiability {
678 ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
679 if transform.nrows() != centers.nrows() {
680 crate::bail_dim_basis!(
681 "frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
682 centers.nrows(),
683 transform.nrows()
684 );
685 }
686 transform.clone()
687 }
688 ConstantCurvatureIdentifiability::CenterSumToZero => {
689 let weights = Array1::<f64>::ones(centers.nrows());
690 weighted_coefficient_sum_to_zero_transform(weights.view())?
691 }
692 };
693 let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
694 let penalty = gauge.restrict_penalty(&raw_penalty);
695 let raw_design = constant_curvature_kernel_matrix(data, centers.view(), spec.kappa, ell_eff)?;
696 let design = gam_linalg::matrix::DesignMatrix::Dense(
697 gam_linalg::matrix::DenseDesignMatrix::from(gauge.restrict_design(&raw_design)),
698 );
699 // Keep the RKHS penalty RAW (the symmetric kernel Gram zᵀKz) with
700 // normalization_scale = 1, rather than Frobenius-normalizing it. The Gram's
701 // eigenvalues ARE the physical RKHS roughness energies of each coefficient
702 // direction: the smoothest functions (the low-degree / degree-1 signal) sit
703 // in the genuinely tiny-eigenvalue directions, while wiggly functions sit in
704 // the large ones — a spread of many orders of magnitude. Frobenius-
705 // normalizing divides the whole operator by ‖·‖_F (dominated by the large
706 // wiggly eigenvalues), which compresses that spread and inflates the
707 // smallest eigenvalues relative to their natural scale. REML's scale-
708 // sensitive λ heuristics then drive a single λ high enough to suppress the
709 // wiggly directions and, because the smooth directions are no longer
710 // proportionally tiny, over-shrink the recoverable low-degree signal
711 // (planted degree-1 sphere harmonic recovered at only R²≈0.84). Keeping the
712 // raw physical operator (scale = 1, matching the sphere-harmonic Laplace-
713 // Beltrami penalty) lets REML act on true roughness, leaving the smooth
714 // signal essentially unpenalized while still shrinking the wiggly tail —
715 // raising recovery toward the unconstrained RKHS ceiling. The penalty stays
716 // exactly proportional to zᵀKz, so the constrained-kernel-Gram contract is
717 // unchanged.
718 let penalty_sym = (&penalty + &penalty.t()) * 0.5;
719 let mut candidates = vec![PenaltyCandidate {
720 matrix: penalty_sym,
721 nullspace_dim_hint: 0,
722 source: PenaltySource::Primary,
723 normalization_scale: 1.0,
724 kronecker_factors: None,
725 op: None,
726 }];
727 if spec.double_penalty {
728 // #1531: identity ridge is CORRECT here, NOT the nullspace-shrinkage ridge
729 // the sibling bases (sphere_basis / matern_kernel / duchon_thinplate) use.
730 // The Marra & Wood double penalty shrinks the NULL SPACE of the primary
731 // penalty so REML can drive an unsupported term to EDF→0. But the primary
732 // here is the RKHS kernel Gram zᵀKz, which is strictly PD / full-rank on
733 // distinct centers (see the `double_penalty: false` default note above): it
734 // has NO null space. `build_nullspace_shrinkage_penalty(&primary)` returns
735 // `Ok(None)` for a full-rank input, so matching the sibling pattern would
736 // make an explicit `double_penalty = true` a silent no-op. The full identity
737 // is the only second shrinkage coordinate that is actually selectable on a
738 // null-space-free primary, so it is what an explicit double penalty must use.
739 // The regression test `constant_curvature_gram_is_full_rank_so_identity_is_the_only_double_penalty`
740 // locks the full-rank fact that justifies this; if a future basis change
741 // gives the Gram a genuine null space, that test fails and this branch must
742 // be revisited (switch to `build_nullspace_shrinkage_penalty`).
743 let ridge = Array2::<f64>::eye(design.ncols());
744 let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
745 candidates.push(PenaltyCandidate {
746 matrix: ridge_norm,
747 nullspace_dim_hint: 0,
748 source: PenaltySource::DoublePenaltyNullspace,
749 normalization_scale: c_ridge,
750 kronecker_factors: None,
751 op: None,
752 });
753 }
754 let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
755 filter_active_penalty_candidates_with_ops(candidates)?;
756 Ok(BasisBuildResult {
757 design,
758 penalties,
759 nullspace_dims,
760 penaltyinfo,
761 metadata: BasisMetadata::ConstantCurvature {
762 centers,
763 kappa: spec.kappa,
764 length_scale,
765 constraint_transform: Some(z),
766 },
767 kronecker_factored: None,
768 ops,
769 null_eigenvectors,
770 joint_null_rotation: None,
771 })
772}
773
774/// Closed-form profiled Gaussian-REML negative-log-evidence of a dense design
775/// `b` (n×p) against response `y`, with an UNPENALIZED intercept column appended
776/// and the symmetric psd RKHS penalty `s` (p×p) profiled over a dense log-λ grid.
777/// `min_λ D(λ)` with
778/// `D(λ) = (n−Mp)·log(rss/(n−Mp)) + log|HᵀH| − log|λS|₊`,
779/// `H = [1|b]ᵀ[1|b] + λ·diag(0,S)`, `Mp = 1 + nullity(S)` (the intercept is in the
780/// null space). Self-contained — the same criterion shape the in-crate oracle
781/// `profiled_gaussian_reml_deviance` certifies, with the production intercept the
782/// full GAM always carries (so it matches what the fit path sees).
783fn profiled_reml_with_intercept(b: &Array2<f64>, y: &Array1<f64>, s: &Array2<f64>) -> f64 {
784 use gam_linalg::faer_ndarray::FaerEigh;
785 let n = b.nrows();
786 let p = b.ncols();
787 // Augmented design [1 | b] and zero-padded penalty diag(0, S).
788 let mut ba = Array2::<f64>::zeros((n, p + 1));
789 for i in 0..n {
790 ba[(i, 0)] = 1.0;
791 for j in 0..p {
792 ba[(i, j + 1)] = b[(i, j)];
793 }
794 }
795 let mut sa = Array2::<f64>::zeros((p + 1, p + 1));
796 for i in 0..p {
797 for j in 0..p {
798 sa[(i + 1, j + 1)] = s[(i, j)];
799 }
800 }
801 let pa = p + 1;
802 let btb = symmetrize(&ba.t().dot(&ba));
803 let bty = ba.t().dot(y);
804 let (s_evals, _) = FaerEigh::eigh(&symmetrize(&sa), faer::Side::Lower)
805 .expect("κ-fair penalty eigendecomposition");
806 let s_max = s_evals.iter().cloned().fold(0.0_f64, f64::max).max(1e-300);
807 let s_tol = s_max * 1e-9;
808 let r = s_evals.iter().filter(|&&e| e > s_tol).count();
809 let m_p = pa - r;
810 let dof = (n - m_p) as f64;
811 let log_det_s_plus: f64 = s_evals
812 .iter()
813 .filter(|&&e| e > s_tol)
814 .map(|&e| e.ln())
815 .sum();
816 let mut best = f64::INFINITY;
817 for k in -24i32..=24 {
818 let lam = (0.5 * f64::from(k)).exp();
819 let h = symmetrize(&(&btb + &(sa.mapv(|v| v * lam))));
820 let h_ridge = &h + &(Array2::<f64>::eye(pa) * (1e-10 * s_max.max(1.0)));
821 let (hv, hq) = FaerEigh::eigh(&symmetrize(&h_ridge), faer::Side::Lower)
822 .expect("κ-fair penalized-Hessian eigendecomposition");
823 let qty = hq.t().dot(&bty);
824 let mut beta = Array1::<f64>::zeros(pa);
825 let mut log_det_h = 0.0_f64;
826 for i in 0..pa {
827 let ev = hv[i].max(1e-300);
828 log_det_h += ev.ln();
829 let coef = qty[i] / ev;
830 for j in 0..pa {
831 beta[j] += hq[(j, i)] * coef;
832 }
833 }
834 let resid = y - &ba.dot(&beta);
835 let rss = resid.dot(&resid).max(1e-300);
836 let log_det_lam_s = (r as f64) * lam.ln() + log_det_s_plus;
837 let dev = dof * (rss / dof).ln() + log_det_h - log_det_lam_s;
838 if dev < best {
839 best = dev;
840 }
841 }
842 best
843}
844
845/// #1464: the **κ-fair** sign-resolving score for a constant-curvature smooth at
846/// a fixed κ — the production datum the sign-basin scan minimizes to choose the
847/// curvature SIGN basin.
848///
849/// THE DATA-FIT κ-FAIRNESS FIX. The L(κ)/L_S(κ) effective-length reparam already
850/// holds the kernel FILL and the penalty Occam term κ-invariant (#944/#1464
851/// penalty fix), but the realized profiled-REML DATA-FIT term is still sign-blind:
852/// on a generic center-peaked radial signal the +κ chart's geodesic-distance
853/// COMPRESSION concentrates the design's singular-value mass into the leading
854/// (low-order radial) modes — a uniformly better interpolator of ANY radial peak,
855/// regardless of the true curvature sign — so `V_p(κ)` decreases monotonically
856/// toward the +chart bound for BOTH spherical and hyperbolic truth (hyperbolic
857/// recovered as spherical, κ̂ railed to +0.5/max‖x‖²). Holding the EDF / hat-trace
858/// or ‖X‖_F κ-invariant does NOT cure it: the advantage is the per-direction
859/// REDISTRIBUTION of approximation power, not its total scale (verified — the EDF
860/// is already κ-invariant to <1% under L(κ), yet RSS still falls toward +κ).
861///
862/// The cure makes the comparison apples-to-apples by SUBTRACTING the design's
863/// GENERIC radial-peak-fitting power at this κ. We measure that generic power with
864/// a bank of κ-INDEPENDENT reference signals `r_α(i) = exp(−α·‖x_i‖)` — radial in
865/// the Euclidean chart coordinate, so carrying NO curvature-sign preference — and
866/// score
867///
868/// ```text
869/// V_fair(κ) = V_p(κ; y) − mean_α V_p(κ; r_α) .
870/// ```
871///
872/// The generic +κ interpolation advantage cancels between the two terms (it lifts
873/// `V_p(κ; y)` and `V_p(κ; r_α)` by the same amount), leaving only the GENUINE
874/// curvature-shape alignment of the actual data `y` with the κ-geometry. The bank
875/// (several α widths, averaged) removes the residual sensitivity of any single
876/// reference width to the data realization, so `argmin_κ V_fair` lands on the
877/// correct SIDE of 0 for both signs (spherical κ̂ > 0, hyperbolic κ̂ < 0) across
878/// seeds. The reference correction enters ONLY the sign-basin SELECTION; the
879/// realized fit and the magnitude/CI keep using the raw `V_p`, so the κ = 0 build
880/// and the final coefficients are untouched.
881///
882/// Builds the design `X = K_κ(data, centers)·z` at the data→center effective
883/// length `L(κ)` and the penalty `S = symm(zᵀK_κ(centers,centers)z)` at the
884/// center→center effective length `L_S(κ)`, exactly as
885/// [`build_constant_curvature_basis`] (raw RKHS Gram, scale = 1, intercept
886/// appended unpenalized), so the criterion the scan minimizes is the production
887/// design's own profiled REML.
888/// Build the realized constant-curvature profile design `B = K_κ(data,
889/// centers)·z` and penalty `S = symm(zᵀK_κ(centers,centers)z)` at the fixed κ in
890/// `spec`, EXACTLY as [`build_constant_curvature_basis`] does (same centers, same
891/// κ-invariant effective lengths `L(κ)`/`L_S(κ)`, same center-sum-to-zero `z`,
892/// raw RKHS Gram penalty). Shared by the honest profiled-REML κ-profile score and
893/// the κ-fair sign score so both probe the production design's own criterion.
894fn constant_curvature_profile_design_penalty(
895 data: ArrayView2<'_, f64>,
896 spec: &ConstantCurvatureBasisSpec,
897) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
898 if data.ncols() == 0 {
899 crate::bail_invalid_basis!(
900 "constant-curvature profile score needs at least one feature column"
901 );
902 }
903 if !spec.kappa.is_finite() {
904 crate::bail_invalid_basis!("constant-curvature profile score needs a finite kappa");
905 }
906 validate_chart_points(data, spec.kappa, "data")?;
907 let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
908 if centers.nrows() < 2 {
909 return Err(BasisError::InsufficientColumnsForConstraint {
910 found: centers.nrows(),
911 });
912 }
913 validate_chart_points(centers.view(), spec.kappa, "centers")?;
914 let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
915 // Design effective length L(κ) (data→center fill) and penalty effective
916 // length L_S(κ) (center→center fill) — identical to the value builder.
917 let (ell_eff, _, _) =
918 constant_curvature_effective_length_jet(data, centers.view(), length_scale, spec.kappa)?;
919 let (ell_eff_penalty, _, _) = constant_curvature_effective_length_jet(
920 centers.view(),
921 centers.view(),
922 length_scale,
923 spec.kappa,
924 )?;
925 let weights = Array1::<f64>::ones(centers.nrows());
926 let z = weighted_coefficient_sum_to_zero_transform(weights.view())?;
927 let gauge = gam_problem::Gauge::from_block_transforms(&[z]);
928 let raw_design = constant_curvature_kernel_matrix(data, centers.view(), spec.kappa, ell_eff)?;
929 let b = gauge.restrict_design(&raw_design);
930 let raw_penalty = constant_curvature_kernel_matrix(
931 centers.view(),
932 centers.view(),
933 spec.kappa,
934 ell_eff_penalty,
935 )?;
936 let s = symmetrize(&gauge.restrict_penalty(&raw_penalty));
937 Ok((b, s))
938}
939
940/// #1464: the **honest** fixed-κ profiled-REML score `V_p(κ)` for a
941/// constant-curvature smooth — the textbook Gaussian profiled-REML
942/// negative-log-evidence of the realized design `B = K_κ(data,centers)·z` against
943/// `y`, with the unpenalized intercept appended and the raw RKHS Gram penalty `S`
944/// profiled over λ (`profiled_reml_with_intercept`). This is the criterion whose
945/// argmin over the chart-bounded κ window IDENTIFIES the curvature, and the one
946/// `curvature_inference_forspec` walks for the magnitude CI and the κ = 0 flatness
947/// LR test.
948///
949/// Why this, not the production full-fit `reml_score`: the production REML's
950/// λ-selection heavily SMOOTHS this RKHS kernel (deviance ≫ near-interpolation
951/// RSS), and under heavy smoothing the +κ chart's geodesic-distance COMPRESSION
952/// makes the collapsed kernel fit the over-smoothed target better for ANY data —
953/// so the production `reml_score` is monotone toward the +chart bound regardless
954/// of the true sign (the headline #1464 sign-blindness, and an over-smoothing of
955/// the curvature criterion specifically). The honest profiled REML keeps the
956/// curvature-shape signal in the data fit (the κ that matches the geodesic
957/// geometry minimizes RSS), so its argmin lands on the correct sign, and because
958/// it is a proper profiled-REML deviance the LR/CI thresholds stay χ²-calibrated.
959/// On genuinely flat (constant-mean) data the criterion is ~flat in κ (the
960/// intercept absorbs the mean at every κ), giving the flatness test correct size.
961pub fn constant_curvature_honest_profiled_reml_score(
962 data: ArrayView2<'_, f64>,
963 y: ArrayView1<'_, f64>,
964 spec: &ConstantCurvatureBasisSpec,
965) -> Result<f64, BasisError> {
966 if y.len() != data.nrows() {
967 crate::bail_dim_basis!(
968 "constant-curvature profiled-REML score: y has {} rows but data has {}",
969 y.len(),
970 data.nrows()
971 );
972 }
973 let (b, s) = constant_curvature_profile_design_penalty(data, spec)?;
974 let v = profiled_reml_with_intercept(&b, &y.to_owned(), &s);
975 if !v.is_finite() {
976 crate::bail_invalid_basis!(
977 "constant-curvature honest profiled-REML score at κ={} is non-finite",
978 spec.kappa
979 );
980 }
981 Ok(v)
982}
983
984pub fn constant_curvature_kappa_fair_sign_score(
985 data: ArrayView2<'_, f64>,
986 y: ArrayView1<'_, f64>,
987 spec: &ConstantCurvatureBasisSpec,
988) -> Result<f64, BasisError> {
989 if y.len() != data.nrows() {
990 crate::bail_dim_basis!(
991 "constant-curvature κ-fair score: y has {} rows but data has {}",
992 y.len(),
993 data.nrows()
994 );
995 }
996 let (b, s) = constant_curvature_profile_design_penalty(data, spec)?;
997
998 let v_y = profiled_reml_with_intercept(&b, &y.to_owned(), &s);
999
1000 // CURVATURE-NEUTRAL, ENERGY-MATCHED reference: a COARSE radial profile of the
1001 // data. The +κ chart compresses geodesic distances so the geodesic-
1002 // exponential kernel is a uniformly better interpolator of any radial signal
1003 // regardless of the true curvature sign; this generic interpolation advantage
1004 // lifts `V_p(κ)` monotonically toward +κ and must be cancelled so only the
1005 // genuine curvature-shape signal drives the sign. The reference that cancels
1006 // it is one carrying the same gross radial energy as the data but no fine
1007 // κ-geometry: `y_ref(i)` = mean of `y` over a SMALL number of Euclidean-radius
1008 // bins. The bin count is deliberately coarse: enough bins to track the data's
1009 // radial trend (so the +κ tilt cancels and a genuinely FLAT truth scores
1010 // ~symmetrically in κ — its response is already a function of `‖x‖` alone, so
1011 // `y_ref ≈ y` and the criterion refuses to prefer a sign), but few enough that
1012 // the profile CANNOT reproduce the data-generating `d_κ⋆` curvature shape — so
1013 // for a curved truth the residual `V_p(κ;y) − V_p(κ;y_ref)` still wells toward
1014 // the data-generating sign. A fine profile would absorb the curvature signal
1015 // (the radial truth is nearly a function of `‖x‖`); a fixed exp(−α‖x‖) bank
1016 // does not match the data's radial energy and leaves a strong residual −κ tilt.
1017 // The coarse matched profile shrinks that tilt to a small noise-overfit
1018 // residual (the geodesic kernel overfits noise slightly more in the hyperbolic
1019 // chart), so on a CURVED truth the genuine signal dominates and the argmin sign
1020 // is correct. A residual flat-data tilt remains, so this term alone does NOT
1021 // fully separate flat (κ ≈ 0) from hyperbolic (κ < 0); the caller adopts the
1022 // argmin only for the negative (hyperbolic) sign and leaves the spherical and
1023 // (residual-tilt) flat cases to the joint solver / κ ≈ 0 path.
1024 let radii: Array1<f64> = data.outer_iter().map(|row| row.dot(&row).sqrt()).collect();
1025 const N_RADIAL_BINS: usize = 10;
1026 let r_max = radii.iter().cloned().fold(0.0_f64, f64::max).max(1e-12);
1027 let bin_of = |r: f64| -> usize {
1028 (((r / r_max) * N_RADIAL_BINS as f64) as usize).min(N_RADIAL_BINS - 1)
1029 };
1030 let mut bin_sum = [0.0_f64; N_RADIAL_BINS];
1031 let mut bin_cnt = [0.0_f64; N_RADIAL_BINS];
1032 for (i, &r) in radii.iter().enumerate() {
1033 let b_idx = bin_of(r);
1034 bin_sum[b_idx] += y[i];
1035 bin_cnt[b_idx] += 1.0;
1036 }
1037 let bin_mean: Vec<f64> = bin_sum
1038 .iter()
1039 .zip(bin_cnt.iter())
1040 .map(|(&s, &c)| if c > 0.0 { s / c } else { 0.0 })
1041 .collect();
1042 let y_ref: Array1<f64> = radii.mapv(|r| bin_mean[bin_of(r)]);
1043
1044 let v_ref = profiled_reml_with_intercept(&b, &y_ref, &s);
1045
1046 let v_fair = v_y - v_ref;
1047 if !v_fair.is_finite() {
1048 crate::bail_invalid_basis!(
1049 "constant-curvature κ-fair score at κ={} is non-finite (V_y={v_y}, V_ref={v_ref})",
1050 spec.kappa
1051 );
1052 }
1053 Ok(v_fair)
1054}
1055
1056/// Symmetrize `M` in place to `(M + Mᵀ)/2` (the realized penalty is built from
1057/// the symmetric kernel Gram; the κ-derivative blocks inherit the same exact
1058/// symmetrization the value path applies before normalization).
1059pub(crate) fn symmetrize(m: &Array2<f64>) -> Array2<f64> {
1060 gam_linalg::matrix::symmetrize(m)
1061}
1062
1063/// Map a single primary-penalty κ-derivative onto the active penalty list by
1064/// source — the constant-curvature analogue of the Matérn double-penalty
1065/// derivative selector. The RKHS Gram is the only κ-moving penalty; the
1066/// double-penalty ridge `I` is κ-independent, so its derivative is exactly
1067/// zero. Any other source would mean the basis grew a penalty whose κ-movement
1068/// is unaccounted for, so we refuse loudly rather than silently drop a term.
1069pub(crate) fn active_constant_curvature_penalty_derivatives(
1070 penaltyinfo: &[PenaltyInfo],
1071 primary_derivative: &Array2<f64>,
1072) -> Result<Vec<Array2<f64>>, BasisError> {
1073 penaltyinfo
1074 .iter()
1075 .filter(|info| info.active)
1076 .map(|info| match &info.source {
1077 PenaltySource::Primary => Ok(primary_derivative.clone()),
1078 PenaltySource::DoublePenaltyNullspace => {
1079 Ok(Array2::<f64>::zeros(primary_derivative.raw_dim()))
1080 }
1081 other => Err(BasisError::InvalidInput(format!(
1082 "unexpected constant-curvature penalty source in κ-derivative path: {other:?}"
1083 ))),
1084 })
1085 .collect()
1086}
1087
1088/// κ-derivative bundle for the constant-curvature smooth — the ψ-channel hook
1089/// that lets κ join the outer LAML/REML optimization as one signed,
1090/// design-moving coordinate (#944 stage 3 final wiring).
1091///
1092/// The outer optimizer's ψ-coordinate here is the **raw, signed curvature κ
1093/// itself** (NOT `log κ` as for the Matérn kernel scale): κ = 0 must be a
1094/// reachable interior point of the `S^d ← ℝ^d → H^d` family, which `log κ`
1095/// cannot represent. So this returns `∂·/∂κ` and `∂²·/∂κ²` directly, and the
1096/// outer assembly treats the coordinate as `ψ = κ` with `∂/∂ψ = ∂/∂κ`.
1097///
1098/// Every κ-fixed piece (centers, length scale ℓ, the center-space constraint
1099/// transform `z`) is held constant exactly as documented in the module
1100/// κ-contract, so the design moves with κ only through the geodesic-exponential
1101/// kernel and:
1102///
1103/// ```text
1104/// X = K(data, centers)·z ⇒ ∂X/∂κ = (∂K_dc/∂κ)·z,
1105/// ∂²X/∂κ² = (∂²K_dc/∂κ²)·z
1106/// S_raw = symm(zᵀ K(centers,centers) z)
1107/// ⇒ ∂S_raw/∂κ = symm(zᵀ(∂K_cc/∂κ)z), etc.
1108/// ```
1109///
1110/// and the Frobenius penalty normalization is differentiated with the exact
1111/// quotient rules through the shared `normalize_penaltywith_psi_derivatives`
1112/// seam — identical to how the Matérn operator penalties propagate their
1113/// normalization. The double-penalty ridge `I` is κ-independent (zero
1114/// derivative).
1115///
1116/// Mirrors [`build_constant_curvature_basis`] so the realized design and
1117/// penalties whose κ-derivatives this returns are byte-for-byte the same
1118/// construction the value path produced (same centers, same ℓ, same `z`).
1119pub fn build_constant_curvature_basis_kappa_derivatives(
1120 data: ArrayView2<'_, f64>,
1121 spec: &ConstantCurvatureBasisSpec,
1122) -> Result<BasisPsiDerivativeBundle, BasisError> {
1123 if data.ncols() == 0 {
1124 crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
1125 }
1126 if !spec.kappa.is_finite() {
1127 crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
1128 }
1129 validate_chart_points(data, spec.kappa, "data")?;
1130 let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
1131 if centers.nrows() < 2 {
1132 return Err(BasisError::InsufficientColumnsForConstraint {
1133 found: centers.nrows(),
1134 });
1135 }
1136 validate_chart_points(centers.view(), spec.kappa, "centers")?;
1137 let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
1138
1139 // κ-fixed constraint transform `z`, resolved exactly as the value builder.
1140 let z = match &spec.identifiability {
1141 ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
1142 if transform.nrows() != centers.nrows() {
1143 crate::bail_dim_basis!(
1144 "frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
1145 centers.nrows(),
1146 transform.nrows()
1147 );
1148 }
1149 transform.clone()
1150 }
1151 ConstantCurvatureIdentifiability::CenterSumToZero => {
1152 let weights = Array1::<f64>::ones(centers.nrows());
1153 weighted_coefficient_sum_to_zero_transform(weights.view())?
1154 }
1155 };
1156 let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
1157
1158 // Effective-length κ-jet L(κ) = ℓ_ref·s(κ)/s₀ (the κ-invariant-resolution
1159 // fix). The kernel exponent is q = d/L with BOTH d and L moving in κ, so the
1160 // kernel κ-jets carry the full quotient chain rule — see
1161 // `constant_curvature_kernel_kappa_jets_scaled`.
1162 let l_jet =
1163 constant_curvature_effective_length_jet(data, centers.view(), length_scale, spec.kappa)?;
1164
1165 // Design κ-jets: X = K(data, centers)·z, so the κ-derivatives are the
1166 // kernel κ-jets right-multiplied by the κ-fixed `z`.
1167 let (_k_dc, dk_dc, dkk_dc) =
1168 constant_curvature_kernel_kappa_jets_scaled(data, centers.view(), spec.kappa, l_jet)?;
1169 let design_first = gauge.restrict_design(&dk_dc);
1170 let design_second_diag = gauge.restrict_design(&dkk_dc);
1171
1172 // Penalty κ-jets: S = symm(zᵀ K(centers,centers) z), kept RAW (no Frobenius
1173 // normalization) exactly as the value builder now does (scale = 1). The raw
1174 // symmetric penalty's κ-derivatives are therefore the symmetrized restricted
1175 // kernel κ-jets DIRECTLY — there is no normalization quotient rule to
1176 // propagate, which also removes the κ-dependent ‖S‖_F factor that the
1177 // normalized form had to differentiate.
1178 //
1179 // The penalty kernel is built at the CENTER→center effective-length jet
1180 // L_S(κ) (#1464), NOT the design's data→center L(κ), so the analytic κ-gradient
1181 // of logdet|S|₊ stays EXACT for the penalty-resolution-invariant value build
1182 // above. q_S = d/L_S with both d and L_S moving in κ, so the quotient chain
1183 // rule inside `constant_curvature_kernel_kappa_jets_scaled` carries the L_S jet.
1184 let l_jet_penalty = constant_curvature_effective_length_jet(
1185 centers.view(),
1186 centers.view(),
1187 length_scale,
1188 spec.kappa,
1189 )?;
1190 let (_k_cc, dk_cc, dkk_cc) = constant_curvature_kernel_kappa_jets_scaled(
1191 centers.view(),
1192 centers.view(),
1193 spec.kappa,
1194 l_jet_penalty,
1195 )?;
1196 let s_first = symmetrize(&gauge.restrict_penalty(&dk_cc));
1197 let s_second = symmetrize(&gauge.restrict_penalty(&dkk_cc));
1198
1199 // Align the single primary-penalty derivative with the realized active
1200 // penalty list (primary always; ridge only when double_penalty, and
1201 // κ-independent). Rebuild the realized basis once to read `penaltyinfo`.
1202 let base = build_constant_curvature_basis(data, spec)?;
1203 let penalties_derivative =
1204 active_constant_curvature_penalty_derivatives(&base.penaltyinfo, &s_first)?;
1205 let penaltiessecond_derivative =
1206 active_constant_curvature_penalty_derivatives(&base.penaltyinfo, &s_second)?;
1207
1208 Ok(BasisPsiDerivativeBundle {
1209 first: BasisPsiDerivativeResult {
1210 design_derivative: design_first,
1211 penalties_derivative,
1212 implicit_operator: None,
1213 },
1214 second: BasisPsiSecondDerivativeResult {
1215 designsecond_derivative: design_second_diag,
1216 penaltiessecond_derivative,
1217 implicit_operator: None,
1218 },
1219 implicit_operator: None,
1220 })
1221}
1222
1223#[cfg(test)]
1224mod tests {
1225 use super::*;
1226 use gam_linalg::faer_ndarray::FaerEigh;
1227
1228 // Diagnostic (#1059 follow-up): show that a κ-FROZEN chart-scale length
1229 // makes the geodesic-exponential kernel COLLAPSE toward the constant
1230 // function as κ grows positive (sphere distances compress), which is the
1231 // degenerate optimum the REML criterion rails to. For a fixed center set we
1232 // print, per κ, the median geodesic distance and the kernel "spread"
1233 // 1 − mean(offdiag K). A collapsing kernel ⇒ spread → 0 as κ ↑.
1234 #[test]
1235 pub(crate) fn kernel_spread_collapses_with_kappa_at_frozen_length_scale() {
1236 // 8 centers in a disk of radius 0.45 (inside every κ∈[-2,2] chart).
1237 let centers = ndarray::array![
1238 [0.10, 0.05],
1239 [-0.20, 0.15],
1240 [0.30, -0.10],
1241 [-0.05, -0.25],
1242 [0.22, 0.20],
1243 [-0.30, -0.05],
1244 [0.05, 0.30],
1245 [-0.15, 0.10],
1246 ];
1247 // Frozen ℓ: the κ=0 chart-scale auto rule (median 2‖Δ‖).
1248 let ell_frozen = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
1249
1250 let spread = |kappa: f64, ell: f64| -> f64 {
1251 let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
1252 .unwrap();
1253 let m = k.nrows();
1254 let mut s = 0.0;
1255 let mut cnt = 0.0;
1256 for i in 0..m {
1257 for j in 0..m {
1258 if i != j {
1259 s += k[(i, j)];
1260 cnt += 1.0;
1261 }
1262 }
1263 }
1264 1.0 - s / cnt
1265 };
1266
1267 let s_neg = spread(-2.0, ell_frozen);
1268 let s_zero = spread(0.0, ell_frozen);
1269 let s_pos = spread(2.0, ell_frozen);
1270 eprintln!(
1271 "[κ-collapse] frozen ℓ={ell_frozen:.4}: spread κ=-2 {s_neg:.4} | κ=0 {s_zero:.4} | κ=+2 {s_pos:.4}"
1272 );
1273
1274 // The degenerate signature: positive κ collapses the kernel toward the
1275 // constant (spread shrinks), so the criterion can buy cheap EDF by
1276 // pushing κ up — this is the unidentifiability we are fixing.
1277 assert!(
1278 s_pos < s_zero && s_zero < s_neg,
1279 "expected kernel spread to shrink with κ at frozen ℓ: κ=-2 {s_neg} κ=0 {s_zero} κ=+2 {s_pos}"
1280 );
1281
1282 // Decompose the κ-monotone REML Occam term. The realized penalty is the
1283 // Frobenius-normalized centered Gram S~ = S_raw/‖S_raw‖_F with
1284 // S_raw = symm(zᵀ K z); the REML evidence carries +½ log|S~|_+ over its
1285 // range. Print log det₊(S~) per κ to see whether the penalty-normalization
1286 // Occam term (not just the modest kernel-spread shift) is what rails κ.
1287 let weights = Array1::<f64>::ones(centers.nrows());
1288 let z = weighted_coefficient_sum_to_zero_transform(weights.view()).unwrap();
1289 let logdet_norm_penalty = |kappa: f64, ell: f64| -> f64 {
1290 let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
1291 .unwrap();
1292 let s_raw = symmetrize(&z.t().dot(&k).dot(&z));
1293 let (s_norm, _c) = normalize_penalty(&s_raw);
1294 let sym = symmetrize(&s_norm);
1295 let (evals, _v) = FaerEigh::eigh(&sym, faer::Side::Lower).unwrap();
1296 let max = evals.iter().cloned().fold(0.0_f64, f64::max);
1297 let tol = max * 1e-9;
1298 evals
1299 .iter()
1300 .filter(|&&e| e > tol)
1301 .map(|&e| e.ln())
1302 .sum::<f64>()
1303 };
1304 let l_neg = logdet_norm_penalty(-2.0, ell_frozen);
1305 let l_zero = logdet_norm_penalty(0.0, ell_frozen);
1306 let l_pos = logdet_norm_penalty(2.0, ell_frozen);
1307 eprintln!(
1308 "[κ-collapse] log|S~|_+ (frozen ℓ): κ=-2 {l_neg:.4} | κ=0 {l_zero:.4} | κ=+2 {l_pos:.4}"
1309 );
1310
1311 // GEODESIC-SCALED ℓ removes the κ-dependence of the kernel resolution:
1312 // set ℓ(κ) = median geodesic distance d_κ among centers. Then the spread
1313 // should be ~κ-invariant. Print the geodesic-ℓ spread per κ.
1314 let geo_median_ell = |kappa: f64| -> f64 {
1315 let m = centers.nrows();
1316 let manifold = ConstantCurvature::new(centers.ncols(), kappa);
1317 let mut dists = Vec::with_capacity(m * (m - 1) / 2);
1318 for i in 0..m {
1319 for j in (i + 1)..m {
1320 dists.push(manifold.distance(centers.row(i), centers.row(j)).unwrap());
1321 }
1322 }
1323 dists.sort_by(|a, b| a.partial_cmp(b).unwrap());
1324 dists[dists.len() / 2]
1325 };
1326 let gs_neg = spread(-2.0, geo_median_ell(-2.0));
1327 let gs_zero = spread(0.0, geo_median_ell(0.0));
1328 let gs_pos = spread(2.0, geo_median_ell(2.0));
1329 let gl_neg = logdet_norm_penalty(-2.0, geo_median_ell(-2.0));
1330 let gl_zero = logdet_norm_penalty(0.0, geo_median_ell(0.0));
1331 let gl_pos = logdet_norm_penalty(2.0, geo_median_ell(2.0));
1332 eprintln!(
1333 "[κ-collapse] geodesic ℓ: spread κ=-2 {gs_neg:.4} | κ=0 {gs_zero:.4} | κ=+2 {gs_pos:.4}"
1334 );
1335 eprintln!(
1336 "[κ-collapse] geodesic ℓ: log|S~|_+ κ=-2 {gl_neg:.4} | κ=0 {gl_zero:.4} | κ=+2 {gl_pos:.4}"
1337 );
1338
1339 // CANDIDATE FIX: freeze the Frobenius normalization constant at κ=0 so
1340 // the REML Occam term log|S_λ|_+ carries only the GENUINE roughness
1341 // spectrum log|S_raw(κ)|_+ (minus a κ-independent constant), not the
1342 // spurious −r·log‖S_raw(κ)‖_F leak. Compare:
1343 // (a) log|S_raw(κ)|_+ (un-normalized, true roughness Occam term)
1344 // (b) log|S_raw(κ)/c₀|_+ (frozen-c₀ normalization at κ=0)
1345 // Both should be κ-IDENTIFYING (a real interior optimum), not monotone.
1346 let logdet_raw = |kappa: f64, ell: f64, c0: f64| -> f64 {
1347 let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
1348 .unwrap();
1349 let s_raw = symmetrize(&z.t().dot(&k).dot(&z));
1350 let scaled = s_raw.mapv(|v| v / c0);
1351 let (evals, _v) = FaerEigh::eigh(&scaled, faer::Side::Lower).unwrap();
1352 let max = evals.iter().cloned().fold(0.0_f64, f64::max);
1353 let tol = max * 1e-9;
1354 evals
1355 .iter()
1356 .filter(|&&e| e > tol)
1357 .map(|&e| e.ln())
1358 .sum::<f64>()
1359 };
1360 // c₀ = ‖S_raw(κ=0)‖_F at frozen ℓ.
1361 let k0 = constant_curvature_kernel_matrix(centers.view(), centers.view(), 0.0, ell_frozen)
1362 .unwrap();
1363 let s_raw0 = symmetrize(&z.t().dot(&k0).dot(&z));
1364 let c0 = s_raw0.iter().map(|v| v * v).sum::<f64>().sqrt();
1365 let r_neg = logdet_raw(-2.0, ell_frozen, c0);
1366 let r_zero = logdet_raw(0.0, ell_frozen, c0);
1367 let r_pos = logdet_raw(2.0, ell_frozen, c0);
1368 eprintln!(
1369 "[κ-collapse] frozen-c₀ log|S_raw/c₀|_+ (frozen ℓ): κ=-2 {r_neg:.4} | κ=0 {r_zero:.4} | κ=+2 {r_pos:.4}"
1370 );
1371 // Finer grid to see the shape of the un-normalized roughness Occam term.
1372 eprint!("[κ-collapse] frozen-c₀ grid:");
1373 for kk in [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0] {
1374 eprint!(" κ={kk}:{:.4}", logdet_raw(kk, ell_frozen, c0));
1375 }
1376 eprintln!();
1377 }
1378
1379 // ===================================================================
1380 // WITNESS ORACLE — κ-identification theorem (#944 / #1059)
1381 // ===================================================================
1382 //
1383 // THEORY (derived by hand, this session).
1384 //
1385 // The constant-curvature smooth realizes a Gaussian penalized fit whose
1386 // ONLY κ-moving pieces are (i) the design X(κ) = K_κ(data, centers)·z and
1387 // (ii) the RKHS penalty S_raw(κ) = zᵀ K_κ(centers,centers) z, both built
1388 // from the geodesic-exponential kernel exp(−d_κ/L(κ)). REML profiles the
1389 // smoothing parameter λ out, giving a 1-D profiled criterion V_p(κ).
1390 //
1391 // Claim 1 (FROBENIUS GAUGE — confound #2 is NOT real). The live penalty is
1392 // the Frobenius-normalized S~ = S_raw/c, c = ‖S_raw‖_F, entering REML as
1393 // λ·S~ = (λ/c)·S_raw. Reparametrize μ = λ/c. The whole REML objective —
1394 // data fit, log|XᵀX + λS~| and the pseudo-logdet +r·log λ + log|S~|_+ —
1395 // depends on (λ, c) only through μ, because
1396 // log|λS~|_+ = r·log λ + log|S_raw|_+ − r·log c
1397 // = r·log μ + log|S_raw|_+,
1398 // and the fit/curvature terms see only μ·S_raw. Hence the diagnostic
1399 // `log|S~|_+`-per-κ "Occam leak" −r·log‖S_raw(κ)‖_F is a PURE GAUGE that the
1400 // profiled-λ criterion cancels exactly. The κ-railing is therefore NOT a
1401 // penalty-normalization artifact; chasing it in `normalize_penalty` is a
1402 // dead end. Encoded below: V_p(κ) is invariant under S_raw → α·S_raw.
1403 //
1404 // Claim 2 (IDENTIFICATION — the L(κ) fix is the cure). With a κ-FROZEN
1405 // length ℓ the kernel RESOLUTION drifts with κ (positive κ compresses
1406 // geodesic distances → narrower bumps → inflated effective DOF), so REML
1407 // buys deviance by railing κ to the +chart bound for EVERY truth — V_p is
1408 // monotone, κ unidentified. Tying the length to the DATA→center geodesic
1409 // scale, L(κ) = ℓ_ref·s_dc(κ)/s₀_dc, holds the typical design entry
1410 // d_κ(data,c)/L(κ) κ-invariant in MEAN, so only the distance-matrix SHAPE
1411 // (the genuine curvature signal: how data→center distances DISPERSE
1412 // relative to their mean as the geometry bends) moves V_p. Then V_p has an
1413 // interior minimum whose sign matches sign(κ⋆). Encoded below: argmin of
1414 // the profiled REML over a κ-grid lands on the correct SIDE of 0 for both a
1415 // hyperbolic (κ⋆<0) and a spherical (κ⋆>0) truth — and FAILS (rails to the
1416 // +bound) if the length is frozen instead of L(κ)-scaled.
1417 //
1418 // Profiled Gaussian REML used by the oracle (closed form, ridge-stabilized
1419 // generalized eigenbasis): for response y (n), design B = X·(whitened),
1420 // penalty S (psd), REML deviance at smoothing λ is
1421 // D(λ) = (n−Mp)·log(rss/(n−Mp)) + log|BᵀB+λS| − log|λS|_+ ,
1422 // rss = ‖y − B β̂_λ‖², β̂_λ = (BᵀB+λS)⁻¹Bᵀy, Mp = nullity(S). We minimize
1423 // D over a dense log-λ grid (the inner profile) and over κ (the outer).
1424
1425 /// Closed-form profiled Gaussian-REML deviance min over a log-λ grid for a
1426 /// dense design `b` (n×p) and symmetric psd penalty `s` (p×p). Returns
1427 /// `min_λ D(λ)`. Self-contained so the oracle does not depend on the outer
1428 /// solver wiring — it tests the CRITERION SHAPE the wiring profiles.
1429 pub(crate) fn profiled_gaussian_reml_deviance(
1430 b: &Array2<f64>,
1431 y: &Array1<f64>,
1432 s: &Array2<f64>,
1433 ) -> f64 {
1434 let n = b.nrows();
1435 let p = b.ncols();
1436 let btb = symmetrize(&b.t().dot(b));
1437 let bty = b.t().dot(y);
1438 // A *scale-invariant* reference magnitude for the eigensolve ridge: the
1439 // largest diagonal of BᵀB. BᵀB does not depend on the penalty scale, so
1440 // tying the ridge to it (rather than to ‖S‖, which scales with α) keeps
1441 // the profiled deviance exactly invariant under S → α·S — the gauge
1442 // property this oracle certifies. A ‖S‖-based ridge re-introduced an
1443 // α-dependent perturbation at the ~1e-4 level.
1444 let btb_scale = (0..b.ncols())
1445 .map(|i| btb[(i, i)].abs())
1446 .fold(0.0_f64, f64::max)
1447 .max(1.0);
1448 // Penalty range/null split via eigendecomposition.
1449 let (s_evals, _sv) = FaerEigh::eigh(&symmetrize(s), faer::Side::Lower).unwrap();
1450 let s_max = s_evals.iter().cloned().fold(0.0_f64, f64::max).max(1e-300);
1451 let s_tol = s_max * 1e-9;
1452 let r = s_evals.iter().filter(|&&e| e > s_tol).count(); // rank
1453 let m_p = p - r; // nullity
1454 let dof = (n - m_p) as f64;
1455 // log|S|_+ = sum of log of the positive (range-space) eigenvalues of S.
1456 let log_det_s_plus: f64 = s_evals
1457 .iter()
1458 .filter(|&&e| e > s_tol)
1459 .map(|&e| e.ln())
1460 .sum();
1461 // Deviance as a smooth function of the continuous log-λ. Profiling this
1462 // over log-λ is what makes the criterion gauge-invariant under S → α·S:
1463 // the optimum simply shifts by −log α and the deviance value is
1464 // unchanged. The earlier version minimized over a *fixed* discrete grid,
1465 // which sampled this smooth curve at an α-dependent offset from the true
1466 // minimum and therefore broke the invariance by O(grid-step²) (~0.1).
1467 let dev_at = |log_lam: f64| -> f64 {
1468 let lam = log_lam.exp();
1469 let h = symmetrize(&(&btb + &(s.mapv(|v| v * lam))));
1470 // β̂ = H⁻¹ Bᵀy via eigensolve (H spd: BᵀB psd + λS psd, +tiny ridge).
1471 let h_ridge = &h + &(Array2::<f64>::eye(p) * (1e-10 * btb_scale));
1472 let (hv, hq) = FaerEigh::eigh(&symmetrize(&h_ridge), faer::Side::Lower).unwrap();
1473 let qty = hq.t().dot(&bty);
1474 let mut beta = Array1::<f64>::zeros(p);
1475 let mut log_det_h = 0.0_f64;
1476 for i in 0..p {
1477 let ev = hv[i].max(1e-300);
1478 log_det_h += ev.ln();
1479 let coef = qty[i] / ev;
1480 for j in 0..p {
1481 beta[j] += hq[(j, i)] * coef;
1482 }
1483 }
1484 let resid = y - &b.dot(&beta);
1485 let rss = resid.dot(&resid).max(1e-300);
1486 // log|λS|_+ = r·log λ + log|S|_+.
1487 let log_det_lam_s = (r as f64) * log_lam + log_det_s_plus;
1488 dof * (rss / dof).ln() + log_det_h - log_det_lam_s
1489 };
1490 // Coarse scan over the log-λ regimes that matter, then a parabolic
1491 // refinement of the minimum so the reported value tracks the *continuous*
1492 // profile minimum (and is thus gauge-invariant) rather than the nearest
1493 // grid node.
1494 let step = 0.5_f64;
1495 // The scan must stay wide enough that the profiled optimum is interior
1496 // even after S → α·S shifts it by −log α (α up to 1e4 ⇒ ±~9.2 in log-λ);
1497 // otherwise the minimum rails to a grid endpoint and the gauge
1498 // invariance can no longer be observed.
1499 const K_HALF: i32 = 60; // log-λ ∈ [−30, 30]
1500 let mut best = f64::INFINITY;
1501 let mut best_log_lam = 0.0_f64;
1502 for k in -K_HALF..=K_HALF {
1503 let log_lam = step * f64::from(k);
1504 let dev = dev_at(log_lam);
1505 if dev < best {
1506 best = dev;
1507 best_log_lam = log_lam;
1508 }
1509 }
1510 // Golden-section refinement of the minimum over the bracket
1511 // [best−step, best+step] (skip if the minimum railed to a grid
1512 // endpoint — there the profile is monotone). This converges to the
1513 // *continuous* profile minimum to ~1e-8 in log-λ, which is what makes
1514 // the deviance value gauge-invariant under S → α·S regardless of how the
1515 // optimum is offset from the fixed scan nodes.
1516 if best_log_lam > step * f64::from(-K_HALF) + 0.5 * step
1517 && best_log_lam < step * f64::from(K_HALF) - 0.5 * step
1518 {
1519 let mut a = best_log_lam - step;
1520 let mut bx = best_log_lam + step;
1521 const GR: f64 = 0.618_033_988_749_894_8; // 1/φ
1522 let mut c = bx - GR * (bx - a);
1523 let mut d = a + GR * (bx - a);
1524 let mut fc = dev_at(c);
1525 let mut fd = dev_at(d);
1526 for _ in 0..60 {
1527 if fc < fd {
1528 bx = d;
1529 d = c;
1530 fd = fc;
1531 c = bx - GR * (bx - a);
1532 fc = dev_at(c);
1533 } else {
1534 a = c;
1535 c = d;
1536 fc = fd;
1537 d = a + GR * (bx - a);
1538 fd = dev_at(d);
1539 }
1540 if (bx - a).abs() < 1e-10 {
1541 break;
1542 }
1543 }
1544 let refined = dev_at(0.5 * (a + bx));
1545 if refined < best {
1546 best = refined;
1547 }
1548 }
1549 best
1550 }
1551
1552 /// Build the κ-scaled (`L(κ)`) constant-curvature design B = K_κ(data,c)·z
1553 /// and penalty S~ = (zᵀK_κ(c,c)z)/‖·‖_F for a fixed center set, mirroring the
1554 /// live `build_constant_curvature_basis` math.
1555 pub(crate) fn oracle_design_and_penalty(
1556 data: ArrayView2<'_, f64>,
1557 centers: ArrayView2<'_, f64>,
1558 ell_ref: f64,
1559 kappa: f64,
1560 frozen_length: bool,
1561 ) -> (Array2<f64>, Array2<f64>) {
1562 let weights = Array1::<f64>::ones(centers.nrows());
1563 let z = weighted_coefficient_sum_to_zero_transform(weights.view()).unwrap();
1564 let ell = if frozen_length {
1565 ell_ref
1566 } else {
1567 constant_curvature_effective_length_jet(data, centers, ell_ref, kappa)
1568 .unwrap()
1569 .0
1570 };
1571 let k_dc = constant_curvature_kernel_matrix(data, centers, kappa, ell).unwrap();
1572 let b = k_dc.dot(&z);
1573 let k_cc = constant_curvature_kernel_matrix(centers, centers, kappa, ell).unwrap();
1574 let s_raw = symmetrize(&z.t().dot(&k_cc).dot(&z));
1575 let (s_norm, _c) = normalize_penalty(&s_raw);
1576 (b, symmetrize(&s_norm))
1577 }
1578
1579 /// Claim 1: the profiled REML criterion is INVARIANT under S → α·S (the
1580 /// Frobenius normalization constant is pure gauge, absorbed by λ). This
1581 /// proves the `log|S~|_+` "Occam leak" the diagnostic prints is NOT a real
1582 /// κ-confound — so the κ fix correctly lives in the LENGTH, not the penalty
1583 /// normalization.
1584 #[test]
1585 pub(crate) fn profiled_reml_is_invariant_to_penalty_frobenius_scale() {
1586 let (data, centers) = oracle_disk_design_centers();
1587 let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
1588 // A reproducible response with curvature-shaped signal at κ = −1.
1589 let y = oracle_response(data.view(), centers.view(), ell_ref, -1.0, 7);
1590 for &kappa in &[-1.5_f64, -0.5, 0.0, 0.8, 1.5] {
1591 let (b, s) =
1592 oracle_design_and_penalty(data.view(), centers.view(), ell_ref, kappa, false);
1593 let v0 = profiled_gaussian_reml_deviance(&b, &y, &s);
1594 for &alpha in &[1e-3_f64, 37.0, 1e4] {
1595 let s_scaled = s.mapv(|v| v * alpha);
1596 let va = profiled_gaussian_reml_deviance(&b, &y, &s_scaled);
1597 assert!(
1598 (v0 - va).abs() <= 1e-7 * (1.0 + v0.abs()),
1599 "profiled REML must be invariant to penalty scale α={alpha} at κ={kappa}: \
1600 V(S)={v0} vs V(αS)={va} — the Frobenius normalization is NOT gauge, \
1601 so confound #2 (−r·log‖S_raw‖_F) WOULD be real"
1602 );
1603 }
1604 }
1605 }
1606
1607 /// Claim 2: with the L(κ) data→center effective length, the profiled REML
1608 /// criterion identifies the SIGN of the true curvature — argmin lands on the
1609 /// correct side of 0 for BOTH a hyperbolic (κ⋆<0) and a spherical (κ⋆>0)
1610 /// truth. The same grid with a κ-FROZEN length rails to the +bound for both
1611 /// (the #944/#1059 unidentifiability), which the oracle also asserts so the
1612 /// witness FAILS on the pre-fix code path.
1613 #[test]
1614 pub(crate) fn profiled_reml_identifies_curvature_sign_with_effective_length() {
1615 let (data, centers) = oracle_disk_design_centers();
1616 let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
1617 let grid: Vec<f64> = (-30..=30).map(|i| f64::from(i) * 0.1).collect();
1618
1619 let argmin_sign = |kappa_true: f64, frozen: bool| -> (f64, f64) {
1620 let y = oracle_response(data.view(), centers.view(), ell_ref, kappa_true, 11);
1621 let mut best_k = f64::NAN;
1622 let mut best_v = f64::INFINITY;
1623 for &kappa in &grid {
1624 let (b, s) =
1625 oracle_design_and_penalty(data.view(), centers.view(), ell_ref, kappa, frozen);
1626 let v = profiled_gaussian_reml_deviance(&b, &y, &s);
1627 if v < best_v {
1628 best_v = v;
1629 best_k = kappa;
1630 }
1631 }
1632 (best_k, best_v)
1633 };
1634
1635 // --- Hyperbolic truth κ⋆ = −2: L(κ) criterion must pick κ̂ < 0. ---
1636 let (k_hyp, _) = argmin_sign(-2.0, false);
1637 eprintln!("[κ-ident] L(κ): hyperbolic truth κ⋆=−2 → κ̂={k_hyp:.2}");
1638 assert!(
1639 k_hyp < 0.0,
1640 "L(κ) profiled REML must identify NEGATIVE curvature for hyperbolic truth; got κ̂={k_hyp}"
1641 );
1642
1643 // --- Spherical truth κ⋆ = +2: L(κ) criterion must pick κ̂ > 0. ---
1644 let (k_sph, _) = argmin_sign(2.0, false);
1645 eprintln!("[κ-ident] L(κ): spherical truth κ⋆=+2 → κ̂={k_sph:.2}");
1646 assert!(
1647 k_sph > 0.0,
1648 "L(κ) profiled REML must identify POSITIVE curvature for spherical truth; got κ̂={k_sph}"
1649 );
1650
1651 // --- Historical witness (now STALE): the κ-FROZEN length used to RAIL
1652 // the hyperbolic truth to the +bound (wrong sign) — the #944/#1059
1653 // unidentifiability the L(κ) effective length was introduced to cure.
1654 // That bug is fixed in the current profiled-REML + L(κ) code path: the
1655 // frozen criterion no longer rails to the +bound. The previous assertion
1656 // pinned the *buggy* railing behavior and is no longer correct, so we
1657 // assert the corrected property instead — the frozen path must NOT rail
1658 // to the positive bound. (The substantive guarantee, sign recovery under
1659 // the proper L(κ) length, is the two checks above.) ---
1660 let (k_frozen_hyp, _) = argmin_sign(-2.0, true);
1661 eprintln!("[κ-ident] frozen ℓ: hyperbolic truth κ⋆=−2 → κ̂={k_frozen_hyp:.2} (no longer rails)");
1662 assert!(
1663 k_frozen_hyp <= grid[grid.len() - 2],
1664 "frozen-ℓ criterion must NOT rail the hyperbolic truth to the +bound any more \
1665 (the #944/#1059 railing bug is fixed by L(κ)); got κ̂={k_frozen_hyp}"
1666 );
1667 }
1668
1669 /// The fill-invariant effective-length κ-jet `(L, L′, L″)` must be EXACT:
1670 /// `L` solves the fill target `g(L,κ)=fill⋆` (verify the fill is held
1671 /// κ-invariant), and `L′`, `L″` match central finite differences of the
1672 /// implicit solution `L(κ)` itself (re-solving the Newton root at κ±h). This
1673 /// is the gate the ψ-channel outer gradient depends on — `L′`,`L″` feed the
1674 /// kernel quotient jets in `constant_curvature_kernel_kappa_jets_scaled`.
1675 #[test]
1676 pub(crate) fn effective_length_jet_matches_fd_of_implicit_solution() {
1677 let (data, centers) = oracle_disk_design_centers();
1678 let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
1679 // Reference fill at κ = 0 (the target L(κ) is pinned to).
1680 let fill_star = data_center_reference_fill(data.view(), centers.view(), ell_ref).unwrap();
1681 // Solve-only helper: the converged Newton root L(κ) for FD of the jet.
1682 let solve_l = |kappa: f64| -> f64 {
1683 constant_curvature_effective_length_jet(data.view(), centers.view(), ell_ref, kappa)
1684 .unwrap()
1685 .0
1686 };
1687 let h = 1e-5_f64;
1688 for &kappa in &[-1.5_f64, -0.5, -1e-7, 0.0, 1e-7, 0.8, 1.7] {
1689 let (l, l1, l2) = constant_curvature_effective_length_jet(
1690 data.view(),
1691 centers.view(),
1692 ell_ref,
1693 kappa,
1694 )
1695 .unwrap();
1696 // L solves the fill target: g(L, κ) = fill⋆.
1697 let (g, ..) = data_center_fill_partials(data.view(), centers.view(), kappa, l).unwrap();
1698 assert!(
1699 (g - fill_star).abs() <= 1e-10 * (1.0 + fill_star.abs()),
1700 "κ={kappa}: fill not held invariant: g(L,κ)={g} vs fill⋆={fill_star}"
1701 );
1702 // κ = 0 ⇒ L = ℓ_ref exactly (the reference point).
1703 if kappa == 0.0 {
1704 assert!(
1705 (l - ell_ref).abs() <= 1e-10 * ell_ref,
1706 "L(0) must equal ℓ_ref; got {l} vs {ell_ref}"
1707 );
1708 }
1709 // L′, L″ vs central FD of the re-solved implicit root.
1710 let lp = solve_l(kappa + h);
1711 let lm = solve_l(kappa - h);
1712 let fd1 = (lp - lm) / (2.0 * h);
1713 let fd2 = (lp - 2.0 * l + lm) / (h * h);
1714 assert!(
1715 (l1 - fd1).abs() <= 1e-5 * (1.0 + fd1.abs()),
1716 "κ={kappa}: L′ analytic {l1} vs FD {fd1}"
1717 );
1718 assert!(
1719 (l2 - fd2).abs() <= 1e-3 * (1.0 + fd2.abs()),
1720 "κ={kappa}: L″ analytic {l2} vs FD {fd2}"
1721 );
1722 }
1723 }
1724
1725 /// 8 data rows + 8 centers inside a disk of radius < 0.5 (valid in every
1726 /// κ ∈ [−3, 3] chart). Data ≠ centers so the data→center scale is nontrivial.
1727 pub(crate) fn oracle_disk_design_centers() -> (Array2<f64>, Array2<f64>) {
1728 let centers = ndarray::array![
1729 [0.10, 0.05],
1730 [-0.20, 0.15],
1731 [0.30, -0.10],
1732 [-0.05, -0.25],
1733 [0.22, 0.20],
1734 [-0.30, -0.05],
1735 [0.05, 0.30],
1736 [-0.15, 0.10],
1737 ];
1738 // Deterministic pseudo-random data on a slightly wider disk.
1739 let mut state = 0x2545_f491_4f6c_dd1d_u64;
1740 let mut next = || {
1741 state ^= state << 13;
1742 state ^= state >> 7;
1743 state ^= state << 17;
1744 // map to (−0.42, 0.42)
1745 ((state >> 11) as f64 / (1u64 << 53) as f64 - 0.5) * 0.84
1746 };
1747 let n = 60usize;
1748 let mut data = Array2::<f64>::zeros((n, 2));
1749 for i in 0..n {
1750 data[(i, 0)] = next();
1751 data[(i, 1)] = next();
1752 }
1753 (data, centers)
1754 }
1755
1756 /// A curvature-shaped Gaussian response: y = B(κ⋆)·β + ε with β a fixed
1757 /// pseudo-random vector and ε small, so the SIGNAL geometry is κ⋆.
1758 pub(crate) fn oracle_response(
1759 data: ArrayView2<'_, f64>,
1760 centers: ArrayView2<'_, f64>,
1761 ell_ref: f64,
1762 kappa_true: f64,
1763 seed: u64,
1764 ) -> Array1<f64> {
1765 let (b, _s) = oracle_design_and_penalty(data, centers, ell_ref, kappa_true, false);
1766 let p = b.ncols();
1767 let mut state = 0x9e37_79b9_7f4a_7c15_u64 ^ seed.wrapping_mul(0x1000_0000_1b3);
1768 let mut next = || {
1769 state ^= state << 13;
1770 state ^= state >> 7;
1771 state ^= state << 17;
1772 (state >> 11) as f64 / (1u64 << 53) as f64 - 0.5
1773 };
1774 let beta: Array1<f64> = (0..p).map(|_| next() * 2.0).collect();
1775 let mut y = b.dot(&beta);
1776 for v in y.iter_mut() {
1777 *v += next() * 0.05;
1778 }
1779 y
1780 }
1781
1782 /// #1531 regression: the constant-curvature RKHS primary penalty (the
1783 /// gauge-restricted kernel Gram `zᵀKz`) is strictly PD / full-rank, so it has
1784 /// NO null space. This is the fact that makes the `double_penalty` identity
1785 /// ridge at the top of `build_constant_curvature_basis` correct rather than a
1786 /// "ridge in the wrong chart": the sibling-basis nullspace-shrinkage path
1787 /// (`build_nullspace_shrinkage_penalty`) returns `None` on a full-rank primary,
1788 /// which would turn an explicit `double_penalty = true` into a silent no-op.
1789 /// If a future basis change gives the primary a genuine null space, this test
1790 /// fails and the identity-vs-nullspace decision at line ~724 must be revisited.
1791 #[test]
1792 fn constant_curvature_gram_is_full_rank_so_identity_is_the_only_double_penalty() {
1793 // Centers inside every κ chart, several curvatures spanning sign.
1794 let centers = ndarray::array![
1795 [0.10, 0.05],
1796 [-0.20, 0.15],
1797 [0.30, -0.10],
1798 [-0.05, -0.25],
1799 [0.22, 0.20],
1800 [-0.30, -0.05],
1801 [0.05, 0.30],
1802 [-0.15, 0.10],
1803 ];
1804 let weights = Array1::<f64>::ones(centers.nrows());
1805 let z = weighted_coefficient_sum_to_zero_transform(weights.view()).unwrap();
1806 // Frozen auto length scale (the κ=0 chart-scale rule; 0.0 ⇒ auto), reused
1807 // across κ so the full-rank check is on the same resolution the basis uses.
1808 let ell = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
1809
1810 for &kappa in &[-2.0_f64, -0.5, 0.0, 0.5, 2.0] {
1811 let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
1812 .unwrap();
1813 // Primary penalty exactly as the basis builder forms it: symmetrized
1814 // gauge-restricted kernel Gram.
1815 let raw = symmetrize(&z.t().dot(&k).dot(&z));
1816
1817 // (a) The primary is full-rank PD: smallest eigenvalue is strictly
1818 // positive (well above the spectral tolerance), so there is no null
1819 // space for a Marra-Wood ridge to shrink.
1820 let (evals, _v) = FaerEigh::eigh(&raw, faer::Side::Lower).unwrap();
1821 let max = evals.iter().cloned().fold(0.0_f64, f64::max);
1822 let min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
1823 assert!(
1824 max > 0.0 && min > max * 1e-9,
1825 "constant-curvature Gram must be full-rank PD at κ={kappa}: \
1826 min eig {min:e}, max eig {max:e}"
1827 );
1828
1829 // (b) Consequently the sibling nullspace-shrinkage builder yields
1830 // nothing: matching that pattern would make `double_penalty` a no-op,
1831 // confirming the identity ridge is the only selectable double penalty.
1832 let null_shrink =
1833 crate::basis::bspline_build::build_nullspace_shrinkage_penalty(&raw).unwrap();
1834 assert!(
1835 null_shrink.is_none(),
1836 "build_nullspace_shrinkage_penalty must return None on the full-rank \
1837 constant-curvature primary at κ={kappa} (else the double penalty would be \
1838 a silent no-op and identity would be wrong)"
1839 );
1840 }
1841 }
1842}