use ndarray::{Array1, Array2, ArrayView2, Axis};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::geometry::constant_curvature::{ConstantCurvature, distance_kappa_jet};
use super::{
BasisBuildResult, BasisError, BasisMetadata, BasisPsiDerivativeBundle,
BasisPsiDerivativeResult, BasisPsiSecondDerivativeResult, CenterStrategy, PenaltyCandidate,
PenaltyInfo, PenaltySource, filter_active_penalty_candidates_with_ops, normalize_penalty,
normalize_penaltywith_psi_derivatives, select_centers_by_strategy,
weighted_coefficient_sum_to_zero_transform,
};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum ConstantCurvatureIdentifiability {
#[default]
CenterSumToZero,
FrozenTransform { transform: Array2<f64> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantCurvatureBasisSpec {
pub center_strategy: CenterStrategy,
pub kappa: f64,
pub length_scale: f64,
pub double_penalty: bool,
#[serde(default)]
pub identifiability: ConstantCurvatureIdentifiability,
}
impl Default for ConstantCurvatureBasisSpec {
fn default() -> Self {
Self {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
kappa: 0.0,
length_scale: 0.0,
double_penalty: true,
identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
}
}
}
pub(crate) fn validate_chart_points(
points: ArrayView2<'_, f64>,
kappa: f64,
what: &str,
) -> Result<(), BasisError> {
for (i, row) in points.outer_iter().enumerate() {
let mut nx2 = 0.0_f64;
for &v in row.iter() {
if !v.is_finite() {
crate::bail_invalid_basis!(
"constant-curvature {what} row {i} has a non-finite coordinate"
);
}
nx2 += v * v;
}
if 1.0 + kappa * nx2 <= 0.0 {
crate::bail_invalid_basis!(
"constant-curvature {what} row {i} lies outside the κ-stereographic chart \
(need 1 + κ·‖x‖² > 0; got κ = {kappa}, ‖x‖² = {nx2}); for κ < 0 the chart is \
the open ball ‖x‖ < 1/√(−κ)"
);
}
}
Ok(())
}
pub fn constant_curvature_kernel_matrix(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
length_scale: f64,
) -> Result<Array2<f64>, BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"constant-curvature kernel dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
if !(length_scale.is_finite() && length_scale > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature kernel needs a positive finite length_scale; got {length_scale}"
);
}
validate_chart_points(data, kappa, "data")?;
validate_chart_points(centers, kappa, "centers")?;
let manifold = ConstantCurvature::new(data.ncols(), kappa);
let mut out = Array2::<f64>::zeros((data.nrows(), centers.nrows()));
out.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.try_for_each(|(i, mut row)| -> Result<(), BasisError> {
for (j, c) in centers.outer_iter().enumerate() {
let d = manifold.distance(data.row(i), c).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature distance failed at (row {i}, center {j}): {e}"
))
})?;
row[j] = (-d / length_scale).exp();
}
Ok(())
})?;
Ok(out)
}
pub fn constant_curvature_kernel_kappa_jets(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
length_scale: f64,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"constant-curvature kernel-jet dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
if !(length_scale.is_finite() && length_scale > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature kernel jets need a positive finite length_scale; got {length_scale}"
);
}
validate_chart_points(data, kappa, "data")?;
validate_chart_points(centers, kappa, "centers")?;
let manifold = ConstantCurvature::new(data.ncols(), kappa);
let n = data.nrows();
let m = centers.nrows();
let mut value = Array2::<f64>::zeros((n, m));
let mut dk = Array2::<f64>::zeros((n, m));
let mut dkk = Array2::<f64>::zeros((n, m));
let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
.into_par_iter()
.map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
let mut row = Vec::with_capacity(m);
for (j, c) in centers.outer_iter().enumerate() {
let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature distance κ-jet failed at (row {i}, center {j}): {e}"
))
})?;
let k = (-d / length_scale).exp();
let g = d1 / length_scale;
row.push((k, -g * k, (g * g - d2 / length_scale) * k));
}
Ok((i, row))
})
.collect::<Result<Vec<_>, BasisError>>()?;
for (i, row) in rows {
for (j, (k, k1, k2)) in row.into_iter().enumerate() {
value[(i, j)] = k;
dk[(i, j)] = k1;
dkk[(i, j)] = k2;
}
}
Ok((value, dk, dkk))
}
pub(crate) fn constant_curvature_kernel_kappa_jets_scaled(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
l_jet: (f64, f64, f64),
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"constant-curvature scaled kernel-jet dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
let (l, l1, l2) = l_jet;
if !(l.is_finite() && l > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature scaled kernel jets need a positive finite effective length; got {l}"
);
}
validate_chart_points(data, kappa, "data")?;
validate_chart_points(centers, kappa, "centers")?;
let manifold = ConstantCurvature::new(data.ncols(), kappa);
let n = data.nrows();
let m = centers.nrows();
let mut value = Array2::<f64>::zeros((n, m));
let mut dk = Array2::<f64>::zeros((n, m));
let mut dkk = Array2::<f64>::zeros((n, m));
let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
.into_par_iter()
.map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
let mut row = Vec::with_capacity(m);
for (j, c) in centers.outer_iter().enumerate() {
let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature scaled distance κ-jet failed at (row {i}, center {j}): {e}"
))
})?;
let q = d / l;
let q1 = d1 / l - d * l1 / (l * l);
let q2 = d2 / l - 2.0 * d1 * l1 / (l * l) - d * l2 / (l * l)
+ 2.0 * d * l1 * l1 / (l * l * l);
let k = (-q).exp();
row.push((k, -q1 * k, (q1 * q1 - q2) * k));
}
Ok((i, row))
})
.collect::<Result<Vec<_>, BasisError>>()?;
for (i, row) in rows {
for (j, (k, k1, k2)) in row.into_iter().enumerate() {
value[(i, j)] = k;
dk[(i, j)] = k1;
dkk[(i, j)] = k2;
}
}
Ok((value, dk, dkk))
}
pub fn realized_constant_curvature_length_scale(
centers: ArrayView2<'_, f64>,
spec_length_scale: f64,
) -> Result<f64, BasisError> {
if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
return Ok(spec_length_scale);
}
if spec_length_scale != 0.0 {
crate::bail_invalid_basis!(
"constant-curvature length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
);
}
let m = centers.nrows();
if m < 2 {
return Err(BasisError::InsufficientColumnsForConstraint { found: m });
}
let mut dists: Vec<f64> = Vec::with_capacity(m * (m - 1) / 2);
for i in 0..m {
for j in (i + 1)..m {
let mut s = 0.0_f64;
for k in 0..centers.ncols() {
let dlt = centers[(i, k)] - centers[(j, k)];
s += dlt * dlt;
}
dists.push(2.0 * s.sqrt());
}
}
dists.sort_by(|a, b| a.partial_cmp(b).expect("finite chart distances"));
let median = dists[dists.len() / 2];
if !(median.is_finite() && median > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature auto length_scale failed: centers are degenerate \
(median pairwise chart distance = {median})"
);
}
Ok(median)
}
pub(crate) fn data_center_reference_fill(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
ell_ref: f64,
) -> Result<f64, BasisError> {
if !(ell_ref.is_finite() && ell_ref > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature reference fill needs a positive finite ℓ_ref; got {ell_ref}"
);
}
let mut sum = 0.0_f64;
let mut cnt = 0.0_f64;
for xi in data.outer_iter() {
for cj in centers.outer_iter() {
let mut s = 0.0_f64;
for k in 0..centers.ncols() {
let dlt = xi[k] - cj[k];
s += dlt * dlt;
}
let d0 = 2.0 * s.sqrt(); sum += (-d0 / ell_ref).exp();
cnt += 1.0;
}
}
if cnt <= 0.0 {
crate::bail_invalid_basis!(
"constant-curvature reference fill needs at least one data row and one center"
);
}
Ok(sum / cnt)
}
pub(crate) fn data_center_fill_partials(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
l: f64,
) -> Result<(f64, f64, f64, f64, f64, f64), BasisError> {
if !(l.is_finite() && l > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature fill partials need a positive finite length; got {l}"
);
}
let manifold = ConstantCurvature::new(centers.ncols(), kappa);
let l2 = l * l;
let l3 = l2 * l;
let l4 = l2 * l2;
let mut g = 0.0_f64;
let mut g_l = 0.0_f64;
let mut g_k = 0.0_f64;
let mut g_ll = 0.0_f64;
let mut g_kk = 0.0_f64;
let mut g_lk = 0.0_f64;
let mut cnt = 0.0_f64;
for xi in data.outer_iter() {
for cj in centers.outer_iter() {
let (d, d1, d2) = distance_kappa_jet(&manifold, xi, cj).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature data→center fill κ-jet failed: {e}"
))
})?;
let k = (-d / l).exp();
g += k;
g_l += k * d / l2;
g_k += -k * d1 / l;
g_ll += k * d * (d - 2.0 * l) / l4;
g_kk += k * ((d1 * d1) / l - d2) / l;
g_lk += k * d1 * (l - d) / l3;
cnt += 1.0;
}
}
if cnt <= 0.0 {
crate::bail_invalid_basis!(
"constant-curvature fill partials need at least one data row and one center"
);
}
Ok((
g / cnt,
g_l / cnt,
g_k / cnt,
g_ll / cnt,
g_kk / cnt,
g_lk / cnt,
))
}
pub(crate) fn constant_curvature_effective_length_jet(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
ell_ref: f64,
kappa: f64,
) -> Result<(f64, f64, f64), BasisError> {
let fill_star = data_center_reference_fill(data, centers, ell_ref)?;
let mut l = ell_ref;
const NEWTON_MAX_ITER: usize = 100;
const NEWTON_REL_TOL: f64 = 1.0e-13;
let mut converged = false;
for _ in 0..NEWTON_MAX_ITER {
let (g, g_l, ..) = data_center_fill_partials(data, centers, kappa, l)?;
if !(g_l.is_finite() && g_l > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature effective length: non-positive fill slope g_L = {g_l} \
(degenerate data/centers at κ = {kappa})"
);
}
let step = (g - fill_star) / g_l;
l -= step;
if !(l.is_finite() && l > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature effective length: Newton left the positive axis (L = {l}) \
solving the fill target at κ = {kappa}"
);
}
if step.abs() <= NEWTON_REL_TOL * l {
converged = true;
break;
}
}
if !converged {
crate::bail_invalid_basis!(
"constant-curvature effective length: fill-target Newton did not converge at κ = {kappa}"
);
}
let (_, g_l, g_k, g_ll, g_kk, g_lk) = data_center_fill_partials(data, centers, kappa, l)?;
let l1 = -g_k / g_l;
let l2 = -(g_ll * l1 * l1 + 2.0 * g_lk * l1 + g_kk) / g_l;
Ok((l, l1, l2))
}
pub fn build_constant_curvature_basis(
data: ArrayView2<'_, f64>,
spec: &ConstantCurvatureBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
if data.ncols() == 0 {
crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
}
if !spec.kappa.is_finite() {
crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
}
validate_chart_points(data, spec.kappa, "data")?;
let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
if centers.nrows() < 2 {
return Err(BasisError::InsufficientColumnsForConstraint {
found: centers.nrows(),
});
}
validate_chart_points(centers.view(), spec.kappa, "centers")?;
let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
let (ell_eff, _, _) =
constant_curvature_effective_length_jet(data, centers.view(), length_scale, spec.kappa)?;
let raw_penalty =
constant_curvature_kernel_matrix(centers.view(), centers.view(), spec.kappa, ell_eff)?;
let z = match &spec.identifiability {
ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != centers.nrows() {
crate::bail_dim_basis!(
"frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
centers.nrows(),
transform.nrows()
);
}
transform.clone()
}
ConstantCurvatureIdentifiability::CenterSumToZero => {
let weights = Array1::<f64>::ones(centers.nrows());
weighted_coefficient_sum_to_zero_transform(weights.view())?
}
};
let gauge = crate::solver::gauge::Gauge::from_block_transforms(&[z.clone()]);
let penalty = gauge.restrict_penalty(&raw_penalty);
let raw_design = constant_curvature_kernel_matrix(data, centers.view(), spec.kappa, ell_eff)?;
let design = crate::matrix::DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
gauge.restrict_design(&raw_design),
));
let (penalty_norm, c_primary) = normalize_penalty(&((&penalty + &penalty.t()) * 0.5));
let mut candidates = vec![PenaltyCandidate {
matrix: penalty_norm,
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: c_primary,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
let ridge = Array2::<f64>::eye(design.ncols());
let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
candidates.push(PenaltyCandidate {
matrix: ridge_norm,
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: c_ridge,
kronecker_factors: None,
op: None,
});
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::ConstantCurvature {
centers,
kappa: spec.kappa,
length_scale,
constraint_transform: Some(z),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
pub(crate) fn symmetrize(m: &Array2<f64>) -> Array2<f64> {
(m + &m.t()) * 0.5
}
pub(crate) fn active_constant_curvature_penalty_derivatives(
penaltyinfo: &[PenaltyInfo],
primary_derivative: &Array2<f64>,
) -> Result<Vec<Array2<f64>>, BasisError> {
penaltyinfo
.iter()
.filter(|info| info.active)
.map(|info| match &info.source {
PenaltySource::Primary => Ok(primary_derivative.clone()),
PenaltySource::DoublePenaltyNullspace => {
Ok(Array2::<f64>::zeros(primary_derivative.raw_dim()))
}
other => Err(BasisError::InvalidInput(format!(
"unexpected constant-curvature penalty source in κ-derivative path: {other:?}"
))),
})
.collect()
}
pub fn build_constant_curvature_basis_kappa_derivatives(
data: ArrayView2<'_, f64>,
spec: &ConstantCurvatureBasisSpec,
) -> Result<BasisPsiDerivativeBundle, BasisError> {
if data.ncols() == 0 {
crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
}
if !spec.kappa.is_finite() {
crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
}
validate_chart_points(data, spec.kappa, "data")?;
let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
if centers.nrows() < 2 {
return Err(BasisError::InsufficientColumnsForConstraint {
found: centers.nrows(),
});
}
validate_chart_points(centers.view(), spec.kappa, "centers")?;
let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
let z = match &spec.identifiability {
ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != centers.nrows() {
crate::bail_dim_basis!(
"frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
centers.nrows(),
transform.nrows()
);
}
transform.clone()
}
ConstantCurvatureIdentifiability::CenterSumToZero => {
let weights = Array1::<f64>::ones(centers.nrows());
weighted_coefficient_sum_to_zero_transform(weights.view())?
}
};
let gauge = crate::solver::gauge::Gauge::from_block_transforms(&[z.clone()]);
let l_jet =
constant_curvature_effective_length_jet(data, centers.view(), length_scale, spec.kappa)?;
let (_k_dc, dk_dc, dkk_dc) =
constant_curvature_kernel_kappa_jets_scaled(data, centers.view(), spec.kappa, l_jet)?;
let design_first = gauge.restrict_design(&dk_dc);
let design_second_diag = gauge.restrict_design(&dkk_dc);
let (k_cc, dk_cc, dkk_cc) = constant_curvature_kernel_kappa_jets_scaled(
centers.view(),
centers.view(),
spec.kappa,
l_jet,
)?;
let s_raw = symmetrize(&gauge.restrict_penalty(&k_cc));
let s_raw_first = symmetrize(&gauge.restrict_penalty(&dk_cc));
let s_raw_second = symmetrize(&gauge.restrict_penalty(&dkk_cc));
let (_s_norm, s_norm_first, s_norm_second, _c) =
normalize_penaltywith_psi_derivatives(&s_raw, &s_raw_first, &s_raw_second);
let base = build_constant_curvature_basis(data, spec)?;
let penalties_derivative =
active_constant_curvature_penalty_derivatives(&base.penaltyinfo, &s_norm_first)?;
let penaltiessecond_derivative =
active_constant_curvature_penalty_derivatives(&base.penaltyinfo, &s_norm_second)?;
Ok(BasisPsiDerivativeBundle {
first: BasisPsiDerivativeResult {
design_derivative: design_first,
penalties_derivative,
implicit_operator: None,
},
second: BasisPsiSecondDerivativeResult {
designsecond_derivative: design_second_diag,
penaltiessecond_derivative,
implicit_operator: None,
},
implicit_operator: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::faer_ndarray::FaerEigh;
#[test]
pub(crate) fn kernel_spread_collapses_with_kappa_at_frozen_length_scale() {
let centers = ndarray::array![
[0.10, 0.05],
[-0.20, 0.15],
[0.30, -0.10],
[-0.05, -0.25],
[0.22, 0.20],
[-0.30, -0.05],
[0.05, 0.30],
[-0.15, 0.10],
];
let ell_frozen = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
let spread = |kappa: f64, ell: f64| -> f64 {
let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
.unwrap();
let m = k.nrows();
let mut s = 0.0;
let mut cnt = 0.0;
for i in 0..m {
for j in 0..m {
if i != j {
s += k[(i, j)];
cnt += 1.0;
}
}
}
1.0 - s / cnt
};
let s_neg = spread(-2.0, ell_frozen);
let s_zero = spread(0.0, ell_frozen);
let s_pos = spread(2.0, ell_frozen);
eprintln!(
"[κ-collapse] frozen ℓ={ell_frozen:.4}: spread κ=-2 {s_neg:.4} | κ=0 {s_zero:.4} | κ=+2 {s_pos:.4}"
);
assert!(
s_pos < s_zero && s_zero < s_neg,
"expected kernel spread to shrink with κ at frozen ℓ: κ=-2 {s_neg} κ=0 {s_zero} κ=+2 {s_pos}"
);
let weights = Array1::<f64>::ones(centers.nrows());
let z = weighted_coefficient_sum_to_zero_transform(weights.view()).unwrap();
let logdet_norm_penalty = |kappa: f64, ell: f64| -> f64 {
let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
.unwrap();
let s_raw = symmetrize(&z.t().dot(&k).dot(&z));
let (s_norm, _c) = normalize_penalty(&s_raw);
let sym = symmetrize(&s_norm);
let (evals, _v) = FaerEigh::eigh(&sym, faer::Side::Lower).unwrap();
let max = evals.iter().cloned().fold(0.0_f64, f64::max);
let tol = max * 1e-9;
evals
.iter()
.filter(|&&e| e > tol)
.map(|&e| e.ln())
.sum::<f64>()
};
let l_neg = logdet_norm_penalty(-2.0, ell_frozen);
let l_zero = logdet_norm_penalty(0.0, ell_frozen);
let l_pos = logdet_norm_penalty(2.0, ell_frozen);
eprintln!(
"[κ-collapse] log|S~|_+ (frozen ℓ): κ=-2 {l_neg:.4} | κ=0 {l_zero:.4} | κ=+2 {l_pos:.4}"
);
let geo_median_ell = |kappa: f64| -> f64 {
let m = centers.nrows();
let manifold = ConstantCurvature::new(centers.ncols(), kappa);
let mut dists = Vec::with_capacity(m * (m - 1) / 2);
for i in 0..m {
for j in (i + 1)..m {
dists.push(manifold.distance(centers.row(i), centers.row(j)).unwrap());
}
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap());
dists[dists.len() / 2]
};
let gs_neg = spread(-2.0, geo_median_ell(-2.0));
let gs_zero = spread(0.0, geo_median_ell(0.0));
let gs_pos = spread(2.0, geo_median_ell(2.0));
let gl_neg = logdet_norm_penalty(-2.0, geo_median_ell(-2.0));
let gl_zero = logdet_norm_penalty(0.0, geo_median_ell(0.0));
let gl_pos = logdet_norm_penalty(2.0, geo_median_ell(2.0));
eprintln!(
"[κ-collapse] geodesic ℓ: spread κ=-2 {gs_neg:.4} | κ=0 {gs_zero:.4} | κ=+2 {gs_pos:.4}"
);
eprintln!(
"[κ-collapse] geodesic ℓ: log|S~|_+ κ=-2 {gl_neg:.4} | κ=0 {gl_zero:.4} | κ=+2 {gl_pos:.4}"
);
let logdet_raw = |kappa: f64, ell: f64, c0: f64| -> f64 {
let k = constant_curvature_kernel_matrix(centers.view(), centers.view(), kappa, ell)
.unwrap();
let s_raw = symmetrize(&z.t().dot(&k).dot(&z));
let scaled = s_raw.mapv(|v| v / c0);
let (evals, _v) = FaerEigh::eigh(&scaled, faer::Side::Lower).unwrap();
let max = evals.iter().cloned().fold(0.0_f64, f64::max);
let tol = max * 1e-9;
evals
.iter()
.filter(|&&e| e > tol)
.map(|&e| e.ln())
.sum::<f64>()
};
let k0 = constant_curvature_kernel_matrix(centers.view(), centers.view(), 0.0, ell_frozen)
.unwrap();
let s_raw0 = symmetrize(&z.t().dot(&k0).dot(&z));
let c0 = s_raw0.iter().map(|v| v * v).sum::<f64>().sqrt();
let r_neg = logdet_raw(-2.0, ell_frozen, c0);
let r_zero = logdet_raw(0.0, ell_frozen, c0);
let r_pos = logdet_raw(2.0, ell_frozen, c0);
eprintln!(
"[κ-collapse] frozen-c₀ log|S_raw/c₀|_+ (frozen ℓ): κ=-2 {r_neg:.4} | κ=0 {r_zero:.4} | κ=+2 {r_pos:.4}"
);
eprint!("[κ-collapse] frozen-c₀ grid:");
for kk in [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0] {
eprint!(" κ={kk}:{:.4}", logdet_raw(kk, ell_frozen, c0));
}
eprintln!();
}
pub(crate) fn profiled_gaussian_reml_deviance(
b: &Array2<f64>,
y: &Array1<f64>,
s: &Array2<f64>,
) -> f64 {
let n = b.nrows();
let p = b.ncols();
let btb = symmetrize(&b.t().dot(b));
let bty = b.t().dot(y);
let (s_evals, _sv) = FaerEigh::eigh(&symmetrize(s), faer::Side::Lower).unwrap();
let s_max = s_evals.iter().cloned().fold(0.0_f64, f64::max).max(1e-300);
let s_tol = s_max * 1e-9;
let r = s_evals.iter().filter(|&&e| e > s_tol).count(); let m_p = p - r; let dof = (n - m_p) as f64;
let mut best = f64::INFINITY;
for k in -24i32..=24 {
let lam = (0.5 * f64::from(k)).exp();
let h = &btb + &(s.mapv(|v| v * lam));
let h = symmetrize(&h);
let h_ridge = &h + &(Array2::<f64>::eye(p) * (1e-10 * s_max.max(1.0)));
let (hv, hq) = FaerEigh::eigh(&symmetrize(&h_ridge), faer::Side::Lower).unwrap();
let qty = hq.t().dot(&bty);
let mut beta = Array1::<f64>::zeros(p);
let mut log_det_h = 0.0_f64;
for i in 0..p {
let ev = hv[i].max(1e-300);
log_det_h += ev.ln();
let coef = qty[i] / ev;
for j in 0..p {
beta[j] += hq[(j, i)] * coef;
}
}
let resid = y - &b.dot(&beta);
let rss = resid.dot(&resid).max(1e-300);
let log_det_s_plus: f64 = s_evals
.iter()
.filter(|&&e| e > s_tol)
.map(|&e| e.ln())
.sum();
let log_det_lam_s = (r as f64) * lam.ln() + log_det_s_plus;
let dev = dof * (rss / dof).ln() + log_det_h - log_det_lam_s;
if dev < best {
best = dev;
}
}
best
}
pub(crate) fn oracle_design_and_penalty(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
ell_ref: f64,
kappa: f64,
frozen_length: bool,
) -> (Array2<f64>, Array2<f64>) {
let weights = Array1::<f64>::ones(centers.nrows());
let z = weighted_coefficient_sum_to_zero_transform(weights.view()).unwrap();
let ell = if frozen_length {
ell_ref
} else {
constant_curvature_effective_length_jet(data, centers, ell_ref, kappa)
.unwrap()
.0
};
let k_dc = constant_curvature_kernel_matrix(data, centers, kappa, ell).unwrap();
let b = k_dc.dot(&z);
let k_cc = constant_curvature_kernel_matrix(centers, centers, kappa, ell).unwrap();
let s_raw = symmetrize(&z.t().dot(&k_cc).dot(&z));
let (s_norm, _c) = normalize_penalty(&s_raw);
(b, symmetrize(&s_norm))
}
#[test]
pub(crate) fn profiled_reml_is_invariant_to_penalty_frobenius_scale() {
let (data, centers) = oracle_disk_design_centers();
let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
let y = oracle_response(data.view(), centers.view(), ell_ref, -1.0, 7);
for &kappa in &[-1.5_f64, -0.5, 0.0, 0.8, 1.5] {
let (b, s) =
oracle_design_and_penalty(data.view(), centers.view(), ell_ref, kappa, false);
let v0 = profiled_gaussian_reml_deviance(&b, &y, &s);
for &alpha in &[1e-3_f64, 37.0, 1e4] {
let s_scaled = s.mapv(|v| v * alpha);
let va = profiled_gaussian_reml_deviance(&b, &y, &s_scaled);
assert!(
(v0 - va).abs() <= 1e-7 * (1.0 + v0.abs()),
"profiled REML must be invariant to penalty scale α={alpha} at κ={kappa}: \
V(S)={v0} vs V(αS)={va} — the Frobenius normalization is NOT gauge, \
so confound #2 (−r·log‖S_raw‖_F) WOULD be real"
);
}
}
}
#[test]
pub(crate) fn profiled_reml_identifies_curvature_sign_with_effective_length() {
let (data, centers) = oracle_disk_design_centers();
let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
let grid: Vec<f64> = (-30..=30).map(|i| f64::from(i) * 0.1).collect();
let argmin_sign = |kappa_true: f64, frozen: bool| -> (f64, f64) {
let y = oracle_response(data.view(), centers.view(), ell_ref, kappa_true, 11);
let mut best_k = f64::NAN;
let mut best_v = f64::INFINITY;
for &kappa in &grid {
let (b, s) =
oracle_design_and_penalty(data.view(), centers.view(), ell_ref, kappa, frozen);
let v = profiled_gaussian_reml_deviance(&b, &y, &s);
if v < best_v {
best_v = v;
best_k = kappa;
}
}
(best_k, best_v)
};
let (k_hyp, _) = argmin_sign(-2.0, false);
eprintln!("[κ-ident] L(κ): hyperbolic truth κ⋆=−2 → κ̂={k_hyp:.2}");
assert!(
k_hyp < 0.0,
"L(κ) profiled REML must identify NEGATIVE curvature for hyperbolic truth; got κ̂={k_hyp}"
);
let (k_sph, _) = argmin_sign(2.0, false);
eprintln!("[κ-ident] L(κ): spherical truth κ⋆=+2 → κ̂={k_sph:.2}");
assert!(
k_sph > 0.0,
"L(κ) profiled REML must identify POSITIVE curvature for spherical truth; got κ̂={k_sph}"
);
let (k_frozen_hyp, _) = argmin_sign(-2.0, true);
eprintln!("[κ-ident] frozen ℓ: hyperbolic truth κ⋆=−2 → κ̂={k_frozen_hyp:.2} (railing bug)");
assert!(
k_frozen_hyp > grid[grid.len() - 2],
"frozen-ℓ criterion is expected to RAIL hyperbolic truth to the +bound (the bug \
L(κ) fixes); if it no longer rails, the frozen-vs-scaled contrast is stale: κ̂={k_frozen_hyp}"
);
}
#[test]
pub(crate) fn effective_length_jet_matches_fd_of_implicit_solution() {
let (data, centers) = oracle_disk_design_centers();
let ell_ref = realized_constant_curvature_length_scale(centers.view(), 0.0).unwrap();
let fill_star = data_center_reference_fill(data.view(), centers.view(), ell_ref).unwrap();
let solve_l = |kappa: f64| -> f64 {
constant_curvature_effective_length_jet(data.view(), centers.view(), ell_ref, kappa)
.unwrap()
.0
};
let h = 1e-5_f64;
for &kappa in &[-1.5_f64, -0.5, -1e-7, 0.0, 1e-7, 0.8, 1.7] {
let (l, l1, l2) = constant_curvature_effective_length_jet(
data.view(),
centers.view(),
ell_ref,
kappa,
)
.unwrap();
let (g, ..) = data_center_fill_partials(data.view(), centers.view(), kappa, l).unwrap();
assert!(
(g - fill_star).abs() <= 1e-10 * (1.0 + fill_star.abs()),
"κ={kappa}: fill not held invariant: g(L,κ)={g} vs fill⋆={fill_star}"
);
if kappa == 0.0 {
assert!(
(l - ell_ref).abs() <= 1e-10 * ell_ref,
"L(0) must equal ℓ_ref; got {l} vs {ell_ref}"
);
}
let lp = solve_l(kappa + h);
let lm = solve_l(kappa - h);
let fd1 = (lp - lm) / (2.0 * h);
let fd2 = (lp - 2.0 * l + lm) / (h * h);
assert!(
(l1 - fd1).abs() <= 1e-5 * (1.0 + fd1.abs()),
"κ={kappa}: L′ analytic {l1} vs FD {fd1}"
);
assert!(
(l2 - fd2).abs() <= 1e-3 * (1.0 + fd2.abs()),
"κ={kappa}: L″ analytic {l2} vs FD {fd2}"
);
}
}
pub(crate) fn oracle_disk_design_centers() -> (Array2<f64>, Array2<f64>) {
let centers = ndarray::array![
[0.10, 0.05],
[-0.20, 0.15],
[0.30, -0.10],
[-0.05, -0.25],
[0.22, 0.20],
[-0.30, -0.05],
[0.05, 0.30],
[-0.15, 0.10],
];
let mut state = 0x2545_f491_4f6c_dd1d_u64;
let mut next = || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
((state >> 11) as f64 / (1u64 << 53) as f64 - 0.5) * 0.84
};
let n = 60usize;
let mut data = Array2::<f64>::zeros((n, 2));
for i in 0..n {
data[(i, 0)] = next();
data[(i, 1)] = next();
}
(data, centers)
}
pub(crate) fn oracle_response(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
ell_ref: f64,
kappa_true: f64,
seed: u64,
) -> Array1<f64> {
let (b, _s) = oracle_design_and_penalty(data, centers, ell_ref, kappa_true, false);
let p = b.ncols();
let mut state = 0x9e37_79b9_7f4a_7c15_u64 ^ seed.wrapping_mul(0x1000_0000_1b3);
let mut next = || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
(state >> 11) as f64 / (1u64 << 53) as f64 - 0.5
};
let beta: Array1<f64> = (0..p).map(|_| next() * 2.0).collect();
let mut y = b.dot(&beta);
for v in y.iter_mut() {
*v += next() * 0.05;
}
y
}
}