use ndarray::{Array1, Array2, ArrayView2, Axis};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::geometry::constant_curvature::{ConstantCurvature, distance_kappa_jet};
use super::{
BasisBuildResult, BasisError, BasisMetadata, CenterStrategy, PenaltyCandidate, PenaltySource,
filter_active_penalty_candidates_with_ops, normalize_penalty, select_centers_by_strategy,
weighted_coefficient_sum_to_zero_transform,
};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum ConstantCurvatureIdentifiability {
#[default]
CenterSumToZero,
FrozenTransform { transform: Array2<f64> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantCurvatureBasisSpec {
pub center_strategy: CenterStrategy,
pub kappa: f64,
pub length_scale: f64,
pub double_penalty: bool,
#[serde(default)]
pub identifiability: ConstantCurvatureIdentifiability,
}
impl Default for ConstantCurvatureBasisSpec {
fn default() -> Self {
Self {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 50 },
kappa: 0.0,
length_scale: 0.0,
double_penalty: true,
identifiability: ConstantCurvatureIdentifiability::CenterSumToZero,
}
}
}
fn validate_chart_points(
points: ArrayView2<'_, f64>,
kappa: f64,
what: &str,
) -> Result<(), BasisError> {
for (i, row) in points.outer_iter().enumerate() {
let mut nx2 = 0.0_f64;
for &v in row.iter() {
if !v.is_finite() {
crate::bail_invalid_basis!(
"constant-curvature {what} row {i} has a non-finite coordinate"
);
}
nx2 += v * v;
}
if 1.0 + kappa * nx2 <= 0.0 {
crate::bail_invalid_basis!(
"constant-curvature {what} row {i} lies outside the κ-stereographic chart \
(need 1 + κ·‖x‖² > 0; got κ = {kappa}, ‖x‖² = {nx2}); for κ < 0 the chart is \
the open ball ‖x‖ < 1/√(−κ)"
);
}
}
Ok(())
}
pub fn constant_curvature_kernel_matrix(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
length_scale: f64,
) -> Result<Array2<f64>, BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"constant-curvature kernel dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
if !(length_scale.is_finite() && length_scale > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature kernel needs a positive finite length_scale; got {length_scale}"
);
}
validate_chart_points(data, kappa, "data")?;
validate_chart_points(centers, kappa, "centers")?;
let manifold = ConstantCurvature::new(data.ncols(), kappa);
let mut out = Array2::<f64>::zeros((data.nrows(), centers.nrows()));
out.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.try_for_each(|(i, mut row)| -> Result<(), BasisError> {
for (j, c) in centers.outer_iter().enumerate() {
let d = manifold.distance(data.row(i), c).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature distance failed at (row {i}, center {j}): {e}"
))
})?;
row[j] = (-d / length_scale).exp();
}
Ok(())
})?;
Ok(out)
}
pub fn constant_curvature_kernel_kappa_jets(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kappa: f64,
length_scale: f64,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
if data.ncols() != centers.ncols() {
crate::bail_dim_basis!(
"constant-curvature kernel-jet dimension mismatch: data d={} centers d={}",
data.ncols(),
centers.ncols()
);
}
if !(length_scale.is_finite() && length_scale > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature kernel jets need a positive finite length_scale; got {length_scale}"
);
}
validate_chart_points(data, kappa, "data")?;
validate_chart_points(centers, kappa, "centers")?;
let manifold = ConstantCurvature::new(data.ncols(), kappa);
let n = data.nrows();
let m = centers.nrows();
let mut value = Array2::<f64>::zeros((n, m));
let mut dk = Array2::<f64>::zeros((n, m));
let mut dkk = Array2::<f64>::zeros((n, m));
let rows: Vec<(usize, Vec<(f64, f64, f64)>)> = (0..n)
.into_par_iter()
.map(|i| -> Result<(usize, Vec<(f64, f64, f64)>), BasisError> {
let mut row = Vec::with_capacity(m);
for (j, c) in centers.outer_iter().enumerate() {
let (d, d1, d2) = distance_kappa_jet(&manifold, data.row(i), c).map_err(|e| {
BasisError::InvalidInput(format!(
"constant-curvature distance κ-jet failed at (row {i}, center {j}): {e}"
))
})?;
let k = (-d / length_scale).exp();
let g = d1 / length_scale;
row.push((k, -g * k, (g * g - d2 / length_scale) * k));
}
Ok((i, row))
})
.collect::<Result<Vec<_>, BasisError>>()?;
for (i, row) in rows {
for (j, (k, k1, k2)) in row.into_iter().enumerate() {
value[(i, j)] = k;
dk[(i, j)] = k1;
dkk[(i, j)] = k2;
}
}
Ok((value, dk, dkk))
}
pub fn realized_constant_curvature_length_scale(
centers: ArrayView2<'_, f64>,
spec_length_scale: f64,
) -> Result<f64, BasisError> {
if spec_length_scale.is_finite() && spec_length_scale > 0.0 {
return Ok(spec_length_scale);
}
if spec_length_scale != 0.0 {
crate::bail_invalid_basis!(
"constant-curvature length_scale must be positive (or 0.0 for auto); got {spec_length_scale}"
);
}
let m = centers.nrows();
if m < 2 {
return Err(BasisError::InsufficientColumnsForConstraint { found: m });
}
let mut dists: Vec<f64> = Vec::with_capacity(m * (m - 1) / 2);
for i in 0..m {
for j in (i + 1)..m {
let mut s = 0.0_f64;
for k in 0..centers.ncols() {
let dlt = centers[(i, k)] - centers[(j, k)];
s += dlt * dlt;
}
dists.push(2.0 * s.sqrt());
}
}
dists.sort_by(|a, b| a.partial_cmp(b).expect("finite chart distances"));
let median = dists[dists.len() / 2];
if !(median.is_finite() && median > 0.0) {
crate::bail_invalid_basis!(
"constant-curvature auto length_scale failed: centers are degenerate \
(median pairwise chart distance = {median})"
);
}
Ok(median)
}
pub fn build_constant_curvature_basis(
data: ArrayView2<'_, f64>,
spec: &ConstantCurvatureBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
if data.ncols() == 0 {
crate::bail_invalid_basis!("constant-curvature smooth needs at least one feature column");
}
if !spec.kappa.is_finite() {
crate::bail_invalid_basis!("constant-curvature smooth needs a finite kappa");
}
validate_chart_points(data, spec.kappa, "data")?;
let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
if centers.nrows() < 2 {
return Err(BasisError::InsufficientColumnsForConstraint {
found: centers.nrows(),
});
}
validate_chart_points(centers.view(), spec.kappa, "centers")?;
let length_scale = realized_constant_curvature_length_scale(centers.view(), spec.length_scale)?;
let raw_penalty =
constant_curvature_kernel_matrix(centers.view(), centers.view(), spec.kappa, length_scale)?;
let z = match &spec.identifiability {
ConstantCurvatureIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != centers.nrows() {
crate::bail_dim_basis!(
"frozen constant-curvature identifiability transform mismatch: {} centers but transform has {} rows",
centers.nrows(),
transform.nrows()
);
}
transform.clone()
}
ConstantCurvatureIdentifiability::CenterSumToZero => {
let weights = Array1::<f64>::ones(centers.nrows());
weighted_coefficient_sum_to_zero_transform(weights.view())?
}
};
let penalty = z.t().dot(&raw_penalty).dot(&z);
let raw_design =
constant_curvature_kernel_matrix(data, centers.view(), spec.kappa, length_scale)?;
let design = crate::matrix::DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
raw_design.dot(&z),
));
let (penalty_norm, c_primary) = normalize_penalty(&((&penalty + &penalty.t()) * 0.5));
let mut candidates = vec![PenaltyCandidate {
matrix: penalty_norm,
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: c_primary,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
let ridge = Array2::<f64>::eye(design.ncols());
let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
candidates.push(PenaltyCandidate {
matrix: ridge_norm,
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: c_ridge,
kronecker_factors: None,
op: None,
});
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::ConstantCurvature {
centers,
kappa: spec.kappa,
length_scale,
constraint_transform: Some(z),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}