use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::geometry::constant_curvature::ConstantCurvature;
use crate::geometry::manifold::RiemannianManifold;
use crate::geometry::{GeometryResult, GrassmannManifold, SpdManifold, StiefelManifold};
fn parse_kv(inner: &str) -> Result<Vec<(String, String)>, String> {
let trimmed = inner.trim();
if trimmed.is_empty() {
return Ok(Vec::new());
}
let mut out = Vec::new();
for piece in trimmed.split(',') {
let piece = piece.trim();
if piece.is_empty() {
continue;
}
let (k, v) = piece
.split_once('=')
.ok_or_else(|| format!("response_geometry parameter {piece:?} must be key=value"))?;
out.push((k.trim().to_ascii_lowercase(), v.trim().to_string()));
}
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ResponseManifold {
Spd { n: usize },
Grassmann { k: usize, n: usize },
Stiefel { k: usize, n: usize },
Poincare { dim: usize, curvature: f64 },
ConstantCurvature { dim: usize, kappa: f64 },
}
impl ResponseManifold {
pub fn resolve(
kind: &str,
n: Option<usize>,
k: Option<usize>,
dim: Option<usize>,
curvature: Option<f64>,
) -> Result<Self, String> {
match kind {
"spd" => {
let n = n.ok_or_else(|| "response_geometry='spd' requires n".to_string())?;
if n == 0 {
return Err("response_geometry='spd' requires n >= 1".to_string());
}
Ok(Self::Spd { n })
}
"grassmann" => {
let k = k.ok_or_else(|| "response_geometry='grassmann' requires k".to_string())?;
let n = n.ok_or_else(|| "response_geometry='grassmann' requires n".to_string())?;
if k == 0 || n == 0 || k > n {
return Err("response_geometry='grassmann' requires 1 <= k <= n".to_string());
}
Ok(Self::Grassmann { k, n })
}
"stiefel" => {
let k = k.ok_or_else(|| "response_geometry='stiefel' requires k".to_string())?;
let n = n.ok_or_else(|| "response_geometry='stiefel' requires n".to_string())?;
if k == 0 || n == 0 || k > n {
return Err("response_geometry='stiefel' requires 1 <= k <= n".to_string());
}
Ok(Self::Stiefel { k, n })
}
"poincare" => {
let dim =
dim.ok_or_else(|| "response_geometry='poincare' requires dim".to_string())?;
if dim == 0 {
return Err("response_geometry='poincare' requires dim >= 1".to_string());
}
let curvature = curvature
.ok_or_else(|| "response_geometry='poincare' requires curvature".to_string())?;
if !(curvature.is_finite() && curvature < 0.0) {
return Err(
"response_geometry='poincare' requires finite curvature < 0".to_string()
);
}
Ok(Self::Poincare { dim, curvature })
}
"constant_curvature" => {
let dim = dim.ok_or_else(|| {
"response_geometry='constant_curvature' requires dim".to_string()
})?;
if dim == 0 {
return Err(
"response_geometry='constant_curvature' requires dim >= 1".to_string()
);
}
let kappa = curvature.unwrap_or(0.0);
if !kappa.is_finite() {
return Err(
"response_geometry='constant_curvature' requires finite curvature"
.to_string(),
);
}
Ok(Self::ConstantCurvature { dim, kappa })
}
other => Err(format!(
"response_geometry must be one of 'spd', 'grassmann', 'stiefel', 'poincare', \
'constant_curvature', 'spherical', or 'simplex'; got {other:?}"
)),
}
}
pub fn parse(label: &str, cols: usize) -> Result<Self, String> {
let lowered = label.trim().to_ascii_lowercase();
let (head, params) = match lowered.split_once('(') {
Some((h, rest)) => {
let rest = rest.trim_end();
let inner = rest
.strip_suffix(')')
.ok_or_else(|| format!("response_geometry {label:?}: missing closing ')'"))?;
(h.trim().to_string(), parse_kv(inner)?)
}
None => (lowered.clone(), Vec::new()),
};
let get_usize = |key: &str| -> Result<Option<usize>, String> {
for (k, v) in ¶ms {
if k == key {
let parsed: usize = v.parse().map_err(|_| {
format!("response_geometry {label:?}: {key} must be a non-negative integer")
})?;
return Ok(Some(parsed));
}
}
Ok(None)
};
let get_f64 = |key: &str| -> Result<Option<f64>, String> {
for (k, v) in ¶ms {
if k == key {
let parsed: f64 = v.parse().map_err(|_| {
format!("response_geometry {label:?}: {key} must be a real number")
})?;
return Ok(Some(parsed));
}
}
Ok(None)
};
match head.as_str() {
"spd" => {
let n = match get_usize("n")? {
Some(n) => n,
None => {
let r = (cols as f64).sqrt().round() as usize;
if r * r != cols {
return Err(format!(
"response_geometry='spd': {cols} response columns is not a perfect \
square; pass spd(n=...) explicitly"
));
}
r
}
};
Self::resolve("spd", Some(n), None, None, None)
}
"grassmann" | "stiefel" => {
let k = get_usize("k")?.ok_or_else(|| {
format!("response_geometry='{head}' requires k, e.g. {head}(k=2)")
})?;
let n = match get_usize("n")? {
Some(n) => n,
None => {
if k == 0 || cols % k != 0 {
return Err(format!(
"response_geometry='{head}': {cols} response columns is not \
divisible by k={k}; pass {head}(k=..,n=..) explicitly"
));
}
cols / k
}
};
Self::resolve(&head, Some(n), Some(k), None, None)
}
"poincare" => {
let dim = get_usize("dim")?.unwrap_or(cols);
let curvature = get_f64("curvature")?.unwrap_or(-1.0);
Self::resolve("poincare", None, None, Some(dim), Some(curvature))
}
"constant_curvature" => {
let dim = get_usize("dim")?.unwrap_or(cols);
let kappa = get_f64("kappa")?
.or_else(|| get_f64("curvature").ok().flatten())
.unwrap_or(0.0);
Self::resolve("constant_curvature", None, None, Some(dim), Some(kappa))
}
other => Err(format!(
"response_geometry must be one of 'spd', 'grassmann(k=..)', 'stiefel(k=..)', \
'poincare', 'constant_curvature', 'spherical', or 'simplex'; got {other:?}"
)),
}
}
pub fn canonical_label(&self) -> String {
match self {
Self::Spd { n } => format!("spd(n={n})"),
Self::Grassmann { k, n } => format!("grassmann(k={k},n={n})"),
Self::Stiefel { k, n } => format!("stiefel(k={k},n={n})"),
Self::Poincare { dim, curvature } => {
format!("poincare(dim={dim},curvature={curvature})")
}
Self::ConstantCurvature { dim, kappa } => {
format!("constant_curvature(dim={dim},kappa={kappa})")
}
}
}
pub fn ambient_dim(&self) -> usize {
match self {
Self::Spd { n } => n * n,
Self::Grassmann { k, n } | Self::Stiefel { k, n } => n * k,
Self::Poincare { dim, .. } | Self::ConstantCurvature { dim, .. } => *dim,
}
}
fn riemannian(&self) -> Option<Box<dyn RiemannianManifold>> {
match self {
Self::Spd { n } => Some(Box::new(SpdManifold::new(*n))),
Self::Grassmann { k, n } => GrassmannManifold::new(*k, *n)
.ok()
.map(|m| Box::new(m) as _),
Self::Stiefel { k, n } => StiefelManifold::new(*k, *n).ok().map(|m| Box::new(m) as _),
Self::ConstantCurvature { dim, kappa } => {
Some(Box::new(ConstantCurvature::new(*dim, *kappa)))
}
Self::Poincare { .. } => None,
}
}
fn log_point(
&self,
base: ArrayView1<'_, f64>,
value: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
match self {
Self::Poincare { curvature, .. } => {
crate::geometry::poincare::log_map(base, value, *curvature)
}
Self::ConstantCurvature { .. }
| Self::Spd { .. }
| Self::Grassmann { .. }
| Self::Stiefel { .. } => self
.riemannian()
.expect("riemannian response manifold")
.log_map(base, value),
}
}
fn exp_point(
&self,
base: ArrayView1<'_, f64>,
tangent: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
match self {
Self::Poincare { curvature, .. } => {
crate::geometry::poincare::exp_map(base, tangent, *curvature)
}
Self::ConstantCurvature { .. }
| Self::Spd { .. }
| Self::Grassmann { .. }
| Self::Stiefel { .. } => self
.riemannian()
.expect("riemannian response manifold")
.exp_map(base, tangent),
}
}
fn sq_metric_norm(
&self,
base: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> GeometryResult<f64> {
match self {
Self::Poincare { curvature, .. } => {
let lam = crate::geometry::poincare::conformal_factor(base, *curvature)?;
Ok(lam * lam * v.iter().map(|x| x * x).sum::<f64>())
}
Self::ConstantCurvature { .. }
| Self::Spd { .. }
| Self::Grassmann { .. }
| Self::Stiefel { .. } => {
let g = self
.riemannian()
.expect("riemannian response manifold")
.metric_tensor(base)?;
let gv = g.dot(&v);
Ok(v.dot(&gv).max(0.0))
}
}
}
}
pub fn response_log_map(
manifold: ResponseManifold,
values: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let ambient = manifold.ambient_dim();
let (n_rows, cols) = values.dim();
if base.len() != ambient {
return Err(format!(
"response geometry base point has length {}; expected {ambient}",
base.len()
));
}
if cols != ambient {
return Err(format!(
"response geometry values have {cols} columns; expected {ambient}"
));
}
let mut out = Array2::<f64>::zeros((n_rows, ambient));
for row in 0..n_rows {
let tangent = manifold
.log_point(base, values.row(row))
.map_err(|e| format!("response geometry log map (row {row}): {e}"))?;
out.row_mut(row).assign(&tangent);
}
Ok(out)
}
pub fn response_exp_map(
manifold: ResponseManifold,
tangent: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let ambient = manifold.ambient_dim();
let (n_rows, cols) = tangent.dim();
if base.len() != ambient {
return Err(format!(
"response geometry base point has length {}; expected {ambient}",
base.len()
));
}
if cols != ambient {
return Err(format!(
"response geometry tangent has {cols} columns; expected {ambient}"
));
}
if !tangent.iter().all(|v| v.is_finite()) {
return Err("response geometry tangent must contain only finite values".to_string());
}
let mut out = Array2::<f64>::zeros((n_rows, ambient));
for row in 0..n_rows {
let value = manifold
.exp_point(base, tangent.row(row))
.map_err(|e| format!("response geometry exp map (row {row}): {e}"))?;
out.row_mut(row).assign(&value);
}
Ok(out)
}
pub fn dispatch_log_map(
values: ArrayView2<'_, f64>,
label: &str,
base: Option<ArrayView1<'_, f64>>,
) -> Result<(Array2<f64>, Array1<f64>, String), String> {
let manifold = ResponseManifold::parse(label, values.ncols())?;
let base_point = match base {
Some(b) => b.to_owned(),
None => response_frechet_mean(manifold, values, None, 1.0e-12, 256)?,
};
let tangent = response_log_map(manifold, values, base_point.view())?;
Ok((tangent, base_point, manifold.canonical_label()))
}
pub fn dispatch_exp_map(
tangent: ArrayView2<'_, f64>,
label: &str,
base: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let manifold = ResponseManifold::parse(label, tangent.ncols())?;
response_exp_map(manifold, tangent, base)
}
pub fn response_frechet_mean(
manifold: ResponseManifold,
values: ArrayView2<'_, f64>,
weights: Option<ArrayView1<'_, f64>>,
tol: f64,
max_iter: usize,
) -> Result<Array1<f64>, String> {
let ambient = manifold.ambient_dim();
let (m, cols) = values.dim();
if m == 0 || cols != ambient {
return Err(format!(
"response geometry Fréchet mean: values must be M×{ambient} with M >= 1"
));
}
if !(tol.is_finite() && tol > 0.0) {
return Err("response geometry Fréchet mean tolerance must be finite and positive".into());
}
let w = crate::geometry::normalize_weights(m, weights)
.map_err(|_| "response geometry Fréchet mean: invalid weights".to_string())?;
let samples: Vec<Array1<f64>> = (0..m).map(|i| values.row(i).to_owned()).collect();
let dispersion = |p: ArrayView1<'_, f64>| -> Result<f64, String> {
let mut acc = 0.0_f64;
for (i, x) in samples.iter().enumerate() {
let lg = manifold
.log_point(p, x.view())
.map_err(|e| format!("response geometry Fréchet mean log map: {e}"))?;
let sq = manifold
.sq_metric_norm(p, lg.view())
.map_err(|e| format!("response geometry Fréchet mean metric: {e}"))?;
acc += w[i] * sq;
}
Ok(acc)
};
let mut p = manifold
.exp_point(samples[0].view(), Array1::<f64>::zeros(ambient).view())
.map_err(|e| format!("response geometry Fréchet mean init: {e}"))?;
{
let mut xi = Array1::<f64>::zeros(ambient);
for (i, x) in samples.iter().enumerate() {
let lg = manifold
.log_point(p.view(), x.view())
.map_err(|e| format!("response geometry Fréchet mean init log: {e}"))?;
xi.scaled_add(w[i], &lg);
}
p = manifold
.exp_point(p.view(), xi.view())
.map_err(|e| format!("response geometry Fréchet mean init step: {e}"))?;
}
let mut f_cur = dispersion(p.view())?;
let mut best_p = p.clone();
let mut best_grad = f64::INFINITY;
const STALL_REL: f64 = 5.0e-3;
const STALL_PATIENCE: usize = 10;
let mut stall = 0_usize;
const ARMIJO_C1: f64 = 1.0e-4;
const MAX_BACKTRACK_HALVINGS: usize = 60;
const ARMIJO_ROUNDOFF_EPS_MULTIPLE: f64 = 8.0;
for _ in 0..max_iter {
let mut xi = Array1::<f64>::zeros(ambient);
for (i, x) in samples.iter().enumerate() {
let lg = manifold
.log_point(p.view(), x.view())
.map_err(|e| format!("response geometry Fréchet mean log map: {e}"))?;
xi.scaled_add(w[i], &lg);
}
let grad_norm = manifold
.sq_metric_norm(p.view(), xi.view())
.map_err(|e| format!("response geometry Fréchet mean metric: {e}"))?
.sqrt();
if grad_norm <= tol {
return Ok(p);
}
let improved = grad_norm < best_grad * (1.0 - STALL_REL);
if grad_norm < best_grad {
best_grad = grad_norm;
best_p.assign(&p);
}
if improved {
stall = 0;
} else {
stall += 1;
if stall >= STALL_PATIENCE {
return Ok(best_p);
}
}
let pred = grad_norm * grad_norm;
let f_tol = ARMIJO_ROUNDOFF_EPS_MULTIPLE * f64::EPSILON * (1.0 + f_cur.abs());
let mut t = 1.0_f64;
let mut accepted = false;
for _ in 0..MAX_BACKTRACK_HALVINGS {
let step = &xi * t;
let cand = match manifold.exp_point(p.view(), step.view()) {
Ok(c) => c,
Err(_) => {
t *= 0.5;
continue;
}
};
let f_cand = match dispersion(cand.view()) {
Ok(f) => f,
Err(_) => {
t *= 0.5;
continue;
}
};
if f_cand <= f_cur - 2.0 * ARMIJO_C1 * t * pred + f_tol {
p = cand;
f_cur = f_cand;
accepted = true;
break;
}
t *= 0.5;
}
if !accepted {
return Ok(best_p);
}
}
Err("response geometry Fréchet mean did not reach stationarity within max_iter".into())
}
#[derive(Clone, Debug)]
pub struct ResponseCurvatureFit {
pub dim: usize,
pub kappa_hat: f64,
pub kappa_r2: f64,
pub characteristic_radius: f64,
pub base: Array1<f64>,
pub v_p_hat: f64,
pub railed_at_resolution_limit: bool,
pub profile_ci: crate::geometry::curvature_estimand::KappaProfileCi,
pub flatness: crate::geometry::curvature_estimand::FlatnessTest,
}
fn response_kappa_bounds(values: ArrayView2<'_, f64>) -> (f64, f64, f64) {
let (n_rows, dim) = values.dim();
let mut r2_max = 0.0_f64;
for row in values.outer_iter() {
let r2 = row.dot(&row);
if r2 > r2_max {
r2_max = r2;
}
}
let mut centroid = Array1::<f64>::zeros(dim.max(1));
if n_rows > 0 && dim > 0 {
for row in values.outer_iter() {
centroid += &row;
}
centroid.mapv_inplace(|v| v / n_rows as f64);
}
let mut s2_max = 0.0_f64;
if dim > 0 {
for row in values.outer_iter() {
let diff = &row - ¢roid;
let r2 = diff.dot(&diff);
if r2 > s2_max {
s2_max = r2;
}
}
}
if r2_max <= 0.0 && s2_max <= 0.0 {
return (-1.0e6, 1.0e6, 0.0);
}
let kappa_min = if r2_max > 0.0 {
-0.999 / r2_max
} else {
-1.0e6
};
let rho_max = 2.0 * s2_max.sqrt();
let kappa_max = if s2_max > 0.0 {
let edge = 0.9 * std::f64::consts::PI / rho_max;
edge * edge
} else {
1.0e6
};
(kappa_min, kappa_max, rho_max)
}
pub fn response_curvature_criterion(
values: ArrayView2<'_, f64>,
dim: usize,
kappa: f64,
) -> Result<(f64, Array1<f64>), String> {
if !kappa.is_finite() {
return Err("response curvature criterion: kappa must be finite".into());
}
let (n_rows, cols) = values.dim();
if n_rows == 0 || cols != dim || dim == 0 {
return Err(format!(
"response curvature criterion: values must be N×{dim} with N >= 1"
));
}
let mut base = Array1::<f64>::zeros(dim);
for row in values.outer_iter() {
base += &row;
}
base.mapv_inplace(|v| v / n_rows as f64);
let chart = ConstantCurvature::new(dim, kappa);
chart
.conformal_factor(base.view())
.map_err(|e| format!("response curvature criterion: base off chart: {e}"))?;
let d = dim as f64;
let mut dispersion = 0.0_f64; let mut ln_jac = 0.0_f64; let mut ln_lambda = 0.0_f64; for row in values.outer_iter() {
let s = chart
.distance(base.view(), row)
.map_err(|e| format!("response curvature criterion distance: {e}"))?;
dispersion += s * s;
ln_jac += chart.jacobian_radial(s).max(1.0e-300).ln();
let lam = chart
.conformal_factor(row)
.map_err(|e| format!("response curvature criterion conformal factor: {e}"))?;
ln_lambda += lam.ln();
}
let nobs = (n_rows * dim) as f64;
let disp = dispersion.max(1.0e-300 * nobs.max(1.0));
let v_p = 0.5 * nobs * (disp / nobs).ln() + ln_jac - d * ln_lambda;
Ok((v_p, base))
}
pub fn fit_response_curvature(
values: ArrayView2<'_, f64>,
dim: usize,
level: f64,
tol: f64,
max_iter: usize,
) -> Result<ResponseCurvatureFit, String> {
if dim == 0 {
return Err("constant-curvature response geometry requires dim >= 1".into());
}
let (n_rows, cols) = values.dim();
if n_rows == 0 || cols != dim {
return Err(format!(
"constant-curvature response geometry: values must be N×{dim} with N >= 1"
));
}
if !(level > 0.0 && level < 1.0) {
return Err("response curvature CI level must lie in (0, 1)".into());
}
let (kappa_min, kappa_max, rho_max) = response_kappa_bounds(values);
let mut v_p = |kappa: f64| -> Result<f64, String> {
response_curvature_criterion(values, dim, kappa).map(|(v, _)| v)
};
const GOLDEN_INV: f64 = 0.618_033_988_749_894_8; let mut a = kappa_min;
let mut b = kappa_max;
let mut c = b - GOLDEN_INV * (b - a);
let mut d_pt = a + GOLDEN_INV * (b - a);
let mut fc = v_p(c)?;
let mut fd = v_p(d_pt)?;
let ktol = (tol * (kappa_max - kappa_min)).max(tol).max(1.0e-12);
for _ in 0..max_iter {
if (b - a).abs() <= ktol {
break;
}
if fc < fd {
b = d_pt;
d_pt = c;
fd = fc;
c = b - GOLDEN_INV * (b - a);
fc = v_p(c)?;
} else {
a = c;
c = d_pt;
fc = fd;
d_pt = a + GOLDEN_INV * (b - a);
fd = v_p(d_pt)?;
}
}
let kappa_hat = 0.5 * (a + b);
let (v_p_hat, base) = response_curvature_criterion(values, dim, kappa_hat)?;
let span = kappa_max - kappa_min;
let rail_margin = (0.02 * span).max(ktol);
let railed_at_resolution_limit = kappa_hat >= kappa_max - rail_margin;
let kappa_r2 = kappa_hat * rho_max * rho_max;
let h = (1.0e-3 * (kappa_max - kappa_min)).max(1.0e-6);
let v_pp = if (kappa_hat - h) > kappa_min && (kappa_hat + h) < kappa_max {
let vp = v_p(kappa_hat + h)?;
let vm = v_p(kappa_hat - h)?;
(vp - 2.0 * v_p_hat + vm) / (h * h)
} else {
f64::NAN
};
let profile_ci = crate::geometry::curvature_estimand::profile_ci_walk(
&mut v_p, kappa_hat, v_pp, kappa_min, kappa_max, level, ktol,
)?;
let flatness = crate::geometry::curvature_estimand::flatness_lr_test(&mut v_p, kappa_hat)?;
Ok(ResponseCurvatureFit {
dim,
kappa_hat,
kappa_r2,
characteristic_radius: rho_max,
railed_at_resolution_limit,
base,
v_p_hat,
profile_ci,
flatness,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, array};
fn round_trip(manifold: ResponseManifold, values: Array2<f64>) {
let base =
response_frechet_mean(manifold, values.view(), None, 1e-12, 500).expect("frechet mean");
let tangent = response_log_map(manifold, values.view(), base.view()).expect("log map");
let back = response_exp_map(manifold, tangent.view(), base.view()).expect("exp map");
for row in 0..values.nrows() {
for col in 0..values.ncols() {
assert!(
(back[[row, col]] - values[[row, col]]).abs() < 1e-6,
"{manifold:?} exp∘log mismatch at ({row},{col}): {} vs {}",
back[[row, col]],
values[[row, col]]
);
}
}
}
#[test]
fn spd_round_trip_and_mean() {
let values = array![
[2.0, 0.0, 0.0, 1.0],
[1.0, 0.3, 0.3, 2.0],
[3.0, -0.5, -0.5, 1.5],
];
round_trip(ResponseManifold::Spd { n: 2 }, values);
}
#[test]
fn grassmann_round_trip_and_mean() {
let values = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.6, 0.8, 0.0],];
round_trip(ResponseManifold::Grassmann { k: 1, n: 3 }, values);
}
#[test]
fn stiefel_round_trip_and_mean() {
let values = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.6, 0.8],];
round_trip(ResponseManifold::Stiefel { k: 1, n: 3 }, values);
}
#[test]
fn poincare_round_trip_and_mean() {
let values = array![[0.1, 0.2], [-0.3, 0.1], [0.2, -0.25],];
round_trip(
ResponseManifold::Poincare {
dim: 2,
curvature: -1.0,
},
values,
);
}
#[test]
fn resolver_rejects_bad_shapes() {
assert!(ResponseManifold::resolve("grassmann", Some(2), Some(3), None, None).is_err());
assert!(ResponseManifold::resolve("spd", None, None, None, None).is_err());
assert!(ResponseManifold::resolve("poincare", None, None, Some(2), Some(1.0)).is_err());
assert!(ResponseManifold::resolve("nonsense", None, None, None, None).is_err());
assert_eq!(
ResponseManifold::resolve("spd", Some(3), None, None, None).unwrap(),
ResponseManifold::Spd { n: 3 }
);
}
#[test]
fn parse_infers_shapes_from_columns() {
assert_eq!(
ResponseManifold::parse("spd", 9).unwrap(),
ResponseManifold::Spd { n: 3 }
);
assert!(ResponseManifold::parse("spd", 8).is_err());
assert_eq!(
ResponseManifold::parse("grassmann(k=2)", 10).unwrap(),
ResponseManifold::Grassmann { k: 2, n: 5 }
);
assert_eq!(
ResponseManifold::parse("Stiefel( k = 2 , n = 4 )", 8).unwrap(),
ResponseManifold::Stiefel { k: 2, n: 4 }
);
assert!(ResponseManifold::parse("grassmann", 10).is_err());
assert!(ResponseManifold::parse("grassmann(k=3)", 10).is_err());
assert_eq!(
ResponseManifold::parse("poincare", 3).unwrap(),
ResponseManifold::Poincare {
dim: 3,
curvature: -1.0
}
);
assert_eq!(
ResponseManifold::parse("poincare(curvature=-0.5)", 3).unwrap(),
ResponseManifold::Poincare {
dim: 3,
curvature: -0.5
}
);
assert!(ResponseManifold::parse("hyperbolic", 3).is_err());
}
#[test]
fn dispatch_round_trips_through_user_label() {
let cases: Vec<(&str, Array2<f64>)> = vec![
(
"spd",
array![
[2.0, 0.0, 0.0, 1.0],
[1.0, 0.3, 0.3, 2.0],
[3.0, -0.5, -0.5, 1.5],
],
),
(
"grassmann(k=1)",
array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.6, 0.8, 0.0]],
),
(
"stiefel(k=1)",
array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.6, 0.8]],
),
("poincare", array![[0.1, 0.2], [-0.3, 0.1], [0.2, -0.25]]),
];
for (label, values) in cases {
let (tangent, base, canonical) =
dispatch_log_map(values.view(), label, None).expect("dispatch log");
assert!(canonical.starts_with(label.split('(').next().unwrap()));
let back = dispatch_exp_map(tangent.view(), label, base.view()).expect("dispatch exp");
for row in 0..values.nrows() {
for col in 0..values.ncols() {
assert!(
(back[[row, col]] - values[[row, col]]).abs() < 1e-6,
"{label} exp∘log mismatch at ({row},{col}): {} vs {}",
back[[row, col]],
values[[row, col]]
);
}
}
}
}
#[test]
fn ambient_dim_matches_layout() {
assert_eq!(ResponseManifold::Spd { n: 3 }.ambient_dim(), 9);
assert_eq!(ResponseManifold::Grassmann { k: 2, n: 5 }.ambient_dim(), 10);
assert_eq!(ResponseManifold::Stiefel { k: 2, n: 4 }.ambient_dim(), 8);
assert_eq!(
ResponseManifold::Poincare {
dim: 4,
curvature: -1.0
}
.ambient_dim(),
4
);
}
struct DetNormal {
state: u64,
spare: Option<f64>,
}
impl DetNormal {
fn new(seed: u64) -> Self {
Self {
state: seed | 1,
spare: None,
}
}
fn u01(&mut self) -> f64 {
let mut x = self.state;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.state = x;
let v = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
((v >> 11) as f64 + 0.5) / (1u64 << 53) as f64
}
fn normal(&mut self) -> f64 {
if let Some(z) = self.spare.take() {
return z;
}
let u1 = self.u01().max(1e-12);
let u2 = self.u01();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
self.spare = Some(r * theta.sin());
r * theta.cos()
}
}
fn synth_cloud(dim: usize, k_star: f64, n: usize, sigma: f64, seed: u64) -> Array2<f64> {
let manifold = ResponseManifold::ConstantCurvature { dim, kappa: k_star };
let center = Array1::<f64>::zeros(dim);
let mut rng = DetNormal::new(seed);
let mut values = Array2::<f64>::zeros((n, dim));
for i in 0..n {
let t: Array1<f64> = (0..dim).map(|_| sigma * rng.normal()).collect();
let y = manifold
.exp_point(center.view(), t.view())
.expect("exp tangent to response");
values.row_mut(i).assign(&y);
}
let mut mean = Array1::<f64>::zeros(dim);
for row in values.outer_iter() {
mean += &row;
}
mean.mapv_inplace(|v| v / n as f64);
for mut row in values.outer_iter_mut() {
row -= &mean;
}
values
}
#[test]
fn fit_response_curvature_is_reparameterization_invariant() {
let dim = 3usize;
let sigma = 0.15;
let n = 300usize;
let k_stars = [-1.5_f64, -0.5, 0.0, 0.6, 1.2];
let mut k_hats = Vec::new();
for (idx, &k_star) in k_stars.iter().enumerate() {
let values = synth_cloud(dim, k_star, n, sigma, 0xC0FFEE ^ (idx as u64 + 1));
let (kmin, kmax, _rho) = response_kappa_bounds(values.view());
let fit = fit_response_curvature(values.view(), dim, 0.95, 1e-12, 256)
.expect("response curvature fit");
k_hats.push(fit.kappa_hat);
let span = kmax - kmin;
assert!(
fit.kappa_hat > kmin + 0.02 * span && fit.kappa_hat < kmax - 0.02 * span,
"κ⋆={k_star}: κ̂={} railed to bracket [{kmin}, {kmax}]",
fit.kappa_hat
);
assert!(
(fit.kappa_hat - k_star).abs() <= 0.6 + 0.3 * k_star.abs(),
"κ⋆={k_star}: κ̂={} too far",
fit.kappa_hat
);
assert!(
fit.profile_ci.ci_lo <= fit.kappa_hat && fit.kappa_hat <= fit.profile_ci.ci_hi,
"κ⋆={k_star}: CI [{}, {}] excludes κ̂={}",
fit.profile_ci.ci_lo,
fit.profile_ci.ci_hi,
fit.kappa_hat
);
assert!(fit.flatness.lr_stat >= 0.0);
assert!(
fit.flatness.p_value > 0.0 && fit.flatness.p_value < 1.0,
"κ⋆={k_star}: degenerate flatness p={}",
fit.flatness.p_value
);
if k_star == 0.0 {
assert!(
fit.flatness.lr_stat < 3.84,
"flat truth wrongly rejected: lr={}",
fit.flatness.lr_stat
);
}
let alpha = 1.5_f64;
let scaled = values.mapv(|v| alpha * v);
let fit_scaled = fit_response_curvature(scaled.view(), dim, 0.95, 1e-12, 256)
.expect("scaled response curvature fit");
let expected = fit.kappa_hat / (alpha * alpha);
assert!(
(fit_scaled.kappa_hat - expected).abs() <= 0.05 + 0.05 * expected.abs(),
"κ⋆={k_star}: rescale covariance broken: κ̂(αy)={} vs κ̂(y)/α²={}",
fit_scaled.kappa_hat,
expected
);
}
for w in k_hats.windows(2) {
assert!(w[1] > w[0] - 0.05, "κ̂ not monotone in κ⋆: {:?}", k_hats);
}
}
#[test]
fn fit_response_curvature_d1_uses_conformal_term_only() {
let sigma = 0.12;
let n = 400usize;
for &k_star in &[-1.0_f64, 0.0, 0.8] {
let values = synth_cloud(1, k_star, n, sigma, 0xD1 ^ (k_star.to_bits()));
let (kmin, kmax, _rho) = response_kappa_bounds(values.view());
let fit = fit_response_curvature(values.view(), 1, 0.95, 1e-12, 256)
.expect("d=1 curvature fit");
let span = kmax - kmin;
assert!(
fit.kappa_hat > kmin + 0.01 * span && fit.kappa_hat < kmax - 0.01 * span,
"d=1 κ⋆={k_star}: κ̂={} railed to [{kmin},{kmax}]",
fit.kappa_hat
);
assert!(
fit.profile_ci.ci_lo <= fit.kappa_hat && fit.kappa_hat <= fit.profile_ci.ci_hi,
"d=1 κ⋆={k_star}: CI excludes κ̂"
);
assert!(fit.kappa_hat.is_finite() && fit.v_p_hat.is_finite());
}
}
#[test]
fn response_curvature_criterion_rejects_boundary_probes() {
let values = array![[0.5_f64, 0.0], [-0.4, 0.3], [0.1, -0.5]];
let r2_max = values
.outer_iter()
.map(|r| r.dot(&r))
.fold(0.0_f64, f64::max);
let kappa_edge = -1.0 / r2_max;
assert!(
response_curvature_criterion(values.view(), 2, kappa_edge).is_err(),
"criterion must reject the hyperbolic chart edge κ=−1/R²"
);
assert!(
response_curvature_criterion(values.view(), 2, 1.5 * kappa_edge).is_err(),
"criterion must reject past the hyperbolic chart edge"
);
let (v, _) = response_curvature_criterion(values.view(), 2, 0.9 * kappa_edge)
.expect("interior κ valid");
assert!(v.is_finite());
assert!(response_curvature_criterion(values.view(), 2, f64::NAN).is_err());
assert!(response_curvature_criterion(values.view(), 2, f64::INFINITY).is_err());
}
}